fix: route workspace through bulk WebSocket clients and merge query params (#970)

Bulk clients (sync and async) were not forwarding the workspace parameter,
causing all bulk operations to hit the default workspace regardless of the
Api instance's workspace setting. Also fixes the gateway socket endpoint to
pass query parameters (including workspace) to the dispatcher, and prevents
the auth handshake from overwriting an explicitly set workspace.

Updates knowledge table store tests for paged query interface.
This commit is contained in:
cybermaggedon 2026-06-02 14:19:15 +01:00 committed by GitHub
parent 6b1dd16f9f
commit 00bb964e93
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 67 additions and 73 deletions

View file

@ -337,7 +337,7 @@ class Api:
from . bulk_client import BulkClient
# Extract base URL (remove api/v1/ suffix)
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
self._bulk_client = BulkClient(base_url, self.timeout, self.token)
self._bulk_client = BulkClient(base_url, self.timeout, self.token, workspace=self.workspace)
return self._bulk_client
def metrics(self):
@ -462,7 +462,7 @@ class Api:
from . async_bulk_client import AsyncBulkClient
# Extract base URL (remove api/v1/ suffix)
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token)
self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token, workspace=self.workspace)
return self._async_bulk_client
def async_metrics(self):

View file

@ -9,10 +9,11 @@ from . types import Triple
class AsyncBulkClient:
"""Asynchronous bulk operations client"""
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
def __init__(self, url: str, timeout: int, token: Optional[str], workspace: str = "default") -> None:
self.url: str = self._convert_to_ws_url(url)
self.timeout: int = timeout
self.token: Optional[str] = token
self.workspace: str = workspace
def _convert_to_ws_url(self, url: str) -> str:
"""Convert HTTP URL to WebSocket URL"""
@ -25,11 +26,21 @@ class AsyncBulkClient:
else:
return f"ws://{url}"
def _build_ws_url(self, path: str) -> str:
"""Build a WebSocket URL with token and workspace query params."""
ws_url = f"{self.url}{path}"
params = []
if self.token:
params.append(f"token={self.token}")
if self.workspace:
params.append(f"workspace={self.workspace}")
if params:
ws_url = f"{ws_url}?{'&'.join(params)}"
return ws_url
async def import_triples(self, flow: str, triples: AsyncIterator[Triple], **kwargs: Any) -> None:
"""Bulk import triples via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/triples")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for triple in triples:
@ -42,9 +53,7 @@ class AsyncBulkClient:
async def export_triples(self, flow: str, **kwargs: Any) -> AsyncIterator[Triple]:
"""Bulk export triples via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/triples")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@ -57,9 +66,7 @@ class AsyncBulkClient:
async def import_graph_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import graph embeddings via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/graph-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for embedding in embeddings:
@ -67,9 +74,7 @@ class AsyncBulkClient:
async def export_graph_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
"""Bulk export graph embeddings via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/graph-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@ -77,9 +82,7 @@ class AsyncBulkClient:
async def import_document_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import document embeddings via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/document-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for embedding in embeddings:
@ -87,9 +90,7 @@ class AsyncBulkClient:
async def export_document_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
"""Bulk export document embeddings via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/document-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@ -97,9 +98,7 @@ class AsyncBulkClient:
async def import_entity_contexts(self, flow: str, contexts: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import entity contexts via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/entity-contexts")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for context in contexts:
@ -107,9 +106,7 @@ class AsyncBulkClient:
async def export_entity_contexts(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
"""Bulk export entity contexts via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/entity-contexts")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@ -117,9 +114,7 @@ class AsyncBulkClient:
async def import_rows(self, flow: str, rows: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import rows via WebSocket"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/rows")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for row in rows:

View file

@ -34,7 +34,7 @@ class BulkClient:
Note: For true async support, use AsyncBulkClient instead.
"""
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
def __init__(self, url: str, timeout: int, token: Optional[str], workspace: str = "default") -> None:
"""
Initialize synchronous bulk client.
@ -42,10 +42,12 @@ class BulkClient:
url: Base URL for TrustGraph API (HTTP/HTTPS will be converted to WS/WSS)
timeout: WebSocket timeout in seconds
token: Optional bearer token for authentication
workspace: Workspace for data isolation
"""
self.url: str = self._convert_to_ws_url(url)
self.timeout: int = timeout
self.token: Optional[str] = token
self.workspace: str = workspace
def _convert_to_ws_url(self, url: str) -> str:
"""Convert HTTP URL to WebSocket URL"""
@ -58,6 +60,18 @@ class BulkClient:
else:
return f"ws://{url}"
def _build_ws_url(self, path: str) -> str:
"""Build a WebSocket URL with token and workspace query params."""
ws_url = f"{self.url}{path}"
params = []
if self.token:
params.append(f"token={self.token}")
if self.workspace:
params.append(f"workspace={self.workspace}")
if params:
ws_url = f"{ws_url}?{'&'.join(params)}"
return ws_url
def _run_async(self, coro: Coroutine[Any, Any, Any]) -> Any:
"""Run async coroutine synchronously"""
try:
@ -116,9 +130,7 @@ class BulkClient:
metadata: Optional[Dict[str, Any]], batch_size: int
) -> None:
"""Async implementation of triple import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/triples")
if metadata is None:
metadata = {"id": "", "metadata": [], "collection": "default"}
@ -194,9 +206,7 @@ class BulkClient:
async def _export_triples_async(self, flow: str) -> Iterator[Triple]:
"""Async implementation of triple export"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/triples")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@ -238,9 +248,7 @@ class BulkClient:
async def _import_graph_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None:
"""Async implementation of graph embeddings import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/graph-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
for embedding in embeddings:
@ -296,9 +304,7 @@ class BulkClient:
async def _export_graph_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]:
"""Async implementation of graph embeddings export"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/graph-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@ -336,9 +342,7 @@ class BulkClient:
async def _import_document_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None:
"""Async implementation of document embeddings import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/document-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
for embedding in embeddings:
@ -394,9 +398,7 @@ class BulkClient:
async def _export_document_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]:
"""Async implementation of document embeddings export"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/document-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@ -446,9 +448,7 @@ class BulkClient:
metadata: Optional[Dict[str, Any]], batch_size: int
) -> None:
"""Async implementation of entity contexts import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/entity-contexts")
if metadata is None:
metadata = {"id": "", "metadata": [], "collection": "default"}
@ -522,9 +522,7 @@ class BulkClient:
async def _export_entity_contexts_async(self, flow: str) -> Iterator[Dict[str, Any]]:
"""Async implementation of entity contexts export"""
ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/entity-contexts")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@ -562,9 +560,7 @@ class BulkClient:
async def _import_rows_async(self, flow: str, rows: Iterator[Dict[str, Any]]) -> None:
"""Async implementation of rows import"""
ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/rows")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
for row in rows:

View file

@ -167,7 +167,8 @@ class SocketClient:
)
if resp.get("type") == "auth-ok":
self.workspace = resp.get("workspace", self.workspace)
if self.workspace == "default":
self.workspace = resp.get("workspace", self.workspace)
elif resp.get("type") == "auth-failed":
await self._socket.close()
raise ProtocolException(