mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-10 15:25:14 +02:00
fix: route workspace through bulk WebSocket clients and merge query params
Bulk clients (sync and async) were not forwarding the workspace parameter, causing all bulk operations to hit the default workspace regardless of the Api instance's workspace setting. Also fixes the gateway socket endpoint to pass query parameters (including workspace) to the dispatcher, and prevents the auth handshake from overwriting an explicitly set workspace. Updates knowledge table store tests for paged query interface.
This commit is contained in:
parent
6b1dd16f9f
commit
8d8fb5766d
6 changed files with 67 additions and 73 deletions
|
|
@ -35,9 +35,9 @@ def _make_store():
|
|||
class TestGetGraphEmbeddings:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
|
||||
@patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
|
||||
async def test_row_converts_to_entity_embeddings_with_singular_vector(
|
||||
self, mock_async_execute
|
||||
self, mock_async_execute_paged
|
||||
):
|
||||
"""
|
||||
Cassandra rows return entities as a list of [entity_tuple, vector]
|
||||
|
|
@ -57,7 +57,7 @@ class TestGetGraphEmbeddings:
|
|||
store = _make_store()
|
||||
store.cassandra = Mock()
|
||||
store.get_graph_embeddings_stmt = Mock()
|
||||
mock_async_execute.return_value = [fake_row]
|
||||
mock_async_execute_paged.return_value = [[fake_row]]
|
||||
|
||||
received = []
|
||||
|
||||
|
|
@ -66,7 +66,7 @@ class TestGetGraphEmbeddings:
|
|||
|
||||
await store.get_graph_embeddings("alice", "doc-1", receiver)
|
||||
|
||||
mock_async_execute.assert_called_once_with(
|
||||
mock_async_execute_paged.assert_called_once_with(
|
||||
store.cassandra,
|
||||
store.get_graph_embeddings_stmt,
|
||||
("alice", "doc-1"),
|
||||
|
|
@ -96,8 +96,8 @@ class TestGetGraphEmbeddings:
|
|||
assert ge.entities[2].entity.value == "a literal entity"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
|
||||
async def test_empty_entities_blob_yields_empty_list(self, mock_async_execute):
|
||||
@patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
|
||||
async def test_empty_entities_blob_yields_empty_list(self, mock_async_execute_paged):
|
||||
"""row[3] being None / empty must produce a GraphEmbeddings with
|
||||
no entities, not raise."""
|
||||
fake_row = (None, None, None, None)
|
||||
|
|
@ -105,7 +105,7 @@ class TestGetGraphEmbeddings:
|
|||
store = _make_store()
|
||||
store.cassandra = Mock()
|
||||
store.get_graph_embeddings_stmt = Mock()
|
||||
mock_async_execute.return_value = [fake_row]
|
||||
mock_async_execute_paged.return_value = [[fake_row]]
|
||||
|
||||
received = []
|
||||
|
||||
|
|
@ -118,8 +118,8 @@ class TestGetGraphEmbeddings:
|
|||
assert received[0].entities == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
|
||||
async def test_multiple_rows_each_emit_one_message(self, mock_async_execute):
|
||||
@patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
|
||||
async def test_multiple_rows_each_emit_one_message(self, mock_async_execute_paged):
|
||||
fake_rows = [
|
||||
(None, None, None, [
|
||||
(("http://example.org/a", True), [1.0]),
|
||||
|
|
@ -132,7 +132,7 @@ class TestGetGraphEmbeddings:
|
|||
store = _make_store()
|
||||
store.cassandra = Mock()
|
||||
store.get_graph_embeddings_stmt = Mock()
|
||||
mock_async_execute.return_value = fake_rows
|
||||
mock_async_execute_paged.return_value = [fake_rows]
|
||||
|
||||
received = []
|
||||
|
||||
|
|
@ -153,8 +153,8 @@ class TestGetTriples:
|
|||
the same Metadata construction. Cover it for parity."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
|
||||
async def test_row_converts_to_triples(self, mock_async_execute):
|
||||
@patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
|
||||
async def test_row_converts_to_triples(self, mock_async_execute_paged):
|
||||
# row[3] is a list of (s_val, s_uri, p_val, p_uri, o_val, o_uri)
|
||||
fake_row = (
|
||||
None, None, None,
|
||||
|
|
@ -170,7 +170,7 @@ class TestGetTriples:
|
|||
store = _make_store()
|
||||
store.cassandra = Mock()
|
||||
store.get_triples_stmt = Mock()
|
||||
mock_async_execute.return_value = [fake_row]
|
||||
mock_async_execute_paged.return_value = [[fake_row]]
|
||||
|
||||
received = []
|
||||
|
||||
|
|
|
|||
|
|
@ -337,7 +337,7 @@ class Api:
|
|||
from . bulk_client import BulkClient
|
||||
# Extract base URL (remove api/v1/ suffix)
|
||||
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
|
||||
self._bulk_client = BulkClient(base_url, self.timeout, self.token)
|
||||
self._bulk_client = BulkClient(base_url, self.timeout, self.token, workspace=self.workspace)
|
||||
return self._bulk_client
|
||||
|
||||
def metrics(self):
|
||||
|
|
@ -462,7 +462,7 @@ class Api:
|
|||
from . async_bulk_client import AsyncBulkClient
|
||||
# Extract base URL (remove api/v1/ suffix)
|
||||
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
|
||||
self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token)
|
||||
self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token, workspace=self.workspace)
|
||||
return self._async_bulk_client
|
||||
|
||||
def async_metrics(self):
|
||||
|
|
|
|||
|
|
@ -9,10 +9,11 @@ from . types import Triple
|
|||
class AsyncBulkClient:
|
||||
"""Asynchronous bulk operations client"""
|
||||
|
||||
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
|
||||
def __init__(self, url: str, timeout: int, token: Optional[str], workspace: str = "default") -> None:
|
||||
self.url: str = self._convert_to_ws_url(url)
|
||||
self.timeout: int = timeout
|
||||
self.token: Optional[str] = token
|
||||
self.workspace: str = workspace
|
||||
|
||||
def _convert_to_ws_url(self, url: str) -> str:
|
||||
"""Convert HTTP URL to WebSocket URL"""
|
||||
|
|
@ -25,11 +26,21 @@ class AsyncBulkClient:
|
|||
else:
|
||||
return f"ws://{url}"
|
||||
|
||||
def _build_ws_url(self, path: str) -> str:
|
||||
"""Build a WebSocket URL with token and workspace query params."""
|
||||
ws_url = f"{self.url}{path}"
|
||||
params = []
|
||||
if self.token:
|
||||
params.append(f"token={self.token}")
|
||||
if self.workspace:
|
||||
params.append(f"workspace={self.workspace}")
|
||||
if params:
|
||||
ws_url = f"{ws_url}?{'&'.join(params)}"
|
||||
return ws_url
|
||||
|
||||
async def import_triples(self, flow: str, triples: AsyncIterator[Triple], **kwargs: Any) -> None:
|
||||
"""Bulk import triples via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/triples")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for triple in triples:
|
||||
|
|
@ -42,9 +53,7 @@ class AsyncBulkClient:
|
|||
|
||||
async def export_triples(self, flow: str, **kwargs: Any) -> AsyncIterator[Triple]:
|
||||
"""Bulk export triples via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/triples")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for raw_message in websocket:
|
||||
|
|
@ -57,9 +66,7 @@ class AsyncBulkClient:
|
|||
|
||||
async def import_graph_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
"""Bulk import graph embeddings via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/graph-embeddings")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for embedding in embeddings:
|
||||
|
|
@ -67,9 +74,7 @@ class AsyncBulkClient:
|
|||
|
||||
async def export_graph_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Bulk export graph embeddings via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/graph-embeddings")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for raw_message in websocket:
|
||||
|
|
@ -77,9 +82,7 @@ class AsyncBulkClient:
|
|||
|
||||
async def import_document_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
"""Bulk import document embeddings via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/document-embeddings")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for embedding in embeddings:
|
||||
|
|
@ -87,9 +90,7 @@ class AsyncBulkClient:
|
|||
|
||||
async def export_document_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Bulk export document embeddings via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/document-embeddings")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for raw_message in websocket:
|
||||
|
|
@ -97,9 +98,7 @@ class AsyncBulkClient:
|
|||
|
||||
async def import_entity_contexts(self, flow: str, contexts: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
"""Bulk import entity contexts via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/entity-contexts")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for context in contexts:
|
||||
|
|
@ -107,9 +106,7 @@ class AsyncBulkClient:
|
|||
|
||||
async def export_entity_contexts(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Bulk export entity contexts via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/entity-contexts")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for raw_message in websocket:
|
||||
|
|
@ -117,9 +114,7 @@ class AsyncBulkClient:
|
|||
|
||||
async def import_rows(self, flow: str, rows: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
"""Bulk import rows via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/rows")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for row in rows:
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class BulkClient:
|
|||
Note: For true async support, use AsyncBulkClient instead.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
|
||||
def __init__(self, url: str, timeout: int, token: Optional[str], workspace: str = "default") -> None:
|
||||
"""
|
||||
Initialize synchronous bulk client.
|
||||
|
||||
|
|
@ -42,10 +42,12 @@ class BulkClient:
|
|||
url: Base URL for TrustGraph API (HTTP/HTTPS will be converted to WS/WSS)
|
||||
timeout: WebSocket timeout in seconds
|
||||
token: Optional bearer token for authentication
|
||||
workspace: Workspace for data isolation
|
||||
"""
|
||||
self.url: str = self._convert_to_ws_url(url)
|
||||
self.timeout: int = timeout
|
||||
self.token: Optional[str] = token
|
||||
self.workspace: str = workspace
|
||||
|
||||
def _convert_to_ws_url(self, url: str) -> str:
|
||||
"""Convert HTTP URL to WebSocket URL"""
|
||||
|
|
@ -58,6 +60,18 @@ class BulkClient:
|
|||
else:
|
||||
return f"ws://{url}"
|
||||
|
||||
def _build_ws_url(self, path: str) -> str:
|
||||
"""Build a WebSocket URL with token and workspace query params."""
|
||||
ws_url = f"{self.url}{path}"
|
||||
params = []
|
||||
if self.token:
|
||||
params.append(f"token={self.token}")
|
||||
if self.workspace:
|
||||
params.append(f"workspace={self.workspace}")
|
||||
if params:
|
||||
ws_url = f"{ws_url}?{'&'.join(params)}"
|
||||
return ws_url
|
||||
|
||||
def _run_async(self, coro: Coroutine[Any, Any, Any]) -> Any:
|
||||
"""Run async coroutine synchronously"""
|
||||
try:
|
||||
|
|
@ -116,9 +130,7 @@ class BulkClient:
|
|||
metadata: Optional[Dict[str, Any]], batch_size: int
|
||||
) -> None:
|
||||
"""Async implementation of triple import"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/triples")
|
||||
|
||||
if metadata is None:
|
||||
metadata = {"id": "", "metadata": [], "collection": "default"}
|
||||
|
|
@ -194,9 +206,7 @@ class BulkClient:
|
|||
|
||||
async def _export_triples_async(self, flow: str) -> Iterator[Triple]:
|
||||
"""Async implementation of triple export"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/triples")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for raw_message in websocket:
|
||||
|
|
@ -238,9 +248,7 @@ class BulkClient:
|
|||
|
||||
async def _import_graph_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None:
|
||||
"""Async implementation of graph embeddings import"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/graph-embeddings")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
for embedding in embeddings:
|
||||
|
|
@ -296,9 +304,7 @@ class BulkClient:
|
|||
|
||||
async def _export_graph_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]:
|
||||
"""Async implementation of graph embeddings export"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/graph-embeddings")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for raw_message in websocket:
|
||||
|
|
@ -336,9 +342,7 @@ class BulkClient:
|
|||
|
||||
async def _import_document_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None:
|
||||
"""Async implementation of document embeddings import"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/document-embeddings")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
for embedding in embeddings:
|
||||
|
|
@ -394,9 +398,7 @@ class BulkClient:
|
|||
|
||||
async def _export_document_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]:
|
||||
"""Async implementation of document embeddings export"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/document-embeddings")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for raw_message in websocket:
|
||||
|
|
@ -446,9 +448,7 @@ class BulkClient:
|
|||
metadata: Optional[Dict[str, Any]], batch_size: int
|
||||
) -> None:
|
||||
"""Async implementation of entity contexts import"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/entity-contexts")
|
||||
|
||||
if metadata is None:
|
||||
metadata = {"id": "", "metadata": [], "collection": "default"}
|
||||
|
|
@ -522,9 +522,7 @@ class BulkClient:
|
|||
|
||||
async def _export_entity_contexts_async(self, flow: str) -> Iterator[Dict[str, Any]]:
|
||||
"""Async implementation of entity contexts export"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/entity-contexts")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for raw_message in websocket:
|
||||
|
|
@ -562,9 +560,7 @@ class BulkClient:
|
|||
|
||||
async def _import_rows_async(self, flow: str, rows: Iterator[Dict[str, Any]]) -> None:
|
||||
"""Async implementation of rows import"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/rows")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
for row in rows:
|
||||
|
|
|
|||
|
|
@ -167,7 +167,8 @@ class SocketClient:
|
|||
)
|
||||
|
||||
if resp.get("type") == "auth-ok":
|
||||
self.workspace = resp.get("workspace", self.workspace)
|
||||
if self.workspace == "default":
|
||||
self.workspace = resp.get("workspace", self.workspace)
|
||||
elif resp.get("type") == "auth-failed":
|
||||
await self._socket.close()
|
||||
raise ProtocolException(
|
||||
|
|
|
|||
|
|
@ -117,8 +117,10 @@ class SocketEndpoint:
|
|||
|
||||
running = Running()
|
||||
|
||||
params = dict(request.query)
|
||||
params.update(request.match_info)
|
||||
dispatcher = await self.dispatcher(
|
||||
ws, running, request.match_info
|
||||
ws, running, params
|
||||
)
|
||||
|
||||
worker_task = tg.create_task(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue