From 8d8fb5766d39651f40157b0eae181f1a5466522a Mon Sep 17 00:00:00 2001 From: Cyber MacGeddon Date: Tue, 2 Jun 2026 14:18:06 +0100 Subject: [PATCH] fix: route workspace through bulk WebSocket clients and merge query params Bulk clients (sync and async) were not forwarding the workspace parameter, causing all bulk operations to hit the default workspace regardless of the Api instance's workspace setting. Also fixes the gateway socket endpoint to pass query parameters (including workspace) to the dispatcher, and prevents the auth handshake from overwriting an explicitly set workspace. Updates knowledge table store tests for paged query interface. --- .../test_tables/test_knowledge_table_store.py | 26 +++++----- trustgraph-base/trustgraph/api/api.py | 4 +- .../trustgraph/api/async_bulk_client.py | 51 ++++++++---------- trustgraph-base/trustgraph/api/bulk_client.py | 52 +++++++++---------- .../trustgraph/api/socket_client.py | 3 +- .../trustgraph/gateway/endpoint/socket.py | 4 +- 6 files changed, 67 insertions(+), 73 deletions(-) diff --git a/tests/unit/test_tables/test_knowledge_table_store.py b/tests/unit/test_tables/test_knowledge_table_store.py index 59d15b45..9a0b55c4 100644 --- a/tests/unit/test_tables/test_knowledge_table_store.py +++ b/tests/unit/test_tables/test_knowledge_table_store.py @@ -35,9 +35,9 @@ def _make_store(): class TestGetGraphEmbeddings: @pytest.mark.asyncio - @patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock) + @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock) async def test_row_converts_to_entity_embeddings_with_singular_vector( - self, mock_async_execute + self, mock_async_execute_paged ): """ Cassandra rows return entities as a list of [entity_tuple, vector] @@ -57,7 +57,7 @@ class TestGetGraphEmbeddings: store = _make_store() store.cassandra = Mock() store.get_graph_embeddings_stmt = Mock() - mock_async_execute.return_value = [fake_row] + mock_async_execute_paged.return_value = [[fake_row]] received = [] @@ -66,7 +66,7 @@ class TestGetGraphEmbeddings: await store.get_graph_embeddings("alice", "doc-1", receiver) - mock_async_execute.assert_called_once_with( + mock_async_execute_paged.assert_called_once_with( store.cassandra, store.get_graph_embeddings_stmt, ("alice", "doc-1"), @@ -96,8 +96,8 @@ class TestGetGraphEmbeddings: assert ge.entities[2].entity.value == "a literal entity" @pytest.mark.asyncio - @patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock) - async def test_empty_entities_blob_yields_empty_list(self, mock_async_execute): + @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock) + async def test_empty_entities_blob_yields_empty_list(self, mock_async_execute_paged): """row[3] being None / empty must produce a GraphEmbeddings with no entities, not raise.""" fake_row = (None, None, None, None) @@ -105,7 +105,7 @@ class TestGetGraphEmbeddings: store = _make_store() store.cassandra = Mock() store.get_graph_embeddings_stmt = Mock() - mock_async_execute.return_value = [fake_row] + mock_async_execute_paged.return_value = [[fake_row]] received = [] @@ -118,8 +118,8 @@ class TestGetGraphEmbeddings: assert received[0].entities == [] @pytest.mark.asyncio - @patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock) - async def test_multiple_rows_each_emit_one_message(self, mock_async_execute): + @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock) + async def test_multiple_rows_each_emit_one_message(self, mock_async_execute_paged): fake_rows = [ (None, None, None, [ (("http://example.org/a", True), [1.0]), @@ -132,7 +132,7 @@ class TestGetGraphEmbeddings: store = _make_store() store.cassandra = Mock() store.get_graph_embeddings_stmt = Mock() - mock_async_execute.return_value = fake_rows + mock_async_execute_paged.return_value = [fake_rows] received = [] @@ -153,8 +153,8 @@ class TestGetTriples: the same Metadata construction. Cover it for parity.""" @pytest.mark.asyncio - @patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock) - async def test_row_converts_to_triples(self, mock_async_execute): + @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock) + async def test_row_converts_to_triples(self, mock_async_execute_paged): # row[3] is a list of (s_val, s_uri, p_val, p_uri, o_val, o_uri) fake_row = ( None, None, None, @@ -170,7 +170,7 @@ class TestGetTriples: store = _make_store() store.cassandra = Mock() store.get_triples_stmt = Mock() - mock_async_execute.return_value = [fake_row] + mock_async_execute_paged.return_value = [[fake_row]] received = [] diff --git a/trustgraph-base/trustgraph/api/api.py b/trustgraph-base/trustgraph/api/api.py index 9074bac1..0190d3f5 100644 --- a/trustgraph-base/trustgraph/api/api.py +++ b/trustgraph-base/trustgraph/api/api.py @@ -337,7 +337,7 @@ class Api: from . bulk_client import BulkClient # Extract base URL (remove api/v1/ suffix) base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/") - self._bulk_client = BulkClient(base_url, self.timeout, self.token) + self._bulk_client = BulkClient(base_url, self.timeout, self.token, workspace=self.workspace) return self._bulk_client def metrics(self): @@ -462,7 +462,7 @@ class Api: from . async_bulk_client import AsyncBulkClient # Extract base URL (remove api/v1/ suffix) base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/") - self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token) + self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token, workspace=self.workspace) return self._async_bulk_client def async_metrics(self): diff --git a/trustgraph-base/trustgraph/api/async_bulk_client.py b/trustgraph-base/trustgraph/api/async_bulk_client.py index 9a6a49c3..f93ab667 100644 --- a/trustgraph-base/trustgraph/api/async_bulk_client.py +++ b/trustgraph-base/trustgraph/api/async_bulk_client.py @@ -9,10 +9,11 @@ from . types import Triple class AsyncBulkClient: """Asynchronous bulk operations client""" - def __init__(self, url: str, timeout: int, token: Optional[str]) -> None: + def __init__(self, url: str, timeout: int, token: Optional[str], workspace: str = "default") -> None: self.url: str = self._convert_to_ws_url(url) self.timeout: int = timeout self.token: Optional[str] = token + self.workspace: str = workspace def _convert_to_ws_url(self, url: str) -> str: """Convert HTTP URL to WebSocket URL""" @@ -25,11 +26,21 @@ class AsyncBulkClient: else: return f"ws://{url}" + def _build_ws_url(self, path: str) -> str: + """Build a WebSocket URL with token and workspace query params.""" + ws_url = f"{self.url}{path}" + params = [] + if self.token: + params.append(f"token={self.token}") + if self.workspace: + params.append(f"workspace={self.workspace}") + if params: + ws_url = f"{ws_url}?{'&'.join(params)}" + return ws_url + async def import_triples(self, flow: str, triples: AsyncIterator[Triple], **kwargs: Any) -> None: """Bulk import triples via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/triples") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for triple in triples: @@ -42,9 +53,7 @@ class AsyncBulkClient: async def export_triples(self, flow: str, **kwargs: Any) -> AsyncIterator[Triple]: """Bulk export triples via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/triples") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -57,9 +66,7 @@ class AsyncBulkClient: async def import_graph_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None: """Bulk import graph embeddings via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/graph-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for embedding in embeddings: @@ -67,9 +74,7 @@ class AsyncBulkClient: async def export_graph_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]: """Bulk export graph embeddings via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/graph-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -77,9 +82,7 @@ class AsyncBulkClient: async def import_document_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None: """Bulk import document embeddings via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/document-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for embedding in embeddings: @@ -87,9 +90,7 @@ class AsyncBulkClient: async def export_document_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]: """Bulk export document embeddings via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/document-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -97,9 +98,7 @@ class AsyncBulkClient: async def import_entity_contexts(self, flow: str, contexts: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None: """Bulk import entity contexts via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/entity-contexts") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for context in contexts: @@ -107,9 +106,7 @@ class AsyncBulkClient: async def export_entity_contexts(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]: """Bulk export entity contexts via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/entity-contexts") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -117,9 +114,7 @@ class AsyncBulkClient: async def import_rows(self, flow: str, rows: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None: """Bulk import rows via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/rows") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for row in rows: diff --git a/trustgraph-base/trustgraph/api/bulk_client.py b/trustgraph-base/trustgraph/api/bulk_client.py index 0e49fc4e..ae185240 100644 --- a/trustgraph-base/trustgraph/api/bulk_client.py +++ b/trustgraph-base/trustgraph/api/bulk_client.py @@ -34,7 +34,7 @@ class BulkClient: Note: For true async support, use AsyncBulkClient instead. """ - def __init__(self, url: str, timeout: int, token: Optional[str]) -> None: + def __init__(self, url: str, timeout: int, token: Optional[str], workspace: str = "default") -> None: """ Initialize synchronous bulk client. @@ -42,10 +42,12 @@ class BulkClient: url: Base URL for TrustGraph API (HTTP/HTTPS will be converted to WS/WSS) timeout: WebSocket timeout in seconds token: Optional bearer token for authentication + workspace: Workspace for data isolation """ self.url: str = self._convert_to_ws_url(url) self.timeout: int = timeout self.token: Optional[str] = token + self.workspace: str = workspace def _convert_to_ws_url(self, url: str) -> str: """Convert HTTP URL to WebSocket URL""" @@ -58,6 +60,18 @@ class BulkClient: else: return f"ws://{url}" + def _build_ws_url(self, path: str) -> str: + """Build a WebSocket URL with token and workspace query params.""" + ws_url = f"{self.url}{path}" + params = [] + if self.token: + params.append(f"token={self.token}") + if self.workspace: + params.append(f"workspace={self.workspace}") + if params: + ws_url = f"{ws_url}?{'&'.join(params)}" + return ws_url + def _run_async(self, coro: Coroutine[Any, Any, Any]) -> Any: """Run async coroutine synchronously""" try: @@ -116,9 +130,7 @@ class BulkClient: metadata: Optional[Dict[str, Any]], batch_size: int ) -> None: """Async implementation of triple import""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/triples") if metadata is None: metadata = {"id": "", "metadata": [], "collection": "default"} @@ -194,9 +206,7 @@ class BulkClient: async def _export_triples_async(self, flow: str) -> Iterator[Triple]: """Async implementation of triple export""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/triples") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -238,9 +248,7 @@ class BulkClient: async def _import_graph_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None: """Async implementation of graph embeddings import""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/graph-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: for embedding in embeddings: @@ -296,9 +304,7 @@ class BulkClient: async def _export_graph_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]: """Async implementation of graph embeddings export""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/graph-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -336,9 +342,7 @@ class BulkClient: async def _import_document_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None: """Async implementation of document embeddings import""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/document-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: for embedding in embeddings: @@ -394,9 +398,7 @@ class BulkClient: async def _export_document_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]: """Async implementation of document embeddings export""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/document-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -446,9 +448,7 @@ class BulkClient: metadata: Optional[Dict[str, Any]], batch_size: int ) -> None: """Async implementation of entity contexts import""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/entity-contexts") if metadata is None: metadata = {"id": "", "metadata": [], "collection": "default"} @@ -522,9 +522,7 @@ class BulkClient: async def _export_entity_contexts_async(self, flow: str) -> Iterator[Dict[str, Any]]: """Async implementation of entity contexts export""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/entity-contexts") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -562,9 +560,7 @@ class BulkClient: async def _import_rows_async(self, flow: str, rows: Iterator[Dict[str, Any]]) -> None: """Async implementation of rows import""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/rows") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: for row in rows: diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index 6eeb95ff..b88d0c78 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -167,7 +167,8 @@ class SocketClient: ) if resp.get("type") == "auth-ok": - self.workspace = resp.get("workspace", self.workspace) + if self.workspace == "default": + self.workspace = resp.get("workspace", self.workspace) elif resp.get("type") == "auth-failed": await self._socket.close() raise ProtocolException( diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/socket.py b/trustgraph-flow/trustgraph/gateway/endpoint/socket.py index f53ad73b..af6183db 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/socket.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/socket.py @@ -117,8 +117,10 @@ class SocketEndpoint: running = Running() + params = dict(request.query) + params.update(request.match_info) dispatcher = await self.dispatcher( - ws, running, request.match_info + ws, running, params ) worker_task = tg.create_task(