Compare commits

..

No commits in common. "master" and "v2.5.12" have entirely different histories.

22 changed files with 1120 additions and 1465 deletions

View file

@ -11,11 +11,11 @@
<a href="https://trendshift.io/repositories/17291" target="_blank"><img src="https://trendshift.io/api/badge/repositories/17291" alt="trustgraph-ai%2Ftrustgraph | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a> <a href="https://trendshift.io/repositories/17291" target="_blank"><img src="https://trendshift.io/api/badge/repositories/17291" alt="trustgraph-ai%2Ftrustgraph | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
# The semantic deployment platform # The agent runtime platform
</div> </div>
TrustGraph is a comprehensive semantic infrastructure for agents built around context graphs — structured, queryable representations of your domain knowledge that ground every agent query in verified, explainable facts in private deployments with sovereign control. The platform is the full stack for agentic systems: context graphs, memory, retrieval, orchestration, and inference for deterministic agent workloads. TrustGraph is an agent runtime platform built around context graphs — structured, queryable representations of your domain knowledge that ground every agent query in verified, explainable facts in private deployments with sovereign control. The platform is the full stack for agentic systems: context graphs, memory, retrieval, orchestration, and inference for precision-critical agent workloads.
The platform: The platform:
- [x] Multi-model and multimodal database system - [x] Multi-model and multimodal database system
@ -99,21 +99,23 @@ For a browser based configuration, try the [Configuration Terminal](https://conf
- [**Developer APIs and CLI**](https://docs.trustgraph.ai/reference) - [**Developer APIs and CLI**](https://docs.trustgraph.ai/reference)
- [**Deployment Guides**](https://docs.trustgraph.ai/deployment) - [**Deployment Guides**](https://docs.trustgraph.ai/deployment)
## Context Graph UI ## Workbench
<img width="1389" height="961" alt="Image" src="https://github.com/user-attachments/assets/35c9250d-0f01-40cb-9294-1ee8fd9a1b56" /> The **Workbench** provides tools for all major features of TrustGraph. The **Workbench** is on port `8888` by default.
The UI provides tools for all major features of TrustGraph. The UI deploys on port `8888` by default. - **Vector Search**: Search the installed knowledge bases
- **Agentic, GraphRAG and LLM Chat**: Chat interface for agents, GraphRAG queries, or direct to LLMs
- **Agent Console** — Query your agents directly with streaming responses and live explainability event tracking, so you can watch reasoning unfold in real time - **Relationships**: Analyze deep relationships in the installed knowledge bases
- **GraphRAG View** — Interactive graph RAG queries with a visual explainability DAG and inline provenance display, making it easy to see exactly where answers came from - **Graph Visualizer**: 3D GraphViz of the installed knowledge bases
- **Context Explorer** — An interactive 3D context graph explorer with dynamic graph loading, BFS neighborhood extraction, edge pulse animation, and multiple navigation views - **Library**: Staging area for installing knowledge bases
- **Document Ingestion** — A complete upload and submission workflow with page and chunk inspection and document structure browsing - **Flow Classes**: Workflow preset configurations
- **Ontology Workbench** — A full ontology editor with class and property trees, OWL/XML and Turtle import/export with round-trip fidelity, circular dependency detection, and safe-delete confirmation dialogs - **Flows**: Create custom workflows and adjust LLM parameters during runtime
- **Schema Workbench** — Interactive schema management with list, create, edit, and delete operations including field and index management - **Knowledge Cores**: Manage resuable knowledge bases
- **Flow Management** — Flow creation and detail views with configurable parameters, temperature controls, and grouped storage layout - **Prompts**: Manage and adjust prompts during runtime
- **Workspace UX** — Workspace selection and management surfaced directly in the interface - **Schemas**: Define custom schemas for structured data knowledge bases
- **Prompt Editor** — A dedicated prompt editing workflow - **Ontologies**: Define custom ontologies for unstructured data knowledge bases
- **Agent Tools**: Define tools with collections, knowledge cores, MCP connections, and tool groups
- **MCP Tools**: Connect to MCP servers
## TypeScript Library for UIs ## TypeScript Library for UIs

View file

@ -410,56 +410,3 @@ class TestEdgeCases:
assert hosts == ['mixed-host'] assert hosts == ['mixed-host']
assert username is None # Stays None assert username is None # Stays None
assert password == 'mixed-pass' assert password == 'mixed-pass'
class TestReplicationFactorParamPath:
def test_explicit_kwarg(self):
with patch.dict(os.environ, {}, clear=True):
_, _, _, _, rf = resolve_cassandra_config(
replication_factor=3,
)
assert rf == 3
def test_kwarg_overrides_env(self):
with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
_, _, _, _, rf = resolve_cassandra_config(
replication_factor=3,
)
assert rf == 3
def test_env_fallback_when_kwarg_none(self):
with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
_, _, _, _, rf = resolve_cassandra_config(
replication_factor=None,
)
assert rf == 5
def test_default_when_no_kwarg_no_env(self):
with patch.dict(os.environ, {}, clear=True):
_, _, _, _, rf = resolve_cassandra_config()
assert rf == 1
def test_params_dict_path(self):
with patch.dict(os.environ, {}, clear=True):
params = {'cassandra_replication_factor': 3}
_, _, _, _, rf = resolve_cassandra_config(
replication_factor=params.get('cassandra_replication_factor'),
)
assert rf == 3
def test_params_dict_overrides_env(self):
with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
params = {'cassandra_replication_factor': 3}
_, _, _, _, rf = resolve_cassandra_config(
replication_factor=params.get('cassandra_replication_factor'),
)
assert rf == 3
def test_params_dict_missing_falls_to_env(self):
with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
params = {}
_, _, _, _, rf = resolve_cassandra_config(
replication_factor=params.get('cassandra_replication_factor'),
)
assert rf == 5

View file

@ -1,136 +0,0 @@
import os
import pytest
from unittest.mock import patch
from trustgraph.base.qdrant_config import (
get_qdrant_defaults,
resolve_qdrant_config,
)
class TestGetQdrantDefaults:
def test_defaults_with_no_env_vars(self):
with patch.dict(os.environ, {}, clear=True):
defaults = get_qdrant_defaults()
assert defaults['url'] == 'http://localhost:6333'
assert defaults['api_key'] is None
assert defaults['replication_factor'] == 1
assert defaults['shard_number'] == 1
def test_defaults_from_env(self):
env = {
'QDRANT_URL': 'http://qdrant:6333',
'QDRANT_API_KEY': 'secret',
'QDRANT_REPLICATION_FACTOR': '3',
'QDRANT_SHARD_NUMBER': '5',
}
with patch.dict(os.environ, env, clear=True):
defaults = get_qdrant_defaults()
assert defaults['url'] == 'http://qdrant:6333'
assert defaults['api_key'] == 'secret'
assert defaults['replication_factor'] == 3
assert defaults['shard_number'] == 5
class TestResolveQdrantConfig:
def test_defaults(self):
with patch.dict(os.environ, {}, clear=True):
url, api_key, rf, sn = resolve_qdrant_config()
assert url == 'http://localhost:6333'
assert api_key is None
assert rf == 1
assert sn == 1
def test_explicit_kwargs(self):
with patch.dict(os.environ, {}, clear=True):
url, api_key, rf, sn = resolve_qdrant_config(
url='http://custom:6333',
api_key='key',
replication_factor=3,
shard_number=5,
)
assert url == 'http://custom:6333'
assert api_key == 'key'
assert rf == 3
assert sn == 5
def test_kwargs_override_env(self):
env = {
'QDRANT_URL': 'http://env:6333',
'QDRANT_REPLICATION_FACTOR': '10',
'QDRANT_SHARD_NUMBER': '10',
}
with patch.dict(os.environ, env, clear=True):
url, _, rf, sn = resolve_qdrant_config(
url='http://explicit:6333',
replication_factor=3,
shard_number=5,
)
assert url == 'http://explicit:6333'
assert rf == 3
assert sn == 5
def test_env_fallback_when_kwargs_none(self):
env = {
'QDRANT_URL': 'http://env:6333',
'QDRANT_REPLICATION_FACTOR': '3',
'QDRANT_SHARD_NUMBER': '5',
}
with patch.dict(os.environ, env, clear=True):
url, _, rf, sn = resolve_qdrant_config()
assert url == 'http://env:6333'
assert rf == 3
assert sn == 5
def test_params_dict_path(self):
with patch.dict(os.environ, {}, clear=True):
params = {
'store_uri': 'http://params:6333',
'api_key': 'pkey',
'qdrant_replication_factor': 3,
'qdrant_shard_number': 5,
}
url, api_key, rf, sn = resolve_qdrant_config(
url=params.get('store_uri'),
api_key=params.get('api_key'),
replication_factor=params.get('qdrant_replication_factor'),
shard_number=params.get('qdrant_shard_number'),
)
assert url == 'http://params:6333'
assert api_key == 'pkey'
assert rf == 3
assert sn == 5
def test_params_dict_overrides_env(self):
env = {
'QDRANT_REPLICATION_FACTOR': '10',
'QDRANT_SHARD_NUMBER': '10',
}
with patch.dict(os.environ, env, clear=True):
params = {
'qdrant_replication_factor': 3,
'qdrant_shard_number': 5,
}
_, _, rf, sn = resolve_qdrant_config(
replication_factor=params.get('qdrant_replication_factor'),
shard_number=params.get('qdrant_shard_number'),
)
assert rf == 3
assert sn == 5
def test_params_dict_missing_falls_to_env(self):
env = {
'QDRANT_REPLICATION_FACTOR': '3',
'QDRANT_SHARD_NUMBER': '5',
}
with patch.dict(os.environ, env, clear=True):
params = {}
_, _, rf, sn = resolve_qdrant_config(
replication_factor=params.get('qdrant_replication_factor'),
shard_number=params.get('qdrant_shard_number'),
)
assert rf == 3
assert sn == 5

View file

@ -49,7 +49,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer): async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test successful PDF processing""" """Test successful PDF processing"""
# Mock PDF content # Mock PDF content
pdf_content = b"%PDF-1.7\nfake pdf content" pdf_content = b"fake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8') pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
# Mock PyPDFLoader # Mock PyPDFLoader
@ -88,55 +88,13 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
# Verify triples were sent for each page (provenance) # Verify triples were sent for each page (provenance)
assert mock_triples_flow.send.call_count == 2 assert mock_triples_flow.send.call_count == 2
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_rejects_librarian_content_that_is_not_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test rejecting non-PDF content before invoking the PDF loader"""
html_content = b"<html><body>Not found</body></html>"
html_base64 = base64.b64encode(html_content)
mock_metadata = Metadata(id="test-doc")
mock_document = Document(metadata=mock_metadata, document_id="doc-123")
mock_msg = MagicMock()
mock_msg.value.return_value = mock_document
mock_output_flow = AsyncMock()
mock_triples_flow = AsyncMock()
mock_flow = MagicMock(side_effect=lambda name: {
"output": mock_output_flow,
"triples": mock_triples_flow,
}.get(name))
mock_flow.librarian.fetch_document_metadata = AsyncMock(
return_value=MagicMock(kind="application/pdf")
)
mock_flow.librarian.fetch_document_content = AsyncMock(
return_value=html_base64
)
mock_flow.librarian.save_child_document = AsyncMock()
config = {
'id': 'test-pdf-decoder',
'taskgroup': AsyncMock()
}
processor = Processor(**config)
await processor.on_message(mock_msg, None, mock_flow)
mock_pdf_loader_class.assert_not_called()
mock_output_flow.send.assert_not_called()
mock_triples_flow.send.assert_not_called()
mock_flow.librarian.save_child_document.assert_not_called()
@patch('trustgraph.base.librarian_client.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader') @patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer): async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test handling of empty PDF""" """Test handling of empty PDF"""
pdf_content = b"%PDF-1.7\nfake pdf content" pdf_content = b"fake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8') pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
mock_loader = MagicMock() mock_loader = MagicMock()
@ -168,7 +126,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer): async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test handling of unicode content in PDF""" """Test handling of unicode content in PDF"""
pdf_content = b"%PDF-1.7\nfake pdf content" pdf_content = b"fake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8') pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
mock_loader = MagicMock() mock_loader = MagicMock()

View file

@ -333,8 +333,8 @@ class TestUnifiedTableQueries:
"""Test queries against the unified rows table""" """Test queries against the unified rows table"""
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('trustgraph.query.rows.cassandra.service.async_execute_paged', new_callable=AsyncMock) @patch('trustgraph.query.rows.cassandra.service.async_execute', new_callable=AsyncMock)
async def test_query_with_index_match(self, mock_async_execute_paged): async def test_query_with_index_match(self, mock_async_execute):
"""Test query execution with matching index""" """Test query execution with matching index"""
processor = MagicMock() processor = MagicMock()
processor.session = MagicMock() processor.session = MagicMock()
@ -344,10 +344,10 @@ class TestUnifiedTableQueries:
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor) processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor) processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
# Mock async_execute_paged to return test data (list of pages) # Mock async_execute to return test data
mock_row = MagicMock() mock_row = MagicMock()
mock_row.data = {"id": "123", "name": "Test Product", "category": "electronics"} mock_row.data = {"id": "123", "name": "Test Product", "category": "electronics"}
mock_async_execute_paged.return_value = [[mock_row]] mock_async_execute.return_value = [mock_row]
schema = RowSchema( schema = RowSchema(
name="products", name="products",
@ -370,10 +370,10 @@ class TestUnifiedTableQueries:
# Verify Cassandra was connected and queried # Verify Cassandra was connected and queried
processor.connect_cassandra.assert_called_once() processor.connect_cassandra.assert_called_once()
mock_async_execute_paged.assert_called_once() mock_async_execute.assert_called_once()
# Verify query structure - should query unified rows table # Verify query structure - should query unified rows table
call_args = mock_async_execute_paged.call_args call_args = mock_async_execute.call_args
query = call_args[0][1] query = call_args[0][1]
params = call_args[0][2] params = call_args[0][2]
@ -394,8 +394,8 @@ class TestUnifiedTableQueries:
assert results[0]["category"] == "electronics" assert results[0]["category"] == "electronics"
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('trustgraph.query.rows.cassandra.service.async_scan', new_callable=AsyncMock) @patch('trustgraph.query.rows.cassandra.service.async_execute', new_callable=AsyncMock)
async def test_query_without_index_match(self, mock_async_scan): async def test_query_without_index_match(self, mock_async_execute):
"""Test query execution without matching index (scan mode)""" """Test query execution without matching index (scan mode)"""
processor = MagicMock() processor = MagicMock()
processor.session = MagicMock() processor.session = MagicMock()
@ -406,10 +406,12 @@ class TestUnifiedTableQueries:
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor) processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor) processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
# Mock async_scan to return filtered test data # Mock async_execute to return test data
mock_row1 = MagicMock() mock_row1 = MagicMock()
mock_row1.data = {"id": "1", "name": "Product A", "price": "100"} mock_row1.data = {"id": "1", "name": "Product A", "price": "100"}
mock_async_scan.return_value = [mock_row1] mock_row2 = MagicMock()
mock_row2.data = {"id": "2", "name": "Product B", "price": "200"}
mock_async_execute.return_value = [mock_row1, mock_row2]
schema = RowSchema( schema = RowSchema(
name="products", name="products",
@ -430,16 +432,13 @@ class TestUnifiedTableQueries:
limit=10 limit=10
) )
# Verify async_scan was called # Query should use ALLOW FILTERING for scan
mock_async_scan.assert_called_once() call_args = mock_async_execute.call_args
# Verify query structure
call_args = mock_async_scan.call_args
query = call_args[0][1] query = call_args[0][1]
assert "ALLOW FILTERING" in query assert "ALLOW FILTERING" in query
# Should return filtered results # Should post-filter results
assert len(results) == 1 assert len(results) == 1
assert results[0]["name"] == "Product A" assert results[0]["name"] == "Product A"

View file

@ -103,19 +103,35 @@ def resolve_cassandra_config(
host: Optional[str] = None, host: Optional[str] = None,
username: Optional[str] = None, username: Optional[str] = None,
password: Optional[str] = None, password: Optional[str] = None,
default_keyspace: Optional[str] = None, default_keyspace: Optional[str] = None
replication_factor: Optional[int] = None,
) -> Tuple[List[str], Optional[str], Optional[str], Optional[str], int]: ) -> Tuple[List[str], Optional[str], Optional[str], Optional[str], int]:
"""
Resolve Cassandra configuration from various sources.
Can accept either argparse args object or explicit parameters.
Converts host string to list format for Cassandra driver.
Args:
args: Optional argparse namespace with cassandra_host, cassandra_username, cassandra_password, cassandra_keyspace, cassandra_replication_factor
host: Optional explicit host parameter (overrides args)
username: Optional explicit username parameter (overrides args)
password: Optional explicit password parameter (overrides args)
default_keyspace: Optional default keyspace if not specified elsewhere
Returns:
tuple: (hosts_list, username, password, keyspace, replication_factor)
"""
# If args provided, extract values
keyspace = None keyspace = None
replication_factor = 1
if args is not None: if args is not None:
host = host or getattr(args, 'cassandra_host', None) host = host or getattr(args, 'cassandra_host', None)
username = username or getattr(args, 'cassandra_username', None) username = username or getattr(args, 'cassandra_username', None)
password = password or getattr(args, 'cassandra_password', None) password = password or getattr(args, 'cassandra_password', None)
keyspace = getattr(args, 'cassandra_keyspace', None) keyspace = getattr(args, 'cassandra_keyspace', None)
replication_factor = replication_factor or getattr( replication_factor = getattr(args, 'cassandra_replication_factor', 1)
args, 'cassandra_replication_factor', None
)
# Apply defaults if still None
defaults = get_cassandra_defaults() defaults = get_cassandra_defaults()
host = host or defaults['host'] host = host or defaults['host']
username = username or defaults['username'] username = username or defaults['username']

View file

@ -78,7 +78,7 @@ def load_structured_data(
logger.info("Step 1: Analyzing data to discover best matching schema...") logger.info("Step 1: Analyzing data to discover best matching schema...")
# Step 1: Auto-discover schema (reuse discover_schema logic) # Step 1: Auto-discover schema (reuse discover_schema logic)
discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, token=token, workspace=workspace) discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace)
if not discovered_schema: if not discovered_schema:
logger.error("Failed to discover suitable schema automatically") logger.error("Failed to discover suitable schema automatically")
print("❌ Could not automatically determine the best schema for your data.") print("❌ Could not automatically determine the best schema for your data.")
@ -90,7 +90,7 @@ def load_structured_data(
# Step 2: Auto-generate descriptor # Step 2: Auto-generate descriptor
logger.info("Step 2: Generating descriptor configuration...") logger.info("Step 2: Generating descriptor configuration...")
auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, token=token, workspace=workspace) auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, workspace=workspace)
if not auto_descriptor: if not auto_descriptor:
logger.error("Failed to generate descriptor automatically") logger.error("Failed to generate descriptor automatically")
print("❌ Could not automatically generate descriptor configuration.") print("❌ Could not automatically generate descriptor configuration.")
@ -172,7 +172,7 @@ def load_structured_data(
logger.info(f"Sample chars: {sample_chars} characters") logger.info(f"Sample chars: {sample_chars} characters")
# Use the helper function to discover schema (get raw response for display) # Use the helper function to discover schema (get raw response for display)
response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, token=token, workspace=workspace) response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, workspace=workspace)
if response: if response:
# Debug: print response type and content # Debug: print response type and content
@ -203,7 +203,7 @@ def load_structured_data(
# If no schema specified, discover it first # If no schema specified, discover it first
if not schema_name: if not schema_name:
logger.info("No schema specified, auto-discovering...") logger.info("No schema specified, auto-discovering...")
schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, token=token, workspace=workspace) schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace)
if not schema_name: if not schema_name:
print("Error: Could not determine schema automatically.") print("Error: Could not determine schema automatically.")
print("Please specify a schema using --schema-name or run --discover-schema first.") print("Please specify a schema using --schema-name or run --discover-schema first.")
@ -213,7 +213,7 @@ def load_structured_data(
logger.info(f"Target schema: {schema_name}") logger.info(f"Target schema: {schema_name}")
# Generate descriptor using helper function # Generate descriptor using helper function
descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, token=token, workspace=workspace) descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace=workspace)
if descriptor: if descriptor:
# Output the generated descriptor # Output the generated descriptor
@ -603,7 +603,7 @@ def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None, worksp
# Helper functions for auto mode # Helper functions for auto mode
def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, token=None, workspace="default"): def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, workspace="default"):
"""Auto-discover the best matching schema for the input data """Auto-discover the best matching schema for the input data
Args: Args:
@ -626,7 +626,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur
# Import API modules # Import API modules
from trustgraph.api import Api from trustgraph.api import Api
from trustgraph.api.types import ConfigKey from trustgraph.api.types import ConfigKey
api = Api(api_url, token=token, workspace=workspace) api = Api(api_url, workspace=workspace)
config_api = api.config() config_api = api.config()
# Get available schemas # Get available schemas
@ -707,7 +707,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur
return None return None
def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, token=None, workspace="default"): def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace="default"):
"""Auto-generate descriptor configuration for the discovered schema""" """Auto-generate descriptor configuration for the discovered schema"""
try: try:
# Read sample data # Read sample data
@ -717,7 +717,7 @@ def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, fl
# Import API modules # Import API modules
from trustgraph.api import Api from trustgraph.api import Api
from trustgraph.api.types import ConfigKey from trustgraph.api.types import ConfigKey
api = Api(api_url, token=token, workspace=workspace) api = Api(api_url, workspace=workspace)
config_api = api.config() config_api = api.config()
# Get schema definition # Get schema definition

View file

@ -83,8 +83,7 @@ class Processor(AsyncProcessor):
host=cassandra_host, host=cassandra_host,
username=cassandra_username, username=cassandra_username,
password=cassandra_password, password=cassandra_password,
default_keyspace="config", default_keyspace="config"
replication_factor=params.get("cassandra_replication_factor"),
) )
# Store resolved configuration # Store resolved configuration

View file

@ -61,8 +61,7 @@ class Processor(WorkspaceProcessor):
host=cassandra_host, host=cassandra_host,
username=cassandra_username, username=cassandra_username,
password=cassandra_password, password=cassandra_password,
default_keyspace="knowledge", default_keyspace="knowledge"
replication_factor=params.get("cassandra_replication_factor"),
) )
self.cassandra_host = hosts self.cassandra_host = hosts

View file

@ -32,10 +32,6 @@ logger = logging.getLogger(__name__)
default_ident = "document-decoder" default_ident = "document-decoder"
def _looks_like_pdf(content):
return content.lstrip().startswith(b"%PDF-")
class Processor(FlowProcessor): class Processor(FlowProcessor):
def __init__(self, **params): def __init__(self, **params):
@ -98,37 +94,33 @@ class Processor(FlowProcessor):
) )
return return
# Check if we should fetch from librarian or use inline data with tempfile.NamedTemporaryFile(delete_on_close=False, suffix='.pdf') as fp:
if v.document_id:
# Fetch from librarian via Pulsar
logger.info(f"Fetching document {v.document_id} from librarian...")
content = await flow.librarian.fetch_document_content(
document_id=v.document_id,
)
# Content is base64 encoded
if isinstance(content, str):
content = content.encode('utf-8')
decoded_content = base64.b64decode(content)
logger.info(f"Fetched {len(decoded_content)} bytes from librarian")
else:
# Use inline data (backward compatibility)
decoded_content = base64.b64decode(v.data)
if not _looks_like_pdf(decoded_content):
logger.error(
f"Document {v.metadata.id} is not valid PDF content. "
f"Ignoring document."
)
return
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as fp:
temp_path = fp.name temp_path = fp.name
fp.write(decoded_content)
fp.close() # Check if we should fetch from librarian or use inline data
if v.document_id:
# Fetch from librarian via Pulsar
logger.info(f"Fetching document {v.document_id} from librarian...")
fp.close()
content = await flow.librarian.fetch_document_content(
document_id=v.document_id,
)
# Content is base64 encoded
if isinstance(content, str):
content = content.encode('utf-8')
decoded_content = base64.b64decode(content)
with open(temp_path, 'wb') as f:
f.write(decoded_content)
logger.info(f"Fetched {len(decoded_content)} bytes from librarian")
else:
# Use inline data (backward compatibility)
fp.write(base64.b64decode(v.data))
fp.close()
global PyPDFLoader global PyPDFLoader
if PyPDFLoader is None: if PyPDFLoader is None:

View file

@ -101,7 +101,6 @@ class Processor(AsyncProcessor):
username=cassandra_username, username=cassandra_username,
password=cassandra_password, password=cassandra_password,
default_keyspace="iam", default_keyspace="iam",
replication_factor=params.get("cassandra_replication_factor"),
) )
self.cassandra_host = hosts self.cassandra_host = hosts

View file

@ -146,8 +146,7 @@ class Processor(WorkspaceProcessor):
host=cassandra_host, host=cassandra_host,
username=cassandra_username, username=cassandra_username,
password=cassandra_password, password=cassandra_password,
default_keyspace="librarian", default_keyspace="librarian"
replication_factor=params.get("cassandra_replication_factor"),
) )
# Store resolved configuration # Store resolved configuration

View file

@ -27,8 +27,7 @@ class Processor(DocumentEmbeddingsQueryService):
api_key = params.get("api_key") api_key = params.get("api_key")
url, api_key, _, _ = resolve_qdrant_config( url, api_key, _, _ = resolve_qdrant_config(
url=store_uri, url=store_uri, api_key=api_key,
api_key=api_key,
) )
super(Processor, self).__init__( super(Processor, self).__init__(

View file

@ -24,7 +24,7 @@ from .... schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
from .... schema import Error, RowSchema, Field as SchemaField from .... schema import Error, RowSchema, Field as SchemaField
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
from .... tables.cassandra_async import async_execute, async_execute_paged, async_scan from .... tables.cassandra_async import async_execute
from ... graphql import GraphQLSchemaBuilder, SortDirection from ... graphql import GraphQLSchemaBuilder, SortDirection
@ -180,7 +180,7 @@ class Processor(FlowProcessor):
description=field_def.get("description", ""), description=field_def.get("description", ""),
required=field_def.get("required", False), required=field_def.get("required", False),
enum_values=field_def.get("enum", []), enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False), indexed=field_def.get("indexed", False)
) )
fields.append(field) fields.append(field)
@ -232,8 +232,6 @@ class Processor(FlowProcessor):
for index_name in index_names: for index_name in index_names:
if index_name in filters: if index_name in filters:
value = filters[index_name] value = filters[index_name]
if value == "" or value is None:
continue
# Single field index -> single element list # Single field index -> single element list
index_value = [str(value)] index_value = [str(value)]
return (index_name, index_value) return (index_name, index_value)
@ -284,13 +282,11 @@ class Processor(FlowProcessor):
query += f" LIMIT {limit}" query += f" LIMIT {limit}"
try: try:
pages = await async_execute_paged( rows = await async_execute(self.session, query, params)
self.session, query, params for row in rows:
) # Convert data map to dict with proper field names
for page in pages: row_dict = dict(row.data) if row.data else {}
for row in page: results.append(row_dict)
row_dict = dict(row.data) if row.data else {}
results.append(row_dict)
except Exception as e: except Exception as e:
logger.error(f"Failed to query rows: {e}", exc_info=True) logger.error(f"Failed to query rows: {e}", exc_info=True)
raise raise
@ -312,6 +308,8 @@ class Processor(FlowProcessor):
# Query using the first index (arbitrary choice for scan) # Query using the first index (arbitrary choice for scan)
primary_index = index_names[0] primary_index = index_names[0]
# We need to scan all values for this index
# This requires ALLOW FILTERING or a different approach
query = f""" query = f"""
SELECT data, source FROM {safe_keyspace}.rows SELECT data, source FROM {safe_keyspace}.rows
WHERE collection = %s WHERE collection = %s
@ -322,18 +320,17 @@ class Processor(FlowProcessor):
params = [collection, schema_name, primary_index] params = [collection, schema_name, primary_index]
try: try:
def row_filter(row): rows = await async_execute(self.session, query, params)
row_dict = dict(row.data) if row.data else {}
return self._matches_filters(row_dict, filters, row_schema)
matched_rows = await async_scan( for row in rows:
self.session, query, params,
row_filter=row_filter,
limit=limit,
)
for row in matched_rows:
row_dict = dict(row.data) if row.data else {} row_dict = dict(row.data) if row.data else {}
results.append(row_dict)
# Apply post-filters
if self._matches_filters(row_dict, filters, row_schema):
results.append(row_dict)
if limit and len(results) >= limit:
break
except Exception as e: except Exception as e:
logger.error(f"Failed to scan rows: {e}", exc_info=True) logger.error(f"Failed to scan rows: {e}", exc_info=True)
@ -366,7 +363,7 @@ class Processor(FlowProcessor):
# Parse filter key for operator # Parse filter key for operator
if '_' in filter_key: if '_' in filter_key:
parts = filter_key.rsplit('_', 1) parts = filter_key.rsplit('_', 1)
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in', 'not', 'startsWith', 'endsWith', 'not_in']: if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in']:
field_name = parts[0] field_name = parts[0]
operator = parts[1] operator = parts[1]
else: else:
@ -403,18 +400,6 @@ class Processor(FlowProcessor):
elif operator == 'in': elif operator == 'in':
if str(row_value) not in [str(v) for v in filter_value]: if str(row_value) not in [str(v) for v in filter_value]:
return False return False
elif operator == 'not':
if str(row_value) == str(filter_value):
return False
elif operator == 'startsWith':
if not str(row_value).startswith(str(filter_value)):
return False
elif operator == 'endsWith':
if not str(row_value).endswith(str(filter_value)):
return False
elif operator == 'not_in':
if str(row_value) in [str(v) for v in filter_value]:
return False
except (ValueError, TypeError): except (ValueError, TypeError):
return False return False

View file

@ -30,8 +30,6 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
url, api_key, replication_factor, shard_number = resolve_qdrant_config( url, api_key, replication_factor, shard_number = resolve_qdrant_config(
url=store_uri, api_key=api_key, url=store_uri, api_key=api_key,
replication_factor=params.get("qdrant_replication_factor"),
shard_number=params.get("qdrant_shard_number"),
) )
super(Processor, self).__init__( super(Processor, self).__init__(

View file

@ -44,8 +44,6 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
url, api_key, replication_factor, shard_number = resolve_qdrant_config( url, api_key, replication_factor, shard_number = resolve_qdrant_config(
url=store_uri, api_key=api_key, url=store_uri, api_key=api_key,
replication_factor=params.get("qdrant_replication_factor"),
shard_number=params.get("qdrant_shard_number"),
) )
super(Processor, self).__init__( super(Processor, self).__init__(

View file

@ -27,8 +27,7 @@ class Processor(FlowProcessor):
host=params.get("cassandra_host"), host=params.get("cassandra_host"),
username=params.get("cassandra_username"), username=params.get("cassandra_username"),
password=params.get("cassandra_password"), password=params.get("cassandra_password"),
default_keyspace='knowledge', default_keyspace='knowledge'
replication_factor=params.get("cassandra_replication_factor"),
) )
super(Processor, self).__init__( super(Processor, self).__init__(

View file

@ -46,8 +46,6 @@ class Processor(CollectionConfigHandler, FlowProcessor):
url, api_key, replication_factor, shard_number = resolve_qdrant_config( url, api_key, replication_factor, shard_number = resolve_qdrant_config(
url=store_uri, api_key=api_key, url=store_uri, api_key=api_key,
replication_factor=params.get("qdrant_replication_factor"),
shard_number=params.get("qdrant_shard_number"),
) )
super(Processor, self).__init__( super(Processor, self).__init__(

View file

@ -50,8 +50,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
hosts, username, password, keyspace, replication_factor = resolve_cassandra_config( hosts, username, password, keyspace, replication_factor = resolve_cassandra_config(
host=cassandra_host, host=cassandra_host,
username=cassandra_username, username=cassandra_username,
password=cassandra_password, password=cassandra_password
replication_factor=params.get("cassandra_replication_factor"),
) )
# Store resolved configuration with proper names # Store resolved configuration with proper names
@ -172,7 +171,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
description=field_def.get("description", ""), description=field_def.get("description", ""),
required=field_def.get("required", False), required=field_def.get("required", False),
enum_values=field_def.get("enum", []), enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False), indexed=field_def.get("indexed", False)
) )
fields.append(field) fields.append(field)

View file

@ -80,14 +80,14 @@ def _set_exception_if_pending(fut, exc):
fut.set_exception(exc) fut.set_exception(exc)
async def async_execute_paged(session, query, parameters=None, fetch_size=5000): async def async_execute_paged(session, query, parameters=None, fetch_size=100):
"""Execute a CQL query with page-by-page iteration. """Execute a CQL query with page-by-page iteration.
Uses synchronous session.execute() inside run_in_executor so that Uses synchronous session.execute() inside run_in_executor so that
the driver's ResultSet paging works correctly without materialising the driver's ResultSet paging works correctly without materialising
the entire result set in memory. the entire result set in memory.
Returns all pages as a list of lists. Yields one page of rows at a time (as a list).
""" """
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
@ -111,50 +111,3 @@ async def async_execute_paged(session, query, parameters=None, fetch_size=5000):
return await loop.run_in_executor( return await loop.run_in_executor(
None, _fetch_all_pages None, _fetch_all_pages
) )
async def async_scan(
session, query, parameters=None, row_filter=None,
limit=None, fetch_size=5000,
):
"""Scan a CQL query page-by-page, applying a filter and limit.
Only matching rows accumulate in memory. Each page is discarded
after processing, so peak memory is bounded by fetch_size plus
the number of matching rows (capped by limit).
Args:
session: cassandra.cluster.Session
query: CQL statement string
parameters: bind params
row_filter: callable(row) -> bool, or None to accept all
limit: max results to return, or None for unlimited
fetch_size: rows per Cassandra page fetch
Returns:
List of matching rows.
"""
loop = asyncio.get_running_loop()
if isinstance(query, str):
stmt = SimpleStatement(query, fetch_size=fetch_size)
else:
stmt = query
stmt.fetch_size = fetch_size
def _scan():
results = []
result_set = session.execute(stmt, parameters)
while True:
for row in result_set.current_rows:
if row_filter is None or row_filter(row):
results.append(row)
if limit and len(results) >= limit:
return results
if result_set.has_more_pages:
result_set.fetch_next_page()
else:
break
return results
return await loop.run_in_executor(None, _scan)

File diff suppressed because it is too large Load diff

View file

@ -1,110 +1,49 @@
from dataclasses import dataclass
from websockets.asyncio.client import connect from websockets.asyncio.client import connect
from urllib.parse import urlencode, urlparse, urlunparse, parse_qs
import asyncio import asyncio
import logging import logging
import json import json
import uuid import uuid
import hashlib import time
logger = logging.getLogger(__name__)
def _token_key(token):
"""Derive a dict key from a token without storing the raw secret."""
return hashlib.sha256(token.encode()).hexdigest()[:16]
class WebSocketManager: class WebSocketManager:
"""Manages an authenticated WebSocket connection to the TrustGraph
gateway on behalf of a single caller.
Each caller token gets its own WebSocketManager so that gateway-side def __init__(self, url, token=None):
identity, workspace, and capability scoping are preserved end-to-end.
"""
def __init__(self, url, token):
self.url = url self.url = url
# ── Security boundary: token storage ──
# This is the MCP caller's Bearer token, forwarded verbatim to
# the gateway. It MUST NOT be logged, persisted, or shared
# across callers. It is held only for the lifetime of this
# connection so that re-auth (e.g. after a reconnect) is
# possible.
self.token = token self.token = token
self.socket = None self.socket = None
self.identity = None
self.last_used = None # FIXME: authentication is broken. The /api/v1/socket endpoint uses
# in-band auth (first-frame protocol via the Mux dispatcher), not
# query-parameter tokens. This query-string token is silently ignored.
# Fix: after connect(), send an auth frame with the bearer token as
# the first message, matching the gateway's in-band auth protocol.
def _build_url(self):
if not self.token:
return self.url
parsed = urlparse(self.url)
params = parse_qs(parsed.query)
params["token"] = [self.token]
new_query = urlencode(params, doseq=True)
return urlunparse(parsed._replace(query=new_query))
async def start(self): async def start(self):
"""Connect and authenticate via the gateway's in-band auth self.socket = await connect(self._build_url())
protocol. Raises on auth failure."""
# ── Security boundary: MCP server → gateway ──
# The WebSocket connects to the gateway and authenticates using
# the caller's Bearer token via the in-band first-frame auth
# protocol. The token belongs to the MCP client — we forward
# it as-is and never interpret its contents.
self.socket = await connect(self.url)
self.pending_requests = {} self.pending_requests = {}
self.running = True self.running = True
await self._authenticate()
self.reader_task = asyncio.create_task(self.reader()) self.reader_task = asyncio.create_task(self.reader())
async def _authenticate(self):
"""Send in-band auth frame and wait for auth-ok / auth-failed.
The gateway expects ``{"type": "auth", "token": "..."}`` as the
first frame on a new WebSocket. Any service frame sent before
auth-ok is rejected.
"""
await self.socket.send(json.dumps({
"type": "auth",
"token": self.token,
}))
response_text = await asyncio.wait_for(self.socket.recv(), 10)
response = json.loads(response_text)
if response.get("type") == "auth-ok":
logger.info(
"WebSocket authenticated, default workspace: %s",
response.get("workspace"),
)
return
# Auth failed — close immediately, do not leave an
# unauthenticated socket open.
await self.socket.close()
self.socket = None
if response.get("type") == "auth-failed":
raise RuntimeError(
"Gateway rejected the authentication token"
)
raise RuntimeError(
f"Unexpected auth response type: {response.get('type')}"
)
async def whoami(self):
"""Verify the token by calling the gateway's whoami endpoint.
Returns the identity dict and caches it on ``self.identity``.
"""
gen = self.request("iam", {"operation": "whoami"}, flow_id=None)
async for response in gen:
self.identity = response
return response
async def stop(self): async def stop(self):
self.running = False self.running = False
if hasattr(self, "reader_task"): await self.reader_task
await self.reader_task
async def reader(self): async def reader(self):
"""Background task: read WebSocket frames and route them to the """
correct pending-request queue by ``id``.""" Background task to read websocket responses and route to correct
request
"""
while self.running: while self.running:
try: try:
@ -120,21 +59,23 @@ class WebSocketManager:
request_id = response.get("id") request_id = response.get("id")
if request_id and request_id in self.pending_requests: if request_id and request_id in self.pending_requests:
# Put the response in the queue
queue = self.pending_requests[request_id] queue = self.pending_requests[request_id]
await queue.put(response) await queue.put(response)
else: else:
logger.warning( logging.warning(
"Response for unknown request ID: %s", request_id f"Response for unknown request ID: {request_id}"
) )
except Exception as e: except Exception as e:
logger.error("Error in websocket reader: %s", e) logging.error(f"Error in websocket reader: {e}")
# Put error in all pending queues
for queue in self.pending_requests.values(): for queue in self.pending_requests.values():
try: try:
await queue.put({"error": str(e)}) await queue.put({"error": str(e)})
except Exception: except:
pass pass
self.pending_requests.clear() self.pending_requests.clear()
@ -145,29 +86,25 @@ class WebSocketManager:
async def request( async def request(
self, service, request_data, flow_id="default", self, service, request_data, flow_id="default",
workspace=None,
): ):
"""Send a request via WebSocket and yield responses. """
Send a request via websocket and handle single or streaming responses
Args:
service: Gateway service name (e.g. "graph-rag", "config").
request_data: Inner request payload.
flow_id: Optional flow identifier. ``None`` omits the field
(workspace-level services don't use flows).
workspace: Optional workspace override. When ``None`` the
gateway uses the caller's default workspace.
""" """
import time # Generate unique request ID
self.last_used = time.monotonic()
request_id = f"{uuid.uuid4()}" request_id = f"{uuid.uuid4()}"
# Determine if this service streams responses
streaming_services = {"agent"}
is_streaming = service in streaming_services
# Create a queue for all responses (streaming and single)
response_queue = asyncio.Queue() response_queue = asyncio.Queue()
self.pending_requests[request_id] = response_queue self.pending_requests[request_id] = response_queue
try: try:
# Build request message
message = { message = {
"id": request_id, "id": request_id,
"service": service, "service": service,
@ -177,16 +114,7 @@ class WebSocketManager:
if flow_id is not None: if flow_id is not None:
message["flow"] = flow_id message["flow"] = flow_id
# ── Security boundary: workspace scoping ── # Send request
# When the caller supplies a workspace, we set it on the
# message envelope. The gateway's enforce_workspace()
# validates that the authenticated identity is permitted
# to access the target workspace — we MUST NOT skip or
# override that check. When workspace is None, the
# gateway default-fills from the identity's bound workspace.
if workspace is not None:
message["workspace"] = workspace
await self.socket.send(json.dumps(message)) await self.socket.send(json.dumps(message))
while self.running: while self.running:
@ -199,17 +127,19 @@ class WebSocketManager:
continue continue
if "error" in response: if "error" in response:
if isinstance(response["error"], dict): if "message" in response["error"]:
raise RuntimeError( raise RuntimeError(response["error"]["text"])
response["error"].get("message", str(response["error"]))
)
else: else:
raise RuntimeError(str(response["error"])) raise RuntimeError(str(response["error"]))
yield response["response"] yield response["response"]
if response.get("complete"): if "complete" in response:
break if response["complete"]:
break
finally: except Exception as e:
# Clean up on error
self.pending_requests.pop(request_id, None) self.pending_requests.pop(request_id, None)
raise e