release/v2.4 -> master (#924)

* CLI auth migration, document embeddings core lifecycle (#913)

Migrate get_kg_core and put_kg_core CLI tools to use Api/SocketClient
with first-frame auth (fixes broken raw websocket path). Fix wire
format field names (root/vector). Remove ~600 lines of dead raw
websocket code from invoke_graph_rag.py.

Add document embeddings core lifecycle to the knowledge service:
list/get/put/delete/load operations across schema, translator,
Cassandra table store, knowledge manager, gateway registry, REST API,
socket client, and CLI (tg-get-de-core, tg-put-de-core).

Fix delete_kg_core to also clean up document embeddings rows.

* Remove spurious workspace parameter from SPARQL algebra evaluator (#915)

Fix threading of workspace paramater:
- The SPARQL algebra evaluator was threading a workspace parameter
  through every function and passing it to TriplesClient.query(),
  which doesn't accept it. Workspace isolation is handled by pub/sub
  topic routing — the TriplesClient is already scoped to a
  workspace-specific flow, same as GraphRAG. Passing workspace
  explicitly was both incorrect and unnecessary.

Update tests:
- tests/unit/test_query/test_sparql_algebra.py (new) — Tests
  _query_pattern, _eval_bgp, and evaluate() with various algebra
  nodes. Key tests assert workspace is never in tc.query() kwargs,
  plus correctness tests for BGP, JOIN, UNION, SLICE, DISTINCT, and
  edge cases.
- tests/unit/test_retrieval/test_graph_rag.py — Added
  test_triples_query_never_passes_workspace (checks query()) and
  test_follow_edges_never_passes_workspace (checks query_stream()).

* Make all Cassandra and Qdrant I/O async-safe with proper concurrency controls (#916)

Cassandra triples services were using syncronous EntityCentricKnowledgeGraph
methods from async contexts, and connection state was managed with
threading.local which is wrong for asyncio coroutines sharing a single
thread. Qdrant services had no async wrapping at all, blocking the event
loop on every network call. Rows services had unprotected shared state
mutations across concurrent coroutines.

- Add async methods to EntityCentricKnowledgeGraph (async_insert,
  async_get_s/p/o/sp/po/os/spo/all, async_collection_exists,
  async_create_collection, async_delete_collection) using the existing
  cassandra_async.async_execute bridge
- Rewrite triples write + query services: replace threading.local with
  asyncio.Lock + dict cache for per-workspace connections, use async
  ECKG methods for all data operations, keep asyncio.to_thread only for
  one-time blocking ECKG construction
- Wrap all Qdrant calls in asyncio.to_thread across all 6 services
  (doc/graph/row embeddings write + query), add asyncio.Lock + set cache
  for collection existence checks
- Add asyncio.Lock to rows write + query services to protect shared
  state (schemas, sessions, config caches) from concurrent mutation
- Update all affected tests to match new async patterns

* Fixed error only returning a page of results (#921)

The root cause: async_execute only materialises the first result
page (by design — it says so in its docstring). The streaming query
set fetch_size=20 and expected to iterate all results, but only got
the first 20 rows back.

The fix uses
  asyncio.to_thread(lambda: list(tg.session.execute(...)))
which lets the sync driver iterate
all pages in a worker thread — exactly what the pre-async code did.

* Optional test warning suppression (#923)

* Fix test collection module errors & silence upstream Pytest warnings (#823)

* chore: add virtual environment and .env directories to gitignore

* test: filter upstream DeprecationWarning and UserWarning messages

* fix(namespace): remove empty __init__.py files to fix PEP 420 implicit namespace routing for trustgraph sub-packages

* Revert __init__.py deletions

* Add .ini changes but commented out, will be useful at times

---------

Co-authored-by: Salil M <d2kyt@protonmail.com>
This commit is contained in:
cybermaggedon 2026-05-15 13:02:51 +01:00 committed by GitHub
parent 159b1e2824
commit 142dd0231c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
42 changed files with 1910 additions and 1492 deletions

3
.gitignore vendored
View file

@ -17,3 +17,6 @@ trustgraph-unstructured/trustgraph/unstructured_version.py
trustgraph-mcp/trustgraph/mcp_version.py
trustgraph/trustgraph/trustgraph_version.py
vertexai/
venv/
.venv/
.env

View file

@ -188,13 +188,14 @@ class TestConfigurationPriorityEndToEnd:
)
@pytest.mark.asyncio
@patch('trustgraph.direct.cassandra_kg.Cluster')
async def test_no_config_defaults_end_to_end(self, mock_cluster):
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
async def test_no_config_defaults_end_to_end(self, mock_kg_class):
"""Test that defaults are used when no configuration provided end-to-end."""
mock_cluster_instance = MagicMock()
mock_session = MagicMock()
mock_cluster_instance.connect.return_value = mock_session
mock_cluster.return_value = mock_cluster_instance
from unittest.mock import AsyncMock
mock_tg_instance = MagicMock()
mock_tg_instance.async_get_all = AsyncMock(return_value=[])
mock_kg_class.return_value = mock_tg_instance
with patch.dict(os.environ, {}, clear=True):
processor = TriplesQuery(taskgroup=MagicMock())
@ -205,20 +206,16 @@ class TestConfigurationPriorityEndToEnd:
mock_query.s = None
mock_query.p = None
mock_query.o = None
mock_query.g = None
mock_query.limit = 100
# Mock the get_all method to return empty list
mock_tg_instance = MagicMock()
mock_tg_instance.get_all.return_value = []
processor.tg = mock_tg_instance
await processor.query_triples('default_user', mock_query)
# Should use defaults
mock_cluster.assert_called_once()
call_args = mock_cluster.call_args
assert call_args.args[0] == ['cassandra'] # Default host
assert 'auth_provider' not in call_args.kwargs # No auth with default config
mock_kg_class.assert_called_once_with(
hosts=['cassandra'],
keyspace='default_user'
)
class TestNoBackwardCompatibilityEndToEnd:

View file

@ -101,6 +101,8 @@ class TestRowsCassandraIntegration:
processor.session = None
# Bind actual methods from the new unified table implementation
import asyncio
processor._setup_lock = asyncio.Lock()
processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor)
processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor)
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
@ -108,6 +110,7 @@ class TestRowsCassandraIntegration:
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor)
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
processor.on_object = Processor.on_object.__get__(processor, Processor)
processor.collection_exists = MagicMock(return_value=True)

View file

@ -184,7 +184,7 @@ class TestObjectsGraphQLQueryIntegration:
await processor.on_schema_config("default", sample_schema_config, version=1)
# Connect to Cassandra
processor.connect_cassandra()
await processor.connect_cassandra()
assert processor.session is not None
# Create test keyspace and table
@ -219,7 +219,7 @@ class TestObjectsGraphQLQueryIntegration:
"""Test inserting data and querying via GraphQL"""
# Load schema and connect
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
await processor.connect_cassandra()
# Setup test data
keyspace = "test_user"
@ -293,7 +293,7 @@ class TestObjectsGraphQLQueryIntegration:
"""Test GraphQL queries with filtering on indexed fields"""
# Setup (reuse previous setup)
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
await processor.connect_cassandra()
keyspace = "test_user"
collection = "filter_test"
@ -387,7 +387,7 @@ class TestObjectsGraphQLQueryIntegration:
"""Test full message processing workflow"""
# Setup
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
await processor.connect_cassandra()
# Create mock message
request = RowsQueryRequest(
@ -433,7 +433,7 @@ class TestObjectsGraphQLQueryIntegration:
"""Test handling multiple concurrent GraphQL queries"""
# Setup
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
await processor.connect_cassandra()
# Create multiple query tasks
queries = [
@ -519,7 +519,7 @@ class TestObjectsGraphQLQueryIntegration:
"""Test handling of large query result sets"""
# Setup
await processor.on_schema_config("default", sample_schema_config, version=1)
processor.connect_cassandra()
await processor.connect_cassandra()
keyspace = "large_test_user"
collection = "large_collection"

View file

@ -17,3 +17,12 @@ markers =
contract: marks tests as contract tests (service interface validation)
vertexai: marks tests as vertex ai specific tests
asyncio: marks tests that use asyncio
# This is helpful if you're bored with deprecationwarnings. I prefer to
# keep the warnings for now, it avoids masking problems.
#
# filterwarnings =
# ignore:Core Pydantic V1 functionality isn't compatible with Python 3.14.*:UserWarning
# ignore:builtin type SwigPyPacked has no __module__ attribute:DeprecationWarning
# ignore:builtin type SwigPyObject has no __module__ attribute:DeprecationWarning
# ignore:builtin type swigvarlink has no __module__ attribute:DeprecationWarning
# ignore:.*_UnionGenericAlias.*is deprecated and slated for removal in Python 3.17:DeprecationWarning

View file

@ -89,12 +89,15 @@ class TestRowsGraphQLQueryLogic:
@pytest.mark.asyncio
async def test_schema_config_parsing(self):
"""Test parsing of schema configuration"""
import asyncio
processor = MagicMock()
processor.schemas = {}
processor.schema_builders = {}
processor.graphql_schemas = {}
processor.config_key = "schema"
processor.query_cassandra = MagicMock()
processor._setup_lock = asyncio.Lock()
processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor)
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test config
@ -335,7 +338,7 @@ class TestUnifiedTableQueries:
"""Test query execution with matching index"""
processor = MagicMock()
processor.session = MagicMock()
processor.connect_cassandra = MagicMock()
processor.connect_cassandra = AsyncMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
@ -396,7 +399,7 @@ class TestUnifiedTableQueries:
"""Test query execution without matching index (scan mode)"""
processor = MagicMock()
processor.session = MagicMock()
processor.connect_cassandra = MagicMock()
processor.connect_cassandra = AsyncMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)

View file

@ -0,0 +1,302 @@
"""
Tests for the SPARQL algebra evaluator.
Verifies that evaluate() and _query_pattern() call TriplesClient.query()
with the correct arguments, and in particular that workspace is never
passed workspace isolation is handled by pub/sub topic routing.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, call
from rdflib.term import Variable, URIRef, Literal
from rdflib.plugins.sparql.parserutils import CompValue
from trustgraph.schema import Term, IRI, LITERAL
from trustgraph.query.sparql.algebra import (
evaluate, _query_pattern, _eval_bgp,
)
# --- Helpers ---
def iri(v):
return Term(type=IRI, iri=v)
def lit(v):
return Term(type=LITERAL, value=v)
def make_triple(s, p, o):
t = MagicMock()
t.s = s
t.p = p
t.o = o
return t
def make_bgp(*patterns):
"""Build a CompValue BGP node from (s, p, o) tuples of rdflib terms."""
node = CompValue("BGP")
node.triples = list(patterns)
return node
def make_project(inner, variables):
node = CompValue("Project")
node.p = inner
node.PV = [Variable(v) for v in variables]
return node
def make_select(inner):
node = CompValue("SelectQuery")
node.p = inner
return node
def make_join(left, right):
node = CompValue("Join")
node.p1 = left
node.p2 = right
return node
def make_union(left, right):
node = CompValue("Union")
node.p1 = left
node.p2 = right
return node
def make_slice(inner, start, length):
node = CompValue("Slice")
node.p = inner
node.start = start
node.length = length
return node
def make_distinct(inner):
node = CompValue("Distinct")
node.p = inner
return node
class TestQueryPattern:
"""Tests for _query_pattern — the leaf that calls TriplesClient."""
@pytest.mark.asyncio
async def test_passes_correct_args(self):
tc = AsyncMock()
tc.query.return_value = []
await _query_pattern(
tc,
s=iri("http://example.com/s"),
p=iri("http://example.com/p"),
o=None,
collection="my-collection",
limit=100,
)
tc.query.assert_called_once_with(
s=iri("http://example.com/s"),
p=iri("http://example.com/p"),
o=None,
limit=100,
collection="my-collection",
)
@pytest.mark.asyncio
async def test_workspace_not_passed(self):
tc = AsyncMock()
tc.query.return_value = []
await _query_pattern(tc, None, None, None, "default", 10)
kwargs = tc.query.call_args.kwargs
assert "workspace" not in kwargs
@pytest.mark.asyncio
async def test_returns_query_results(self):
tc = AsyncMock()
triple = make_triple(iri("http://a"), iri("http://b"), lit("c"))
tc.query.return_value = [triple]
results = await _query_pattern(tc, None, None, None, "default", 10)
assert len(results) == 1
assert results[0].s.iri == "http://a"
class TestEvalBgp:
"""Tests for BGP evaluation — triple pattern queries."""
@pytest.mark.asyncio
async def test_single_pattern_all_variables(self):
tc = AsyncMock()
triple = make_triple(iri("http://s"), iri("http://p"), lit("o"))
tc.query.return_value = [triple]
bgp = make_bgp(
(Variable("s"), Variable("p"), Variable("o")),
)
solutions = await evaluate(bgp, tc, collection="default", limit=100)
assert len(solutions) == 1
assert solutions[0]["s"].iri == "http://s"
assert solutions[0]["p"].iri == "http://p"
assert solutions[0]["o"].value == "o"
@pytest.mark.asyncio
async def test_single_pattern_bound_subject(self):
tc = AsyncMock()
tc.query.return_value = [
make_triple(iri("http://s"), iri("http://p"), lit("val")),
]
bgp = make_bgp(
(URIRef("http://s"), Variable("p"), Variable("o")),
)
solutions = await evaluate(bgp, tc, collection="default")
tc.query.assert_called_once()
kwargs = tc.query.call_args.kwargs
assert "workspace" not in kwargs
assert kwargs["collection"] == "default"
@pytest.mark.asyncio
async def test_empty_bgp_returns_empty_solution(self):
tc = AsyncMock()
bgp = make_bgp()
solutions = await evaluate(bgp, tc, collection="default")
assert solutions == [{}]
tc.query.assert_not_called()
@pytest.mark.asyncio
async def test_no_results_returns_empty(self):
tc = AsyncMock()
tc.query.return_value = []
bgp = make_bgp(
(Variable("s"), Variable("p"), Variable("o")),
)
solutions = await evaluate(bgp, tc, collection="default")
assert solutions == []
class TestEvaluate:
"""Tests for the top-level evaluate() dispatcher."""
@pytest.mark.asyncio
async def test_select_query_node(self):
tc = AsyncMock()
tc.query.return_value = [
make_triple(iri("http://s"), iri("http://p"), lit("o")),
]
bgp = make_bgp(
(Variable("s"), Variable("p"), Variable("o")),
)
select = make_select(make_project(bgp, ["s", "p"]))
solutions = await evaluate(select, tc, collection="default")
assert len(solutions) == 1
assert "s" in solutions[0]
assert "p" in solutions[0]
assert "o" not in solutions[0]
@pytest.mark.asyncio
async def test_workspace_never_in_query_calls(self):
"""Verify that no matter the algebra structure, workspace is never
passed to TriplesClient.query()."""
tc = AsyncMock()
tc.query.return_value = [
make_triple(iri("http://s"), iri("http://p"), lit("o")),
]
bgp1 = make_bgp((Variable("s"), Variable("p"), Variable("o")))
bgp2 = make_bgp((Variable("a"), Variable("b"), Variable("c")))
tree = make_select(make_project(
make_union(bgp1, bgp2), ["s", "p", "o"]
))
await evaluate(tree, tc, collection="test-coll")
for c in tc.query.call_args_list:
assert "workspace" not in c.kwargs
@pytest.mark.asyncio
async def test_join(self):
tc = AsyncMock()
tc.query.side_effect = [
[make_triple(iri("http://a"), iri("http://p"), lit("v"))],
[make_triple(iri("http://a"), iri("http://q"), lit("w"))],
]
bgp1 = make_bgp((Variable("s"), URIRef("http://p"), Variable("v1")))
bgp2 = make_bgp((Variable("s"), URIRef("http://q"), Variable("v2")))
tree = make_join(bgp1, bgp2)
solutions = await evaluate(tree, tc, collection="default")
assert len(solutions) == 1
assert solutions[0]["s"].iri == "http://a"
@pytest.mark.asyncio
async def test_slice(self):
tc = AsyncMock()
triples = [
make_triple(iri(f"http://s{i}"), iri("http://p"), lit(f"o{i}"))
for i in range(5)
]
tc.query.return_value = triples
bgp = make_bgp((Variable("s"), Variable("p"), Variable("o")))
tree = make_slice(bgp, start=1, length=2)
solutions = await evaluate(tree, tc, collection="default")
assert len(solutions) == 2
@pytest.mark.asyncio
async def test_distinct(self):
tc = AsyncMock()
triple = make_triple(iri("http://s"), iri("http://p"), lit("o"))
tc.query.return_value = [triple, triple]
bgp = make_bgp((Variable("s"), Variable("p"), Variable("o")))
tree = make_distinct(bgp)
solutions = await evaluate(tree, tc, collection="default")
assert len(solutions) == 1
@pytest.mark.asyncio
async def test_unsupported_node_returns_empty_solution(self):
tc = AsyncMock()
node = CompValue("SomethingUnknown")
solutions = await evaluate(node, tc, collection="default")
assert solutions == [{}]
tc.query.assert_not_called()
@pytest.mark.asyncio
async def test_non_compvalue_returns_empty_solution(self):
tc = AsyncMock()
solutions = await evaluate("not a node", tc, collection="default")
assert solutions == [{}]

View file

@ -2,8 +2,10 @@
Tests for Cassandra triples query service
"""
import asyncio
import pytest
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, patch, AsyncMock
from trustgraph.query.triples.cassandra.service import Processor, create_term
from trustgraph.schema import Term, IRI, LITERAL
@ -18,7 +20,7 @@ class TestCassandraQueryProcessor:
return Processor(
taskgroup=MagicMock(),
id='test-cassandra-query',
graph_host='localhost'
cassandra_host='localhost'
)
def test_create_term_with_http_uri(self, processor):
@ -85,7 +87,7 @@ class TestCassandraQueryProcessor:
mock_result.dtype = None
mock_result.lang = None
mock_result.o = 'test_object'
mock_tg_instance.get_spo.return_value = [mock_result]
mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result])
processor = Processor(
taskgroup=MagicMock(),
@ -110,8 +112,8 @@ class TestCassandraQueryProcessor:
keyspace='test_user'
)
# Verify get_spo was called with correct parameters
mock_tg_instance.get_spo.assert_called_once_with(
# Verify async_get_spo was called with correct parameters
mock_tg_instance.async_get_spo.assert_called_once_with(
'test_collection', 'test_subject', 'test_predicate', 'test_object', g=None, limit=100
)
@ -130,7 +132,8 @@ class TestCassandraQueryProcessor:
assert processor.cassandra_host == ['cassandra'] # Updated default
assert processor.cassandra_username is None
assert processor.cassandra_password is None
assert processor.table is None
assert processor._connections == {}
assert isinstance(processor._conn_lock, asyncio.Lock)
def test_processor_initialization_with_custom_params(self):
"""Test processor initialization with custom parameters"""
@ -146,7 +149,8 @@ class TestCassandraQueryProcessor:
assert processor.cassandra_host == ['cassandra.example.com']
assert processor.cassandra_username == 'queryuser'
assert processor.cassandra_password == 'querypass'
assert processor.table is None
assert processor._connections == {}
assert isinstance(processor._conn_lock, asyncio.Lock)
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
@ -164,7 +168,7 @@ class TestCassandraQueryProcessor:
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_sp.return_value = [mock_result]
mock_tg_instance.async_get_sp = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock())
@ -178,7 +182,7 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
mock_tg_instance.async_get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
assert len(result) == 1
assert result[0].s.iri == 'test_subject'
assert result[0].p.iri == 'test_predicate'
@ -200,7 +204,7 @@ class TestCassandraQueryProcessor:
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_s.return_value = [mock_result]
mock_tg_instance.async_get_s = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock())
@ -214,7 +218,7 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
mock_tg_instance.async_get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
assert len(result) == 1
assert result[0].s.iri == 'test_subject'
assert result[0].p.iri == 'result_predicate'
@ -236,7 +240,7 @@ class TestCassandraQueryProcessor:
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_p.return_value = [mock_result]
mock_tg_instance.async_get_p = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock())
@ -250,7 +254,7 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
mock_tg_instance.async_get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
assert len(result) == 1
assert result[0].s.iri == 'result_subject'
assert result[0].p.iri == 'test_predicate'
@ -272,7 +276,7 @@ class TestCassandraQueryProcessor:
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_o.return_value = [mock_result]
mock_tg_instance.async_get_o = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock())
@ -286,7 +290,7 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
mock_tg_instance.async_get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
assert len(result) == 1
assert result[0].s.iri == 'result_subject'
assert result[0].p.iri == 'result_predicate'
@ -305,11 +309,11 @@ class TestCassandraQueryProcessor:
mock_result.s = 'all_subject'
mock_result.p = 'all_predicate'
mock_result.o = 'all_object'
mock_result.g = ''
mock_result.d = ''
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_all.return_value = [mock_result]
mock_tg_instance.async_get_all = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock())
@ -323,7 +327,7 @@ class TestCassandraQueryProcessor:
result = await processor.query_triples('test_user', query)
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000)
mock_tg_instance.async_get_all.assert_called_once_with('test_collection', limit=1000)
assert len(result) == 1
assert result[0].s.iri == 'all_subject'
assert result[0].p.iri == 'all_predicate'
@ -410,7 +414,7 @@ class TestCassandraQueryProcessor:
mock_result.dtype = None
mock_result.lang = None
mock_result.o = 'test_object'
mock_tg_instance.get_spo.return_value = [mock_result]
mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result])
processor = Processor(
taskgroup=MagicMock(),
@ -451,7 +455,7 @@ class TestCassandraQueryProcessor:
mock_result.dtype = None
mock_result.lang = None
mock_result.o = 'test_object'
mock_tg_instance.get_spo.return_value = [mock_result]
mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock())
@ -489,8 +493,8 @@ class TestCassandraQueryProcessor:
mock_result.lang = None
mock_result.p = 'p'
mock_result.o = 'o'
mock_tg_instance1.get_s.return_value = [mock_result]
mock_tg_instance2.get_s.return_value = [mock_result]
mock_tg_instance1.async_get_s = AsyncMock(return_value=[mock_result])
mock_tg_instance2.async_get_s = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock())
@ -504,7 +508,6 @@ class TestCassandraQueryProcessor:
)
await processor.query_triples('user1', query1)
assert processor.table == 'user1'
# Second query with different table
query2 = TriplesQueryRequest(
@ -516,10 +519,11 @@ class TestCassandraQueryProcessor:
)
await processor.query_triples('user2', query2)
assert processor.table == 'user2'
# Verify TrustGraph was created twice
# Verify TrustGraph was created twice for different workspaces
assert mock_kg_class.call_count == 2
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user1')
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user2')
@pytest.mark.asyncio
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
@ -529,7 +533,7 @@ class TestCassandraQueryProcessor:
mock_tg_instance = MagicMock()
mock_kg_class.return_value = mock_tg_instance
mock_tg_instance.get_spo.side_effect = Exception("Query failed")
mock_tg_instance.async_get_spo = AsyncMock(side_effect=Exception("Query failed"))
processor = Processor(taskgroup=MagicMock())
@ -566,7 +570,7 @@ class TestCassandraQueryProcessor:
mock_result2.otype = None
mock_result2.dtype = None
mock_result2.lang = None
mock_tg_instance.get_sp.return_value = [mock_result1, mock_result2]
mock_tg_instance.async_get_sp = AsyncMock(return_value=[mock_result1, mock_result2])
processor = Processor(taskgroup=MagicMock())
@ -603,7 +607,7 @@ class TestCassandraQueryPerformanceOptimizations:
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_po.return_value = [mock_result]
mock_tg_instance.async_get_po = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock())
@ -618,8 +622,8 @@ class TestCassandraQueryPerformanceOptimizations:
result = await processor.query_triples('test_user', query)
# Verify get_po was called (should use optimized po_table)
mock_tg_instance.get_po.assert_called_once_with(
# Verify async_get_po was called (should use optimized po_table)
mock_tg_instance.async_get_po.assert_called_once_with(
'test_collection', 'test_predicate', 'test_object', g=None, limit=50
)
@ -643,7 +647,7 @@ class TestCassandraQueryPerformanceOptimizations:
mock_result.otype = None
mock_result.dtype = None
mock_result.lang = None
mock_tg_instance.get_os.return_value = [mock_result]
mock_tg_instance.async_get_os = AsyncMock(return_value=[mock_result])
processor = Processor(taskgroup=MagicMock())
@ -658,8 +662,8 @@ class TestCassandraQueryPerformanceOptimizations:
result = await processor.query_triples('test_user', query)
# Verify get_os was called (should use optimized subject_table with clustering)
mock_tg_instance.get_os.assert_called_once_with(
# Verify async_get_os was called (should use optimized subject_table with clustering)
mock_tg_instance.async_get_os.assert_called_once_with(
'test_collection', 'test_object', 'test_subject', g=None, limit=25
)
@ -678,28 +682,28 @@ class TestCassandraQueryPerformanceOptimizations:
mock_kg_class.return_value = mock_tg_instance
# Mock empty results for all queries
mock_tg_instance.get_all.return_value = []
mock_tg_instance.get_s.return_value = []
mock_tg_instance.get_p.return_value = []
mock_tg_instance.get_o.return_value = []
mock_tg_instance.get_sp.return_value = []
mock_tg_instance.get_po.return_value = []
mock_tg_instance.get_os.return_value = []
mock_tg_instance.get_spo.return_value = []
mock_tg_instance.async_get_all = AsyncMock(return_value=[])
mock_tg_instance.async_get_s = AsyncMock(return_value=[])
mock_tg_instance.async_get_p = AsyncMock(return_value=[])
mock_tg_instance.async_get_o = AsyncMock(return_value=[])
mock_tg_instance.async_get_sp = AsyncMock(return_value=[])
mock_tg_instance.async_get_po = AsyncMock(return_value=[])
mock_tg_instance.async_get_os = AsyncMock(return_value=[])
mock_tg_instance.async_get_spo = AsyncMock(return_value=[])
processor = Processor(taskgroup=MagicMock())
# Test each query pattern
test_patterns = [
# (s, p, o, expected_method)
(None, None, None, 'get_all'), # All triples
('s1', None, None, 'get_s'), # Subject only
(None, 'p1', None, 'get_p'), # Predicate only
(None, None, 'o1', 'get_o'), # Object only
('s1', 'p1', None, 'get_sp'), # Subject + Predicate
(None, 'p1', 'o1', 'get_po'), # Predicate + Object (CRITICAL OPTIMIZATION)
('s1', None, 'o1', 'get_os'), # Object + Subject
('s1', 'p1', 'o1', 'get_spo'), # All three
(None, None, None, 'async_get_all'), # All triples
('s1', None, None, 'async_get_s'), # Subject only
(None, 'p1', None, 'async_get_p'), # Predicate only
(None, None, 'o1', 'async_get_o'), # Object only
('s1', 'p1', None, 'async_get_sp'), # Subject + Predicate
(None, 'p1', 'o1', 'async_get_po'), # Predicate + Object (CRITICAL OPTIMIZATION)
('s1', None, 'o1', 'async_get_os'), # Object + Subject
('s1', 'p1', 'o1', 'async_get_spo'), # All three
]
for s, p, o, expected_method in test_patterns:
@ -759,7 +763,7 @@ class TestCassandraQueryPerformanceOptimizations:
mock_result.lang = None
mock_results.append(mock_result)
mock_tg_instance.get_po.return_value = mock_results
mock_tg_instance.async_get_po = AsyncMock(return_value=mock_results)
processor = Processor(taskgroup=MagicMock())
@ -774,8 +778,8 @@ class TestCassandraQueryPerformanceOptimizations:
result = await processor.query_triples('large_dataset_user', query)
# Verify optimized get_po was used (no ALLOW FILTERING needed!)
mock_tg_instance.get_po.assert_called_once_with(
# Verify optimized async_get_po was used (no ALLOW FILTERING needed!)
mock_tg_instance.async_get_po.assert_called_once_with(
'massive_collection',
'http://www.w3.org/1999/02/22-rdf-syntax-ns#type',
'http://example.com/Person',

View file

@ -113,12 +113,15 @@ class TestDocEmbeddingsNullProtection:
@pytest.mark.asyncio
async def test_valid_embedding_upserted(self):
import asyncio
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
proc = Processor.__new__(Processor)
proc.qdrant = MagicMock()
proc.qdrant.collection_exists.return_value = True
proc.collection_exists = MagicMock(return_value=True)
proc._cache_lock = asyncio.Lock()
proc._known_collections = set()
msg = MagicMock()
msg.metadata.collection = "col1"
@ -134,12 +137,15 @@ class TestDocEmbeddingsNullProtection:
@pytest.mark.asyncio
async def test_dimension_in_collection_name(self):
"""Collection name should include vector dimension."""
import asyncio
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
proc = Processor.__new__(Processor)
proc.qdrant = MagicMock()
proc.qdrant.collection_exists.return_value = True
proc.collection_exists = MagicMock(return_value=True)
proc._cache_lock = asyncio.Lock()
proc._known_collections = set()
msg = MagicMock()
msg.metadata.collection = "docs"
@ -220,12 +226,15 @@ class TestGraphEmbeddingsNullProtection:
@pytest.mark.asyncio
async def test_valid_entity_and_vector_upserted(self):
import asyncio
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
proc = Processor.__new__(Processor)
proc.qdrant = MagicMock()
proc.qdrant.collection_exists.return_value = True
proc.collection_exists = MagicMock(return_value=True)
proc._cache_lock = asyncio.Lock()
proc._known_collections = set()
msg = MagicMock()
msg.metadata.collection = "col1"
@ -241,12 +250,15 @@ class TestGraphEmbeddingsNullProtection:
@pytest.mark.asyncio
async def test_lazy_collection_creation_on_new_dimension(self):
import asyncio
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
proc = Processor.__new__(Processor)
proc.qdrant = MagicMock()
proc.qdrant.collection_exists.return_value = False
proc.collection_exists = MagicMock(return_value=True)
proc._cache_lock = asyncio.Lock()
proc._known_collections = set()
msg = MagicMock()
msg.metadata.collection = "graphs"

View file

@ -337,6 +337,57 @@ class TestQuery:
cache_key = "test_collection:unlabeled_entity"
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
@pytest.mark.asyncio
async def test_triples_query_never_passes_workspace(self):
"""Workspace isolation is handled by pub/sub topic routing, not
by passing workspace to TriplesClient.query(). Verify that
GraphRAG never passes workspace as a keyword argument."""
mock_rag = MagicMock()
mock_cache = MagicMock()
mock_cache.get.return_value = None
mock_rag.label_cache = mock_cache
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_triple = MagicMock()
mock_triple.o = "Label"
mock_triples_client.query.return_value = [mock_triple]
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False
)
await query.maybe_label("http://example.com/entity")
for c in mock_triples_client.query.call_args_list:
assert "workspace" not in c.kwargs
@pytest.mark.asyncio
async def test_follow_edges_never_passes_workspace(self):
"""Verify follow_edges never passes workspace to query_stream."""
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_triple = MagicMock()
mock_triple.s, mock_triple.p, mock_triple.o = "e1", "p1", "o1"
mock_triples_client.query_stream.return_value = [mock_triple]
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
triple_limit=10
)
subgraph = set()
await query.follow_edges("e1", subgraph, path_length=1)
for c in mock_triples_client.query_stream.call_args_list:
assert "workspace" not in c.kwargs
@pytest.mark.asyncio
async def test_follow_edges_basic_functionality(self):
"""Test Query.follow_edges method basic triple discovery"""

View file

@ -413,8 +413,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Assert
expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions
# Verify collection existence is checked on each write
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
# Second write uses cached collection state — no collection_exists check
mock_qdrant_instance.collection_exists.assert_not_called()
# But upsert should still be called
mock_qdrant_instance.upsert.assert_called_once()

View file

@ -125,13 +125,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config)
processor.ensure_collection("test_collection", 384)
await processor.ensure_collection("test_collection", 384)
mock_qdrant_instance.collection_exists.assert_called_once_with("test_collection")
mock_qdrant_instance.create_collection.assert_called_once()
# Verify the collection is cached
assert "test_collection" in processor.created_collections
assert "test_collection" in processor._known_collections
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_ensure_collection_skips_existing(self, mock_qdrant_client):
@ -149,7 +149,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config)
processor.ensure_collection("existing_collection", 384)
await processor.ensure_collection("existing_collection", 384)
mock_qdrant_instance.collection_exists.assert_called_once()
mock_qdrant_instance.create_collection.assert_not_called()
@ -168,9 +168,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
processor.created_collections.add("cached_collection")
processor._known_collections.add("cached_collection")
processor.ensure_collection("cached_collection", 384)
await processor.ensure_collection("cached_collection", 384)
# Should not check or create - just return
mock_qdrant_instance.collection_exists.assert_not_called()
@ -391,7 +391,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
processor.created_collections.add('rows_test_workspace_test_collection_schema1_384')
processor._known_collections.add('rows_test_workspace_test_collection_schema1_384')
await processor.delete_collection('test_workspace', 'test_collection')
@ -399,7 +399,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
assert mock_qdrant_instance.delete_collection.call_count == 2
# Verify the cached collection was removed
assert 'rows_test_workspace_test_collection_schema1_384' not in processor.created_collections
assert 'rows_test_workspace_test_collection_schema1_384' not in processor._known_collections
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_delete_collection_schema(self, mock_qdrant_client):

View file

@ -121,10 +121,13 @@ class TestRowsCassandraStorageLogic:
@pytest.mark.asyncio
async def test_schema_config_parsing(self):
"""Test parsing of schema configurations"""
import asyncio
processor = MagicMock()
processor.schemas = {}
processor.config_key = "schema"
processor.registered_partitions = set()
processor._setup_lock = asyncio.Lock()
processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor)
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
# Create test configuration

View file

@ -2,6 +2,8 @@
Tests for Cassandra triples storage service
"""
import asyncio
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
@ -24,7 +26,8 @@ class TestCassandraStorageProcessor:
assert processor.cassandra_host == ['cassandra'] # Updated default
assert processor.cassandra_username is None
assert processor.cassandra_password is None
assert processor.table is None
assert processor._connections == {}
assert isinstance(processor._conn_lock, asyncio.Lock)
def test_processor_initialization_with_custom_params(self):
"""Test processor initialization with custom parameters (new cassandra_* names)"""
@ -41,7 +44,8 @@ class TestCassandraStorageProcessor:
assert processor.cassandra_host == ['cassandra.example.com']
assert processor.cassandra_username == 'testuser'
assert processor.cassandra_password == 'testpass'
assert processor.table is None
assert processor._connections == {}
assert isinstance(processor._conn_lock, asyncio.Lock)
def test_processor_initialization_with_partial_auth(self):
"""Test processor initialization with only username (no password)"""
@ -92,6 +96,7 @@ class TestCassandraStorageProcessor:
"""Test table switching logic when authentication is provided"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance
processor = Processor(
@ -114,7 +119,6 @@ class TestCassandraStorageProcessor:
username='testuser',
password='testpass'
)
assert processor.table == 'user1'
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
@ -122,6 +126,7 @@ class TestCassandraStorageProcessor:
"""Test table switching logic when no authentication is provided"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
@ -138,7 +143,6 @@ class TestCassandraStorageProcessor:
hosts=['cassandra'], # Updated default
keyspace='user2'
)
assert processor.table == 'user2'
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
@ -146,6 +150,7 @@ class TestCassandraStorageProcessor:
"""Test that TrustGraph is not recreated when table hasn't changed"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
@ -169,6 +174,7 @@ class TestCassandraStorageProcessor:
"""Test that triples are properly inserted into Cassandra"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
@ -208,12 +214,12 @@ class TestCassandraStorageProcessor:
await processor.store_triples('user1', mock_message)
# Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters)
assert mock_tg_instance.insert.call_count == 2
mock_tg_instance.insert.assert_any_call(
assert mock_tg_instance.async_insert.call_count == 2
mock_tg_instance.async_insert.assert_any_call(
'collection1', 'subject1', 'predicate1', 'object1',
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
)
mock_tg_instance.insert.assert_any_call(
mock_tg_instance.async_insert.assert_any_call(
'collection1', 'subject2', 'predicate2', 'object2',
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
)
@ -224,6 +230,7 @@ class TestCassandraStorageProcessor:
"""Test behavior when message has no triples"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
@ -236,19 +243,17 @@ class TestCassandraStorageProcessor:
await processor.store_triples('user1', mock_message)
# Verify no triples were inserted
mock_tg_instance.insert.assert_not_called()
mock_tg_instance.async_insert.assert_not_called()
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
@patch('trustgraph.storage.triples.cassandra.write.time.sleep')
async def test_exception_handling_with_retry(self, mock_sleep, mock_kg_class):
async def test_exception_handling_on_connection_failure(self, mock_kg_class):
"""Test exception handling during TrustGraph creation"""
taskgroup_mock = MagicMock()
mock_kg_class.side_effect = Exception("Connection failed")
processor = Processor(taskgroup=taskgroup_mock)
# Create mock message
mock_message = MagicMock()
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
@ -256,9 +261,6 @@ class TestCassandraStorageProcessor:
with pytest.raises(Exception, match="Connection failed"):
await processor.store_triples('user1', mock_message)
# Verify sleep was called before re-raising
mock_sleep.assert_called_once_with(1)
def test_add_args_method(self):
"""Test that add_args properly configures argument parser"""
from argparse import ArgumentParser
@ -359,8 +361,6 @@ class TestCassandraStorageProcessor:
mock_message1.triples = []
await processor.store_triples('user1', mock_message1)
assert processor.table == 'user1'
assert processor.tg == mock_tg_instance1
# Second message with different table
mock_message2 = MagicMock()
@ -368,11 +368,11 @@ class TestCassandraStorageProcessor:
mock_message2.triples = []
await processor.store_triples('user2', mock_message2)
assert processor.table == 'user2'
assert processor.tg == mock_tg_instance2
# Verify TrustGraph was created twice for different tables
# Verify TrustGraph was created twice for different workspaces
assert mock_kg_class.call_count == 2
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user1')
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user2')
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
@ -380,6 +380,7 @@ class TestCassandraStorageProcessor:
"""Test storing triples with special characters and unicode"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
@ -405,7 +406,7 @@ class TestCassandraStorageProcessor:
await processor.store_triples('test_workspace', mock_message)
# Verify the triple was inserted with special characters preserved
mock_tg_instance.insert.assert_called_once_with(
mock_tg_instance.async_insert.assert_called_once_with(
'test_collection',
'subject with spaces & symbols',
'predicate:with/colons',
@ -418,29 +419,29 @@ class TestCassandraStorageProcessor:
@pytest.mark.asyncio
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
async def test_store_triples_preserves_old_table_on_exception(self, mock_kg_class):
"""Test that table remains unchanged when TrustGraph creation fails"""
async def test_connection_failure_does_not_cache_stale_state(self, mock_kg_class):
"""Test that a failed connection doesn't leave stale cached state"""
taskgroup_mock = MagicMock()
mock_good_instance = MagicMock()
processor = Processor(taskgroup=taskgroup_mock)
# Set an initial table
processor.table = ('old_user', 'old_collection')
# Mock TrustGraph to raise exception
mock_kg_class.side_effect = Exception("Connection failed")
mock_message = MagicMock()
mock_message.metadata.collection = 'new_collection'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
# First call fails
mock_kg_class.side_effect = Exception("Connection failed")
with pytest.raises(Exception, match="Connection failed"):
await processor.store_triples('new_user', mock_message)
await processor.store_triples('user1', mock_message)
# Table should remain unchanged since self.table = table happens after try/except
assert processor.table == ('old_user', 'old_collection')
# TrustGraph should be set to None though
assert processor.tg is None
# Second call succeeds — should retry connection, not use stale state
mock_kg_class.side_effect = None
mock_kg_class.return_value = mock_good_instance
await processor.store_triples('user1', mock_message)
# Connection was attempted twice (failed + succeeded)
assert mock_kg_class.call_count == 2
class TestCassandraPerformanceOptimizations:
@ -452,6 +453,7 @@ class TestCassandraPerformanceOptimizations:
"""Test that legacy mode still works with single table"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}):
@ -472,6 +474,7 @@ class TestCassandraPerformanceOptimizations:
"""Test that optimized mode uses multi-table schema"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}):
@ -492,6 +495,7 @@ class TestCassandraPerformanceOptimizations:
"""Test that all tables stay consistent during batch writes"""
taskgroup_mock = MagicMock()
mock_tg_instance = MagicMock()
mock_tg_instance.async_insert = AsyncMock()
mock_kg_class.return_value = mock_tg_instance
processor = Processor(taskgroup=taskgroup_mock)
@ -517,7 +521,7 @@ class TestCassandraPerformanceOptimizations:
await processor.store_triples('user1', mock_message)
# Verify insert was called for the triple (implementation details tested in KnowledgeGraph)
mock_tg_instance.insert.assert_called_once_with(
mock_tg_instance.async_insert.assert_called_once_with(
'collection1', 'test_subject', 'test_predicate', 'test_object',
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
)

View file

@ -89,7 +89,8 @@ class TestSanitizeName:
class TestFindCollection:
def test_finds_matching_collection(self):
@pytest.mark.asyncio
async def test_finds_matching_collection(self):
proc = _make_processor()
mock_coll = MagicMock()
mock_coll.name = "rows_test_workspace_test_col_customers_384"
@ -98,11 +99,12 @@ class TestFindCollection:
mock_collections.collections = [mock_coll]
proc.qdrant.get_collections.return_value = mock_collections
result = proc.find_collection("test-workspace", "test-col", "customers")
result = await proc.find_collection("test-workspace", "test-col", "customers")
assert result == "rows_test_workspace_test_col_customers_384"
def test_returns_none_when_no_match(self):
@pytest.mark.asyncio
async def test_returns_none_when_no_match(self):
proc = _make_processor()
mock_coll = MagicMock()
mock_coll.name = "rows_other_workspace_other_col_schema_768"
@ -111,14 +113,15 @@ class TestFindCollection:
mock_collections.collections = [mock_coll]
proc.qdrant.get_collections.return_value = mock_collections
result = proc.find_collection("test-workspace", "test-col", "customers")
result = await proc.find_collection("test-workspace", "test-col", "customers")
assert result is None
def test_returns_none_on_error(self):
@pytest.mark.asyncio
async def test_returns_none_on_error(self):
proc = _make_processor()
proc.qdrant.get_collections.side_effect = Exception("connection error")
result = proc.find_collection("workspace", "col", "schema")
result = await proc.find_collection("workspace", "col", "schema")
assert result is None
@ -139,7 +142,7 @@ class TestQueryRowEmbeddings:
@pytest.mark.asyncio
async def test_no_collection_returns_empty(self):
proc = _make_processor()
proc.find_collection = MagicMock(return_value=None)
proc.find_collection = AsyncMock(return_value=None)
request = _make_request()
result = await proc.query_row_embeddings("test-workspace", request)
@ -148,7 +151,7 @@ class TestQueryRowEmbeddings:
@pytest.mark.asyncio
async def test_successful_query_returns_matches(self):
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
points = [
_make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95),
@ -172,7 +175,7 @@ class TestQueryRowEmbeddings:
async def test_index_name_filter_applied(self):
"""When index_name is specified, a Qdrant filter should be used."""
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
mock_result = MagicMock()
mock_result.points = []
@ -188,7 +191,7 @@ class TestQueryRowEmbeddings:
async def test_no_index_name_no_filter(self):
"""When index_name is empty, no filter should be applied."""
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
mock_result = MagicMock()
mock_result.points = []
@ -204,7 +207,7 @@ class TestQueryRowEmbeddings:
async def test_missing_payload_fields_default(self):
"""Points with missing payload fields should use defaults."""
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
point = MagicMock()
point.payload = {} # Empty payload
@ -225,7 +228,7 @@ class TestQueryRowEmbeddings:
@pytest.mark.asyncio
async def test_qdrant_error_propagates(self):
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
proc.qdrant.query_points.side_effect = Exception("qdrant down")
request = _make_request()

View file

@ -132,3 +132,34 @@ class Knowledge:
self.request(request = input)
def list_de_cores(self):
input = {
"operation": "list-de-cores",
"workspace": self.api.workspace,
}
return self.request(request = input)["ids"]
def delete_de_core(self, id):
input = {
"operation": "delete-de-core",
"workspace": self.api.workspace,
"id": id,
}
self.request(request = input)
def load_de_core(self, id, flow="default", collection="default"):
input = {
"operation": "load-de-core",
"workspace": self.api.workspace,
"id": id,
"flow": flow,
"collection": collection,
}
self.request(request = input)

View file

@ -491,6 +491,58 @@ class SocketClient:
triples=raw_triples,
)
def get_kg_core(self, id: str) -> Iterator[Dict[str, Any]]:
request = {
"operation": "get-kg-core",
"workspace": self.workspace,
"id": id,
}
for response in self._send_request_sync(
"knowledge", None, request, streaming_raw=True,
):
if response.get("eos"):
break
yield response
def put_kg_core(
self, id: str, triples=None, graph_embeddings=None,
) -> Dict[str, Any]:
request = {
"operation": "put-kg-core",
"workspace": self.workspace,
"id": id,
}
if triples is not None:
request["triples"] = triples
if graph_embeddings is not None:
request["graph-embeddings"] = graph_embeddings
return self._send_request_sync("knowledge", None, request)
def get_de_core(self, id: str) -> Iterator[Dict[str, Any]]:
request = {
"operation": "get-de-core",
"workspace": self.workspace,
"id": id,
}
for response in self._send_request_sync(
"knowledge", None, request, streaming_raw=True,
):
if response.get("eos"):
break
yield response
def put_de_core(
self, id: str, document_embeddings=None,
) -> Dict[str, Any]:
request = {
"operation": "put-de-core",
"workspace": self.workspace,
"id": id,
}
if document_embeddings is not None:
request["document-embeddings"] = document_embeddings
return self._send_request_sync("knowledge", None, request)
def close(self) -> None:
"""Close the persistent WebSocket connection."""
if self._loop and not self._loop.is_closed():

View file

@ -1,6 +1,7 @@
from typing import Dict, Any, Tuple, Optional
from ...schema import (
KnowledgeRequest, KnowledgeResponse, Triples, GraphEmbeddings,
DocumentEmbeddings, ChunkEmbeddings,
Metadata, EntityEmbeddings
)
from .base import MessageTranslator
@ -43,6 +44,23 @@ class KnowledgeRequestTranslator(MessageTranslator):
]
)
document_embeddings = None
if "document-embeddings" in data:
document_embeddings = DocumentEmbeddings(
metadata=Metadata(
id=data["document-embeddings"]["metadata"]["id"],
root=data["document-embeddings"]["metadata"].get("root", ""),
collection=data["document-embeddings"]["metadata"]["collection"]
),
chunks=[
ChunkEmbeddings(
chunk_id=ch["chunk_id"],
vector=ch["vector"],
)
for ch in data["document-embeddings"]["chunks"]
]
)
return KnowledgeRequest(
operation=data.get("operation"),
id=data.get("id"),
@ -50,6 +68,7 @@ class KnowledgeRequestTranslator(MessageTranslator):
collection=data.get("collection"),
triples=triples,
graph_embeddings=graph_embeddings,
document_embeddings=document_embeddings,
)
def encode(self, obj: KnowledgeRequest) -> Dict[str, Any]:
@ -90,6 +109,22 @@ class KnowledgeRequestTranslator(MessageTranslator):
],
}
if obj.document_embeddings:
result["document-embeddings"] = {
"metadata": {
"id": obj.document_embeddings.metadata.id,
"root": obj.document_embeddings.metadata.root,
"collection": obj.document_embeddings.metadata.collection,
},
"chunks": [
{
"chunk_id": ch.chunk_id,
"vector": ch.vector,
}
for ch in obj.document_embeddings.chunks
],
}
return result
@ -140,6 +175,25 @@ class KnowledgeResponseTranslator(MessageTranslator):
}
}
# Streaming document embeddings response
if obj.document_embeddings:
return {
"document-embeddings": {
"metadata": {
"id": obj.document_embeddings.metadata.id,
"root": obj.document_embeddings.metadata.root,
"collection": obj.document_embeddings.metadata.collection,
},
"chunks": [
{
"chunk_id": ch.chunk_id,
"vector": ch.vector,
}
for ch in obj.document_embeddings.chunks
],
}
}
# End of stream marker
if obj.eos is True:
return {"eos": True}
@ -155,7 +209,7 @@ class KnowledgeResponseTranslator(MessageTranslator):
is_final = (
obj.ids is not None or # List response
obj.eos is True or # End of stream
(not obj.triples and not obj.graph_embeddings) # Empty response
(not obj.triples and not obj.graph_embeddings and not obj.document_embeddings) # Empty response
)
return response, is_final

View file

@ -4,7 +4,7 @@ from ..core.topic import queue
from ..core.metadata import Metadata
from .document import Document, TextDocument
from .graph import Triples
from .embeddings import GraphEmbeddings
from .embeddings import GraphEmbeddings, DocumentEmbeddings
# get-kg-core
# -> (???)
@ -41,6 +41,9 @@ class KnowledgeRequest:
triples: Triples | None = None
graph_embeddings: GraphEmbeddings | None = None
# put-de-core
document_embeddings: DocumentEmbeddings | None = None
@dataclass
class KnowledgeResponse:
error: Error | None = None
@ -48,6 +51,7 @@ class KnowledgeResponse:
eos: bool = False # Indicates end of knowledge core stream
triples: Triples | None = None
graph_embeddings: GraphEmbeddings | None = None
document_embeddings: DocumentEmbeddings | None = None
knowledge_request_queue = queue('knowledge', cls='request')
knowledge_response_queue = queue('knowledge', cls='response')

View file

@ -37,6 +37,7 @@ tg-dump-msgpack = "trustgraph.cli.dump_msgpack:main"
tg-dump-queues = "trustgraph.cli.dump_queues:main"
tg-monitor-prompts = "trustgraph.cli.monitor_prompts:main"
tg-get-flow-blueprint = "trustgraph.cli.get_flow_blueprint:main"
tg-get-de-core = "trustgraph.cli.get_de_core:main"
tg-get-kg-core = "trustgraph.cli.get_kg_core:main"
tg-get-document-content = "trustgraph.cli.get_document_content:main"
tg-graph-to-turtle = "trustgraph.cli.graph_to_turtle:main"
@ -77,6 +78,7 @@ tg-load-turtle = "trustgraph.cli.load_turtle:main"
tg-load-knowledge = "trustgraph.cli.load_knowledge:main"
tg-load-structured-data = "trustgraph.cli.load_structured_data:main"
tg-put-flow-blueprint = "trustgraph.cli.put_flow_blueprint:main"
tg-put-de-core = "trustgraph.cli.put_de_core:main"
tg-put-kg-core = "trustgraph.cli.put_kg_core:main"
tg-remove-library-document = "trustgraph.cli.remove_library_document:main"
tg-save-doc-embeds = "trustgraph.cli.save_doc_embeds:main"

View file

@ -0,0 +1,111 @@
"""
Uses the knowledge service to fetch a document embeddings core which is
saved to a local file in msgpack format.
"""
import argparse
import os
import msgpack
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def write_de(f, data):
msg = (
"de",
{
"m": {
"i": data["metadata"]["id"],
"m": data["metadata"]["root"],
"c": data["metadata"]["collection"],
},
"c": [
{
"i": ch["chunk_id"],
"v": ch["vector"],
}
for ch in data["chunks"]
]
}
)
f.write(msgpack.packb(msg, use_bin_type=True))
def fetch(url, workspace, id, output, token=None):
api = Api(url=url, token=token, workspace=workspace)
socket = api.socket()
try:
de = 0
with open(output, "wb") as f:
for response in socket.get_de_core(id):
if "document-embeddings" in response:
de += 1
write_de(f, response["document-embeddings"])
print(f"Got: {de} document embeddings messages.")
finally:
socket.close()
def main():
parser = argparse.ArgumentParser(
prog='tg-get-de-core',
description=__doc__,
)
parser.add_argument(
'-u', '--url',
default=default_url,
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'--id', '--identifier',
required=True,
help=f'Document embeddings core ID',
)
parser.add_argument(
'-o', '--output',
required=True,
help=f'Output file'
)
parser.add_argument(
'-t', '--token',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
args = parser.parse_args()
try:
fetch(
url=args.url,
workspace=args.workspace,
id=args.id,
output=args.output,
token=args.token,
)
except Exception as e:
print("Exception:", e, flush=True)
if __name__ == "__main__":
main()

View file

@ -5,13 +5,11 @@ to a local file in msgpack format.
import argparse
import os
import uuid
import asyncio
import json
from websockets.asyncio.client import connect
import msgpack
default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/')
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
@ -21,7 +19,7 @@ def write_triple(f, data):
{
"m": {
"i": data["metadata"]["id"],
"m": data["metadata"]["metadata"],
"m": data["metadata"]["root"],
"c": data["metadata"]["collection"],
},
"t": data["triples"],
@ -35,13 +33,13 @@ def write_ge(f, data):
{
"m": {
"i": data["metadata"]["id"],
"m": data["metadata"]["metadata"],
"m": data["metadata"]["root"],
"c": data["metadata"]["collection"],
},
"e": [
{
"e": ent["entity"],
"v": ent["vectors"],
"v": ent["vector"],
}
for ent in data["entities"]
]
@ -49,54 +47,18 @@ def write_ge(f, data):
)
f.write(msgpack.packb(msg, use_bin_type=True))
async def fetch(url, workspace, id, output, token=None):
def fetch(url, workspace, id, output, token=None):
if not url.endswith("/"):
url += "/"
url = url + "api/v1/socket"
if token:
url = f"{url}?token={token}"
mid = str(uuid.uuid4())
async with connect(url) as ws:
req = json.dumps({
"id": mid,
"workspace": workspace,
"service": "knowledge",
"request": {
"operation": "get-kg-core",
"workspace": workspace,
"id": id,
}
})
await ws.send(req)
api = Api(url=url, token=token, workspace=workspace)
socket = api.socket()
try:
ge = 0
t = 0
with open(output, "wb") as f:
while True:
msg = await ws.recv()
obj = json.loads(msg)
if "response" not in obj:
raise RuntimeError("No response?")
response = obj["response"]
if "error" in response:
raise RuntimeError(obj["error"])
if "eos" in response:
if response["eos"]: break
for response in socket.get_kg_core(id):
if "triples" in response:
t += 1
@ -108,7 +70,8 @@ async def fetch(url, workspace, id, output, token=None):
print(f"Got: {t} triple, {ge} GE messages.")
await ws.close()
finally:
socket.close()
def main():
@ -151,7 +114,6 @@ def main():
try:
asyncio.run(
fetch(
url=args.url,
workspace=args.workspace,
@ -159,7 +121,6 @@ def main():
output=args.output,
token=args.token,
)
)
except Exception as e:

View file

@ -3,11 +3,8 @@ Uses the GraphRAG service to answer a question
"""
import argparse
import json
import os
import sys
import websockets
import asyncio
from trustgraph.api import (
Api,
ExplainabilityClient,
@ -31,607 +28,6 @@ default_max_path_length = 2
default_edge_score_limit = 30
default_edge_limit = 25
# Provenance predicates
TG = "https://trustgraph.ai/ns/"
TG_QUERY = TG + "query"
TG_CONCEPT = TG + "concept"
TG_ENTITY = TG + "entity"
TG_EDGE_COUNT = TG + "edgeCount"
TG_SELECTED_EDGE = TG + "selectedEdge"
TG_EDGE = TG + "edge"
TG_REASONING = TG + "reasoning"
TG_DOCUMENT = TG + "document"
TG_CONTAINS = TG + "contains"
PROV = "http://www.w3.org/ns/prov#"
PROV_STARTED_AT_TIME = PROV + "startedAtTime"
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
def _get_event_type(prov_id):
"""Extract event type from provenance_id"""
if "question" in prov_id:
return "question"
elif "grounding" in prov_id:
return "grounding"
elif "exploration" in prov_id:
return "exploration"
elif "focus" in prov_id:
return "focus"
elif "synthesis" in prov_id:
return "synthesis"
return "provenance"
def _format_provenance_details(event_type, triples):
"""Format provenance details based on event type and triples"""
lines = []
if event_type == "question":
# Show query and timestamp
for s, p, o in triples:
if p == TG_QUERY:
lines.append(f" Query: {o}")
elif p == PROV_STARTED_AT_TIME:
lines.append(f" Time: {o}")
elif event_type == "grounding":
# Show extracted concepts
concepts = [o for s, p, o in triples if p == TG_CONCEPT]
if concepts:
lines.append(f" Concepts: {len(concepts)}")
for concept in concepts:
lines.append(f" - {concept}")
elif event_type == "exploration":
# Show edge count (seed entities resolved separately with labels)
for s, p, o in triples:
if p == TG_EDGE_COUNT:
lines.append(f" Edges explored: {o}")
elif event_type == "focus":
# For focus, just count edge selection URIs
# The actual edge details are fetched separately via edge_selections parameter
edge_sel_uris = []
for s, p, o in triples:
if p == TG_SELECTED_EDGE:
edge_sel_uris.append(o)
if edge_sel_uris:
lines.append(f" Focused on {len(edge_sel_uris)} edge(s)")
elif event_type == "synthesis":
# Show document reference (content already streamed)
for s, p, o in triples:
if p == TG_DOCUMENT:
lines.append(f" Document: {o}")
return lines
async def _query_triples_once(ws_url, flow_id, prov_id, collection, graph=None, debug=False):
"""Query triples for a provenance node (single attempt)"""
request = {
"id": "triples-request",
"service": "triples",
"flow": flow_id,
"request": {
"s": {"t": "i", "i": prov_id},
"collection": collection,
"limit": 100
}
}
# Add graph filter if specified (for named graph queries)
if graph is not None:
request["request"]["g"] = graph
if debug:
print(f" [debug] querying triples for s={prov_id}", file=sys.stderr)
triples = []
try:
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
await websocket.send(json.dumps(request))
async for raw_message in websocket:
response = json.loads(raw_message)
if debug:
print(f" [debug] response: {json.dumps(response)[:200]}", file=sys.stderr)
if response.get("id") != "triples-request":
continue
if "error" in response:
if debug:
print(f" [debug] error: {response['error']}", file=sys.stderr)
break
if "response" in response:
resp = response["response"]
# Handle triples response
# Response format: {"response": [triples...]}
# Each triple uses compact keys: "i" for iri, "v" for value, "t" for type
triple_list = resp.get("response", [])
for t in triple_list:
s = t.get("s", {}).get("i", t.get("s", {}).get("v", ""))
p = t.get("p", {}).get("i", t.get("p", {}).get("v", ""))
# Handle quoted triples (type "t") and regular values
o_term = t.get("o", {})
if o_term.get("t") == "t":
# Quoted triple - extract s, p, o from nested structure
tr = o_term.get("tr", {})
o = {
"s": tr.get("s", {}).get("i", ""),
"p": tr.get("p", {}).get("i", ""),
"o": tr.get("o", {}).get("i", tr.get("o", {}).get("v", "")),
}
else:
o = o_term.get("i", o_term.get("v", ""))
triples.append((s, p, o))
if resp.get("complete") or response.get("complete"):
break
except Exception as e:
if debug:
print(f" [debug] exception: {e}", file=sys.stderr)
if debug:
print(f" [debug] got {len(triples)} triples", file=sys.stderr)
return triples
async def _query_triples(ws_url, flow_id, prov_id, collection, graph=None, max_retries=5, retry_delay=0.2, debug=False):
"""Query triples for a provenance node with retries for race condition"""
for attempt in range(max_retries):
triples = await _query_triples_once(ws_url, flow_id, prov_id, collection, graph=graph, debug=debug)
if triples:
return triples
# Wait before retry if empty (triples may not be stored yet)
if attempt < max_retries - 1:
if debug:
print(f" [debug] retry {attempt + 1}/{max_retries}...", file=sys.stderr)
await asyncio.sleep(retry_delay)
return []
async def _query_edge_provenance(ws_url, flow_id, edge_s, edge_p, edge_o, collection, debug=False):
"""
Query for provenance of an edge (s, p, o) in the knowledge graph.
Finds subgraphs that contain the edge via tg:contains, then follows
prov:wasDerivedFrom to find source documents.
Returns list of source URIs (chunks, pages, documents).
"""
# Query for subgraphs that contain this edge: ?subgraph tg:contains <<s p o>>
request = {
"id": "edge-prov-request",
"service": "triples",
"flow": flow_id,
"request": {
"p": {"t": "i", "i": TG_CONTAINS},
"o": {
"t": "t", # Quoted triple type
"tr": {
"s": {"t": "i", "i": edge_s},
"p": {"t": "i", "i": edge_p},
"o": {"t": "i", "i": edge_o} if edge_o.startswith("http") or edge_o.startswith("urn:") else {"t": "l", "v": edge_o},
}
},
"collection": collection,
"limit": 10
}
}
if debug:
print(f" [debug] querying edge provenance for ({edge_s}, {edge_p}, {edge_o})", file=sys.stderr)
stmt_uris = []
try:
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
await websocket.send(json.dumps(request))
async for raw_message in websocket:
response = json.loads(raw_message)
if response.get("id") != "edge-prov-request":
continue
if "error" in response:
if debug:
print(f" [debug] error: {response['error']}", file=sys.stderr)
break
if "response" in response:
resp = response["response"]
triple_list = resp.get("response", [])
for t in triple_list:
s = t.get("s", {}).get("i", "")
if s:
stmt_uris.append(s)
if resp.get("complete") or response.get("complete"):
break
except Exception as e:
if debug:
print(f" [debug] exception querying edge provenance: {e}", file=sys.stderr)
if debug:
print(f" [debug] found {len(stmt_uris)} reifying statements", file=sys.stderr)
# For each statement, query wasDerivedFrom to find sources
sources = []
for stmt_uri in stmt_uris:
# Query: stmt_uri prov:wasDerivedFrom ?source
request = {
"id": "derived-from-request",
"service": "triples",
"flow": flow_id,
"request": {
"s": {"t": "i", "i": stmt_uri},
"p": {"t": "i", "i": PROV_WAS_DERIVED_FROM},
"collection": collection,
"limit": 10
}
}
try:
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
await websocket.send(json.dumps(request))
async for raw_message in websocket:
response = json.loads(raw_message)
if response.get("id") != "derived-from-request":
continue
if "error" in response:
break
if "response" in response:
resp = response["response"]
triple_list = resp.get("response", [])
for t in triple_list:
o = t.get("o", {}).get("i", "")
if o:
sources.append(o)
if resp.get("complete") or response.get("complete"):
break
except Exception as e:
if debug:
print(f" [debug] exception querying wasDerivedFrom: {e}", file=sys.stderr)
if debug:
print(f" [debug] found {len(sources)} source(s): {sources}", file=sys.stderr)
return sources
async def _query_derived_from(ws_url, flow_id, uri, collection, debug=False):
"""Query for the prov:wasDerivedFrom parent of a URI. Returns None if no parent."""
request = {
"id": "parent-request",
"service": "triples",
"flow": flow_id,
"request": {
"s": {"t": "i", "i": uri},
"p": {"t": "i", "i": PROV_WAS_DERIVED_FROM},
"collection": collection,
"limit": 1
}
}
try:
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
await websocket.send(json.dumps(request))
async for raw_message in websocket:
response = json.loads(raw_message)
if response.get("id") != "parent-request":
continue
if "error" in response:
break
if "response" in response:
resp = response["response"]
triple_list = resp.get("response", [])
if triple_list:
return triple_list[0].get("o", {}).get("i", None)
if resp.get("complete") or response.get("complete"):
break
except Exception as e:
if debug:
print(f" [debug] exception querying parent: {e}", file=sys.stderr)
return None
async def _trace_provenance_chain(ws_url, flow_id, source_uri, collection, label_cache, debug=False):
"""
Trace the full provenance chain from a source URI up to the root document.
Returns a list of (uri, label) tuples from leaf to root.
"""
chain = []
current = source_uri
max_depth = 10 # Prevent infinite loops
for _ in range(max_depth):
if not current:
break
# Get label for current entity
label = await _query_label(ws_url, flow_id, current, collection, label_cache, debug)
chain.append((current, label))
# Get parent
parent = await _query_derived_from(ws_url, flow_id, current, collection, debug)
if not parent or parent == current:
break
current = parent
return chain
def _format_provenance_chain(chain):
"""
Format a provenance chain as a human-readable string.
Chain is [(uri, label), ...] from leaf to root.
"""
if not chain:
return ""
# Show labels, from leaf to root
labels = [label for uri, label in chain]
return "".join(labels)
def _is_iri(value):
"""Check if a value looks like an IRI."""
if not isinstance(value, str):
return False
return value.startswith("http://") or value.startswith("https://") or value.startswith("urn:")
async def _query_label(ws_url, flow_id, iri, collection, label_cache, debug=False):
"""
Query for the rdfs:label of an IRI.
Uses label_cache to avoid repeated queries.
Returns the label if found, otherwise returns the IRI.
"""
if not _is_iri(iri):
return iri
# Check cache first
if iri in label_cache:
return label_cache[iri]
request = {
"id": "label-request",
"service": "triples",
"flow": flow_id,
"request": {
"s": {"t": "i", "i": iri},
"p": {"t": "i", "i": RDFS_LABEL},
"collection": collection,
"limit": 1
}
}
label = iri # Default to IRI if no label found
try:
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
await websocket.send(json.dumps(request))
async for raw_message in websocket:
response = json.loads(raw_message)
if response.get("id") != "label-request":
continue
if "error" in response:
break
if "response" in response:
resp = response["response"]
triple_list = resp.get("response", [])
if triple_list:
# Get the label value
o = triple_list[0].get("o", {})
label = o.get("v", o.get("i", iri))
if resp.get("complete") or response.get("complete"):
break
except Exception as e:
if debug:
print(f" [debug] exception querying label for {iri}: {e}", file=sys.stderr)
# Cache the result
label_cache[iri] = label
return label
async def _resolve_edge_labels(ws_url, flow_id, edge_triple, collection, label_cache, debug=False):
"""
Resolve labels for all IRI components of an edge triple.
Returns (s_label, p_label, o_label).
"""
s = edge_triple.get("s", "?")
p = edge_triple.get("p", "?")
o = edge_triple.get("o", "?")
s_label = await _query_label(ws_url, flow_id, s, collection, label_cache, debug)
p_label = await _query_label(ws_url, flow_id, p, collection, label_cache, debug)
o_label = await _query_label(ws_url, flow_id, o, collection, label_cache, debug)
return s_label, p_label, o_label
async def _question_explainable(
url, flow_id, question, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, token=None, debug=False
):
"""Execute graph RAG with explainability - shows provenance events with details"""
# Convert HTTP URL to WebSocket URL
if url.startswith("http://"):
ws_url = url.replace("http://", "ws://", 1)
elif url.startswith("https://"):
ws_url = url.replace("https://", "wss://", 1)
else:
ws_url = f"ws://{url}"
ws_url = f"{ws_url.rstrip('/')}/api/v1/socket"
if token:
ws_url = f"{ws_url}?token={token}"
# Cache for label lookups to avoid repeated queries
label_cache = {}
request = {
"id": "cli-request",
"service": "graph-rag",
"flow": flow_id,
"request": {
"query": question,
"collection": collection,
"entity-limit": entity_limit,
"triple-limit": triple_limit,
"max-subgraph-size": max_subgraph_size,
"max-path-length": max_path_length,
"streaming": True
}
}
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=300) as websocket:
await websocket.send(json.dumps(request))
async for raw_message in websocket:
response = json.loads(raw_message)
if response.get("id") != "cli-request":
continue
if "error" in response:
print(f"\nError: {response['error']}", file=sys.stderr)
break
if "response" in response:
resp = response["response"]
# Check for errors in response
if "error" in resp and resp["error"]:
err = resp["error"]
print(f"\nError: {err.get('message', 'Unknown error')}", file=sys.stderr)
break
message_type = resp.get("message_type", "")
if debug:
print(f" [debug] message_type={message_type}, keys={list(resp.keys())}", file=sys.stderr)
if message_type == "explain":
# Display explain event with details
explain_id = resp.get("explain_id", "")
explain_graph = resp.get("explain_graph") # Named graph (e.g., urn:graph:retrieval)
if explain_id:
event_type = _get_event_type(explain_id)
print(f"\n [{event_type}] {explain_id}", file=sys.stderr)
# Query triples for this explain node (using named graph filter)
triples = await _query_triples(
ws_url, flow_id, explain_id, collection, graph=explain_graph, debug=debug
)
# Format and display details
details = _format_provenance_details(event_type, triples)
for line in details:
print(line, file=sys.stderr)
# For exploration events, resolve entity labels
if event_type == "exploration":
entity_iris = [o for s, p, o in triples if p == TG_ENTITY]
if entity_iris:
print(f" Seed entities: {len(entity_iris)}", file=sys.stderr)
for iri in entity_iris:
label = await _query_label(
ws_url, flow_id, iri, collection,
label_cache, debug=debug
)
print(f" - {label}", file=sys.stderr)
# For focus events, query each edge selection for details
if event_type == "focus":
for s, p, o in triples:
if debug:
print(f" [debug] triple: p={p}, o={o}, o_type={type(o).__name__}", file=sys.stderr)
if p == TG_SELECTED_EDGE and isinstance(o, str):
if debug:
print(f" [debug] querying edge selection: {o}", file=sys.stderr)
# Query the edge selection entity (using named graph filter)
edge_triples = await _query_triples(
ws_url, flow_id, o, collection, graph=explain_graph, debug=debug
)
if debug:
print(f" [debug] got {len(edge_triples)} edge triples", file=sys.stderr)
# Extract edge and reasoning
edge_triple = None # Store the actual triple for provenance lookup
reasoning = None
for es, ep, eo in edge_triples:
if debug:
print(f" [debug] edge triple: ep={ep}, eo={eo}", file=sys.stderr)
if ep == TG_EDGE and isinstance(eo, dict):
# eo is a quoted triple dict
edge_triple = eo
elif ep == TG_REASONING:
reasoning = eo
if edge_triple:
# Resolve labels for edge components
s_label, p_label, o_label = await _resolve_edge_labels(
ws_url, flow_id, edge_triple, collection,
label_cache, debug=debug
)
print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr)
if reasoning:
r_short = reasoning[:100] + "..." if len(reasoning) > 100 else reasoning
print(f" Reason: {r_short}", file=sys.stderr)
# Trace edge provenance in the workspace collection (not explainability)
if edge_triple:
sources = await _query_edge_provenance(
ws_url, flow_id,
edge_triple.get("s", ""),
edge_triple.get("p", ""),
edge_triple.get("o", ""),
collection, # Use the query collection, not explainability
debug=debug
)
if sources:
for src in sources:
# Trace full chain from source to root document
chain = await _trace_provenance_chain(
ws_url, flow_id, src, collection,
label_cache, debug=debug
)
chain_str = _format_provenance_chain(chain)
print(f" Source: {chain_str}", file=sys.stderr)
elif message_type == "chunk" or not message_type:
# Display response chunk
chunk = resp.get("response", "")
if chunk:
print(chunk, end="", flush=True)
# Check if session is complete
if resp.get("end_of_session"):
break
print() # Final newline
def _question_explainable_api(
url, flow_id, question_text, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, edge_score_limit=30,

View file

@ -0,0 +1,119 @@
"""
Puts a document embeddings core into the knowledge manager via the API
socket.
"""
import argparse
import os
import msgpack
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def read_message(unpacked, id):
if unpacked[0] == "de":
msg = unpacked[1]
return {
"metadata": {
"id": id,
"root": msg["m"]["m"],
"collection": "default",
},
"chunks": [
{
"chunk_id": ch["i"],
"vector": ch["v"],
}
for ch in msg["c"]
],
}
else:
raise RuntimeError("Unexpected message type", unpacked[0])
def put(url, workspace, id, input, token=None):
api = Api(url=url, token=token, workspace=workspace)
socket = api.socket()
try:
de = 0
with open(input, "rb") as f:
unpacker = msgpack.Unpacker(f, raw=False)
while True:
try:
unpacked = unpacker.unpack()
except msgpack.OutOfData:
break
msg = read_message(unpacked, id)
de += 1
socket.put_de_core(id, document_embeddings=msg)
print(f"Put: {de} document embeddings messages.")
finally:
socket.close()
def main():
parser = argparse.ArgumentParser(
prog='tg-put-de-core',
description=__doc__,
)
parser.add_argument(
'-u', '--url',
default=default_url,
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'--id', '--identifier',
required=True,
help=f'Document embeddings core ID',
)
parser.add_argument(
'-i', '--input',
required=True,
help=f'Input file'
)
parser.add_argument(
'-t', '--token',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
args = parser.parse_args()
try:
put(
url=args.url,
workspace=args.workspace,
id=args.id,
input=args.input,
token=args.token,
)
except Exception as e:
print("Exception:", e, flush=True)
if __name__ == "__main__":
main()

View file

@ -4,13 +4,11 @@ Puts a knowledge core into the knowledge manager via the API socket.
import argparse
import os
import uuid
import asyncio
import json
from websockets.asyncio.client import connect
import msgpack
default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/')
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
@ -21,13 +19,13 @@ def read_message(unpacked, id):
return "ge", {
"metadata": {
"id": id,
"metadata": msg["m"]["m"],
"collection": "default", # Not used?
"root": msg["m"]["m"],
"collection": "default",
},
"entities": [
{
"entity": ent["e"],
"vectors": ent["v"],
"vector": ent["v"],
}
for ent in msg["e"]
],
@ -37,26 +35,20 @@ def read_message(unpacked, id):
return "t", {
"metadata": {
"id": id,
"metadata": msg["m"]["m"],
"collection": "default", # Not used by receiver?
"root": msg["m"]["m"],
"collection": "default",
},
"triples": msg["t"],
}
else:
raise RuntimeError("Unpacked unexpected messsage type", unpacked[0])
async def put(url, workspace, id, input, token=None):
def put(url, workspace, id, input, token=None):
if not url.endswith("/"):
url += "/"
url = url + "api/v1/socket"
if token:
url = f"{url}?token={token}"
async with connect(url) as ws:
api = Api(url=url, token=token, workspace=workspace)
socket = api.socket()
try:
ge = 0
t = 0
@ -68,69 +60,26 @@ async def put(url, workspace, id, input, token=None):
try:
unpacked = unpacker.unpack()
except:
except msgpack.OutOfData:
break
kind, msg = read_message(unpacked, id)
mid = str(uuid.uuid4())
if kind == "ge":
ge += 1
req = json.dumps({
"id": mid,
"workspace": workspace,
"service": "knowledge",
"request": {
"operation": "put-kg-core",
"workspace": workspace,
"id": id,
"graph-embeddings": msg
}
})
socket.put_kg_core(id, graph_embeddings=msg)
elif kind == "t":
t += 1
req = json.dumps({
"id": mid,
"workspace": workspace,
"service": "knowledge",
"request": {
"operation": "put-kg-core",
"workspace": workspace,
"id": id,
"triples": msg
}
})
socket.put_kg_core(id, triples=msg)
else:
raise RuntimeError("Unexpected message kind", kind)
await ws.send(req)
# Retry loop, wait for right response to come back
while True:
msg = await ws.recv()
msg = json.loads(msg)
if msg["id"] != mid:
continue
if "response" in msg:
if "error" in msg["response"]:
raise RuntimeError(msg["response"]["error"])
break
print(f"Put: {t} triple, {ge} GE messages.")
await ws.close()
finally:
socket.close()
def main():
@ -173,7 +122,6 @@ def main():
try:
asyncio.run(
put(
url=args.url,
workspace=args.workspace,
@ -181,7 +129,6 @@ def main():
input=args.input,
token=args.token,
)
)
except Exception as e:

View file

@ -1,5 +1,6 @@
from .. schema import KnowledgeResponse, Error, Triples, GraphEmbeddings
from .. schema import DocumentEmbeddings
from .. knowledge import hash
from .. exceptions import RequestError
from .. tables.knowledge import KnowledgeTableStore
@ -157,6 +158,98 @@ class KnowledgeManager:
)
)
async def list_de_cores(self, request, respond, workspace):
ids = await self.table_store.list_de_cores(workspace)
await respond(
KnowledgeResponse(
error = None,
ids = ids,
eos = False,
triples = None,
graph_embeddings = None,
)
)
async def get_de_core(self, request, respond, workspace):
logger.info("Getting document embeddings core...")
async def publish_de(de):
await respond(
KnowledgeResponse(
error = None,
ids = None,
eos = False,
triples = None,
graph_embeddings = None,
document_embeddings = de,
)
)
await self.table_store.get_document_embeddings(
workspace,
request.id,
publish_de,
)
logger.debug("Document embeddings core retrieval complete")
await respond(
KnowledgeResponse(
error = None,
ids = None,
eos = True,
triples = None,
graph_embeddings = None,
)
)
async def put_de_core(self, request, respond, workspace):
if request.document_embeddings:
await self.table_store.add_document_embeddings(
workspace, request.document_embeddings
)
await respond(
KnowledgeResponse(
error = None,
ids = None,
eos = False,
triples = None,
graph_embeddings = None,
)
)
async def delete_de_core(self, request, respond, workspace):
logger.info("Deleting document embeddings core...")
await self.table_store.delete_document_embeddings(
workspace, request.id
)
await respond(
KnowledgeResponse(
error = None,
ids = None,
eos = False,
triples = None,
graph_embeddings = None,
)
)
async def load_de_core(self, request, respond, workspace):
if self.background_task is None:
self.background_task = asyncio.create_task(
self.core_loader()
)
await self.loader_queue.put((request, respond, workspace))
async def core_loader(self):
logger.info("Knowledge background processor running...")
@ -165,7 +258,7 @@ class KnowledgeManager:
logger.debug("Waiting for next load...")
request, respond, workspace = await self.loader_queue.get()
logger.info(f"Loading knowledge: {request.id}")
logger.info(f"Loading: {request.operation} {request.id}")
try:
@ -187,24 +280,13 @@ class KnowledgeManager:
if "interfaces" not in flow:
raise RuntimeError("No defined interfaces")
if "triples-store" not in flow["interfaces"]:
raise RuntimeError("Flow has no triples-store")
if "graph-embeddings-store" not in flow["interfaces"]:
raise RuntimeError("Flow has no graph-embeddings-store")
t_q = flow["interfaces"]["triples-store"]["flow"]
ge_q = flow["interfaces"]["graph-embeddings-store"]["flow"]
# Got this far, it should all work
await respond(
KnowledgeResponse(
error = None,
ids = None,
eos = False,
triples = None,
graph_embeddings = None
if request.operation == "load-de-core":
await self._load_de_core(
request, respond, workspace, flow,
)
else:
await self._load_kg_core(
request, respond, workspace, flow,
)
except Exception as e:
@ -223,14 +305,36 @@ class KnowledgeManager:
)
)
logger.debug("Knowledge processing done")
logger.debug("Starting knowledge loading process...")
continue
try:
async def _load_kg_core(self, request, respond, workspace, flow):
if "triples-store" not in flow["interfaces"]:
raise RuntimeError("Flow has no triples-store")
if "graph-embeddings-store" not in flow["interfaces"]:
raise RuntimeError("Flow has no graph-embeddings-store")
t_q = flow["interfaces"]["triples-store"]["flow"]
ge_q = flow["interfaces"]["graph-embeddings-store"]["flow"]
await respond(
KnowledgeResponse(
error = None,
ids = None,
eos = False,
triples = None,
graph_embeddings = None
)
)
t_pub = None
ge_pub = None
try:
logger.debug(f"Triples queue: {t_q}")
logger.debug(f"Graph embeddings queue: {ge_q}")
@ -249,7 +353,6 @@ class KnowledgeManager:
await ge_pub.start()
async def publish_triples(t):
# Override collection with request collection
if hasattr(t, 'metadata') and hasattr(t.metadata, 'collection'):
t.metadata.collection = request.collection or "default"
await t_pub.send(None, t)
@ -263,7 +366,6 @@ class KnowledgeManager:
)
async def publish_ge(g):
# Override collection with request collection
if hasattr(g, 'metadata') and hasattr(g.metadata, 'collection'):
g.metadata.collection = request.collection or "default"
await ge_pub.send(None, g)
@ -276,7 +378,7 @@ class KnowledgeManager:
publish_ge,
)
logger.debug("Knowledge loading completed")
logger.debug("Knowledge core loading completed")
except Exception as e:
@ -289,6 +391,59 @@ class KnowledgeManager:
if t_pub: await t_pub.stop()
if ge_pub: await ge_pub.stop()
logger.debug("Knowledge processing done")
async def _load_de_core(self, request, respond, workspace, flow):
continue
if "document-embeddings-store" not in flow["interfaces"]:
raise RuntimeError("Flow has no document-embeddings-store")
de_q = flow["interfaces"]["document-embeddings-store"]["flow"]
await respond(
KnowledgeResponse(
error = None,
ids = None,
eos = False,
triples = None,
graph_embeddings = None
)
)
de_pub = None
try:
logger.debug(f"Document embeddings queue: {de_q}")
de_pub = Publisher(
self.flow_config.pubsub, de_q,
schema=DocumentEmbeddings,
)
logger.debug("Starting publisher...")
await de_pub.start()
async def publish_de(de):
if hasattr(de, 'metadata') and hasattr(de.metadata, 'collection'):
de.metadata.collection = request.collection or "default"
await de_pub.send(None, de)
logger.debug("Publishing document embeddings...")
await self.table_store.get_document_embeddings(
workspace,
request.id,
publish_de,
)
logger.debug("Document embeddings core loading completed")
except Exception as e:
logger.error(f"Knowledge exception: {e}", exc_info=True)
finally:
logger.debug("Stopping publisher...")
if de_pub: await de_pub.stop()

View file

@ -187,6 +187,11 @@ class Processor(WorkspaceProcessor):
"put-kg-core": self.knowledge.put_kg_core,
"load-kg-core": self.knowledge.load_kg_core,
"unload-kg-core": self.knowledge.unload_kg_core,
"list-de-cores": self.knowledge.list_de_cores,
"get-de-core": self.knowledge.get_de_core,
"delete-de-core": self.knowledge.delete_de_core,
"put-de-core": self.knowledge.put_de_core,
"load-de-core": self.knowledge.load_de_core,
}
if v.operation not in impls:

View file

@ -1,10 +1,14 @@
import datetime
import os
import logging
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.query import BatchStatement, SimpleStatement
from ssl import SSLContext, PROTOCOL_TLSv1_2
import os
import logging
from ..tables.cassandra_async import async_execute
# Global list to track clusters for cleanup
_active_clusters = []
@ -461,7 +465,6 @@ class KnowledgeGraph:
def create_collection(self, collection):
"""Create collection by inserting metadata row"""
try:
import datetime
self.session.execute(
f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)",
(collection, datetime.datetime.now())
@ -954,7 +957,6 @@ class EntityCentricKnowledgeGraph:
def create_collection(self, collection):
"""Create collection by inserting metadata row"""
try:
import datetime
self.session.execute(
f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)",
(collection, datetime.datetime.now())
@ -1045,6 +1047,222 @@ class EntityCentricKnowledgeGraph:
logger.info(f"Deleted collection {collection}: {len(entities)} entity partitions, {len(quads)} quads")
# ========================================================================
# Async methods — use cassandra driver's native async API via async_execute
# ========================================================================
async def async_insert(self, collection, s, p, o, g=None, otype=None, dtype="", lang=""):
if g is None:
g = DEFAULT_GRAPH
if otype is None:
if o.startswith("http://") or o.startswith("https://"):
otype = "u"
else:
otype = "l"
batch = BatchStatement()
batch.add(self.insert_entity_stmt, (collection, s, 'S', p, otype, s, o, g, dtype, lang))
batch.add(self.insert_entity_stmt, (collection, p, 'P', p, otype, s, o, g, dtype, lang))
if otype == 'u' or otype == 't':
batch.add(self.insert_entity_stmt, (collection, o, 'O', p, otype, s, o, g, dtype, lang))
if g != DEFAULT_GRAPH:
batch.add(self.insert_entity_stmt, (collection, g, 'G', p, otype, s, o, g, dtype, lang))
batch.add(self.insert_collection_stmt, (collection, g, s, p, o, otype, dtype, lang))
await async_execute(self.session, batch)
async def async_get_all(self, collection, limit=50):
return await async_execute(
self.session, self.get_collection_all_stmt, (collection, limit)
)
async def async_get_s(self, collection, s, g=None, limit=10):
rows = await async_execute(
self.session, self.get_entity_as_s_stmt, (collection, s, limit)
)
results = []
for row in rows:
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is not None and d != g:
continue
results.append(QuadResult(
s=row.s, p=row.p, o=row.o, g=d,
otype=row.otype, dtype=row.dtype, lang=row.lang
))
return results
async def async_get_p(self, collection, p, g=None, limit=10):
rows = await async_execute(
self.session, self.get_entity_as_p_stmt, (collection, p, limit)
)
results = []
for row in rows:
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is not None and d != g:
continue
results.append(QuadResult(
s=row.s, p=row.p, o=row.o, g=d,
otype=row.otype, dtype=row.dtype, lang=row.lang
))
return results
async def async_get_o(self, collection, o, g=None, limit=10):
rows = await async_execute(
self.session, self.get_entity_as_o_stmt, (collection, o, limit)
)
results = []
for row in rows:
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is not None and d != g:
continue
results.append(QuadResult(
s=row.s, p=row.p, o=row.o, g=d,
otype=row.otype, dtype=row.dtype, lang=row.lang
))
return results
async def async_get_sp(self, collection, s, p, g=None, limit=10):
rows = await async_execute(
self.session, self.get_entity_as_s_p_stmt, (collection, s, p, limit)
)
results = []
for row in rows:
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is not None and d != g:
continue
results.append(QuadResult(
s=s, p=p, o=row.o, g=d,
otype=row.otype, dtype=row.dtype, lang=row.lang
))
return results
async def async_get_po(self, collection, p, o, g=None, limit=10):
rows = await async_execute(
self.session, self.get_entity_as_o_p_stmt, (collection, o, p, limit)
)
results = []
for row in rows:
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is not None and d != g:
continue
results.append(QuadResult(
s=row.s, p=p, o=o, g=d,
otype=row.otype, dtype=row.dtype, lang=row.lang
))
return results
async def async_get_os(self, collection, o, s, g=None, limit=10):
rows = await async_execute(
self.session, self.get_entity_as_s_stmt, (collection, s, limit)
)
results = []
for row in rows:
if row.o != o:
continue
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is not None and d != g:
continue
results.append(QuadResult(
s=s, p=row.p, o=o, g=d,
otype=row.otype, dtype=row.dtype, lang=row.lang
))
return results
async def async_get_spo(self, collection, s, p, o, g=None, limit=10):
rows = await async_execute(
self.session, self.get_entity_as_s_p_stmt, (collection, s, p, limit)
)
results = []
for row in rows:
if row.o != o:
continue
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is not None and d != g:
continue
results.append(QuadResult(
s=s, p=p, o=o, g=d,
otype=row.otype, dtype=row.dtype, lang=row.lang
))
return results
async def async_get_g(self, collection, g, limit=50):
if g is None:
g = DEFAULT_GRAPH
return await async_execute(
self.session, self.get_collection_by_graph_stmt, (collection, g, limit)
)
async def async_collection_exists(self, collection):
try:
result = await async_execute(
self.session,
f"SELECT collection FROM {self.collection_metadata_table} WHERE collection = %s LIMIT 1",
(collection,)
)
return bool(result)
except Exception as e:
logger.error(f"Error checking collection existence: {e}")
return False
async def async_create_collection(self, collection):
await async_execute(
self.session,
f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)",
(collection, datetime.datetime.now())
)
logger.info(f"Created collection metadata for {collection}")
async def async_delete_collection(self, collection):
rows = await async_execute(
self.session,
f"SELECT d, s, p, o, otype, dtype, lang FROM {self.collection_table} WHERE collection = %s",
(collection,)
)
entities = set()
quads = []
for row in rows:
d, s, p, o = row.d, row.s, row.p, row.o
otype = row.otype
dtype = row.dtype if hasattr(row, 'dtype') else ''
lang = row.lang if hasattr(row, 'lang') else ''
quads.append((d, s, p, o, otype, dtype, lang))
entities.add(s)
entities.add(p)
if otype == 'u' or otype == 't':
entities.add(o)
if d != DEFAULT_GRAPH:
entities.add(d)
batch = BatchStatement()
count = 0
for entity in entities:
batch.add(self.delete_entity_partition_stmt, (collection, entity))
count += 1
if count % 50 == 0:
await async_execute(self.session, batch)
batch = BatchStatement()
if count % 50 != 0:
await async_execute(self.session, batch)
batch = BatchStatement()
count = 0
for d, s, p, o, otype, dtype, lang in quads:
batch.add(self.delete_collection_row_stmt, (collection, d, s, p, o, otype, dtype, lang))
count += 1
if count % 50 == 0:
await async_execute(self.session, batch)
batch = BatchStatement()
if count % 50 != 0:
await async_execute(self.session, batch)
await async_execute(
self.session,
f"DELETE FROM {self.collection_metadata_table} WHERE collection = %s",
(collection,)
)
logger.info(f"Deleted collection {collection}: {len(entities)} entity partitions, {len(quads)} quads")
def close(self):
"""Close connections"""
if hasattr(self, 'session') and self.session:

View file

@ -457,6 +457,12 @@ for _op in ("put-kg-core", "delete-kg-core",
"load-kg-core", "unload-kg-core"):
_register_kind_op("knowledge", _op, "knowledge:write")
# knowledge: document-embeddings core service.
for _op in ("get-de-core", "list-de-cores"):
_register_kind_op("knowledge", _op, "knowledge:read")
for _op in ("put-de-core", "delete-de-core", "load-de-core"):
_register_kind_op("knowledge", _op, "knowledge:write")
# collection-management: workspace collection lifecycle.
_register_kind_op("collection-management", "list-collections", "collections:read")

View file

@ -4,11 +4,10 @@ Document embeddings query service. Input is vector, output is an array
of chunk_ids
"""
import asyncio
import logging
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
from .... schema import Error
@ -38,32 +37,6 @@ class Processor(DocumentEmbeddingsQueryService):
)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self.last_collection = None
def ensure_collection_exists(self, collection, dim):
"""Ensure collection exists, create if it doesn't"""
if collection != self.last_collection:
if not self.qdrant.collection_exists(collection):
try:
self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
)
logger.info(f"Created collection: {collection}")
except Exception as e:
logger.error(f"Qdrant collection creation failed: {e}")
raise e
self.last_collection = collection
def collection_exists(self, collection):
"""Check if collection exists (no implicit creation)"""
return self.qdrant.collection_exists(collection)
def collection_exists(self, collection):
"""Check if collection exists (no implicit creation)"""
return self.qdrant.collection_exists(collection)
async def query_document_embeddings(self, workspace, msg):
@ -73,21 +46,24 @@ class Processor(DocumentEmbeddingsQueryService):
if not vec:
return []
# Use dimension suffix in collection name
dim = len(vec)
collection = f"d_{workspace}_{msg.collection}_{dim}"
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
exists = await asyncio.to_thread(
self.qdrant.collection_exists, collection
)
if not exists:
logger.info(f"Collection {collection} does not exist, returning empty results")
return []
search_result = self.qdrant.query_points(
result = await asyncio.to_thread(
self.qdrant.query_points,
collection_name=collection,
query=vec,
limit=msg.limit,
with_payload=True,
).points
)
search_result = result.points
chunks = []
for r in search_result:

View file

@ -4,11 +4,10 @@ Graph embeddings query service. Input is vector, output is list of
entities
"""
import asyncio
import logging
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
from .... schema import GraphEmbeddingsResponse, EntityMatch
from .... schema import Error, Term, IRI, LITERAL
@ -38,32 +37,6 @@ class Processor(GraphEmbeddingsQueryService):
)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self.last_collection = None
def ensure_collection_exists(self, collection, dim):
"""Ensure collection exists, create if it doesn't"""
if collection != self.last_collection:
if not self.qdrant.collection_exists(collection):
try:
self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
)
logger.info(f"Created collection: {collection}")
except Exception as e:
logger.error(f"Qdrant collection creation failed: {e}")
raise e
self.last_collection = collection
def collection_exists(self, collection):
"""Check if collection exists (no implicit creation)"""
return self.qdrant.collection_exists(collection)
def collection_exists(self, collection):
"""Check if collection exists (no implicit creation)"""
return self.qdrant.collection_exists(collection)
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
@ -79,23 +52,26 @@ class Processor(GraphEmbeddingsQueryService):
if not vec:
return []
# Use dimension suffix in collection name
dim = len(vec)
collection = f"t_{workspace}_{msg.collection}_{dim}"
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
exists = await asyncio.to_thread(
self.qdrant.collection_exists, collection
)
if not exists:
logger.info(f"Collection {collection} does not exist")
return []
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) unique entities
search_result = self.qdrant.query_points(
result = await asyncio.to_thread(
self.qdrant.query_points,
collection_name=collection,
query=vec,
limit=msg.limit * 2,
with_payload=True,
).points
)
search_result = result.points
entity_set = set()
entities = []

View file

@ -6,6 +6,7 @@ Output is matching row index information (index_name, index_value) for
use in subsequent Cassandra lookups.
"""
import asyncio
import logging
import re
from typing import Optional
@ -70,7 +71,7 @@ class Processor(FlowProcessor):
safe_name = 'r_' + safe_name
return safe_name.lower()
def find_collection(self, workspace: str, collection: str, schema_name: str) -> Optional[str]:
async def find_collection(self, workspace: str, collection: str, schema_name: str) -> Optional[str]:
"""Find the Qdrant collection for a given workspace/collection/schema"""
prefix = (
f"rows_{self.sanitize_name(workspace)}_"
@ -78,14 +79,15 @@ class Processor(FlowProcessor):
)
try:
all_collections = self.qdrant.get_collections().collections
all_collections = await asyncio.to_thread(
lambda: self.qdrant.get_collections().collections
)
matching = [
coll.name for coll in all_collections
if coll.name.startswith(prefix)
]
if matching:
# Return first match (there should typically be only one per dimension)
return matching[0]
except Exception as e:
@ -100,8 +102,7 @@ class Processor(FlowProcessor):
if not vec:
return []
# Find the collection for this workspace/collection/schema
qdrant_collection = self.find_collection(
qdrant_collection = await self.find_collection(
workspace, request.collection, request.schema_name
)
@ -113,7 +114,6 @@ class Processor(FlowProcessor):
return []
try:
# Build optional filter for index_name
query_filter = None
if request.index_name:
query_filter = Filter(
@ -125,16 +125,16 @@ class Processor(FlowProcessor):
]
)
# Query Qdrant
search_result = self.qdrant.query_points(
result = await asyncio.to_thread(
self.qdrant.query_points,
collection_name=qdrant_collection,
query=vec,
limit=request.limit,
with_payload=True,
query_filter=query_filter,
).points
)
search_result = result.points
# Convert to RowIndexMatch objects
matches = []
for point in search_result:
payload = point.payload or {}

View file

@ -11,6 +11,7 @@ Queries against the unified 'rows' table with schema:
- source: text
"""
import asyncio
import json
import logging
import re
@ -97,12 +98,14 @@ class Processor(FlowProcessor):
# Cassandra session
self.cluster = None
self.session = None
self._setup_lock = asyncio.Lock()
# Known keyspaces
self.known_keyspaces: Set[str] = set()
def connect_cassandra(self):
async def connect_cassandra(self):
"""Connect to Cassandra cluster"""
async with self._setup_lock:
if self.session:
return
@ -112,14 +115,16 @@ class Processor(FlowProcessor):
username=self.cassandra_username,
password=self.cassandra_password
)
self.cluster = Cluster(
cluster = Cluster(
contact_points=self.cassandra_host,
auth_provider=auth_provider
)
else:
self.cluster = Cluster(contact_points=self.cassandra_host)
cluster = Cluster(contact_points=self.cassandra_host)
self.session = self.cluster.connect()
session = await asyncio.to_thread(cluster.connect)
self.cluster = cluster
self.session = session
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
except Exception as e:
@ -140,14 +145,17 @@ class Processor(FlowProcessor):
f"for workspace {workspace}"
)
# Replace existing schemas for this workspace
async with self._setup_lock:
await self._apply_schema_config(workspace, config)
async def _apply_schema_config(self, workspace, config):
ws_schemas: Dict[str, RowSchema] = {}
self.schemas[workspace] = ws_schemas
builder = GraphQLSchemaBuilder()
self.schema_builders[workspace] = builder
# Check if our config type exists
if self.config_key not in config:
logger.warning(
f"No '{self.config_key}' type in configuration "
@ -156,16 +164,12 @@ class Processor(FlowProcessor):
self.graphql_schemas[workspace] = None
return
# Get the schemas dictionary for our type
schemas_config = config[self.config_key]
# Process each schema in the schemas config
for schema_name, schema_json in schemas_config.items():
try:
# Parse the JSON schema definition
schema_def = json.loads(schema_json)
# Create Field objects
fields = []
for field_def in schema_def.get("fields", []):
field = SchemaField(
@ -180,7 +184,6 @@ class Processor(FlowProcessor):
)
fields.append(field)
# Create RowSchema
row_schema = RowSchema(
name=schema_def.get("name", schema_name),
description=schema_def.get("description", ""),
@ -202,7 +205,6 @@ class Processor(FlowProcessor):
f"{len(ws_schemas)} schemas"
)
# Regenerate GraphQL schema for this workspace
self.graphql_schemas[workspace] = builder.build(self.query_cassandra)
def get_index_names(self, schema: RowSchema) -> List[str]:
@ -254,7 +256,7 @@ class Processor(FlowProcessor):
For other queries, we need to scan and post-filter.
"""
# Connect if needed
self.connect_cassandra()
await self.connect_cassandra()
safe_keyspace = self.sanitize_name(workspace)

View file

@ -30,14 +30,13 @@ class EvaluationError(Exception):
pass
async def evaluate(node, triples_client, workspace, collection, limit=10000):
async def evaluate(node, triples_client, collection, limit=10000):
"""
Evaluate a SPARQL algebra node.
Args:
node: rdflib CompValue algebra node
triples_client: TriplesClient instance for triple pattern queries
workspace: workspace/keyspace identifier
collection: collection identifier
limit: safety limit on results
@ -55,24 +54,24 @@ async def evaluate(node, triples_client, workspace, collection, limit=10000):
logger.warning(f"Unsupported algebra node: {name}")
return [{}]
return await handler(node, triples_client, workspace, collection, limit)
return await handler(node, triples_client, collection, limit)
# --- Node handlers ---
async def _eval_select_query(node, tc, workspace, collection, limit):
async def _eval_select_query(node, tc, collection, limit):
"""Evaluate a SelectQuery node."""
return await evaluate(node.p, tc, workspace, collection, limit)
return await evaluate(node.p, tc, collection, limit)
async def _eval_project(node, tc, workspace, collection, limit):
async def _eval_project(node, tc, collection, limit):
"""Evaluate a Project node (SELECT variable projection)."""
solutions = await evaluate(node.p, tc, workspace, collection, limit)
solutions = await evaluate(node.p, tc, collection, limit)
variables = [str(v) for v in node.PV]
return project(solutions, variables)
async def _eval_bgp(node, tc, workspace, collection, limit):
async def _eval_bgp(node, tc, collection, limit):
"""
Evaluate a Basic Graph Pattern.
@ -107,7 +106,7 @@ async def _eval_bgp(node, tc, workspace, collection, limit):
# Query the triples store
results = await _query_pattern(
tc, s_val, p_val, o_val, workspace, collection, limit
tc, s_val, p_val, o_val, collection, limit
)
# Map results back to variable bindings,
@ -130,17 +129,17 @@ async def _eval_bgp(node, tc, workspace, collection, limit):
return solutions[:limit]
async def _eval_join(node, tc, workspace, collection, limit):
async def _eval_join(node, tc, collection, limit):
"""Evaluate a Join node."""
left = await evaluate(node.p1, tc, workspace, collection, limit)
right = await evaluate(node.p2, tc, workspace, collection, limit)
left = await evaluate(node.p1, tc, collection, limit)
right = await evaluate(node.p2, tc, collection, limit)
return hash_join(left, right)[:limit]
async def _eval_left_join(node, tc, workspace, collection, limit):
async def _eval_left_join(node, tc, collection, limit):
"""Evaluate a LeftJoin node (OPTIONAL)."""
left_sols = await evaluate(node.p1, tc, workspace, collection, limit)
right_sols = await evaluate(node.p2, tc, workspace, collection, limit)
left_sols = await evaluate(node.p1, tc, collection, limit)
right_sols = await evaluate(node.p2, tc, collection, limit)
filter_fn = None
if hasattr(node, "expr") and node.expr is not None:
@ -153,16 +152,16 @@ async def _eval_left_join(node, tc, workspace, collection, limit):
return left_join(left_sols, right_sols, filter_fn)[:limit]
async def _eval_union(node, tc, workspace, collection, limit):
async def _eval_union(node, tc, collection, limit):
"""Evaluate a Union node."""
left = await evaluate(node.p1, tc, workspace, collection, limit)
right = await evaluate(node.p2, tc, workspace, collection, limit)
left = await evaluate(node.p1, tc, collection, limit)
right = await evaluate(node.p2, tc, collection, limit)
return union(left, right)[:limit]
async def _eval_filter(node, tc, workspace, collection, limit):
async def _eval_filter(node, tc, collection, limit):
"""Evaluate a Filter node."""
solutions = await evaluate(node.p, tc, workspace, collection, limit)
solutions = await evaluate(node.p, tc, collection, limit)
expr = node.expr
return [
sol for sol in solutions
@ -170,22 +169,22 @@ async def _eval_filter(node, tc, workspace, collection, limit):
]
async def _eval_distinct(node, tc, workspace, collection, limit):
async def _eval_distinct(node, tc, collection, limit):
"""Evaluate a Distinct node."""
solutions = await evaluate(node.p, tc, workspace, collection, limit)
solutions = await evaluate(node.p, tc, collection, limit)
return distinct(solutions)
async def _eval_reduced(node, tc, workspace, collection, limit):
async def _eval_reduced(node, tc, collection, limit):
"""Evaluate a Reduced node (like Distinct but implementation-defined)."""
# Treat same as Distinct
solutions = await evaluate(node.p, tc, workspace, collection, limit)
solutions = await evaluate(node.p, tc, collection, limit)
return distinct(solutions)
async def _eval_order_by(node, tc, workspace, collection, limit):
async def _eval_order_by(node, tc, collection, limit):
"""Evaluate an OrderBy node."""
solutions = await evaluate(node.p, tc, workspace, collection, limit)
solutions = await evaluate(node.p, tc, collection, limit)
key_fns = []
for cond in node.expr:
@ -206,7 +205,7 @@ async def _eval_order_by(node, tc, workspace, collection, limit):
return order_by(solutions, key_fns)
async def _eval_slice(node, tc, workspace, collection, limit):
async def _eval_slice(node, tc, collection, limit):
"""Evaluate a Slice node (LIMIT/OFFSET)."""
# Pass tighter limit downstream if possible
inner_limit = limit
@ -214,13 +213,13 @@ async def _eval_slice(node, tc, workspace, collection, limit):
offset = node.start or 0
inner_limit = min(limit, offset + node.length)
solutions = await evaluate(node.p, tc, workspace, collection, inner_limit)
solutions = await evaluate(node.p, tc, collection, inner_limit)
return slice_solutions(solutions, node.start or 0, node.length)
async def _eval_extend(node, tc, workspace, collection, limit):
async def _eval_extend(node, tc, collection, limit):
"""Evaluate an Extend node (BIND)."""
solutions = await evaluate(node.p, tc, workspace, collection, limit)
solutions = await evaluate(node.p, tc, collection, limit)
var_name = str(node.var)
expr = node.expr
@ -246,9 +245,9 @@ async def _eval_extend(node, tc, workspace, collection, limit):
return result
async def _eval_group(node, tc, workspace, collection, limit):
async def _eval_group(node, tc, collection, limit):
"""Evaluate a Group node (GROUP BY with aggregation)."""
solutions = await evaluate(node.p, tc, workspace, collection, limit)
solutions = await evaluate(node.p, tc, collection, limit)
# Extract grouping expressions
group_exprs = []
@ -289,9 +288,9 @@ async def _eval_group(node, tc, workspace, collection, limit):
return result
async def _eval_aggregate_join(node, tc, workspace, collection, limit):
async def _eval_aggregate_join(node, tc, collection, limit):
"""Evaluate an AggregateJoin (aggregation functions after GROUP BY)."""
solutions = await evaluate(node.p, tc, workspace, collection, limit)
solutions = await evaluate(node.p, tc, collection, limit)
result = []
for sol in solutions:
@ -310,7 +309,7 @@ async def _eval_aggregate_join(node, tc, workspace, collection, limit):
return result
async def _eval_graph(node, tc, workspace, collection, limit):
async def _eval_graph(node, tc, collection, limit):
"""Evaluate a Graph node (GRAPH clause)."""
term = node.term
@ -319,16 +318,16 @@ async def _eval_graph(node, tc, workspace, collection, limit):
# We'd need to pass graph to triples queries
# For now, evaluate inner pattern normally
logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired")
return await evaluate(node.p, tc, workspace, collection, limit)
return await evaluate(node.p, tc, collection, limit)
elif isinstance(term, Variable):
# GRAPH ?g { ... } — variable graph
logger.info(f"GRAPH ?{term} clause - variable graph not yet wired")
return await evaluate(node.p, tc, workspace, collection, limit)
return await evaluate(node.p, tc, collection, limit)
else:
return await evaluate(node.p, tc, workspace, collection, limit)
return await evaluate(node.p, tc, collection, limit)
async def _eval_values(node, tc, workspace, collection, limit):
async def _eval_values(node, tc, collection, limit):
"""Evaluate a VALUES clause (inline data)."""
variables = [str(v) for v in node.var]
solutions = []
@ -343,9 +342,9 @@ async def _eval_values(node, tc, workspace, collection, limit):
return solutions
async def _eval_to_multiset(node, tc, workspace, collection, limit):
async def _eval_to_multiset(node, tc, collection, limit):
"""Evaluate a ToMultiSet node (subquery)."""
return await evaluate(node.p, tc, workspace, collection, limit)
return await evaluate(node.p, tc, collection, limit)
# --- Aggregate computation ---
@ -487,7 +486,7 @@ def _resolve_term(tmpl, solution):
return rdflib_term_to_term(tmpl)
async def _query_pattern(tc, s, p, o, workspace, collection, limit):
async def _query_pattern(tc, s, p, o, collection, limit):
"""
Issue a streaming triple pattern query via TriplesClient.
@ -496,7 +495,6 @@ async def _query_pattern(tc, s, p, o, workspace, collection, limit):
results = await tc.query(
s=s, p=p, o=o,
limit=limit,
workspace=workspace,
collection=collection,
)
return results

View file

@ -141,7 +141,6 @@ class Processor(FlowProcessor):
solutions = await evaluate(
parsed.algebra,
triples_client,
workspace=flow.workspace,
collection=request.collection or "default",
limit=request.limit or 10000,
)

View file

@ -6,8 +6,8 @@ null. Output is a list of quads.
import asyncio
import logging
import json
from cassandra.query import SimpleStatement
from .... direct.cassandra_kg import (
@ -176,45 +176,42 @@ class Processor(TriplesQueryService):
self.cassandra_host = hosts
self.cassandra_username = username
self.cassandra_password = password
self.table = None
def ensure_connection(self, workspace):
"""Ensure we have a connection to the correct keyspace."""
if workspace != self.table:
KGClass = EntityCentricKnowledgeGraph
self._connections = {}
self._conn_lock = asyncio.Lock()
async def _get_connection(self, workspace):
async with self._conn_lock:
if workspace not in self._connections:
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
tg = await asyncio.to_thread(
EntityCentricKnowledgeGraph,
hosts=self.cassandra_host,
keyspace=workspace,
username=self.cassandra_username,
password=self.cassandra_password
password=self.cassandra_password,
)
else:
self.tg = KGClass(
tg = await asyncio.to_thread(
EntityCentricKnowledgeGraph,
hosts=self.cassandra_host,
keyspace=workspace,
)
self.table = workspace
self._connections[workspace] = tg
return self._connections[workspace]
async def query_triples(self, workspace, query):
try:
# ensure_connection may construct a fresh
# EntityCentricKnowledgeGraph which does sync schema
# setup against Cassandra. Push it to a worker thread
# so the event loop doesn't block on first-use per workspace.
await asyncio.to_thread(self.ensure_connection, workspace)
# Extract values from query
s_val = get_term_value(query.s)
p_val = get_term_value(query.p)
o_val = get_term_value(query.o)
g_val = query.g # Already a string or None
g_val = query.g
tg = await self._get_connection(workspace)
def get_object_metadata(row):
"""Extract term type metadata from result row"""
return (
getattr(row, 'otype', None),
getattr(row, 'dtype', None),
@ -223,33 +220,21 @@ class Processor(TriplesQueryService):
quads = []
# All self.tg.get_* calls below are sync wrappers around
# cassandra session.execute. Materialise inside a worker
# thread so iteration never triggers sync paging back on
# the event loop.
# Route to appropriate query method based on which fields are specified
if s_val is not None:
if p_val is not None:
if o_val is not None:
# SPO specified - find matching graphs
resp = await asyncio.to_thread(
lambda: list(self.tg.get_spo(
resp = await tg.async_get_spo(
query.collection, s_val, p_val, o_val,
g=g_val, limit=query.limit,
))
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
term_type, datatype, language = get_object_metadata(t)
quads.append((s_val, p_val, o_val, g, term_type, datatype, language))
else:
# SP specified
resp = await asyncio.to_thread(
lambda: list(self.tg.get_sp(
resp = await tg.async_get_sp(
query.collection, s_val, p_val,
g=g_val, limit=query.limit,
))
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
@ -257,24 +242,18 @@ class Processor(TriplesQueryService):
quads.append((s_val, p_val, t.o, g, term_type, datatype, language))
else:
if o_val is not None:
# SO specified
resp = await asyncio.to_thread(
lambda: list(self.tg.get_os(
resp = await tg.async_get_os(
query.collection, o_val, s_val,
g=g_val, limit=query.limit,
))
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
term_type, datatype, language = get_object_metadata(t)
quads.append((s_val, t.p, o_val, g, term_type, datatype, language))
else:
# S only
resp = await asyncio.to_thread(
lambda: list(self.tg.get_s(
resp = await tg.async_get_s(
query.collection, s_val,
g=g_val, limit=query.limit,
))
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
@ -283,24 +262,18 @@ class Processor(TriplesQueryService):
else:
if p_val is not None:
if o_val is not None:
# PO specified
resp = await asyncio.to_thread(
lambda: list(self.tg.get_po(
resp = await tg.async_get_po(
query.collection, p_val, o_val,
g=g_val, limit=query.limit,
))
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
term_type, datatype, language = get_object_metadata(t)
quads.append((t.s, p_val, o_val, g, term_type, datatype, language))
else:
# P only
resp = await asyncio.to_thread(
lambda: list(self.tg.get_p(
resp = await tg.async_get_p(
query.collection, p_val,
g=g_val, limit=query.limit,
))
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
@ -308,40 +281,26 @@ class Processor(TriplesQueryService):
quads.append((t.s, p_val, t.o, g, term_type, datatype, language))
else:
if o_val is not None:
# O only
resp = await asyncio.to_thread(
lambda: list(self.tg.get_o(
resp = await tg.async_get_o(
query.collection, o_val,
g=g_val, limit=query.limit,
))
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
term_type, datatype, language = get_object_metadata(t)
quads.append((t.s, t.p, o_val, g, term_type, datatype, language))
else:
# Nothing specified - get all
resp = await asyncio.to_thread(
lambda: list(self.tg.get_all(
resp = await tg.async_get_all(
query.collection, limit=query.limit,
))
)
for t in resp:
# Note: quads_by_collection uses 'd' for graph field
g = t.d if hasattr(t, 'd') else DEFAULT_GRAPH
# Filter by graph
# g_val=None means all graphs (no filter)
# g_val="" means default graph only
# otherwise filter to specific named graph
if g_val is not None:
if g != g_val:
continue
term_type, datatype, language = get_object_metadata(t)
quads.append((t.s, t.p, t.o, g, term_type, datatype, language))
# Convert to Triple objects (with g field)
# s and p are always IRIs in RDF
# Object uses term_type/datatype/language metadata from database
triples = [
Triple(
s=create_term(q[0], term_type='u'),
@ -365,51 +324,41 @@ class Processor(TriplesQueryService):
Uses Cassandra's paging to fetch results incrementally.
"""
try:
await asyncio.to_thread(self.ensure_connection, workspace)
batch_size = query.batch_size if query.batch_size > 0 else 20
limit = query.limit if query.limit > 0 else 10000
# Extract query pattern
s_val = get_term_value(query.s)
p_val = get_term_value(query.p)
o_val = get_term_value(query.o)
g_val = query.g
def get_object_metadata(row):
"""Extract term type metadata from result row"""
return (
getattr(row, 'otype', None),
getattr(row, 'dtype', None),
getattr(row, 'lang', None),
)
# For streaming, we need to execute with fetch_size
# Use the collection table for get_all queries (most common streaming case)
# Determine which query to use based on pattern
if s_val is None and p_val is None and o_val is None:
# Get all - use collection table with paging
cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {self.tg.collection_table} WHERE collection = %s"
tg = await self._get_connection(workspace)
cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {tg.collection_table} WHERE collection = %s"
params = [query.collection]
statement = SimpleStatement(cql, fetch_size=batch_size)
# async_execute only materialises the first page;
# this query needs all pages, so use sync execute
# in a worker thread where page iteration can block.
result_set = await asyncio.to_thread(
lambda: list(tg.session.execute(statement, params))
)
else:
# For specific patterns, fall back to non-streaming
# (these typically return small result sets anyway)
async for batch, is_final in self._fallback_stream(workspace, query, batch_size):
yield batch, is_final
return
# Materialise in a worker thread. We lose true streaming
# paging (the driver fetches all pages eagerly inside the
# thread) but the event loop stays responsive, and result
# sets at this layer are typically small enough that this
# is acceptable. If true async paging is needed later,
# revisit using ResponseFuture page callbacks.
statement = SimpleStatement(cql, fetch_size=batch_size)
result_set = await asyncio.to_thread(
lambda: list(self.tg.session.execute(statement, params))
)
batch = []
count = 0

View file

@ -3,11 +3,13 @@
Accepts entity/vector pairs and writes them to a Qdrant store.
"""
import asyncio
import uuid
import logging
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
import uuid
import logging
from .... base import DocumentEmbeddingsStoreService, CollectionConfigHandler
from .... base import AsyncProcessor, Consumer, Producer
@ -35,13 +37,35 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self._cache_lock = asyncio.Lock()
self._known_collections: set[str] = set()
# Register for config push notifications
self.register_config_handler(self.on_collection_config, types=["collection"])
async def ensure_collection(self, collection_name, dim):
async with self._cache_lock:
if collection_name in self._known_collections:
return
exists = await asyncio.to_thread(
self.qdrant.collection_exists, collection_name
)
if not exists:
logger.info(
f"Lazily creating Qdrant collection {collection_name} "
f"with dimension {dim}"
)
await asyncio.to_thread(
self.qdrant.create_collection,
collection_name=collection_name,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
)
self._known_collections.add(collection_name)
async def store_document_embeddings(self, workspace, message):
# Validate collection exists in config before processing
if not self.collection_exists(workspace, message.metadata.collection):
logger.warning(
f"Collection {message.metadata.collection} for workspace {workspace} "
@ -60,24 +84,15 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
if not vec:
continue
# Create collection name with dimension suffix for lazy creation
dim = len(vec)
collection = (
f"d_{workspace}_{message.metadata.collection}_{dim}"
)
# Lazily create collection if it doesn't exist (but only if authorized in config)
if not self.qdrant.collection_exists(collection):
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim,
distance=Distance.COSINE
)
)
await self.ensure_collection(collection, dim)
self.qdrant.upsert(
await asyncio.to_thread(
self.qdrant.upsert,
collection_name=collection,
points=[
PointStruct(
@ -87,7 +102,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
"chunk_id": chunk_id,
}
)
]
],
)
@staticmethod
@ -124,8 +139,9 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
try:
prefix = f"d_{workspace}_{collection}_"
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
all_collections = await asyncio.to_thread(
lambda: self.qdrant.get_collections().collections
)
matching_collections = [
coll.name for coll in all_collections
if coll.name.startswith(prefix)
@ -135,7 +151,11 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
logger.info(f"No collections found matching prefix {prefix}")
else:
for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name)
await asyncio.to_thread(
self.qdrant.delete_collection, collection_name
)
async with self._cache_lock:
self._known_collections.discard(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")

View file

@ -3,11 +3,13 @@
Accepts entity/vector pairs and writes them to a Qdrant store.
"""
import asyncio
import uuid
import logging
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
import uuid
import logging
from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler
from .... base import AsyncProcessor, Consumer, Producer
@ -50,13 +52,35 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self._cache_lock = asyncio.Lock()
self._known_collections: set[str] = set()
# Register for config push notifications
self.register_config_handler(self.on_collection_config, types=["collection"])
async def ensure_collection(self, collection_name, dim):
async with self._cache_lock:
if collection_name in self._known_collections:
return
exists = await asyncio.to_thread(
self.qdrant.collection_exists, collection_name
)
if not exists:
logger.info(
f"Lazily creating Qdrant collection {collection_name} "
f"with dimension {dim}"
)
await asyncio.to_thread(
self.qdrant.create_collection,
collection_name=collection_name,
vectors_config=VectorParams(
size=dim, distance=Distance.COSINE
),
)
self._known_collections.add(collection_name)
async def store_graph_embeddings(self, workspace, message):
# Validate collection exists in config before processing
if not self.collection_exists(workspace, message.metadata.collection):
logger.warning(
f"Collection {message.metadata.collection} for workspace {workspace} "
@ -75,22 +99,12 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
if not vec:
continue
# Create collection name with dimension suffix for lazy creation
dim = len(vec)
collection = (
f"t_{workspace}_{message.metadata.collection}_{dim}"
)
# Lazily create collection if it doesn't exist (but only if authorized in config)
if not self.qdrant.collection_exists(collection):
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim,
distance=Distance.COSINE
)
)
await self.ensure_collection(collection, dim)
payload = {
"entity": entity_value,
@ -98,7 +112,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
if entity.chunk_id:
payload["chunk_id"] = entity.chunk_id
self.qdrant.upsert(
await asyncio.to_thread(
self.qdrant.upsert,
collection_name=collection,
points=[
PointStruct(
@ -106,7 +121,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
vector=vec,
payload=payload,
)
]
],
)
@staticmethod
@ -143,8 +158,9 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
try:
prefix = f"t_{workspace}_{collection}_"
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
all_collections = await asyncio.to_thread(
lambda: self.qdrant.get_collections().collections
)
matching_collections = [
coll.name for coll in all_collections
if coll.name.startswith(prefix)
@ -154,7 +170,11 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
logger.info(f"No collections found matching prefix {prefix}")
else:
for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name)
await asyncio.to_thread(
self.qdrant.delete_collection, collection_name
)
async with self._cache_lock:
self._known_collections.discard(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")

View file

@ -16,10 +16,10 @@ Payload structure:
- text: The text that was embedded (for debugging/display)
"""
import asyncio
import logging
import re
import uuid
from typing import Set, Tuple
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct, Distance, VectorParams
@ -63,11 +63,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
# Register config handler for collection management
self.register_config_handler(self.on_collection_config, types=["collection"])
# Cache of created Qdrant collections
self.created_collections: Set[str] = set()
# Qdrant client
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self._cache_lock = asyncio.Lock()
self._known_collections: set[str] = set()
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Qdrant collection naming"""
@ -85,25 +83,28 @@ class Processor(CollectionConfigHandler, FlowProcessor):
safe_schema = self.sanitize_name(schema_name)
return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}"
def ensure_collection(self, collection_name: str, dimension: int):
async def ensure_collection(self, collection_name: str, dimension: int):
"""Create Qdrant collection if it doesn't exist"""
if collection_name in self.created_collections:
async with self._cache_lock:
if collection_name in self._known_collections:
return
if not self.qdrant.collection_exists(collection_name):
exists = await asyncio.to_thread(
self.qdrant.collection_exists, collection_name
)
if not exists:
logger.info(
f"Creating Qdrant collection {collection_name} "
f"with dimension {dimension}"
)
self.qdrant.create_collection(
await asyncio.to_thread(
self.qdrant.create_collection,
collection_name=collection_name,
vectors_config=VectorParams(
size=dimension,
distance=Distance.COSINE
),
)
)
self.created_collections.add(collection_name)
self._known_collections.add(collection_name)
async def on_embeddings(self, msg, consumer, flow):
"""Process incoming RowEmbeddings and write to Qdrant"""
@ -143,15 +144,14 @@ class Processor(CollectionConfigHandler, FlowProcessor):
dimension = len(vector)
# Create/get collection name (lazily on first vector)
if qdrant_collection is None:
qdrant_collection = self.get_collection_name(
workspace, collection, schema_name, dimension
)
self.ensure_collection(qdrant_collection, dimension)
await self.ensure_collection(qdrant_collection, dimension)
# Write to Qdrant
self.qdrant.upsert(
await asyncio.to_thread(
self.qdrant.upsert,
collection_name=qdrant_collection,
points=[
PointStruct(
@ -163,7 +163,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
"text": row_emb.text
}
)
]
],
)
embeddings_written += 1
@ -181,8 +181,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
try:
prefix = f"rows_{self.sanitize_name(workspace)}_{self.sanitize_name(collection)}_"
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
all_collections = await asyncio.to_thread(
lambda: self.qdrant.get_collections().collections
)
matching_collections = [
coll.name for coll in all_collections
if coll.name.startswith(prefix)
@ -192,8 +193,11 @@ class Processor(CollectionConfigHandler, FlowProcessor):
logger.info(f"No Qdrant collections found matching prefix {prefix}")
else:
for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name)
self.created_collections.discard(collection_name)
await asyncio.to_thread(
self.qdrant.delete_collection, collection_name
)
async with self._cache_lock:
self._known_collections.discard(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(
f"Deleted {len(matching_collections)} collection(s) "
@ -217,8 +221,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
)
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
all_collections = await asyncio.to_thread(
lambda: self.qdrant.get_collections().collections
)
matching_collections = [
coll.name for coll in all_collections
if coll.name.startswith(prefix)
@ -228,8 +233,11 @@ class Processor(CollectionConfigHandler, FlowProcessor):
logger.info(f"No Qdrant collections found matching prefix {prefix}")
else:
for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name)
self.created_collections.discard(collection_name)
await asyncio.to_thread(
self.qdrant.delete_collection, collection_name
)
async with self._cache_lock:
self._known_collections.discard(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
except Exception as e:

View file

@ -82,7 +82,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
# Cache of known keyspaces and whether tables exist
self.known_keyspaces: Set[str] = set()
self.tables_initialized: Set[str] = set() # keyspaces with rows/row_partitions tables
self.tables_initialized: Set[str] = set()
# Cache of registered (collection, schema_name) pairs
self.registered_partitions: Set[Tuple[str, str]] = set()
@ -94,6 +94,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
self.cluster = None
self.session = None
# Protects connection setup and cache mutations
self._setup_lock = asyncio.Lock()
def connect_cassandra(self):
"""Connect to Cassandra cluster"""
if self.session:
@ -126,6 +129,11 @@ class Processor(CollectionConfigHandler, FlowProcessor):
f"for workspace {workspace}"
)
async with self._setup_lock:
return await self._apply_schema_config(workspace, config, version)
async def _apply_schema_config(self, workspace, config, version):
# Track which schemas changed in this workspace
old_schemas = self.schemas.get(workspace, {})
old_schema_names = set(old_schemas.keys())
@ -391,12 +399,8 @@ class Processor(CollectionConfigHandler, FlowProcessor):
schema_name = obj.schema_name
source = getattr(obj.metadata, 'source', '') or ''
# Ensure tables exist (sync DDL — push to a worker thread
# so the event loop stays responsive when running in a
# processor group sharing the loop with siblings).
async with self._setup_lock:
await asyncio.to_thread(self.ensure_tables, keyspace)
# Register partitions if first time seeing this (collection, schema_name)
await asyncio.to_thread(
self.register_partitions,
keyspace, collection, schema_name, workspace,
@ -461,35 +465,27 @@ class Processor(CollectionConfigHandler, FlowProcessor):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create/verify collection exists in Cassandra row store"""
# Connect if not already connected (sync, push to thread)
async with self._setup_lock:
await asyncio.to_thread(self.connect_cassandra)
# Ensure tables exist (sync DDL, push to thread)
await asyncio.to_thread(self.ensure_tables, workspace)
logger.info(f"Collection {collection} ready for workspace {workspace}")
async def delete_collection(self, workspace: str, collection: str):
"""Delete all data for a specific collection using partition tracking"""
# Connect if not already connected
async with self._setup_lock:
await asyncio.to_thread(self.connect_cassandra)
safe_keyspace = self.sanitize_name(workspace)
# Check if keyspace exists
if workspace not in self.known_keyspaces:
check_keyspace_cql = """
SELECT keyspace_name FROM system_schema.keyspaces
WHERE keyspace_name = %s
"""
result = await async_execute(
self.session, check_keyspace_cql, (safe_keyspace,)
)
safe_ks = self.sanitize_name(workspace)
check_cql = "SELECT keyspace_name FROM system_schema.keyspaces WHERE keyspace_name = %s"
result = await async_execute(self.session, check_cql, (safe_ks,))
if not result:
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
logger.info(f"Keyspace {safe_ks} does not exist, nothing to delete")
return
self.known_keyspaces.add(workspace)
safe_keyspace = self.sanitize_name(workspace)
# Discover all partitions for this collection
select_partitions_cql = f"""
SELECT schema_name, index_name FROM {safe_keyspace}.row_partitions
@ -540,7 +536,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
logger.error(f"Failed to clean up row_partitions for {collection}: {e}")
raise
# Clear from local cache
async with self._setup_lock:
self.registered_partitions = {
(col, sch) for col, sch in self.registered_partitions
if col != collection
@ -553,7 +549,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
async def delete_collection_schema(self, workspace: str, collection: str, schema_name: str):
"""Delete all data for a specific collection + schema combination"""
# Connect if not already connected
async with self._setup_lock:
await asyncio.to_thread(self.connect_cassandra)
safe_keyspace = self.sanitize_name(workspace)
@ -614,7 +610,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
)
raise
# Clear from local cache
async with self._setup_lock:
self.registered_partitions.discard((collection, schema_name))
logger.info(

View file

@ -4,12 +4,7 @@ Graph writer. Input is graph edge. Writes edges to Cassandra graph.
"""
import asyncio
import base64
import os
import argparse
import time
import logging
import json
from .... direct.cassandra_kg import (
EntityCentricKnowledgeGraph, DEFAULT_GRAPH
@ -28,6 +23,8 @@ default_ident = "triples-write"
def serialize_triple(triple):
"""Serialize a Triple object to JSON for storage."""
import json
if triple is None:
return None
@ -141,63 +138,48 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
self.cassandra_host = hosts
self.cassandra_username = username
self.cassandra_password = password
self.table = None
self.tg = None
self._connections = {}
self._conn_lock = asyncio.Lock()
# Register for config push notifications
self.register_config_handler(self.on_collection_config, types=["collection"])
async def store_triples(self, workspace, message):
# The cassandra-driver work below — connection, schema
# setup, and per-triple inserts — is all synchronous.
# Wrap the whole batch in a worker thread so the event
# loop stays responsive for sibling processors when
# running in a processor group.
def _do_store():
if self.table is None or self.table != workspace:
self.tg = None
# Use factory function to select implementation
KGClass = EntityCentricKnowledgeGraph
try:
async def _get_connection(self, workspace):
async with self._conn_lock:
if workspace not in self._connections:
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
tg = await asyncio.to_thread(
EntityCentricKnowledgeGraph,
hosts=self.cassandra_host,
keyspace=workspace,
username=self.cassandra_username,
password=self.cassandra_password,
)
else:
self.tg = KGClass(
tg = await asyncio.to_thread(
EntityCentricKnowledgeGraph,
hosts=self.cassandra_host,
keyspace=workspace,
)
except Exception as e:
logger.error(f"Exception: {e}", exc_info=True)
time.sleep(1)
raise e
self._connections[workspace] = tg
return self._connections[workspace]
self.table = workspace
async def store_triples(self, workspace, message):
tg = await self._get_connection(workspace)
for t in message.triples:
# Extract values from Term objects
s_val = get_term_value(t.s)
p_val = get_term_value(t.p)
o_val = get_term_value(t.o)
# t.g is None for default graph, or a graph IRI
g_val = t.g if t.g is not None else DEFAULT_GRAPH
# Extract object type metadata for entity-centric storage
otype = get_term_otype(t.o)
dtype = get_term_dtype(t.o)
lang = get_term_lang(t.o)
self.tg.insert(
await tg.async_insert(
message.metadata.collection,
s_val,
p_val,
@ -208,89 +190,32 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
lang=lang,
)
await asyncio.to_thread(_do_store)
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create a collection in Cassandra triple store via config push"""
def _do_create():
# Create or reuse connection for this workspace's keyspace
if self.table is None or self.table != workspace:
self.tg = None
# Use factory function to select implementation
KGClass = EntityCentricKnowledgeGraph
try:
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=workspace,
username=self.cassandra_username,
password=self.cassandra_password,
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=workspace,
)
except Exception as e:
logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}")
raise
tg = await self._get_connection(workspace)
self.table = workspace
# Create collection using the built-in method
logger.info(f"Creating collection {collection} for workspace {workspace}")
if self.tg.collection_exists(collection):
exists = await tg.async_collection_exists(collection)
if exists:
logger.info(f"Collection {collection} already exists")
else:
self.tg.create_collection(collection)
await tg.async_create_collection(collection)
logger.info(f"Created collection {collection}")
try:
await asyncio.to_thread(_do_create)
except Exception as e:
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, workspace: str, collection: str):
"""Delete all data for a specific collection from the unified triples table"""
def _do_delete():
# Create or reuse connection for this workspace's keyspace
if self.table is None or self.table != workspace:
self.tg = None
# Use factory function to select implementation
KGClass = EntityCentricKnowledgeGraph
try:
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=workspace,
username=self.cassandra_username,
password=self.cassandra_password,
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=workspace,
)
except Exception as e:
logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}")
raise
tg = await self._get_connection(workspace)
self.table = workspace
# Delete all triples for this collection using the built-in method
self.tg.delete_collection(collection)
await tg.async_delete_collection(collection)
logger.info(f"Deleted all triples for collection {collection} from keyspace {workspace}")
try:
await asyncio.to_thread(_do_delete)
except Exception as e:
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise

View file

@ -1,6 +1,7 @@
from .. schema import KnowledgeResponse, Triple, Triples, EntityEmbeddings
from .. schema import Metadata, Term, IRI, LITERAL, GraphEmbeddings
from .. schema import DocumentEmbeddings, ChunkEmbeddings
from cassandra.cluster import Cluster
@ -217,6 +218,16 @@ class KnowledgeTableStore:
WHERE workspace = ? AND document_id = ?
""")
self.delete_document_embeddings_stmt = self.cassandra.prepare("""
DELETE FROM document_embeddings
WHERE workspace = ? AND document_id = ?
""")
self.list_de_cores_stmt = self.cassandra.prepare("""
SELECT DISTINCT workspace, document_id FROM document_embeddings
WHERE workspace = ?
""")
async def add_triples(self, workspace, m):
when = int(time.time() * 1000)
@ -338,6 +349,50 @@ class KnowledgeTableStore:
logger.error("Exception occurred", exc_info=True)
raise
try:
await async_execute(
self.cassandra,
self.delete_document_embeddings_stmt,
(workspace, document_id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
async def delete_document_embeddings(self, workspace, document_id):
logger.debug("Delete document embeddings...")
try:
await async_execute(
self.cassandra,
self.delete_document_embeddings_stmt,
(workspace, document_id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
async def list_de_cores(self, workspace):
logger.debug("List DE cores...")
try:
rows = await async_execute(
self.cassandra,
self.list_de_cores_stmt,
(workspace,),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
lst = [row[1] for row in rows]
logger.debug("Done")
return lst
async def get_triples(self, workspace, document_id, receiver):
logger.debug("Get triples...")
@ -417,3 +472,42 @@ class KnowledgeTableStore:
logger.debug("Done")
async def get_document_embeddings(self, workspace, document_id, receiver):
logger.debug("Get DE...")
try:
rows = await async_execute(
self.cassandra,
self.get_document_embeddings_stmt,
(workspace, document_id),
)
except Exception:
logger.error("Exception occurred", exc_info=True)
raise
for row in rows:
if row[3]:
chunks = [
ChunkEmbeddings(
chunk_id=ch[0],
vector=ch[1],
)
for ch in row[3]
]
else:
chunks = []
await receiver(
DocumentEmbeddings(
metadata = Metadata(
id = document_id,
collection = "default",
),
chunks = chunks
)
)
logger.debug("Done")