From 6dfa47aac8661dc50944f166e687a9c99ca90448 Mon Sep 17 00:00:00 2001
From: Jack Colquitt <126733989+JackColquitt@users.noreply.github.com>
Date: Sat, 30 May 2026 17:07:19 -0700
Subject: [PATCH 01/18] Revise README for semantic infrastructure terminology
(#962)
Updated the README to reflect changes in terminology and improve clarity regarding the platform's features.
---
README.md | 32 +++++++++++++++-----------------
1 file changed, 15 insertions(+), 17 deletions(-)
diff --git a/README.md b/README.md
index c366a3d9..1edccff6 100644
--- a/README.md
+++ b/README.md
@@ -11,11 +11,11 @@
-# The agent runtime platform
+# The semantic infrastructure for agents
-TrustGraph is an agent runtime platform built around context graphs — structured, queryable representations of your domain knowledge that ground every agent query in verified, explainable facts in private deployments with sovereign control. The platform is the full stack for agentic systems: context graphs, memory, retrieval, orchestration, and inference for precision-critical agent workloads.
+TrustGraph is a comprehensive semantic infrastructure for agents built around context graphs — structured, queryable representations of your domain knowledge that ground every agent query in verified, explainable facts in private deployments with sovereign control. The platform is the full stack for agentic systems: context graphs, memory, retrieval, orchestration, and inference for deterministic agent workloads.
The platform:
- [x] Multi-model and multimodal database system
@@ -99,23 +99,21 @@ For a browser based configuration, try the [Configuration Terminal](https://conf
- [**Developer APIs and CLI**](https://docs.trustgraph.ai/reference)
- [**Deployment Guides**](https://docs.trustgraph.ai/deployment)
-## Workbench
+## Context Graph UI
-The **Workbench** provides tools for all major features of TrustGraph. The **Workbench** is on port `8888` by default.
+
-- **Vector Search**: Search the installed knowledge bases
-- **Agentic, GraphRAG and LLM Chat**: Chat interface for agents, GraphRAG queries, or direct to LLMs
-- **Relationships**: Analyze deep relationships in the installed knowledge bases
-- **Graph Visualizer**: 3D GraphViz of the installed knowledge bases
-- **Library**: Staging area for installing knowledge bases
-- **Flow Classes**: Workflow preset configurations
-- **Flows**: Create custom workflows and adjust LLM parameters during runtime
-- **Knowledge Cores**: Manage resuable knowledge bases
-- **Prompts**: Manage and adjust prompts during runtime
-- **Schemas**: Define custom schemas for structured data knowledge bases
-- **Ontologies**: Define custom ontologies for unstructured data knowledge bases
-- **Agent Tools**: Define tools with collections, knowledge cores, MCP connections, and tool groups
-- **MCP Tools**: Connect to MCP servers
+The UI provides tools for all major features of TrustGraph. The UI deploys on port `8888` by default.
+
+- **Agent Console** — Query your agents directly with streaming responses and live explainability event tracking, so you can watch reasoning unfold in real time
+- **GraphRAG View** — Interactive graph RAG queries with a visual explainability DAG and inline provenance display, making it easy to see exactly where answers came from
+- **Context Explorer** — An interactive 3D context graph explorer with dynamic graph loading, BFS neighborhood extraction, edge pulse animation, and multiple navigation views
+- **Document Ingestion** — A complete upload and submission workflow with page and chunk inspection and document structure browsing
+- **Ontology Workbench** — A full ontology editor with class and property trees, OWL/XML and Turtle import/export with round-trip fidelity, circular dependency detection, and safe-delete confirmation dialogs
+- **Schema Workbench** — Interactive schema management with list, create, edit, and delete operations including field and index management
+- **Flow Management** — Flow creation and detail views with configurable parameters, temperature controls, and grouped storage layout
+- **Workspace UX** — Workspace selection and management surfaced directly in the interface
+- **Prompt Editor** — A dedicated prompt editing workflow
## TypeScript Library for UIs
From d1e6b99e9616f3e45f42275b86372095c94e09e5 Mon Sep 17 00:00:00 2001
From: cybermaggedon
Date: Mon, 1 Jun 2026 09:53:28 +0100
Subject: [PATCH 02/18] fix: CLI tools ignoring -w flag for workspace routing
(#964)
Several CLI commands silently routed requests to the default workspace
regardless of the -w flag: show-flows, show-flow-blueprints,
show-parameter-types, set-prompt --system, and load-structured-data.
The workspace was sent in the inner request body but not on the
WebSocket envelope or API client constructor, so the gateway always
dispatched to the default workspace queue.
---
.../trustgraph/cli/load_structured_data.py | 2 +-
trustgraph-cli/trustgraph/cli/set_prompt.py | 3 ++-
.../trustgraph/cli/show_flow_blueprints.py | 2 +-
trustgraph-cli/trustgraph/cli/show_flows.py | 2 +-
.../trustgraph/cli/show_parameter_types.py | 19 +++++++++++++------
5 files changed, 18 insertions(+), 10 deletions(-)
diff --git a/trustgraph-cli/trustgraph/cli/load_structured_data.py b/trustgraph-cli/trustgraph/cli/load_structured_data.py
index 3cd2a229..dccf548e 100644
--- a/trustgraph-cli/trustgraph/cli/load_structured_data.py
+++ b/trustgraph-cli/trustgraph/cli/load_structured_data.py
@@ -293,7 +293,7 @@ def load_structured_data(
# Send to TrustGraph
print(f"🚀 Importing {len(output_records)} records to TrustGraph...")
- imported_count = _send_to_trustgraph(output_records, api_url, flow, batch_size, token=token)
+ imported_count = _send_to_trustgraph(output_records, api_url, flow, batch_size, token=token, workspace=workspace)
# Get summary info from descriptor
format_info = descriptor.get('format', {})
diff --git a/trustgraph-cli/trustgraph/cli/set_prompt.py b/trustgraph-cli/trustgraph/cli/set_prompt.py
index dbf9c326..2feaba00 100644
--- a/trustgraph-cli/trustgraph/cli/set_prompt.py
+++ b/trustgraph-cli/trustgraph/cli/set_prompt.py
@@ -119,7 +119,8 @@ def main():
raise RuntimeError("Can't use --system with other args")
set_system(
- url=args.api_url, system=args.system, token=args.token
+ url=args.api_url, system=args.system, token=args.token,
+ workspace=args.workspace,
)
else:
diff --git a/trustgraph-cli/trustgraph/cli/show_flow_blueprints.py b/trustgraph-cli/trustgraph/cli/show_flow_blueprints.py
index 4924c925..c1aea836 100644
--- a/trustgraph-cli/trustgraph/cli/show_flow_blueprints.py
+++ b/trustgraph-cli/trustgraph/cli/show_flow_blueprints.py
@@ -105,7 +105,7 @@ async def fetch_data(client, workspace):
return blueprint_names, blueprints, param_type_defs
async def _show_flow_blueprints_async(url, token=None, workspace="default"):
- async with AsyncSocketClient(url, timeout=60, token=token) as client:
+ async with AsyncSocketClient(url, timeout=60, token=token, workspace=workspace) as client:
return await fetch_data(client, workspace)
def show_flow_blueprints(url, token=None, workspace="default"):
diff --git a/trustgraph-cli/trustgraph/cli/show_flows.py b/trustgraph-cli/trustgraph/cli/show_flows.py
index 6e9479f9..b8a30c44 100644
--- a/trustgraph-cli/trustgraph/cli/show_flows.py
+++ b/trustgraph-cli/trustgraph/cli/show_flows.py
@@ -213,7 +213,7 @@ async def fetch_show_flows(client, workspace):
async def _show_flows_async(url, token=None, workspace="default"):
- async with AsyncSocketClient(url, timeout=60, token=token) as client:
+ async with AsyncSocketClient(url, timeout=60, token=token, workspace=workspace) as client:
return await fetch_show_flows(client, workspace)
def show_flows(url, token=None, workspace="default"):
diff --git a/trustgraph-cli/trustgraph/cli/show_parameter_types.py b/trustgraph-cli/trustgraph/cli/show_parameter_types.py
index 67d6e823..b0b25f3d 100644
--- a/trustgraph-cli/trustgraph/cli/show_parameter_types.py
+++ b/trustgraph-cli/trustgraph/cli/show_parameter_types.py
@@ -15,6 +15,7 @@ import json
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
+default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def format_enum_values(enum_list):
"""
@@ -125,11 +126,11 @@ async def fetch_single_param_type(client, param_type_name):
return json.loads(values[0].get("value", "{}"))
return None
-def show_parameter_types(url, token=None):
+def show_parameter_types(url, token=None, workspace="default"):
"""Show all parameter type definitions."""
async def _fetch():
- async with AsyncSocketClient(url, timeout=60, token=token) as client:
+ async with AsyncSocketClient(url, timeout=60, token=token, workspace=workspace) as client:
return await fetch_all_param_types(client)
param_type_names, param_type_defs = asyncio.run(_fetch())
@@ -153,11 +154,11 @@ def show_parameter_types(url, token=None):
))
print()
-def show_specific_parameter_type(url, param_type_name, token=None):
+def show_specific_parameter_type(url, param_type_name, token=None, workspace="default"):
"""Show a specific parameter type definition."""
async def _fetch():
- async with AsyncSocketClient(url, timeout=60, token=token) as client:
+ async with AsyncSocketClient(url, timeout=60, token=token, workspace=workspace) as client:
return await fetch_single_param_type(client, param_type_name)
param_type_def = asyncio.run(_fetch())
@@ -193,6 +194,12 @@ def main():
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
+ parser.add_argument(
+ '-w', '--workspace',
+ default=default_workspace,
+ help=f'Workspace (default: {default_workspace})',
+ )
+
parser.add_argument(
'-t', '--type',
help='Show only the specified parameter type',
@@ -202,9 +209,9 @@ def main():
try:
if args.type:
- show_specific_parameter_type(args.api_url, args.type, args.token)
+ show_specific_parameter_type(args.api_url, args.type, args.token, workspace=args.workspace)
else:
- show_parameter_types(args.api_url, args.token)
+ show_parameter_types(args.api_url, args.token, workspace=args.workspace)
except Exception as e:
print("Exception:", e, flush=True)
From e6dfccc56d8d7e151cffbad657b3fe689e02aad6 Mon Sep 17 00:00:00 2001
From: cybermaggedon
Date: Mon, 1 Jun 2026 12:25:19 +0100
Subject: [PATCH 03/18] fix: WebSocket auth handshake overwriting explicit
workspace (#966)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
The auth-ok response includes the token's bound workspace, and
AsyncSocketClient was unconditionally adopting it — clobbering any
workspace the caller explicitly requested via the constructor.
---
trustgraph-base/trustgraph/api/async_socket_client.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py
index d18bee34..7b38a4b1 100644
--- a/trustgraph-base/trustgraph/api/async_socket_client.py
+++ b/trustgraph-base/trustgraph/api/async_socket_client.py
@@ -30,6 +30,7 @@ class AsyncSocketClient:
self.timeout = timeout
self.token = token
self.workspace = workspace
+ self._workspace_explicit = workspace != "default"
self._request_counter = 0
self._socket = None
self._connect_cm = None
@@ -92,7 +93,8 @@ class AsyncSocketClient:
)
if resp.get("type") == "auth-ok":
- self.workspace = resp.get("workspace", self.workspace)
+ if not self._workspace_explicit:
+ self.workspace = resp.get("workspace", self.workspace)
elif resp.get("type") == "auth-failed":
await self._socket.close()
raise ProtocolException(
From 7e1fb76bc9e6fb78af956703d15047acee6d8afe Mon Sep 17 00:00:00 2001
From: cybermaggedon
Date: Mon, 1 Jun 2026 12:35:09 +0100
Subject: [PATCH 04/18] Fix HF embeddings tests (#967)
The tests were patching
trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings - a module-level
attribute that doesn't exist because HuggingFaceEmbeddings is
imported locally inside _load_model. Changed all 8 occurrences to
patch langchain_huggingface.HuggingFaceEmbeddings, which is the
actual import source the code uses at runtime.
---
.../test_huggingface_dynamic_model.py | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/tests/unit/test_embeddings/test_huggingface_dynamic_model.py b/tests/unit/test_embeddings/test_huggingface_dynamic_model.py
index aef6fc92..65837323 100644
--- a/tests/unit/test_embeddings/test_huggingface_dynamic_model.py
+++ b/tests/unit/test_embeddings/test_huggingface_dynamic_model.py
@@ -18,7 +18,7 @@ from trustgraph.embeddings.hf.hf import Processor
class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
"""Test HuggingFace dynamic model loading and caching"""
- @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
+ @patch('langchain_huggingface.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_default_model_loaded_on_init(self, mock_embeddings_init, mock_async_init, mock_hf_class):
@@ -39,7 +39,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
assert processor.cached_model_name == "test-model"
assert processor.embeddings is not None
- @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
+ @patch('langchain_huggingface.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_model_caching_avoids_reload(self, mock_embeddings_init, mock_async_init, mock_hf_class):
@@ -63,7 +63,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
mock_hf_class.assert_not_called()
assert processor.cached_model_name == "test-model"
- @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
+ @patch('langchain_huggingface.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_model_reload_on_name_change(self, mock_embeddings_init, mock_async_init, mock_hf_class):
@@ -84,7 +84,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
mock_hf_class.assert_called_once_with(model_name="different-model")
assert processor.cached_model_name == "different-model"
- @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
+ @patch('langchain_huggingface.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_on_embeddings_uses_default_model(self, mock_embeddings_init, mock_async_init, mock_hf_class):
@@ -107,7 +107,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
assert processor.cached_model_name == "test-model" # Still using default
assert result == [[0.1, 0.2, 0.3, 0.4, 0.5]]
- @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
+ @patch('langchain_huggingface.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_on_embeddings_uses_specified_model(self, mock_embeddings_init, mock_async_init, mock_hf_class):
@@ -130,7 +130,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
assert processor.cached_model_name == "custom-model"
mock_hf_instance.embed_documents.assert_called_once_with(["test text"])
- @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
+ @patch('langchain_huggingface.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_multiple_model_switches(self, mock_embeddings_init, mock_async_init, mock_hf_class):
@@ -164,7 +164,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
assert call_count_after_b == initial_call_count + 2 # Reload for model-b
assert call_count_after_a_again == initial_call_count + 3 # Reload back to model-a
- @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
+ @patch('langchain_huggingface.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_none_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_hf_class):
@@ -187,7 +187,7 @@ class TestHuggingFaceDynamicModelLoading(IsolatedAsyncioTestCase):
assert mock_hf_class.call_count == initial_count
assert processor.cached_model_name == "test-model"
- @patch('trustgraph.embeddings.hf.hf.HuggingFaceEmbeddings')
+ @patch('langchain_huggingface.HuggingFaceEmbeddings')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.embeddings_service.EmbeddingsService.__init__')
async def test_initialization_without_model_uses_default(self, mock_embeddings_init, mock_async_init, mock_hf_class):
From 97453d9b8319910c4f9d1860a6752b5a9c9f53ed Mon Sep 17 00:00:00 2001
From: Jack Colquitt <126733989+JackColquitt@users.noreply.github.com>
Date: Mon, 1 Jun 2026 14:08:30 -0700
Subject: [PATCH 05/18] Change project title to 'The semantic deployment
platform' (#968)
Updated the project title in the README.
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 1edccff6..b66edc70 100644
--- a/README.md
+++ b/README.md
@@ -11,7 +11,7 @@
-# The semantic infrastructure for agents
+# The semantic deployment platform
From 6b1dd16f9fd2d7bfe27e6430d4e2fd2dec8eb790 Mon Sep 17 00:00:00 2001
From: cybermaggedon
Date: Mon, 1 Jun 2026 22:39:30 +0100
Subject: [PATCH 06/18] fix: large document handling and Cassandra query
pagination (#969)
- Paginate heavy Cassandra reads (triples, graph/document embeddings)
using synchronous session.execute() in run_in_executor with fetch_size
paging, preventing materialization hang on large result sets
- Fix document stream endpoint to use workspace-scoped librarian queues
- Add decoder error handling for PDF/OCR/unstructured processors
- Add WebSocket mux guards for missing auth fields
- Add null check in librarian document streaming
- Rewrite get_document_content CLI to stream via librarian
- Add Poppler dependency to unstructured container
---
containers/Containerfile.unstructured | 2 +-
.../trustgraph/cli/get_document_content.py | 24 +++-
.../decoding/mistral_ocr/processor.py | 9 +-
.../trustgraph/decoding/pdf/pdf_decoder.py | 10 +-
.../gateway/dispatch/document_stream.py | 8 +-
.../trustgraph/gateway/dispatch/mux.py | 4 +
.../trustgraph/librarian/librarian.py | 3 +
.../trustgraph/tables/cassandra_async.py | 35 +++++
.../trustgraph/tables/knowledge.py | 127 +++++++++---------
.../trustgraph/decoding/ocr/pdf_decoder.py | 9 +-
.../decoding/universal/processor.py | 9 +-
11 files changed, 166 insertions(+), 74 deletions(-)
diff --git a/containers/Containerfile.unstructured b/containers/Containerfile.unstructured
index 6de8a800..2b9a18f7 100644
--- a/containers/Containerfile.unstructured
+++ b/containers/Containerfile.unstructured
@@ -7,7 +7,7 @@ FROM docker.io/fedora:42 AS base
ENV PIP_BREAK_SYSTEM_PACKAGES=1
-RUN dnf install -y python3.13 libxcb mesa-libGL && \
+RUN dnf install -y python3.13 libxcb mesa-libGL poppler-utils && \
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
python -m ensurepip --upgrade && \
pip3 install --no-cache-dir --upgrade 'pip>=26.0' 'setuptools>=78.1.1' && \
diff --git a/trustgraph-cli/trustgraph/cli/get_document_content.py b/trustgraph-cli/trustgraph/cli/get_document_content.py
index 62fa7ca2..f4d44cca 100644
--- a/trustgraph-cli/trustgraph/cli/get_document_content.py
+++ b/trustgraph-cli/trustgraph/cli/get_document_content.py
@@ -5,7 +5,7 @@ Gets document content from the library by document ID.
import argparse
import os
import sys
-from trustgraph.api import Api
+import requests
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
@@ -13,15 +13,29 @@ default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def get_content(url, document_id, output_file, token=None, workspace="default"):
- api = Api(url, token=token, workspace=workspace).library()
+ stream_url = url.rstrip("/") + "/api/v1/document-stream"
- content = api.get_document_content(id=document_id)
+ params = {
+ "document-id": document_id,
+ "workspace": workspace,
+ }
+
+ headers = {}
+ if token:
+ headers["Authorization"] = f"Bearer {token}"
+
+ resp = requests.get(stream_url, params=params, headers=headers, stream=True)
+ resp.raise_for_status()
if output_file:
+ total = 0
with open(output_file, 'wb') as f:
- f.write(content)
- print(f"Written {len(content)} bytes to {output_file}")
+ for chunk in resp.iter_content(chunk_size=65536):
+ f.write(chunk)
+ total += len(chunk)
+ print(f"Written {total} bytes to {output_file}")
else:
+ content = resp.content
try:
text = content.decode('utf-8')
print(text)
diff --git a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py
index f214111d..40ecac8a 100755
--- a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py
+++ b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py
@@ -219,7 +219,14 @@ class Processor(FlowProcessor):
source_doc_id = v.document_id or v.metadata.id
# Run OCR, get per-page markdown
- pages = self.ocr(blob)
+ try:
+ pages = self.ocr(blob)
+ except Exception as e:
+ logger.error(
+ f"Failed to decode PDF {source_doc_id}: "
+ f"{type(e).__name__}: {e}"
+ )
+ return
for markdown, page_num in pages:
diff --git a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py
index 209153f6..ca242265 100755
--- a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py
+++ b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py
@@ -129,7 +129,15 @@ class Processor(FlowProcessor):
)
PyPDFLoader = _cls
loader = PyPDFLoader(temp_path)
- pages = loader.load()
+ try:
+ pages = loader.load()
+ except Exception as e:
+ source_doc_id = v.document_id or v.metadata.id
+ logger.error(
+ f"Failed to decode PDF {source_doc_id}: "
+ f"{type(e).__name__}: {e}"
+ )
+ return
# Get the source document ID
source_doc_id = v.document_id or v.metadata.id
diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py b/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py
index 2992d99f..74b4d7df 100644
--- a/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py
+++ b/trustgraph-flow/trustgraph/gateway/dispatch/document_stream.py
@@ -3,6 +3,7 @@ import asyncio
import uuid
import logging
from . librarian import LibrarianRequestor
+from ... schema import librarian_request_queue, librarian_response_queue
# Module logger
logger = logging.getLogger(__name__)
@@ -23,10 +24,13 @@ class DocumentStreamExport:
response = await ok()
+ uid = str(uuid.uuid4())
lr = LibrarianRequestor(
backend=self.backend,
- consumer="api-gateway-doc-stream-" + str(uuid.uuid4()),
- subscriber="api-gateway-doc-stream-" + str(uuid.uuid4()),
+ consumer="api-gateway-doc-stream-" + uid,
+ subscriber="api-gateway-doc-stream-" + uid,
+ request_queue=f"{librarian_request_queue}:{workspace}",
+ response_queue=f"{librarian_response_queue}:{workspace}",
)
try:
diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py
index bdbd18d8..73bbb1f3 100644
--- a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py
+++ b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py
@@ -288,6 +288,8 @@ class Mux:
await self.maybe_tidy_workers(workers)
async def responder(resp, fin):
+ if self.ws is None:
+ return
await self.ws.send_json({
"id": id,
"response": resp,
@@ -321,6 +323,8 @@ class Mux:
)
except Exception as e:
+ if self.ws is None:
+ return
await self.ws.send_json({
"id": id,
"error": {"message": str(e), "type": "error"},
diff --git a/trustgraph-flow/trustgraph/librarian/librarian.py b/trustgraph-flow/trustgraph/librarian/librarian.py
index 1c4d010e..cc5f0bdf 100644
--- a/trustgraph-flow/trustgraph/librarian/librarian.py
+++ b/trustgraph-flow/trustgraph/librarian/librarian.py
@@ -162,6 +162,9 @@ class Librarian:
request.document_id
)
+ if object_id is None:
+ raise RequestError(f"Document not found: {request.document_id}")
+
content = await self.blob_store.get(
object_id
)
diff --git a/trustgraph-flow/trustgraph/tables/cassandra_async.py b/trustgraph-flow/trustgraph/tables/cassandra_async.py
index 2f497748..205ed6b9 100644
--- a/trustgraph-flow/trustgraph/tables/cassandra_async.py
+++ b/trustgraph-flow/trustgraph/tables/cassandra_async.py
@@ -27,6 +27,8 @@ Notes:
import asyncio
+from cassandra.query import SimpleStatement
+
async def async_execute(session, query, parameters=None):
"""Execute a CQL statement asynchronously.
@@ -76,3 +78,36 @@ def _set_result_if_pending(fut, result):
def _set_exception_if_pending(fut, exc):
if not fut.done():
fut.set_exception(exc)
+
+
+async def async_execute_paged(session, query, parameters=None, fetch_size=100):
+ """Execute a CQL query with page-by-page iteration.
+
+ Uses synchronous session.execute() inside run_in_executor so that
+ the driver's ResultSet paging works correctly without materialising
+ the entire result set in memory.
+
+ Yields one page of rows at a time (as a list).
+ """
+ loop = asyncio.get_running_loop()
+
+ if isinstance(query, str):
+ stmt = SimpleStatement(query, fetch_size=fetch_size)
+ else:
+ stmt = query
+ stmt.fetch_size = fetch_size
+
+ def _fetch_all_pages():
+ pages = []
+ result_set = session.execute(stmt, parameters)
+ while True:
+ pages.append(list(result_set.current_rows))
+ if result_set.has_more_pages:
+ result_set.fetch_next_page()
+ else:
+ break
+ return pages
+
+ return await loop.run_in_executor(
+ None, _fetch_all_pages
+ )
diff --git a/trustgraph-flow/trustgraph/tables/knowledge.py b/trustgraph-flow/trustgraph/tables/knowledge.py
index cf085fdd..6a23731b 100644
--- a/trustgraph-flow/trustgraph/tables/knowledge.py
+++ b/trustgraph-flow/trustgraph/tables/knowledge.py
@@ -5,7 +5,7 @@ from .. schema import DocumentEmbeddings, ChunkEmbeddings
from cassandra.cluster import Cluster
-from . cassandra_async import async_execute
+from . cassandra_async import async_execute, async_execute_paged
def term_to_tuple(term):
@@ -398,7 +398,7 @@ class KnowledgeTableStore:
logger.debug("Get triples...")
try:
- rows = await async_execute(
+ pages = await async_execute_paged(
self.cassandra,
self.get_triples_stmt,
(workspace, document_id),
@@ -407,29 +407,30 @@ class KnowledgeTableStore:
logger.error("Exception occurred", exc_info=True)
raise
- for row in rows:
+ for page in pages:
+ for row in page:
- if row[3]:
- triples = [
- Triple(
- s = tuple_to_term(elt[0], elt[1]),
- p = tuple_to_term(elt[2], elt[3]),
- o = tuple_to_term(elt[4], elt[5]),
+ if row[3]:
+ triples = [
+ Triple(
+ s = tuple_to_term(elt[0], elt[1]),
+ p = tuple_to_term(elt[2], elt[3]),
+ o = tuple_to_term(elt[4], elt[5]),
+ )
+ for elt in row[3]
+ ]
+ else:
+ triples = []
+
+ await receiver(
+ Triples(
+ metadata = Metadata(
+ id = document_id,
+ collection = "default",
+ ),
+ triples = triples
)
- for elt in row[3]
- ]
- else:
- triples = []
-
- await receiver(
- Triples(
- metadata = Metadata(
- id = document_id,
- collection = "default", # FIXME: What to put here?
- ),
- triples = triples
)
- )
logger.debug("Done")
@@ -438,7 +439,7 @@ class KnowledgeTableStore:
logger.debug("Get GE...")
try:
- rows = await async_execute(
+ pages = await async_execute_paged(
self.cassandra,
self.get_graph_embeddings_stmt,
(workspace, document_id),
@@ -447,28 +448,29 @@ class KnowledgeTableStore:
logger.error("Exception occurred", exc_info=True)
raise
- for row in rows:
+ for page in pages:
+ for row in page:
- if row[3]:
- entities = [
- EntityEmbeddings(
- entity = tuple_to_term(ent[0][0], ent[0][1]),
- vector = ent[1]
+ if row[3]:
+ entities = [
+ EntityEmbeddings(
+ entity = tuple_to_term(ent[0][0], ent[0][1]),
+ vector = ent[1]
+ )
+ for ent in row[3]
+ ]
+ else:
+ entities = []
+
+ await receiver(
+ GraphEmbeddings(
+ metadata = Metadata(
+ id = document_id,
+ collection = "default",
+ ),
+ entities = entities
)
- for ent in row[3]
- ]
- else:
- entities = []
-
- await receiver(
- GraphEmbeddings(
- metadata = Metadata(
- id = document_id,
- collection = "default", # FIXME: What to put here?
- ),
- entities = entities
)
- )
logger.debug("Done")
@@ -477,7 +479,7 @@ class KnowledgeTableStore:
logger.debug("Get DE...")
try:
- rows = await async_execute(
+ pages = await async_execute_paged(
self.cassandra,
self.get_document_embeddings_stmt,
(workspace, document_id),
@@ -486,28 +488,29 @@ class KnowledgeTableStore:
logger.error("Exception occurred", exc_info=True)
raise
- for row in rows:
+ for page in pages:
+ for row in page:
- if row[3]:
- chunks = [
- ChunkEmbeddings(
- chunk_id=ch[0],
- vector=ch[1],
+ if row[3]:
+ chunks = [
+ ChunkEmbeddings(
+ chunk_id=ch[0],
+ vector=ch[1],
+ )
+ for ch in row[3]
+ ]
+ else:
+ chunks = []
+
+ await receiver(
+ DocumentEmbeddings(
+ metadata = Metadata(
+ id = document_id,
+ collection = "default",
+ ),
+ chunks = chunks
)
- for ch in row[3]
- ]
- else:
- chunks = []
-
- await receiver(
- DocumentEmbeddings(
- metadata = Metadata(
- id = document_id,
- collection = "default",
- ),
- chunks = chunks
)
- )
logger.debug("Done")
diff --git a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py
index 1b4815c6..0d5101df 100755
--- a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py
+++ b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py
@@ -107,7 +107,14 @@ class Processor(FlowProcessor):
# Get the source document ID
source_doc_id = v.document_id or v.metadata.id
- pages = convert_from_bytes(blob)
+ try:
+ pages = convert_from_bytes(blob)
+ except Exception as e:
+ logger.error(
+ f"Failed to decode PDF {source_doc_id}: "
+ f"{type(e).__name__}: {e}"
+ )
+ return
for ix, page in enumerate(pages):
diff --git a/trustgraph-unstructured/trustgraph/decoding/universal/processor.py b/trustgraph-unstructured/trustgraph/decoding/universal/processor.py
index b4936786..deedb7b4 100644
--- a/trustgraph-unstructured/trustgraph/decoding/universal/processor.py
+++ b/trustgraph-unstructured/trustgraph/decoding/universal/processor.py
@@ -418,7 +418,14 @@ class Processor(FlowProcessor):
doc_uri_str = document_uri(source_doc_id)
# Extract elements using unstructured
- elements = self.extract_elements(blob, mime_type)
+ try:
+ elements = self.extract_elements(blob, mime_type)
+ except Exception as e:
+ logger.error(
+ f"Failed to extract elements from {source_doc_id}: "
+ f"{type(e).__name__}: {e}"
+ )
+ return
if not elements:
logger.warning("No elements extracted from document")
From 00bb964e93c4a62ac52ac92f2808d5019c84cd78 Mon Sep 17 00:00:00 2001
From: cybermaggedon
Date: Tue, 2 Jun 2026 14:19:15 +0100
Subject: [PATCH 07/18] fix: route workspace through bulk WebSocket clients and
merge query params (#970)
Bulk clients (sync and async) were not forwarding the workspace parameter,
causing all bulk operations to hit the default workspace regardless of the
Api instance's workspace setting. Also fixes the gateway socket endpoint to
pass query parameters (including workspace) to the dispatcher, and prevents
the auth handshake from overwriting an explicitly set workspace.
Updates knowledge table store tests for paged query interface.
---
.../test_tables/test_knowledge_table_store.py | 26 +++++-----
trustgraph-base/trustgraph/api/api.py | 4 +-
.../trustgraph/api/async_bulk_client.py | 51 ++++++++----------
trustgraph-base/trustgraph/api/bulk_client.py | 52 +++++++++----------
.../trustgraph/api/socket_client.py | 3 +-
.../trustgraph/gateway/endpoint/socket.py | 4 +-
6 files changed, 67 insertions(+), 73 deletions(-)
diff --git a/tests/unit/test_tables/test_knowledge_table_store.py b/tests/unit/test_tables/test_knowledge_table_store.py
index 59d15b45..9a0b55c4 100644
--- a/tests/unit/test_tables/test_knowledge_table_store.py
+++ b/tests/unit/test_tables/test_knowledge_table_store.py
@@ -35,9 +35,9 @@ def _make_store():
class TestGetGraphEmbeddings:
@pytest.mark.asyncio
- @patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
+ @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
async def test_row_converts_to_entity_embeddings_with_singular_vector(
- self, mock_async_execute
+ self, mock_async_execute_paged
):
"""
Cassandra rows return entities as a list of [entity_tuple, vector]
@@ -57,7 +57,7 @@ class TestGetGraphEmbeddings:
store = _make_store()
store.cassandra = Mock()
store.get_graph_embeddings_stmt = Mock()
- mock_async_execute.return_value = [fake_row]
+ mock_async_execute_paged.return_value = [[fake_row]]
received = []
@@ -66,7 +66,7 @@ class TestGetGraphEmbeddings:
await store.get_graph_embeddings("alice", "doc-1", receiver)
- mock_async_execute.assert_called_once_with(
+ mock_async_execute_paged.assert_called_once_with(
store.cassandra,
store.get_graph_embeddings_stmt,
("alice", "doc-1"),
@@ -96,8 +96,8 @@ class TestGetGraphEmbeddings:
assert ge.entities[2].entity.value == "a literal entity"
@pytest.mark.asyncio
- @patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
- async def test_empty_entities_blob_yields_empty_list(self, mock_async_execute):
+ @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
+ async def test_empty_entities_blob_yields_empty_list(self, mock_async_execute_paged):
"""row[3] being None / empty must produce a GraphEmbeddings with
no entities, not raise."""
fake_row = (None, None, None, None)
@@ -105,7 +105,7 @@ class TestGetGraphEmbeddings:
store = _make_store()
store.cassandra = Mock()
store.get_graph_embeddings_stmt = Mock()
- mock_async_execute.return_value = [fake_row]
+ mock_async_execute_paged.return_value = [[fake_row]]
received = []
@@ -118,8 +118,8 @@ class TestGetGraphEmbeddings:
assert received[0].entities == []
@pytest.mark.asyncio
- @patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
- async def test_multiple_rows_each_emit_one_message(self, mock_async_execute):
+ @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
+ async def test_multiple_rows_each_emit_one_message(self, mock_async_execute_paged):
fake_rows = [
(None, None, None, [
(("http://example.org/a", True), [1.0]),
@@ -132,7 +132,7 @@ class TestGetGraphEmbeddings:
store = _make_store()
store.cassandra = Mock()
store.get_graph_embeddings_stmt = Mock()
- mock_async_execute.return_value = fake_rows
+ mock_async_execute_paged.return_value = [fake_rows]
received = []
@@ -153,8 +153,8 @@ class TestGetTriples:
the same Metadata construction. Cover it for parity."""
@pytest.mark.asyncio
- @patch('trustgraph.tables.knowledge.async_execute', new_callable=AsyncMock)
- async def test_row_converts_to_triples(self, mock_async_execute):
+ @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
+ async def test_row_converts_to_triples(self, mock_async_execute_paged):
# row[3] is a list of (s_val, s_uri, p_val, p_uri, o_val, o_uri)
fake_row = (
None, None, None,
@@ -170,7 +170,7 @@ class TestGetTriples:
store = _make_store()
store.cassandra = Mock()
store.get_triples_stmt = Mock()
- mock_async_execute.return_value = [fake_row]
+ mock_async_execute_paged.return_value = [[fake_row]]
received = []
diff --git a/trustgraph-base/trustgraph/api/api.py b/trustgraph-base/trustgraph/api/api.py
index 9074bac1..0190d3f5 100644
--- a/trustgraph-base/trustgraph/api/api.py
+++ b/trustgraph-base/trustgraph/api/api.py
@@ -337,7 +337,7 @@ class Api:
from . bulk_client import BulkClient
# Extract base URL (remove api/v1/ suffix)
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
- self._bulk_client = BulkClient(base_url, self.timeout, self.token)
+ self._bulk_client = BulkClient(base_url, self.timeout, self.token, workspace=self.workspace)
return self._bulk_client
def metrics(self):
@@ -462,7 +462,7 @@ class Api:
from . async_bulk_client import AsyncBulkClient
# Extract base URL (remove api/v1/ suffix)
base_url = self.url.rsplit("api/v1/", 1)[0].rstrip("/")
- self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token)
+ self._async_bulk_client = AsyncBulkClient(base_url, self.timeout, self.token, workspace=self.workspace)
return self._async_bulk_client
def async_metrics(self):
diff --git a/trustgraph-base/trustgraph/api/async_bulk_client.py b/trustgraph-base/trustgraph/api/async_bulk_client.py
index 9a6a49c3..f93ab667 100644
--- a/trustgraph-base/trustgraph/api/async_bulk_client.py
+++ b/trustgraph-base/trustgraph/api/async_bulk_client.py
@@ -9,10 +9,11 @@ from . types import Triple
class AsyncBulkClient:
"""Asynchronous bulk operations client"""
- def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
+ def __init__(self, url: str, timeout: int, token: Optional[str], workspace: str = "default") -> None:
self.url: str = self._convert_to_ws_url(url)
self.timeout: int = timeout
self.token: Optional[str] = token
+ self.workspace: str = workspace
def _convert_to_ws_url(self, url: str) -> str:
"""Convert HTTP URL to WebSocket URL"""
@@ -25,11 +26,21 @@ class AsyncBulkClient:
else:
return f"ws://{url}"
+ def _build_ws_url(self, path: str) -> str:
+ """Build a WebSocket URL with token and workspace query params."""
+ ws_url = f"{self.url}{path}"
+ params = []
+ if self.token:
+ params.append(f"token={self.token}")
+ if self.workspace:
+ params.append(f"workspace={self.workspace}")
+ if params:
+ ws_url = f"{ws_url}?{'&'.join(params)}"
+ return ws_url
+
async def import_triples(self, flow: str, triples: AsyncIterator[Triple], **kwargs: Any) -> None:
"""Bulk import triples via WebSocket"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/triples")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for triple in triples:
@@ -42,9 +53,7 @@ class AsyncBulkClient:
async def export_triples(self, flow: str, **kwargs: Any) -> AsyncIterator[Triple]:
"""Bulk export triples via WebSocket"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/triples")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@@ -57,9 +66,7 @@ class AsyncBulkClient:
async def import_graph_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import graph embeddings via WebSocket"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/graph-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for embedding in embeddings:
@@ -67,9 +74,7 @@ class AsyncBulkClient:
async def export_graph_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
"""Bulk export graph embeddings via WebSocket"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/graph-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@@ -77,9 +82,7 @@ class AsyncBulkClient:
async def import_document_embeddings(self, flow: str, embeddings: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import document embeddings via WebSocket"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/document-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for embedding in embeddings:
@@ -87,9 +90,7 @@ class AsyncBulkClient:
async def export_document_embeddings(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
"""Bulk export document embeddings via WebSocket"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/document-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@@ -97,9 +98,7 @@ class AsyncBulkClient:
async def import_entity_contexts(self, flow: str, contexts: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import entity contexts via WebSocket"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/entity-contexts")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for context in contexts:
@@ -107,9 +106,7 @@ class AsyncBulkClient:
async def export_entity_contexts(self, flow: str, **kwargs: Any) -> AsyncIterator[Dict[str, Any]]:
"""Bulk export entity contexts via WebSocket"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/entity-contexts")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@@ -117,9 +114,7 @@ class AsyncBulkClient:
async def import_rows(self, flow: str, rows: AsyncIterator[Dict[str, Any]], **kwargs: Any) -> None:
"""Bulk import rows via WebSocket"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/rows")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for row in rows:
diff --git a/trustgraph-base/trustgraph/api/bulk_client.py b/trustgraph-base/trustgraph/api/bulk_client.py
index 0e49fc4e..ae185240 100644
--- a/trustgraph-base/trustgraph/api/bulk_client.py
+++ b/trustgraph-base/trustgraph/api/bulk_client.py
@@ -34,7 +34,7 @@ class BulkClient:
Note: For true async support, use AsyncBulkClient instead.
"""
- def __init__(self, url: str, timeout: int, token: Optional[str]) -> None:
+ def __init__(self, url: str, timeout: int, token: Optional[str], workspace: str = "default") -> None:
"""
Initialize synchronous bulk client.
@@ -42,10 +42,12 @@ class BulkClient:
url: Base URL for TrustGraph API (HTTP/HTTPS will be converted to WS/WSS)
timeout: WebSocket timeout in seconds
token: Optional bearer token for authentication
+ workspace: Workspace for data isolation
"""
self.url: str = self._convert_to_ws_url(url)
self.timeout: int = timeout
self.token: Optional[str] = token
+ self.workspace: str = workspace
def _convert_to_ws_url(self, url: str) -> str:
"""Convert HTTP URL to WebSocket URL"""
@@ -58,6 +60,18 @@ class BulkClient:
else:
return f"ws://{url}"
+ def _build_ws_url(self, path: str) -> str:
+ """Build a WebSocket URL with token and workspace query params."""
+ ws_url = f"{self.url}{path}"
+ params = []
+ if self.token:
+ params.append(f"token={self.token}")
+ if self.workspace:
+ params.append(f"workspace={self.workspace}")
+ if params:
+ ws_url = f"{ws_url}?{'&'.join(params)}"
+ return ws_url
+
def _run_async(self, coro: Coroutine[Any, Any, Any]) -> Any:
"""Run async coroutine synchronously"""
try:
@@ -116,9 +130,7 @@ class BulkClient:
metadata: Optional[Dict[str, Any]], batch_size: int
) -> None:
"""Async implementation of triple import"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/import/triples"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/triples")
if metadata is None:
metadata = {"id": "", "metadata": [], "collection": "default"}
@@ -194,9 +206,7 @@ class BulkClient:
async def _export_triples_async(self, flow: str) -> Iterator[Triple]:
"""Async implementation of triple export"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/export/triples"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/triples")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@@ -238,9 +248,7 @@ class BulkClient:
async def _import_graph_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None:
"""Async implementation of graph embeddings import"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/import/graph-embeddings"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/graph-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
for embedding in embeddings:
@@ -296,9 +304,7 @@ class BulkClient:
async def _export_graph_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]:
"""Async implementation of graph embeddings export"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/export/graph-embeddings"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/graph-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@@ -336,9 +342,7 @@ class BulkClient:
async def _import_document_embeddings_async(self, flow: str, embeddings: Iterator[Dict[str, Any]]) -> None:
"""Async implementation of document embeddings import"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/import/document-embeddings"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/document-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
for embedding in embeddings:
@@ -394,9 +398,7 @@ class BulkClient:
async def _export_document_embeddings_async(self, flow: str) -> Iterator[Dict[str, Any]]:
"""Async implementation of document embeddings export"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/export/document-embeddings"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/document-embeddings")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@@ -446,9 +448,7 @@ class BulkClient:
metadata: Optional[Dict[str, Any]], batch_size: int
) -> None:
"""Async implementation of entity contexts import"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/import/entity-contexts"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/entity-contexts")
if metadata is None:
metadata = {"id": "", "metadata": [], "collection": "default"}
@@ -522,9 +522,7 @@ class BulkClient:
async def _export_entity_contexts_async(self, flow: str) -> Iterator[Dict[str, Any]]:
"""Async implementation of entity contexts export"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/export/entity-contexts"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/export/entity-contexts")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
async for raw_message in websocket:
@@ -562,9 +560,7 @@ class BulkClient:
async def _import_rows_async(self, flow: str, rows: Iterator[Dict[str, Any]]) -> None:
"""Async implementation of rows import"""
- ws_url = f"{self.url}/api/v1/flow/{flow}/import/rows"
- if self.token:
- ws_url = f"{ws_url}?token={self.token}"
+ ws_url = self._build_ws_url(f"/api/v1/flow/{flow}/import/rows")
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
for row in rows:
diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py
index 6eeb95ff..b88d0c78 100644
--- a/trustgraph-base/trustgraph/api/socket_client.py
+++ b/trustgraph-base/trustgraph/api/socket_client.py
@@ -167,7 +167,8 @@ class SocketClient:
)
if resp.get("type") == "auth-ok":
- self.workspace = resp.get("workspace", self.workspace)
+ if self.workspace == "default":
+ self.workspace = resp.get("workspace", self.workspace)
elif resp.get("type") == "auth-failed":
await self._socket.close()
raise ProtocolException(
diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/socket.py b/trustgraph-flow/trustgraph/gateway/endpoint/socket.py
index f53ad73b..af6183db 100644
--- a/trustgraph-flow/trustgraph/gateway/endpoint/socket.py
+++ b/trustgraph-flow/trustgraph/gateway/endpoint/socket.py
@@ -117,8 +117,10 @@ class SocketEndpoint:
running = Running()
+ params = dict(request.query)
+ params.update(request.match_info)
dispatcher = await self.dispatcher(
- ws, running, request.match_info
+ ws, running, params
)
worker_task = tg.create_task(
From 60f861bac431e2f0895de496b90c23a46375ab21 Mon Sep 17 00:00:00 2001
From: cybermaggedon
Date: Tue, 2 Jun 2026 14:49:24 +0100
Subject: [PATCH 08/18] Added an instance tag ID (#971)
---
trustgraph-base/trustgraph/base/logging.py | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
diff --git a/trustgraph-base/trustgraph/base/logging.py b/trustgraph-base/trustgraph/base/logging.py
index 9bf599b1..ff10c140 100644
--- a/trustgraph-base/trustgraph/base/logging.py
+++ b/trustgraph-base/trustgraph/base/logging.py
@@ -11,6 +11,7 @@ Supports dual output to console and Loki for centralized log aggregation.
import contextvars
import logging
import logging.handlers
+import uuid
from argparse import ArgumentParser
from queue import Queue
from typing import Any
@@ -132,14 +133,12 @@ def setup_logging(args: dict[str, Any]) -> None:
try:
from logging_loki import LokiHandler
- # Create Loki handler with optional authentication. The
- # processor label is NOT baked in here — it's stamped onto
- # each record by _ProcessorIdFilter reading the task-local
- # contextvar, and logging_loki's emitter reads record.tags
- # to build per-record Loki labels.
+ instance_id = str(uuid.uuid4())[:8]
+
loki_handler_kwargs = {
'url': loki_url,
'version': "1",
+ 'tags': {'instance': instance_id},
}
if loki_username and loki_password:
From aa158e1ba3c886f6fed731da5bb44a498a2f806f Mon Sep 17 00:00:00 2001
From: cybermaggedon
Date: Wed, 3 Jun 2026 09:45:53 +0100
Subject: [PATCH 09/18] fix: skip authorise() for AUTHENTICATED/PUBLIC
sentinels in WebSocket mux (#972)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
The mux unconditionally called auth.authorise() for every operation,
passing capability sentinels like AUTHENTICATED ("__authenticated__")
to the IAM regime. Since no role grants "__authenticated__", the regime
denied the request — breaking whoami (and any future AUTHENTICATED-only
operation) over the WebSocket path while the HTTP endpoints worked fine.
Match the guard pattern used by iam_endpoint.py and registry_endpoint.py:
only call authorise() for real capability strings, not sentinels.
---
.../trustgraph/gateway/dispatch/mux.py | 56 ++++++++++---------
1 file changed, 31 insertions(+), 25 deletions(-)
diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py
index 73bbb1f3..9b119f8e 100644
--- a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py
+++ b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py
@@ -4,6 +4,8 @@ import queue
import uuid
import logging
+from ..capabilities import PUBLIC, AUTHENTICATED
+
# Module logger
logger = logging.getLogger(__name__)
@@ -156,37 +158,41 @@ class Mux:
})
return
- # Resolve workspace first (default-fill from the caller's
- # bound workspace), then ask the regime to authorise the
- # service-level capability against the matched
- # operation's resource shape.
+ # Resolve workspace (default-fill from the caller's
+ # bound workspace). Workspace resolution applies to all
+ # operations regardless of capability level.
try:
await enforce_workspace(data, self.identity, self.auth)
if isinstance(inner, dict):
await enforce_workspace(inner, self.identity, self.auth)
- if data.get("flow"):
- resource = {
- "workspace": data.get("workspace", ""),
- "flow": data.get("flow", ""),
- }
- parameters = {}
- else:
- # Build a minimal RequestContext so the matched
- # operation's own extractors decide resource and
- # parameters — same path the HTTP endpoints take.
- from ..registry import RequestContext
- ctx = RequestContext(
- body=inner if isinstance(inner, dict) else {},
- match_info={},
- identity=self.identity,
- )
- resource = op.extract_resource(ctx)
- parameters = op.extract_parameters(ctx)
+ # Authorisation: capability sentinels short-circuit
+ # the regime call; capability strings go through
+ # authorise().
+ if op.capability not in (PUBLIC, AUTHENTICATED):
+ if data.get("flow"):
+ resource = {
+ "workspace": data.get("workspace", ""),
+ "flow": data.get("flow", ""),
+ }
+ parameters = {}
+ else:
+ # Build a minimal RequestContext so the matched
+ # operation's own extractors decide resource
+ # and parameters — same path the HTTP
+ # endpoints take.
+ from ..registry import RequestContext
+ ctx = RequestContext(
+ body=inner if isinstance(inner, dict) else {},
+ match_info={},
+ identity=self.identity,
+ )
+ resource = op.extract_resource(ctx)
+ parameters = op.extract_parameters(ctx)
- await self.auth.authorise(
- self.identity, op.capability, resource, parameters,
- )
+ await self.auth.authorise(
+ self.identity, op.capability, resource, parameters,
+ )
except _web.HTTPNotFound:
await self.ws.send_json({
"id": request_id,
From 6df7471a556df6e37adaedee923c240588755000 Mon Sep 17 00:00:00 2001
From: cybermaggedon
Date: Wed, 3 Jun 2026 10:46:52 +0100
Subject: [PATCH 10/18] =?UTF-8?q?feat:=20complete=20knowledge=20core=20sto?=
=?UTF-8?q?rage=20=E2=80=94=20named=20graphs,=20provenance,=20source=20mat?=
=?UTF-8?q?erial=20(#973)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Implements all three changes from the knowledge-core-completeness tech spec:
1. Named graph field preserved through Cassandra storage (7-element tuple),
enabling provenance triples to retain their graph URIs on round-trip.
2. Provenance triples already arrive on triples-input — no routing change
needed; Change 1 was sufficient.
3. Source material (library documents) streamed alongside triples and
embeddings during core download/upload. The knowledge manager fetches
the document hierarchy from the librarian on download and recreates it
on upload, preserving the full provenance chain across instances.
---
.../tech-specs/knowledge-core-completeness.md | 535 ++++++++++++++++++
.../unit/test_cores/test_knowledge_manager.py | 254 ++++++++-
.../test_tables/test_knowledge_table_store.py | 33 +-
.../test_knowledge_translator_roundtrip.py | 169 +++++-
.../trustgraph/api/socket_client.py | 5 +
.../messaging/translators/knowledge.py | 76 ++-
.../trustgraph/schema/knowledge/knowledge.py | 21 +
trustgraph-cli/trustgraph/cli/get_kg_core.py | 37 +-
trustgraph-cli/trustgraph/cli/put_kg_core.py | 29 +-
trustgraph-flow/trustgraph/cores/knowledge.py | 123 +++-
trustgraph-flow/trustgraph/cores/service.py | 7 +
.../gateway/dispatch/core_export.py | 33 ++
.../gateway/dispatch/core_import.py | 33 ++
.../trustgraph/tables/knowledge.py | 7 +-
14 files changed, 1347 insertions(+), 15 deletions(-)
create mode 100644 docs/tech-specs/knowledge-core-completeness.md
diff --git a/docs/tech-specs/knowledge-core-completeness.md b/docs/tech-specs/knowledge-core-completeness.md
new file mode 100644
index 00000000..3ccb41f0
--- /dev/null
+++ b/docs/tech-specs/knowledge-core-completeness.md
@@ -0,0 +1,535 @@
+---
+layout: default
+title: "Knowledge Core Completeness"
+parent: "Tech Specs"
+---
+
+# Knowledge Core Completeness
+
+## Overview
+
+Knowledge cores are portable snapshots of extracted knowledge: triples, graph
+embeddings, and document embeddings stored in Cassandra's `knowledge` keyspace.
+They can be downloaded as files, transferred between TrustGraph instances, and
+loaded back into vector and graph stores.
+
+Recent additions to TrustGraph — explainability/provenance and named graphs —
+were not carried through to the knowledge core system. This means that
+exporting and re-importing a core loses provenance links, graph assignments,
+and source material, breaking the explainability chain.
+
+This specification addresses three gaps:
+
+1. **Named graphs not stored** — The `g` (graph name) field on triples is
+ silently dropped when writing to the core store and comes back as `None`
+ on read.
+2. **Provenance triples not captured** — Provenance triples (PROV-O) are
+ generated during extraction and flow to graph stores, but never enter
+ the knowledge core store. It is unclear whether they arrive at the store
+ in the correct form.
+3. **Source material not included** — Documents, text pages, and chunks in
+ the librarian's bucket store are not part of the core. After loading a
+ core on a different instance, provenance links to source material point
+ at nothing.
+
+## Goals
+
+- **Self-contained cores**: A downloaded knowledge core file contains
+ everything needed to reconstruct the full knowledge graph including
+ provenance and source attribution on a fresh instance.
+- **Named graph preservation**: Round-tripping a core preserves graph
+ assignments on all triples.
+- **Backward compatibility**: Existing core files (without graph names or
+ source material) can still be uploaded and loaded. New fields are optional
+ on import.
+- **No change to core identity**: A core is still identified by its document
+ ID. The additional data is associated with the same core ID.
+- **Minimal file format changes**: Extend the existing msgpack record format
+ with new record types rather than restructuring existing ones.
+
+## Background
+
+### Current Lifecycle
+
+```
+Extraction pipeline
+ │
+ ├─ triples ──────────────────► knowledge core store (Cassandra)
+ ├─ graph embeddings ─────────► knowledge core store (Cassandra)
+ ├─ document embeddings ──────► knowledge core store (Cassandra)
+ ├─ provenance triples ───────► graph store (only)
+ └─ source documents ─────────► librarian bucket store (only)
+
+Download: Cassandra ──► knowledge manager ──► API gateway ──► client file
+Upload: client file ──► API gateway ──► knowledge manager ──► Cassandra
+Load: Cassandra ──► knowledge manager ──► Pulsar topics ──► graph/vector stores
+```
+
+### Current Core File Format (msgpack)
+
+A core file is a sequence of concatenated msgpack records. Each record is a
+2-element tuple: `(type_tag, payload)`.
+
+| Type tag | Payload | Description |
+|----------|---------|-------------|
+| `"t"` | `{"m": {id, root, collection}, "t": [triple_dicts]}` | Triple batch |
+| `"ge"` | `{"m": {id, root, collection}, "e": [{entity, vector}]}` | Graph embedding batch |
+
+### What's Missing
+
+#### Named Graphs
+
+The `Triple` dataclass has a `g: str | None` field (graph name IRI), used to
+separate provenance graphs (`urn:graph:source`, `urn:graph:retrieval`) from
+the default graph. However:
+
+- **Cassandra schema** (`knowledge.triples` table): stores a 6-tuple per
+ triple `(s_val, s_is_uri, p_val, p_is_uri, o_val, o_is_uri)` — no graph
+ field.
+- **`add_triples()`** (`tables/knowledge.py:231`): destructures only `s`,
+ `p`, `o` — `g` is discarded.
+- **`get_triples()`** (`tables/knowledge.py:396`): reconstructs `Triple`
+ with `g` defaulting to `None`.
+- **Core file format**: triple dicts do not include a graph field.
+
+#### Provenance Triples
+
+Provenance triples are generated in the extraction pipeline
+(`trustgraph-base/trustgraph/provenance/triples.py`) and published to graph
+store topics. They use named graphs (`urn:graph:source`,
+`urn:graph:retrieval`) and PROV-O vocabulary.
+
+The knowledge core store processor (`storage/knowledge/store.py`) listens on
+`triples-input` and `graph-embeddings-input`. Whether provenance triples
+arrive on the same `triples-input` topic or a separate one needs
+verification. Even if they do arrive, the graph name would be lost (per
+above).
+
+#### Source Material
+
+The librarian stores the full document hierarchy in a separate system:
+
+- **Blob store** (S3/MinIO): original documents, text pages, chunks —
+ keyed by object UUID under `doc/{object_id}`.
+- **Cassandra `library` keyspace**: document metadata including `id`,
+ `kind` (MIME type), `title`, `parent_id`, `document_type`
+ (`source`/`extracted`), `object_id` (blob reference).
+
+Provenance triples link extracted facts back to chunk/page/document IDs.
+Those IDs resolve through the librarian. When a core is loaded on a
+different instance, the librarian has no matching documents, so the entire
+provenance chain is broken.
+
+### Key Source Files
+
+| Component | File | Purpose |
+|-----------|------|---------|
+| Core Cassandra schema | `trustgraph-flow/trustgraph/tables/knowledge.py` | Table definitions, read/write |
+| Core manager | `trustgraph-flow/trustgraph/cores/knowledge.py` | API operations, load-to-store |
+| Core store processor | `trustgraph-flow/trustgraph/storage/knowledge/store.py` | Extraction → Cassandra |
+| CLI download | `trustgraph-cli/trustgraph/cli/get_kg_core.py` | Core → msgpack file |
+| CLI upload | `trustgraph-cli/trustgraph/cli/put_kg_core.py` | Msgpack file → core |
+| CLI load | `trustgraph-cli/trustgraph/cli/load_kg_core.py` | Core → graph/vector stores |
+| API client | `trustgraph-base/trustgraph/api/knowledge.py` | Client-side knowledge API |
+| Triple schema | `trustgraph-base/trustgraph/schema/core/primitives.py` | Triple dataclass with `g` field |
+| Provenance generation | `trustgraph-base/trustgraph/provenance/triples.py` | PROV-O triple creation |
+| Librarian | `trustgraph-flow/trustgraph/librarian/librarian.py` | Document storage service |
+| Library tables | `trustgraph-flow/trustgraph/tables/library.py` | Document metadata in Cassandra |
+| Blob store | `trustgraph-flow/trustgraph/librarian/blob_store.py` | S3/MinIO object storage |
+
+## Technical Design
+
+### Change 1: Named Graph Field in Core Storage
+
+#### Cassandra Schema
+
+Extend the `triples` tuple from 6 to 7 elements, adding the graph name:
+
+```
+triples list>
+```
+
+**Migration**: The schema change uses `ALTER TABLE` or is handled by
+creating a new table version. Existing rows with 6-element tuples must be
+handled gracefully on read — if the tuple has 6 elements, treat graph as
+default.
+
+#### Write Path (`add_triples`)
+
+Change `tables/knowledge.py:add_triples()` to include `triple.g`:
+
+```python
+triples = [
+ (
+ *term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o),
+ v.g or ""
+ )
+ for v in m.triples
+]
+```
+
+#### Read Path (`get_triples`)
+
+Change `tables/knowledge.py:get_triples()` to restore the graph name:
+
+```python
+Triple(
+ s = tuple_to_term(elt[0], elt[1]),
+ p = tuple_to_term(elt[2], elt[3]),
+ o = tuple_to_term(elt[4], elt[5]),
+ g = elt[6] if len(elt) > 6 and elt[6] else None,
+)
+```
+
+The `len(elt) > 6` guard provides backward compatibility with existing
+6-element rows.
+
+#### Core File Format
+
+Extend triple dicts in the `"t"` record to include the graph name:
+
+```python
+# In get_kg_core.py write_triple — each triple dict gains "g" key
+{"s": ..., "p": ..., "o": ..., "g": "urn:graph:source"}
+```
+
+On read (`put_kg_core.py`), treat missing `"g"` key as default graph for
+backward compatibility with old core files.
+
+### Change 2: Provenance Triples in Cores
+
+#### Investigation Required
+
+Before implementation, verify:
+
+1. Whether provenance triples arrive on the `triples-input` topic that the
+ knowledge core store processor already listens on.
+2. If not, which topic they use, and whether the store processor should
+ subscribe to it.
+
+#### If provenance triples already arrive at the store
+
+The only change needed is Change 1 (named graphs) — the provenance triples
+are already being stored, just without their graph name. Once graph names
+are preserved, provenance triples will round-trip correctly.
+
+#### If provenance triples do NOT arrive at the store
+
+Two options:
+
+**Option A — Route provenance to the existing store topic**: Configure the
+flow so provenance triples are published to the same `triples-input` topic.
+This is the simpler approach and keeps the store processor unchanged.
+
+**Option B — Add a subscription**: Add a new `ConsumerSpec` in the store
+processor for the provenance topic. This keeps provenance routing
+independent but adds complexity.
+
+Recommendation: Option A, unless there is a reason provenance triples are
+intentionally kept off the core store topic.
+
+### Change 3: Source Material in Cores
+
+This is the largest change. The goal is that when a core is loaded on a
+fresh instance, provenance links to source material resolve.
+
+#### Architecture
+
+Source material is **not stored in the knowledge core tables**. It lives in
+the librarian (Cassandra `library` keyspace + S3/MinIO blob store) and is
+fetched on demand via the librarian's existing service API.
+
+The knowledge manager acts as a **client of the librarian service** — it
+calls the librarian's request/response API over pub/sub to retrieve document
+metadata and content. It does not access the library's Cassandra tables or
+blob store directly.
+
+#### Transport
+
+The librarian's pub/sub API already handles chunking of large documents.
+This chunking is designed to be websocket-friendly, so library content
+flowing through the API gateway to external clients does not require
+re-chunking. The API gateway remains a transport layer.
+
+```
+Download:
+ Knowledge manager ──pub/sub──► Librarian (fetch metadata + content)
+ Knowledge manager ──pub/sub──► API gateway ──websocket──► Client
+
+Upload:
+ Client ──websocket──► API gateway ──pub/sub──► Knowledge manager
+ Knowledge manager ──pub/sub──► Librarian (store metadata + content)
+```
+
+#### What to Include
+
+The provenance chain links facts → chunks → pages → documents. For the
+chain to resolve, the core must include:
+
+1. **Document metadata** — the library record for each document in the
+ hierarchy (id, kind, title, parent_id, document_type, etc.)
+2. **Document content** — the blob data for each document (original file,
+ extracted text pages, text chunks)
+
+Including the full hierarchy is necessary because:
+- A user viewing provenance needs to traverse fact → chunk → page → document
+- The chunk text is needed to show what text a fact was extracted from
+- The page text provides broader context
+- The original document is needed for full source attribution
+
+#### Size Implications
+
+Source material will significantly increase core file sizes. A rough model:
+
+| Component | Typical size per document |
+|-----------|-------------------------|
+| Triples + embeddings (current) | 1-10 MB |
+| Chunk text (all chunks) | ~same as original document |
+| Page text (all pages) | ~same as original document |
+| Original document (PDF, etc.) | Varies widely (KB to hundreds of MB) |
+
+For a 10 MB PDF, the core could grow from ~5 MB to ~25 MB (original +
+derived text + existing data). For large document sets, cores could become
+very large.
+
+**Decision needed**: Whether to include original documents or just derived
+text (pages + chunks). Including only derived text still allows provenance
+display but loses the ability to serve the original file.
+
+#### New Core File Record Types
+
+Add new msgpack record types for library content:
+
+| Type tag | Payload | Description |
+|----------|---------|-------------|
+| `"lm"` | `{"id", "kind", "title", "parent_id", "document_type", "comments", "tags", "metadata"}` | Library document metadata |
+| `"lb"` | `{"id", "data"}` | Library document blob content (chunked by pub/sub layer) |
+
+These are emitted after the existing `"t"` and `"ge"` records during
+download and processed during upload.
+
+#### Download Path
+
+Extend `KnowledgeManager.get_kg_core()` to:
+
+1. Stream triples and graph embeddings from the core store (existing
+ behavior).
+2. Use the librarian service API to retrieve documents associated with
+ this core ID:
+ a. Fetch the root document metadata and content.
+ b. Use `list-children` to discover child documents (pages, chunks).
+ c. Recursively fetch metadata and content for each child.
+3. Stream each document as `"lm"` (metadata) and `"lb"` (content) records.
+
+The knowledge manager gains the librarian service as a pub/sub dependency.
+Large document content is chunked by the librarian's existing pub/sub
+transport — the knowledge manager receives and forwards these chunks without
+buffering the full blob in memory.
+
+#### Upload Path
+
+Extend `KnowledgeManager.put_kg_core()` to handle the new record types:
+
+1. For `"lm"` records: call the librarian service API to create/update
+ the document metadata.
+2. For `"lb"` records: call the librarian service API to store the
+ document content.
+
+Parent-child relationships are preserved because `parent_id` is stored in
+the metadata. Documents should be processed in hierarchy order (parent
+before child) to satisfy any ordering constraints.
+
+#### Load Path
+
+The load path (`_load_kg_core`) publishes triples and embeddings to Pulsar
+topics for ingestion into graph/vector stores. Source material does not need
+to flow through the load path — it is already in the librarian after the
+upload step and can be accessed directly by services that need it.
+
+No changes to the load path for source material.
+
+#### CLI Changes
+
+**`tg-get-kg-core`**: Add handling for `"lm"` and `"lb"` record types in
+the file writer.
+
+**`tg-put-kg-core`**: Add handling for `"lm"` and `"lb"` record types in
+the file reader. Send library records to the knowledge manager alongside
+triple/embedding records.
+
+#### Associating Documents with Cores
+
+The core ID is `metadata.root`, which is the root document ID from the
+librarian. This provides a natural join: the core's root document and all
+its children (pages, chunks) are the source material for that core.
+
+The librarian's `list-children` API provides the child documents. A
+recursive traversal from the root document collects the full hierarchy.
+
+### API Changes
+
+#### KnowledgeResponse Schema
+
+Add optional fields to `KnowledgeResponse` for library data:
+
+```python
+@dataclass
+class KnowledgeResponse:
+ error: Error | None = None
+ ids: list | None = None
+ eos: bool = False
+ triples: Triples | None = None
+ graph_embeddings: GraphEmbeddings | None = None
+ document_embeddings: DocumentEmbeddings | None = None
+ library_metadata: LibraryMetadata | None = None # new
+ library_blob: LibraryBlob | None = None # new
+```
+
+#### New Schema Types
+
+```python
+@dataclass
+class LibraryMetadata:
+ id: str
+ kind: str | None = None
+ title: str | None = None
+ parent_id: str | None = None
+ document_type: str | None = None
+ comments: str | None = None
+ tags: list[str] | None = None
+ metadata: list[Triple] | None = None
+
+@dataclass
+class LibraryBlob:
+ id: str
+ data: bytes
+```
+
+#### Socket API
+
+The existing streaming protocol for `get-kg-core` / `put-kg-core` carries
+these new fields naturally — responses already stream multiple record types.
+
+### Dependencies Between Changes
+
+```
+Change 1 (named graphs) ◄── Change 2 depends on this
+ │
+ └── Change 2 (provenance triples)
+ │
+ └── Change 3 (source material) is independent
+```
+
+Change 1 is a prerequisite for Change 2 (provenance triples use named
+graphs). Change 3 is independent and can be implemented in parallel.
+
+## Security Considerations
+
+- **Workspace isolation**: Core download/upload must respect workspace
+ boundaries. Source material from the librarian must only be included if
+ it belongs to the same workspace as the core. This is already enforced
+ by the existing workspace-scoped queries.
+- **Large blob transfer**: Streaming large documents through the API
+ is handled by the librarian's existing pub/sub chunking, which is
+ designed to be websocket-friendly. No additional chunking layer is
+ needed.
+- **Cross-instance trust**: When uploading a core from an external source,
+ the library content should be treated as untrusted input. Document
+ metadata and blob content should be validated before insertion.
+
+## Performance Considerations
+
+- **Core file size**: Including source material will significantly increase
+ core file sizes. Consider adding a flag to download/upload commands to
+ optionally exclude source material for use cases where only the knowledge
+ graph is needed.
+- **Streaming**: All paths already use streaming (paged Cassandra queries,
+ msgpack record-at-a-time). Library content should follow the same pattern.
+- **Cassandra schema migration**: Changing the tuple width in the `triples`
+ table requires careful handling. Cassandra frozen tuples cannot be altered
+ in place — a migration strategy is needed (see Migration Plan).
+
+## Testing Strategy
+
+- **Unit tests**: Triple round-trip with graph name (write → read →
+ verify `g` field preserved). Backward compatibility with 6-element tuples.
+- **Integration tests**: Full lifecycle — extract with provenance → download
+ core → upload to fresh instance → load → verify provenance chain resolves.
+- **File format tests**: Read old-format core files (no graph name, no
+ library records) and verify they load without error.
+- **Library inclusion tests**: Download core with source material → upload →
+ verify documents accessible through librarian.
+
+## Migration Plan
+
+### Cassandra Schema
+
+The `triples` table stores tuples in a `list>` column. Cassandra
+does not support altering the type of an existing column. Options:
+
+**Option A — New table**: Create a `triples_v2` table with the 7-element
+tuple. Migrate data from `triples` to `triples_v2`. The read path checks
+both tables during a transition period, then the old table is dropped.
+
+**Option B — Dual read**: Keep the existing table. The read path handles
+both 6-element and 7-element tuples by checking length. New writes use
+7-element tuples. This works if Cassandra accepts variable-length tuples in
+a list — **needs verification**.
+
+**Option C — Separate graph column**: Instead of extending the tuple, add a
+parallel `graphs list` column where `graphs[i]` corresponds to
+`triples[i]`. This avoids tuple migration entirely but requires keeping the
+two lists in sync.
+
+Recommendation: Verify Option B first (simplest). Fall back to Option A if
+Cassandra rejects mixed tuple lengths.
+
+### Core File Format
+
+Backward compatible by design:
+- Old files lack `"g"` in triple dicts and have no `"lm"`/`"lb"` records →
+ handled by defaults.
+- New files read by old code → old code ignores unknown record types (the
+ existing `read_message` raises on unknown types, so this needs a small
+ fix to skip unknown types gracefully).
+
+## Open Questions
+
+1. **Provenance topic routing**: Do provenance triples currently arrive at
+ the `triples-input` topic consumed by the knowledge core store? If not,
+ what topic are they on?
+
+2. **Include original documents?**: Should cores include the original
+ uploaded document (e.g. PDF), or only derived text (pages + chunks)?
+ Including originals makes cores fully self-contained but potentially
+ very large. Excluding them preserves provenance text display but loses
+ the ability to serve the original file.
+
+3. **Optional source material**: Should there be a flag on download/upload
+ to include or exclude source material? This would let users choose
+ between compact cores (knowledge only) and complete cores (knowledge +
+ sources).
+
+4. **Cassandra tuple migration**: Can Cassandra handle mixed-length tuples
+ in a `list>` column, or is a table migration required?
+
+5. **Document embedding cores**: DE cores are managed alongside KG cores.
+ Do they need the same treatment (source material inclusion)? The
+ document embeddings reference chunk IDs — the same provenance chain
+ applies.
+
+6. **Core versioning**: Should the core file include a version marker so
+ readers can distinguish old-format from new-format files without
+ trial-and-error parsing?
+
+## References
+
+- Extraction-time provenance: `docs/tech-specs/extraction-time-provenance.md`
+- Query-time explainability: `docs/tech-specs/query-time-explainability.md`
+- Agent explainability: `docs/tech-specs/agent-explainability.md`
+- Data ownership model: `docs/tech-specs/data-ownership-model.md`
diff --git a/tests/unit/test_cores/test_knowledge_manager.py b/tests/unit/test_cores/test_knowledge_manager.py
index 8f73dcc6..7797c9be 100644
--- a/tests/unit/test_cores/test_knowledge_manager.py
+++ b/tests/unit/test_cores/test_knowledge_manager.py
@@ -11,7 +11,12 @@ from unittest.mock import AsyncMock, Mock, patch, MagicMock
from unittest.mock import call
from trustgraph.cores.knowledge import KnowledgeManager
-from trustgraph.schema import KnowledgeResponse, Triples, GraphEmbeddings, Metadata, Triple, Term, EntityEmbeddings, IRI, LITERAL
+from trustgraph.schema import (
+ KnowledgeResponse, Triples, GraphEmbeddings, Metadata, Triple, Term,
+ EntityEmbeddings, IRI, LITERAL,
+ LibraryMetadata, LibraryBlob,
+ LibrarianResponse, DocumentMetadata,
+)
@pytest.fixture
@@ -373,11 +378,252 @@ class TestKnowledgeManagerOtherMethods:
mock_respond = AsyncMock()
await knowledge_manager.delete_kg_core(mock_request, mock_respond, "test-user")
-
+
# Verify table store was called correctly
knowledge_manager.table_store.delete_kg_core.assert_called_once_with("test-user", "test-doc-id")
-
+
# Verify response
mock_respond.assert_called_once()
response = mock_respond.call_args[0][0]
- assert response.error is None
\ No newline at end of file
+ assert response.error is None
+
+
+class TestKnowledgeManagerLibraryDownload:
+ """Test get_kg_core streaming of library documents."""
+
+ @pytest.fixture
+ def manager_with_librarian(self, mock_flow_config):
+ with patch('trustgraph.cores.knowledge.KnowledgeTableStore'):
+ mock_librarian = AsyncMock()
+ manager = KnowledgeManager(
+ cassandra_host=["localhost"],
+ cassandra_username="test_user",
+ cassandra_password="test_pass",
+ keyspace="test_keyspace",
+ flow_config=mock_flow_config,
+ librarian=mock_librarian,
+ )
+ manager.table_store = AsyncMock()
+ return manager
+
+ @pytest.mark.asyncio
+ async def test_get_kg_core_streams_library_docs(self, manager_with_librarian):
+ mock_request = Mock()
+ mock_request.id = "root-doc"
+ mock_respond = AsyncMock()
+
+ manager_with_librarian.table_store.get_triples = AsyncMock()
+ manager_with_librarian.table_store.get_graph_embeddings = AsyncMock()
+
+ root_meta = DocumentMetadata(
+ id="root-doc", kind="application/pdf", title="Test PDF",
+ document_type="source",
+ )
+ child_meta = DocumentMetadata(
+ id="chunk-1", kind="text/plain", title="Chunk 1",
+ parent_id="root-doc", document_type="chunk",
+ )
+
+ manager_with_librarian.librarian.fetch_document_metadata.return_value = root_meta
+ manager_with_librarian.librarian.request.return_value = LibrarianResponse(
+ document_metadatas=[child_meta],
+ )
+ manager_with_librarian.librarian.fetch_document_content.side_effect = [
+ b"cm9vdCBjb250ZW50",
+ b"Y2h1bmsgY29udGVudA==",
+ ]
+
+ await manager_with_librarian.get_kg_core(
+ mock_request, mock_respond, "test-user"
+ )
+
+ responses = [c[0][0] for c in mock_respond.call_args_list]
+
+ lm_responses = [r for r in responses if r.library_metadata is not None]
+ lb_responses = [r for r in responses if r.library_blob is not None]
+ eos_responses = [r for r in responses if r.eos is True]
+
+ assert len(lm_responses) == 2
+ assert lm_responses[0].library_metadata.id == "root-doc"
+ assert lm_responses[0].library_metadata.document_type == "source"
+ assert lm_responses[1].library_metadata.id == "chunk-1"
+ assert lm_responses[1].library_metadata.parent_id == "root-doc"
+
+ assert len(lb_responses) == 2
+ assert lb_responses[0].library_blob.id == "root-doc"
+ assert lb_responses[0].library_blob.data == b"cm9vdCBjb250ZW50"
+ assert lb_responses[1].library_blob.id == "chunk-1"
+
+ assert len(eos_responses) == 1
+
+ @pytest.mark.asyncio
+ async def test_get_kg_core_no_librarian_skips_library(self, mock_flow_config):
+ with patch('trustgraph.cores.knowledge.KnowledgeTableStore'):
+ manager = KnowledgeManager(
+ cassandra_host=["localhost"],
+ cassandra_username="u", cassandra_password="p",
+ keyspace="ks", flow_config=mock_flow_config,
+ )
+ manager.table_store = AsyncMock()
+ manager.table_store.get_triples = AsyncMock()
+ manager.table_store.get_graph_embeddings = AsyncMock()
+
+ mock_request = Mock()
+ mock_request.id = "doc-1"
+ mock_respond = AsyncMock()
+
+ await manager.get_kg_core(mock_request, mock_respond, "w")
+
+ responses = [c[0][0] for c in mock_respond.call_args_list]
+ assert all(r.library_metadata is None for r in responses)
+ assert all(r.library_blob is None for r in responses)
+
+ @pytest.mark.asyncio
+ async def test_get_kg_core_librarian_metadata_failure_is_graceful(
+ self, manager_with_librarian,
+ ):
+ mock_request = Mock()
+ mock_request.id = "missing-doc"
+ mock_respond = AsyncMock()
+
+ manager_with_librarian.table_store.get_triples = AsyncMock()
+ manager_with_librarian.table_store.get_graph_embeddings = AsyncMock()
+ manager_with_librarian.librarian.fetch_document_metadata.side_effect = (
+ RuntimeError("not found")
+ )
+
+ await manager_with_librarian.get_kg_core(
+ mock_request, mock_respond, "test-user"
+ )
+
+ responses = [c[0][0] for c in mock_respond.call_args_list]
+ assert all(r.library_metadata is None for r in responses)
+ assert any(r.eos for r in responses)
+
+
+class TestKnowledgeManagerLibraryUpload:
+ """Test put_kg_core handling of library metadata and blob records."""
+
+ @pytest.fixture
+ def manager_with_librarian(self, mock_flow_config):
+ with patch('trustgraph.cores.knowledge.KnowledgeTableStore'):
+ mock_librarian = AsyncMock()
+ manager = KnowledgeManager(
+ cassandra_host=["localhost"],
+ cassandra_username="u", cassandra_password="p",
+ keyspace="ks", flow_config=mock_flow_config,
+ librarian=mock_librarian,
+ )
+ manager.table_store = AsyncMock()
+ return manager
+
+ @pytest.mark.asyncio
+ async def test_put_metadata_then_blob_calls_librarian(
+ self, manager_with_librarian,
+ ):
+ mock_respond = AsyncMock()
+ manager_with_librarian.librarian.request.return_value = LibrarianResponse()
+
+ # First call: metadata
+ req_meta = Mock()
+ req_meta.triples = None
+ req_meta.graph_embeddings = None
+ req_meta.library_metadata = LibraryMetadata(
+ id="doc-1", kind="application/pdf", title="Test",
+ document_type="source",
+ )
+ req_meta.library_blob = None
+ await manager_with_librarian.put_kg_core(req_meta, mock_respond, "ws")
+
+ # Metadata is buffered, librarian not called yet
+ manager_with_librarian.librarian.request.assert_not_called()
+
+ # Second call: blob
+ req_blob = Mock()
+ req_blob.triples = None
+ req_blob.graph_embeddings = None
+ req_blob.library_metadata = None
+ req_blob.library_blob = LibraryBlob(
+ id="doc-1", data=b"dGVzdA==",
+ )
+ await manager_with_librarian.put_kg_core(req_blob, mock_respond, "ws")
+
+ # Now librarian should have been called with add-document
+ manager_with_librarian.librarian.request.assert_called_once()
+ call_args = manager_with_librarian.librarian.request.call_args[0][0]
+ assert call_args.operation == "add-document"
+ assert call_args.document_metadata.id == "doc-1"
+ assert call_args.document_metadata.kind == "application/pdf"
+ assert call_args.content == b"dGVzdA=="
+
+ @pytest.mark.asyncio
+ async def test_put_child_document_uses_add_child_operation(
+ self, manager_with_librarian,
+ ):
+ mock_respond = AsyncMock()
+ manager_with_librarian.librarian.request.return_value = LibrarianResponse()
+
+ req_meta = Mock()
+ req_meta.triples = None
+ req_meta.graph_embeddings = None
+ req_meta.library_metadata = LibraryMetadata(
+ id="chunk-1", kind="text/plain", title="Chunk",
+ parent_id="doc-1", document_type="chunk",
+ )
+ req_meta.library_blob = None
+ await manager_with_librarian.put_kg_core(req_meta, mock_respond, "ws")
+
+ req_blob = Mock()
+ req_blob.triples = None
+ req_blob.graph_embeddings = None
+ req_blob.library_metadata = None
+ req_blob.library_blob = LibraryBlob(id="chunk-1", data=b"Y2h1bms=")
+ await manager_with_librarian.put_kg_core(req_blob, mock_respond, "ws")
+
+ call_args = manager_with_librarian.librarian.request.call_args[0][0]
+ assert call_args.operation == "add-child-document"
+ assert call_args.document_metadata.parent_id == "doc-1"
+
+ @pytest.mark.asyncio
+ async def test_put_blob_without_metadata_logs_warning(
+ self, manager_with_librarian,
+ ):
+ mock_respond = AsyncMock()
+
+ req_blob = Mock()
+ req_blob.triples = None
+ req_blob.graph_embeddings = None
+ req_blob.library_metadata = None
+ req_blob.library_blob = LibraryBlob(id="orphan", data=b"data")
+ await manager_with_librarian.put_kg_core(req_blob, mock_respond, "ws")
+
+ # Librarian should not be called for orphan blob
+ manager_with_librarian.librarian.request.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_put_existing_document_is_graceful(
+ self, manager_with_librarian,
+ ):
+ mock_respond = AsyncMock()
+ manager_with_librarian.librarian.request.side_effect = RuntimeError(
+ "Document already exists"
+ )
+
+ req_meta = Mock()
+ req_meta.triples = None
+ req_meta.graph_embeddings = None
+ req_meta.library_metadata = LibraryMetadata(
+ id="doc-1", kind="application/pdf", title="Test",
+ document_type="source",
+ )
+ req_meta.library_blob = None
+ await manager_with_librarian.put_kg_core(req_meta, mock_respond, "ws")
+
+ req_blob = Mock()
+ req_blob.triples = None
+ req_blob.graph_embeddings = None
+ req_blob.library_metadata = None
+ req_blob.library_blob = LibraryBlob(id="doc-1", data=b"data")
+ await manager_with_librarian.put_kg_core(req_blob, mock_respond, "ws")
+
+ # Should not raise — "already exists" is handled gracefully
\ No newline at end of file
diff --git a/tests/unit/test_tables/test_knowledge_table_store.py b/tests/unit/test_tables/test_knowledge_table_store.py
index 9a0b55c4..2d058733 100644
--- a/tests/unit/test_tables/test_knowledge_table_store.py
+++ b/tests/unit/test_tables/test_knowledge_table_store.py
@@ -155,7 +155,7 @@ class TestGetTriples:
@pytest.mark.asyncio
@patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
async def test_row_converts_to_triples(self, mock_async_execute_paged):
- # row[3] is a list of (s_val, s_uri, p_val, p_uri, o_val, o_uri)
+ # row[3] is a list of (s_val, s_uri, p_val, p_uri, o_val, o_uri, graph)
fake_row = (
None, None, None,
[
@@ -163,6 +163,7 @@ class TestGetTriples:
"http://example.org/alice", True,
"http://example.org/knows", True,
"http://example.org/bob", True,
+ "urn:graph:source",
),
],
)
@@ -191,3 +192,33 @@ class TestGetTriples:
assert t.s.iri == "http://example.org/alice"
assert t.p.iri == "http://example.org/knows"
assert t.o.iri == "http://example.org/bob"
+ assert t.g == "urn:graph:source"
+
+ @pytest.mark.asyncio
+ @patch('trustgraph.tables.knowledge.async_execute_paged', new_callable=AsyncMock)
+ async def test_empty_graph_name_becomes_none(self, mock_async_execute_paged):
+ fake_row = (
+ None, None, None,
+ [
+ (
+ "http://example.org/alice", True,
+ "http://example.org/knows", True,
+ "http://example.org/bob", True,
+ "",
+ ),
+ ],
+ )
+
+ store = _make_store()
+ store.cassandra = Mock()
+ store.get_triples_stmt = Mock()
+ mock_async_execute_paged.return_value = [[fake_row]]
+
+ received = []
+
+ async def receiver(msg):
+ received.append(msg)
+
+ await store.get_triples("w", "d", receiver)
+
+ assert received[0].triples[0].g is None
diff --git a/tests/unit/test_translators/test_knowledge_translator_roundtrip.py b/tests/unit/test_translators/test_knowledge_translator_roundtrip.py
index 437b83c8..af128f23 100644
--- a/tests/unit/test_translators/test_knowledge_translator_roundtrip.py
+++ b/tests/unit/test_translators/test_knowledge_translator_roundtrip.py
@@ -1,5 +1,6 @@
"""
-Round-trip unit tests for KnowledgeRequestTranslator.
+Round-trip unit tests for KnowledgeRequestTranslator and
+KnowledgeResponseTranslator.
Regression coverage: a previous version of the decode side constructed
EntityEmbeddings(vectors=...) — the schema field is `vector` (singular),
@@ -15,9 +16,13 @@ Triples breaks the test.
import pytest
-from trustgraph.messaging.translators.knowledge import KnowledgeRequestTranslator
+from trustgraph.messaging.translators.knowledge import (
+ KnowledgeRequestTranslator,
+ KnowledgeResponseTranslator,
+)
from trustgraph.schema import (
KnowledgeRequest,
+ KnowledgeResponse,
GraphEmbeddings,
EntityEmbeddings,
Triples,
@@ -25,6 +30,8 @@ from trustgraph.schema import (
Metadata,
Term,
IRI,
+ LibraryMetadata,
+ LibraryBlob,
)
@@ -145,3 +152,161 @@ class TestKnowledgeRequestTranslatorTriples:
assert t.s.iri == "http://example.org/alice"
assert t.p.iri == "http://example.org/knows"
assert t.o.iri == "http://example.org/bob"
+
+
+class TestKnowledgeRequestTranslatorLibrary:
+
+ def test_roundtrip_preserves_library_metadata(self, translator):
+ request = KnowledgeRequest(
+ operation="put-kg-core",
+ id="doc-1",
+ library_metadata=LibraryMetadata(
+ id="doc-1",
+ kind="application/pdf",
+ title="Test Document",
+ parent_id="",
+ document_type="source",
+ comments="test comments",
+ tags=["tag1", "tag2"],
+ ),
+ )
+
+ encoded = translator.encode(request)
+ assert "library-metadata" in encoded
+ lm = encoded["library-metadata"]
+ assert lm["id"] == "doc-1"
+ assert lm["kind"] == "application/pdf"
+ assert lm["title"] == "Test Document"
+ assert lm["parent-id"] == ""
+ assert lm["document-type"] == "source"
+ assert lm["comments"] == "test comments"
+ assert lm["tags"] == ["tag1", "tag2"]
+
+ decoded = translator.decode(encoded)
+ assert decoded.library_metadata is not None
+ assert decoded.library_metadata.id == "doc-1"
+ assert decoded.library_metadata.kind == "application/pdf"
+ assert decoded.library_metadata.title == "Test Document"
+ assert decoded.library_metadata.parent_id == ""
+ assert decoded.library_metadata.document_type == "source"
+ assert decoded.library_metadata.comments == "test comments"
+ assert decoded.library_metadata.tags == ["tag1", "tag2"]
+
+ def test_roundtrip_preserves_child_document_metadata(self, translator):
+ request = KnowledgeRequest(
+ operation="put-kg-core",
+ id="doc-1",
+ library_metadata=LibraryMetadata(
+ id="chunk-1",
+ kind="text/plain",
+ title="Chunk 1",
+ parent_id="doc-1",
+ document_type="chunk",
+ ),
+ )
+
+ encoded = translator.encode(request)
+ decoded = translator.decode(encoded)
+
+ assert decoded.library_metadata.parent_id == "doc-1"
+ assert decoded.library_metadata.document_type == "chunk"
+
+ def test_roundtrip_preserves_library_blob(self, translator):
+ request = KnowledgeRequest(
+ operation="put-kg-core",
+ id="doc-1",
+ library_blob=LibraryBlob(
+ id="doc-1",
+ data=b"SGVsbG8gV29ybGQ=",
+ ),
+ )
+
+ encoded = translator.encode(request)
+ assert "library-blob" in encoded
+ assert encoded["library-blob"]["id"] == "doc-1"
+ assert encoded["library-blob"]["data"] == "SGVsbG8gV29ybGQ="
+
+ decoded = translator.decode(encoded)
+ assert decoded.library_blob is not None
+ assert decoded.library_blob.id == "doc-1"
+ assert decoded.library_blob.data == "SGVsbG8gV29ybGQ="
+
+ def test_absent_library_fields_decode_as_none(self, translator):
+ decoded = translator.decode({
+ "operation": "get-kg-core",
+ "id": "doc-1",
+ })
+ assert decoded.library_metadata is None
+ assert decoded.library_blob is None
+
+
+class TestKnowledgeResponseTranslatorLibrary:
+
+ @pytest.fixture
+ def response_translator(self):
+ return KnowledgeResponseTranslator()
+
+ def test_encode_library_metadata(self, response_translator):
+ response = KnowledgeResponse(
+ ids=None,
+ library_metadata=LibraryMetadata(
+ id="doc-1",
+ kind="application/pdf",
+ title="Test",
+ parent_id="",
+ document_type="source",
+ comments="",
+ tags=[],
+ ),
+ )
+ encoded = response_translator.encode(response)
+ assert "library-metadata" in encoded
+ assert encoded["library-metadata"]["id"] == "doc-1"
+ assert encoded["library-metadata"]["kind"] == "application/pdf"
+ assert encoded["library-metadata"]["document-type"] == "source"
+
+ def test_encode_library_blob_bytes_to_string(self, response_translator):
+ response = KnowledgeResponse(
+ ids=None,
+ library_blob=LibraryBlob(
+ id="doc-1",
+ data=b"dGVzdCBkYXRh",
+ ),
+ )
+ encoded = response_translator.encode(response)
+ assert "library-blob" in encoded
+ assert encoded["library-blob"]["id"] == "doc-1"
+ assert encoded["library-blob"]["data"] == "dGVzdCBkYXRh"
+ assert isinstance(encoded["library-blob"]["data"], str)
+
+ def test_encode_library_blob_string_passthrough(self, response_translator):
+ response = KnowledgeResponse(
+ ids=None,
+ library_blob=LibraryBlob(
+ id="doc-1",
+ data="already-a-string",
+ ),
+ )
+ encoded = response_translator.encode(response)
+ assert encoded["library-blob"]["data"] == "already-a-string"
+
+ def test_library_metadata_is_not_final(self, response_translator):
+ response = KnowledgeResponse(
+ ids=None,
+ library_metadata=LibraryMetadata(id="doc-1"),
+ )
+ _, is_final = response_translator.encode_with_completion(response)
+ assert is_final is False
+
+ def test_library_blob_is_not_final(self, response_translator):
+ response = KnowledgeResponse(
+ ids=None,
+ library_blob=LibraryBlob(id="doc-1", data=b"data"),
+ )
+ _, is_final = response_translator.encode_with_completion(response)
+ assert is_final is False
+
+ def test_eos_is_final(self, response_translator):
+ response = KnowledgeResponse(eos=True)
+ _, is_final = response_translator.encode_with_completion(response)
+ assert is_final is True
diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py
index b88d0c78..91bc67a1 100644
--- a/trustgraph-base/trustgraph/api/socket_client.py
+++ b/trustgraph-base/trustgraph/api/socket_client.py
@@ -502,6 +502,7 @@ class SocketClient:
def put_kg_core(
self, id: str, triples=None, graph_embeddings=None,
+ library_metadata=None, library_blob=None,
) -> Dict[str, Any]:
request = {
"operation": "put-kg-core",
@@ -512,6 +513,10 @@ class SocketClient:
request["triples"] = triples
if graph_embeddings is not None:
request["graph-embeddings"] = graph_embeddings
+ if library_metadata is not None:
+ request["library-metadata"] = library_metadata
+ if library_blob is not None:
+ request["library-blob"] = library_blob
return self._send_request_sync("knowledge", None, request)
def get_de_core(self, id: str) -> Iterator[Dict[str, Any]]:
diff --git a/trustgraph-base/trustgraph/messaging/translators/knowledge.py b/trustgraph-base/trustgraph/messaging/translators/knowledge.py
index 3830bf59..3f09b41b 100644
--- a/trustgraph-base/trustgraph/messaging/translators/knowledge.py
+++ b/trustgraph-base/trustgraph/messaging/translators/knowledge.py
@@ -2,7 +2,8 @@ from typing import Dict, Any, Tuple, Optional
from ...schema import (
KnowledgeRequest, KnowledgeResponse, Triples, GraphEmbeddings,
DocumentEmbeddings, ChunkEmbeddings,
- Metadata, EntityEmbeddings
+ Metadata, EntityEmbeddings,
+ LibraryMetadata, LibraryBlob,
)
from .base import MessageTranslator
from .primitives import ValueTranslator, SubgraphTranslator
@@ -61,6 +62,27 @@ class KnowledgeRequestTranslator(MessageTranslator):
]
)
+ library_metadata = None
+ if "library-metadata" in data:
+ lm = data["library-metadata"]
+ library_metadata = LibraryMetadata(
+ id=lm.get("id", ""),
+ kind=lm.get("kind", ""),
+ title=lm.get("title", ""),
+ parent_id=lm.get("parent-id", ""),
+ document_type=lm.get("document-type", ""),
+ comments=lm.get("comments", ""),
+ tags=lm.get("tags", []),
+ )
+
+ library_blob = None
+ if "library-blob" in data:
+ lb = data["library-blob"]
+ library_blob = LibraryBlob(
+ id=lb.get("id", ""),
+ data=lb.get("data", b""),
+ )
+
return KnowledgeRequest(
operation=data.get("operation"),
id=data.get("id"),
@@ -69,6 +91,8 @@ class KnowledgeRequestTranslator(MessageTranslator):
triples=triples,
graph_embeddings=graph_embeddings,
document_embeddings=document_embeddings,
+ library_metadata=library_metadata,
+ library_blob=library_blob,
)
def encode(self, obj: KnowledgeRequest) -> Dict[str, Any]:
@@ -125,6 +149,26 @@ class KnowledgeRequestTranslator(MessageTranslator):
],
}
+ if obj.library_metadata:
+ result["library-metadata"] = {
+ "id": obj.library_metadata.id,
+ "kind": obj.library_metadata.kind,
+ "title": obj.library_metadata.title,
+ "parent-id": obj.library_metadata.parent_id,
+ "document-type": obj.library_metadata.document_type,
+ "comments": obj.library_metadata.comments,
+ "tags": obj.library_metadata.tags,
+ }
+
+ if obj.library_blob:
+ data = obj.library_blob.data
+ if isinstance(data, bytes):
+ data = data.decode("utf-8")
+ result["library-blob"] = {
+ "id": obj.library_blob.id,
+ "data": data,
+ }
+
return result
@@ -194,6 +238,32 @@ class KnowledgeResponseTranslator(MessageTranslator):
}
}
+ # Streaming library metadata response
+ if obj.library_metadata:
+ return {
+ "library-metadata": {
+ "id": obj.library_metadata.id,
+ "kind": obj.library_metadata.kind,
+ "title": obj.library_metadata.title,
+ "parent-id": obj.library_metadata.parent_id,
+ "document-type": obj.library_metadata.document_type,
+ "comments": obj.library_metadata.comments,
+ "tags": obj.library_metadata.tags,
+ }
+ }
+
+ # Streaming library blob response
+ if obj.library_blob:
+ data = obj.library_blob.data
+ if isinstance(data, bytes):
+ data = data.decode("utf-8")
+ return {
+ "library-blob": {
+ "id": obj.library_blob.id,
+ "data": data,
+ }
+ }
+
# End of stream marker
if obj.eos is True:
return {"eos": True}
@@ -209,7 +279,9 @@ class KnowledgeResponseTranslator(MessageTranslator):
is_final = (
obj.ids is not None or # List response
obj.eos is True or # End of stream
- (not obj.triples and not obj.graph_embeddings and not obj.document_embeddings) # Empty response
+ (not obj.triples and not obj.graph_embeddings
+ and not obj.document_embeddings
+ and not obj.library_metadata and not obj.library_blob) # Empty response
)
return response, is_final
\ No newline at end of file
diff --git a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py
index a3879103..4353065b 100644
--- a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py
+++ b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py
@@ -21,6 +21,21 @@ from .embeddings import GraphEmbeddings, DocumentEmbeddings
# <- ()
# <- (error)
+@dataclass
+class LibraryMetadata:
+ id: str = ""
+ kind: str = ""
+ title: str = ""
+ parent_id: str = ""
+ document_type: str = ""
+ comments: str = ""
+ tags: list[str] = field(default_factory=list)
+
+@dataclass
+class LibraryBlob:
+ id: str = ""
+ data: bytes = b""
+
@dataclass
class KnowledgeRequest:
# get-kg-core, delete-kg-core, list-kg-cores, put-kg-core
@@ -44,6 +59,10 @@ class KnowledgeRequest:
# put-de-core
document_embeddings: DocumentEmbeddings | None = None
+ # put-kg-core (source material)
+ library_metadata: LibraryMetadata | None = None
+ library_blob: LibraryBlob | None = None
+
@dataclass
class KnowledgeResponse:
error: Error | None = None
@@ -52,6 +71,8 @@ class KnowledgeResponse:
triples: Triples | None = None
graph_embeddings: GraphEmbeddings | None = None
document_embeddings: DocumentEmbeddings | None = None
+ library_metadata: LibraryMetadata | None = None
+ library_blob: LibraryBlob | None = None
knowledge_request_queue = queue('knowledge', cls='request')
knowledge_response_queue = queue('knowledge', cls='response')
diff --git a/trustgraph-cli/trustgraph/cli/get_kg_core.py b/trustgraph-cli/trustgraph/cli/get_kg_core.py
index b4f37b81..2ff1a3cc 100644
--- a/trustgraph-cli/trustgraph/cli/get_kg_core.py
+++ b/trustgraph-cli/trustgraph/cli/get_kg_core.py
@@ -47,6 +47,31 @@ def write_ge(f, data):
)
f.write(msgpack.packb(msg, use_bin_type=True))
+def write_library_metadata(f, data):
+ msg = (
+ "lm",
+ {
+ "i": data["id"],
+ "k": data.get("kind", ""),
+ "t": data.get("title", ""),
+ "p": data.get("parent-id", ""),
+ "d": data.get("document-type", ""),
+ "c": data.get("comments", ""),
+ "g": data.get("tags", []),
+ }
+ )
+ f.write(msgpack.packb(msg, use_bin_type=True))
+
+def write_library_blob(f, data):
+ msg = (
+ "lb",
+ {
+ "i": data["id"],
+ "d": data.get("data", b""),
+ }
+ )
+ f.write(msgpack.packb(msg, use_bin_type=True))
+
def fetch(url, workspace, id, output, token=None):
api = Api(url=url, token=token, workspace=workspace)
@@ -55,6 +80,8 @@ def fetch(url, workspace, id, output, token=None):
try:
ge = 0
t = 0
+ lm = 0
+ lb = 0
with open(output, "wb") as f:
@@ -68,7 +95,15 @@ def fetch(url, workspace, id, output, token=None):
ge += 1
write_ge(f, response["graph-embeddings"])
- print(f"Got: {t} triple, {ge} GE messages.")
+ if "library-metadata" in response:
+ lm += 1
+ write_library_metadata(f, response["library-metadata"])
+
+ if "library-blob" in response:
+ lb += 1
+ write_library_blob(f, response["library-blob"])
+
+ print(f"Got: {t} triple, {ge} GE, {lm} library metadata, {lb} library blob messages.")
finally:
socket.close()
diff --git a/trustgraph-cli/trustgraph/cli/put_kg_core.py b/trustgraph-cli/trustgraph/cli/put_kg_core.py
index fe0981a5..f4e0b3dd 100644
--- a/trustgraph-cli/trustgraph/cli/put_kg_core.py
+++ b/trustgraph-cli/trustgraph/cli/put_kg_core.py
@@ -40,6 +40,23 @@ def read_message(unpacked, id):
},
"triples": msg["t"],
}
+ elif unpacked[0] == "lm":
+ msg = unpacked[1]
+ return "lm", {
+ "id": msg["i"],
+ "kind": msg.get("k", ""),
+ "title": msg.get("t", ""),
+ "parent-id": msg.get("p", ""),
+ "document-type": msg.get("d", ""),
+ "comments": msg.get("c", ""),
+ "tags": msg.get("g", []),
+ }
+ elif unpacked[0] == "lb":
+ msg = unpacked[1]
+ return "lb", {
+ "id": msg["i"],
+ "data": msg.get("d", b""),
+ }
else:
raise RuntimeError("Unpacked unexpected messsage type", unpacked[0])
@@ -51,6 +68,8 @@ def put(url, workspace, id, input, token=None):
try:
ge = 0
t = 0
+ lm = 0
+ lb = 0
with open(input, "rb") as f:
@@ -73,10 +92,18 @@ def put(url, workspace, id, input, token=None):
t += 1
socket.put_kg_core(id, triples=msg)
+ elif kind == "lm":
+ lm += 1
+ socket.put_kg_core(id, library_metadata=msg)
+
+ elif kind == "lb":
+ lb += 1
+ socket.put_kg_core(id, library_blob=msg)
+
else:
raise RuntimeError("Unexpected message kind", kind)
- print(f"Put: {t} triple, {ge} GE messages.")
+ print(f"Put: {t} triple, {ge} GE, {lm} library metadata, {lb} library blob messages.")
finally:
socket.close()
diff --git a/trustgraph-flow/trustgraph/cores/knowledge.py b/trustgraph-flow/trustgraph/cores/knowledge.py
index f1fa53f5..6f017c43 100644
--- a/trustgraph-flow/trustgraph/cores/knowledge.py
+++ b/trustgraph-flow/trustgraph/cores/knowledge.py
@@ -1,6 +1,7 @@
from .. schema import KnowledgeResponse, Error, Triples, GraphEmbeddings
-from .. schema import DocumentEmbeddings
+from .. schema import DocumentEmbeddings, LibraryMetadata, LibraryBlob
+from .. schema import LibrarianRequest, DocumentMetadata
from .. knowledge import hash
from .. exceptions import RequestError
from .. tables.knowledge import KnowledgeTableStore
@@ -18,7 +19,7 @@ class KnowledgeManager:
def __init__(
self, cassandra_host, cassandra_username, cassandra_password,
- keyspace, flow_config, replication_factor=1,
+ keyspace, flow_config, librarian=None, replication_factor=1,
):
self.table_store = KnowledgeTableStore(
@@ -26,6 +27,9 @@ class KnowledgeManager:
replication_factor
)
+ self.librarian = librarian
+ self._pending_library_metadata = {}
+
self.loader_queue = asyncio.Queue(maxsize=20)
self.background_task = None
self.flow_config = flow_config
@@ -86,6 +90,9 @@ class KnowledgeManager:
publish_ge,
)
+ if self.librarian:
+ await self._stream_library_docs(request.id, respond)
+
logger.debug("Knowledge core retrieval complete")
await respond(
@@ -122,6 +129,12 @@ class KnowledgeManager:
workspace, request.graph_embeddings
)
+ if request.library_metadata and self.librarian:
+ await self._put_library_metadata(request.library_metadata, workspace)
+
+ if request.library_blob and self.librarian:
+ await self._put_library_blob(request.library_blob, workspace)
+
await respond(
KnowledgeResponse(
error = None,
@@ -250,6 +263,112 @@ class KnowledgeManager:
await self.loader_queue.put((request, respond, workspace))
+ async def _stream_library_docs(self, document_id, respond):
+
+ try:
+ root_meta = await self.librarian.fetch_document_metadata(
+ document_id
+ )
+ except Exception as e:
+ logger.warning(f"Could not fetch library metadata for {document_id}: {e}")
+ return
+
+ if root_meta is None:
+ return
+
+ await self._stream_one_doc(root_meta, respond)
+
+ try:
+ resp = await self.librarian.request(
+ LibrarianRequest(
+ operation="list-children",
+ document_id=document_id,
+ )
+ )
+ except Exception as e:
+ logger.warning(f"Could not list children for {document_id}: {e}")
+ return
+
+ for child_meta in resp.document_metadatas:
+ await self._stream_one_doc(child_meta, respond)
+
+ async def _stream_one_doc(self, doc_meta, respond):
+
+ lm = LibraryMetadata(
+ id=doc_meta.id,
+ kind=doc_meta.kind,
+ title=doc_meta.title,
+ parent_id=doc_meta.parent_id,
+ document_type=doc_meta.document_type,
+ comments=doc_meta.comments,
+ tags=doc_meta.tags or [],
+ )
+
+ await respond(
+ KnowledgeResponse(library_metadata=lm)
+ )
+
+ try:
+ content = await self.librarian.fetch_document_content(
+ doc_meta.id
+ )
+ except Exception as e:
+ logger.warning(f"Could not fetch content for {doc_meta.id}: {e}")
+ return
+
+ await respond(
+ KnowledgeResponse(
+ library_blob=LibraryBlob(
+ id=doc_meta.id,
+ data=content,
+ )
+ )
+ )
+
+ async def _put_library_metadata(self, lm, workspace):
+ self._pending_library_metadata[lm.id] = lm
+
+ async def _put_library_blob(self, lb, workspace):
+
+ lm = self._pending_library_metadata.pop(lb.id, None)
+ if lm is None:
+ logger.warning(
+ f"Received library blob for {lb.id} with no preceding metadata"
+ )
+ return
+
+ doc_meta = DocumentMetadata(
+ id=lm.id,
+ kind=lm.kind,
+ title=lm.title,
+ parent_id=lm.parent_id,
+ document_type=lm.document_type,
+ comments=lm.comments,
+ tags=lm.tags or [],
+ )
+
+ if lm.parent_id:
+ operation = "add-child-document"
+ else:
+ operation = "add-document"
+
+ try:
+ await self.librarian.request(
+ LibrarianRequest(
+ operation=operation,
+ document_id=lm.id,
+ document_metadata=doc_meta,
+ content=lb.data,
+ )
+ )
+ except RuntimeError as e:
+ if "already exists" in str(e):
+ logger.debug(f"Library document {lm.id} already exists, skipping")
+ else:
+ logger.warning(f"Could not save library document {lm.id}: {e}")
+ except Exception as e:
+ logger.warning(f"Could not save library document {lm.id}: {e}")
+
async def core_loader(self):
logger.info("Knowledge background processor running...")
diff --git a/trustgraph-flow/trustgraph/cores/service.py b/trustgraph-flow/trustgraph/cores/service.py
index a04e42ca..a8f52efd 100755
--- a/trustgraph-flow/trustgraph/cores/service.py
+++ b/trustgraph-flow/trustgraph/cores/service.py
@@ -12,6 +12,7 @@ import logging
from .. base import WorkspaceProcessor, Consumer, Producer, Publisher, Subscriber
from .. base import ConsumerMetrics, ProducerMetrics
from .. base.cassandra_config import add_cassandra_args, resolve_cassandra_config
+from .. base import LibrarianClient
from .. schema import KnowledgeRequest, KnowledgeResponse, Error
from .. schema import knowledge_request_queue, knowledge_response_queue
@@ -77,12 +78,17 @@ class Processor(WorkspaceProcessor):
}
)
+ self.librarian_client = LibrarianClient(
+ id=id, backend=self.pubsub, taskgroup=self.taskgroup,
+ )
+
self.knowledge = KnowledgeManager(
cassandra_host = self.cassandra_host,
cassandra_username = self.cassandra_username,
cassandra_password = self.cassandra_password,
keyspace = keyspace,
flow_config = self,
+ librarian = self.librarian_client,
replication_factor = replication_factor,
)
@@ -156,6 +162,7 @@ class Processor(WorkspaceProcessor):
async def start(self):
await super(Processor, self).start()
+ await self.librarian_client.start()
async def on_knowledge_config(self, workspace, config, version):
diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py b/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py
index 6696afbe..90080cc4 100644
--- a/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py
+++ b/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py
@@ -73,6 +73,39 @@ class CoreExport:
enc = msgpack.packb(msg)
await response.write(enc)
+ if "library-metadata" in resp:
+
+ data = resp["library-metadata"]
+ msg = (
+ "lm",
+ {
+ "i": data["id"],
+ "k": data.get("kind", ""),
+ "t": data.get("title", ""),
+ "p": data.get("parent-id", ""),
+ "d": data.get("document-type", ""),
+ "c": data.get("comments", ""),
+ "g": data.get("tags", []),
+ }
+ )
+
+ enc = msgpack.packb(msg)
+ await response.write(enc)
+
+ if "library-blob" in resp:
+
+ data = resp["library-blob"]
+ msg = (
+ "lb",
+ {
+ "i": data["id"],
+ "d": data.get("data", b""),
+ }
+ )
+
+ enc = msgpack.packb(msg, use_bin_type=True)
+ await response.write(enc)
+
await kr.process(
{
"operation": "get-kg-core",
diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py
index d03d4efd..bf660def 100644
--- a/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py
+++ b/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py
@@ -79,6 +79,39 @@ class CoreImport:
await kr.process(msg)
+ elif unpacked[0] == "lm":
+ msg = unpacked[1]
+ msg = {
+ "operation": "put-kg-core",
+ "workspace": workspace,
+ "id": id,
+ "library-metadata": {
+ "id": msg["i"],
+ "kind": msg.get("k", ""),
+ "title": msg.get("t", ""),
+ "parent-id": msg.get("p", ""),
+ "document-type": msg.get("d", ""),
+ "comments": msg.get("c", ""),
+ "tags": msg.get("g", []),
+ }
+ }
+
+ await kr.process(msg)
+
+ elif unpacked[0] == "lb":
+ msg = unpacked[1]
+ msg = {
+ "operation": "put-kg-core",
+ "workspace": workspace,
+ "id": id,
+ "library-blob": {
+ "id": msg["i"],
+ "data": msg.get("d", b""),
+ }
+ }
+
+ await kr.process(msg)
+
except Exception as e:
logger.error(f"Core import exception: {e}", exc_info=True)
await error(str(e))
diff --git a/trustgraph-flow/trustgraph/tables/knowledge.py b/trustgraph-flow/trustgraph/tables/knowledge.py
index 6a23731b..4fcb2dd3 100644
--- a/trustgraph-flow/trustgraph/tables/knowledge.py
+++ b/trustgraph-flow/trustgraph/tables/knowledge.py
@@ -98,7 +98,8 @@ class KnowledgeTableStore:
text, boolean, text, boolean, text, boolean
>>,
triples list>,
PRIMARY KEY ((workspace, document_id), id)
);
@@ -234,7 +235,8 @@ class KnowledgeTableStore:
triples = [
(
- *term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o)
+ *term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o),
+ v.g or ""
)
for v in m.triples
]
@@ -416,6 +418,7 @@ class KnowledgeTableStore:
s = tuple_to_term(elt[0], elt[1]),
p = tuple_to_term(elt[2], elt[3]),
o = tuple_to_term(elt[4], elt[5]),
+ g = elt[6] if elt[6] else None,
)
for elt in row[3]
]
From acf182c26541220793a3ad0923a52e6592d18f84 Mon Sep 17 00:00:00 2001
From: cybermaggedon
Date: Wed, 3 Jun 2026 10:59:58 +0100
Subject: [PATCH 11/18] feat: add env-var fallback for librarian object-store
config (#974)
The librarian now reads OBJECT_STORE_ENDPOINT, OBJECT_STORE_ACCESS_KEY,
OBJECT_STORE_SECRET_KEY, OBJECT_STORE_REGION, and OBJECT_STORE_USE_SSL
from the environment when not set via params. This lets K8s Secrets
supply credentials without them appearing in launch.yaml.
---
.../trustgraph/librarian/service.py | 51 ++++++++++++++-----
1 file changed, 38 insertions(+), 13 deletions(-)
diff --git a/trustgraph-flow/trustgraph/librarian/service.py b/trustgraph-flow/trustgraph/librarian/service.py
index cc5efdae..ee5e9c1b 100755
--- a/trustgraph-flow/trustgraph/librarian/service.py
+++ b/trustgraph-flow/trustgraph/librarian/service.py
@@ -8,6 +8,7 @@ import asyncio
import base64
import json
import logging
+import os
from datetime import datetime
from .. base import WorkspaceProcessor, Consumer, Producer, Publisher, Subscriber
@@ -54,6 +55,16 @@ default_object_store_access_key = "object-user"
default_object_store_secret_key = "object-password"
default_object_store_use_ssl = False
default_object_store_region = None
+
+# Environment variables consulted as a fallback when the
+# corresponding params field is not set in the processor-group YAML
+# or via CLI. Intended for K8s Secret / env-var injection so
+# credentials never have to live in the YAML (and thus in git).
+ENV_OBJECT_STORE_ENDPOINT = "OBJECT_STORE_ENDPOINT"
+ENV_OBJECT_STORE_ACCESS_KEY = "OBJECT_STORE_ACCESS_KEY"
+ENV_OBJECT_STORE_SECRET_KEY = "OBJECT_STORE_SECRET_KEY"
+ENV_OBJECT_STORE_USE_SSL = "OBJECT_STORE_USE_SSL"
+ENV_OBJECT_STORE_REGION = "OBJECT_STORE_REGION"
default_cassandra_host = "cassandra"
default_min_chunk_size = 1 # No minimum by default (for Garage)
@@ -89,22 +100,36 @@ class Processor(WorkspaceProcessor):
"config_response_queue", default_config_response_queue
)
- object_store_endpoint = params.get("object_store_endpoint", default_object_store_endpoint)
- object_store_access_key = params.get(
- "object_store_access_key",
- default_object_store_access_key
+ # Resolve object-store config. Precedence: explicit params
+ # (CLI / processor-group YAML) → environment variable →
+ # hardcoded default. The env-var path lets K8s Secrets feed
+ # credentials without them appearing in the YAML.
+ object_store_endpoint = (
+ params.get("object_store_endpoint")
+ or os.environ.get(ENV_OBJECT_STORE_ENDPOINT)
+ or default_object_store_endpoint
)
- object_store_secret_key = params.get(
- "object_store_secret_key",
- default_object_store_secret_key
+ object_store_access_key = (
+ params.get("object_store_access_key")
+ or os.environ.get(ENV_OBJECT_STORE_ACCESS_KEY)
+ or default_object_store_access_key
)
- object_store_use_ssl = params.get(
- "object_store_use_ssl",
- default_object_store_use_ssl
+ object_store_secret_key = (
+ params.get("object_store_secret_key")
+ or os.environ.get(ENV_OBJECT_STORE_SECRET_KEY)
+ or default_object_store_secret_key
)
- object_store_region = params.get(
- "object_store_region",
- default_object_store_region
+ object_store_use_ssl = params.get("object_store_use_ssl")
+ if object_store_use_ssl is None:
+ env_ssl = os.environ.get(ENV_OBJECT_STORE_USE_SSL)
+ if env_ssl is not None:
+ object_store_use_ssl = env_ssl.lower() in ("true", "1", "yes")
+ else:
+ object_store_use_ssl = default_object_store_use_ssl
+ object_store_region = (
+ params.get("object_store_region")
+ or os.environ.get(ENV_OBJECT_STORE_REGION)
+ or default_object_store_region
)
min_chunk_size = params.get(
From 4913f8c2eb4df9edd48c25bf43af6c240d16cd30 Mon Sep 17 00:00:00 2001
From: cybermaggedon
Date: Thu, 4 Jun 2026 11:49:29 +0100
Subject: [PATCH 12/18] feat: data store replication configuration and TLS
upgrade (#975)
- Add centralised qdrant_config.py helper with env-var fallback for
QDRANT_URL, QDRANT_API_KEY, QDRANT_REPLICATION_FACTOR, QDRANT_SHARD_NUMBER
- Update all 6 Qdrant processors to use the helper; writers pass
replication_factor and shard_number to create_collection
- Fix hardcoded Cassandra replication_factor=1 in cassandra_kg.py,
write.py, and sparql_cassandra.py to respect CASSANDRA_REPLICATION_FACTOR
- Upgrade Cassandra TLS from deprecated PROTOCOL_TLSv1_2 to
ssl.create_default_context() across all connectors
---
.../test_null_embedding_protection.py | 2 +
.../trustgraph/base/qdrant_config.py | 87 +++++++++++++++++++
.../trustgraph/direct/cassandra_kg.py | 18 ++--
.../query/doc_embeddings/qdrant/service.py | 28 ++----
.../query/graph_embeddings/qdrant/service.py | 28 ++----
.../query/ontology/sparql_cassandra.py | 2 +-
.../query/row_embeddings/qdrant/service.py | 28 +++---
.../storage/doc_embeddings/qdrant/write.py | 32 +++----
.../storage/graph_embeddings/qdrant/write.py | 33 +++----
.../storage/row_embeddings/qdrant/write.py | 32 +++----
.../storage/rows/cassandra/write.py | 5 +-
trustgraph-flow/trustgraph/tables/config.py | 4 +-
trustgraph-flow/trustgraph/tables/iam.py | 4 +-
.../trustgraph/tables/knowledge.py | 4 +-
trustgraph-flow/trustgraph/tables/library.py | 4 +-
15 files changed, 182 insertions(+), 129 deletions(-)
create mode 100644 trustgraph-base/trustgraph/base/qdrant_config.py
diff --git a/tests/unit/test_reliability/test_null_embedding_protection.py b/tests/unit/test_reliability/test_null_embedding_protection.py
index dbe06b40..41d0f88b 100644
--- a/tests/unit/test_reliability/test_null_embedding_protection.py
+++ b/tests/unit/test_reliability/test_null_embedding_protection.py
@@ -259,6 +259,8 @@ class TestGraphEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
proc._cache_lock = asyncio.Lock()
proc._known_collections = set()
+ proc.replication_factor = 1
+ proc.shard_number = 1
msg = MagicMock()
msg.metadata.collection = "graphs"
diff --git a/trustgraph-base/trustgraph/base/qdrant_config.py b/trustgraph-base/trustgraph/base/qdrant_config.py
new file mode 100644
index 00000000..f3e015ca
--- /dev/null
+++ b/trustgraph-base/trustgraph/base/qdrant_config.py
@@ -0,0 +1,87 @@
+
+import os
+import argparse
+from typing import Optional, Any, Tuple
+
+
+def get_qdrant_defaults() -> dict:
+ return {
+ 'url': os.getenv('QDRANT_URL', 'http://localhost:6333'),
+ 'api_key': os.getenv('QDRANT_API_KEY'),
+ 'replication_factor': int(os.getenv('QDRANT_REPLICATION_FACTOR', '1')),
+ 'shard_number': int(os.getenv('QDRANT_SHARD_NUMBER', '1')),
+ }
+
+
+def add_qdrant_args(parser: argparse.ArgumentParser) -> None:
+ defaults = get_qdrant_defaults()
+
+ url_help = f"Qdrant URL (default: {defaults['url']})"
+ if 'QDRANT_URL' in os.environ:
+ url_help += " [from QDRANT_URL]"
+
+ api_key_help = "Qdrant API key"
+ if defaults['api_key']:
+ api_key_help += " (default: )"
+ if 'QDRANT_API_KEY' in os.environ:
+ api_key_help += " [from QDRANT_API_KEY]"
+
+ replication_help = f"Qdrant collection replication factor (default: {defaults['replication_factor']})"
+ if 'QDRANT_REPLICATION_FACTOR' in os.environ:
+ replication_help += " [from QDRANT_REPLICATION_FACTOR]"
+
+ shard_help = f"Qdrant collection shard number (default: {defaults['shard_number']})"
+ if 'QDRANT_SHARD_NUMBER' in os.environ:
+ shard_help += " [from QDRANT_SHARD_NUMBER]"
+
+ parser.add_argument(
+ '--store-uri',
+ default=defaults['url'],
+ help=url_help,
+ )
+
+ parser.add_argument(
+ '--api-key',
+ default=defaults['api_key'],
+ help=api_key_help,
+ )
+
+ parser.add_argument(
+ '--qdrant-replication-factor',
+ type=int,
+ default=defaults['replication_factor'],
+ help=replication_help,
+ )
+
+ parser.add_argument(
+ '--qdrant-shard-number',
+ type=int,
+ default=defaults['shard_number'],
+ help=shard_help,
+ )
+
+
+def resolve_qdrant_config(
+ args: Optional[Any] = None,
+ url: Optional[str] = None,
+ api_key: Optional[str] = None,
+ replication_factor: Optional[int] = None,
+ shard_number: Optional[int] = None,
+) -> Tuple[str, Optional[str], int, int]:
+ if args is not None:
+ url = url or getattr(args, 'store_uri', None)
+ api_key = api_key or getattr(args, 'api_key', None)
+ replication_factor = replication_factor or getattr(
+ args, 'qdrant_replication_factor', None
+ )
+ shard_number = shard_number or getattr(
+ args, 'qdrant_shard_number', None
+ )
+
+ defaults = get_qdrant_defaults()
+ url = url or defaults['url']
+ api_key = api_key or defaults['api_key']
+ replication_factor = replication_factor or defaults['replication_factor']
+ shard_number = shard_number or defaults['shard_number']
+
+ return url, api_key, replication_factor, shard_number
diff --git a/trustgraph-flow/trustgraph/direct/cassandra_kg.py b/trustgraph-flow/trustgraph/direct/cassandra_kg.py
index d7abd1a9..f1e4a577 100644
--- a/trustgraph-flow/trustgraph/direct/cassandra_kg.py
+++ b/trustgraph-flow/trustgraph/direct/cassandra_kg.py
@@ -6,7 +6,7 @@ import logging
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.query import BatchStatement, SimpleStatement
-from ssl import SSLContext, PROTOCOL_TLSv1_2
+import ssl
from ..tables.cassandra_async import async_execute
@@ -41,13 +41,15 @@ class KnowledgeGraph:
def __init__(
self, hosts=None,
- keyspace="trustgraph", username=None, password=None
+ keyspace="trustgraph", username=None, password=None,
+ replication_factor=1,
):
if hosts is None:
hosts = ["localhost"]
self.keyspace = keyspace
+ self.replication_factor = replication_factor
self.username = username
# 7-table schema for quads with full query pattern support
@@ -68,7 +70,7 @@ class KnowledgeGraph:
self.collection_metadata_table = "collection_metadata"
if username and password:
- ssl_context = SSLContext(PROTOCOL_TLSv1_2)
+ ssl_context = ssl.create_default_context()
auth_provider = PlainTextAuthProvider(username=username, password=password)
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
else:
@@ -92,7 +94,7 @@ class KnowledgeGraph:
create keyspace if not exists {self.keyspace}
with replication = {{
'class' : 'SimpleStrategy',
- 'replication_factor' : 1
+ 'replication_factor' : {self.replication_factor}
}};
""")
@@ -539,13 +541,15 @@ class EntityCentricKnowledgeGraph:
def __init__(
self, hosts=None,
- keyspace="trustgraph", username=None, password=None
+ keyspace="trustgraph", username=None, password=None,
+ replication_factor=1,
):
if hosts is None:
hosts = ["localhost"]
self.keyspace = keyspace
+ self.replication_factor = replication_factor
self.username = username
# 2-table entity-centric schema
@@ -556,7 +560,7 @@ class EntityCentricKnowledgeGraph:
self.collection_metadata_table = "collection_metadata"
if username and password:
- ssl_context = SSLContext(PROTOCOL_TLSv1_2)
+ ssl_context = ssl.create_default_context()
auth_provider = PlainTextAuthProvider(username=username, password=password)
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
else:
@@ -580,7 +584,7 @@ class EntityCentricKnowledgeGraph:
create keyspace if not exists {self.keyspace}
with replication = {{
'class' : 'SimpleStrategy',
- 'replication_factor' : 1
+ 'replication_factor' : {self.replication_factor}
}};
""")
diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py
index f6770744..b98ab7e5 100755
--- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py
+++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py
@@ -12,31 +12,32 @@ from qdrant_client import QdrantClient
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
from .... schema import Error
from .... base import DocumentEmbeddingsQueryService
+from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "doc-embeddings-query"
-default_store_uri = 'http://localhost:6333'
-
class Processor(DocumentEmbeddingsQueryService):
def __init__(self, **params):
- store_uri = params.get("store_uri", default_store_uri)
+ store_uri = params.get("store_uri")
+ api_key = params.get("api_key")
- #optional api key
- api_key = params.get("api_key", None)
+ url, api_key, _, _ = resolve_qdrant_config(
+ url=store_uri, api_key=api_key,
+ )
super(Processor, self).__init__(
**params | {
- "store_uri": store_uri,
+ "store_uri": url,
"api_key": api_key,
}
)
- self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
+ self.qdrant = QdrantClient(url=url, api_key=api_key)
async def query_document_embeddings(self, workspace, msg):
@@ -85,18 +86,7 @@ class Processor(DocumentEmbeddingsQueryService):
def add_args(parser):
DocumentEmbeddingsQueryService.add_args(parser)
-
- parser.add_argument(
- '-t', '--store-uri',
- default=default_store_uri,
- help=f'Qdrant store URI (default: {default_store_uri})'
- )
-
- parser.add_argument(
- '-k', '--api-key',
- default=None,
- help=f'API key for qdrant (default: None)'
- )
+ add_qdrant_args(parser)
def run():
diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py
index 167130c9..aa93925d 100755
--- a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py
+++ b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py
@@ -12,31 +12,32 @@ from qdrant_client import QdrantClient
from .... schema import GraphEmbeddingsResponse, EntityMatch
from .... schema import Error, Term, IRI, LITERAL
from .... base import GraphEmbeddingsQueryService
+from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "graph-embeddings-query"
-default_store_uri = 'http://localhost:6333'
-
class Processor(GraphEmbeddingsQueryService):
def __init__(self, **params):
- store_uri = params.get("store_uri", default_store_uri)
+ store_uri = params.get("store_uri")
+ api_key = params.get("api_key")
- #optional api key
- api_key = params.get("api_key", None)
+ url, api_key, _, _ = resolve_qdrant_config(
+ url=store_uri, api_key=api_key,
+ )
super(Processor, self).__init__(
**params | {
- "store_uri": store_uri,
+ "store_uri": url,
"api_key": api_key,
}
)
- self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
+ self.qdrant = QdrantClient(url=url, api_key=api_key)
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
@@ -104,18 +105,7 @@ class Processor(GraphEmbeddingsQueryService):
def add_args(parser):
GraphEmbeddingsQueryService.add_args(parser)
-
- parser.add_argument(
- '-t', '--store-uri',
- default=default_store_uri,
- help=f'Qdrant store URI (default: {default_store_uri})'
- )
-
- parser.add_argument(
- '-k', '--api-key',
- default=None,
- help=f'API key for qdrant (default: None)'
- )
+ add_qdrant_args(parser)
def run():
diff --git a/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py b/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py
index b7f0f423..a9005ee4 100644
--- a/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py
+++ b/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py
@@ -116,7 +116,7 @@ class CassandraTripleStore(Store if RDFLIB_AVAILABLE else object):
# Create keyspace
self.session.execute(f"""
CREATE KEYSPACE IF NOT EXISTS {self.keyspace}
- WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}
+ WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': {self.cassandra_config.get('replication_factor', 1)}}}
""")
# Create triples table optimized for SPARQL queries
diff --git a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py
index 1534c044..7e1a5851 100644
--- a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py
+++ b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py
@@ -19,12 +19,12 @@ from .... schema import (
RowIndexMatch, Error
)
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
+from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "row-embeddings-query"
-default_store_uri = 'http://localhost:6333'
default_concurrency = 10
@@ -35,13 +35,17 @@ class Processor(FlowProcessor):
id = params.get("id", default_ident)
concurrency = params.get("concurrency", default_concurrency)
- store_uri = params.get("store_uri", default_store_uri)
- api_key = params.get("api_key", None)
+ store_uri = params.get("store_uri")
+ api_key = params.get("api_key")
+
+ url, api_key, _, _ = resolve_qdrant_config(
+ url=store_uri, api_key=api_key,
+ )
super(Processor, self).__init__(
**params | {
"id": id,
- "store_uri": store_uri,
+ "store_uri": url,
"api_key": api_key,
}
)
@@ -62,7 +66,7 @@ class Processor(FlowProcessor):
)
)
- self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
+ self.qdrant = QdrantClient(url=url, api_key=api_key)
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Qdrant collection naming"""
@@ -192,21 +196,9 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
- """Add command-line arguments"""
FlowProcessor.add_args(parser)
-
- parser.add_argument(
- '-t', '--store-uri',
- default=default_store_uri,
- help=f'Qdrant store URI (default: {default_store_uri})'
- )
-
- parser.add_argument(
- '-k', '--api-key',
- default=None,
- help='API key for Qdrant (default: None)'
- )
+ add_qdrant_args(parser)
parser.add_argument(
'-c', '--concurrency',
diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py
index 2bfef99c..c212fa86 100644
--- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py
+++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py
@@ -14,29 +14,34 @@ from qdrant_client.models import Distance, VectorParams
from .... base import DocumentEmbeddingsStoreService, CollectionConfigHandler
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
+from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "doc-embeddings-write"
-default_store_uri = 'http://localhost:6333'
-
class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
def __init__(self, **params):
- store_uri = params.get("store_uri", default_store_uri)
- api_key = params.get("api_key", None)
+ store_uri = params.get("store_uri")
+ api_key = params.get("api_key")
+
+ url, api_key, replication_factor, shard_number = resolve_qdrant_config(
+ url=store_uri, api_key=api_key,
+ )
super(Processor, self).__init__(
**params | {
- "store_uri": store_uri,
+ "store_uri": url,
"api_key": api_key,
}
)
- self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
+ self.qdrant = QdrantClient(url=url, api_key=api_key)
+ self.replication_factor = replication_factor
+ self.shard_number = shard_number
self._cache_lock = asyncio.Lock()
self._known_collections: set[str] = set()
@@ -61,6 +66,8 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
+ replication_factor=self.replication_factor,
+ shard_number=self.shard_number,
)
self._known_collections.add(collection_name)
@@ -109,18 +116,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
def add_args(parser):
DocumentEmbeddingsStoreService.add_args(parser)
-
- parser.add_argument(
- '-t', '--store-uri',
- default=default_store_uri,
- help=f'Qdrant URI (default: {default_store_uri})'
- )
-
- parser.add_argument(
- '-k', '--api-key',
- default=None,
- help=f'Qdrant API key (default: None)'
- )
+ add_qdrant_args(parser)
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""
diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py
index 13dcdba8..ab04e42e 100755
--- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py
+++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py
@@ -14,6 +14,7 @@ from qdrant_client.models import Distance, VectorParams
from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
+from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
from .... schema import IRI, LITERAL
# Module logger
@@ -29,29 +30,32 @@ def get_term_value(term):
elif term.type == LITERAL:
return term.value
else:
- # For blank nodes or other types, use id or value
return term.id or term.value
default_ident = "graph-embeddings-write"
-default_store_uri = 'http://localhost:6333'
-
class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
def __init__(self, **params):
- store_uri = params.get("store_uri", default_store_uri)
- api_key = params.get("api_key", None)
+ store_uri = params.get("store_uri")
+ api_key = params.get("api_key")
+
+ url, api_key, replication_factor, shard_number = resolve_qdrant_config(
+ url=store_uri, api_key=api_key,
+ )
super(Processor, self).__init__(
**params | {
- "store_uri": store_uri,
+ "store_uri": url,
"api_key": api_key,
}
)
- self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
+ self.qdrant = QdrantClient(url=url, api_key=api_key)
+ self.replication_factor = replication_factor
+ self.shard_number = shard_number
self._cache_lock = asyncio.Lock()
self._known_collections: set[str] = set()
@@ -76,6 +80,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
+ replication_factor=self.replication_factor,
+ shard_number=self.shard_number,
)
self._known_collections.add(collection_name)
@@ -128,18 +134,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
def add_args(parser):
GraphEmbeddingsStoreService.add_args(parser)
-
- parser.add_argument(
- '-t', '--store-uri',
- default=default_store_uri,
- help=f'Qdrant store URI (default: {default_store_uri})'
- )
-
- parser.add_argument(
- '-k', '--api-key',
- default=None,
- help=f'Qdrant API key'
- )
+ add_qdrant_args(parser)
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""
diff --git a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py
index a01629c5..9071dbc1 100644
--- a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py
+++ b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py
@@ -27,12 +27,12 @@ from qdrant_client.models import PointStruct, Distance, VectorParams
from .... schema import RowEmbeddings
from .... base import FlowProcessor, ConsumerSpec
from .... base import CollectionConfigHandler
+from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "row-embeddings-write"
-default_store_uri = 'http://localhost:6333'
class Processor(CollectionConfigHandler, FlowProcessor):
@@ -41,13 +41,17 @@ class Processor(CollectionConfigHandler, FlowProcessor):
id = params.get("id", default_ident)
- store_uri = params.get("store_uri", default_store_uri)
- api_key = params.get("api_key", None)
+ store_uri = params.get("store_uri")
+ api_key = params.get("api_key")
+
+ url, api_key, replication_factor, shard_number = resolve_qdrant_config(
+ url=store_uri, api_key=api_key,
+ )
super(Processor, self).__init__(
**params | {
"id": id,
- "store_uri": store_uri,
+ "store_uri": url,
"api_key": api_key,
}
)
@@ -63,7 +67,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
# Register config handler for collection management
self.register_config_handler(self.on_collection_config, types=["collection"])
- self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
+ self.qdrant = QdrantClient(url=url, api_key=api_key)
+ self.replication_factor = replication_factor
+ self.shard_number = shard_number
self._cache_lock = asyncio.Lock()
self._known_collections: set[str] = set()
@@ -103,6 +109,8 @@ class Processor(CollectionConfigHandler, FlowProcessor):
size=dimension,
distance=Distance.COSINE
),
+ replication_factor=self.replication_factor,
+ shard_number=self.shard_number,
)
self._known_collections.add(collection_name)
@@ -249,21 +257,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
@staticmethod
def add_args(parser):
- """Add command-line arguments"""
FlowProcessor.add_args(parser)
-
- parser.add_argument(
- '-t', '--store-uri',
- default=default_store_uri,
- help=f'Qdrant URI (default: {default_store_uri})'
- )
-
- parser.add_argument(
- '-k', '--api-key',
- default=None,
- help='Qdrant API key (default: None)'
- )
+ add_qdrant_args(parser)
def run():
diff --git a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py
index 65eeee06..12345e46 100755
--- a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py
+++ b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py
@@ -47,7 +47,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
- hosts, username, password, keyspace, _ = resolve_cassandra_config(
+ hosts, username, password, keyspace, replication_factor = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
password=cassandra_password
@@ -57,6 +57,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
self.cassandra_host = hosts # Store as list
self.cassandra_username = username
self.cassandra_password = password
+ self.replication_factor = replication_factor
# Config key for schemas
self.config_key = params.get("config_type", "schema")
@@ -232,7 +233,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
CREATE KEYSPACE IF NOT EXISTS {safe_keyspace}
WITH REPLICATION = {{
'class': 'SimpleStrategy',
- 'replication_factor': 1
+ 'replication_factor': {self.replication_factor}
}}
"""
diff --git a/trustgraph-flow/trustgraph/tables/config.py b/trustgraph-flow/trustgraph/tables/config.py
index 74ceb6f4..c87cb3b5 100644
--- a/trustgraph-flow/trustgraph/tables/config.py
+++ b/trustgraph-flow/trustgraph/tables/config.py
@@ -4,7 +4,7 @@ from .. schema import Metadata, GraphEmbeddings
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
-from ssl import SSLContext, PROTOCOL_TLSv1_2
+import ssl
import uuid
import time
@@ -33,7 +33,7 @@ class ConfigTableStore:
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
if cassandra_username and cassandra_password:
- ssl_context = SSLContext(PROTOCOL_TLSv1_2)
+ ssl_context = ssl.create_default_context()
auth_provider = PlainTextAuthProvider(
username=cassandra_username, password=cassandra_password
)
diff --git a/trustgraph-flow/trustgraph/tables/iam.py b/trustgraph-flow/trustgraph/tables/iam.py
index d7bf5e3d..b60e9cff 100644
--- a/trustgraph-flow/trustgraph/tables/iam.py
+++ b/trustgraph-flow/trustgraph/tables/iam.py
@@ -15,7 +15,7 @@ import logging
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
-from ssl import SSLContext, PROTOCOL_TLSv1_2
+import ssl
from . cassandra_async import async_execute
@@ -39,7 +39,7 @@ class IamTableStore:
cassandra_host = [h.strip() for h in cassandra_host.split(",")]
if cassandra_username and cassandra_password:
- ssl_context = SSLContext(PROTOCOL_TLSv1_2)
+ ssl_context = ssl.create_default_context()
auth_provider = PlainTextAuthProvider(
username=cassandra_username, password=cassandra_password,
)
diff --git a/trustgraph-flow/trustgraph/tables/knowledge.py b/trustgraph-flow/trustgraph/tables/knowledge.py
index 4fcb2dd3..53a12b35 100644
--- a/trustgraph-flow/trustgraph/tables/knowledge.py
+++ b/trustgraph-flow/trustgraph/tables/knowledge.py
@@ -23,7 +23,7 @@ def tuple_to_term(value, is_uri):
else:
return Term(type=LITERAL, value=value)
from cassandra.auth import PlainTextAuthProvider
-from ssl import SSLContext, PROTOCOL_TLSv1_2
+import ssl
import uuid
import time
@@ -50,7 +50,7 @@ class KnowledgeTableStore:
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
if cassandra_username and cassandra_password:
- ssl_context = SSLContext(PROTOCOL_TLSv1_2)
+ ssl_context = ssl.create_default_context()
auth_provider = PlainTextAuthProvider(
username=cassandra_username, password=cassandra_password
)
diff --git a/trustgraph-flow/trustgraph/tables/library.py b/trustgraph-flow/trustgraph/tables/library.py
index 58486f0e..5094e103 100644
--- a/trustgraph-flow/trustgraph/tables/library.py
+++ b/trustgraph-flow/trustgraph/tables/library.py
@@ -24,7 +24,7 @@ from .. exceptions import RequestError
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.query import BatchStatement
-from ssl import SSLContext, PROTOCOL_TLSv1_2
+import ssl
import uuid
import time
@@ -53,7 +53,7 @@ class LibraryTableStore:
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
if cassandra_username and cassandra_password:
- ssl_context = SSLContext(PROTOCOL_TLSv1_2)
+ ssl_context = ssl.create_default_context()
auth_provider = PlainTextAuthProvider(
username=cassandra_username, password=cassandra_password
)
From 08bfec153992c1e1a24719b410726e887fc16825 Mon Sep 17 00:00:00 2001
From: cybermaggedon
Date: Thu, 4 Jun 2026 12:36:36 +0100
Subject: [PATCH 13/18] fix: wire replication params through YAML/params path
for Cassandra and Qdrant (#976)
resolve_cassandra_config did not accept replication_factor as a kwarg,
so cassandra_replication_factor from YAML params was silently ignored
by all 6 callers. Add the kwarg and pass it from every caller.
Same fix for Qdrant: 3 writers now pass qdrant_replication_factor and
qdrant_shard_number from params.
Add tests covering the params path for both helpers.
---
tests/unit/test_base/test_cassandra_config.py | 55 ++++++-
tests/unit/test_base/test_qdrant_config.py | 136 ++++++++++++++++++
.../trustgraph/base/cassandra_config.py | 26 +---
.../trustgraph/config/service/service.py | 3 +-
trustgraph-flow/trustgraph/cores/service.py | 3 +-
.../trustgraph/iam/service/service.py | 1 +
.../trustgraph/librarian/service.py | 3 +-
.../query/doc_embeddings/qdrant/service.py | 3 +-
.../storage/doc_embeddings/qdrant/write.py | 2 +
.../storage/graph_embeddings/qdrant/write.py | 2 +
.../trustgraph/storage/knowledge/store.py | 3 +-
.../storage/row_embeddings/qdrant/write.py | 2 +
.../storage/rows/cassandra/write.py | 3 +-
13 files changed, 214 insertions(+), 28 deletions(-)
create mode 100644 tests/unit/test_base/test_qdrant_config.py
diff --git a/tests/unit/test_base/test_cassandra_config.py b/tests/unit/test_base/test_cassandra_config.py
index a291434d..fe8a8379 100644
--- a/tests/unit/test_base/test_cassandra_config.py
+++ b/tests/unit/test_base/test_cassandra_config.py
@@ -409,4 +409,57 @@ class TestEdgeCases:
assert hosts == ['mixed-host']
assert username is None # Stays None
- assert password == 'mixed-pass'
\ No newline at end of file
+ assert password == 'mixed-pass'
+
+
+class TestReplicationFactorParamPath:
+
+ def test_explicit_kwarg(self):
+ with patch.dict(os.environ, {}, clear=True):
+ _, _, _, _, rf = resolve_cassandra_config(
+ replication_factor=3,
+ )
+ assert rf == 3
+
+ def test_kwarg_overrides_env(self):
+ with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
+ _, _, _, _, rf = resolve_cassandra_config(
+ replication_factor=3,
+ )
+ assert rf == 3
+
+ def test_env_fallback_when_kwarg_none(self):
+ with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
+ _, _, _, _, rf = resolve_cassandra_config(
+ replication_factor=None,
+ )
+ assert rf == 5
+
+ def test_default_when_no_kwarg_no_env(self):
+ with patch.dict(os.environ, {}, clear=True):
+ _, _, _, _, rf = resolve_cassandra_config()
+ assert rf == 1
+
+ def test_params_dict_path(self):
+ with patch.dict(os.environ, {}, clear=True):
+ params = {'cassandra_replication_factor': 3}
+ _, _, _, _, rf = resolve_cassandra_config(
+ replication_factor=params.get('cassandra_replication_factor'),
+ )
+ assert rf == 3
+
+ def test_params_dict_overrides_env(self):
+ with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
+ params = {'cassandra_replication_factor': 3}
+ _, _, _, _, rf = resolve_cassandra_config(
+ replication_factor=params.get('cassandra_replication_factor'),
+ )
+ assert rf == 3
+
+ def test_params_dict_missing_falls_to_env(self):
+ with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
+ params = {}
+ _, _, _, _, rf = resolve_cassandra_config(
+ replication_factor=params.get('cassandra_replication_factor'),
+ )
+ assert rf == 5
\ No newline at end of file
diff --git a/tests/unit/test_base/test_qdrant_config.py b/tests/unit/test_base/test_qdrant_config.py
new file mode 100644
index 00000000..dbbe4214
--- /dev/null
+++ b/tests/unit/test_base/test_qdrant_config.py
@@ -0,0 +1,136 @@
+
+import os
+import pytest
+from unittest.mock import patch
+
+from trustgraph.base.qdrant_config import (
+ get_qdrant_defaults,
+ resolve_qdrant_config,
+)
+
+
+class TestGetQdrantDefaults:
+
+ def test_defaults_with_no_env_vars(self):
+ with patch.dict(os.environ, {}, clear=True):
+ defaults = get_qdrant_defaults()
+ assert defaults['url'] == 'http://localhost:6333'
+ assert defaults['api_key'] is None
+ assert defaults['replication_factor'] == 1
+ assert defaults['shard_number'] == 1
+
+ def test_defaults_from_env(self):
+ env = {
+ 'QDRANT_URL': 'http://qdrant:6333',
+ 'QDRANT_API_KEY': 'secret',
+ 'QDRANT_REPLICATION_FACTOR': '3',
+ 'QDRANT_SHARD_NUMBER': '5',
+ }
+ with patch.dict(os.environ, env, clear=True):
+ defaults = get_qdrant_defaults()
+ assert defaults['url'] == 'http://qdrant:6333'
+ assert defaults['api_key'] == 'secret'
+ assert defaults['replication_factor'] == 3
+ assert defaults['shard_number'] == 5
+
+
+class TestResolveQdrantConfig:
+
+ def test_defaults(self):
+ with patch.dict(os.environ, {}, clear=True):
+ url, api_key, rf, sn = resolve_qdrant_config()
+ assert url == 'http://localhost:6333'
+ assert api_key is None
+ assert rf == 1
+ assert sn == 1
+
+ def test_explicit_kwargs(self):
+ with patch.dict(os.environ, {}, clear=True):
+ url, api_key, rf, sn = resolve_qdrant_config(
+ url='http://custom:6333',
+ api_key='key',
+ replication_factor=3,
+ shard_number=5,
+ )
+ assert url == 'http://custom:6333'
+ assert api_key == 'key'
+ assert rf == 3
+ assert sn == 5
+
+ def test_kwargs_override_env(self):
+ env = {
+ 'QDRANT_URL': 'http://env:6333',
+ 'QDRANT_REPLICATION_FACTOR': '10',
+ 'QDRANT_SHARD_NUMBER': '10',
+ }
+ with patch.dict(os.environ, env, clear=True):
+ url, _, rf, sn = resolve_qdrant_config(
+ url='http://explicit:6333',
+ replication_factor=3,
+ shard_number=5,
+ )
+ assert url == 'http://explicit:6333'
+ assert rf == 3
+ assert sn == 5
+
+ def test_env_fallback_when_kwargs_none(self):
+ env = {
+ 'QDRANT_URL': 'http://env:6333',
+ 'QDRANT_REPLICATION_FACTOR': '3',
+ 'QDRANT_SHARD_NUMBER': '5',
+ }
+ with patch.dict(os.environ, env, clear=True):
+ url, _, rf, sn = resolve_qdrant_config()
+ assert url == 'http://env:6333'
+ assert rf == 3
+ assert sn == 5
+
+ def test_params_dict_path(self):
+ with patch.dict(os.environ, {}, clear=True):
+ params = {
+ 'store_uri': 'http://params:6333',
+ 'api_key': 'pkey',
+ 'qdrant_replication_factor': 3,
+ 'qdrant_shard_number': 5,
+ }
+ url, api_key, rf, sn = resolve_qdrant_config(
+ url=params.get('store_uri'),
+ api_key=params.get('api_key'),
+ replication_factor=params.get('qdrant_replication_factor'),
+ shard_number=params.get('qdrant_shard_number'),
+ )
+ assert url == 'http://params:6333'
+ assert api_key == 'pkey'
+ assert rf == 3
+ assert sn == 5
+
+ def test_params_dict_overrides_env(self):
+ env = {
+ 'QDRANT_REPLICATION_FACTOR': '10',
+ 'QDRANT_SHARD_NUMBER': '10',
+ }
+ with patch.dict(os.environ, env, clear=True):
+ params = {
+ 'qdrant_replication_factor': 3,
+ 'qdrant_shard_number': 5,
+ }
+ _, _, rf, sn = resolve_qdrant_config(
+ replication_factor=params.get('qdrant_replication_factor'),
+ shard_number=params.get('qdrant_shard_number'),
+ )
+ assert rf == 3
+ assert sn == 5
+
+ def test_params_dict_missing_falls_to_env(self):
+ env = {
+ 'QDRANT_REPLICATION_FACTOR': '3',
+ 'QDRANT_SHARD_NUMBER': '5',
+ }
+ with patch.dict(os.environ, env, clear=True):
+ params = {}
+ _, _, rf, sn = resolve_qdrant_config(
+ replication_factor=params.get('qdrant_replication_factor'),
+ shard_number=params.get('qdrant_shard_number'),
+ )
+ assert rf == 3
+ assert sn == 5
diff --git a/trustgraph-base/trustgraph/base/cassandra_config.py b/trustgraph-base/trustgraph/base/cassandra_config.py
index 78505c68..b2e36fbd 100644
--- a/trustgraph-base/trustgraph/base/cassandra_config.py
+++ b/trustgraph-base/trustgraph/base/cassandra_config.py
@@ -103,35 +103,19 @@ def resolve_cassandra_config(
host: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
- default_keyspace: Optional[str] = None
+ default_keyspace: Optional[str] = None,
+ replication_factor: Optional[int] = None,
) -> Tuple[List[str], Optional[str], Optional[str], Optional[str], int]:
- """
- Resolve Cassandra configuration from various sources.
-
- Can accept either argparse args object or explicit parameters.
- Converts host string to list format for Cassandra driver.
-
- Args:
- args: Optional argparse namespace with cassandra_host, cassandra_username, cassandra_password, cassandra_keyspace, cassandra_replication_factor
- host: Optional explicit host parameter (overrides args)
- username: Optional explicit username parameter (overrides args)
- password: Optional explicit password parameter (overrides args)
- default_keyspace: Optional default keyspace if not specified elsewhere
-
- Returns:
- tuple: (hosts_list, username, password, keyspace, replication_factor)
- """
- # If args provided, extract values
keyspace = None
- replication_factor = 1
if args is not None:
host = host or getattr(args, 'cassandra_host', None)
username = username or getattr(args, 'cassandra_username', None)
password = password or getattr(args, 'cassandra_password', None)
keyspace = getattr(args, 'cassandra_keyspace', None)
- replication_factor = getattr(args, 'cassandra_replication_factor', 1)
+ replication_factor = replication_factor or getattr(
+ args, 'cassandra_replication_factor', None
+ )
- # Apply defaults if still None
defaults = get_cassandra_defaults()
host = host or defaults['host']
username = username or defaults['username']
diff --git a/trustgraph-flow/trustgraph/config/service/service.py b/trustgraph-flow/trustgraph/config/service/service.py
index c5fac198..725f1106 100644
--- a/trustgraph-flow/trustgraph/config/service/service.py
+++ b/trustgraph-flow/trustgraph/config/service/service.py
@@ -83,7 +83,8 @@ class Processor(AsyncProcessor):
host=cassandra_host,
username=cassandra_username,
password=cassandra_password,
- default_keyspace="config"
+ default_keyspace="config",
+ replication_factor=params.get("cassandra_replication_factor"),
)
# Store resolved configuration
diff --git a/trustgraph-flow/trustgraph/cores/service.py b/trustgraph-flow/trustgraph/cores/service.py
index a8f52efd..5c50c207 100755
--- a/trustgraph-flow/trustgraph/cores/service.py
+++ b/trustgraph-flow/trustgraph/cores/service.py
@@ -61,7 +61,8 @@ class Processor(WorkspaceProcessor):
host=cassandra_host,
username=cassandra_username,
password=cassandra_password,
- default_keyspace="knowledge"
+ default_keyspace="knowledge",
+ replication_factor=params.get("cassandra_replication_factor"),
)
self.cassandra_host = hosts
diff --git a/trustgraph-flow/trustgraph/iam/service/service.py b/trustgraph-flow/trustgraph/iam/service/service.py
index 8ce22757..b2f3976d 100644
--- a/trustgraph-flow/trustgraph/iam/service/service.py
+++ b/trustgraph-flow/trustgraph/iam/service/service.py
@@ -101,6 +101,7 @@ class Processor(AsyncProcessor):
username=cassandra_username,
password=cassandra_password,
default_keyspace="iam",
+ replication_factor=params.get("cassandra_replication_factor"),
)
self.cassandra_host = hosts
diff --git a/trustgraph-flow/trustgraph/librarian/service.py b/trustgraph-flow/trustgraph/librarian/service.py
index ee5e9c1b..4d3efbfb 100755
--- a/trustgraph-flow/trustgraph/librarian/service.py
+++ b/trustgraph-flow/trustgraph/librarian/service.py
@@ -146,7 +146,8 @@ class Processor(WorkspaceProcessor):
host=cassandra_host,
username=cassandra_username,
password=cassandra_password,
- default_keyspace="librarian"
+ default_keyspace="librarian",
+ replication_factor=params.get("cassandra_replication_factor"),
)
# Store resolved configuration
diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py
index b98ab7e5..de25a139 100755
--- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py
+++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py
@@ -27,7 +27,8 @@ class Processor(DocumentEmbeddingsQueryService):
api_key = params.get("api_key")
url, api_key, _, _ = resolve_qdrant_config(
- url=store_uri, api_key=api_key,
+ url=store_uri,
+ api_key=api_key,
)
super(Processor, self).__init__(
diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py
index c212fa86..08d88849 100644
--- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py
+++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py
@@ -30,6 +30,8 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
url, api_key, replication_factor, shard_number = resolve_qdrant_config(
url=store_uri, api_key=api_key,
+ replication_factor=params.get("qdrant_replication_factor"),
+ shard_number=params.get("qdrant_shard_number"),
)
super(Processor, self).__init__(
diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py
index ab04e42e..b6072bdc 100755
--- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py
+++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py
@@ -44,6 +44,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
url, api_key, replication_factor, shard_number = resolve_qdrant_config(
url=store_uri, api_key=api_key,
+ replication_factor=params.get("qdrant_replication_factor"),
+ shard_number=params.get("qdrant_shard_number"),
)
super(Processor, self).__init__(
diff --git a/trustgraph-flow/trustgraph/storage/knowledge/store.py b/trustgraph-flow/trustgraph/storage/knowledge/store.py
index 162a4057..f6e12a85 100644
--- a/trustgraph-flow/trustgraph/storage/knowledge/store.py
+++ b/trustgraph-flow/trustgraph/storage/knowledge/store.py
@@ -27,7 +27,8 @@ class Processor(FlowProcessor):
host=params.get("cassandra_host"),
username=params.get("cassandra_username"),
password=params.get("cassandra_password"),
- default_keyspace='knowledge'
+ default_keyspace='knowledge',
+ replication_factor=params.get("cassandra_replication_factor"),
)
super(Processor, self).__init__(
diff --git a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py
index 9071dbc1..4c65edb1 100644
--- a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py
+++ b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py
@@ -46,6 +46,8 @@ class Processor(CollectionConfigHandler, FlowProcessor):
url, api_key, replication_factor, shard_number = resolve_qdrant_config(
url=store_uri, api_key=api_key,
+ replication_factor=params.get("qdrant_replication_factor"),
+ shard_number=params.get("qdrant_shard_number"),
)
super(Processor, self).__init__(
diff --git a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py
index 12345e46..e5506723 100755
--- a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py
+++ b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py
@@ -50,7 +50,8 @@ class Processor(CollectionConfigHandler, FlowProcessor):
hosts, username, password, keyspace, replication_factor = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
- password=cassandra_password
+ password=cassandra_password,
+ replication_factor=params.get("cassandra_replication_factor"),
)
# Store resolved configuration with proper names
From dbc21c0bb9dfbc2971c97680e9903e5738c81da6 Mon Sep 17 00:00:00 2001
From: cybermaggedon
Date: Mon, 8 Jun 2026 15:22:11 +0100
Subject: [PATCH 14/18] fix: structured data query and auth fixes (#978)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- Pass auth token to schema discovery and descriptor generation in
tg-load-structured-data, fixing 401 errors with IAM enabled
- Fix row query pagination: replace single-page async_execute with
async_scan that streams pages and applies filters without
materialising the full result set (OOM on large datasets)
- Add missing filter operators (not, startsWith, endsWith, not_in)
to row query post-filter matching
- Fall back to scan path when an indexed field is queried with an
empty string value, since empty index values are not stored
- Revert top-level indexes array support — the current table schema
overwrites rows with duplicate index values, so only primary_key
fields are safe to index until the schema is redesigned
---
.../trustgraph/cli/load_structured_data.py | 18 +++----
.../query/rows/cassandra/service.py | 53 ++++++++++++-------
.../storage/rows/cassandra/write.py | 2 +-
.../trustgraph/tables/cassandra_async.py | 51 +++++++++++++++++-
4 files changed, 93 insertions(+), 31 deletions(-)
diff --git a/trustgraph-cli/trustgraph/cli/load_structured_data.py b/trustgraph-cli/trustgraph/cli/load_structured_data.py
index dccf548e..5649a5ae 100644
--- a/trustgraph-cli/trustgraph/cli/load_structured_data.py
+++ b/trustgraph-cli/trustgraph/cli/load_structured_data.py
@@ -78,7 +78,7 @@ def load_structured_data(
logger.info("Step 1: Analyzing data to discover best matching schema...")
# Step 1: Auto-discover schema (reuse discover_schema logic)
- discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace)
+ discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, token=token, workspace=workspace)
if not discovered_schema:
logger.error("Failed to discover suitable schema automatically")
print("❌ Could not automatically determine the best schema for your data.")
@@ -90,7 +90,7 @@ def load_structured_data(
# Step 2: Auto-generate descriptor
logger.info("Step 2: Generating descriptor configuration...")
- auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, workspace=workspace)
+ auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, token=token, workspace=workspace)
if not auto_descriptor:
logger.error("Failed to generate descriptor automatically")
print("❌ Could not automatically generate descriptor configuration.")
@@ -172,7 +172,7 @@ def load_structured_data(
logger.info(f"Sample chars: {sample_chars} characters")
# Use the helper function to discover schema (get raw response for display)
- response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, workspace=workspace)
+ response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, token=token, workspace=workspace)
if response:
# Debug: print response type and content
@@ -203,7 +203,7 @@ def load_structured_data(
# If no schema specified, discover it first
if not schema_name:
logger.info("No schema specified, auto-discovering...")
- schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace)
+ schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, token=token, workspace=workspace)
if not schema_name:
print("Error: Could not determine schema automatically.")
print("Please specify a schema using --schema-name or run --discover-schema first.")
@@ -213,7 +213,7 @@ def load_structured_data(
logger.info(f"Target schema: {schema_name}")
# Generate descriptor using helper function
- descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace=workspace)
+ descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, token=token, workspace=workspace)
if descriptor:
# Output the generated descriptor
@@ -603,7 +603,7 @@ def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None, worksp
# Helper functions for auto mode
-def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, workspace="default"):
+def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, token=None, workspace="default"):
"""Auto-discover the best matching schema for the input data
Args:
@@ -626,7 +626,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur
# Import API modules
from trustgraph.api import Api
from trustgraph.api.types import ConfigKey
- api = Api(api_url, workspace=workspace)
+ api = Api(api_url, token=token, workspace=workspace)
config_api = api.config()
# Get available schemas
@@ -707,7 +707,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur
return None
-def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace="default"):
+def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, token=None, workspace="default"):
"""Auto-generate descriptor configuration for the discovered schema"""
try:
# Read sample data
@@ -717,7 +717,7 @@ def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, fl
# Import API modules
from trustgraph.api import Api
from trustgraph.api.types import ConfigKey
- api = Api(api_url, workspace=workspace)
+ api = Api(api_url, token=token, workspace=workspace)
config_api = api.config()
# Get schema definition
diff --git a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py
index 7157daae..f9868d67 100644
--- a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py
+++ b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py
@@ -24,7 +24,7 @@ from .... schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
from .... schema import Error, RowSchema, Field as SchemaField
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
-from .... tables.cassandra_async import async_execute
+from .... tables.cassandra_async import async_execute, async_execute_paged, async_scan
from ... graphql import GraphQLSchemaBuilder, SortDirection
@@ -180,7 +180,7 @@ class Processor(FlowProcessor):
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
- indexed=field_def.get("indexed", False)
+ indexed=field_def.get("indexed", False),
)
fields.append(field)
@@ -232,6 +232,8 @@ class Processor(FlowProcessor):
for index_name in index_names:
if index_name in filters:
value = filters[index_name]
+ if value == "" or value is None:
+ continue
# Single field index -> single element list
index_value = [str(value)]
return (index_name, index_value)
@@ -282,11 +284,13 @@ class Processor(FlowProcessor):
query += f" LIMIT {limit}"
try:
- rows = await async_execute(self.session, query, params)
- for row in rows:
- # Convert data map to dict with proper field names
- row_dict = dict(row.data) if row.data else {}
- results.append(row_dict)
+ pages = await async_execute_paged(
+ self.session, query, params
+ )
+ for page in pages:
+ for row in page:
+ row_dict = dict(row.data) if row.data else {}
+ results.append(row_dict)
except Exception as e:
logger.error(f"Failed to query rows: {e}", exc_info=True)
raise
@@ -308,8 +312,6 @@ class Processor(FlowProcessor):
# Query using the first index (arbitrary choice for scan)
primary_index = index_names[0]
- # We need to scan all values for this index
- # This requires ALLOW FILTERING or a different approach
query = f"""
SELECT data, source FROM {safe_keyspace}.rows
WHERE collection = %s
@@ -320,17 +322,18 @@ class Processor(FlowProcessor):
params = [collection, schema_name, primary_index]
try:
- rows = await async_execute(self.session, query, params)
-
- for row in rows:
+ def row_filter(row):
row_dict = dict(row.data) if row.data else {}
+ return self._matches_filters(row_dict, filters, row_schema)
- # Apply post-filters
- if self._matches_filters(row_dict, filters, row_schema):
- results.append(row_dict)
-
- if limit and len(results) >= limit:
- break
+ matched_rows = await async_scan(
+ self.session, query, params,
+ row_filter=row_filter,
+ limit=limit,
+ )
+ for row in matched_rows:
+ row_dict = dict(row.data) if row.data else {}
+ results.append(row_dict)
except Exception as e:
logger.error(f"Failed to scan rows: {e}", exc_info=True)
@@ -363,7 +366,7 @@ class Processor(FlowProcessor):
# Parse filter key for operator
if '_' in filter_key:
parts = filter_key.rsplit('_', 1)
- if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in']:
+ if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in', 'not', 'startsWith', 'endsWith', 'not_in']:
field_name = parts[0]
operator = parts[1]
else:
@@ -400,6 +403,18 @@ class Processor(FlowProcessor):
elif operator == 'in':
if str(row_value) not in [str(v) for v in filter_value]:
return False
+ elif operator == 'not':
+ if str(row_value) == str(filter_value):
+ return False
+ elif operator == 'startsWith':
+ if not str(row_value).startswith(str(filter_value)):
+ return False
+ elif operator == 'endsWith':
+ if not str(row_value).endswith(str(filter_value)):
+ return False
+ elif operator == 'not_in':
+ if str(row_value) in [str(v) for v in filter_value]:
+ return False
except (ValueError, TypeError):
return False
diff --git a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py
index e5506723..31fc41a7 100755
--- a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py
+++ b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py
@@ -172,7 +172,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
- indexed=field_def.get("indexed", False)
+ indexed=field_def.get("indexed", False),
)
fields.append(field)
diff --git a/trustgraph-flow/trustgraph/tables/cassandra_async.py b/trustgraph-flow/trustgraph/tables/cassandra_async.py
index 205ed6b9..fe410a26 100644
--- a/trustgraph-flow/trustgraph/tables/cassandra_async.py
+++ b/trustgraph-flow/trustgraph/tables/cassandra_async.py
@@ -80,14 +80,14 @@ def _set_exception_if_pending(fut, exc):
fut.set_exception(exc)
-async def async_execute_paged(session, query, parameters=None, fetch_size=100):
+async def async_execute_paged(session, query, parameters=None, fetch_size=5000):
"""Execute a CQL query with page-by-page iteration.
Uses synchronous session.execute() inside run_in_executor so that
the driver's ResultSet paging works correctly without materialising
the entire result set in memory.
- Yields one page of rows at a time (as a list).
+ Returns all pages as a list of lists.
"""
loop = asyncio.get_running_loop()
@@ -111,3 +111,50 @@ async def async_execute_paged(session, query, parameters=None, fetch_size=100):
return await loop.run_in_executor(
None, _fetch_all_pages
)
+
+
+async def async_scan(
+ session, query, parameters=None, row_filter=None,
+ limit=None, fetch_size=5000,
+):
+ """Scan a CQL query page-by-page, applying a filter and limit.
+
+ Only matching rows accumulate in memory. Each page is discarded
+ after processing, so peak memory is bounded by fetch_size plus
+ the number of matching rows (capped by limit).
+
+ Args:
+ session: cassandra.cluster.Session
+ query: CQL statement string
+ parameters: bind params
+ row_filter: callable(row) -> bool, or None to accept all
+ limit: max results to return, or None for unlimited
+ fetch_size: rows per Cassandra page fetch
+
+ Returns:
+ List of matching rows.
+ """
+ loop = asyncio.get_running_loop()
+
+ if isinstance(query, str):
+ stmt = SimpleStatement(query, fetch_size=fetch_size)
+ else:
+ stmt = query
+ stmt.fetch_size = fetch_size
+
+ def _scan():
+ results = []
+ result_set = session.execute(stmt, parameters)
+ while True:
+ for row in result_set.current_rows:
+ if row_filter is None or row_filter(row):
+ results.append(row)
+ if limit and len(results) >= limit:
+ return results
+ if result_set.has_more_pages:
+ result_set.fetch_next_page()
+ else:
+ break
+ return results
+
+ return await loop.run_in_executor(None, _scan)
From e1c93514543c32464ed739813b9ee527b1d06d7e Mon Sep 17 00:00:00 2001
From: cybermaggedon
Date: Tue, 9 Jun 2026 16:29:32 +0100
Subject: [PATCH 15/18] fix: update row query tests to mock async_execute_paged
and async_scan (#979)
The query service now uses async_execute_paged (indexed path) and
async_scan (scan path) instead of async_execute. Tests were mocking
the old function, causing them to hang indefinitely.
---
.../test_query/test_rows_cassandra_query.py | 31 ++++++++++---------
1 file changed, 16 insertions(+), 15 deletions(-)
diff --git a/tests/unit/test_query/test_rows_cassandra_query.py b/tests/unit/test_query/test_rows_cassandra_query.py
index b61500a4..fb385f43 100644
--- a/tests/unit/test_query/test_rows_cassandra_query.py
+++ b/tests/unit/test_query/test_rows_cassandra_query.py
@@ -333,8 +333,8 @@ class TestUnifiedTableQueries:
"""Test queries against the unified rows table"""
@pytest.mark.asyncio
- @patch('trustgraph.query.rows.cassandra.service.async_execute', new_callable=AsyncMock)
- async def test_query_with_index_match(self, mock_async_execute):
+ @patch('trustgraph.query.rows.cassandra.service.async_execute_paged', new_callable=AsyncMock)
+ async def test_query_with_index_match(self, mock_async_execute_paged):
"""Test query execution with matching index"""
processor = MagicMock()
processor.session = MagicMock()
@@ -344,10 +344,10 @@ class TestUnifiedTableQueries:
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
- # Mock async_execute to return test data
+ # Mock async_execute_paged to return test data (list of pages)
mock_row = MagicMock()
mock_row.data = {"id": "123", "name": "Test Product", "category": "electronics"}
- mock_async_execute.return_value = [mock_row]
+ mock_async_execute_paged.return_value = [[mock_row]]
schema = RowSchema(
name="products",
@@ -370,10 +370,10 @@ class TestUnifiedTableQueries:
# Verify Cassandra was connected and queried
processor.connect_cassandra.assert_called_once()
- mock_async_execute.assert_called_once()
+ mock_async_execute_paged.assert_called_once()
# Verify query structure - should query unified rows table
- call_args = mock_async_execute.call_args
+ call_args = mock_async_execute_paged.call_args
query = call_args[0][1]
params = call_args[0][2]
@@ -394,8 +394,8 @@ class TestUnifiedTableQueries:
assert results[0]["category"] == "electronics"
@pytest.mark.asyncio
- @patch('trustgraph.query.rows.cassandra.service.async_execute', new_callable=AsyncMock)
- async def test_query_without_index_match(self, mock_async_execute):
+ @patch('trustgraph.query.rows.cassandra.service.async_scan', new_callable=AsyncMock)
+ async def test_query_without_index_match(self, mock_async_scan):
"""Test query execution without matching index (scan mode)"""
processor = MagicMock()
processor.session = MagicMock()
@@ -406,12 +406,10 @@ class TestUnifiedTableQueries:
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
- # Mock async_execute to return test data
+ # Mock async_scan to return filtered test data
mock_row1 = MagicMock()
mock_row1.data = {"id": "1", "name": "Product A", "price": "100"}
- mock_row2 = MagicMock()
- mock_row2.data = {"id": "2", "name": "Product B", "price": "200"}
- mock_async_execute.return_value = [mock_row1, mock_row2]
+ mock_async_scan.return_value = [mock_row1]
schema = RowSchema(
name="products",
@@ -432,13 +430,16 @@ class TestUnifiedTableQueries:
limit=10
)
- # Query should use ALLOW FILTERING for scan
- call_args = mock_async_execute.call_args
+ # Verify async_scan was called
+ mock_async_scan.assert_called_once()
+
+ # Verify query structure
+ call_args = mock_async_scan.call_args
query = call_args[0][1]
assert "ALLOW FILTERING" in query
- # Should post-filter results
+ # Should return filtered results
assert len(results) == 1
assert results[0]["name"] == "Product A"
From 28a51c244f60c9fe189aa972e1fd58940cd44c64 Mon Sep 17 00:00:00 2001
From: Jacob Molz
Date: Tue, 9 Jun 2026 11:37:10 -0400
Subject: [PATCH 16/18] fix: reject invalid PDF decoder input (#977)
---
tests/unit/test_decoding/test_pdf_decoder.py | 48 ++++++++++++++-
.../trustgraph/decoding/pdf/pdf_decoder.py | 60 +++++++++++--------
2 files changed, 79 insertions(+), 29 deletions(-)
diff --git a/tests/unit/test_decoding/test_pdf_decoder.py b/tests/unit/test_decoding/test_pdf_decoder.py
index 04807b20..641a9d78 100644
--- a/tests/unit/test_decoding/test_pdf_decoder.py
+++ b/tests/unit/test_decoding/test_pdf_decoder.py
@@ -49,7 +49,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test successful PDF processing"""
# Mock PDF content
- pdf_content = b"fake pdf content"
+ pdf_content = b"%PDF-1.7\nfake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
# Mock PyPDFLoader
@@ -88,13 +88,55 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
# Verify triples were sent for each page (provenance)
assert mock_triples_flow.send.call_count == 2
+ @patch('trustgraph.base.librarian_client.Consumer')
+ @patch('trustgraph.base.librarian_client.Producer')
+ @patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
+ @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
+ async def test_on_message_rejects_librarian_content_that_is_not_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer):
+ """Test rejecting non-PDF content before invoking the PDF loader"""
+ html_content = b"Not found"
+ html_base64 = base64.b64encode(html_content)
+
+ mock_metadata = Metadata(id="test-doc")
+ mock_document = Document(metadata=mock_metadata, document_id="doc-123")
+ mock_msg = MagicMock()
+ mock_msg.value.return_value = mock_document
+
+ mock_output_flow = AsyncMock()
+ mock_triples_flow = AsyncMock()
+ mock_flow = MagicMock(side_effect=lambda name: {
+ "output": mock_output_flow,
+ "triples": mock_triples_flow,
+ }.get(name))
+ mock_flow.librarian.fetch_document_metadata = AsyncMock(
+ return_value=MagicMock(kind="application/pdf")
+ )
+ mock_flow.librarian.fetch_document_content = AsyncMock(
+ return_value=html_base64
+ )
+ mock_flow.librarian.save_child_document = AsyncMock()
+
+ config = {
+ 'id': 'test-pdf-decoder',
+ 'taskgroup': AsyncMock()
+ }
+
+ processor = Processor(**config)
+
+ await processor.on_message(mock_msg, None, mock_flow)
+
+ mock_pdf_loader_class.assert_not_called()
+ mock_output_flow.send.assert_not_called()
+ mock_triples_flow.send.assert_not_called()
+ mock_flow.librarian.save_child_document.assert_not_called()
+
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test handling of empty PDF"""
- pdf_content = b"fake pdf content"
+ pdf_content = b"%PDF-1.7\nfake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
mock_loader = MagicMock()
@@ -126,7 +168,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test handling of unicode content in PDF"""
- pdf_content = b"fake pdf content"
+ pdf_content = b"%PDF-1.7\nfake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
mock_loader = MagicMock()
diff --git a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py
index ca242265..ae393028 100755
--- a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py
+++ b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py
@@ -32,6 +32,10 @@ logger = logging.getLogger(__name__)
default_ident = "document-decoder"
+def _looks_like_pdf(content):
+ return content.lstrip().startswith(b"%PDF-")
+
+
class Processor(FlowProcessor):
def __init__(self, **params):
@@ -94,33 +98,37 @@ class Processor(FlowProcessor):
)
return
- with tempfile.NamedTemporaryFile(delete_on_close=False, suffix='.pdf') as fp:
+ # Check if we should fetch from librarian or use inline data
+ if v.document_id:
+ # Fetch from librarian via Pulsar
+ logger.info(f"Fetching document {v.document_id} from librarian...")
+
+ content = await flow.librarian.fetch_document_content(
+ document_id=v.document_id,
+
+ )
+
+ # Content is base64 encoded
+ if isinstance(content, str):
+ content = content.encode('utf-8')
+ decoded_content = base64.b64decode(content)
+
+ logger.info(f"Fetched {len(decoded_content)} bytes from librarian")
+ else:
+ # Use inline data (backward compatibility)
+ decoded_content = base64.b64decode(v.data)
+
+ if not _looks_like_pdf(decoded_content):
+ logger.error(
+ f"Document {v.metadata.id} is not valid PDF content. "
+ f"Ignoring document."
+ )
+ return
+
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as fp:
temp_path = fp.name
-
- # Check if we should fetch from librarian or use inline data
- if v.document_id:
- # Fetch from librarian via Pulsar
- logger.info(f"Fetching document {v.document_id} from librarian...")
- fp.close()
-
- content = await flow.librarian.fetch_document_content(
- document_id=v.document_id,
-
- )
-
- # Content is base64 encoded
- if isinstance(content, str):
- content = content.encode('utf-8')
- decoded_content = base64.b64decode(content)
-
- with open(temp_path, 'wb') as f:
- f.write(decoded_content)
-
- logger.info(f"Fetched {len(decoded_content)} bytes from librarian")
- else:
- # Use inline data (backward compatibility)
- fp.write(base64.b64decode(v.data))
- fp.close()
+ fp.write(decoded_content)
+ fp.close()
global PyPDFLoader
if PyPDFLoader is None:
From 79d7ef6a90994c5bab7ea0434db64cdffedbb8ff Mon Sep 17 00:00:00 2001
From: Jacob Molz
Date: Tue, 9 Jun 2026 11:37:10 -0400
Subject: [PATCH 17/18] fix: reject invalid PDF decoder input (#977)
---
tests/unit/test_decoding/test_pdf_decoder.py | 48 ++++++++++++++-
.../trustgraph/decoding/pdf/pdf_decoder.py | 60 +++++++++++--------
2 files changed, 79 insertions(+), 29 deletions(-)
diff --git a/tests/unit/test_decoding/test_pdf_decoder.py b/tests/unit/test_decoding/test_pdf_decoder.py
index 04807b20..641a9d78 100644
--- a/tests/unit/test_decoding/test_pdf_decoder.py
+++ b/tests/unit/test_decoding/test_pdf_decoder.py
@@ -49,7 +49,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test successful PDF processing"""
# Mock PDF content
- pdf_content = b"fake pdf content"
+ pdf_content = b"%PDF-1.7\nfake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
# Mock PyPDFLoader
@@ -88,13 +88,55 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
# Verify triples were sent for each page (provenance)
assert mock_triples_flow.send.call_count == 2
+ @patch('trustgraph.base.librarian_client.Consumer')
+ @patch('trustgraph.base.librarian_client.Producer')
+ @patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
+ @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
+ async def test_on_message_rejects_librarian_content_that_is_not_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer):
+ """Test rejecting non-PDF content before invoking the PDF loader"""
+ html_content = b"Not found"
+ html_base64 = base64.b64encode(html_content)
+
+ mock_metadata = Metadata(id="test-doc")
+ mock_document = Document(metadata=mock_metadata, document_id="doc-123")
+ mock_msg = MagicMock()
+ mock_msg.value.return_value = mock_document
+
+ mock_output_flow = AsyncMock()
+ mock_triples_flow = AsyncMock()
+ mock_flow = MagicMock(side_effect=lambda name: {
+ "output": mock_output_flow,
+ "triples": mock_triples_flow,
+ }.get(name))
+ mock_flow.librarian.fetch_document_metadata = AsyncMock(
+ return_value=MagicMock(kind="application/pdf")
+ )
+ mock_flow.librarian.fetch_document_content = AsyncMock(
+ return_value=html_base64
+ )
+ mock_flow.librarian.save_child_document = AsyncMock()
+
+ config = {
+ 'id': 'test-pdf-decoder',
+ 'taskgroup': AsyncMock()
+ }
+
+ processor = Processor(**config)
+
+ await processor.on_message(mock_msg, None, mock_flow)
+
+ mock_pdf_loader_class.assert_not_called()
+ mock_output_flow.send.assert_not_called()
+ mock_triples_flow.send.assert_not_called()
+ mock_flow.librarian.save_child_document.assert_not_called()
+
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test handling of empty PDF"""
- pdf_content = b"fake pdf content"
+ pdf_content = b"%PDF-1.7\nfake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
mock_loader = MagicMock()
@@ -126,7 +168,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test handling of unicode content in PDF"""
- pdf_content = b"fake pdf content"
+ pdf_content = b"%PDF-1.7\nfake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
mock_loader = MagicMock()
diff --git a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py
index ca242265..ae393028 100755
--- a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py
+++ b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py
@@ -32,6 +32,10 @@ logger = logging.getLogger(__name__)
default_ident = "document-decoder"
+def _looks_like_pdf(content):
+ return content.lstrip().startswith(b"%PDF-")
+
+
class Processor(FlowProcessor):
def __init__(self, **params):
@@ -94,33 +98,37 @@ class Processor(FlowProcessor):
)
return
- with tempfile.NamedTemporaryFile(delete_on_close=False, suffix='.pdf') as fp:
+ # Check if we should fetch from librarian or use inline data
+ if v.document_id:
+ # Fetch from librarian via Pulsar
+ logger.info(f"Fetching document {v.document_id} from librarian...")
+
+ content = await flow.librarian.fetch_document_content(
+ document_id=v.document_id,
+
+ )
+
+ # Content is base64 encoded
+ if isinstance(content, str):
+ content = content.encode('utf-8')
+ decoded_content = base64.b64decode(content)
+
+ logger.info(f"Fetched {len(decoded_content)} bytes from librarian")
+ else:
+ # Use inline data (backward compatibility)
+ decoded_content = base64.b64decode(v.data)
+
+ if not _looks_like_pdf(decoded_content):
+ logger.error(
+ f"Document {v.metadata.id} is not valid PDF content. "
+ f"Ignoring document."
+ )
+ return
+
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as fp:
temp_path = fp.name
-
- # Check if we should fetch from librarian or use inline data
- if v.document_id:
- # Fetch from librarian via Pulsar
- logger.info(f"Fetching document {v.document_id} from librarian...")
- fp.close()
-
- content = await flow.librarian.fetch_document_content(
- document_id=v.document_id,
-
- )
-
- # Content is base64 encoded
- if isinstance(content, str):
- content = content.encode('utf-8')
- decoded_content = base64.b64decode(content)
-
- with open(temp_path, 'wb') as f:
- f.write(decoded_content)
-
- logger.info(f"Fetched {len(decoded_content)} bytes from librarian")
- else:
- # Use inline data (backward compatibility)
- fp.write(base64.b64decode(v.data))
- fp.close()
+ fp.write(decoded_content)
+ fp.close()
global PyPDFLoader
if PyPDFLoader is None:
From 627c669097e80b4c94196d93ad62ea9d5dbfcfe3 Mon Sep 17 00:00:00 2001
From: cybermaggedon
Date: Wed, 10 Jun 2026 14:10:43 +0100
Subject: [PATCH 18/18] feat: per-caller Bearer token auth and new query tools
for MCP server (#984)
Replace the broken GATEWAY_SECRET auth (token was sent as a query
parameter, silently ignored by the gateway) with end-to-end Bearer
token forwarding. Each MCP caller gets a dedicated WebSocket
authenticated via the gateway's in-band first-frame protocol, with
whoami verification on first connect.
Also fix and extend the tool surface:
- embeddings: accept list of texts (was single string)
- triples_query: use Term wire format with compact keys (was legacy
Value format), add collection and graph parameters
- sparql_query: new tool for SPARQL SELECT/ASK/CONSTRUCT/DESCRIBE
- graphql_query: new tool for structured data (rows) GraphQL queries
- all tools: add optional workspace parameter
---
trustgraph-mcp/trustgraph/mcp_server/mcp.py | 1872 ++++++++---------
.../trustgraph/mcp_server/tg_socket.py | 168 +-
2 files changed, 1044 insertions(+), 996 deletions(-)
diff --git a/trustgraph-mcp/trustgraph/mcp_server/mcp.py b/trustgraph-mcp/trustgraph/mcp_server/mcp.py
index 7378db64..11b975b2 100755
--- a/trustgraph-mcp/trustgraph/mcp_server/mcp.py
+++ b/trustgraph-mcp/trustgraph/mcp_server/mcp.py
@@ -8,71 +8,180 @@ import logging
import json
import uuid
import argparse
-from dataclasses import dataclass
+from dataclasses import dataclass, field
from collections.abc import AsyncIterator
from functools import partial
from mcp.server.fastmcp import FastMCP, Context
-from mcp.types import TextContent
-from websockets.asyncio.client import connect
+from mcp.server.auth.provider import AccessToken, TokenVerifier
+from mcp.server.auth.middleware.auth_context import get_access_token
from trustgraph.base.logging import add_logging_args, setup_logging
-from . tg_socket import WebSocketManager
+from . tg_socket import WebSocketManager, _token_key
+
+logger = logging.getLogger(__name__)
+
+
+# Wire-format Term type codes (match TermTranslator compact keys)
+_TERM_TYPES = {
+ "iri": "i",
+ "literal": "l",
+ "blank": "b",
+}
+
+
+def _make_term(value: str, term_type: str) -> dict:
+ """Build a compact-key Term dict for the gateway wire format.
+
+ Args:
+ value: The term value (IRI string, literal text, or blank node id).
+ term_type: One of "iri", "literal", "blank".
+ """
+ t = _TERM_TYPES.get(term_type)
+ if t is None:
+ raise ValueError(
+ f"Unknown term type '{term_type}' — "
+ f"expected one of: {', '.join(_TERM_TYPES)}"
+ )
+
+ if t == "i":
+ return {"t": t, "i": value}
+ elif t == "l":
+ return {"t": t, "v": value}
+ elif t == "b":
+ return {"t": t, "d": value}
+ return {"t": t}
+
+# ── Security boundary: MCP client → MCP server ──
+# The MCP client authenticates to this server via a Bearer token in the
+# HTTP Authorization header. The SDK's auth middleware extracts and
+# verifies the token before any tool handler runs.
+#
+# We implement a pass-through TokenVerifier: the gateway is the real
+# authority, so we accept any non-empty Bearer token here and forward
+# it to the gateway for validation. The gateway's in-band auth
+# protocol and IAM regime decide whether the token is valid.
+#
+# This means an invalid token will connect to the MCP server but will
+# fail when the first WebSocket auth frame is sent to the gateway.
+# That is intentional — the gateway is the single source of truth.
+
+
+class PassthroughTokenVerifier(TokenVerifier):
+ """Accept any non-empty Bearer token and forward it downstream.
+
+ The TrustGraph gateway is the authority for token validation, not
+ this MCP server. We store the raw token in the AccessToken so that
+ tool handlers can retrieve it via ``get_access_token().token`` and
+ forward it to the gateway.
+ """
+
+ async def verify_token(self, token: str) -> AccessToken | None:
+ if not token:
+ return None
+ return AccessToken(
+ token=token,
+ client_id="mcp-caller",
+ scopes=[],
+ )
+
@dataclass
class AppContext:
- sockets: dict[str, WebSocketManager]
- websocket_url: str
- gateway_token: str
+ sockets: dict[str, WebSocketManager] = field(default_factory=dict)
+ websocket_url: str = ""
+
@asynccontextmanager
-async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = "") -> AsyncIterator[AppContext]:
+async def app_lifespan(
+ server: FastMCP,
+ websocket_url: str = "ws://api-gateway:8088/api/v1/socket",
+) -> AsyncIterator[AppContext]:
+ """Manage per-server state: the pool of per-caller WebSocket
+ connections to the gateway."""
- """
- Manage application lifecycle with type-safe context
- """
-
- # Initialize on startup
- sockets = {}
+ sockets: dict[str, WebSocketManager] = {}
try:
- yield AppContext(sockets=sockets, websocket_url=websocket_url, gateway_token=gateway_token)
+ yield AppContext(sockets=sockets, websocket_url=websocket_url)
finally:
- # Cleanup on shutdown
- logging.info("Shutting down context")
+ logger.info("Shutting down — closing %d WebSocket(s)", len(sockets))
- for k, manager in sockets.items():
- logging.info(f"Closing socket for {k}")
- await manager.stop()
+ for key, manager in sockets.items():
+ try:
+ await manager.stop()
+ except Exception as e:
+ logger.warning("Error closing socket %s: %s", key, e)
- logging.info("Shutdown complete")
+ logger.info("Shutdown complete")
-async def get_socket_manager(ctx):
+
+def _require_token() -> str:
+ """Extract the caller's Bearer token from the MCP auth context.
+
+ Raises RuntimeError if no token is present (the caller did not
+ authenticate).
+ """
+ # ── Security boundary: token extraction ──
+ # get_access_token() reads the contextvar set by the SDK's
+ # AuthContextMiddleware. The token was placed there by
+ # PassthroughTokenVerifier.verify_token() and is the raw Bearer
+ # value from the MCP client's Authorization header.
+ access = get_access_token()
+ if access is None or not access.token:
+ raise RuntimeError(
+ "Authentication required — send a Bearer token in the "
+ "Authorization header"
+ )
+ return access.token
+
+
+async def get_socket_manager(ctx, token):
+ """Return (or create) an authenticated WebSocket for this token.
+
+ Each unique token gets its own WebSocket connection so that
+ gateway-side identity, workspace binding, and capability scoping
+ are preserved per caller.
+ """
lifespan_context = ctx.request_context.lifespan_context
sockets = lifespan_context.sockets
websocket_url = lifespan_context.websocket_url
- gateway_token = lifespan_context.gateway_token
- if "default" in sockets:
- logging.info("Return existing socket manager")
- return sockets["default"]
+ key = _token_key(token)
- logging.info(f"Opening socket to {websocket_url}...")
+ if key in sockets:
+ manager = sockets[key]
+ if manager.socket is not None:
+ return manager
+ # Socket was closed (e.g. server-side timeout) — reconnect.
+ del sockets[key]
- # Create manager with empty pending requests
- manager = WebSocketManager(websocket_url, token=gateway_token)
+ logger.info("Opening authenticated WebSocket to %s …", websocket_url)
- # Start reader task with the proper manager
+ manager = WebSocketManager(websocket_url, token=token)
await manager.start()
- sockets["default"] = manager
+ # Verify the token is valid by calling whoami. This confirms the
+ # gateway accepted the token and gives us the caller's identity.
+ try:
+ identity = await manager.whoami()
+ logger.info(
+ "WebSocket ready — caller: %s",
+ identity.get("handle", "unknown"),
+ )
+ except Exception as e:
+ await manager.stop()
+ raise RuntimeError(
+ f"Token rejected by gateway (whoami failed): {e}"
+ ) from e
- logging.info("Return new socket manager")
+ sockets[key] = manager
return manager
+
@dataclass
class EmbeddingsResponse:
vectors: List[List[float]]
@@ -182,10 +291,23 @@ class PutConfigResponse:
class DeleteConfigResponse:
pass
+@dataclass
+class SparqlQueryResponse:
+ query_type: str
+ variables: List[str]
+ bindings: List[Dict[str, Any]]
+ ask_result: bool
+ triples: List[Dict[str, Any]]
+
+@dataclass
+class GraphQLQueryResponse:
+ data: Any
+ errors: List[Dict[str, Any]]
+
@dataclass
class GetPromptsResponse:
prompts: List[str]
-
+
@dataclass
class GetPromptResponse:
prompt: Dict[str, Any]
@@ -194,31 +316,61 @@ class GetPromptResponse:
class GetSystemPromptResponse:
prompt: str
+
class McpServer:
- def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = ""):
+ def __init__(
+ self,
+ host: str = "0.0.0.0",
+ port: int = 8000,
+ websocket_url: str = "ws://api-gateway:8088/api/v1/socket",
+ auth_issuer: str = "",
+ auth_resource_url: str = "",
+ ):
self.host = host
self.port = port
self.websocket_url = websocket_url
- self.gateway_token = gateway_token
- # Create a partial function to pass websocket_url to app_lifespan
- lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url, gateway_token=gateway_token)
-
+ lifespan_with_url = partial(
+ app_lifespan, websocket_url=websocket_url,
+ )
+
+ # ── Security: MCP-level auth configuration ──
+ # The SDK requires AuthSettings whenever a token_verifier is
+ # present. The issuer_url tells MCP clients where to obtain
+ # tokens; resource_server_url identifies this server in OAuth
+ # protected-resource metadata.
+ #
+ # The PassthroughTokenVerifier accepts any non-empty Bearer
+ # token — real validation happens at the gateway. This is
+ # intentional: the gateway is the single source of truth for
+ # identity and capability checks.
+ from mcp.server.auth.settings import AuthSettings
+
+ auth_settings = AuthSettings(
+ issuer_url=auth_issuer or f"http://{host}:{port}",
+ resource_server_url=auth_resource_url or f"http://{host}:{port}",
+ )
+
self.mcp = FastMCP(
- "TrustGraph", dependencies=["trustgraph-base"],
- host=self.host, port=self.port,
+ "TrustGraph",
+ dependencies=["trustgraph-base"],
+ host=self.host,
+ port=self.port,
lifespan=lifespan_with_url,
+ token_verifier=PassthroughTokenVerifier(),
+ auth=auth_settings,
)
self._register_tools()
-
+
def _register_tools(self):
"""Register all MCP tools"""
- # Register all the tools that were previously registered globally
self.mcp.tool()(self.embeddings)
self.mcp.tool()(self.text_completion)
self.mcp.tool()(self.graph_rag)
self.mcp.tool()(self.agent)
self.mcp.tool()(self.triples_query)
+ self.mcp.tool()(self.sparql_query)
+ self.mcp.tool()(self.graphql_query)
self.mcp.tool()(self.graph_embeddings_query)
self.mcp.tool()(self.get_config_all)
self.mcp.tool()(self.get_config)
@@ -243,67 +395,69 @@ class McpServer:
self.mcp.tool()(self.load_document)
self.mcp.tool()(self.remove_document)
self.mcp.tool()(self.add_processing)
-
+
def run(self):
"""Run the MCP server"""
self.mcp.run(transport="streamable-http")
+ async def _get_manager(self, ctx):
+ """Get an authenticated WebSocket manager for the current caller.
+
+ Extracts the Bearer token from the MCP auth context and returns
+ a per-token WebSocket connection to the gateway.
+ """
+ token = _require_token()
+ return await get_socket_manager(ctx, token)
+
async def embeddings(
self,
- text: str,
+ texts: List[str],
flow_id: str | None = None,
+ workspace: str | None = None,
ctx: Context = None,
) -> EmbeddingsResponse:
"""
- Generate vector embeddings for the given text using TrustGraph's embedding models.
-
+ Generate vector embeddings for the given texts using TrustGraph's embedding models.
+
This tool converts text into high-dimensional vectors that capture semantic meaning,
enabling similarity searches, clustering, and other vector-based operations.
-
+
Args:
- text: The input text to convert into embeddings. Can be a sentence, paragraph,
- or document. The text will be processed by the configured embedding model.
+ texts: List of input texts to convert into embeddings. Each text can be a
+ sentence, paragraph, or document.
flow_id: Optional flow identifier to use for processing (default: "default").
Different flows may use different embedding models or configurations.
-
+ workspace: Optional workspace to query. If omitted, uses the caller's
+ default workspace.
+
Returns:
- EmbeddingsResponse containing a list of vectors. Each vector is a list of floats
- representing the text's semantic embedding in the model's vector space.
-
- Example usage:
- - Convert a query into embeddings for similarity search
- - Generate embeddings for documents before storing them
- - Create embeddings for comparison with existing knowledge
+ EmbeddingsResponse containing a list of vectors, one per input text.
"""
- logging.info("Embeddings request made")
+ logger.info("Embeddings request")
if flow_id is None: flow_id = "default"
- manager = await get_socket_manager(ctx, "trustgraph")
+ manager = await self._get_manager(ctx)
- if ctx is None:
- raise RuntimeError("No context provided")
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Computing embeddings via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
- await ctx.session.send_log_message(
- level="info",
- data=f"Computing embeddings via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
+ request_data = {"texts": texts}
+
+ gen = manager.request(
+ "embeddings", request_data, flow_id, workspace=workspace,
)
- # Send websocket request
- request_data = {"text": text}
- logging.info("making request")
-
- gen = manager.request("embeddings", request_data, flow_id)
-
async for response in gen:
-
- # Extract vectors from response
vectors = response.get("vectors", [[]])
break
-
+
return EmbeddingsResponse(vectors=vectors)
async def text_completion(
@@ -311,62 +465,47 @@ class McpServer:
prompt: str,
system: str | None = None,
flow_id: str | None = None,
+ workspace: str | None = None,
ctx: Context = None,
) -> TextCompletionResponse:
"""
Generate text completions using TrustGraph's language models.
-
- This tool sends prompts to configured language models and returns generated text.
- It supports both user prompts and system instructions for controlling generation.
-
+
Args:
prompt: The main prompt or question to send to the language model.
- This is the primary input that guides the model's response.
system: Optional system prompt that sets the context, role, or behavior
- for the AI assistant (e.g., "You are a helpful coding assistant").
- System prompts influence how the model interprets and responds.
- flow_id: Optional flow identifier (default: "default"). Different flows
- may use different models, parameters, or processing pipelines.
-
+ for the AI assistant.
+ flow_id: Optional flow identifier (default: "default").
+ workspace: Optional workspace to query. If omitted, uses the caller's
+ default workspace.
+
Returns:
TextCompletionResponse containing the generated text response from the model.
-
- Example usage:
- - Ask questions and get AI-generated answers
- - Generate code, documentation, or creative content
- - Perform text analysis, summarization, or transformation tasks
- - Use system prompts to control tone, style, or domain expertise
"""
if system is None: system = ""
if flow_id is None: flow_id = "default"
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- # Use websocket if context is available
- logging.info("Text completion request made via websocket")
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Generating text completion via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Generating text completion via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
-
- # Send websocket request
request_data = {"system": system, "prompt": prompt}
- gen = manager.request("text-completion", request_data, flow_id)
+ gen = manager.request(
+ "text-completion", request_data, flow_id, workspace=workspace,
+ )
async for response in gen:
-
- # Extract vectors from response
text = response.get("response", "")
break
-
+
return TextCompletionResponse(response=text)
async def graph_rag(
@@ -378,58 +517,43 @@ class McpServer:
max_subgraph_size: int | None = None,
max_path_length: int | None = None,
flow_id: str | None = None,
+ workspace: str | None = None,
ctx: Context = None,
) -> GraphRagResponse:
"""
Perform Graph-based Retrieval Augmented Generation (GraphRAG) queries.
-
+
GraphRAG combines knowledge graph traversal with language model generation to provide
- contextually rich answers. It explores relationships between entities to build relevant
- context before generating responses.
-
+ contextually rich answers.
+
Args:
question: The question or query to answer using the knowledge graph.
- The system will find relevant entities and relationships to inform the response.
collection: Knowledge collection to query (default: "default").
- Different collections may contain domain-specific knowledge.
entity_limit: Maximum number of entities to retrieve during graph traversal.
- Higher limits provide more context but increase processing time.
triple_limit: Maximum number of relationship triples to consider.
- Controls the depth of relationship exploration.
max_subgraph_size: Maximum size of the subgraph to extract for context.
- Larger subgraphs provide richer context but use more resources.
max_path_length: Maximum path length to traverse in the knowledge graph.
- Longer paths can discover distant but relevant relationships.
flow_id: Processing flow to use (default: "default").
-
+ workspace: Optional workspace to query. If omitted, uses the caller's
+ default workspace.
+
Returns:
GraphRagResponse containing the generated answer informed by knowledge graph context.
-
- Example usage:
- - Answer complex questions requiring multi-hop reasoning
- - Explore relationships between entities in your knowledge base
- - Generate responses grounded in structured knowledge
- - Perform research queries across connected information
"""
if collection is None: collection = "default"
if flow_id is None: flow_id = "default"
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("GraphRAG request made via websocket")
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Processing GraphRAG query via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
- manager = await get_socket_manager(ctx)
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Processing GraphRAG query via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
-
- # Build request data with all parameters
request_data = {
"query": question
}
@@ -440,20 +564,19 @@ class McpServer:
if max_subgraph_size: request_data["max_subgraph_size"] = max_subgraph_size
if max_path_length: request_data["max_path_length"] = max_path_length
- gen = manager.request("graph-rag", request_data, flow_id)
+ gen = manager.request(
+ "graph-rag", request_data, flow_id, workspace=workspace,
+ )
text_chunks = []
async for response in gen:
- # Handle new message format with message_type
message_type = response.get("message_type", "chunk")
- # Only collect text from chunk messages
if message_type == "chunk":
chunk_text = response.get("response", "")
if chunk_text:
text_chunks.append(chunk_text)
- # Check if session is complete
if response.get("end_of_session"):
break
@@ -464,404 +587,447 @@ class McpServer:
question: str,
collection: str | None = None,
flow_id: str | None = None,
+ workspace: str | None = None,
ctx: Context = None,
) -> AgentResponse:
"""
Execute intelligent agent queries with reasoning and tool usage capabilities.
-
- The agent can perform complex multi-step reasoning, use tools, and provide
- detailed thought processes. It's designed for tasks requiring planning,
- analysis, and iterative problem-solving.
-
+
Args:
- question: The question or task for the agent to solve. Can be complex
- queries requiring multiple steps, analysis, or tool usage.
+ question: The question or task for the agent to solve.
collection: Knowledge collection the agent can access (default: "default").
- Determines what information and tools are available.
- flow_id: Agent workflow to use (default: "default"). Different flows
- may have different capabilities, tools, or reasoning strategies.
-
+ flow_id: Agent workflow to use (default: "default").
+ workspace: Optional workspace to query. If omitted, uses the caller's
+ default workspace.
+
Returns:
AgentResponse containing the final answer after the agent's reasoning process.
- During execution, you'll see intermediate thoughts and observations.
-
- Example usage:
- - Solve complex analytical problems requiring multiple steps
- - Perform research tasks across multiple information sources
- - Handle queries that need tool usage and decision-making
- - Get detailed explanations of reasoning processes
-
- Note: This tool provides real-time updates on the agent's thinking process
- through log messages, so you can follow its reasoning steps.
"""
if collection is None: collection = "default"
if flow_id is None: flow_id = "default"
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Agent request made via websocket")
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Processing agent query via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
- manager = await get_socket_manager(ctx)
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Processing agent query via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
-
- # Build request data with all parameters
request_data = {
"question": question
}
if collection: request_data["collection"] = collection
- gen = manager.request("agent", request_data, flow_id)
+ gen = manager.request(
+ "agent", request_data, flow_id, workspace=workspace,
+ )
async for response in gen:
- logging.debug(f"Agent response: {response}")
+ logger.debug("Agent response: %s", response)
- if "thought" in response:
- await ctx.session.send_log_message(
- level="info",
- data=f"Thinking: {response['thought']}",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ if "thought" in response:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Thinking: {response['thought']}",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
- if "observation" in response:
- await ctx.session.send_log_message(
- level="info",
- data=f"Observation: {response['observation']}",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if "observation" in response:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Observation: {response['observation']}",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
- # Extract vectors from response
if "answer" in response:
answer = response.get("answer", "")
return AgentResponse(answer=answer)
async def triples_query(
self,
- s_v: str | None = None,
- s_e: bool | None = None,
- p_v: str | None = None,
- p_e: bool | None = None,
- o_v: str | None = None,
- o_e: bool | None = None,
+ s: str | None = None,
+ s_type: str | None = None,
+ p: str | None = None,
+ p_type: str | None = None,
+ o: str | None = None,
+ o_type: str | None = None,
+ collection: str | None = None,
+ graph: str | None = None,
limit: int | None = None,
flow_id: str | None = None,
+ workspace: str | None = None,
ctx: Context = None,
) -> TriplesQueryResponse:
"""
Query knowledge graph triples using subject-predicate-object patterns.
-
- Knowledge graphs store information as triples (subject, predicate, object).
- This tool allows flexible querying by specifying any combination of these
- components, with wildcards for unspecified parts.
-
+
+ Each of s, p, o is an RDF term value. Use the corresponding _type
+ parameter to specify the term kind:
+ - "iri" (default for s and p): an IRI / entity reference
+ - "literal" (default for o): a plain literal value
+ - "blank": a blank node identifier
+
Args:
- s_v: Subject value to match (e.g., "John", "Apple Inc."). Leave None for wildcard.
- s_e: Whether subject should be treated as an entity (True) or literal (False).
- p_v: Predicate/relationship value (e.g., "works_for", "type_of"). Leave None for wildcard.
- p_e: Whether predicate should be treated as an entity (True) or literal (False).
- o_v: Object value to match (e.g., "Engineer", "Company"). Leave None for wildcard.
- o_e: Whether object should be treated as an entity (True) or literal (False).
+ s: Subject value to match. Leave None for wildcard.
+ s_type: Subject term type: "iri" (default), "literal", or "blank".
+ p: Predicate value to match. Leave None for wildcard.
+ p_type: Predicate term type: "iri" (default), "literal", or "blank".
+ o: Object value to match. Leave None for wildcard.
+ o_type: Object term type: "iri", "literal" (default), or "blank".
+ collection: Knowledge collection to query (default: "default").
+ graph: Named graph IRI to restrict the query. None = default graph,
+ "*" = all graphs.
limit: Maximum number of triples to return (default: 20).
flow_id: Processing flow identifier (default: "default").
-
+ workspace: Optional workspace to query. If omitted, uses the caller's
+ default workspace.
+
Returns:
TriplesQueryResponse containing matching triples from the knowledge graph.
-
- Example queries:
- - Find all relationships for an entity: s_v="John", others None
- - Find all instances of a relationship: p_v="works_for", others None
- - Find specific facts: s_v="John", p_v="works_for", o_v=None
- - Explore entity types: p_v="type_of", others None
-
- Use this for:
- - Exploring knowledge graph structure
- - Finding specific facts or relationships
- - Discovering connections between entities
- - Validating or debugging knowledge content
"""
if flow_id is None: flow_id = "default"
if limit is None: limit = 20
+ if collection is None: collection = "default"
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Triples query request made via websocket")
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Processing triples query via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Processing triples query via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
-
- # Build request data with Value objects
request_data = {
- "limit": limit
+ "limit": limit,
+ "collection": collection,
}
- # Add subject if provided
- if s_v is not None:
- request_data["s"] = {"v": s_v, "e": s_e }
+ if s is not None:
+ request_data["s"] = _make_term(s, s_type or "iri")
- # Add predicate if provided
- if p_v is not None:
- request_data["p"] = {"v": p_v, "e": p_e }
+ if p is not None:
+ request_data["p"] = _make_term(p, p_type or "iri")
- # Add object if provided
- if o_v is not None:
- request_data["o"] = {"v": o_v, "e": o_e }
+ if o is not None:
+ request_data["o"] = _make_term(o, o_type or "literal")
- gen = manager.request("triples", request_data, flow_id)
+ if graph is not None:
+ request_data["g"] = graph
+
+ gen = manager.request(
+ "triples", request_data, flow_id, workspace=workspace,
+ )
async for response in gen:
- # Extract response data
triples = response.get("response", [])
break
-
+
return TriplesQueryResponse(triples=triples)
+ async def sparql_query(
+ self,
+ query: str,
+ collection: str | None = None,
+ limit: int | None = None,
+ flow_id: str | None = None,
+ workspace: str | None = None,
+ ctx: Context = None,
+ ) -> SparqlQueryResponse:
+ """
+ Execute a SPARQL query against the knowledge graph.
+
+ Supports SELECT, ASK, CONSTRUCT, and DESCRIBE query forms.
+
+ Args:
+ query: SPARQL query string (e.g. "SELECT ?s ?p ?o WHERE { ?s ?p ?o } LIMIT 10").
+ collection: Knowledge collection to query (default: "default").
+ limit: Safety limit on number of results (default: 10000).
+ flow_id: Processing flow identifier (default: "default").
+ workspace: Optional workspace to query. If omitted, uses the caller's
+ default workspace.
+
+ Returns:
+ SparqlQueryResponse containing the query results. The structure depends
+ on query type:
+ - SELECT: variables (column names) and bindings (rows of Term values)
+ - ASK: ask_result (boolean)
+ - CONSTRUCT/DESCRIBE: triples
+ """
+
+ if collection is None: collection = "default"
+ if flow_id is None: flow_id = "default"
+ if limit is None: limit = 10000
+
+ manager = await self._get_manager(ctx)
+
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Processing SPARQL query via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
+
+ request_data = {
+ "query": query,
+ "collection": collection,
+ "limit": limit,
+ }
+
+ gen = manager.request(
+ "sparql", request_data, flow_id, workspace=workspace,
+ )
+
+ async for response in gen:
+ query_type = response.get("query-type", "")
+ return SparqlQueryResponse(
+ query_type=query_type,
+ variables=response.get("variables", []),
+ bindings=response.get("bindings", []),
+ ask_result=response.get("ask-result", False),
+ triples=response.get("triples", []),
+ )
+
+ async def graphql_query(
+ self,
+ query: str,
+ collection: str | None = None,
+ variables: Dict[str, Any] | None = None,
+ operation_name: str | None = None,
+ flow_id: str | None = None,
+ workspace: str | None = None,
+ ctx: Context = None,
+ ) -> GraphQLQueryResponse:
+ """
+ Execute a GraphQL query against structured data (rows).
+
+ Queries structured data schemas that have been loaded into TrustGraph.
+ The available types and fields depend on the schemas configured in the
+ target workspace.
+
+ Args:
+ query: GraphQL query string (e.g. '{ customers(where: {status: {eq: "active"}}) { id name } }').
+ collection: Data collection to query (default: "default").
+ variables: Optional GraphQL variables as a dict.
+ operation_name: Optional operation name for multi-operation documents.
+ flow_id: Processing flow identifier (default: "default").
+ workspace: Optional workspace to query. If omitted, uses the caller's
+ default workspace.
+
+ Returns:
+ GraphQLQueryResponse containing data (the query result) and errors
+ (any GraphQL field-level errors).
+ """
+
+ if collection is None: collection = "default"
+ if flow_id is None: flow_id = "default"
+
+ manager = await self._get_manager(ctx)
+
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Processing GraphQL query via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
+
+ request_data = {
+ "query": query,
+ "collection": collection,
+ "variables": variables or {},
+ }
+
+ if operation_name is not None:
+ request_data["operation_name"] = operation_name
+
+ gen = manager.request(
+ "rows", request_data, flow_id, workspace=workspace,
+ )
+
+ async for response in gen:
+ return GraphQLQueryResponse(
+ data=response.get("data"),
+ errors=response.get("errors", []),
+ )
+
async def graph_embeddings_query(
self,
vectors: List[List[float]],
limit: int | None = None,
flow_id: str | None = None,
+ workspace: str | None = None,
ctx: Context = None,
) -> GraphEmbeddingsQueryResponse:
"""
Find entities in the knowledge graph using vector similarity search.
-
- This tool performs semantic search by comparing embedding vectors to find
- the most similar entities in the knowledge graph. It's useful for finding
- conceptually related information even when exact text matches don't exist.
-
+
Args:
- vectors: List of embedding vectors to search with. Each vector should be
- a list of floats representing semantic embeddings (typically from
- the embeddings tool). Multiple vectors can be provided for batch queries.
+ vectors: List of embedding vectors to search with.
limit: Maximum number of similar entities to return (default: 20).
- Higher limits provide more results but may include less relevant matches.
flow_id: Processing flow identifier (default: "default").
-
+ workspace: Optional workspace to query. If omitted, uses the caller's
+ default workspace.
+
Returns:
- GraphEmbeddingsQueryResponse containing entities ranked by similarity to the
- input vectors, along with similarity scores and entity metadata.
-
- Example workflow:
- 1. Use the 'embeddings' tool to convert text to vectors
- 2. Use this tool to find similar entities in the knowledge graph
- 3. Explore the returned entities for relevant information
-
- Use this for:
- - Semantic search across knowledge entities
- - Finding conceptually similar content
- - Discovering related entities without exact keyword matches
- - Building recommendation systems based on entity similarity
+ GraphEmbeddingsQueryResponse containing entities ranked by similarity.
"""
if flow_id is None: flow_id = "default"
if limit is None: limit = 20
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Graph embeddings query request made via websocket")
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Processing graph embeddings query via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Processing graph embeddings query via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
-
- # Build request data
request_data = {
"vectors": vectors,
"limit": limit
}
- gen = manager.request("graph-embeddings", request_data, flow_id)
+ gen = manager.request(
+ "graph-embeddings", request_data, flow_id, workspace=workspace,
+ )
async for response in gen:
- # Extract entities from response
entities = response.get("entities", [])
break
-
+
return GraphEmbeddingsQueryResponse(entities=entities)
async def get_config_all(
self,
+ workspace: str | None = None,
ctx: Context = None,
) -> ConfigResponse:
"""
Retrieve the complete TrustGraph system configuration.
-
- This tool returns all configuration settings for the TrustGraph system,
- including model configurations, API keys, flow definitions, and system parameters.
-
+
+ Args:
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
- ConfigResponse containing the full configuration as a nested dictionary
- with all system settings, organized by category (e.g., models, flows, storage).
-
- Use this for:
- - Inspecting current system configuration
- - Debugging configuration issues
- - Understanding available models and settings
- - Auditing system setup and parameters
+ ConfigResponse containing the full configuration as a nested dictionary.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get config all request made via websocket")
-
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving all configuration via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Retrieving all configuration via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "config"
}
- gen = manager.request("config", request_data, None)
+ gen = manager.request("config", request_data, None, workspace=workspace)
async for response in gen:
config = response.get("config", {})
break
-
+
return ConfigResponse(config=config)
async def get_config(
self,
keys: List[Dict[str, str]],
+ workspace: str | None = None,
ctx: Context = None,
) -> ConfigGetResponse:
"""
Retrieve specific configuration values by key.
-
- This tool allows you to fetch specific configuration settings without
- retrieving the entire configuration. Useful for checking particular
- settings or API keys.
-
+
Args:
- keys: List of configuration keys to retrieve. Each key should be a dict with:
- - 'type': Configuration category (e.g., 'llm', 'embeddings', 'storage')
- - 'key': Specific setting name within that category
-
+ keys: List of configuration keys to retrieve. Each key should be a dict with
+ 'type' and 'key' fields.
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
ConfigGetResponse containing the requested configuration values.
-
- Example keys:
- - {'type': 'llm', 'key': 'openai.model'}
- - {'type': 'embeddings', 'key': 'default.model'}
- - {'type': 'storage', 'key': 'database.url'}
-
- Use this for:
- - Checking specific model configurations
- - Validating API key settings
- - Inspecting individual system parameters
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get config request made via websocket")
-
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving specific configuration via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Retrieving specific configuration via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "get",
"keys": keys
}
- gen = manager.request("config", request_data, None)
+ gen = manager.request("config", request_data, None, workspace=workspace)
async for response in gen:
values = response.get("values", [])
break
-
+
return ConfigGetResponse(values=values)
async def put_config(
self,
values: List[Dict[str, str]],
+ workspace: str | None = None,
ctx: Context = None,
) -> PutConfigResponse:
"""
Update system configuration values.
-
- This tool allows you to modify TrustGraph system settings, such as
- model parameters, API keys, and system behavior configurations.
-
+
Args:
- values: List of configuration updates. Each update should be a dict with:
- - 'type': Configuration category (e.g., 'llm', 'embeddings')
- - 'key': Specific setting name to update
- - 'value': New value for the setting
-
+ values: List of configuration updates. Each should be a dict with
+ 'type', 'key', and 'value' fields.
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
PutConfigResponse confirming the configuration update.
-
- Example updates:
- - {'type': 'llm', 'key': 'openai.model', 'value': 'gpt-4'}
- - {'type': 'embeddings', 'key': 'batch_size', 'value': '100'}
-
- Use this for:
- - Switching between different models
- - Updating API credentials
- - Modifying system behavior parameters
- - Configuring processing settings
-
- Note: Configuration changes may require system restart to take effect.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Put config request made via websocket")
-
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Updating configuration via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Updating configuration via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "put",
"values": values
}
- gen = manager.request("config", request_data, None)
+ gen = manager.request("config", request_data, None, workspace=workspace)
async for response in gen:
return PutConfigResponse()
@@ -869,97 +1035,73 @@ class McpServer:
async def delete_config(
self,
keys: List[Dict[str, str]],
+ workspace: str | None = None,
ctx: Context = None,
) -> DeleteConfigResponse:
"""
Delete specific configuration entries from the system.
-
- This tool removes configuration settings, reverting them to system defaults
- or disabling specific features.
-
+
Args:
- keys: List of configuration keys to delete. Each key should be a dict with:
- - 'type': Configuration category (e.g., 'llm', 'embeddings')
- - 'key': Specific setting name to remove
-
+ keys: List of configuration keys to delete. Each should be a dict with
+ 'type' and 'key' fields.
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
DeleteConfigResponse confirming the deletion.
-
- Use this for:
- - Removing custom model configurations
- - Clearing API credentials
- - Resetting settings to defaults
- - Cleaning up obsolete configurations
-
- Warning: Deleting essential configuration may cause system functionality
- to be disabled until properly reconfigured.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Delete config request made via websocket")
-
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Deleting configuration via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Deleting configuration via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "delete",
"keys": keys
}
- gen = manager.request("config", request_data, None)
+ gen = manager.request("config", request_data, None, workspace=workspace)
async for response in gen:
return DeleteConfigResponse()
async def get_prompts(
self,
+ workspace: str | None = None,
ctx: Context = None,
) -> GetPromptsResponse:
"""
List all available prompt templates in the system.
-
- Prompt templates are reusable prompts that can be used with language models
- for consistent behavior across different queries and use cases.
-
+
+ Args:
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
GetPromptsResponse containing a list of available prompt template IDs.
- Each ID can be used with get_prompt to retrieve the full template.
-
- Use this for:
- - Discovering available prompt templates
- - Exploring pre-configured prompts for different tasks
- - Finding templates for specific use cases
- - Understanding what prompt options are available
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get prompts request made via websocket")
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Retrieving prompt templates via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving prompt templates via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
-
- # First get all config
request_data = {
"operation": "config"
}
- gen = manager.request("config", request_data, None)
+ gen = manager.request("config", request_data, None, workspace=workspace)
async for response in gen:
config = response.get("config", {})
@@ -971,49 +1113,36 @@ class McpServer:
async def get_prompt(
self,
prompt_id: str,
+ workspace: str | None = None,
ctx: Context = None,
) -> GetPromptResponse:
"""
Retrieve a specific prompt template by ID.
-
- Prompt templates contain structured prompts with placeholders, instructions,
- and metadata for specific tasks or domains.
-
+
Args:
prompt_id: The unique identifier of the prompt template to retrieve.
- Use get_prompts to see available template IDs.
-
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
- GetPromptResponse containing the complete prompt template with its
- structure, placeholders, and usage instructions.
-
- Use this for:
- - Examining prompt template structure
- - Understanding how to use specific templates
- - Copying or modifying existing prompts
- - Learning prompt engineering patterns
+ GetPromptResponse containing the complete prompt template.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get prompt request made via websocket")
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Retrieving prompt template '{prompt_id}' via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving prompt template '{prompt_id}' via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
-
- # First get all config
request_data = {
"operation": "config"
}
- gen = manager.request("config", request_data, None)
+ gen = manager.request("config", request_data, None, workspace=workspace)
async for response in gen:
config = response.get("config", {})
@@ -1025,44 +1154,35 @@ class McpServer:
async def get_system_prompt(
self,
+ workspace: str | None = None,
ctx: Context = None,
) -> GetSystemPromptResponse:
"""
Retrieve the current system prompt configuration.
-
- The system prompt defines the default behavior, personality, and instructions
- for language models across the TrustGraph system.
-
+
+ Args:
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
- GetSystemPromptResponse containing the system prompt text and configuration.
-
- Use this for:
- - Understanding default AI behavior settings
- - Checking current system-wide prompt configuration
- - Auditing AI personality and instruction settings
- - Debugging unexpected AI responses
+ GetSystemPromptResponse containing the system prompt text.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get system prompt request made via websocket")
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Retrieving system prompt via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving system prompt via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
-
- # First get all config
request_data = {
"operation": "config"
}
- gen = manager.request("config", request_data, None)
+ gen = manager.request("config", request_data, None, workspace=workspace)
async for response in gen:
config = response.get("config", {})
@@ -1073,51 +1193,39 @@ class McpServer:
async def get_token_costs(
self,
+ workspace: str | None = None,
ctx: Context = None,
) -> ConfigTokenCostsResponse:
"""
Retrieve token pricing information for all configured AI models.
-
- This tool provides cost information for input and output tokens across
- different language models, helping with budget planning and cost optimization.
-
+
+ Args:
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
- ConfigTokenCostsResponse containing pricing data for each model including:
- - Model name/identifier
- - Input token cost (per token)
- - Output token cost (per token)
-
- Use this for:
- - Estimating costs for different models
- - Choosing cost-effective models for tasks
- - Budget planning and cost analysis
- - Monitoring and optimizing AI spending
+ ConfigTokenCostsResponse containing pricing data for each model.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get token costs request made via websocket")
-
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving token costs via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Retrieving token costs via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "getvalues",
"type": "token-costs"
}
- gen = manager.request("config", request_data, None)
+ gen = manager.request("config", request_data, None, workspace=workspace)
async for response in gen:
values = response.get("values", [])
- # Transform to match TypeScript API format
costs = []
for item in values:
try:
@@ -1130,106 +1238,89 @@ class McpServer:
except (json.JSONDecodeError, AttributeError):
continue
break
-
+
return ConfigTokenCostsResponse(costs=costs)
async def get_knowledge_cores(
self,
+ workspace: str | None = None,
ctx: Context = None,
) -> KnowledgeCoresResponse:
"""
List all available knowledge graph cores in the current workspace.
- Knowledge cores are packaged collections of structured knowledge that can
- be loaded into the system for querying and reasoning. They contain entities,
- relationships, and facts organized as knowledge graphs.
+ Args:
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
Returns:
KnowledgeCoresResponse containing a list of available knowledge core IDs.
-
- Use this for:
- - Discovering available knowledge collections
- - Understanding what knowledge domains are accessible
- - Planning which cores to load for specific tasks
- - Managing knowledge resources
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get knowledge cores request made via websocket")
-
- manager = await get_socket_manager(ctx)
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving knowledge graph cores via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Retrieving knowledge graph cores via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "list-kg-cores",
}
- gen = manager.request("knowledge", request_data, None)
+ gen = manager.request(
+ "knowledge", request_data, None, workspace=workspace,
+ )
async for response in gen:
ids = response.get("ids", [])
break
-
+
return KnowledgeCoresResponse(ids=ids)
async def delete_kg_core(
self,
core_id: str,
+ workspace: str | None = None,
ctx: Context = None,
) -> DeleteKgCoreResponse:
"""
Permanently delete a knowledge graph core.
- This operation removes a knowledge core from storage. Use with caution
- as this action cannot be undone.
-
Args:
core_id: Unique identifier of the knowledge core to delete.
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
Returns:
DeleteKgCoreResponse confirming the deletion.
-
- Use this for:
- - Cleaning up obsolete knowledge cores
- - Removing test or experimental data
- - Managing storage space
- - Maintaining organized knowledge collections
-
- Warning: This permanently deletes the knowledge core and all its data.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Delete KG core request made via websocket")
-
- manager = await get_socket_manager(ctx)
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Deleting knowledge graph core '{core_id}' via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Deleting knowledge graph core '{core_id}' via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "delete-kg-core",
"id": core_id,
}
- gen = manager.request("knowledge", request_data, None)
+ gen = manager.request(
+ "knowledge", request_data, None, workspace=workspace,
+ )
async for response in gen:
break
-
+
return DeleteKgCoreResponse()
async def load_kg_core(
@@ -1237,46 +1328,34 @@ class McpServer:
core_id: str,
flow: str,
collection: str | None = None,
+ workspace: str | None = None,
ctx: Context = None,
) -> LoadKgCoreResponse:
"""
Load a knowledge graph core into the active system for querying.
- This operation makes a knowledge core available for GraphRAG queries,
- triple searches, and other knowledge-based operations.
-
Args:
core_id: Unique identifier of the knowledge core to load.
- flow: Processing flow to use for loading the core. Different flows
- may apply different processing, indexing, or optimization steps.
- collection: Target collection name (default: "default"). The loaded
- knowledge will be available under this collection name.
+ flow: Processing flow to use for loading the core.
+ collection: Target collection name (default: "default").
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
Returns:
LoadKgCoreResponse confirming the core has been loaded.
-
- Use this for:
- - Making knowledge cores available for queries
- - Switching between different knowledge domains
- - Loading domain-specific knowledge for tasks
- - Preparing knowledge for GraphRAG operations
"""
if collection is None: collection = "default"
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Load KG core request made via websocket")
-
- manager = await get_socket_manager(ctx)
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Loading knowledge graph core '{core_id}' via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Loading knowledge graph core '{core_id}' via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "load-kg-core",
@@ -1285,292 +1364,241 @@ class McpServer:
"collection": collection
}
- gen = manager.request("knowledge", request_data, None)
+ gen = manager.request(
+ "knowledge", request_data, None, workspace=workspace,
+ )
async for response in gen:
break
-
+
return LoadKgCoreResponse()
async def get_kg_core(
self,
core_id: str,
+ workspace: str | None = None,
ctx: Context = None,
) -> GetKgCoreResponse:
"""
Download and retrieve the complete content of a knowledge graph core.
- This tool streams the entire content of a knowledge core, returning all
- entities, relationships, and metadata. Due to potentially large data sizes,
- the content is streamed in chunks.
-
Args:
core_id: Unique identifier of the knowledge core to retrieve.
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
Returns:
GetKgCoreResponse containing all chunks of the knowledge core data.
- Each chunk contains part of the knowledge graph structure.
-
- Use this for:
- - Examining knowledge core content and structure
- - Debugging knowledge graph data
- - Exporting knowledge for backup or analysis
- - Understanding the scope and quality of knowledge
-
- Note: Large knowledge cores may take significant time to download.
- Progress updates are provided through log messages during streaming.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get KG core request made via websocket")
-
- manager = await get_socket_manager(ctx)
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving knowledge graph core '{core_id}' via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Retrieving knowledge graph core '{core_id}' via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "get-kg-core",
"id": core_id,
}
- # Collect all streaming responses
chunks = []
- gen = manager.request("knowledge", request_data, None)
+ gen = manager.request(
+ "knowledge", request_data, None, workspace=workspace,
+ )
async for response in gen:
- # Check for end of stream
if response.get("eos", False):
- await ctx.session.send_log_message(
- level="info",
- data=f"Completed streaming KG core data",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Completed streaming KG core data",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
break
else:
chunks.append(response)
- await ctx.session.send_log_message(
- level="info",
- data=f"Received KG core chunk ({len(chunks)} chunks so far)",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
-
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Received KG core chunk ({len(chunks)} chunks so far)",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
+
return GetKgCoreResponse(chunks=chunks)
async def get_flows(
self,
+ workspace: str | None = None,
ctx: Context = None,
) -> FlowsResponse:
"""
List all available processing flows in the system.
-
- Flows define processing pipelines for different types of operations
- (e.g., document processing, knowledge extraction, query handling).
- Each flow encapsulates a specific workflow with configured steps.
-
+
+ Args:
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
FlowsResponse containing a list of available flow identifiers.
-
- Use this for:
- - Discovering available processing workflows
- - Understanding what processing options are available
- - Choosing appropriate flows for specific tasks
- - Planning workflow-based operations
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get flows request made via websocket")
-
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving available flows via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Retrieving available flows via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "list-flows"
}
- gen = manager.request("flow", request_data, None)
+ gen = manager.request(
+ "flow", request_data, None, workspace=workspace,
+ )
async for response in gen:
flow_ids = response.get("flow-ids", [])
break
-
+
return FlowsResponse(flow_ids=flow_ids)
async def get_flow(
self,
flow_id: str,
+ workspace: str | None = None,
ctx: Context = None,
) -> FlowResponse:
"""
Retrieve the complete definition of a specific processing flow.
-
- This tool returns the detailed configuration, steps, and parameters
- of a processing flow, showing how it processes data and what operations it performs.
-
+
Args:
flow_id: Unique identifier of the flow to retrieve.
-
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
- FlowResponse containing the complete flow definition including:
- - Flow configuration and parameters
- - Processing steps and their order
- - Input/output specifications
- - Dependencies and requirements
-
- Use this for:
- - Understanding how specific flows work
- - Debugging flow processing issues
- - Learning flow configuration patterns
- - Customizing or duplicating flows
+ FlowResponse containing the complete flow definition.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get flow request made via websocket")
-
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving flow definition for '{flow_id}' via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Retrieving flow definition for '{flow_id}' via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "get-flow",
"flow-id": flow_id,
}
- gen = manager.request("flow", request_data, None)
+ gen = manager.request(
+ "flow", request_data, None, workspace=workspace,
+ )
async for response in gen:
flow_data = response.get("flow", "{}")
- # Parse JSON flow definition as done in TypeScript
flow = json.loads(flow_data) if isinstance(flow_data, str) else flow_data
break
-
+
return FlowResponse(flow=flow)
async def get_flow_classes(
self,
+ workspace: str | None = None,
ctx: Context = None,
) -> FlowClassesResponse:
"""
List all available flow class templates.
-
- Flow classes are templates that define types of processing workflows.
- They serve as blueprints for creating specific flow instances with
- customized parameters.
-
+
+ Args:
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
FlowClassesResponse containing a list of available flow class names.
-
- Use this for:
- - Discovering available flow templates
- - Understanding what types of processing are supported
- - Planning new flow creation
- - Exploring system capabilities
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get flow classes request made via websocket")
-
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving flow classes via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Retrieving flow classes via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "list-classes"
}
- gen = manager.request("flow", request_data, None)
+ gen = manager.request(
+ "flow", request_data, None, workspace=workspace,
+ )
async for response in gen:
class_names = response.get("class-names", [])
break
-
+
return FlowClassesResponse(class_names=class_names)
async def get_flow_class(
self,
class_name: str,
+ workspace: str | None = None,
ctx: Context = None,
) -> FlowClassResponse:
"""
Retrieve the definition of a specific flow class template.
-
- Flow classes define the structure, parameters, and capabilities of
- flow types. This tool returns the class specification including
- configurable parameters and processing logic.
-
+
Args:
class_name: Name of the flow class to retrieve.
-
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
- FlowClassResponse containing the flow class definition with:
- - Class parameters and configuration options
- - Processing capabilities and requirements
- - Usage instructions and examples
-
- Use this for:
- - Understanding flow class capabilities
- - Learning how to configure new flows
- - Troubleshooting flow creation issues
- - Exploring advanced flow features
+ FlowClassResponse containing the flow class definition.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get flow class request made via websocket")
-
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving flow class definition for '{class_name}' via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Retrieving flow class definition for '{class_name}' via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "get-class",
"class-name": class_name
}
- gen = manager.request("flow", request_data, None)
+ gen = manager.request(
+ "flow", request_data, None, workspace=workspace,
+ )
async for response in gen:
class_def_data = response.get("class-definition", "{}")
- # Parse JSON class definition as done in TypeScript
class_definition = json.loads(class_def_data) if isinstance(class_def_data, str) else class_def_data
break
-
+
return FlowClassResponse(class_definition=class_definition)
async def start_flow(
@@ -1578,43 +1606,32 @@ class McpServer:
flow_id: str,
class_name: str,
description: str,
+ workspace: str | None = None,
ctx: Context = None,
) -> StartFlowResponse:
"""
Create and start a new processing flow instance.
-
- This tool creates a new flow based on a flow class template and starts
- it running. The flow will begin processing according to its configuration.
-
+
Args:
flow_id: Unique identifier for the new flow instance.
class_name: Flow class template to use for creating the flow.
- Use get_flow_classes to see available classes.
description: Human-readable description of the flow's purpose.
-
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
StartFlowResponse confirming the flow has been started.
-
- Use this for:
- - Creating new processing workflows
- - Starting automated processing tasks
- - Launching background operations
- - Initiating data processing pipelines
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Start flow request made via websocket")
-
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Starting flow '{flow_id}' with class '{class_name}' via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Starting flow '{flow_id}' with class '{class_name}' via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "start-flow",
@@ -1623,162 +1640,135 @@ class McpServer:
"description": description
}
- gen = manager.request("flow", request_data, None)
+ gen = manager.request(
+ "flow", request_data, None, workspace=workspace,
+ )
async for response in gen:
break
-
+
return StartFlowResponse()
async def stop_flow(
self,
flow_id: str,
+ workspace: str | None = None,
ctx: Context = None,
) -> StopFlowResponse:
"""
Stop a running flow instance.
-
- This tool gracefully stops a running flow, allowing it to complete
- current operations before shutting down.
-
+
Args:
flow_id: Unique identifier of the flow instance to stop.
-
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
+
Returns:
StopFlowResponse confirming the flow has been stopped.
-
- Use this for:
- - Stopping unwanted or completed flows
- - Managing system resources
- - Interrupting long-running processes
- - Maintaining flow lifecycle
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Stop flow request made via websocket")
-
- manager = await get_socket_manager(ctx, "trustgraph")
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Stopping flow '{flow_id}' via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Stopping flow '{flow_id}' via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "stop-flow",
"flow-id": flow_id
}
- gen = manager.request("flow", request_data, None)
+ gen = manager.request(
+ "flow", request_data, None, workspace=workspace,
+ )
async for response in gen:
break
-
+
return StopFlowResponse()
async def get_documents(
self,
+ workspace: str | None = None,
ctx: Context = None,
) -> DocumentsResponse:
"""
List all documents stored in the TrustGraph document library.
- This tool returns metadata for all documents that have been uploaded
- to the system, including their processing status and properties.
+ Args:
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
Returns:
- DocumentsResponse containing metadata for each document including:
- - Document ID and title
- - Upload timestamp
- - MIME type and size information
- - Tags and custom metadata
- - Processing status
-
- Use this for:
- - Browsing available documents
- - Managing document collections
- - Finding documents by metadata
- - Auditing document storage
+ DocumentsResponse containing metadata for each document.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get documents request made via websocket")
-
- manager = await get_socket_manager(ctx)
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving documents list via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Retrieving documents list via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "list-documents",
}
- gen = manager.request("librarian", request_data, None)
+ gen = manager.request(
+ "librarian", request_data, None, workspace=workspace,
+ )
async for response in gen:
document_metadatas = response.get("document-metadatas", [])
break
-
+
return DocumentsResponse(document_metadatas=document_metadatas)
async def get_processing(
self,
+ workspace: str | None = None,
ctx: Context = None,
) -> ProcessingResponse:
"""
List all documents currently in the processing queue.
- This tool shows documents that are being processed or waiting to be
- processed, along with their processing status and configuration.
+ Args:
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
Returns:
- ProcessingResponse containing processing metadata including:
- - Processing job ID and document ID
- - Processing flow and status
- - Target collection
- - Timestamp and progress information
-
- Use this for:
- - Monitoring document processing progress
- - Debugging processing issues
- - Managing processing queues
- - Understanding system workload
+ ProcessingResponse containing processing metadata.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Get processing request made via websocket")
-
- manager = await get_socket_manager(ctx)
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Retrieving processing list via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Retrieving processing list via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "list-processing",
}
- gen = manager.request("librarian", request_data, None)
+ gen = manager.request(
+ "librarian", request_data, None, workspace=workspace,
+ )
async for response in gen:
processing_metadatas = response.get("processing-metadatas", [])
break
-
+
return ProcessingResponse(processing_metadatas=processing_metadatas)
async def load_document(
@@ -1790,50 +1780,39 @@ class McpServer:
title: str = "",
comments: str = "",
tags: List[str] | None = None,
+ workspace: str | None = None,
ctx: Context = None,
) -> LoadDocumentResponse:
"""
Upload a document to the TrustGraph document library.
- This tool stores documents with rich metadata for later processing,
- search, and knowledge extraction. Documents can be text files, PDFs,
- or other supported formats.
-
Args:
document: The document content as a string. For binary files,
this should be base64-encoded content.
document_id: Optional unique identifier. If not provided, one will be generated.
metadata: Optional list of custom metadata key-value pairs.
- mime_type: MIME type of the document (e.g., 'text/plain', 'application/pdf').
+ mime_type: MIME type of the document.
title: Human-readable title for the document.
comments: Optional description or notes about the document.
- tags: List of tags for categorizing and finding the document.
+ tags: List of tags for categorizing the document.
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
Returns:
LoadDocumentResponse confirming the document has been stored.
-
- Use this for:
- - Adding new documents to the knowledge base
- - Storing reference materials and data sources
- - Building document collections for processing
- - Importing external content for analysis
"""
if tags is None: tags = []
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Load document request made via websocket")
-
- manager = await get_socket_manager(ctx)
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Loading document to library via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data="Loading document to library via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
import time
timestamp = int(time.time())
@@ -1852,63 +1831,55 @@ class McpServer:
"content": document
}
- gen = manager.request("librarian", request_data, None)
+ gen = manager.request(
+ "librarian", request_data, None, workspace=workspace,
+ )
async for response in gen:
break
-
+
return LoadDocumentResponse()
async def remove_document(
self,
document_id: str,
+ workspace: str | None = None,
ctx: Context = None,
) -> RemoveDocumentResponse:
"""
Permanently remove a document from the library.
- This operation deletes a document and all its associated metadata.
- Use with caution as this action cannot be undone.
-
Args:
document_id: Unique identifier of the document to remove.
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
Returns:
RemoveDocumentResponse confirming the document has been deleted.
-
- Use this for:
- - Cleaning up obsolete or incorrect documents
- - Managing storage space
- - Removing sensitive or inappropriate content
- - Maintaining organized document collections
-
- Warning: This permanently deletes the document and all its metadata.
"""
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Remove document request made via websocket")
-
- manager = await get_socket_manager(ctx)
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Removing document '{document_id}' from library via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Removing document '{document_id}' from library via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
request_data = {
"operation": "remove-document",
"document-id": document_id,
}
- gen = manager.request("librarian", request_data, None)
+ gen = manager.request(
+ "librarian", request_data, None, workspace=workspace,
+ )
async for response in gen:
break
-
+
return RemoveDocumentResponse()
async def add_processing(
@@ -1918,53 +1889,37 @@ class McpServer:
flow: str,
collection: str | None = None,
tags: List[str] | None = None,
+ workspace: str | None = None,
ctx: Context = None,
) -> AddProcessingResponse:
"""
Queue a document for processing through a specific workflow.
- This tool adds a document to the processing queue where it will be
- processed by the specified flow to extract knowledge, create embeddings,
- or perform other analysis operations.
-
Args:
processing_id: Unique identifier for this processing job.
document_id: ID of the document to process (must exist in library).
- flow: Processing flow to use. Different flows perform different
- types of analysis (e.g., knowledge extraction, summarization).
+ flow: Processing flow to use.
collection: Target collection for processed knowledge (default: "default").
- Results will be stored under this collection name.
tags: Optional tags for categorizing this processing job.
+ workspace: Optional workspace. If omitted, uses the caller's
+ default workspace.
Returns:
AddProcessingResponse confirming the document has been queued.
-
- Use this for:
- - Processing uploaded documents into knowledge
- - Extracting entities and relationships from text
- - Creating searchable embeddings
- - Converting documents into structured knowledge
-
- Note: Processing may take time depending on document size and flow complexity.
- Use get_processing to monitor progress.
"""
if collection is None: collection = "default"
if tags is None: tags = []
- if ctx is None:
- raise RuntimeError("No context provided")
+ manager = await self._get_manager(ctx)
- logging.info("Add processing request made via websocket")
-
- manager = await get_socket_manager(ctx)
-
- await ctx.session.send_log_message(
- level="info",
- data=f"Adding document '{document_id}' to processing queue via websocket...",
- logger="notification_stream",
- related_request_id=ctx.request_id,
- )
+ if ctx:
+ await ctx.session.send_log_message(
+ level="info",
+ data=f"Adding document '{document_id}' to processing queue via websocket...",
+ logger="notification_stream",
+ related_request_id=ctx.request_id,
+ )
import time
timestamp = int(time.time())
@@ -1981,38 +1936,61 @@ class McpServer:
}
}
- gen = manager.request("librarian", request_data, None)
+ gen = manager.request(
+ "librarian", request_data, None, workspace=workspace,
+ )
async for response in gen:
break
-
+
return AddProcessingResponse()
+
def main():
parser = argparse.ArgumentParser(description='TrustGraph MCP Server')
- parser.add_argument('--host', default='0.0.0.0', help='Host to bind to (default: 0.0.0.0)')
- parser.add_argument('--port', type=int, default=8000, help='Port to bind to (default: 8000)')
- parser.add_argument('--websocket-url', default='ws://api-gateway:8088/api/v1/socket', help='WebSocket URL to connect to (default: ws://api-gateway:8088/api/v1/socket)')
+ parser.add_argument(
+ '--host', default='0.0.0.0',
+ help='Host to bind to (default: 0.0.0.0)',
+ )
+ parser.add_argument(
+ '--port', type=int, default=8000,
+ help='Port to bind to (default: 8000)',
+ )
+ parser.add_argument(
+ '--websocket-url',
+ default='ws://api-gateway:8088/api/v1/socket',
+ help='WebSocket URL for the TrustGraph gateway',
+ )
+ parser.add_argument(
+ '--auth-issuer',
+ default=os.environ.get("AUTH_ISSUER", ""),
+ help='OAuth issuer URL for MCP auth metadata discovery',
+ )
+ parser.add_argument(
+ '--auth-resource-url',
+ default=os.environ.get("AUTH_RESOURCE_URL", ""),
+ help='Resource server URL for OAuth protected resource metadata',
+ )
- # Add logging arguments
add_logging_args(parser)
args = parser.parse_args()
- # Setup logging before creating server
setup_logging(vars(args))
- # Read gateway auth token from environment
- gateway_token = os.environ.get("GATEWAY_SECRET", "")
-
- # Create and run the MCP server
- server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url, gateway_token=gateway_token)
+ server = McpServer(
+ host=args.host,
+ port=args.port,
+ websocket_url=args.websocket_url,
+ auth_issuer=args.auth_issuer,
+ auth_resource_url=args.auth_resource_url,
+ )
server.run()
+
def run():
- """Legacy function for backward compatibility"""
main()
+
if __name__ == "__main__":
main()
-
diff --git a/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py b/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py
index bff8ae75..9fbf7459 100644
--- a/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py
+++ b/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py
@@ -1,49 +1,110 @@
-from dataclasses import dataclass
from websockets.asyncio.client import connect
-from urllib.parse import urlencode, urlparse, urlunparse, parse_qs
import asyncio
import logging
import json
import uuid
-import time
+import hashlib
+
+logger = logging.getLogger(__name__)
+
+
+def _token_key(token):
+ """Derive a dict key from a token without storing the raw secret."""
+ return hashlib.sha256(token.encode()).hexdigest()[:16]
+
class WebSocketManager:
+ """Manages an authenticated WebSocket connection to the TrustGraph
+ gateway on behalf of a single caller.
- def __init__(self, url, token=None):
+ Each caller token gets its own WebSocketManager so that gateway-side
+ identity, workspace, and capability scoping are preserved end-to-end.
+ """
+
+ def __init__(self, url, token):
self.url = url
+ # ── Security boundary: token storage ──
+ # This is the MCP caller's Bearer token, forwarded verbatim to
+ # the gateway. It MUST NOT be logged, persisted, or shared
+ # across callers. It is held only for the lifetime of this
+ # connection so that re-auth (e.g. after a reconnect) is
+ # possible.
self.token = token
self.socket = None
-
- # FIXME: authentication is broken. The /api/v1/socket endpoint uses
- # in-band auth (first-frame protocol via the Mux dispatcher), not
- # query-parameter tokens. This query-string token is silently ignored.
- # Fix: after connect(), send an auth frame with the bearer token as
- # the first message, matching the gateway's in-band auth protocol.
- def _build_url(self):
- if not self.token:
- return self.url
- parsed = urlparse(self.url)
- params = parse_qs(parsed.query)
- params["token"] = [self.token]
- new_query = urlencode(params, doseq=True)
- return urlunparse(parsed._replace(query=new_query))
+ self.identity = None
+ self.last_used = None
async def start(self):
- self.socket = await connect(self._build_url())
+ """Connect and authenticate via the gateway's in-band auth
+ protocol. Raises on auth failure."""
+
+ # ── Security boundary: MCP server → gateway ──
+ # The WebSocket connects to the gateway and authenticates using
+ # the caller's Bearer token via the in-band first-frame auth
+ # protocol. The token belongs to the MCP client — we forward
+ # it as-is and never interpret its contents.
+ self.socket = await connect(self.url)
self.pending_requests = {}
self.running = True
+
+ await self._authenticate()
+
self.reader_task = asyncio.create_task(self.reader())
+ async def _authenticate(self):
+ """Send in-band auth frame and wait for auth-ok / auth-failed.
+
+ The gateway expects ``{"type": "auth", "token": "..."}`` as the
+ first frame on a new WebSocket. Any service frame sent before
+ auth-ok is rejected.
+ """
+ await self.socket.send(json.dumps({
+ "type": "auth",
+ "token": self.token,
+ }))
+
+ response_text = await asyncio.wait_for(self.socket.recv(), 10)
+ response = json.loads(response_text)
+
+ if response.get("type") == "auth-ok":
+ logger.info(
+ "WebSocket authenticated, default workspace: %s",
+ response.get("workspace"),
+ )
+ return
+
+ # Auth failed — close immediately, do not leave an
+ # unauthenticated socket open.
+ await self.socket.close()
+ self.socket = None
+
+ if response.get("type") == "auth-failed":
+ raise RuntimeError(
+ "Gateway rejected the authentication token"
+ )
+
+ raise RuntimeError(
+ f"Unexpected auth response type: {response.get('type')}"
+ )
+
+ async def whoami(self):
+ """Verify the token by calling the gateway's whoami endpoint.
+ Returns the identity dict and caches it on ``self.identity``.
+ """
+ gen = self.request("iam", {"operation": "whoami"}, flow_id=None)
+ async for response in gen:
+ self.identity = response
+ return response
+
async def stop(self):
self.running = False
- await self.reader_task
+ if hasattr(self, "reader_task"):
+ await self.reader_task
async def reader(self):
- """
- Background task to read websocket responses and route to correct
- request
- """
+ """Background task: read WebSocket frames and route them to the
+ correct pending-request queue by ``id``."""
while self.running:
try:
@@ -59,23 +120,21 @@ class WebSocketManager:
request_id = response.get("id")
if request_id and request_id in self.pending_requests:
- # Put the response in the queue
queue = self.pending_requests[request_id]
await queue.put(response)
else:
- logging.warning(
- f"Response for unknown request ID: {request_id}"
+ logger.warning(
+ "Response for unknown request ID: %s", request_id
)
except Exception as e:
- logging.error(f"Error in websocket reader: {e}")
+ logger.error("Error in websocket reader: %s", e)
- # Put error in all pending queues
for queue in self.pending_requests.values():
try:
await queue.put({"error": str(e)})
- except:
+ except Exception:
pass
self.pending_requests.clear()
@@ -86,25 +145,29 @@ class WebSocketManager:
async def request(
self, service, request_data, flow_id="default",
+ workspace=None,
):
- """
- Send a request via websocket and handle single or streaming responses
+ """Send a request via WebSocket and yield responses.
+
+ Args:
+ service: Gateway service name (e.g. "graph-rag", "config").
+ request_data: Inner request payload.
+ flow_id: Optional flow identifier. ``None`` omits the field
+ (workspace-level services don't use flows).
+ workspace: Optional workspace override. When ``None`` the
+ gateway uses the caller's default workspace.
"""
- # Generate unique request ID
+ import time
+ self.last_used = time.monotonic()
+
request_id = f"{uuid.uuid4()}"
- # Determine if this service streams responses
- streaming_services = {"agent"}
- is_streaming = service in streaming_services
-
- # Create a queue for all responses (streaming and single)
response_queue = asyncio.Queue()
self.pending_requests[request_id] = response_queue
try:
- # Build request message
message = {
"id": request_id,
"service": service,
@@ -114,7 +177,16 @@ class WebSocketManager:
if flow_id is not None:
message["flow"] = flow_id
- # Send request
+ # ── Security boundary: workspace scoping ──
+ # When the caller supplies a workspace, we set it on the
+ # message envelope. The gateway's enforce_workspace()
+ # validates that the authenticated identity is permitted
+ # to access the target workspace — we MUST NOT skip or
+ # override that check. When workspace is None, the
+ # gateway default-fills from the identity's bound workspace.
+ if workspace is not None:
+ message["workspace"] = workspace
+
await self.socket.send(json.dumps(message))
while self.running:
@@ -127,19 +199,17 @@ class WebSocketManager:
continue
if "error" in response:
- if "message" in response["error"]:
- raise RuntimeError(response["error"]["text"])
+ if isinstance(response["error"], dict):
+ raise RuntimeError(
+ response["error"].get("message", str(response["error"]))
+ )
else:
raise RuntimeError(str(response["error"]))
yield response["response"]
- if "complete" in response:
- if response["complete"]:
- break
+ if response.get("complete"):
+ break
- except Exception as e:
- # Clean up on error
+ finally:
self.pending_requests.pop(request_id, None)
- raise e
-