fix: route workspace through bulk WebSocket clients and merge query params

Bulk clients (sync and async) were not forwarding the workspace parameter,
causing all bulk operations to hit the default workspace regardless of the
Api instance's workspace setting. Also fixes the gateway socket endpoint to
pass query parameters (including workspace) to the dispatcher, and prevents
the auth handshake from overwriting an explicitly set workspace.

Updates knowledge table store tests for paged query interface.
This commit is contained in:
Cyber MacGeddon 2026-06-02 14:18:06 +01:00
parent 6b1dd16f9f
commit 8d8fb5766d
6 changed files with 67 additions and 73 deletions

View file

@ -35,9 +35,9 @@ def _make_store():
class TestGetGraphEmbeddings:
@pytest.mark.asyncio
@patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
@patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
async def test_row_converts_to_entity_embeddings_with_singular_vector(
self, mock_async_execute
self, mock_async_execute_paged
):
"""
Cassandra rows return entities as a list of [entity_tuple, vector]
@ -57,7 +57,7 @@ class TestGetGraphEmbeddings:
store = _make_store()
store.cassandra = Mock()
store.get_graph_embeddings_stmt = Mock()
mock_async_execute.return_value = [fake_row]
mock_async_execute_paged.return_value = [[fake_row]]
received = []
@ -66,7 +66,7 @@ class TestGetGraphEmbeddings:
await store.get_graph_embeddings("alice", "doc-1", receiver)
mock_async_execute.assert_called_once_with(
mock_async_execute_paged.assert_called_once_with(
store.cassandra,
store.get_graph_embeddings_stmt,
("alice", "doc-1"),
@ -96,8 +96,8 @@ class TestGetGraphEmbeddings:
assert ge.entities[2].entity.value == "a literal entity"
@pytest.mark.asyncio
@patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
async def test_empty_entities_blob_yields_empty_list(self, mock_async_execute):
@patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
async def test_empty_entities_blob_yields_empty_list(self, mock_async_execute_paged):
"""row[3] being None / empty must produce a GraphEmbeddings with
no entities, not raise."""
fake_row = (None, None, None, None)
@ -105,7 +105,7 @@ class TestGetGraphEmbeddings:
store = _make_store()
store.cassandra = Mock()
store.get_graph_embeddings_stmt = Mock()
mock_async_execute.return_value = [fake_row]
mock_async_execute_paged.return_value = [[fake_row]]
received = []
@ -118,8 +118,8 @@ class TestGetGraphEmbeddings:
assert received[0].entities == []
@pytest.mark.asyncio
@patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
async def test_multiple_rows_each_emit_one_message(self, mock_async_execute):
@patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
async def test_multiple_rows_each_emit_one_message(self, mock_async_execute_paged):
fake_rows = [
(None, None, None, [
(("http://example.org/a", True), [1.0]),
@ -132,7 +132,7 @@ class TestGetGraphEmbeddings:
store = _make_store()
store.cassandra = Mock()
store.get_graph_embeddings_stmt = Mock()
mock_async_execute.return_value = fake_rows
mock_async_execute_paged.return_value = [fake_rows]
received = []
@ -153,8 +153,8 @@ class TestGetTriples:
the same Metadata construction. Cover it for parity."""
@pytest.mark.asyncio
@patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
async def test_row_converts_to_triples(self, mock_async_execute):
@patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
async def test_row_converts_to_triples(self, mock_async_execute_paged):
# row[3] is a list of (s_val, s_uri, p_val, p_uri, o_val, o_uri)
fake_row = (
None, None, None,
@ -170,7 +170,7 @@ class TestGetTriples:
store = _make_store()
store.cassandra = Mock()
store.get_triples_stmt = Mock()
mock_async_execute.return_value = [fake_row]
mock_async_execute_paged.return_value = [[fake_row]]
received = []

View file

@ -337,7 +337,7 @@ class Api:
from . bulk_client import BulkClient
# Extract base URL (remove api/v1/ suffix)
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
self._bulk_client = BulkClient(base_url, self.timeout, self.token)
self._bulk_client = BulkClient(base_url, self.timeout, self.token, workspace=self.workspace)
return self._bulk_client
def metrics(self):
@ -462,7 +462,7 @@ class Api:
from . async_bulk_client import AsyncBulkClient
# Extract base URL (remove api/v1/ suffix)
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token)
self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token, workspace=self.workspace)
return self._async_bulk_client
def async_metrics(self):

View file

@ -9,10 +9,11 @@ from . types import Triple
class AsyncBulkClient:
"""Asynchronous bulk operations client"""
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
def __init__(self, url: str, timeout: int, token: Optional[str], workspace: str = "default") -> None:
self.url: str = self._convert_to_ws_url(url)
self.timeout: int = timeout
self.token: Optional[str] = token
self.workspace: str = workspace
def _convert_to_ws_url(self, url: str) -> str:
"""Convert HTTP URL to WebSocket URL"""
@ -25,11 +26,21 @@ class AsyncBulkClient:
else:
return f"ws://{url}"
def _build_ws_url(self, path: str) -> str:
"""Build a WebSocket URL with token and workspace query params."""
ws_url = f"{self.url}{path}"
params = []
if self.token:
params.append(f"token={self.token}")
if self.workspace:
params.append(f"workspace={self.workspace}")
if params:
ws_url = f"{ws_url}?{'&'.join(params)}"
return ws_url
async def import_triples(self, flow: str, triples: AsyncIterator[Triple], **kwargs: Any) -> None:
"""Bulk import triples via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/triples")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for triple in triples:
@ -42,9 +53,7 @@ class AsyncBulkClient:
async def export_triples(self, flow: str, **kwargs: Any) -> AsyncIterator[Triple]:
"""Bulk export triples via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/triples")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@ -57,9 +66,7 @@ class AsyncBulkClient:
async def import_graph_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import graph embeddings via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/graph-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for embedding in embeddings:
@ -67,9 +74,7 @@ class AsyncBulkClient:
async def export_graph_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
"""Bulk export graph embeddings via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/graph-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@ -77,9 +82,7 @@ class AsyncBulkClient:
async def import_document_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import document embeddings via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/document-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for embedding in embeddings:
@ -87,9 +90,7 @@ class AsyncBulkClient:
async def export_document_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
"""Bulk export document embeddings via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/document-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@ -97,9 +98,7 @@ class AsyncBulkClient:
async def import_entity_contexts(self, flow: str, contexts: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import entity contexts via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/entity-contexts")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for context in contexts:
@ -107,9 +106,7 @@ class AsyncBulkClient:
async def export_entity_contexts(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
"""Bulk export entity contexts via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/entity-contexts")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@ -117,9 +114,7 @@ class AsyncBulkClient:
async def import_rows(self, flow: str, rows: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import rows via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/rows")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for row in rows:

View file

@ -34,7 +34,7 @@ class BulkClient:
Note: For true async support, use AsyncBulkClient instead.
"""
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
def __init__(self, url: str, timeout: int, token: Optional[str], workspace: str = "default") -> None:
"""
Initialize synchronous bulk client.
@ -42,10 +42,12 @@ class BulkClient:
url: Base URL for TrustGraph API (HTTP/HTTPS will be converted to WS/WSS)
timeout: WebSocket timeout in seconds
token: Optional bearer token for authentication
workspace: Workspace for data isolation
"""
self.url: str = self._convert_to_ws_url(url)
self.timeout: int = timeout
self.token: Optional[str] = token
self.workspace: str = workspace
def _convert_to_ws_url(self, url: str) -> str:
"""Convert HTTP URL to WebSocket URL"""
@ -58,6 +60,18 @@ class BulkClient:
else:
return f"ws://{url}"
def _build_ws_url(self, path: str) -> str:
"""Build a WebSocket URL with token and workspace query params."""
ws_url = f"{self.url}{path}"
params = []
if self.token:
params.append(f"token={self.token}")
if self.workspace:
params.append(f"workspace={self.workspace}")
if params:
ws_url = f"{ws_url}?{'&'.join(params)}"
return ws_url
def _run_async(self, coro: Coroutine[Any, Any, Any]) -> Any:
"""Run async coroutine synchronously"""
try:
@ -116,9 +130,7 @@ class BulkClient:
metadata: Optional[Dict[str, Any]], batch_size: int
) -> None:
"""Async implementation of triple import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/triples")
if metadata is None:
metadata = {"id": "", "metadata": [], "collection": "default"}
@ -194,9 +206,7 @@ class BulkClient:
async def _export_triples_async(self, flow: str) -> Iterator[Triple]:
"""Async implementation of triple export"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/triples")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@ -238,9 +248,7 @@ class BulkClient:
async def _import_graph_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None:
"""Async implementation of graph embeddings import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/graph-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
for embedding in embeddings:
@ -296,9 +304,7 @@ class BulkClient:
async def _export_graph_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]:
"""Async implementation of graph embeddings export"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/graph-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@ -336,9 +342,7 @@ class BulkClient:
async def _import_document_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None:
"""Async implementation of document embeddings import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/document-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
for embedding in embeddings:
@ -394,9 +398,7 @@ class BulkClient:
async def _export_document_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]:
"""Async implementation of document embeddings export"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/document-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@ -446,9 +448,7 @@ class BulkClient:
metadata: Optional[Dict[str, Any]], batch_size: int
) -> None:
"""Async implementation of entity contexts import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/entity-contexts")
if metadata is None:
metadata = {"id": "", "metadata": [], "collection": "default"}
@ -522,9 +522,7 @@ class BulkClient:
async def _export_entity_contexts_async(self, flow: str) -> Iterator[Dict[str, Any]]:
"""Async implementation of entity contexts export"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/entity-contexts")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@ -562,9 +560,7 @@ class BulkClient:
async def _import_rows_async(self, flow: str, rows: Iterator[Dict[str, Any]]) -> None:
"""Async implementation of rows import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/rows")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
for row in rows:

View file

@ -167,7 +167,8 @@ class SocketClient:
)
if resp.get("type") == "auth-ok":
self.workspace = resp.get("workspace", self.workspace)
if self.workspace == "default":
self.workspace = resp.get("workspace", self.workspace)
elif resp.get("type") == "auth-failed":
await self._socket.close()
raise ProtocolException(

View file

@ -117,8 +117,10 @@ class SocketEndpoint:
running = Running()
params = dict(request.query)
params.update(request.match_info)
dispatcher = await self.dispatcher(
ws, running, request.match_info
ws, running, params
)
worker_task = tg.create_task(