mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-11 07:45:13 +02:00
Compare commits
11 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
627c669097 | ||
|
|
81d57826c8 | ||
|
|
79d7ef6a90 | ||
|
|
28a51c244f | ||
|
|
fa5ebe2393 | ||
|
|
e1c9351454 | ||
|
|
dbc21c0bb9 | ||
|
|
97453d9b83 | ||
|
|
6dfa47aac8 | ||
|
|
dcee842455 | ||
|
|
36eadbda3a |
10 changed files with 1247 additions and 1088 deletions
32
README.md
32
README.md
|
|
@ -11,11 +11,11 @@
|
||||||
|
|
||||||
<a href="https://trendshift.io/repositories/17291" target="_blank"><img src="https://trendshift.io/api/badge/repositories/17291" alt="trustgraph-ai%2Ftrustgraph | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
<a href="https://trendshift.io/repositories/17291" target="_blank"><img src="https://trendshift.io/api/badge/repositories/17291" alt="trustgraph-ai%2Ftrustgraph | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||||
|
|
||||||
# The agent runtime platform
|
# The semantic deployment platform
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
TrustGraph is an agent runtime platform built around context graphs — structured, queryable representations of your domain knowledge that ground every agent query in verified, explainable facts in private deployments with sovereign control. The platform is the full stack for agentic systems: context graphs, memory, retrieval, orchestration, and inference for precision-critical agent workloads.
|
TrustGraph is a comprehensive semantic infrastructure for agents built around context graphs — structured, queryable representations of your domain knowledge that ground every agent query in verified, explainable facts in private deployments with sovereign control. The platform is the full stack for agentic systems: context graphs, memory, retrieval, orchestration, and inference for deterministic agent workloads.
|
||||||
|
|
||||||
The platform:
|
The platform:
|
||||||
- [x] Multi-model and multimodal database system
|
- [x] Multi-model and multimodal database system
|
||||||
|
|
@ -99,23 +99,21 @@ For a browser based configuration, try the [Configuration Terminal](https://conf
|
||||||
- [**Developer APIs and CLI**](https://docs.trustgraph.ai/reference)
|
- [**Developer APIs and CLI**](https://docs.trustgraph.ai/reference)
|
||||||
- [**Deployment Guides**](https://docs.trustgraph.ai/deployment)
|
- [**Deployment Guides**](https://docs.trustgraph.ai/deployment)
|
||||||
|
|
||||||
## Workbench
|
## Context Graph UI
|
||||||
|
|
||||||
The **Workbench** provides tools for all major features of TrustGraph. The **Workbench** is on port `8888` by default.
|
<img width="1389" height="961" alt="Image" src="https://github.com/user-attachments/assets/35c9250d-0f01-40cb-9294-1ee8fd9a1b56" />
|
||||||
|
|
||||||
- **Vector Search**: Search the installed knowledge bases
|
The UI provides tools for all major features of TrustGraph. The UI deploys on port `8888` by default.
|
||||||
- **Agentic, GraphRAG and LLM Chat**: Chat interface for agents, GraphRAG queries, or direct to LLMs
|
|
||||||
- **Relationships**: Analyze deep relationships in the installed knowledge bases
|
- **Agent Console** — Query your agents directly with streaming responses and live explainability event tracking, so you can watch reasoning unfold in real time
|
||||||
- **Graph Visualizer**: 3D GraphViz of the installed knowledge bases
|
- **GraphRAG View** — Interactive graph RAG queries with a visual explainability DAG and inline provenance display, making it easy to see exactly where answers came from
|
||||||
- **Library**: Staging area for installing knowledge bases
|
- **Context Explorer** — An interactive 3D context graph explorer with dynamic graph loading, BFS neighborhood extraction, edge pulse animation, and multiple navigation views
|
||||||
- **Flow Classes**: Workflow preset configurations
|
- **Document Ingestion** — A complete upload and submission workflow with page and chunk inspection and document structure browsing
|
||||||
- **Flows**: Create custom workflows and adjust LLM parameters during runtime
|
- **Ontology Workbench** — A full ontology editor with class and property trees, OWL/XML and Turtle import/export with round-trip fidelity, circular dependency detection, and safe-delete confirmation dialogs
|
||||||
- **Knowledge Cores**: Manage resuable knowledge bases
|
- **Schema Workbench** — Interactive schema management with list, create, edit, and delete operations including field and index management
|
||||||
- **Prompts**: Manage and adjust prompts during runtime
|
- **Flow Management** — Flow creation and detail views with configurable parameters, temperature controls, and grouped storage layout
|
||||||
- **Schemas**: Define custom schemas for structured data knowledge bases
|
- **Workspace UX** — Workspace selection and management surfaced directly in the interface
|
||||||
- **Ontologies**: Define custom ontologies for unstructured data knowledge bases
|
- **Prompt Editor** — A dedicated prompt editing workflow
|
||||||
- **Agent Tools**: Define tools with collections, knowledge cores, MCP connections, and tool groups
|
|
||||||
- **MCP Tools**: Connect to MCP servers
|
|
||||||
|
|
||||||
## TypeScript Library for UIs
|
## TypeScript Library for UIs
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
||||||
async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer):
|
async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer):
|
||||||
"""Test successful PDF processing"""
|
"""Test successful PDF processing"""
|
||||||
# Mock PDF content
|
# Mock PDF content
|
||||||
pdf_content = b"fake pdf content"
|
pdf_content = b"%PDF-1.7\nfake pdf content"
|
||||||
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
|
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
|
||||||
|
|
||||||
# Mock PyPDFLoader
|
# Mock PyPDFLoader
|
||||||
|
|
@ -88,13 +88,55 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
||||||
# Verify triples were sent for each page (provenance)
|
# Verify triples were sent for each page (provenance)
|
||||||
assert mock_triples_flow.send.call_count == 2
|
assert mock_triples_flow.send.call_count == 2
|
||||||
|
|
||||||
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
|
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
|
||||||
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
|
async def test_on_message_rejects_librarian_content_that_is_not_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer):
|
||||||
|
"""Test rejecting non-PDF content before invoking the PDF loader"""
|
||||||
|
html_content = b"<html><body>Not found</body></html>"
|
||||||
|
html_base64 = base64.b64encode(html_content)
|
||||||
|
|
||||||
|
mock_metadata = Metadata(id="test-doc")
|
||||||
|
mock_document = Document(metadata=mock_metadata, document_id="doc-123")
|
||||||
|
mock_msg = MagicMock()
|
||||||
|
mock_msg.value.return_value = mock_document
|
||||||
|
|
||||||
|
mock_output_flow = AsyncMock()
|
||||||
|
mock_triples_flow = AsyncMock()
|
||||||
|
mock_flow = MagicMock(side_effect=lambda name: {
|
||||||
|
"output": mock_output_flow,
|
||||||
|
"triples": mock_triples_flow,
|
||||||
|
}.get(name))
|
||||||
|
mock_flow.librarian.fetch_document_metadata = AsyncMock(
|
||||||
|
return_value=MagicMock(kind="application/pdf")
|
||||||
|
)
|
||||||
|
mock_flow.librarian.fetch_document_content = AsyncMock(
|
||||||
|
return_value=html_base64
|
||||||
|
)
|
||||||
|
mock_flow.librarian.save_child_document = AsyncMock()
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'id': 'test-pdf-decoder',
|
||||||
|
'taskgroup': AsyncMock()
|
||||||
|
}
|
||||||
|
|
||||||
|
processor = Processor(**config)
|
||||||
|
|
||||||
|
await processor.on_message(mock_msg, None, mock_flow)
|
||||||
|
|
||||||
|
mock_pdf_loader_class.assert_not_called()
|
||||||
|
mock_output_flow.send.assert_not_called()
|
||||||
|
mock_triples_flow.send.assert_not_called()
|
||||||
|
mock_flow.librarian.save_child_document.assert_not_called()
|
||||||
|
|
||||||
@patch('trustgraph.base.librarian_client.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.base.librarian_client.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
|
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer):
|
async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer):
|
||||||
"""Test handling of empty PDF"""
|
"""Test handling of empty PDF"""
|
||||||
pdf_content = b"fake pdf content"
|
pdf_content = b"%PDF-1.7\nfake pdf content"
|
||||||
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
|
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
|
||||||
|
|
||||||
mock_loader = MagicMock()
|
mock_loader = MagicMock()
|
||||||
|
|
@ -126,7 +168,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer):
|
async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer):
|
||||||
"""Test handling of unicode content in PDF"""
|
"""Test handling of unicode content in PDF"""
|
||||||
pdf_content = b"fake pdf content"
|
pdf_content = b"%PDF-1.7\nfake pdf content"
|
||||||
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
|
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
|
||||||
|
|
||||||
mock_loader = MagicMock()
|
mock_loader = MagicMock()
|
||||||
|
|
|
||||||
|
|
@ -333,8 +333,8 @@ class TestUnifiedTableQueries:
|
||||||
"""Test queries against the unified rows table"""
|
"""Test queries against the unified rows table"""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch('trustgraph.query.rows.cassandra.service.async_execute', new_callable=AsyncMock)
|
@patch('trustgraph.query.rows.cassandra.service.async_execute_paged', new_callable=AsyncMock)
|
||||||
async def test_query_with_index_match(self, mock_async_execute):
|
async def test_query_with_index_match(self, mock_async_execute_paged):
|
||||||
"""Test query execution with matching index"""
|
"""Test query execution with matching index"""
|
||||||
processor = MagicMock()
|
processor = MagicMock()
|
||||||
processor.session = MagicMock()
|
processor.session = MagicMock()
|
||||||
|
|
@ -344,10 +344,10 @@ class TestUnifiedTableQueries:
|
||||||
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||||
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
|
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
|
||||||
|
|
||||||
# Mock async_execute to return test data
|
# Mock async_execute_paged to return test data (list of pages)
|
||||||
mock_row = MagicMock()
|
mock_row = MagicMock()
|
||||||
mock_row.data = {"id": "123", "name": "Test Product", "category": "electronics"}
|
mock_row.data = {"id": "123", "name": "Test Product", "category": "electronics"}
|
||||||
mock_async_execute.return_value = [mock_row]
|
mock_async_execute_paged.return_value = [[mock_row]]
|
||||||
|
|
||||||
schema = RowSchema(
|
schema = RowSchema(
|
||||||
name="products",
|
name="products",
|
||||||
|
|
@ -370,10 +370,10 @@ class TestUnifiedTableQueries:
|
||||||
|
|
||||||
# Verify Cassandra was connected and queried
|
# Verify Cassandra was connected and queried
|
||||||
processor.connect_cassandra.assert_called_once()
|
processor.connect_cassandra.assert_called_once()
|
||||||
mock_async_execute.assert_called_once()
|
mock_async_execute_paged.assert_called_once()
|
||||||
|
|
||||||
# Verify query structure - should query unified rows table
|
# Verify query structure - should query unified rows table
|
||||||
call_args = mock_async_execute.call_args
|
call_args = mock_async_execute_paged.call_args
|
||||||
query = call_args[0][1]
|
query = call_args[0][1]
|
||||||
params = call_args[0][2]
|
params = call_args[0][2]
|
||||||
|
|
||||||
|
|
@ -394,8 +394,8 @@ class TestUnifiedTableQueries:
|
||||||
assert results[0]["category"] == "electronics"
|
assert results[0]["category"] == "electronics"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch('trustgraph.query.rows.cassandra.service.async_execute', new_callable=AsyncMock)
|
@patch('trustgraph.query.rows.cassandra.service.async_scan', new_callable=AsyncMock)
|
||||||
async def test_query_without_index_match(self, mock_async_execute):
|
async def test_query_without_index_match(self, mock_async_scan):
|
||||||
"""Test query execution without matching index (scan mode)"""
|
"""Test query execution without matching index (scan mode)"""
|
||||||
processor = MagicMock()
|
processor = MagicMock()
|
||||||
processor.session = MagicMock()
|
processor.session = MagicMock()
|
||||||
|
|
@ -406,12 +406,10 @@ class TestUnifiedTableQueries:
|
||||||
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
|
||||||
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
|
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
|
||||||
|
|
||||||
# Mock async_execute to return test data
|
# Mock async_scan to return filtered test data
|
||||||
mock_row1 = MagicMock()
|
mock_row1 = MagicMock()
|
||||||
mock_row1.data = {"id": "1", "name": "Product A", "price": "100"}
|
mock_row1.data = {"id": "1", "name": "Product A", "price": "100"}
|
||||||
mock_row2 = MagicMock()
|
mock_async_scan.return_value = [mock_row1]
|
||||||
mock_row2.data = {"id": "2", "name": "Product B", "price": "200"}
|
|
||||||
mock_async_execute.return_value = [mock_row1, mock_row2]
|
|
||||||
|
|
||||||
schema = RowSchema(
|
schema = RowSchema(
|
||||||
name="products",
|
name="products",
|
||||||
|
|
@ -432,13 +430,16 @@ class TestUnifiedTableQueries:
|
||||||
limit=10
|
limit=10
|
||||||
)
|
)
|
||||||
|
|
||||||
# Query should use ALLOW FILTERING for scan
|
# Verify async_scan was called
|
||||||
call_args = mock_async_execute.call_args
|
mock_async_scan.assert_called_once()
|
||||||
|
|
||||||
|
# Verify query structure
|
||||||
|
call_args = mock_async_scan.call_args
|
||||||
query = call_args[0][1]
|
query = call_args[0][1]
|
||||||
|
|
||||||
assert "ALLOW FILTERING" in query
|
assert "ALLOW FILTERING" in query
|
||||||
|
|
||||||
# Should post-filter results
|
# Should return filtered results
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
assert results[0]["name"] == "Product A"
|
assert results[0]["name"] == "Product A"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,7 @@ def load_structured_data(
|
||||||
logger.info("Step 1: Analyzing data to discover best matching schema...")
|
logger.info("Step 1: Analyzing data to discover best matching schema...")
|
||||||
|
|
||||||
# Step 1: Auto-discover schema (reuse discover_schema logic)
|
# Step 1: Auto-discover schema (reuse discover_schema logic)
|
||||||
discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace)
|
discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, token=token, workspace=workspace)
|
||||||
if not discovered_schema:
|
if not discovered_schema:
|
||||||
logger.error("Failed to discover suitable schema automatically")
|
logger.error("Failed to discover suitable schema automatically")
|
||||||
print("❌ Could not automatically determine the best schema for your data.")
|
print("❌ Could not automatically determine the best schema for your data.")
|
||||||
|
|
@ -90,7 +90,7 @@ def load_structured_data(
|
||||||
|
|
||||||
# Step 2: Auto-generate descriptor
|
# Step 2: Auto-generate descriptor
|
||||||
logger.info("Step 2: Generating descriptor configuration...")
|
logger.info("Step 2: Generating descriptor configuration...")
|
||||||
auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, workspace=workspace)
|
auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, token=token, workspace=workspace)
|
||||||
if not auto_descriptor:
|
if not auto_descriptor:
|
||||||
logger.error("Failed to generate descriptor automatically")
|
logger.error("Failed to generate descriptor automatically")
|
||||||
print("❌ Could not automatically generate descriptor configuration.")
|
print("❌ Could not automatically generate descriptor configuration.")
|
||||||
|
|
@ -172,7 +172,7 @@ def load_structured_data(
|
||||||
logger.info(f"Sample chars: {sample_chars} characters")
|
logger.info(f"Sample chars: {sample_chars} characters")
|
||||||
|
|
||||||
# Use the helper function to discover schema (get raw response for display)
|
# Use the helper function to discover schema (get raw response for display)
|
||||||
response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, workspace=workspace)
|
response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, token=token, workspace=workspace)
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
# Debug: print response type and content
|
# Debug: print response type and content
|
||||||
|
|
@ -203,7 +203,7 @@ def load_structured_data(
|
||||||
# If no schema specified, discover it first
|
# If no schema specified, discover it first
|
||||||
if not schema_name:
|
if not schema_name:
|
||||||
logger.info("No schema specified, auto-discovering...")
|
logger.info("No schema specified, auto-discovering...")
|
||||||
schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace)
|
schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, token=token, workspace=workspace)
|
||||||
if not schema_name:
|
if not schema_name:
|
||||||
print("Error: Could not determine schema automatically.")
|
print("Error: Could not determine schema automatically.")
|
||||||
print("Please specify a schema using --schema-name or run --discover-schema first.")
|
print("Please specify a schema using --schema-name or run --discover-schema first.")
|
||||||
|
|
@ -213,7 +213,7 @@ def load_structured_data(
|
||||||
logger.info(f"Target schema: {schema_name}")
|
logger.info(f"Target schema: {schema_name}")
|
||||||
|
|
||||||
# Generate descriptor using helper function
|
# Generate descriptor using helper function
|
||||||
descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace=workspace)
|
descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, token=token, workspace=workspace)
|
||||||
|
|
||||||
if descriptor:
|
if descriptor:
|
||||||
# Output the generated descriptor
|
# Output the generated descriptor
|
||||||
|
|
@ -603,7 +603,7 @@ def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None, worksp
|
||||||
|
|
||||||
|
|
||||||
# Helper functions for auto mode
|
# Helper functions for auto mode
|
||||||
def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, workspace="default"):
|
def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, token=None, workspace="default"):
|
||||||
"""Auto-discover the best matching schema for the input data
|
"""Auto-discover the best matching schema for the input data
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -626,7 +626,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur
|
||||||
# Import API modules
|
# Import API modules
|
||||||
from trustgraph.api import Api
|
from trustgraph.api import Api
|
||||||
from trustgraph.api.types import ConfigKey
|
from trustgraph.api.types import ConfigKey
|
||||||
api = Api(api_url, workspace=workspace)
|
api = Api(api_url, token=token, workspace=workspace)
|
||||||
config_api = api.config()
|
config_api = api.config()
|
||||||
|
|
||||||
# Get available schemas
|
# Get available schemas
|
||||||
|
|
@ -707,7 +707,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace="default"):
|
def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, token=None, workspace="default"):
|
||||||
"""Auto-generate descriptor configuration for the discovered schema"""
|
"""Auto-generate descriptor configuration for the discovered schema"""
|
||||||
try:
|
try:
|
||||||
# Read sample data
|
# Read sample data
|
||||||
|
|
@ -717,7 +717,7 @@ def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, fl
|
||||||
# Import API modules
|
# Import API modules
|
||||||
from trustgraph.api import Api
|
from trustgraph.api import Api
|
||||||
from trustgraph.api.types import ConfigKey
|
from trustgraph.api.types import ConfigKey
|
||||||
api = Api(api_url, workspace=workspace)
|
api = Api(api_url, token=token, workspace=workspace)
|
||||||
config_api = api.config()
|
config_api = api.config()
|
||||||
|
|
||||||
# Get schema definition
|
# Get schema definition
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,10 @@ logger = logging.getLogger(__name__)
|
||||||
default_ident = "document-decoder"
|
default_ident = "document-decoder"
|
||||||
|
|
||||||
|
|
||||||
|
def _looks_like_pdf(content):
|
||||||
|
return content.lstrip().startswith(b"%PDF-")
|
||||||
|
|
||||||
|
|
||||||
class Processor(FlowProcessor):
|
class Processor(FlowProcessor):
|
||||||
|
|
||||||
def __init__(self, **params):
|
def __init__(self, **params):
|
||||||
|
|
@ -94,33 +98,37 @@ class Processor(FlowProcessor):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(delete_on_close=False, suffix='.pdf') as fp:
|
# Check if we should fetch from librarian or use inline data
|
||||||
|
if v.document_id:
|
||||||
|
# Fetch from librarian via Pulsar
|
||||||
|
logger.info(f"Fetching document {v.document_id} from librarian...")
|
||||||
|
|
||||||
|
content = await flow.librarian.fetch_document_content(
|
||||||
|
document_id=v.document_id,
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
# Content is base64 encoded
|
||||||
|
if isinstance(content, str):
|
||||||
|
content = content.encode('utf-8')
|
||||||
|
decoded_content = base64.b64decode(content)
|
||||||
|
|
||||||
|
logger.info(f"Fetched {len(decoded_content)} bytes from librarian")
|
||||||
|
else:
|
||||||
|
# Use inline data (backward compatibility)
|
||||||
|
decoded_content = base64.b64decode(v.data)
|
||||||
|
|
||||||
|
if not _looks_like_pdf(decoded_content):
|
||||||
|
logger.error(
|
||||||
|
f"Document {v.metadata.id} is not valid PDF content. "
|
||||||
|
f"Ignoring document."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as fp:
|
||||||
temp_path = fp.name
|
temp_path = fp.name
|
||||||
|
fp.write(decoded_content)
|
||||||
# Check if we should fetch from librarian or use inline data
|
fp.close()
|
||||||
if v.document_id:
|
|
||||||
# Fetch from librarian via Pulsar
|
|
||||||
logger.info(f"Fetching document {v.document_id} from librarian...")
|
|
||||||
fp.close()
|
|
||||||
|
|
||||||
content = await flow.librarian.fetch_document_content(
|
|
||||||
document_id=v.document_id,
|
|
||||||
|
|
||||||
)
|
|
||||||
|
|
||||||
# Content is base64 encoded
|
|
||||||
if isinstance(content, str):
|
|
||||||
content = content.encode('utf-8')
|
|
||||||
decoded_content = base64.b64decode(content)
|
|
||||||
|
|
||||||
with open(temp_path, 'wb') as f:
|
|
||||||
f.write(decoded_content)
|
|
||||||
|
|
||||||
logger.info(f"Fetched {len(decoded_content)} bytes from librarian")
|
|
||||||
else:
|
|
||||||
# Use inline data (backward compatibility)
|
|
||||||
fp.write(base64.b64decode(v.data))
|
|
||||||
fp.close()
|
|
||||||
|
|
||||||
global PyPDFLoader
|
global PyPDFLoader
|
||||||
if PyPDFLoader is None:
|
if PyPDFLoader is None:
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ from .... schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
|
||||||
from .... schema import Error, RowSchema, Field as SchemaField
|
from .... schema import Error, RowSchema, Field as SchemaField
|
||||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||||
from .... tables.cassandra_async import async_execute
|
from .... tables.cassandra_async import async_execute, async_execute_paged, async_scan
|
||||||
|
|
||||||
from ... graphql import GraphQLSchemaBuilder, SortDirection
|
from ... graphql import GraphQLSchemaBuilder, SortDirection
|
||||||
|
|
||||||
|
|
@ -180,7 +180,7 @@ class Processor(FlowProcessor):
|
||||||
description=field_def.get("description", ""),
|
description=field_def.get("description", ""),
|
||||||
required=field_def.get("required", False),
|
required=field_def.get("required", False),
|
||||||
enum_values=field_def.get("enum", []),
|
enum_values=field_def.get("enum", []),
|
||||||
indexed=field_def.get("indexed", False)
|
indexed=field_def.get("indexed", False),
|
||||||
)
|
)
|
||||||
fields.append(field)
|
fields.append(field)
|
||||||
|
|
||||||
|
|
@ -232,6 +232,8 @@ class Processor(FlowProcessor):
|
||||||
for index_name in index_names:
|
for index_name in index_names:
|
||||||
if index_name in filters:
|
if index_name in filters:
|
||||||
value = filters[index_name]
|
value = filters[index_name]
|
||||||
|
if value == "" or value is None:
|
||||||
|
continue
|
||||||
# Single field index -> single element list
|
# Single field index -> single element list
|
||||||
index_value = [str(value)]
|
index_value = [str(value)]
|
||||||
return (index_name, index_value)
|
return (index_name, index_value)
|
||||||
|
|
@ -282,11 +284,13 @@ class Processor(FlowProcessor):
|
||||||
query += f" LIMIT {limit}"
|
query += f" LIMIT {limit}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rows = await async_execute(self.session, query, params)
|
pages = await async_execute_paged(
|
||||||
for row in rows:
|
self.session, query, params
|
||||||
# Convert data map to dict with proper field names
|
)
|
||||||
row_dict = dict(row.data) if row.data else {}
|
for page in pages:
|
||||||
results.append(row_dict)
|
for row in page:
|
||||||
|
row_dict = dict(row.data) if row.data else {}
|
||||||
|
results.append(row_dict)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to query rows: {e}", exc_info=True)
|
logger.error(f"Failed to query rows: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
@ -308,8 +312,6 @@ class Processor(FlowProcessor):
|
||||||
# Query using the first index (arbitrary choice for scan)
|
# Query using the first index (arbitrary choice for scan)
|
||||||
primary_index = index_names[0]
|
primary_index = index_names[0]
|
||||||
|
|
||||||
# We need to scan all values for this index
|
|
||||||
# This requires ALLOW FILTERING or a different approach
|
|
||||||
query = f"""
|
query = f"""
|
||||||
SELECT data, source FROM {safe_keyspace}.rows
|
SELECT data, source FROM {safe_keyspace}.rows
|
||||||
WHERE collection = %s
|
WHERE collection = %s
|
||||||
|
|
@ -320,17 +322,18 @@ class Processor(FlowProcessor):
|
||||||
params = [collection, schema_name, primary_index]
|
params = [collection, schema_name, primary_index]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rows = await async_execute(self.session, query, params)
|
def row_filter(row):
|
||||||
|
|
||||||
for row in rows:
|
|
||||||
row_dict = dict(row.data) if row.data else {}
|
row_dict = dict(row.data) if row.data else {}
|
||||||
|
return self._matches_filters(row_dict, filters, row_schema)
|
||||||
|
|
||||||
# Apply post-filters
|
matched_rows = await async_scan(
|
||||||
if self._matches_filters(row_dict, filters, row_schema):
|
self.session, query, params,
|
||||||
results.append(row_dict)
|
row_filter=row_filter,
|
||||||
|
limit=limit,
|
||||||
if limit and len(results) >= limit:
|
)
|
||||||
break
|
for row in matched_rows:
|
||||||
|
row_dict = dict(row.data) if row.data else {}
|
||||||
|
results.append(row_dict)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to scan rows: {e}", exc_info=True)
|
logger.error(f"Failed to scan rows: {e}", exc_info=True)
|
||||||
|
|
@ -363,7 +366,7 @@ class Processor(FlowProcessor):
|
||||||
# Parse filter key for operator
|
# Parse filter key for operator
|
||||||
if '_' in filter_key:
|
if '_' in filter_key:
|
||||||
parts = filter_key.rsplit('_', 1)
|
parts = filter_key.rsplit('_', 1)
|
||||||
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in']:
|
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in', 'not', 'startsWith', 'endsWith', 'not_in']:
|
||||||
field_name = parts[0]
|
field_name = parts[0]
|
||||||
operator = parts[1]
|
operator = parts[1]
|
||||||
else:
|
else:
|
||||||
|
|
@ -400,6 +403,18 @@ class Processor(FlowProcessor):
|
||||||
elif operator == 'in':
|
elif operator == 'in':
|
||||||
if str(row_value) not in [str(v) for v in filter_value]:
|
if str(row_value) not in [str(v) for v in filter_value]:
|
||||||
return False
|
return False
|
||||||
|
elif operator == 'not':
|
||||||
|
if str(row_value) == str(filter_value):
|
||||||
|
return False
|
||||||
|
elif operator == 'startsWith':
|
||||||
|
if not str(row_value).startswith(str(filter_value)):
|
||||||
|
return False
|
||||||
|
elif operator == 'endsWith':
|
||||||
|
if not str(row_value).endswith(str(filter_value)):
|
||||||
|
return False
|
||||||
|
elif operator == 'not_in':
|
||||||
|
if str(row_value) in [str(v) for v in filter_value]:
|
||||||
|
return False
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -172,7 +172,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
description=field_def.get("description", ""),
|
description=field_def.get("description", ""),
|
||||||
required=field_def.get("required", False),
|
required=field_def.get("required", False),
|
||||||
enum_values=field_def.get("enum", []),
|
enum_values=field_def.get("enum", []),
|
||||||
indexed=field_def.get("indexed", False)
|
indexed=field_def.get("indexed", False),
|
||||||
)
|
)
|
||||||
fields.append(field)
|
fields.append(field)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -80,14 +80,14 @@ def _set_exception_if_pending(fut, exc):
|
||||||
fut.set_exception(exc)
|
fut.set_exception(exc)
|
||||||
|
|
||||||
|
|
||||||
async def async_execute_paged(session, query, parameters=None, fetch_size=100):
|
async def async_execute_paged(session, query, parameters=None, fetch_size=5000):
|
||||||
"""Execute a CQL query with page-by-page iteration.
|
"""Execute a CQL query with page-by-page iteration.
|
||||||
|
|
||||||
Uses synchronous session.execute() inside run_in_executor so that
|
Uses synchronous session.execute() inside run_in_executor so that
|
||||||
the driver's ResultSet paging works correctly without materialising
|
the driver's ResultSet paging works correctly without materialising
|
||||||
the entire result set in memory.
|
the entire result set in memory.
|
||||||
|
|
||||||
Yields one page of rows at a time (as a list).
|
Returns all pages as a list of lists.
|
||||||
"""
|
"""
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
|
@ -111,3 +111,50 @@ async def async_execute_paged(session, query, parameters=None, fetch_size=100):
|
||||||
return await loop.run_in_executor(
|
return await loop.run_in_executor(
|
||||||
None, _fetch_all_pages
|
None, _fetch_all_pages
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_scan(
|
||||||
|
session, query, parameters=None, row_filter=None,
|
||||||
|
limit=None, fetch_size=5000,
|
||||||
|
):
|
||||||
|
"""Scan a CQL query page-by-page, applying a filter and limit.
|
||||||
|
|
||||||
|
Only matching rows accumulate in memory. Each page is discarded
|
||||||
|
after processing, so peak memory is bounded by fetch_size plus
|
||||||
|
the number of matching rows (capped by limit).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: cassandra.cluster.Session
|
||||||
|
query: CQL statement string
|
||||||
|
parameters: bind params
|
||||||
|
row_filter: callable(row) -> bool, or None to accept all
|
||||||
|
limit: max results to return, or None for unlimited
|
||||||
|
fetch_size: rows per Cassandra page fetch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching rows.
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
if isinstance(query, str):
|
||||||
|
stmt = SimpleStatement(query, fetch_size=fetch_size)
|
||||||
|
else:
|
||||||
|
stmt = query
|
||||||
|
stmt.fetch_size = fetch_size
|
||||||
|
|
||||||
|
def _scan():
|
||||||
|
results = []
|
||||||
|
result_set = session.execute(stmt, parameters)
|
||||||
|
while True:
|
||||||
|
for row in result_set.current_rows:
|
||||||
|
if row_filter is None or row_filter(row):
|
||||||
|
results.append(row)
|
||||||
|
if limit and len(results) >= limit:
|
||||||
|
return results
|
||||||
|
if result_set.has_more_pages:
|
||||||
|
result_set.fetch_next_page()
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return results
|
||||||
|
|
||||||
|
return await loop.run_in_executor(None, _scan)
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,49 +1,110 @@
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from websockets.asyncio.client import connect
|
from websockets.asyncio.client import connect
|
||||||
from urllib.parse import urlencode, urlparse, urlunparse, parse_qs
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
import hashlib
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _token_key(token):
|
||||||
|
"""Derive a dict key from a token without storing the raw secret."""
|
||||||
|
return hashlib.sha256(token.encode()).hexdigest()[:16]
|
||||||
|
|
||||||
|
|
||||||
class WebSocketManager:
|
class WebSocketManager:
|
||||||
|
"""Manages an authenticated WebSocket connection to the TrustGraph
|
||||||
|
gateway on behalf of a single caller.
|
||||||
|
|
||||||
def __init__(self, url, token=None):
|
Each caller token gets its own WebSocketManager so that gateway-side
|
||||||
|
identity, workspace, and capability scoping are preserved end-to-end.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, url, token):
|
||||||
self.url = url
|
self.url = url
|
||||||
|
# ── Security boundary: token storage ──
|
||||||
|
# This is the MCP caller's Bearer token, forwarded verbatim to
|
||||||
|
# the gateway. It MUST NOT be logged, persisted, or shared
|
||||||
|
# across callers. It is held only for the lifetime of this
|
||||||
|
# connection so that re-auth (e.g. after a reconnect) is
|
||||||
|
# possible.
|
||||||
self.token = token
|
self.token = token
|
||||||
self.socket = None
|
self.socket = None
|
||||||
|
self.identity = None
|
||||||
# FIXME: authentication is broken. The /api/v1/socket endpoint uses
|
self.last_used = None
|
||||||
# in-band auth (first-frame protocol via the Mux dispatcher), not
|
|
||||||
# query-parameter tokens. This query-string token is silently ignored.
|
|
||||||
# Fix: after connect(), send an auth frame with the bearer token as
|
|
||||||
# the first message, matching the gateway's in-band auth protocol.
|
|
||||||
def _build_url(self):
|
|
||||||
if not self.token:
|
|
||||||
return self.url
|
|
||||||
parsed = urlparse(self.url)
|
|
||||||
params = parse_qs(parsed.query)
|
|
||||||
params["token"] = [self.token]
|
|
||||||
new_query = urlencode(params, doseq=True)
|
|
||||||
return urlunparse(parsed._replace(query=new_query))
|
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
self.socket = await connect(self._build_url())
|
"""Connect and authenticate via the gateway's in-band auth
|
||||||
|
protocol. Raises on auth failure."""
|
||||||
|
|
||||||
|
# ── Security boundary: MCP server → gateway ──
|
||||||
|
# The WebSocket connects to the gateway and authenticates using
|
||||||
|
# the caller's Bearer token via the in-band first-frame auth
|
||||||
|
# protocol. The token belongs to the MCP client — we forward
|
||||||
|
# it as-is and never interpret its contents.
|
||||||
|
self.socket = await connect(self.url)
|
||||||
self.pending_requests = {}
|
self.pending_requests = {}
|
||||||
self.running = True
|
self.running = True
|
||||||
|
|
||||||
|
await self._authenticate()
|
||||||
|
|
||||||
self.reader_task = asyncio.create_task(self.reader())
|
self.reader_task = asyncio.create_task(self.reader())
|
||||||
|
|
||||||
|
async def _authenticate(self):
|
||||||
|
"""Send in-band auth frame and wait for auth-ok / auth-failed.
|
||||||
|
|
||||||
|
The gateway expects ``{"type": "auth", "token": "..."}`` as the
|
||||||
|
first frame on a new WebSocket. Any service frame sent before
|
||||||
|
auth-ok is rejected.
|
||||||
|
"""
|
||||||
|
await self.socket.send(json.dumps({
|
||||||
|
"type": "auth",
|
||||||
|
"token": self.token,
|
||||||
|
}))
|
||||||
|
|
||||||
|
response_text = await asyncio.wait_for(self.socket.recv(), 10)
|
||||||
|
response = json.loads(response_text)
|
||||||
|
|
||||||
|
if response.get("type") == "auth-ok":
|
||||||
|
logger.info(
|
||||||
|
"WebSocket authenticated, default workspace: %s",
|
||||||
|
response.get("workspace"),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Auth failed — close immediately, do not leave an
|
||||||
|
# unauthenticated socket open.
|
||||||
|
await self.socket.close()
|
||||||
|
self.socket = None
|
||||||
|
|
||||||
|
if response.get("type") == "auth-failed":
|
||||||
|
raise RuntimeError(
|
||||||
|
"Gateway rejected the authentication token"
|
||||||
|
)
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Unexpected auth response type: {response.get('type')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def whoami(self):
|
||||||
|
"""Verify the token by calling the gateway's whoami endpoint.
|
||||||
|
Returns the identity dict and caches it on ``self.identity``.
|
||||||
|
"""
|
||||||
|
gen = self.request("iam", {"operation": "whoami"}, flow_id=None)
|
||||||
|
async for response in gen:
|
||||||
|
self.identity = response
|
||||||
|
return response
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
self.running = False
|
self.running = False
|
||||||
await self.reader_task
|
if hasattr(self, "reader_task"):
|
||||||
|
await self.reader_task
|
||||||
|
|
||||||
async def reader(self):
|
async def reader(self):
|
||||||
"""
|
"""Background task: read WebSocket frames and route them to the
|
||||||
Background task to read websocket responses and route to correct
|
correct pending-request queue by ``id``."""
|
||||||
request
|
|
||||||
"""
|
|
||||||
|
|
||||||
while self.running:
|
while self.running:
|
||||||
try:
|
try:
|
||||||
|
|
@ -59,23 +120,21 @@ class WebSocketManager:
|
||||||
|
|
||||||
request_id = response.get("id")
|
request_id = response.get("id")
|
||||||
if request_id and request_id in self.pending_requests:
|
if request_id and request_id in self.pending_requests:
|
||||||
# Put the response in the queue
|
|
||||||
queue = self.pending_requests[request_id]
|
queue = self.pending_requests[request_id]
|
||||||
await queue.put(response)
|
await queue.put(response)
|
||||||
else:
|
else:
|
||||||
logging.warning(
|
logger.warning(
|
||||||
f"Response for unknown request ID: {request_id}"
|
"Response for unknown request ID: %s", request_id
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
||||||
logging.error(f"Error in websocket reader: {e}")
|
logger.error("Error in websocket reader: %s", e)
|
||||||
|
|
||||||
# Put error in all pending queues
|
|
||||||
for queue in self.pending_requests.values():
|
for queue in self.pending_requests.values():
|
||||||
try:
|
try:
|
||||||
await queue.put({"error": str(e)})
|
await queue.put({"error": str(e)})
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self.pending_requests.clear()
|
self.pending_requests.clear()
|
||||||
|
|
@ -86,25 +145,29 @@ class WebSocketManager:
|
||||||
|
|
||||||
async def request(
|
async def request(
|
||||||
self, service, request_data, flow_id="default",
|
self, service, request_data, flow_id="default",
|
||||||
|
workspace=None,
|
||||||
):
|
):
|
||||||
"""
|
"""Send a request via WebSocket and yield responses.
|
||||||
Send a request via websocket and handle single or streaming responses
|
|
||||||
|
Args:
|
||||||
|
service: Gateway service name (e.g. "graph-rag", "config").
|
||||||
|
request_data: Inner request payload.
|
||||||
|
flow_id: Optional flow identifier. ``None`` omits the field
|
||||||
|
(workspace-level services don't use flows).
|
||||||
|
workspace: Optional workspace override. When ``None`` the
|
||||||
|
gateway uses the caller's default workspace.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Generate unique request ID
|
import time
|
||||||
|
self.last_used = time.monotonic()
|
||||||
|
|
||||||
request_id = f"{uuid.uuid4()}"
|
request_id = f"{uuid.uuid4()}"
|
||||||
|
|
||||||
# Determine if this service streams responses
|
|
||||||
streaming_services = {"agent"}
|
|
||||||
is_streaming = service in streaming_services
|
|
||||||
|
|
||||||
# Create a queue for all responses (streaming and single)
|
|
||||||
response_queue = asyncio.Queue()
|
response_queue = asyncio.Queue()
|
||||||
self.pending_requests[request_id] = response_queue
|
self.pending_requests[request_id] = response_queue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# Build request message
|
|
||||||
message = {
|
message = {
|
||||||
"id": request_id,
|
"id": request_id,
|
||||||
"service": service,
|
"service": service,
|
||||||
|
|
@ -114,7 +177,16 @@ class WebSocketManager:
|
||||||
if flow_id is not None:
|
if flow_id is not None:
|
||||||
message["flow"] = flow_id
|
message["flow"] = flow_id
|
||||||
|
|
||||||
# Send request
|
# ── Security boundary: workspace scoping ──
|
||||||
|
# When the caller supplies a workspace, we set it on the
|
||||||
|
# message envelope. The gateway's enforce_workspace()
|
||||||
|
# validates that the authenticated identity is permitted
|
||||||
|
# to access the target workspace — we MUST NOT skip or
|
||||||
|
# override that check. When workspace is None, the
|
||||||
|
# gateway default-fills from the identity's bound workspace.
|
||||||
|
if workspace is not None:
|
||||||
|
message["workspace"] = workspace
|
||||||
|
|
||||||
await self.socket.send(json.dumps(message))
|
await self.socket.send(json.dumps(message))
|
||||||
|
|
||||||
while self.running:
|
while self.running:
|
||||||
|
|
@ -127,19 +199,17 @@ class WebSocketManager:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if "error" in response:
|
if "error" in response:
|
||||||
if "message" in response["error"]:
|
if isinstance(response["error"], dict):
|
||||||
raise RuntimeError(response["error"]["text"])
|
raise RuntimeError(
|
||||||
|
response["error"].get("message", str(response["error"]))
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(str(response["error"]))
|
raise RuntimeError(str(response["error"]))
|
||||||
|
|
||||||
yield response["response"]
|
yield response["response"]
|
||||||
|
|
||||||
if "complete" in response:
|
if response.get("complete"):
|
||||||
if response["complete"]:
|
break
|
||||||
break
|
|
||||||
|
|
||||||
except Exception as e:
|
finally:
|
||||||
# Clean up on error
|
|
||||||
self.pending_requests.pop(request_id, None)
|
self.pending_requests.pop(request_id, None)
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue