mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-02 22:41:01 +02:00
Compare commits
2 commits
6b1dd16f9f
...
60f861bac4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60f861bac4 | ||
|
|
00bb964e93 |
7 changed files with 71 additions and 78 deletions
|
|
@ -35,9 +35,9 @@ def _make_store():
|
|||
class TestGetGraphEmbeddings:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
|
||||
@patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
|
||||
async def test_row_converts_to_entity_embeddings_with_singular_vector(
|
||||
self, mock_async_execute
|
||||
self, mock_async_execute_paged
|
||||
):
|
||||
"""
|
||||
Cassandra rows return entities as a list of [entity_tuple, vector]
|
||||
|
|
@ -57,7 +57,7 @@ class TestGetGraphEmbeddings:
|
|||
store = _make_store()
|
||||
store.cassandra = Mock()
|
||||
store.get_graph_embeddings_stmt = Mock()
|
||||
mock_async_execute.return_value = [fake_row]
|
||||
mock_async_execute_paged.return_value = [[fake_row]]
|
||||
|
||||
received = []
|
||||
|
||||
|
|
@ -66,7 +66,7 @@ class TestGetGraphEmbeddings:
|
|||
|
||||
await store.get_graph_embeddings("alice", "doc-1", receiver)
|
||||
|
||||
mock_async_execute.assert_called_once_with(
|
||||
mock_async_execute_paged.assert_called_once_with(
|
||||
store.cassandra,
|
||||
store.get_graph_embeddings_stmt,
|
||||
("alice", "doc-1"),
|
||||
|
|
@ -96,8 +96,8 @@ class TestGetGraphEmbeddings:
|
|||
assert ge.entities[2].entity.value == "a literal entity"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
|
||||
async def test_empty_entities_blob_yields_empty_list(self, mock_async_execute):
|
||||
@patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
|
||||
async def test_empty_entities_blob_yields_empty_list(self, mock_async_execute_paged):
|
||||
"""row[3] being None / empty must produce a GraphEmbeddings with
|
||||
no entities, not raise."""
|
||||
fake_row = (None, None, None, None)
|
||||
|
|
@ -105,7 +105,7 @@ class TestGetGraphEmbeddings:
|
|||
store = _make_store()
|
||||
store.cassandra = Mock()
|
||||
store.get_graph_embeddings_stmt = Mock()
|
||||
mock_async_execute.return_value = [fake_row]
|
||||
mock_async_execute_paged.return_value = [[fake_row]]
|
||||
|
||||
received = []
|
||||
|
||||
|
|
@ -118,8 +118,8 @@ class TestGetGraphEmbeddings:
|
|||
assert received[0].entities == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
|
||||
async def test_multiple_rows_each_emit_one_message(self, mock_async_execute):
|
||||
@patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
|
||||
async def test_multiple_rows_each_emit_one_message(self, mock_async_execute_paged):
|
||||
fake_rows = [
|
||||
(None, None, None, [
|
||||
(("http://example.org/a", True), [1.0]),
|
||||
|
|
@ -132,7 +132,7 @@ class TestGetGraphEmbeddings:
|
|||
store = _make_store()
|
||||
store.cassandra = Mock()
|
||||
store.get_graph_embeddings_stmt = Mock()
|
||||
mock_async_execute.return_value = fake_rows
|
||||
mock_async_execute_paged.return_value = [fake_rows]
|
||||
|
||||
received = []
|
||||
|
||||
|
|
@ -153,8 +153,8 @@ class TestGetTriples:
|
|||
the same Metadata construction. Cover it for parity."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
|
||||
async def test_row_converts_to_triples(self, mock_async_execute):
|
||||
@patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
|
||||
async def test_row_converts_to_triples(self, mock_async_execute_paged):
|
||||
# row[3] is a list of (s_val, s_uri, p_val, p_uri, o_val, o_uri)
|
||||
fake_row = (
|
||||
None, None, None,
|
||||
|
|
@ -170,7 +170,7 @@ class TestGetTriples:
|
|||
store = _make_store()
|
||||
store.cassandra = Mock()
|
||||
store.get_triples_stmt = Mock()
|
||||
mock_async_execute.return_value = [fake_row]
|
||||
mock_async_execute_paged.return_value = [[fake_row]]
|
||||
|
||||
received = []
|
||||
|
||||
|
|
|
|||
|
|
@ -337,7 +337,7 @@ class Api:
|
|||
from . bulk_client import BulkClient
|
||||
# Extract base URL (remove api/v1/ suffix)
|
||||
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
|
||||
self._bulk_client = BulkClient(base_url, self.timeout, self.token)
|
||||
self._bulk_client = BulkClient(base_url, self.timeout, self.token, workspace=self.workspace)
|
||||
return self._bulk_client
|
||||
|
||||
def metrics(self):
|
||||
|
|
@ -462,7 +462,7 @@ class Api:
|
|||
from . async_bulk_client import AsyncBulkClient
|
||||
# Extract base URL (remove api/v1/ suffix)
|
||||
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
|
||||
self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token)
|
||||
self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token, workspace=self.workspace)
|
||||
return self._async_bulk_client
|
||||
|
||||
def async_metrics(self):
|
||||
|
|
|
|||
|
|
@ -9,10 +9,11 @@ from . types import Triple
|
|||
class AsyncBulkClient:
|
||||
"""Asynchronous bulk operations client"""
|
||||
|
||||
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
|
||||
def __init__(self, url: str, timeout: int, token: Optional[str], workspace: str = "default") -> None:
|
||||
self.url: str = self._convert_to_ws_url(url)
|
||||
self.timeout: int = timeout
|
||||
self.token: Optional[str] = token
|
||||
self.workspace: str = workspace
|
||||
|
||||
def _convert_to_ws_url(self, url: str) -> str:
|
||||
"""Convert HTTP URL to WebSocket URL"""
|
||||
|
|
@ -25,11 +26,21 @@ class AsyncBulkClient:
|
|||
else:
|
||||
return f"ws://{url}"
|
||||
|
||||
def _build_ws_url(self, path: str) -> str:
|
||||
"""Build a WebSocket URL with token and workspace query params."""
|
||||
ws_url = f"{self.url}{path}"
|
||||
params = []
|
||||
if self.token:
|
||||
params.append(f"token={self.token}")
|
||||
if self.workspace:
|
||||
params.append(f"workspace={self.workspace}")
|
||||
if params:
|
||||
ws_url = f"{ws_url}?{'&'.join(params)}"
|
||||
return ws_url
|
||||
|
||||
async def import_triples(self, flow: str, triples: AsyncIterator[Triple], **kwargs: Any) -> None:
|
||||
"""Bulk import triples via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/triples")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for triple in triples:
|
||||
|
|
@ -42,9 +53,7 @@ class AsyncBulkClient:
|
|||
|
||||
async def export_triples(self, flow: str, **kwargs: Any) -> AsyncIterator[Triple]:
|
||||
"""Bulk export triples via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/triples")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for raw_message in websocket:
|
||||
|
|
@ -57,9 +66,7 @@ class AsyncBulkClient:
|
|||
|
||||
async def import_graph_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
"""Bulk import graph embeddings via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/graph-embeddings")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for embedding in embeddings:
|
||||
|
|
@ -67,9 +74,7 @@ class AsyncBulkClient:
|
|||
|
||||
async def export_graph_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Bulk export graph embeddings via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/graph-embeddings")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for raw_message in websocket:
|
||||
|
|
@ -77,9 +82,7 @@ class AsyncBulkClient:
|
|||
|
||||
async def import_document_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
"""Bulk import document embeddings via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/document-embeddings")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for embedding in embeddings:
|
||||
|
|
@ -87,9 +90,7 @@ class AsyncBulkClient:
|
|||
|
||||
async def export_document_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Bulk export document embeddings via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/document-embeddings")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for raw_message in websocket:
|
||||
|
|
@ -97,9 +98,7 @@ class AsyncBulkClient:
|
|||
|
||||
async def import_entity_contexts(self, flow: str, contexts: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
"""Bulk import entity contexts via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/entity-contexts")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for context in contexts:
|
||||
|
|
@ -107,9 +106,7 @@ class AsyncBulkClient:
|
|||
|
||||
async def export_entity_contexts(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Bulk export entity contexts via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/entity-contexts")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for raw_message in websocket:
|
||||
|
|
@ -117,9 +114,7 @@ class AsyncBulkClient:
|
|||
|
||||
async def import_rows(self, flow: str, rows: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
|
||||
"""Bulk import rows via WebSocket"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/rows")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for row in rows:
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class BulkClient:
|
|||
Note: For true async support, use AsyncBulkClient instead.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
|
||||
def __init__(self, url: str, timeout: int, token: Optional[str], workspace: str = "default") -> None:
|
||||
"""
|
||||
Initialize synchronous bulk client.
|
||||
|
||||
|
|
@ -42,10 +42,12 @@ class BulkClient:
|
|||
url: Base URL for TrustGraph API (HTTP/HTTPS will be converted to WS/WSS)
|
||||
timeout: WebSocket timeout in seconds
|
||||
token: Optional bearer token for authentication
|
||||
workspace: Workspace for data isolation
|
||||
"""
|
||||
self.url: str = self._convert_to_ws_url(url)
|
||||
self.timeout: int = timeout
|
||||
self.token: Optional[str] = token
|
||||
self.workspace: str = workspace
|
||||
|
||||
def _convert_to_ws_url(self, url: str) -> str:
|
||||
"""Convert HTTP URL to WebSocket URL"""
|
||||
|
|
@ -58,6 +60,18 @@ class BulkClient:
|
|||
else:
|
||||
return f"ws://{url}"
|
||||
|
||||
def _build_ws_url(self, path: str) -> str:
|
||||
"""Build a WebSocket URL with token and workspace query params."""
|
||||
ws_url = f"{self.url}{path}"
|
||||
params = []
|
||||
if self.token:
|
||||
params.append(f"token={self.token}")
|
||||
if self.workspace:
|
||||
params.append(f"workspace={self.workspace}")
|
||||
if params:
|
||||
ws_url = f"{ws_url}?{'&'.join(params)}"
|
||||
return ws_url
|
||||
|
||||
def _run_async(self, coro: Coroutine[Any, Any, Any]) -> Any:
|
||||
"""Run async coroutine synchronously"""
|
||||
try:
|
||||
|
|
@ -116,9 +130,7 @@ class BulkClient:
|
|||
metadata: Optional[Dict[str, Any]], batch_size: int
|
||||
) -> None:
|
||||
"""Async implementation of triple import"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/triples")
|
||||
|
||||
if metadata is None:
|
||||
metadata = {"id": "", "metadata": [], "collection": "default"}
|
||||
|
|
@ -194,9 +206,7 @@ class BulkClient:
|
|||
|
||||
async def _export_triples_async(self, flow: str) -> Iterator[Triple]:
|
||||
"""Async implementation of triple export"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/triples")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for raw_message in websocket:
|
||||
|
|
@ -238,9 +248,7 @@ class BulkClient:
|
|||
|
||||
async def _import_graph_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None:
|
||||
"""Async implementation of graph embeddings import"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/graph-embeddings")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
for embedding in embeddings:
|
||||
|
|
@ -296,9 +304,7 @@ class BulkClient:
|
|||
|
||||
async def _export_graph_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]:
|
||||
"""Async implementation of graph embeddings export"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/graph-embeddings")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for raw_message in websocket:
|
||||
|
|
@ -336,9 +342,7 @@ class BulkClient:
|
|||
|
||||
async def _import_document_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None:
|
||||
"""Async implementation of document embeddings import"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/document-embeddings")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
for embedding in embeddings:
|
||||
|
|
@ -394,9 +398,7 @@ class BulkClient:
|
|||
|
||||
async def _export_document_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]:
|
||||
"""Async implementation of document embeddings export"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/document-embeddings")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for raw_message in websocket:
|
||||
|
|
@ -446,9 +448,7 @@ class BulkClient:
|
|||
metadata: Optional[Dict[str, Any]], batch_size: int
|
||||
) -> None:
|
||||
"""Async implementation of entity contexts import"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/entity-contexts")
|
||||
|
||||
if metadata is None:
|
||||
metadata = {"id": "", "metadata": [], "collection": "default"}
|
||||
|
|
@ -522,9 +522,7 @@ class BulkClient:
|
|||
|
||||
async def _export_entity_contexts_async(self, flow: str) -> Iterator[Dict[str, Any]]:
|
||||
"""Async implementation of entity contexts export"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/entity-contexts")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
async for raw_message in websocket:
|
||||
|
|
@ -562,9 +560,7 @@ class BulkClient:
|
|||
|
||||
async def _import_rows_async(self, flow: str, rows: Iterator[Dict[str, Any]]) -> None:
|
||||
"""Async implementation of rows import"""
|
||||
ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/rows")
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
for row in rows:
|
||||
|
|
|
|||
|
|
@ -167,7 +167,8 @@ class SocketClient:
|
|||
)
|
||||
|
||||
if resp.get("type") == "auth-ok":
|
||||
self.workspace = resp.get("workspace", self.workspace)
|
||||
if self.workspace == "default":
|
||||
self.workspace = resp.get("workspace", self.workspace)
|
||||
elif resp.get("type") == "auth-failed":
|
||||
await self._socket.close()
|
||||
raise ProtocolException(
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ Supports dual output to console and Loki for centralized log aggregation.
|
|||
import contextvars
|
||||
import logging
|
||||
import logging.handlers
|
||||
import uuid
|
||||
from argparse import ArgumentParser
|
||||
from queue import Queue
|
||||
from typing import Any
|
||||
|
|
@ -132,14 +133,12 @@ def setup_logging(args: dict[str, Any]) -> None:
|
|||
try:
|
||||
from logging_loki import LokiHandler
|
||||
|
||||
# Create Loki handler with optional authentication. The
|
||||
# processor label is NOT baked in here — it's stamped onto
|
||||
# each record by _ProcessorIdFilter reading the task-local
|
||||
# contextvar, and logging_loki's emitter reads record.tags
|
||||
# to build per-record Loki labels.
|
||||
instance_id = str(uuid.uuid4())[:8]
|
||||
|
||||
loki_handler_kwargs = {
|
||||
'url': loki_url,
|
||||
'version': "1",
|
||||
'tags': {'instance': instance_id},
|
||||
}
|
||||
|
||||
if loki_username and loki_password:
|
||||
|
|
|
|||
|
|
@ -117,8 +117,10 @@ class SocketEndpoint:
|
|||
|
||||
running = Running()
|
||||
|
||||
params = dict(request.query)
|
||||
params.update(request.match_info)
|
||||
dispatcher = await self.dispatcher(
|
||||
ws, running, request.match_info
|
||||
ws, running, params
|
||||
)
|
||||
|
||||
worker_task = tg.create_task(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue