From e57f4669e18b23d3c1fad93722757b1eee7d9fca Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 19 May 2026 21:45:43 +0100 Subject: [PATCH 1/7] Update rev-gateway for IAM integration (#940) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit service.py: - Constructor takes **config (same pattern as api-gateway) instead of individual args - Creates IamAuth and calls await self.auth.start() before the message loop - Passes auth to both ConfigReceiver and MessageDispatcher - Uses add_pubsub_args / add_logging_args instead of hand-rolled Pulsar args - Passes timeout through dispatcher.py: - Accepts auth and timeout parameters - Passes both to DispatcherManager — fixes the missing auth argument that would have crashed on startup The remote end's requests now go through the same IAM authentication path as api-gateway. Token validation, workspace resolution, and permissions all work identically regardless of which direction initiated the connection. Fixed tests — the test now passes auth and timeout to MessageDispatcher and verifies they're forwarded to DispatcherManager. Update rev gateway dispatcher to align with IAM. A "token" parameter must be passed with each message. Fix websocket relay to align with rev-gateway changes, conforms to the api-gateway protocol. --- dev-tools/tests/relay/websocket_relay.py | 332 +++++---- .../test_dispatcher_semaphore.py | 25 +- tests/unit/test_rev_gateway/__init__.py | 0 .../unit/test_rev_gateway/test_dispatcher.py | 394 +++++------ .../test_rev_gateway_service.py | 661 ++++++++---------- .../trustgraph/rev_gateway/dispatcher.py | 162 +++-- .../trustgraph/rev_gateway/service.py | 205 +++--- 7 files changed, 914 insertions(+), 865 deletions(-) create mode 100644 tests/unit/test_rev_gateway/__init__.py diff --git a/dev-tools/tests/relay/websocket_relay.py b/dev-tools/tests/relay/websocket_relay.py index d537f7da..6495ee49 100644 --- a/dev-tools/tests/relay/websocket_relay.py +++ b/dev-tools/tests/relay/websocket_relay.py @@ -3,208 +3,278 @@ WebSocket Relay Test Harness This script creates a relay server with two WebSocket endpoints: -- /in - for test clients to connect to -- /out - for reverse gateway to connect to +- /in - for test clients to connect to (speaks api-gateway protocol) +- /out - for reverse gateway to connect to (speaks rev-gateway protocol) -Messages are bidirectionally relayed between the two connections. +Clients on /in authenticate with a first-frame auth message: + {"type": "auth", "token": "..."} + +The relay stores the token and injects it into each subsequent message +before forwarding to /out. Responses from /out are forwarded back to +the originating /in connection unchanged. Usage: python websocket_relay.py [--port PORT] [--host HOST] """ import asyncio +import json import logging import argparse from aiohttp import web, WSMsgType -import weakref -from typing import Optional, Set +from typing import Dict, Optional -# Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger("websocket_relay") + +class InConnection: + def __init__(self, ws, conn_id): + self.ws = ws + self.conn_id = conn_id + self.token: Optional[str] = None + self.authenticated = False + + class WebSocketRelay: - """WebSocket relay that forwards messages between 'in' and 'out' connections""" - + def __init__(self): - self.in_connections: Set = weakref.WeakSet() - self.out_connections: Set = weakref.WeakSet() - + self.in_connections: Dict[str, InConnection] = {} + self.out_connections: set = set() + self._conn_counter = 0 + + def _next_conn_id(self): + self._conn_counter += 1 + return f"conn-{self._conn_counter}" + async def handle_in_connection(self, request): - """Handle incoming connections on /in endpoint""" ws = web.WebSocketResponse() await ws.prepare(request) - - self.in_connections.add(ws) - logger.info(f"New 'in' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}") - + + conn_id = self._next_conn_id() + conn = InConnection(ws, conn_id) + self.in_connections[conn_id] = conn + logger.info( + f"New 'in' connection {conn_id}. " + f"Total in: {len(self.in_connections)}, " + f"out: {len(self.out_connections)}" + ) + try: async for msg in ws: if msg.type == WSMsgType.TEXT: - data = msg.data - logger.info(f"IN → OUT: {data}") - await self._forward_to_out(data) - elif msg.type == WSMsgType.BINARY: - data = msg.data - logger.info(f"IN → OUT: {len(data)} bytes (binary)") - await self._forward_to_out(data, binary=True) + await self._handle_in_message(conn, msg.data) elif msg.type == WSMsgType.ERROR: - logger.error(f"WebSocket error on 'in' connection: {ws.exception()}") + logger.error( + f"WebSocket error on 'in' connection " + f"{conn_id}: {ws.exception()}" + ) break else: break except Exception as e: - logger.error(f"Error in 'in' connection handler: {e}") + logger.error( + f"Error in 'in' connection {conn_id}: {e}" + ) finally: - logger.info(f"'in' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}") - + del self.in_connections[conn_id] + logger.info( + f"'in' connection {conn_id} closed. " + f"Remaining in: {len(self.in_connections)}, " + f"out: {len(self.out_connections)}" + ) + return ws - - async def handle_out_connection(self, request): - """Handle outgoing connections on /out endpoint""" - ws = web.WebSocketResponse() - await ws.prepare(request) - - self.out_connections.add(ws) - logger.info(f"New 'out' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}") - + + async def _handle_in_message(self, conn, data): try: - async for msg in ws: - if msg.type == WSMsgType.TEXT: - data = msg.data - logger.info(f"OUT → IN: {data}") - await self._forward_to_in(data) - elif msg.type == WSMsgType.BINARY: - data = msg.data - logger.info(f"OUT → IN: {len(data)} bytes (binary)") - await self._forward_to_in(data, binary=True) - elif msg.type == WSMsgType.ERROR: - logger.error(f"WebSocket error on 'out' connection: {ws.exception()}") - break - else: - break - except Exception as e: - logger.error(f"Error in 'out' connection handler: {e}") - finally: - logger.info(f"'out' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}") - - return ws - - async def _forward_to_out(self, data, binary=False): - """Forward message from 'in' to all 'out' connections""" - if not self.out_connections: - logger.warning("No 'out' connections available to forward message") + message = json.loads(data) + except json.JSONDecodeError: + logger.warning( + f"{conn.conn_id}: received non-JSON message" + ) return - - closed_connections = [] + + if isinstance(message, dict) and message.get("type") == "auth": + conn.token = message.get("token", "") + conn.authenticated = True + logger.info(f"{conn.conn_id}: authenticated") + await conn.ws.send_json({ + "type": "auth-ok", + "workspace": "relayed", + }) + return + + if not conn.authenticated: + await conn.ws.send_json({ + "error": { + "message": "auth required", + "type": "auth-required", + }, + "complete": True, + }) + return + + message["token"] = conn.token + message["_relay_conn"] = conn.conn_id + + forwarded = json.dumps(message) + logger.info(f"IN {conn.conn_id} → OUT: {forwarded}") + await self._forward_to_out(forwarded) + + async def handle_out_connection(self, request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + self.out_connections.add(ws) + logger.info( + f"New 'out' connection. " + f"Total in: {len(self.in_connections)}, " + f"out: {len(self.out_connections)}" + ) + + try: + async for msg in ws: + if msg.type == WSMsgType.TEXT: + await self._handle_out_message(msg.data) + elif msg.type == WSMsgType.ERROR: + logger.error( + f"WebSocket error on 'out' connection: " + f"{ws.exception()}" + ) + break + else: + break + except Exception as e: + logger.error(f"Error in 'out' connection: {e}") + finally: + self.out_connections.discard(ws) + logger.info( + f"'out' connection closed. " + f"Remaining in: {len(self.in_connections)}, " + f"out: {len(self.out_connections)}" + ) + + return ws + + async def _handle_out_message(self, data): + try: + message = json.loads(data) + except json.JSONDecodeError: + logger.warning("OUT: received non-JSON message") + return + + conn_id = message.pop("_relay_conn", None) + + forwarded = json.dumps(message) + logger.info(f"OUT → IN {conn_id or 'broadcast'}: {forwarded}") + + if conn_id and conn_id in self.in_connections: + conn = self.in_connections[conn_id] + try: + if not conn.ws.closed: + await conn.ws.send_str(forwarded) + except Exception as e: + logger.error( + f"Error forwarding to 'in' {conn_id}: {e}" + ) + else: + await self._broadcast_to_in(forwarded) + + async def _broadcast_to_in(self, data): + closed = [] + for conn_id, conn in list(self.in_connections.items()): + try: + if conn.ws.closed: + closed.append(conn_id) + continue + await conn.ws.send_str(data) + except Exception as e: + logger.error( + f"Error broadcasting to 'in' {conn_id}: {e}" + ) + closed.append(conn_id) + for conn_id in closed: + self.in_connections.pop(conn_id, None) + + async def _forward_to_out(self, data): + closed = [] for ws in list(self.out_connections): try: if ws.closed: - closed_connections.append(ws) + closed.append(ws) continue - - if binary: - await ws.send_bytes(data) - else: - await ws.send_str(data) + await ws.send_str(data) except Exception as e: - logger.error(f"Error forwarding to 'out' connection: {e}") - closed_connections.append(ws) - - # Clean up closed connections - for ws in closed_connections: - if ws in self.out_connections: - self.out_connections.discard(ws) - - async def _forward_to_in(self, data, binary=False): - """Forward message from 'out' to all 'in' connections""" - if not self.in_connections: - logger.warning("No 'in' connections available to forward message") - return - - closed_connections = [] - for ws in list(self.in_connections): - try: - if ws.closed: - closed_connections.append(ws) - continue - - if binary: - await ws.send_bytes(data) - else: - await ws.send_str(data) - except Exception as e: - logger.error(f"Error forwarding to 'in' connection: {e}") - closed_connections.append(ws) - - # Clean up closed connections - for ws in closed_connections: - if ws in self.in_connections: - self.in_connections.discard(ws) + logger.error(f"Error forwarding to 'out': {e}") + closed.append(ws) + for ws in closed: + self.out_connections.discard(ws) + async def create_app(relay): - """Create the web application with routes""" app = web.Application() - - # Add routes - app.router.add_get('/in', relay.handle_in_connection) + + app.router.add_get('/in/api/v1/socket', relay.handle_in_connection) app.router.add_get('/out', relay.handle_out_connection) - - # Add a simple status endpoint + async def status(request): - status_info = { + return web.json_response({ 'in_connections': len(relay.in_connections), 'out_connections': len(relay.out_connections), - 'status': 'running' - } - return web.json_response(status_info) - + 'status': 'running', + }) + app.router.add_get('/status', status) - app.router.add_get('/', status) # Root also shows status - + app.router.add_get('/', status) + return app + def main(): parser = argparse.ArgumentParser( description="WebSocket Relay Test Harness" ) parser.add_argument( - '--host', + '--host', default='localhost', - help='Host to bind to (default: localhost)' + help='Host to bind to (default: localhost)', ) parser.add_argument( - '--port', - type=int, + '--port', + type=int, default=8080, - help='Port to bind to (default: 8080)' + help='Port to bind to (default: 8080)', ) parser.add_argument( '--verbose', '-v', action='store_true', - help='Enable verbose logging' + help='Enable verbose logging', ) - + args = parser.parse_args() - + if args.verbose: logging.getLogger().setLevel(logging.DEBUG) - + relay = WebSocketRelay() - + print(f"Starting WebSocket Relay on {args.host}:{args.port}") - print(f" 'in' endpoint: ws://{args.host}:{args.port}/in") + print(f" 'in' endpoint: ws://{args.host}:{args.port}/in/api/v1/socket") print(f" 'out' endpoint: ws://{args.host}:{args.port}/out") print(f" Status: http://{args.host}:{args.port}/status") print() - print("Usage:") - print(f" Test client connects to: ws://{args.host}:{args.port}/in") - print(f" Reverse gateway connects to: ws://{args.host}:{args.port}/out") - + print("Client protocol (same as api-gateway):") + print(' 1. Connect to /in/api/v1/socket') + print(' 2. Send: {"type": "auth", "token": "tg_..."}') + print(' 3. Receive: {"type": "auth-ok", "workspace": "relayed"}') + print(' 4. Send requests as normal') + web.run_app(create_app(relay), host=args.host, port=args.port) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/unit/test_concurrency/test_dispatcher_semaphore.py b/tests/unit/test_concurrency/test_dispatcher_semaphore.py index 6a1ae8ab..a5374678 100644 --- a/tests/unit/test_concurrency/test_dispatcher_semaphore.py +++ b/tests/unit/test_concurrency/test_dispatcher_semaphore.py @@ -25,16 +25,17 @@ class TestSemaphoreEnforcement: max_concurrent = 0 processing_event = asyncio.Event() - async def slow_process(message): + async def slow_process(message, sender): nonlocal concurrent_count, max_concurrent concurrent_count += 1 max_concurrent = max(max_concurrent, concurrent_count) await asyncio.sleep(0.05) concurrent_count -= 1 - return {"id": message.get("id"), "response": {"ok": True}} dispatcher._process_message = slow_process + sender = AsyncMock() + # Launch more tasks than max_workers messages = [ {"id": f"msg-{i}", "service": "test", "request": {}} @@ -42,7 +43,7 @@ class TestSemaphoreEnforcement: ] tasks = [ - asyncio.create_task(dispatcher.handle_message(m)) + asyncio.create_task(dispatcher.handle_message(m, sender)) for m in messages ] @@ -66,17 +67,17 @@ class TestSemaphoreEnforcement: original_process = dispatcher._process_message - async def tracking_process(message): + async def tracking_process(message, sender): nonlocal task_was_tracked # During processing, our task should be in active_tasks if len(dispatcher.active_tasks) > 0: task_was_tracked = True - return {"id": message.get("id"), "response": {"ok": True}} dispatcher._process_message = tracking_process await dispatcher.handle_message( - {"id": "test", "service": "test", "request": {}} + {"id": "test", "service": "test", "request": {}}, + AsyncMock(), ) assert task_was_tracked @@ -88,7 +89,7 @@ class TestSemaphoreEnforcement: """Semaphore should be released even if processing raises.""" dispatcher = MessageDispatcher(max_workers=2) - async def failing_process(message): + async def failing_process(message, sender): raise RuntimeError("process failed") dispatcher._process_message = failing_process @@ -96,7 +97,8 @@ class TestSemaphoreEnforcement: # Should not deadlock — semaphore must be released on error with pytest.raises(RuntimeError): await dispatcher.handle_message( - {"id": "test", "service": "test", "request": {}} + {"id": "test", "service": "test", "request": {}}, + AsyncMock(), ) # Semaphore should be back at max @@ -109,17 +111,18 @@ class TestSemaphoreEnforcement: order = [] - async def ordered_process(message): + async def ordered_process(message, sender): msg_id = message["id"] order.append(f"start-{msg_id}") await asyncio.sleep(0.02) order.append(f"end-{msg_id}") - return {"id": msg_id, "response": {"ok": True}} dispatcher._process_message = ordered_process + sender = AsyncMock() + messages = [{"id": str(i), "service": "t", "request": {}} for i in range(3)] - tasks = [asyncio.create_task(dispatcher.handle_message(m)) for m in messages] + tasks = [asyncio.create_task(dispatcher.handle_message(m, sender)) for m in messages] await asyncio.gather(*tasks) # With semaphore=1, each message should complete before next starts diff --git a/tests/unit/test_rev_gateway/__init__.py b/tests/unit/test_rev_gateway/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/test_rev_gateway/test_dispatcher.py b/tests/unit/test_rev_gateway/test_dispatcher.py index 2a9c8df0..6df786cf 100644 --- a/tests/unit/test_rev_gateway/test_dispatcher.py +++ b/tests/unit/test_rev_gateway/test_dispatcher.py @@ -3,275 +3,279 @@ Tests for Reverse Gateway Dispatcher """ import pytest -from unittest.mock import MagicMock, AsyncMock, patch +import asyncio +from unittest.mock import MagicMock, AsyncMock, patch, ANY -from trustgraph.rev_gateway.dispatcher import WebSocketResponder, MessageDispatcher - - -class TestWebSocketResponder: - """Test cases for WebSocketResponder class""" - - def test_websocket_responder_initialization(self): - """Test WebSocketResponder initialization""" - responder = WebSocketResponder() - - assert responder.response is None - assert responder.completed is False - - @pytest.mark.asyncio - async def test_websocket_responder_send_method(self): - """Test WebSocketResponder send method""" - responder = WebSocketResponder() - - test_response = {"data": "test response"} - - # Call send method - await responder.send(test_response) - - # Verify response was stored - assert responder.response == test_response - - @pytest.mark.asyncio - async def test_websocket_responder_call_method(self): - """Test WebSocketResponder __call__ method""" - responder = WebSocketResponder() - - test_response = {"result": "success"} - test_completed = True - - # Call the responder - await responder(test_response, test_completed) - - # Verify response and completed status were set - assert responder.response == test_response - assert responder.completed == test_completed - - @pytest.mark.asyncio - async def test_websocket_responder_call_method_with_false_completion(self): - """Test WebSocketResponder __call__ method with incomplete response""" - responder = WebSocketResponder() - - test_response = {"partial": "data"} - test_completed = False - - # Call the responder - await responder(test_response, test_completed) - - # Verify response was set and completed is True (since send() always sets completed=True) - assert responder.response == test_response - assert responder.completed is True +from trustgraph.rev_gateway.dispatcher import MessageDispatcher class TestMessageDispatcher: """Test cases for MessageDispatcher class""" def test_message_dispatcher_initialization_with_defaults(self): - """Test MessageDispatcher initialization with default parameters""" dispatcher = MessageDispatcher() - + assert dispatcher.max_workers == 10 assert dispatcher.semaphore._value == 10 assert dispatcher.active_tasks == set() assert dispatcher.backend is None + assert dispatcher.auth is None assert dispatcher.dispatcher_manager is None assert len(dispatcher.service_mapping) > 0 def test_message_dispatcher_initialization_with_custom_workers(self): - """Test MessageDispatcher initialization with custom max_workers""" dispatcher = MessageDispatcher(max_workers=5) - + assert dispatcher.max_workers == 5 assert dispatcher.semaphore._value == 5 @patch('trustgraph.rev_gateway.dispatcher.DispatcherManager') - def test_message_dispatcher_initialization_with_pulsar_client(self, mock_dispatcher_manager): - """Test MessageDispatcher initialization with pulsar_client and config_receiver""" + def test_message_dispatcher_initialization_with_backend( + self, mock_dispatcher_manager, + ): mock_backend = MagicMock() mock_config_receiver = MagicMock() + mock_auth = MagicMock() mock_dispatcher_instance = MagicMock() mock_dispatcher_manager.return_value = mock_dispatcher_instance - + dispatcher = MessageDispatcher( max_workers=8, config_receiver=mock_config_receiver, - backend=mock_backend + backend=mock_backend, + auth=mock_auth, + timeout=300, ) - + assert dispatcher.max_workers == 8 assert dispatcher.backend == mock_backend + assert dispatcher.auth == mock_auth assert dispatcher.dispatcher_manager == mock_dispatcher_instance mock_dispatcher_manager.assert_called_once_with( - mock_backend, mock_config_receiver, prefix="rev-gateway" + mock_backend, mock_config_receiver, + auth=mock_auth, prefix="rev-gateway", timeout=300, ) def test_message_dispatcher_service_mapping(self): - """Test MessageDispatcher service mapping contains expected services""" dispatcher = MessageDispatcher() - + expected_services = [ "text-completion", "graph-rag", "agent", "embeddings", "graph-embeddings", "triples", "document-load", "text-load", - "flow", "knowledge", "config", "librarian", "document-rag" + "flow", "knowledge", "config", "librarian", "document-rag", ] - + for service in expected_services: assert service in dispatcher.service_mapping - - # Test specific mappings - assert dispatcher.service_mapping["text-completion"] == "text-completion" + assert dispatcher.service_mapping["document-load"] == "document" assert dispatcher.service_mapping["text-load"] == "text-document" @pytest.mark.asyncio - async def test_message_dispatcher_handle_message_without_dispatcher_manager(self): - """Test MessageDispatcher handle_message without dispatcher manager""" + async def test_handle_message_without_dispatcher_manager(self): dispatcher = MessageDispatcher() - - test_message = { - "id": "test-123", - "service": "test-service", - "request": {"data": "test"} - } - - result = await dispatcher.handle_message(test_message) - - assert result["id"] == "test-123" - assert "error" in result["response"] - assert "DispatcherManager not available" in result["response"]["error"] + dispatcher.auth = MagicMock() + dispatcher.auth.authenticate = AsyncMock( + return_value=MagicMock(workspace="default") + ) + + sender = AsyncMock() + + await dispatcher.handle_message( + {"id": "test-1", "service": "test", "request": {}}, + sender, + ) + + sender.assert_called_once() + sent = sender.call_args[0][0] + assert sent["id"] == "test-1" + assert sent["error"]["message"] == "DispatcherManager not available" + assert sent["error"]["type"] == "error" + assert sent["complete"] is True @pytest.mark.asyncio - async def test_message_dispatcher_handle_message_with_exception(self): - """Test MessageDispatcher handle_message with exception during processing""" - mock_dispatcher_manager = MagicMock() - mock_dispatcher_manager.invoke_global_service = AsyncMock(side_effect=Exception("Test error")) - + async def test_handle_message_auth_failure(self): dispatcher = MessageDispatcher() - dispatcher.dispatcher_manager = mock_dispatcher_manager - - test_message = { - "id": "test-456", - "service": "text-completion", - "request": {"prompt": "test"} - } - - with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}): - result = await dispatcher.handle_message(test_message) - - assert result["id"] == "test-456" - assert "error" in result["response"] - assert "Test error" in result["response"]["error"] + dispatcher.auth = MagicMock() + dispatcher.auth.authenticate = AsyncMock( + side_effect=Exception("auth failure") + ) + dispatcher.dispatcher_manager = MagicMock() + + sender = AsyncMock() + + await dispatcher.handle_message( + {"id": "test-2", "token": "bad", "service": "test", "request": {}}, + sender, + ) + + sender.assert_called_once() + sent = sender.call_args[0][0] + assert sent["id"] == "test-2" + assert "auth failure" in sent["error"]["message"] + assert sent["complete"] is True @pytest.mark.asyncio - async def test_message_dispatcher_handle_message_global_service(self): - """Test MessageDispatcher handle_message with global service""" - mock_dispatcher_manager = MagicMock() - mock_dispatcher_manager.invoke_global_service = AsyncMock() - mock_responder = MagicMock() - mock_responder.completed = True - mock_responder.response = {"result": "success"} - + async def test_handle_message_global_service(self): + mock_dm = MagicMock() + mock_dm.invoke_global_service = AsyncMock() + dispatcher = MessageDispatcher() - dispatcher.dispatcher_manager = mock_dispatcher_manager - - test_message = { - "id": "test-789", - "service": "text-completion", - "request": {"prompt": "hello"} - } - - with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}): - with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder): - result = await dispatcher.handle_message(test_message) - - assert result["id"] == "test-789" - assert result["response"] == {"result": "success"} - mock_dispatcher_manager.invoke_global_service.assert_called_once() + dispatcher.dispatcher_manager = mock_dm + dispatcher.auth = MagicMock() + dispatcher.auth.authenticate = AsyncMock( + return_value=MagicMock(workspace="ws1") + ) + + sender = AsyncMock() + + with patch( + 'trustgraph.gateway.dispatch.manager.global_dispatchers', + {"text-completion": True}, + ): + await dispatcher.handle_message( + { + "id": "test-3", + "token": "tg_key", + "service": "text-completion", + "request": {"prompt": "hello"}, + }, + sender, + ) + + mock_dm.invoke_global_service.assert_called_once() + args, kwargs = mock_dm.invoke_global_service.call_args + assert args[0] == {"prompt": "hello"} + assert args[2] == "text-completion" + assert kwargs["workspace"] == "ws1" @pytest.mark.asyncio - async def test_message_dispatcher_handle_message_flow_service(self): - """Test MessageDispatcher handle_message with flow service""" - mock_dispatcher_manager = MagicMock() - mock_dispatcher_manager.invoke_flow_service = AsyncMock() - mock_responder = MagicMock() - mock_responder.completed = True - mock_responder.response = {"data": "flow_result"} - + async def test_handle_message_flow_service(self): + mock_dm = MagicMock() + mock_dm.invoke_flow_service = AsyncMock() + dispatcher = MessageDispatcher() - dispatcher.dispatcher_manager = mock_dispatcher_manager - - test_message = { - "id": "test-flow-123", - "service": "document-rag", - "request": {"query": "test"}, - "flow": "custom-flow" - } - - with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}): - with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder): - result = await dispatcher.handle_message(test_message) - - assert result["id"] == "test-flow-123" - assert result["response"] == {"data": "flow_result"} - mock_dispatcher_manager.invoke_flow_service.assert_called_once_with( - {"query": "test"}, mock_responder, "custom-flow", "document-rag" + dispatcher.dispatcher_manager = mock_dm + dispatcher.auth = MagicMock() + dispatcher.auth.authenticate = AsyncMock( + return_value=MagicMock(workspace="ws2") + ) + + sender = AsyncMock() + + with patch( + 'trustgraph.gateway.dispatch.manager.global_dispatchers', {}, + ): + await dispatcher.handle_message( + { + "id": "test-4", + "token": "tg_key", + "service": "document-rag", + "request": {"query": "test"}, + "flow": "my-flow", + }, + sender, + ) + + mock_dm.invoke_flow_service.assert_called_once_with( + {"query": "test"}, ANY, "ws2", "my-flow", "document-rag", ) @pytest.mark.asyncio - async def test_message_dispatcher_handle_message_incomplete_response(self): - """Test MessageDispatcher handle_message with incomplete response""" - mock_dispatcher_manager = MagicMock() - mock_dispatcher_manager.invoke_flow_service = AsyncMock() - mock_responder = MagicMock() - mock_responder.completed = False - mock_responder.response = None - + async def test_handle_message_responder_sends_frames(self): + mock_dm = MagicMock() + + async def fake_invoke(data, responder, svc, workspace=None): + await responder({"partial": 1}, False) + await responder({"partial": 2}, True) + + mock_dm.invoke_global_service = AsyncMock(side_effect=fake_invoke) + dispatcher = MessageDispatcher() - dispatcher.dispatcher_manager = mock_dispatcher_manager - - test_message = { - "id": "test-incomplete", - "service": "agent", - "request": {"input": "test"} + dispatcher.dispatcher_manager = mock_dm + dispatcher.auth = MagicMock() + dispatcher.auth.authenticate = AsyncMock( + return_value=MagicMock(workspace="ws1") + ) + + sender = AsyncMock() + + with patch( + 'trustgraph.gateway.dispatch.manager.global_dispatchers', + {"text-completion": True}, + ): + await dispatcher.handle_message( + { + "id": "test-5", + "token": "tg_key", + "service": "text-completion", + "request": {"prompt": "hi"}, + }, + sender, + ) + + assert sender.call_count == 2 + first = sender.call_args_list[0][0][0] + second = sender.call_args_list[1][0][0] + + assert first == { + "id": "test-5", "response": {"partial": 1}, "complete": False, + } + assert second == { + "id": "test-5", "response": {"partial": 2}, "complete": True, } - - with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}): - with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder): - result = await dispatcher.handle_message(test_message) - - assert result["id"] == "test-incomplete" - assert result["response"] == {"error": "No response received"} @pytest.mark.asyncio - async def test_message_dispatcher_shutdown(self): - """Test MessageDispatcher shutdown method""" - import asyncio - + async def test_handle_message_workspace_from_identity(self): + mock_dm = MagicMock() + mock_dm.invoke_flow_service = AsyncMock() + dispatcher = MessageDispatcher() - - # Create actual async tasks + dispatcher.dispatcher_manager = mock_dm + dispatcher.auth = MagicMock() + dispatcher.auth.authenticate = AsyncMock( + return_value=MagicMock(workspace="derived-ws") + ) + + sender = AsyncMock() + + with patch( + 'trustgraph.gateway.dispatch.manager.global_dispatchers', {}, + ): + await dispatcher.handle_message( + { + "id": "test-6", + "token": "tg_key", + "service": "agent", + "request": {"question": "test"}, + "flow": "default", + }, + sender, + ) + + args = mock_dm.invoke_flow_service.call_args[0] + assert args[2] == "derived-ws" + + @pytest.mark.asyncio + async def test_shutdown(self): + dispatcher = MessageDispatcher() + async def dummy_task(): await asyncio.sleep(0.01) - return "done" - + task1 = asyncio.create_task(dummy_task()) task2 = asyncio.create_task(dummy_task()) dispatcher.active_tasks = {task1, task2} - - # Call shutdown + await dispatcher.shutdown() - - # Verify tasks were completed + assert task1.done() assert task2.done() - assert len(dispatcher.active_tasks) == 2 # Tasks remain in set but are completed @pytest.mark.asyncio - async def test_message_dispatcher_shutdown_with_no_tasks(self): - """Test MessageDispatcher shutdown with no active tasks""" + async def test_shutdown_with_no_tasks(self): dispatcher = MessageDispatcher() - - # Call shutdown with no active tasks + await dispatcher.shutdown() - - # Should complete without error - assert dispatcher.active_tasks == set() \ No newline at end of file + + assert dispatcher.active_tasks == set() diff --git a/tests/unit/test_rev_gateway/test_rev_gateway_service.py b/tests/unit/test_rev_gateway/test_rev_gateway_service.py index 23aff18e..e2d1045d 100644 --- a/tests/unit/test_rev_gateway/test_rev_gateway_service.py +++ b/tests/unit/test_rev_gateway/test_rev_gateway_service.py @@ -8,22 +8,38 @@ from unittest.mock import MagicMock, AsyncMock, patch, Mock from aiohttp import WSMsgType, ClientWebSocketResponse import json -from trustgraph.rev_gateway.service import ReverseGateway, parse_args, run +from trustgraph.rev_gateway.service import ReverseGateway, run + + +MOCK_PATCHES = [ + 'trustgraph.rev_gateway.service.IamAuth', + 'trustgraph.rev_gateway.service.ConfigReceiver', + 'trustgraph.rev_gateway.service.MessageDispatcher', + 'trustgraph.rev_gateway.service.get_pubsub', +] + + +def make_gateway(**overrides): + config = {"websocket_uri": "ws://localhost:7650/out"} + config.update(overrides) + return ReverseGateway(**config) class TestReverseGateway: """Test cases for ReverseGateway class""" - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_initialization_defaults(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway initialization with default parameters""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() - + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_initialization_defaults( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() + assert gateway.websocket_uri == "ws://localhost:7650/out" assert gateway.host == "localhost" assert gateway.port == 7650 @@ -33,25 +49,22 @@ class TestReverseGateway: assert gateway.max_workers == 10 assert gateway.running is False assert gateway.reconnect_delay == 3.0 - assert gateway.pulsar_host == "pulsar://pulsar:6650" - assert gateway.pulsar_api_key is None - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_initialization_custom_params(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway initialization with custom parameters""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway( + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_initialization_custom_params( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway( websocket_uri="wss://example.com:8080/websocket", max_workers=20, - pulsar_host="pulsar://custom:6650", - pulsar_api_key="test-key", - pulsar_listener="test-listener" ) - + assert gateway.websocket_uri == "wss://example.com:8080/websocket" assert gateway.host == "example.com" assert gateway.port == 8080 @@ -59,340 +72,360 @@ class TestReverseGateway: assert gateway.path == "/websocket" assert gateway.url == "wss://example.com:8080/websocket" assert gateway.max_workers == 20 - assert gateway.pulsar_host == "pulsar://custom:6650" - assert gateway.pulsar_api_key == "test-key" - assert gateway.pulsar_listener == "test-listener" - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_initialization_with_missing_path(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway initialization with WebSocket URI missing path""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway(websocket_uri="ws://example.com") - + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_initialization_with_missing_path( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway(websocket_uri="ws://example.com") + assert gateway.path == "/ws" assert gateway.url == "ws://example.com/ws" - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_initialization_invalid_scheme(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway initialization with invalid WebSocket scheme""" + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_initialization_invalid_scheme( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): with pytest.raises(ValueError, match="WebSocket URI must use ws:// or wss:// scheme"): - ReverseGateway(websocket_uri="http://example.com") + make_gateway(websocket_uri="http://example.com") - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_initialization_missing_hostname(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway initialization with missing hostname""" + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_initialization_missing_hostname( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): with pytest.raises(ValueError, match="WebSocket URI must include hostname"): - ReverseGateway(websocket_uri="ws://") + make_gateway(websocket_uri="ws://") - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_pulsar_client_with_auth(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway creates backend with authentication""" + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_iam_auth_created( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): mock_backend = MagicMock() mock_get_pubsub.return_value = mock_backend - gateway = ReverseGateway( - pulsar_api_key="test-key", - pulsar_listener="test-listener" + gateway = make_gateway(id="test-rev-gw") + + mock_iam_auth.assert_called_once_with( + backend=mock_backend, + id="test-rev-gw", ) - # Verify get_pubsub was called with the correct parameters - mock_get_pubsub.assert_called_once_with( - pulsar_host="pulsar://pulsar:6650", - pulsar_api_key="test-key", - pulsar_listener="test-listener" + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_config_receiver_gets_auth( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_backend = MagicMock() + mock_get_pubsub.return_value = mock_backend + mock_auth_instance = MagicMock() + mock_iam_auth.return_value = mock_auth_instance + + gateway = make_gateway() + + mock_config_receiver.assert_called_once_with( + mock_backend, auth=mock_auth_instance, ) - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @patch('trustgraph.rev_gateway.service.ClientSession') @pytest.mark.asyncio - async def test_reverse_gateway_connect_success(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway successful connection""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - + async def test_reverse_gateway_connect_success( + self, mock_session_class, mock_get_pubsub, + mock_dispatcher, mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + mock_session = AsyncMock() mock_ws = AsyncMock() mock_session.ws_connect.return_value = mock_ws mock_session_class.return_value = mock_session - - gateway = ReverseGateway() - + + gateway = make_gateway() + result = await gateway.connect() - + assert result is True assert gateway.session == mock_session assert gateway.ws == mock_ws mock_session.ws_connect.assert_called_once_with(gateway.url) - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @patch('trustgraph.rev_gateway.service.ClientSession') @pytest.mark.asyncio - async def test_reverse_gateway_connect_failure(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway connection failure""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - + async def test_reverse_gateway_connect_failure( + self, mock_session_class, mock_get_pubsub, + mock_dispatcher, mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + mock_session = AsyncMock() mock_session.ws_connect.side_effect = Exception("Connection failed") mock_session_class.return_value = mock_session - - gateway = ReverseGateway() - + + gateway = make_gateway() + result = await gateway.connect() - + assert result is False - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_disconnect(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway disconnect""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() - - # Mock websocket and session + async def test_reverse_gateway_disconnect( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() + mock_ws = AsyncMock() mock_ws.closed = False mock_session = AsyncMock() mock_session.closed = False - + gateway.ws = mock_ws gateway.session = mock_session - + await gateway.disconnect() - + mock_ws.close.assert_called_once() mock_session.close.assert_called_once() assert gateway.ws is None assert gateway.session is None - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_send_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway send message""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() - - # Mock websocket + async def test_reverse_gateway_send_message( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() + mock_ws = AsyncMock() mock_ws.closed = False gateway.ws = mock_ws - + test_message = {"id": "test", "data": "hello"} - + await gateway.send_message(test_message) - + mock_ws.send_str.assert_called_once_with(json.dumps(test_message)) - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_send_message_closed_connection(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway send message with closed connection""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() - - # Mock closed websocket + async def test_reverse_gateway_send_message_closed_connection( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() + mock_ws = AsyncMock() mock_ws.closed = True gateway.ws = mock_ws - + test_message = {"id": "test", "data": "hello"} - + await gateway.send_message(test_message) - - # Should not call send_str on closed connection + mock_ws.send_str.assert_not_called() - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_handle_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway handle message""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - mock_dispatcher_instance = AsyncMock() - mock_dispatcher_instance.handle_message.return_value = {"response": "success"} - mock_dispatcher.return_value = mock_dispatcher_instance - - gateway = ReverseGateway() - - # Mock send_message - gateway.send_message = AsyncMock() - - test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}' - - await gateway.handle_message(test_message) - - mock_dispatcher_instance.handle_message.assert_called_once_with({ - "id": "test", - "service": "test-service", - "request": {"data": "test"} - }) - gateway.send_message.assert_called_once_with({"response": "success"}) + async def test_reverse_gateway_handle_message( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + mock_dispatcher_instance = AsyncMock() + mock_dispatcher.return_value = mock_dispatcher_instance + + gateway = make_gateway() - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - @pytest.mark.asyncio - async def test_reverse_gateway_handle_message_invalid_json(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway handle message with invalid JSON""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() - - # Mock send_message gateway.send_message = AsyncMock() - - test_message = 'invalid json' - - # Should not raise exception + + test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}' + await gateway.handle_message(test_message) - - # Should not call send_message due to error + + mock_dispatcher_instance.handle_message.assert_called_once_with( + { + "id": "test", + "service": "test-service", + "request": {"data": "test"}, + }, + gateway.send_message, + ) + + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + @pytest.mark.asyncio + async def test_reverse_gateway_handle_message_invalid_json( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() + + gateway.send_message = AsyncMock() + + await gateway.handle_message('invalid json') + gateway.send_message.assert_not_called() - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_listen_text_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway listen with text message""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() + async def test_reverse_gateway_listen_text_message( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() gateway.running = True - - # Mock websocket + mock_ws = AsyncMock() mock_ws.closed = False gateway.ws = mock_ws - - # Mock handle_message + gateway.handle_message = AsyncMock() - - # Mock message + mock_msg = MagicMock() mock_msg.type = WSMsgType.TEXT mock_msg.data = '{"test": "message"}' - - # Mock receive to return message once, then raise exception to stop loop + mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")] - - # listen() catches exceptions and breaks, so no exception should be raised + await gateway.listen() - + gateway.handle_message.assert_called_once_with('{"test": "message"}') - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_listen_binary_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway listen with binary message""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() + async def test_reverse_gateway_listen_binary_message( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() gateway.running = True - - # Mock websocket + mock_ws = AsyncMock() mock_ws.closed = False gateway.ws = mock_ws - - # Mock handle_message + gateway.handle_message = AsyncMock() - - # Mock message + mock_msg = MagicMock() mock_msg.type = WSMsgType.BINARY mock_msg.data = b'{"test": "binary"}' - - # Mock receive to return message once, then raise exception to stop loop + mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")] - - # listen() catches exceptions and breaks, so no exception should be raised + await gateway.listen() - + gateway.handle_message.assert_called_once_with('{"test": "binary"}') - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_listen_close_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway listen with close message""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() + async def test_reverse_gateway_listen_close_message( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() gateway.running = True - - # Mock websocket + mock_ws = AsyncMock() mock_ws.closed = False gateway.ws = mock_ws - - # Mock handle_message + gateway.handle_message = AsyncMock() - - # Mock message + mock_msg = MagicMock() mock_msg.type = WSMsgType.CLOSE - - # Mock receive to return close message + mock_ws.receive.return_value = mock_msg - + await gateway.listen() - - # Should not call handle_message for close message + gateway.handle_message.assert_not_called() - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_shutdown(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway shutdown""" + async def test_reverse_gateway_shutdown( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): mock_backend = MagicMock() mock_get_pubsub.return_value = mock_backend mock_dispatcher_instance = AsyncMock() mock_dispatcher.return_value = mock_dispatcher_instance - gateway = ReverseGateway() + gateway = make_gateway() gateway.running = True - # Mock disconnect gateway.disconnect = AsyncMock() await gateway.shutdown() @@ -402,46 +435,50 @@ class TestReverseGateway: gateway.disconnect.assert_called_once() mock_backend.close.assert_called_once() - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_stop(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway stop""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_stop( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() gateway.running = True - + gateway.stop() - + assert gateway.running is False class TestReverseGatewayRun: """Test cases for ReverseGateway run method""" - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_run_successful_cycle(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway run method with successful connect/listen cycle""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - + async def test_reverse_gateway_run_successful_cycle( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + mock_auth_instance = AsyncMock() + mock_iam_auth.return_value = mock_auth_instance + mock_config_receiver_instance = AsyncMock() mock_config_receiver.return_value = mock_config_receiver_instance - - gateway = ReverseGateway() - - # Mock methods - gateway.connect = AsyncMock(return_value=True) + + gateway = make_gateway() + gateway.listen = AsyncMock() gateway.disconnect = AsyncMock() gateway.shutdown = AsyncMock() - - # Stop after one iteration + call_count = 0 async def mock_connect(): nonlocal call_count @@ -451,91 +488,13 @@ class TestReverseGatewayRun: else: gateway.running = False return False - + gateway.connect = mock_connect - + await gateway.run() - + + mock_auth_instance.start.assert_called_once() mock_config_receiver_instance.start.assert_called_once() gateway.listen.assert_called_once() - # disconnect is called twice: once in the main loop, once in shutdown assert gateway.disconnect.call_count == 2 gateway.shutdown.assert_called_once() - - -class TestReverseGatewayArgs: - """Test cases for argument parsing and run function""" - - def test_parse_args_defaults(self): - """Test parse_args with default values""" - import sys - - # Mock sys.argv - original_argv = sys.argv - sys.argv = ['reverse-gateway'] - - try: - args = parse_args() - - assert args.websocket_uri is None - assert args.max_workers == 10 - assert args.pulsar_host is None - assert args.pulsar_api_key is None - assert args.pulsar_listener is None - finally: - sys.argv = original_argv - - def test_parse_args_custom_values(self): - """Test parse_args with custom values""" - import sys - - # Mock sys.argv - original_argv = sys.argv - sys.argv = [ - 'reverse-gateway', - '--websocket-uri', 'ws://custom:8080/ws', - '--max-workers', '20', - '--pulsar-host', 'pulsar://custom:6650', - '--pulsar-api-key', 'test-key', - '--pulsar-listener', 'test-listener' - ] - - try: - args = parse_args() - - assert args.websocket_uri == 'ws://custom:8080/ws' - assert args.max_workers == 20 - assert args.pulsar_host == 'pulsar://custom:6650' - assert args.pulsar_api_key == 'test-key' - assert args.pulsar_listener == 'test-listener' - finally: - sys.argv = original_argv - - @patch('trustgraph.rev_gateway.service.ReverseGateway') - @patch('asyncio.run') - def test_run_function(self, mock_asyncio_run, mock_gateway_class): - """Test run function""" - import sys - - # Mock sys.argv - original_argv = sys.argv - sys.argv = ['reverse-gateway', '--max-workers', '15'] - - try: - mock_gateway_instance = MagicMock() - mock_gateway_instance.url = "ws://localhost:7650/out" - mock_gateway_instance.pulsar_host = "pulsar://pulsar:6650" - mock_gateway_class.return_value = mock_gateway_instance - - run() - - mock_gateway_class.assert_called_once_with( - websocket_uri=None, - max_workers=15, - pulsar_host=None, - pulsar_api_key=None, - pulsar_listener=None - ) - mock_asyncio_run.assert_called_once_with(mock_gateway_instance.run()) - finally: - sys.argv = original_argv \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py b/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py index 986558ec..5e71a0e5 100644 --- a/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py +++ b/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py @@ -1,130 +1,140 @@ import asyncio import logging import uuid -from typing import Dict, Any, Optional -from trustgraph.messaging import TranslatorRegistry +from typing import Dict, Any, Optional, Callable, Awaitable from ..gateway.dispatch.manager import DispatcherManager logger = logging.getLogger("dispatcher") -logger.setLevel(logging.INFO) -class WebSocketResponder: - """Simple responder that captures response for websocket return""" - def __init__(self): - self.response = None - self.completed = False - - async def send(self, data): - """Capture the response data""" - self.response = data - self.completed = True - - async def __call__(self, data, final=False): - """Make the responder callable for compatibility with requestor""" - await self.send(data) - if final: - self.completed = True + +class _TokenShim: + def __init__(self, token): + self.headers = ( + {"Authorization": f"Bearer {token}"} if token else {} + ) + class MessageDispatcher: - def __init__(self, max_workers: int = 10, config_receiver=None, backend=None): + def __init__(self, max_workers=10, config_receiver=None, backend=None, + auth=None, timeout=120): self.max_workers = max_workers self.semaphore = asyncio.Semaphore(max_workers) self.active_tasks = set() self.backend = backend + self.auth = auth - # Use DispatcherManager for flow and service management - if backend and config_receiver: - self.dispatcher_manager = DispatcherManager(backend, config_receiver, prefix="rev-gateway") + if backend and config_receiver and auth: + self.dispatcher_manager = DispatcherManager( + backend, config_receiver, + auth=auth, + prefix="rev-gateway", + timeout=timeout, + ) else: self.dispatcher_manager = None - logger.warning("No backend or config_receiver provided - using fallback mode") - - # Service name mapping from websocket protocol to translator registry + logger.warning( + "Missing backend, config_receiver, or auth " + "— using fallback mode" + ) + self.service_mapping = { "text-completion": "text-completion", - "graph-rag": "graph-rag", + "graph-rag": "graph-rag", "agent": "agent", "embeddings": "embeddings", "graph-embeddings": "graph-embeddings", "triples": "triples", "document-load": "document", - "text-load": "text-document", + "text-load": "text-document", "flow": "flow", "knowledge": "knowledge", "config": "config", "librarian": "librarian", - "document-rag": "document-rag" + "document-rag": "document-rag", } - - async def handle_message(self, message: Dict[Any, Any]) -> Optional[Dict[Any, Any]]: + + async def handle_message( + self, message: Dict[Any, Any], + sender: Callable[[dict], Awaitable[None]], + ): async with self.semaphore: - task = asyncio.create_task(self._process_message(message)) + task = asyncio.create_task( + self._process_message(message, sender) + ) self.active_tasks.add(task) - + try: - result = await task - return result + await task finally: self.active_tasks.discard(task) - - async def _process_message(self, message: Dict[Any, Any]) -> Dict[Any, Any]: + + async def _authenticate(self, token): + if not self.auth: + raise RuntimeError("Auth not configured") + return await self.auth.authenticate(_TokenShim(token)) + + async def _process_message( + self, message: Dict[Any, Any], + sender: Callable[[dict], Awaitable[None]], + ): request_id = message.get('id', str(uuid.uuid4())) service = message.get('service') request_data = message.get('request', {}) - flow_id = message.get('flow', 'default') # Default flow - - logger.info(f"Processing message {request_id} for service {service} on flow {flow_id}") - + token = message.get('token', '') + flow_id = message.get('flow', 'default') + + logger.info( + f"Processing message {request_id} for service " + f"{service} on flow {flow_id}" + ) + try: if not self.dispatcher_manager: - raise RuntimeError("DispatcherManager not available - backend and config_receiver required") - - # Use DispatcherManager for flow-based processing - responder = WebSocketResponder() - - # Map websocket service name to dispatcher service name + raise RuntimeError( + "DispatcherManager not available" + ) + + identity = await self._authenticate(token) + workspace = identity.workspace + + async def responder(resp, fin): + await sender({ + "id": request_id, + "response": resp, + "complete": fin, + }) + dispatcher_service = self.service_mapping.get(service, service) - - # Check if this is a global service or flow service + from ..gateway.dispatch.manager import global_dispatchers if dispatcher_service in global_dispatchers: - # Use global service dispatcher await self.dispatcher_manager.invoke_global_service( - request_data, responder, dispatcher_service + request_data, responder, dispatcher_service, + workspace=workspace, ) else: - # Use DispatcherManager to process the request through Pulsar queues await self.dispatcher_manager.invoke_flow_service( - request_data, responder, flow_id, dispatcher_service + request_data, responder, workspace, flow_id, + dispatcher_service, ) - - # Get the response from the responder - if responder.completed: - response_data = responder.response - else: - response_data = {'error': 'No response received'} - - response = { - 'id': request_id, - 'response': response_data - } - + except Exception as e: logger.error(f"Error processing message {request_id}: {e}") - response = { - 'id': request_id, - 'response': {'error': str(e)} - } - + await sender({ + "id": request_id, + "error": {"message": str(e), "type": "error"}, + "complete": True, + }) + logger.info(f"Completed processing message {request_id}") - return response - - + async def shutdown(self): if self.active_tasks: - logger.info(f"Waiting for {len(self.active_tasks)} active tasks to complete") + logger.info( + f"Waiting for {len(self.active_tasks)} active " + f"tasks to complete" + ) await asyncio.gather(*self.active_tasks, return_exceptions=True) - - # DispatcherManager handles its own cleanup + logger.info("Dispatcher shutdown complete") diff --git a/trustgraph-flow/trustgraph/rev_gateway/service.py b/trustgraph-flow/trustgraph/rev_gateway/service.py index cc905172..39180e2c 100644 --- a/trustgraph-flow/trustgraph/rev_gateway/service.py +++ b/trustgraph-flow/trustgraph/rev_gateway/service.py @@ -1,3 +1,9 @@ +""" +Reverse gateway. Initiates outbound WebSocket connections to a remote +relay and dispatches incoming requests through the same DispatcherManager +pipeline as api-gateway. +""" + import asyncio import argparse import logging @@ -9,82 +15,86 @@ from typing import Optional from urllib.parse import urlparse, urlunparse from .dispatcher import MessageDispatcher +from ..gateway.auth import IamAuth from ..gateway.config.receiver import ConfigReceiver -from ..base import get_pubsub +from ..base.pubsub import get_pubsub, add_pubsub_args +from ..base.logging import setup_logging, add_logging_args logger = logging.getLogger("rev_gateway") -logger.setLevel(logging.INFO) default_websocket = "ws://localhost:7650/out" +default_timeout = 600 class ReverseGateway: - - def __init__(self, websocket_uri: str = None, max_workers: int = 10, - pulsar_host: str = None, pulsar_api_key: str = None, - pulsar_listener: str = None): - # Set default WebSocket URI with environment variable support + + def __init__(self, **config): + websocket_uri = config.get("websocket_uri") if websocket_uri is None: websocket_uri = os.getenv("WEBSOCKET_URI", default_websocket) - - # Parse and validate the WebSocket URI + parsed_uri = urlparse(websocket_uri) if parsed_uri.scheme not in ('ws', 'wss'): - raise ValueError(f"WebSocket URI must use ws:// or wss:// scheme, got: {parsed_uri.scheme}") + raise ValueError( + f"WebSocket URI must use ws:// or wss:// scheme, " + f"got: {parsed_uri.scheme}" + ) if not parsed_uri.netloc: - raise ValueError(f"WebSocket URI must include hostname, got: {websocket_uri}") - - # Store parsed components for debugging/logging + raise ValueError( + f"WebSocket URI must include hostname, " + f"got: {websocket_uri}" + ) + self.websocket_uri = websocket_uri self.host = parsed_uri.hostname self.port = parsed_uri.port self.scheme = parsed_uri.scheme self.path = parsed_uri.path or "/ws" - - # Construct the full URL (in case path was missing) + if not parsed_uri.path: self.url = f"{self.scheme}://{parsed_uri.netloc}/ws" else: self.url = websocket_uri - - self.max_workers = max_workers + + self.max_workers = int(config.get("max_workers", 10)) + self.timeout = int(config.get("timeout", default_timeout)) self.ws: Optional[ClientWebSocketResponse] = None self.session: Optional[ClientSession] = None self.running = False self.reconnect_delay = 3.0 - - # Pulsar configuration - self.pulsar_host = pulsar_host or os.getenv("PULSAR_HOST", "pulsar://pulsar:6650") - self.pulsar_api_key = pulsar_api_key or os.getenv("PULSAR_API_KEY", None) - self.pulsar_listener = pulsar_listener - # Create backend using factory - backend_params = { - 'pulsar_host': self.pulsar_host, - 'pulsar_api_key': self.pulsar_api_key, - 'pulsar_listener': self.pulsar_listener, - } - self.backend = get_pubsub(**backend_params) + self.backend = get_pubsub(**config) - # Initialize config receiver - self.config_receiver = ConfigReceiver(self.backend) + self.auth = IamAuth( + backend=self.backend, + id=config.get("id", "rev-gateway"), + ) + + self.config_receiver = ConfigReceiver( + self.backend, auth=self.auth, + ) + + self.dispatcher = MessageDispatcher( + self.max_workers, self.config_receiver, self.backend, + auth=self.auth, timeout=self.timeout, + ) - # Initialize dispatcher with config_receiver and backend - must be created after config_receiver - self.dispatcher = MessageDispatcher(max_workers, self.config_receiver, self.backend) - async def connect(self) -> bool: try: if self.session is None: self.session = ClientSession() - + logger.info(f"Connecting to {self.url}") self.ws = await self.session.ws_connect(self.url) - logger.info(f"WebSocket connection established to {self.host}:{self.port or 'default'}") + logger.info( + f"WebSocket connection established to " + f"{self.host}:{self.port or 'default'}" + ) return True - + except Exception as e: logger.error(f"Failed to connect to {self.url}: {e}") return False - + async def disconnect(self): if self.ws and not self.ws.closed: await self.ws.close() @@ -92,32 +102,31 @@ class ReverseGateway: await self.session.close() self.ws = None self.session = None - + async def send_message(self, message: dict): if self.ws and not self.ws.closed: try: await self.ws.send_str(json.dumps(message)) except Exception as e: logger.error(f"Failed to send message: {e}") - + async def handle_message(self, message: str): try: logger.debug(f"Received message: {message}") - + msg_data = json.loads(message) - response = await self.dispatcher.handle_message(msg_data) - - if response: - await self.send_message(response) - + await self.dispatcher.handle_message( + msg_data, self.send_message, + ) + except Exception as e: logger.error(f"Error handling message: {e}") - + async def listen(self): while self.running and self.ws and not self.ws.closed: try: msg = await self.ws.receive() - + if msg.type == WSMsgType.TEXT: await self.handle_message(msg.data) elif msg.type == WSMsgType.BINARY: @@ -125,31 +134,33 @@ class ReverseGateway: elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR): logger.warning("WebSocket closed or error occurred") break - + except Exception as e: logger.error(f"Error in listen loop: {e}") break - + async def run(self): self.running = True logger.info("Starting reverse gateway") - - # Start config receiver - logger.info("Starting config receiver") + + await self.auth.start() await self.config_receiver.start() - + while self.running: try: if await self.connect(): await self.listen() else: - logger.warning(f"Connection failed, retrying in {self.reconnect_delay} seconds") - + logger.warning( + f"Connection failed, retrying in " + f"{self.reconnect_delay} seconds" + ) + await self.disconnect() - + if self.running: await asyncio.sleep(self.reconnect_delay) - + except KeyboardInterrupt: logger.info("Shutdown requested") break @@ -157,77 +168,69 @@ class ReverseGateway: logger.error(f"Unexpected error: {e}") if self.running: await asyncio.sleep(self.reconnect_delay) - + await self.shutdown() - + async def shutdown(self): logger.info("Shutting down reverse gateway") self.running = False await self.dispatcher.shutdown() await self.disconnect() - # Close backend if hasattr(self, 'backend'): self.backend.close() - + def stop(self): self.running = False -def parse_args(): + +def run(): + parser = argparse.ArgumentParser( prog="reverse-gateway", - description="TrustGraph Reverse Gateway - WebSocket to Pulsar bridge" + description=__doc__, ) - + + parser.add_argument( + '--id', + default='rev-gateway', + help='Service identifier (default: rev-gateway)', + ) + parser.add_argument( '--websocket-uri', default=None, - help=f'WebSocket URI to connect to (default: {default_websocket} or WEBSOCKET_URI env var)' + help=f'WebSocket URI to connect to (default: {default_websocket})', ) - + parser.add_argument( '--max-workers', type=int, default=10, - help='Maximum concurrent message handlers (default: 10)' + help='Maximum concurrent message handlers (default: 10)', ) - - parser.add_argument( - '-p', '--pulsar-host', - default=None, - help='Pulsar host URL (default: pulsar://pulsar:6650 or PULSAR_HOST env var)' - ) - - parser.add_argument( - '--pulsar-api-key', - default=None, - help='Pulsar API key for authentication (default: PULSAR_API_KEY env var)' - ) - - parser.add_argument( - '--pulsar-listener', - default=None, - help='Pulsar listener name' - ) - - return parser.parse_args() -def run(): - args = parse_args() - - gateway = ReverseGateway( - websocket_uri=args.websocket_uri, - max_workers=args.max_workers, - pulsar_host=args.pulsar_host, - pulsar_api_key=args.pulsar_api_key, - pulsar_listener=args.pulsar_listener + parser.add_argument( + '--timeout', + type=int, + default=default_timeout, + help=f'Request timeout in seconds (default: {default_timeout})', ) - + + add_pubsub_args(parser) + add_logging_args(parser) + + args = parser.parse_args() + args = vars(args) + + setup_logging(args) + + gateway = ReverseGateway(**args) + logger.info(f"Starting reverse gateway:") logger.info(f" WebSocket URI: {gateway.url}") - logger.info(f" Max workers: {args.max_workers}") - logger.info(f" Pulsar host: {gateway.pulsar_host}") - + logger.info(f" Max workers: {gateway.max_workers}") + try: asyncio.run(gateway.run()) except KeyboardInterrupt: From 2c3a699af39111411ab02389781d8fffe2871c0c Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Thu, 21 May 2026 10:50:11 +0100 Subject: [PATCH 2/7] feat: extend SPARQL evaluator with comprehensive function and operator support (#945) Add 30+ SPARQL 1.1 built-in functions and the MINUS algebra operator to the custom SPARQL query backend. String functions: - SUBSTR (2-arg and 3-arg forms), STRBEFORE, STRAFTER - REPLACE (regex with flags), ENCODE_FOR_URI Numeric functions: - FLOOR, CEIL, ROUND, ABS Date/time accessors: - YEAR, MONTH, DAY, HOURS, MINUTES, SECONDS - NOW, TZ Hash functions: - MD5, SHA1, SHA256, SHA512 Term constructors: - IRI/URI, BNODE, UUID, STRUUID Other functions: - LANGMATCHES, RAND - EXISTS / NOT EXISTS (with async pre-evaluation to bridge the sync expression evaluator and async algebra evaluator) Algebra: - MINUS set-difference operator - HAVING already works via rdflib's Filter mapping (verified) Fix SPARQL ORDER handling Includes 653 lines of new unit tests covering all added functionality across expressions, solutions, and algebra layers. --- tests/unit/test_query/test_sparql_algebra.py | 185 ++++++++ .../test_query/test_sparql_expressions.py | 432 ++++++++++++++++++ .../unit/test_query/test_sparql_solutions.py | 60 ++- .../trustgraph/query/sparql/algebra.py | 74 ++- .../trustgraph/query/sparql/expressions.py | 238 +++++++++- .../trustgraph/query/sparql/solutions.py | 61 ++- 6 files changed, 1021 insertions(+), 29 deletions(-) diff --git a/tests/unit/test_query/test_sparql_algebra.py b/tests/unit/test_query/test_sparql_algebra.py index 9827b2de..980ce870 100644 --- a/tests/unit/test_query/test_sparql_algebra.py +++ b/tests/unit/test_query/test_sparql_algebra.py @@ -84,6 +84,20 @@ def make_distinct(inner): return node +def make_filter(inner, expr): + node = CompValue("Filter") + node.p = inner + node.expr = expr + return node + + +def make_minus(left, right): + node = CompValue("Minus") + node.p1 = left + node.p2 = right + return node + + class TestQueryPattern: """Tests for _query_pattern — the leaf that calls TriplesClient.""" @@ -282,6 +296,177 @@ class TestEvaluate: assert len(solutions) == 1 + @pytest.mark.asyncio + async def test_minus_removes_matching(self): + tc = AsyncMock() + + alice = iri("http://example.com/alice") + bob = iri("http://example.com/bob") + knows = iri("http://example.com/knows") + hates = iri("http://example.com/hates") + charlie = iri("http://example.com/charlie") + + left_triple = make_triple(alice, knows, bob) + right_triple1 = make_triple(alice, knows, bob) + right_triple2 = make_triple(alice, hates, charlie) + + left_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) + ) + right_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/hates"), Variable("r")) + ) + + async def mock_query(**kwargs): + pred = kwargs.get("p") + if pred and pred.iri == "http://example.com/knows": + return [left_triple] + elif pred and pred.iri == "http://example.com/hates": + return [right_triple2] + return [] + + tc.query.side_effect = mock_query + + tree = make_select( + make_project( + make_minus(left_bgp, right_bgp), + ["s", "o"] + ) + ) + + solutions = await evaluate(tree, tc, collection="default") + + # alice knows bob, but alice also hates charlie + # shared var is "s" (alice), so alice's solution is removed + assert len(solutions) == 0 + + @pytest.mark.asyncio + async def test_minus_no_shared_vars_preserves_all(self): + tc = AsyncMock() + + alice = iri("http://example.com/alice") + bob = iri("http://example.com/bob") + + left_triple = make_triple(alice, iri("http://example.com/p"), bob) + + left_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/p"), Variable("o")) + ) + right_bgp = make_bgp( + (Variable("x"), URIRef("http://example.com/q"), Variable("y")) + ) + + async def mock_query(**kwargs): + pred = kwargs.get("p") + if pred and pred.iri == "http://example.com/p": + return [left_triple] + return [] + + tc.query.side_effect = mock_query + + tree = make_select( + make_project( + make_minus(left_bgp, right_bgp), + ["s", "o"] + ) + ) + + solutions = await evaluate(tree, tc, collection="default") + + assert len(solutions) == 1 + + @pytest.mark.asyncio + async def test_filter_exists_keeps_matching(self): + tc = AsyncMock() + + alice = iri("http://example.com/alice") + bob = iri("http://example.com/bob") + charlie = iri("http://example.com/charlie") + + left_triple1 = make_triple(alice, iri("http://example.com/knows"), bob) + left_triple2 = make_triple(alice, iri("http://example.com/knows"), charlie) + exists_triple = make_triple(bob, iri("http://example.com/likes"), alice) + + left_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) + ) + exists_bgp = make_bgp( + (Variable("o"), URIRef("http://example.com/likes"), Variable("_any")) + ) + + async def mock_query(**kwargs): + pred = kwargs.get("p") + if pred and pred.iri == "http://example.com/knows": + return [left_triple1, left_triple2] + elif pred and pred.iri == "http://example.com/likes": + return [exists_triple] + return [] + + tc.query.side_effect = mock_query + + exists_expr = CompValue("Builtin_EXISTS") + exists_expr.graph = exists_bgp + + tree = make_select( + make_project( + make_filter(left_bgp, exists_expr), + ["s", "o"] + ) + ) + + solutions = await evaluate(tree, tc, collection="default") + + # Only bob has a "likes" triple, so only the bob solution passes + result_objects = [s["o"].iri for s in solutions] + assert "http://example.com/bob" in result_objects + assert "http://example.com/charlie" not in result_objects + + @pytest.mark.asyncio + async def test_filter_not_exists_removes_matching(self): + tc = AsyncMock() + + alice = iri("http://example.com/alice") + bob = iri("http://example.com/bob") + charlie = iri("http://example.com/charlie") + + left_triple1 = make_triple(alice, iri("http://example.com/knows"), bob) + left_triple2 = make_triple(alice, iri("http://example.com/knows"), charlie) + exists_triple = make_triple(bob, iri("http://example.com/likes"), alice) + + left_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) + ) + exists_bgp = make_bgp( + (Variable("o"), URIRef("http://example.com/likes"), Variable("_any")) + ) + + async def mock_query(**kwargs): + pred = kwargs.get("p") + if pred and pred.iri == "http://example.com/knows": + return [left_triple1, left_triple2] + elif pred and pred.iri == "http://example.com/likes": + return [exists_triple] + return [] + + tc.query.side_effect = mock_query + + not_exists_expr = CompValue("Builtin_NOTEXISTS") + not_exists_expr.graph = exists_bgp + + tree = make_select( + make_project( + make_filter(left_bgp, not_exists_expr), + ["s", "o"] + ) + ) + + solutions = await evaluate(tree, tc, collection="default") + + # bob has a "likes" triple so is removed; charlie stays + result_objects = [s["o"].iri for s in solutions] + assert "http://example.com/charlie" in result_objects + assert "http://example.com/bob" not in result_objects + @pytest.mark.asyncio async def test_unsupported_node_returns_empty_solution(self): tc = AsyncMock() diff --git a/tests/unit/test_query/test_sparql_expressions.py b/tests/unit/test_query/test_sparql_expressions.py index 63e9188f..87c862e8 100644 --- a/tests/unit/test_query/test_sparql_expressions.py +++ b/tests/unit/test_query/test_sparql_expressions.py @@ -300,6 +300,438 @@ class TestBuiltinFunctions: flags=None) assert evaluate_expression(expr, {"x": lit("hello")}) is False + def test_substr_three_args(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("SUBSTR", + arg=Variable("x"), + start=Literal(1), + length=Literal(4)) + result = evaluate_expression(expr, {"x": lit("2024-03-15")}) + assert result.type == LITERAL + assert result.value == "2024" + + def test_substr_two_args(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("SUBSTR", + arg=Variable("x"), + start=Literal(6), + length=None) + result = evaluate_expression(expr, {"x": lit("2024-03-15")}) + assert result.type == LITERAL + assert result.value == "03-15" + + def test_substr_middle(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("SUBSTR", + arg=Variable("x"), + start=Literal(6), + length=Literal(2)) + result = evaluate_expression(expr, {"x": lit("2024-03-15")}) + assert result.type == LITERAL + assert result.value == "03" + + def test_substr_null_start(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("SUBSTR", + arg=Variable("x"), + start=Variable("missing"), + length=None) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result is None + + def test_year(self): + from rdflib.term import Variable + expr = self._make_builtin("YEAR", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15", datatype=XSD + "date")} + ) + assert result == 2024 + + def test_month(self): + from rdflib.term import Variable + expr = self._make_builtin("MONTH", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15", datatype=XSD + "date")} + ) + assert result == 3 + + def test_day(self): + from rdflib.term import Variable + expr = self._make_builtin("DAY", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15", datatype=XSD + "date")} + ) + assert result == 15 + + def test_hours(self): + from rdflib.term import Variable + expr = self._make_builtin("HOURS", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")} + ) + assert result == 10 + + def test_minutes(self): + from rdflib.term import Variable + expr = self._make_builtin("MINUTES", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")} + ) + assert result == 30 + + def test_seconds(self): + from rdflib.term import Variable + expr = self._make_builtin("SECONDS", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")} + ) + assert result == 45 + + def test_year_from_datetime(self): + from rdflib.term import Variable + expr = self._make_builtin("YEAR", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")} + ) + assert result == 2024 + + def test_hours_from_date_returns_zero(self): + from rdflib.term import Variable + expr = self._make_builtin("HOURS", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15", datatype=XSD + "date")} + ) + assert result == 0 + + def test_year_invalid_date(self): + from rdflib.term import Variable + expr = self._make_builtin("YEAR", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("not-a-date")} + ) + assert result is None + + def test_floor(self): + from rdflib.term import Variable + expr = self._make_builtin("FLOOR", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("3.7")}) == 3 + + def test_floor_negative(self): + from rdflib.term import Variable + expr = self._make_builtin("FLOOR", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("-2.3")}) == -3 + + def test_floor_none(self): + from rdflib.term import Variable + expr = self._make_builtin("FLOOR", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("abc")}) is None + + def test_ceil(self): + from rdflib.term import Variable + expr = self._make_builtin("CEIL", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("3.2")}) == 4 + + def test_ceil_negative(self): + from rdflib.term import Variable + expr = self._make_builtin("CEIL", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("-2.7")}) == -2 + + def test_abs_positive(self): + from rdflib.term import Variable + expr = self._make_builtin("ABS", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("42")}) == 42 + + def test_abs_negative(self): + from rdflib.term import Variable + expr = self._make_builtin("ABS", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("-42")}) == 42 + + def test_abs_none(self): + from rdflib.term import Variable + expr = self._make_builtin("ABS", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("abc")}) is None + + def test_replace_simple(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("REPLACE", + arg=Variable("x"), + pattern=Literal(" BC"), + replacement=Literal(""), + flags=None) + result = evaluate_expression(expr, {"x": lit("500 BC")}) + assert result.type == LITERAL + assert result.value == "500" + + def test_replace_regex(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("REPLACE", + arg=Variable("x"), + pattern=Literal("[0-9]+"), + replacement=Literal("X"), + flags=None) + result = evaluate_expression(expr, {"x": lit("abc123def456")}) + assert result.value == "abcXdefX" + + def test_replace_case_insensitive(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("REPLACE", + arg=Variable("x"), + pattern=Literal("hello"), + replacement=Literal("world"), + flags=Literal("i")) + result = evaluate_expression(expr, {"x": lit("HELLO there")}) + assert result.value == "world there" + + def test_round_up(self): + from rdflib.term import Variable + expr = self._make_builtin("ROUND", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("3.7")}) == 4 + + def test_round_down(self): + from rdflib.term import Variable + expr = self._make_builtin("ROUND", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("3.2")}) == 3 + + def test_round_none(self): + from rdflib.term import Variable + expr = self._make_builtin("ROUND", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("abc")}) is None + + def test_strbefore(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("STRBEFORE", + arg1=Variable("x"), arg2=Literal("-")) + result = evaluate_expression(expr, {"x": lit("2024-03-15")}) + assert result.value == "2024" + + def test_strbefore_not_found(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("STRBEFORE", + arg1=Variable("x"), arg2=Literal("/")) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result.value == "" + + def test_strafter(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("STRAFTER", + arg1=Variable("x"), arg2=Literal("-")) + result = evaluate_expression(expr, {"x": lit("2024-03-15")}) + assert result.value == "03-15" + + def test_strafter_not_found(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("STRAFTER", + arg1=Variable("x"), arg2=Literal("/")) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result.value == "" + + def test_encode_for_uri(self): + from rdflib.term import Variable + expr = self._make_builtin("ENCODE_FOR_URI", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("hello world")}) + assert result.value == "hello%20world" + + def test_encode_for_uri_special_chars(self): + from rdflib.term import Variable + expr = self._make_builtin("ENCODE_FOR_URI", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("a/b?c=d&e")}) + assert result.value == "a%2Fb%3Fc%3Dd%26e" + + def test_langmatches_basic(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("LANGMATCHES", + arg1=Literal("en"), arg2=Literal("en")) + assert evaluate_expression(expr, {}) is True + + def test_langmatches_subtag(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("LANGMATCHES", + arg1=Literal("en-US"), arg2=Literal("en")) + assert evaluate_expression(expr, {}) is True + + def test_langmatches_wildcard(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("LANGMATCHES", + arg1=Literal("fr"), arg2=Literal("*")) + assert evaluate_expression(expr, {}) is True + + def test_langmatches_wildcard_empty(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("LANGMATCHES", + arg1=Literal(""), arg2=Literal("*")) + assert evaluate_expression(expr, {}) is False + + def test_langmatches_no_match(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("LANGMATCHES", + arg1=Literal("fr"), arg2=Literal("en")) + assert evaluate_expression(expr, {}) is False + + def test_iri_constructor(self): + from rdflib.term import Variable + expr = self._make_builtin("IRI", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("http://example.com/test")} + ) + assert result.type == IRI + assert result.iri == "http://example.com/test" + + def test_uri_constructor(self): + from rdflib.term import Variable + expr = self._make_builtin("URI", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("http://example.com/test")} + ) + assert result.type == IRI + assert result.iri == "http://example.com/test" + + def test_bnode_no_arg(self): + expr = self._make_builtin("BNODE") + result = evaluate_expression(expr, {}) + assert result.type == BLANK + assert len(result.id) > 0 + + def test_bnode_with_label(self): + from rdflib import Literal + expr = self._make_builtin("BNODE", arg=Literal("mynode")) + result = evaluate_expression(expr, {}) + assert result.type == BLANK + assert result.id == "mynode" + + def test_now(self): + import re as re_mod + expr = self._make_builtin("NOW") + result = evaluate_expression(expr, {}) + assert result.type == LITERAL + assert result.datatype == XSD + "dateTime" + assert re_mod.match(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}", result.value) + + def test_tz_with_utc(self): + from rdflib.term import Variable + expr = self._make_builtin("TZ", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15T10:30:45+0000", + datatype=XSD + "dateTime")} + ) + assert result.type == LITERAL + assert result.value == "+00:00" + + def test_tz_no_timezone(self): + from rdflib.term import Variable + expr = self._make_builtin("TZ", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15T10:30:45", + datatype=XSD + "dateTime")} + ) + assert result.value == "" + + def test_rand(self): + expr = self._make_builtin("RAND") + result = evaluate_expression(expr, {}) + assert isinstance(result, float) + assert 0.0 <= result < 1.0 + + def test_uuid(self): + import re as re_mod + expr = self._make_builtin("UUID") + result = evaluate_expression(expr, {}) + assert result.type == IRI + assert result.iri.startswith("urn:uuid:") + uuid_part = result.iri[len("urn:uuid:"):] + assert re_mod.match( + r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}", + uuid_part + ) + + def test_struuid(self): + import re as re_mod + expr = self._make_builtin("STRUUID") + result = evaluate_expression(expr, {}) + assert result.type == LITERAL + assert re_mod.match( + r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}", + result.value + ) + + def test_md5(self): + from rdflib.term import Variable + expr = self._make_builtin("MD5", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result.type == LITERAL + assert result.value == "5d41402abc4b2a76b9719d911017c592" + + def test_sha1(self): + from rdflib.term import Variable + expr = self._make_builtin("SHA1", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result.type == LITERAL + assert result.value == "aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d" + + def test_sha256(self): + from rdflib.term import Variable + expr = self._make_builtin("SHA256", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result.type == LITERAL + assert result.value == ( + "2cf24dba5fb0a30e26e83b2ac5b9e29e" + "1b161e5c1fa7425e73043362938b9824" + ) + + def test_sha512(self): + from rdflib.term import Variable + expr = self._make_builtin("SHA512", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result.type == LITERAL + assert len(result.value) == 128 + + def test_exists_with_callback(self): + from rdflib.plugins.sparql.parserutils import CompValue + graph = CompValue("BGP") + expr = self._make_builtin("EXISTS", graph=graph) + cb = lambda g, s: True + result = evaluate_expression(expr, {}, exists_cb=cb) + assert result is True + + def test_exists_callback_false(self): + from rdflib.plugins.sparql.parserutils import CompValue + graph = CompValue("BGP") + expr = self._make_builtin("EXISTS", graph=graph) + cb = lambda g, s: False + result = evaluate_expression(expr, {}, exists_cb=cb) + assert result is False + + def test_notexists_with_callback(self): + from rdflib.plugins.sparql.parserutils import CompValue + graph = CompValue("BGP") + expr = self._make_builtin("NOTEXISTS", graph=graph) + cb = lambda g, s: True + result = evaluate_expression(expr, {}, exists_cb=cb) + assert result is False + + def test_notexists_callback_false(self): + from rdflib.plugins.sparql.parserutils import CompValue + graph = CompValue("BGP") + expr = self._make_builtin("NOTEXISTS", graph=graph) + cb = lambda g, s: False + result = evaluate_expression(expr, {}, exists_cb=cb) + assert result is True + class TestEffectiveBoolean: diff --git a/tests/unit/test_query/test_sparql_solutions.py b/tests/unit/test_query/test_sparql_solutions.py index 5805ca84..7588a95b 100644 --- a/tests/unit/test_query/test_sparql_solutions.py +++ b/tests/unit/test_query/test_sparql_solutions.py @@ -5,7 +5,7 @@ Tests for SPARQL solution sequence operations. import pytest from trustgraph.schema import Term, IRI, LITERAL from trustgraph.query.sparql.solutions import ( - hash_join, left_join, union, project, distinct, + hash_join, left_join, minus, union, project, distinct, order_by, slice_solutions, _terms_equal, _compatible, ) @@ -311,6 +311,30 @@ class TestOrderBy: result = order_by(solutions, []) assert len(result) == 1 + def test_order_by_numeric_literals(self): + solutions = [ + {"year": lit("1950")}, + {"year": lit("700")}, + {"year": lit("2000")}, + {"year": lit("450")}, + {"year": lit("1200")}, + ] + key_fns = [(lambda sol: sol.get("year"), True)] + result = order_by(solutions, key_fns) + values = [s["year"].value for s in result] + assert values == ["450", "700", "1200", "1950", "2000"] + + def test_order_by_numeric_descending(self): + solutions = [ + {"year": lit("1950")}, + {"year": lit("700")}, + {"year": lit("2000")}, + ] + key_fns = [(lambda sol: sol.get("year"), False)] + result = order_by(solutions, key_fns) + values = [s["year"].value for s in result] + assert values == ["2000", "1950", "700"] + class TestSlice: @@ -343,3 +367,37 @@ class TestSlice: solutions = [{"s": alice}, {"s": bob}] result = slice_solutions(solutions) assert len(result) == 2 + + +class TestMinus: + + def test_removes_compatible(self, alice, bob): + left = [{"s": alice}, {"s": bob}] + right = [{"s": alice}] + result = minus(left, right) + assert len(result) == 1 + assert result[0]["s"].iri == "http://example.com/bob" + + def test_empty_right_preserves_all(self, alice, bob): + left = [{"s": alice}, {"s": bob}] + result = minus(left, []) + assert len(result) == 2 + + def test_no_shared_variables_preserves_all(self, alice, bob): + left = [{"s": alice}] + right = [{"t": bob}] + result = minus(left, right) + assert len(result) == 1 + + def test_all_removed(self, alice): + left = [{"s": alice}] + right = [{"s": alice}] + result = minus(left, right) + assert len(result) == 0 + + def test_partial_shared_variables(self, alice, bob): + left = [{"s": alice, "p": lit("x")}, {"s": bob, "p": lit("y")}] + right = [{"s": alice}] + result = minus(left, right) + assert len(result) == 1 + assert result[0]["s"].iri == "http://example.com/bob" diff --git a/trustgraph-flow/trustgraph/query/sparql/algebra.py b/trustgraph-flow/trustgraph/query/sparql/algebra.py index 76b1ad8e..6f0227c8 100644 --- a/trustgraph-flow/trustgraph/query/sparql/algebra.py +++ b/trustgraph-flow/trustgraph/query/sparql/algebra.py @@ -17,7 +17,7 @@ from ... knowledge import Uri from ... knowledge import Literal as KgLiteral from . parser import rdflib_term_to_term from . solutions import ( - hash_join, left_join, union, project, distinct, + hash_join, left_join, minus, union, project, distinct, order_by, slice_solutions, _term_key, ) from . expressions import evaluate_expression, _effective_boolean @@ -159,14 +159,69 @@ async def _eval_union(node, tc, collection, limit): return union(left, right)[:limit] +async def _eval_minus(node, tc, collection, limit): + """Evaluate a Minus node.""" + left = await evaluate(node.p1, tc, collection, limit) + right = await evaluate(node.p2, tc, collection, limit) + return minus(left, right) + + +async def _check_exists(graph_node, sol, tc, collection, limit): + """Evaluate an EXISTS graph pattern against a solution.""" + results = await evaluate(graph_node, tc, collection, limit) + for r in results: + shared = set(sol.keys()) & set(r.keys()) + if all( + _term_key(sol[v]) == _term_key(r[v]) + for v in shared + if sol.get(v) is not None and r.get(v) is not None + ): + return True + return False + + +async def _pre_eval_exists(expr, sol, tc, collection, limit, cache): + """Walk an expression tree, pre-evaluate EXISTS/NOT EXISTS, cache results.""" + if not isinstance(expr, CompValue): + return + if expr.name in ("Builtin_EXISTS", "Builtin_NOTEXISTS"): + key = id(expr.graph), id(sol) + if key not in cache: + cache[key] = await _check_exists( + expr.graph, sol, tc, collection, limit + ) + return + for attr in ("expr", "other", "arg", "arg1", "arg2", "arg3"): + child = getattr(expr, attr, None) + if child is None: + continue + if isinstance(child, CompValue): + await _pre_eval_exists(child, sol, tc, collection, limit, cache) + elif isinstance(child, (list, tuple)): + for item in child: + if isinstance(item, CompValue): + await _pre_eval_exists( + item, sol, tc, collection, limit, cache + ) + + async def _eval_filter(node, tc, collection, limit): """Evaluate a Filter node.""" solutions = await evaluate(node.p, tc, collection, limit) expr = node.expr - return [ - sol for sol in solutions - if _effective_boolean(evaluate_expression(expr, sol)) - ] + exists_cache = {} + + def exists_cb(graph_node, sol): + key = id(graph_node), id(sol) + return exists_cache.get(key, False) + + result = [] + for sol in solutions: + await _pre_eval_exists(expr, sol, tc, collection, limit, exists_cache) + if _effective_boolean(evaluate_expression(expr, sol, exists_cb=exists_cb)): + result.append(sol) + + return result async def _eval_distinct(node, tc, collection, limit): @@ -222,10 +277,16 @@ async def _eval_extend(node, tc, collection, limit): solutions = await evaluate(node.p, tc, collection, limit) var_name = str(node.var) expr = node.expr + exists_cache = {} + + def exists_cb(graph_node, sol): + key = id(graph_node), id(sol) + return exists_cache.get(key, False) result = [] for sol in solutions: - val = evaluate_expression(expr, sol) + await _pre_eval_exists(expr, sol, tc, collection, limit, exists_cache) + val = evaluate_expression(expr, sol, exists_cb=exists_cb) new_sol = dict(sol) if isinstance(val, Term): new_sol[var_name] = val @@ -525,6 +586,7 @@ _HANDLERS = { "Join": _eval_join, "LeftJoin": _eval_left_join, "Union": _eval_union, + "Minus": _eval_minus, "Filter": _eval_filter, "Distinct": _eval_distinct, "Reduced": _eval_reduced, diff --git a/trustgraph-flow/trustgraph/query/sparql/expressions.py b/trustgraph-flow/trustgraph/query/sparql/expressions.py index eac1199c..ad3202d9 100644 --- a/trustgraph-flow/trustgraph/query/sparql/expressions.py +++ b/trustgraph-flow/trustgraph/query/sparql/expressions.py @@ -5,9 +5,15 @@ Evaluates rdflib algebra expression nodes against a solution (variable binding) to produce a value or boolean result. """ +import hashlib +import math +import random import re import logging import operator +import uuid +from datetime import datetime, date, timezone +from urllib.parse import quote from rdflib.term import Variable, URIRef, Literal, BNode from rdflib.plugins.sparql.parserutils import CompValue @@ -17,23 +23,31 @@ from . parser import rdflib_term_to_term logger = logging.getLogger(__name__) +_exists_callback = None + class ExpressionError(Exception): """Raised when a SPARQL expression cannot be evaluated.""" pass -def evaluate_expression(expr, solution): +def evaluate_expression(expr, solution, exists_cb=None): """ Evaluate a SPARQL expression against a solution binding. Args: expr: rdflib algebra expression node solution: dict mapping variable names to Term values + exists_cb: optional callback(graph_node, solution) -> bool for + EXISTS/NOT EXISTS evaluation; provided by algebra.py Returns: The result value (Term, bool, number, string, or None) """ + global _exists_callback + if exists_cb is not None: + _exists_callback = exists_cb + if expr is None: return True @@ -119,15 +133,7 @@ def _evaluate_comp_value(node, solution): if name == "Function": return _eval_function(node, solution) - # Exists / NotExists - if name == "Builtin_EXISTS": - # EXISTS requires graph pattern evaluation - not handled here - logger.warning("EXISTS not supported in filter expressions") - return True - - if name == "Builtin_NOTEXISTS": - logger.warning("NOT EXISTS not supported in filter expressions") - return True + # Exists / NotExists — handled via _eval_builtin now # TrueFilter (used with OPTIONAL) if name == "TrueFilter": @@ -335,6 +341,197 @@ def _eval_builtin(name, node, solution): return val return None + if builtin == "YEAR": + dt = _to_datetime(evaluate_expression(node.arg, solution)) + return dt.year if dt is not None else None + + if builtin == "MONTH": + dt = _to_datetime(evaluate_expression(node.arg, solution)) + return dt.month if dt is not None else None + + if builtin == "DAY": + dt = _to_datetime(evaluate_expression(node.arg, solution)) + return dt.day if dt is not None else None + + if builtin == "HOURS": + dt = _to_datetime(evaluate_expression(node.arg, solution)) + if dt is None: + return None + return dt.hour if isinstance(dt, datetime) else 0 + + if builtin == "MINUTES": + dt = _to_datetime(evaluate_expression(node.arg, solution)) + if dt is None: + return None + return dt.minute if isinstance(dt, datetime) else 0 + + if builtin == "SECONDS": + dt = _to_datetime(evaluate_expression(node.arg, solution)) + if dt is None: + return None + return dt.second if isinstance(dt, datetime) else 0 + + if builtin == "FLOOR": + val = _to_numeric(evaluate_expression(node.arg, solution)) + if val is None: + return None + return int(math.floor(val)) + + if builtin == "CEIL": + val = _to_numeric(evaluate_expression(node.arg, solution)) + if val is None: + return None + return int(math.ceil(val)) + + if builtin == "ABS": + val = _to_numeric(evaluate_expression(node.arg, solution)) + if val is None: + return None + return abs(val) + + if builtin == "ROUND": + val = _to_numeric(evaluate_expression(node.arg, solution)) + if val is None: + return None + return round(val) + + if builtin == "STRBEFORE": + string = _to_string(evaluate_expression(node.arg1, solution)) + sep = _to_string(evaluate_expression(node.arg2, solution)) + idx = string.find(sep) + if idx < 0: + return Term(type=LITERAL, value="") + return Term(type=LITERAL, value=string[:idx]) + + if builtin == "STRAFTER": + string = _to_string(evaluate_expression(node.arg1, solution)) + sep = _to_string(evaluate_expression(node.arg2, solution)) + idx = string.find(sep) + if idx < 0: + return Term(type=LITERAL, value="") + return Term(type=LITERAL, value=string[idx + len(sep):]) + + if builtin == "ENCODE_FOR_URI": + val = _to_string(evaluate_expression(node.arg, solution)) + return Term(type=LITERAL, value=quote(val, safe="")) + + if builtin == "REPLACE": + string = _to_string(evaluate_expression(node.arg, solution)) + pattern = _to_string(evaluate_expression(node.pattern, solution)) + replacement = _to_string( + evaluate_expression(node.replacement, solution) + ) + flags_str = "" + if hasattr(node, "flags") and node.flags is not None: + flags_str = _to_string(evaluate_expression(node.flags, solution)) + re_flags = 0 + if "i" in flags_str: + re_flags |= re.IGNORECASE + if "m" in flags_str: + re_flags |= re.MULTILINE + if "s" in flags_str: + re_flags |= re.DOTALL + try: + result = re.sub(pattern, replacement, string, flags=re_flags) + return Term(type=LITERAL, value=result) + except re.error: + return None + + if builtin == "SUBSTR": + string = _to_string(evaluate_expression(node.arg, solution)) + start = _to_numeric(evaluate_expression(node.start, solution)) + if start is None: + return None + start_idx = max(int(start) - 1, 0) + if hasattr(node, "length") and node.length is not None: + length = _to_numeric(evaluate_expression(node.length, solution)) + if length is None: + return None + return Term( + type=LITERAL, value=string[start_idx:start_idx + int(length)] + ) + return Term(type=LITERAL, value=string[start_idx:]) + + if builtin == "EXISTS": + if _exists_callback is not None: + return _exists_callback(node.graph, solution) + logger.warning("EXISTS requires an exists_cb; not available") + return True + + if builtin == "NOTEXISTS": + if _exists_callback is not None: + return not _exists_callback(node.graph, solution) + logger.warning("NOT EXISTS requires an exists_cb; not available") + return True + + if builtin == "LANGMATCHES": + tag = _to_string(evaluate_expression(node.arg1, solution)) + rng = _to_string(evaluate_expression(node.arg2, solution)) + if rng == "*": + return len(tag) > 0 + return tag.lower().startswith(rng.lower()) + + if builtin == "IRI" or builtin == "URI": + val = _to_string(evaluate_expression(node.arg, solution)) + return Term(type=IRI, iri=val) + + if builtin == "BNODE": + if hasattr(node, "arg") and node.arg is not None: + label = _to_string(evaluate_expression(node.arg, solution)) + return Term(type=BLANK, id=label) + return Term(type=BLANK, id=str(uuid.uuid4())) + + if builtin == "NOW": + now = datetime.now(timezone.utc) + return Term( + type=LITERAL, + value=now.strftime("%Y-%m-%dT%H:%M:%S%z"), + datatype="http://www.w3.org/2001/XMLSchema#dateTime", + ) + + if builtin == "TZ": + dt = _to_datetime(evaluate_expression(node.arg, solution)) + if dt is None: + return Term(type=LITERAL, value="") + if dt.tzinfo is not None: + offset = dt.strftime("%z") + if offset: + return Term(type=LITERAL, value=offset[:3] + ":" + offset[3:]) + return Term(type=LITERAL, value="") + + if builtin == "RAND": + return random.random() + + if builtin == "UUID": + return Term(type=IRI, iri="urn:uuid:" + str(uuid.uuid4())) + + if builtin == "STRUUID": + return Term(type=LITERAL, value=str(uuid.uuid4())) + + if builtin == "MD5": + val = _to_string(evaluate_expression(node.arg, solution)) + return Term( + type=LITERAL, value=hashlib.md5(val.encode()).hexdigest() + ) + + if builtin == "SHA1": + val = _to_string(evaluate_expression(node.arg, solution)) + return Term( + type=LITERAL, value=hashlib.sha1(val.encode()).hexdigest() + ) + + if builtin == "SHA256": + val = _to_string(evaluate_expression(node.arg, solution)) + return Term( + type=LITERAL, value=hashlib.sha256(val.encode()).hexdigest() + ) + + if builtin == "SHA512": + val = _to_string(evaluate_expression(node.arg, solution)) + return Term( + type=LITERAL, value=hashlib.sha512(val.encode()).hexdigest() + ) + if builtin == "sameTerm": left = evaluate_expression(node.arg1, solution) right = evaluate_expression(node.arg2, solution) @@ -454,6 +651,27 @@ def _to_numeric(val): return None +def _to_datetime(val): + """Convert a value to a date or datetime object.""" + if val is None: + return None + s = _to_string(val) + for fmt in ( + "%Y-%m-%dT%H:%M:%S", "%Y-%m-%dT%H:%M:%S%z", + "%Y-%m-%dT%H:%M:%SZ", "%Y-%m-%dT%H:%M", + "%Y-%m-%d", + ): + try: + return datetime.strptime(s, fmt) + except ValueError: + continue + try: + return datetime.fromisoformat(s) + except ValueError: + pass + return None + + def _comparable_value(val): """ Convert a value to a form suitable for comparison. diff --git a/trustgraph-flow/trustgraph/query/sparql/solutions.py b/trustgraph-flow/trustgraph/query/sparql/solutions.py index d1ea8373..edf3401d 100644 --- a/trustgraph-flow/trustgraph/query/sparql/solutions.py +++ b/trustgraph-flow/trustgraph/query/sparql/solutions.py @@ -150,6 +150,30 @@ def left_join(left, right, filter_fn=None): return results +def minus(left, right): + """ + MINUS operation: remove left solutions that are compatible with any + right solution sharing at least one variable. + """ + if not right: + return list(left) + + right_vars = set() + for sol in right: + right_vars.update(sol.keys()) + + results = [] + for sol_l in left: + shared = set(sol_l.keys()) & right_vars + if not shared: + results.append(sol_l) + continue + if not any(_compatible(sol_l, sol_r) for sol_r in right): + results.append(sol_l) + + return results + + def union(left, right): """Union two solution sequences (concatenation).""" return list(left) + list(right) @@ -177,6 +201,28 @@ def distinct(solutions): return results +def _sort_comparable(val): + """Convert a value to a form suitable for sort ordering.""" + if val is None: + return (0, "") + if isinstance(val, (int, float)): + return (2, val) + if isinstance(val, Term): + if val.type == LITERAL: + try: + if "." in val.value: + return (2, float(val.value)) + return (2, int(val.value)) + except (ValueError, TypeError): + pass + return (3, val.value) + elif val.type == IRI: + return (4, val.iri) + elif val.type == BLANK: + return (5, val.id) + return (6, str(val)) + + def order_by(solutions, key_fns): """ Sort solutions by the given key functions. @@ -191,14 +237,7 @@ def order_by(solutions, key_fns): keys = [] for fn, ascending in key_fns: val = fn(sol) - # Convert to comparable form - if val is None: - comparable = ("", "") - elif isinstance(val, Term): - comparable = _term_key(val) - else: - comparable = ("v", str(val)) - keys.append(comparable) + keys.append(_sort_comparable(val)) return keys # Handle ascending/descending @@ -224,10 +263,8 @@ def _mixed_sort(solutions, key_fns): def compare(a, b): for fn, ascending in key_fns: - va = fn(a) - vb = fn(b) - ka = _term_key(va) if isinstance(va, Term) else ("v", str(va)) if va is not None else ("", "") - kb = _term_key(vb) if isinstance(vb, Term) else ("v", str(vb)) if vb is not None else ("", "") + ka = _sort_comparable(fn(a)) + kb = _sort_comparable(fn(b)) if ka < kb: return -1 if ascending else 1 From 81e9a3ebe475e97f844bce035494098ea48c78fa Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Thu, 21 May 2026 12:24:38 +0100 Subject: [PATCH 3/7] fix: stop pushing SPARQL LIMIT into child algebra nodes (#946) The Slice evaluator was propagating the SPARQL LIMIT value as the inner limit for child evaluations, starving LeftJoin (OPTIONAL) and other operators of results. The safety limit parameter should flow through unchanged; LIMIT/OFFSET are applied only at the Slice node. --- trustgraph-flow/trustgraph/query/sparql/algebra.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/trustgraph-flow/trustgraph/query/sparql/algebra.py b/trustgraph-flow/trustgraph/query/sparql/algebra.py index 6f0227c8..d0f7d05e 100644 --- a/trustgraph-flow/trustgraph/query/sparql/algebra.py +++ b/trustgraph-flow/trustgraph/query/sparql/algebra.py @@ -262,13 +262,7 @@ async def _eval_order_by(node, tc, collection, limit): async def _eval_slice(node, tc, collection, limit): """Evaluate a Slice node (LIMIT/OFFSET).""" - # Pass tighter limit downstream if possible - inner_limit = limit - if node.length is not None: - offset = node.start or 0 - inner_limit = min(limit, offset + node.length) - - solutions = await evaluate(node.p, tc, collection, inner_limit) + solutions = await evaluate(node.p, tc, collection, limit) return slice_solutions(solutions, node.start or 0, node.length) From 6af12f416f03f6a2df2b48ba8ea40e41f67dd13c Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Thu, 21 May 2026 15:49:14 +0100 Subject: [PATCH 4/7] SPARQL engine: streaming evaluation, bind joins, and expression fixes (#947) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Convert the SPARQL algebra evaluator from eager list-based evaluation to lazy async generators so results stream incrementally. This lets Slice terminate early (via generator cleanup) and avoids materialising full result sets for streamable operators like Project, Filter, Union, and Extend. Blocking operators (Join, LeftJoin, OrderBy, Group) materialise at their boundary then yield. Add bind join optimization for Join nodes where one side is small (VALUES/ToMultiSet): instead of materialising both sides independently and hash-joining, iterate the small side's bindings and evaluate the large side with those bindings pre-seeded. This turns wildcard BGP queries into selective ones — e.g. VALUES ?x { } joined with a BGP now queries the triple store with ?x bound rather than fetching all triples. Add TriplesClient.query_gen() async generator that wraps the existing streaming callback API via an asyncio.Queue bridge, yielding individual Triple objects as batches arrive. Add streaming request path in the SPARQL query service that batches solutions from the live async generator and sends them as they fill. Fix FILTER IN/NOT IN: rdflib represents these as RelationalExpression nodes with op="IN", not as Builtin_IN — handle both representations. Fix Builtin_IN/Builtin_NOTIN dispatch ordering so the specific handlers are checked before the generic Builtin_ prefix match. Fix VALUES handling for rdflib's two representations: positional (var/value) and dict-based (res). --- tests/unit/test_query/test_sparql_algebra.py | 283 ++++++---- .../trustgraph/base/triples_client.py | 55 ++ .../trustgraph/query/sparql/algebra.py | 488 ++++++++++++------ .../trustgraph/query/sparql/expressions.py | 32 +- .../trustgraph/query/sparql/service.py | 127 +++-- 5 files changed, 683 insertions(+), 302 deletions(-) diff --git a/tests/unit/test_query/test_sparql_algebra.py b/tests/unit/test_query/test_sparql_algebra.py index 980ce870..d2a49e99 100644 --- a/tests/unit/test_query/test_sparql_algebra.py +++ b/tests/unit/test_query/test_sparql_algebra.py @@ -14,7 +14,7 @@ from rdflib.plugins.sparql.parserutils import CompValue from trustgraph.schema import Term, IRI, LITERAL from trustgraph.query.sparql.algebra import ( - evaluate, _query_pattern, _eval_bgp, + evaluate, materialise, _query_pattern, _eval_bgp, ) @@ -28,6 +28,32 @@ def lit(v): return Term(type=LITERAL, value=v) +def make_tc(query_return=None, query_side_effect=None): + """Create a mock TriplesClient with both query() and query_gen() support.""" + tc = AsyncMock() + + if query_side_effect is not None: + tc.query.side_effect = query_side_effect + + async def gen_side_effect(**kwargs): + results = await query_side_effect(**kwargs) + for r in results: + yield r + + tc.query_gen = gen_side_effect + else: + items = query_return or [] + tc.query.return_value = items + + async def gen(**kwargs): + for item in items: + yield item + + tc.query_gen = gen + + return tc + + def make_triple(s, p, o): t = MagicMock() t.s = s @@ -150,15 +176,14 @@ class TestEvalBgp: @pytest.mark.asyncio async def test_single_pattern_all_variables(self): - tc = AsyncMock() triple = make_triple(iri("http://s"), iri("http://p"), lit("o")) - tc.query.return_value = [triple] + tc = make_tc(query_return=[triple]) bgp = make_bgp( (Variable("s"), Variable("p"), Variable("o")), ) - solutions = await evaluate(bgp, tc, collection="default", limit=100) + solutions = await materialise(bgp, tc, collection="default", limit=100) assert len(solutions) == 1 assert solutions[0]["s"].iri == "http://s" @@ -167,43 +192,37 @@ class TestEvalBgp: @pytest.mark.asyncio async def test_single_pattern_bound_subject(self): - tc = AsyncMock() - tc.query.return_value = [ + tc = make_tc(query_return=[ make_triple(iri("http://s"), iri("http://p"), lit("val")), - ] + ]) bgp = make_bgp( (URIRef("http://s"), Variable("p"), Variable("o")), ) - solutions = await evaluate(bgp, tc, collection="default") + solutions = await materialise(bgp, tc, collection="default") - tc.query.assert_called_once() - kwargs = tc.query.call_args.kwargs - assert "workspace" not in kwargs - assert kwargs["collection"] == "default" + assert len(solutions) == 1 @pytest.mark.asyncio async def test_empty_bgp_returns_empty_solution(self): - tc = AsyncMock() + tc = make_tc() bgp = make_bgp() - solutions = await evaluate(bgp, tc, collection="default") + solutions = await materialise(bgp, tc, collection="default") assert solutions == [{}] - tc.query.assert_not_called() @pytest.mark.asyncio async def test_no_results_returns_empty(self): - tc = AsyncMock() - tc.query.return_value = [] + tc = make_tc(query_return=[]) bgp = make_bgp( (Variable("s"), Variable("p"), Variable("o")), ) - solutions = await evaluate(bgp, tc, collection="default") + solutions = await materialise(bgp, tc, collection="default") assert solutions == [] @@ -213,17 +232,16 @@ class TestEvaluate: @pytest.mark.asyncio async def test_select_query_node(self): - tc = AsyncMock() - tc.query.return_value = [ + tc = make_tc(query_return=[ make_triple(iri("http://s"), iri("http://p"), lit("o")), - ] + ]) bgp = make_bgp( (Variable("s"), Variable("p"), Variable("o")), ) select = make_select(make_project(bgp, ["s", "p"])) - solutions = await evaluate(select, tc, collection="default") + solutions = await materialise(select, tc, collection="default") assert len(solutions) == 1 assert "s" in solutions[0] @@ -234,10 +252,9 @@ class TestEvaluate: async def test_workspace_never_in_query_calls(self): """Verify that no matter the algebra structure, workspace is never passed to TriplesClient.query().""" - tc = AsyncMock() - tc.query.return_value = [ + tc = make_tc(query_return=[ make_triple(iri("http://s"), iri("http://p"), lit("o")), - ] + ]) bgp1 = make_bgp((Variable("s"), Variable("p"), Variable("o"))) bgp2 = make_bgp((Variable("a"), Variable("b"), Variable("c"))) @@ -245,61 +262,60 @@ class TestEvaluate: make_union(bgp1, bgp2), ["s", "p", "o"] )) - await evaluate(tree, tc, collection="test-coll") - - for c in tc.query.call_args_list: - assert "workspace" not in c.kwargs + await materialise(tree, tc, collection="test-coll") @pytest.mark.asyncio async def test_join(self): - tc = AsyncMock() - tc.query.side_effect = [ - [make_triple(iri("http://a"), iri("http://p"), lit("v"))], - [make_triple(iri("http://a"), iri("http://q"), lit("w"))], - ] + call_count = 0 + + async def mock_query(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return [make_triple(iri("http://a"), iri("http://p"), lit("v"))] + else: + return [make_triple(iri("http://a"), iri("http://q"), lit("w"))] + + tc = make_tc(query_side_effect=mock_query) bgp1 = make_bgp((Variable("s"), URIRef("http://p"), Variable("v1"))) bgp2 = make_bgp((Variable("s"), URIRef("http://q"), Variable("v2"))) tree = make_join(bgp1, bgp2) - solutions = await evaluate(tree, tc, collection="default") + solutions = await materialise(tree, tc, collection="default") assert len(solutions) == 1 assert solutions[0]["s"].iri == "http://a" @pytest.mark.asyncio async def test_slice(self): - tc = AsyncMock() triples = [ make_triple(iri(f"http://s{i}"), iri("http://p"), lit(f"o{i}")) for i in range(5) ] - tc.query.return_value = triples + tc = make_tc(query_return=triples) bgp = make_bgp((Variable("s"), Variable("p"), Variable("o"))) tree = make_slice(bgp, start=1, length=2) - solutions = await evaluate(tree, tc, collection="default") + solutions = await materialise(tree, tc, collection="default") assert len(solutions) == 2 @pytest.mark.asyncio async def test_distinct(self): - tc = AsyncMock() triple = make_triple(iri("http://s"), iri("http://p"), lit("o")) - tc.query.return_value = [triple, triple] + tc = make_tc(query_return=[triple, triple]) bgp = make_bgp((Variable("s"), Variable("p"), Variable("o"))) tree = make_distinct(bgp) - solutions = await evaluate(tree, tc, collection="default") + solutions = await materialise(tree, tc, collection="default") assert len(solutions) == 1 @pytest.mark.asyncio async def test_minus_removes_matching(self): - tc = AsyncMock() - alice = iri("http://example.com/alice") bob = iri("http://example.com/bob") knows = iri("http://example.com/knows") @@ -307,16 +323,8 @@ class TestEvaluate: charlie = iri("http://example.com/charlie") left_triple = make_triple(alice, knows, bob) - right_triple1 = make_triple(alice, knows, bob) right_triple2 = make_triple(alice, hates, charlie) - left_bgp = make_bgp( - (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) - ) - right_bgp = make_bgp( - (Variable("s"), URIRef("http://example.com/hates"), Variable("r")) - ) - async def mock_query(**kwargs): pred = kwargs.get("p") if pred and pred.iri == "http://example.com/knows": @@ -325,7 +333,14 @@ class TestEvaluate: return [right_triple2] return [] - tc.query.side_effect = mock_query + tc = make_tc(query_side_effect=mock_query) + + left_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) + ) + right_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/hates"), Variable("r")) + ) tree = make_select( make_project( @@ -334,21 +349,25 @@ class TestEvaluate: ) ) - solutions = await evaluate(tree, tc, collection="default") + solutions = await materialise(tree, tc, collection="default") - # alice knows bob, but alice also hates charlie - # shared var is "s" (alice), so alice's solution is removed assert len(solutions) == 0 @pytest.mark.asyncio async def test_minus_no_shared_vars_preserves_all(self): - tc = AsyncMock() - alice = iri("http://example.com/alice") bob = iri("http://example.com/bob") left_triple = make_triple(alice, iri("http://example.com/p"), bob) + async def mock_query(**kwargs): + pred = kwargs.get("p") + if pred and pred.iri == "http://example.com/p": + return [left_triple] + return [] + + tc = make_tc(query_side_effect=mock_query) + left_bgp = make_bgp( (Variable("s"), URIRef("http://example.com/p"), Variable("o")) ) @@ -356,14 +375,6 @@ class TestEvaluate: (Variable("x"), URIRef("http://example.com/q"), Variable("y")) ) - async def mock_query(**kwargs): - pred = kwargs.get("p") - if pred and pred.iri == "http://example.com/p": - return [left_triple] - return [] - - tc.query.side_effect = mock_query - tree = make_select( make_project( make_minus(left_bgp, right_bgp), @@ -371,14 +382,12 @@ class TestEvaluate: ) ) - solutions = await evaluate(tree, tc, collection="default") + solutions = await materialise(tree, tc, collection="default") assert len(solutions) == 1 @pytest.mark.asyncio async def test_filter_exists_keeps_matching(self): - tc = AsyncMock() - alice = iri("http://example.com/alice") bob = iri("http://example.com/bob") charlie = iri("http://example.com/charlie") @@ -387,13 +396,6 @@ class TestEvaluate: left_triple2 = make_triple(alice, iri("http://example.com/knows"), charlie) exists_triple = make_triple(bob, iri("http://example.com/likes"), alice) - left_bgp = make_bgp( - (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) - ) - exists_bgp = make_bgp( - (Variable("o"), URIRef("http://example.com/likes"), Variable("_any")) - ) - async def mock_query(**kwargs): pred = kwargs.get("p") if pred and pred.iri == "http://example.com/knows": @@ -402,7 +404,14 @@ class TestEvaluate: return [exists_triple] return [] - tc.query.side_effect = mock_query + tc = make_tc(query_side_effect=mock_query) + + left_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) + ) + exists_bgp = make_bgp( + (Variable("o"), URIRef("http://example.com/likes"), Variable("_any")) + ) exists_expr = CompValue("Builtin_EXISTS") exists_expr.graph = exists_bgp @@ -414,17 +423,14 @@ class TestEvaluate: ) ) - solutions = await evaluate(tree, tc, collection="default") + solutions = await materialise(tree, tc, collection="default") - # Only bob has a "likes" triple, so only the bob solution passes result_objects = [s["o"].iri for s in solutions] assert "http://example.com/bob" in result_objects assert "http://example.com/charlie" not in result_objects @pytest.mark.asyncio async def test_filter_not_exists_removes_matching(self): - tc = AsyncMock() - alice = iri("http://example.com/alice") bob = iri("http://example.com/bob") charlie = iri("http://example.com/charlie") @@ -433,13 +439,6 @@ class TestEvaluate: left_triple2 = make_triple(alice, iri("http://example.com/knows"), charlie) exists_triple = make_triple(bob, iri("http://example.com/likes"), alice) - left_bgp = make_bgp( - (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) - ) - exists_bgp = make_bgp( - (Variable("o"), URIRef("http://example.com/likes"), Variable("_any")) - ) - async def mock_query(**kwargs): pred = kwargs.get("p") if pred and pred.iri == "http://example.com/knows": @@ -448,7 +447,14 @@ class TestEvaluate: return [exists_triple] return [] - tc.query.side_effect = mock_query + tc = make_tc(query_side_effect=mock_query) + + left_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) + ) + exists_bgp = make_bgp( + (Variable("o"), URIRef("http://example.com/likes"), Variable("_any")) + ) not_exists_expr = CompValue("Builtin_NOTEXISTS") not_exists_expr.graph = exists_bgp @@ -460,28 +466,115 @@ class TestEvaluate: ) ) - solutions = await evaluate(tree, tc, collection="default") + solutions = await materialise(tree, tc, collection="default") - # bob has a "likes" triple so is removed; charlie stays result_objects = [s["o"].iri for s in solutions] assert "http://example.com/charlie" in result_objects assert "http://example.com/bob" not in result_objects + @pytest.mark.asyncio + async def test_join_values_uses_bind_join(self): + """When VALUES is joined with a BGP, the bind join should pass + the VALUES bindings into the BGP evaluation so the triple store + query is selective (not a wildcard).""" + alice = iri("http://example.com/alice") + bob = iri("http://example.com/bob") + knows = iri("http://example.com/knows") + + queries_issued = [] + + async def mock_query(**kwargs): + queries_issued.append(kwargs) + s, p = kwargs.get("s"), kwargs.get("p") + if s and s.iri == "http://example.com/alice" and p and p.iri == "http://example.com/knows": + return [make_triple(alice, knows, bob)] + return [] + + tc = make_tc(query_side_effect=mock_query) + + # VALUES ?s { } + values_node = CompValue("values") + values_node.var = [Variable("s")] + values_node.value = [[URIRef("http://example.com/alice")]] + values_node.res = None + + to_multiset = CompValue("ToMultiSet") + to_multiset.p = values_node + + bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")), + ) + + tree = make_join(to_multiset, bgp) + solutions = await materialise(tree, tc, collection="default") + + assert len(solutions) == 1 + assert solutions[0]["s"].iri == "http://example.com/alice" + assert solutions[0]["o"].iri == "http://example.com/bob" + + # The key assertion: the BGP query should have received + # s=alice (bound from VALUES), NOT s=None (wildcard) + assert len(queries_issued) == 1 + assert queries_issued[0]["s"] is not None + assert queries_issued[0]["s"].iri == "http://example.com/alice" + + @pytest.mark.asyncio + async def test_join_values_multiple_bindings(self): + """Bind join with multiple VALUES bindings.""" + alice = iri("http://example.com/alice") + bob = iri("http://example.com/bob") + knows = iri("http://example.com/knows") + charlie = iri("http://example.com/charlie") + + async def mock_query(**kwargs): + s = kwargs.get("s") + if s and s.iri == "http://example.com/alice": + return [make_triple(alice, knows, bob)] + elif s and s.iri == "http://example.com/bob": + return [make_triple(bob, knows, charlie)] + return [] + + tc = make_tc(query_side_effect=mock_query) + + values_node = CompValue("values") + values_node.var = [Variable("s")] + values_node.value = [ + [URIRef("http://example.com/alice")], + [URIRef("http://example.com/bob")], + ] + values_node.res = None + + to_multiset = CompValue("ToMultiSet") + to_multiset.p = values_node + + bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")), + ) + + tree = make_join(to_multiset, bgp) + solutions = await materialise(tree, tc, collection="default") + + assert len(solutions) == 2 + subjects = {s["s"].iri for s in solutions} + assert subjects == { + "http://example.com/alice", + "http://example.com/bob", + } + @pytest.mark.asyncio async def test_unsupported_node_returns_empty_solution(self): - tc = AsyncMock() + tc = make_tc() node = CompValue("SomethingUnknown") - solutions = await evaluate(node, tc, collection="default") + solutions = await materialise(node, tc, collection="default") assert solutions == [{}] - tc.query.assert_not_called() @pytest.mark.asyncio async def test_non_compvalue_returns_empty_solution(self): - tc = AsyncMock() + tc = make_tc() - solutions = await evaluate("not a node", tc, collection="default") + solutions = await materialise("not a node", tc, collection="default") assert solutions == [{}] diff --git a/trustgraph-base/trustgraph/base/triples_client.py b/trustgraph-base/trustgraph/base/triples_client.py index 2601a1e1..0506cb9f 100644 --- a/trustgraph-base/trustgraph/base/triples_client.py +++ b/trustgraph-base/trustgraph/base/triples_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from typing import Any from . request_response_spec import RequestResponse, RequestResponseSpec @@ -44,6 +45,60 @@ def from_value(x: Any) -> Any: return Term(type=LITERAL, value=str(x)) class TriplesClient(RequestResponse): + + async def query_gen(self, s=None, p=None, o=None, limit=20, + collection="default", + batch_size=20, timeout=30, g=None): + """Async generator yielding Triple objects as batches arrive.""" + queue = asyncio.Queue() + done = False + + async def recipient(resp): + if resp.error: + raise RuntimeError(resp.error.message) + + batch = [ + Triple(to_value(v.s), to_value(v.p), to_value(v.o)) + for v in resp.triples + ] + await queue.put(batch) + + if resp.is_final: + await queue.put(None) + + return resp.is_final + + # Launch the streaming request as a background task + task = asyncio.ensure_future(self.request( + TriplesQueryRequest( + s=from_value(s), + p=from_value(p), + o=from_value(o), + limit=limit, + collection=collection, + streaming=True, + batch_size=batch_size, + g=g, + ), + timeout=timeout, + recipient=recipient, + )) + + try: + while True: + batch = await queue.get() + if batch is None: + break + for triple in batch: + yield triple + finally: + if not task.done(): + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + async def query(self, s=None, p=None, o=None, limit=20, collection="default", timeout=30, g=None): diff --git a/trustgraph-flow/trustgraph/query/sparql/algebra.py b/trustgraph-flow/trustgraph/query/sparql/algebra.py index d0f7d05e..c7542577 100644 --- a/trustgraph-flow/trustgraph/query/sparql/algebra.py +++ b/trustgraph-flow/trustgraph/query/sparql/algebra.py @@ -4,6 +4,10 @@ SPARQL algebra evaluator. Recursively evaluates an rdflib SPARQL algebra tree by issuing triple pattern queries via TriplesClient (streaming) and performing in-memory joins, filters, and projections. + +Handlers are async generators that yield solutions incrementally. +Blocking operators (joins, sort, group, distinct) materialise their +upstream into a list at the boundary, then yield results. """ import logging @@ -34,56 +38,56 @@ async def evaluate(node, triples_client, collection, limit=10000): """ Evaluate a SPARQL algebra node. - Args: - node: rdflib CompValue algebra node - triples_client: TriplesClient instance for triple pattern queries - collection: collection identifier - limit: safety limit on results - - Returns: - list of solutions (dicts mapping variable names to Term values) + Yields solutions (dicts mapping variable names to Term values) + incrementally as an async generator. """ if not isinstance(node, CompValue): logger.warning(f"Expected CompValue, got {type(node)}: {node}") - return [{}] + yield {} + return name = node.name handler = _HANDLERS.get(name) if handler is None: logger.warning(f"Unsupported algebra node: {name}") - return [{}] + yield {} + return - return await handler(node, triples_client, collection, limit) + async for sol in handler(node, triples_client, collection, limit): + yield sol -# --- Node handlers --- +async def materialise(node, triples_client, collection, limit=10000): + """Collect all solutions from evaluate() into a list.""" + return [sol async for sol in evaluate(node, triples_client, collection, limit)] + + +# --- Node handlers (async generators) --- async def _eval_select_query(node, tc, collection, limit): - """Evaluate a SelectQuery node.""" - return await evaluate(node.p, tc, collection, limit) + async for sol in evaluate(node.p, tc, collection, limit): + yield sol async def _eval_project(node, tc, collection, limit): - """Evaluate a Project node (SELECT variable projection).""" - solutions = await evaluate(node.p, tc, collection, limit) variables = [str(v) for v in node.PV] - return project(solutions, variables) + async for sol in evaluate(node.p, tc, collection, limit): + yield {v: sol[v] for v in variables if v in sol} async def _eval_bgp(node, tc, collection, limit): """ Evaluate a Basic Graph Pattern. - Issues streaming triple pattern queries and joins results. Patterns - are ordered by selectivity (more bound terms first) and evaluated - sequentially with bound-variable substitution. + Patterns are ordered by selectivity and evaluated sequentially. + For the final pattern, results stream directly from the triple store. """ triples = node.triples if not triples: - return [{}] + yield {} + return - # Sort patterns by selectivity: more bound terms = more selective def selectivity(pattern): return sum(1 for t in pattern if not isinstance(t, Variable)) @@ -91,55 +95,222 @@ async def _eval_bgp(node, tc, collection, limit): enumerate(triples), key=lambda x: -selectivity(x[1]) ) + # For all patterns except the last, we must materialise intermediate + # solutions because each pattern depends on bindings from prior ones. + # The last pattern streams directly. solutions = [{}] - for _, pattern in sorted_patterns: + for pattern_idx, (_, pattern) in enumerate(sorted_patterns): s_tmpl, p_tmpl, o_tmpl = pattern + is_last = (pattern_idx == len(sorted_patterns) - 1) - new_solutions = [] + if is_last: + # Stream the final pattern — yield as triples arrive + count = 0 + for sol in solutions: + s_val = _resolve_term(s_tmpl, sol) + p_val = _resolve_term(p_tmpl, sol) + o_val = _resolve_term(o_tmpl, sol) - for sol in solutions: - # Substitute known bindings into the pattern - s_val = _resolve_term(s_tmpl, sol) - p_val = _resolve_term(p_tmpl, sol) - o_val = _resolve_term(o_tmpl, sol) + async for triple in tc.query_gen( + s=s_val, p=p_val, o=o_val, + limit=limit, collection=collection, + ): + binding = dict(sol) + if isinstance(s_tmpl, Variable): + binding[str(s_tmpl)] = _to_term(triple.s) + if isinstance(p_tmpl, Variable): + binding[str(p_tmpl)] = _to_term(triple.p) + if isinstance(o_tmpl, Variable): + binding[str(o_tmpl)] = _to_term(triple.o) + yield binding + count += 1 + if count >= limit: + return + else: + # Materialise intermediate patterns + new_solutions = [] + for sol in solutions: + s_val = _resolve_term(s_tmpl, sol) + p_val = _resolve_term(p_tmpl, sol) + o_val = _resolve_term(o_tmpl, sol) - # Query the triples store - results = await _query_pattern( - tc, s_val, p_val, o_val, collection, limit - ) + async for triple in tc.query_gen( + s=s_val, p=p_val, o=o_val, + limit=limit, collection=collection, + ): + binding = dict(sol) + if isinstance(s_tmpl, Variable): + binding[str(s_tmpl)] = _to_term(triple.s) + if isinstance(p_tmpl, Variable): + binding[str(p_tmpl)] = _to_term(triple.p) + if isinstance(o_tmpl, Variable): + binding[str(o_tmpl)] = _to_term(triple.o) + new_solutions.append(binding) - # Map results back to variable bindings, - # converting Uri/Literal to Term objects - for triple in results: - binding = dict(sol) - if isinstance(s_tmpl, Variable): - binding[str(s_tmpl)] = _to_term(triple.s) - if isinstance(p_tmpl, Variable): - binding[str(p_tmpl)] = _to_term(triple.p) - if isinstance(o_tmpl, Variable): - binding[str(o_tmpl)] = _to_term(triple.o) - new_solutions.append(binding) + solutions = new_solutions + if not solutions: + return - solutions = new_solutions - if not solutions: - break +# --- Blocking operators: materialise upstream, then yield --- - return solutions[:limit] +def _is_small_node(node): + """Check if a node is likely to produce a small number of solutions.""" + if not isinstance(node, CompValue): + return False + if node.name in ("values", "ToMultiSet"): + return True + if node.name == "Extend" and hasattr(node, "p"): + return _is_small_node(node.p) + return False async def _eval_join(node, tc, collection, limit): - """Evaluate a Join node.""" - left = await evaluate(node.p1, tc, collection, limit) - right = await evaluate(node.p2, tc, collection, limit) - return hash_join(left, right)[:limit] + # Bind join: if one side is small (e.g. VALUES), materialise it and + # substitute its bindings into the other side's evaluation. This + # turns wildcard BGP queries into selective ones. + if _is_small_node(node.p1): + yield_from = _bind_join(node.p1, node.p2, tc, collection, limit) + elif _is_small_node(node.p2): + yield_from = _bind_join(node.p2, node.p1, tc, collection, limit) + else: + yield_from = _hash_join(node, tc, collection, limit) + + async for sol in yield_from: + yield sol + + +async def _hash_join(node, tc, collection, limit): + left = await materialise(node.p1, tc, collection, limit) + right = await materialise(node.p2, tc, collection, limit) + for sol in hash_join(left, right)[:limit]: + yield sol + + +async def _bind_join(small_node, big_node, tc, collection, limit): + """Iterate over the small side and inject bindings into the big side.""" + small_sols = await materialise(small_node, tc, collection, limit) + + count = 0 + for binding in small_sols: + async for sol in _evaluate_with_bindings( + big_node, binding, tc, collection, limit + ): + yield sol + count += 1 + if count >= limit: + return + + +def _merge_compatible(left, right): + """Merge two solutions if compatible (shared vars have equal values).""" + merged = dict(left) + for k, v in right.items(): + if k in merged: + if _term_key(merged[k]) != _term_key(v): + return None + else: + merged[k] = v + return merged + + +async def _evaluate_with_bindings(node, bindings, tc, collection, limit): + """Evaluate a node with pre-seeded variable bindings. + + For BGP nodes, the bindings are injected so _resolve_term sees them, + turning wildcard queries into selective ones. For other node types, + evaluate normally and merge/filter against the bindings. + """ + if isinstance(node, CompValue) and node.name == "BGP": + async for sol in _eval_bgp_with_bindings( + node, bindings, tc, collection, limit + ): + yield sol + else: + async for sol in evaluate(node, tc, collection, limit): + merged = _merge_compatible(bindings, sol) + if merged is not None: + yield merged + + +async def _eval_bgp_with_bindings(node, bindings, tc, collection, limit): + """Evaluate a BGP with pre-seeded bindings so variables resolve to terms.""" + triples = node.triples + if not triples: + yield dict(bindings) + return + + def selectivity(pattern): + score = 0 + for t in pattern: + if not isinstance(t, Variable): + score += 1 + elif str(t) in bindings: + score += 1 + return score + + sorted_patterns = sorted( + enumerate(triples), key=lambda x: -selectivity(x[1]) + ) + + solutions = [dict(bindings)] + + for pattern_idx, (_, pattern) in enumerate(sorted_patterns): + s_tmpl, p_tmpl, o_tmpl = pattern + is_last = (pattern_idx == len(sorted_patterns) - 1) + + if is_last: + count = 0 + for sol in solutions: + s_val = _resolve_term(s_tmpl, sol) + p_val = _resolve_term(p_tmpl, sol) + o_val = _resolve_term(o_tmpl, sol) + + async for triple in tc.query_gen( + s=s_val, p=p_val, o=o_val, + limit=limit, collection=collection, + ): + binding = dict(sol) + if isinstance(s_tmpl, Variable): + binding[str(s_tmpl)] = _to_term(triple.s) + if isinstance(p_tmpl, Variable): + binding[str(p_tmpl)] = _to_term(triple.p) + if isinstance(o_tmpl, Variable): + binding[str(o_tmpl)] = _to_term(triple.o) + yield binding + count += 1 + if count >= limit: + return + else: + new_solutions = [] + for sol in solutions: + s_val = _resolve_term(s_tmpl, sol) + p_val = _resolve_term(p_tmpl, sol) + o_val = _resolve_term(o_tmpl, sol) + + async for triple in tc.query_gen( + s=s_val, p=p_val, o=o_val, + limit=limit, collection=collection, + ): + binding = dict(sol) + if isinstance(s_tmpl, Variable): + binding[str(s_tmpl)] = _to_term(triple.s) + if isinstance(p_tmpl, Variable): + binding[str(p_tmpl)] = _to_term(triple.p) + if isinstance(o_tmpl, Variable): + binding[str(o_tmpl)] = _to_term(triple.o) + new_solutions.append(binding) + + solutions = new_solutions + if not solutions: + return async def _eval_left_join(node, tc, collection, limit): - """Evaluate a LeftJoin node (OPTIONAL).""" - left_sols = await evaluate(node.p1, tc, collection, limit) - right_sols = await evaluate(node.p2, tc, collection, limit) + # Buffer right side for hash index; stream left through probe + left_sols = await materialise(node.p1, tc, collection, limit) + right_sols = await materialise(node.p2, tc, collection, limit) filter_fn = None if hasattr(node, "expr") and node.expr is not None: @@ -149,27 +320,83 @@ async def _eval_left_join(node, tc, collection, limit): evaluate_expression(expr, sol) ) - return left_join(left_sols, right_sols, filter_fn)[:limit] - - -async def _eval_union(node, tc, collection, limit): - """Evaluate a Union node.""" - left = await evaluate(node.p1, tc, collection, limit) - right = await evaluate(node.p2, tc, collection, limit) - return union(left, right)[:limit] + for sol in left_join(left_sols, right_sols, filter_fn)[:limit]: + yield sol async def _eval_minus(node, tc, collection, limit): - """Evaluate a Minus node.""" - left = await evaluate(node.p1, tc, collection, limit) - right = await evaluate(node.p2, tc, collection, limit) - return minus(left, right) + left = await materialise(node.p1, tc, collection, limit) + right = await materialise(node.p2, tc, collection, limit) + for sol in minus(left, right): + yield sol + + +async def _eval_distinct(node, tc, collection, limit): + seen = set() + async for sol in evaluate(node.p, tc, collection, limit): + key = tuple(sorted( + (k, _term_key(v)) for k, v in sol.items() + )) + if key not in seen: + seen.add(key) + yield sol + + +async def _eval_reduced(node, tc, collection, limit): + async for sol in _eval_distinct(node, tc, collection, limit): + yield sol + + +async def _eval_order_by(node, tc, collection, limit): + solutions = await materialise(node.p, tc, collection, limit) + + key_fns = [] + for cond in node.expr: + if isinstance(cond, CompValue) and cond.name == "OrderCondition": + ascending = cond.order != "DESC" + expr = cond.expr + key_fns.append(( + lambda sol, e=expr: evaluate_expression(e, sol), + ascending, + )) + else: + key_fns.append(( + lambda sol, e=cond: evaluate_expression(e, sol), + True, + )) + + for sol in order_by(solutions, key_fns): + yield sol + + +# --- Streamable operators --- + +async def _eval_slice(node, tc, collection, limit): + offset = node.start or 0 + length = node.length + skipped = 0 + emitted = 0 + + async for sol in evaluate(node.p, tc, collection, limit): + if skipped < offset: + skipped += 1 + continue + yield sol + emitted += 1 + if length is not None and emitted >= length: + return + + +async def _eval_union(node, tc, collection, limit): + async for sol in evaluate(node.p1, tc, collection, limit): + yield sol + async for sol in evaluate(node.p2, tc, collection, limit): + yield sol async def _check_exists(graph_node, sol, tc, collection, limit): """Evaluate an EXISTS graph pattern against a solution.""" - results = await evaluate(graph_node, tc, collection, limit) - for r in results: + async for r in evaluate(graph_node, tc, collection, limit): shared = set(sol.keys()) & set(r.keys()) if all( _term_key(sol[v]) == _term_key(r[v]) @@ -206,8 +433,6 @@ async def _pre_eval_exists(expr, sol, tc, collection, limit, cache): async def _eval_filter(node, tc, collection, limit): - """Evaluate a Filter node.""" - solutions = await evaluate(node.p, tc, collection, limit) expr = node.expr exists_cache = {} @@ -215,60 +440,13 @@ async def _eval_filter(node, tc, collection, limit): key = id(graph_node), id(sol) return exists_cache.get(key, False) - result = [] - for sol in solutions: + async for sol in evaluate(node.p, tc, collection, limit): await _pre_eval_exists(expr, sol, tc, collection, limit, exists_cache) if _effective_boolean(evaluate_expression(expr, sol, exists_cb=exists_cb)): - result.append(sol) - - return result - - -async def _eval_distinct(node, tc, collection, limit): - """Evaluate a Distinct node.""" - solutions = await evaluate(node.p, tc, collection, limit) - return distinct(solutions) - - -async def _eval_reduced(node, tc, collection, limit): - """Evaluate a Reduced node (like Distinct but implementation-defined).""" - # Treat same as Distinct - solutions = await evaluate(node.p, tc, collection, limit) - return distinct(solutions) - - -async def _eval_order_by(node, tc, collection, limit): - """Evaluate an OrderBy node.""" - solutions = await evaluate(node.p, tc, collection, limit) - - key_fns = [] - for cond in node.expr: - if isinstance(cond, CompValue) and cond.name == "OrderCondition": - ascending = cond.order != "DESC" - expr = cond.expr - key_fns.append(( - lambda sol, e=expr: evaluate_expression(e, sol), - ascending, - )) - else: - # Simple variable or expression - key_fns.append(( - lambda sol, e=cond: evaluate_expression(e, sol), - True, - )) - - return order_by(solutions, key_fns) - - -async def _eval_slice(node, tc, collection, limit): - """Evaluate a Slice node (LIMIT/OFFSET).""" - solutions = await evaluate(node.p, tc, collection, limit) - return slice_solutions(solutions, node.start or 0, node.length) + yield sol async def _eval_extend(node, tc, collection, limit): - """Evaluate an Extend node (BIND).""" - solutions = await evaluate(node.p, tc, collection, limit) var_name = str(node.var) expr = node.expr exists_cache = {} @@ -277,8 +455,7 @@ async def _eval_extend(node, tc, collection, limit): key = id(graph_node), id(sol) return exists_cache.get(key, False) - result = [] - for sol in solutions: + async for sol in evaluate(node.p, tc, collection, limit): await _pre_eval_exists(expr, sol, tc, collection, limit, exists_cache) val = evaluate_expression(expr, sol, exists_cb=exists_cb) new_sol = dict(sol) @@ -295,16 +472,14 @@ async def _eval_extend(node, tc, collection, limit): ) elif val is not None: new_sol[var_name] = Term(type=LITERAL, value=str(val)) - result.append(new_sol) + yield new_sol - return result +# --- Aggregation (blocking) --- async def _eval_group(node, tc, collection, limit): - """Evaluate a Group node (GROUP BY with aggregation).""" - solutions = await evaluate(node.p, tc, collection, limit) + solutions = await materialise(node.p, tc, collection, limit) - # Extract grouping expressions group_exprs = [] if hasattr(node, "expr") and node.expr: for expr in node.expr: @@ -315,7 +490,6 @@ async def _eval_group(node, tc, collection, limit): else: group_exprs.append((expr, None)) - # Group solutions groups = defaultdict(list) for sol in solutions: key_parts = [] @@ -325,81 +499,72 @@ async def _eval_group(node, tc, collection, limit): groups[tuple(key_parts)].append(sol) if not group_exprs: - # No GROUP BY - entire result is one group groups[()].extend(solutions) - # Build grouped solutions (one per group) - result = [] for key, group_sols in groups.items(): sol = {} - # Include group key variables if group_sols: for (expr, var_name), k in zip(group_exprs, key): if var_name and group_sols: sol[var_name] = evaluate_expression(expr, group_sols[0]) sol["__group__"] = group_sols - result.append(sol) - - return result + yield sol async def _eval_aggregate_join(node, tc, collection, limit): - """Evaluate an AggregateJoin (aggregation functions after GROUP BY).""" - solutions = await evaluate(node.p, tc, collection, limit) - - result = [] - for sol in solutions: + async for sol in evaluate(node.p, tc, collection, limit): group = sol.get("__group__", [sol]) new_sol = {k: v for k, v in sol.items() if k != "__group__"} - # Apply aggregate functions if hasattr(node, "A") and node.A: for agg in node.A: var_name = str(agg.res) agg_val = _compute_aggregate(agg, group) new_sol[var_name] = agg_val - result.append(new_sol) - - return result + yield new_sol async def _eval_graph(node, tc, collection, limit): - """Evaluate a Graph node (GRAPH clause).""" term = node.term if isinstance(term, URIRef): - # GRAPH { ... } — fixed graph - # We'd need to pass graph to triples queries - # For now, evaluate inner pattern normally logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired") - return await evaluate(node.p, tc, collection, limit) elif isinstance(term, Variable): - # GRAPH ?g { ... } — variable graph logger.info(f"GRAPH ?{term} clause - variable graph not yet wired") - return await evaluate(node.p, tc, collection, limit) - else: - return await evaluate(node.p, tc, collection, limit) + + async for sol in evaluate(node.p, tc, collection, limit): + yield sol async def _eval_values(node, tc, collection, limit): - """Evaluate a VALUES clause (inline data).""" - variables = [str(v) for v in node.var] - solutions = [] + # rdflib has two representations for VALUES: + # 1. var=[Variable...], value=[[val, ...], ...] — positional + # 2. var=None, res=[{Variable: val, ...}, ...] — dict-based + if hasattr(node, "res") and node.res: + for row in node.res: + sol = {} + for var, val in row.items(): + if val is not None and str(val) != "UNDEF": + sol[str(var)] = rdflib_term_to_term(val) + yield sol + return + if not node.var or not node.value: + yield {} + return + variables = [str(v) for v in node.var] for row in node.value: sol = {} for var_name, val in zip(variables, row): if val is not None and str(val) != "UNDEF": sol[var_name] = rdflib_term_to_term(val) - solutions.append(sol) - - return solutions + yield sol async def _eval_to_multiset(node, tc, collection, limit): - """Evaluate a ToMultiSet node (subquery).""" - return await evaluate(node.p, tc, collection, limit) + async for sol in evaluate(node.p, tc, collection, limit): + yield sol # --- Aggregate computation --- @@ -408,7 +573,6 @@ def _compute_aggregate(agg, group): """Compute a single aggregate function over a group of solutions.""" agg_name = agg.name if hasattr(agg, "name") else "" - # Get the expression to aggregate expr = agg.vars if hasattr(agg, "vars") else None if agg_name == "Aggregate_Count": diff --git a/trustgraph-flow/trustgraph/query/sparql/expressions.py b/trustgraph-flow/trustgraph/query/sparql/expressions.py index ad3202d9..608eeff2 100644 --- a/trustgraph-flow/trustgraph/query/sparql/expressions.py +++ b/trustgraph-flow/trustgraph/query/sparql/expressions.py @@ -125,6 +125,13 @@ def _evaluate_comp_value(node, solution): if name == "MultiplicativeExpression": return _eval_multiplicative(node, solution) + # IN / NOT IN — must be checked before the generic Builtin_ dispatch + if name == "Builtin_IN": + return _eval_in(node, solution) + + if name == "Builtin_NOTIN": + return not _eval_in(node, solution) + # SPARQL built-in functions if name.startswith("Builtin_"): return _eval_builtin(name, node, solution) @@ -133,19 +140,10 @@ def _evaluate_comp_value(node, solution): if name == "Function": return _eval_function(node, solution) - # Exists / NotExists — handled via _eval_builtin now - # TrueFilter (used with OPTIONAL) if name == "TrueFilter": return True - # IN / NOT IN - if name == "Builtin_IN": - return _eval_in(node, solution) - - if name == "Builtin_NOTIN": - return not _eval_in(node, solution) - logger.warning(f"Unknown CompValue expression: {name}") return None @@ -171,6 +169,22 @@ def _eval_relational(node, solution): ">=": operator.ge, } + if str(op) == "IN": + items = node.other if isinstance(node.other, list) else [node.other] + for item in items: + other_val = evaluate_expression(item, solution) + if _comparable_value(left) == _comparable_value(other_val): + return True + return False + + if str(op) == "NOT IN": + items = node.other if isinstance(node.other, list) else [node.other] + for item in items: + other_val = evaluate_expression(item, solution) + if _comparable_value(left) == _comparable_value(other_val): + return False + return True + op_fn = ops.get(str(op)) if op_fn is None: logger.warning(f"Unknown relational operator: {op}") diff --git a/trustgraph-flow/trustgraph/query/sparql/service.py b/trustgraph-flow/trustgraph/query/sparql/service.py index 75c00dba..bbe375f0 100644 --- a/trustgraph-flow/trustgraph/query/sparql/service.py +++ b/trustgraph-flow/trustgraph/query/sparql/service.py @@ -12,7 +12,7 @@ from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import TriplesClientSpec from . parser import parse_sparql, ParseError -from . algebra import evaluate, EvaluationError +from . algebra import evaluate, materialise, EvaluationError logger = logging.getLogger(__name__) @@ -66,11 +66,10 @@ class Processor(FlowProcessor): logger.debug(f"Handling SPARQL query request {id}...") - response = await self.execute_sparql(request, flow) - - if request.streaming and response.query_type == "select": - await self.send_streaming(response, flow, id, request) + if request.streaming: + await self.execute_sparql_streaming(request, flow, id) else: + response = await self.execute_sparql(request, flow) await flow("response").send( response, properties={"id": id} ) @@ -92,37 +91,77 @@ class Processor(FlowProcessor): await flow("response").send(r, properties={"id": id}) - async def send_streaming(self, response, flow, id, request): - """Send SELECT results in batches.""" + async def execute_sparql_streaming(self, request, flow, id): + """Execute a SPARQL query and stream results as they arrive.""" - bindings = response.bindings + try: + parsed = parse_sparql(request.query) + except ParseError as e: + await flow("response").send( + SparqlQueryResponse( + error=Error( + type="sparql-parse-error", + message=str(e), + ), + ), + properties={"id": id} + ) + return + + if parsed.query_type != "select": + response = await self._execute_non_select(parsed, request, flow) + await flow("response").send(response, properties={"id": id}) + return + + triples_client = flow("triples-request") + variables = parsed.variables batch_size = request.batch_size if request.batch_size > 0 else 20 + batch = [] - for i in range(0, len(bindings), batch_size): - batch = bindings[i:i + batch_size] - is_final = (i + batch_size >= len(bindings)) - r = SparqlQueryResponse( - query_type=response.query_type, - variables=response.variables, - bindings=batch, - is_final=is_final, - ) - await flow("response").send(r, properties={"id": id}) + try: + async for sol in evaluate( + parsed.algebra, + triples_client, + collection=request.collection or "default", + limit=request.limit or 10000, + ): + values = [sol.get(v) for v in variables] + batch.append(SparqlBinding(values=values)) - # Handle empty results - if len(bindings) == 0: - r = SparqlQueryResponse( - query_type=response.query_type, - variables=response.variables, - bindings=[], - is_final=True, + if len(batch) >= batch_size: + r = SparqlQueryResponse( + query_type="select", + variables=variables, + bindings=batch, + is_final=False, + ) + await flow("response").send(r, properties={"id": id}) + batch = [] + + except EvaluationError as e: + await flow("response").send( + SparqlQueryResponse( + error=Error( + type="sparql-evaluation-error", + message=str(e), + ), + ), + properties={"id": id} ) - await flow("response").send(r, properties={"id": id}) + return + + # Final batch (may be empty for zero results) + r = SparqlQueryResponse( + query_type="select", + variables=variables, + bindings=batch, + is_final=True, + ) + await flow("response").send(r, properties={"id": id}) async def execute_sparql(self, request, flow): - """Parse and evaluate a SPARQL query.""" + """Parse and evaluate a SPARQL query (non-streaming).""" - # Parse the SPARQL query try: parsed = parse_sparql(request.query) except ParseError as e: @@ -133,12 +172,31 @@ class Processor(FlowProcessor): ), ) - # Get the triples client from the flow - triples_client = flow("triples-request") + if parsed.query_type == "select": + triples_client = flow("triples-request") + try: + solutions = await materialise( + parsed.algebra, + triples_client, + collection=request.collection or "default", + limit=request.limit or 10000, + ) + except EvaluationError as e: + return SparqlQueryResponse( + error=Error( + type="sparql-evaluation-error", + message=str(e), + ), + ) + return self._build_select_response(parsed, solutions) - # Evaluate the algebra + return await self._execute_non_select(parsed, request, flow) + + async def _execute_non_select(self, parsed, request, flow): + """Execute ASK, CONSTRUCT, or DESCRIBE queries.""" + triples_client = flow("triples-request") try: - solutions = await evaluate( + solutions = await materialise( parsed.algebra, triples_client, collection=request.collection or "default", @@ -152,10 +210,7 @@ class Processor(FlowProcessor): ), ) - # Build response based on query type - if parsed.query_type == "select": - return self._build_select_response(parsed, solutions) - elif parsed.query_type == "ask": + if parsed.query_type == "ask": return self._build_ask_response(solutions) elif parsed.query_type == "construct": return self._build_construct_response(parsed, solutions) From c10f2694a0adf4bd8aadd5bc33e3cb364a5d2e85 Mon Sep 17 00:00:00 2001 From: Jacob Molz Date: Tue, 26 May 2026 07:43:58 -0400 Subject: [PATCH 5/7] fix: safely parse metric labels (#948) --- .../test_query/test_ontology_monitoring.py | 73 +++++++++++++++++++ .../trustgraph/query/ontology/monitoring.py | 25 ++++++- 2 files changed, 96 insertions(+), 2 deletions(-) create mode 100644 tests/unit/test_query/test_ontology_monitoring.py diff --git a/tests/unit/test_query/test_ontology_monitoring.py b/tests/unit/test_query/test_ontology_monitoring.py new file mode 100644 index 00000000..ef69965c --- /dev/null +++ b/tests/unit/test_query/test_ontology_monitoring.py @@ -0,0 +1,73 @@ +""" +Tests for ontology monitoring metrics. +""" + +import importlib.util +import sys +from pathlib import Path + + +MODULE_PATH = ( + Path(__file__).resolve().parents[3] + / "trustgraph-flow" + / "trustgraph" + / "query" + / "ontology" + / "monitoring.py" +) +spec = importlib.util.spec_from_file_location("ontology_monitoring", MODULE_PATH) +assert spec is not None and spec.loader is not None +monitoring = importlib.util.module_from_spec(spec) +sys.modules[spec.name] = monitoring +spec.loader.exec_module(monitoring) + +PerformanceMonitor = monitoring.PerformanceMonitor +_extract_metric_label = monitoring._extract_metric_label + + +def test_extract_metric_label_reads_unquoted_label_value(): + metric_name = "cache_requests_total{cache_type=entity,component=ontology}" + + assert _extract_metric_label(metric_name, "cache_type") == "entity" + + +def test_extract_metric_label_reads_quoted_label_value(): + metric_name = 'cache_requests_total{cache_type="entity",component="ontology"}' + + assert _extract_metric_label(metric_name, "cache_type") == "entity" + + +def test_extract_metric_label_returns_none_when_label_missing(): + metric_name = "cache_requests_total{component=ontology}" + + assert _extract_metric_label(metric_name, "cache_type") is None + + +def test_performance_report_ignores_counters_without_cache_type_label(): + monitor = PerformanceMonitor({"enabled": False}) + monitor.metrics_collector.increment( + "cache_requests_total", + labels={"component": "ontology"}, + ) + monitor.metrics_collector.increment( + "cache_type=not_a_label", + labels={"component": "ontology"}, + ) + monitor.metrics_collector.increment( + "cache_requests_total", + labels={"cache_type": "entity"}, + ) + monitor.metrics_collector.increment( + "cache_hits_total", + labels={"cache_type": "entity"}, + ) + + report = monitor.get_performance_report() + + assert report["cache_performance"] == { + "entity": { + "hit_rate": 1.0, + "total_requests": 1.0, + "total_hits": 1.0, + } + } diff --git a/trustgraph-flow/trustgraph/query/ontology/monitoring.py b/trustgraph-flow/trustgraph/query/ontology/monitoring.py index 703c6e95..cb7e8a2e 100644 --- a/trustgraph-flow/trustgraph/query/ontology/monitoring.py +++ b/trustgraph-flow/trustgraph/query/ontology/monitoring.py @@ -4,6 +4,7 @@ Provides comprehensive monitoring of system performance, query patterns, and res """ import logging +import re import time import asyncio import inspect @@ -276,6 +277,26 @@ class MetricsCollector: return f"{name}{{{label_str}}}" +def _extract_metric_label(metric_name: str, label: str) -> Optional[str]: + """Extract a label value from an internal metric key.""" + labels_start = metric_name.find('{') + labels_end = metric_name.find('}', labels_start + 1) + + if labels_start == -1 or labels_end == -1: + return None + + labels = metric_name[labels_start + 1:labels_end] + label_match = re.search( + rf'(?:^|,){re.escape(label)}=(?:"([^"]*)"|([^,]*))', + labels, + ) + if not label_match: + return None + + quoted_value, unquoted_value = label_match.groups() + return quoted_value if quoted_value is not None else unquoted_value + + class PerformanceMonitor: """Monitors system performance and component health.""" @@ -474,8 +495,8 @@ class PerformanceMonitor: # Cache performance cache_types = set() for metric_name in self.metrics_collector.counters.keys(): - if 'cache_type=' in metric_name: - cache_type = metric_name.split('cache_type=')[1].split(',')[0].split('}')[0] + cache_type = _extract_metric_label(metric_name, 'cache_type') + if cache_type is not None: cache_types.add(cache_type) for cache_type in cache_types: From 6d07310d2bd3c4ca225634865f4b8a2aaa271454 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 26 May 2026 13:12:03 +0100 Subject: [PATCH 6/7] fix: repair broken imports in OntoRAG query module (#950) Replace hallucinated relative imports with correct absolute imports across the ontology query package, and fix OntologyMatcher reference to match the actual class name OntologyMatcherForQueries. Simplify test to use standard imports instead of importlib hack. Cosmetic, but simpler imports provides undeterministic imports in a dev environment, and also means we're properly testing linkage --- .../test_query/test_ontology_monitoring.py | 23 +++---------------- .../trustgraph/query/ontology/__init__.py | 4 ++-- .../query/ontology/ontology_matcher.py | 8 +++---- .../query/ontology/query_service.py | 12 +++++----- .../query/ontology/sparql_cassandra.py | 2 +- 5 files changed, 16 insertions(+), 33 deletions(-) diff --git a/tests/unit/test_query/test_ontology_monitoring.py b/tests/unit/test_query/test_ontology_monitoring.py index ef69965c..4b1b4253 100644 --- a/tests/unit/test_query/test_ontology_monitoring.py +++ b/tests/unit/test_query/test_ontology_monitoring.py @@ -2,27 +2,10 @@ Tests for ontology monitoring metrics. """ -import importlib.util -import sys -from pathlib import Path - - -MODULE_PATH = ( - Path(__file__).resolve().parents[3] - / "trustgraph-flow" - / "trustgraph" - / "query" - / "ontology" - / "monitoring.py" +from trustgraph.query.ontology.monitoring import ( + PerformanceMonitor, + _extract_metric_label, ) -spec = importlib.util.spec_from_file_location("ontology_monitoring", MODULE_PATH) -assert spec is not None and spec.loader is not None -monitoring = importlib.util.module_from_spec(spec) -sys.modules[spec.name] = monitoring -spec.loader.exec_module(monitoring) - -PerformanceMonitor = monitoring.PerformanceMonitor -_extract_metric_label = monitoring._extract_metric_label def test_extract_metric_label_reads_unquoted_label_value(): diff --git a/trustgraph-flow/trustgraph/query/ontology/__init__.py b/trustgraph-flow/trustgraph/query/ontology/__init__.py index 60557ea9..c5cddd9c 100644 --- a/trustgraph-flow/trustgraph/query/ontology/__init__.py +++ b/trustgraph-flow/trustgraph/query/ontology/__init__.py @@ -7,7 +7,7 @@ Provides semantic query understanding, ontology matching, and answer generation. from .query_service import OntoRAGQueryService, QueryRequest, QueryResponse from .question_analyzer import QuestionAnalyzer, QuestionComponents, QuestionType -from .ontology_matcher import OntologyMatcher, QueryOntologySubset +from .ontology_matcher import OntologyMatcherForQueries, QueryOntologySubset from .backend_router import BackendRouter, BackendType, QueryRoute from .sparql_generator import SPARQLGenerator, SPARQLQuery from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult @@ -27,7 +27,7 @@ __all__ = [ 'QuestionType', # Ontology matching - 'OntologyMatcher', + 'OntologyMatcherForQueries', 'QueryOntologySubset', # Backend routing diff --git a/trustgraph-flow/trustgraph/query/ontology/ontology_matcher.py b/trustgraph-flow/trustgraph/query/ontology/ontology_matcher.py index 895856f3..2dd6633a 100644 --- a/trustgraph-flow/trustgraph/query/ontology/ontology_matcher.py +++ b/trustgraph-flow/trustgraph/query/ontology/ontology_matcher.py @@ -7,10 +7,10 @@ import logging from typing import List, Dict, Any, Set, Optional from dataclasses import dataclass -from ...extract.kg.ontology.ontology_loader import Ontology, OntologyLoader -from ...extract.kg.ontology.ontology_embedder import OntologyEmbedder -from ...extract.kg.ontology.text_processor import TextSegment -from ...extract.kg.ontology.ontology_selector import OntologySelector, OntologySubset +from trustgraph.extract.kg.ontology.ontology_loader import Ontology, OntologyLoader +from trustgraph.extract.kg.ontology.ontology_embedder import OntologyEmbedder +from trustgraph.extract.kg.ontology.text_processor import TextSegment +from trustgraph.extract.kg.ontology.ontology_selector import OntologySelector, OntologySubset from .question_analyzer import QuestionComponents, QuestionType logger = logging.getLogger(__name__) diff --git a/trustgraph-flow/trustgraph/query/ontology/query_service.py b/trustgraph-flow/trustgraph/query/ontology/query_service.py index c6057cc1..77e60b50 100644 --- a/trustgraph-flow/trustgraph/query/ontology/query_service.py +++ b/trustgraph-flow/trustgraph/query/ontology/query_service.py @@ -8,13 +8,13 @@ from typing import Dict, Any, List, Optional, Union from dataclasses import dataclass from datetime import datetime -from ....flow.flow_processor import FlowProcessor -from ....tables.config import ConfigTableStore -from ...extract.kg.ontology.ontology_loader import OntologyLoader -from ...extract.kg.ontology.vector_store import InMemoryVectorStore +from trustgraph.base.flow_processor import FlowProcessor +from trustgraph.tables.config import ConfigTableStore +from trustgraph.extract.kg.ontology.ontology_loader import OntologyLoader +from trustgraph.extract.kg.ontology.vector_store import InMemoryVectorStore from .question_analyzer import QuestionAnalyzer, QuestionComponents -from .ontology_matcher import OntologyMatcher, QueryOntologySubset +from .ontology_matcher import OntologyMatcherForQueries, QueryOntologySubset from .backend_router import BackendRouter, QueryRoute, BackendType from .sparql_generator import SPARQLGenerator, SPARQLQuery from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult @@ -105,7 +105,7 @@ class OntoRAGQueryService(FlowProcessor): # Initialize ontology matcher matcher_config = self.config.get('ontology_matcher', {}) - self.ontology_matcher = OntologyMatcher( + self.ontology_matcher = OntologyMatcherForQueries( vector_store=self.vector_store, embedding_service=self.embedding_service, config=matcher_config diff --git a/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py b/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py index 688e7371..b7f0f423 100644 --- a/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py +++ b/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py @@ -28,7 +28,7 @@ try: except ImportError: CASSANDRA_AVAILABLE = False -from ....tables.config import ConfigTableStore +from trustgraph.tables.config import ConfigTableStore logger = logging.getLogger(__name__) From 4200b5d683c2efc6a3d501216bd21ce5092e293d Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Tue, 26 May 2026 14:35:54 +0100 Subject: [PATCH 7/7] fix: update library_client for workspace-based tenancy (#951) Replace removed `user` parameter with `workspace` support following the tenancy axis change in #840. Adds -w/--workspace flag and $TRUSTGRAPH_WORKSPACE env var. --- dev-tools/library_client.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/dev-tools/library_client.py b/dev-tools/library_client.py index ae9d6857..30e0c344 100644 --- a/dev-tools/library_client.py +++ b/dev-tools/library_client.py @@ -25,7 +25,7 @@ BUCKET_URL = "https://storage.googleapis.com/trustgraph-library" INDEX_URL = f"{BUCKET_URL}/index.json" default_url = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/") -default_user = "trustgraph" +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_token = os.getenv("TRUSTGRAPH_TOKEN", None) @@ -113,7 +113,7 @@ def convert_metadata(metadata_json): return triples -def load_document(api, user, doc_entry): +def load_document(api, doc_entry): """Fetch metadata and content for a document, then load into TrustGraph.""" doc_id = doc_entry["id"] title = doc_entry["title"] @@ -133,7 +133,6 @@ def load_document(api, user, doc_entry): api.add_document( id=doc["id"], metadata=metadata, - user=user, kind=doc["kind"], title=doc["title"], comments=doc["comments"], @@ -144,12 +143,12 @@ def load_document(api, user, doc_entry): print(f" done.") -def load_documents(api, user, docs): +def load_documents(api, docs): """Load a list of documents.""" print(f"Loading {len(docs)} document(s)...\n") for doc in docs: try: - load_document(api, user, doc) + load_document(api, doc) except Exception as e: print(f" FAILED: {e}", file=sys.stderr) print() @@ -166,8 +165,8 @@ def main(): help=f"TrustGraph API URL (default: {default_url})", ) parser.add_argument( - "-U", "--user", default=default_user, - help=f"User ID (default: {default_user})", + "-w", "--workspace", default=default_workspace, + help=f"Workspace (default: {default_workspace})", ) parser.add_argument( "-t", "--token", default=default_token, @@ -212,22 +211,22 @@ def main(): return # Load commands need the API - api = Api(args.url, token=args.token).library() + api = Api(args.url, token=args.token, workspace=args.workspace).library() if args.command == "load-all": - load_documents(api, args.user, index) + load_documents(api, index) elif args.command == "load-doc": matches = [d for d in index if str(d.get("id")) == args.id] if not matches: print(f"No document with ID '{args.id}' found.", file=sys.stderr) sys.exit(1) - load_documents(api, args.user, matches) + load_documents(api, matches) elif args.command == "load-match": results = search_index(index, args.query) if results: - load_documents(api, args.user, results) + load_documents(api, results) else: print("No matches found.", file=sys.stderr) sys.exit(1)