mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-16 19:05:14 +02:00
release/v2.4 -> master (#924)
* CLI auth migration, document embeddings core lifecycle (#913) Migrate get_kg_core and put_kg_core CLI tools to use Api/SocketClient with first-frame auth (fixes broken raw websocket path). Fix wire format field names (root/vector). Remove ~600 lines of dead raw websocket code from invoke_graph_rag.py. Add document embeddings core lifecycle to the knowledge service: list/get/put/delete/load operations across schema, translator, Cassandra table store, knowledge manager, gateway registry, REST API, socket client, and CLI (tg-get-de-core, tg-put-de-core). Fix delete_kg_core to also clean up document embeddings rows. * Remove spurious workspace parameter from SPARQL algebra evaluator (#915) Fix threading of workspace paramater: - The SPARQL algebra evaluator was threading a workspace parameter through every function and passing it to TriplesClient.query(), which doesn't accept it. Workspace isolation is handled by pub/sub topic routing — the TriplesClient is already scoped to a workspace-specific flow, same as GraphRAG. Passing workspace explicitly was both incorrect and unnecessary. Update tests: - tests/unit/test_query/test_sparql_algebra.py (new) — Tests _query_pattern, _eval_bgp, and evaluate() with various algebra nodes. Key tests assert workspace is never in tc.query() kwargs, plus correctness tests for BGP, JOIN, UNION, SLICE, DISTINCT, and edge cases. - tests/unit/test_retrieval/test_graph_rag.py — Added test_triples_query_never_passes_workspace (checks query()) and test_follow_edges_never_passes_workspace (checks query_stream()). * Make all Cassandra and Qdrant I/O async-safe with proper concurrency controls (#916) Cassandra triples services were using syncronous EntityCentricKnowledgeGraph methods from async contexts, and connection state was managed with threading.local which is wrong for asyncio coroutines sharing a single thread. Qdrant services had no async wrapping at all, blocking the event loop on every network call. Rows services had unprotected shared state mutations across concurrent coroutines. - Add async methods to EntityCentricKnowledgeGraph (async_insert, async_get_s/p/o/sp/po/os/spo/all, async_collection_exists, async_create_collection, async_delete_collection) using the existing cassandra_async.async_execute bridge - Rewrite triples write + query services: replace threading.local with asyncio.Lock + dict cache for per-workspace connections, use async ECKG methods for all data operations, keep asyncio.to_thread only for one-time blocking ECKG construction - Wrap all Qdrant calls in asyncio.to_thread across all 6 services (doc/graph/row embeddings write + query), add asyncio.Lock + set cache for collection existence checks - Add asyncio.Lock to rows write + query services to protect shared state (schemas, sessions, config caches) from concurrent mutation - Update all affected tests to match new async patterns * Fixed error only returning a page of results (#921) The root cause: async_execute only materialises the first result page (by design — it says so in its docstring). The streaming query set fetch_size=20 and expected to iterate all results, but only got the first 20 rows back. The fix uses asyncio.to_thread(lambda: list(tg.session.execute(...))) which lets the sync driver iterate all pages in a worker thread — exactly what the pre-async code did. * Optional test warning suppression (#923) * Fix test collection module errors & silence upstream Pytest warnings (#823) * chore: add virtual environment and .env directories to gitignore * test: filter upstream DeprecationWarning and UserWarning messages * fix(namespace): remove empty __init__.py files to fix PEP 420 implicit namespace routing for trustgraph sub-packages * Revert __init__.py deletions * Add .ini changes but commented out, will be useful at times --------- Co-authored-by: Salil M <d2kyt@protonmail.com>
This commit is contained in:
parent
159b1e2824
commit
142dd0231c
42 changed files with 1910 additions and 1492 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -17,3 +17,6 @@ trustgraph-unstructured/trustgraph/unstructured_version.py
|
|||
trustgraph-mcp/trustgraph/mcp_version.py
|
||||
trustgraph/trustgraph/trustgraph_version.py
|
||||
vertexai/
|
||||
venv/
|
||||
.venv/
|
||||
.env
|
||||
|
|
|
|||
|
|
@ -188,13 +188,14 @@ class TestConfigurationPriorityEndToEnd:
|
|||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.direct.cassandra_kg.Cluster')
|
||||
async def test_no_config_defaults_end_to_end(self, mock_cluster):
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
async def test_no_config_defaults_end_to_end(self, mock_kg_class):
|
||||
"""Test that defaults are used when no configuration provided end-to-end."""
|
||||
mock_cluster_instance = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_cluster_instance.connect.return_value = mock_session
|
||||
mock_cluster.return_value = mock_cluster_instance
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_get_all = AsyncMock(return_value=[])
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
processor = TriplesQuery(taskgroup=MagicMock())
|
||||
|
|
@ -205,20 +206,16 @@ class TestConfigurationPriorityEndToEnd:
|
|||
mock_query.s = None
|
||||
mock_query.p = None
|
||||
mock_query.o = None
|
||||
mock_query.g = None
|
||||
mock_query.limit = 100
|
||||
|
||||
# Mock the get_all method to return empty list
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.get_all.return_value = []
|
||||
processor.tg = mock_tg_instance
|
||||
|
||||
await processor.query_triples('default_user', mock_query)
|
||||
|
||||
# Should use defaults
|
||||
mock_cluster.assert_called_once()
|
||||
call_args = mock_cluster.call_args
|
||||
assert call_args.args[0] == ['cassandra'] # Default host
|
||||
assert 'auth_provider' not in call_args.kwargs # No auth with default config
|
||||
mock_kg_class.assert_called_once_with(
|
||||
hosts=['cassandra'],
|
||||
keyspace='default_user'
|
||||
)
|
||||
|
||||
|
||||
class TestNoBackwardCompatibilityEndToEnd:
|
||||
|
|
|
|||
|
|
@ -101,6 +101,8 @@ class TestRowsCassandraIntegration:
|
|||
processor.session = None
|
||||
|
||||
# Bind actual methods from the new unified table implementation
|
||||
import asyncio
|
||||
processor._setup_lock = asyncio.Lock()
|
||||
processor.connect_cassandra = Processor.connect_cassandra.__get__(processor, Processor)
|
||||
processor.ensure_keyspace = Processor.ensure_keyspace.__get__(processor, Processor)
|
||||
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)
|
||||
|
|
@ -108,6 +110,7 @@ class TestRowsCassandraIntegration:
|
|||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.build_index_value = Processor.build_index_value.__get__(processor, Processor)
|
||||
processor.register_partitions = Processor.register_partitions.__get__(processor, Processor)
|
||||
processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor)
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
processor.on_object = Processor.on_object.__get__(processor, Processor)
|
||||
processor.collection_exists = MagicMock(return_value=True)
|
||||
|
|
|
|||
|
|
@ -184,7 +184,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
|
||||
# Connect to Cassandra
|
||||
processor.connect_cassandra()
|
||||
await processor.connect_cassandra()
|
||||
assert processor.session is not None
|
||||
|
||||
# Create test keyspace and table
|
||||
|
|
@ -219,7 +219,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
"""Test inserting data and querying via GraphQL"""
|
||||
# Load schema and connect
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
await processor.connect_cassandra()
|
||||
|
||||
# Setup test data
|
||||
keyspace = "test_user"
|
||||
|
|
@ -293,7 +293,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
"""Test GraphQL queries with filtering on indexed fields"""
|
||||
# Setup (reuse previous setup)
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
await processor.connect_cassandra()
|
||||
|
||||
keyspace = "test_user"
|
||||
collection = "filter_test"
|
||||
|
|
@ -387,7 +387,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
"""Test full message processing workflow"""
|
||||
# Setup
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
await processor.connect_cassandra()
|
||||
|
||||
# Create mock message
|
||||
request = RowsQueryRequest(
|
||||
|
|
@ -433,7 +433,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
"""Test handling multiple concurrent GraphQL queries"""
|
||||
# Setup
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
await processor.connect_cassandra()
|
||||
|
||||
# Create multiple query tasks
|
||||
queries = [
|
||||
|
|
@ -519,7 +519,7 @@ class TestObjectsGraphQLQueryIntegration:
|
|||
"""Test handling of large query result sets"""
|
||||
# Setup
|
||||
await processor.on_schema_config("default", sample_schema_config, version=1)
|
||||
processor.connect_cassandra()
|
||||
await processor.connect_cassandra()
|
||||
|
||||
keyspace = "large_test_user"
|
||||
collection = "large_collection"
|
||||
|
|
|
|||
|
|
@ -17,3 +17,12 @@ markers =
|
|||
contract: marks tests as contract tests (service interface validation)
|
||||
vertexai: marks tests as vertex ai specific tests
|
||||
asyncio: marks tests that use asyncio
|
||||
# This is helpful if you're bored with deprecationwarnings. I prefer to
|
||||
# keep the warnings for now, it avoids masking problems.
|
||||
#
|
||||
# filterwarnings =
|
||||
# ignore:Core Pydantic V1 functionality isn't compatible with Python 3.14.*:UserWarning
|
||||
# ignore:builtin type SwigPyPacked has no __module__ attribute:DeprecationWarning
|
||||
# ignore:builtin type SwigPyObject has no __module__ attribute:DeprecationWarning
|
||||
# ignore:builtin type swigvarlink has no __module__ attribute:DeprecationWarning
|
||||
# ignore:.*_UnionGenericAlias.*is deprecated and slated for removal in Python 3.17:DeprecationWarning
|
||||
|
|
|
|||
|
|
@ -89,12 +89,15 @@ class TestRowsGraphQLQueryLogic:
|
|||
@pytest.mark.asyncio
|
||||
async def test_schema_config_parsing(self):
|
||||
"""Test parsing of schema configuration"""
|
||||
import asyncio
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.schema_builders = {}
|
||||
processor.graphql_schemas = {}
|
||||
processor.config_key = "schema"
|
||||
processor.query_cassandra = MagicMock()
|
||||
processor._setup_lock = asyncio.Lock()
|
||||
processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor)
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
# Create test config
|
||||
|
|
@ -335,7 +338,7 @@ class TestUnifiedTableQueries:
|
|||
"""Test query execution with matching index"""
|
||||
processor = MagicMock()
|
||||
processor.session = MagicMock()
|
||||
processor.connect_cassandra = MagicMock()
|
||||
processor.connect_cassandra = AsyncMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||
|
|
@ -396,7 +399,7 @@ class TestUnifiedTableQueries:
|
|||
"""Test query execution without matching index (scan mode)"""
|
||||
processor = MagicMock()
|
||||
processor.session = MagicMock()
|
||||
processor.connect_cassandra = MagicMock()
|
||||
processor.connect_cassandra = AsyncMock()
|
||||
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
|
||||
processor.get_index_names = Processor.get_index_names.__get__(processor, Processor)
|
||||
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
|
||||
|
|
|
|||
302
tests/unit/test_query/test_sparql_algebra.py
Normal file
302
tests/unit/test_query/test_sparql_algebra.py
Normal file
|
|
@ -0,0 +1,302 @@
|
|||
"""
|
||||
Tests for the SPARQL algebra evaluator.
|
||||
|
||||
Verifies that evaluate() and _query_pattern() call TriplesClient.query()
|
||||
with the correct arguments, and in particular that workspace is never
|
||||
passed — workspace isolation is handled by pub/sub topic routing.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, call
|
||||
|
||||
from rdflib.term import Variable, URIRef, Literal
|
||||
from rdflib.plugins.sparql.parserutils import CompValue
|
||||
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
from trustgraph.query.sparql.algebra import (
|
||||
evaluate, _query_pattern, _eval_bgp,
|
||||
)
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
def iri(v):
|
||||
return Term(type=IRI, iri=v)
|
||||
|
||||
|
||||
def lit(v):
|
||||
return Term(type=LITERAL, value=v)
|
||||
|
||||
|
||||
def make_triple(s, p, o):
|
||||
t = MagicMock()
|
||||
t.s = s
|
||||
t.p = p
|
||||
t.o = o
|
||||
return t
|
||||
|
||||
|
||||
def make_bgp(*patterns):
|
||||
"""Build a CompValue BGP node from (s, p, o) tuples of rdflib terms."""
|
||||
node = CompValue("BGP")
|
||||
node.triples = list(patterns)
|
||||
return node
|
||||
|
||||
|
||||
def make_project(inner, variables):
|
||||
node = CompValue("Project")
|
||||
node.p = inner
|
||||
node.PV = [Variable(v) for v in variables]
|
||||
return node
|
||||
|
||||
|
||||
def make_select(inner):
|
||||
node = CompValue("SelectQuery")
|
||||
node.p = inner
|
||||
return node
|
||||
|
||||
|
||||
def make_join(left, right):
|
||||
node = CompValue("Join")
|
||||
node.p1 = left
|
||||
node.p2 = right
|
||||
return node
|
||||
|
||||
|
||||
def make_union(left, right):
|
||||
node = CompValue("Union")
|
||||
node.p1 = left
|
||||
node.p2 = right
|
||||
return node
|
||||
|
||||
|
||||
def make_slice(inner, start, length):
|
||||
node = CompValue("Slice")
|
||||
node.p = inner
|
||||
node.start = start
|
||||
node.length = length
|
||||
return node
|
||||
|
||||
|
||||
def make_distinct(inner):
|
||||
node = CompValue("Distinct")
|
||||
node.p = inner
|
||||
return node
|
||||
|
||||
|
||||
class TestQueryPattern:
|
||||
"""Tests for _query_pattern — the leaf that calls TriplesClient."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_correct_args(self):
|
||||
tc = AsyncMock()
|
||||
tc.query.return_value = []
|
||||
|
||||
await _query_pattern(
|
||||
tc,
|
||||
s=iri("http://example.com/s"),
|
||||
p=iri("http://example.com/p"),
|
||||
o=None,
|
||||
collection="my-collection",
|
||||
limit=100,
|
||||
)
|
||||
|
||||
tc.query.assert_called_once_with(
|
||||
s=iri("http://example.com/s"),
|
||||
p=iri("http://example.com/p"),
|
||||
o=None,
|
||||
limit=100,
|
||||
collection="my-collection",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workspace_not_passed(self):
|
||||
tc = AsyncMock()
|
||||
tc.query.return_value = []
|
||||
|
||||
await _query_pattern(tc, None, None, None, "default", 10)
|
||||
|
||||
kwargs = tc.query.call_args.kwargs
|
||||
assert "workspace" not in kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_query_results(self):
|
||||
tc = AsyncMock()
|
||||
triple = make_triple(iri("http://a"), iri("http://b"), lit("c"))
|
||||
tc.query.return_value = [triple]
|
||||
|
||||
results = await _query_pattern(tc, None, None, None, "default", 10)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].s.iri == "http://a"
|
||||
|
||||
|
||||
class TestEvalBgp:
|
||||
"""Tests for BGP evaluation — triple pattern queries."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_pattern_all_variables(self):
|
||||
tc = AsyncMock()
|
||||
triple = make_triple(iri("http://s"), iri("http://p"), lit("o"))
|
||||
tc.query.return_value = [triple]
|
||||
|
||||
bgp = make_bgp(
|
||||
(Variable("s"), Variable("p"), Variable("o")),
|
||||
)
|
||||
|
||||
solutions = await evaluate(bgp, tc, collection="default", limit=100)
|
||||
|
||||
assert len(solutions) == 1
|
||||
assert solutions[0]["s"].iri == "http://s"
|
||||
assert solutions[0]["p"].iri == "http://p"
|
||||
assert solutions[0]["o"].value == "o"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_pattern_bound_subject(self):
|
||||
tc = AsyncMock()
|
||||
tc.query.return_value = [
|
||||
make_triple(iri("http://s"), iri("http://p"), lit("val")),
|
||||
]
|
||||
|
||||
bgp = make_bgp(
|
||||
(URIRef("http://s"), Variable("p"), Variable("o")),
|
||||
)
|
||||
|
||||
solutions = await evaluate(bgp, tc, collection="default")
|
||||
|
||||
tc.query.assert_called_once()
|
||||
kwargs = tc.query.call_args.kwargs
|
||||
assert "workspace" not in kwargs
|
||||
assert kwargs["collection"] == "default"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_bgp_returns_empty_solution(self):
|
||||
tc = AsyncMock()
|
||||
|
||||
bgp = make_bgp()
|
||||
|
||||
solutions = await evaluate(bgp, tc, collection="default")
|
||||
|
||||
assert solutions == [{}]
|
||||
tc.query.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_results_returns_empty(self):
|
||||
tc = AsyncMock()
|
||||
tc.query.return_value = []
|
||||
|
||||
bgp = make_bgp(
|
||||
(Variable("s"), Variable("p"), Variable("o")),
|
||||
)
|
||||
|
||||
solutions = await evaluate(bgp, tc, collection="default")
|
||||
|
||||
assert solutions == []
|
||||
|
||||
|
||||
class TestEvaluate:
|
||||
"""Tests for the top-level evaluate() dispatcher."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_select_query_node(self):
|
||||
tc = AsyncMock()
|
||||
tc.query.return_value = [
|
||||
make_triple(iri("http://s"), iri("http://p"), lit("o")),
|
||||
]
|
||||
|
||||
bgp = make_bgp(
|
||||
(Variable("s"), Variable("p"), Variable("o")),
|
||||
)
|
||||
select = make_select(make_project(bgp, ["s", "p"]))
|
||||
|
||||
solutions = await evaluate(select, tc, collection="default")
|
||||
|
||||
assert len(solutions) == 1
|
||||
assert "s" in solutions[0]
|
||||
assert "p" in solutions[0]
|
||||
assert "o" not in solutions[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workspace_never_in_query_calls(self):
|
||||
"""Verify that no matter the algebra structure, workspace is never
|
||||
passed to TriplesClient.query()."""
|
||||
tc = AsyncMock()
|
||||
tc.query.return_value = [
|
||||
make_triple(iri("http://s"), iri("http://p"), lit("o")),
|
||||
]
|
||||
|
||||
bgp1 = make_bgp((Variable("s"), Variable("p"), Variable("o")))
|
||||
bgp2 = make_bgp((Variable("a"), Variable("b"), Variable("c")))
|
||||
tree = make_select(make_project(
|
||||
make_union(bgp1, bgp2), ["s", "p", "o"]
|
||||
))
|
||||
|
||||
await evaluate(tree, tc, collection="test-coll")
|
||||
|
||||
for c in tc.query.call_args_list:
|
||||
assert "workspace" not in c.kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join(self):
|
||||
tc = AsyncMock()
|
||||
tc.query.side_effect = [
|
||||
[make_triple(iri("http://a"), iri("http://p"), lit("v"))],
|
||||
[make_triple(iri("http://a"), iri("http://q"), lit("w"))],
|
||||
]
|
||||
|
||||
bgp1 = make_bgp((Variable("s"), URIRef("http://p"), Variable("v1")))
|
||||
bgp2 = make_bgp((Variable("s"), URIRef("http://q"), Variable("v2")))
|
||||
tree = make_join(bgp1, bgp2)
|
||||
|
||||
solutions = await evaluate(tree, tc, collection="default")
|
||||
|
||||
assert len(solutions) == 1
|
||||
assert solutions[0]["s"].iri == "http://a"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slice(self):
|
||||
tc = AsyncMock()
|
||||
triples = [
|
||||
make_triple(iri(f"http://s{i}"), iri("http://p"), lit(f"o{i}"))
|
||||
for i in range(5)
|
||||
]
|
||||
tc.query.return_value = triples
|
||||
|
||||
bgp = make_bgp((Variable("s"), Variable("p"), Variable("o")))
|
||||
tree = make_slice(bgp, start=1, length=2)
|
||||
|
||||
solutions = await evaluate(tree, tc, collection="default")
|
||||
|
||||
assert len(solutions) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_distinct(self):
|
||||
tc = AsyncMock()
|
||||
triple = make_triple(iri("http://s"), iri("http://p"), lit("o"))
|
||||
tc.query.return_value = [triple, triple]
|
||||
|
||||
bgp = make_bgp((Variable("s"), Variable("p"), Variable("o")))
|
||||
tree = make_distinct(bgp)
|
||||
|
||||
solutions = await evaluate(tree, tc, collection="default")
|
||||
|
||||
assert len(solutions) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_node_returns_empty_solution(self):
|
||||
tc = AsyncMock()
|
||||
|
||||
node = CompValue("SomethingUnknown")
|
||||
|
||||
solutions = await evaluate(node, tc, collection="default")
|
||||
|
||||
assert solutions == [{}]
|
||||
tc.query.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_compvalue_returns_empty_solution(self):
|
||||
tc = AsyncMock()
|
||||
|
||||
solutions = await evaluate("not a node", tc, collection="default")
|
||||
|
||||
assert solutions == [{}]
|
||||
|
|
@ -2,8 +2,10 @@
|
|||
Tests for Cassandra triples query service
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
from trustgraph.query.triples.cassandra.service import Processor, create_term
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
|
|
@ -18,7 +20,7 @@ class TestCassandraQueryProcessor:
|
|||
return Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id='test-cassandra-query',
|
||||
graph_host='localhost'
|
||||
cassandra_host='localhost'
|
||||
)
|
||||
|
||||
def test_create_term_with_http_uri(self, processor):
|
||||
|
|
@ -85,7 +87,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_result.o = 'test_object'
|
||||
mock_tg_instance.get_spo.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
|
|
@ -110,8 +112,8 @@ class TestCassandraQueryProcessor:
|
|||
keyspace='test_user'
|
||||
)
|
||||
|
||||
# Verify get_spo was called with correct parameters
|
||||
mock_tg_instance.get_spo.assert_called_once_with(
|
||||
# Verify async_get_spo was called with correct parameters
|
||||
mock_tg_instance.async_get_spo.assert_called_once_with(
|
||||
'test_collection', 'test_subject', 'test_predicate', 'test_object', g=None, limit=100
|
||||
)
|
||||
|
||||
|
|
@ -130,7 +132,8 @@ class TestCassandraQueryProcessor:
|
|||
assert processor.cassandra_host == ['cassandra'] # Updated default
|
||||
assert processor.cassandra_username is None
|
||||
assert processor.cassandra_password is None
|
||||
assert processor.table is None
|
||||
assert processor._connections == {}
|
||||
assert isinstance(processor._conn_lock, asyncio.Lock)
|
||||
|
||||
def test_processor_initialization_with_custom_params(self):
|
||||
"""Test processor initialization with custom parameters"""
|
||||
|
|
@ -146,7 +149,8 @@ class TestCassandraQueryProcessor:
|
|||
assert processor.cassandra_host == ['cassandra.example.com']
|
||||
assert processor.cassandra_username == 'queryuser'
|
||||
assert processor.cassandra_password == 'querypass'
|
||||
assert processor.table is None
|
||||
assert processor._connections == {}
|
||||
assert isinstance(processor._conn_lock, asyncio.Lock)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
|
|
@ -164,7 +168,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_sp.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_sp = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -178,7 +182,7 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
|
||||
mock_tg_instance.async_get_sp.assert_called_once_with('test_collection', 'test_subject', 'test_predicate', g=None, limit=50)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.iri == 'test_subject'
|
||||
assert result[0].p.iri == 'test_predicate'
|
||||
|
|
@ -200,7 +204,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_s.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_s = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -214,7 +218,7 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
|
||||
mock_tg_instance.async_get_s.assert_called_once_with('test_collection', 'test_subject', g=None, limit=25)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.iri == 'test_subject'
|
||||
assert result[0].p.iri == 'result_predicate'
|
||||
|
|
@ -236,7 +240,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_p.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_p = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -250,7 +254,7 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
|
||||
mock_tg_instance.async_get_p.assert_called_once_with('test_collection', 'test_predicate', g=None, limit=10)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.iri == 'result_subject'
|
||||
assert result[0].p.iri == 'test_predicate'
|
||||
|
|
@ -272,7 +276,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_o.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_o = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -286,7 +290,7 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
|
||||
mock_tg_instance.async_get_o.assert_called_once_with('test_collection', 'test_object', g=None, limit=75)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.iri == 'result_subject'
|
||||
assert result[0].p.iri == 'result_predicate'
|
||||
|
|
@ -305,11 +309,11 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.s = 'all_subject'
|
||||
mock_result.p = 'all_predicate'
|
||||
mock_result.o = 'all_object'
|
||||
mock_result.g = ''
|
||||
mock_result.d = ''
|
||||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_all.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_all = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -323,7 +327,7 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
mock_tg_instance.get_all.assert_called_once_with('test_collection', limit=1000)
|
||||
mock_tg_instance.async_get_all.assert_called_once_with('test_collection', limit=1000)
|
||||
assert len(result) == 1
|
||||
assert result[0].s.iri == 'all_subject'
|
||||
assert result[0].p.iri == 'all_predicate'
|
||||
|
|
@ -410,7 +414,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_result.o = 'test_object'
|
||||
mock_tg_instance.get_spo.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
|
|
@ -451,7 +455,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_result.o = 'test_object'
|
||||
mock_tg_instance.get_spo.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_spo = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -489,8 +493,8 @@ class TestCassandraQueryProcessor:
|
|||
mock_result.lang = None
|
||||
mock_result.p = 'p'
|
||||
mock_result.o = 'o'
|
||||
mock_tg_instance1.get_s.return_value = [mock_result]
|
||||
mock_tg_instance2.get_s.return_value = [mock_result]
|
||||
mock_tg_instance1.async_get_s = AsyncMock(return_value=[mock_result])
|
||||
mock_tg_instance2.async_get_s = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -504,7 +508,6 @@ class TestCassandraQueryProcessor:
|
|||
)
|
||||
|
||||
await processor.query_triples('user1', query1)
|
||||
assert processor.table == 'user1'
|
||||
|
||||
# Second query with different table
|
||||
query2 = TriplesQueryRequest(
|
||||
|
|
@ -516,10 +519,11 @@ class TestCassandraQueryProcessor:
|
|||
)
|
||||
|
||||
await processor.query_triples('user2', query2)
|
||||
assert processor.table == 'user2'
|
||||
|
||||
# Verify TrustGraph was created twice
|
||||
# Verify TrustGraph was created twice for different workspaces
|
||||
assert mock_kg_class.call_count == 2
|
||||
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user1')
|
||||
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user2')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.query.triples.cassandra.service.EntityCentricKnowledgeGraph')
|
||||
|
|
@ -529,7 +533,7 @@ class TestCassandraQueryProcessor:
|
|||
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
mock_tg_instance.get_spo.side_effect = Exception("Query failed")
|
||||
mock_tg_instance.async_get_spo = AsyncMock(side_effect=Exception("Query failed"))
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -566,7 +570,7 @@ class TestCassandraQueryProcessor:
|
|||
mock_result2.otype = None
|
||||
mock_result2.dtype = None
|
||||
mock_result2.lang = None
|
||||
mock_tg_instance.get_sp.return_value = [mock_result1, mock_result2]
|
||||
mock_tg_instance.async_get_sp = AsyncMock(return_value=[mock_result1, mock_result2])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -603,7 +607,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_po.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_po = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -618,8 +622,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify get_po was called (should use optimized po_table)
|
||||
mock_tg_instance.get_po.assert_called_once_with(
|
||||
# Verify async_get_po was called (should use optimized po_table)
|
||||
mock_tg_instance.async_get_po.assert_called_once_with(
|
||||
'test_collection', 'test_predicate', 'test_object', g=None, limit=50
|
||||
)
|
||||
|
||||
|
|
@ -643,7 +647,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
mock_result.otype = None
|
||||
mock_result.dtype = None
|
||||
mock_result.lang = None
|
||||
mock_tg_instance.get_os.return_value = [mock_result]
|
||||
mock_tg_instance.async_get_os = AsyncMock(return_value=[mock_result])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -658,8 +662,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
result = await processor.query_triples('test_user', query)
|
||||
|
||||
# Verify get_os was called (should use optimized subject_table with clustering)
|
||||
mock_tg_instance.get_os.assert_called_once_with(
|
||||
# Verify async_get_os was called (should use optimized subject_table with clustering)
|
||||
mock_tg_instance.async_get_os.assert_called_once_with(
|
||||
'test_collection', 'test_object', 'test_subject', g=None, limit=25
|
||||
)
|
||||
|
||||
|
|
@ -678,28 +682,28 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
# Mock empty results for all queries
|
||||
mock_tg_instance.get_all.return_value = []
|
||||
mock_tg_instance.get_s.return_value = []
|
||||
mock_tg_instance.get_p.return_value = []
|
||||
mock_tg_instance.get_o.return_value = []
|
||||
mock_tg_instance.get_sp.return_value = []
|
||||
mock_tg_instance.get_po.return_value = []
|
||||
mock_tg_instance.get_os.return_value = []
|
||||
mock_tg_instance.get_spo.return_value = []
|
||||
mock_tg_instance.async_get_all = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_s = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_p = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_o = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_sp = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_po = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_os = AsyncMock(return_value=[])
|
||||
mock_tg_instance.async_get_spo = AsyncMock(return_value=[])
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
# Test each query pattern
|
||||
test_patterns = [
|
||||
# (s, p, o, expected_method)
|
||||
(None, None, None, 'get_all'), # All triples
|
||||
('s1', None, None, 'get_s'), # Subject only
|
||||
(None, 'p1', None, 'get_p'), # Predicate only
|
||||
(None, None, 'o1', 'get_o'), # Object only
|
||||
('s1', 'p1', None, 'get_sp'), # Subject + Predicate
|
||||
(None, 'p1', 'o1', 'get_po'), # Predicate + Object (CRITICAL OPTIMIZATION)
|
||||
('s1', None, 'o1', 'get_os'), # Object + Subject
|
||||
('s1', 'p1', 'o1', 'get_spo'), # All three
|
||||
(None, None, None, 'async_get_all'), # All triples
|
||||
('s1', None, None, 'async_get_s'), # Subject only
|
||||
(None, 'p1', None, 'async_get_p'), # Predicate only
|
||||
(None, None, 'o1', 'async_get_o'), # Object only
|
||||
('s1', 'p1', None, 'async_get_sp'), # Subject + Predicate
|
||||
(None, 'p1', 'o1', 'async_get_po'), # Predicate + Object (CRITICAL OPTIMIZATION)
|
||||
('s1', None, 'o1', 'async_get_os'), # Object + Subject
|
||||
('s1', 'p1', 'o1', 'async_get_spo'), # All three
|
||||
]
|
||||
|
||||
for s, p, o, expected_method in test_patterns:
|
||||
|
|
@ -759,7 +763,7 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
mock_result.lang = None
|
||||
mock_results.append(mock_result)
|
||||
|
||||
mock_tg_instance.get_po.return_value = mock_results
|
||||
mock_tg_instance.async_get_po = AsyncMock(return_value=mock_results)
|
||||
|
||||
processor = Processor(taskgroup=MagicMock())
|
||||
|
||||
|
|
@ -774,8 +778,8 @@ class TestCassandraQueryPerformanceOptimizations:
|
|||
|
||||
result = await processor.query_triples('large_dataset_user', query)
|
||||
|
||||
# Verify optimized get_po was used (no ALLOW FILTERING needed!)
|
||||
mock_tg_instance.get_po.assert_called_once_with(
|
||||
# Verify optimized async_get_po was used (no ALLOW FILTERING needed!)
|
||||
mock_tg_instance.async_get_po.assert_called_once_with(
|
||||
'massive_collection',
|
||||
'http://www.w3.org/1999/02/22-rdf-syntax-ns#type',
|
||||
'http://example.com/Person',
|
||||
|
|
|
|||
|
|
@ -113,12 +113,15 @@ class TestDocEmbeddingsNullProtection:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_embedding_upserted(self):
|
||||
import asyncio
|
||||
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
|
||||
|
||||
proc = Processor.__new__(Processor)
|
||||
proc.qdrant = MagicMock()
|
||||
proc.qdrant.collection_exists.return_value = True
|
||||
proc.collection_exists = MagicMock(return_value=True)
|
||||
proc._cache_lock = asyncio.Lock()
|
||||
proc._known_collections = set()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.collection = "col1"
|
||||
|
|
@ -134,12 +137,15 @@ class TestDocEmbeddingsNullProtection:
|
|||
@pytest.mark.asyncio
|
||||
async def test_dimension_in_collection_name(self):
|
||||
"""Collection name should include vector dimension."""
|
||||
import asyncio
|
||||
from trustgraph.storage.doc_embeddings.qdrant.write import Processor
|
||||
|
||||
proc = Processor.__new__(Processor)
|
||||
proc.qdrant = MagicMock()
|
||||
proc.qdrant.collection_exists.return_value = True
|
||||
proc.collection_exists = MagicMock(return_value=True)
|
||||
proc._cache_lock = asyncio.Lock()
|
||||
proc._known_collections = set()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.collection = "docs"
|
||||
|
|
@ -220,12 +226,15 @@ class TestGraphEmbeddingsNullProtection:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_entity_and_vector_upserted(self):
|
||||
import asyncio
|
||||
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
|
||||
|
||||
proc = Processor.__new__(Processor)
|
||||
proc.qdrant = MagicMock()
|
||||
proc.qdrant.collection_exists.return_value = True
|
||||
proc.collection_exists = MagicMock(return_value=True)
|
||||
proc._cache_lock = asyncio.Lock()
|
||||
proc._known_collections = set()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.collection = "col1"
|
||||
|
|
@ -241,12 +250,15 @@ class TestGraphEmbeddingsNullProtection:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lazy_collection_creation_on_new_dimension(self):
|
||||
import asyncio
|
||||
from trustgraph.storage.graph_embeddings.qdrant.write import Processor
|
||||
|
||||
proc = Processor.__new__(Processor)
|
||||
proc.qdrant = MagicMock()
|
||||
proc.qdrant.collection_exists.return_value = False
|
||||
proc.collection_exists = MagicMock(return_value=True)
|
||||
proc._cache_lock = asyncio.Lock()
|
||||
proc._known_collections = set()
|
||||
|
||||
msg = MagicMock()
|
||||
msg.metadata.collection = "graphs"
|
||||
|
|
|
|||
|
|
@ -337,6 +337,57 @@ class TestQuery:
|
|||
cache_key = "test_collection:unlabeled_entity"
|
||||
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_triples_query_never_passes_workspace(self):
|
||||
"""Workspace isolation is handled by pub/sub topic routing, not
|
||||
by passing workspace to TriplesClient.query(). Verify that
|
||||
GraphRAG never passes workspace as a keyword argument."""
|
||||
mock_rag = MagicMock()
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.get.return_value = None
|
||||
mock_rag.label_cache = mock_cache
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
mock_triple = MagicMock()
|
||||
mock_triple.o = "Label"
|
||||
mock_triples_client.query.return_value = [mock_triple]
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
|
||||
await query.maybe_label("http://example.com/entity")
|
||||
|
||||
for c in mock_triples_client.query.call_args_list:
|
||||
assert "workspace" not in c.kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_never_passes_workspace(self):
|
||||
"""Verify follow_edges never passes workspace to query_stream."""
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
|
||||
mock_triple = MagicMock()
|
||||
mock_triple.s, mock_triple.p, mock_triple.o = "e1", "p1", "o1"
|
||||
mock_triples_client.query_stream.return_value = [mock_triple]
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
triple_limit=10
|
||||
)
|
||||
|
||||
subgraph = set()
|
||||
await query.follow_edges("e1", subgraph, path_length=1)
|
||||
|
||||
for c in mock_triples_client.query_stream.call_args_list:
|
||||
assert "workspace" not in c.kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_basic_functionality(self):
|
||||
"""Test Query.follow_edges method basic triple discovery"""
|
||||
|
|
|
|||
|
|
@ -413,8 +413,8 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
# Assert
|
||||
expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions
|
||||
|
||||
# Verify collection existence is checked on each write
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with(expected_collection)
|
||||
# Second write uses cached collection state — no collection_exists check
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
|
||||
# But upsert should still be called
|
||||
mock_qdrant_instance.upsert.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -125,13 +125,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
processor = Processor(**config)
|
||||
|
||||
processor.ensure_collection("test_collection", 384)
|
||||
await processor.ensure_collection("test_collection", 384)
|
||||
|
||||
mock_qdrant_instance.collection_exists.assert_called_once_with("test_collection")
|
||||
mock_qdrant_instance.create_collection.assert_called_once()
|
||||
|
||||
# Verify the collection is cached
|
||||
assert "test_collection" in processor.created_collections
|
||||
assert "test_collection" in processor._known_collections
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_ensure_collection_skips_existing(self, mock_qdrant_client):
|
||||
|
|
@ -149,7 +149,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
|
||||
processor = Processor(**config)
|
||||
|
||||
processor.ensure_collection("existing_collection", 384)
|
||||
await processor.ensure_collection("existing_collection", 384)
|
||||
|
||||
mock_qdrant_instance.collection_exists.assert_called_once()
|
||||
mock_qdrant_instance.create_collection.assert_not_called()
|
||||
|
|
@ -168,9 +168,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.created_collections.add("cached_collection")
|
||||
processor._known_collections.add("cached_collection")
|
||||
|
||||
processor.ensure_collection("cached_collection", 384)
|
||||
await processor.ensure_collection("cached_collection", 384)
|
||||
|
||||
# Should not check or create - just return
|
||||
mock_qdrant_instance.collection_exists.assert_not_called()
|
||||
|
|
@ -391,7 +391,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
}
|
||||
|
||||
processor = Processor(**config)
|
||||
processor.created_collections.add('rows_test_workspace_test_collection_schema1_384')
|
||||
processor._known_collections.add('rows_test_workspace_test_collection_schema1_384')
|
||||
|
||||
await processor.delete_collection('test_workspace', 'test_collection')
|
||||
|
||||
|
|
@ -399,7 +399,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
|
|||
assert mock_qdrant_instance.delete_collection.call_count == 2
|
||||
|
||||
# Verify the cached collection was removed
|
||||
assert 'rows_test_workspace_test_collection_schema1_384' not in processor.created_collections
|
||||
assert 'rows_test_workspace_test_collection_schema1_384' not in processor._known_collections
|
||||
|
||||
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
|
||||
async def test_delete_collection_schema(self, mock_qdrant_client):
|
||||
|
|
|
|||
|
|
@ -121,10 +121,13 @@ class TestRowsCassandraStorageLogic:
|
|||
@pytest.mark.asyncio
|
||||
async def test_schema_config_parsing(self):
|
||||
"""Test parsing of schema configurations"""
|
||||
import asyncio
|
||||
processor = MagicMock()
|
||||
processor.schemas = {}
|
||||
processor.config_key = "schema"
|
||||
processor.registered_partitions = set()
|
||||
processor._setup_lock = asyncio.Lock()
|
||||
processor._apply_schema_config = Processor._apply_schema_config.__get__(processor, Processor)
|
||||
processor.on_schema_config = Processor.on_schema_config.__get__(processor, Processor)
|
||||
|
||||
# Create test configuration
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
Tests for Cassandra triples storage service
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
|
|
@ -24,7 +26,8 @@ class TestCassandraStorageProcessor:
|
|||
assert processor.cassandra_host == ['cassandra'] # Updated default
|
||||
assert processor.cassandra_username is None
|
||||
assert processor.cassandra_password is None
|
||||
assert processor.table is None
|
||||
assert processor._connections == {}
|
||||
assert isinstance(processor._conn_lock, asyncio.Lock)
|
||||
|
||||
def test_processor_initialization_with_custom_params(self):
|
||||
"""Test processor initialization with custom parameters (new cassandra_* names)"""
|
||||
|
|
@ -41,7 +44,8 @@ class TestCassandraStorageProcessor:
|
|||
assert processor.cassandra_host == ['cassandra.example.com']
|
||||
assert processor.cassandra_username == 'testuser'
|
||||
assert processor.cassandra_password == 'testpass'
|
||||
assert processor.table is None
|
||||
assert processor._connections == {}
|
||||
assert isinstance(processor._conn_lock, asyncio.Lock)
|
||||
|
||||
def test_processor_initialization_with_partial_auth(self):
|
||||
"""Test processor initialization with only username (no password)"""
|
||||
|
|
@ -92,6 +96,7 @@ class TestCassandraStorageProcessor:
|
|||
"""Test table switching logic when authentication is provided"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(
|
||||
|
|
@ -114,7 +119,6 @@ class TestCassandraStorageProcessor:
|
|||
username='testuser',
|
||||
password='testpass'
|
||||
)
|
||||
assert processor.table == 'user1'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
|
|
@ -122,6 +126,7 @@ class TestCassandraStorageProcessor:
|
|||
"""Test table switching logic when no authentication is provided"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -138,7 +143,6 @@ class TestCassandraStorageProcessor:
|
|||
hosts=['cassandra'], # Updated default
|
||||
keyspace='user2'
|
||||
)
|
||||
assert processor.table == 'user2'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
|
|
@ -146,6 +150,7 @@ class TestCassandraStorageProcessor:
|
|||
"""Test that TrustGraph is not recreated when table hasn't changed"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -169,6 +174,7 @@ class TestCassandraStorageProcessor:
|
|||
"""Test that triples are properly inserted into Cassandra"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -208,12 +214,12 @@ class TestCassandraStorageProcessor:
|
|||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters)
|
||||
assert mock_tg_instance.insert.call_count == 2
|
||||
mock_tg_instance.insert.assert_any_call(
|
||||
assert mock_tg_instance.async_insert.call_count == 2
|
||||
mock_tg_instance.async_insert.assert_any_call(
|
||||
'collection1', 'subject1', 'predicate1', 'object1',
|
||||
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
|
||||
)
|
||||
mock_tg_instance.insert.assert_any_call(
|
||||
mock_tg_instance.async_insert.assert_any_call(
|
||||
'collection1', 'subject2', 'predicate2', 'object2',
|
||||
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
|
||||
)
|
||||
|
|
@ -224,6 +230,7 @@ class TestCassandraStorageProcessor:
|
|||
"""Test behavior when message has no triples"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -236,19 +243,17 @@ class TestCassandraStorageProcessor:
|
|||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify no triples were inserted
|
||||
mock_tg_instance.insert.assert_not_called()
|
||||
mock_tg_instance.async_insert.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
@patch('trustgraph.storage.triples.cassandra.write.time.sleep')
|
||||
async def test_exception_handling_with_retry(self, mock_sleep, mock_kg_class):
|
||||
async def test_exception_handling_on_connection_failure(self, mock_kg_class):
|
||||
"""Test exception handling during TrustGraph creation"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_kg_class.side_effect = Exception("Connection failed")
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Create mock message
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
|
@ -256,9 +261,6 @@ class TestCassandraStorageProcessor:
|
|||
with pytest.raises(Exception, match="Connection failed"):
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify sleep was called before re-raising
|
||||
mock_sleep.assert_called_once_with(1)
|
||||
|
||||
def test_add_args_method(self):
|
||||
"""Test that add_args properly configures argument parser"""
|
||||
from argparse import ArgumentParser
|
||||
|
|
@ -359,8 +361,6 @@ class TestCassandraStorageProcessor:
|
|||
mock_message1.triples = []
|
||||
|
||||
await processor.store_triples('user1', mock_message1)
|
||||
assert processor.table == 'user1'
|
||||
assert processor.tg == mock_tg_instance1
|
||||
|
||||
# Second message with different table
|
||||
mock_message2 = MagicMock()
|
||||
|
|
@ -368,11 +368,11 @@ class TestCassandraStorageProcessor:
|
|||
mock_message2.triples = []
|
||||
|
||||
await processor.store_triples('user2', mock_message2)
|
||||
assert processor.table == 'user2'
|
||||
assert processor.tg == mock_tg_instance2
|
||||
|
||||
# Verify TrustGraph was created twice for different tables
|
||||
# Verify TrustGraph was created twice for different workspaces
|
||||
assert mock_kg_class.call_count == 2
|
||||
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user1')
|
||||
mock_kg_class.assert_any_call(hosts=['cassandra'], keyspace='user2')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
|
|
@ -380,6 +380,7 @@ class TestCassandraStorageProcessor:
|
|||
"""Test storing triples with special characters and unicode"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -405,7 +406,7 @@ class TestCassandraStorageProcessor:
|
|||
await processor.store_triples('test_workspace', mock_message)
|
||||
|
||||
# Verify the triple was inserted with special characters preserved
|
||||
mock_tg_instance.insert.assert_called_once_with(
|
||||
mock_tg_instance.async_insert.assert_called_once_with(
|
||||
'test_collection',
|
||||
'subject with spaces & symbols',
|
||||
'predicate:with/colons',
|
||||
|
|
@ -418,29 +419,29 @@ class TestCassandraStorageProcessor:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@patch('trustgraph.storage.triples.cassandra.write.EntityCentricKnowledgeGraph')
|
||||
async def test_store_triples_preserves_old_table_on_exception(self, mock_kg_class):
|
||||
"""Test that table remains unchanged when TrustGraph creation fails"""
|
||||
async def test_connection_failure_does_not_cache_stale_state(self, mock_kg_class):
|
||||
"""Test that a failed connection doesn't leave stale cached state"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_good_instance = MagicMock()
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
||||
# Set an initial table
|
||||
processor.table = ('old_user', 'old_collection')
|
||||
|
||||
# Mock TrustGraph to raise exception
|
||||
mock_kg_class.side_effect = Exception("Connection failed")
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.metadata.collection = 'new_collection'
|
||||
mock_message.metadata.collection = 'collection1'
|
||||
mock_message.triples = []
|
||||
|
||||
# First call fails
|
||||
mock_kg_class.side_effect = Exception("Connection failed")
|
||||
with pytest.raises(Exception, match="Connection failed"):
|
||||
await processor.store_triples('new_user', mock_message)
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Table should remain unchanged since self.table = table happens after try/except
|
||||
assert processor.table == ('old_user', 'old_collection')
|
||||
# TrustGraph should be set to None though
|
||||
assert processor.tg is None
|
||||
# Second call succeeds — should retry connection, not use stale state
|
||||
mock_kg_class.side_effect = None
|
||||
mock_kg_class.return_value = mock_good_instance
|
||||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Connection was attempted twice (failed + succeeded)
|
||||
assert mock_kg_class.call_count == 2
|
||||
|
||||
|
||||
class TestCassandraPerformanceOptimizations:
|
||||
|
|
@ -452,6 +453,7 @@ class TestCassandraPerformanceOptimizations:
|
|||
"""Test that legacy mode still works with single table"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'true'}):
|
||||
|
|
@ -472,6 +474,7 @@ class TestCassandraPerformanceOptimizations:
|
|||
"""Test that optimized mode uses multi-table schema"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
with patch.dict('os.environ', {'CASSANDRA_USE_LEGACY': 'false'}):
|
||||
|
|
@ -492,6 +495,7 @@ class TestCassandraPerformanceOptimizations:
|
|||
"""Test that all tables stay consistent during batch writes"""
|
||||
taskgroup_mock = MagicMock()
|
||||
mock_tg_instance = MagicMock()
|
||||
mock_tg_instance.async_insert = AsyncMock()
|
||||
mock_kg_class.return_value = mock_tg_instance
|
||||
|
||||
processor = Processor(taskgroup=taskgroup_mock)
|
||||
|
|
@ -517,7 +521,7 @@ class TestCassandraPerformanceOptimizations:
|
|||
await processor.store_triples('user1', mock_message)
|
||||
|
||||
# Verify insert was called for the triple (implementation details tested in KnowledgeGraph)
|
||||
mock_tg_instance.insert.assert_called_once_with(
|
||||
mock_tg_instance.async_insert.assert_called_once_with(
|
||||
'collection1', 'test_subject', 'test_predicate', 'test_object',
|
||||
g=DEFAULT_GRAPH, otype='l', dtype='', lang=''
|
||||
)
|
||||
|
|
|
|||
|
|
@ -89,7 +89,8 @@ class TestSanitizeName:
|
|||
|
||||
class TestFindCollection:
|
||||
|
||||
def test_finds_matching_collection(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_finds_matching_collection(self):
|
||||
proc = _make_processor()
|
||||
mock_coll = MagicMock()
|
||||
mock_coll.name = "rows_test_workspace_test_col_customers_384"
|
||||
|
|
@ -98,11 +99,12 @@ class TestFindCollection:
|
|||
mock_collections.collections = [mock_coll]
|
||||
proc.qdrant.get_collections.return_value = mock_collections
|
||||
|
||||
result = proc.find_collection("test-workspace", "test-col", "customers")
|
||||
result = await proc.find_collection("test-workspace", "test-col", "customers")
|
||||
|
||||
assert result == "rows_test_workspace_test_col_customers_384"
|
||||
|
||||
def test_returns_none_when_no_match(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_no_match(self):
|
||||
proc = _make_processor()
|
||||
mock_coll = MagicMock()
|
||||
mock_coll.name = "rows_other_workspace_other_col_schema_768"
|
||||
|
|
@ -111,14 +113,15 @@ class TestFindCollection:
|
|||
mock_collections.collections = [mock_coll]
|
||||
proc.qdrant.get_collections.return_value = mock_collections
|
||||
|
||||
result = proc.find_collection("test-workspace", "test-col", "customers")
|
||||
result = await proc.find_collection("test-workspace", "test-col", "customers")
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_on_error(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_error(self):
|
||||
proc = _make_processor()
|
||||
proc.qdrant.get_collections.side_effect = Exception("connection error")
|
||||
|
||||
result = proc.find_collection("workspace", "col", "schema")
|
||||
result = await proc.find_collection("workspace", "col", "schema")
|
||||
assert result is None
|
||||
|
||||
|
||||
|
|
@ -139,7 +142,7 @@ class TestQueryRowEmbeddings:
|
|||
@pytest.mark.asyncio
|
||||
async def test_no_collection_returns_empty(self):
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value=None)
|
||||
proc.find_collection = AsyncMock(return_value=None)
|
||||
request = _make_request()
|
||||
|
||||
result = await proc.query_row_embeddings("test-workspace", request)
|
||||
|
|
@ -148,7 +151,7 @@ class TestQueryRowEmbeddings:
|
|||
@pytest.mark.asyncio
|
||||
async def test_successful_query_returns_matches(self):
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||
|
||||
points = [
|
||||
_make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95),
|
||||
|
|
@ -172,7 +175,7 @@ class TestQueryRowEmbeddings:
|
|||
async def test_index_name_filter_applied(self):
|
||||
"""When index_name is specified, a Qdrant filter should be used."""
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.points = []
|
||||
|
|
@ -188,7 +191,7 @@ class TestQueryRowEmbeddings:
|
|||
async def test_no_index_name_no_filter(self):
|
||||
"""When index_name is empty, no filter should be applied."""
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.points = []
|
||||
|
|
@ -204,7 +207,7 @@ class TestQueryRowEmbeddings:
|
|||
async def test_missing_payload_fields_default(self):
|
||||
"""Points with missing payload fields should use defaults."""
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||
|
||||
point = MagicMock()
|
||||
point.payload = {} # Empty payload
|
||||
|
|
@ -225,7 +228,7 @@ class TestQueryRowEmbeddings:
|
|||
@pytest.mark.asyncio
|
||||
async def test_qdrant_error_propagates(self):
|
||||
proc = _make_processor()
|
||||
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
|
||||
proc.find_collection = AsyncMock(return_value="rows_w_c_s_384")
|
||||
proc.qdrant.query_points.side_effect = Exception("qdrant down")
|
||||
|
||||
request = _make_request()
|
||||
|
|
|
|||
|
|
@ -132,3 +132,34 @@ class Knowledge:
|
|||
|
||||
self.request(request = input)
|
||||
|
||||
def list_de_cores(self):
|
||||
|
||||
input = {
|
||||
"operation": "list-de-cores",
|
||||
"workspace": self.api.workspace,
|
||||
}
|
||||
|
||||
return self.request(request = input)["ids"]
|
||||
|
||||
def delete_de_core(self, id):
|
||||
|
||||
input = {
|
||||
"operation": "delete-de-core",
|
||||
"workspace": self.api.workspace,
|
||||
"id": id,
|
||||
}
|
||||
|
||||
self.request(request = input)
|
||||
|
||||
def load_de_core(self, id, flow="default", collection="default"):
|
||||
|
||||
input = {
|
||||
"operation": "load-de-core",
|
||||
"workspace": self.api.workspace,
|
||||
"id": id,
|
||||
"flow": flow,
|
||||
"collection": collection,
|
||||
}
|
||||
|
||||
self.request(request = input)
|
||||
|
||||
|
|
|
|||
|
|
@ -491,6 +491,58 @@ class SocketClient:
|
|||
triples=raw_triples,
|
||||
)
|
||||
|
||||
def get_kg_core(self, id: str) -> Iterator[Dict[str, Any]]:
|
||||
request = {
|
||||
"operation": "get-kg-core",
|
||||
"workspace": self.workspace,
|
||||
"id": id,
|
||||
}
|
||||
for response in self._send_request_sync(
|
||||
"knowledge", None, request, streaming_raw=True,
|
||||
):
|
||||
if response.get("eos"):
|
||||
break
|
||||
yield response
|
||||
|
||||
def put_kg_core(
|
||||
self, id: str, triples=None, graph_embeddings=None,
|
||||
) -> Dict[str, Any]:
|
||||
request = {
|
||||
"operation": "put-kg-core",
|
||||
"workspace": self.workspace,
|
||||
"id": id,
|
||||
}
|
||||
if triples is not None:
|
||||
request["triples"] = triples
|
||||
if graph_embeddings is not None:
|
||||
request["graph-embeddings"] = graph_embeddings
|
||||
return self._send_request_sync("knowledge", None, request)
|
||||
|
||||
def get_de_core(self, id: str) -> Iterator[Dict[str, Any]]:
|
||||
request = {
|
||||
"operation": "get-de-core",
|
||||
"workspace": self.workspace,
|
||||
"id": id,
|
||||
}
|
||||
for response in self._send_request_sync(
|
||||
"knowledge", None, request, streaming_raw=True,
|
||||
):
|
||||
if response.get("eos"):
|
||||
break
|
||||
yield response
|
||||
|
||||
def put_de_core(
|
||||
self, id: str, document_embeddings=None,
|
||||
) -> Dict[str, Any]:
|
||||
request = {
|
||||
"operation": "put-de-core",
|
||||
"workspace": self.workspace,
|
||||
"id": id,
|
||||
}
|
||||
if document_embeddings is not None:
|
||||
request["document-embeddings"] = document_embeddings
|
||||
return self._send_request_sync("knowledge", None, request)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the persistent WebSocket connection."""
|
||||
if self._loop and not self._loop.is_closed():
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Dict, Any, Tuple, Optional
|
||||
from ...schema import (
|
||||
KnowledgeRequest, KnowledgeResponse, Triples, GraphEmbeddings,
|
||||
DocumentEmbeddings, ChunkEmbeddings,
|
||||
Metadata, EntityEmbeddings
|
||||
)
|
||||
from .base import MessageTranslator
|
||||
|
|
@ -43,6 +44,23 @@ class KnowledgeRequestTranslator(MessageTranslator):
|
|||
]
|
||||
)
|
||||
|
||||
document_embeddings = None
|
||||
if "document-embeddings" in data:
|
||||
document_embeddings = DocumentEmbeddings(
|
||||
metadata=Metadata(
|
||||
id=data["document-embeddings"]["metadata"]["id"],
|
||||
root=data["document-embeddings"]["metadata"].get("root", ""),
|
||||
collection=data["document-embeddings"]["metadata"]["collection"]
|
||||
),
|
||||
chunks=[
|
||||
ChunkEmbeddings(
|
||||
chunk_id=ch["chunk_id"],
|
||||
vector=ch["vector"],
|
||||
)
|
||||
for ch in data["document-embeddings"]["chunks"]
|
||||
]
|
||||
)
|
||||
|
||||
return KnowledgeRequest(
|
||||
operation=data.get("operation"),
|
||||
id=data.get("id"),
|
||||
|
|
@ -50,6 +68,7 @@ class KnowledgeRequestTranslator(MessageTranslator):
|
|||
collection=data.get("collection"),
|
||||
triples=triples,
|
||||
graph_embeddings=graph_embeddings,
|
||||
document_embeddings=document_embeddings,
|
||||
)
|
||||
|
||||
def encode(self, obj: KnowledgeRequest) -> Dict[str, Any]:
|
||||
|
|
@ -90,6 +109,22 @@ class KnowledgeRequestTranslator(MessageTranslator):
|
|||
],
|
||||
}
|
||||
|
||||
if obj.document_embeddings:
|
||||
result["document-embeddings"] = {
|
||||
"metadata": {
|
||||
"id": obj.document_embeddings.metadata.id,
|
||||
"root": obj.document_embeddings.metadata.root,
|
||||
"collection": obj.document_embeddings.metadata.collection,
|
||||
},
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_id": ch.chunk_id,
|
||||
"vector": ch.vector,
|
||||
}
|
||||
for ch in obj.document_embeddings.chunks
|
||||
],
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -140,6 +175,25 @@ class KnowledgeResponseTranslator(MessageTranslator):
|
|||
}
|
||||
}
|
||||
|
||||
# Streaming document embeddings response
|
||||
if obj.document_embeddings:
|
||||
return {
|
||||
"document-embeddings": {
|
||||
"metadata": {
|
||||
"id": obj.document_embeddings.metadata.id,
|
||||
"root": obj.document_embeddings.metadata.root,
|
||||
"collection": obj.document_embeddings.metadata.collection,
|
||||
},
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_id": ch.chunk_id,
|
||||
"vector": ch.vector,
|
||||
}
|
||||
for ch in obj.document_embeddings.chunks
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
# End of stream marker
|
||||
if obj.eos is True:
|
||||
return {"eos": True}
|
||||
|
|
@ -155,7 +209,7 @@ class KnowledgeResponseTranslator(MessageTranslator):
|
|||
is_final = (
|
||||
obj.ids is not None or # List response
|
||||
obj.eos is True or # End of stream
|
||||
(not obj.triples and not obj.graph_embeddings) # Empty response
|
||||
(not obj.triples and not obj.graph_embeddings and not obj.document_embeddings) # Empty response
|
||||
)
|
||||
|
||||
return response, is_final
|
||||
|
|
@ -4,7 +4,7 @@ from ..core.topic import queue
|
|||
from ..core.metadata import Metadata
|
||||
from .document import Document, TextDocument
|
||||
from .graph import Triples
|
||||
from .embeddings import GraphEmbeddings
|
||||
from .embeddings import GraphEmbeddings, DocumentEmbeddings
|
||||
|
||||
# get-kg-core
|
||||
# -> (???)
|
||||
|
|
@ -41,6 +41,9 @@ class KnowledgeRequest:
|
|||
triples: Triples | None = None
|
||||
graph_embeddings: GraphEmbeddings | None = None
|
||||
|
||||
# put-de-core
|
||||
document_embeddings: DocumentEmbeddings | None = None
|
||||
|
||||
@dataclass
|
||||
class KnowledgeResponse:
|
||||
error: Error | None = None
|
||||
|
|
@ -48,6 +51,7 @@ class KnowledgeResponse:
|
|||
eos: bool = False # Indicates end of knowledge core stream
|
||||
triples: Triples | None = None
|
||||
graph_embeddings: GraphEmbeddings | None = None
|
||||
document_embeddings: DocumentEmbeddings | None = None
|
||||
|
||||
knowledge_request_queue = queue('knowledge', cls='request')
|
||||
knowledge_response_queue = queue('knowledge', cls='response')
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ tg-dump-msgpack = "trustgraph.cli.dump_msgpack:main"
|
|||
tg-dump-queues = "trustgraph.cli.dump_queues:main"
|
||||
tg-monitor-prompts = "trustgraph.cli.monitor_prompts:main"
|
||||
tg-get-flow-blueprint = "trustgraph.cli.get_flow_blueprint:main"
|
||||
tg-get-de-core = "trustgraph.cli.get_de_core:main"
|
||||
tg-get-kg-core = "trustgraph.cli.get_kg_core:main"
|
||||
tg-get-document-content = "trustgraph.cli.get_document_content:main"
|
||||
tg-graph-to-turtle = "trustgraph.cli.graph_to_turtle:main"
|
||||
|
|
@ -77,6 +78,7 @@ tg-load-turtle = "trustgraph.cli.load_turtle:main"
|
|||
tg-load-knowledge = "trustgraph.cli.load_knowledge:main"
|
||||
tg-load-structured-data = "trustgraph.cli.load_structured_data:main"
|
||||
tg-put-flow-blueprint = "trustgraph.cli.put_flow_blueprint:main"
|
||||
tg-put-de-core = "trustgraph.cli.put_de_core:main"
|
||||
tg-put-kg-core = "trustgraph.cli.put_kg_core:main"
|
||||
tg-remove-library-document = "trustgraph.cli.remove_library_document:main"
|
||||
tg-save-doc-embeds = "trustgraph.cli.save_doc_embeds:main"
|
||||
|
|
|
|||
111
trustgraph-cli/trustgraph/cli/get_de_core.py
Normal file
111
trustgraph-cli/trustgraph/cli/get_de_core.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
"""
|
||||
Uses the knowledge service to fetch a document embeddings core which is
|
||||
saved to a local file in msgpack format.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import msgpack
|
||||
|
||||
from trustgraph.api import Api
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
|
||||
|
||||
def write_de(f, data):
|
||||
msg = (
|
||||
"de",
|
||||
{
|
||||
"m": {
|
||||
"i": data["metadata"]["id"],
|
||||
"m": data["metadata"]["root"],
|
||||
"c": data["metadata"]["collection"],
|
||||
},
|
||||
"c": [
|
||||
{
|
||||
"i": ch["chunk_id"],
|
||||
"v": ch["vector"],
|
||||
}
|
||||
for ch in data["chunks"]
|
||||
]
|
||||
}
|
||||
)
|
||||
f.write(msgpack.packb(msg, use_bin_type=True))
|
||||
|
||||
def fetch(url, workspace, id, output, token=None):
|
||||
|
||||
api = Api(url=url, token=token, workspace=workspace)
|
||||
socket = api.socket()
|
||||
|
||||
try:
|
||||
de = 0
|
||||
|
||||
with open(output, "wb") as f:
|
||||
|
||||
for response in socket.get_de_core(id):
|
||||
|
||||
if "document-embeddings" in response:
|
||||
de += 1
|
||||
write_de(f, response["document-embeddings"])
|
||||
|
||||
print(f"Got: {de} document embeddings messages.")
|
||||
|
||||
finally:
|
||||
socket.close()
|
||||
|
||||
def main():
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='tg-get-de-core',
|
||||
description=__doc__,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-u', '--url',
|
||||
default=default_url,
|
||||
help=f'API URL (default: {default_url})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-w', '--workspace',
|
||||
default=default_workspace,
|
||||
help=f'Workspace (default: {default_workspace})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--id', '--identifier',
|
||||
required=True,
|
||||
help=f'Document embeddings core ID',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-o', '--output',
|
||||
required=True,
|
||||
help=f'Output file'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--token',
|
||||
default=default_token,
|
||||
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
||||
fetch(
|
||||
url=args.url,
|
||||
workspace=args.workspace,
|
||||
id=args.id,
|
||||
output=args.output,
|
||||
token=args.token,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print("Exception:", e, flush=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -5,13 +5,11 @@ to a local file in msgpack format.
|
|||
|
||||
import argparse
|
||||
import os
|
||||
import uuid
|
||||
import asyncio
|
||||
import json
|
||||
from websockets.asyncio.client import connect
|
||||
import msgpack
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/')
|
||||
from trustgraph.api import Api
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
|
||||
|
||||
|
|
@ -21,7 +19,7 @@ def write_triple(f, data):
|
|||
{
|
||||
"m": {
|
||||
"i": data["metadata"]["id"],
|
||||
"m": data["metadata"]["metadata"],
|
||||
"m": data["metadata"]["root"],
|
||||
"c": data["metadata"]["collection"],
|
||||
},
|
||||
"t": data["triples"],
|
||||
|
|
@ -35,13 +33,13 @@ def write_ge(f, data):
|
|||
{
|
||||
"m": {
|
||||
"i": data["metadata"]["id"],
|
||||
"m": data["metadata"]["metadata"],
|
||||
"m": data["metadata"]["root"],
|
||||
"c": data["metadata"]["collection"],
|
||||
},
|
||||
"e": [
|
||||
{
|
||||
"e": ent["entity"],
|
||||
"v": ent["vectors"],
|
||||
"v": ent["vector"],
|
||||
}
|
||||
for ent in data["entities"]
|
||||
]
|
||||
|
|
@ -49,54 +47,18 @@ def write_ge(f, data):
|
|||
)
|
||||
f.write(msgpack.packb(msg, use_bin_type=True))
|
||||
|
||||
async def fetch(url, workspace, id, output, token=None):
|
||||
def fetch(url, workspace, id, output, token=None):
|
||||
|
||||
if not url.endswith("/"):
|
||||
url += "/"
|
||||
|
||||
url = url + "api/v1/socket"
|
||||
|
||||
if token:
|
||||
url = f"{url}?token={token}"
|
||||
|
||||
mid = str(uuid.uuid4())
|
||||
|
||||
async with connect(url) as ws:
|
||||
|
||||
req = json.dumps({
|
||||
"id": mid,
|
||||
"workspace": workspace,
|
||||
"service": "knowledge",
|
||||
"request": {
|
||||
"operation": "get-kg-core",
|
||||
"workspace": workspace,
|
||||
"id": id,
|
||||
}
|
||||
})
|
||||
|
||||
await ws.send(req)
|
||||
api = Api(url=url, token=token, workspace=workspace)
|
||||
socket = api.socket()
|
||||
|
||||
try:
|
||||
ge = 0
|
||||
t = 0
|
||||
|
||||
with open(output, "wb") as f:
|
||||
|
||||
while True:
|
||||
|
||||
msg = await ws.recv()
|
||||
|
||||
obj = json.loads(msg)
|
||||
|
||||
if "response" not in obj:
|
||||
raise RuntimeError("No response?")
|
||||
|
||||
response = obj["response"]
|
||||
|
||||
if "error" in response:
|
||||
raise RuntimeError(obj["error"])
|
||||
|
||||
if "eos" in response:
|
||||
if response["eos"]: break
|
||||
for response in socket.get_kg_core(id):
|
||||
|
||||
if "triples" in response:
|
||||
t += 1
|
||||
|
|
@ -108,7 +70,8 @@ async def fetch(url, workspace, id, output, token=None):
|
|||
|
||||
print(f"Got: {t} triple, {ge} GE messages.")
|
||||
|
||||
await ws.close()
|
||||
finally:
|
||||
socket.close()
|
||||
|
||||
def main():
|
||||
|
||||
|
|
@ -151,7 +114,6 @@ def main():
|
|||
|
||||
try:
|
||||
|
||||
asyncio.run(
|
||||
fetch(
|
||||
url=args.url,
|
||||
workspace=args.workspace,
|
||||
|
|
@ -159,7 +121,6 @@ def main():
|
|||
output=args.output,
|
||||
token=args.token,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
|
|
|
|||
|
|
@ -3,11 +3,8 @@ Uses the GraphRAG service to answer a question
|
|||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import websockets
|
||||
import asyncio
|
||||
from trustgraph.api import (
|
||||
Api,
|
||||
ExplainabilityClient,
|
||||
|
|
@ -31,607 +28,6 @@ default_max_path_length = 2
|
|||
default_edge_score_limit = 30
|
||||
default_edge_limit = 25
|
||||
|
||||
# Provenance predicates
|
||||
TG = "https://trustgraph.ai/ns/"
|
||||
TG_QUERY = TG + "query"
|
||||
TG_CONCEPT = TG + "concept"
|
||||
TG_ENTITY = TG + "entity"
|
||||
TG_EDGE_COUNT = TG + "edgeCount"
|
||||
TG_SELECTED_EDGE = TG + "selectedEdge"
|
||||
TG_EDGE = TG + "edge"
|
||||
TG_REASONING = TG + "reasoning"
|
||||
TG_DOCUMENT = TG + "document"
|
||||
TG_CONTAINS = TG + "contains"
|
||||
PROV = "http://www.w3.org/ns/prov#"
|
||||
PROV_STARTED_AT_TIME = PROV + "startedAtTime"
|
||||
PROV_WAS_DERIVED_FROM = PROV + "wasDerivedFrom"
|
||||
RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
||||
|
||||
def _get_event_type(prov_id):
|
||||
"""Extract event type from provenance_id"""
|
||||
if "question" in prov_id:
|
||||
return "question"
|
||||
elif "grounding" in prov_id:
|
||||
return "grounding"
|
||||
elif "exploration" in prov_id:
|
||||
return "exploration"
|
||||
elif "focus" in prov_id:
|
||||
return "focus"
|
||||
elif "synthesis" in prov_id:
|
||||
return "synthesis"
|
||||
return "provenance"
|
||||
|
||||
|
||||
def _format_provenance_details(event_type, triples):
|
||||
"""Format provenance details based on event type and triples"""
|
||||
lines = []
|
||||
|
||||
if event_type == "question":
|
||||
# Show query and timestamp
|
||||
for s, p, o in triples:
|
||||
if p == TG_QUERY:
|
||||
lines.append(f" Query: {o}")
|
||||
elif p == PROV_STARTED_AT_TIME:
|
||||
lines.append(f" Time: {o}")
|
||||
|
||||
elif event_type == "grounding":
|
||||
# Show extracted concepts
|
||||
concepts = [o for s, p, o in triples if p == TG_CONCEPT]
|
||||
if concepts:
|
||||
lines.append(f" Concepts: {len(concepts)}")
|
||||
for concept in concepts:
|
||||
lines.append(f" - {concept}")
|
||||
|
||||
elif event_type == "exploration":
|
||||
# Show edge count (seed entities resolved separately with labels)
|
||||
for s, p, o in triples:
|
||||
if p == TG_EDGE_COUNT:
|
||||
lines.append(f" Edges explored: {o}")
|
||||
|
||||
elif event_type == "focus":
|
||||
# For focus, just count edge selection URIs
|
||||
# The actual edge details are fetched separately via edge_selections parameter
|
||||
edge_sel_uris = []
|
||||
for s, p, o in triples:
|
||||
if p == TG_SELECTED_EDGE:
|
||||
edge_sel_uris.append(o)
|
||||
if edge_sel_uris:
|
||||
lines.append(f" Focused on {len(edge_sel_uris)} edge(s)")
|
||||
|
||||
elif event_type == "synthesis":
|
||||
# Show document reference (content already streamed)
|
||||
for s, p, o in triples:
|
||||
if p == TG_DOCUMENT:
|
||||
lines.append(f" Document: {o}")
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
async def _query_triples_once(ws_url, flow_id, prov_id, collection, graph=None, debug=False):
|
||||
"""Query triples for a provenance node (single attempt)"""
|
||||
request = {
|
||||
"id": "triples-request",
|
||||
"service": "triples",
|
||||
"flow": flow_id,
|
||||
"request": {
|
||||
"s": {"t": "i", "i": prov_id},
|
||||
"collection": collection,
|
||||
"limit": 100
|
||||
}
|
||||
}
|
||||
# Add graph filter if specified (for named graph queries)
|
||||
if graph is not None:
|
||||
request["request"]["g"] = graph
|
||||
|
||||
if debug:
|
||||
print(f" [debug] querying triples for s={prov_id}", file=sys.stderr)
|
||||
|
||||
triples = []
|
||||
try:
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
|
||||
await websocket.send(json.dumps(request))
|
||||
|
||||
async for raw_message in websocket:
|
||||
response = json.loads(raw_message)
|
||||
|
||||
if debug:
|
||||
print(f" [debug] response: {json.dumps(response)[:200]}", file=sys.stderr)
|
||||
|
||||
if response.get("id") != "triples-request":
|
||||
continue
|
||||
|
||||
if "error" in response:
|
||||
if debug:
|
||||
print(f" [debug] error: {response['error']}", file=sys.stderr)
|
||||
break
|
||||
|
||||
if "response" in response:
|
||||
resp = response["response"]
|
||||
# Handle triples response
|
||||
# Response format: {"response": [triples...]}
|
||||
# Each triple uses compact keys: "i" for iri, "v" for value, "t" for type
|
||||
triple_list = resp.get("response", [])
|
||||
for t in triple_list:
|
||||
s = t.get("s", {}).get("i", t.get("s", {}).get("v", ""))
|
||||
p = t.get("p", {}).get("i", t.get("p", {}).get("v", ""))
|
||||
# Handle quoted triples (type "t") and regular values
|
||||
o_term = t.get("o", {})
|
||||
if o_term.get("t") == "t":
|
||||
# Quoted triple - extract s, p, o from nested structure
|
||||
tr = o_term.get("tr", {})
|
||||
o = {
|
||||
"s": tr.get("s", {}).get("i", ""),
|
||||
"p": tr.get("p", {}).get("i", ""),
|
||||
"o": tr.get("o", {}).get("i", tr.get("o", {}).get("v", "")),
|
||||
}
|
||||
else:
|
||||
o = o_term.get("i", o_term.get("v", ""))
|
||||
triples.append((s, p, o))
|
||||
|
||||
if resp.get("complete") or response.get("complete"):
|
||||
break
|
||||
except Exception as e:
|
||||
if debug:
|
||||
print(f" [debug] exception: {e}", file=sys.stderr)
|
||||
|
||||
if debug:
|
||||
print(f" [debug] got {len(triples)} triples", file=sys.stderr)
|
||||
|
||||
return triples
|
||||
|
||||
|
||||
async def _query_triples(ws_url, flow_id, prov_id, collection, graph=None, max_retries=5, retry_delay=0.2, debug=False):
|
||||
"""Query triples for a provenance node with retries for race condition"""
|
||||
for attempt in range(max_retries):
|
||||
triples = await _query_triples_once(ws_url, flow_id, prov_id, collection, graph=graph, debug=debug)
|
||||
if triples:
|
||||
return triples
|
||||
# Wait before retry if empty (triples may not be stored yet)
|
||||
if attempt < max_retries - 1:
|
||||
if debug:
|
||||
print(f" [debug] retry {attempt + 1}/{max_retries}...", file=sys.stderr)
|
||||
await asyncio.sleep(retry_delay)
|
||||
return []
|
||||
|
||||
|
||||
async def _query_edge_provenance(ws_url, flow_id, edge_s, edge_p, edge_o, collection, debug=False):
|
||||
"""
|
||||
Query for provenance of an edge (s, p, o) in the knowledge graph.
|
||||
|
||||
Finds subgraphs that contain the edge via tg:contains, then follows
|
||||
prov:wasDerivedFrom to find source documents.
|
||||
|
||||
Returns list of source URIs (chunks, pages, documents).
|
||||
"""
|
||||
# Query for subgraphs that contain this edge: ?subgraph tg:contains <<s p o>>
|
||||
request = {
|
||||
"id": "edge-prov-request",
|
||||
"service": "triples",
|
||||
"flow": flow_id,
|
||||
"request": {
|
||||
"p": {"t": "i", "i": TG_CONTAINS},
|
||||
"o": {
|
||||
"t": "t", # Quoted triple type
|
||||
"tr": {
|
||||
"s": {"t": "i", "i": edge_s},
|
||||
"p": {"t": "i", "i": edge_p},
|
||||
"o": {"t": "i", "i": edge_o} if edge_o.startswith("http") or edge_o.startswith("urn:") else {"t": "l", "v": edge_o},
|
||||
}
|
||||
},
|
||||
"collection": collection,
|
||||
"limit": 10
|
||||
}
|
||||
}
|
||||
|
||||
if debug:
|
||||
print(f" [debug] querying edge provenance for ({edge_s}, {edge_p}, {edge_o})", file=sys.stderr)
|
||||
|
||||
stmt_uris = []
|
||||
try:
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
|
||||
await websocket.send(json.dumps(request))
|
||||
|
||||
async for raw_message in websocket:
|
||||
response = json.loads(raw_message)
|
||||
|
||||
if response.get("id") != "edge-prov-request":
|
||||
continue
|
||||
|
||||
if "error" in response:
|
||||
if debug:
|
||||
print(f" [debug] error: {response['error']}", file=sys.stderr)
|
||||
break
|
||||
|
||||
if "response" in response:
|
||||
resp = response["response"]
|
||||
triple_list = resp.get("response", [])
|
||||
for t in triple_list:
|
||||
s = t.get("s", {}).get("i", "")
|
||||
if s:
|
||||
stmt_uris.append(s)
|
||||
|
||||
if resp.get("complete") or response.get("complete"):
|
||||
break
|
||||
except Exception as e:
|
||||
if debug:
|
||||
print(f" [debug] exception querying edge provenance: {e}", file=sys.stderr)
|
||||
|
||||
if debug:
|
||||
print(f" [debug] found {len(stmt_uris)} reifying statements", file=sys.stderr)
|
||||
|
||||
# For each statement, query wasDerivedFrom to find sources
|
||||
sources = []
|
||||
for stmt_uri in stmt_uris:
|
||||
# Query: stmt_uri prov:wasDerivedFrom ?source
|
||||
request = {
|
||||
"id": "derived-from-request",
|
||||
"service": "triples",
|
||||
"flow": flow_id,
|
||||
"request": {
|
||||
"s": {"t": "i", "i": stmt_uri},
|
||||
"p": {"t": "i", "i": PROV_WAS_DERIVED_FROM},
|
||||
"collection": collection,
|
||||
"limit": 10
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
|
||||
await websocket.send(json.dumps(request))
|
||||
|
||||
async for raw_message in websocket:
|
||||
response = json.loads(raw_message)
|
||||
|
||||
if response.get("id") != "derived-from-request":
|
||||
continue
|
||||
|
||||
if "error" in response:
|
||||
break
|
||||
|
||||
if "response" in response:
|
||||
resp = response["response"]
|
||||
triple_list = resp.get("response", [])
|
||||
for t in triple_list:
|
||||
o = t.get("o", {}).get("i", "")
|
||||
if o:
|
||||
sources.append(o)
|
||||
|
||||
if resp.get("complete") or response.get("complete"):
|
||||
break
|
||||
except Exception as e:
|
||||
if debug:
|
||||
print(f" [debug] exception querying wasDerivedFrom: {e}", file=sys.stderr)
|
||||
|
||||
if debug:
|
||||
print(f" [debug] found {len(sources)} source(s): {sources}", file=sys.stderr)
|
||||
|
||||
return sources
|
||||
|
||||
|
||||
async def _query_derived_from(ws_url, flow_id, uri, collection, debug=False):
|
||||
"""Query for the prov:wasDerivedFrom parent of a URI. Returns None if no parent."""
|
||||
request = {
|
||||
"id": "parent-request",
|
||||
"service": "triples",
|
||||
"flow": flow_id,
|
||||
"request": {
|
||||
"s": {"t": "i", "i": uri},
|
||||
"p": {"t": "i", "i": PROV_WAS_DERIVED_FROM},
|
||||
"collection": collection,
|
||||
"limit": 1
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
|
||||
await websocket.send(json.dumps(request))
|
||||
|
||||
async for raw_message in websocket:
|
||||
response = json.loads(raw_message)
|
||||
|
||||
if response.get("id") != "parent-request":
|
||||
continue
|
||||
|
||||
if "error" in response:
|
||||
break
|
||||
|
||||
if "response" in response:
|
||||
resp = response["response"]
|
||||
triple_list = resp.get("response", [])
|
||||
if triple_list:
|
||||
return triple_list[0].get("o", {}).get("i", None)
|
||||
|
||||
if resp.get("complete") or response.get("complete"):
|
||||
break
|
||||
except Exception as e:
|
||||
if debug:
|
||||
print(f" [debug] exception querying parent: {e}", file=sys.stderr)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _trace_provenance_chain(ws_url, flow_id, source_uri, collection, label_cache, debug=False):
|
||||
"""
|
||||
Trace the full provenance chain from a source URI up to the root document.
|
||||
Returns a list of (uri, label) tuples from leaf to root.
|
||||
"""
|
||||
chain = []
|
||||
current = source_uri
|
||||
max_depth = 10 # Prevent infinite loops
|
||||
|
||||
for _ in range(max_depth):
|
||||
if not current:
|
||||
break
|
||||
|
||||
# Get label for current entity
|
||||
label = await _query_label(ws_url, flow_id, current, collection, label_cache, debug)
|
||||
chain.append((current, label))
|
||||
|
||||
# Get parent
|
||||
parent = await _query_derived_from(ws_url, flow_id, current, collection, debug)
|
||||
if not parent or parent == current:
|
||||
break
|
||||
current = parent
|
||||
|
||||
return chain
|
||||
|
||||
|
||||
def _format_provenance_chain(chain):
|
||||
"""
|
||||
Format a provenance chain as a human-readable string.
|
||||
Chain is [(uri, label), ...] from leaf to root.
|
||||
"""
|
||||
if not chain:
|
||||
return ""
|
||||
|
||||
# Show labels, from leaf to root
|
||||
labels = [label for uri, label in chain]
|
||||
return " → ".join(labels)
|
||||
|
||||
|
||||
def _is_iri(value):
|
||||
"""Check if a value looks like an IRI."""
|
||||
if not isinstance(value, str):
|
||||
return False
|
||||
return value.startswith("http://") or value.startswith("https://") or value.startswith("urn:")
|
||||
|
||||
|
||||
async def _query_label(ws_url, flow_id, iri, collection, label_cache, debug=False):
|
||||
"""
|
||||
Query for the rdfs:label of an IRI.
|
||||
Uses label_cache to avoid repeated queries.
|
||||
Returns the label if found, otherwise returns the IRI.
|
||||
"""
|
||||
if not _is_iri(iri):
|
||||
return iri
|
||||
|
||||
# Check cache first
|
||||
if iri in label_cache:
|
||||
return label_cache[iri]
|
||||
|
||||
request = {
|
||||
"id": "label-request",
|
||||
"service": "triples",
|
||||
"flow": flow_id,
|
||||
"request": {
|
||||
"s": {"t": "i", "i": iri},
|
||||
"p": {"t": "i", "i": RDFS_LABEL},
|
||||
"collection": collection,
|
||||
"limit": 1
|
||||
}
|
||||
}
|
||||
|
||||
label = iri # Default to IRI if no label found
|
||||
try:
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=30) as websocket:
|
||||
await websocket.send(json.dumps(request))
|
||||
|
||||
async for raw_message in websocket:
|
||||
response = json.loads(raw_message)
|
||||
|
||||
if response.get("id") != "label-request":
|
||||
continue
|
||||
|
||||
if "error" in response:
|
||||
break
|
||||
|
||||
if "response" in response:
|
||||
resp = response["response"]
|
||||
triple_list = resp.get("response", [])
|
||||
if triple_list:
|
||||
# Get the label value
|
||||
o = triple_list[0].get("o", {})
|
||||
label = o.get("v", o.get("i", iri))
|
||||
|
||||
if resp.get("complete") or response.get("complete"):
|
||||
break
|
||||
except Exception as e:
|
||||
if debug:
|
||||
print(f" [debug] exception querying label for {iri}: {e}", file=sys.stderr)
|
||||
|
||||
# Cache the result
|
||||
label_cache[iri] = label
|
||||
return label
|
||||
|
||||
|
||||
async def _resolve_edge_labels(ws_url, flow_id, edge_triple, collection, label_cache, debug=False):
|
||||
"""
|
||||
Resolve labels for all IRI components of an edge triple.
|
||||
Returns (s_label, p_label, o_label).
|
||||
"""
|
||||
s = edge_triple.get("s", "?")
|
||||
p = edge_triple.get("p", "?")
|
||||
o = edge_triple.get("o", "?")
|
||||
|
||||
s_label = await _query_label(ws_url, flow_id, s, collection, label_cache, debug)
|
||||
p_label = await _query_label(ws_url, flow_id, p, collection, label_cache, debug)
|
||||
o_label = await _query_label(ws_url, flow_id, o, collection, label_cache, debug)
|
||||
|
||||
return s_label, p_label, o_label
|
||||
|
||||
|
||||
async def _question_explainable(
|
||||
url, flow_id, question, collection, entity_limit, triple_limit,
|
||||
max_subgraph_size, max_path_length, token=None, debug=False
|
||||
):
|
||||
"""Execute graph RAG with explainability - shows provenance events with details"""
|
||||
# Convert HTTP URL to WebSocket URL
|
||||
if url.startswith("http://"):
|
||||
ws_url = url.replace("http://", "ws://", 1)
|
||||
elif url.startswith("https://"):
|
||||
ws_url = url.replace("https://", "wss://", 1)
|
||||
else:
|
||||
ws_url = f"ws://{url}"
|
||||
|
||||
ws_url = f"{ws_url.rstrip('/')}/api/v1/socket"
|
||||
if token:
|
||||
ws_url = f"{ws_url}?token={token}"
|
||||
|
||||
# Cache for label lookups to avoid repeated queries
|
||||
label_cache = {}
|
||||
|
||||
request = {
|
||||
"id": "cli-request",
|
||||
"service": "graph-rag",
|
||||
"flow": flow_id,
|
||||
"request": {
|
||||
"query": question,
|
||||
"collection": collection,
|
||||
"entity-limit": entity_limit,
|
||||
"triple-limit": triple_limit,
|
||||
"max-subgraph-size": max_subgraph_size,
|
||||
"max-path-length": max_path_length,
|
||||
"streaming": True
|
||||
}
|
||||
}
|
||||
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=300) as websocket:
|
||||
await websocket.send(json.dumps(request))
|
||||
|
||||
async for raw_message in websocket:
|
||||
response = json.loads(raw_message)
|
||||
|
||||
if response.get("id") != "cli-request":
|
||||
continue
|
||||
|
||||
if "error" in response:
|
||||
print(f"\nError: {response['error']}", file=sys.stderr)
|
||||
break
|
||||
|
||||
if "response" in response:
|
||||
resp = response["response"]
|
||||
|
||||
# Check for errors in response
|
||||
if "error" in resp and resp["error"]:
|
||||
err = resp["error"]
|
||||
print(f"\nError: {err.get('message', 'Unknown error')}", file=sys.stderr)
|
||||
break
|
||||
|
||||
message_type = resp.get("message_type", "")
|
||||
|
||||
if debug:
|
||||
print(f" [debug] message_type={message_type}, keys={list(resp.keys())}", file=sys.stderr)
|
||||
|
||||
if message_type == "explain":
|
||||
# Display explain event with details
|
||||
explain_id = resp.get("explain_id", "")
|
||||
explain_graph = resp.get("explain_graph") # Named graph (e.g., urn:graph:retrieval)
|
||||
if explain_id:
|
||||
event_type = _get_event_type(explain_id)
|
||||
print(f"\n [{event_type}] {explain_id}", file=sys.stderr)
|
||||
|
||||
# Query triples for this explain node (using named graph filter)
|
||||
triples = await _query_triples(
|
||||
ws_url, flow_id, explain_id, collection, graph=explain_graph, debug=debug
|
||||
)
|
||||
|
||||
# Format and display details
|
||||
details = _format_provenance_details(event_type, triples)
|
||||
for line in details:
|
||||
print(line, file=sys.stderr)
|
||||
|
||||
# For exploration events, resolve entity labels
|
||||
if event_type == "exploration":
|
||||
entity_iris = [o for s, p, o in triples if p == TG_ENTITY]
|
||||
if entity_iris:
|
||||
print(f" Seed entities: {len(entity_iris)}", file=sys.stderr)
|
||||
for iri in entity_iris:
|
||||
label = await _query_label(
|
||||
ws_url, flow_id, iri, collection,
|
||||
label_cache, debug=debug
|
||||
)
|
||||
print(f" - {label}", file=sys.stderr)
|
||||
|
||||
# For focus events, query each edge selection for details
|
||||
if event_type == "focus":
|
||||
for s, p, o in triples:
|
||||
if debug:
|
||||
print(f" [debug] triple: p={p}, o={o}, o_type={type(o).__name__}", file=sys.stderr)
|
||||
if p == TG_SELECTED_EDGE and isinstance(o, str):
|
||||
if debug:
|
||||
print(f" [debug] querying edge selection: {o}", file=sys.stderr)
|
||||
# Query the edge selection entity (using named graph filter)
|
||||
edge_triples = await _query_triples(
|
||||
ws_url, flow_id, o, collection, graph=explain_graph, debug=debug
|
||||
)
|
||||
if debug:
|
||||
print(f" [debug] got {len(edge_triples)} edge triples", file=sys.stderr)
|
||||
# Extract edge and reasoning
|
||||
edge_triple = None # Store the actual triple for provenance lookup
|
||||
reasoning = None
|
||||
for es, ep, eo in edge_triples:
|
||||
if debug:
|
||||
print(f" [debug] edge triple: ep={ep}, eo={eo}", file=sys.stderr)
|
||||
if ep == TG_EDGE and isinstance(eo, dict):
|
||||
# eo is a quoted triple dict
|
||||
edge_triple = eo
|
||||
elif ep == TG_REASONING:
|
||||
reasoning = eo
|
||||
if edge_triple:
|
||||
# Resolve labels for edge components
|
||||
s_label, p_label, o_label = await _resolve_edge_labels(
|
||||
ws_url, flow_id, edge_triple, collection,
|
||||
label_cache, debug=debug
|
||||
)
|
||||
print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr)
|
||||
if reasoning:
|
||||
r_short = reasoning[:100] + "..." if len(reasoning) > 100 else reasoning
|
||||
print(f" Reason: {r_short}", file=sys.stderr)
|
||||
|
||||
# Trace edge provenance in the workspace collection (not explainability)
|
||||
if edge_triple:
|
||||
sources = await _query_edge_provenance(
|
||||
ws_url, flow_id,
|
||||
edge_triple.get("s", ""),
|
||||
edge_triple.get("p", ""),
|
||||
edge_triple.get("o", ""),
|
||||
collection, # Use the query collection, not explainability
|
||||
debug=debug
|
||||
)
|
||||
if sources:
|
||||
for src in sources:
|
||||
# Trace full chain from source to root document
|
||||
chain = await _trace_provenance_chain(
|
||||
ws_url, flow_id, src, collection,
|
||||
label_cache, debug=debug
|
||||
)
|
||||
chain_str = _format_provenance_chain(chain)
|
||||
print(f" Source: {chain_str}", file=sys.stderr)
|
||||
|
||||
elif message_type == "chunk" or not message_type:
|
||||
# Display response chunk
|
||||
chunk = resp.get("response", "")
|
||||
if chunk:
|
||||
print(chunk, end="", flush=True)
|
||||
|
||||
# Check if session is complete
|
||||
if resp.get("end_of_session"):
|
||||
break
|
||||
|
||||
print() # Final newline
|
||||
|
||||
|
||||
def _question_explainable_api(
|
||||
url, flow_id, question_text, collection, entity_limit, triple_limit,
|
||||
max_subgraph_size, max_path_length, edge_score_limit=30,
|
||||
|
|
|
|||
119
trustgraph-cli/trustgraph/cli/put_de_core.py
Normal file
119
trustgraph-cli/trustgraph/cli/put_de_core.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
"""
|
||||
Puts a document embeddings core into the knowledge manager via the API
|
||||
socket.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import msgpack
|
||||
|
||||
from trustgraph.api import Api
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
|
||||
|
||||
def read_message(unpacked, id):
|
||||
|
||||
if unpacked[0] == "de":
|
||||
msg = unpacked[1]
|
||||
return {
|
||||
"metadata": {
|
||||
"id": id,
|
||||
"root": msg["m"]["m"],
|
||||
"collection": "default",
|
||||
},
|
||||
"chunks": [
|
||||
{
|
||||
"chunk_id": ch["i"],
|
||||
"vector": ch["v"],
|
||||
}
|
||||
for ch in msg["c"]
|
||||
],
|
||||
}
|
||||
else:
|
||||
raise RuntimeError("Unexpected message type", unpacked[0])
|
||||
|
||||
def put(url, workspace, id, input, token=None):
|
||||
|
||||
api = Api(url=url, token=token, workspace=workspace)
|
||||
socket = api.socket()
|
||||
|
||||
try:
|
||||
de = 0
|
||||
|
||||
with open(input, "rb") as f:
|
||||
|
||||
unpacker = msgpack.Unpacker(f, raw=False)
|
||||
|
||||
while True:
|
||||
|
||||
try:
|
||||
unpacked = unpacker.unpack()
|
||||
except msgpack.OutOfData:
|
||||
break
|
||||
|
||||
msg = read_message(unpacked, id)
|
||||
de += 1
|
||||
socket.put_de_core(id, document_embeddings=msg)
|
||||
|
||||
print(f"Put: {de} document embeddings messages.")
|
||||
|
||||
finally:
|
||||
socket.close()
|
||||
|
||||
def main():
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='tg-put-de-core',
|
||||
description=__doc__,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-u', '--url',
|
||||
default=default_url,
|
||||
help=f'API URL (default: {default_url})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-w', '--workspace',
|
||||
default=default_workspace,
|
||||
help=f'Workspace (default: {default_workspace})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--id', '--identifier',
|
||||
required=True,
|
||||
help=f'Document embeddings core ID',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-i', '--input',
|
||||
required=True,
|
||||
help=f'Input file'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--token',
|
||||
default=default_token,
|
||||
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
||||
put(
|
||||
url=args.url,
|
||||
workspace=args.workspace,
|
||||
id=args.id,
|
||||
input=args.input,
|
||||
token=args.token,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print("Exception:", e, flush=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -4,13 +4,11 @@ Puts a knowledge core into the knowledge manager via the API socket.
|
|||
|
||||
import argparse
|
||||
import os
|
||||
import uuid
|
||||
import asyncio
|
||||
import json
|
||||
from websockets.asyncio.client import connect
|
||||
import msgpack
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/')
|
||||
from trustgraph.api import Api
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
|
||||
|
||||
|
|
@ -21,13 +19,13 @@ def read_message(unpacked, id):
|
|||
return "ge", {
|
||||
"metadata": {
|
||||
"id": id,
|
||||
"metadata": msg["m"]["m"],
|
||||
"collection": "default", # Not used?
|
||||
"root": msg["m"]["m"],
|
||||
"collection": "default",
|
||||
},
|
||||
"entities": [
|
||||
{
|
||||
"entity": ent["e"],
|
||||
"vectors": ent["v"],
|
||||
"vector": ent["v"],
|
||||
}
|
||||
for ent in msg["e"]
|
||||
],
|
||||
|
|
@ -37,26 +35,20 @@ def read_message(unpacked, id):
|
|||
return "t", {
|
||||
"metadata": {
|
||||
"id": id,
|
||||
"metadata": msg["m"]["m"],
|
||||
"collection": "default", # Not used by receiver?
|
||||
"root": msg["m"]["m"],
|
||||
"collection": "default",
|
||||
},
|
||||
"triples": msg["t"],
|
||||
}
|
||||
else:
|
||||
raise RuntimeError("Unpacked unexpected messsage type", unpacked[0])
|
||||
|
||||
async def put(url, workspace, id, input, token=None):
|
||||
def put(url, workspace, id, input, token=None):
|
||||
|
||||
if not url.endswith("/"):
|
||||
url += "/"
|
||||
|
||||
url = url + "api/v1/socket"
|
||||
|
||||
if token:
|
||||
url = f"{url}?token={token}"
|
||||
|
||||
async with connect(url) as ws:
|
||||
api = Api(url=url, token=token, workspace=workspace)
|
||||
socket = api.socket()
|
||||
|
||||
try:
|
||||
ge = 0
|
||||
t = 0
|
||||
|
||||
|
|
@ -68,69 +60,26 @@ async def put(url, workspace, id, input, token=None):
|
|||
|
||||
try:
|
||||
unpacked = unpacker.unpack()
|
||||
except:
|
||||
except msgpack.OutOfData:
|
||||
break
|
||||
|
||||
kind, msg = read_message(unpacked, id)
|
||||
|
||||
mid = str(uuid.uuid4())
|
||||
|
||||
if kind == "ge":
|
||||
|
||||
ge += 1
|
||||
|
||||
req = json.dumps({
|
||||
"id": mid,
|
||||
"workspace": workspace,
|
||||
"service": "knowledge",
|
||||
"request": {
|
||||
"operation": "put-kg-core",
|
||||
"workspace": workspace,
|
||||
"id": id,
|
||||
"graph-embeddings": msg
|
||||
}
|
||||
})
|
||||
socket.put_kg_core(id, graph_embeddings=msg)
|
||||
|
||||
elif kind == "t":
|
||||
|
||||
t += 1
|
||||
|
||||
req = json.dumps({
|
||||
"id": mid,
|
||||
"workspace": workspace,
|
||||
"service": "knowledge",
|
||||
"request": {
|
||||
"operation": "put-kg-core",
|
||||
"workspace": workspace,
|
||||
"id": id,
|
||||
"triples": msg
|
||||
}
|
||||
})
|
||||
socket.put_kg_core(id, triples=msg)
|
||||
|
||||
else:
|
||||
|
||||
raise RuntimeError("Unexpected message kind", kind)
|
||||
|
||||
await ws.send(req)
|
||||
|
||||
# Retry loop, wait for right response to come back
|
||||
while True:
|
||||
|
||||
msg = await ws.recv()
|
||||
msg = json.loads(msg)
|
||||
|
||||
if msg["id"] != mid:
|
||||
continue
|
||||
|
||||
if "response" in msg:
|
||||
if "error" in msg["response"]:
|
||||
raise RuntimeError(msg["response"]["error"])
|
||||
|
||||
break
|
||||
|
||||
print(f"Put: {t} triple, {ge} GE messages.")
|
||||
|
||||
await ws.close()
|
||||
finally:
|
||||
socket.close()
|
||||
|
||||
def main():
|
||||
|
||||
|
|
@ -173,7 +122,6 @@ def main():
|
|||
|
||||
try:
|
||||
|
||||
asyncio.run(
|
||||
put(
|
||||
url=args.url,
|
||||
workspace=args.workspace,
|
||||
|
|
@ -181,7 +129,6 @@ def main():
|
|||
input=args.input,
|
||||
token=args.token,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
|
||||
from .. schema import KnowledgeResponse, Error, Triples, GraphEmbeddings
|
||||
from .. schema import DocumentEmbeddings
|
||||
from .. knowledge import hash
|
||||
from .. exceptions import RequestError
|
||||
from .. tables.knowledge import KnowledgeTableStore
|
||||
|
|
@ -157,6 +158,98 @@ class KnowledgeManager:
|
|||
)
|
||||
)
|
||||
|
||||
async def list_de_cores(self, request, respond, workspace):
|
||||
|
||||
ids = await self.table_store.list_de_cores(workspace)
|
||||
|
||||
await respond(
|
||||
KnowledgeResponse(
|
||||
error = None,
|
||||
ids = ids,
|
||||
eos = False,
|
||||
triples = None,
|
||||
graph_embeddings = None,
|
||||
)
|
||||
)
|
||||
|
||||
async def get_de_core(self, request, respond, workspace):
|
||||
|
||||
logger.info("Getting document embeddings core...")
|
||||
|
||||
async def publish_de(de):
|
||||
await respond(
|
||||
KnowledgeResponse(
|
||||
error = None,
|
||||
ids = None,
|
||||
eos = False,
|
||||
triples = None,
|
||||
graph_embeddings = None,
|
||||
document_embeddings = de,
|
||||
)
|
||||
)
|
||||
|
||||
await self.table_store.get_document_embeddings(
|
||||
workspace,
|
||||
request.id,
|
||||
publish_de,
|
||||
)
|
||||
|
||||
logger.debug("Document embeddings core retrieval complete")
|
||||
|
||||
await respond(
|
||||
KnowledgeResponse(
|
||||
error = None,
|
||||
ids = None,
|
||||
eos = True,
|
||||
triples = None,
|
||||
graph_embeddings = None,
|
||||
)
|
||||
)
|
||||
|
||||
async def put_de_core(self, request, respond, workspace):
|
||||
|
||||
if request.document_embeddings:
|
||||
await self.table_store.add_document_embeddings(
|
||||
workspace, request.document_embeddings
|
||||
)
|
||||
|
||||
await respond(
|
||||
KnowledgeResponse(
|
||||
error = None,
|
||||
ids = None,
|
||||
eos = False,
|
||||
triples = None,
|
||||
graph_embeddings = None,
|
||||
)
|
||||
)
|
||||
|
||||
async def delete_de_core(self, request, respond, workspace):
|
||||
|
||||
logger.info("Deleting document embeddings core...")
|
||||
|
||||
await self.table_store.delete_document_embeddings(
|
||||
workspace, request.id
|
||||
)
|
||||
|
||||
await respond(
|
||||
KnowledgeResponse(
|
||||
error = None,
|
||||
ids = None,
|
||||
eos = False,
|
||||
triples = None,
|
||||
graph_embeddings = None,
|
||||
)
|
||||
)
|
||||
|
||||
async def load_de_core(self, request, respond, workspace):
|
||||
|
||||
if self.background_task is None:
|
||||
self.background_task = asyncio.create_task(
|
||||
self.core_loader()
|
||||
)
|
||||
|
||||
await self.loader_queue.put((request, respond, workspace))
|
||||
|
||||
async def core_loader(self):
|
||||
|
||||
logger.info("Knowledge background processor running...")
|
||||
|
|
@ -165,7 +258,7 @@ class KnowledgeManager:
|
|||
logger.debug("Waiting for next load...")
|
||||
request, respond, workspace = await self.loader_queue.get()
|
||||
|
||||
logger.info(f"Loading knowledge: {request.id}")
|
||||
logger.info(f"Loading: {request.operation} {request.id}")
|
||||
|
||||
try:
|
||||
|
||||
|
|
@ -187,24 +280,13 @@ class KnowledgeManager:
|
|||
if "interfaces" not in flow:
|
||||
raise RuntimeError("No defined interfaces")
|
||||
|
||||
if "triples-store" not in flow["interfaces"]:
|
||||
raise RuntimeError("Flow has no triples-store")
|
||||
|
||||
if "graph-embeddings-store" not in flow["interfaces"]:
|
||||
raise RuntimeError("Flow has no graph-embeddings-store")
|
||||
|
||||
t_q = flow["interfaces"]["triples-store"]["flow"]
|
||||
ge_q = flow["interfaces"]["graph-embeddings-store"]["flow"]
|
||||
|
||||
# Got this far, it should all work
|
||||
await respond(
|
||||
KnowledgeResponse(
|
||||
error = None,
|
||||
ids = None,
|
||||
eos = False,
|
||||
triples = None,
|
||||
graph_embeddings = None
|
||||
if request.operation == "load-de-core":
|
||||
await self._load_de_core(
|
||||
request, respond, workspace, flow,
|
||||
)
|
||||
else:
|
||||
await self._load_kg_core(
|
||||
request, respond, workspace, flow,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -223,14 +305,36 @@ class KnowledgeManager:
|
|||
)
|
||||
)
|
||||
|
||||
logger.debug("Knowledge processing done")
|
||||
|
||||
logger.debug("Starting knowledge loading process...")
|
||||
continue
|
||||
|
||||
try:
|
||||
async def _load_kg_core(self, request, respond, workspace, flow):
|
||||
|
||||
if "triples-store" not in flow["interfaces"]:
|
||||
raise RuntimeError("Flow has no triples-store")
|
||||
|
||||
if "graph-embeddings-store" not in flow["interfaces"]:
|
||||
raise RuntimeError("Flow has no graph-embeddings-store")
|
||||
|
||||
t_q = flow["interfaces"]["triples-store"]["flow"]
|
||||
ge_q = flow["interfaces"]["graph-embeddings-store"]["flow"]
|
||||
|
||||
await respond(
|
||||
KnowledgeResponse(
|
||||
error = None,
|
||||
ids = None,
|
||||
eos = False,
|
||||
triples = None,
|
||||
graph_embeddings = None
|
||||
)
|
||||
)
|
||||
|
||||
t_pub = None
|
||||
ge_pub = None
|
||||
|
||||
try:
|
||||
|
||||
logger.debug(f"Triples queue: {t_q}")
|
||||
logger.debug(f"Graph embeddings queue: {ge_q}")
|
||||
|
||||
|
|
@ -249,7 +353,6 @@ class KnowledgeManager:
|
|||
await ge_pub.start()
|
||||
|
||||
async def publish_triples(t):
|
||||
# Override collection with request collection
|
||||
if hasattr(t, 'metadata') and hasattr(t.metadata, 'collection'):
|
||||
t.metadata.collection = request.collection or "default"
|
||||
await t_pub.send(None, t)
|
||||
|
|
@ -263,7 +366,6 @@ class KnowledgeManager:
|
|||
)
|
||||
|
||||
async def publish_ge(g):
|
||||
# Override collection with request collection
|
||||
if hasattr(g, 'metadata') and hasattr(g.metadata, 'collection'):
|
||||
g.metadata.collection = request.collection or "default"
|
||||
await ge_pub.send(None, g)
|
||||
|
|
@ -276,7 +378,7 @@ class KnowledgeManager:
|
|||
publish_ge,
|
||||
)
|
||||
|
||||
logger.debug("Knowledge loading completed")
|
||||
logger.debug("Knowledge core loading completed")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
|
|
@ -289,6 +391,59 @@ class KnowledgeManager:
|
|||
if t_pub: await t_pub.stop()
|
||||
if ge_pub: await ge_pub.stop()
|
||||
|
||||
logger.debug("Knowledge processing done")
|
||||
async def _load_de_core(self, request, respond, workspace, flow):
|
||||
|
||||
continue
|
||||
if "document-embeddings-store" not in flow["interfaces"]:
|
||||
raise RuntimeError("Flow has no document-embeddings-store")
|
||||
|
||||
de_q = flow["interfaces"]["document-embeddings-store"]["flow"]
|
||||
|
||||
await respond(
|
||||
KnowledgeResponse(
|
||||
error = None,
|
||||
ids = None,
|
||||
eos = False,
|
||||
triples = None,
|
||||
graph_embeddings = None
|
||||
)
|
||||
)
|
||||
|
||||
de_pub = None
|
||||
|
||||
try:
|
||||
|
||||
logger.debug(f"Document embeddings queue: {de_q}")
|
||||
|
||||
de_pub = Publisher(
|
||||
self.flow_config.pubsub, de_q,
|
||||
schema=DocumentEmbeddings,
|
||||
)
|
||||
|
||||
logger.debug("Starting publisher...")
|
||||
|
||||
await de_pub.start()
|
||||
|
||||
async def publish_de(de):
|
||||
if hasattr(de, 'metadata') and hasattr(de.metadata, 'collection'):
|
||||
de.metadata.collection = request.collection or "default"
|
||||
await de_pub.send(None, de)
|
||||
|
||||
logger.debug("Publishing document embeddings...")
|
||||
|
||||
await self.table_store.get_document_embeddings(
|
||||
workspace,
|
||||
request.id,
|
||||
publish_de,
|
||||
)
|
||||
|
||||
logger.debug("Document embeddings core loading completed")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
logger.error(f"Knowledge exception: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
|
||||
logger.debug("Stopping publisher...")
|
||||
|
||||
if de_pub: await de_pub.stop()
|
||||
|
|
|
|||
|
|
@ -187,6 +187,11 @@ class Processor(WorkspaceProcessor):
|
|||
"put-kg-core": self.knowledge.put_kg_core,
|
||||
"load-kg-core": self.knowledge.load_kg_core,
|
||||
"unload-kg-core": self.knowledge.unload_kg_core,
|
||||
"list-de-cores": self.knowledge.list_de_cores,
|
||||
"get-de-core": self.knowledge.get_de_core,
|
||||
"delete-de-core": self.knowledge.delete_de_core,
|
||||
"put-de-core": self.knowledge.put_de_core,
|
||||
"load-de-core": self.knowledge.load_de_core,
|
||||
}
|
||||
|
||||
if v.operation not in impls:
|
||||
|
|
|
|||
|
|
@ -1,10 +1,14 @@
|
|||
|
||||
import datetime
|
||||
import os
|
||||
import logging
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.auth import PlainTextAuthProvider
|
||||
from cassandra.query import BatchStatement, SimpleStatement
|
||||
from ssl import SSLContext, PROTOCOL_TLSv1_2
|
||||
import os
|
||||
import logging
|
||||
|
||||
from ..tables.cassandra_async import async_execute
|
||||
|
||||
# Global list to track clusters for cleanup
|
||||
_active_clusters = []
|
||||
|
|
@ -461,7 +465,6 @@ class KnowledgeGraph:
|
|||
def create_collection(self, collection):
|
||||
"""Create collection by inserting metadata row"""
|
||||
try:
|
||||
import datetime
|
||||
self.session.execute(
|
||||
f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)",
|
||||
(collection, datetime.datetime.now())
|
||||
|
|
@ -954,7 +957,6 @@ class EntityCentricKnowledgeGraph:
|
|||
def create_collection(self, collection):
|
||||
"""Create collection by inserting metadata row"""
|
||||
try:
|
||||
import datetime
|
||||
self.session.execute(
|
||||
f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)",
|
||||
(collection, datetime.datetime.now())
|
||||
|
|
@ -1045,6 +1047,222 @@ class EntityCentricKnowledgeGraph:
|
|||
|
||||
logger.info(f"Deleted collection {collection}: {len(entities)} entity partitions, {len(quads)} quads")
|
||||
|
||||
# ========================================================================
|
||||
# Async methods — use cassandra driver's native async API via async_execute
|
||||
# ========================================================================
|
||||
|
||||
async def async_insert(self, collection, s, p, o, g=None, otype=None, dtype="", lang=""):
|
||||
if g is None:
|
||||
g = DEFAULT_GRAPH
|
||||
if otype is None:
|
||||
if o.startswith("http://") or o.startswith("https://"):
|
||||
otype = "u"
|
||||
else:
|
||||
otype = "l"
|
||||
|
||||
batch = BatchStatement()
|
||||
batch.add(self.insert_entity_stmt, (collection, s, 'S', p, otype, s, o, g, dtype, lang))
|
||||
batch.add(self.insert_entity_stmt, (collection, p, 'P', p, otype, s, o, g, dtype, lang))
|
||||
if otype == 'u' or otype == 't':
|
||||
batch.add(self.insert_entity_stmt, (collection, o, 'O', p, otype, s, o, g, dtype, lang))
|
||||
if g != DEFAULT_GRAPH:
|
||||
batch.add(self.insert_entity_stmt, (collection, g, 'G', p, otype, s, o, g, dtype, lang))
|
||||
batch.add(self.insert_collection_stmt, (collection, g, s, p, o, otype, dtype, lang))
|
||||
|
||||
await async_execute(self.session, batch)
|
||||
|
||||
async def async_get_all(self, collection, limit=50):
|
||||
return await async_execute(
|
||||
self.session, self.get_collection_all_stmt, (collection, limit)
|
||||
)
|
||||
|
||||
async def async_get_s(self, collection, s, g=None, limit=10):
|
||||
rows = await async_execute(
|
||||
self.session, self.get_entity_as_s_stmt, (collection, s, limit)
|
||||
)
|
||||
results = []
|
||||
for row in rows:
|
||||
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||
if g is not None and d != g:
|
||||
continue
|
||||
results.append(QuadResult(
|
||||
s=row.s, p=row.p, o=row.o, g=d,
|
||||
otype=row.otype, dtype=row.dtype, lang=row.lang
|
||||
))
|
||||
return results
|
||||
|
||||
async def async_get_p(self, collection, p, g=None, limit=10):
|
||||
rows = await async_execute(
|
||||
self.session, self.get_entity_as_p_stmt, (collection, p, limit)
|
||||
)
|
||||
results = []
|
||||
for row in rows:
|
||||
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||
if g is not None and d != g:
|
||||
continue
|
||||
results.append(QuadResult(
|
||||
s=row.s, p=row.p, o=row.o, g=d,
|
||||
otype=row.otype, dtype=row.dtype, lang=row.lang
|
||||
))
|
||||
return results
|
||||
|
||||
async def async_get_o(self, collection, o, g=None, limit=10):
|
||||
rows = await async_execute(
|
||||
self.session, self.get_entity_as_o_stmt, (collection, o, limit)
|
||||
)
|
||||
results = []
|
||||
for row in rows:
|
||||
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||
if g is not None and d != g:
|
||||
continue
|
||||
results.append(QuadResult(
|
||||
s=row.s, p=row.p, o=row.o, g=d,
|
||||
otype=row.otype, dtype=row.dtype, lang=row.lang
|
||||
))
|
||||
return results
|
||||
|
||||
async def async_get_sp(self, collection, s, p, g=None, limit=10):
|
||||
rows = await async_execute(
|
||||
self.session, self.get_entity_as_s_p_stmt, (collection, s, p, limit)
|
||||
)
|
||||
results = []
|
||||
for row in rows:
|
||||
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||
if g is not None and d != g:
|
||||
continue
|
||||
results.append(QuadResult(
|
||||
s=s, p=p, o=row.o, g=d,
|
||||
otype=row.otype, dtype=row.dtype, lang=row.lang
|
||||
))
|
||||
return results
|
||||
|
||||
async def async_get_po(self, collection, p, o, g=None, limit=10):
|
||||
rows = await async_execute(
|
||||
self.session, self.get_entity_as_o_p_stmt, (collection, o, p, limit)
|
||||
)
|
||||
results = []
|
||||
for row in rows:
|
||||
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||
if g is not None and d != g:
|
||||
continue
|
||||
results.append(QuadResult(
|
||||
s=row.s, p=p, o=o, g=d,
|
||||
otype=row.otype, dtype=row.dtype, lang=row.lang
|
||||
))
|
||||
return results
|
||||
|
||||
async def async_get_os(self, collection, o, s, g=None, limit=10):
|
||||
rows = await async_execute(
|
||||
self.session, self.get_entity_as_s_stmt, (collection, s, limit)
|
||||
)
|
||||
results = []
|
||||
for row in rows:
|
||||
if row.o != o:
|
||||
continue
|
||||
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||
if g is not None and d != g:
|
||||
continue
|
||||
results.append(QuadResult(
|
||||
s=s, p=row.p, o=o, g=d,
|
||||
otype=row.otype, dtype=row.dtype, lang=row.lang
|
||||
))
|
||||
return results
|
||||
|
||||
async def async_get_spo(self, collection, s, p, o, g=None, limit=10):
|
||||
rows = await async_execute(
|
||||
self.session, self.get_entity_as_s_p_stmt, (collection, s, p, limit)
|
||||
)
|
||||
results = []
|
||||
for row in rows:
|
||||
if row.o != o:
|
||||
continue
|
||||
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||
if g is not None and d != g:
|
||||
continue
|
||||
results.append(QuadResult(
|
||||
s=s, p=p, o=o, g=d,
|
||||
otype=row.otype, dtype=row.dtype, lang=row.lang
|
||||
))
|
||||
return results
|
||||
|
||||
async def async_get_g(self, collection, g, limit=50):
|
||||
if g is None:
|
||||
g = DEFAULT_GRAPH
|
||||
return await async_execute(
|
||||
self.session, self.get_collection_by_graph_stmt, (collection, g, limit)
|
||||
)
|
||||
|
||||
async def async_collection_exists(self, collection):
|
||||
try:
|
||||
result = await async_execute(
|
||||
self.session,
|
||||
f"SELECT collection FROM {self.collection_metadata_table} WHERE collection = %s LIMIT 1",
|
||||
(collection,)
|
||||
)
|
||||
return bool(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking collection existence: {e}")
|
||||
return False
|
||||
|
||||
async def async_create_collection(self, collection):
|
||||
await async_execute(
|
||||
self.session,
|
||||
f"INSERT INTO {self.collection_metadata_table} (collection, created_at) VALUES (%s, %s)",
|
||||
(collection, datetime.datetime.now())
|
||||
)
|
||||
logger.info(f"Created collection metadata for {collection}")
|
||||
|
||||
async def async_delete_collection(self, collection):
|
||||
rows = await async_execute(
|
||||
self.session,
|
||||
f"SELECT d, s, p, o, otype, dtype, lang FROM {self.collection_table} WHERE collection = %s",
|
||||
(collection,)
|
||||
)
|
||||
|
||||
entities = set()
|
||||
quads = []
|
||||
for row in rows:
|
||||
d, s, p, o = row.d, row.s, row.p, row.o
|
||||
otype = row.otype
|
||||
dtype = row.dtype if hasattr(row, 'dtype') else ''
|
||||
lang = row.lang if hasattr(row, 'lang') else ''
|
||||
quads.append((d, s, p, o, otype, dtype, lang))
|
||||
entities.add(s)
|
||||
entities.add(p)
|
||||
if otype == 'u' or otype == 't':
|
||||
entities.add(o)
|
||||
if d != DEFAULT_GRAPH:
|
||||
entities.add(d)
|
||||
|
||||
batch = BatchStatement()
|
||||
count = 0
|
||||
for entity in entities:
|
||||
batch.add(self.delete_entity_partition_stmt, (collection, entity))
|
||||
count += 1
|
||||
if count % 50 == 0:
|
||||
await async_execute(self.session, batch)
|
||||
batch = BatchStatement()
|
||||
if count % 50 != 0:
|
||||
await async_execute(self.session, batch)
|
||||
|
||||
batch = BatchStatement()
|
||||
count = 0
|
||||
for d, s, p, o, otype, dtype, lang in quads:
|
||||
batch.add(self.delete_collection_row_stmt, (collection, d, s, p, o, otype, dtype, lang))
|
||||
count += 1
|
||||
if count % 50 == 0:
|
||||
await async_execute(self.session, batch)
|
||||
batch = BatchStatement()
|
||||
if count % 50 != 0:
|
||||
await async_execute(self.session, batch)
|
||||
|
||||
await async_execute(
|
||||
self.session,
|
||||
f"DELETE FROM {self.collection_metadata_table} WHERE collection = %s",
|
||||
(collection,)
|
||||
)
|
||||
logger.info(f"Deleted collection {collection}: {len(entities)} entity partitions, {len(quads)} quads")
|
||||
|
||||
def close(self):
|
||||
"""Close connections"""
|
||||
if hasattr(self, 'session') and self.session:
|
||||
|
|
|
|||
|
|
@ -457,6 +457,12 @@ for _op in ("put-kg-core", "delete-kg-core",
|
|||
"load-kg-core", "unload-kg-core"):
|
||||
_register_kind_op("knowledge", _op, "knowledge:write")
|
||||
|
||||
# knowledge: document-embeddings core service.
|
||||
for _op in ("get-de-core", "list-de-cores"):
|
||||
_register_kind_op("knowledge", _op, "knowledge:read")
|
||||
for _op in ("put-de-core", "delete-de-core", "load-de-core"):
|
||||
_register_kind_op("knowledge", _op, "knowledge:write")
|
||||
|
||||
|
||||
# collection-management: workspace collection lifecycle.
|
||||
_register_kind_op("collection-management", "list-collections", "collections:read")
|
||||
|
|
|
|||
|
|
@ -4,11 +4,10 @@ Document embeddings query service. Input is vector, output is an array
|
|||
of chunk_ids
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import PointStruct
|
||||
from qdrant_client.models import Distance, VectorParams
|
||||
|
||||
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
|
||||
from .... schema import Error
|
||||
|
|
@ -38,32 +37,6 @@ class Processor(DocumentEmbeddingsQueryService):
|
|||
)
|
||||
|
||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||
self.last_collection = None
|
||||
|
||||
def ensure_collection_exists(self, collection, dim):
|
||||
"""Ensure collection exists, create if it doesn't"""
|
||||
if collection != self.last_collection:
|
||||
if not self.qdrant.collection_exists(collection):
|
||||
try:
|
||||
self.qdrant.create_collection(
|
||||
collection_name=collection,
|
||||
vectors_config=VectorParams(
|
||||
size=dim, distance=Distance.COSINE
|
||||
),
|
||||
)
|
||||
logger.info(f"Created collection: {collection}")
|
||||
except Exception as e:
|
||||
logger.error(f"Qdrant collection creation failed: {e}")
|
||||
raise e
|
||||
self.last_collection = collection
|
||||
|
||||
def collection_exists(self, collection):
|
||||
"""Check if collection exists (no implicit creation)"""
|
||||
return self.qdrant.collection_exists(collection)
|
||||
|
||||
def collection_exists(self, collection):
|
||||
"""Check if collection exists (no implicit creation)"""
|
||||
return self.qdrant.collection_exists(collection)
|
||||
|
||||
async def query_document_embeddings(self, workspace, msg):
|
||||
|
||||
|
|
@ -73,21 +46,24 @@ class Processor(DocumentEmbeddingsQueryService):
|
|||
if not vec:
|
||||
return []
|
||||
|
||||
# Use dimension suffix in collection name
|
||||
dim = len(vec)
|
||||
collection = f"d_{workspace}_{msg.collection}_{dim}"
|
||||
|
||||
# Check if collection exists - return empty if not
|
||||
if not self.collection_exists(collection):
|
||||
exists = await asyncio.to_thread(
|
||||
self.qdrant.collection_exists, collection
|
||||
)
|
||||
if not exists:
|
||||
logger.info(f"Collection {collection} does not exist, returning empty results")
|
||||
return []
|
||||
|
||||
search_result = self.qdrant.query_points(
|
||||
result = await asyncio.to_thread(
|
||||
self.qdrant.query_points,
|
||||
collection_name=collection,
|
||||
query=vec,
|
||||
limit=msg.limit,
|
||||
with_payload=True,
|
||||
).points
|
||||
)
|
||||
search_result = result.points
|
||||
|
||||
chunks = []
|
||||
for r in search_result:
|
||||
|
|
|
|||
|
|
@ -4,11 +4,10 @@ Graph embeddings query service. Input is vector, output is list of
|
|||
entities
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import PointStruct
|
||||
from qdrant_client.models import Distance, VectorParams
|
||||
|
||||
from .... schema import GraphEmbeddingsResponse, EntityMatch
|
||||
from .... schema import Error, Term, IRI, LITERAL
|
||||
|
|
@ -38,32 +37,6 @@ class Processor(GraphEmbeddingsQueryService):
|
|||
)
|
||||
|
||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||
self.last_collection = None
|
||||
|
||||
def ensure_collection_exists(self, collection, dim):
|
||||
"""Ensure collection exists, create if it doesn't"""
|
||||
if collection != self.last_collection:
|
||||
if not self.qdrant.collection_exists(collection):
|
||||
try:
|
||||
self.qdrant.create_collection(
|
||||
collection_name=collection,
|
||||
vectors_config=VectorParams(
|
||||
size=dim, distance=Distance.COSINE
|
||||
),
|
||||
)
|
||||
logger.info(f"Created collection: {collection}")
|
||||
except Exception as e:
|
||||
logger.error(f"Qdrant collection creation failed: {e}")
|
||||
raise e
|
||||
self.last_collection = collection
|
||||
|
||||
def collection_exists(self, collection):
|
||||
"""Check if collection exists (no implicit creation)"""
|
||||
return self.qdrant.collection_exists(collection)
|
||||
|
||||
def collection_exists(self, collection):
|
||||
"""Check if collection exists (no implicit creation)"""
|
||||
return self.qdrant.collection_exists(collection)
|
||||
|
||||
def create_value(self, ent):
|
||||
if ent.startswith("http://") or ent.startswith("https://"):
|
||||
|
|
@ -79,23 +52,26 @@ class Processor(GraphEmbeddingsQueryService):
|
|||
if not vec:
|
||||
return []
|
||||
|
||||
# Use dimension suffix in collection name
|
||||
dim = len(vec)
|
||||
collection = f"t_{workspace}_{msg.collection}_{dim}"
|
||||
|
||||
# Check if collection exists - return empty if not
|
||||
if not self.collection_exists(collection):
|
||||
exists = await asyncio.to_thread(
|
||||
self.qdrant.collection_exists, collection
|
||||
)
|
||||
if not exists:
|
||||
logger.info(f"Collection {collection} does not exist")
|
||||
return []
|
||||
|
||||
# Heuristic hack, get (2*limit), so that we have more chance
|
||||
# of getting (limit) unique entities
|
||||
search_result = self.qdrant.query_points(
|
||||
result = await asyncio.to_thread(
|
||||
self.qdrant.query_points,
|
||||
collection_name=collection,
|
||||
query=vec,
|
||||
limit=msg.limit * 2,
|
||||
with_payload=True,
|
||||
).points
|
||||
)
|
||||
search_result = result.points
|
||||
|
||||
entity_set = set()
|
||||
entities = []
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ Output is matching row index information (index_name, index_value) for
|
|||
use in subsequent Cassandra lookups.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional
|
||||
|
|
@ -70,7 +71,7 @@ class Processor(FlowProcessor):
|
|||
safe_name = 'r_' + safe_name
|
||||
return safe_name.lower()
|
||||
|
||||
def find_collection(self, workspace: str, collection: str, schema_name: str) -> Optional[str]:
|
||||
async def find_collection(self, workspace: str, collection: str, schema_name: str) -> Optional[str]:
|
||||
"""Find the Qdrant collection for a given workspace/collection/schema"""
|
||||
prefix = (
|
||||
f"rows_{self.sanitize_name(workspace)}_"
|
||||
|
|
@ -78,14 +79,15 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
|
||||
try:
|
||||
all_collections = self.qdrant.get_collections().collections
|
||||
all_collections = await asyncio.to_thread(
|
||||
lambda: self.qdrant.get_collections().collections
|
||||
)
|
||||
matching = [
|
||||
coll.name for coll in all_collections
|
||||
if coll.name.startswith(prefix)
|
||||
]
|
||||
|
||||
if matching:
|
||||
# Return first match (there should typically be only one per dimension)
|
||||
return matching[0]
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -100,8 +102,7 @@ class Processor(FlowProcessor):
|
|||
if not vec:
|
||||
return []
|
||||
|
||||
# Find the collection for this workspace/collection/schema
|
||||
qdrant_collection = self.find_collection(
|
||||
qdrant_collection = await self.find_collection(
|
||||
workspace, request.collection, request.schema_name
|
||||
)
|
||||
|
||||
|
|
@ -113,7 +114,6 @@ class Processor(FlowProcessor):
|
|||
return []
|
||||
|
||||
try:
|
||||
# Build optional filter for index_name
|
||||
query_filter = None
|
||||
if request.index_name:
|
||||
query_filter = Filter(
|
||||
|
|
@ -125,16 +125,16 @@ class Processor(FlowProcessor):
|
|||
]
|
||||
)
|
||||
|
||||
# Query Qdrant
|
||||
search_result = self.qdrant.query_points(
|
||||
result = await asyncio.to_thread(
|
||||
self.qdrant.query_points,
|
||||
collection_name=qdrant_collection,
|
||||
query=vec,
|
||||
limit=request.limit,
|
||||
with_payload=True,
|
||||
query_filter=query_filter,
|
||||
).points
|
||||
)
|
||||
search_result = result.points
|
||||
|
||||
# Convert to RowIndexMatch objects
|
||||
matches = []
|
||||
for point in search_result:
|
||||
payload = point.payload or {}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ Queries against the unified 'rows' table with schema:
|
|||
- source: text
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
|
@ -97,12 +98,14 @@ class Processor(FlowProcessor):
|
|||
# Cassandra session
|
||||
self.cluster = None
|
||||
self.session = None
|
||||
self._setup_lock = asyncio.Lock()
|
||||
|
||||
# Known keyspaces
|
||||
self.known_keyspaces: Set[str] = set()
|
||||
|
||||
def connect_cassandra(self):
|
||||
async def connect_cassandra(self):
|
||||
"""Connect to Cassandra cluster"""
|
||||
async with self._setup_lock:
|
||||
if self.session:
|
||||
return
|
||||
|
||||
|
|
@ -112,14 +115,16 @@ class Processor(FlowProcessor):
|
|||
username=self.cassandra_username,
|
||||
password=self.cassandra_password
|
||||
)
|
||||
self.cluster = Cluster(
|
||||
cluster = Cluster(
|
||||
contact_points=self.cassandra_host,
|
||||
auth_provider=auth_provider
|
||||
)
|
||||
else:
|
||||
self.cluster = Cluster(contact_points=self.cassandra_host)
|
||||
cluster = Cluster(contact_points=self.cassandra_host)
|
||||
|
||||
self.session = self.cluster.connect()
|
||||
session = await asyncio.to_thread(cluster.connect)
|
||||
self.cluster = cluster
|
||||
self.session = session
|
||||
logger.info(f"Connected to Cassandra cluster at {self.cassandra_host}")
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -140,14 +145,17 @@ class Processor(FlowProcessor):
|
|||
f"for workspace {workspace}"
|
||||
)
|
||||
|
||||
# Replace existing schemas for this workspace
|
||||
async with self._setup_lock:
|
||||
await self._apply_schema_config(workspace, config)
|
||||
|
||||
async def _apply_schema_config(self, workspace, config):
|
||||
|
||||
ws_schemas: Dict[str, RowSchema] = {}
|
||||
self.schemas[workspace] = ws_schemas
|
||||
|
||||
builder = GraphQLSchemaBuilder()
|
||||
self.schema_builders[workspace] = builder
|
||||
|
||||
# Check if our config type exists
|
||||
if self.config_key not in config:
|
||||
logger.warning(
|
||||
f"No '{self.config_key}' type in configuration "
|
||||
|
|
@ -156,16 +164,12 @@ class Processor(FlowProcessor):
|
|||
self.graphql_schemas[workspace] = None
|
||||
return
|
||||
|
||||
# Get the schemas dictionary for our type
|
||||
schemas_config = config[self.config_key]
|
||||
|
||||
# Process each schema in the schemas config
|
||||
for schema_name, schema_json in schemas_config.items():
|
||||
try:
|
||||
# Parse the JSON schema definition
|
||||
schema_def = json.loads(schema_json)
|
||||
|
||||
# Create Field objects
|
||||
fields = []
|
||||
for field_def in schema_def.get("fields", []):
|
||||
field = SchemaField(
|
||||
|
|
@ -180,7 +184,6 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
fields.append(field)
|
||||
|
||||
# Create RowSchema
|
||||
row_schema = RowSchema(
|
||||
name=schema_def.get("name", schema_name),
|
||||
description=schema_def.get("description", ""),
|
||||
|
|
@ -202,7 +205,6 @@ class Processor(FlowProcessor):
|
|||
f"{len(ws_schemas)} schemas"
|
||||
)
|
||||
|
||||
# Regenerate GraphQL schema for this workspace
|
||||
self.graphql_schemas[workspace] = builder.build(self.query_cassandra)
|
||||
|
||||
def get_index_names(self, schema: RowSchema) -> List[str]:
|
||||
|
|
@ -254,7 +256,7 @@ class Processor(FlowProcessor):
|
|||
For other queries, we need to scan and post-filter.
|
||||
"""
|
||||
# Connect if needed
|
||||
self.connect_cassandra()
|
||||
await self.connect_cassandra()
|
||||
|
||||
safe_keyspace = self.sanitize_name(workspace)
|
||||
|
||||
|
|
|
|||
|
|
@ -30,14 +30,13 @@ class EvaluationError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
async def evaluate(node, triples_client, workspace, collection, limit=10000):
|
||||
async def evaluate(node, triples_client, collection, limit=10000):
|
||||
"""
|
||||
Evaluate a SPARQL algebra node.
|
||||
|
||||
Args:
|
||||
node: rdflib CompValue algebra node
|
||||
triples_client: TriplesClient instance for triple pattern queries
|
||||
workspace: workspace/keyspace identifier
|
||||
collection: collection identifier
|
||||
limit: safety limit on results
|
||||
|
||||
|
|
@ -55,24 +54,24 @@ async def evaluate(node, triples_client, workspace, collection, limit=10000):
|
|||
logger.warning(f"Unsupported algebra node: {name}")
|
||||
return [{}]
|
||||
|
||||
return await handler(node, triples_client, workspace, collection, limit)
|
||||
return await handler(node, triples_client, collection, limit)
|
||||
|
||||
|
||||
# --- Node handlers ---
|
||||
|
||||
async def _eval_select_query(node, tc, workspace, collection, limit):
|
||||
async def _eval_select_query(node, tc, collection, limit):
|
||||
"""Evaluate a SelectQuery node."""
|
||||
return await evaluate(node.p, tc, workspace, collection, limit)
|
||||
return await evaluate(node.p, tc, collection, limit)
|
||||
|
||||
|
||||
async def _eval_project(node, tc, workspace, collection, limit):
|
||||
async def _eval_project(node, tc, collection, limit):
|
||||
"""Evaluate a Project node (SELECT variable projection)."""
|
||||
solutions = await evaluate(node.p, tc, workspace, collection, limit)
|
||||
solutions = await evaluate(node.p, tc, collection, limit)
|
||||
variables = [str(v) for v in node.PV]
|
||||
return project(solutions, variables)
|
||||
|
||||
|
||||
async def _eval_bgp(node, tc, workspace, collection, limit):
|
||||
async def _eval_bgp(node, tc, collection, limit):
|
||||
"""
|
||||
Evaluate a Basic Graph Pattern.
|
||||
|
||||
|
|
@ -107,7 +106,7 @@ async def _eval_bgp(node, tc, workspace, collection, limit):
|
|||
|
||||
# Query the triples store
|
||||
results = await _query_pattern(
|
||||
tc, s_val, p_val, o_val, workspace, collection, limit
|
||||
tc, s_val, p_val, o_val, collection, limit
|
||||
)
|
||||
|
||||
# Map results back to variable bindings,
|
||||
|
|
@ -130,17 +129,17 @@ async def _eval_bgp(node, tc, workspace, collection, limit):
|
|||
return solutions[:limit]
|
||||
|
||||
|
||||
async def _eval_join(node, tc, workspace, collection, limit):
|
||||
async def _eval_join(node, tc, collection, limit):
|
||||
"""Evaluate a Join node."""
|
||||
left = await evaluate(node.p1, tc, workspace, collection, limit)
|
||||
right = await evaluate(node.p2, tc, workspace, collection, limit)
|
||||
left = await evaluate(node.p1, tc, collection, limit)
|
||||
right = await evaluate(node.p2, tc, collection, limit)
|
||||
return hash_join(left, right)[:limit]
|
||||
|
||||
|
||||
async def _eval_left_join(node, tc, workspace, collection, limit):
|
||||
async def _eval_left_join(node, tc, collection, limit):
|
||||
"""Evaluate a LeftJoin node (OPTIONAL)."""
|
||||
left_sols = await evaluate(node.p1, tc, workspace, collection, limit)
|
||||
right_sols = await evaluate(node.p2, tc, workspace, collection, limit)
|
||||
left_sols = await evaluate(node.p1, tc, collection, limit)
|
||||
right_sols = await evaluate(node.p2, tc, collection, limit)
|
||||
|
||||
filter_fn = None
|
||||
if hasattr(node, "expr") and node.expr is not None:
|
||||
|
|
@ -153,16 +152,16 @@ async def _eval_left_join(node, tc, workspace, collection, limit):
|
|||
return left_join(left_sols, right_sols, filter_fn)[:limit]
|
||||
|
||||
|
||||
async def _eval_union(node, tc, workspace, collection, limit):
|
||||
async def _eval_union(node, tc, collection, limit):
|
||||
"""Evaluate a Union node."""
|
||||
left = await evaluate(node.p1, tc, workspace, collection, limit)
|
||||
right = await evaluate(node.p2, tc, workspace, collection, limit)
|
||||
left = await evaluate(node.p1, tc, collection, limit)
|
||||
right = await evaluate(node.p2, tc, collection, limit)
|
||||
return union(left, right)[:limit]
|
||||
|
||||
|
||||
async def _eval_filter(node, tc, workspace, collection, limit):
|
||||
async def _eval_filter(node, tc, collection, limit):
|
||||
"""Evaluate a Filter node."""
|
||||
solutions = await evaluate(node.p, tc, workspace, collection, limit)
|
||||
solutions = await evaluate(node.p, tc, collection, limit)
|
||||
expr = node.expr
|
||||
return [
|
||||
sol for sol in solutions
|
||||
|
|
@ -170,22 +169,22 @@ async def _eval_filter(node, tc, workspace, collection, limit):
|
|||
]
|
||||
|
||||
|
||||
async def _eval_distinct(node, tc, workspace, collection, limit):
|
||||
async def _eval_distinct(node, tc, collection, limit):
|
||||
"""Evaluate a Distinct node."""
|
||||
solutions = await evaluate(node.p, tc, workspace, collection, limit)
|
||||
solutions = await evaluate(node.p, tc, collection, limit)
|
||||
return distinct(solutions)
|
||||
|
||||
|
||||
async def _eval_reduced(node, tc, workspace, collection, limit):
|
||||
async def _eval_reduced(node, tc, collection, limit):
|
||||
"""Evaluate a Reduced node (like Distinct but implementation-defined)."""
|
||||
# Treat same as Distinct
|
||||
solutions = await evaluate(node.p, tc, workspace, collection, limit)
|
||||
solutions = await evaluate(node.p, tc, collection, limit)
|
||||
return distinct(solutions)
|
||||
|
||||
|
||||
async def _eval_order_by(node, tc, workspace, collection, limit):
|
||||
async def _eval_order_by(node, tc, collection, limit):
|
||||
"""Evaluate an OrderBy node."""
|
||||
solutions = await evaluate(node.p, tc, workspace, collection, limit)
|
||||
solutions = await evaluate(node.p, tc, collection, limit)
|
||||
|
||||
key_fns = []
|
||||
for cond in node.expr:
|
||||
|
|
@ -206,7 +205,7 @@ async def _eval_order_by(node, tc, workspace, collection, limit):
|
|||
return order_by(solutions, key_fns)
|
||||
|
||||
|
||||
async def _eval_slice(node, tc, workspace, collection, limit):
|
||||
async def _eval_slice(node, tc, collection, limit):
|
||||
"""Evaluate a Slice node (LIMIT/OFFSET)."""
|
||||
# Pass tighter limit downstream if possible
|
||||
inner_limit = limit
|
||||
|
|
@ -214,13 +213,13 @@ async def _eval_slice(node, tc, workspace, collection, limit):
|
|||
offset = node.start or 0
|
||||
inner_limit = min(limit, offset + node.length)
|
||||
|
||||
solutions = await evaluate(node.p, tc, workspace, collection, inner_limit)
|
||||
solutions = await evaluate(node.p, tc, collection, inner_limit)
|
||||
return slice_solutions(solutions, node.start or 0, node.length)
|
||||
|
||||
|
||||
async def _eval_extend(node, tc, workspace, collection, limit):
|
||||
async def _eval_extend(node, tc, collection, limit):
|
||||
"""Evaluate an Extend node (BIND)."""
|
||||
solutions = await evaluate(node.p, tc, workspace, collection, limit)
|
||||
solutions = await evaluate(node.p, tc, collection, limit)
|
||||
var_name = str(node.var)
|
||||
expr = node.expr
|
||||
|
||||
|
|
@ -246,9 +245,9 @@ async def _eval_extend(node, tc, workspace, collection, limit):
|
|||
return result
|
||||
|
||||
|
||||
async def _eval_group(node, tc, workspace, collection, limit):
|
||||
async def _eval_group(node, tc, collection, limit):
|
||||
"""Evaluate a Group node (GROUP BY with aggregation)."""
|
||||
solutions = await evaluate(node.p, tc, workspace, collection, limit)
|
||||
solutions = await evaluate(node.p, tc, collection, limit)
|
||||
|
||||
# Extract grouping expressions
|
||||
group_exprs = []
|
||||
|
|
@ -289,9 +288,9 @@ async def _eval_group(node, tc, workspace, collection, limit):
|
|||
return result
|
||||
|
||||
|
||||
async def _eval_aggregate_join(node, tc, workspace, collection, limit):
|
||||
async def _eval_aggregate_join(node, tc, collection, limit):
|
||||
"""Evaluate an AggregateJoin (aggregation functions after GROUP BY)."""
|
||||
solutions = await evaluate(node.p, tc, workspace, collection, limit)
|
||||
solutions = await evaluate(node.p, tc, collection, limit)
|
||||
|
||||
result = []
|
||||
for sol in solutions:
|
||||
|
|
@ -310,7 +309,7 @@ async def _eval_aggregate_join(node, tc, workspace, collection, limit):
|
|||
return result
|
||||
|
||||
|
||||
async def _eval_graph(node, tc, workspace, collection, limit):
|
||||
async def _eval_graph(node, tc, collection, limit):
|
||||
"""Evaluate a Graph node (GRAPH clause)."""
|
||||
term = node.term
|
||||
|
||||
|
|
@ -319,16 +318,16 @@ async def _eval_graph(node, tc, workspace, collection, limit):
|
|||
# We'd need to pass graph to triples queries
|
||||
# For now, evaluate inner pattern normally
|
||||
logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired")
|
||||
return await evaluate(node.p, tc, workspace, collection, limit)
|
||||
return await evaluate(node.p, tc, collection, limit)
|
||||
elif isinstance(term, Variable):
|
||||
# GRAPH ?g { ... } — variable graph
|
||||
logger.info(f"GRAPH ?{term} clause - variable graph not yet wired")
|
||||
return await evaluate(node.p, tc, workspace, collection, limit)
|
||||
return await evaluate(node.p, tc, collection, limit)
|
||||
else:
|
||||
return await evaluate(node.p, tc, workspace, collection, limit)
|
||||
return await evaluate(node.p, tc, collection, limit)
|
||||
|
||||
|
||||
async def _eval_values(node, tc, workspace, collection, limit):
|
||||
async def _eval_values(node, tc, collection, limit):
|
||||
"""Evaluate a VALUES clause (inline data)."""
|
||||
variables = [str(v) for v in node.var]
|
||||
solutions = []
|
||||
|
|
@ -343,9 +342,9 @@ async def _eval_values(node, tc, workspace, collection, limit):
|
|||
return solutions
|
||||
|
||||
|
||||
async def _eval_to_multiset(node, tc, workspace, collection, limit):
|
||||
async def _eval_to_multiset(node, tc, collection, limit):
|
||||
"""Evaluate a ToMultiSet node (subquery)."""
|
||||
return await evaluate(node.p, tc, workspace, collection, limit)
|
||||
return await evaluate(node.p, tc, collection, limit)
|
||||
|
||||
|
||||
# --- Aggregate computation ---
|
||||
|
|
@ -487,7 +486,7 @@ def _resolve_term(tmpl, solution):
|
|||
return rdflib_term_to_term(tmpl)
|
||||
|
||||
|
||||
async def _query_pattern(tc, s, p, o, workspace, collection, limit):
|
||||
async def _query_pattern(tc, s, p, o, collection, limit):
|
||||
"""
|
||||
Issue a streaming triple pattern query via TriplesClient.
|
||||
|
||||
|
|
@ -496,7 +495,6 @@ async def _query_pattern(tc, s, p, o, workspace, collection, limit):
|
|||
results = await tc.query(
|
||||
s=s, p=p, o=o,
|
||||
limit=limit,
|
||||
workspace=workspace,
|
||||
collection=collection,
|
||||
)
|
||||
return results
|
||||
|
|
|
|||
|
|
@ -141,7 +141,6 @@ class Processor(FlowProcessor):
|
|||
solutions = await evaluate(
|
||||
parsed.algebra,
|
||||
triples_client,
|
||||
workspace=flow.workspace,
|
||||
collection=request.collection or "default",
|
||||
limit=request.limit or 10000,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ null. Output is a list of quads.
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import json
|
||||
|
||||
from cassandra.query import SimpleStatement
|
||||
|
||||
from .... direct.cassandra_kg import (
|
||||
|
|
@ -176,45 +176,42 @@ class Processor(TriplesQueryService):
|
|||
self.cassandra_host = hosts
|
||||
self.cassandra_username = username
|
||||
self.cassandra_password = password
|
||||
self.table = None
|
||||
|
||||
def ensure_connection(self, workspace):
|
||||
"""Ensure we have a connection to the correct keyspace."""
|
||||
if workspace != self.table:
|
||||
KGClass = EntityCentricKnowledgeGraph
|
||||
self._connections = {}
|
||||
self._conn_lock = asyncio.Lock()
|
||||
|
||||
async def _get_connection(self, workspace):
|
||||
async with self._conn_lock:
|
||||
if workspace not in self._connections:
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
self.tg = KGClass(
|
||||
tg = await asyncio.to_thread(
|
||||
EntityCentricKnowledgeGraph,
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=workspace,
|
||||
username=self.cassandra_username,
|
||||
password=self.cassandra_password
|
||||
password=self.cassandra_password,
|
||||
)
|
||||
else:
|
||||
self.tg = KGClass(
|
||||
tg = await asyncio.to_thread(
|
||||
EntityCentricKnowledgeGraph,
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=workspace,
|
||||
)
|
||||
self.table = workspace
|
||||
self._connections[workspace] = tg
|
||||
return self._connections[workspace]
|
||||
|
||||
async def query_triples(self, workspace, query):
|
||||
|
||||
try:
|
||||
|
||||
# ensure_connection may construct a fresh
|
||||
# EntityCentricKnowledgeGraph which does sync schema
|
||||
# setup against Cassandra. Push it to a worker thread
|
||||
# so the event loop doesn't block on first-use per workspace.
|
||||
await asyncio.to_thread(self.ensure_connection, workspace)
|
||||
|
||||
# Extract values from query
|
||||
s_val = get_term_value(query.s)
|
||||
p_val = get_term_value(query.p)
|
||||
o_val = get_term_value(query.o)
|
||||
g_val = query.g # Already a string or None
|
||||
g_val = query.g
|
||||
|
||||
tg = await self._get_connection(workspace)
|
||||
|
||||
def get_object_metadata(row):
|
||||
"""Extract term type metadata from result row"""
|
||||
return (
|
||||
getattr(row, 'otype', None),
|
||||
getattr(row, 'dtype', None),
|
||||
|
|
@ -223,33 +220,21 @@ class Processor(TriplesQueryService):
|
|||
|
||||
quads = []
|
||||
|
||||
# All self.tg.get_* calls below are sync wrappers around
|
||||
# cassandra session.execute. Materialise inside a worker
|
||||
# thread so iteration never triggers sync paging back on
|
||||
# the event loop.
|
||||
|
||||
# Route to appropriate query method based on which fields are specified
|
||||
if s_val is not None:
|
||||
if p_val is not None:
|
||||
if o_val is not None:
|
||||
# SPO specified - find matching graphs
|
||||
resp = await asyncio.to_thread(
|
||||
lambda: list(self.tg.get_spo(
|
||||
resp = await tg.async_get_spo(
|
||||
query.collection, s_val, p_val, o_val,
|
||||
g=g_val, limit=query.limit,
|
||||
))
|
||||
)
|
||||
for t in resp:
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
term_type, datatype, language = get_object_metadata(t)
|
||||
quads.append((s_val, p_val, o_val, g, term_type, datatype, language))
|
||||
else:
|
||||
# SP specified
|
||||
resp = await asyncio.to_thread(
|
||||
lambda: list(self.tg.get_sp(
|
||||
resp = await tg.async_get_sp(
|
||||
query.collection, s_val, p_val,
|
||||
g=g_val, limit=query.limit,
|
||||
))
|
||||
)
|
||||
for t in resp:
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
|
|
@ -257,24 +242,18 @@ class Processor(TriplesQueryService):
|
|||
quads.append((s_val, p_val, t.o, g, term_type, datatype, language))
|
||||
else:
|
||||
if o_val is not None:
|
||||
# SO specified
|
||||
resp = await asyncio.to_thread(
|
||||
lambda: list(self.tg.get_os(
|
||||
resp = await tg.async_get_os(
|
||||
query.collection, o_val, s_val,
|
||||
g=g_val, limit=query.limit,
|
||||
))
|
||||
)
|
||||
for t in resp:
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
term_type, datatype, language = get_object_metadata(t)
|
||||
quads.append((s_val, t.p, o_val, g, term_type, datatype, language))
|
||||
else:
|
||||
# S only
|
||||
resp = await asyncio.to_thread(
|
||||
lambda: list(self.tg.get_s(
|
||||
resp = await tg.async_get_s(
|
||||
query.collection, s_val,
|
||||
g=g_val, limit=query.limit,
|
||||
))
|
||||
)
|
||||
for t in resp:
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
|
|
@ -283,24 +262,18 @@ class Processor(TriplesQueryService):
|
|||
else:
|
||||
if p_val is not None:
|
||||
if o_val is not None:
|
||||
# PO specified
|
||||
resp = await asyncio.to_thread(
|
||||
lambda: list(self.tg.get_po(
|
||||
resp = await tg.async_get_po(
|
||||
query.collection, p_val, o_val,
|
||||
g=g_val, limit=query.limit,
|
||||
))
|
||||
)
|
||||
for t in resp:
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
term_type, datatype, language = get_object_metadata(t)
|
||||
quads.append((t.s, p_val, o_val, g, term_type, datatype, language))
|
||||
else:
|
||||
# P only
|
||||
resp = await asyncio.to_thread(
|
||||
lambda: list(self.tg.get_p(
|
||||
resp = await tg.async_get_p(
|
||||
query.collection, p_val,
|
||||
g=g_val, limit=query.limit,
|
||||
))
|
||||
)
|
||||
for t in resp:
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
|
|
@ -308,40 +281,26 @@ class Processor(TriplesQueryService):
|
|||
quads.append((t.s, p_val, t.o, g, term_type, datatype, language))
|
||||
else:
|
||||
if o_val is not None:
|
||||
# O only
|
||||
resp = await asyncio.to_thread(
|
||||
lambda: list(self.tg.get_o(
|
||||
resp = await tg.async_get_o(
|
||||
query.collection, o_val,
|
||||
g=g_val, limit=query.limit,
|
||||
))
|
||||
)
|
||||
for t in resp:
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
term_type, datatype, language = get_object_metadata(t)
|
||||
quads.append((t.s, t.p, o_val, g, term_type, datatype, language))
|
||||
else:
|
||||
# Nothing specified - get all
|
||||
resp = await asyncio.to_thread(
|
||||
lambda: list(self.tg.get_all(
|
||||
resp = await tg.async_get_all(
|
||||
query.collection, limit=query.limit,
|
||||
))
|
||||
)
|
||||
for t in resp:
|
||||
# Note: quads_by_collection uses 'd' for graph field
|
||||
g = t.d if hasattr(t, 'd') else DEFAULT_GRAPH
|
||||
# Filter by graph
|
||||
# g_val=None means all graphs (no filter)
|
||||
# g_val="" means default graph only
|
||||
# otherwise filter to specific named graph
|
||||
if g_val is not None:
|
||||
if g != g_val:
|
||||
continue
|
||||
term_type, datatype, language = get_object_metadata(t)
|
||||
quads.append((t.s, t.p, t.o, g, term_type, datatype, language))
|
||||
|
||||
# Convert to Triple objects (with g field)
|
||||
# s and p are always IRIs in RDF
|
||||
# Object uses term_type/datatype/language metadata from database
|
||||
triples = [
|
||||
Triple(
|
||||
s=create_term(q[0], term_type='u'),
|
||||
|
|
@ -365,51 +324,41 @@ class Processor(TriplesQueryService):
|
|||
Uses Cassandra's paging to fetch results incrementally.
|
||||
"""
|
||||
try:
|
||||
await asyncio.to_thread(self.ensure_connection, workspace)
|
||||
|
||||
batch_size = query.batch_size if query.batch_size > 0 else 20
|
||||
limit = query.limit if query.limit > 0 else 10000
|
||||
|
||||
# Extract query pattern
|
||||
s_val = get_term_value(query.s)
|
||||
p_val = get_term_value(query.p)
|
||||
o_val = get_term_value(query.o)
|
||||
g_val = query.g
|
||||
|
||||
def get_object_metadata(row):
|
||||
"""Extract term type metadata from result row"""
|
||||
return (
|
||||
getattr(row, 'otype', None),
|
||||
getattr(row, 'dtype', None),
|
||||
getattr(row, 'lang', None),
|
||||
)
|
||||
|
||||
# For streaming, we need to execute with fetch_size
|
||||
# Use the collection table for get_all queries (most common streaming case)
|
||||
|
||||
# Determine which query to use based on pattern
|
||||
if s_val is None and p_val is None and o_val is None:
|
||||
# Get all - use collection table with paging
|
||||
cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {self.tg.collection_table} WHERE collection = %s"
|
||||
|
||||
tg = await self._get_connection(workspace)
|
||||
|
||||
cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {tg.collection_table} WHERE collection = %s"
|
||||
params = [query.collection]
|
||||
statement = SimpleStatement(cql, fetch_size=batch_size)
|
||||
# async_execute only materialises the first page;
|
||||
# this query needs all pages, so use sync execute
|
||||
# in a worker thread where page iteration can block.
|
||||
result_set = await asyncio.to_thread(
|
||||
lambda: list(tg.session.execute(statement, params))
|
||||
)
|
||||
|
||||
else:
|
||||
# For specific patterns, fall back to non-streaming
|
||||
# (these typically return small result sets anyway)
|
||||
async for batch, is_final in self._fallback_stream(workspace, query, batch_size):
|
||||
yield batch, is_final
|
||||
return
|
||||
|
||||
# Materialise in a worker thread. We lose true streaming
|
||||
# paging (the driver fetches all pages eagerly inside the
|
||||
# thread) but the event loop stays responsive, and result
|
||||
# sets at this layer are typically small enough that this
|
||||
# is acceptable. If true async paging is needed later,
|
||||
# revisit using ResponseFuture page callbacks.
|
||||
statement = SimpleStatement(cql, fetch_size=batch_size)
|
||||
result_set = await asyncio.to_thread(
|
||||
lambda: list(self.tg.session.execute(statement, params))
|
||||
)
|
||||
|
||||
batch = []
|
||||
count = 0
|
||||
|
||||
|
|
|
|||
|
|
@ -3,11 +3,13 @@
|
|||
Accepts entity/vector pairs and writes them to a Qdrant store.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import PointStruct
|
||||
from qdrant_client.models import Distance, VectorParams
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
from .... base import DocumentEmbeddingsStoreService, CollectionConfigHandler
|
||||
from .... base import AsyncProcessor, Consumer, Producer
|
||||
|
|
@ -35,13 +37,35 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
|||
)
|
||||
|
||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||
self._cache_lock = asyncio.Lock()
|
||||
self._known_collections: set[str] = set()
|
||||
|
||||
# Register for config push notifications
|
||||
self.register_config_handler(self.on_collection_config, types=["collection"])
|
||||
|
||||
async def ensure_collection(self, collection_name, dim):
|
||||
async with self._cache_lock:
|
||||
if collection_name in self._known_collections:
|
||||
return
|
||||
exists = await asyncio.to_thread(
|
||||
self.qdrant.collection_exists, collection_name
|
||||
)
|
||||
if not exists:
|
||||
logger.info(
|
||||
f"Lazily creating Qdrant collection {collection_name} "
|
||||
f"with dimension {dim}"
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
self.qdrant.create_collection,
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(
|
||||
size=dim, distance=Distance.COSINE
|
||||
),
|
||||
)
|
||||
self._known_collections.add(collection_name)
|
||||
|
||||
async def store_document_embeddings(self, workspace, message):
|
||||
|
||||
# Validate collection exists in config before processing
|
||||
if not self.collection_exists(workspace, message.metadata.collection):
|
||||
logger.warning(
|
||||
f"Collection {message.metadata.collection} for workspace {workspace} "
|
||||
|
|
@ -60,24 +84,15 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
|||
if not vec:
|
||||
continue
|
||||
|
||||
# Create collection name with dimension suffix for lazy creation
|
||||
dim = len(vec)
|
||||
collection = (
|
||||
f"d_{workspace}_{message.metadata.collection}_{dim}"
|
||||
)
|
||||
|
||||
# Lazily create collection if it doesn't exist (but only if authorized in config)
|
||||
if not self.qdrant.collection_exists(collection):
|
||||
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
|
||||
self.qdrant.create_collection(
|
||||
collection_name=collection,
|
||||
vectors_config=VectorParams(
|
||||
size=dim,
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
await self.ensure_collection(collection, dim)
|
||||
|
||||
self.qdrant.upsert(
|
||||
await asyncio.to_thread(
|
||||
self.qdrant.upsert,
|
||||
collection_name=collection,
|
||||
points=[
|
||||
PointStruct(
|
||||
|
|
@ -87,7 +102,7 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
|||
"chunk_id": chunk_id,
|
||||
}
|
||||
)
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -124,8 +139,9 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
|||
try:
|
||||
prefix = f"d_{workspace}_{collection}_"
|
||||
|
||||
# Get all collections and filter for matches
|
||||
all_collections = self.qdrant.get_collections().collections
|
||||
all_collections = await asyncio.to_thread(
|
||||
lambda: self.qdrant.get_collections().collections
|
||||
)
|
||||
matching_collections = [
|
||||
coll.name for coll in all_collections
|
||||
if coll.name.startswith(prefix)
|
||||
|
|
@ -135,7 +151,11 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
|||
logger.info(f"No collections found matching prefix {prefix}")
|
||||
else:
|
||||
for collection_name in matching_collections:
|
||||
self.qdrant.delete_collection(collection_name)
|
||||
await asyncio.to_thread(
|
||||
self.qdrant.delete_collection, collection_name
|
||||
)
|
||||
async with self._cache_lock:
|
||||
self._known_collections.discard(collection_name)
|
||||
logger.info(f"Deleted Qdrant collection: {collection_name}")
|
||||
logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")
|
||||
|
||||
|
|
|
|||
|
|
@ -3,11 +3,13 @@
|
|||
Accepts entity/vector pairs and writes them to a Qdrant store.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import PointStruct
|
||||
from qdrant_client.models import Distance, VectorParams
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
from .... base import GraphEmbeddingsStoreService, CollectionConfigHandler
|
||||
from .... base import AsyncProcessor, Consumer, Producer
|
||||
|
|
@ -50,13 +52,35 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
|||
)
|
||||
|
||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||
self._cache_lock = asyncio.Lock()
|
||||
self._known_collections: set[str] = set()
|
||||
|
||||
# Register for config push notifications
|
||||
self.register_config_handler(self.on_collection_config, types=["collection"])
|
||||
|
||||
async def ensure_collection(self, collection_name, dim):
|
||||
async with self._cache_lock:
|
||||
if collection_name in self._known_collections:
|
||||
return
|
||||
exists = await asyncio.to_thread(
|
||||
self.qdrant.collection_exists, collection_name
|
||||
)
|
||||
if not exists:
|
||||
logger.info(
|
||||
f"Lazily creating Qdrant collection {collection_name} "
|
||||
f"with dimension {dim}"
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
self.qdrant.create_collection,
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(
|
||||
size=dim, distance=Distance.COSINE
|
||||
),
|
||||
)
|
||||
self._known_collections.add(collection_name)
|
||||
|
||||
async def store_graph_embeddings(self, workspace, message):
|
||||
|
||||
# Validate collection exists in config before processing
|
||||
if not self.collection_exists(workspace, message.metadata.collection):
|
||||
logger.warning(
|
||||
f"Collection {message.metadata.collection} for workspace {workspace} "
|
||||
|
|
@ -75,22 +99,12 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
|||
if not vec:
|
||||
continue
|
||||
|
||||
# Create collection name with dimension suffix for lazy creation
|
||||
dim = len(vec)
|
||||
collection = (
|
||||
f"t_{workspace}_{message.metadata.collection}_{dim}"
|
||||
)
|
||||
|
||||
# Lazily create collection if it doesn't exist (but only if authorized in config)
|
||||
if not self.qdrant.collection_exists(collection):
|
||||
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
|
||||
self.qdrant.create_collection(
|
||||
collection_name=collection,
|
||||
vectors_config=VectorParams(
|
||||
size=dim,
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
await self.ensure_collection(collection, dim)
|
||||
|
||||
payload = {
|
||||
"entity": entity_value,
|
||||
|
|
@ -98,7 +112,8 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
|||
if entity.chunk_id:
|
||||
payload["chunk_id"] = entity.chunk_id
|
||||
|
||||
self.qdrant.upsert(
|
||||
await asyncio.to_thread(
|
||||
self.qdrant.upsert,
|
||||
collection_name=collection,
|
||||
points=[
|
||||
PointStruct(
|
||||
|
|
@ -106,7 +121,7 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
|||
vector=vec,
|
||||
payload=payload,
|
||||
)
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -143,8 +158,9 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
|||
try:
|
||||
prefix = f"t_{workspace}_{collection}_"
|
||||
|
||||
# Get all collections and filter for matches
|
||||
all_collections = self.qdrant.get_collections().collections
|
||||
all_collections = await asyncio.to_thread(
|
||||
lambda: self.qdrant.get_collections().collections
|
||||
)
|
||||
matching_collections = [
|
||||
coll.name for coll in all_collections
|
||||
if coll.name.startswith(prefix)
|
||||
|
|
@ -154,7 +170,11 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
|||
logger.info(f"No collections found matching prefix {prefix}")
|
||||
else:
|
||||
for collection_name in matching_collections:
|
||||
self.qdrant.delete_collection(collection_name)
|
||||
await asyncio.to_thread(
|
||||
self.qdrant.delete_collection, collection_name
|
||||
)
|
||||
async with self._cache_lock:
|
||||
self._known_collections.discard(collection_name)
|
||||
logger.info(f"Deleted Qdrant collection: {collection_name}")
|
||||
logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")
|
||||
|
||||
|
|
|
|||
|
|
@ -16,10 +16,10 @@ Payload structure:
|
|||
- text: The text that was embedded (for debugging/display)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Set, Tuple
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import PointStruct, Distance, VectorParams
|
||||
|
|
@ -63,11 +63,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
|||
# Register config handler for collection management
|
||||
self.register_config_handler(self.on_collection_config, types=["collection"])
|
||||
|
||||
# Cache of created Qdrant collections
|
||||
self.created_collections: Set[str] = set()
|
||||
|
||||
# Qdrant client
|
||||
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
|
||||
self._cache_lock = asyncio.Lock()
|
||||
self._known_collections: set[str] = set()
|
||||
|
||||
def sanitize_name(self, name: str) -> str:
|
||||
"""Sanitize names for Qdrant collection naming"""
|
||||
|
|
@ -85,25 +83,28 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
|||
safe_schema = self.sanitize_name(schema_name)
|
||||
return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}"
|
||||
|
||||
def ensure_collection(self, collection_name: str, dimension: int):
|
||||
async def ensure_collection(self, collection_name: str, dimension: int):
|
||||
"""Create Qdrant collection if it doesn't exist"""
|
||||
if collection_name in self.created_collections:
|
||||
async with self._cache_lock:
|
||||
if collection_name in self._known_collections:
|
||||
return
|
||||
|
||||
if not self.qdrant.collection_exists(collection_name):
|
||||
exists = await asyncio.to_thread(
|
||||
self.qdrant.collection_exists, collection_name
|
||||
)
|
||||
if not exists:
|
||||
logger.info(
|
||||
f"Creating Qdrant collection {collection_name} "
|
||||
f"with dimension {dimension}"
|
||||
)
|
||||
self.qdrant.create_collection(
|
||||
await asyncio.to_thread(
|
||||
self.qdrant.create_collection,
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(
|
||||
size=dimension,
|
||||
distance=Distance.COSINE
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.created_collections.add(collection_name)
|
||||
self._known_collections.add(collection_name)
|
||||
|
||||
async def on_embeddings(self, msg, consumer, flow):
|
||||
"""Process incoming RowEmbeddings and write to Qdrant"""
|
||||
|
|
@ -143,15 +144,14 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
|||
|
||||
dimension = len(vector)
|
||||
|
||||
# Create/get collection name (lazily on first vector)
|
||||
if qdrant_collection is None:
|
||||
qdrant_collection = self.get_collection_name(
|
||||
workspace, collection, schema_name, dimension
|
||||
)
|
||||
self.ensure_collection(qdrant_collection, dimension)
|
||||
await self.ensure_collection(qdrant_collection, dimension)
|
||||
|
||||
# Write to Qdrant
|
||||
self.qdrant.upsert(
|
||||
await asyncio.to_thread(
|
||||
self.qdrant.upsert,
|
||||
collection_name=qdrant_collection,
|
||||
points=[
|
||||
PointStruct(
|
||||
|
|
@ -163,7 +163,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
|||
"text": row_emb.text
|
||||
}
|
||||
)
|
||||
]
|
||||
],
|
||||
)
|
||||
embeddings_written += 1
|
||||
|
||||
|
|
@ -181,8 +181,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
|||
try:
|
||||
prefix = f"rows_{self.sanitize_name(workspace)}_{self.sanitize_name(collection)}_"
|
||||
|
||||
# Get all collections and filter for matches
|
||||
all_collections = self.qdrant.get_collections().collections
|
||||
all_collections = await asyncio.to_thread(
|
||||
lambda: self.qdrant.get_collections().collections
|
||||
)
|
||||
matching_collections = [
|
||||
coll.name for coll in all_collections
|
||||
if coll.name.startswith(prefix)
|
||||
|
|
@ -192,8 +193,11 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
|||
logger.info(f"No Qdrant collections found matching prefix {prefix}")
|
||||
else:
|
||||
for collection_name in matching_collections:
|
||||
self.qdrant.delete_collection(collection_name)
|
||||
self.created_collections.discard(collection_name)
|
||||
await asyncio.to_thread(
|
||||
self.qdrant.delete_collection, collection_name
|
||||
)
|
||||
async with self._cache_lock:
|
||||
self._known_collections.discard(collection_name)
|
||||
logger.info(f"Deleted Qdrant collection: {collection_name}")
|
||||
logger.info(
|
||||
f"Deleted {len(matching_collections)} collection(s) "
|
||||
|
|
@ -217,8 +221,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
|||
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
|
||||
)
|
||||
|
||||
# Get all collections and filter for matches
|
||||
all_collections = self.qdrant.get_collections().collections
|
||||
all_collections = await asyncio.to_thread(
|
||||
lambda: self.qdrant.get_collections().collections
|
||||
)
|
||||
matching_collections = [
|
||||
coll.name for coll in all_collections
|
||||
if coll.name.startswith(prefix)
|
||||
|
|
@ -228,8 +233,11 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
|||
logger.info(f"No Qdrant collections found matching prefix {prefix}")
|
||||
else:
|
||||
for collection_name in matching_collections:
|
||||
self.qdrant.delete_collection(collection_name)
|
||||
self.created_collections.discard(collection_name)
|
||||
await asyncio.to_thread(
|
||||
self.qdrant.delete_collection, collection_name
|
||||
)
|
||||
async with self._cache_lock:
|
||||
self._known_collections.discard(collection_name)
|
||||
logger.info(f"Deleted Qdrant collection: {collection_name}")
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
|||
|
||||
# Cache of known keyspaces and whether tables exist
|
||||
self.known_keyspaces: Set[str] = set()
|
||||
self.tables_initialized: Set[str] = set() # keyspaces with rows/row_partitions tables
|
||||
self.tables_initialized: Set[str] = set()
|
||||
|
||||
# Cache of registered (collection, schema_name) pairs
|
||||
self.registered_partitions: Set[Tuple[str, str]] = set()
|
||||
|
|
@ -94,6 +94,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
|||
self.cluster = None
|
||||
self.session = None
|
||||
|
||||
# Protects connection setup and cache mutations
|
||||
self._setup_lock = asyncio.Lock()
|
||||
|
||||
def connect_cassandra(self):
|
||||
"""Connect to Cassandra cluster"""
|
||||
if self.session:
|
||||
|
|
@ -126,6 +129,11 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
|||
f"for workspace {workspace}"
|
||||
)
|
||||
|
||||
async with self._setup_lock:
|
||||
return await self._apply_schema_config(workspace, config, version)
|
||||
|
||||
async def _apply_schema_config(self, workspace, config, version):
|
||||
|
||||
# Track which schemas changed in this workspace
|
||||
old_schemas = self.schemas.get(workspace, {})
|
||||
old_schema_names = set(old_schemas.keys())
|
||||
|
|
@ -391,12 +399,8 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
|||
schema_name = obj.schema_name
|
||||
source = getattr(obj.metadata, 'source', '') or ''
|
||||
|
||||
# Ensure tables exist (sync DDL — push to a worker thread
|
||||
# so the event loop stays responsive when running in a
|
||||
# processor group sharing the loop with siblings).
|
||||
async with self._setup_lock:
|
||||
await asyncio.to_thread(self.ensure_tables, keyspace)
|
||||
|
||||
# Register partitions if first time seeing this (collection, schema_name)
|
||||
await asyncio.to_thread(
|
||||
self.register_partitions,
|
||||
keyspace, collection, schema_name, workspace,
|
||||
|
|
@ -461,35 +465,27 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
|||
|
||||
async def create_collection(self, workspace: str, collection: str, metadata: dict):
|
||||
"""Create/verify collection exists in Cassandra row store"""
|
||||
# Connect if not already connected (sync, push to thread)
|
||||
async with self._setup_lock:
|
||||
await asyncio.to_thread(self.connect_cassandra)
|
||||
|
||||
# Ensure tables exist (sync DDL, push to thread)
|
||||
await asyncio.to_thread(self.ensure_tables, workspace)
|
||||
|
||||
logger.info(f"Collection {collection} ready for workspace {workspace}")
|
||||
|
||||
async def delete_collection(self, workspace: str, collection: str):
|
||||
"""Delete all data for a specific collection using partition tracking"""
|
||||
# Connect if not already connected
|
||||
async with self._setup_lock:
|
||||
await asyncio.to_thread(self.connect_cassandra)
|
||||
|
||||
safe_keyspace = self.sanitize_name(workspace)
|
||||
|
||||
# Check if keyspace exists
|
||||
if workspace not in self.known_keyspaces:
|
||||
check_keyspace_cql = """
|
||||
SELECT keyspace_name FROM system_schema.keyspaces
|
||||
WHERE keyspace_name = %s
|
||||
"""
|
||||
result = await async_execute(
|
||||
self.session, check_keyspace_cql, (safe_keyspace,)
|
||||
)
|
||||
safe_ks = self.sanitize_name(workspace)
|
||||
check_cql = "SELECT keyspace_name FROM system_schema.keyspaces WHERE keyspace_name = %s"
|
||||
result = await async_execute(self.session, check_cql, (safe_ks,))
|
||||
if not result:
|
||||
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
|
||||
logger.info(f"Keyspace {safe_ks} does not exist, nothing to delete")
|
||||
return
|
||||
self.known_keyspaces.add(workspace)
|
||||
|
||||
safe_keyspace = self.sanitize_name(workspace)
|
||||
|
||||
# Discover all partitions for this collection
|
||||
select_partitions_cql = f"""
|
||||
SELECT schema_name, index_name FROM {safe_keyspace}.row_partitions
|
||||
|
|
@ -540,7 +536,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
|||
logger.error(f"Failed to clean up row_partitions for {collection}: {e}")
|
||||
raise
|
||||
|
||||
# Clear from local cache
|
||||
async with self._setup_lock:
|
||||
self.registered_partitions = {
|
||||
(col, sch) for col, sch in self.registered_partitions
|
||||
if col != collection
|
||||
|
|
@ -553,7 +549,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
|||
|
||||
async def delete_collection_schema(self, workspace: str, collection: str, schema_name: str):
|
||||
"""Delete all data for a specific collection + schema combination"""
|
||||
# Connect if not already connected
|
||||
async with self._setup_lock:
|
||||
await asyncio.to_thread(self.connect_cassandra)
|
||||
|
||||
safe_keyspace = self.sanitize_name(workspace)
|
||||
|
|
@ -614,7 +610,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
|||
)
|
||||
raise
|
||||
|
||||
# Clear from local cache
|
||||
async with self._setup_lock:
|
||||
self.registered_partitions.discard((collection, schema_name))
|
||||
|
||||
logger.info(
|
||||
|
|
|
|||
|
|
@ -4,12 +4,7 @@ Graph writer. Input is graph edge. Writes edges to Cassandra graph.
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
import logging
|
||||
import json
|
||||
|
||||
from .... direct.cassandra_kg import (
|
||||
EntityCentricKnowledgeGraph, DEFAULT_GRAPH
|
||||
|
|
@ -28,6 +23,8 @@ default_ident = "triples-write"
|
|||
|
||||
def serialize_triple(triple):
|
||||
"""Serialize a Triple object to JSON for storage."""
|
||||
import json
|
||||
|
||||
if triple is None:
|
||||
return None
|
||||
|
||||
|
|
@ -141,63 +138,48 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
|
|||
self.cassandra_host = hosts
|
||||
self.cassandra_username = username
|
||||
self.cassandra_password = password
|
||||
self.table = None
|
||||
self.tg = None
|
||||
|
||||
self._connections = {}
|
||||
self._conn_lock = asyncio.Lock()
|
||||
|
||||
# Register for config push notifications
|
||||
self.register_config_handler(self.on_collection_config, types=["collection"])
|
||||
|
||||
async def store_triples(self, workspace, message):
|
||||
|
||||
# The cassandra-driver work below — connection, schema
|
||||
# setup, and per-triple inserts — is all synchronous.
|
||||
# Wrap the whole batch in a worker thread so the event
|
||||
# loop stays responsive for sibling processors when
|
||||
# running in a processor group.
|
||||
|
||||
def _do_store():
|
||||
|
||||
if self.table is None or self.table != workspace:
|
||||
|
||||
self.tg = None
|
||||
|
||||
# Use factory function to select implementation
|
||||
KGClass = EntityCentricKnowledgeGraph
|
||||
|
||||
try:
|
||||
async def _get_connection(self, workspace):
|
||||
async with self._conn_lock:
|
||||
if workspace not in self._connections:
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
self.tg = KGClass(
|
||||
tg = await asyncio.to_thread(
|
||||
EntityCentricKnowledgeGraph,
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=workspace,
|
||||
username=self.cassandra_username,
|
||||
password=self.cassandra_password,
|
||||
)
|
||||
else:
|
||||
self.tg = KGClass(
|
||||
tg = await asyncio.to_thread(
|
||||
EntityCentricKnowledgeGraph,
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=workspace,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Exception: {e}", exc_info=True)
|
||||
time.sleep(1)
|
||||
raise e
|
||||
self._connections[workspace] = tg
|
||||
return self._connections[workspace]
|
||||
|
||||
self.table = workspace
|
||||
async def store_triples(self, workspace, message):
|
||||
|
||||
tg = await self._get_connection(workspace)
|
||||
|
||||
for t in message.triples:
|
||||
# Extract values from Term objects
|
||||
s_val = get_term_value(t.s)
|
||||
p_val = get_term_value(t.p)
|
||||
o_val = get_term_value(t.o)
|
||||
# t.g is None for default graph, or a graph IRI
|
||||
g_val = t.g if t.g is not None else DEFAULT_GRAPH
|
||||
|
||||
# Extract object type metadata for entity-centric storage
|
||||
otype = get_term_otype(t.o)
|
||||
dtype = get_term_dtype(t.o)
|
||||
lang = get_term_lang(t.o)
|
||||
|
||||
self.tg.insert(
|
||||
await tg.async_insert(
|
||||
message.metadata.collection,
|
||||
s_val,
|
||||
p_val,
|
||||
|
|
@ -208,89 +190,32 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
|
|||
lang=lang,
|
||||
)
|
||||
|
||||
await asyncio.to_thread(_do_store)
|
||||
|
||||
async def create_collection(self, workspace: str, collection: str, metadata: dict):
|
||||
"""Create a collection in Cassandra triple store via config push"""
|
||||
|
||||
def _do_create():
|
||||
# Create or reuse connection for this workspace's keyspace
|
||||
if self.table is None or self.table != workspace:
|
||||
self.tg = None
|
||||
|
||||
# Use factory function to select implementation
|
||||
KGClass = EntityCentricKnowledgeGraph
|
||||
|
||||
try:
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
self.tg = KGClass(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=workspace,
|
||||
username=self.cassandra_username,
|
||||
password=self.cassandra_password,
|
||||
)
|
||||
else:
|
||||
self.tg = KGClass(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=workspace,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}")
|
||||
raise
|
||||
tg = await self._get_connection(workspace)
|
||||
|
||||
self.table = workspace
|
||||
|
||||
# Create collection using the built-in method
|
||||
logger.info(f"Creating collection {collection} for workspace {workspace}")
|
||||
|
||||
if self.tg.collection_exists(collection):
|
||||
exists = await tg.async_collection_exists(collection)
|
||||
if exists:
|
||||
logger.info(f"Collection {collection} already exists")
|
||||
else:
|
||||
self.tg.create_collection(collection)
|
||||
await tg.async_create_collection(collection)
|
||||
logger.info(f"Created collection {collection}")
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_do_create)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def delete_collection(self, workspace: str, collection: str):
|
||||
"""Delete all data for a specific collection from the unified triples table"""
|
||||
|
||||
def _do_delete():
|
||||
# Create or reuse connection for this workspace's keyspace
|
||||
if self.table is None or self.table != workspace:
|
||||
self.tg = None
|
||||
|
||||
# Use factory function to select implementation
|
||||
KGClass = EntityCentricKnowledgeGraph
|
||||
|
||||
try:
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
self.tg = KGClass(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=workspace,
|
||||
username=self.cassandra_username,
|
||||
password=self.cassandra_password,
|
||||
)
|
||||
else:
|
||||
self.tg = KGClass(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=workspace,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}")
|
||||
raise
|
||||
tg = await self._get_connection(workspace)
|
||||
|
||||
self.table = workspace
|
||||
|
||||
# Delete all triples for this collection using the built-in method
|
||||
self.tg.delete_collection(collection)
|
||||
await tg.async_delete_collection(collection)
|
||||
logger.info(f"Deleted all triples for collection {collection} from keyspace {workspace}")
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_do_delete)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
|
||||
from .. schema import KnowledgeResponse, Triple, Triples, EntityEmbeddings
|
||||
from .. schema import Metadata, Term, IRI, LITERAL, GraphEmbeddings
|
||||
from .. schema import DocumentEmbeddings, ChunkEmbeddings
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
|
||||
|
|
@ -217,6 +218,16 @@ class KnowledgeTableStore:
|
|||
WHERE workspace = ? AND document_id = ?
|
||||
""")
|
||||
|
||||
self.delete_document_embeddings_stmt = self.cassandra.prepare("""
|
||||
DELETE FROM document_embeddings
|
||||
WHERE workspace = ? AND document_id = ?
|
||||
""")
|
||||
|
||||
self.list_de_cores_stmt = self.cassandra.prepare("""
|
||||
SELECT DISTINCT workspace, document_id FROM document_embeddings
|
||||
WHERE workspace = ?
|
||||
""")
|
||||
|
||||
async def add_triples(self, workspace, m):
|
||||
|
||||
when = int(time.time() * 1000)
|
||||
|
|
@ -338,6 +349,50 @@ class KnowledgeTableStore:
|
|||
logger.error("Exception occurred", exc_info=True)
|
||||
raise
|
||||
|
||||
try:
|
||||
await async_execute(
|
||||
self.cassandra,
|
||||
self.delete_document_embeddings_stmt,
|
||||
(workspace, document_id),
|
||||
)
|
||||
except Exception:
|
||||
logger.error("Exception occurred", exc_info=True)
|
||||
raise
|
||||
|
||||
async def delete_document_embeddings(self, workspace, document_id):
|
||||
|
||||
logger.debug("Delete document embeddings...")
|
||||
|
||||
try:
|
||||
await async_execute(
|
||||
self.cassandra,
|
||||
self.delete_document_embeddings_stmt,
|
||||
(workspace, document_id),
|
||||
)
|
||||
except Exception:
|
||||
logger.error("Exception occurred", exc_info=True)
|
||||
raise
|
||||
|
||||
async def list_de_cores(self, workspace):
|
||||
|
||||
logger.debug("List DE cores...")
|
||||
|
||||
try:
|
||||
rows = await async_execute(
|
||||
self.cassandra,
|
||||
self.list_de_cores_stmt,
|
||||
(workspace,),
|
||||
)
|
||||
except Exception:
|
||||
logger.error("Exception occurred", exc_info=True)
|
||||
raise
|
||||
|
||||
lst = [row[1] for row in rows]
|
||||
|
||||
logger.debug("Done")
|
||||
|
||||
return lst
|
||||
|
||||
async def get_triples(self, workspace, document_id, receiver):
|
||||
|
||||
logger.debug("Get triples...")
|
||||
|
|
@ -417,3 +472,42 @@ class KnowledgeTableStore:
|
|||
|
||||
logger.debug("Done")
|
||||
|
||||
async def get_document_embeddings(self, workspace, document_id, receiver):
|
||||
|
||||
logger.debug("Get DE...")
|
||||
|
||||
try:
|
||||
rows = await async_execute(
|
||||
self.cassandra,
|
||||
self.get_document_embeddings_stmt,
|
||||
(workspace, document_id),
|
||||
)
|
||||
except Exception:
|
||||
logger.error("Exception occurred", exc_info=True)
|
||||
raise
|
||||
|
||||
for row in rows:
|
||||
|
||||
if row[3]:
|
||||
chunks = [
|
||||
ChunkEmbeddings(
|
||||
chunk_id=ch[0],
|
||||
vector=ch[1],
|
||||
)
|
||||
for ch in row[3]
|
||||
]
|
||||
else:
|
||||
chunks = []
|
||||
|
||||
await receiver(
|
||||
DocumentEmbeddings(
|
||||
metadata = Metadata(
|
||||
id = document_id,
|
||||
collection = "default",
|
||||
),
|
||||
chunks = chunks
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug("Done")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue