From 142dd0231c7673b1f9700aaea07f80fbdc2a4236 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Fri, 15 May 2026 13:02:51 +0100 Subject: [PATCH] release/v2.4 -> master (#924) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * CLI auth migration, document embeddings core lifecycle (#913) Migrate get_kg_core and put_kg_core CLI tools to use Api/SocketClient with first-frame auth (fixes broken raw websocket path). Fix wire format field names (root/vector). Remove ~600 lines of dead raw websocket code from invoke_graph_rag.py. Add document embeddings core lifecycle to the knowledge service: list/get/put/delete/load operations across schema, translator, Cassandra table store, knowledge manager, gateway registry, REST API, socket client, and CLI (tg-get-de-core, tg-put-de-core). Fix delete_kg_core to also clean up document embeddings rows. * Remove spurious workspace parameter from SPARQL algebra evaluator (#915) Fix threading of workspace paramater: - The SPARQL algebra evaluator was threading a workspace parameter through every function and passing it to TriplesClient.query(), which doesn't accept it. Workspace isolation is handled by pub/sub topic routing — the TriplesClient is already scoped to a workspace-specific flow, same as GraphRAG. Passing workspace explicitly was both incorrect and unnecessary. Update tests: - tests/unit/test_query/test_sparql_algebra.py (new) — Tests _query_pattern, _eval_bgp, and evaluate() with various algebra nodes. Key tests assert workspace is never in tc.query() kwargs, plus correctness tests for BGP, JOIN, UNION, SLICE, DISTINCT, and edge cases. - tests/unit/test_retrieval/test_graph_rag.py — Added test_triples_query_never_passes_workspace (checks query()) and test_follow_edges_never_passes_workspace (checks query_stream()). * Make all Cassandra and Qdrant I/O async-safe with proper concurrency controls (#916) Cassandra triples services were using syncronous EntityCentricKnowledgeGraph methods from async contexts, and connection state was managed with threading.local which is wrong for asyncio coroutines sharing a single thread. Qdrant services had no async wrapping at all, blocking the event loop on every network call. Rows services had unprotected shared state mutations across concurrent coroutines. - Add async methods to EntityCentricKnowledgeGraph (async_insert, async_get_s/p/o/sp/po/os/spo/all, async_collection_exists, async_create_collection, async_delete_collection) using the existing cassandra_async.async_execute bridge - Rewrite triples write + query services: replace threading.local with asyncio.Lock + dict cache for per-workspace connections, use async ECKG methods for all data operations, keep asyncio.to_thread only for one-time blocking ECKG construction - Wrap all Qdrant calls in asyncio.to_thread across all 6 services (doc/graph/row embeddings write + query), add asyncio.Lock + set cache for collection existence checks - Add asyncio.Lock to rows write + query services to protect shared state (schemas, sessions, config caches) from concurrent mutation - Update all affected tests to match new async patterns * Fixed error only returning a page of results (#921) The root cause: async_execute only materialises the first result page (by design — it says so in its docstring). The streaming query set fetch_size=20 and expected to iterate all results, but only got the first 20 rows back. The fix uses asyncio.to_thread(lambda: list(tg.session.execute(...))) which lets the sync driver iterate all pages in a worker thread — exactly what the pre-async code did. * Optional test warning suppression (#923) * Fix test collection module errors & silence upstream Pytest warnings (#823) * chore: add virtual environment and .env directories to gitignore * test: filter upstream DeprecationWarning and UserWarning messages * fix(namespace): remove empty __init__.py files to fix PEP 420 implicit namespace routing for trustgraph sub-packages * Revert __init__.py deletions * Add .ini changes but commented out, will be useful at times --------- Co-authored-by: Salil M --- .gitignore | 5 +- .../test_cassandra_config_end_to_end.py | 79 ++- .../test_rows_cassandra_integration.py | 3 + .../test_rows_graphql_query_integration.py | 12 +- tests/pytest.ini | 11 +- .../test_query/test_rows_cassandra_query.py | 7 +- tests/unit/test_query/test_sparql_algebra.py | 302 +++++++++ .../test_triples_cassandra_query.py | 112 ++-- .../test_null_embedding_protection.py | 12 + tests/unit/test_retrieval/test_graph_rag.py | 51 ++ .../test_doc_embeddings_qdrant_storage.py | 4 +- .../test_row_embeddings_qdrant_storage.py | 14 +- .../test_rows_cassandra_storage.py | 3 + .../test_triples_cassandra_storage.py | 78 +-- .../test_row_embeddings_query.py | 27 +- trustgraph-base/trustgraph/api/knowledge.py | 31 + .../trustgraph/api/socket_client.py | 52 ++ .../messaging/translators/knowledge.py | 56 +- .../trustgraph/schema/knowledge/knowledge.py | 6 +- trustgraph-cli/pyproject.toml | 2 + trustgraph-cli/trustgraph/cli/get_de_core.py | 111 ++++ trustgraph-cli/trustgraph/cli/get_kg_core.py | 77 +-- .../trustgraph/cli/invoke_graph_rag.py | 604 ------------------ trustgraph-cli/trustgraph/cli/put_de_core.py | 119 ++++ trustgraph-cli/trustgraph/cli/put_kg_core.py | 99 +-- trustgraph-flow/trustgraph/cores/knowledge.py | 325 +++++++--- trustgraph-flow/trustgraph/cores/service.py | 5 + .../trustgraph/direct/cassandra_kg.py | 226 ++++++- .../trustgraph/gateway/registry.py | 6 + .../query/doc_embeddings/qdrant/service.py | 42 +- .../query/graph_embeddings/qdrant/service.py | 42 +- .../query/row_embeddings/qdrant/service.py | 20 +- .../query/rows/cassandra/service.py | 60 +- .../trustgraph/query/sparql/algebra.py | 84 ++- .../trustgraph/query/sparql/service.py | 1 - .../query/triples/cassandra/service.py | 171 ++--- .../storage/doc_embeddings/qdrant/write.py | 58 +- .../storage/graph_embeddings/qdrant/write.py | 58 +- .../storage/row_embeddings/qdrant/write.py | 76 ++- .../storage/rows/cassandra/write.py | 78 ++- .../storage/triples/cassandra/write.py | 179 ++---- .../trustgraph/tables/knowledge.py | 94 +++ 42 files changed, 1910 insertions(+), 1492 deletions(-) create mode 100644 tests/unit/test_query/test_sparql_algebra.py create mode 100644 trustgraph-cli/trustgraph/cli/get_de_core.py create mode 100644 trustgraph-cli/trustgraph/cli/put_de_core.py diff --git a/.gitignore b/.gitignore index 32942156..366edb4a 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,7 @@ trustgraph-vertexai/trustgraph/vertexai_version.py trustgraph-unstructured/trustgraph/unstructured_version.py trustgraph-mcp/trustgraph/mcp_version.py trustgraph/trustgraph/trustgraph_version.py -vertexai/ \ No newline at end of file +vertexai/ +venv/ +.venv/ +.env diff --git a/tests/integration/test_cassandra_config_end_to_end.py b/tests/integration/test_cassandra_config_end_to_end.py index 514a5dbf..290e1348 100644 --- a/tests/integration/test_cassandra_config_end_to_end.py +++ b/tests/integration/test_cassandra_config_end_to_end.py @@ -63,26 +63,26 @@ class TestEndToEndConfigurationFlow: 'CASSANDRA_USERNAME': 'obj-user', 'CASSANDRA_PASSWORD': 'obj-pass' } - + mock_auth_instance = MagicMock() mock_auth_provider.return_value = mock_auth_instance mock_cluster_instance = MagicMock() mock_session = MagicMock() mock_cluster_instance.connect.return_value = mock_session mock_cluster.return_value = mock_cluster_instance - + with patch.dict(os.environ, env_vars, clear=True): processor = RowsWriter(taskgroup=MagicMock()) - + # Trigger Cassandra connection processor.connect_cassandra() - + # Verify auth provider was created with env vars mock_auth_provider.assert_called_once_with( username='obj-user', password='obj-pass' ) - + # Verify cluster was created with hosts from env and auth mock_cluster.assert_called_once() call_args = mock_cluster.call_args @@ -188,37 +188,34 @@ class TestConfigurationPriorityEndToEnd: ) @pytest.mark.asyncio - @patch('trustgraph.direct.cassandra_kg.Cluster') - async def test_no_config_defaults_end_to_end(self, mock_cluster): + @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') + async def test_no_config_defaults_end_to_end(self, mock_kg_class): """Test that defaults are used when no configuration provided end-to-end.""" - mock_cluster_instance = MagicMock() - mock_session = MagicMock() - mock_cluster_instance.connect.return_value = mock_session - mock_cluster.return_value = mock_cluster_instance - + from unittest.mock import AsyncMock + + mock_tg_instance = MagicMock() + mock_tg_instance.async_get_all = AsyncMock(return_value=[]) + mock_kg_class.return_value = mock_tg_instance + with patch.dict(os.environ, {}, clear=True): processor = TriplesQuery(taskgroup=MagicMock()) - + # Mock query to trigger TrustGraph creation mock_query = MagicMock() mock_query.collection = 'default_collection' mock_query.s = None mock_query.p = None mock_query.o = None + mock_query.g = None mock_query.limit = 100 - - # Mock the get_all method to return empty list - mock_tg_instance = MagicMock() - mock_tg_instance.get_all.return_value = [] - processor.tg = mock_tg_instance - + await processor.query_triples('default_user', mock_query) - + # Should use defaults - mock_cluster.assert_called_once() - call_args = mock_cluster.call_args - assert call_args.args[0] == ['cassandra'] # Default host - assert 'auth_provider' not in call_args.kwargs # No auth with default config + mock_kg_class.assert_called_once_with( + hosts=['cassandra'], + keyspace='default_user' + ) class TestNoBackwardCompatibilityEndToEnd: @@ -324,16 +321,16 @@ class TestMultipleHostsHandling: env_vars = { 'CASSANDRA_HOST': 'host1,host2,host3,host4,host5' } - + mock_cluster_instance = MagicMock() mock_session = MagicMock() mock_cluster_instance.connect.return_value = mock_session mock_cluster.return_value = mock_cluster_instance - + with patch.dict(os.environ, env_vars, clear=True): processor = RowsWriter(taskgroup=MagicMock()) processor.connect_cassandra() - + # Verify all hosts were passed to Cluster mock_cluster.assert_called_once() call_args = mock_cluster.call_args @@ -392,27 +389,27 @@ class TestAuthenticationFlow: 'CASSANDRA_USERNAME': 'auth-user', 'CASSANDRA_PASSWORD': 'auth-secret' } - + mock_auth_instance = MagicMock() mock_auth_provider.return_value = mock_auth_instance mock_cluster_instance = MagicMock() mock_cluster.return_value = mock_cluster_instance - + with patch.dict(os.environ, env_vars, clear=True): processor = RowsWriter(taskgroup=MagicMock()) processor.connect_cassandra() - + # Auth provider should be created mock_auth_provider.assert_called_once_with( username='auth-user', password='auth-secret' ) - + # Cluster should be created with auth provider call_args = mock_cluster.call_args assert 'auth_provider' in call_args.kwargs assert call_args.kwargs['auth_provider'] == mock_auth_instance - + @patch('trustgraph.storage.rows.cassandra.write.Cluster') @patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') def test_no_authentication_when_credentials_missing(self, mock_auth_provider, mock_cluster): @@ -421,21 +418,21 @@ class TestAuthenticationFlow: 'CASSANDRA_HOST': 'no-auth-host' # No username/password } - + mock_cluster_instance = MagicMock() mock_cluster.return_value = mock_cluster_instance - + with patch.dict(os.environ, env_vars, clear=True): processor = RowsWriter(taskgroup=MagicMock()) processor.connect_cassandra() - + # Auth provider should not be created mock_auth_provider.assert_not_called() - + # Cluster should be created without auth provider call_args = mock_cluster.call_args assert 'auth_provider' not in call_args.kwargs - + @patch('trustgraph.storage.rows.cassandra.write.Cluster') @patch('trustgraph.storage.rows.cassandra.write.PlainTextAuthProvider') def test_no_authentication_when_only_username_provided(self, mock_auth_provider, mock_cluster): @@ -446,15 +443,15 @@ class TestAuthenticationFlow: cassandra_username='partial-user' # No password ) - + mock_cluster_instance = MagicMock() mock_cluster.return_value = mock_cluster_instance - + processor.connect_cassandra() - + # Auth provider should not be created (needs both username AND password) mock_auth_provider.assert_not_called() - + # Cluster should be created without auth provider call_args = mock_cluster.call_args assert 'auth_provider' not in call_args.kwargs \ No newline at end of file diff --git a/tests/integration/test_rows_cassandra_integration.py b/tests/integration/test_rows_cassandra_integration.py index 1358d420..d668600c 100644 --- a/tests/integration/test_rows_cassandra_integration.py +++ b/tests/integration/test_rows_cassandra_integration.py @@ -101,6 +101,8 @@ class TestRowsCassandraIntegration: processor.session = None # Bind actual methods from the new unified table implementation + import asyncio + processor._setup_lock = asyncio.Lock() processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor) processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor) processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor) @@ -108,6 +110,7 @@ class TestRowsCassandraIntegration: processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) processor.build_index_value = Processor.build_index_value.__get__(processor, Processor) processor.register_partitions = Processor.register_partitions.__get__(processor, Processor) + processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor) processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) processor.on_object = Processor.on_object.__get__(processor, Processor) processor.collection_exists = MagicMock(return_value=True) diff --git a/tests/integration/test_rows_graphql_query_integration.py b/tests/integration/test_rows_graphql_query_integration.py index 29b4464d..a455accd 100644 --- a/tests/integration/test_rows_graphql_query_integration.py +++ b/tests/integration/test_rows_graphql_query_integration.py @@ -184,7 +184,7 @@ class TestObjectsGraphQLQueryIntegration: await processor.on_schema_config("default", sample_schema_config, version=1) # Connect to Cassandra - processor.connect_cassandra() + await processor.connect_cassandra() assert processor.session is not None # Create test keyspace and table @@ -219,7 +219,7 @@ class TestObjectsGraphQLQueryIntegration: """Test inserting data and querying via GraphQL""" # Load schema and connect await processor.on_schema_config("default", sample_schema_config, version=1) - processor.connect_cassandra() + await processor.connect_cassandra() # Setup test data keyspace = "test_user" @@ -293,7 +293,7 @@ class TestObjectsGraphQLQueryIntegration: """Test GraphQL queries with filtering on indexed fields""" # Setup (reuse previous setup) await processor.on_schema_config("default", sample_schema_config, version=1) - processor.connect_cassandra() + await processor.connect_cassandra() keyspace = "test_user" collection = "filter_test" @@ -387,7 +387,7 @@ class TestObjectsGraphQLQueryIntegration: """Test full message processing workflow""" # Setup await processor.on_schema_config("default", sample_schema_config, version=1) - processor.connect_cassandra() + await processor.connect_cassandra() # Create mock message request = RowsQueryRequest( @@ -433,7 +433,7 @@ class TestObjectsGraphQLQueryIntegration: """Test handling multiple concurrent GraphQL queries""" # Setup await processor.on_schema_config("default", sample_schema_config, version=1) - processor.connect_cassandra() + await processor.connect_cassandra() # Create multiple query tasks queries = [ @@ -519,7 +519,7 @@ class TestObjectsGraphQLQueryIntegration: """Test handling of large query result sets""" # Setup await processor.on_schema_config("default", sample_schema_config, version=1) - processor.connect_cassandra() + await processor.connect_cassandra() keyspace = "large_test_user" collection = "large_collection" diff --git a/tests/pytest.ini b/tests/pytest.ini index 5dcc095c..a89759ab 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -16,4 +16,13 @@ markers = unit: marks tests as unit tests contract: marks tests as contract tests (service interface validation) vertexai: marks tests as vertex ai specific tests - asyncio: marks tests that use asyncio \ No newline at end of file + asyncio: marks tests that use asyncio +# This is helpful if you're bored with deprecationwarnings. I prefer to +# keep the warnings for now, it avoids masking problems. +# +# filterwarnings = +# ignore:Core Pydantic V1 functionality isn't compatible with Python 3.14.*:UserWarning +# ignore:builtin type SwigPyPacked has no __module__ attribute:DeprecationWarning +# ignore:builtin type SwigPyObject has no __module__ attribute:DeprecationWarning +# ignore:builtin type swigvarlink has no __module__ attribute:DeprecationWarning +# ignore:.*_UnionGenericAlias.*is deprecated and slated for removal in Python 3.17:DeprecationWarning diff --git a/tests/unit/test_query/test_rows_cassandra_query.py b/tests/unit/test_query/test_rows_cassandra_query.py index bb6bbe84..b61500a4 100644 --- a/tests/unit/test_query/test_rows_cassandra_query.py +++ b/tests/unit/test_query/test_rows_cassandra_query.py @@ -89,12 +89,15 @@ class TestRowsGraphQLQueryLogic: @pytest.mark.asyncio async def test_schema_config_parsing(self): """Test parsing of schema configuration""" + import asyncio processor = MagicMock() processor.schemas = {} processor.schema_builders = {} processor.graphql_schemas = {} processor.config_key = "schema" processor.query_cassandra = MagicMock() + processor._setup_lock = asyncio.Lock() + processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor) processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) # Create test config @@ -335,7 +338,7 @@ class TestUnifiedTableQueries: """Test query execution with matching index""" processor = MagicMock() processor.session = MagicMock() - processor.connect_cassandra = MagicMock() + processor.connect_cassandra = AsyncMock() processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor) @@ -396,7 +399,7 @@ class TestUnifiedTableQueries: """Test query execution without matching index (scan mode)""" processor = MagicMock() processor.session = MagicMock() - processor.connect_cassandra = MagicMock() + processor.connect_cassandra = AsyncMock() processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) processor.get_index_names = Processor.get_index_names.__get__(processor, Processor) processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor) diff --git a/tests/unit/test_query/test_sparql_algebra.py b/tests/unit/test_query/test_sparql_algebra.py new file mode 100644 index 00000000..9827b2de --- /dev/null +++ b/tests/unit/test_query/test_sparql_algebra.py @@ -0,0 +1,302 @@ +""" +Tests for the SPARQL algebra evaluator. + +Verifies that evaluate() and _query_pattern() call TriplesClient.query() +with the correct arguments, and in particular that workspace is never +passed — workspace isolation is handled by pub/sub topic routing. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, call + +from rdflib.term import Variable, URIRef, Literal +from rdflib.plugins.sparql.parserutils import CompValue + +from trustgraph.schema import Term, IRI, LITERAL +from trustgraph.query.sparql.algebra import ( + evaluate, _query_pattern, _eval_bgp, +) + + +# --- Helpers --- + +def iri(v): + return Term(type=IRI, iri=v) + + +def lit(v): + return Term(type=LITERAL, value=v) + + +def make_triple(s, p, o): + t = MagicMock() + t.s = s + t.p = p + t.o = o + return t + + +def make_bgp(*patterns): + """Build a CompValue BGP node from (s, p, o) tuples of rdflib terms.""" + node = CompValue("BGP") + node.triples = list(patterns) + return node + + +def make_project(inner, variables): + node = CompValue("Project") + node.p = inner + node.PV = [Variable(v) for v in variables] + return node + + +def make_select(inner): + node = CompValue("SelectQuery") + node.p = inner + return node + + +def make_join(left, right): + node = CompValue("Join") + node.p1 = left + node.p2 = right + return node + + +def make_union(left, right): + node = CompValue("Union") + node.p1 = left + node.p2 = right + return node + + +def make_slice(inner, start, length): + node = CompValue("Slice") + node.p = inner + node.start = start + node.length = length + return node + + +def make_distinct(inner): + node = CompValue("Distinct") + node.p = inner + return node + + +class TestQueryPattern: + """Tests for _query_pattern — the leaf that calls TriplesClient.""" + + @pytest.mark.asyncio + async def test_passes_correct_args(self): + tc = AsyncMock() + tc.query.return_value = [] + + await _query_pattern( + tc, + s=iri("http://example.com/s"), + p=iri("http://example.com/p"), + o=None, + collection="my-collection", + limit=100, + ) + + tc.query.assert_called_once_with( + s=iri("http://example.com/s"), + p=iri("http://example.com/p"), + o=None, + limit=100, + collection="my-collection", + ) + + @pytest.mark.asyncio + async def test_workspace_not_passed(self): + tc = AsyncMock() + tc.query.return_value = [] + + await _query_pattern(tc, None, None, None, "default", 10) + + kwargs = tc.query.call_args.kwargs + assert "workspace" not in kwargs + + @pytest.mark.asyncio + async def test_returns_query_results(self): + tc = AsyncMock() + triple = make_triple(iri("http://a"), iri("http://b"), lit("c")) + tc.query.return_value = [triple] + + results = await _query_pattern(tc, None, None, None, "default", 10) + + assert len(results) == 1 + assert results[0].s.iri == "http://a" + + +class TestEvalBgp: + """Tests for BGP evaluation — triple pattern queries.""" + + @pytest.mark.asyncio + async def test_single_pattern_all_variables(self): + tc = AsyncMock() + triple = make_triple(iri("http://s"), iri("http://p"), lit("o")) + tc.query.return_value = [triple] + + bgp = make_bgp( + (Variable("s"), Variable("p"), Variable("o")), + ) + + solutions = await evaluate(bgp, tc, collection="default", limit=100) + + assert len(solutions) == 1 + assert solutions[0]["s"].iri == "http://s" + assert solutions[0]["p"].iri == "http://p" + assert solutions[0]["o"].value == "o" + + @pytest.mark.asyncio + async def test_single_pattern_bound_subject(self): + tc = AsyncMock() + tc.query.return_value = [ + make_triple(iri("http://s"), iri("http://p"), lit("val")), + ] + + bgp = make_bgp( + (URIRef("http://s"), Variable("p"), Variable("o")), + ) + + solutions = await evaluate(bgp, tc, collection="default") + + tc.query.assert_called_once() + kwargs = tc.query.call_args.kwargs + assert "workspace" not in kwargs + assert kwargs["collection"] == "default" + + @pytest.mark.asyncio + async def test_empty_bgp_returns_empty_solution(self): + tc = AsyncMock() + + bgp = make_bgp() + + solutions = await evaluate(bgp, tc, collection="default") + + assert solutions == [{}] + tc.query.assert_not_called() + + @pytest.mark.asyncio + async def test_no_results_returns_empty(self): + tc = AsyncMock() + tc.query.return_value = [] + + bgp = make_bgp( + (Variable("s"), Variable("p"), Variable("o")), + ) + + solutions = await evaluate(bgp, tc, collection="default") + + assert solutions == [] + + +class TestEvaluate: + """Tests for the top-level evaluate() dispatcher.""" + + @pytest.mark.asyncio + async def test_select_query_node(self): + tc = AsyncMock() + tc.query.return_value = [ + make_triple(iri("http://s"), iri("http://p"), lit("o")), + ] + + bgp = make_bgp( + (Variable("s"), Variable("p"), Variable("o")), + ) + select = make_select(make_project(bgp, ["s", "p"])) + + solutions = await evaluate(select, tc, collection="default") + + assert len(solutions) == 1 + assert "s" in solutions[0] + assert "p" in solutions[0] + assert "o" not in solutions[0] + + @pytest.mark.asyncio + async def test_workspace_never_in_query_calls(self): + """Verify that no matter the algebra structure, workspace is never + passed to TriplesClient.query().""" + tc = AsyncMock() + tc.query.return_value = [ + make_triple(iri("http://s"), iri("http://p"), lit("o")), + ] + + bgp1 = make_bgp((Variable("s"), Variable("p"), Variable("o"))) + bgp2 = make_bgp((Variable("a"), Variable("b"), Variable("c"))) + tree = make_select(make_project( + make_union(bgp1, bgp2), ["s", "p", "o"] + )) + + await evaluate(tree, tc, collection="test-coll") + + for c in tc.query.call_args_list: + assert "workspace" not in c.kwargs + + @pytest.mark.asyncio + async def test_join(self): + tc = AsyncMock() + tc.query.side_effect = [ + [make_triple(iri("http://a"), iri("http://p"), lit("v"))], + [make_triple(iri("http://a"), iri("http://q"), lit("w"))], + ] + + bgp1 = make_bgp((Variable("s"), URIRef("http://p"), Variable("v1"))) + bgp2 = make_bgp((Variable("s"), URIRef("http://q"), Variable("v2"))) + tree = make_join(bgp1, bgp2) + + solutions = await evaluate(tree, tc, collection="default") + + assert len(solutions) == 1 + assert solutions[0]["s"].iri == "http://a" + + @pytest.mark.asyncio + async def test_slice(self): + tc = AsyncMock() + triples = [ + make_triple(iri(f"http://s{i}"), iri("http://p"), lit(f"o{i}")) + for i in range(5) + ] + tc.query.return_value = triples + + bgp = make_bgp((Variable("s"), Variable("p"), Variable("o"))) + tree = make_slice(bgp, start=1, length=2) + + solutions = await evaluate(tree, tc, collection="default") + + assert len(solutions) == 2 + + @pytest.mark.asyncio + async def test_distinct(self): + tc = AsyncMock() + triple = make_triple(iri("http://s"), iri("http://p"), lit("o")) + tc.query.return_value = [triple, triple] + + bgp = make_bgp((Variable("s"), Variable("p"), Variable("o"))) + tree = make_distinct(bgp) + + solutions = await evaluate(tree, tc, collection="default") + + assert len(solutions) == 1 + + @pytest.mark.asyncio + async def test_unsupported_node_returns_empty_solution(self): + tc = AsyncMock() + + node = CompValue("SomethingUnknown") + + solutions = await evaluate(node, tc, collection="default") + + assert solutions == [{}] + tc.query.assert_not_called() + + @pytest.mark.asyncio + async def test_non_compvalue_returns_empty_solution(self): + tc = AsyncMock() + + solutions = await evaluate("not a node", tc, collection="default") + + assert solutions == [{}] diff --git a/tests/unit/test_query/test_triples_cassandra_query.py b/tests/unit/test_query/test_triples_cassandra_query.py index 09681214..980fa904 100644 --- a/tests/unit/test_query/test_triples_cassandra_query.py +++ b/tests/unit/test_query/test_triples_cassandra_query.py @@ -2,8 +2,10 @@ Tests for Cassandra triples query service """ +import asyncio + import pytest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, AsyncMock from trustgraph.query.triples.cassandra.service import Processor, create_term from trustgraph.schema import Term, IRI, LITERAL @@ -18,7 +20,7 @@ class TestCassandraQueryProcessor: return Processor( taskgroup=MagicMock(), id='test-cassandra-query', - graph_host='localhost' + cassandra_host='localhost' ) def test_create_term_with_http_uri(self, processor): @@ -85,7 +87,7 @@ class TestCassandraQueryProcessor: mock_result.dtype = None mock_result.lang = None mock_result.o = 'test_object' - mock_tg_instance.get_spo.return_value = [mock_result] + mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result]) processor = Processor( taskgroup=MagicMock(), @@ -110,8 +112,8 @@ class TestCassandraQueryProcessor: keyspace='test_user' ) - # Verify get_spo was called with correct parameters - mock_tg_instance.get_spo.assert_called_once_with( + # Verify async_get_spo was called with correct parameters + mock_tg_instance.async_get_spo.assert_called_once_with( 'test_collection', 'test_subject', 'test_predicate', 'test_object', g=None, limit=100 ) @@ -130,23 +132,25 @@ class TestCassandraQueryProcessor: assert processor.cassandra_host == ['cassandra'] # Updated default assert processor.cassandra_username is None assert processor.cassandra_password is None - assert processor.table is None + assert processor._connections == {} + assert isinstance(processor._conn_lock, asyncio.Lock) def test_processor_initialization_with_custom_params(self): """Test processor initialization with custom parameters""" taskgroup_mock = MagicMock() - + processor = Processor( taskgroup=taskgroup_mock, cassandra_host='cassandra.example.com', cassandra_username='queryuser', cassandra_password='querypass' ) - + assert processor.cassandra_host == ['cassandra.example.com'] assert processor.cassandra_username == 'queryuser' assert processor.cassandra_password == 'querypass' - assert processor.table is None + assert processor._connections == {} + assert isinstance(processor._conn_lock, asyncio.Lock) @pytest.mark.asyncio @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') @@ -164,7 +168,7 @@ class TestCassandraQueryProcessor: mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_sp.return_value = [mock_result] + mock_tg_instance.async_get_sp = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -178,7 +182,7 @@ class TestCassandraQueryProcessor: result = await processor.query_triples('test_user', query) - mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50) + mock_tg_instance.async_get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50) assert len(result) == 1 assert result[0].s.iri == 'test_subject' assert result[0].p.iri == 'test_predicate' @@ -200,7 +204,7 @@ class TestCassandraQueryProcessor: mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_s.return_value = [mock_result] + mock_tg_instance.async_get_s = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -214,7 +218,7 @@ class TestCassandraQueryProcessor: result = await processor.query_triples('test_user', query) - mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25) + mock_tg_instance.async_get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25) assert len(result) == 1 assert result[0].s.iri == 'test_subject' assert result[0].p.iri == 'result_predicate' @@ -236,7 +240,7 @@ class TestCassandraQueryProcessor: mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_p.return_value = [mock_result] + mock_tg_instance.async_get_p = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -250,7 +254,7 @@ class TestCassandraQueryProcessor: result = await processor.query_triples('test_user', query) - mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10) + mock_tg_instance.async_get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10) assert len(result) == 1 assert result[0].s.iri == 'result_subject' assert result[0].p.iri == 'test_predicate' @@ -272,7 +276,7 @@ class TestCassandraQueryProcessor: mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_o.return_value = [mock_result] + mock_tg_instance.async_get_o = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -286,7 +290,7 @@ class TestCassandraQueryProcessor: result = await processor.query_triples('test_user', query) - mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75) + mock_tg_instance.async_get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75) assert len(result) == 1 assert result[0].s.iri == 'result_subject' assert result[0].p.iri == 'result_predicate' @@ -305,11 +309,11 @@ class TestCassandraQueryProcessor: mock_result.s = 'all_subject' mock_result.p = 'all_predicate' mock_result.o = 'all_object' - mock_result.g = '' + mock_result.d = '' mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_all.return_value = [mock_result] + mock_tg_instance.async_get_all = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -323,7 +327,7 @@ class TestCassandraQueryProcessor: result = await processor.query_triples('test_user', query) - mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000) + mock_tg_instance.async_get_all.assert_called_once_with('test_collection', limit=1000) assert len(result) == 1 assert result[0].s.iri == 'all_subject' assert result[0].p.iri == 'all_predicate' @@ -410,7 +414,7 @@ class TestCassandraQueryProcessor: mock_result.dtype = None mock_result.lang = None mock_result.o = 'test_object' - mock_tg_instance.get_spo.return_value = [mock_result] + mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result]) processor = Processor( taskgroup=MagicMock(), @@ -451,7 +455,7 @@ class TestCassandraQueryProcessor: mock_result.dtype = None mock_result.lang = None mock_result.o = 'test_object' - mock_tg_instance.get_spo.return_value = [mock_result] + mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -489,8 +493,8 @@ class TestCassandraQueryProcessor: mock_result.lang = None mock_result.p = 'p' mock_result.o = 'o' - mock_tg_instance1.get_s.return_value = [mock_result] - mock_tg_instance2.get_s.return_value = [mock_result] + mock_tg_instance1.async_get_s = AsyncMock(return_value=[mock_result]) + mock_tg_instance2.async_get_s = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -504,7 +508,6 @@ class TestCassandraQueryProcessor: ) await processor.query_triples('user1', query1) - assert processor.table == 'user1' # Second query with different table query2 = TriplesQueryRequest( @@ -516,10 +519,11 @@ class TestCassandraQueryProcessor: ) await processor.query_triples('user2', query2) - assert processor.table == 'user2' - # Verify TrustGraph was created twice + # Verify TrustGraph was created twice for different workspaces assert mock_kg_class.call_count == 2 + mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user1') + mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user2') @pytest.mark.asyncio @patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph') @@ -529,7 +533,7 @@ class TestCassandraQueryProcessor: mock_tg_instance = MagicMock() mock_kg_class.return_value = mock_tg_instance - mock_tg_instance.get_spo.side_effect = Exception("Query failed") + mock_tg_instance.async_get_spo = AsyncMock(side_effect=Exception("Query failed")) processor = Processor(taskgroup=MagicMock()) @@ -566,7 +570,7 @@ class TestCassandraQueryProcessor: mock_result2.otype = None mock_result2.dtype = None mock_result2.lang = None - mock_tg_instance.get_sp.return_value = [mock_result1, mock_result2] + mock_tg_instance.async_get_sp = AsyncMock(return_value=[mock_result1, mock_result2]) processor = Processor(taskgroup=MagicMock()) @@ -603,7 +607,7 @@ class TestCassandraQueryPerformanceOptimizations: mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_po.return_value = [mock_result] + mock_tg_instance.async_get_po = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -618,8 +622,8 @@ class TestCassandraQueryPerformanceOptimizations: result = await processor.query_triples('test_user', query) - # Verify get_po was called (should use optimized po_table) - mock_tg_instance.get_po.assert_called_once_with( + # Verify async_get_po was called (should use optimized po_table) + mock_tg_instance.async_get_po.assert_called_once_with( 'test_collection', 'test_predicate', 'test_object', g=None, limit=50 ) @@ -643,7 +647,7 @@ class TestCassandraQueryPerformanceOptimizations: mock_result.otype = None mock_result.dtype = None mock_result.lang = None - mock_tg_instance.get_os.return_value = [mock_result] + mock_tg_instance.async_get_os = AsyncMock(return_value=[mock_result]) processor = Processor(taskgroup=MagicMock()) @@ -658,8 +662,8 @@ class TestCassandraQueryPerformanceOptimizations: result = await processor.query_triples('test_user', query) - # Verify get_os was called (should use optimized subject_table with clustering) - mock_tg_instance.get_os.assert_called_once_with( + # Verify async_get_os was called (should use optimized subject_table with clustering) + mock_tg_instance.async_get_os.assert_called_once_with( 'test_collection', 'test_object', 'test_subject', g=None, limit=25 ) @@ -678,28 +682,28 @@ class TestCassandraQueryPerformanceOptimizations: mock_kg_class.return_value = mock_tg_instance # Mock empty results for all queries - mock_tg_instance.get_all.return_value = [] - mock_tg_instance.get_s.return_value = [] - mock_tg_instance.get_p.return_value = [] - mock_tg_instance.get_o.return_value = [] - mock_tg_instance.get_sp.return_value = [] - mock_tg_instance.get_po.return_value = [] - mock_tg_instance.get_os.return_value = [] - mock_tg_instance.get_spo.return_value = [] + mock_tg_instance.async_get_all = AsyncMock(return_value=[]) + mock_tg_instance.async_get_s = AsyncMock(return_value=[]) + mock_tg_instance.async_get_p = AsyncMock(return_value=[]) + mock_tg_instance.async_get_o = AsyncMock(return_value=[]) + mock_tg_instance.async_get_sp = AsyncMock(return_value=[]) + mock_tg_instance.async_get_po = AsyncMock(return_value=[]) + mock_tg_instance.async_get_os = AsyncMock(return_value=[]) + mock_tg_instance.async_get_spo = AsyncMock(return_value=[]) processor = Processor(taskgroup=MagicMock()) # Test each query pattern test_patterns = [ # (s, p, o, expected_method) - (None, None, None, 'get_all'), # All triples - ('s1', None, None, 'get_s'), # Subject only - (None, 'p1', None, 'get_p'), # Predicate only - (None, None, 'o1', 'get_o'), # Object only - ('s1', 'p1', None, 'get_sp'), # Subject + Predicate - (None, 'p1', 'o1', 'get_po'), # Predicate + Object (CRITICAL OPTIMIZATION) - ('s1', None, 'o1', 'get_os'), # Object + Subject - ('s1', 'p1', 'o1', 'get_spo'), # All three + (None, None, None, 'async_get_all'), # All triples + ('s1', None, None, 'async_get_s'), # Subject only + (None, 'p1', None, 'async_get_p'), # Predicate only + (None, None, 'o1', 'async_get_o'), # Object only + ('s1', 'p1', None, 'async_get_sp'), # Subject + Predicate + (None, 'p1', 'o1', 'async_get_po'), # Predicate + Object (CRITICAL OPTIMIZATION) + ('s1', None, 'o1', 'async_get_os'), # Object + Subject + ('s1', 'p1', 'o1', 'async_get_spo'), # All three ] for s, p, o, expected_method in test_patterns: @@ -759,7 +763,7 @@ class TestCassandraQueryPerformanceOptimizations: mock_result.lang = None mock_results.append(mock_result) - mock_tg_instance.get_po.return_value = mock_results + mock_tg_instance.async_get_po = AsyncMock(return_value=mock_results) processor = Processor(taskgroup=MagicMock()) @@ -774,8 +778,8 @@ class TestCassandraQueryPerformanceOptimizations: result = await processor.query_triples('large_dataset_user', query) - # Verify optimized get_po was used (no ALLOW FILTERING needed!) - mock_tg_instance.get_po.assert_called_once_with( + # Verify optimized async_get_po was used (no ALLOW FILTERING needed!) + mock_tg_instance.async_get_po.assert_called_once_with( 'massive_collection', 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type', 'http://example.com/Person', diff --git a/tests/unit/test_reliability/test_null_embedding_protection.py b/tests/unit/test_reliability/test_null_embedding_protection.py index 2296e961..dbe06b40 100644 --- a/tests/unit/test_reliability/test_null_embedding_protection.py +++ b/tests/unit/test_reliability/test_null_embedding_protection.py @@ -113,12 +113,15 @@ class TestDocEmbeddingsNullProtection: @pytest.mark.asyncio async def test_valid_embedding_upserted(self): + import asyncio from trustgraph.storage.doc_embeddings.qdrant.write import Processor proc = Processor.__new__(Processor) proc.qdrant = MagicMock() proc.qdrant.collection_exists.return_value = True proc.collection_exists = MagicMock(return_value=True) + proc._cache_lock = asyncio.Lock() + proc._known_collections = set() msg = MagicMock() msg.metadata.collection = "col1" @@ -134,12 +137,15 @@ class TestDocEmbeddingsNullProtection: @pytest.mark.asyncio async def test_dimension_in_collection_name(self): """Collection name should include vector dimension.""" + import asyncio from trustgraph.storage.doc_embeddings.qdrant.write import Processor proc = Processor.__new__(Processor) proc.qdrant = MagicMock() proc.qdrant.collection_exists.return_value = True proc.collection_exists = MagicMock(return_value=True) + proc._cache_lock = asyncio.Lock() + proc._known_collections = set() msg = MagicMock() msg.metadata.collection = "docs" @@ -220,12 +226,15 @@ class TestGraphEmbeddingsNullProtection: @pytest.mark.asyncio async def test_valid_entity_and_vector_upserted(self): + import asyncio from trustgraph.storage.graph_embeddings.qdrant.write import Processor proc = Processor.__new__(Processor) proc.qdrant = MagicMock() proc.qdrant.collection_exists.return_value = True proc.collection_exists = MagicMock(return_value=True) + proc._cache_lock = asyncio.Lock() + proc._known_collections = set() msg = MagicMock() msg.metadata.collection = "col1" @@ -241,12 +250,15 @@ class TestGraphEmbeddingsNullProtection: @pytest.mark.asyncio async def test_lazy_collection_creation_on_new_dimension(self): + import asyncio from trustgraph.storage.graph_embeddings.qdrant.write import Processor proc = Processor.__new__(Processor) proc.qdrant = MagicMock() proc.qdrant.collection_exists.return_value = False proc.collection_exists = MagicMock(return_value=True) + proc._cache_lock = asyncio.Lock() + proc._known_collections = set() msg = MagicMock() msg.metadata.collection = "graphs" diff --git a/tests/unit/test_retrieval/test_graph_rag.py b/tests/unit/test_retrieval/test_graph_rag.py index e0f41357..d1979211 100644 --- a/tests/unit/test_retrieval/test_graph_rag.py +++ b/tests/unit/test_retrieval/test_graph_rag.py @@ -337,6 +337,57 @@ class TestQuery: cache_key = "test_collection:unlabeled_entity" mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity") + @pytest.mark.asyncio + async def test_triples_query_never_passes_workspace(self): + """Workspace isolation is handled by pub/sub topic routing, not + by passing workspace to TriplesClient.query(). Verify that + GraphRAG never passes workspace as a keyword argument.""" + mock_rag = MagicMock() + mock_cache = MagicMock() + mock_cache.get.return_value = None + mock_rag.label_cache = mock_cache + mock_triples_client = AsyncMock() + mock_rag.triples_client = mock_triples_client + + mock_triple = MagicMock() + mock_triple.o = "Label" + mock_triples_client.query.return_value = [mock_triple] + + query = Query( + rag=mock_rag, + collection="test_collection", + verbose=False + ) + + await query.maybe_label("http://example.com/entity") + + for c in mock_triples_client.query.call_args_list: + assert "workspace" not in c.kwargs + + @pytest.mark.asyncio + async def test_follow_edges_never_passes_workspace(self): + """Verify follow_edges never passes workspace to query_stream.""" + mock_rag = MagicMock() + mock_triples_client = AsyncMock() + mock_rag.triples_client = mock_triples_client + + mock_triple = MagicMock() + mock_triple.s, mock_triple.p, mock_triple.o = "e1", "p1", "o1" + mock_triples_client.query_stream.return_value = [mock_triple] + + query = Query( + rag=mock_rag, + collection="test_collection", + verbose=False, + triple_limit=10 + ) + + subgraph = set() + await query.follow_edges("e1", subgraph, path_length=1) + + for c in mock_triples_client.query_stream.call_args_list: + assert "workspace" not in c.kwargs + @pytest.mark.asyncio async def test_follow_edges_basic_functionality(self): """Test Query.follow_edges method basic triple discovery""" diff --git a/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py index ce6e6b3d..360ac3dc 100644 --- a/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_qdrant_storage.py @@ -413,8 +413,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase): # Assert expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions - # Verify collection existence is checked on each write - mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection) + # Second write uses cached collection state — no collection_exists check + mock_qdrant_instance.collection_exists.assert_not_called() # But upsert should still be called mock_qdrant_instance.upsert.assert_called_once() diff --git a/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py b/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py index 8754f47c..44fdf516 100644 --- a/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py +++ b/tests/unit/test_storage/test_row_embeddings_qdrant_storage.py @@ -125,13 +125,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): processor = Processor(**config) - processor.ensure_collection("test_collection", 384) + await processor.ensure_collection("test_collection", 384) mock_qdrant_instance.collection_exists.assert_called_once_with("test_collection") mock_qdrant_instance.create_collection.assert_called_once() # Verify the collection is cached - assert "test_collection" in processor.created_collections + assert "test_collection" in processor._known_collections @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') async def test_ensure_collection_skips_existing(self, mock_qdrant_client): @@ -149,7 +149,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): processor = Processor(**config) - processor.ensure_collection("existing_collection", 384) + await processor.ensure_collection("existing_collection", 384) mock_qdrant_instance.collection_exists.assert_called_once() mock_qdrant_instance.create_collection.assert_not_called() @@ -168,9 +168,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - processor.created_collections.add("cached_collection") + processor._known_collections.add("cached_collection") - processor.ensure_collection("cached_collection", 384) + await processor.ensure_collection("cached_collection", 384) # Should not check or create - just return mock_qdrant_instance.collection_exists.assert_not_called() @@ -391,7 +391,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): } processor = Processor(**config) - processor.created_collections.add('rows_test_workspace_test_collection_schema1_384') + processor._known_collections.add('rows_test_workspace_test_collection_schema1_384') await processor.delete_collection('test_workspace', 'test_collection') @@ -399,7 +399,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase): assert mock_qdrant_instance.delete_collection.call_count == 2 # Verify the cached collection was removed - assert 'rows_test_workspace_test_collection_schema1_384' not in processor.created_collections + assert 'rows_test_workspace_test_collection_schema1_384' not in processor._known_collections @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') async def test_delete_collection_schema(self, mock_qdrant_client): diff --git a/tests/unit/test_storage/test_rows_cassandra_storage.py b/tests/unit/test_storage/test_rows_cassandra_storage.py index 852f01a1..3e5664ea 100644 --- a/tests/unit/test_storage/test_rows_cassandra_storage.py +++ b/tests/unit/test_storage/test_rows_cassandra_storage.py @@ -121,10 +121,13 @@ class TestRowsCassandraStorageLogic: @pytest.mark.asyncio async def test_schema_config_parsing(self): """Test parsing of schema configurations""" + import asyncio processor = MagicMock() processor.schemas = {} processor.config_key = "schema" processor.registered_partitions = set() + processor._setup_lock = asyncio.Lock() + processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor) processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor) # Create test configuration diff --git a/tests/unit/test_storage/test_triples_cassandra_storage.py b/tests/unit/test_storage/test_triples_cassandra_storage.py index 04acbb16..394f0e54 100644 --- a/tests/unit/test_storage/test_triples_cassandra_storage.py +++ b/tests/unit/test_storage/test_triples_cassandra_storage.py @@ -2,6 +2,8 @@ Tests for Cassandra triples storage service """ +import asyncio + import pytest from unittest.mock import MagicMock, patch, AsyncMock @@ -24,12 +26,13 @@ class TestCassandraStorageProcessor: assert processor.cassandra_host == ['cassandra'] # Updated default assert processor.cassandra_username is None assert processor.cassandra_password is None - assert processor.table is None + assert processor._connections == {} + assert isinstance(processor._conn_lock, asyncio.Lock) def test_processor_initialization_with_custom_params(self): """Test processor initialization with custom parameters (new cassandra_* names)""" taskgroup_mock = MagicMock() - + processor = Processor( taskgroup=taskgroup_mock, id='custom-storage', @@ -37,11 +40,12 @@ class TestCassandraStorageProcessor: cassandra_username='testuser', cassandra_password='testpass' ) - + assert processor.cassandra_host == ['cassandra.example.com'] assert processor.cassandra_username == 'testuser' assert processor.cassandra_password == 'testpass' - assert processor.table is None + assert processor._connections == {} + assert isinstance(processor._conn_lock, asyncio.Lock) def test_processor_initialization_with_partial_auth(self): """Test processor initialization with only username (no password)""" @@ -92,6 +96,7 @@ class TestCassandraStorageProcessor: """Test table switching logic when authentication is provided""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor( @@ -114,7 +119,6 @@ class TestCassandraStorageProcessor: username='testuser', password='testpass' ) - assert processor.table == 'user1' @pytest.mark.asyncio @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') @@ -122,6 +126,7 @@ class TestCassandraStorageProcessor: """Test table switching logic when no authentication is provided""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) @@ -138,7 +143,6 @@ class TestCassandraStorageProcessor: hosts=['cassandra'], # Updated default keyspace='user2' ) - assert processor.table == 'user2' @pytest.mark.asyncio @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') @@ -146,6 +150,7 @@ class TestCassandraStorageProcessor: """Test that TrustGraph is not recreated when table hasn't changed""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) @@ -169,6 +174,7 @@ class TestCassandraStorageProcessor: """Test that triples are properly inserted into Cassandra""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) @@ -208,12 +214,12 @@ class TestCassandraStorageProcessor: await processor.store_triples('user1', mock_message) # Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters) - assert mock_tg_instance.insert.call_count == 2 - mock_tg_instance.insert.assert_any_call( + assert mock_tg_instance.async_insert.call_count == 2 + mock_tg_instance.async_insert.assert_any_call( 'collection1', 'subject1', 'predicate1', 'object1', g=DEFAULT_GRAPH, otype='l', dtype='', lang='' ) - mock_tg_instance.insert.assert_any_call( + mock_tg_instance.async_insert.assert_any_call( 'collection1', 'subject2', 'predicate2', 'object2', g=DEFAULT_GRAPH, otype='l', dtype='', lang='' ) @@ -224,6 +230,7 @@ class TestCassandraStorageProcessor: """Test behavior when message has no triples""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) @@ -236,19 +243,17 @@ class TestCassandraStorageProcessor: await processor.store_triples('user1', mock_message) # Verify no triples were inserted - mock_tg_instance.insert.assert_not_called() + mock_tg_instance.async_insert.assert_not_called() @pytest.mark.asyncio @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') - @patch('trustgraph.storage.triples.cassandra.write.time.sleep') - async def test_exception_handling_with_retry(self, mock_sleep, mock_kg_class): + async def test_exception_handling_on_connection_failure(self, mock_kg_class): """Test exception handling during TrustGraph creation""" taskgroup_mock = MagicMock() mock_kg_class.side_effect = Exception("Connection failed") processor = Processor(taskgroup=taskgroup_mock) - # Create mock message mock_message = MagicMock() mock_message.metadata.collection = 'collection1' mock_message.triples = [] @@ -256,9 +261,6 @@ class TestCassandraStorageProcessor: with pytest.raises(Exception, match="Connection failed"): await processor.store_triples('user1', mock_message) - # Verify sleep was called before re-raising - mock_sleep.assert_called_once_with(1) - def test_add_args_method(self): """Test that add_args properly configures argument parser""" from argparse import ArgumentParser @@ -359,8 +361,6 @@ class TestCassandraStorageProcessor: mock_message1.triples = [] await processor.store_triples('user1', mock_message1) - assert processor.table == 'user1' - assert processor.tg == mock_tg_instance1 # Second message with different table mock_message2 = MagicMock() @@ -368,11 +368,11 @@ class TestCassandraStorageProcessor: mock_message2.triples = [] await processor.store_triples('user2', mock_message2) - assert processor.table == 'user2' - assert processor.tg == mock_tg_instance2 - # Verify TrustGraph was created twice for different tables + # Verify TrustGraph was created twice for different workspaces assert mock_kg_class.call_count == 2 + mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user1') + mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user2') @pytest.mark.asyncio @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') @@ -380,6 +380,7 @@ class TestCassandraStorageProcessor: """Test storing triples with special characters and unicode""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) @@ -405,7 +406,7 @@ class TestCassandraStorageProcessor: await processor.store_triples('test_workspace', mock_message) # Verify the triple was inserted with special characters preserved - mock_tg_instance.insert.assert_called_once_with( + mock_tg_instance.async_insert.assert_called_once_with( 'test_collection', 'subject with spaces & symbols', 'predicate:with/colons', @@ -418,29 +419,29 @@ class TestCassandraStorageProcessor: @pytest.mark.asyncio @patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph') - async def test_store_triples_preserves_old_table_on_exception(self, mock_kg_class): - """Test that table remains unchanged when TrustGraph creation fails""" + async def test_connection_failure_does_not_cache_stale_state(self, mock_kg_class): + """Test that a failed connection doesn't leave stale cached state""" taskgroup_mock = MagicMock() + mock_good_instance = MagicMock() processor = Processor(taskgroup=taskgroup_mock) - # Set an initial table - processor.table = ('old_user', 'old_collection') - - # Mock TrustGraph to raise exception - mock_kg_class.side_effect = Exception("Connection failed") - mock_message = MagicMock() - mock_message.metadata.collection = 'new_collection' + mock_message.metadata.collection = 'collection1' mock_message.triples = [] + # First call fails + mock_kg_class.side_effect = Exception("Connection failed") with pytest.raises(Exception, match="Connection failed"): - await processor.store_triples('new_user', mock_message) + await processor.store_triples('user1', mock_message) - # Table should remain unchanged since self.table = table happens after try/except - assert processor.table == ('old_user', 'old_collection') - # TrustGraph should be set to None though - assert processor.tg is None + # Second call succeeds — should retry connection, not use stale state + mock_kg_class.side_effect = None + mock_kg_class.return_value = mock_good_instance + await processor.store_triples('user1', mock_message) + + # Connection was attempted twice (failed + succeeded) + assert mock_kg_class.call_count == 2 class TestCassandraPerformanceOptimizations: @@ -452,6 +453,7 @@ class TestCassandraPerformanceOptimizations: """Test that legacy mode still works with single table""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}): @@ -472,6 +474,7 @@ class TestCassandraPerformanceOptimizations: """Test that optimized mode uses multi-table schema""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}): @@ -492,6 +495,7 @@ class TestCassandraPerformanceOptimizations: """Test that all tables stay consistent during batch writes""" taskgroup_mock = MagicMock() mock_tg_instance = MagicMock() + mock_tg_instance.async_insert = AsyncMock() mock_kg_class.return_value = mock_tg_instance processor = Processor(taskgroup=taskgroup_mock) @@ -517,7 +521,7 @@ class TestCassandraPerformanceOptimizations: await processor.store_triples('user1', mock_message) # Verify insert was called for the triple (implementation details tested in KnowledgeGraph) - mock_tg_instance.insert.assert_called_once_with( + mock_tg_instance.async_insert.assert_called_once_with( 'collection1', 'test_subject', 'test_predicate', 'test_object', g=DEFAULT_GRAPH, otype='l', dtype='', lang='' ) diff --git a/tests/unit/test_structured_data/test_row_embeddings_query.py b/tests/unit/test_structured_data/test_row_embeddings_query.py index 51cf834f..f1297e1c 100644 --- a/tests/unit/test_structured_data/test_row_embeddings_query.py +++ b/tests/unit/test_structured_data/test_row_embeddings_query.py @@ -89,7 +89,8 @@ class TestSanitizeName: class TestFindCollection: - def test_finds_matching_collection(self): + @pytest.mark.asyncio + async def test_finds_matching_collection(self): proc = _make_processor() mock_coll = MagicMock() mock_coll.name = "rows_test_workspace_test_col_customers_384" @@ -98,11 +99,12 @@ class TestFindCollection: mock_collections.collections = [mock_coll] proc.qdrant.get_collections.return_value = mock_collections - result = proc.find_collection("test-workspace", "test-col", "customers") + result = await proc.find_collection("test-workspace", "test-col", "customers") assert result == "rows_test_workspace_test_col_customers_384" - def test_returns_none_when_no_match(self): + @pytest.mark.asyncio + async def test_returns_none_when_no_match(self): proc = _make_processor() mock_coll = MagicMock() mock_coll.name = "rows_other_workspace_other_col_schema_768" @@ -111,14 +113,15 @@ class TestFindCollection: mock_collections.collections = [mock_coll] proc.qdrant.get_collections.return_value = mock_collections - result = proc.find_collection("test-workspace", "test-col", "customers") + result = await proc.find_collection("test-workspace", "test-col", "customers") assert result is None - def test_returns_none_on_error(self): + @pytest.mark.asyncio + async def test_returns_none_on_error(self): proc = _make_processor() proc.qdrant.get_collections.side_effect = Exception("connection error") - result = proc.find_collection("workspace", "col", "schema") + result = await proc.find_collection("workspace", "col", "schema") assert result is None @@ -139,7 +142,7 @@ class TestQueryRowEmbeddings: @pytest.mark.asyncio async def test_no_collection_returns_empty(self): proc = _make_processor() - proc.find_collection = MagicMock(return_value=None) + proc.find_collection = AsyncMock(return_value=None) request = _make_request() result = await proc.query_row_embeddings("test-workspace", request) @@ -148,7 +151,7 @@ class TestQueryRowEmbeddings: @pytest.mark.asyncio async def test_successful_query_returns_matches(self): proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_w_c_s_384") + proc.find_collection = AsyncMock(return_value="rows_w_c_s_384") points = [ _make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95), @@ -172,7 +175,7 @@ class TestQueryRowEmbeddings: async def test_index_name_filter_applied(self): """When index_name is specified, a Qdrant filter should be used.""" proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_w_c_s_384") + proc.find_collection = AsyncMock(return_value="rows_w_c_s_384") mock_result = MagicMock() mock_result.points = [] @@ -188,7 +191,7 @@ class TestQueryRowEmbeddings: async def test_no_index_name_no_filter(self): """When index_name is empty, no filter should be applied.""" proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_w_c_s_384") + proc.find_collection = AsyncMock(return_value="rows_w_c_s_384") mock_result = MagicMock() mock_result.points = [] @@ -204,7 +207,7 @@ class TestQueryRowEmbeddings: async def test_missing_payload_fields_default(self): """Points with missing payload fields should use defaults.""" proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_w_c_s_384") + proc.find_collection = AsyncMock(return_value="rows_w_c_s_384") point = MagicMock() point.payload = {} # Empty payload @@ -225,7 +228,7 @@ class TestQueryRowEmbeddings: @pytest.mark.asyncio async def test_qdrant_error_propagates(self): proc = _make_processor() - proc.find_collection = MagicMock(return_value="rows_w_c_s_384") + proc.find_collection = AsyncMock(return_value="rows_w_c_s_384") proc.qdrant.query_points.side_effect = Exception("qdrant down") request = _make_request() diff --git a/trustgraph-base/trustgraph/api/knowledge.py b/trustgraph-base/trustgraph/api/knowledge.py index c3ec2308..06357d70 100644 --- a/trustgraph-base/trustgraph/api/knowledge.py +++ b/trustgraph-base/trustgraph/api/knowledge.py @@ -132,3 +132,34 @@ class Knowledge: self.request(request = input) + def list_de_cores(self): + + input = { + "operation": "list-de-cores", + "workspace": self.api.workspace, + } + + return self.request(request = input)["ids"] + + def delete_de_core(self, id): + + input = { + "operation": "delete-de-core", + "workspace": self.api.workspace, + "id": id, + } + + self.request(request = input) + + def load_de_core(self, id, flow="default", collection="default"): + + input = { + "operation": "load-de-core", + "workspace": self.api.workspace, + "id": id, + "flow": flow, + "collection": collection, + } + + self.request(request = input) + diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index aeb15f85..75a7be9a 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -491,6 +491,58 @@ class SocketClient: triples=raw_triples, ) + def get_kg_core(self, id: str) -> Iterator[Dict[str, Any]]: + request = { + "operation": "get-kg-core", + "workspace": self.workspace, + "id": id, + } + for response in self._send_request_sync( + "knowledge", None, request, streaming_raw=True, + ): + if response.get("eos"): + break + yield response + + def put_kg_core( + self, id: str, triples=None, graph_embeddings=None, + ) -> Dict[str, Any]: + request = { + "operation": "put-kg-core", + "workspace": self.workspace, + "id": id, + } + if triples is not None: + request["triples"] = triples + if graph_embeddings is not None: + request["graph-embeddings"] = graph_embeddings + return self._send_request_sync("knowledge", None, request) + + def get_de_core(self, id: str) -> Iterator[Dict[str, Any]]: + request = { + "operation": "get-de-core", + "workspace": self.workspace, + "id": id, + } + for response in self._send_request_sync( + "knowledge", None, request, streaming_raw=True, + ): + if response.get("eos"): + break + yield response + + def put_de_core( + self, id: str, document_embeddings=None, + ) -> Dict[str, Any]: + request = { + "operation": "put-de-core", + "workspace": self.workspace, + "id": id, + } + if document_embeddings is not None: + request["document-embeddings"] = document_embeddings + return self._send_request_sync("knowledge", None, request) + def close(self) -> None: """Close the persistent WebSocket connection.""" if self._loop and not self._loop.is_closed(): diff --git a/trustgraph-base/trustgraph/messaging/translators/knowledge.py b/trustgraph-base/trustgraph/messaging/translators/knowledge.py index f2cc8e46..3830bf59 100644 --- a/trustgraph-base/trustgraph/messaging/translators/knowledge.py +++ b/trustgraph-base/trustgraph/messaging/translators/knowledge.py @@ -1,6 +1,7 @@ from typing import Dict, Any, Tuple, Optional from ...schema import ( KnowledgeRequest, KnowledgeResponse, Triples, GraphEmbeddings, + DocumentEmbeddings, ChunkEmbeddings, Metadata, EntityEmbeddings ) from .base import MessageTranslator @@ -43,6 +44,23 @@ class KnowledgeRequestTranslator(MessageTranslator): ] ) + document_embeddings = None + if "document-embeddings" in data: + document_embeddings = DocumentEmbeddings( + metadata=Metadata( + id=data["document-embeddings"]["metadata"]["id"], + root=data["document-embeddings"]["metadata"].get("root", ""), + collection=data["document-embeddings"]["metadata"]["collection"] + ), + chunks=[ + ChunkEmbeddings( + chunk_id=ch["chunk_id"], + vector=ch["vector"], + ) + for ch in data["document-embeddings"]["chunks"] + ] + ) + return KnowledgeRequest( operation=data.get("operation"), id=data.get("id"), @@ -50,6 +68,7 @@ class KnowledgeRequestTranslator(MessageTranslator): collection=data.get("collection"), triples=triples, graph_embeddings=graph_embeddings, + document_embeddings=document_embeddings, ) def encode(self, obj: KnowledgeRequest) -> Dict[str, Any]: @@ -90,6 +109,22 @@ class KnowledgeRequestTranslator(MessageTranslator): ], } + if obj.document_embeddings: + result["document-embeddings"] = { + "metadata": { + "id": obj.document_embeddings.metadata.id, + "root": obj.document_embeddings.metadata.root, + "collection": obj.document_embeddings.metadata.collection, + }, + "chunks": [ + { + "chunk_id": ch.chunk_id, + "vector": ch.vector, + } + for ch in obj.document_embeddings.chunks + ], + } + return result @@ -140,6 +175,25 @@ class KnowledgeResponseTranslator(MessageTranslator): } } + # Streaming document embeddings response + if obj.document_embeddings: + return { + "document-embeddings": { + "metadata": { + "id": obj.document_embeddings.metadata.id, + "root": obj.document_embeddings.metadata.root, + "collection": obj.document_embeddings.metadata.collection, + }, + "chunks": [ + { + "chunk_id": ch.chunk_id, + "vector": ch.vector, + } + for ch in obj.document_embeddings.chunks + ], + } + } + # End of stream marker if obj.eos is True: return {"eos": True} @@ -155,7 +209,7 @@ class KnowledgeResponseTranslator(MessageTranslator): is_final = ( obj.ids is not None or # List response obj.eos is True or # End of stream - (not obj.triples and not obj.graph_embeddings) # Empty response + (not obj.triples and not obj.graph_embeddings and not obj.document_embeddings) # Empty response ) return response, is_final \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py index 64cb7082..a3879103 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py +++ b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py @@ -4,7 +4,7 @@ from ..core.topic import queue from ..core.metadata import Metadata from .document import Document, TextDocument from .graph import Triples -from .embeddings import GraphEmbeddings +from .embeddings import GraphEmbeddings, DocumentEmbeddings # get-kg-core # -> (???) @@ -41,6 +41,9 @@ class KnowledgeRequest: triples: Triples | None = None graph_embeddings: GraphEmbeddings | None = None + # put-de-core + document_embeddings: DocumentEmbeddings | None = None + @dataclass class KnowledgeResponse: error: Error | None = None @@ -48,6 +51,7 @@ class KnowledgeResponse: eos: bool = False # Indicates end of knowledge core stream triples: Triples | None = None graph_embeddings: GraphEmbeddings | None = None + document_embeddings: DocumentEmbeddings | None = None knowledge_request_queue = queue('knowledge', cls='request') knowledge_response_queue = queue('knowledge', cls='response') diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index e8062fba..10dca2e8 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -37,6 +37,7 @@ tg-dump-msgpack = "trustgraph.cli.dump_msgpack:main" tg-dump-queues = "trustgraph.cli.dump_queues:main" tg-monitor-prompts = "trustgraph.cli.monitor_prompts:main" tg-get-flow-blueprint = "trustgraph.cli.get_flow_blueprint:main" +tg-get-de-core = "trustgraph.cli.get_de_core:main" tg-get-kg-core = "trustgraph.cli.get_kg_core:main" tg-get-document-content = "trustgraph.cli.get_document_content:main" tg-graph-to-turtle = "trustgraph.cli.graph_to_turtle:main" @@ -77,6 +78,7 @@ tg-load-turtle = "trustgraph.cli.load_turtle:main" tg-load-knowledge = "trustgraph.cli.load_knowledge:main" tg-load-structured-data = "trustgraph.cli.load_structured_data:main" tg-put-flow-blueprint = "trustgraph.cli.put_flow_blueprint:main" +tg-put-de-core = "trustgraph.cli.put_de_core:main" tg-put-kg-core = "trustgraph.cli.put_kg_core:main" tg-remove-library-document = "trustgraph.cli.remove_library_document:main" tg-save-doc-embeds = "trustgraph.cli.save_doc_embeds:main" diff --git a/trustgraph-cli/trustgraph/cli/get_de_core.py b/trustgraph-cli/trustgraph/cli/get_de_core.py new file mode 100644 index 00000000..caf74ba9 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/get_de_core.py @@ -0,0 +1,111 @@ +""" +Uses the knowledge service to fetch a document embeddings core which is +saved to a local file in msgpack format. +""" + +import argparse +import os +import msgpack + +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") + +def write_de(f, data): + msg = ( + "de", + { + "m": { + "i": data["metadata"]["id"], + "m": data["metadata"]["root"], + "c": data["metadata"]["collection"], + }, + "c": [ + { + "i": ch["chunk_id"], + "v": ch["vector"], + } + for ch in data["chunks"] + ] + } + ) + f.write(msgpack.packb(msg, use_bin_type=True)) + +def fetch(url, workspace, id, output, token=None): + + api = Api(url=url, token=token, workspace=workspace) + socket = api.socket() + + try: + de = 0 + + with open(output, "wb") as f: + + for response in socket.get_de_core(id): + + if "document-embeddings" in response: + de += 1 + write_de(f, response["document-embeddings"]) + + print(f"Got: {de} document embeddings messages.") + + finally: + socket.close() + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-get-de-core', + description=__doc__, + ) + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + + parser.add_argument( + '--id', '--identifier', + required=True, + help=f'Document embeddings core ID', + ) + + parser.add_argument( + '-o', '--output', + required=True, + help=f'Output file' + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + args = parser.parse_args() + + try: + + fetch( + url=args.url, + workspace=args.workspace, + id=args.id, + output=args.output, + token=args.token, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/get_kg_core.py b/trustgraph-cli/trustgraph/cli/get_kg_core.py index 8bee4115..b4f37b81 100644 --- a/trustgraph-cli/trustgraph/cli/get_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/get_kg_core.py @@ -5,13 +5,11 @@ to a local file in msgpack format. import argparse import os -import uuid -import asyncio -import json -from websockets.asyncio.client import connect import msgpack -default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") @@ -21,7 +19,7 @@ def write_triple(f, data): { "m": { "i": data["metadata"]["id"], - "m": data["metadata"]["metadata"], + "m": data["metadata"]["root"], "c": data["metadata"]["collection"], }, "t": data["triples"], @@ -35,13 +33,13 @@ def write_ge(f, data): { "m": { "i": data["metadata"]["id"], - "m": data["metadata"]["metadata"], + "m": data["metadata"]["root"], "c": data["metadata"]["collection"], }, "e": [ { "e": ent["entity"], - "v": ent["vectors"], + "v": ent["vector"], } for ent in data["entities"] ] @@ -49,54 +47,18 @@ def write_ge(f, data): ) f.write(msgpack.packb(msg, use_bin_type=True)) -async def fetch(url, workspace, id, output, token=None): +def fetch(url, workspace, id, output, token=None): - if not url.endswith("/"): - url += "/" - - url = url + "api/v1/socket" - - if token: - url = f"{url}?token={token}" - - mid = str(uuid.uuid4()) - - async with connect(url) as ws: - - req = json.dumps({ - "id": mid, - "workspace": workspace, - "service": "knowledge", - "request": { - "operation": "get-kg-core", - "workspace": workspace, - "id": id, - } - }) - - await ws.send(req) + api = Api(url=url, token=token, workspace=workspace) + socket = api.socket() + try: ge = 0 t = 0 with open(output, "wb") as f: - while True: - - msg = await ws.recv() - - obj = json.loads(msg) - - if "response" not in obj: - raise RuntimeError("No response?") - - response = obj["response"] - - if "error" in response: - raise RuntimeError(obj["error"]) - - if "eos" in response: - if response["eos"]: break + for response in socket.get_kg_core(id): if "triples" in response: t += 1 @@ -108,7 +70,8 @@ async def fetch(url, workspace, id, output, token=None): print(f"Got: {t} triple, {ge} GE messages.") - await ws.close() + finally: + socket.close() def main(): @@ -151,14 +114,12 @@ def main(): try: - asyncio.run( - fetch( - url=args.url, - workspace=args.workspace, - id=args.id, - output=args.output, - token=args.token, - ) + fetch( + url=args.url, + workspace=args.workspace, + id=args.id, + output=args.output, + token=args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py index 23d6bcac..f39cdab0 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py @@ -3,11 +3,8 @@ Uses the GraphRAG service to answer a question """ import argparse -import json import os import sys -import websockets -import asyncio from trustgraph.api import ( Api, ExplainabilityClient, @@ -31,607 +28,6 @@ default_max_path_length = 2 default_edge_score_limit = 30 default_edge_limit = 25 -# Provenance predicates -TG = "https://trustgraph.ai/ns/" -TG_QUERY = TG + "query" -TG_CONCEPT = TG + "concept" -TG_ENTITY = TG + "entity" -TG_EDGE_COUNT = TG + "edgeCount" -TG_SELECTED_EDGE = TG + "selectedEdge" -TG_EDGE = TG + "edge" -TG_REASONING = TG + "reasoning" -TG_DOCUMENT = TG + "document" -TG_CONTAINS = TG + "contains" -PROV = "http://www.w3.org/ns/prov#" -PROV_STARTED_AT_TIME = PROV + "startedAtTime" -PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom" -RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label" - - -def _get_event_type(prov_id): - """Extract event type from provenance_id""" - if "question" in prov_id: - return "question" - elif "grounding" in prov_id: - return "grounding" - elif "exploration" in prov_id: - return "exploration" - elif "focus" in prov_id: - return "focus" - elif "synthesis" in prov_id: - return "synthesis" - return "provenance" - - -def _format_provenance_details(event_type, triples): - """Format provenance details based on event type and triples""" - lines = [] - - if event_type == "question": - # Show query and timestamp - for s, p, o in triples: - if p == TG_QUERY: - lines.append(f" Query: {o}") - elif p == PROV_STARTED_AT_TIME: - lines.append(f" Time: {o}") - - elif event_type == "grounding": - # Show extracted concepts - concepts = [o for s, p, o in triples if p == TG_CONCEPT] - if concepts: - lines.append(f" Concepts: {len(concepts)}") - for concept in concepts: - lines.append(f" - {concept}") - - elif event_type == "exploration": - # Show edge count (seed entities resolved separately with labels) - for s, p, o in triples: - if p == TG_EDGE_COUNT: - lines.append(f" Edges explored: {o}") - - elif event_type == "focus": - # For focus, just count edge selection URIs - # The actual edge details are fetched separately via edge_selections parameter - edge_sel_uris = [] - for s, p, o in triples: - if p == TG_SELECTED_EDGE: - edge_sel_uris.append(o) - if edge_sel_uris: - lines.append(f" Focused on {len(edge_sel_uris)} edge(s)") - - elif event_type == "synthesis": - # Show document reference (content already streamed) - for s, p, o in triples: - if p == TG_DOCUMENT: - lines.append(f" Document: {o}") - - return lines - - -async def _query_triples_once(ws_url, flow_id, prov_id, collection, graph=None, debug=False): - """Query triples for a provenance node (single attempt)""" - request = { - "id": "triples-request", - "service": "triples", - "flow": flow_id, - "request": { - "s": {"t": "i", "i": prov_id}, - "collection": collection, - "limit": 100 - } - } - # Add graph filter if specified (for named graph queries) - if graph is not None: - request["request"]["g"] = graph - - if debug: - print(f" [debug] querying triples for s={prov_id}", file=sys.stderr) - - triples = [] - try: - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: - await websocket.send(json.dumps(request)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if debug: - print(f" [debug] response: {json.dumps(response)[:200]}", file=sys.stderr) - - if response.get("id") != "triples-request": - continue - - if "error" in response: - if debug: - print(f" [debug] error: {response['error']}", file=sys.stderr) - break - - if "response" in response: - resp = response["response"] - # Handle triples response - # Response format: {"response": [triples...]} - # Each triple uses compact keys: "i" for iri, "v" for value, "t" for type - triple_list = resp.get("response", []) - for t in triple_list: - s = t.get("s", {}).get("i", t.get("s", {}).get("v", "")) - p = t.get("p", {}).get("i", t.get("p", {}).get("v", "")) - # Handle quoted triples (type "t") and regular values - o_term = t.get("o", {}) - if o_term.get("t") == "t": - # Quoted triple - extract s, p, o from nested structure - tr = o_term.get("tr", {}) - o = { - "s": tr.get("s", {}).get("i", ""), - "p": tr.get("p", {}).get("i", ""), - "o": tr.get("o", {}).get("i", tr.get("o", {}).get("v", "")), - } - else: - o = o_term.get("i", o_term.get("v", "")) - triples.append((s, p, o)) - - if resp.get("complete") or response.get("complete"): - break - except Exception as e: - if debug: - print(f" [debug] exception: {e}", file=sys.stderr) - - if debug: - print(f" [debug] got {len(triples)} triples", file=sys.stderr) - - return triples - - -async def _query_triples(ws_url, flow_id, prov_id, collection, graph=None, max_retries=5, retry_delay=0.2, debug=False): - """Query triples for a provenance node with retries for race condition""" - for attempt in range(max_retries): - triples = await _query_triples_once(ws_url, flow_id, prov_id, collection, graph=graph, debug=debug) - if triples: - return triples - # Wait before retry if empty (triples may not be stored yet) - if attempt < max_retries - 1: - if debug: - print(f" [debug] retry {attempt + 1}/{max_retries}...", file=sys.stderr) - await asyncio.sleep(retry_delay) - return [] - - -async def _query_edge_provenance(ws_url, flow_id, edge_s, edge_p, edge_o, collection, debug=False): - """ - Query for provenance of an edge (s, p, o) in the knowledge graph. - - Finds subgraphs that contain the edge via tg:contains, then follows - prov:wasDerivedFrom to find source documents. - - Returns list of source URIs (chunks, pages, documents). - """ - # Query for subgraphs that contain this edge: ?subgraph tg:contains <> - request = { - "id": "edge-prov-request", - "service": "triples", - "flow": flow_id, - "request": { - "p": {"t": "i", "i": TG_CONTAINS}, - "o": { - "t": "t", # Quoted triple type - "tr": { - "s": {"t": "i", "i": edge_s}, - "p": {"t": "i", "i": edge_p}, - "o": {"t": "i", "i": edge_o} if edge_o.startswith("http") or edge_o.startswith("urn:") else {"t": "l", "v": edge_o}, - } - }, - "collection": collection, - "limit": 10 - } - } - - if debug: - print(f" [debug] querying edge provenance for ({edge_s}, {edge_p}, {edge_o})", file=sys.stderr) - - stmt_uris = [] - try: - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: - await websocket.send(json.dumps(request)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != "edge-prov-request": - continue - - if "error" in response: - if debug: - print(f" [debug] error: {response['error']}", file=sys.stderr) - break - - if "response" in response: - resp = response["response"] - triple_list = resp.get("response", []) - for t in triple_list: - s = t.get("s", {}).get("i", "") - if s: - stmt_uris.append(s) - - if resp.get("complete") or response.get("complete"): - break - except Exception as e: - if debug: - print(f" [debug] exception querying edge provenance: {e}", file=sys.stderr) - - if debug: - print(f" [debug] found {len(stmt_uris)} reifying statements", file=sys.stderr) - - # For each statement, query wasDerivedFrom to find sources - sources = [] - for stmt_uri in stmt_uris: - # Query: stmt_uri prov:wasDerivedFrom ?source - request = { - "id": "derived-from-request", - "service": "triples", - "flow": flow_id, - "request": { - "s": {"t": "i", "i": stmt_uri}, - "p": {"t": "i", "i": PROV_WAS_DERIVED_FROM}, - "collection": collection, - "limit": 10 - } - } - - try: - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: - await websocket.send(json.dumps(request)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != "derived-from-request": - continue - - if "error" in response: - break - - if "response" in response: - resp = response["response"] - triple_list = resp.get("response", []) - for t in triple_list: - o = t.get("o", {}).get("i", "") - if o: - sources.append(o) - - if resp.get("complete") or response.get("complete"): - break - except Exception as e: - if debug: - print(f" [debug] exception querying wasDerivedFrom: {e}", file=sys.stderr) - - if debug: - print(f" [debug] found {len(sources)} source(s): {sources}", file=sys.stderr) - - return sources - - -async def _query_derived_from(ws_url, flow_id, uri, collection, debug=False): - """Query for the prov:wasDerivedFrom parent of a URI. Returns None if no parent.""" - request = { - "id": "parent-request", - "service": "triples", - "flow": flow_id, - "request": { - "s": {"t": "i", "i": uri}, - "p": {"t": "i", "i": PROV_WAS_DERIVED_FROM}, - "collection": collection, - "limit": 1 - } - } - - try: - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: - await websocket.send(json.dumps(request)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != "parent-request": - continue - - if "error" in response: - break - - if "response" in response: - resp = response["response"] - triple_list = resp.get("response", []) - if triple_list: - return triple_list[0].get("o", {}).get("i", None) - - if resp.get("complete") or response.get("complete"): - break - except Exception as e: - if debug: - print(f" [debug] exception querying parent: {e}", file=sys.stderr) - - return None - - -async def _trace_provenance_chain(ws_url, flow_id, source_uri, collection, label_cache, debug=False): - """ - Trace the full provenance chain from a source URI up to the root document. - Returns a list of (uri, label) tuples from leaf to root. - """ - chain = [] - current = source_uri - max_depth = 10 # Prevent infinite loops - - for _ in range(max_depth): - if not current: - break - - # Get label for current entity - label = await _query_label(ws_url, flow_id, current, collection, label_cache, debug) - chain.append((current, label)) - - # Get parent - parent = await _query_derived_from(ws_url, flow_id, current, collection, debug) - if not parent or parent == current: - break - current = parent - - return chain - - -def _format_provenance_chain(chain): - """ - Format a provenance chain as a human-readable string. - Chain is [(uri, label), ...] from leaf to root. - """ - if not chain: - return "" - - # Show labels, from leaf to root - labels = [label for uri, label in chain] - return " → ".join(labels) - - -def _is_iri(value): - """Check if a value looks like an IRI.""" - if not isinstance(value, str): - return False - return value.startswith("http://") or value.startswith("https://") or value.startswith("urn:") - - -async def _query_label(ws_url, flow_id, iri, collection, label_cache, debug=False): - """ - Query for the rdfs:label of an IRI. - Uses label_cache to avoid repeated queries. - Returns the label if found, otherwise returns the IRI. - """ - if not _is_iri(iri): - return iri - - # Check cache first - if iri in label_cache: - return label_cache[iri] - - request = { - "id": "label-request", - "service": "triples", - "flow": flow_id, - "request": { - "s": {"t": "i", "i": iri}, - "p": {"t": "i", "i": RDFS_LABEL}, - "collection": collection, - "limit": 1 - } - } - - label = iri # Default to IRI if no label found - try: - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket: - await websocket.send(json.dumps(request)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != "label-request": - continue - - if "error" in response: - break - - if "response" in response: - resp = response["response"] - triple_list = resp.get("response", []) - if triple_list: - # Get the label value - o = triple_list[0].get("o", {}) - label = o.get("v", o.get("i", iri)) - - if resp.get("complete") or response.get("complete"): - break - except Exception as e: - if debug: - print(f" [debug] exception querying label for {iri}: {e}", file=sys.stderr) - - # Cache the result - label_cache[iri] = label - return label - - -async def _resolve_edge_labels(ws_url, flow_id, edge_triple, collection, label_cache, debug=False): - """ - Resolve labels for all IRI components of an edge triple. - Returns (s_label, p_label, o_label). - """ - s = edge_triple.get("s", "?") - p = edge_triple.get("p", "?") - o = edge_triple.get("o", "?") - - s_label = await _query_label(ws_url, flow_id, s, collection, label_cache, debug) - p_label = await _query_label(ws_url, flow_id, p, collection, label_cache, debug) - o_label = await _query_label(ws_url, flow_id, o, collection, label_cache, debug) - - return s_label, p_label, o_label - - -async def _question_explainable( - url, flow_id, question, collection, entity_limit, triple_limit, - max_subgraph_size, max_path_length, token=None, debug=False -): - """Execute graph RAG with explainability - shows provenance events with details""" - # Convert HTTP URL to WebSocket URL - if url.startswith("http://"): - ws_url = url.replace("http://", "ws://", 1) - elif url.startswith("https://"): - ws_url = url.replace("https://", "wss://", 1) - else: - ws_url = f"ws://{url}" - - ws_url = f"{ws_url.rstrip('/')}/api/v1/socket" - if token: - ws_url = f"{ws_url}?token={token}" - - # Cache for label lookups to avoid repeated queries - label_cache = {} - - request = { - "id": "cli-request", - "service": "graph-rag", - "flow": flow_id, - "request": { - "query": question, - "collection": collection, - "entity-limit": entity_limit, - "triple-limit": triple_limit, - "max-subgraph-size": max_subgraph_size, - "max-path-length": max_path_length, - "streaming": True - } - } - - async with websockets.connect(ws_url, ping_interval=20, ping_timeout=300) as websocket: - await websocket.send(json.dumps(request)) - - async for raw_message in websocket: - response = json.loads(raw_message) - - if response.get("id") != "cli-request": - continue - - if "error" in response: - print(f"\nError: {response['error']}", file=sys.stderr) - break - - if "response" in response: - resp = response["response"] - - # Check for errors in response - if "error" in resp and resp["error"]: - err = resp["error"] - print(f"\nError: {err.get('message', 'Unknown error')}", file=sys.stderr) - break - - message_type = resp.get("message_type", "") - - if debug: - print(f" [debug] message_type={message_type}, keys={list(resp.keys())}", file=sys.stderr) - - if message_type == "explain": - # Display explain event with details - explain_id = resp.get("explain_id", "") - explain_graph = resp.get("explain_graph") # Named graph (e.g., urn:graph:retrieval) - if explain_id: - event_type = _get_event_type(explain_id) - print(f"\n [{event_type}] {explain_id}", file=sys.stderr) - - # Query triples for this explain node (using named graph filter) - triples = await _query_triples( - ws_url, flow_id, explain_id, collection, graph=explain_graph, debug=debug - ) - - # Format and display details - details = _format_provenance_details(event_type, triples) - for line in details: - print(line, file=sys.stderr) - - # For exploration events, resolve entity labels - if event_type == "exploration": - entity_iris = [o for s, p, o in triples if p == TG_ENTITY] - if entity_iris: - print(f" Seed entities: {len(entity_iris)}", file=sys.stderr) - for iri in entity_iris: - label = await _query_label( - ws_url, flow_id, iri, collection, - label_cache, debug=debug - ) - print(f" - {label}", file=sys.stderr) - - # For focus events, query each edge selection for details - if event_type == "focus": - for s, p, o in triples: - if debug: - print(f" [debug] triple: p={p}, o={o}, o_type={type(o).__name__}", file=sys.stderr) - if p == TG_SELECTED_EDGE and isinstance(o, str): - if debug: - print(f" [debug] querying edge selection: {o}", file=sys.stderr) - # Query the edge selection entity (using named graph filter) - edge_triples = await _query_triples( - ws_url, flow_id, o, collection, graph=explain_graph, debug=debug - ) - if debug: - print(f" [debug] got {len(edge_triples)} edge triples", file=sys.stderr) - # Extract edge and reasoning - edge_triple = None # Store the actual triple for provenance lookup - reasoning = None - for es, ep, eo in edge_triples: - if debug: - print(f" [debug] edge triple: ep={ep}, eo={eo}", file=sys.stderr) - if ep == TG_EDGE and isinstance(eo, dict): - # eo is a quoted triple dict - edge_triple = eo - elif ep == TG_REASONING: - reasoning = eo - if edge_triple: - # Resolve labels for edge components - s_label, p_label, o_label = await _resolve_edge_labels( - ws_url, flow_id, edge_triple, collection, - label_cache, debug=debug - ) - print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr) - if reasoning: - r_short = reasoning[:100] + "..." if len(reasoning) > 100 else reasoning - print(f" Reason: {r_short}", file=sys.stderr) - - # Trace edge provenance in the workspace collection (not explainability) - if edge_triple: - sources = await _query_edge_provenance( - ws_url, flow_id, - edge_triple.get("s", ""), - edge_triple.get("p", ""), - edge_triple.get("o", ""), - collection, # Use the query collection, not explainability - debug=debug - ) - if sources: - for src in sources: - # Trace full chain from source to root document - chain = await _trace_provenance_chain( - ws_url, flow_id, src, collection, - label_cache, debug=debug - ) - chain_str = _format_provenance_chain(chain) - print(f" Source: {chain_str}", file=sys.stderr) - - elif message_type == "chunk" or not message_type: - # Display response chunk - chunk = resp.get("response", "") - if chunk: - print(chunk, end="", flush=True) - - # Check if session is complete - if resp.get("end_of_session"): - break - - print() # Final newline - - def _question_explainable_api( url, flow_id, question_text, collection, entity_limit, triple_limit, max_subgraph_size, max_path_length, edge_score_limit=30, diff --git a/trustgraph-cli/trustgraph/cli/put_de_core.py b/trustgraph-cli/trustgraph/cli/put_de_core.py new file mode 100644 index 00000000..1d6589af --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/put_de_core.py @@ -0,0 +1,119 @@ +""" +Puts a document embeddings core into the knowledge manager via the API +socket. +""" + +import argparse +import os +import msgpack + +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") + +def read_message(unpacked, id): + + if unpacked[0] == "de": + msg = unpacked[1] + return { + "metadata": { + "id": id, + "root": msg["m"]["m"], + "collection": "default", + }, + "chunks": [ + { + "chunk_id": ch["i"], + "vector": ch["v"], + } + for ch in msg["c"] + ], + } + else: + raise RuntimeError("Unexpected message type", unpacked[0]) + +def put(url, workspace, id, input, token=None): + + api = Api(url=url, token=token, workspace=workspace) + socket = api.socket() + + try: + de = 0 + + with open(input, "rb") as f: + + unpacker = msgpack.Unpacker(f, raw=False) + + while True: + + try: + unpacked = unpacker.unpack() + except msgpack.OutOfData: + break + + msg = read_message(unpacked, id) + de += 1 + socket.put_de_core(id, document_embeddings=msg) + + print(f"Put: {de} document embeddings messages.") + + finally: + socket.close() + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-put-de-core', + description=__doc__, + ) + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + + parser.add_argument( + '--id', '--identifier', + required=True, + help=f'Document embeddings core ID', + ) + + parser.add_argument( + '-i', '--input', + required=True, + help=f'Input file' + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + args = parser.parse_args() + + try: + + put( + url=args.url, + workspace=args.workspace, + id=args.id, + input=args.input, + token=args.token, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/put_kg_core.py b/trustgraph-cli/trustgraph/cli/put_kg_core.py index bd3169c8..fe0981a5 100644 --- a/trustgraph-cli/trustgraph/cli/put_kg_core.py +++ b/trustgraph-cli/trustgraph/cli/put_kg_core.py @@ -4,13 +4,11 @@ Puts a knowledge core into the knowledge manager via the API socket. import argparse import os -import uuid -import asyncio -import json -from websockets.asyncio.client import connect import msgpack -default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") @@ -21,13 +19,13 @@ def read_message(unpacked, id): return "ge", { "metadata": { "id": id, - "metadata": msg["m"]["m"], - "collection": "default", # Not used? + "root": msg["m"]["m"], + "collection": "default", }, "entities": [ { "entity": ent["e"], - "vectors": ent["v"], + "vector": ent["v"], } for ent in msg["e"] ], @@ -37,26 +35,20 @@ def read_message(unpacked, id): return "t", { "metadata": { "id": id, - "metadata": msg["m"]["m"], - "collection": "default", # Not used by receiver? + "root": msg["m"]["m"], + "collection": "default", }, "triples": msg["t"], } else: raise RuntimeError("Unpacked unexpected messsage type", unpacked[0]) -async def put(url, workspace, id, input, token=None): +def put(url, workspace, id, input, token=None): - if not url.endswith("/"): - url += "/" - - url = url + "api/v1/socket" - - if token: - url = f"{url}?token={token}" - - async with connect(url) as ws: + api = Api(url=url, token=token, workspace=workspace) + socket = api.socket() + try: ge = 0 t = 0 @@ -68,69 +60,26 @@ async def put(url, workspace, id, input, token=None): try: unpacked = unpacker.unpack() - except: + except msgpack.OutOfData: break kind, msg = read_message(unpacked, id) - mid = str(uuid.uuid4()) - if kind == "ge": - ge += 1 - - req = json.dumps({ - "id": mid, - "workspace": workspace, - "service": "knowledge", - "request": { - "operation": "put-kg-core", - "workspace": workspace, - "id": id, - "graph-embeddings": msg - } - }) + socket.put_kg_core(id, graph_embeddings=msg) elif kind == "t": - t += 1 - - req = json.dumps({ - "id": mid, - "workspace": workspace, - "service": "knowledge", - "request": { - "operation": "put-kg-core", - "workspace": workspace, - "id": id, - "triples": msg - } - }) + socket.put_kg_core(id, triples=msg) else: - raise RuntimeError("Unexpected message kind", kind) - await ws.send(req) - - # Retry loop, wait for right response to come back - while True: - - msg = await ws.recv() - msg = json.loads(msg) - - if msg["id"] != mid: - continue - - if "response" in msg: - if "error" in msg["response"]: - raise RuntimeError(msg["response"]["error"]) - - break - print(f"Put: {t} triple, {ge} GE messages.") - await ws.close() + finally: + socket.close() def main(): @@ -173,14 +122,12 @@ def main(): try: - asyncio.run( - put( - url=args.url, - workspace=args.workspace, - id=args.id, - input=args.input, - token=args.token, - ) + put( + url=args.url, + workspace=args.workspace, + id=args.id, + input=args.input, + token=args.token, ) except Exception as e: diff --git a/trustgraph-flow/trustgraph/cores/knowledge.py b/trustgraph-flow/trustgraph/cores/knowledge.py index 09c6137d..f1fa53f5 100644 --- a/trustgraph-flow/trustgraph/cores/knowledge.py +++ b/trustgraph-flow/trustgraph/cores/knowledge.py @@ -1,5 +1,6 @@ from .. schema import KnowledgeResponse, Error, Triples, GraphEmbeddings +from .. schema import DocumentEmbeddings from .. knowledge import hash from .. exceptions import RequestError from .. tables.knowledge import KnowledgeTableStore @@ -157,6 +158,98 @@ class KnowledgeManager: ) ) + async def list_de_cores(self, request, respond, workspace): + + ids = await self.table_store.list_de_cores(workspace) + + await respond( + KnowledgeResponse( + error = None, + ids = ids, + eos = False, + triples = None, + graph_embeddings = None, + ) + ) + + async def get_de_core(self, request, respond, workspace): + + logger.info("Getting document embeddings core...") + + async def publish_de(de): + await respond( + KnowledgeResponse( + error = None, + ids = None, + eos = False, + triples = None, + graph_embeddings = None, + document_embeddings = de, + ) + ) + + await self.table_store.get_document_embeddings( + workspace, + request.id, + publish_de, + ) + + logger.debug("Document embeddings core retrieval complete") + + await respond( + KnowledgeResponse( + error = None, + ids = None, + eos = True, + triples = None, + graph_embeddings = None, + ) + ) + + async def put_de_core(self, request, respond, workspace): + + if request.document_embeddings: + await self.table_store.add_document_embeddings( + workspace, request.document_embeddings + ) + + await respond( + KnowledgeResponse( + error = None, + ids = None, + eos = False, + triples = None, + graph_embeddings = None, + ) + ) + + async def delete_de_core(self, request, respond, workspace): + + logger.info("Deleting document embeddings core...") + + await self.table_store.delete_document_embeddings( + workspace, request.id + ) + + await respond( + KnowledgeResponse( + error = None, + ids = None, + eos = False, + triples = None, + graph_embeddings = None, + ) + ) + + async def load_de_core(self, request, respond, workspace): + + if self.background_task is None: + self.background_task = asyncio.create_task( + self.core_loader() + ) + + await self.loader_queue.put((request, respond, workspace)) + async def core_loader(self): logger.info("Knowledge background processor running...") @@ -165,7 +258,7 @@ class KnowledgeManager: logger.debug("Waiting for next load...") request, respond, workspace = await self.loader_queue.get() - logger.info(f"Loading knowledge: {request.id}") + logger.info(f"Loading: {request.operation} {request.id}") try: @@ -187,25 +280,14 @@ class KnowledgeManager: if "interfaces" not in flow: raise RuntimeError("No defined interfaces") - if "triples-store" not in flow["interfaces"]: - raise RuntimeError("Flow has no triples-store") - - if "graph-embeddings-store" not in flow["interfaces"]: - raise RuntimeError("Flow has no graph-embeddings-store") - - t_q = flow["interfaces"]["triples-store"]["flow"] - ge_q = flow["interfaces"]["graph-embeddings-store"]["flow"] - - # Got this far, it should all work - await respond( - KnowledgeResponse( - error = None, - ids = None, - eos = False, - triples = None, - graph_embeddings = None + if request.operation == "load-de-core": + await self._load_de_core( + request, respond, workspace, flow, + ) + else: + await self._load_kg_core( + request, respond, workspace, flow, ) - ) except Exception as e: @@ -223,72 +305,145 @@ class KnowledgeManager: ) ) - - logger.debug("Starting knowledge loading process...") - - try: - - t_pub = None - ge_pub = None - - logger.debug(f"Triples queue: {t_q}") - logger.debug(f"Graph embeddings queue: {ge_q}") - - t_pub = Publisher( - self.flow_config.pubsub, t_q, - schema=Triples, - ) - ge_pub = Publisher( - self.flow_config.pubsub, ge_q, - schema=GraphEmbeddings - ) - - logger.debug("Starting publishers...") - - await t_pub.start() - await ge_pub.start() - - async def publish_triples(t): - # Override collection with request collection - if hasattr(t, 'metadata') and hasattr(t.metadata, 'collection'): - t.metadata.collection = request.collection or "default" - await t_pub.send(None, t) - - logger.debug("Publishing triples...") - - await self.table_store.get_triples( - workspace, - request.id, - publish_triples, - ) - - async def publish_ge(g): - # Override collection with request collection - if hasattr(g, 'metadata') and hasattr(g.metadata, 'collection'): - g.metadata.collection = request.collection or "default" - await ge_pub.send(None, g) - - logger.debug("Publishing graph embeddings...") - - await self.table_store.get_graph_embeddings( - workspace, - request.id, - publish_ge, - ) - - logger.debug("Knowledge loading completed") - - except Exception as e: - - logger.error(f"Knowledge exception: {e}", exc_info=True) - - finally: - - logger.debug("Stopping publishers...") - - if t_pub: await t_pub.stop() - if ge_pub: await ge_pub.stop() - logger.debug("Knowledge processing done") continue + + async def _load_kg_core(self, request, respond, workspace, flow): + + if "triples-store" not in flow["interfaces"]: + raise RuntimeError("Flow has no triples-store") + + if "graph-embeddings-store" not in flow["interfaces"]: + raise RuntimeError("Flow has no graph-embeddings-store") + + t_q = flow["interfaces"]["triples-store"]["flow"] + ge_q = flow["interfaces"]["graph-embeddings-store"]["flow"] + + await respond( + KnowledgeResponse( + error = None, + ids = None, + eos = False, + triples = None, + graph_embeddings = None + ) + ) + + t_pub = None + ge_pub = None + + try: + + logger.debug(f"Triples queue: {t_q}") + logger.debug(f"Graph embeddings queue: {ge_q}") + + t_pub = Publisher( + self.flow_config.pubsub, t_q, + schema=Triples, + ) + ge_pub = Publisher( + self.flow_config.pubsub, ge_q, + schema=GraphEmbeddings + ) + + logger.debug("Starting publishers...") + + await t_pub.start() + await ge_pub.start() + + async def publish_triples(t): + if hasattr(t, 'metadata') and hasattr(t.metadata, 'collection'): + t.metadata.collection = request.collection or "default" + await t_pub.send(None, t) + + logger.debug("Publishing triples...") + + await self.table_store.get_triples( + workspace, + request.id, + publish_triples, + ) + + async def publish_ge(g): + if hasattr(g, 'metadata') and hasattr(g.metadata, 'collection'): + g.metadata.collection = request.collection or "default" + await ge_pub.send(None, g) + + logger.debug("Publishing graph embeddings...") + + await self.table_store.get_graph_embeddings( + workspace, + request.id, + publish_ge, + ) + + logger.debug("Knowledge core loading completed") + + except Exception as e: + + logger.error(f"Knowledge exception: {e}", exc_info=True) + + finally: + + logger.debug("Stopping publishers...") + + if t_pub: await t_pub.stop() + if ge_pub: await ge_pub.stop() + + async def _load_de_core(self, request, respond, workspace, flow): + + if "document-embeddings-store" not in flow["interfaces"]: + raise RuntimeError("Flow has no document-embeddings-store") + + de_q = flow["interfaces"]["document-embeddings-store"]["flow"] + + await respond( + KnowledgeResponse( + error = None, + ids = None, + eos = False, + triples = None, + graph_embeddings = None + ) + ) + + de_pub = None + + try: + + logger.debug(f"Document embeddings queue: {de_q}") + + de_pub = Publisher( + self.flow_config.pubsub, de_q, + schema=DocumentEmbeddings, + ) + + logger.debug("Starting publisher...") + + await de_pub.start() + + async def publish_de(de): + if hasattr(de, 'metadata') and hasattr(de.metadata, 'collection'): + de.metadata.collection = request.collection or "default" + await de_pub.send(None, de) + + logger.debug("Publishing document embeddings...") + + await self.table_store.get_document_embeddings( + workspace, + request.id, + publish_de, + ) + + logger.debug("Document embeddings core loading completed") + + except Exception as e: + + logger.error(f"Knowledge exception: {e}", exc_info=True) + + finally: + + logger.debug("Stopping publisher...") + + if de_pub: await de_pub.stop() diff --git a/trustgraph-flow/trustgraph/cores/service.py b/trustgraph-flow/trustgraph/cores/service.py index c84b536c..a04e42ca 100755 --- a/trustgraph-flow/trustgraph/cores/service.py +++ b/trustgraph-flow/trustgraph/cores/service.py @@ -187,6 +187,11 @@ class Processor(WorkspaceProcessor): "put-kg-core": self.knowledge.put_kg_core, "load-kg-core": self.knowledge.load_kg_core, "unload-kg-core": self.knowledge.unload_kg_core, + "list-de-cores": self.knowledge.list_de_cores, + "get-de-core": self.knowledge.get_de_core, + "delete-de-core": self.knowledge.delete_de_core, + "put-de-core": self.knowledge.put_de_core, + "load-de-core": self.knowledge.load_de_core, } if v.operation not in impls: diff --git a/trustgraph-flow/trustgraph/direct/cassandra_kg.py b/trustgraph-flow/trustgraph/direct/cassandra_kg.py index 59d2a2a1..d7abd1a9 100644 --- a/trustgraph-flow/trustgraph/direct/cassandra_kg.py +++ b/trustgraph-flow/trustgraph/direct/cassandra_kg.py @@ -1,10 +1,14 @@ +import datetime +import os +import logging + from cassandra.cluster import Cluster from cassandra.auth import PlainTextAuthProvider from cassandra.query import BatchStatement, SimpleStatement from ssl import SSLContext, PROTOCOL_TLSv1_2 -import os -import logging + +from ..tables.cassandra_async import async_execute # Global list to track clusters for cleanup _active_clusters = [] @@ -461,7 +465,6 @@ class KnowledgeGraph: def create_collection(self, collection): """Create collection by inserting metadata row""" try: - import datetime self.session.execute( f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)", (collection, datetime.datetime.now()) @@ -954,7 +957,6 @@ class EntityCentricKnowledgeGraph: def create_collection(self, collection): """Create collection by inserting metadata row""" try: - import datetime self.session.execute( f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)", (collection, datetime.datetime.now()) @@ -1045,6 +1047,222 @@ class EntityCentricKnowledgeGraph: logger.info(f"Deleted collection {collection}: {len(entities)} entity partitions, {len(quads)} quads") + # ======================================================================== + # Async methods — use cassandra driver's native async API via async_execute + # ======================================================================== + + async def async_insert(self, collection, s, p, o, g=None, otype=None, dtype="", lang=""): + if g is None: + g = DEFAULT_GRAPH + if otype is None: + if o.startswith("http://") or o.startswith("https://"): + otype = "u" + else: + otype = "l" + + batch = BatchStatement() + batch.add(self.insert_entity_stmt, (collection, s, 'S', p, otype, s, o, g, dtype, lang)) + batch.add(self.insert_entity_stmt, (collection, p, 'P', p, otype, s, o, g, dtype, lang)) + if otype == 'u' or otype == 't': + batch.add(self.insert_entity_stmt, (collection, o, 'O', p, otype, s, o, g, dtype, lang)) + if g != DEFAULT_GRAPH: + batch.add(self.insert_entity_stmt, (collection, g, 'G', p, otype, s, o, g, dtype, lang)) + batch.add(self.insert_collection_stmt, (collection, g, s, p, o, otype, dtype, lang)) + + await async_execute(self.session, batch) + + async def async_get_all(self, collection, limit=50): + return await async_execute( + self.session, self.get_collection_all_stmt, (collection, limit) + ) + + async def async_get_s(self, collection, s, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_s_stmt, (collection, s, limit) + ) + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=row.s, p=row.p, o=row.o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_p(self, collection, p, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_p_stmt, (collection, p, limit) + ) + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=row.s, p=row.p, o=row.o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_o(self, collection, o, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_o_stmt, (collection, o, limit) + ) + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=row.s, p=row.p, o=row.o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_sp(self, collection, s, p, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_s_p_stmt, (collection, s, p, limit) + ) + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=s, p=p, o=row.o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_po(self, collection, p, o, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_o_p_stmt, (collection, o, p, limit) + ) + results = [] + for row in rows: + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=row.s, p=p, o=o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_os(self, collection, o, s, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_s_stmt, (collection, s, limit) + ) + results = [] + for row in rows: + if row.o != o: + continue + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=s, p=row.p, o=o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_spo(self, collection, s, p, o, g=None, limit=10): + rows = await async_execute( + self.session, self.get_entity_as_s_p_stmt, (collection, s, p, limit) + ) + results = [] + for row in rows: + if row.o != o: + continue + d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + if g is not None and d != g: + continue + results.append(QuadResult( + s=s, p=p, o=o, g=d, + otype=row.otype, dtype=row.dtype, lang=row.lang + )) + return results + + async def async_get_g(self, collection, g, limit=50): + if g is None: + g = DEFAULT_GRAPH + return await async_execute( + self.session, self.get_collection_by_graph_stmt, (collection, g, limit) + ) + + async def async_collection_exists(self, collection): + try: + result = await async_execute( + self.session, + f"SELECT collection FROM {self.collection_metadata_table} WHERE collection = %s LIMIT 1", + (collection,) + ) + return bool(result) + except Exception as e: + logger.error(f"Error checking collection existence: {e}") + return False + + async def async_create_collection(self, collection): + await async_execute( + self.session, + f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)", + (collection, datetime.datetime.now()) + ) + logger.info(f"Created collection metadata for {collection}") + + async def async_delete_collection(self, collection): + rows = await async_execute( + self.session, + f"SELECT d, s, p, o, otype, dtype, lang FROM {self.collection_table} WHERE collection = %s", + (collection,) + ) + + entities = set() + quads = [] + for row in rows: + d, s, p, o = row.d, row.s, row.p, row.o + otype = row.otype + dtype = row.dtype if hasattr(row, 'dtype') else '' + lang = row.lang if hasattr(row, 'lang') else '' + quads.append((d, s, p, o, otype, dtype, lang)) + entities.add(s) + entities.add(p) + if otype == 'u' or otype == 't': + entities.add(o) + if d != DEFAULT_GRAPH: + entities.add(d) + + batch = BatchStatement() + count = 0 + for entity in entities: + batch.add(self.delete_entity_partition_stmt, (collection, entity)) + count += 1 + if count % 50 == 0: + await async_execute(self.session, batch) + batch = BatchStatement() + if count % 50 != 0: + await async_execute(self.session, batch) + + batch = BatchStatement() + count = 0 + for d, s, p, o, otype, dtype, lang in quads: + batch.add(self.delete_collection_row_stmt, (collection, d, s, p, o, otype, dtype, lang)) + count += 1 + if count % 50 == 0: + await async_execute(self.session, batch) + batch = BatchStatement() + if count % 50 != 0: + await async_execute(self.session, batch) + + await async_execute( + self.session, + f"DELETE FROM {self.collection_metadata_table} WHERE collection = %s", + (collection,) + ) + logger.info(f"Deleted collection {collection}: {len(entities)} entity partitions, {len(quads)} quads") + def close(self): """Close connections""" if hasattr(self, 'session') and self.session: diff --git a/trustgraph-flow/trustgraph/gateway/registry.py b/trustgraph-flow/trustgraph/gateway/registry.py index 5e3344f4..4d439097 100644 --- a/trustgraph-flow/trustgraph/gateway/registry.py +++ b/trustgraph-flow/trustgraph/gateway/registry.py @@ -457,6 +457,12 @@ for _op in ("put-kg-core", "delete-kg-core", "load-kg-core", "unload-kg-core"): _register_kind_op("knowledge", _op, "knowledge:write") +# knowledge: document-embeddings core service. +for _op in ("get-de-core", "list-de-cores"): + _register_kind_op("knowledge", _op, "knowledge:read") +for _op in ("put-de-core", "delete-de-core", "load-de-core"): + _register_kind_op("knowledge", _op, "knowledge:write") + # collection-management: workspace collection lifecycle. _register_kind_op("collection-management", "list-collections", "collections:read") diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py index 1d59c835..f6770744 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py @@ -4,11 +4,10 @@ Document embeddings query service. Input is vector, output is an array of chunk_ids """ +import asyncio import logging from qdrant_client import QdrantClient -from qdrant_client.models import PointStruct -from qdrant_client.models import Distance, VectorParams from .... schema import DocumentEmbeddingsResponse, ChunkMatch from .... schema import Error @@ -38,32 +37,6 @@ class Processor(DocumentEmbeddingsQueryService): ) self.qdrant = QdrantClient(url=store_uri, api_key=api_key) - self.last_collection = None - - def ensure_collection_exists(self, collection, dim): - """Ensure collection exists, create if it doesn't""" - if collection != self.last_collection: - if not self.qdrant.collection_exists(collection): - try: - self.qdrant.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, distance=Distance.COSINE - ), - ) - logger.info(f"Created collection: {collection}") - except Exception as e: - logger.error(f"Qdrant collection creation failed: {e}") - raise e - self.last_collection = collection - - def collection_exists(self, collection): - """Check if collection exists (no implicit creation)""" - return self.qdrant.collection_exists(collection) - - def collection_exists(self, collection): - """Check if collection exists (no implicit creation)""" - return self.qdrant.collection_exists(collection) async def query_document_embeddings(self, workspace, msg): @@ -73,21 +46,24 @@ class Processor(DocumentEmbeddingsQueryService): if not vec: return [] - # Use dimension suffix in collection name dim = len(vec) collection = f"d_{workspace}_{msg.collection}_{dim}" - # Check if collection exists - return empty if not - if not self.collection_exists(collection): + exists = await asyncio.to_thread( + self.qdrant.collection_exists, collection + ) + if not exists: logger.info(f"Collection {collection} does not exist, returning empty results") return [] - search_result = self.qdrant.query_points( + result = await asyncio.to_thread( + self.qdrant.query_points, collection_name=collection, query=vec, limit=msg.limit, with_payload=True, - ).points + ) + search_result = result.points chunks = [] for r in search_result: diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py index b8fb1361..167130c9 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py @@ -4,11 +4,10 @@ Graph embeddings query service. Input is vector, output is list of entities """ +import asyncio import logging from qdrant_client import QdrantClient -from qdrant_client.models import PointStruct -from qdrant_client.models import Distance, VectorParams from .... schema import GraphEmbeddingsResponse, EntityMatch from .... schema import Error, Term, IRI, LITERAL @@ -38,32 +37,6 @@ class Processor(GraphEmbeddingsQueryService): ) self.qdrant = QdrantClient(url=store_uri, api_key=api_key) - self.last_collection = None - - def ensure_collection_exists(self, collection, dim): - """Ensure collection exists, create if it doesn't""" - if collection != self.last_collection: - if not self.qdrant.collection_exists(collection): - try: - self.qdrant.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, distance=Distance.COSINE - ), - ) - logger.info(f"Created collection: {collection}") - except Exception as e: - logger.error(f"Qdrant collection creation failed: {e}") - raise e - self.last_collection = collection - - def collection_exists(self, collection): - """Check if collection exists (no implicit creation)""" - return self.qdrant.collection_exists(collection) - - def collection_exists(self, collection): - """Check if collection exists (no implicit creation)""" - return self.qdrant.collection_exists(collection) def create_value(self, ent): if ent.startswith("http://") or ent.startswith("https://"): @@ -79,23 +52,26 @@ class Processor(GraphEmbeddingsQueryService): if not vec: return [] - # Use dimension suffix in collection name dim = len(vec) collection = f"t_{workspace}_{msg.collection}_{dim}" - # Check if collection exists - return empty if not - if not self.collection_exists(collection): + exists = await asyncio.to_thread( + self.qdrant.collection_exists, collection + ) + if not exists: logger.info(f"Collection {collection} does not exist") return [] # Heuristic hack, get (2*limit), so that we have more chance # of getting (limit) unique entities - search_result = self.qdrant.query_points( + result = await asyncio.to_thread( + self.qdrant.query_points, collection_name=collection, query=vec, limit=msg.limit * 2, with_payload=True, - ).points + ) + search_result = result.points entity_set = set() entities = [] diff --git a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py index dd89a8d8..1534c044 100644 --- a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py @@ -6,6 +6,7 @@ Output is matching row index information (index_name, index_value) for use in subsequent Cassandra lookups. """ +import asyncio import logging import re from typing import Optional @@ -70,7 +71,7 @@ class Processor(FlowProcessor): safe_name = 'r_' + safe_name return safe_name.lower() - def find_collection(self, workspace: str, collection: str, schema_name: str) -> Optional[str]: + async def find_collection(self, workspace: str, collection: str, schema_name: str) -> Optional[str]: """Find the Qdrant collection for a given workspace/collection/schema""" prefix = ( f"rows_{self.sanitize_name(workspace)}_" @@ -78,14 +79,15 @@ class Processor(FlowProcessor): ) try: - all_collections = self.qdrant.get_collections().collections + all_collections = await asyncio.to_thread( + lambda: self.qdrant.get_collections().collections + ) matching = [ coll.name for coll in all_collections if coll.name.startswith(prefix) ] if matching: - # Return first match (there should typically be only one per dimension) return matching[0] except Exception as e: @@ -100,8 +102,7 @@ class Processor(FlowProcessor): if not vec: return [] - # Find the collection for this workspace/collection/schema - qdrant_collection = self.find_collection( + qdrant_collection = await self.find_collection( workspace, request.collection, request.schema_name ) @@ -113,7 +114,6 @@ class Processor(FlowProcessor): return [] try: - # Build optional filter for index_name query_filter = None if request.index_name: query_filter = Filter( @@ -125,16 +125,16 @@ class Processor(FlowProcessor): ] ) - # Query Qdrant - search_result = self.qdrant.query_points( + result = await asyncio.to_thread( + self.qdrant.query_points, collection_name=qdrant_collection, query=vec, limit=request.limit, with_payload=True, query_filter=query_filter, - ).points + ) + search_result = result.points - # Convert to RowIndexMatch objects matches = [] for point in search_result: payload = point.payload or {} diff --git a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py index 73cfcd83..7157daae 100644 --- a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py @@ -11,6 +11,7 @@ Queries against the unified 'rows' table with schema: - source: text """ +import asyncio import json import logging import re @@ -97,34 +98,38 @@ class Processor(FlowProcessor): # Cassandra session self.cluster = None self.session = None + self._setup_lock = asyncio.Lock() # Known keyspaces self.known_keyspaces: Set[str] = set() - def connect_cassandra(self): + async def connect_cassandra(self): """Connect to Cassandra cluster""" - if self.session: - return + async with self._setup_lock: + if self.session: + return - try: - if self.cassandra_username and self.cassandra_password: - auth_provider = PlainTextAuthProvider( - username=self.cassandra_username, - password=self.cassandra_password - ) - self.cluster = Cluster( - contact_points=self.cassandra_host, - auth_provider=auth_provider - ) - else: - self.cluster = Cluster(contact_points=self.cassandra_host) + try: + if self.cassandra_username and self.cassandra_password: + auth_provider = PlainTextAuthProvider( + username=self.cassandra_username, + password=self.cassandra_password + ) + cluster = Cluster( + contact_points=self.cassandra_host, + auth_provider=auth_provider + ) + else: + cluster = Cluster(contact_points=self.cassandra_host) - self.session = self.cluster.connect() - logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}") + session = await asyncio.to_thread(cluster.connect) + self.cluster = cluster + self.session = session + logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}") - except Exception as e: - logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True) - raise + except Exception as e: + logger.error(f"Failed to connect to Cassandra: {e}", exc_info=True) + raise def sanitize_name(self, name: str) -> str: """Sanitize names for Cassandra compatibility""" @@ -140,14 +145,17 @@ class Processor(FlowProcessor): f"for workspace {workspace}" ) - # Replace existing schemas for this workspace + async with self._setup_lock: + await self._apply_schema_config(workspace, config) + + async def _apply_schema_config(self, workspace, config): + ws_schemas: Dict[str, RowSchema] = {} self.schemas[workspace] = ws_schemas builder = GraphQLSchemaBuilder() self.schema_builders[workspace] = builder - # Check if our config type exists if self.config_key not in config: logger.warning( f"No '{self.config_key}' type in configuration " @@ -156,16 +164,12 @@ class Processor(FlowProcessor): self.graphql_schemas[workspace] = None return - # Get the schemas dictionary for our type schemas_config = config[self.config_key] - # Process each schema in the schemas config for schema_name, schema_json in schemas_config.items(): try: - # Parse the JSON schema definition schema_def = json.loads(schema_json) - # Create Field objects fields = [] for field_def in schema_def.get("fields", []): field = SchemaField( @@ -180,7 +184,6 @@ class Processor(FlowProcessor): ) fields.append(field) - # Create RowSchema row_schema = RowSchema( name=schema_def.get("name", schema_name), description=schema_def.get("description", ""), @@ -202,7 +205,6 @@ class Processor(FlowProcessor): f"{len(ws_schemas)} schemas" ) - # Regenerate GraphQL schema for this workspace self.graphql_schemas[workspace] = builder.build(self.query_cassandra) def get_index_names(self, schema: RowSchema) -> List[str]: @@ -254,7 +256,7 @@ class Processor(FlowProcessor): For other queries, we need to scan and post-filter. """ # Connect if needed - self.connect_cassandra() + await self.connect_cassandra() safe_keyspace = self.sanitize_name(workspace) diff --git a/trustgraph-flow/trustgraph/query/sparql/algebra.py b/trustgraph-flow/trustgraph/query/sparql/algebra.py index bff9a336..76b1ad8e 100644 --- a/trustgraph-flow/trustgraph/query/sparql/algebra.py +++ b/trustgraph-flow/trustgraph/query/sparql/algebra.py @@ -30,14 +30,13 @@ class EvaluationError(Exception): pass -async def evaluate(node, triples_client, workspace, collection, limit=10000): +async def evaluate(node, triples_client, collection, limit=10000): """ Evaluate a SPARQL algebra node. Args: node: rdflib CompValue algebra node triples_client: TriplesClient instance for triple pattern queries - workspace: workspace/keyspace identifier collection: collection identifier limit: safety limit on results @@ -55,24 +54,24 @@ async def evaluate(node, triples_client, workspace, collection, limit=10000): logger.warning(f"Unsupported algebra node: {name}") return [{}] - return await handler(node, triples_client, workspace, collection, limit) + return await handler(node, triples_client, collection, limit) # --- Node handlers --- -async def _eval_select_query(node, tc, workspace, collection, limit): +async def _eval_select_query(node, tc, collection, limit): """Evaluate a SelectQuery node.""" - return await evaluate(node.p, tc, workspace, collection, limit) + return await evaluate(node.p, tc, collection, limit) -async def _eval_project(node, tc, workspace, collection, limit): +async def _eval_project(node, tc, collection, limit): """Evaluate a Project node (SELECT variable projection).""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) + solutions = await evaluate(node.p, tc, collection, limit) variables = [str(v) for v in node.PV] return project(solutions, variables) -async def _eval_bgp(node, tc, workspace, collection, limit): +async def _eval_bgp(node, tc, collection, limit): """ Evaluate a Basic Graph Pattern. @@ -107,7 +106,7 @@ async def _eval_bgp(node, tc, workspace, collection, limit): # Query the triples store results = await _query_pattern( - tc, s_val, p_val, o_val, workspace, collection, limit + tc, s_val, p_val, o_val, collection, limit ) # Map results back to variable bindings, @@ -130,17 +129,17 @@ async def _eval_bgp(node, tc, workspace, collection, limit): return solutions[:limit] -async def _eval_join(node, tc, workspace, collection, limit): +async def _eval_join(node, tc, collection, limit): """Evaluate a Join node.""" - left = await evaluate(node.p1, tc, workspace, collection, limit) - right = await evaluate(node.p2, tc, workspace, collection, limit) + left = await evaluate(node.p1, tc, collection, limit) + right = await evaluate(node.p2, tc, collection, limit) return hash_join(left, right)[:limit] -async def _eval_left_join(node, tc, workspace, collection, limit): +async def _eval_left_join(node, tc, collection, limit): """Evaluate a LeftJoin node (OPTIONAL).""" - left_sols = await evaluate(node.p1, tc, workspace, collection, limit) - right_sols = await evaluate(node.p2, tc, workspace, collection, limit) + left_sols = await evaluate(node.p1, tc, collection, limit) + right_sols = await evaluate(node.p2, tc, collection, limit) filter_fn = None if hasattr(node, "expr") and node.expr is not None: @@ -153,16 +152,16 @@ async def _eval_left_join(node, tc, workspace, collection, limit): return left_join(left_sols, right_sols, filter_fn)[:limit] -async def _eval_union(node, tc, workspace, collection, limit): +async def _eval_union(node, tc, collection, limit): """Evaluate a Union node.""" - left = await evaluate(node.p1, tc, workspace, collection, limit) - right = await evaluate(node.p2, tc, workspace, collection, limit) + left = await evaluate(node.p1, tc, collection, limit) + right = await evaluate(node.p2, tc, collection, limit) return union(left, right)[:limit] -async def _eval_filter(node, tc, workspace, collection, limit): +async def _eval_filter(node, tc, collection, limit): """Evaluate a Filter node.""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) + solutions = await evaluate(node.p, tc, collection, limit) expr = node.expr return [ sol for sol in solutions @@ -170,22 +169,22 @@ async def _eval_filter(node, tc, workspace, collection, limit): ] -async def _eval_distinct(node, tc, workspace, collection, limit): +async def _eval_distinct(node, tc, collection, limit): """Evaluate a Distinct node.""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) + solutions = await evaluate(node.p, tc, collection, limit) return distinct(solutions) -async def _eval_reduced(node, tc, workspace, collection, limit): +async def _eval_reduced(node, tc, collection, limit): """Evaluate a Reduced node (like Distinct but implementation-defined).""" # Treat same as Distinct - solutions = await evaluate(node.p, tc, workspace, collection, limit) + solutions = await evaluate(node.p, tc, collection, limit) return distinct(solutions) -async def _eval_order_by(node, tc, workspace, collection, limit): +async def _eval_order_by(node, tc, collection, limit): """Evaluate an OrderBy node.""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) + solutions = await evaluate(node.p, tc, collection, limit) key_fns = [] for cond in node.expr: @@ -206,7 +205,7 @@ async def _eval_order_by(node, tc, workspace, collection, limit): return order_by(solutions, key_fns) -async def _eval_slice(node, tc, workspace, collection, limit): +async def _eval_slice(node, tc, collection, limit): """Evaluate a Slice node (LIMIT/OFFSET).""" # Pass tighter limit downstream if possible inner_limit = limit @@ -214,13 +213,13 @@ async def _eval_slice(node, tc, workspace, collection, limit): offset = node.start or 0 inner_limit = min(limit, offset + node.length) - solutions = await evaluate(node.p, tc, workspace, collection, inner_limit) + solutions = await evaluate(node.p, tc, collection, inner_limit) return slice_solutions(solutions, node.start or 0, node.length) -async def _eval_extend(node, tc, workspace, collection, limit): +async def _eval_extend(node, tc, collection, limit): """Evaluate an Extend node (BIND).""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) + solutions = await evaluate(node.p, tc, collection, limit) var_name = str(node.var) expr = node.expr @@ -246,9 +245,9 @@ async def _eval_extend(node, tc, workspace, collection, limit): return result -async def _eval_group(node, tc, workspace, collection, limit): +async def _eval_group(node, tc, collection, limit): """Evaluate a Group node (GROUP BY with aggregation).""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) + solutions = await evaluate(node.p, tc, collection, limit) # Extract grouping expressions group_exprs = [] @@ -289,9 +288,9 @@ async def _eval_group(node, tc, workspace, collection, limit): return result -async def _eval_aggregate_join(node, tc, workspace, collection, limit): +async def _eval_aggregate_join(node, tc, collection, limit): """Evaluate an AggregateJoin (aggregation functions after GROUP BY).""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) + solutions = await evaluate(node.p, tc, collection, limit) result = [] for sol in solutions: @@ -310,7 +309,7 @@ async def _eval_aggregate_join(node, tc, workspace, collection, limit): return result -async def _eval_graph(node, tc, workspace, collection, limit): +async def _eval_graph(node, tc, collection, limit): """Evaluate a Graph node (GRAPH clause).""" term = node.term @@ -319,16 +318,16 @@ async def _eval_graph(node, tc, workspace, collection, limit): # We'd need to pass graph to triples queries # For now, evaluate inner pattern normally logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired") - return await evaluate(node.p, tc, workspace, collection, limit) + return await evaluate(node.p, tc, collection, limit) elif isinstance(term, Variable): # GRAPH ?g { ... } — variable graph logger.info(f"GRAPH ?{term} clause - variable graph not yet wired") - return await evaluate(node.p, tc, workspace, collection, limit) + return await evaluate(node.p, tc, collection, limit) else: - return await evaluate(node.p, tc, workspace, collection, limit) + return await evaluate(node.p, tc, collection, limit) -async def _eval_values(node, tc, workspace, collection, limit): +async def _eval_values(node, tc, collection, limit): """Evaluate a VALUES clause (inline data).""" variables = [str(v) for v in node.var] solutions = [] @@ -343,9 +342,9 @@ async def _eval_values(node, tc, workspace, collection, limit): return solutions -async def _eval_to_multiset(node, tc, workspace, collection, limit): +async def _eval_to_multiset(node, tc, collection, limit): """Evaluate a ToMultiSet node (subquery).""" - return await evaluate(node.p, tc, workspace, collection, limit) + return await evaluate(node.p, tc, collection, limit) # --- Aggregate computation --- @@ -487,7 +486,7 @@ def _resolve_term(tmpl, solution): return rdflib_term_to_term(tmpl) -async def _query_pattern(tc, s, p, o, workspace, collection, limit): +async def _query_pattern(tc, s, p, o, collection, limit): """ Issue a streaming triple pattern query via TriplesClient. @@ -496,7 +495,6 @@ async def _query_pattern(tc, s, p, o, workspace, collection, limit): results = await tc.query( s=s, p=p, o=o, limit=limit, - workspace=workspace, collection=collection, ) return results diff --git a/trustgraph-flow/trustgraph/query/sparql/service.py b/trustgraph-flow/trustgraph/query/sparql/service.py index 983cd4f6..75c00dba 100644 --- a/trustgraph-flow/trustgraph/query/sparql/service.py +++ b/trustgraph-flow/trustgraph/query/sparql/service.py @@ -141,7 +141,6 @@ class Processor(FlowProcessor): solutions = await evaluate( parsed.algebra, triples_client, - workspace=flow.workspace, collection=request.collection or "default", limit=request.limit or 10000, ) diff --git a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py index a9bdbbac..822dba25 100755 --- a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py @@ -6,8 +6,8 @@ null. Output is a list of quads. import asyncio import logging - import json + from cassandra.query import SimpleStatement from .... direct.cassandra_kg import ( @@ -176,45 +176,42 @@ class Processor(TriplesQueryService): self.cassandra_host = hosts self.cassandra_username = username self.cassandra_password = password - self.table = None - def ensure_connection(self, workspace): - """Ensure we have a connection to the correct keyspace.""" - if workspace != self.table: - KGClass = EntityCentricKnowledgeGraph + self._connections = {} + self._conn_lock = asyncio.Lock() - if self.cassandra_username and self.cassandra_password: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - username=self.cassandra_username, - password=self.cassandra_password - ) - else: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - ) - self.table = workspace + async def _get_connection(self, workspace): + async with self._conn_lock: + if workspace not in self._connections: + if self.cassandra_username and self.cassandra_password: + tg = await asyncio.to_thread( + EntityCentricKnowledgeGraph, + hosts=self.cassandra_host, + keyspace=workspace, + username=self.cassandra_username, + password=self.cassandra_password, + ) + else: + tg = await asyncio.to_thread( + EntityCentricKnowledgeGraph, + hosts=self.cassandra_host, + keyspace=workspace, + ) + self._connections[workspace] = tg + return self._connections[workspace] async def query_triples(self, workspace, query): try: - # ensure_connection may construct a fresh - # EntityCentricKnowledgeGraph which does sync schema - # setup against Cassandra. Push it to a worker thread - # so the event loop doesn't block on first-use per workspace. - await asyncio.to_thread(self.ensure_connection, workspace) - - # Extract values from query s_val = get_term_value(query.s) p_val = get_term_value(query.p) o_val = get_term_value(query.o) - g_val = query.g # Already a string or None + g_val = query.g + + tg = await self._get_connection(workspace) def get_object_metadata(row): - """Extract term type metadata from result row""" return ( getattr(row, 'otype', None), getattr(row, 'dtype', None), @@ -223,33 +220,21 @@ class Processor(TriplesQueryService): quads = [] - # All self.tg.get_* calls below are sync wrappers around - # cassandra session.execute. Materialise inside a worker - # thread so iteration never triggers sync paging back on - # the event loop. - - # Route to appropriate query method based on which fields are specified if s_val is not None: if p_val is not None: if o_val is not None: - # SPO specified - find matching graphs - resp = await asyncio.to_thread( - lambda: list(self.tg.get_spo( - query.collection, s_val, p_val, o_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_spo( + query.collection, s_val, p_val, o_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH term_type, datatype, language = get_object_metadata(t) quads.append((s_val, p_val, o_val, g, term_type, datatype, language)) else: - # SP specified - resp = await asyncio.to_thread( - lambda: list(self.tg.get_sp( - query.collection, s_val, p_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_sp( + query.collection, s_val, p_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH @@ -257,24 +242,18 @@ class Processor(TriplesQueryService): quads.append((s_val, p_val, t.o, g, term_type, datatype, language)) else: if o_val is not None: - # SO specified - resp = await asyncio.to_thread( - lambda: list(self.tg.get_os( - query.collection, o_val, s_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_os( + query.collection, o_val, s_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH term_type, datatype, language = get_object_metadata(t) quads.append((s_val, t.p, o_val, g, term_type, datatype, language)) else: - # S only - resp = await asyncio.to_thread( - lambda: list(self.tg.get_s( - query.collection, s_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_s( + query.collection, s_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH @@ -283,24 +262,18 @@ class Processor(TriplesQueryService): else: if p_val is not None: if o_val is not None: - # PO specified - resp = await asyncio.to_thread( - lambda: list(self.tg.get_po( - query.collection, p_val, o_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_po( + query.collection, p_val, o_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH term_type, datatype, language = get_object_metadata(t) quads.append((t.s, p_val, o_val, g, term_type, datatype, language)) else: - # P only - resp = await asyncio.to_thread( - lambda: list(self.tg.get_p( - query.collection, p_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_p( + query.collection, p_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH @@ -308,40 +281,26 @@ class Processor(TriplesQueryService): quads.append((t.s, p_val, t.o, g, term_type, datatype, language)) else: if o_val is not None: - # O only - resp = await asyncio.to_thread( - lambda: list(self.tg.get_o( - query.collection, o_val, - g=g_val, limit=query.limit, - )) + resp = await tg.async_get_o( + query.collection, o_val, + g=g_val, limit=query.limit, ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH term_type, datatype, language = get_object_metadata(t) quads.append((t.s, t.p, o_val, g, term_type, datatype, language)) else: - # Nothing specified - get all - resp = await asyncio.to_thread( - lambda: list(self.tg.get_all( - query.collection, limit=query.limit, - )) + resp = await tg.async_get_all( + query.collection, limit=query.limit, ) for t in resp: - # Note: quads_by_collection uses 'd' for graph field g = t.d if hasattr(t, 'd') else DEFAULT_GRAPH - # Filter by graph - # g_val=None means all graphs (no filter) - # g_val="" means default graph only - # otherwise filter to specific named graph if g_val is not None: if g != g_val: continue term_type, datatype, language = get_object_metadata(t) quads.append((t.s, t.p, t.o, g, term_type, datatype, language)) - # Convert to Triple objects (with g field) - # s and p are always IRIs in RDF - # Object uses term_type/datatype/language metadata from database triples = [ Triple( s=create_term(q[0], term_type='u'), @@ -365,51 +324,41 @@ class Processor(TriplesQueryService): Uses Cassandra's paging to fetch results incrementally. """ try: - await asyncio.to_thread(self.ensure_connection, workspace) batch_size = query.batch_size if query.batch_size > 0 else 20 limit = query.limit if query.limit > 0 else 10000 - # Extract query pattern s_val = get_term_value(query.s) p_val = get_term_value(query.p) o_val = get_term_value(query.o) g_val = query.g def get_object_metadata(row): - """Extract term type metadata from result row""" return ( getattr(row, 'otype', None), getattr(row, 'dtype', None), getattr(row, 'lang', None), ) - # For streaming, we need to execute with fetch_size - # Use the collection table for get_all queries (most common streaming case) - - # Determine which query to use based on pattern if s_val is None and p_val is None and o_val is None: - # Get all - use collection table with paging - cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {self.tg.collection_table} WHERE collection = %s" + + tg = await self._get_connection(workspace) + + cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {tg.collection_table} WHERE collection = %s" params = [query.collection] + statement = SimpleStatement(cql, fetch_size=batch_size) + # async_execute only materialises the first page; + # this query needs all pages, so use sync execute + # in a worker thread where page iteration can block. + result_set = await asyncio.to_thread( + lambda: list(tg.session.execute(statement, params)) + ) + else: - # For specific patterns, fall back to non-streaming - # (these typically return small result sets anyway) async for batch, is_final in self._fallback_stream(workspace, query, batch_size): yield batch, is_final return - # Materialise in a worker thread. We lose true streaming - # paging (the driver fetches all pages eagerly inside the - # thread) but the event loop stays responsive, and result - # sets at this layer are typically small enough that this - # is acceptable. If true async paging is needed later, - # revisit using ResponseFuture page callbacks. - statement = SimpleStatement(cql, fetch_size=batch_size) - result_set = await asyncio.to_thread( - lambda: list(self.tg.session.execute(statement, params)) - ) - batch = [] count = 0 diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py index fb7166b5..2bfef99c 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py @@ -3,11 +3,13 @@ Accepts entity/vector pairs and writes them to a Qdrant store. """ +import asyncio +import uuid +import logging + from qdrant_client import QdrantClient from qdrant_client.models import PointStruct from qdrant_client.models import Distance, VectorParams -import uuid -import logging from .... base import DocumentEmbeddingsStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer @@ -35,13 +37,35 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): ) self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + self._cache_lock = asyncio.Lock() + self._known_collections: set[str] = set() # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) + async def ensure_collection(self, collection_name, dim): + async with self._cache_lock: + if collection_name in self._known_collections: + return + exists = await asyncio.to_thread( + self.qdrant.collection_exists, collection_name + ) + if not exists: + logger.info( + f"Lazily creating Qdrant collection {collection_name} " + f"with dimension {dim}" + ) + await asyncio.to_thread( + self.qdrant.create_collection, + collection_name=collection_name, + vectors_config=VectorParams( + size=dim, distance=Distance.COSINE + ), + ) + self._known_collections.add(collection_name) + async def store_document_embeddings(self, workspace, message): - # Validate collection exists in config before processing if not self.collection_exists(workspace, message.metadata.collection): logger.warning( f"Collection {message.metadata.collection} for workspace {workspace} " @@ -60,24 +84,15 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): if not vec: continue - # Create collection name with dimension suffix for lazy creation dim = len(vec) collection = ( f"d_{workspace}_{message.metadata.collection}_{dim}" ) - # Lazily create collection if it doesn't exist (but only if authorized in config) - if not self.qdrant.collection_exists(collection): - logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") - self.qdrant.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, - distance=Distance.COSINE - ) - ) + await self.ensure_collection(collection, dim) - self.qdrant.upsert( + await asyncio.to_thread( + self.qdrant.upsert, collection_name=collection, points=[ PointStruct( @@ -87,7 +102,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): "chunk_id": chunk_id, } ) - ] + ], ) @staticmethod @@ -124,8 +139,9 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): try: prefix = f"d_{workspace}_{collection}_" - # Get all collections and filter for matches - all_collections = self.qdrant.get_collections().collections + all_collections = await asyncio.to_thread( + lambda: self.qdrant.get_collections().collections + ) matching_collections = [ coll.name for coll in all_collections if coll.name.startswith(prefix) @@ -135,7 +151,11 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService): logger.info(f"No collections found matching prefix {prefix}") else: for collection_name in matching_collections: - self.qdrant.delete_collection(collection_name) + await asyncio.to_thread( + self.qdrant.delete_collection, collection_name + ) + async with self._cache_lock: + self._known_collections.discard(collection_name) logger.info(f"Deleted Qdrant collection: {collection_name}") logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}") diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py index 391c2a04..13dcdba8 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py @@ -3,11 +3,13 @@ Accepts entity/vector pairs and writes them to a Qdrant store. """ +import asyncio +import uuid +import logging + from qdrant_client import QdrantClient from qdrant_client.models import PointStruct from qdrant_client.models import Distance, VectorParams -import uuid -import logging from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler from .... base import AsyncProcessor, Consumer, Producer @@ -50,13 +52,35 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): ) self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + self._cache_lock = asyncio.Lock() + self._known_collections: set[str] = set() # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) + async def ensure_collection(self, collection_name, dim): + async with self._cache_lock: + if collection_name in self._known_collections: + return + exists = await asyncio.to_thread( + self.qdrant.collection_exists, collection_name + ) + if not exists: + logger.info( + f"Lazily creating Qdrant collection {collection_name} " + f"with dimension {dim}" + ) + await asyncio.to_thread( + self.qdrant.create_collection, + collection_name=collection_name, + vectors_config=VectorParams( + size=dim, distance=Distance.COSINE + ), + ) + self._known_collections.add(collection_name) + async def store_graph_embeddings(self, workspace, message): - # Validate collection exists in config before processing if not self.collection_exists(workspace, message.metadata.collection): logger.warning( f"Collection {message.metadata.collection} for workspace {workspace} " @@ -75,22 +99,12 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): if not vec: continue - # Create collection name with dimension suffix for lazy creation dim = len(vec) collection = ( f"t_{workspace}_{message.metadata.collection}_{dim}" ) - # Lazily create collection if it doesn't exist (but only if authorized in config) - if not self.qdrant.collection_exists(collection): - logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}") - self.qdrant.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, - distance=Distance.COSINE - ) - ) + await self.ensure_collection(collection, dim) payload = { "entity": entity_value, @@ -98,7 +112,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): if entity.chunk_id: payload["chunk_id"] = entity.chunk_id - self.qdrant.upsert( + await asyncio.to_thread( + self.qdrant.upsert, collection_name=collection, points=[ PointStruct( @@ -106,7 +121,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): vector=vec, payload=payload, ) - ] + ], ) @staticmethod @@ -143,8 +158,9 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): try: prefix = f"t_{workspace}_{collection}_" - # Get all collections and filter for matches - all_collections = self.qdrant.get_collections().collections + all_collections = await asyncio.to_thread( + lambda: self.qdrant.get_collections().collections + ) matching_collections = [ coll.name for coll in all_collections if coll.name.startswith(prefix) @@ -154,7 +170,11 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService): logger.info(f"No collections found matching prefix {prefix}") else: for collection_name in matching_collections: - self.qdrant.delete_collection(collection_name) + await asyncio.to_thread( + self.qdrant.delete_collection, collection_name + ) + async with self._cache_lock: + self._known_collections.discard(collection_name) logger.info(f"Deleted Qdrant collection: {collection_name}") logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}") diff --git a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py index 32d87871..a01629c5 100644 --- a/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/row_embeddings/qdrant/write.py @@ -16,10 +16,10 @@ Payload structure: - text: The text that was embedded (for debugging/display) """ +import asyncio import logging import re import uuid -from typing import Set, Tuple from qdrant_client import QdrantClient from qdrant_client.models import PointStruct, Distance, VectorParams @@ -63,11 +63,9 @@ class Processor(CollectionConfigHandler, FlowProcessor): # Register config handler for collection management self.register_config_handler(self.on_collection_config, types=["collection"]) - # Cache of created Qdrant collections - self.created_collections: Set[str] = set() - - # Qdrant client self.qdrant = QdrantClient(url=store_uri, api_key=api_key) + self._cache_lock = asyncio.Lock() + self._known_collections: set[str] = set() def sanitize_name(self, name: str) -> str: """Sanitize names for Qdrant collection naming""" @@ -85,25 +83,28 @@ class Processor(CollectionConfigHandler, FlowProcessor): safe_schema = self.sanitize_name(schema_name) return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}" - def ensure_collection(self, collection_name: str, dimension: int): + async def ensure_collection(self, collection_name: str, dimension: int): """Create Qdrant collection if it doesn't exist""" - if collection_name in self.created_collections: - return - - if not self.qdrant.collection_exists(collection_name): - logger.info( - f"Creating Qdrant collection {collection_name} " - f"with dimension {dimension}" + async with self._cache_lock: + if collection_name in self._known_collections: + return + exists = await asyncio.to_thread( + self.qdrant.collection_exists, collection_name ) - self.qdrant.create_collection( - collection_name=collection_name, - vectors_config=VectorParams( - size=dimension, - distance=Distance.COSINE + if not exists: + logger.info( + f"Creating Qdrant collection {collection_name} " + f"with dimension {dimension}" ) - ) - - self.created_collections.add(collection_name) + await asyncio.to_thread( + self.qdrant.create_collection, + collection_name=collection_name, + vectors_config=VectorParams( + size=dimension, + distance=Distance.COSINE + ), + ) + self._known_collections.add(collection_name) async def on_embeddings(self, msg, consumer, flow): """Process incoming RowEmbeddings and write to Qdrant""" @@ -143,15 +144,14 @@ class Processor(CollectionConfigHandler, FlowProcessor): dimension = len(vector) - # Create/get collection name (lazily on first vector) if qdrant_collection is None: qdrant_collection = self.get_collection_name( workspace, collection, schema_name, dimension ) - self.ensure_collection(qdrant_collection, dimension) + await self.ensure_collection(qdrant_collection, dimension) - # Write to Qdrant - self.qdrant.upsert( + await asyncio.to_thread( + self.qdrant.upsert, collection_name=qdrant_collection, points=[ PointStruct( @@ -163,7 +163,7 @@ class Processor(CollectionConfigHandler, FlowProcessor): "text": row_emb.text } ) - ] + ], ) embeddings_written += 1 @@ -181,8 +181,9 @@ class Processor(CollectionConfigHandler, FlowProcessor): try: prefix = f"rows_{self.sanitize_name(workspace)}_{self.sanitize_name(collection)}_" - # Get all collections and filter for matches - all_collections = self.qdrant.get_collections().collections + all_collections = await asyncio.to_thread( + lambda: self.qdrant.get_collections().collections + ) matching_collections = [ coll.name for coll in all_collections if coll.name.startswith(prefix) @@ -192,8 +193,11 @@ class Processor(CollectionConfigHandler, FlowProcessor): logger.info(f"No Qdrant collections found matching prefix {prefix}") else: for collection_name in matching_collections: - self.qdrant.delete_collection(collection_name) - self.created_collections.discard(collection_name) + await asyncio.to_thread( + self.qdrant.delete_collection, collection_name + ) + async with self._cache_lock: + self._known_collections.discard(collection_name) logger.info(f"Deleted Qdrant collection: {collection_name}") logger.info( f"Deleted {len(matching_collections)} collection(s) " @@ -217,8 +221,9 @@ class Processor(CollectionConfigHandler, FlowProcessor): f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_" ) - # Get all collections and filter for matches - all_collections = self.qdrant.get_collections().collections + all_collections = await asyncio.to_thread( + lambda: self.qdrant.get_collections().collections + ) matching_collections = [ coll.name for coll in all_collections if coll.name.startswith(prefix) @@ -228,8 +233,11 @@ class Processor(CollectionConfigHandler, FlowProcessor): logger.info(f"No Qdrant collections found matching prefix {prefix}") else: for collection_name in matching_collections: - self.qdrant.delete_collection(collection_name) - self.created_collections.discard(collection_name) + await asyncio.to_thread( + self.qdrant.delete_collection, collection_name + ) + async with self._cache_lock: + self._known_collections.discard(collection_name) logger.info(f"Deleted Qdrant collection: {collection_name}") except Exception as e: diff --git a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py index a5dad748..65eeee06 100755 --- a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py @@ -82,7 +82,7 @@ class Processor(CollectionConfigHandler, FlowProcessor): # Cache of known keyspaces and whether tables exist self.known_keyspaces: Set[str] = set() - self.tables_initialized: Set[str] = set() # keyspaces with rows/row_partitions tables + self.tables_initialized: Set[str] = set() # Cache of registered (collection, schema_name) pairs self.registered_partitions: Set[Tuple[str, str]] = set() @@ -94,6 +94,9 @@ class Processor(CollectionConfigHandler, FlowProcessor): self.cluster = None self.session = None + # Protects connection setup and cache mutations + self._setup_lock = asyncio.Lock() + def connect_cassandra(self): """Connect to Cassandra cluster""" if self.session: @@ -126,6 +129,11 @@ class Processor(CollectionConfigHandler, FlowProcessor): f"for workspace {workspace}" ) + async with self._setup_lock: + return await self._apply_schema_config(workspace, config, version) + + async def _apply_schema_config(self, workspace, config, version): + # Track which schemas changed in this workspace old_schemas = self.schemas.get(workspace, {}) old_schema_names = set(old_schemas.keys()) @@ -391,16 +399,12 @@ class Processor(CollectionConfigHandler, FlowProcessor): schema_name = obj.schema_name source = getattr(obj.metadata, 'source', '') or '' - # Ensure tables exist (sync DDL — push to a worker thread - # so the event loop stays responsive when running in a - # processor group sharing the loop with siblings). - await asyncio.to_thread(self.ensure_tables, keyspace) - - # Register partitions if first time seeing this (collection, schema_name) - await asyncio.to_thread( - self.register_partitions, - keyspace, collection, schema_name, workspace, - ) + async with self._setup_lock: + await asyncio.to_thread(self.ensure_tables, keyspace) + await asyncio.to_thread( + self.register_partitions, + keyspace, collection, schema_name, workspace, + ) safe_keyspace = self.sanitize_name(keyspace) @@ -461,35 +465,27 @@ class Processor(CollectionConfigHandler, FlowProcessor): async def create_collection(self, workspace: str, collection: str, metadata: dict): """Create/verify collection exists in Cassandra row store""" - # Connect if not already connected (sync, push to thread) - await asyncio.to_thread(self.connect_cassandra) - - # Ensure tables exist (sync DDL, push to thread) - await asyncio.to_thread(self.ensure_tables, workspace) + async with self._setup_lock: + await asyncio.to_thread(self.connect_cassandra) + await asyncio.to_thread(self.ensure_tables, workspace) logger.info(f"Collection {collection} ready for workspace {workspace}") async def delete_collection(self, workspace: str, collection: str): """Delete all data for a specific collection using partition tracking""" - # Connect if not already connected - await asyncio.to_thread(self.connect_cassandra) + async with self._setup_lock: + await asyncio.to_thread(self.connect_cassandra) + if workspace not in self.known_keyspaces: + safe_ks = self.sanitize_name(workspace) + check_cql = "SELECT keyspace_name FROM system_schema.keyspaces WHERE keyspace_name = %s" + result = await async_execute(self.session, check_cql, (safe_ks,)) + if not result: + logger.info(f"Keyspace {safe_ks} does not exist, nothing to delete") + return + self.known_keyspaces.add(workspace) safe_keyspace = self.sanitize_name(workspace) - # Check if keyspace exists - if workspace not in self.known_keyspaces: - check_keyspace_cql = """ - SELECT keyspace_name FROM system_schema.keyspaces - WHERE keyspace_name = %s - """ - result = await async_execute( - self.session, check_keyspace_cql, (safe_keyspace,) - ) - if not result: - logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete") - return - self.known_keyspaces.add(workspace) - # Discover all partitions for this collection select_partitions_cql = f""" SELECT schema_name, index_name FROM {safe_keyspace}.row_partitions @@ -540,11 +536,11 @@ class Processor(CollectionConfigHandler, FlowProcessor): logger.error(f"Failed to clean up row_partitions for {collection}: {e}") raise - # Clear from local cache - self.registered_partitions = { - (col, sch) for col, sch in self.registered_partitions - if col != collection - } + async with self._setup_lock: + self.registered_partitions = { + (col, sch) for col, sch in self.registered_partitions + if col != collection + } logger.info( f"Deleted collection {collection}: {partitions_deleted} partitions " @@ -553,8 +549,8 @@ class Processor(CollectionConfigHandler, FlowProcessor): async def delete_collection_schema(self, workspace: str, collection: str, schema_name: str): """Delete all data for a specific collection + schema combination""" - # Connect if not already connected - await asyncio.to_thread(self.connect_cassandra) + async with self._setup_lock: + await asyncio.to_thread(self.connect_cassandra) safe_keyspace = self.sanitize_name(workspace) @@ -614,8 +610,8 @@ class Processor(CollectionConfigHandler, FlowProcessor): ) raise - # Clear from local cache - self.registered_partitions.discard((collection, schema_name)) + async with self._setup_lock: + self.registered_partitions.discard((collection, schema_name)) logger.info( f"Deleted {collection}/{schema_name}: {partitions_deleted} partitions " diff --git a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py index 0774153b..79d6c549 100755 --- a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py @@ -4,12 +4,7 @@ Graph writer. Input is graph edge. Writes edges to Cassandra graph. """ import asyncio -import base64 -import os -import argparse -import time import logging -import json from .... direct.cassandra_kg import ( EntityCentricKnowledgeGraph, DEFAULT_GRAPH @@ -28,6 +23,8 @@ default_ident = "triples-write" def serialize_triple(triple): """Serialize a Triple object to JSON for storage.""" + import json + if triple is None: return None @@ -141,156 +138,84 @@ class Processor(CollectionConfigHandler, TriplesStoreService): self.cassandra_host = hosts self.cassandra_username = username self.cassandra_password = password - self.table = None - self.tg = None + + self._connections = {} + self._conn_lock = asyncio.Lock() # Register for config push notifications self.register_config_handler(self.on_collection_config, types=["collection"]) + async def _get_connection(self, workspace): + async with self._conn_lock: + if workspace not in self._connections: + if self.cassandra_username and self.cassandra_password: + tg = await asyncio.to_thread( + EntityCentricKnowledgeGraph, + hosts=self.cassandra_host, + keyspace=workspace, + username=self.cassandra_username, + password=self.cassandra_password, + ) + else: + tg = await asyncio.to_thread( + EntityCentricKnowledgeGraph, + hosts=self.cassandra_host, + keyspace=workspace, + ) + self._connections[workspace] = tg + return self._connections[workspace] + async def store_triples(self, workspace, message): - # The cassandra-driver work below — connection, schema - # setup, and per-triple inserts — is all synchronous. - # Wrap the whole batch in a worker thread so the event - # loop stays responsive for sibling processors when - # running in a processor group. + tg = await self._get_connection(workspace) - def _do_store(): + for t in message.triples: + s_val = get_term_value(t.s) + p_val = get_term_value(t.p) + o_val = get_term_value(t.o) + g_val = t.g if t.g is not None else DEFAULT_GRAPH - if self.table is None or self.table != workspace: + otype = get_term_otype(t.o) + dtype = get_term_dtype(t.o) + lang = get_term_lang(t.o) - self.tg = None - - # Use factory function to select implementation - KGClass = EntityCentricKnowledgeGraph - - try: - if self.cassandra_username and self.cassandra_password: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - username=self.cassandra_username, - password=self.cassandra_password, - ) - else: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - ) - except Exception as e: - logger.error(f"Exception: {e}", exc_info=True) - time.sleep(1) - raise e - - self.table = workspace - - for t in message.triples: - # Extract values from Term objects - s_val = get_term_value(t.s) - p_val = get_term_value(t.p) - o_val = get_term_value(t.o) - # t.g is None for default graph, or a graph IRI - g_val = t.g if t.g is not None else DEFAULT_GRAPH - - # Extract object type metadata for entity-centric storage - otype = get_term_otype(t.o) - dtype = get_term_dtype(t.o) - lang = get_term_lang(t.o) - - self.tg.insert( - message.metadata.collection, - s_val, - p_val, - o_val, - g=g_val, - otype=otype, - dtype=dtype, - lang=lang, - ) - - await asyncio.to_thread(_do_store) + await tg.async_insert( + message.metadata.collection, + s_val, + p_val, + o_val, + g=g_val, + otype=otype, + dtype=dtype, + lang=lang, + ) async def create_collection(self, workspace: str, collection: str, metadata: dict): """Create a collection in Cassandra triple store via config push""" + try: + tg = await self._get_connection(workspace) - def _do_create(): - # Create or reuse connection for this workspace's keyspace - if self.table is None or self.table != workspace: - self.tg = None - - # Use factory function to select implementation - KGClass = EntityCentricKnowledgeGraph - - try: - if self.cassandra_username and self.cassandra_password: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - username=self.cassandra_username, - password=self.cassandra_password, - ) - else: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - ) - except Exception as e: - logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}") - raise - - self.table = workspace - - # Create collection using the built-in method logger.info(f"Creating collection {collection} for workspace {workspace}") - if self.tg.collection_exists(collection): + exists = await tg.async_collection_exists(collection) + if exists: logger.info(f"Collection {collection} already exists") else: - self.tg.create_collection(collection) + await tg.async_create_collection(collection) logger.info(f"Created collection {collection}") - try: - await asyncio.to_thread(_do_create) except Exception as e: logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True) raise async def delete_collection(self, workspace: str, collection: str): """Delete all data for a specific collection from the unified triples table""" + try: + tg = await self._get_connection(workspace) - def _do_delete(): - # Create or reuse connection for this workspace's keyspace - if self.table is None or self.table != workspace: - self.tg = None - - # Use factory function to select implementation - KGClass = EntityCentricKnowledgeGraph - - try: - if self.cassandra_username and self.cassandra_password: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - username=self.cassandra_username, - password=self.cassandra_password, - ) - else: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=workspace, - ) - except Exception as e: - logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}") - raise - - self.table = workspace - - # Delete all triples for this collection using the built-in method - self.tg.delete_collection(collection) + await tg.async_delete_collection(collection) logger.info(f"Deleted all triples for collection {collection} from keyspace {workspace}") - try: - await asyncio.to_thread(_do_delete) except Exception as e: logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True) raise diff --git a/trustgraph-flow/trustgraph/tables/knowledge.py b/trustgraph-flow/trustgraph/tables/knowledge.py index 5d45358d..cf085fdd 100644 --- a/trustgraph-flow/trustgraph/tables/knowledge.py +++ b/trustgraph-flow/trustgraph/tables/knowledge.py @@ -1,6 +1,7 @@ from .. schema import KnowledgeResponse, Triple, Triples, EntityEmbeddings from .. schema import Metadata, Term, IRI, LITERAL, GraphEmbeddings +from .. schema import DocumentEmbeddings, ChunkEmbeddings from cassandra.cluster import Cluster @@ -217,6 +218,16 @@ class KnowledgeTableStore: WHERE workspace = ? AND document_id = ? """) + self.delete_document_embeddings_stmt = self.cassandra.prepare(""" + DELETE FROM document_embeddings + WHERE workspace = ? AND document_id = ? + """) + + self.list_de_cores_stmt = self.cassandra.prepare(""" + SELECT DISTINCT workspace, document_id FROM document_embeddings + WHERE workspace = ? + """) + async def add_triples(self, workspace, m): when = int(time.time() * 1000) @@ -338,6 +349,50 @@ class KnowledgeTableStore: logger.error("Exception occurred", exc_info=True) raise + try: + await async_execute( + self.cassandra, + self.delete_document_embeddings_stmt, + (workspace, document_id), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise + + async def delete_document_embeddings(self, workspace, document_id): + + logger.debug("Delete document embeddings...") + + try: + await async_execute( + self.cassandra, + self.delete_document_embeddings_stmt, + (workspace, document_id), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise + + async def list_de_cores(self, workspace): + + logger.debug("List DE cores...") + + try: + rows = await async_execute( + self.cassandra, + self.list_de_cores_stmt, + (workspace,), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise + + lst = [row[1] for row in rows] + + logger.debug("Done") + + return lst + async def get_triples(self, workspace, document_id, receiver): logger.debug("Get triples...") @@ -417,3 +472,42 @@ class KnowledgeTableStore: logger.debug("Done") + async def get_document_embeddings(self, workspace, document_id, receiver): + + logger.debug("Get DE...") + + try: + rows = await async_execute( + self.cassandra, + self.get_document_embeddings_stmt, + (workspace, document_id), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise + + for row in rows: + + if row[3]: + chunks = [ + ChunkEmbeddings( + chunk_id=ch[0], + vector=ch[1], + ) + for ch in row[3] + ] + else: + chunks = [] + + await receiver( + DocumentEmbeddings( + metadata = Metadata( + id = document_id, + collection = "default", + ), + chunks = chunks + ) + ) + + logger.debug("Done") +