Additional user fixes and test fixes

This commit is contained in:
Cyber MacGeddon 2026-04-21 10:53:15 +01:00
parent db05427d0e
commit 7f0f79dd15
62 changed files with 1078 additions and 1315 deletions

View file

@ -28,7 +28,6 @@ def sample_text_document():
"""Sample document with moderate length text."""
metadata = Metadata(
id="test-doc-1",
user="test-user",
collection="test-collection"
)
text = "The quick brown fox jumps over the lazy dog. " * 20
@ -43,7 +42,6 @@ def long_text_document():
"""Long document for testing multiple chunks."""
metadata = Metadata(
id="test-doc-long",
user="test-user",
collection="test-collection"
)
# Create a long text that will definitely be chunked
@ -59,7 +57,6 @@ def unicode_text_document():
"""Document with various unicode characters."""
metadata = Metadata(
id="test-doc-unicode",
user="test-user",
collection="test-collection"
)
text = """
@ -84,7 +81,6 @@ def empty_text_document():
"""Empty document for edge case testing."""
metadata = Metadata(
id="test-doc-empty",
user="test-user",
collection="test-collection"
)
return TextDocument(

View file

@ -185,7 +185,6 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_text_doc = MagicMock()
mock_text_doc.metadata = Metadata(
id="test-doc-123",
user="test-user",
collection="test-collection"
)
mock_text_doc.text = b"This is test document content"

View file

@ -185,7 +185,6 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_text_doc = MagicMock()
mock_text_doc.metadata = Metadata(
id="test-doc-456",
user="test-user",
collection="test-collection"
)
mock_text_doc.text = b"This is test document content for token chunking"

View file

@ -73,7 +73,6 @@ def sample_triples():
return Triples(
metadata=Metadata(
id="test-doc-id",
user="test-user",
collection="default", # This should be overridden
),
triples=[
@ -92,7 +91,6 @@ def sample_graph_embeddings():
return GraphEmbeddings(
metadata=Metadata(
id="test-doc-id",
user="test-user",
collection="default", # This should be overridden
),
entities=[
@ -148,7 +146,6 @@ class TestKnowledgeManagerLoadCore:
mock_triples_pub.send.assert_called_once()
sent_triples = mock_triples_pub.send.call_args[0][1]
assert sent_triples.metadata.collection == "test-collection"
assert sent_triples.metadata.user == "test-user"
assert sent_triples.metadata.id == "test-doc-id"
@pytest.mark.asyncio
@ -187,7 +184,6 @@ class TestKnowledgeManagerLoadCore:
mock_ge_pub.send.assert_called_once()
sent_ge = mock_ge_pub.send.call_args[0][1]
assert sent_ge.metadata.collection == "test-collection"
assert sent_ge.metadata.user == "test-user"
assert sent_ge.metadata.id == "test-doc-id"
@pytest.mark.asyncio

View file

@ -237,7 +237,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
# Mock message with inline data
content = b"# Document Title\nBody text content."
mock_metadata = Metadata(id="test-doc", user="testuser",
mock_metadata = Metadata(id="test-doc",
collection="default")
mock_document = Document(
metadata=mock_metadata,
@ -294,7 +294,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
# Mock message
content = b"fake pdf"
mock_metadata = Metadata(id="test-doc", user="testuser",
mock_metadata = Metadata(id="test-doc",
collection="default")
mock_document = Document(
metadata=mock_metadata,
@ -345,7 +345,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
]
content = b"fake pdf"
mock_metadata = Metadata(id="test-doc", user="testuser",
mock_metadata = Metadata(id="test-doc",
collection="default")
mock_document = Document(
metadata=mock_metadata,

View file

@ -20,9 +20,8 @@ def processor():
)
def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1",
user="test", collection="default"):
metadata = Metadata(id=doc_id, user=user, collection=collection)
def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1", collection="default"):
metadata = Metadata(id=doc_id, collection=collection)
value = Chunk(metadata=metadata, chunk=chunk_text, document_id=doc_id)
msg = MagicMock()
msg.value.return_value = value
@ -144,7 +143,6 @@ class TestDocumentEmbeddingsProcessor:
await processor.on_message(msg, MagicMock(), flow)
result = mock_output.send.call_args[0][0]
assert result.metadata.user == "alice"
assert result.metadata.collection == "reports"
assert result.metadata.id == "d1"

View file

@ -27,8 +27,8 @@ def _make_entity_context(name, context, chunk_id="chunk-1"):
return MagicMock(entity=entity, context=context, chunk_id=chunk_id)
def _make_message(entities, doc_id="doc-1", user="test", collection="default"):
metadata = Metadata(id=doc_id, user=user, collection=collection)
def _make_message(entities, doc_id="doc-1", collection="default"):
metadata = Metadata(id=doc_id, collection=collection)
value = EntityContexts(metadata=metadata, entities=entities)
msg = MagicMock()
msg.value.return_value = value
@ -151,7 +151,7 @@ class TestGraphEmbeddingsBatchProcessing:
_make_entity_context(f"E{i}", f"ctx {i}")
for i in range(5)
]
msg = _make_message(entities, doc_id="doc-42", user="alice", collection="main")
msg = _make_message(entities, doc_id="doc-42", collection="main")
mock_embed = AsyncMock(return_value=[[0.0]] * 5)
mock_output = AsyncMock()
@ -168,7 +168,6 @@ class TestGraphEmbeddingsBatchProcessing:
for call in mock_output.send.call_args_list:
result = call[0][0]
assert result.metadata.id == "doc-42"
assert result.metadata.user == "alice"
assert result.metadata.collection == "main"
@pytest.mark.asyncio

View file

@ -34,11 +34,10 @@ def _make_defn(entity, definition):
return {"entity": entity, "definition": definition}
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1",
user="user-1", collection="col-1", document_id=""):
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", collection="col-1", document_id=""):
chunk = Chunk(
metadata=Metadata(
id=meta_id, root=root, user=user, collection=collection,
id=meta_id, root=root, collection=collection,
),
chunk=text.encode("utf-8"),
document_id=document_id,
@ -229,8 +228,7 @@ class TestMetadataPreservation:
defs = [_make_defn("X", "def X")]
flow, triples_pub, _, _ = _make_flow(defs)
msg = _make_chunk_msg(
"text", meta_id="c-1", root="r-1",
user="u-1", collection="coll-1",
"text", meta_id="c-1", root="r-1", collection="coll-1",
)
await proc.on_message(msg, MagicMock(), flow)
@ -238,7 +236,6 @@ class TestMetadataPreservation:
for triples_msg in _sent_triples(triples_pub):
assert triples_msg.metadata.id == "c-1"
assert triples_msg.metadata.root == "r-1"
assert triples_msg.metadata.user == "u-1"
assert triples_msg.metadata.collection == "coll-1"
@pytest.mark.asyncio
@ -247,8 +244,7 @@ class TestMetadataPreservation:
defs = [_make_defn("X", "def X")]
flow, _, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg(
"text", meta_id="c-2", root="r-2",
user="u-2", collection="coll-2",
"text", meta_id="c-2", root="r-2", collection="coll-2",
)
await proc.on_message(msg, MagicMock(), flow)

View file

@ -38,12 +38,11 @@ def _make_rel(subject, predicate, obj, object_entity=True):
}
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1",
user="user-1", collection="col-1", document_id=""):
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", collection="col-1", document_id=""):
"""Build a mock message wrapping a Chunk."""
chunk = Chunk(
metadata=Metadata(
id=meta_id, root=root, user=user, collection=collection,
id=meta_id, root=root, collection=collection,
),
chunk=text.encode("utf-8"),
document_id=document_id,
@ -189,8 +188,7 @@ class TestMetadataPreservation:
rels = [_make_rel("X", "rel", "Y")]
flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg(
"text", meta_id="c-1", root="r-1",
user="u-1", collection="coll-1",
"text", meta_id="c-1", root="r-1", collection="coll-1",
)
await proc.on_message(msg, MagicMock(), flow)
@ -198,7 +196,6 @@ class TestMetadataPreservation:
for triples_msg in _sent_triples(pub):
assert triples_msg.metadata.id == "c-1"
assert triples_msg.metadata.root == "r-1"
assert triples_msg.metadata.user == "u-1"
assert triples_msg.metadata.collection == "coll-1"

View file

@ -186,7 +186,6 @@ class TestEntityContextsImportMessageProcessing:
assert isinstance(sent, EntityContexts)
assert isinstance(sent.metadata, Metadata)
assert sent.metadata.id == "doc-123"
assert sent.metadata.user == "testuser"
assert sent.metadata.collection == "testcollection"
assert len(sent.entities) == 2

View file

@ -188,7 +188,6 @@ class TestGraphEmbeddingsImportMessageProcessing:
assert isinstance(sent, GraphEmbeddings)
assert isinstance(sent.metadata, Metadata)
assert sent.metadata.id == "doc-123"
assert sent.metadata.user == "testuser"
assert sent.metadata.collection == "testcollection"
assert len(sent.entities) == 2

View file

@ -235,7 +235,6 @@ class TestRowsImportMessageProcessing:
# Check metadata
assert sent_object.metadata.id == "obj-123"
assert sent_object.metadata.user == "testuser"
assert sent_object.metadata.collection == "testcollection"
@patch('trustgraph.gateway.dispatch.rows_import.Publisher')

View file

@ -23,7 +23,6 @@ class TestTextDocumentTranslator:
)
assert msg.metadata.id == "doc-1"
assert msg.metadata.user == "alice"
assert msg.metadata.collection == "research"
assert msg.text == payload.encode("utf-8")

View file

@ -108,7 +108,6 @@ def sample_triples(sample_triple):
"""Sample Triples batch object"""
metadata = Metadata(
id="test-doc-123",
user="test_user",
collection="test_collection",
)
@ -123,7 +122,6 @@ def sample_chunk():
"""Sample text chunk for processing"""
metadata = Metadata(
id="test-chunk-456",
user="test_user",
collection="test_collection",
)

View file

@ -322,7 +322,6 @@ This is not JSON at all
assert isinstance(sent_triples, Triples)
# Check metadata fields individually since implementation creates new Metadata object
assert sent_triples.metadata.id == sample_metadata.id
assert sent_triples.metadata.user == sample_metadata.user
assert sent_triples.metadata.collection == sample_metadata.collection
assert len(sent_triples.triples) == 1
assert sent_triples.triples[0].s.iri == "test:subject"
@ -346,7 +345,6 @@ This is not JSON at all
assert isinstance(sent_contexts, EntityContexts)
# Check metadata fields individually since implementation creates new Metadata object
assert sent_contexts.metadata.id == sample_metadata.id
assert sent_contexts.metadata.user == sample_metadata.user
assert sent_contexts.metadata.collection == sample_metadata.collection
assert len(sent_contexts.entities) == 1
assert sent_contexts.entities[0].entity.iri == "test:entity"

View file

@ -311,8 +311,7 @@ class TestObjectExtractionBusinessLogic:
"""Test ExtractedObject creation and properties"""
# Arrange
metadata = Metadata(
id="test-extraction-001",
user="test_user",
id="test-extraction-001",
collection="test_collection",
)
@ -337,7 +336,6 @@ class TestObjectExtractionBusinessLogic:
assert extracted_obj.values[0]["customer_id"] == "CUST001"
assert extracted_obj.confidence == 0.95
assert "John Doe" in extracted_obj.source_span
assert extracted_obj.metadata.user == "test_user"
def test_config_parsing_error_handling(self):
"""Test configuration parsing with invalid JSON"""

View file

@ -371,7 +371,6 @@ class TestTripleConstructionLogic:
metadata = Metadata(
id="test-doc-123",
user="test_user",
collection="test_collection",
)
@ -384,7 +383,6 @@ class TestTripleConstructionLogic:
# Assert
assert isinstance(triples_batch, Triples)
assert triples_batch.metadata.id == "test-doc-123"
assert triples_batch.metadata.user == "test_user"
assert triples_batch.metadata.collection == "test_collection"
assert len(triples_batch.triples) == 2

View file

@ -17,7 +17,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test document embeddings
@ -80,7 +79,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings for a single chunk"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -89,7 +87,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify insert was called once for the single chunk with its vector
processor.vecstore.insert.assert_called_once_with(
@ -122,7 +120,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with empty chunk (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -131,7 +128,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify no insert was called for empty chunk
processor.vecstore.insert.assert_not_called()
@ -141,7 +138,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with None chunk_id"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -150,7 +146,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Note: Implementation passes through None chunk_ids (only skips empty string "")
processor.vecstore.insert.assert_called_once_with(
@ -162,7 +158,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with mix of valid and empty chunks"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
valid_chunk = ChunkEmbeddings(
@ -179,7 +174,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [valid_chunk, empty_chunk, another_valid]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify valid chunks were inserted, empty string chunk was skipped
expected_calls = [
@ -200,11 +195,10 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with empty chunks list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.chunks = []
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify no insert was called
processor.vecstore.insert.assert_not_called()
@ -214,7 +208,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings for chunk with no vectors"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -223,7 +216,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify no insert was called (no vectors to insert)
processor.vecstore.insert.assert_not_called()
@ -233,7 +226,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Each chunk has a single vector of different dimensions
@ -251,7 +243,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk1, chunk2, chunk3]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify all vectors were inserted regardless of dimension with user/collection parameters
expected_calls = [
@ -273,7 +265,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with Unicode content in chunk_id"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -282,7 +273,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify Unicode chunk_id was stored correctly with user/collection parameters
processor.vecstore.insert.assert_called_once_with(
@ -294,7 +285,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with long chunk_id"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a long chunk_id
@ -305,7 +295,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify long chunk_id was inserted with user/collection parameters
processor.vecstore.insert.assert_called_once_with(
@ -317,7 +307,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with whitespace-only chunk"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -326,7 +315,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify whitespace content was inserted (not filtered out) with user/collection parameters
processor.vecstore.insert.assert_called_once_with(
@ -343,25 +332,24 @@ class TestMilvusDocEmbeddingsStorageProcessor:
('test@domain.com', 'test-collection.v1'),
]
for user, collection in test_cases:
for workspace, collection in test_cases:
processor.vecstore.reset_mock() # Reset mock for each test case
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = user
message.metadata.collection = collection
chunk = ChunkEmbeddings(
chunk_id="Test content",
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify insert was called with the correct user/collection
await processor.store_document_embeddings(workspace, message)
# Verify insert was called with the correct workspace/collection
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], "Test content", user, collection
[0.1, 0.2, 0.3], "Test content", workspace, collection
)
@pytest.mark.asyncio
@ -370,7 +358,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
# Store embeddings for user1/collection1
message1 = MagicMock()
message1.metadata = MagicMock()
message1.metadata.user = 'user1'
message1.metadata.collection = 'collection1'
chunk1 = ChunkEmbeddings(
chunk_id="User1 content",
@ -381,7 +368,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
# Store embeddings for user2/collection2
message2 = MagicMock()
message2.metadata = MagicMock()
message2.metadata.user = 'user2'
message2.metadata.collection = 'collection2'
chunk2 = ChunkEmbeddings(
chunk_id="User2 content",
@ -389,8 +375,8 @@ class TestMilvusDocEmbeddingsStorageProcessor:
)
message2.chunks = [chunk2]
await processor.store_document_embeddings(message1)
await processor.store_document_embeddings(message2)
await processor.store_document_embeddings('user1', message1)
await processor.store_document_embeddings('user2', message2)
# Verify both calls were made with correct parameters
expected_calls = [
@ -411,18 +397,17 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with special characters in user/collection names"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'user@domain.com' # Email-like user
message.metadata.collection = 'test-collection.v1' # Collection with special chars
chunk = ChunkEmbeddings(
chunk_id="Special chars test",
vector=[0.1, 0.2, 0.3]
)
message.chunks = [chunk]
await processor.store_document_embeddings(message)
# Verify the exact user/collection strings are passed (sanitization happens in DocVectors)
await processor.store_document_embeddings('user@domain.com', message)
# Verify the exact workspace/collection strings are passed (sanitization happens in DocVectors)
processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], "Special chars test", 'user@domain.com', 'test-collection.v1'
)

View file

@ -21,7 +21,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test document embeddings
@ -120,7 +119,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings for a single chunk"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -135,7 +133,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2']):
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify index name and operations (with dimension suffix)
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
@ -185,7 +183,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test that writing to non-existent index creates it lazily"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -200,7 +197,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
processor.pinecone.Index.return_value = mock_index
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify index was created with correct dimension
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
@ -217,7 +214,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with empty chunk (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -229,7 +225,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify no upsert was called for empty chunk
mock_index.upsert.assert_not_called()
@ -239,7 +235,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with None chunk (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -251,7 +246,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify no upsert was called for None chunk
mock_index.upsert.assert_not_called()
@ -261,7 +256,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with chunk that decodes to empty string"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -273,7 +267,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify no upsert was called for empty decoded chunk
mock_index.upsert.assert_not_called()
@ -283,7 +277,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Each chunk has a single vector of different dimensions
@ -325,14 +318,13 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with empty chunks list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.chunks = []
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify no operations were performed
processor.pinecone.Index.assert_not_called()
@ -343,7 +335,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings for chunk with no vectors"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -355,7 +346,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify no upsert was called (no vectors to insert)
mock_index.upsert.assert_not_called()
@ -365,7 +356,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test that lazy creation happens when index doesn't exist"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -380,7 +370,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
processor.pinecone.Index.return_value = mock_index
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify index was created
processor.pinecone.create_index.assert_called_once()
@ -390,7 +380,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test that lazy creation works correctly"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -405,7 +394,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
processor.pinecone.Index.return_value = mock_index
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify index was created and used
processor.pinecone.create_index.assert_called_once()
@ -416,7 +405,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with Unicode content"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings(
@ -430,7 +418,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify Unicode content was properly decoded and stored
call_args = mock_index.upsert.call_args
@ -442,7 +430,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with large document chunks"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a large document chunk
@ -458,7 +445,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message)
await processor.store_document_embeddings('test_user', message)
# Verify large content was stored
call_args = mock_index.upsert.call_args

View file

@ -84,7 +84,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with chunks and vectors
mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection'
mock_chunk = MagicMock()
@ -94,7 +93,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk]
# Act
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('test_user', mock_message)
# Assert
# Verify collection existence was checked (with dimension suffix)
@ -138,7 +137,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with multiple chunks
mock_message = MagicMock()
mock_message.metadata.user = 'multi_user'
mock_message.metadata.collection = 'multi_collection'
mock_chunk1 = MagicMock()
@ -152,7 +150,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk1, mock_chunk2]
# Act
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('multi_user', mock_message)
# Assert
# Should be called twice (once per chunk)
@ -198,7 +196,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with multiple chunks, each having a single vector
mock_message = MagicMock()
mock_message.metadata.user = 'vector_user'
mock_message.metadata.collection = 'vector_collection'
mock_chunk1 = MagicMock()
@ -216,7 +213,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk1, mock_chunk2, mock_chunk3]
# Act
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('vector_user', mock_message)
# Assert
# Should be called 3 times (once per chunk)
@ -255,7 +252,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with empty chunk_id
mock_message = MagicMock()
mock_message.metadata.user = 'empty_user'
mock_message.metadata.collection = 'empty_collection'
mock_chunk_empty = MagicMock()
@ -265,7 +261,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk_empty]
# Act
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('empty_user', mock_message)
# Assert
# Should not call upsert for empty chunk_ids
@ -298,7 +294,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'new_user'
mock_message.metadata.collection = 'new_collection'
mock_chunk = MagicMock()
@ -308,7 +303,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk]
# Act
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('new_user', mock_message)
# Assert - collection should be lazily created
expected_collection = 'd_new_user_new_collection_5' # 5 dimensions
@ -350,7 +345,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'error_user'
mock_message.metadata.collection = 'error_collection'
mock_chunk = MagicMock()
@ -361,7 +355,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Act & Assert - should propagate the creation error
with pytest.raises(Exception, match="Connection error"):
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('error_user', mock_message)
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
@ -388,7 +382,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create first mock message
mock_message1 = MagicMock()
mock_message1.metadata.user = 'cache_user'
mock_message1.metadata.collection = 'cache_collection'
mock_chunk1 = MagicMock()
@ -398,7 +391,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message1.chunks = [mock_chunk1]
# First call
await processor.store_document_embeddings(mock_message1)
await processor.store_document_embeddings('cache_user', mock_message1)
# Reset mock to track second call
mock_qdrant_instance.reset_mock()
@ -406,7 +399,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create second mock message with same dimensions
mock_message2 = MagicMock()
mock_message2.metadata.user = 'cache_user'
mock_message2.metadata.collection = 'cache_collection'
mock_chunk2 = MagicMock()
@ -416,7 +408,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message2.chunks = [mock_chunk2]
# Act - Second call with same collection
await processor.store_document_embeddings(mock_message2)
await processor.store_document_embeddings('cache_user', mock_message2)
# Assert
expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions
@ -452,7 +444,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with chunks of different dimensions
mock_message = MagicMock()
mock_message.metadata.user = 'dim_user'
mock_message.metadata.collection = 'dim_collection'
mock_chunk1 = MagicMock()
@ -466,7 +457,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk1, mock_chunk2]
# Act
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('dim_user', mock_message)
# Assert
# Should check existence of DIFFERENT collections for each dimension
@ -526,7 +517,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with URI-style chunk_id
mock_message = MagicMock()
mock_message.metadata.user = 'uri_user'
mock_message.metadata.collection = 'uri_collection'
mock_chunk = MagicMock()
@ -536,7 +526,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk]
# Act
await processor.store_document_embeddings(mock_message)
await processor.store_document_embeddings('uri_user', mock_message)
# Assert
# Verify the chunk_id was stored correctly

View file

@ -17,7 +17,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test entities with embeddings
@ -80,7 +79,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings for a single entity"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -89,7 +87,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify insert was called once with the full vector
processor.vecstore.insert.assert_called_once()
@ -125,7 +123,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with empty entity value (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -134,7 +131,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify no insert was called for empty entity
processor.vecstore.insert.assert_not_called()
@ -144,7 +141,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with None entity value (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -153,7 +149,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify no insert was called for None entity
processor.vecstore.insert.assert_not_called()
@ -163,7 +159,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with mix of valid and invalid entities"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
valid_entity = EntityEmbeddings(
@ -183,7 +178,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
)
message.entities = [valid_entity, empty_entity, none_entity]
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify only valid entity was inserted with user/collection/chunk_id parameters
processor.vecstore.insert.assert_called_once_with(
@ -196,11 +191,10 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with empty entities list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.entities = []
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify no insert was called
processor.vecstore.insert.assert_not_called()
@ -210,7 +204,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings for entity with no vectors"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -219,7 +212,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
)
message.entities = [entity]
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify no insert was called (no vectors to insert)
processor.vecstore.insert.assert_not_called()
@ -229,7 +222,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Each entity has a single vector of different dimensions
@ -247,7 +239,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
)
message.entities = [entity1, entity2, entity3]
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify all vectors were inserted regardless of dimension
expected_calls = [
@ -267,7 +259,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings for both URI and literal entities"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
uri_entity = EntityEmbeddings(
@ -280,7 +271,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
)
message.entities = [uri_entity, literal_entity]
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify both entities were inserted
expected_calls = [

View file

@ -21,7 +21,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create test entity embeddings (each entity has a single vector)
@ -124,7 +123,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings for a single entity"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -139,7 +137,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1']):
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify index name and operations (with dimension suffix)
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
@ -189,7 +187,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test that writing to non-existent index creates it lazily"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -204,7 +201,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
processor.pinecone.Index.return_value = mock_index
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify index was created with correct dimension
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
@ -221,7 +218,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with empty entity value (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -233,7 +229,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify no upsert was called for empty entity
mock_index.upsert.assert_not_called()
@ -243,7 +239,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with None entity value (should be skipped)"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -255,7 +250,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify no upsert was called for None entity
mock_index.upsert.assert_not_called()
@ -265,7 +260,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with different vector dimensions"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Each entity has a single vector of different dimensions
@ -288,7 +282,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify different indexes were used for different dimensions
index_calls = processor.pinecone.Index.call_args_list
@ -307,14 +301,13 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with empty entities list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.entities = []
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify no operations were performed
processor.pinecone.Index.assert_not_called()
@ -325,7 +318,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings for entity with no vectors"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -337,7 +329,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify no upsert was called (no vectors to insert)
mock_index.upsert.assert_not_called()
@ -347,7 +339,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test that lazy creation happens when index doesn't exist"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -362,7 +353,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
processor.pinecone.Index.return_value = mock_index
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify index was created
processor.pinecone.create_index.assert_called_once()
@ -372,7 +363,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test that lazy creation works correctly"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
entity = EntityEmbeddings(
@ -387,7 +377,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
processor.pinecone.Index.return_value = mock_index
with patch('uuid.uuid4', return_value='test-id'):
await processor.store_graph_embeddings(message)
await processor.store_graph_embeddings('test_user', message)
# Verify index was created and used
processor.pinecone.create_index.assert_called_once()

View file

@ -64,7 +64,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with entities and vectors
mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection'
mock_entity = MagicMock()
@ -75,7 +74,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.entities = [mock_entity]
# Act
await processor.store_graph_embeddings(mock_message)
await processor.store_graph_embeddings('test_user', mock_message)
# Assert
# Verify collection existence was checked (with dimension suffix)
@ -118,7 +117,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with multiple entities
mock_message = MagicMock()
mock_message.metadata.user = 'multi_user'
mock_message.metadata.collection = 'multi_collection'
mock_entity1 = MagicMock()
@ -134,7 +132,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.entities = [mock_entity1, mock_entity2]
# Act
await processor.store_graph_embeddings(mock_message)
await processor.store_graph_embeddings('multi_user', mock_message)
# Assert
# Should be called twice (once per entity)
@ -179,7 +177,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with three entities
mock_message = MagicMock()
mock_message.metadata.user = 'vector_user'
mock_message.metadata.collection = 'vector_collection'
mock_entity1 = MagicMock()
@ -200,7 +197,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.entities = [mock_entity1, mock_entity2, mock_entity3]
# Act
await processor.store_graph_embeddings(mock_message)
await processor.store_graph_embeddings('vector_user', mock_message)
# Assert
# Should be called 3 times (once per entity)
@ -238,7 +235,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with empty entity value
mock_message = MagicMock()
mock_message.metadata.user = 'empty_user'
mock_message.metadata.collection = 'empty_collection'
mock_entity_empty = MagicMock()
@ -253,7 +249,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.entities = [mock_entity_empty, mock_entity_none]
# Act
await processor.store_graph_embeddings(mock_message)
await processor.store_graph_embeddings('empty_user', mock_message)
# Assert
# Should not call upsert for empty entities

View file

@ -1,5 +1,5 @@
"""
Tests for Memgraph user/collection isolation in storage service
Tests for Memgraph workspace/collection isolation in storage service.
"""
import pytest
@ -8,47 +8,45 @@ from unittest.mock import MagicMock, patch
from trustgraph.storage.triples.memgraph.write import Processor
class TestMemgraphUserCollectionIsolation:
"""Test cases for Memgraph storage service with user/collection isolation"""
class TestMemgraphWorkspaceCollectionIsolation:
"""Test cases for Memgraph storage service with workspace/collection isolation"""
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_storage_creates_indexes_with_user_collection(self, mock_graph_db):
"""Test that storage creates both legacy and user/collection indexes"""
def test_storage_creates_indexes_with_workspace_collection(self, mock_graph_db):
"""Test that storage creates both legacy and workspace/collection indexes"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = Processor(taskgroup=MagicMock())
# Verify all indexes were attempted (4 legacy + 4 user/collection = 8 total)
# 4 legacy + 4 workspace/collection = 8 total
assert mock_session.run.call_count == 8
# Check some specific index creation calls
expected_calls = [
"CREATE INDEX ON :Node",
"CREATE INDEX ON :Node(uri)",
"CREATE INDEX ON :Literal",
"CREATE INDEX ON :Literal(value)",
"CREATE INDEX ON :Node(user)",
"CREATE INDEX ON :Node(workspace)",
"CREATE INDEX ON :Node(collection)",
"CREATE INDEX ON :Literal(user)",
"CREATE INDEX ON :Literal(workspace)",
"CREATE INDEX ON :Literal(collection)"
]
for expected_call in expected_calls:
mock_session.run.assert_any_call(expected_call)
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
@pytest.mark.asyncio
async def test_store_triples_with_user_collection(self, mock_graph_db):
"""Test that store_triples includes user/collection in all operations"""
async def test_store_triples_with_workspace_collection(self, mock_graph_db):
"""Test that store_triples includes workspace/collection in all operations"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
@ -58,45 +56,39 @@ class TestMemgraphUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
# Create mock triple with URI object
from trustgraph.schema import IRI
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.o.value = "http://example.com/object"
triple.o.is_uri = True
triple.s.type = IRI
triple.s.iri = "http://example.com/subject"
triple.p.type = IRI
triple.p.iri = "http://example.com/predicate"
triple.o.type = IRI
triple.o.iri = "http://example.com/object"
# Create mock message with metadata
mock_message = MagicMock()
mock_message.triples = [triple]
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples("test_workspace", mock_message)
# Verify user/collection parameters were passed to all operations
# Should have: create_node (subject), create_node (object), relate_node = 3 calls
# create_node (subject), create_node (object), relate_node = 3 calls
assert mock_driver.execute_query.call_count == 3
# Check that user and collection were included in all calls
for call in mock_driver.execute_query.call_args_list:
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
assert 'user' in call_kwargs
assert 'collection' in call_kwargs
assert call_kwargs['user'] == "test_user"
assert call_kwargs['collection'] == "test_collection"
for c in mock_driver.execute_query.call_args_list:
kwargs = c.kwargs
assert kwargs['workspace'] == "test_workspace"
assert kwargs['collection'] == "test_collection"
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
@pytest.mark.asyncio
async def test_store_triples_with_default_user_collection(self, mock_graph_db):
"""Test that defaults are used when user/collection not provided in metadata"""
async def test_store_triples_with_default_collection(self, mock_graph_db):
"""Test that default collection is used when not provided in metadata"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
@ -106,157 +98,151 @@ class TestMemgraphUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock())
# Create mock triple
from trustgraph.schema import IRI, LITERAL
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.s.type = IRI
triple.s.iri = "http://example.com/subject"
triple.p.type = IRI
triple.p.iri = "http://example.com/predicate"
triple.o.type = LITERAL
triple.o.value = "literal_value"
triple.o.is_uri = False
# Create mock message without user/collection metadata
mock_message = MagicMock()
mock_message.triples = [triple]
mock_message.metadata.user = None
mock_message.metadata.collection = None
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples("default", mock_message)
# Verify defaults were used
for call in mock_driver.execute_query.call_args_list:
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
assert call_kwargs['user'] == "default"
assert call_kwargs['collection'] == "default"
for c in mock_driver.execute_query.call_args_list:
kwargs = c.kwargs
assert kwargs['workspace'] == "default"
assert kwargs['collection'] == "default"
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_create_node_includes_user_collection(self, mock_graph_db):
"""Test that create_node includes user/collection properties"""
def test_create_node_includes_workspace_collection(self, mock_graph_db):
"""Test that create_node includes workspace/collection properties"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=MagicMock())
processor.create_node("http://example.com/node", "test_user", "test_collection")
processor.create_node("http://example.com/node", "test_workspace", "test_collection")
mock_driver.execute_query.assert_called_with(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri="http://example.com/node",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="memgraph"
)
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_create_literal_includes_user_collection(self, mock_graph_db):
"""Test that create_literal includes user/collection properties"""
def test_create_literal_includes_workspace_collection(self, mock_graph_db):
"""Test that create_literal includes workspace/collection properties"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=MagicMock())
processor.create_literal("test_value", "test_user", "test_collection")
processor.create_literal("test_value", "test_workspace", "test_collection")
mock_driver.execute_query.assert_called_with(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value="test_value",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="memgraph"
)
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_relate_node_includes_user_collection(self, mock_graph_db):
"""Test that relate_node includes user/collection properties"""
def test_relate_node_includes_workspace_collection(self, mock_graph_db):
"""Test that relate_node includes workspace/collection properties"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 0
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=MagicMock())
processor.relate_node(
"http://example.com/subject",
"http://example.com/predicate",
"http://example.com/predicate",
"http://example.com/object",
"test_user",
"test_workspace",
"test_collection"
)
mock_driver.execute_query.assert_called_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src="http://example.com/subject",
dest="http://example.com/object",
uri="http://example.com/predicate",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="memgraph"
)
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_relate_literal_includes_user_collection(self, mock_graph_db):
"""Test that relate_literal includes user/collection properties"""
def test_relate_literal_includes_workspace_collection(self, mock_graph_db):
"""Test that relate_literal includes workspace/collection properties"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 0
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=MagicMock())
processor.relate_literal(
"http://example.com/subject",
"http://example.com/predicate",
"literal_value",
"test_user",
"test_workspace",
"test_collection"
)
mock_driver.execute_query.assert_called_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src="http://example.com/subject",
dest="literal_value",
uri="http://example.com/predicate",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="memgraph"
)
@ -264,20 +250,15 @@ class TestMemgraphUserCollectionIsolation:
def test_add_args_includes_memgraph_parameters(self):
"""Test that add_args properly configures Memgraph-specific parameters"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added with Memgraph defaults
args = parser.parse_args([])
assert hasattr(args, 'graph_host')
assert args.graph_host == 'bolt://memgraph:7687'
assert hasattr(args, 'username')
@ -288,19 +269,18 @@ class TestMemgraphUserCollectionIsolation:
assert args.database == 'memgraph'
class TestMemgraphUserCollectionRegression:
"""Regression tests to ensure user/collection isolation prevents data leakage"""
class TestMemgraphWorkspaceCollectionRegression:
"""Regression tests to ensure workspace/collection isolation prevents data leakage"""
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
@pytest.mark.asyncio
async def test_regression_no_cross_user_data_access(self, mock_graph_db):
"""Regression test: Ensure users cannot access each other's data"""
async def test_regression_no_cross_workspace_data_access(self, mock_graph_db):
"""Regression test: Ensure workspaces cannot access each other's data"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
@ -310,60 +290,55 @@ class TestMemgraphUserCollectionRegression:
processor = Processor(taskgroup=MagicMock())
# Store data for user1
from trustgraph.schema import IRI, LITERAL
triple = MagicMock()
triple.s.value = "http://example.com/subject"
triple.p.value = "http://example.com/predicate"
triple.o.value = "user1_data"
triple.o.is_uri = False
triple.s.type = IRI
triple.s.iri = "http://example.com/subject"
triple.p.type = IRI
triple.p.iri = "http://example.com/predicate"
triple.o.type = LITERAL
triple.o.value = "ws1_data"
message_user1 = MagicMock()
message_user1.triples = [triple]
message_user1.metadata.user = "user1"
message_user1.metadata.collection = "collection1"
message_ws1 = MagicMock()
message_ws1.triples = [triple]
message_ws1.metadata.collection = "collection1"
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message_user1)
await processor.store_triples("workspace1", message_ws1)
# Verify that all storage operations included user1/collection1 parameters
for call in mock_driver.execute_query.call_args_list:
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
if 'user' in call_kwargs:
assert call_kwargs['user'] == "user1"
assert call_kwargs['collection'] == "collection1"
for c in mock_driver.execute_query.call_args_list:
kwargs = c.kwargs
if 'workspace' in kwargs:
assert kwargs['workspace'] == "workspace1"
assert kwargs['collection'] == "collection1"
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
@pytest.mark.asyncio
async def test_regression_same_uri_different_users(self, mock_graph_db):
"""Regression test: Same URI can exist for different users without conflict"""
async def test_regression_same_uri_different_workspaces(self, mock_graph_db):
"""Regression test: Same URI can exist in different workspaces without conflict"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock()
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_result.summary = mock_summary
mock_driver.execute_query.return_value = mock_result
processor = Processor(taskgroup=MagicMock())
# Same URI for different users should create separate nodes
processor.create_node("http://example.com/same-uri", "user1", "collection1")
processor.create_node("http://example.com/same-uri", "user2", "collection2")
# Verify both calls were made with different user/collection parameters
calls = mock_driver.execute_query.call_args_list[-2:] # Get last 2 calls
call1_kwargs = calls[0].kwargs if hasattr(calls[0], 'kwargs') else calls[0][1]
call2_kwargs = calls[1].kwargs if hasattr(calls[1], 'kwargs') else calls[1][1]
assert call1_kwargs['user'] == "user1" and call1_kwargs['collection'] == "collection1"
assert call2_kwargs['user'] == "user2" and call2_kwargs['collection'] == "collection2"
# Both should have the same URI but different user/collection
assert call1_kwargs['uri'] == call2_kwargs['uri'] == "http://example.com/same-uri"
processor.create_node("http://example.com/same-uri", "workspace1", "collection1")
processor.create_node("http://example.com/same-uri", "workspace2", "collection2")
calls = mock_driver.execute_query.call_args_list[-2:]
k1 = calls[0].kwargs
k2 = calls[1].kwargs
assert k1['workspace'] == "workspace1" and k1['collection'] == "collection1"
assert k2['workspace'] == "workspace2" and k2['collection'] == "collection2"
assert k1['uri'] == k2['uri'] == "http://example.com/same-uri"

View file

@ -1,5 +1,5 @@
"""
Tests for Neo4j user/collection isolation in triples storage and query
Tests for Neo4j workspace/collection isolation in triples storage and query.
"""
import pytest
@ -11,468 +11,406 @@ from trustgraph.schema import Triples, Triple, Term, Metadata, IRI, LITERAL
from trustgraph.schema import TriplesQueryRequest
class TestNeo4jUserCollectionIsolation:
"""Test cases for Neo4j user/collection isolation functionality"""
class TestNeo4jWorkspaceCollectionIsolation:
"""Test cases for Neo4j workspace/collection isolation functionality"""
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_storage_creates_indexes_with_user_collection(self, mock_graph_db):
"""Test that storage service creates compound indexes for user/collection"""
def test_storage_creates_indexes_with_workspace_collection(self, mock_graph_db):
"""Test that storage service creates compound indexes for workspace/collection"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = StorageProcessor(taskgroup=taskgroup_mock)
# Verify both legacy and new compound indexes are created
expected_indexes = [
"CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)",
"CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)",
"CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)",
"CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)",
"CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)",
"CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)",
"CREATE INDEX node_workspace_collection_uri FOR (n:Node) ON (n.workspace, n.collection, n.uri)",
"CREATE INDEX literal_workspace_collection_value FOR (n:Literal) ON (n.workspace, n.collection, n.value)",
"CREATE INDEX rel_workspace FOR ()-[r:Rel]-() ON (r.workspace)",
"CREATE INDEX rel_collection FOR ()-[r:Rel]-() ON (r.collection)"
]
# Check that all expected indexes were created
for expected_query in expected_indexes:
mock_session.run.assert_any_call(expected_query)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_store_triples_with_user_collection(self, mock_graph_db):
"""Test that triples are stored with user/collection properties"""
async def test_store_triples_with_workspace_collection(self, mock_graph_db):
"""Test that triples are stored with workspace/collection properties"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = StorageProcessor(taskgroup=taskgroup_mock)
# Create test message with user/collection metadata
metadata = Metadata(
id="test-id",
user="test_user",
collection="test_collection"
)
metadata = Metadata(id="test-id", collection="test_collection")
triple = Triple(
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="literal_value")
)
message = Triples(
metadata=metadata,
triples=[triple]
)
# Mock execute_query to return summaries
message = Triples(metadata=metadata, triples=[triple])
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_driver.execute_query.return_value.summary = mock_summary
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
# Verify nodes and relationships were created with user/collection properties
await processor.store_triples("test_workspace", message)
expected_calls = [
call(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri="http://example.com/subject",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_='neo4j'
),
call(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value="literal_value",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_='neo4j'
),
call(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src="http://example.com/subject",
dest="literal_value",
uri="http://example.com/predicate",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_='neo4j'
)
]
for expected_call in expected_calls:
mock_driver.execute_query.assert_any_call(*expected_call.args, **expected_call.kwargs)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_store_triples_with_default_user_collection(self, mock_graph_db):
"""Test that default user/collection are used when not provided"""
async def test_store_triples_with_default_collection(self, mock_graph_db):
"""Test that default collection is used when not provided"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = StorageProcessor(taskgroup=taskgroup_mock)
# Create test message without user/collection
metadata = Metadata(id="test-id")
triple = Triple(
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=IRI, iri="http://example.com/object")
)
message = Triples(
metadata=metadata,
triples=[triple]
)
# Mock execute_query
message = Triples(metadata=metadata, triples=[triple])
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_driver.execute_query.return_value.summary = mock_summary
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
# Verify defaults were used
await processor.store_triples("default", message)
mock_driver.execute_query.assert_any_call(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri="http://example.com/subject",
user="default",
workspace="default",
collection="default",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_filters_by_user_collection(self, mock_graph_db):
"""Test that query service filters results by user/collection"""
async def test_query_triples_filters_by_workspace_collection(self, mock_graph_db):
"""Test that query service filters results by workspace/collection"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = QueryProcessor(taskgroup=MagicMock())
# Create test query
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=None
)
# Mock query results
mock_records = [
MagicMock(data=lambda: {"dest": "http://example.com/object1"}),
MagicMock(data=lambda: {"dest": "literal_value"})
]
mock_driver.execute_query.return_value = (mock_records, MagicMock(), MagicMock())
result = await processor.query_triples(query)
# Verify queries include user/collection filters
await processor.query_triples("test_workspace", query)
expected_literal_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN dest.value as dest"
)
expected_node_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"RETURN dest.uri as dest"
)
# Check that queries were executed with user/collection parameters
calls = mock_driver.execute_query.call_args_list
assert any(
expected_literal_query in str(call) and
"user='test_user'" in str(call) and
"collection='test_collection'" in str(call)
for call in calls
expected_literal_query in str(c) and
"workspace='test_workspace'" in str(c) and
"collection='test_collection'" in str(c)
for c in calls
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_query_triples_with_default_user_collection(self, mock_graph_db):
"""Test that query service uses defaults when user/collection not provided"""
async def test_query_triples_with_default_collection(self, mock_graph_db):
"""Test that query service uses default collection when not provided"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = QueryProcessor(taskgroup=MagicMock())
# Create test query without user/collection
query = TriplesQueryRequest(
s=None,
p=None,
o=None
)
# Mock empty results
query = TriplesQueryRequest(s=None, p=None, o=None)
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
result = await processor.query_triples(query)
# Verify defaults were used in queries
await processor.query_triples("default", query)
calls = mock_driver.execute_query.call_args_list
assert any(
"user='default'" in str(call) and "collection='default'" in str(call)
for call in calls
"workspace='default'" in str(c) and "collection='default'" in str(c)
for c in calls
)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_data_isolation_between_users(self, mock_graph_db):
"""Test that data from different users is properly isolated"""
async def test_data_isolation_between_workspaces(self, mock_graph_db):
"""Test that data from different workspaces is properly isolated"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = StorageProcessor(taskgroup=taskgroup_mock)
# Create messages for different users
message_user1 = Triples(
metadata=Metadata(user="user1", collection="coll1"),
message_ws1 = Triples(
metadata=Metadata(collection="coll1"),
triples=[
Triple(
s=Term(type=IRI, iri="http://example.com/user1/subject"),
s=Term(type=IRI, iri="http://example.com/ws1/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="user1_data")
o=Term(type=LITERAL, value="ws1_data")
)
]
)
message_user2 = Triples(
metadata=Metadata(user="user2", collection="coll2"),
message_ws2 = Triples(
metadata=Metadata(collection="coll2"),
triples=[
Triple(
s=Term(type=IRI, iri="http://example.com/user2/subject"),
s=Term(type=IRI, iri="http://example.com/ws2/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="user2_data")
o=Term(type=LITERAL, value="ws2_data")
)
]
)
# Mock execute_query
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_driver.execute_query.return_value.summary = mock_summary
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
# Store data for both users
await processor.store_triples(message_user1)
await processor.store_triples(message_user2)
# Verify user1 data was stored with user1/coll1
await processor.store_triples("workspace1", message_ws1)
await processor.store_triples("workspace2", message_ws2)
mock_driver.execute_query.assert_any_call(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
value="user1_data",
user="user1",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value="ws1_data",
workspace="workspace1",
collection="coll1",
database_='neo4j'
)
# Verify user2 data was stored with user2/coll2
mock_driver.execute_query.assert_any_call(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
value="user2_data",
user="user2",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value="ws2_data",
workspace="workspace2",
collection="coll2",
database_='neo4j'
)
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_wildcard_query_respects_user_collection(self, mock_graph_db):
"""Test that wildcard queries still filter by user/collection"""
async def test_wildcard_query_respects_workspace_collection(self, mock_graph_db):
"""Test that wildcard queries still filter by workspace/collection"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = QueryProcessor(taskgroup=MagicMock())
# Create wildcard query (all nulls) with user/collection
query = TriplesQueryRequest(
user="test_user",
collection="test_collection",
s=None,
p=None,
o=None
s=None, p=None, o=None,
)
# Mock results
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
result = await processor.query_triples(query)
# Verify wildcard queries include user/collection filters
await processor.query_triples("test_workspace", query)
wildcard_query = (
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest"
)
calls = mock_driver.execute_query.call_args_list
assert any(
wildcard_query in str(call) and
"user='test_user'" in str(call) and
"collection='test_collection'" in str(call)
for call in calls
wildcard_query in str(c) and
"workspace='test_workspace'" in str(c) and
"collection='test_collection'" in str(c)
for c in calls
)
def test_add_args_includes_neo4j_parameters(self):
"""Test that add_args includes Neo4j-specific parameters"""
from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser()
with patch('trustgraph.storage.triples.neo4j.write.TriplesStoreService.add_args'):
StorageProcessor.add_args(parser)
args = parser.parse_args([])
assert hasattr(args, 'graph_host')
assert hasattr(args, 'username')
assert hasattr(args, 'password')
assert hasattr(args, 'database')
# Check defaults
assert args.graph_host == 'bolt://neo4j:7687'
assert args.username == 'neo4j'
assert args.password == 'password'
assert args.database == 'neo4j'
class TestNeo4jUserCollectionRegression:
"""Regression tests to ensure user/collection isolation prevents data leaks"""
class TestNeo4jWorkspaceCollectionRegression:
"""Regression tests to ensure workspace/collection isolation prevents data leaks"""
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio
async def test_regression_no_cross_user_data_access(self, mock_graph_db):
@pytest.mark.asyncio
async def test_regression_no_cross_workspace_data_access(self, mock_graph_db):
"""
Regression test: Ensure user1 cannot access user2's data
This test guards against the bug where all users shared the same
Neo4j graph space, causing data contamination between users.
Regression test: Ensure workspace1 cannot access workspace2's data.
Guards against a bug where all data shared the same Neo4j graph
space, causing data contamination between workspaces.
"""
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
processor = QueryProcessor(taskgroup=MagicMock())
# User1 queries for all triples
query_user1 = TriplesQueryRequest(
user="user1",
query_ws1 = TriplesQueryRequest(
collection="collection1",
s=None, p=None, o=None
)
# Mock that the database has data but none matching user1/collection1
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
result = await processor.query_triples(query_user1)
# Verify empty results (user1 cannot see other users' data)
result = await processor.query_triples("workspace1", query_ws1)
assert len(result) == 0
# Verify the query included user/collection filters
calls = mock_driver.execute_query.call_args_list
for call in calls:
query_str = str(call)
for c in calls:
query_str = str(c)
if "MATCH" in query_str:
assert "user: $user" in query_str or "user='user1'" in query_str
assert "workspace: $workspace" in query_str or "workspace='workspace1'" in query_str
assert "collection: $collection" in query_str or "collection='collection1'" in query_str
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio
async def test_regression_same_uri_different_users(self, mock_graph_db):
async def test_regression_same_uri_different_workspaces(self, mock_graph_db):
"""
Regression test: Same URI in different user contexts should create separate nodes
This ensures that http://example.com/entity for user1 is completely separate
from http://example.com/entity for user2.
Regression test: Same URI in different workspace contexts should create separate nodes.
Ensures http://example.com/entity in workspace1 is completely
separate from the same URI in workspace2.
"""
taskgroup_mock = MagicMock()
mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
processor = StorageProcessor(taskgroup=taskgroup_mock)
# Same URI for different users
shared_uri = "http://example.com/shared_entity"
message_user1 = Triples(
metadata=Metadata(user="user1", collection="coll1"),
message_ws1 = Triples(
metadata=Metadata(collection="coll1"),
triples=[
Triple(
s=Term(type=IRI, iri=shared_uri),
p=Term(type=IRI, iri="http://example.com/p"),
o=Term(type=LITERAL, value="user1_value")
o=Term(type=LITERAL, value="ws1_value")
)
]
)
message_user2 = Triples(
metadata=Metadata(user="user2", collection="coll2"),
message_ws2 = Triples(
metadata=Metadata(collection="coll2"),
triples=[
Triple(
s=Term(type=IRI, iri=shared_uri),
p=Term(type=IRI, iri="http://example.com/p"),
o=Term(type=LITERAL, value="user2_value")
o=Term(type=LITERAL, value="ws2_value")
)
]
)
# Mock execute_query
mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10
mock_driver.execute_query.return_value.summary = mock_summary
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message_user1)
await processor.store_triples(message_user2)
# Verify two separate nodes were created with same URI but different user/collection
user1_node_call = call(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
await processor.store_triples("workspace1", message_ws1)
await processor.store_triples("workspace2", message_ws2)
ws1_node_call = call(
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=shared_uri,
user="user1",
workspace="workspace1",
collection="coll1",
database_='neo4j'
)
user2_node_call = call(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
ws2_node_call = call(
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=shared_uri,
user="user2",
workspace="workspace2",
collection="coll2",
database_='neo4j'
)
mock_driver.execute_query.assert_has_calls([user1_node_call, user2_node_call], any_order=True)
mock_driver.execute_query.assert_has_calls([ws1_node_call, ws2_node_call], any_order=True)

View file

@ -1,3 +1,12 @@
def _flow_mock(workspace):
"""Build a mock flow object that is callable and exposes .workspace."""
from unittest.mock import MagicMock
f = MagicMock()
f.workspace = workspace
return f
"""
Unit tests for trustgraph.storage.row_embeddings.qdrant.write
Tests the Stage 2 processor that stores pre-computed row embeddings in Qdrant.
@ -92,13 +101,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config)
collection_name = processor.get_collection_name(
user="test_user",
workspace="test_workspace",
collection="test_collection",
schema_name="customer_data",
dimension=384
)
assert collection_name == "rows_test_user_test_collection_customer_data_384"
assert collection_name == "rows_test_workspace_test_collection_customer_data_384"
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_ensure_collection_creates_new(self, mock_qdrant_client):
@ -185,11 +194,10 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
processor.known_collections[('test_workspace', 'test_collection')] = {}
# Create embeddings message
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
@ -210,14 +218,14 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace'))
# Verify upsert was called
mock_qdrant_instance.upsert.assert_called_once()
# Verify upsert parameters
upsert_call_args = mock_qdrant_instance.upsert.call_args
assert upsert_call_args[1]['collection_name'] == 'rows_test_user_test_collection_customers_3'
assert upsert_call_args[1]['collection_name'] == 'rows_test_workspace_test_collection_customers_3'
point = upsert_call_args[1]['points'][0]
assert point.vector == [0.1, 0.2, 0.3]
@ -243,10 +251,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
processor.known_collections[('test_workspace', 'test_collection')] = {}
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
@ -267,7 +274,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace'))
# Should be called once for the single embedding
assert mock_qdrant_instance.upsert.call_count == 1
@ -287,10 +294,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {}
processor.known_collections[('test_workspace', 'test_collection')] = {}
metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection'
metadata.id = 'doc-123'
@ -311,7 +317,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace'))
# Should not call upsert for empty vectors
mock_qdrant_instance.upsert.assert_not_called()
@ -334,7 +340,6 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
# No collections registered
metadata = MagicMock()
metadata.user = 'unknown_user'
metadata.collection = 'unknown_collection'
metadata.id = 'doc-123'
@ -354,7 +359,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock())
await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace'))
# Should not call upsert for unknown collection
mock_qdrant_instance.upsert.assert_not_called()
@ -368,11 +373,11 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
# Mock collections list
mock_coll1 = MagicMock()
mock_coll1.name = 'rows_test_user_test_collection_schema1_384'
mock_coll1.name = 'rows_test_workspace_test_collection_schema1_384'
mock_coll2 = MagicMock()
mock_coll2.name = 'rows_test_user_test_collection_schema2_384'
mock_coll2.name = 'rows_test_workspace_test_collection_schema2_384'
mock_coll3 = MagicMock()
mock_coll3.name = 'rows_other_user_other_collection_schema_384'
mock_coll3.name = 'rows_other_workspace_other_collection_schema_384'
mock_collections = MagicMock()
mock_collections.collections = [mock_coll1, mock_coll2, mock_coll3]
@ -386,15 +391,15 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
}
processor = Processor(**config)
processor.created_collections.add('rows_test_user_test_collection_schema1_384')
processor.created_collections.add('rows_test_workspace_test_collection_schema1_384')
await processor.delete_collection('test_user', 'test_collection')
await processor.delete_collection('test_workspace', 'test_collection')
# Should delete only the matching collections
assert mock_qdrant_instance.delete_collection.call_count == 2
# Verify the cached collection was removed
assert 'rows_test_user_test_collection_schema1_384' not in processor.created_collections
assert 'rows_test_workspace_test_collection_schema1_384' not in processor.created_collections
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_delete_collection_schema(self, mock_qdrant_client):
@ -404,9 +409,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_qdrant_instance = MagicMock()
mock_coll1 = MagicMock()
mock_coll1.name = 'rows_test_user_test_collection_customers_384'
mock_coll1.name = 'rows_test_workspace_test_collection_customers_384'
mock_coll2 = MagicMock()
mock_coll2.name = 'rows_test_user_test_collection_orders_384'
mock_coll2.name = 'rows_test_workspace_test_collection_orders_384'
mock_collections = MagicMock()
mock_collections.collections = [mock_coll1, mock_coll2]
@ -422,13 +427,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config)
await processor.delete_collection_schema(
'test_user', 'test_collection', 'customers'
'test_workspace', 'test_collection', 'customers'
)
# Should only delete the customers schema collection
mock_qdrant_instance.delete_collection.assert_called_once()
call_args = mock_qdrant_instance.delete_collection.call_args[0]
assert call_args[0] == 'rows_test_user_test_collection_customers_384'
assert call_args[0] == 'rows_test_workspace_test_collection_customers_384'
if __name__ == '__main__':

View file

@ -187,7 +187,7 @@ class TestRowsCassandraStorageLogic:
)
}
}
processor.tables_initialized = {"test_user"}
processor.tables_initialized = {"default"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
@ -204,7 +204,6 @@ class TestRowsCassandraStorageLogic:
test_obj = ExtractedObject(
metadata=Metadata(
id="test-001",
user="test_user",
collection="test_collection",
),
schema_name="test_schema",
@ -227,7 +226,7 @@ class TestRowsCassandraStorageLogic:
values = insert_call[0][2]
# Verify using unified rows table
assert "INSERT INTO test_user.rows" in insert_cql
assert "INSERT INTO default.rows" in insert_cql
# Values should be: (collection, schema_name, index_name, index_value, data, source)
assert values[0] == "test_collection" # collection
@ -254,7 +253,7 @@ class TestRowsCassandraStorageLogic:
)
}
}
processor.tables_initialized = {"test_user"}
processor.tables_initialized = {"default"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
@ -270,7 +269,6 @@ class TestRowsCassandraStorageLogic:
test_obj = ExtractedObject(
metadata=Metadata(
id="test-001",
user="test_user",
collection="test_collection",
),
schema_name="multi_index_schema",
@ -315,7 +313,7 @@ class TestRowsCassandraStorageBatchLogic:
)
}
}
processor.tables_initialized = {"test_user"}
processor.tables_initialized = {"default"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
@ -332,7 +330,6 @@ class TestRowsCassandraStorageBatchLogic:
batch_obj = ExtractedObject(
metadata=Metadata(
id="batch-001",
user="test_user",
collection="batch_collection",
),
schema_name="batch_schema",
@ -373,7 +370,7 @@ class TestRowsCassandraStorageBatchLogic:
)
}
}
processor.tables_initialized = {"test_user"}
processor.tables_initialized = {"default"}
processor.registered_partitions = set()
processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
@ -388,7 +385,6 @@ class TestRowsCassandraStorageBatchLogic:
empty_batch_obj = ExtractedObject(
metadata=Metadata(
id="empty-001",
user="test_user",
collection="empty_collection",
),
schema_name="empty_schema",
@ -446,7 +442,7 @@ class TestUnifiedTableStructure:
def test_ensure_tables_idempotent(self):
"""Test that ensure_tables is idempotent"""
processor = MagicMock()
processor.tables_initialized = {"test_user"} # Already initialized
processor.tables_initialized = {"default"} # Already initialized
processor.session = MagicMock()
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)

View file

@ -102,11 +102,10 @@ class TestCassandraStorageProcessor:
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
# Verify KnowledgeGraph was called with auth parameters
mock_kg_class.assert_called_once_with(
@ -129,11 +128,10 @@ class TestCassandraStorageProcessor:
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user2'
mock_message.metadata.collection = 'collection2'
mock_message.triples = []
await processor.store_triples(mock_message)
await processor.store_triples('user2', mock_message)
# Verify KnowledgeGraph was called without auth parameters
mock_kg_class.assert_called_once_with(
@ -154,16 +152,15 @@ class TestCassandraStorageProcessor:
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
# First call should create TrustGraph
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
assert mock_kg_class.call_count == 1
# Second call with same table should reuse TrustGraph
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
assert mock_kg_class.call_count == 1 # Should not increase
@pytest.mark.asyncio
@ -205,11 +202,10 @@ class TestCassandraStorageProcessor:
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = [triple1, triple2]
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
# Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters)
assert mock_tg_instance.insert.call_count == 2
@ -234,11 +230,10 @@ class TestCassandraStorageProcessor:
# Create mock message with empty triples
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
# Verify no triples were inserted
mock_tg_instance.insert.assert_not_called()
@ -255,12 +250,11 @@ class TestCassandraStorageProcessor:
# Create mock message
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
with pytest.raises(Exception, match="Connection failed"):
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
# Verify sleep was called before re-raising
mock_sleep.assert_called_once_with(1)
@ -361,21 +355,19 @@ class TestCassandraStorageProcessor:
# First message with table1
mock_message1 = MagicMock()
mock_message1.metadata.user = 'user1'
mock_message1.metadata.collection = 'collection1'
mock_message1.triples = []
await processor.store_triples(mock_message1)
await processor.store_triples('user1', mock_message1)
assert processor.table == 'user1'
assert processor.tg == mock_tg_instance1
# Second message with different table
mock_message2 = MagicMock()
mock_message2.metadata.user = 'user2'
mock_message2.metadata.collection = 'collection2'
mock_message2.triples = []
await processor.store_triples(mock_message2)
await processor.store_triples('user2', mock_message2)
assert processor.table == 'user2'
assert processor.tg == mock_tg_instance2
@ -407,11 +399,10 @@ class TestCassandraStorageProcessor:
triple.g = None
mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection'
mock_message.triples = [triple]
await processor.store_triples(mock_message)
await processor.store_triples('test_workspace', mock_message)
# Verify the triple was inserted with special characters preserved
mock_tg_instance.insert.assert_called_once_with(
@ -440,12 +431,11 @@ class TestCassandraStorageProcessor:
mock_kg_class.side_effect = Exception("Connection failed")
mock_message = MagicMock()
mock_message.metadata.user = 'new_user'
mock_message.metadata.collection = 'new_collection'
mock_message.triples = []
with pytest.raises(Exception, match="Connection failed"):
await processor.store_triples(mock_message)
await processor.store_triples('new_user', mock_message)
# Table should remain unchanged since self.table = table happens after try/except
assert processor.table == ('old_user', 'old_collection')
@ -468,11 +458,10 @@ class TestCassandraPerformanceOptimizations:
processor = Processor(taskgroup=taskgroup_mock)
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
# Verify KnowledgeGraph instance uses legacy mode
assert mock_tg_instance is not None
@ -489,11 +478,10 @@ class TestCassandraPerformanceOptimizations:
processor = Processor(taskgroup=taskgroup_mock)
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = []
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
# Verify KnowledgeGraph instance is in optimized mode
assert mock_tg_instance is not None
@ -523,11 +511,10 @@ class TestCassandraPerformanceOptimizations:
triple.g = None
mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1'
mock_message.triples = [triple]
await processor.store_triples(mock_message)
await processor.store_triples('user1', mock_message)
# Verify insert was called for the triple (implementation details tested in KnowledgeGraph)
mock_tg_instance.insert.assert_called_once_with(

View file

@ -17,7 +17,6 @@ class TestFalkorDBStorageProcessor:
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a test triple
@ -89,13 +88,13 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result
processor.create_node(test_uri, 'test_user', 'test_collection')
processor.create_node(test_uri, 'test_workspace', 'test_collection')
processor.io.query.assert_called_once_with(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
params={
"uri": test_uri,
"user": 'test_user',
"workspace": 'test_workspace',
"collection": 'test_collection',
},
)
@ -109,13 +108,13 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result
processor.create_literal(test_value, 'test_user', 'test_collection')
processor.create_literal(test_value, 'test_workspace', 'test_collection')
processor.io.query.assert_called_once_with(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
params={
"value": test_value,
"user": 'test_user',
"workspace": 'test_workspace',
"collection": 'test_collection',
},
)
@ -132,17 +131,17 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result
processor.relate_node(src_uri, pred_uri, dest_uri, 'test_user', 'test_collection')
processor.relate_node(src_uri, pred_uri, dest_uri, 'test_workspace', 'test_collection')
processor.io.query.assert_called_once_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
params={
"src": src_uri,
"dest": dest_uri,
"uri": pred_uri,
"user": 'test_user',
"workspace": 'test_workspace',
"collection": 'test_collection',
},
)
@ -159,17 +158,17 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result
processor.relate_literal(src_uri, pred_uri, literal_value, 'test_user', 'test_collection')
processor.relate_literal(src_uri, pred_uri, literal_value, 'test_workspace', 'test_collection')
processor.io.query.assert_called_once_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
params={
"src": src_uri,
"dest": literal_value,
"uri": pred_uri,
"user": 'test_user',
"workspace": 'test_workspace',
"collection": 'test_collection',
},
)
@ -179,7 +178,6 @@ class TestFalkorDBStorageProcessor:
"""Test storing triple with URI object"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
triple = Triple(
@ -200,21 +198,21 @@ class TestFalkorDBStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
await processor.store_triples('test_workspace', message)
# Verify queries were called in the correct order
expected_calls = [
# Create subject node
(("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",),
{"params": {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection"}}),
(("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",),
{"params": {"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection"}}),
# Create object node
(("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",),
{"params": {"uri": "http://example.com/object", "user": "test_user", "collection": "test_collection"}}),
(("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",),
{"params": {"uri": "http://example.com/object", "workspace": "test_workspace", "collection": "test_collection"}}),
# Create relationship
(("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",),
{"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate", "user": "test_user", "collection": "test_collection"}}),
(("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",),
{"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate", "workspace": "test_workspace", "collection": "test_collection"}}),
]
assert processor.io.query.call_count == 3
@ -242,16 +240,16 @@ class TestFalkorDBStorageProcessor:
# Verify queries were called in the correct order
expected_calls = [
# Create subject node
(("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",),
{"params": {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection"}}),
(("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",),
{"params": {"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection"}}),
# Create literal object
(("MERGE (n:Literal {value: $value, user: $user, collection: $collection})",),
{"params": {"value": "literal object", "user": "test_user", "collection": "test_collection"}}),
(("MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",),
{"params": {"value": "literal object", "workspace": "test_workspace", "collection": "test_collection"}}),
# Create relationship
(("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",),
{"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate", "user": "test_user", "collection": "test_collection"}}),
(("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",),
{"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate", "workspace": "test_workspace", "collection": "test_collection"}}),
]
assert processor.io.query.call_count == 3
@ -265,7 +263,6 @@ class TestFalkorDBStorageProcessor:
"""Test storing multiple triples"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
triple1 = Triple(
@ -291,7 +288,7 @@ class TestFalkorDBStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
await processor.store_triples('test_workspace', message)
# Verify total number of queries (3 per triple)
assert processor.io.query.call_count == 6
@ -313,7 +310,6 @@ class TestFalkorDBStorageProcessor:
"""Test storing empty triples list"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.triples = []
@ -323,7 +319,7 @@ class TestFalkorDBStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
await processor.store_triples('test_workspace', message)
# Verify no queries were made
processor.io.query.assert_not_called()
@ -333,7 +329,6 @@ class TestFalkorDBStorageProcessor:
"""Test storing triples with mixed URI and literal objects"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
triple1 = Triple(
@ -359,7 +354,7 @@ class TestFalkorDBStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
await processor.store_triples('test_workspace', message)
# Verify total number of queries (3 per triple)
assert processor.io.query.call_count == 6
@ -450,13 +445,13 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result
processor.create_node(test_uri, 'test_user', 'test_collection')
processor.create_node(test_uri, 'test_workspace', 'test_collection')
processor.io.query.assert_called_once_with(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
params={
"uri": test_uri,
"user": 'test_user',
"workspace": 'test_workspace',
"collection": 'test_collection',
},
)
@ -470,13 +465,13 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result
processor.create_literal(test_value, 'test_user', 'test_collection')
processor.create_literal(test_value, 'test_workspace', 'test_collection')
processor.io.query.assert_called_once_with(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
params={
"value": test_value,
"user": 'test_user',
"workspace": 'test_workspace',
"collection": 'test_collection',
},
)

View file

@ -17,7 +17,6 @@ class TestMemgraphStorageProcessor:
"""Create a mock message for testing"""
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
# Create a test triple
@ -43,7 +42,7 @@ class TestMemgraphStorageProcessor:
taskgroup=MagicMock(),
id='test-memgraph-storage',
graph_host='bolt://localhost:7687',
username='test_user',
username='test_workspace',
password='test_pass',
database='test_db'
)
@ -105,9 +104,9 @@ class TestMemgraphStorageProcessor:
"CREATE INDEX ON :Node(uri)",
"CREATE INDEX ON :Literal",
"CREATE INDEX ON :Literal(value)",
"CREATE INDEX ON :Node(user)",
"CREATE INDEX ON :Node(workspace)",
"CREATE INDEX ON :Node(collection)",
"CREATE INDEX ON :Literal(user)",
"CREATE INDEX ON :Literal(workspace)",
"CREATE INDEX ON :Literal(collection)"
]
@ -145,12 +144,12 @@ class TestMemgraphStorageProcessor:
processor.io.execute_query.return_value = mock_result
processor.create_node(test_uri, "test_user", "test_collection")
processor.create_node(test_uri, "test_workspace", "test_collection")
processor.io.execute_query.assert_called_once_with(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=test_uri,
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_=processor.db
)
@ -166,12 +165,12 @@ class TestMemgraphStorageProcessor:
processor.io.execute_query.return_value = mock_result
processor.create_literal(test_value, "test_user", "test_collection")
processor.create_literal(test_value, "test_workspace", "test_collection")
processor.io.execute_query.assert_called_once_with(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value=test_value,
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_=processor.db
)
@ -190,14 +189,14 @@ class TestMemgraphStorageProcessor:
processor.io.execute_query.return_value = mock_result
processor.relate_node(src_uri, pred_uri, dest_uri, "test_user", "test_collection")
processor.relate_node(src_uri, pred_uri, dest_uri, "test_workspace", "test_collection")
processor.io.execute_query.assert_called_once_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=src_uri, dest=dest_uri, uri=pred_uri,
user="test_user", collection="test_collection",
workspace="test_workspace", collection="test_collection",
database_=processor.db
)
@ -215,14 +214,14 @@ class TestMemgraphStorageProcessor:
processor.io.execute_query.return_value = mock_result
processor.relate_literal(src_uri, pred_uri, literal_value, "test_user", "test_collection")
processor.relate_literal(src_uri, pred_uri, literal_value, "test_workspace", "test_collection")
processor.io.execute_query.assert_called_once_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=src_uri, dest=literal_value, uri=pred_uri,
user="test_user", collection="test_collection",
workspace="test_workspace", collection="test_collection",
database_=processor.db
)
@ -236,22 +235,22 @@ class TestMemgraphStorageProcessor:
o=Term(type=IRI, iri='http://example.com/object')
)
processor.create_triple(mock_tx, triple, "test_user", "test_collection")
processor.create_triple(mock_tx, triple, "test_workspace", "test_collection")
# Verify transaction calls
expected_calls = [
# Create subject node
("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
{'uri': 'http://example.com/subject', 'user': 'test_user', 'collection': 'test_collection'}),
("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
{'uri': 'http://example.com/subject', 'workspace': 'test_workspace', 'collection': 'test_collection'}),
# Create object node
("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
{'uri': 'http://example.com/object', 'user': 'test_user', 'collection': 'test_collection'}),
("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
{'uri': 'http://example.com/object', 'workspace': 'test_workspace', 'collection': 'test_collection'}),
# Create relationship
("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
{'src': 'http://example.com/subject', 'dest': 'http://example.com/object', 'uri': 'http://example.com/predicate',
'user': 'test_user', 'collection': 'test_collection'})
'workspace': 'test_workspace', 'collection': 'test_collection'})
]
assert mock_tx.run.call_count == 3
@ -270,22 +269,22 @@ class TestMemgraphStorageProcessor:
o=Term(type=LITERAL, value='literal object')
)
processor.create_triple(mock_tx, triple, "test_user", "test_collection")
processor.create_triple(mock_tx, triple, "test_workspace", "test_collection")
# Verify transaction calls
expected_calls = [
# Create subject node
("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
{'uri': 'http://example.com/subject', 'user': 'test_user', 'collection': 'test_collection'}),
("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
{'uri': 'http://example.com/subject', 'workspace': 'test_workspace', 'collection': 'test_collection'}),
# Create literal object
("MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
{'value': 'literal object', 'user': 'test_user', 'collection': 'test_collection'}),
("MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
{'value': 'literal object', 'workspace': 'test_workspace', 'collection': 'test_collection'}),
# Create relationship
("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
{'src': 'http://example.com/subject', 'dest': 'literal object', 'uri': 'http://example.com/predicate',
'user': 'test_user', 'collection': 'test_collection'})
'workspace': 'test_workspace', 'collection': 'test_collection'})
]
assert mock_tx.run.call_count == 3
@ -323,7 +322,7 @@ class TestMemgraphStorageProcessor:
# Verify user/collection parameters were included
for call in processor.io.execute_query.call_args_list:
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
assert 'user' in call_kwargs
assert 'workspace' in call_kwargs
assert 'collection' in call_kwargs
@pytest.mark.asyncio
@ -343,7 +342,6 @@ class TestMemgraphStorageProcessor:
# Create message with multiple triples
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
triple1 = Triple(
@ -364,7 +362,7 @@ class TestMemgraphStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
await processor.store_triples('test_workspace', message)
# Verify execute_query was called:
# Triple1: create_node(s) + create_literal(o) + relate_literal = 3 calls
@ -375,7 +373,7 @@ class TestMemgraphStorageProcessor:
# Verify user/collection parameters were included in all calls
for call in processor.io.execute_query.call_args_list:
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
assert call_kwargs['user'] == 'test_user'
assert call_kwargs['workspace'] == 'test_workspace'
assert call_kwargs['collection'] == 'test_collection'
@pytest.mark.asyncio
@ -389,7 +387,6 @@ class TestMemgraphStorageProcessor:
message = MagicMock()
message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection'
message.triples = []
@ -399,7 +396,7 @@ class TestMemgraphStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message)
await processor.store_triples('test_workspace', message)
# Verify no session calls were made (no triples to process)
processor.io.session.assert_not_called()

View file

@ -68,9 +68,9 @@ class TestNeo4jStorageProcessor:
"CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)",
"CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)",
"CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)",
"CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)",
"CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)",
"CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)",
"CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.workspace, n.collection, n.uri)",
"CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.workspace, n.collection, n.value)",
"CREATE INDEX rel_workspace FOR ()-[r:Rel]-() ON (r.workspace)",
"CREATE INDEX rel_collection FOR ()-[r:Rel]-() ON (r.collection)"
]
@ -116,12 +116,12 @@ class TestNeo4jStorageProcessor:
processor = Processor(taskgroup=taskgroup_mock)
# Test create_node
processor.create_node("http://example.com/node", "test_user", "test_collection")
processor.create_node("http://example.com/node", "test_workspace", "test_collection")
mock_driver.execute_query.assert_called_with(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri="http://example.com/node",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="neo4j"
)
@ -146,12 +146,12 @@ class TestNeo4jStorageProcessor:
processor = Processor(taskgroup=taskgroup_mock)
# Test create_literal
processor.create_literal("literal value", "test_user", "test_collection")
processor.create_literal("literal value", "test_workspace", "test_collection")
mock_driver.execute_query.assert_called_with(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value="literal value",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="neo4j"
)
@ -180,18 +180,18 @@ class TestNeo4jStorageProcessor:
"http://example.com/subject",
"http://example.com/predicate",
"http://example.com/object",
"test_user",
"test_workspace",
"test_collection"
)
mock_driver.execute_query.assert_called_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src="http://example.com/subject",
dest="http://example.com/object",
uri="http://example.com/predicate",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="neo4j"
)
@ -220,18 +220,18 @@ class TestNeo4jStorageProcessor:
"http://example.com/subject",
"http://example.com/predicate",
"literal value",
"test_user",
"test_workspace",
"test_collection"
)
mock_driver.execute_query.assert_called_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src="http://example.com/subject",
dest="literal value",
uri="http://example.com/predicate",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="neo4j"
)
@ -268,36 +268,35 @@ class TestNeo4jStorageProcessor:
# Create mock message with metadata
mock_message = MagicMock()
mock_message.triples = [triple]
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples("test_workspace", mock_message)
# Verify create_node was called for subject and object
# Verify relate_node was called
expected_calls = [
# Subject node creation
(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
{"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection", "database_": "neo4j"}
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
{"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"}
),
# Object node creation
(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
{"uri": "http://example.com/object", "user": "test_user", "collection": "test_collection", "database_": "neo4j"}
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
{"uri": "http://example.com/object", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"}
),
# Relationship creation
(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
{
"src": "http://example.com/subject",
"dest": "http://example.com/object",
"uri": "http://example.com/predicate",
"user": "test_user",
"workspace": "test_workspace",
"collection": "test_collection",
"database_": "neo4j"
}
@ -340,12 +339,11 @@ class TestNeo4jStorageProcessor:
# Create mock message with metadata
mock_message = MagicMock()
mock_message.triples = [triple]
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples("test_workspace", mock_message)
# Verify create_node was called for subject
# Verify create_literal was called for object
@ -353,24 +351,24 @@ class TestNeo4jStorageProcessor:
expected_calls = [
# Subject node creation
(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
{"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection", "database_": "neo4j"}
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
{"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"}
),
# Literal creation
(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
{"value": "literal value", "user": "test_user", "collection": "test_collection", "database_": "neo4j"}
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
{"value": "literal value", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"}
),
# Relationship creation
(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
{
"src": "http://example.com/subject",
"dest": "literal value",
"uri": "http://example.com/predicate",
"user": "test_user",
"workspace": "test_workspace",
"collection": "test_collection",
"database_": "neo4j"
}
@ -421,12 +419,11 @@ class TestNeo4jStorageProcessor:
# Create mock message with metadata
mock_message = MagicMock()
mock_message.triples = [triple1, triple2]
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples("test_workspace", mock_message)
# Should have processed both triples
# Triple1: 2 nodes + 1 relationship = 3 calls
@ -449,12 +446,11 @@ class TestNeo4jStorageProcessor:
# Create mock message with empty triples and metadata
mock_message = MagicMock()
mock_message.triples = []
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples("test_workspace", mock_message)
# Should not have made any execute_query calls beyond index creation
# Only index creation calls should have been made during initialization
@ -568,38 +564,37 @@ class TestNeo4jStorageProcessor:
mock_message = MagicMock()
mock_message.triples = [triple]
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection"
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message)
await processor.store_triples("test_workspace", mock_message)
# Verify the triple was processed with special characters preserved
mock_driver.execute_query.assert_any_call(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri="http://example.com/subject with spaces",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="neo4j"
)
mock_driver.execute_query.assert_any_call(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value='literal with "quotes" and unicode: ñáéíóú',
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="neo4j"
)
mock_driver.execute_query.assert_any_call(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src="http://example.com/subject with spaces",
dest='literal with "quotes" and unicode: ñáéíóú',
uri="http://example.com/predicate:with/symbols",
user="test_user",
workspace="test_workspace",
collection="test_collection",
database_="neo4j"
)

View file

@ -24,11 +24,10 @@ def _make_processor(qdrant_client=None):
return proc
def _make_request(vector=None, user="test-user", collection="test-col",
def _make_request(vector=None, collection="test-col",
schema_name="customers", limit=10, index_name=None):
return RowEmbeddingsRequest(
vector=vector or [0.1, 0.2, 0.3],
user=user,
collection=collection,
schema_name=schema_name,
limit=limit,
@ -36,6 +35,14 @@ def _make_request(vector=None, user="test-user", collection="test-col",
)
def _make_flow(workspace="test-workspace", pub=None):
"""Make a mock flow object that is callable and has .workspace."""
flow = MagicMock()
flow.return_value = pub if pub is not None else AsyncMock()
flow.workspace = workspace
return flow
def _make_search_point(index_name, index_value, text, score):
point = MagicMock()
point.payload = {
@ -85,34 +92,33 @@ class TestFindCollection:
def test_finds_matching_collection(self):
proc = _make_processor()
mock_coll = MagicMock()
mock_coll.name = "rows_test_user_test_col_customers_384"
mock_coll.name = "rows_test_workspace_test_col_customers_384"
mock_collections = MagicMock()
mock_collections.collections = [mock_coll]
proc.qdrant.get_collections.return_value = mock_collections
result = proc.find_collection("test-user", "test-col", "customers")
result = proc.find_collection("test-workspace", "test-col", "customers")
# Prefix: rows_test_user_test_col_customers_
assert result == "rows_test_user_test_col_customers_384"
assert result == "rows_test_workspace_test_col_customers_384"
def test_returns_none_when_no_match(self):
proc = _make_processor()
mock_coll = MagicMock()
mock_coll.name = "rows_other_user_other_col_schema_768"
mock_coll.name = "rows_other_workspace_other_col_schema_768"
mock_collections = MagicMock()
mock_collections.collections = [mock_coll]
proc.qdrant.get_collections.return_value = mock_collections
result = proc.find_collection("test-user", "test-col", "customers")
result = proc.find_collection("test-workspace", "test-col", "customers")
assert result is None
def test_returns_none_on_error(self):
proc = _make_processor()
proc.qdrant.get_collections.side_effect = Exception("connection error")
result = proc.find_collection("user", "col", "schema")
result = proc.find_collection("workspace", "col", "schema")
assert result is None
@ -127,7 +133,7 @@ class TestQueryRowEmbeddings:
proc = _make_processor()
request = _make_request(vector=[])
result = await proc.query_row_embeddings(request)
result = await proc.query_row_embeddings("test-workspace", request)
assert result == []
@pytest.mark.asyncio
@ -136,13 +142,13 @@ class TestQueryRowEmbeddings:
proc.find_collection = MagicMock(return_value=None)
request = _make_request()
result = await proc.query_row_embeddings(request)
result = await proc.query_row_embeddings("test-workspace", request)
assert result == []
@pytest.mark.asyncio
async def test_successful_query_returns_matches(self):
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
points = [
_make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95),
@ -153,7 +159,7 @@ class TestQueryRowEmbeddings:
proc.qdrant.query_points.return_value = mock_result
request = _make_request()
result = await proc.query_row_embeddings(request)
result = await proc.query_row_embeddings("test-workspace", request)
assert len(result) == 2
assert isinstance(result[0], RowIndexMatch)
@ -166,14 +172,14 @@ class TestQueryRowEmbeddings:
async def test_index_name_filter_applied(self):
"""When index_name is specified, a Qdrant filter should be used."""
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
mock_result = MagicMock()
mock_result.points = []
proc.qdrant.query_points.return_value = mock_result
request = _make_request(index_name="address")
await proc.query_row_embeddings(request)
await proc.query_row_embeddings("test-workspace", request)
call_kwargs = proc.qdrant.query_points.call_args[1]
assert call_kwargs["query_filter"] is not None
@ -182,14 +188,14 @@ class TestQueryRowEmbeddings:
async def test_no_index_name_no_filter(self):
"""When index_name is empty, no filter should be applied."""
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
mock_result = MagicMock()
mock_result.points = []
proc.qdrant.query_points.return_value = mock_result
request = _make_request(index_name="")
await proc.query_row_embeddings(request)
await proc.query_row_embeddings("test-workspace", request)
call_kwargs = proc.qdrant.query_points.call_args[1]
assert call_kwargs["query_filter"] is None
@ -198,7 +204,7 @@ class TestQueryRowEmbeddings:
async def test_missing_payload_fields_default(self):
"""Points with missing payload fields should use defaults."""
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
point = MagicMock()
point.payload = {} # Empty payload
@ -209,7 +215,7 @@ class TestQueryRowEmbeddings:
proc.qdrant.query_points.return_value = mock_result
request = _make_request()
result = await proc.query_row_embeddings(request)
result = await proc.query_row_embeddings("test-workspace", request)
assert len(result) == 1
assert result[0].index_name == ""
@ -219,13 +225,13 @@ class TestQueryRowEmbeddings:
@pytest.mark.asyncio
async def test_qdrant_error_propagates(self):
proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384")
proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
proc.qdrant.query_points.side_effect = Exception("qdrant down")
request = _make_request()
with pytest.raises(Exception, match="qdrant down"):
await proc.query_row_embeddings(request)
await proc.query_row_embeddings("test-workspace", request)
# ---------------------------------------------------------------------------
@ -243,7 +249,7 @@ class TestOnMessage:
])
mock_pub = AsyncMock()
flow = lambda name: mock_pub
flow = _make_flow(pub=mock_pub)
msg = MagicMock()
msg.value.return_value = _make_request()
@ -264,7 +270,7 @@ class TestOnMessage:
)
mock_pub = AsyncMock()
flow = lambda name: mock_pub
flow = _make_flow(pub=mock_pub)
msg = MagicMock()
msg.value.return_value = _make_request()
@ -284,7 +290,7 @@ class TestOnMessage:
proc.query_row_embeddings = AsyncMock(return_value=[])
mock_pub = AsyncMock()
flow = lambda name: mock_pub
flow = _make_flow(pub=mock_pub)
msg = MagicMock()
msg.value.return_value = _make_request()

View file

@ -45,12 +45,9 @@ class TestGetGraphEmbeddings:
with `vector=` (singular) the schema field name. A previous
version used `vectors=` and TypeError'd at runtime.
"""
# Arrange — fake row matching the get_triples_stmt result shape:
# row[0..2] are unused by the method, row[3] is the entities blob
fake_row = (
None, None, None,
[
# ((value, is_uri), vector)
(("http://example.org/alice", True), [0.1, 0.2, 0.3]),
(("http://example.org/bob", True), [0.4, 0.5, 0.6]),
(("a literal entity", False), [0.7, 0.8, 0.9]),
@ -67,14 +64,8 @@ class TestGetGraphEmbeddings:
async def receiver(msg):
received.append(msg)
# Act
await store.get_graph_embeddings(
user="alice",
document_id="doc-1",
receiver=receiver,
)
await store.get_graph_embeddings("alice", "doc-1", receiver)
# Assert
mock_async_execute.assert_called_once_with(
store.cassandra,
store.get_graph_embeddings_stmt,
@ -86,7 +77,6 @@ class TestGetGraphEmbeddings:
assert isinstance(ge, GraphEmbeddings)
assert isinstance(ge.metadata, Metadata)
assert ge.metadata.id == "doc-1"
assert ge.metadata.user == "alice"
assert len(ge.entities) == 3
assert all(isinstance(e, EntityEmbeddings) for e in ge.entities)
@ -122,7 +112,7 @@ class TestGetGraphEmbeddings:
async def receiver(msg):
received.append(msg)
await store.get_graph_embeddings("u", "d", receiver)
await store.get_graph_embeddings("w", "d", receiver)
assert len(received) == 1
assert received[0].entities == []
@ -149,7 +139,7 @@ class TestGetGraphEmbeddings:
async def receiver(msg):
received.append(msg)
await store.get_graph_embeddings("u", "d", receiver)
await store.get_graph_embeddings("w", "d", receiver)
assert len(received) == 2
assert received[0].entities[0].entity.iri == "http://example.org/a"
@ -194,7 +184,6 @@ class TestGetTriples:
assert isinstance(triples_msg, Triples)
assert isinstance(triples_msg.metadata, Metadata)
assert triples_msg.metadata.id == "doc-1"
assert triples_msg.metadata.user == "alice"
assert len(triples_msg.triples) == 1
t = triples_msg.triples[0]

View file

@ -30,7 +30,6 @@ def sample():
metadata=Metadata(
id="doc-1",
root="",
user="alice",
collection="testcoll",
),
chunks=[
@ -56,7 +55,6 @@ class TestDocumentEmbeddingsTranslator:
assert isinstance(decoded, DocumentEmbeddings)
assert isinstance(decoded.metadata, Metadata)
assert decoded.metadata.id == "doc-1"
assert decoded.metadata.user == "alice"
assert decoded.metadata.collection == "testcoll"
assert len(decoded.chunks) == 2

View file

@ -41,7 +41,6 @@ def translator():
def graph_embeddings_request():
return KnowledgeRequest(
operation="put-kg-core",
user="alice",
id="doc-1",
flow="default",
collection="testcoll",
@ -49,7 +48,6 @@ def graph_embeddings_request():
metadata=Metadata(
id="doc-1",
root="",
user="alice",
collection="testcoll",
),
entities=[
@ -70,7 +68,6 @@ def graph_embeddings_request():
def triples_request():
return KnowledgeRequest(
operation="put-kg-core",
user="alice",
id="doc-1",
flow="default",
collection="testcoll",
@ -78,7 +75,6 @@ def triples_request():
metadata=Metadata(
id="doc-1",
root="",
user="alice",
collection="testcoll",
),
triples=[
@ -123,7 +119,6 @@ class TestKnowledgeRequestTranslatorGraphEmbeddings:
assert isinstance(ge, GraphEmbeddings)
assert isinstance(ge.metadata, Metadata)
assert ge.metadata.id == "doc-1"
assert ge.metadata.user == "alice"
assert ge.metadata.collection == "testcoll"
assert len(ge.entities) == 2
@ -143,7 +138,6 @@ class TestKnowledgeRequestTranslatorTriples:
assert decoded.triples is not None
assert isinstance(decoded.triples.metadata, Metadata)
assert decoded.triples.metadata.id == "doc-1"
assert decoded.triples.metadata.user == "alice"
assert decoded.triples.metadata.collection == "testcoll"
assert len(decoded.triples.triples) == 1

View file

@ -12,7 +12,6 @@ from trustgraph.api import Api
from trustgraph.api.types import hash, Uri, Literal, Triple
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_user = 'trustgraph'
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")

View file

@ -40,7 +40,6 @@ def load_structured_data(
sample_chars: int = 500,
schema_name: str = None,
flow: str = 'default',
user: str = 'trustgraph',
collection: str = 'default',
dry_run: bool = False,
verbose: bool = False,
@ -64,7 +63,6 @@ def load_structured_data(
sample_chars: Maximum characters to read for sampling
schema_name: Target schema name for generation
flow: TrustGraph flow name to use for prompts
user: User name for metadata (default: trustgraph)
collection: Collection name for metadata (default: default)
dry_run: If True, validate but don't import data
verbose: Enable verbose logging
@ -112,7 +110,7 @@ def load_structured_data(
try:
# Use shared pipeline for preview (small sample)
preview_objects, _ = _process_data_pipeline(input_file, temp_descriptor.name, user, collection, sample_size=5)
preview_objects, _ = _process_data_pipeline(input_file, temp_descriptor.name, collection, sample_size=5)
# Show preview
print("📊 Data Preview (first few records):")
@ -133,7 +131,7 @@ def load_structured_data(
print("🚀 Importing data to TrustGraph...")
# Use shared pipeline for full processing (no sample limit)
output_objects, descriptor = _process_data_pipeline(input_file, temp_descriptor.name, user, collection)
output_objects, descriptor = _process_data_pipeline(input_file, temp_descriptor.name, collection)
# Get batch size from descriptor
batch_size = descriptor.get('output', {}).get('options', {}).get('batch_size', 1000)
@ -244,7 +242,7 @@ def load_structured_data(
logger.info(f"Parsing {input_file} with descriptor {descriptor_file}...")
# Use shared pipeline
output_records, descriptor = _process_data_pipeline(input_file, descriptor_file, user, collection, sample_size)
output_records, descriptor = _process_data_pipeline(input_file, descriptor_file, collection, sample_size)
# Output results
if output_file:
@ -288,7 +286,7 @@ def load_structured_data(
logger.info(f"Loading {input_file} to TrustGraph using descriptor {descriptor_file}...")
# Use shared pipeline (no sample_size limit for full load)
output_records, descriptor = _process_data_pipeline(input_file, descriptor_file, user, collection)
output_records, descriptor = _process_data_pipeline(input_file, descriptor_file, collection)
# Get batch size from descriptor or use default
batch_size = descriptor.get('output', {}).get('options', {}).get('batch_size', 1000)
@ -529,18 +527,17 @@ def _apply_transformations(records, mappings):
return processed_records
def _format_extracted_objects(processed_records, descriptor, user, collection):
def _format_extracted_objects(processed_records, descriptor, collection):
"""Convert to TrustGraph ExtractedObject format"""
output_records = []
schema_name = descriptor.get('output', {}).get('schema_name', 'default')
confidence = descriptor.get('output', {}).get('options', {}).get('confidence', 0.9)
for record in processed_records:
output_record = {
"metadata": {
"id": f"parsed-{len(output_records)+1}",
"metadata": [], # Empty metadata triples
"user": user,
"collection": collection
},
"schema_name": schema_name,
@ -553,7 +550,7 @@ def _format_extracted_objects(processed_records, descriptor, user, collection):
return output_records
def _process_data_pipeline(input_file, descriptor_file, user, collection, sample_size=None):
def _process_data_pipeline(input_file, descriptor_file, collection, sample_size=None):
"""Shared pipeline: load descriptor → read → parse → transform → format"""
# Load descriptor configuration
descriptor = _load_descriptor(descriptor_file)
@ -570,7 +567,7 @@ def _process_data_pipeline(input_file, descriptor_file, user, collection, sample
processed_records = _apply_transformations(parsed_records, mappings)
# Format output for TrustGraph ExtractedObject structure
output_records = _format_extracted_objects(processed_records, descriptor, user, collection)
output_records = _format_extracted_objects(processed_records, descriptor, collection)
return output_records, descriptor
@ -1048,7 +1045,6 @@ For more information on the descriptor format, see:
sample_chars=args.sample_chars,
schema_name=args.schema_name,
flow=args.flow,
user=args.user,
collection=args.collection,
dry_run=args.dry_run,
verbose=args.verbose,

View file

@ -6,9 +6,9 @@ import re
logger = logging.getLogger(__name__)
def make_safe_collection_name(user, collection, prefix):
def make_safe_collection_name(workspace, collection, prefix):
"""
Create a safe Milvus collection name from user/collection parameters.
Create a safe Milvus collection name from workspace/collection parameters.
Milvus only allows letters, numbers, and underscores.
"""
def sanitize(s):
@ -23,10 +23,10 @@ def make_safe_collection_name(user, collection, prefix):
safe = 'default'
return safe
safe_user = sanitize(user)
safe_workspace = sanitize(workspace)
safe_collection = sanitize(collection)
return f"{prefix}_{safe_user}_{safe_collection}"
return f"{prefix}_{safe_workspace}_{safe_collection}"
class DocVectors:
@ -49,26 +49,26 @@ class DocVectors:
self.next_reload = time.time() + self.reload_time
logger.debug(f"Reload at {self.next_reload}")
def collection_exists(self, user, collection):
def collection_exists(self, workspace, collection):
"""
Check if any collection exists for this user/collection combination.
Check if any collection exists for this workspace/collection combination.
Since collections are dimension-specific, this checks if ANY dimension variant exists.
"""
base_name = make_safe_collection_name(user, collection, self.prefix)
base_name = make_safe_collection_name(workspace, collection, self.prefix)
prefix = f"{base_name}_"
all_collections = self.client.list_collections()
return any(coll.startswith(prefix) for coll in all_collections)
def create_collection(self, user, collection, dimension=384):
def create_collection(self, workspace, collection, dimension=384):
"""
No-op for explicit collection creation.
Collections are created lazily on first insert with actual dimension.
"""
logger.info(f"Collection creation requested for {user}/{collection} - will be created lazily on first insert")
logger.info(f"Collection creation requested for {workspace}/{collection} - will be created lazily on first insert")
def init_collection(self, dimension, user, collection):
def init_collection(self, dimension, workspace, collection):
base_name = make_safe_collection_name(user, collection, self.prefix)
base_name = make_safe_collection_name(workspace, collection, self.prefix)
collection_name = f"{base_name}_{dimension}"
pkey_field = FieldSchema(
@ -116,15 +116,15 @@ class DocVectors:
index_params=index_params
)
self.collections[(dimension, user, collection)] = collection_name
self.collections[(dimension, workspace, collection)] = collection_name
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
def insert(self, embeds, chunk_id, user, collection):
def insert(self, embeds, chunk_id, workspace, collection):
dim = len(embeds)
if (dim, user, collection) not in self.collections:
self.init_collection(dim, user, collection)
if (dim, workspace, collection) not in self.collections:
self.init_collection(dim, workspace, collection)
data = [
{
@ -134,25 +134,25 @@ class DocVectors:
]
self.client.insert(
collection_name=self.collections[(dim, user, collection)],
collection_name=self.collections[(dim, workspace, collection)],
data=data
)
def search(self, embeds, user, collection, fields=["chunk_id"], limit=10):
def search(self, embeds, workspace, collection, fields=["chunk_id"], limit=10):
dim = len(embeds)
# Check if collection exists - return empty if not
if (dim, user, collection) not in self.collections:
base_name = make_safe_collection_name(user, collection, self.prefix)
if (dim, workspace, collection) not in self.collections:
base_name = make_safe_collection_name(workspace, collection, self.prefix)
collection_name = f"{base_name}_{dim}"
if not self.client.has_collection(collection_name):
logger.info(f"Collection {collection_name} does not exist, returning empty results")
return []
# Collection exists but not in cache, add it
self.collections[(dim, user, collection)] = collection_name
self.collections[(dim, workspace, collection)] = collection_name
coll = self.collections[(dim, user, collection)]
coll = self.collections[(dim, workspace, collection)]
logger.debug("Loading...")
self.client.load_collection(
@ -181,12 +181,12 @@ class DocVectors:
return res
def delete_collection(self, user, collection):
def delete_collection(self, workspace, collection):
"""
Delete all dimension variants of the collection for the given user/collection.
Delete all dimension variants of the collection for the given workspace/collection.
Since collections are created with dimension suffixes, we need to find and delete all.
"""
base_name = make_safe_collection_name(user, collection, self.prefix)
base_name = make_safe_collection_name(workspace, collection, self.prefix)
prefix = f"{base_name}_"
# Get all collections and filter for matches
@ -199,10 +199,10 @@ class DocVectors:
for collection_name in matching_collections:
self.client.drop_collection(collection_name)
logger.info(f"Deleted Milvus collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")
# Remove from our local cache
keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection]
keys_to_remove = [key for key in self.collections.keys() if key[1] == workspace and key[2] == collection]
for key in keys_to_remove:
del self.collections[key]

View file

@ -6,9 +6,9 @@ import re
logger = logging.getLogger(__name__)
def make_safe_collection_name(user, collection, prefix):
def make_safe_collection_name(workspace, collection, prefix):
"""
Create a safe Milvus collection name from user/collection parameters.
Create a safe Milvus collection name from workspace/collection parameters.
Milvus only allows letters, numbers, and underscores.
"""
def sanitize(s):
@ -23,10 +23,10 @@ def make_safe_collection_name(user, collection, prefix):
safe = 'default'
return safe
safe_user = sanitize(user)
safe_workspace = sanitize(workspace)
safe_collection = sanitize(collection)
return f"{prefix}_{safe_user}_{safe_collection}"
return f"{prefix}_{safe_workspace}_{safe_collection}"
class EntityVectors:
@ -49,26 +49,26 @@ class EntityVectors:
self.next_reload = time.time() + self.reload_time
logger.debug(f"Reload at {self.next_reload}")
def collection_exists(self, user, collection):
def collection_exists(self, workspace, collection):
"""
Check if any collection exists for this user/collection combination.
Check if any collection exists for this workspace/collection combination.
Since collections are dimension-specific, this checks if ANY dimension variant exists.
"""
base_name = make_safe_collection_name(user, collection, self.prefix)
base_name = make_safe_collection_name(workspace, collection, self.prefix)
prefix = f"{base_name}_"
all_collections = self.client.list_collections()
return any(coll.startswith(prefix) for coll in all_collections)
def create_collection(self, user, collection, dimension=384):
def create_collection(self, workspace, collection, dimension=384):
"""
No-op for explicit collection creation.
Collections are created lazily on first insert with actual dimension.
"""
logger.info(f"Collection creation requested for {user}/{collection} - will be created lazily on first insert")
logger.info(f"Collection creation requested for {workspace}/{collection} - will be created lazily on first insert")
def init_collection(self, dimension, user, collection):
def init_collection(self, dimension, workspace, collection):
base_name = make_safe_collection_name(user, collection, self.prefix)
base_name = make_safe_collection_name(workspace, collection, self.prefix)
collection_name = f"{base_name}_{dimension}"
pkey_field = FieldSchema(
@ -122,15 +122,15 @@ class EntityVectors:
index_params=index_params
)
self.collections[(dimension, user, collection)] = collection_name
self.collections[(dimension, workspace, collection)] = collection_name
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
def insert(self, embeds, entity, user, collection, chunk_id=""):
def insert(self, embeds, entity, workspace, collection, chunk_id=""):
dim = len(embeds)
if (dim, user, collection) not in self.collections:
self.init_collection(dim, user, collection)
if (dim, workspace, collection) not in self.collections:
self.init_collection(dim, workspace, collection)
data = [
{
@ -141,25 +141,25 @@ class EntityVectors:
]
self.client.insert(
collection_name=self.collections[(dim, user, collection)],
collection_name=self.collections[(dim, workspace, collection)],
data=data
)
def search(self, embeds, user, collection, fields=["entity"], limit=10):
def search(self, embeds, workspace, collection, fields=["entity"], limit=10):
dim = len(embeds)
# Check if collection exists - return empty if not
if (dim, user, collection) not in self.collections:
base_name = make_safe_collection_name(user, collection, self.prefix)
if (dim, workspace, collection) not in self.collections:
base_name = make_safe_collection_name(workspace, collection, self.prefix)
collection_name = f"{base_name}_{dim}"
if not self.client.has_collection(collection_name):
logger.info(f"Collection {collection_name} does not exist, returning empty results")
return []
# Collection exists but not in cache, add it
self.collections[(dim, user, collection)] = collection_name
self.collections[(dim, workspace, collection)] = collection_name
coll = self.collections[(dim, user, collection)]
coll = self.collections[(dim, workspace, collection)]
logger.debug("Loading...")
self.client.load_collection(
@ -188,12 +188,12 @@ class EntityVectors:
return res
def delete_collection(self, user, collection):
def delete_collection(self, workspace, collection):
"""
Delete all dimension variants of the collection for the given user/collection.
Delete all dimension variants of the collection for the given workspace/collection.
Since collections are created with dimension suffixes, we need to find and delete all.
"""
base_name = make_safe_collection_name(user, collection, self.prefix)
base_name = make_safe_collection_name(workspace, collection, self.prefix)
prefix = f"{base_name}_"
# Get all collections and filter for matches
@ -206,10 +206,10 @@ class EntityVectors:
for collection_name in matching_collections:
self.client.drop_collection(collection_name)
logger.info(f"Deleted Milvus collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")
# Remove from our local cache
keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection]
keys_to_remove = [key for key in self.collections.keys() if key[1] == workspace and key[2] == collection]
for key in keys_to_remove:
del self.collections[key]

View file

@ -70,7 +70,7 @@ class GraphQLSchemaBuilder:
Build the GraphQL schema with the provided query callback.
The query callback will be invoked when resolving queries, with:
- user: str
- workspace: str
- collection: str
- schema_name: str
- row_schema: RowSchema
@ -228,7 +228,7 @@ class GraphQLSchemaBuilder:
limit: Optional[int] = 100
) -> List[graphql_type]:
# Get context values
user = info.context["user"]
workspace = info.context["workspace"]
collection = info.context["collection"]
# Parse the where clause
@ -236,7 +236,7 @@ class GraphQLSchemaBuilder:
# Call the query backend
results = await query_callback(
user, collection, schema_name, row_schema,
workspace, collection, schema_name, row_schema,
filters, limit, order_by, direction
)

View file

@ -167,7 +167,7 @@ class QueryExplainer:
question_components, query_results, processing_metadata
)
# Generate user-friendly explanation
# Generate workspace-friendly explanation
user_friendly_explanation = self._generate_user_friendly_explanation(
question, question_components, ontology_subsets, final_answer
)
@ -503,7 +503,7 @@ class QueryExplainer:
question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset],
final_answer: str) -> str:
"""Generate user-friendly explanation of the process."""
"""Generate workspace-friendly explanation of the process."""
explanation_parts = []
# Introduction

View file

@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
@dataclass
class QueryRequest:
"""Query request from user."""
"""Query request from workspace."""
question: str
context: Optional[str] = None
ontology_hint: Optional[str] = None

View file

@ -1,6 +1,6 @@
"""
Question analyzer for ontology-sensitive query system.
Decomposes user questions into semantic components.
Decomposes workspace questions into semantic components.
"""
import logging

View file

@ -1,7 +1,7 @@
"""
Row embeddings query service for Qdrant.
Input is query vectors plus user/collection/schema context.
Input is query vectors plus workspace/collection/schema context.
Output is matching row index information (index_name, index_value) for
use in subsequent Cassandra lookups.
"""
@ -70,10 +70,10 @@ class Processor(FlowProcessor):
safe_name = 'r_' + safe_name
return safe_name.lower()
def find_collection(self, user: str, collection: str, schema_name: str) -> Optional[str]:
"""Find the Qdrant collection for a given user/collection/schema"""
def find_collection(self, workspace: str, collection: str, schema_name: str) -> Optional[str]:
"""Find the Qdrant collection for a given workspace/collection/schema"""
prefix = (
f"rows_{self.sanitize_name(user)}_"
f"rows_{self.sanitize_name(workspace)}_"
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
)
@ -163,7 +163,7 @@ class Processor(FlowProcessor):
logger.debug(
f"Handling row embeddings query for "
f"{request.user}/{request.collection}/{request.schema_name}..."
f"{flow.workspace}/{request.collection}/{request.schema_name}..."
)
# Execute query

View file

@ -238,7 +238,7 @@ class Processor(FlowProcessor):
async def query_cassandra(
self,
user: str,
workspace: str,
collection: str,
schema_name: str,
row_schema: RowSchema,
@ -256,7 +256,7 @@ class Processor(FlowProcessor):
# Connect if needed
self.connect_cassandra()
safe_keyspace = self.sanitize_name(user)
safe_keyspace = self.sanitize_name(workspace)
# Try to find an index that matches the filters
index_match = self.find_matching_index(row_schema, filters)
@ -409,7 +409,6 @@ class Processor(FlowProcessor):
query: str,
variables: Dict[str, Any],
operation_name: Optional[str],
user: str,
collection: str
) -> Dict[str, Any]:
"""Execute a GraphQL query against the workspace's schema"""
@ -424,7 +423,7 @@ class Processor(FlowProcessor):
# Create context for the query
context = {
"processor": self,
"user": user,
"workspace": workspace,
"collection": collection
}
@ -479,7 +478,6 @@ class Processor(FlowProcessor):
query=request.query,
variables=dict(request.variables) if request.variables else {},
operation_name=request.operation_name,
user=request.user,
collection=request.collection
)

View file

@ -30,14 +30,14 @@ class EvaluationError(Exception):
pass
async def evaluate(node, triples_client, user, collection, limit=10000):
async def evaluate(node, triples_client, workspace, collection, limit=10000):
"""
Evaluate a SPARQL algebra node.
Args:
node: rdflib CompValue algebra node
triples_client: TriplesClient instance for triple pattern queries
user: user/keyspace identifier
workspace: workspace/keyspace identifier
collection: collection identifier
limit: safety limit on results
@ -55,24 +55,24 @@ async def evaluate(node, triples_client, user, collection, limit=10000):
logger.warning(f"Unsupported algebra node: {name}")
return [{}]
return await handler(node, triples_client, user, collection, limit)
return await handler(node, triples_client, workspace, collection, limit)
# --- Node handlers ---
async def _eval_select_query(node, tc, user, collection, limit):
async def _eval_select_query(node, tc, workspace, collection, limit):
"""Evaluate a SelectQuery node."""
return await evaluate(node.p, tc, user, collection, limit)
return await evaluate(node.p, tc, workspace, collection, limit)
async def _eval_project(node, tc, user, collection, limit):
async def _eval_project(node, tc, workspace, collection, limit):
"""Evaluate a Project node (SELECT variable projection)."""
solutions = await evaluate(node.p, tc, user, collection, limit)
solutions = await evaluate(node.p, tc, workspace, collection, limit)
variables = [str(v) for v in node.PV]
return project(solutions, variables)
async def _eval_bgp(node, tc, user, collection, limit):
async def _eval_bgp(node, tc, workspace, collection, limit):
"""
Evaluate a Basic Graph Pattern.
@ -107,7 +107,7 @@ async def _eval_bgp(node, tc, user, collection, limit):
# Query the triples store
results = await _query_pattern(
tc, s_val, p_val, o_val, user, collection, limit
tc, s_val, p_val, o_val, workspace, collection, limit
)
# Map results back to variable bindings,
@ -130,17 +130,17 @@ async def _eval_bgp(node, tc, user, collection, limit):
return solutions[:limit]
async def _eval_join(node, tc, user, collection, limit):
async def _eval_join(node, tc, workspace, collection, limit):
"""Evaluate a Join node."""
left = await evaluate(node.p1, tc, user, collection, limit)
right = await evaluate(node.p2, tc, user, collection, limit)
left = await evaluate(node.p1, tc, workspace, collection, limit)
right = await evaluate(node.p2, tc, workspace, collection, limit)
return hash_join(left, right)[:limit]
async def _eval_left_join(node, tc, user, collection, limit):
async def _eval_left_join(node, tc, workspace, collection, limit):
"""Evaluate a LeftJoin node (OPTIONAL)."""
left_sols = await evaluate(node.p1, tc, user, collection, limit)
right_sols = await evaluate(node.p2, tc, user, collection, limit)
left_sols = await evaluate(node.p1, tc, workspace, collection, limit)
right_sols = await evaluate(node.p2, tc, workspace, collection, limit)
filter_fn = None
if hasattr(node, "expr") and node.expr is not None:
@ -153,16 +153,16 @@ async def _eval_left_join(node, tc, user, collection, limit):
return left_join(left_sols, right_sols, filter_fn)[:limit]
async def _eval_union(node, tc, user, collection, limit):
async def _eval_union(node, tc, workspace, collection, limit):
"""Evaluate a Union node."""
left = await evaluate(node.p1, tc, user, collection, limit)
right = await evaluate(node.p2, tc, user, collection, limit)
left = await evaluate(node.p1, tc, workspace, collection, limit)
right = await evaluate(node.p2, tc, workspace, collection, limit)
return union(left, right)[:limit]
async def _eval_filter(node, tc, user, collection, limit):
async def _eval_filter(node, tc, workspace, collection, limit):
"""Evaluate a Filter node."""
solutions = await evaluate(node.p, tc, user, collection, limit)
solutions = await evaluate(node.p, tc, workspace, collection, limit)
expr = node.expr
return [
sol for sol in solutions
@ -170,22 +170,22 @@ async def _eval_filter(node, tc, user, collection, limit):
]
async def _eval_distinct(node, tc, user, collection, limit):
async def _eval_distinct(node, tc, workspace, collection, limit):
"""Evaluate a Distinct node."""
solutions = await evaluate(node.p, tc, user, collection, limit)
solutions = await evaluate(node.p, tc, workspace, collection, limit)
return distinct(solutions)
async def _eval_reduced(node, tc, user, collection, limit):
async def _eval_reduced(node, tc, workspace, collection, limit):
"""Evaluate a Reduced node (like Distinct but implementation-defined)."""
# Treat same as Distinct
solutions = await evaluate(node.p, tc, user, collection, limit)
solutions = await evaluate(node.p, tc, workspace, collection, limit)
return distinct(solutions)
async def _eval_order_by(node, tc, user, collection, limit):
async def _eval_order_by(node, tc, workspace, collection, limit):
"""Evaluate an OrderBy node."""
solutions = await evaluate(node.p, tc, user, collection, limit)
solutions = await evaluate(node.p, tc, workspace, collection, limit)
key_fns = []
for cond in node.expr:
@ -206,7 +206,7 @@ async def _eval_order_by(node, tc, user, collection, limit):
return order_by(solutions, key_fns)
async def _eval_slice(node, tc, user, collection, limit):
async def _eval_slice(node, tc, workspace, collection, limit):
"""Evaluate a Slice node (LIMIT/OFFSET)."""
# Pass tighter limit downstream if possible
inner_limit = limit
@ -214,13 +214,13 @@ async def _eval_slice(node, tc, user, collection, limit):
offset = node.start or 0
inner_limit = min(limit, offset + node.length)
solutions = await evaluate(node.p, tc, user, collection, inner_limit)
solutions = await evaluate(node.p, tc, workspace, collection, inner_limit)
return slice_solutions(solutions, node.start or 0, node.length)
async def _eval_extend(node, tc, user, collection, limit):
async def _eval_extend(node, tc, workspace, collection, limit):
"""Evaluate an Extend node (BIND)."""
solutions = await evaluate(node.p, tc, user, collection, limit)
solutions = await evaluate(node.p, tc, workspace, collection, limit)
var_name = str(node.var)
expr = node.expr
@ -246,9 +246,9 @@ async def _eval_extend(node, tc, user, collection, limit):
return result
async def _eval_group(node, tc, user, collection, limit):
async def _eval_group(node, tc, workspace, collection, limit):
"""Evaluate a Group node (GROUP BY with aggregation)."""
solutions = await evaluate(node.p, tc, user, collection, limit)
solutions = await evaluate(node.p, tc, workspace, collection, limit)
# Extract grouping expressions
group_exprs = []
@ -289,9 +289,9 @@ async def _eval_group(node, tc, user, collection, limit):
return result
async def _eval_aggregate_join(node, tc, user, collection, limit):
async def _eval_aggregate_join(node, tc, workspace, collection, limit):
"""Evaluate an AggregateJoin (aggregation functions after GROUP BY)."""
solutions = await evaluate(node.p, tc, user, collection, limit)
solutions = await evaluate(node.p, tc, workspace, collection, limit)
result = []
for sol in solutions:
@ -310,7 +310,7 @@ async def _eval_aggregate_join(node, tc, user, collection, limit):
return result
async def _eval_graph(node, tc, user, collection, limit):
async def _eval_graph(node, tc, workspace, collection, limit):
"""Evaluate a Graph node (GRAPH clause)."""
term = node.term
@ -319,16 +319,16 @@ async def _eval_graph(node, tc, user, collection, limit):
# We'd need to pass graph to triples queries
# For now, evaluate inner pattern normally
logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired")
return await evaluate(node.p, tc, user, collection, limit)
return await evaluate(node.p, tc, workspace, collection, limit)
elif isinstance(term, Variable):
# GRAPH ?g { ... } — variable graph
logger.info(f"GRAPH ?{term} clause - variable graph not yet wired")
return await evaluate(node.p, tc, user, collection, limit)
return await evaluate(node.p, tc, workspace, collection, limit)
else:
return await evaluate(node.p, tc, user, collection, limit)
return await evaluate(node.p, tc, workspace, collection, limit)
async def _eval_values(node, tc, user, collection, limit):
async def _eval_values(node, tc, workspace, collection, limit):
"""Evaluate a VALUES clause (inline data)."""
variables = [str(v) for v in node.var]
solutions = []
@ -343,9 +343,9 @@ async def _eval_values(node, tc, user, collection, limit):
return solutions
async def _eval_to_multiset(node, tc, user, collection, limit):
async def _eval_to_multiset(node, tc, workspace, collection, limit):
"""Evaluate a ToMultiSet node (subquery)."""
return await evaluate(node.p, tc, user, collection, limit)
return await evaluate(node.p, tc, workspace, collection, limit)
# --- Aggregate computation ---
@ -487,7 +487,7 @@ def _resolve_term(tmpl, solution):
return rdflib_term_to_term(tmpl)
async def _query_pattern(tc, s, p, o, user, collection, limit):
async def _query_pattern(tc, s, p, o, workspace, collection, limit):
"""
Issue a streaming triple pattern query via TriplesClient.
@ -496,7 +496,7 @@ async def _query_pattern(tc, s, p, o, user, collection, limit):
results = await tc.query(
s=s, p=p, o=o,
limit=limit,
user=user,
workspace=workspace,
collection=collection,
)
return results

View file

@ -141,7 +141,7 @@ class Processor(FlowProcessor):
solutions = await evaluate(
parsed.algebra,
triples_client,
user=flow.workspace,
workspace=flow.workspace,
collection=request.collection or "default",
limit=request.limit or 10000,
)

View file

@ -178,24 +178,24 @@ class Processor(TriplesQueryService):
self.cassandra_password = password
self.table = None
def ensure_connection(self, user):
def ensure_connection(self, workspace):
"""Ensure we have a connection to the correct keyspace."""
if user != self.table:
if workspace != self.table:
KGClass = EntityCentricKnowledgeGraph
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
keyspace=workspace,
username=self.cassandra_username,
password=self.cassandra_password
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
keyspace=workspace,
)
self.table = user
self.table = workspace
async def query_triples(self, workspace, query):

View file

@ -67,7 +67,7 @@ class Processor(TriplesQueryService):
try:
user = workspace
workspace = workspace
collection = query.collection if query.collection else "default"
triples = []
@ -79,13 +79,13 @@ class Processor(TriplesQueryService):
# SPO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN $src as src "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p), value=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -93,13 +93,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN $src as src "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p), uri=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -111,13 +111,13 @@ class Processor(TriplesQueryService):
# SP
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN dest.value as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -126,13 +126,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN dest.uri as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -147,13 +147,13 @@ class Processor(TriplesQueryService):
# SO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), value=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -162,13 +162,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), uri=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -181,13 +181,13 @@ class Processor(TriplesQueryService):
# S
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel, dest.value as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -196,13 +196,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel, dest.uri as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -220,13 +220,13 @@ class Processor(TriplesQueryService):
# PO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p), value=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -235,13 +235,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], get_term_value(query.p), get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p), dest=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -254,13 +254,13 @@ class Processor(TriplesQueryService):
# P
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, dest.value as dest "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -269,13 +269,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], get_term_value(query.p), data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, dest.uri as dest "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -290,13 +290,13 @@ class Processor(TriplesQueryService):
# O
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit),
value=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -305,13 +305,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], data["rel"], get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit),
uri=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -324,12 +324,12 @@ class Processor(TriplesQueryService):
# *
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
"LIMIT " + str(query.limit),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -338,12 +338,12 @@ class Processor(TriplesQueryService):
triples.append((data["src"], data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
"LIMIT " + str(query.limit),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)

View file

@ -67,9 +67,8 @@ class Processor(TriplesQueryService):
try:
user = workspace
collection = query.collection if query.collection else "default"
triples = []
if query.s is not None:
@ -79,13 +78,13 @@ class Processor(TriplesQueryService):
# SPO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN $src as src "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p), value=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -93,13 +92,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN $src as src "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p), uri=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -111,13 +110,13 @@ class Processor(TriplesQueryService):
# SP
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN dest.value as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -126,13 +125,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN dest.uri as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -147,13 +146,13 @@ class Processor(TriplesQueryService):
# SO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), value=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -162,13 +161,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel "
"LIMIT " + str(query.limit),
src=get_term_value(query.s), uri=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -181,13 +180,13 @@ class Processor(TriplesQueryService):
# S
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel, dest.value as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -196,13 +195,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel, dest.uri as dest "
"LIMIT " + str(query.limit),
src=get_term_value(query.s),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -220,13 +219,13 @@ class Processor(TriplesQueryService):
# PO
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p), value=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -235,13 +234,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], get_term_value(query.p), get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p), dest=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -254,13 +253,13 @@ class Processor(TriplesQueryService):
# P
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, dest.value as dest "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -269,13 +268,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], get_term_value(query.p), data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, dest.uri as dest "
"LIMIT " + str(query.limit),
uri=get_term_value(query.p),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -290,13 +289,13 @@ class Processor(TriplesQueryService):
# O
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit),
value=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -305,13 +304,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], data["rel"], get_term_value(query.o)))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit),
uri=get_term_value(query.o),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -324,12 +323,12 @@ class Processor(TriplesQueryService):
# *
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest "
"LIMIT " + str(query.limit),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -338,12 +337,12 @@ class Processor(TriplesQueryService):
triples.append((data["src"], data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
"LIMIT " + str(query.limit),
user=user, collection=collection,
workspace=workspace, collection=collection,
database_=self.db,
)
@ -366,7 +365,7 @@ class Processor(TriplesQueryService):
logger.error(f"Exception querying triples: {e}", exc_info=True)
raise e
@staticmethod
def add_args(parser):

View file

@ -60,27 +60,27 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
help=f'Milvus store URI (default: {default_store_uri})'
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""
Create collection via config push - collections are created lazily on first write
with the correct dimension determined from the actual embeddings.
"""
try:
logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write")
self.vecstore.create_collection(user, collection)
logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write")
self.vecstore.create_collection(workspace, collection)
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for document embeddings via config push"""
try:
self.vecstore.delete_collection(user, collection)
logger.info(f"Successfully deleted collection {user}/{collection}")
self.vecstore.delete_collection(workspace, collection)
logger.info(f"Successfully deleted collection {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
def run():

View file

@ -165,22 +165,22 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
help=f'Pinecone region, (default: {default_region}'
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""
Create collection via config push - indexes are created lazily on first write
with the correct dimension determined from the actual embeddings.
"""
try:
logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write")
logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write")
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for document embeddings via config push"""
try:
prefix = f"d-{user}-{collection}-"
prefix = f"d-{workspace}-{collection}-"
# Get all indexes and filter for matches
all_indexes = self.pinecone.list_indexes()
@ -195,10 +195,10 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
for index_name in matching_indexes:
self.pinecone.delete_index(index_name)
logger.info(f"Deleted Pinecone index: {index_name}")
logger.info(f"Deleted {len(matching_indexes)} index(es) for {user}/{collection}")
logger.info(f"Deleted {len(matching_indexes)} index(es) for {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
def run():

View file

@ -107,22 +107,22 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
help=f'Qdrant API key (default: None)'
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""
Create collection via config push - collections are created lazily on first write
with the correct dimension determined from the actual embeddings.
"""
try:
logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write")
logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write")
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for document embeddings via config push"""
try:
prefix = f"d_{user}_{collection}_"
prefix = f"d_{workspace}_{collection}_"
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
@ -137,10 +137,10 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
def run():

View file

@ -73,27 +73,27 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
help=f'Milvus store URI (default: {default_store_uri})'
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""
Create collection via config push - collections are created lazily on first write
with the correct dimension determined from the actual embeddings.
"""
try:
logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write")
self.vecstore.create_collection(user, collection)
logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write")
self.vecstore.create_collection(workspace, collection)
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for graph embeddings via config push"""
try:
self.vecstore.delete_collection(user, collection)
logger.info(f"Successfully deleted collection {user}/{collection}")
self.vecstore.delete_collection(workspace, collection)
logger.info(f"Successfully deleted collection {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
def run():

View file

@ -183,22 +183,22 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
help=f'Pinecone region, (default: {default_region}'
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""
Create collection via config push - indexes are created lazily on first write
with the correct dimension determined from the actual embeddings.
"""
try:
logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write")
logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write")
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for graph embeddings via config push"""
try:
prefix = f"t-{user}-{collection}-"
prefix = f"t-{workspace}-{collection}-"
# Get all indexes and filter for matches
all_indexes = self.pinecone.list_indexes()
@ -213,10 +213,10 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
for index_name in matching_indexes:
self.pinecone.delete_index(index_name)
logger.info(f"Deleted Pinecone index: {index_name}")
logger.info(f"Deleted {len(matching_indexes)} index(es) for {user}/{collection}")
logger.info(f"Deleted {len(matching_indexes)} index(es) for {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
def run():

View file

@ -126,22 +126,22 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
help=f'Qdrant API key'
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""
Create collection via config push - collections are created lazily on first write
with the correct dimension determined from the actual embeddings.
"""
try:
logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write")
logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write")
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for graph embeddings via config push"""
try:
prefix = f"t_{user}_{collection}_"
prefix = f"t_{workspace}_{collection}_"
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
@ -156,10 +156,10 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
def run():

View file

@ -2,13 +2,13 @@
Row embeddings writer for Qdrant (Stage 2).
Consumes RowEmbeddings messages (which already contain computed vectors)
and writes them to Qdrant. One Qdrant collection per (user, collection, schema_name) pair.
and writes them to Qdrant. One Qdrant collection per (workspace, collection, schema_name) pair.
This follows the two-stage pattern used by graph-embeddings and document-embeddings:
Stage 1 (row-embeddings): Compute embeddings
Stage 2 (this processor): Store embeddings
Collection naming: rows_{user}_{collection}_{schema_name}_{dimension}
Collection naming: rows_{workspace}_{collection}_{schema_name}_{dimension}
Payload structure:
- index_name: The indexed field(s) this embedding represents
@ -77,10 +77,10 @@ class Processor(CollectionConfigHandler, FlowProcessor):
return safe_name.lower()
def get_collection_name(
self, user: str, collection: str, schema_name: str, dimension: int
self, workspace: str, collection: str, schema_name: str, dimension: int
) -> str:
"""Generate Qdrant collection name"""
safe_user = self.sanitize_name(user)
safe_user = self.sanitize_name(workspace)
safe_collection = self.sanitize_name(collection)
safe_schema = self.sanitize_name(schema_name)
return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}"
@ -169,17 +169,17 @@ class Processor(CollectionConfigHandler, FlowProcessor):
logger.info(f"Wrote {embeddings_written} embeddings to Qdrant")
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Collection creation via config push - collections created lazily on first write"""
logger.info(
f"Row embeddings collection create request for {user}/{collection} - "
f"Row embeddings collection create request for {workspace}/{collection} - "
f"will be created lazily on first write"
)
async def delete_collection(self, user: str, collection: str):
"""Delete all Qdrant collections for a given user/collection"""
async def delete_collection(self, workspace: str, collection: str):
"""Delete all Qdrant collections for a given workspace/collection"""
try:
prefix = f"rows_{self.sanitize_name(user)}_{self.sanitize_name(collection)}_"
prefix = f"rows_{self.sanitize_name(workspace)}_{self.sanitize_name(collection)}_"
# Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections
@ -197,23 +197,23 @@ class Processor(CollectionConfigHandler, FlowProcessor):
logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(
f"Deleted {len(matching_collections)} collection(s) "
f"for {user}/{collection}"
f"for {workspace}/{collection}"
)
except Exception as e:
logger.error(
f"Failed to delete collection {user}/{collection}: {e}",
f"Failed to delete collection {workspace}/{collection}: {e}",
exc_info=True
)
raise
async def delete_collection_schema(
self, user: str, collection: str, schema_name: str
self, workspace: str, collection: str, schema_name: str
):
"""Delete Qdrant collection for a specific user/collection/schema"""
"""Delete Qdrant collection for a specific workspace/collection/schema"""
try:
prefix = (
f"rows_{self.sanitize_name(user)}_"
f"rows_{self.sanitize_name(workspace)}_"
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
)
@ -234,7 +234,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
except Exception as e:
logger.error(
f"Failed to delete collection {user}/{collection}/{schema_name}: {e}",
f"Failed to delete collection {workspace}/{collection}/{schema_name}: {e}",
exc_info=True
)
raise

View file

@ -459,25 +459,25 @@ class Processor(CollectionConfigHandler, FlowProcessor):
f"({len(index_names)} indexes per row)"
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create/verify collection exists in Cassandra row store"""
# Connect if not already connected (sync, push to thread)
await asyncio.to_thread(self.connect_cassandra)
# Ensure tables exist (sync DDL, push to thread)
await asyncio.to_thread(self.ensure_tables, user)
await asyncio.to_thread(self.ensure_tables, workspace)
logger.info(f"Collection {collection} ready for user {user}")
logger.info(f"Collection {collection} ready for workspace {workspace}")
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete all data for a specific collection using partition tracking"""
# Connect if not already connected
await asyncio.to_thread(self.connect_cassandra)
safe_keyspace = self.sanitize_name(user)
safe_keyspace = self.sanitize_name(workspace)
# Check if keyspace exists
if user not in self.known_keyspaces:
if workspace not in self.known_keyspaces:
check_keyspace_cql = """
SELECT keyspace_name FROM system_schema.keyspaces
WHERE keyspace_name = %s
@ -488,7 +488,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
if not result:
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
return
self.known_keyspaces.add(user)
self.known_keyspaces.add(workspace)
# Discover all partitions for this collection
select_partitions_cql = f"""
@ -551,12 +551,12 @@ class Processor(CollectionConfigHandler, FlowProcessor):
f"from keyspace {safe_keyspace}"
)
async def delete_collection_schema(self, user: str, collection: str, schema_name: str):
async def delete_collection_schema(self, workspace: str, collection: str, schema_name: str):
"""Delete all data for a specific collection + schema combination"""
# Connect if not already connected
await asyncio.to_thread(self.connect_cassandra)
safe_keyspace = self.sanitize_name(user)
safe_keyspace = self.sanitize_name(workspace)
# Discover partitions for this collection + schema
select_partitions_cql = f"""

View file

@ -210,12 +210,12 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
await asyncio.to_thread(_do_store)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create a collection in Cassandra triple store via config push"""
def _do_create():
# Create or reuse connection for this user's keyspace
if self.table is None or self.table != user:
# Create or reuse connection for this workspace's keyspace
if self.table is None or self.table != workspace:
self.tg = None
# Use factory function to select implementation
@ -225,23 +225,23 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
keyspace=workspace,
username=self.cassandra_username,
password=self.cassandra_password,
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
keyspace=workspace,
)
except Exception as e:
logger.error(f"Failed to connect to Cassandra for user {user}: {e}")
logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}")
raise
self.table = user
self.table = workspace
# Create collection using the built-in method
logger.info(f"Creating collection {collection} for user {user}")
logger.info(f"Creating collection {collection} for workspace {workspace}")
if self.tg.collection_exists(collection):
logger.info(f"Collection {collection} already exists")
@ -252,15 +252,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
try:
await asyncio.to_thread(_do_create)
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete all data for a specific collection from the unified triples table"""
def _do_delete():
# Create or reuse connection for this user's keyspace
if self.table is None or self.table != user:
# Create or reuse connection for this workspace's keyspace
if self.table is None or self.table != workspace:
self.tg = None
# Use factory function to select implementation
@ -270,29 +270,29 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
keyspace=workspace,
username=self.cassandra_username,
password=self.cassandra_password,
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
keyspace=workspace,
)
except Exception as e:
logger.error(f"Failed to connect to Cassandra for user {user}: {e}")
logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}")
raise
self.table = user
self.table = workspace
# Delete all triples for this collection using the built-in method
self.tg.delete_collection(collection)
logger.info(f"Deleted all triples for collection {collection} from keyspace {user}")
logger.info(f"Deleted all triples for collection {collection} from keyspace {workspace}")
try:
await asyncio.to_thread(_do_delete)
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
@staticmethod

View file

@ -59,15 +59,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
# Register for config push notifications
self.register_config_handler(self.on_collection_config, types=["collection"])
def create_node(self, uri, user, collection):
def create_node(self, uri, workspace, collection):
logger.debug(f"Create node {uri} for user={user}, collection={collection}")
logger.debug(f"Create node {uri} for workspace={workspace}, collection={collection}")
res = self.io.query(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
params={
"uri": uri,
"user": user,
"workspace": workspace,
"collection": collection,
},
)
@ -77,15 +77,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=res.run_time_ms
))
def create_literal(self, value, user, collection):
def create_literal(self, value, workspace, collection):
logger.debug(f"Create literal {value} for user={user}, collection={collection}")
logger.debug(f"Create literal {value} for workspace={workspace}, collection={collection}")
res = self.io.query(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
params={
"value": value,
"user": user,
"workspace": workspace,
"collection": collection,
},
)
@ -95,19 +95,19 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=res.run_time_ms
))
def relate_node(self, src, uri, dest, user, collection):
def relate_node(self, src, uri, dest, workspace, collection):
logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}")
logger.debug(f"Create node rel {src} {uri} {dest} for workspace={workspace}, collection={collection}")
res = self.io.query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
params={
"src": src,
"dest": dest,
"uri": uri,
"user": user,
"workspace": workspace,
"collection": collection,
},
)
@ -117,19 +117,19 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=res.run_time_ms
))
def relate_literal(self, src, uri, dest, user, collection):
def relate_literal(self, src, uri, dest, workspace, collection):
logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}")
logger.debug(f"Create literal rel {src} {uri} {dest} for workspace={workspace}, collection={collection}")
res = self.io.query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
params={
"src": src,
"dest": dest,
"uri": uri,
"user": user,
"workspace": workspace,
"collection": collection,
},
)
@ -139,28 +139,28 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=res.run_time_ms
))
def collection_exists(self, user, collection):
def collection_exists(self, workspace, collection):
"""Check if collection metadata node exists"""
result = self.io.query(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"RETURN c LIMIT 1",
params={"user": user, "collection": collection}
params={"workspace": workspace, "collection": collection}
)
return result.result_set is not None and len(result.result_set) > 0
def create_collection(self, user, collection):
def create_collection(self, workspace, collection):
"""Create collection metadata node"""
import datetime
self.io.query(
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) "
"MERGE (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"SET c.created_at = $created_at",
params={
"user": user,
"workspace": workspace,
"collection": collection,
"created_at": datetime.datetime.now().isoformat()
}
)
logger.info(f"Created collection metadata node for {user}/{collection}")
logger.info(f"Created collection metadata node for {workspace}/{collection}")
async def store_triples(self, workspace, message):
collection = message.metadata.collection if message.metadata.collection else "default"
@ -206,58 +206,58 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
help=f'FalkorDB database (default: {default_database})'
)
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create collection metadata in FalkorDB via config push"""
try:
# Check if collection exists
result = self.io.query(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) RETURN c LIMIT 1",
params={"user": user, "collection": collection}
"MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) RETURN c LIMIT 1",
params={"workspace": workspace, "collection": collection}
)
if result.result_set:
logger.info(f"Collection {user}/{collection} already exists")
logger.info(f"Collection {workspace}/{collection} already exists")
else:
# Create collection metadata node
import datetime
self.io.query(
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) "
"MERGE (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"SET c.created_at = $created_at",
params={
"user": user,
"workspace": workspace,
"collection": collection,
"created_at": datetime.datetime.now().isoformat()
}
)
logger.info(f"Created collection {user}/{collection}")
logger.info(f"Created collection {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for FalkorDB triples via config push"""
try:
# Delete all nodes and literals for this user/collection
# Delete all nodes and literals for this workspace/collection
node_result = self.io.query(
"MATCH (n:Node {user: $user, collection: $collection}) DETACH DELETE n",
params={"user": user, "collection": collection}
"MATCH (n:Node {workspace: $workspace, collection: $collection}) DETACH DELETE n",
params={"workspace": workspace, "collection": collection}
)
literal_result = self.io.query(
"MATCH (n:Literal {user: $user, collection: $collection}) DETACH DELETE n",
params={"user": user, "collection": collection}
"MATCH (n:Literal {workspace: $workspace, collection: $collection}) DETACH DELETE n",
params={"workspace": workspace, "collection": collection}
)
# Delete collection metadata node
metadata_result = self.io.query(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) DELETE c",
params={"user": user, "collection": collection}
"MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) DELETE c",
params={"workspace": workspace, "collection": collection}
)
logger.info(f"Deleted {node_result.nodes_deleted} nodes, {literal_result.nodes_deleted} literals, and {metadata_result.nodes_deleted} metadata nodes for collection {user}/{collection}")
logger.info(f"Deleted {node_result.nodes_deleted} nodes, {literal_result.nodes_deleted} literals, and {metadata_result.nodes_deleted} metadata nodes for collection {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
def run():

View file

@ -117,10 +117,10 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
# Maybe index already exists
logger.warning("Index create failure ignored")
# New indexes for user/collection filtering
# New indexes for workspace/collection filtering
try:
session.run(
"CREATE INDEX ON :Node(user)"
"CREATE INDEX ON :Node(workspace)"
)
except Exception as e:
logger.warning(f"User index create failure: {e}")
@ -136,7 +136,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
try:
session.run(
"CREATE INDEX ON :Literal(user)"
"CREATE INDEX ON :Literal(workspace)"
)
except Exception as e:
logger.warning(f"User index create failure: {e}")
@ -152,13 +152,13 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
logger.info("Index creation done")
def create_node(self, uri, user, collection):
def create_node(self, uri, workspace, collection):
logger.debug(f"Create node {uri} for user={user}, collection={collection}")
logger.debug(f"Create node {uri} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
uri=uri, user=user, collection=collection,
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=uri, workspace=workspace, collection=collection,
database_=self.db,
).summary
@ -167,13 +167,13 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after
))
def create_literal(self, value, user, collection):
def create_literal(self, value, workspace, collection):
logger.debug(f"Create literal {value} for user={user}, collection={collection}")
logger.debug(f"Create literal {value} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
value=value, user=user, collection=collection,
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value=value, workspace=workspace, collection=collection,
database_=self.db,
).summary
@ -182,15 +182,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after
))
def relate_node(self, src, uri, dest, user, collection):
def relate_node(self, src, uri, dest, workspace, collection):
logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}")
logger.debug(f"Create node rel {src} {uri} {dest} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, user=user, collection=collection,
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, workspace=workspace, collection=collection,
database_=self.db,
).summary
@ -199,15 +199,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after
))
def relate_literal(self, src, uri, dest, user, collection):
def relate_literal(self, src, uri, dest, workspace, collection):
logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}")
logger.debug(f"Create literal rel {src} {uri} {dest} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, user=user, collection=collection,
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, workspace=workspace, collection=collection,
database_=self.db,
).summary
@ -216,7 +216,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after
))
def create_triple(self, tx, t, user, collection):
def create_triple(self, tx, t, workspace, collection):
s_val = get_term_value(t.s)
p_val = get_term_value(t.p)
@ -224,38 +224,38 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
# Create new s node with given uri, if not exists
result = tx.run(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
uri=s_val, user=user, collection=collection
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=s_val, workspace=workspace, collection=collection
)
if t.o.type == IRI:
# Create new o node with given uri, if not exists
result = tx.run(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
uri=o_val, user=user, collection=collection
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=o_val, workspace=workspace, collection=collection
)
result = tx.run(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=s_val, dest=o_val, uri=p_val, user=user, collection=collection,
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=s_val, dest=o_val, uri=p_val, workspace=workspace, collection=collection,
)
else:
# Create new o literal with given uri, if not exists
result = tx.run(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
value=o_val, user=user, collection=collection
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value=o_val, workspace=workspace, collection=collection
)
result = tx.run(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=s_val, dest=o_val, uri=p_val, user=user, collection=collection,
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=s_val, dest=o_val, uri=p_val, workspace=workspace, collection=collection,
)
async def store_triples(self, workspace, message):
@ -288,7 +288,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
# Alternative implementation using transactions
# with self.io.session(database=self.db) as session:
# session.execute_write(self.create_triple, t, user, collection)
# session.execute_write(self.create_triple, t, workspace, collection)
@staticmethod
def add_args(parser):
@ -319,72 +319,72 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
help=f'Memgraph database (default: {default_database})'
)
def _collection_exists_in_db(self, user, collection):
def _collection_exists_in_db(self, workspace, collection):
"""Check if collection metadata node exists"""
with self.io.session(database=self.db) as session:
result = session.run(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"RETURN c LIMIT 1",
user=user, collection=collection
workspace=workspace, collection=collection
)
return bool(list(result))
def _create_collection_in_db(self, user, collection):
def _create_collection_in_db(self, workspace, collection):
"""Create collection metadata node"""
import datetime
with self.io.session(database=self.db) as session:
session.run(
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) "
"MERGE (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"SET c.created_at = $created_at",
user=user, collection=collection,
workspace=workspace, collection=collection,
created_at=datetime.datetime.now().isoformat()
)
logger.info(f"Created collection metadata node for {user}/{collection}")
logger.info(f"Created collection metadata node for {workspace}/{collection}")
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create collection metadata in Memgraph via config push"""
try:
if self._collection_exists_in_db(user, collection):
logger.info(f"Collection {user}/{collection} already exists")
if self._collection_exists_in_db(workspace, collection):
logger.info(f"Collection {workspace}/{collection} already exists")
else:
self._create_collection_in_db(user, collection)
logger.info(f"Created collection {user}/{collection}")
self._create_collection_in_db(workspace, collection)
logger.info(f"Created collection {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete all data for a specific collection via config push"""
try:
with self.io.session(database=self.db) as session:
# Delete all nodes for this user and collection
# Delete all nodes for this workspace and collection
node_result = session.run(
"MATCH (n:Node {user: $user, collection: $collection}) "
"MATCH (n:Node {workspace: $workspace, collection: $collection}) "
"DETACH DELETE n",
user=user, collection=collection
workspace=workspace, collection=collection
)
nodes_deleted = node_result.consume().counters.nodes_deleted
# Delete all literals for this user and collection
# Delete all literals for this workspace and collection
literal_result = session.run(
"MATCH (n:Literal {user: $user, collection: $collection}) "
"MATCH (n:Literal {workspace: $workspace, collection: $collection}) "
"DETACH DELETE n",
user=user, collection=collection
workspace=workspace, collection=collection
)
literals_deleted = literal_result.consume().counters.nodes_deleted
# Delete collection metadata node
metadata_result = session.run(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"DELETE c",
user=user, collection=collection
workspace=workspace, collection=collection
)
metadata_deleted = metadata_result.consume().counters.nodes_deleted
# Note: Relationships are automatically deleted with DETACH DELETE
logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {user}/{collection}")
logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")

View file

@ -80,14 +80,12 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
logger.info("Create indexes...")
# Legacy indexes for backwards compatibility
try:
session.run(
"CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)",
)
except Exception as e:
logger.warning(f"Index create failure: {e}")
# Maybe index already exists
logger.warning("Index create failure ignored")
try:
@ -96,7 +94,6 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
)
except Exception as e:
logger.warning(f"Index create failure: {e}")
# Maybe index already exists
logger.warning("Index create failure ignored")
try:
@ -105,13 +102,11 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
)
except Exception as e:
logger.warning(f"Index create failure: {e}")
# Maybe index already exists
logger.warning("Index create failure ignored")
# New compound indexes for user/collection filtering
try:
session.run(
"CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)",
"CREATE INDEX node_workspace_collection_uri FOR (n:Node) ON (n.workspace, n.collection, n.uri)",
)
except Exception as e:
logger.warning(f"Compound index create failure: {e}")
@ -119,17 +114,16 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
try:
session.run(
"CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)",
"CREATE INDEX literal_workspace_collection_value FOR (n:Literal) ON (n.workspace, n.collection, n.value)",
)
except Exception as e:
logger.warning(f"Compound index create failure: {e}")
logger.warning("Index create failure ignored")
# Note: Neo4j doesn't support compound indexes on relationships in all versions
# Try to create individual indexes on relationship properties
# Neo4j doesn't support compound indexes on relationships in all versions
try:
session.run(
"CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)",
"CREATE INDEX rel_workspace FOR ()-[r:Rel]-() ON (r.workspace)",
)
except Exception as e:
logger.warning(f"Relationship index create failure: {e}")
@ -145,13 +139,13 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
logger.info("Index creation done")
def create_node(self, uri, user, collection):
def create_node(self, uri, workspace, collection):
logger.debug(f"Create node {uri} for user={user}, collection={collection}")
logger.debug(f"Create node {uri} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
uri=uri, user=user, collection=collection,
"MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=uri, workspace=workspace, collection=collection,
database_=self.db,
).summary
@ -160,13 +154,13 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after
))
def create_literal(self, value, user, collection):
def create_literal(self, value, workspace, collection):
logger.debug(f"Create literal {value} for user={user}, collection={collection}")
logger.debug(f"Create literal {value} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})",
value=value, user=user, collection=collection,
"MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value=value, workspace=workspace, collection=collection,
database_=self.db,
).summary
@ -175,15 +169,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after
))
def relate_node(self, src, uri, dest, user, collection):
def relate_node(self, src, uri, dest, workspace, collection):
logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}")
logger.debug(f"Create node rel {src} {uri} {dest} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, user=user, collection=collection,
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, workspace=workspace, collection=collection,
database_=self.db,
).summary
@ -192,15 +186,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after
))
def relate_literal(self, src, uri, dest, user, collection):
def relate_literal(self, src, uri, dest, workspace, collection):
logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}")
logger.debug(f"Create literal rel {src} {uri} {dest} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, user=user, collection=collection,
"MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, workspace=workspace, collection=collection,
database_=self.db,
).summary
@ -266,75 +260,70 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
help=f'Neo4j database (default: {default_database})'
)
def _collection_exists_in_db(self, user, collection):
def _collection_exists_in_db(self, workspace, collection):
"""Check if collection metadata node exists"""
with self.io.session(database=self.db) as session:
result = session.run(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"RETURN c LIMIT 1",
user=user, collection=collection
workspace=workspace, collection=collection
)
return bool(list(result))
def _create_collection_in_db(self, user, collection):
def _create_collection_in_db(self, workspace, collection):
"""Create collection metadata node"""
import datetime
with self.io.session(database=self.db) as session:
session.run(
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) "
"MERGE (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"SET c.created_at = $created_at",
user=user, collection=collection,
workspace=workspace, collection=collection,
created_at=datetime.datetime.now().isoformat()
)
logger.info(f"Created collection metadata node for {user}/{collection}")
logger.info(f"Created collection metadata node for {workspace}/{collection}")
async def create_collection(self, user: str, collection: str, metadata: dict):
async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create collection metadata in Neo4j via config push"""
try:
if self._collection_exists_in_db(user, collection):
logger.info(f"Collection {user}/{collection} already exists")
if self._collection_exists_in_db(workspace, collection):
logger.info(f"Collection {workspace}/{collection} already exists")
else:
self._create_collection_in_db(user, collection)
logger.info(f"Created collection {user}/{collection}")
self._create_collection_in_db(workspace, collection)
logger.info(f"Created collection {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
async def delete_collection(self, workspace: str, collection: str):
"""Delete all data for a specific collection via config push"""
try:
with self.io.session(database=self.db) as session:
# Delete all nodes for this user and collection
node_result = session.run(
"MATCH (n:Node {user: $user, collection: $collection}) "
"MATCH (n:Node {workspace: $workspace, collection: $collection}) "
"DETACH DELETE n",
user=user, collection=collection
workspace=workspace, collection=collection
)
nodes_deleted = node_result.consume().counters.nodes_deleted
# Delete all literals for this user and collection
literal_result = session.run(
"MATCH (n:Literal {user: $user, collection: $collection}) "
"MATCH (n:Literal {workspace: $workspace, collection: $collection}) "
"DETACH DELETE n",
user=user, collection=collection
workspace=workspace, collection=collection
)
literals_deleted = literal_result.consume().counters.nodes_deleted
# Note: Relationships are automatically deleted with DETACH DELETE
# Delete collection metadata node
metadata_result = session.run(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) "
"MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"DELETE c",
user=user, collection=collection
workspace=workspace, collection=collection
)
metadata_deleted = metadata_result.consume().counters.nodes_deleted
logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {user}/{collection}")
logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {workspace}/{collection}")
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise
def run():