Compare commits

..

No commits in common. "master" and "v2.5.10" have entirely different histories.

32 changed files with 1261 additions and 1684 deletions

View file

@ -11,11 +11,11 @@
<a href="https://trendshift.io/repositories/17291" target="_blank"><img src="https://trendshift.io/api/badge/repositories/17291" alt="trustgraph-ai%2Ftrustgraph | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
# The semantic deployment platform
# The agent runtime platform
</div>
TrustGraph is a comprehensive semantic infrastructure for agents built around context graphs — structured, queryable representations of your domain knowledge that ground every agent query in verified, explainable facts in private deployments with sovereign control. The platform is the full stack for agentic systems: context graphs, memory, retrieval, orchestration, and inference for deterministic agent workloads.
TrustGraph is an agent runtime platform built around context graphs — structured, queryable representations of your domain knowledge that ground every agent query in verified, explainable facts in private deployments with sovereign control. The platform is the full stack for agentic systems: context graphs, memory, retrieval, orchestration, and inference for precision-critical agent workloads.
The platform:
- [x] Multi-model and multimodal database system
@ -99,21 +99,23 @@ For a browser based configuration, try the [Configuration Terminal](https://conf
- [**Developer APIs and CLI**](https://docs.trustgraph.ai/reference)
- [**Deployment Guides**](https://docs.trustgraph.ai/deployment)
## Context Graph UI
## Workbench
<img width="1389" height="961" alt="Image" src="https://github.com/user-attachments/assets/35c9250d-0f01-40cb-9294-1ee8fd9a1b56" />
The **Workbench** provides tools for all major features of TrustGraph. The **Workbench** is on port `8888` by default.
The UI provides tools for all major features of TrustGraph. The UI deploys on port `8888` by default.
- **Agent Console** — Query your agents directly with streaming responses and live explainability event tracking, so you can watch reasoning unfold in real time
- **GraphRAG View** — Interactive graph RAG queries with a visual explainability DAG and inline provenance display, making it easy to see exactly where answers came from
- **Context Explorer** — An interactive 3D context graph explorer with dynamic graph loading, BFS neighborhood extraction, edge pulse animation, and multiple navigation views
- **Document Ingestion** — A complete upload and submission workflow with page and chunk inspection and document structure browsing
- **Ontology Workbench** — A full ontology editor with class and property trees, OWL/XML and Turtle import/export with round-trip fidelity, circular dependency detection, and safe-delete confirmation dialogs
- **Schema Workbench** — Interactive schema management with list, create, edit, and delete operations including field and index management
- **Flow Management** — Flow creation and detail views with configurable parameters, temperature controls, and grouped storage layout
- **Workspace UX** — Workspace selection and management surfaced directly in the interface
- **Prompt Editor** — A dedicated prompt editing workflow
- **Vector Search**: Search the installed knowledge bases
- **Agentic, GraphRAG and LLM Chat**: Chat interface for agents, GraphRAG queries, or direct to LLMs
- **Relationships**: Analyze deep relationships in the installed knowledge bases
- **Graph Visualizer**: 3D GraphViz of the installed knowledge bases
- **Library**: Staging area for installing knowledge bases
- **Flow Classes**: Workflow preset configurations
- **Flows**: Create custom workflows and adjust LLM parameters during runtime
- **Knowledge Cores**: Manage resuable knowledge bases
- **Prompts**: Manage and adjust prompts during runtime
- **Schemas**: Define custom schemas for structured data knowledge bases
- **Ontologies**: Define custom ontologies for unstructured data knowledge bases
- **Agent Tools**: Define tools with collections, knowledge cores, MCP connections, and tool groups
- **MCP Tools**: Connect to MCP servers
## TypeScript Library for UIs

View file

@ -409,57 +409,4 @@ class TestEdgeCases:
assert hosts == ['mixed-host']
assert username is None # Stays None
assert password == 'mixed-pass'
class TestReplicationFactorParamPath:
def test_explicit_kwarg(self):
with patch.dict(os.environ, {}, clear=True):
_, _, _, _, rf = resolve_cassandra_config(
replication_factor=3,
)
assert rf == 3
def test_kwarg_overrides_env(self):
with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
_, _, _, _, rf = resolve_cassandra_config(
replication_factor=3,
)
assert rf == 3
def test_env_fallback_when_kwarg_none(self):
with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
_, _, _, _, rf = resolve_cassandra_config(
replication_factor=None,
)
assert rf == 5
def test_default_when_no_kwarg_no_env(self):
with patch.dict(os.environ, {}, clear=True):
_, _, _, _, rf = resolve_cassandra_config()
assert rf == 1
def test_params_dict_path(self):
with patch.dict(os.environ, {}, clear=True):
params = {'cassandra_replication_factor': 3}
_, _, _, _, rf = resolve_cassandra_config(
replication_factor=params.get('cassandra_replication_factor'),
)
assert rf == 3
def test_params_dict_overrides_env(self):
with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
params = {'cassandra_replication_factor': 3}
_, _, _, _, rf = resolve_cassandra_config(
replication_factor=params.get('cassandra_replication_factor'),
)
assert rf == 3
def test_params_dict_missing_falls_to_env(self):
with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
params = {}
_, _, _, _, rf = resolve_cassandra_config(
replication_factor=params.get('cassandra_replication_factor'),
)
assert rf == 5
assert password == 'mixed-pass'

View file

@ -1,136 +0,0 @@
import os
import pytest
from unittest.mock import patch
from trustgraph.base.qdrant_config import (
get_qdrant_defaults,
resolve_qdrant_config,
)
class TestGetQdrantDefaults:
def test_defaults_with_no_env_vars(self):
with patch.dict(os.environ, {}, clear=True):
defaults = get_qdrant_defaults()
assert defaults['url'] == 'http://localhost:6333'
assert defaults['api_key'] is None
assert defaults['replication_factor'] == 1
assert defaults['shard_number'] == 1
def test_defaults_from_env(self):
env = {
'QDRANT_URL': 'http://qdrant:6333',
'QDRANT_API_KEY': 'secret',
'QDRANT_REPLICATION_FACTOR': '3',
'QDRANT_SHARD_NUMBER': '5',
}
with patch.dict(os.environ, env, clear=True):
defaults = get_qdrant_defaults()
assert defaults['url'] == 'http://qdrant:6333'
assert defaults['api_key'] == 'secret'
assert defaults['replication_factor'] == 3
assert defaults['shard_number'] == 5
class TestResolveQdrantConfig:
def test_defaults(self):
with patch.dict(os.environ, {}, clear=True):
url, api_key, rf, sn = resolve_qdrant_config()
assert url == 'http://localhost:6333'
assert api_key is None
assert rf == 1
assert sn == 1
def test_explicit_kwargs(self):
with patch.dict(os.environ, {}, clear=True):
url, api_key, rf, sn = resolve_qdrant_config(
url='http://custom:6333',
api_key='key',
replication_factor=3,
shard_number=5,
)
assert url == 'http://custom:6333'
assert api_key == 'key'
assert rf == 3
assert sn == 5
def test_kwargs_override_env(self):
env = {
'QDRANT_URL': 'http://env:6333',
'QDRANT_REPLICATION_FACTOR': '10',
'QDRANT_SHARD_NUMBER': '10',
}
with patch.dict(os.environ, env, clear=True):
url, _, rf, sn = resolve_qdrant_config(
url='http://explicit:6333',
replication_factor=3,
shard_number=5,
)
assert url == 'http://explicit:6333'
assert rf == 3
assert sn == 5
def test_env_fallback_when_kwargs_none(self):
env = {
'QDRANT_URL': 'http://env:6333',
'QDRANT_REPLICATION_FACTOR': '3',
'QDRANT_SHARD_NUMBER': '5',
}
with patch.dict(os.environ, env, clear=True):
url, _, rf, sn = resolve_qdrant_config()
assert url == 'http://env:6333'
assert rf == 3
assert sn == 5
def test_params_dict_path(self):
with patch.dict(os.environ, {}, clear=True):
params = {
'store_uri': 'http://params:6333',
'api_key': 'pkey',
'qdrant_replication_factor': 3,
'qdrant_shard_number': 5,
}
url, api_key, rf, sn = resolve_qdrant_config(
url=params.get('store_uri'),
api_key=params.get('api_key'),
replication_factor=params.get('qdrant_replication_factor'),
shard_number=params.get('qdrant_shard_number'),
)
assert url == 'http://params:6333'
assert api_key == 'pkey'
assert rf == 3
assert sn == 5
def test_params_dict_overrides_env(self):
env = {
'QDRANT_REPLICATION_FACTOR': '10',
'QDRANT_SHARD_NUMBER': '10',
}
with patch.dict(os.environ, env, clear=True):
params = {
'qdrant_replication_factor': 3,
'qdrant_shard_number': 5,
}
_, _, rf, sn = resolve_qdrant_config(
replication_factor=params.get('qdrant_replication_factor'),
shard_number=params.get('qdrant_shard_number'),
)
assert rf == 3
assert sn == 5
def test_params_dict_missing_falls_to_env(self):
env = {
'QDRANT_REPLICATION_FACTOR': '3',
'QDRANT_SHARD_NUMBER': '5',
}
with patch.dict(os.environ, env, clear=True):
params = {}
_, _, rf, sn = resolve_qdrant_config(
replication_factor=params.get('qdrant_replication_factor'),
shard_number=params.get('qdrant_shard_number'),
)
assert rf == 3
assert sn == 5

View file

@ -49,7 +49,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test successful PDF processing"""
# Mock PDF content
pdf_content = b"%PDF-1.7\nfake pdf content"
pdf_content = b"fake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
# Mock PyPDFLoader
@ -88,55 +88,13 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
# Verify triples were sent for each page (provenance)
assert mock_triples_flow.send.call_count == 2
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_rejects_librarian_content_that_is_not_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test rejecting non-PDF content before invoking the PDF loader"""
html_content = b"<html><body>Not found</body></html>"
html_base64 = base64.b64encode(html_content)
mock_metadata = Metadata(id="test-doc")
mock_document = Document(metadata=mock_metadata, document_id="doc-123")
mock_msg = MagicMock()
mock_msg.value.return_value = mock_document
mock_output_flow = AsyncMock()
mock_triples_flow = AsyncMock()
mock_flow = MagicMock(side_effect=lambda name: {
"output": mock_output_flow,
"triples": mock_triples_flow,
}.get(name))
mock_flow.librarian.fetch_document_metadata = AsyncMock(
return_value=MagicMock(kind="application/pdf")
)
mock_flow.librarian.fetch_document_content = AsyncMock(
return_value=html_base64
)
mock_flow.librarian.save_child_document = AsyncMock()
config = {
'id': 'test-pdf-decoder',
'taskgroup': AsyncMock()
}
processor = Processor(**config)
await processor.on_message(mock_msg, None, mock_flow)
mock_pdf_loader_class.assert_not_called()
mock_output_flow.send.assert_not_called()
mock_triples_flow.send.assert_not_called()
mock_flow.librarian.save_child_document.assert_not_called()
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test handling of empty PDF"""
pdf_content = b"%PDF-1.7\nfake pdf content"
pdf_content = b"fake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
mock_loader = MagicMock()
@ -168,7 +126,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test handling of unicode content in PDF"""
pdf_content = b"%PDF-1.7\nfake pdf content"
pdf_content = b"fake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
mock_loader = MagicMock()

View file

@ -333,8 +333,8 @@ class TestUnifiedTableQueries:
"""Test queries against the unified rows table"""
@pytest.mark.asyncio
@patch('trustgraph.query.rows.cassandra.service.async_execute_paged', new_callable=AsyncMock)
async def test_query_with_index_match(self, mock_async_execute_paged):
@patch('trustgraph.query.rows.cassandra.service.async_execute', new_callable=AsyncMock)
async def test_query_with_index_match(self, mock_async_execute):
"""Test query execution with matching index"""
processor = MagicMock()
processor.session = MagicMock()
@ -344,10 +344,10 @@ class TestUnifiedTableQueries:
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
# Mock async_execute_paged to return test data (list of pages)
# Mock async_execute to return test data
mock_row = MagicMock()
mock_row.data = {"id": "123", "name": "Test Product", "category": "electronics"}
mock_async_execute_paged.return_value = [[mock_row]]
mock_async_execute.return_value = [mock_row]
schema = RowSchema(
name="products",
@ -370,10 +370,10 @@ class TestUnifiedTableQueries:
# Verify Cassandra was connected and queried
processor.connect_cassandra.assert_called_once()
mock_async_execute_paged.assert_called_once()
mock_async_execute.assert_called_once()
# Verify query structure - should query unified rows table
call_args = mock_async_execute_paged.call_args
call_args = mock_async_execute.call_args
query = call_args[0][1]
params = call_args[0][2]
@ -394,8 +394,8 @@ class TestUnifiedTableQueries:
assert results[0]["category"] == "electronics"
@pytest.mark.asyncio
@patch('trustgraph.query.rows.cassandra.service.async_scan', new_callable=AsyncMock)
async def test_query_without_index_match(self, mock_async_scan):
@patch('trustgraph.query.rows.cassandra.service.async_execute', new_callable=AsyncMock)
async def test_query_without_index_match(self, mock_async_execute):
"""Test query execution without matching index (scan mode)"""
processor = MagicMock()
processor.session = MagicMock()
@ -406,10 +406,12 @@ class TestUnifiedTableQueries:
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
# Mock async_scan to return filtered test data
# Mock async_execute to return test data
mock_row1 = MagicMock()
mock_row1.data = {"id": "1", "name": "Product A", "price": "100"}
mock_async_scan.return_value = [mock_row1]
mock_row2 = MagicMock()
mock_row2.data = {"id": "2", "name": "Product B", "price": "200"}
mock_async_execute.return_value = [mock_row1, mock_row2]
schema = RowSchema(
name="products",
@ -430,16 +432,13 @@ class TestUnifiedTableQueries:
limit=10
)
# Verify async_scan was called
mock_async_scan.assert_called_once()
# Verify query structure
call_args = mock_async_scan.call_args
# Query should use ALLOW FILTERING for scan
call_args = mock_async_execute.call_args
query = call_args[0][1]
assert "ALLOW FILTERING" in query
# Should return filtered results
# Should post-filter results
assert len(results) == 1
assert results[0]["name"] == "Product A"

View file

@ -259,8 +259,6 @@ class TestGraphEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
proc._cache_lock = asyncio.Lock()
proc._known_collections = set()
proc.replication_factor = 1
proc.shard_number = 1
msg = MagicMock()
msg.metadata.collection = "graphs"

View file

@ -103,19 +103,35 @@ def resolve_cassandra_config(
host: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
default_keyspace: Optional[str] = None,
replication_factor: Optional[int] = None,
default_keyspace: Optional[str] = None
) -> Tuple[List[str], Optional[str], Optional[str], Optional[str], int]:
"""
Resolve Cassandra configuration from various sources.
Can accept either argparse args object or explicit parameters.
Converts host string to list format for Cassandra driver.
Args:
args: Optional argparse namespace with cassandra_host, cassandra_username, cassandra_password, cassandra_keyspace, cassandra_replication_factor
host: Optional explicit host parameter (overrides args)
username: Optional explicit username parameter (overrides args)
password: Optional explicit password parameter (overrides args)
default_keyspace: Optional default keyspace if not specified elsewhere
Returns:
tuple: (hosts_list, username, password, keyspace, replication_factor)
"""
# If args provided, extract values
keyspace = None
replication_factor = 1
if args is not None:
host = host or getattr(args, 'cassandra_host', None)
username = username or getattr(args, 'cassandra_username', None)
password = password or getattr(args, 'cassandra_password', None)
keyspace = getattr(args, 'cassandra_keyspace', None)
replication_factor = replication_factor or getattr(
args, 'cassandra_replication_factor', None
)
replication_factor = getattr(args, 'cassandra_replication_factor', 1)
# Apply defaults if still None
defaults = get_cassandra_defaults()
host = host or defaults['host']
username = username or defaults['username']

View file

@ -1,87 +0,0 @@
import os
import argparse
from typing import Optional, Any, Tuple
def get_qdrant_defaults() -> dict:
return {
'url': os.getenv('QDRANT_URL', 'http://localhost:6333'),
'api_key': os.getenv('QDRANT_API_KEY'),
'replication_factor': int(os.getenv('QDRANT_REPLICATION_FACTOR', '1')),
'shard_number': int(os.getenv('QDRANT_SHARD_NUMBER', '1')),
}
def add_qdrant_args(parser: argparse.ArgumentParser) -> None:
defaults = get_qdrant_defaults()
url_help = f"Qdrant URL (default: {defaults['url']})"
if 'QDRANT_URL' in os.environ:
url_help += " [from QDRANT_URL]"
api_key_help = "Qdrant API key"
if defaults['api_key']:
api_key_help += " (default: <set>)"
if 'QDRANT_API_KEY' in os.environ:
api_key_help += " [from QDRANT_API_KEY]"
replication_help = f"Qdrant collection replication factor (default: {defaults['replication_factor']})"
if 'QDRANT_REPLICATION_FACTOR' in os.environ:
replication_help += " [from QDRANT_REPLICATION_FACTOR]"
shard_help = f"Qdrant collection shard number (default: {defaults['shard_number']})"
if 'QDRANT_SHARD_NUMBER' in os.environ:
shard_help += " [from QDRANT_SHARD_NUMBER]"
parser.add_argument(
'--store-uri',
default=defaults['url'],
help=url_help,
)
parser.add_argument(
'--api-key',
default=defaults['api_key'],
help=api_key_help,
)
parser.add_argument(
'--qdrant-replication-factor',
type=int,
default=defaults['replication_factor'],
help=replication_help,
)
parser.add_argument(
'--qdrant-shard-number',
type=int,
default=defaults['shard_number'],
help=shard_help,
)
def resolve_qdrant_config(
args: Optional[Any] = None,
url: Optional[str] = None,
api_key: Optional[str] = None,
replication_factor: Optional[int] = None,
shard_number: Optional[int] = None,
) -> Tuple[str, Optional[str], int, int]:
if args is not None:
url = url or getattr(args, 'store_uri', None)
api_key = api_key or getattr(args, 'api_key', None)
replication_factor = replication_factor or getattr(
args, 'qdrant_replication_factor', None
)
shard_number = shard_number or getattr(
args, 'qdrant_shard_number', None
)
defaults = get_qdrant_defaults()
url = url or defaults['url']
api_key = api_key or defaults['api_key']
replication_factor = replication_factor or defaults['replication_factor']
shard_number = shard_number or defaults['shard_number']
return url, api_key, replication_factor, shard_number

View file

@ -78,7 +78,7 @@ def load_structured_data(
logger.info("Step 1: Analyzing data to discover best matching schema...")
# Step 1: Auto-discover schema (reuse discover_schema logic)
discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, token=token, workspace=workspace)
discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace)
if not discovered_schema:
logger.error("Failed to discover suitable schema automatically")
print("❌ Could not automatically determine the best schema for your data.")
@ -90,7 +90,7 @@ def load_structured_data(
# Step 2: Auto-generate descriptor
logger.info("Step 2: Generating descriptor configuration...")
auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, token=token, workspace=workspace)
auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, workspace=workspace)
if not auto_descriptor:
logger.error("Failed to generate descriptor automatically")
print("❌ Could not automatically generate descriptor configuration.")
@ -172,7 +172,7 @@ def load_structured_data(
logger.info(f"Sample chars: {sample_chars} characters")
# Use the helper function to discover schema (get raw response for display)
response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, token=token, workspace=workspace)
response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, workspace=workspace)
if response:
# Debug: print response type and content
@ -203,7 +203,7 @@ def load_structured_data(
# If no schema specified, discover it first
if not schema_name:
logger.info("No schema specified, auto-discovering...")
schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, token=token, workspace=workspace)
schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace)
if not schema_name:
print("Error: Could not determine schema automatically.")
print("Please specify a schema using --schema-name or run --discover-schema first.")
@ -213,7 +213,7 @@ def load_structured_data(
logger.info(f"Target schema: {schema_name}")
# Generate descriptor using helper function
descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, token=token, workspace=workspace)
descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace=workspace)
if descriptor:
# Output the generated descriptor
@ -603,7 +603,7 @@ def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None, worksp
# Helper functions for auto mode
def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, token=None, workspace="default"):
def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, workspace="default"):
"""Auto-discover the best matching schema for the input data
Args:
@ -626,7 +626,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur
# Import API modules
from trustgraph.api import Api
from trustgraph.api.types import ConfigKey
api = Api(api_url, token=token, workspace=workspace)
api = Api(api_url, workspace=workspace)
config_api = api.config()
# Get available schemas
@ -707,7 +707,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur
return None
def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, token=None, workspace="default"):
def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace="default"):
"""Auto-generate descriptor configuration for the discovered schema"""
try:
# Read sample data
@ -717,7 +717,7 @@ def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, fl
# Import API modules
from trustgraph.api import Api
from trustgraph.api.types import ConfigKey
api = Api(api_url, token=token, workspace=workspace)
api = Api(api_url, workspace=workspace)
config_api = api.config()
# Get schema definition

View file

@ -83,8 +83,7 @@ class Processor(AsyncProcessor):
host=cassandra_host,
username=cassandra_username,
password=cassandra_password,
default_keyspace="config",
replication_factor=params.get("cassandra_replication_factor"),
default_keyspace="config"
)
# Store resolved configuration

View file

@ -61,8 +61,7 @@ class Processor(WorkspaceProcessor):
host=cassandra_host,
username=cassandra_username,
password=cassandra_password,
default_keyspace="knowledge",
replication_factor=params.get("cassandra_replication_factor"),
default_keyspace="knowledge"
)
self.cassandra_host = hosts

View file

@ -32,10 +32,6 @@ logger = logging.getLogger(__name__)
default_ident = "document-decoder"
def _looks_like_pdf(content):
return content.lstrip().startswith(b"%PDF-")
class Processor(FlowProcessor):
def __init__(self, **params):
@ -98,37 +94,33 @@ class Processor(FlowProcessor):
)
return
# Check if we should fetch from librarian or use inline data
if v.document_id:
# Fetch from librarian via Pulsar
logger.info(f"Fetching document {v.document_id} from librarian...")
content = await flow.librarian.fetch_document_content(
document_id=v.document_id,
)
# Content is base64 encoded
if isinstance(content, str):
content = content.encode('utf-8')
decoded_content = base64.b64decode(content)
logger.info(f"Fetched {len(decoded_content)} bytes from librarian")
else:
# Use inline data (backward compatibility)
decoded_content = base64.b64decode(v.data)
if not _looks_like_pdf(decoded_content):
logger.error(
f"Document {v.metadata.id} is not valid PDF content. "
f"Ignoring document."
)
return
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as fp:
with tempfile.NamedTemporaryFile(delete_on_close=False, suffix='.pdf') as fp:
temp_path = fp.name
fp.write(decoded_content)
fp.close()
# Check if we should fetch from librarian or use inline data
if v.document_id:
# Fetch from librarian via Pulsar
logger.info(f"Fetching document {v.document_id} from librarian...")
fp.close()
content = await flow.librarian.fetch_document_content(
document_id=v.document_id,
)
# Content is base64 encoded
if isinstance(content, str):
content = content.encode('utf-8')
decoded_content = base64.b64decode(content)
with open(temp_path, 'wb') as f:
f.write(decoded_content)
logger.info(f"Fetched {len(decoded_content)} bytes from librarian")
else:
# Use inline data (backward compatibility)
fp.write(base64.b64decode(v.data))
fp.close()
global PyPDFLoader
if PyPDFLoader is None:

View file

@ -6,7 +6,7 @@ import logging
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.query import BatchStatement, SimpleStatement
import ssl
from ssl import SSLContext, PROTOCOL_TLSv1_2
from ..tables.cassandra_async import async_execute
@ -41,15 +41,13 @@ class KnowledgeGraph:
def __init__(
self, hosts=None,
keyspace="trustgraph", username=None, password=None,
replication_factor=1,
keyspace="trustgraph", username=None, password=None
):
if hosts is None:
hosts = ["localhost"]
self.keyspace = keyspace
self.replication_factor = replication_factor
self.username = username
# 7-table schema for quads with full query pattern support
@ -70,7 +68,7 @@ class KnowledgeGraph:
self.collection_metadata_table = "collection_metadata"
if username and password:
ssl_context = ssl.create_default_context()
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
auth_provider = PlainTextAuthProvider(username=username, password=password)
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
else:
@ -94,7 +92,7 @@ class KnowledgeGraph:
create keyspace if not exists {self.keyspace}
with replication = {{
'class' : 'SimpleStrategy',
'replication_factor' : {self.replication_factor}
'replication_factor' : 1
}};
""")
@ -541,15 +539,13 @@ class EntityCentricKnowledgeGraph:
def __init__(
self, hosts=None,
keyspace="trustgraph", username=None, password=None,
replication_factor=1,
keyspace="trustgraph", username=None, password=None
):
if hosts is None:
hosts = ["localhost"]
self.keyspace = keyspace
self.replication_factor = replication_factor
self.username = username
# 2-table entity-centric schema
@ -560,7 +556,7 @@ class EntityCentricKnowledgeGraph:
self.collection_metadata_table = "collection_metadata"
if username and password:
ssl_context = ssl.create_default_context()
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
auth_provider = PlainTextAuthProvider(username=username, password=password)
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
else:
@ -584,7 +580,7 @@ class EntityCentricKnowledgeGraph:
create keyspace if not exists {self.keyspace}
with replication = {{
'class' : 'SimpleStrategy',
'replication_factor' : {self.replication_factor}
'replication_factor' : 1
}};
""")

View file

@ -101,7 +101,6 @@ class Processor(AsyncProcessor):
username=cassandra_username,
password=cassandra_password,
default_keyspace="iam",
replication_factor=params.get("cassandra_replication_factor"),
)
self.cassandra_host = hosts

View file

@ -8,7 +8,6 @@ import asyncio
import base64
import json
import logging
import os
from datetime import datetime
from .. base import WorkspaceProcessor, Consumer, Producer, Publisher, Subscriber
@ -55,16 +54,6 @@ default_object_store_access_key = "object-user"
default_object_store_secret_key = "object-password"
default_object_store_use_ssl = False
default_object_store_region = None
# Environment variables consulted as a fallback when the
# corresponding params field is not set in the processor-group YAML
# or via CLI. Intended for K8s Secret / env-var injection so
# credentials never have to live in the YAML (and thus in git).
ENV_OBJECT_STORE_ENDPOINT = "OBJECT_STORE_ENDPOINT"
ENV_OBJECT_STORE_ACCESS_KEY = "OBJECT_STORE_ACCESS_KEY"
ENV_OBJECT_STORE_SECRET_KEY = "OBJECT_STORE_SECRET_KEY"
ENV_OBJECT_STORE_USE_SSL = "OBJECT_STORE_USE_SSL"
ENV_OBJECT_STORE_REGION = "OBJECT_STORE_REGION"
default_cassandra_host = "cassandra"
default_min_chunk_size = 1 # No minimum by default (for Garage)
@ -100,36 +89,22 @@ class Processor(WorkspaceProcessor):
"config_response_queue", default_config_response_queue
)
# Resolve object-store config. Precedence: explicit params
# (CLI / processor-group YAML) → environment variable →
# hardcoded default. The env-var path lets K8s Secrets feed
# credentials without them appearing in the YAML.
object_store_endpoint = (
params.get("object_store_endpoint")
or os.environ.get(ENV_OBJECT_STORE_ENDPOINT)
or default_object_store_endpoint
object_store_endpoint = params.get("object_store_endpoint", default_object_store_endpoint)
object_store_access_key = params.get(
"object_store_access_key",
default_object_store_access_key
)
object_store_access_key = (
params.get("object_store_access_key")
or os.environ.get(ENV_OBJECT_STORE_ACCESS_KEY)
or default_object_store_access_key
object_store_secret_key = params.get(
"object_store_secret_key",
default_object_store_secret_key
)
object_store_secret_key = (
params.get("object_store_secret_key")
or os.environ.get(ENV_OBJECT_STORE_SECRET_KEY)
or default_object_store_secret_key
object_store_use_ssl = params.get(
"object_store_use_ssl",
default_object_store_use_ssl
)
object_store_use_ssl = params.get("object_store_use_ssl")
if object_store_use_ssl is None:
env_ssl = os.environ.get(ENV_OBJECT_STORE_USE_SSL)
if env_ssl is not None:
object_store_use_ssl = env_ssl.lower() in ("true", "1", "yes")
else:
object_store_use_ssl = default_object_store_use_ssl
object_store_region = (
params.get("object_store_region")
or os.environ.get(ENV_OBJECT_STORE_REGION)
or default_object_store_region
object_store_region = params.get(
"object_store_region",
default_object_store_region
)
min_chunk_size = params.get(
@ -146,8 +121,7 @@ class Processor(WorkspaceProcessor):
host=cassandra_host,
username=cassandra_username,
password=cassandra_password,
default_keyspace="librarian",
replication_factor=params.get("cassandra_replication_factor"),
default_keyspace="librarian"
)
# Store resolved configuration

View file

@ -12,33 +12,31 @@ from qdrant_client import QdrantClient
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
from .... schema import Error
from .... base import DocumentEmbeddingsQueryService
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "doc-embeddings-query"
default_store_uri = 'http://localhost:6333'
class Processor(DocumentEmbeddingsQueryService):
def __init__(self, **params):
store_uri = params.get("store_uri")
api_key = params.get("api_key")
store_uri = params.get("store_uri", default_store_uri)
url, api_key, _, _ = resolve_qdrant_config(
url=store_uri,
api_key=api_key,
)
#optional api key
api_key = params.get("api_key", None)
super(Processor, self).__init__(
**params | {
"store_uri": url,
"store_uri": store_uri,
"api_key": api_key,
}
)
self.qdrant = QdrantClient(url=url, api_key=api_key)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
async def query_document_embeddings(self, workspace, msg):
@ -87,7 +85,18 @@ class Processor(DocumentEmbeddingsQueryService):
def add_args(parser):
DocumentEmbeddingsQueryService.add_args(parser)
add_qdrant_args(parser)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant store URI (default: {default_store_uri})'
)
parser.add_argument(
'-k', '--api-key',
default=None,
help=f'API key for qdrant (default: None)'
)
def run():

View file

@ -12,32 +12,31 @@ from qdrant_client import QdrantClient
from .... schema import GraphEmbeddingsResponse, EntityMatch
from .... schema import Error, Term, IRI, LITERAL
from .... base import GraphEmbeddingsQueryService
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "graph-embeddings-query"
default_store_uri = 'http://localhost:6333'
class Processor(GraphEmbeddingsQueryService):
def __init__(self, **params):
store_uri = params.get("store_uri")
api_key = params.get("api_key")
store_uri = params.get("store_uri", default_store_uri)
url, api_key, _, _ = resolve_qdrant_config(
url=store_uri, api_key=api_key,
)
#optional api key
api_key = params.get("api_key", None)
super(Processor, self).__init__(
**params | {
"store_uri": url,
"store_uri": store_uri,
"api_key": api_key,
}
)
self.qdrant = QdrantClient(url=url, api_key=api_key)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
@ -105,7 +104,18 @@ class Processor(GraphEmbeddingsQueryService):
def add_args(parser):
GraphEmbeddingsQueryService.add_args(parser)
add_qdrant_args(parser)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant store URI (default: {default_store_uri})'
)
parser.add_argument(
'-k', '--api-key',
default=None,
help=f'API key for qdrant (default: None)'
)
def run():

View file

@ -116,7 +116,7 @@ class CassandraTripleStore(Store if RDFLIB_AVAILABLE else object):
# Create keyspace
self.session.execute(f"""
CREATE KEYSPACE IF NOT EXISTS {self.keyspace}
WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': {self.cassandra_config.get('replication_factor', 1)}}}
WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}
""")
# Create triples table optimized for SPARQL queries

View file

@ -19,12 +19,12 @@ from .... schema import (
RowIndexMatch, Error
)
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "row-embeddings-query"
default_store_uri = 'http://localhost:6333'
default_concurrency = 10
@ -35,17 +35,13 @@ class Processor(FlowProcessor):
id = params.get("id", default_ident)
concurrency = params.get("concurrency", default_concurrency)
store_uri = params.get("store_uri")
api_key = params.get("api_key")
url, api_key, _, _ = resolve_qdrant_config(
url=store_uri, api_key=api_key,
)
store_uri = params.get("store_uri", default_store_uri)
api_key = params.get("api_key", None)
super(Processor, self).__init__(
**params | {
"id": id,
"store_uri": url,
"store_uri": store_uri,
"api_key": api_key,
}
)
@ -66,7 +62,7 @@ class Processor(FlowProcessor):
)
)
self.qdrant = QdrantClient(url=url, api_key=api_key)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Qdrant collection naming"""
@ -196,9 +192,21 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
add_qdrant_args(parser)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant store URI (default: {default_store_uri})'
)
parser.add_argument(
'-k', '--api-key',
default=None,
help='API key for Qdrant (default: None)'
)
parser.add_argument(
'-c', '--concurrency',

View file

@ -24,7 +24,7 @@ from .... schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
from .... schema import Error, RowSchema, Field as SchemaField
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
from .... tables.cassandra_async import async_execute, async_execute_paged, async_scan
from .... tables.cassandra_async import async_execute
from ... graphql import GraphQLSchemaBuilder, SortDirection
@ -180,7 +180,7 @@ class Processor(FlowProcessor):
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False),
indexed=field_def.get("indexed", False)
)
fields.append(field)
@ -232,8 +232,6 @@ class Processor(FlowProcessor):
for index_name in index_names:
if index_name in filters:
value = filters[index_name]
if value == "" or value is None:
continue
# Single field index -> single element list
index_value = [str(value)]
return (index_name, index_value)
@ -284,13 +282,11 @@ class Processor(FlowProcessor):
query += f" LIMIT {limit}"
try:
pages = await async_execute_paged(
self.session, query, params
)
for page in pages:
for row in page:
row_dict = dict(row.data) if row.data else {}
results.append(row_dict)
rows = await async_execute(self.session, query, params)
for row in rows:
# Convert data map to dict with proper field names
row_dict = dict(row.data) if row.data else {}
results.append(row_dict)
except Exception as e:
logger.error(f"Failed to query rows: {e}", exc_info=True)
raise
@ -312,6 +308,8 @@ class Processor(FlowProcessor):
# Query using the first index (arbitrary choice for scan)
primary_index = index_names[0]
# We need to scan all values for this index
# This requires ALLOW FILTERING or a different approach
query = f"""
SELECT data, source FROM {safe_keyspace}.rows
WHERE collection = %s
@ -322,18 +320,17 @@ class Processor(FlowProcessor):
params = [collection, schema_name, primary_index]
try:
def row_filter(row):
row_dict = dict(row.data) if row.data else {}
return self._matches_filters(row_dict, filters, row_schema)
rows = await async_execute(self.session, query, params)
matched_rows = await async_scan(
self.session, query, params,
row_filter=row_filter,
limit=limit,
)
for row in matched_rows:
for row in rows:
row_dict = dict(row.data) if row.data else {}
results.append(row_dict)
# Apply post-filters
if self._matches_filters(row_dict, filters, row_schema):
results.append(row_dict)
if limit and len(results) >= limit:
break
except Exception as e:
logger.error(f"Failed to scan rows: {e}", exc_info=True)
@ -366,7 +363,7 @@ class Processor(FlowProcessor):
# Parse filter key for operator
if '_' in filter_key:
parts = filter_key.rsplit('_', 1)
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in', 'not', 'startsWith', 'endsWith', 'not_in']:
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in']:
field_name = parts[0]
operator = parts[1]
else:
@ -403,18 +400,6 @@ class Processor(FlowProcessor):
elif operator == 'in':
if str(row_value) not in [str(v) for v in filter_value]:
return False
elif operator == 'not':
if str(row_value) == str(filter_value):
return False
elif operator == 'startsWith':
if not str(row_value).startswith(str(filter_value)):
return False
elif operator == 'endsWith':
if not str(row_value).endswith(str(filter_value)):
return False
elif operator == 'not_in':
if str(row_value) in [str(v) for v in filter_value]:
return False
except (ValueError, TypeError):
return False

View file

@ -14,36 +14,29 @@ from qdrant_client.models import Distance, VectorParams
from .... base import DocumentEmbeddingsStoreService, CollectionConfigHandler
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "doc-embeddings-write"
default_store_uri = 'http://localhost:6333'
class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
def __init__(self, **params):
store_uri = params.get("store_uri")
api_key = params.get("api_key")
url, api_key, replication_factor, shard_number = resolve_qdrant_config(
url=store_uri, api_key=api_key,
replication_factor=params.get("qdrant_replication_factor"),
shard_number=params.get("qdrant_shard_number"),
)
store_uri = params.get("store_uri", default_store_uri)
api_key = params.get("api_key", None)
super(Processor, self).__init__(
**params | {
"store_uri": url,
"store_uri": store_uri,
"api_key": api_key,
}
)
self.qdrant = QdrantClient(url=url, api_key=api_key)
self.replication_factor = replication_factor
self.shard_number = shard_number
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self._cache_lock = asyncio.Lock()
self._known_collections: set[str] = set()
@ -68,8 +61,6 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
replication_factor=self.replication_factor,
shard_number=self.shard_number,
)
self._known_collections.add(collection_name)
@ -118,7 +109,18 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
def add_args(parser):
DocumentEmbeddingsStoreService.add_args(parser)
add_qdrant_args(parser)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant URI (default: {default_store_uri})'
)
parser.add_argument(
'-k', '--api-key',
default=None,
help=f'Qdrant API key (default: None)'
)
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""

View file

@ -14,7 +14,6 @@ from qdrant_client.models import Distance, VectorParams
from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler
from .... base import AsyncProcessor, Consumer, Producer
from .... base import ConsumerMetrics, ProducerMetrics
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
from .... schema import IRI, LITERAL
# Module logger
@ -30,34 +29,29 @@ def get_term_value(term):
elif term.type == LITERAL:
return term.value
else:
# For blank nodes or other types, use id or value
return term.id or term.value
default_ident = "graph-embeddings-write"
default_store_uri = 'http://localhost:6333'
class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
def __init__(self, **params):
store_uri = params.get("store_uri")
api_key = params.get("api_key")
url, api_key, replication_factor, shard_number = resolve_qdrant_config(
url=store_uri, api_key=api_key,
replication_factor=params.get("qdrant_replication_factor"),
shard_number=params.get("qdrant_shard_number"),
)
store_uri = params.get("store_uri", default_store_uri)
api_key = params.get("api_key", None)
super(Processor, self).__init__(
**params | {
"store_uri": url,
"store_uri": store_uri,
"api_key": api_key,
}
)
self.qdrant = QdrantClient(url=url, api_key=api_key)
self.replication_factor = replication_factor
self.shard_number = shard_number
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self._cache_lock = asyncio.Lock()
self._known_collections: set[str] = set()
@ -82,8 +76,6 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
replication_factor=self.replication_factor,
shard_number=self.shard_number,
)
self._known_collections.add(collection_name)
@ -136,7 +128,18 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
def add_args(parser):
GraphEmbeddingsStoreService.add_args(parser)
add_qdrant_args(parser)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant store URI (default: {default_store_uri})'
)
parser.add_argument(
'-k', '--api-key',
default=None,
help=f'Qdrant API key'
)
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""

View file

@ -27,8 +27,7 @@ class Processor(FlowProcessor):
host=params.get("cassandra_host"),
username=params.get("cassandra_username"),
password=params.get("cassandra_password"),
default_keyspace='knowledge',
replication_factor=params.get("cassandra_replication_factor"),
default_keyspace='knowledge'
)
super(Processor, self).__init__(

View file

@ -27,12 +27,12 @@ from qdrant_client.models import PointStruct, Distance, VectorParams
from .... schema import RowEmbeddings
from .... base import FlowProcessor, ConsumerSpec
from .... base import CollectionConfigHandler
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "row-embeddings-write"
default_store_uri = 'http://localhost:6333'
class Processor(CollectionConfigHandler, FlowProcessor):
@ -41,19 +41,13 @@ class Processor(CollectionConfigHandler, FlowProcessor):
id = params.get("id", default_ident)
store_uri = params.get("store_uri")
api_key = params.get("api_key")
url, api_key, replication_factor, shard_number = resolve_qdrant_config(
url=store_uri, api_key=api_key,
replication_factor=params.get("qdrant_replication_factor"),
shard_number=params.get("qdrant_shard_number"),
)
store_uri = params.get("store_uri", default_store_uri)
api_key = params.get("api_key", None)
super(Processor, self).__init__(
**params | {
"id": id,
"store_uri": url,
"store_uri": store_uri,
"api_key": api_key,
}
)
@ -69,9 +63,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
# Register config handler for collection management
self.register_config_handler(self.on_collection_config, types=["collection"])
self.qdrant = QdrantClient(url=url, api_key=api_key)
self.replication_factor = replication_factor
self.shard_number = shard_number
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self._cache_lock = asyncio.Lock()
self._known_collections: set[str] = set()
@ -111,8 +103,6 @@ class Processor(CollectionConfigHandler, FlowProcessor):
size=dimension,
distance=Distance.COSINE
),
replication_factor=self.replication_factor,
shard_number=self.shard_number,
)
self._known_collections.add(collection_name)
@ -259,9 +249,21 @@ class Processor(CollectionConfigHandler, FlowProcessor):
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
add_qdrant_args(parser)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant URI (default: {default_store_uri})'
)
parser.add_argument(
'-k', '--api-key',
default=None,
help='Qdrant API key (default: None)'
)
def run():

View file

@ -47,18 +47,16 @@ class Processor(CollectionConfigHandler, FlowProcessor):
cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback
hosts, username, password, keyspace, replication_factor = resolve_cassandra_config(
hosts, username, password, keyspace, _ = resolve_cassandra_config(
host=cassandra_host,
username=cassandra_username,
password=cassandra_password,
replication_factor=params.get("cassandra_replication_factor"),
password=cassandra_password
)
# Store resolved configuration with proper names
self.cassandra_host = hosts # Store as list
self.cassandra_username = username
self.cassandra_password = password
self.replication_factor = replication_factor
# Config key for schemas
self.config_key = params.get("config_type", "schema")
@ -172,7 +170,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False),
indexed=field_def.get("indexed", False)
)
fields.append(field)
@ -234,7 +232,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
CREATE KEYSPACE IF NOT EXISTS {safe_keyspace}
WITH REPLICATION = {{
'class': 'SimpleStrategy',
'replication_factor': {self.replication_factor}
'replication_factor': 1
}}
"""

View file

@ -80,14 +80,14 @@ def _set_exception_if_pending(fut, exc):
fut.set_exception(exc)
async def async_execute_paged(session, query, parameters=None, fetch_size=5000):
async def async_execute_paged(session, query, parameters=None, fetch_size=100):
"""Execute a CQL query with page-by-page iteration.
Uses synchronous session.execute() inside run_in_executor so that
the driver's ResultSet paging works correctly without materialising
the entire result set in memory.
Returns all pages as a list of lists.
Yields one page of rows at a time (as a list).
"""
loop = asyncio.get_running_loop()
@ -111,50 +111,3 @@ async def async_execute_paged(session, query, parameters=None, fetch_size=5000):
return await loop.run_in_executor(
None, _fetch_all_pages
)
async def async_scan(
session, query, parameters=None, row_filter=None,
limit=None, fetch_size=5000,
):
"""Scan a CQL query page-by-page, applying a filter and limit.
Only matching rows accumulate in memory. Each page is discarded
after processing, so peak memory is bounded by fetch_size plus
the number of matching rows (capped by limit).
Args:
session: cassandra.cluster.Session
query: CQL statement string
parameters: bind params
row_filter: callable(row) -> bool, or None to accept all
limit: max results to return, or None for unlimited
fetch_size: rows per Cassandra page fetch
Returns:
List of matching rows.
"""
loop = asyncio.get_running_loop()
if isinstance(query, str):
stmt = SimpleStatement(query, fetch_size=fetch_size)
else:
stmt = query
stmt.fetch_size = fetch_size
def _scan():
results = []
result_set = session.execute(stmt, parameters)
while True:
for row in result_set.current_rows:
if row_filter is None or row_filter(row):
results.append(row)
if limit and len(results) >= limit:
return results
if result_set.has_more_pages:
result_set.fetch_next_page()
else:
break
return results
return await loop.run_in_executor(None, _scan)

View file

@ -4,7 +4,7 @@ from .. schema import Metadata, GraphEmbeddings
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
import ssl
from ssl import SSLContext, PROTOCOL_TLSv1_2
import uuid
import time
@ -33,7 +33,7 @@ class ConfigTableStore:
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
if cassandra_username and cassandra_password:
ssl_context = ssl.create_default_context()
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
auth_provider = PlainTextAuthProvider(
username=cassandra_username, password=cassandra_password
)

View file

@ -15,7 +15,7 @@ import logging
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
import ssl
from ssl import SSLContext, PROTOCOL_TLSv1_2
from . cassandra_async import async_execute
@ -39,7 +39,7 @@ class IamTableStore:
cassandra_host = [h.strip() for h in cassandra_host.split(",")]
if cassandra_username and cassandra_password:
ssl_context = ssl.create_default_context()
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
auth_provider = PlainTextAuthProvider(
username=cassandra_username, password=cassandra_password,
)

View file

@ -23,7 +23,7 @@ def tuple_to_term(value, is_uri):
else:
return Term(type=LITERAL, value=value)
from cassandra.auth import PlainTextAuthProvider
import ssl
from ssl import SSLContext, PROTOCOL_TLSv1_2
import uuid
import time
@ -50,7 +50,7 @@ class KnowledgeTableStore:
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
if cassandra_username and cassandra_password:
ssl_context = ssl.create_default_context()
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
auth_provider = PlainTextAuthProvider(
username=cassandra_username, password=cassandra_password
)

View file

@ -24,7 +24,7 @@ from .. exceptions import RequestError
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.query import BatchStatement
import ssl
from ssl import SSLContext, PROTOCOL_TLSv1_2
import uuid
import time
@ -53,7 +53,7 @@ class LibraryTableStore:
cassandra_host = [h.strip() for h in cassandra_host.split(',')]
if cassandra_username and cassandra_password:
ssl_context = ssl.create_default_context()
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
auth_provider = PlainTextAuthProvider(
username=cassandra_username, password=cassandra_password
)

File diff suppressed because it is too large Load diff

View file

@ -1,110 +1,49 @@
from dataclasses import dataclass
from websockets.asyncio.client import connect
from urllib.parse import urlencode, urlparse, urlunparse, parse_qs
import asyncio
import logging
import json
import uuid
import hashlib
logger = logging.getLogger(__name__)
def _token_key(token):
"""Derive a dict key from a token without storing the raw secret."""
return hashlib.sha256(token.encode()).hexdigest()[:16]
import time
class WebSocketManager:
"""Manages an authenticated WebSocket connection to the TrustGraph
gateway on behalf of a single caller.
Each caller token gets its own WebSocketManager so that gateway-side
identity, workspace, and capability scoping are preserved end-to-end.
"""
def __init__(self, url, token):
def __init__(self, url, token=None):
self.url = url
# ── Security boundary: token storage ──
# This is the MCP caller's Bearer token, forwarded verbatim to
# the gateway. It MUST NOT be logged, persisted, or shared
# across callers. It is held only for the lifetime of this
# connection so that re-auth (e.g. after a reconnect) is
# possible.
self.token = token
self.socket = None
self.identity = None
self.last_used = None
# FIXME: authentication is broken. The /api/v1/socket endpoint uses
# in-band auth (first-frame protocol via the Mux dispatcher), not
# query-parameter tokens. This query-string token is silently ignored.
# Fix: after connect(), send an auth frame with the bearer token as
# the first message, matching the gateway's in-band auth protocol.
def _build_url(self):
if not self.token:
return self.url
parsed = urlparse(self.url)
params = parse_qs(parsed.query)
params["token"] = [self.token]
new_query = urlencode(params, doseq=True)
return urlunparse(parsed._replace(query=new_query))
async def start(self):
"""Connect and authenticate via the gateway's in-band auth
protocol. Raises on auth failure."""
# ── Security boundary: MCP server → gateway ──
# The WebSocket connects to the gateway and authenticates using
# the caller's Bearer token via the in-band first-frame auth
# protocol. The token belongs to the MCP client — we forward
# it as-is and never interpret its contents.
self.socket = await connect(self.url)
self.socket = await connect(self._build_url())
self.pending_requests = {}
self.running = True
await self._authenticate()
self.reader_task = asyncio.create_task(self.reader())
async def _authenticate(self):
"""Send in-band auth frame and wait for auth-ok / auth-failed.
The gateway expects ``{"type": "auth", "token": "..."}`` as the
first frame on a new WebSocket. Any service frame sent before
auth-ok is rejected.
"""
await self.socket.send(json.dumps({
"type": "auth",
"token": self.token,
}))
response_text = await asyncio.wait_for(self.socket.recv(), 10)
response = json.loads(response_text)
if response.get("type") == "auth-ok":
logger.info(
"WebSocket authenticated, default workspace: %s",
response.get("workspace"),
)
return
# Auth failed — close immediately, do not leave an
# unauthenticated socket open.
await self.socket.close()
self.socket = None
if response.get("type") == "auth-failed":
raise RuntimeError(
"Gateway rejected the authentication token"
)
raise RuntimeError(
f"Unexpected auth response type: {response.get('type')}"
)
async def whoami(self):
"""Verify the token by calling the gateway's whoami endpoint.
Returns the identity dict and caches it on ``self.identity``.
"""
gen = self.request("iam", {"operation": "whoami"}, flow_id=None)
async for response in gen:
self.identity = response
return response
async def stop(self):
self.running = False
if hasattr(self, "reader_task"):
await self.reader_task
await self.reader_task
async def reader(self):
"""Background task: read WebSocket frames and route them to the
correct pending-request queue by ``id``."""
"""
Background task to read websocket responses and route to correct
request
"""
while self.running:
try:
@ -120,21 +59,23 @@ class WebSocketManager:
request_id = response.get("id")
if request_id and request_id in self.pending_requests:
# Put the response in the queue
queue = self.pending_requests[request_id]
await queue.put(response)
else:
logger.warning(
"Response for unknown request ID: %s", request_id
logging.warning(
f"Response for unknown request ID: {request_id}"
)
except Exception as e:
logger.error("Error in websocket reader: %s", e)
logging.error(f"Error in websocket reader: {e}")
# Put error in all pending queues
for queue in self.pending_requests.values():
try:
await queue.put({"error": str(e)})
except Exception:
except:
pass
self.pending_requests.clear()
@ -145,29 +86,25 @@ class WebSocketManager:
async def request(
self, service, request_data, flow_id="default",
workspace=None,
):
"""Send a request via WebSocket and yield responses.
Args:
service: Gateway service name (e.g. "graph-rag", "config").
request_data: Inner request payload.
flow_id: Optional flow identifier. ``None`` omits the field
(workspace-level services don't use flows).
workspace: Optional workspace override. When ``None`` the
gateway uses the caller's default workspace.
"""
Send a request via websocket and handle single or streaming responses
"""
import time
self.last_used = time.monotonic()
# Generate unique request ID
request_id = f"{uuid.uuid4()}"
# Determine if this service streams responses
streaming_services = {"agent"}
is_streaming = service in streaming_services
# Create a queue for all responses (streaming and single)
response_queue = asyncio.Queue()
self.pending_requests[request_id] = response_queue
try:
# Build request message
message = {
"id": request_id,
"service": service,
@ -177,16 +114,7 @@ class WebSocketManager:
if flow_id is not None:
message["flow"] = flow_id
# ── Security boundary: workspace scoping ──
# When the caller supplies a workspace, we set it on the
# message envelope. The gateway's enforce_workspace()
# validates that the authenticated identity is permitted
# to access the target workspace — we MUST NOT skip or
# override that check. When workspace is None, the
# gateway default-fills from the identity's bound workspace.
if workspace is not None:
message["workspace"] = workspace
# Send request
await self.socket.send(json.dumps(message))
while self.running:
@ -199,17 +127,19 @@ class WebSocketManager:
continue
if "error" in response:
if isinstance(response["error"], dict):
raise RuntimeError(
response["error"].get("message", str(response["error"]))
)
if "message" in response["error"]:
raise RuntimeError(response["error"]["text"])
else:
raise RuntimeError(str(response["error"]))
yield response["response"]
if response.get("complete"):
break
if "complete" in response:
if response["complete"]:
break
finally:
except Exception as e:
# Clean up on error
self.pending_requests.pop(request_id, None)
raise e