diff --git a/tests/unit/test_tables/test_knowledge_table_store.py b/tests/unit/test_tables/test_knowledge_table_store.py index 59d15b45..9a0b55c4 100644 --- a/tests/unit/test_tables/test_knowledge_table_store.py +++ b/tests/unit/test_tables/test_knowledge_table_store.py @@ -35,9 +35,9 @@ def _make_store(): class TestGetGraphEmbeddings: @pytest.mark.asyncio - @patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock) + @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock) async def test_row_converts_to_entity_embeddings_with_singular_vector( - self, mock_async_execute + self, mock_async_execute_paged ): """ Cassandra rows return entities as a list of [entity_tuple, vector] @@ -57,7 +57,7 @@ class TestGetGraphEmbeddings: store = _make_store() store.cassandra = Mock() store.get_graph_embeddings_stmt = Mock() - mock_async_execute.return_value = [fake_row] + mock_async_execute_paged.return_value = [[fake_row]] received = [] @@ -66,7 +66,7 @@ class TestGetGraphEmbeddings: await store.get_graph_embeddings("alice", "doc-1", receiver) - mock_async_execute.assert_called_once_with( + mock_async_execute_paged.assert_called_once_with( store.cassandra, store.get_graph_embeddings_stmt, ("alice", "doc-1"), @@ -96,8 +96,8 @@ class TestGetGraphEmbeddings: assert ge.entities[2].entity.value == "a literal entity" @pytest.mark.asyncio - @patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock) - async def test_empty_entities_blob_yields_empty_list(self, mock_async_execute): + @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock) + async def test_empty_entities_blob_yields_empty_list(self, mock_async_execute_paged): """row[3] being None / empty must produce a GraphEmbeddings with no entities, not raise.""" fake_row = (None, None, None, None) @@ -105,7 +105,7 @@ class TestGetGraphEmbeddings: store = _make_store() store.cassandra = Mock() store.get_graph_embeddings_stmt = Mock() - mock_async_execute.return_value = [fake_row] + mock_async_execute_paged.return_value = [[fake_row]] received = [] @@ -118,8 +118,8 @@ class TestGetGraphEmbeddings: assert received[0].entities == [] @pytest.mark.asyncio - @patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock) - async def test_multiple_rows_each_emit_one_message(self, mock_async_execute): + @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock) + async def test_multiple_rows_each_emit_one_message(self, mock_async_execute_paged): fake_rows = [ (None, None, None, [ (("http://example.org/a", True), [1.0]), @@ -132,7 +132,7 @@ class TestGetGraphEmbeddings: store = _make_store() store.cassandra = Mock() store.get_graph_embeddings_stmt = Mock() - mock_async_execute.return_value = fake_rows + mock_async_execute_paged.return_value = [fake_rows] received = [] @@ -153,8 +153,8 @@ class TestGetTriples: the same Metadata construction. Cover it for parity.""" @pytest.mark.asyncio - @patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock) - async def test_row_converts_to_triples(self, mock_async_execute): + @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock) + async def test_row_converts_to_triples(self, mock_async_execute_paged): # row[3] is a list of (s_val, s_uri, p_val, p_uri, o_val, o_uri) fake_row = ( None, None, None, @@ -170,7 +170,7 @@ class TestGetTriples: store = _make_store() store.cassandra = Mock() store.get_triples_stmt = Mock() - mock_async_execute.return_value = [fake_row] + mock_async_execute_paged.return_value = [[fake_row]] received = [] diff --git a/trustgraph-base/trustgraph/api/api.py b/trustgraph-base/trustgraph/api/api.py index 9074bac1..0190d3f5 100644 --- a/trustgraph-base/trustgraph/api/api.py +++ b/trustgraph-base/trustgraph/api/api.py @@ -337,7 +337,7 @@ class Api: from . bulk_client import BulkClient # Extract base URL (remove api/v1/ suffix) base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/") - self._bulk_client = BulkClient(base_url, self.timeout, self.token) + self._bulk_client = BulkClient(base_url, self.timeout, self.token, workspace=self.workspace) return self._bulk_client def metrics(self): @@ -462,7 +462,7 @@ class Api: from . async_bulk_client import AsyncBulkClient # Extract base URL (remove api/v1/ suffix) base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/") - self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token) + self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token, workspace=self.workspace) return self._async_bulk_client def async_metrics(self): diff --git a/trustgraph-base/trustgraph/api/async_bulk_client.py b/trustgraph-base/trustgraph/api/async_bulk_client.py index 9a6a49c3..f93ab667 100644 --- a/trustgraph-base/trustgraph/api/async_bulk_client.py +++ b/trustgraph-base/trustgraph/api/async_bulk_client.py @@ -9,10 +9,11 @@ from . types import Triple class AsyncBulkClient: """Asynchronous bulk operations client""" - def __init__(self, url: str, timeout: int, token: Optional[str]) -> None: + def __init__(self, url: str, timeout: int, token: Optional[str], workspace: str = "default") -> None: self.url: str = self._convert_to_ws_url(url) self.timeout: int = timeout self.token: Optional[str] = token + self.workspace: str = workspace def _convert_to_ws_url(self, url: str) -> str: """Convert HTTP URL to WebSocket URL""" @@ -25,11 +26,21 @@ class AsyncBulkClient: else: return f"ws://{url}" + def _build_ws_url(self, path: str) -> str: + """Build a WebSocket URL with token and workspace query params.""" + ws_url = f"{self.url}{path}" + params = [] + if self.token: + params.append(f"token={self.token}") + if self.workspace: + params.append(f"workspace={self.workspace}") + if params: + ws_url = f"{ws_url}?{'&'.join(params)}" + return ws_url + async def import_triples(self, flow: str, triples: AsyncIterator[Triple], **kwargs: Any) -> None: """Bulk import triples via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/triples") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for triple in triples: @@ -42,9 +53,7 @@ class AsyncBulkClient: async def export_triples(self, flow: str, **kwargs: Any) -> AsyncIterator[Triple]: """Bulk export triples via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/triples") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -57,9 +66,7 @@ class AsyncBulkClient: async def import_graph_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None: """Bulk import graph embeddings via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/graph-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for embedding in embeddings: @@ -67,9 +74,7 @@ class AsyncBulkClient: async def export_graph_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]: """Bulk export graph embeddings via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/graph-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -77,9 +82,7 @@ class AsyncBulkClient: async def import_document_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None: """Bulk import document embeddings via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/document-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for embedding in embeddings: @@ -87,9 +90,7 @@ class AsyncBulkClient: async def export_document_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]: """Bulk export document embeddings via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/document-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -97,9 +98,7 @@ class AsyncBulkClient: async def import_entity_contexts(self, flow: str, contexts: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None: """Bulk import entity contexts via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/entity-contexts") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for context in contexts: @@ -107,9 +106,7 @@ class AsyncBulkClient: async def export_entity_contexts(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]: """Bulk export entity contexts via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/entity-contexts") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -117,9 +114,7 @@ class AsyncBulkClient: async def import_rows(self, flow: str, rows: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None: """Bulk import rows via WebSocket""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/rows") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for row in rows: diff --git a/trustgraph-base/trustgraph/api/bulk_client.py b/trustgraph-base/trustgraph/api/bulk_client.py index 0e49fc4e..ae185240 100644 --- a/trustgraph-base/trustgraph/api/bulk_client.py +++ b/trustgraph-base/trustgraph/api/bulk_client.py @@ -34,7 +34,7 @@ class BulkClient: Note: For true async support, use AsyncBulkClient instead. """ - def __init__(self, url: str, timeout: int, token: Optional[str]) -> None: + def __init__(self, url: str, timeout: int, token: Optional[str], workspace: str = "default") -> None: """ Initialize synchronous bulk client. @@ -42,10 +42,12 @@ class BulkClient: url: Base URL for TrustGraph API (HTTP/HTTPS will be converted to WS/WSS) timeout: WebSocket timeout in seconds token: Optional bearer token for authentication + workspace: Workspace for data isolation """ self.url: str = self._convert_to_ws_url(url) self.timeout: int = timeout self.token: Optional[str] = token + self.workspace: str = workspace def _convert_to_ws_url(self, url: str) -> str: """Convert HTTP URL to WebSocket URL""" @@ -58,6 +60,18 @@ class BulkClient: else: return f"ws://{url}" + def _build_ws_url(self, path: str) -> str: + """Build a WebSocket URL with token and workspace query params.""" + ws_url = f"{self.url}{path}" + params = [] + if self.token: + params.append(f"token={self.token}") + if self.workspace: + params.append(f"workspace={self.workspace}") + if params: + ws_url = f"{ws_url}?{'&'.join(params)}" + return ws_url + def _run_async(self, coro: Coroutine[Any, Any, Any]) -> Any: """Run async coroutine synchronously""" try: @@ -116,9 +130,7 @@ class BulkClient: metadata: Optional[Dict[str, Any]], batch_size: int ) -> None: """Async implementation of triple import""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/triples") if metadata is None: metadata = {"id": "", "metadata": [], "collection": "default"} @@ -194,9 +206,7 @@ class BulkClient: async def _export_triples_async(self, flow: str) -> Iterator[Triple]: """Async implementation of triple export""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/triples") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -238,9 +248,7 @@ class BulkClient: async def _import_graph_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None: """Async implementation of graph embeddings import""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/graph-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: for embedding in embeddings: @@ -296,9 +304,7 @@ class BulkClient: async def _export_graph_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]: """Async implementation of graph embeddings export""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/graph-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -336,9 +342,7 @@ class BulkClient: async def _import_document_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None: """Async implementation of document embeddings import""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/document-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: for embedding in embeddings: @@ -394,9 +398,7 @@ class BulkClient: async def _export_document_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]: """Async implementation of document embeddings export""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/document-embeddings") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -446,9 +448,7 @@ class BulkClient: metadata: Optional[Dict[str, Any]], batch_size: int ) -> None: """Async implementation of entity contexts import""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/entity-contexts") if metadata is None: metadata = {"id": "", "metadata": [], "collection": "default"} @@ -522,9 +522,7 @@ class BulkClient: async def _export_entity_contexts_async(self, flow: str) -> Iterator[Dict[str, Any]]: """Async implementation of entity contexts export""" - ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/entity-contexts") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: async for raw_message in websocket: @@ -562,9 +560,7 @@ class BulkClient: async def _import_rows_async(self, flow: str, rows: Iterator[Dict[str, Any]]) -> None: """Async implementation of rows import""" - ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows" - if self.token: - ws_url = f"{ws_url}?token={self.token}" + ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/rows") async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: for row in rows: diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index 6eeb95ff..b88d0c78 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -167,7 +167,8 @@ class SocketClient: ) if resp.get("type") == "auth-ok": - self.workspace = resp.get("workspace", self.workspace) + if self.workspace == "default": + self.workspace = resp.get("workspace", self.workspace) elif resp.get("type") == "auth-failed": await self._socket.close() raise ProtocolException( diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/socket.py b/trustgraph-flow/trustgraph/gateway/endpoint/socket.py index f53ad73b..af6183db 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/socket.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/socket.py @@ -117,8 +117,10 @@ class SocketEndpoint: running = Running() + params = dict(request.query) + params.update(request.match_info) dispatcher = await self.dispatcher( - ws, running, request.match_info + ws, running, params ) worker_task = tg.create_task(