Additional user fixes and test fixes

This commit is contained in:
Cyber MacGeddon 2026-04-21 10:53:15 +01:00
parent db05427d0e
commit 7f0f79dd15
62 changed files with 1078 additions and 1315 deletions

View file

@ -28,7 +28,6 @@ def sample_text_document():
"""Sample document with moderate length text.""" """Sample document with moderate length text."""
metadata = Metadata( metadata = Metadata(
id="test-doc-1", id="test-doc-1",
user="test-user",
collection="test-collection" collection="test-collection"
) )
text = "The quick brown fox jumps over the lazy dog. " * 20 text = "The quick brown fox jumps over the lazy dog. " * 20
@ -43,7 +42,6 @@ def long_text_document():
"""Long document for testing multiple chunks.""" """Long document for testing multiple chunks."""
metadata = Metadata( metadata = Metadata(
id="test-doc-long", id="test-doc-long",
user="test-user",
collection="test-collection" collection="test-collection"
) )
# Create a long text that will definitely be chunked # Create a long text that will definitely be chunked
@ -59,7 +57,6 @@ def unicode_text_document():
"""Document with various unicode characters.""" """Document with various unicode characters."""
metadata = Metadata( metadata = Metadata(
id="test-doc-unicode", id="test-doc-unicode",
user="test-user",
collection="test-collection" collection="test-collection"
) )
text = """ text = """
@ -84,7 +81,6 @@ def empty_text_document():
"""Empty document for edge case testing.""" """Empty document for edge case testing."""
metadata = Metadata( metadata = Metadata(
id="test-doc-empty", id="test-doc-empty",
user="test-user",
collection="test-collection" collection="test-collection"
) )
return TextDocument( return TextDocument(

View file

@ -185,7 +185,6 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_text_doc = MagicMock() mock_text_doc = MagicMock()
mock_text_doc.metadata = Metadata( mock_text_doc.metadata = Metadata(
id="test-doc-123", id="test-doc-123",
user="test-user",
collection="test-collection" collection="test-collection"
) )
mock_text_doc.text = b"This is test document content" mock_text_doc.text = b"This is test document content"

View file

@ -185,7 +185,6 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_text_doc = MagicMock() mock_text_doc = MagicMock()
mock_text_doc.metadata = Metadata( mock_text_doc.metadata = Metadata(
id="test-doc-456", id="test-doc-456",
user="test-user",
collection="test-collection" collection="test-collection"
) )
mock_text_doc.text = b"This is test document content for token chunking" mock_text_doc.text = b"This is test document content for token chunking"

View file

@ -73,7 +73,6 @@ def sample_triples():
return Triples( return Triples(
metadata=Metadata( metadata=Metadata(
id="test-doc-id", id="test-doc-id",
user="test-user",
collection="default", # This should be overridden collection="default", # This should be overridden
), ),
triples=[ triples=[
@ -92,7 +91,6 @@ def sample_graph_embeddings():
return GraphEmbeddings( return GraphEmbeddings(
metadata=Metadata( metadata=Metadata(
id="test-doc-id", id="test-doc-id",
user="test-user",
collection="default", # This should be overridden collection="default", # This should be overridden
), ),
entities=[ entities=[
@ -148,7 +146,6 @@ class TestKnowledgeManagerLoadCore:
mock_triples_pub.send.assert_called_once() mock_triples_pub.send.assert_called_once()
sent_triples = mock_triples_pub.send.call_args[0][1] sent_triples = mock_triples_pub.send.call_args[0][1]
assert sent_triples.metadata.collection == "test-collection" assert sent_triples.metadata.collection == "test-collection"
assert sent_triples.metadata.user == "test-user"
assert sent_triples.metadata.id == "test-doc-id" assert sent_triples.metadata.id == "test-doc-id"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -187,7 +184,6 @@ class TestKnowledgeManagerLoadCore:
mock_ge_pub.send.assert_called_once() mock_ge_pub.send.assert_called_once()
sent_ge = mock_ge_pub.send.call_args[0][1] sent_ge = mock_ge_pub.send.call_args[0][1]
assert sent_ge.metadata.collection == "test-collection" assert sent_ge.metadata.collection == "test-collection"
assert sent_ge.metadata.user == "test-user"
assert sent_ge.metadata.id == "test-doc-id" assert sent_ge.metadata.id == "test-doc-id"
@pytest.mark.asyncio @pytest.mark.asyncio

View file

@ -237,7 +237,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
# Mock message with inline data # Mock message with inline data
content = b"# Document Title\nBody text content." content = b"# Document Title\nBody text content."
mock_metadata = Metadata(id="test-doc", user="testuser", mock_metadata = Metadata(id="test-doc",
collection="default") collection="default")
mock_document = Document( mock_document = Document(
metadata=mock_metadata, metadata=mock_metadata,
@ -294,7 +294,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
# Mock message # Mock message
content = b"fake pdf" content = b"fake pdf"
mock_metadata = Metadata(id="test-doc", user="testuser", mock_metadata = Metadata(id="test-doc",
collection="default") collection="default")
mock_document = Document( mock_document = Document(
metadata=mock_metadata, metadata=mock_metadata,
@ -345,7 +345,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
] ]
content = b"fake pdf" content = b"fake pdf"
mock_metadata = Metadata(id="test-doc", user="testuser", mock_metadata = Metadata(id="test-doc",
collection="default") collection="default")
mock_document = Document( mock_document = Document(
metadata=mock_metadata, metadata=mock_metadata,

View file

@ -20,9 +20,8 @@ def processor():
) )
def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1", def _make_chunk_message(chunk_text="Hello world", doc_id="doc-1", collection="default"):
user="test", collection="default"): metadata = Metadata(id=doc_id, collection=collection)
metadata = Metadata(id=doc_id, user=user, collection=collection)
value = Chunk(metadata=metadata, chunk=chunk_text, document_id=doc_id) value = Chunk(metadata=metadata, chunk=chunk_text, document_id=doc_id)
msg = MagicMock() msg = MagicMock()
msg.value.return_value = value msg.value.return_value = value
@ -144,7 +143,6 @@ class TestDocumentEmbeddingsProcessor:
await processor.on_message(msg, MagicMock(), flow) await processor.on_message(msg, MagicMock(), flow)
result = mock_output.send.call_args[0][0] result = mock_output.send.call_args[0][0]
assert result.metadata.user == "alice"
assert result.metadata.collection == "reports" assert result.metadata.collection == "reports"
assert result.metadata.id == "d1" assert result.metadata.id == "d1"

View file

@ -27,8 +27,8 @@ def _make_entity_context(name, context, chunk_id="chunk-1"):
return MagicMock(entity=entity, context=context, chunk_id=chunk_id) return MagicMock(entity=entity, context=context, chunk_id=chunk_id)
def _make_message(entities, doc_id="doc-1", user="test", collection="default"): def _make_message(entities, doc_id="doc-1", collection="default"):
metadata = Metadata(id=doc_id, user=user, collection=collection) metadata = Metadata(id=doc_id, collection=collection)
value = EntityContexts(metadata=metadata, entities=entities) value = EntityContexts(metadata=metadata, entities=entities)
msg = MagicMock() msg = MagicMock()
msg.value.return_value = value msg.value.return_value = value
@ -151,7 +151,7 @@ class TestGraphEmbeddingsBatchProcessing:
_make_entity_context(f"E{i}", f"ctx {i}") _make_entity_context(f"E{i}", f"ctx {i}")
for i in range(5) for i in range(5)
] ]
msg = _make_message(entities, doc_id="doc-42", user="alice", collection="main") msg = _make_message(entities, doc_id="doc-42", collection="main")
mock_embed = AsyncMock(return_value=[[0.0]] * 5) mock_embed = AsyncMock(return_value=[[0.0]] * 5)
mock_output = AsyncMock() mock_output = AsyncMock()
@ -168,7 +168,6 @@ class TestGraphEmbeddingsBatchProcessing:
for call in mock_output.send.call_args_list: for call in mock_output.send.call_args_list:
result = call[0][0] result = call[0][0]
assert result.metadata.id == "doc-42" assert result.metadata.id == "doc-42"
assert result.metadata.user == "alice"
assert result.metadata.collection == "main" assert result.metadata.collection == "main"
@pytest.mark.asyncio @pytest.mark.asyncio

View file

@ -34,11 +34,10 @@ def _make_defn(entity, definition):
return {"entity": entity, "definition": definition} return {"entity": entity, "definition": definition}
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", collection="col-1", document_id=""):
user="user-1", collection="col-1", document_id=""):
chunk = Chunk( chunk = Chunk(
metadata=Metadata( metadata=Metadata(
id=meta_id, root=root, user=user, collection=collection, id=meta_id, root=root, collection=collection,
), ),
chunk=text.encode("utf-8"), chunk=text.encode("utf-8"),
document_id=document_id, document_id=document_id,
@ -229,8 +228,7 @@ class TestMetadataPreservation:
defs = [_make_defn("X", "def X")] defs = [_make_defn("X", "def X")]
flow, triples_pub, _, _ = _make_flow(defs) flow, triples_pub, _, _ = _make_flow(defs)
msg = _make_chunk_msg( msg = _make_chunk_msg(
"text", meta_id="c-1", root="r-1", "text", meta_id="c-1", root="r-1", collection="coll-1",
user="u-1", collection="coll-1",
) )
await proc.on_message(msg, MagicMock(), flow) await proc.on_message(msg, MagicMock(), flow)
@ -238,7 +236,6 @@ class TestMetadataPreservation:
for triples_msg in _sent_triples(triples_pub): for triples_msg in _sent_triples(triples_pub):
assert triples_msg.metadata.id == "c-1" assert triples_msg.metadata.id == "c-1"
assert triples_msg.metadata.root == "r-1" assert triples_msg.metadata.root == "r-1"
assert triples_msg.metadata.user == "u-1"
assert triples_msg.metadata.collection == "coll-1" assert triples_msg.metadata.collection == "coll-1"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -247,8 +244,7 @@ class TestMetadataPreservation:
defs = [_make_defn("X", "def X")] defs = [_make_defn("X", "def X")]
flow, _, ecs_pub, _ = _make_flow(defs) flow, _, ecs_pub, _ = _make_flow(defs)
msg = _make_chunk_msg( msg = _make_chunk_msg(
"text", meta_id="c-2", root="r-2", "text", meta_id="c-2", root="r-2", collection="coll-2",
user="u-2", collection="coll-2",
) )
await proc.on_message(msg, MagicMock(), flow) await proc.on_message(msg, MagicMock(), flow)

View file

@ -38,12 +38,11 @@ def _make_rel(subject, predicate, obj, object_entity=True):
} }
def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", def _make_chunk_msg(text, meta_id="chunk-1", root="root-1", collection="col-1", document_id=""):
user="user-1", collection="col-1", document_id=""):
"""Build a mock message wrapping a Chunk.""" """Build a mock message wrapping a Chunk."""
chunk = Chunk( chunk = Chunk(
metadata=Metadata( metadata=Metadata(
id=meta_id, root=root, user=user, collection=collection, id=meta_id, root=root, collection=collection,
), ),
chunk=text.encode("utf-8"), chunk=text.encode("utf-8"),
document_id=document_id, document_id=document_id,
@ -189,8 +188,7 @@ class TestMetadataPreservation:
rels = [_make_rel("X", "rel", "Y")] rels = [_make_rel("X", "rel", "Y")]
flow, pub, _ = _make_flow(rels) flow, pub, _ = _make_flow(rels)
msg = _make_chunk_msg( msg = _make_chunk_msg(
"text", meta_id="c-1", root="r-1", "text", meta_id="c-1", root="r-1", collection="coll-1",
user="u-1", collection="coll-1",
) )
await proc.on_message(msg, MagicMock(), flow) await proc.on_message(msg, MagicMock(), flow)
@ -198,7 +196,6 @@ class TestMetadataPreservation:
for triples_msg in _sent_triples(pub): for triples_msg in _sent_triples(pub):
assert triples_msg.metadata.id == "c-1" assert triples_msg.metadata.id == "c-1"
assert triples_msg.metadata.root == "r-1" assert triples_msg.metadata.root == "r-1"
assert triples_msg.metadata.user == "u-1"
assert triples_msg.metadata.collection == "coll-1" assert triples_msg.metadata.collection == "coll-1"

View file

@ -186,7 +186,6 @@ class TestEntityContextsImportMessageProcessing:
assert isinstance(sent, EntityContexts) assert isinstance(sent, EntityContexts)
assert isinstance(sent.metadata, Metadata) assert isinstance(sent.metadata, Metadata)
assert sent.metadata.id == "doc-123" assert sent.metadata.id == "doc-123"
assert sent.metadata.user == "testuser"
assert sent.metadata.collection == "testcollection" assert sent.metadata.collection == "testcollection"
assert len(sent.entities) == 2 assert len(sent.entities) == 2

View file

@ -188,7 +188,6 @@ class TestGraphEmbeddingsImportMessageProcessing:
assert isinstance(sent, GraphEmbeddings) assert isinstance(sent, GraphEmbeddings)
assert isinstance(sent.metadata, Metadata) assert isinstance(sent.metadata, Metadata)
assert sent.metadata.id == "doc-123" assert sent.metadata.id == "doc-123"
assert sent.metadata.user == "testuser"
assert sent.metadata.collection == "testcollection" assert sent.metadata.collection == "testcollection"
assert len(sent.entities) == 2 assert len(sent.entities) == 2

View file

@ -235,7 +235,6 @@ class TestRowsImportMessageProcessing:
# Check metadata # Check metadata
assert sent_object.metadata.id == "obj-123" assert sent_object.metadata.id == "obj-123"
assert sent_object.metadata.user == "testuser"
assert sent_object.metadata.collection == "testcollection" assert sent_object.metadata.collection == "testcollection"
@patch('trustgraph.gateway.dispatch.rows_import.Publisher') @patch('trustgraph.gateway.dispatch.rows_import.Publisher')

View file

@ -23,7 +23,6 @@ class TestTextDocumentTranslator:
) )
assert msg.metadata.id == "doc-1" assert msg.metadata.id == "doc-1"
assert msg.metadata.user == "alice"
assert msg.metadata.collection == "research" assert msg.metadata.collection == "research"
assert msg.text == payload.encode("utf-8") assert msg.text == payload.encode("utf-8")

View file

@ -108,7 +108,6 @@ def sample_triples(sample_triple):
"""Sample Triples batch object""" """Sample Triples batch object"""
metadata = Metadata( metadata = Metadata(
id="test-doc-123", id="test-doc-123",
user="test_user",
collection="test_collection", collection="test_collection",
) )
@ -123,7 +122,6 @@ def sample_chunk():
"""Sample text chunk for processing""" """Sample text chunk for processing"""
metadata = Metadata( metadata = Metadata(
id="test-chunk-456", id="test-chunk-456",
user="test_user",
collection="test_collection", collection="test_collection",
) )

View file

@ -322,7 +322,6 @@ This is not JSON at all
assert isinstance(sent_triples, Triples) assert isinstance(sent_triples, Triples)
# Check metadata fields individually since implementation creates new Metadata object # Check metadata fields individually since implementation creates new Metadata object
assert sent_triples.metadata.id == sample_metadata.id assert sent_triples.metadata.id == sample_metadata.id
assert sent_triples.metadata.user == sample_metadata.user
assert sent_triples.metadata.collection == sample_metadata.collection assert sent_triples.metadata.collection == sample_metadata.collection
assert len(sent_triples.triples) == 1 assert len(sent_triples.triples) == 1
assert sent_triples.triples[0].s.iri == "test:subject" assert sent_triples.triples[0].s.iri == "test:subject"
@ -346,7 +345,6 @@ This is not JSON at all
assert isinstance(sent_contexts, EntityContexts) assert isinstance(sent_contexts, EntityContexts)
# Check metadata fields individually since implementation creates new Metadata object # Check metadata fields individually since implementation creates new Metadata object
assert sent_contexts.metadata.id == sample_metadata.id assert sent_contexts.metadata.id == sample_metadata.id
assert sent_contexts.metadata.user == sample_metadata.user
assert sent_contexts.metadata.collection == sample_metadata.collection assert sent_contexts.metadata.collection == sample_metadata.collection
assert len(sent_contexts.entities) == 1 assert len(sent_contexts.entities) == 1
assert sent_contexts.entities[0].entity.iri == "test:entity" assert sent_contexts.entities[0].entity.iri == "test:entity"

View file

@ -312,7 +312,6 @@ class TestObjectExtractionBusinessLogic:
# Arrange # Arrange
metadata = Metadata( metadata = Metadata(
id="test-extraction-001", id="test-extraction-001",
user="test_user",
collection="test_collection", collection="test_collection",
) )
@ -337,7 +336,6 @@ class TestObjectExtractionBusinessLogic:
assert extracted_obj.values[0]["customer_id"] == "CUST001" assert extracted_obj.values[0]["customer_id"] == "CUST001"
assert extracted_obj.confidence == 0.95 assert extracted_obj.confidence == 0.95
assert "John Doe" in extracted_obj.source_span assert "John Doe" in extracted_obj.source_span
assert extracted_obj.metadata.user == "test_user"
def test_config_parsing_error_handling(self): def test_config_parsing_error_handling(self):
"""Test configuration parsing with invalid JSON""" """Test configuration parsing with invalid JSON"""

View file

@ -371,7 +371,6 @@ class TestTripleConstructionLogic:
metadata = Metadata( metadata = Metadata(
id="test-doc-123", id="test-doc-123",
user="test_user",
collection="test_collection", collection="test_collection",
) )
@ -384,7 +383,6 @@ class TestTripleConstructionLogic:
# Assert # Assert
assert isinstance(triples_batch, Triples) assert isinstance(triples_batch, Triples)
assert triples_batch.metadata.id == "test-doc-123" assert triples_batch.metadata.id == "test-doc-123"
assert triples_batch.metadata.user == "test_user"
assert triples_batch.metadata.collection == "test_collection" assert triples_batch.metadata.collection == "test_collection"
assert len(triples_batch.triples) == 2 assert len(triples_batch.triples) == 2

View file

@ -17,7 +17,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Create a mock message for testing""" """Create a mock message for testing"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
# Create test document embeddings # Create test document embeddings
@ -80,7 +79,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings for a single chunk""" """Test storing document embeddings for a single chunk"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
@ -89,7 +87,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
) )
message.chunks = [chunk] message.chunks = [chunk]
await processor.store_document_embeddings(message) await processor.store_document_embeddings('test_user', message)
# Verify insert was called once for the single chunk with its vector # Verify insert was called once for the single chunk with its vector
processor.vecstore.insert.assert_called_once_with( processor.vecstore.insert.assert_called_once_with(
@ -122,7 +120,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with empty chunk (should be skipped)""" """Test storing document embeddings with empty chunk (should be skipped)"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
@ -131,7 +128,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
) )
message.chunks = [chunk] message.chunks = [chunk]
await processor.store_document_embeddings(message) await processor.store_document_embeddings('test_user', message)
# Verify no insert was called for empty chunk # Verify no insert was called for empty chunk
processor.vecstore.insert.assert_not_called() processor.vecstore.insert.assert_not_called()
@ -141,7 +138,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with None chunk_id""" """Test storing document embeddings with None chunk_id"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
@ -150,7 +146,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
) )
message.chunks = [chunk] message.chunks = [chunk]
await processor.store_document_embeddings(message) await processor.store_document_embeddings('test_user', message)
# Note: Implementation passes through None chunk_ids (only skips empty string "") # Note: Implementation passes through None chunk_ids (only skips empty string "")
processor.vecstore.insert.assert_called_once_with( processor.vecstore.insert.assert_called_once_with(
@ -162,7 +158,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with mix of valid and empty chunks""" """Test storing document embeddings with mix of valid and empty chunks"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
valid_chunk = ChunkEmbeddings( valid_chunk = ChunkEmbeddings(
@ -179,7 +174,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
) )
message.chunks = [valid_chunk, empty_chunk, another_valid] message.chunks = [valid_chunk, empty_chunk, another_valid]
await processor.store_document_embeddings(message) await processor.store_document_embeddings('test_user', message)
# Verify valid chunks were inserted, empty string chunk was skipped # Verify valid chunks were inserted, empty string chunk was skipped
expected_calls = [ expected_calls = [
@ -200,11 +195,10 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with empty chunks list""" """Test storing document embeddings with empty chunks list"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
message.chunks = [] message.chunks = []
await processor.store_document_embeddings(message) await processor.store_document_embeddings('test_user', message)
# Verify no insert was called # Verify no insert was called
processor.vecstore.insert.assert_not_called() processor.vecstore.insert.assert_not_called()
@ -214,7 +208,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings for chunk with no vectors""" """Test storing document embeddings for chunk with no vectors"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
@ -223,7 +216,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
) )
message.chunks = [chunk] message.chunks = [chunk]
await processor.store_document_embeddings(message) await processor.store_document_embeddings('test_user', message)
# Verify no insert was called (no vectors to insert) # Verify no insert was called (no vectors to insert)
processor.vecstore.insert.assert_not_called() processor.vecstore.insert.assert_not_called()
@ -233,7 +226,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with different vector dimensions""" """Test storing document embeddings with different vector dimensions"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
# Each chunk has a single vector of different dimensions # Each chunk has a single vector of different dimensions
@ -251,7 +243,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
) )
message.chunks = [chunk1, chunk2, chunk3] message.chunks = [chunk1, chunk2, chunk3]
await processor.store_document_embeddings(message) await processor.store_document_embeddings('test_user', message)
# Verify all vectors were inserted regardless of dimension with user/collection parameters # Verify all vectors were inserted regardless of dimension with user/collection parameters
expected_calls = [ expected_calls = [
@ -273,7 +265,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with Unicode content in chunk_id""" """Test storing document embeddings with Unicode content in chunk_id"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
@ -282,7 +273,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
) )
message.chunks = [chunk] message.chunks = [chunk]
await processor.store_document_embeddings(message) await processor.store_document_embeddings('test_user', message)
# Verify Unicode chunk_id was stored correctly with user/collection parameters # Verify Unicode chunk_id was stored correctly with user/collection parameters
processor.vecstore.insert.assert_called_once_with( processor.vecstore.insert.assert_called_once_with(
@ -294,7 +285,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with long chunk_id""" """Test storing document embeddings with long chunk_id"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
# Create a long chunk_id # Create a long chunk_id
@ -305,7 +295,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
) )
message.chunks = [chunk] message.chunks = [chunk]
await processor.store_document_embeddings(message) await processor.store_document_embeddings('test_user', message)
# Verify long chunk_id was inserted with user/collection parameters # Verify long chunk_id was inserted with user/collection parameters
processor.vecstore.insert.assert_called_once_with( processor.vecstore.insert.assert_called_once_with(
@ -317,7 +307,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with whitespace-only chunk""" """Test storing document embeddings with whitespace-only chunk"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
@ -326,7 +315,7 @@ class TestMilvusDocEmbeddingsStorageProcessor:
) )
message.chunks = [chunk] message.chunks = [chunk]
await processor.store_document_embeddings(message) await processor.store_document_embeddings('test_user', message)
# Verify whitespace content was inserted (not filtered out) with user/collection parameters # Verify whitespace content was inserted (not filtered out) with user/collection parameters
processor.vecstore.insert.assert_called_once_with( processor.vecstore.insert.assert_called_once_with(
@ -343,12 +332,11 @@ class TestMilvusDocEmbeddingsStorageProcessor:
('test@domain.com', 'test-collection.v1'), ('test@domain.com', 'test-collection.v1'),
] ]
for user, collection in test_cases: for workspace, collection in test_cases:
processor.vecstore.reset_mock() # Reset mock for each test case processor.vecstore.reset_mock() # Reset mock for each test case
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = user
message.metadata.collection = collection message.metadata.collection = collection
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
@ -357,11 +345,11 @@ class TestMilvusDocEmbeddingsStorageProcessor:
) )
message.chunks = [chunk] message.chunks = [chunk]
await processor.store_document_embeddings(message) await processor.store_document_embeddings(workspace, message)
# Verify insert was called with the correct user/collection # Verify insert was called with the correct workspace/collection
processor.vecstore.insert.assert_called_once_with( processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], "Test content", user, collection [0.1, 0.2, 0.3], "Test content", workspace, collection
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@ -370,7 +358,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
# Store embeddings for user1/collection1 # Store embeddings for user1/collection1
message1 = MagicMock() message1 = MagicMock()
message1.metadata = MagicMock() message1.metadata = MagicMock()
message1.metadata.user = 'user1'
message1.metadata.collection = 'collection1' message1.metadata.collection = 'collection1'
chunk1 = ChunkEmbeddings( chunk1 = ChunkEmbeddings(
chunk_id="User1 content", chunk_id="User1 content",
@ -381,7 +368,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
# Store embeddings for user2/collection2 # Store embeddings for user2/collection2
message2 = MagicMock() message2 = MagicMock()
message2.metadata = MagicMock() message2.metadata = MagicMock()
message2.metadata.user = 'user2'
message2.metadata.collection = 'collection2' message2.metadata.collection = 'collection2'
chunk2 = ChunkEmbeddings( chunk2 = ChunkEmbeddings(
chunk_id="User2 content", chunk_id="User2 content",
@ -389,8 +375,8 @@ class TestMilvusDocEmbeddingsStorageProcessor:
) )
message2.chunks = [chunk2] message2.chunks = [chunk2]
await processor.store_document_embeddings(message1) await processor.store_document_embeddings('user1', message1)
await processor.store_document_embeddings(message2) await processor.store_document_embeddings('user2', message2)
# Verify both calls were made with correct parameters # Verify both calls were made with correct parameters
expected_calls = [ expected_calls = [
@ -411,7 +397,6 @@ class TestMilvusDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with special characters in user/collection names""" """Test storing document embeddings with special characters in user/collection names"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'user@domain.com' # Email-like user
message.metadata.collection = 'test-collection.v1' # Collection with special chars message.metadata.collection = 'test-collection.v1' # Collection with special chars
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
@ -420,9 +405,9 @@ class TestMilvusDocEmbeddingsStorageProcessor:
) )
message.chunks = [chunk] message.chunks = [chunk]
await processor.store_document_embeddings(message) await processor.store_document_embeddings('user@domain.com', message)
# Verify the exact user/collection strings are passed (sanitization happens in DocVectors) # Verify the exact workspace/collection strings are passed (sanitization happens in DocVectors)
processor.vecstore.insert.assert_called_once_with( processor.vecstore.insert.assert_called_once_with(
[0.1, 0.2, 0.3], "Special chars test", 'user@domain.com', 'test-collection.v1' [0.1, 0.2, 0.3], "Special chars test", 'user@domain.com', 'test-collection.v1'
) )

View file

@ -21,7 +21,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Create a mock message for testing""" """Create a mock message for testing"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
# Create test document embeddings # Create test document embeddings
@ -120,7 +119,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings for a single chunk""" """Test storing document embeddings for a single chunk"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
@ -135,7 +133,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
processor.pinecone.has_index.return_value = True processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2']): with patch('uuid.uuid4', side_effect=['id1', 'id2']):
await processor.store_document_embeddings(message) await processor.store_document_embeddings('test_user', message)
# Verify index name and operations (with dimension suffix) # Verify index name and operations (with dimension suffix)
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
@ -185,7 +183,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test that writing to non-existent index creates it lazily""" """Test that writing to non-existent index creates it lazily"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
@ -200,7 +197,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
processor.pinecone.Index.return_value = mock_index processor.pinecone.Index.return_value = mock_index
with patch('uuid.uuid4', return_value='test-id'): with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message) await processor.store_document_embeddings('test_user', message)
# Verify index was created with correct dimension # Verify index was created with correct dimension
expected_index_name = "d-test_user-test_collection-3" # 3 dimensions expected_index_name = "d-test_user-test_collection-3" # 3 dimensions
@ -217,7 +214,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with empty chunk (should be skipped)""" """Test storing document embeddings with empty chunk (should be skipped)"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
@ -229,7 +225,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
mock_index = MagicMock() mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message) await processor.store_document_embeddings('test_user', message)
# Verify no upsert was called for empty chunk # Verify no upsert was called for empty chunk
mock_index.upsert.assert_not_called() mock_index.upsert.assert_not_called()
@ -239,7 +235,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with None chunk (should be skipped)""" """Test storing document embeddings with None chunk (should be skipped)"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
@ -251,7 +246,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
mock_index = MagicMock() mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message) await processor.store_document_embeddings('test_user', message)
# Verify no upsert was called for None chunk # Verify no upsert was called for None chunk
mock_index.upsert.assert_not_called() mock_index.upsert.assert_not_called()
@ -261,7 +256,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with chunk that decodes to empty string""" """Test storing document embeddings with chunk that decodes to empty string"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
@ -273,7 +267,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
mock_index = MagicMock() mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message) await processor.store_document_embeddings('test_user', message)
# Verify no upsert was called for empty decoded chunk # Verify no upsert was called for empty decoded chunk
mock_index.upsert.assert_not_called() mock_index.upsert.assert_not_called()
@ -283,7 +277,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with different vector dimensions""" """Test storing document embeddings with different vector dimensions"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
# Each chunk has a single vector of different dimensions # Each chunk has a single vector of different dimensions
@ -325,14 +318,13 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with empty chunks list""" """Test storing document embeddings with empty chunks list"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
message.chunks = [] message.chunks = []
mock_index = MagicMock() mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message) await processor.store_document_embeddings('test_user', message)
# Verify no operations were performed # Verify no operations were performed
processor.pinecone.Index.assert_not_called() processor.pinecone.Index.assert_not_called()
@ -343,7 +335,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings for chunk with no vectors""" """Test storing document embeddings for chunk with no vectors"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
@ -355,7 +346,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
mock_index = MagicMock() mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index processor.pinecone.Index.return_value = mock_index
await processor.store_document_embeddings(message) await processor.store_document_embeddings('test_user', message)
# Verify no upsert was called (no vectors to insert) # Verify no upsert was called (no vectors to insert)
mock_index.upsert.assert_not_called() mock_index.upsert.assert_not_called()
@ -365,7 +356,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test that lazy creation happens when index doesn't exist""" """Test that lazy creation happens when index doesn't exist"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
@ -380,7 +370,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
processor.pinecone.Index.return_value = mock_index processor.pinecone.Index.return_value = mock_index
with patch('uuid.uuid4', return_value='test-id'): with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message) await processor.store_document_embeddings('test_user', message)
# Verify index was created # Verify index was created
processor.pinecone.create_index.assert_called_once() processor.pinecone.create_index.assert_called_once()
@ -390,7 +380,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test that lazy creation works correctly""" """Test that lazy creation works correctly"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
@ -405,7 +394,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
processor.pinecone.Index.return_value = mock_index processor.pinecone.Index.return_value = mock_index
with patch('uuid.uuid4', return_value='test-id'): with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message) await processor.store_document_embeddings('test_user', message)
# Verify index was created and used # Verify index was created and used
processor.pinecone.create_index.assert_called_once() processor.pinecone.create_index.assert_called_once()
@ -416,7 +405,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with Unicode content""" """Test storing document embeddings with Unicode content"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
chunk = ChunkEmbeddings( chunk = ChunkEmbeddings(
@ -430,7 +418,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
processor.pinecone.has_index.return_value = True processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', return_value='test-id'): with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message) await processor.store_document_embeddings('test_user', message)
# Verify Unicode content was properly decoded and stored # Verify Unicode content was properly decoded and stored
call_args = mock_index.upsert.call_args call_args = mock_index.upsert.call_args
@ -442,7 +430,6 @@ class TestPineconeDocEmbeddingsStorageProcessor:
"""Test storing document embeddings with large document chunks""" """Test storing document embeddings with large document chunks"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
# Create a large document chunk # Create a large document chunk
@ -458,7 +445,7 @@ class TestPineconeDocEmbeddingsStorageProcessor:
processor.pinecone.has_index.return_value = True processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', return_value='test-id'): with patch('uuid.uuid4', return_value='test-id'):
await processor.store_document_embeddings(message) await processor.store_document_embeddings('test_user', message)
# Verify large content was stored # Verify large content was stored
call_args = mock_index.upsert.call_args call_args = mock_index.upsert.call_args

View file

@ -84,7 +84,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with chunks and vectors # Create mock message with chunks and vectors
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection' mock_message.metadata.collection = 'test_collection'
mock_chunk = MagicMock() mock_chunk = MagicMock()
@ -94,7 +93,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk] mock_message.chunks = [mock_chunk]
# Act # Act
await processor.store_document_embeddings(mock_message) await processor.store_document_embeddings('test_user', mock_message)
# Assert # Assert
# Verify collection existence was checked (with dimension suffix) # Verify collection existence was checked (with dimension suffix)
@ -138,7 +137,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with multiple chunks # Create mock message with multiple chunks
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'multi_user'
mock_message.metadata.collection = 'multi_collection' mock_message.metadata.collection = 'multi_collection'
mock_chunk1 = MagicMock() mock_chunk1 = MagicMock()
@ -152,7 +150,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk1, mock_chunk2] mock_message.chunks = [mock_chunk1, mock_chunk2]
# Act # Act
await processor.store_document_embeddings(mock_message) await processor.store_document_embeddings('multi_user', mock_message)
# Assert # Assert
# Should be called twice (once per chunk) # Should be called twice (once per chunk)
@ -198,7 +196,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with multiple chunks, each having a single vector # Create mock message with multiple chunks, each having a single vector
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'vector_user'
mock_message.metadata.collection = 'vector_collection' mock_message.metadata.collection = 'vector_collection'
mock_chunk1 = MagicMock() mock_chunk1 = MagicMock()
@ -216,7 +213,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk1, mock_chunk2, mock_chunk3] mock_message.chunks = [mock_chunk1, mock_chunk2, mock_chunk3]
# Act # Act
await processor.store_document_embeddings(mock_message) await processor.store_document_embeddings('vector_user', mock_message)
# Assert # Assert
# Should be called 3 times (once per chunk) # Should be called 3 times (once per chunk)
@ -255,7 +252,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with empty chunk_id # Create mock message with empty chunk_id
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'empty_user'
mock_message.metadata.collection = 'empty_collection' mock_message.metadata.collection = 'empty_collection'
mock_chunk_empty = MagicMock() mock_chunk_empty = MagicMock()
@ -265,7 +261,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk_empty] mock_message.chunks = [mock_chunk_empty]
# Act # Act
await processor.store_document_embeddings(mock_message) await processor.store_document_embeddings('empty_user', mock_message)
# Assert # Assert
# Should not call upsert for empty chunk_ids # Should not call upsert for empty chunk_ids
@ -298,7 +294,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message # Create mock message
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'new_user'
mock_message.metadata.collection = 'new_collection' mock_message.metadata.collection = 'new_collection'
mock_chunk = MagicMock() mock_chunk = MagicMock()
@ -308,7 +303,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk] mock_message.chunks = [mock_chunk]
# Act # Act
await processor.store_document_embeddings(mock_message) await processor.store_document_embeddings('new_user', mock_message)
# Assert - collection should be lazily created # Assert - collection should be lazily created
expected_collection = 'd_new_user_new_collection_5' # 5 dimensions expected_collection = 'd_new_user_new_collection_5' # 5 dimensions
@ -350,7 +345,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message # Create mock message
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'error_user'
mock_message.metadata.collection = 'error_collection' mock_message.metadata.collection = 'error_collection'
mock_chunk = MagicMock() mock_chunk = MagicMock()
@ -361,7 +355,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Act & Assert - should propagate the creation error # Act & Assert - should propagate the creation error
with pytest.raises(Exception, match="Connection error"): with pytest.raises(Exception, match="Connection error"):
await processor.store_document_embeddings(mock_message) await processor.store_document_embeddings('error_user', mock_message)
@patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.doc_embeddings.qdrant.write.QdrantClient')
@patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid') @patch('trustgraph.storage.doc_embeddings.qdrant.write.uuid')
@ -388,7 +382,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create first mock message # Create first mock message
mock_message1 = MagicMock() mock_message1 = MagicMock()
mock_message1.metadata.user = 'cache_user'
mock_message1.metadata.collection = 'cache_collection' mock_message1.metadata.collection = 'cache_collection'
mock_chunk1 = MagicMock() mock_chunk1 = MagicMock()
@ -398,7 +391,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message1.chunks = [mock_chunk1] mock_message1.chunks = [mock_chunk1]
# First call # First call
await processor.store_document_embeddings(mock_message1) await processor.store_document_embeddings('cache_user', mock_message1)
# Reset mock to track second call # Reset mock to track second call
mock_qdrant_instance.reset_mock() mock_qdrant_instance.reset_mock()
@ -406,7 +399,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create second mock message with same dimensions # Create second mock message with same dimensions
mock_message2 = MagicMock() mock_message2 = MagicMock()
mock_message2.metadata.user = 'cache_user'
mock_message2.metadata.collection = 'cache_collection' mock_message2.metadata.collection = 'cache_collection'
mock_chunk2 = MagicMock() mock_chunk2 = MagicMock()
@ -416,7 +408,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message2.chunks = [mock_chunk2] mock_message2.chunks = [mock_chunk2]
# Act - Second call with same collection # Act - Second call with same collection
await processor.store_document_embeddings(mock_message2) await processor.store_document_embeddings('cache_user', mock_message2)
# Assert # Assert
expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions expected_collection = 'd_cache_user_cache_collection_3' # 3 dimensions
@ -452,7 +444,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with chunks of different dimensions # Create mock message with chunks of different dimensions
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'dim_user'
mock_message.metadata.collection = 'dim_collection' mock_message.metadata.collection = 'dim_collection'
mock_chunk1 = MagicMock() mock_chunk1 = MagicMock()
@ -466,7 +457,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk1, mock_chunk2] mock_message.chunks = [mock_chunk1, mock_chunk2]
# Act # Act
await processor.store_document_embeddings(mock_message) await processor.store_document_embeddings('dim_user', mock_message)
# Assert # Assert
# Should check existence of DIFFERENT collections for each dimension # Should check existence of DIFFERENT collections for each dimension
@ -526,7 +517,6 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with URI-style chunk_id # Create mock message with URI-style chunk_id
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'uri_user'
mock_message.metadata.collection = 'uri_collection' mock_message.metadata.collection = 'uri_collection'
mock_chunk = MagicMock() mock_chunk = MagicMock()
@ -536,7 +526,7 @@ class TestQdrantDocEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.chunks = [mock_chunk] mock_message.chunks = [mock_chunk]
# Act # Act
await processor.store_document_embeddings(mock_message) await processor.store_document_embeddings('uri_user', mock_message)
# Assert # Assert
# Verify the chunk_id was stored correctly # Verify the chunk_id was stored correctly

View file

@ -17,7 +17,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Create a mock message for testing""" """Create a mock message for testing"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
# Create test entities with embeddings # Create test entities with embeddings
@ -80,7 +79,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings for a single entity""" """Test storing graph embeddings for a single entity"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
entity = EntityEmbeddings( entity = EntityEmbeddings(
@ -89,7 +87,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
) )
message.entities = [entity] message.entities = [entity]
await processor.store_graph_embeddings(message) await processor.store_graph_embeddings('test_user', message)
# Verify insert was called once with the full vector # Verify insert was called once with the full vector
processor.vecstore.insert.assert_called_once() processor.vecstore.insert.assert_called_once()
@ -125,7 +123,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with empty entity value (should be skipped)""" """Test storing graph embeddings with empty entity value (should be skipped)"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
entity = EntityEmbeddings( entity = EntityEmbeddings(
@ -134,7 +131,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
) )
message.entities = [entity] message.entities = [entity]
await processor.store_graph_embeddings(message) await processor.store_graph_embeddings('test_user', message)
# Verify no insert was called for empty entity # Verify no insert was called for empty entity
processor.vecstore.insert.assert_not_called() processor.vecstore.insert.assert_not_called()
@ -144,7 +141,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with None entity value (should be skipped)""" """Test storing graph embeddings with None entity value (should be skipped)"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
entity = EntityEmbeddings( entity = EntityEmbeddings(
@ -153,7 +149,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
) )
message.entities = [entity] message.entities = [entity]
await processor.store_graph_embeddings(message) await processor.store_graph_embeddings('test_user', message)
# Verify no insert was called for None entity # Verify no insert was called for None entity
processor.vecstore.insert.assert_not_called() processor.vecstore.insert.assert_not_called()
@ -163,7 +159,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with mix of valid and invalid entities""" """Test storing graph embeddings with mix of valid and invalid entities"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
valid_entity = EntityEmbeddings( valid_entity = EntityEmbeddings(
@ -183,7 +178,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
) )
message.entities = [valid_entity, empty_entity, none_entity] message.entities = [valid_entity, empty_entity, none_entity]
await processor.store_graph_embeddings(message) await processor.store_graph_embeddings('test_user', message)
# Verify only valid entity was inserted with user/collection/chunk_id parameters # Verify only valid entity was inserted with user/collection/chunk_id parameters
processor.vecstore.insert.assert_called_once_with( processor.vecstore.insert.assert_called_once_with(
@ -196,11 +191,10 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with empty entities list""" """Test storing graph embeddings with empty entities list"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
message.entities = [] message.entities = []
await processor.store_graph_embeddings(message) await processor.store_graph_embeddings('test_user', message)
# Verify no insert was called # Verify no insert was called
processor.vecstore.insert.assert_not_called() processor.vecstore.insert.assert_not_called()
@ -210,7 +204,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings for entity with no vectors""" """Test storing graph embeddings for entity with no vectors"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
entity = EntityEmbeddings( entity = EntityEmbeddings(
@ -219,7 +212,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
) )
message.entities = [entity] message.entities = [entity]
await processor.store_graph_embeddings(message) await processor.store_graph_embeddings('test_user', message)
# Verify no insert was called (no vectors to insert) # Verify no insert was called (no vectors to insert)
processor.vecstore.insert.assert_not_called() processor.vecstore.insert.assert_not_called()
@ -229,7 +222,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with different vector dimensions""" """Test storing graph embeddings with different vector dimensions"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
# Each entity has a single vector of different dimensions # Each entity has a single vector of different dimensions
@ -247,7 +239,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
) )
message.entities = [entity1, entity2, entity3] message.entities = [entity1, entity2, entity3]
await processor.store_graph_embeddings(message) await processor.store_graph_embeddings('test_user', message)
# Verify all vectors were inserted regardless of dimension # Verify all vectors were inserted regardless of dimension
expected_calls = [ expected_calls = [
@ -267,7 +259,6 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings for both URI and literal entities""" """Test storing graph embeddings for both URI and literal entities"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
uri_entity = EntityEmbeddings( uri_entity = EntityEmbeddings(
@ -280,7 +271,7 @@ class TestMilvusGraphEmbeddingsStorageProcessor:
) )
message.entities = [uri_entity, literal_entity] message.entities = [uri_entity, literal_entity]
await processor.store_graph_embeddings(message) await processor.store_graph_embeddings('test_user', message)
# Verify both entities were inserted # Verify both entities were inserted
expected_calls = [ expected_calls = [

View file

@ -21,7 +21,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Create a mock message for testing""" """Create a mock message for testing"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
# Create test entity embeddings (each entity has a single vector) # Create test entity embeddings (each entity has a single vector)
@ -124,7 +123,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings for a single entity""" """Test storing graph embeddings for a single entity"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
entity = EntityEmbeddings( entity = EntityEmbeddings(
@ -139,7 +137,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
processor.pinecone.has_index.return_value = True processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1']): with patch('uuid.uuid4', side_effect=['id1']):
await processor.store_graph_embeddings(message) await processor.store_graph_embeddings('test_user', message)
# Verify index name and operations (with dimension suffix) # Verify index name and operations (with dimension suffix)
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
@ -189,7 +187,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test that writing to non-existent index creates it lazily""" """Test that writing to non-existent index creates it lazily"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
entity = EntityEmbeddings( entity = EntityEmbeddings(
@ -204,7 +201,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
processor.pinecone.Index.return_value = mock_index processor.pinecone.Index.return_value = mock_index
with patch('uuid.uuid4', return_value='test-id'): with patch('uuid.uuid4', return_value='test-id'):
await processor.store_graph_embeddings(message) await processor.store_graph_embeddings('test_user', message)
# Verify index was created with correct dimension # Verify index was created with correct dimension
expected_index_name = "t-test_user-test_collection-3" # 3 dimensions expected_index_name = "t-test_user-test_collection-3" # 3 dimensions
@ -221,7 +218,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with empty entity value (should be skipped)""" """Test storing graph embeddings with empty entity value (should be skipped)"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
entity = EntityEmbeddings( entity = EntityEmbeddings(
@ -233,7 +229,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
mock_index = MagicMock() mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message) await processor.store_graph_embeddings('test_user', message)
# Verify no upsert was called for empty entity # Verify no upsert was called for empty entity
mock_index.upsert.assert_not_called() mock_index.upsert.assert_not_called()
@ -243,7 +239,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with None entity value (should be skipped)""" """Test storing graph embeddings with None entity value (should be skipped)"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
entity = EntityEmbeddings( entity = EntityEmbeddings(
@ -255,7 +250,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
mock_index = MagicMock() mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message) await processor.store_graph_embeddings('test_user', message)
# Verify no upsert was called for None entity # Verify no upsert was called for None entity
mock_index.upsert.assert_not_called() mock_index.upsert.assert_not_called()
@ -265,7 +260,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with different vector dimensions""" """Test storing graph embeddings with different vector dimensions"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
# Each entity has a single vector of different dimensions # Each entity has a single vector of different dimensions
@ -288,7 +282,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
processor.pinecone.has_index.return_value = True processor.pinecone.has_index.return_value = True
with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']): with patch('uuid.uuid4', side_effect=['id1', 'id2', 'id3']):
await processor.store_graph_embeddings(message) await processor.store_graph_embeddings('test_user', message)
# Verify different indexes were used for different dimensions # Verify different indexes were used for different dimensions
index_calls = processor.pinecone.Index.call_args_list index_calls = processor.pinecone.Index.call_args_list
@ -307,14 +301,13 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings with empty entities list""" """Test storing graph embeddings with empty entities list"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
message.entities = [] message.entities = []
mock_index = MagicMock() mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message) await processor.store_graph_embeddings('test_user', message)
# Verify no operations were performed # Verify no operations were performed
processor.pinecone.Index.assert_not_called() processor.pinecone.Index.assert_not_called()
@ -325,7 +318,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test storing graph embeddings for entity with no vectors""" """Test storing graph embeddings for entity with no vectors"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
entity = EntityEmbeddings( entity = EntityEmbeddings(
@ -337,7 +329,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
mock_index = MagicMock() mock_index = MagicMock()
processor.pinecone.Index.return_value = mock_index processor.pinecone.Index.return_value = mock_index
await processor.store_graph_embeddings(message) await processor.store_graph_embeddings('test_user', message)
# Verify no upsert was called (no vectors to insert) # Verify no upsert was called (no vectors to insert)
mock_index.upsert.assert_not_called() mock_index.upsert.assert_not_called()
@ -347,7 +339,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test that lazy creation happens when index doesn't exist""" """Test that lazy creation happens when index doesn't exist"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
entity = EntityEmbeddings( entity = EntityEmbeddings(
@ -362,7 +353,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
processor.pinecone.Index.return_value = mock_index processor.pinecone.Index.return_value = mock_index
with patch('uuid.uuid4', return_value='test-id'): with patch('uuid.uuid4', return_value='test-id'):
await processor.store_graph_embeddings(message) await processor.store_graph_embeddings('test_user', message)
# Verify index was created # Verify index was created
processor.pinecone.create_index.assert_called_once() processor.pinecone.create_index.assert_called_once()
@ -372,7 +363,6 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
"""Test that lazy creation works correctly""" """Test that lazy creation works correctly"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
entity = EntityEmbeddings( entity = EntityEmbeddings(
@ -387,7 +377,7 @@ class TestPineconeGraphEmbeddingsStorageProcessor:
processor.pinecone.Index.return_value = mock_index processor.pinecone.Index.return_value = mock_index
with patch('uuid.uuid4', return_value='test-id'): with patch('uuid.uuid4', return_value='test-id'):
await processor.store_graph_embeddings(message) await processor.store_graph_embeddings('test_user', message)
# Verify index was created and used # Verify index was created and used
processor.pinecone.create_index.assert_called_once() processor.pinecone.create_index.assert_called_once()

View file

@ -64,7 +64,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with entities and vectors # Create mock message with entities and vectors
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection' mock_message.metadata.collection = 'test_collection'
mock_entity = MagicMock() mock_entity = MagicMock()
@ -75,7 +74,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.entities = [mock_entity] mock_message.entities = [mock_entity]
# Act # Act
await processor.store_graph_embeddings(mock_message) await processor.store_graph_embeddings('test_user', mock_message)
# Assert # Assert
# Verify collection existence was checked (with dimension suffix) # Verify collection existence was checked (with dimension suffix)
@ -118,7 +117,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with multiple entities # Create mock message with multiple entities
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'multi_user'
mock_message.metadata.collection = 'multi_collection' mock_message.metadata.collection = 'multi_collection'
mock_entity1 = MagicMock() mock_entity1 = MagicMock()
@ -134,7 +132,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.entities = [mock_entity1, mock_entity2] mock_message.entities = [mock_entity1, mock_entity2]
# Act # Act
await processor.store_graph_embeddings(mock_message) await processor.store_graph_embeddings('multi_user', mock_message)
# Assert # Assert
# Should be called twice (once per entity) # Should be called twice (once per entity)
@ -179,7 +177,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with three entities # Create mock message with three entities
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'vector_user'
mock_message.metadata.collection = 'vector_collection' mock_message.metadata.collection = 'vector_collection'
mock_entity1 = MagicMock() mock_entity1 = MagicMock()
@ -200,7 +197,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.entities = [mock_entity1, mock_entity2, mock_entity3] mock_message.entities = [mock_entity1, mock_entity2, mock_entity3]
# Act # Act
await processor.store_graph_embeddings(mock_message) await processor.store_graph_embeddings('vector_user', mock_message)
# Assert # Assert
# Should be called 3 times (once per entity) # Should be called 3 times (once per entity)
@ -238,7 +235,6 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
# Create mock message with empty entity value # Create mock message with empty entity value
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'empty_user'
mock_message.metadata.collection = 'empty_collection' mock_message.metadata.collection = 'empty_collection'
mock_entity_empty = MagicMock() mock_entity_empty = MagicMock()
@ -253,7 +249,7 @@ class TestQdrantGraphEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_message.entities = [mock_entity_empty, mock_entity_none] mock_message.entities = [mock_entity_empty, mock_entity_none]
# Act # Act
await processor.store_graph_embeddings(mock_message) await processor.store_graph_embeddings('empty_user', mock_message)
# Assert # Assert
# Should not call upsert for empty entities # Should not call upsert for empty entities

View file

@ -1,5 +1,5 @@
""" """
Tests for Memgraph user/collection isolation in storage service Tests for Memgraph workspace/collection isolation in storage service.
""" """
import pytest import pytest
@ -8,12 +8,12 @@ from unittest.mock import MagicMock, patch
from trustgraph.storage.triples.memgraph.write import Processor from trustgraph.storage.triples.memgraph.write import Processor
class TestMemgraphUserCollectionIsolation: class TestMemgraphWorkspaceCollectionIsolation:
"""Test cases for Memgraph storage service with user/collection isolation""" """Test cases for Memgraph storage service with workspace/collection isolation"""
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_storage_creates_indexes_with_user_collection(self, mock_graph_db): def test_storage_creates_indexes_with_workspace_collection(self, mock_graph_db):
"""Test that storage creates both legacy and user/collection indexes""" """Test that storage creates both legacy and workspace/collection indexes"""
mock_driver = MagicMock() mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock() mock_session = MagicMock()
@ -21,18 +21,17 @@ class TestMemgraphUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock()) processor = Processor(taskgroup=MagicMock())
# Verify all indexes were attempted (4 legacy + 4 user/collection = 8 total) # 4 legacy + 4 workspace/collection = 8 total
assert mock_session.run.call_count == 8 assert mock_session.run.call_count == 8
# Check some specific index creation calls
expected_calls = [ expected_calls = [
"CREATE INDEX ON :Node", "CREATE INDEX ON :Node",
"CREATE INDEX ON :Node(uri)", "CREATE INDEX ON :Node(uri)",
"CREATE INDEX ON :Literal", "CREATE INDEX ON :Literal",
"CREATE INDEX ON :Literal(value)", "CREATE INDEX ON :Literal(value)",
"CREATE INDEX ON :Node(user)", "CREATE INDEX ON :Node(workspace)",
"CREATE INDEX ON :Node(collection)", "CREATE INDEX ON :Node(collection)",
"CREATE INDEX ON :Literal(user)", "CREATE INDEX ON :Literal(workspace)",
"CREATE INDEX ON :Literal(collection)" "CREATE INDEX ON :Literal(collection)"
] ]
@ -41,14 +40,13 @@ class TestMemgraphUserCollectionIsolation:
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_store_triples_with_user_collection(self, mock_graph_db): async def test_store_triples_with_workspace_collection(self, mock_graph_db):
"""Test that store_triples includes user/collection in all operations""" """Test that store_triples includes workspace/collection in all operations"""
mock_driver = MagicMock() mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock() mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock() mock_result = MagicMock()
mock_summary = MagicMock() mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1 mock_summary.counters.nodes_created = 1
@ -58,45 +56,39 @@ class TestMemgraphUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock()) processor = Processor(taskgroup=MagicMock())
# Create mock triple with URI object from trustgraph.schema import IRI
triple = MagicMock() triple = MagicMock()
triple.s.value = "http://example.com/subject" triple.s.type = IRI
triple.p.value = "http://example.com/predicate" triple.s.iri = "http://example.com/subject"
triple.o.value = "http://example.com/object" triple.p.type = IRI
triple.o.is_uri = True triple.p.iri = "http://example.com/predicate"
triple.o.type = IRI
triple.o.iri = "http://example.com/object"
# Create mock message with metadata
mock_message = MagicMock() mock_message = MagicMock()
mock_message.triples = [triple] mock_message.triples = [triple]
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection" mock_message.metadata.collection = "test_collection"
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True): with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message) await processor.store_triples("test_workspace", mock_message)
# Verify user/collection parameters were passed to all operations # create_node (subject), create_node (object), relate_node = 3 calls
# Should have: create_node (subject), create_node (object), relate_node = 3 calls
assert mock_driver.execute_query.call_count == 3 assert mock_driver.execute_query.call_count == 3
# Check that user and collection were included in all calls for c in mock_driver.execute_query.call_args_list:
for call in mock_driver.execute_query.call_args_list: kwargs = c.kwargs
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1] assert kwargs['workspace'] == "test_workspace"
assert 'user' in call_kwargs assert kwargs['collection'] == "test_collection"
assert 'collection' in call_kwargs
assert call_kwargs['user'] == "test_user"
assert call_kwargs['collection'] == "test_collection"
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_store_triples_with_default_user_collection(self, mock_graph_db): async def test_store_triples_with_default_collection(self, mock_graph_db):
"""Test that defaults are used when user/collection not provided in metadata""" """Test that default collection is used when not provided in metadata"""
mock_driver = MagicMock() mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock() mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock() mock_result = MagicMock()
mock_summary = MagicMock() mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1 mock_summary.counters.nodes_created = 1
@ -106,38 +98,35 @@ class TestMemgraphUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock()) processor = Processor(taskgroup=MagicMock())
# Create mock triple from trustgraph.schema import IRI, LITERAL
triple = MagicMock() triple = MagicMock()
triple.s.value = "http://example.com/subject" triple.s.type = IRI
triple.p.value = "http://example.com/predicate" triple.s.iri = "http://example.com/subject"
triple.p.type = IRI
triple.p.iri = "http://example.com/predicate"
triple.o.type = LITERAL
triple.o.value = "literal_value" triple.o.value = "literal_value"
triple.o.is_uri = False
# Create mock message without user/collection metadata
mock_message = MagicMock() mock_message = MagicMock()
mock_message.triples = [triple] mock_message.triples = [triple]
mock_message.metadata.user = None
mock_message.metadata.collection = None mock_message.metadata.collection = None
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True): with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message) await processor.store_triples("default", mock_message)
# Verify defaults were used for c in mock_driver.execute_query.call_args_list:
for call in mock_driver.execute_query.call_args_list: kwargs = c.kwargs
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1] assert kwargs['workspace'] == "default"
assert call_kwargs['user'] == "default" assert kwargs['collection'] == "default"
assert call_kwargs['collection'] == "default"
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_create_node_includes_user_collection(self, mock_graph_db): def test_create_node_includes_workspace_collection(self, mock_graph_db):
"""Test that create_node includes user/collection properties""" """Test that create_node includes workspace/collection properties"""
mock_driver = MagicMock() mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock() mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock() mock_result = MagicMock()
mock_summary = MagicMock() mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1 mock_summary.counters.nodes_created = 1
@ -147,25 +136,24 @@ class TestMemgraphUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock()) processor = Processor(taskgroup=MagicMock())
processor.create_node("http://example.com/node", "test_user", "test_collection") processor.create_node("http://example.com/node", "test_workspace", "test_collection")
mock_driver.execute_query.assert_called_with( mock_driver.execute_query.assert_called_with(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri="http://example.com/node", uri="http://example.com/node",
user="test_user", workspace="test_workspace",
collection="test_collection", collection="test_collection",
database_="memgraph" database_="memgraph"
) )
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_create_literal_includes_user_collection(self, mock_graph_db): def test_create_literal_includes_workspace_collection(self, mock_graph_db):
"""Test that create_literal includes user/collection properties""" """Test that create_literal includes workspace/collection properties"""
mock_driver = MagicMock() mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock() mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock() mock_result = MagicMock()
mock_summary = MagicMock() mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1 mock_summary.counters.nodes_created = 1
@ -175,25 +163,24 @@ class TestMemgraphUserCollectionIsolation:
processor = Processor(taskgroup=MagicMock()) processor = Processor(taskgroup=MagicMock())
processor.create_literal("test_value", "test_user", "test_collection") processor.create_literal("test_value", "test_workspace", "test_collection")
mock_driver.execute_query.assert_called_with( mock_driver.execute_query.assert_called_with(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})", "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value="test_value", value="test_value",
user="test_user", workspace="test_workspace",
collection="test_collection", collection="test_collection",
database_="memgraph" database_="memgraph"
) )
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_relate_node_includes_user_collection(self, mock_graph_db): def test_relate_node_includes_workspace_collection(self, mock_graph_db):
"""Test that relate_node includes user/collection properties""" """Test that relate_node includes workspace/collection properties"""
mock_driver = MagicMock() mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock() mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock() mock_result = MagicMock()
mock_summary = MagicMock() mock_summary = MagicMock()
mock_summary.counters.nodes_created = 0 mock_summary.counters.nodes_created = 0
@ -207,31 +194,30 @@ class TestMemgraphUserCollectionIsolation:
"http://example.com/subject", "http://example.com/subject",
"http://example.com/predicate", "http://example.com/predicate",
"http://example.com/object", "http://example.com/object",
"test_user", "test_workspace",
"test_collection" "test_collection"
) )
mock_driver.execute_query.assert_called_with( mock_driver.execute_query.assert_called_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src="http://example.com/subject", src="http://example.com/subject",
dest="http://example.com/object", dest="http://example.com/object",
uri="http://example.com/predicate", uri="http://example.com/predicate",
user="test_user", workspace="test_workspace",
collection="test_collection", collection="test_collection",
database_="memgraph" database_="memgraph"
) )
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
def test_relate_literal_includes_user_collection(self, mock_graph_db): def test_relate_literal_includes_workspace_collection(self, mock_graph_db):
"""Test that relate_literal includes user/collection properties""" """Test that relate_literal includes workspace/collection properties"""
mock_driver = MagicMock() mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock() mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock() mock_result = MagicMock()
mock_summary = MagicMock() mock_summary = MagicMock()
mock_summary.counters.nodes_created = 0 mock_summary.counters.nodes_created = 0
@ -245,18 +231,18 @@ class TestMemgraphUserCollectionIsolation:
"http://example.com/subject", "http://example.com/subject",
"http://example.com/predicate", "http://example.com/predicate",
"literal_value", "literal_value",
"test_user", "test_workspace",
"test_collection" "test_collection"
) )
mock_driver.execute_query.assert_called_with( mock_driver.execute_query.assert_called_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src="http://example.com/subject", src="http://example.com/subject",
dest="literal_value", dest="literal_value",
uri="http://example.com/predicate", uri="http://example.com/predicate",
user="test_user", workspace="test_workspace",
collection="test_collection", collection="test_collection",
database_="memgraph" database_="memgraph"
) )
@ -264,18 +250,13 @@ class TestMemgraphUserCollectionIsolation:
def test_add_args_includes_memgraph_parameters(self): def test_add_args_includes_memgraph_parameters(self):
"""Test that add_args properly configures Memgraph-specific parameters""" """Test that add_args properly configures Memgraph-specific parameters"""
from argparse import ArgumentParser from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser() parser = ArgumentParser()
# Mock the parent class add_args method
with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args') as mock_parent_add_args: with patch('trustgraph.storage.triples.memgraph.write.TriplesStoreService.add_args') as mock_parent_add_args:
Processor.add_args(parser) Processor.add_args(parser)
# Verify parent add_args was called
mock_parent_add_args.assert_called_once() mock_parent_add_args.assert_called_once()
# Verify our specific arguments were added with Memgraph defaults
args = parser.parse_args([]) args = parser.parse_args([])
assert hasattr(args, 'graph_host') assert hasattr(args, 'graph_host')
@ -288,19 +269,18 @@ class TestMemgraphUserCollectionIsolation:
assert args.database == 'memgraph' assert args.database == 'memgraph'
class TestMemgraphUserCollectionRegression: class TestMemgraphWorkspaceCollectionRegression:
"""Regression tests to ensure user/collection isolation prevents data leakage""" """Regression tests to ensure workspace/collection isolation prevents data leakage"""
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regression_no_cross_user_data_access(self, mock_graph_db): async def test_regression_no_cross_workspace_data_access(self, mock_graph_db):
"""Regression test: Ensure users cannot access each other's data""" """Regression test: Ensure workspaces cannot access each other's data"""
mock_driver = MagicMock() mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock() mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock() mock_result = MagicMock()
mock_summary = MagicMock() mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1 mock_summary.counters.nodes_created = 1
@ -310,39 +290,37 @@ class TestMemgraphUserCollectionRegression:
processor = Processor(taskgroup=MagicMock()) processor = Processor(taskgroup=MagicMock())
# Store data for user1 from trustgraph.schema import IRI, LITERAL
triple = MagicMock() triple = MagicMock()
triple.s.value = "http://example.com/subject" triple.s.type = IRI
triple.p.value = "http://example.com/predicate" triple.s.iri = "http://example.com/subject"
triple.o.value = "user1_data" triple.p.type = IRI
triple.o.is_uri = False triple.p.iri = "http://example.com/predicate"
triple.o.type = LITERAL
triple.o.value = "ws1_data"
message_user1 = MagicMock() message_ws1 = MagicMock()
message_user1.triples = [triple] message_ws1.triples = [triple]
message_user1.metadata.user = "user1" message_ws1.metadata.collection = "collection1"
message_user1.metadata.collection = "collection1"
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True): with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message_user1) await processor.store_triples("workspace1", message_ws1)
# Verify that all storage operations included user1/collection1 parameters for c in mock_driver.execute_query.call_args_list:
for call in mock_driver.execute_query.call_args_list: kwargs = c.kwargs
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1] if 'workspace' in kwargs:
if 'user' in call_kwargs: assert kwargs['workspace'] == "workspace1"
assert call_kwargs['user'] == "user1" assert kwargs['collection'] == "collection1"
assert call_kwargs['collection'] == "collection1"
@patch('trustgraph.storage.triples.memgraph.write.GraphDatabase') @patch('trustgraph.storage.triples.memgraph.write.GraphDatabase')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regression_same_uri_different_users(self, mock_graph_db): async def test_regression_same_uri_different_workspaces(self, mock_graph_db):
"""Regression test: Same URI can exist for different users without conflict""" """Regression test: Same URI can exist in different workspaces without conflict"""
mock_driver = MagicMock() mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver mock_graph_db.driver.return_value = mock_driver
mock_session = MagicMock() mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session mock_driver.session.return_value.__enter__.return_value = mock_session
# Mock execute_query response
mock_result = MagicMock() mock_result = MagicMock()
mock_summary = MagicMock() mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1 mock_summary.counters.nodes_created = 1
@ -352,18 +330,15 @@ class TestMemgraphUserCollectionRegression:
processor = Processor(taskgroup=MagicMock()) processor = Processor(taskgroup=MagicMock())
# Same URI for different users should create separate nodes processor.create_node("http://example.com/same-uri", "workspace1", "collection1")
processor.create_node("http://example.com/same-uri", "user1", "collection1") processor.create_node("http://example.com/same-uri", "workspace2", "collection2")
processor.create_node("http://example.com/same-uri", "user2", "collection2")
# Verify both calls were made with different user/collection parameters calls = mock_driver.execute_query.call_args_list[-2:]
calls = mock_driver.execute_query.call_args_list[-2:] # Get last 2 calls
call1_kwargs = calls[0].kwargs if hasattr(calls[0], 'kwargs') else calls[0][1] k1 = calls[0].kwargs
call2_kwargs = calls[1].kwargs if hasattr(calls[1], 'kwargs') else calls[1][1] k2 = calls[1].kwargs
assert call1_kwargs['user'] == "user1" and call1_kwargs['collection'] == "collection1" assert k1['workspace'] == "workspace1" and k1['collection'] == "collection1"
assert call2_kwargs['user'] == "user2" and call2_kwargs['collection'] == "collection2" assert k2['workspace'] == "workspace2" and k2['collection'] == "collection2"
# Both should have the same URI but different user/collection assert k1['uri'] == k2['uri'] == "http://example.com/same-uri"
assert call1_kwargs['uri'] == call2_kwargs['uri'] == "http://example.com/same-uri"

View file

@ -1,5 +1,5 @@
""" """
Tests for Neo4j user/collection isolation in triples storage and query Tests for Neo4j workspace/collection isolation in triples storage and query.
""" """
import pytest import pytest
@ -11,12 +11,12 @@ from trustgraph.schema import Triples, Triple, Term, Metadata, IRI, LITERAL
from trustgraph.schema import TriplesQueryRequest from trustgraph.schema import TriplesQueryRequest
class TestNeo4jUserCollectionIsolation: class TestNeo4jWorkspaceCollectionIsolation:
"""Test cases for Neo4j user/collection isolation functionality""" """Test cases for Neo4j workspace/collection isolation functionality"""
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
def test_storage_creates_indexes_with_user_collection(self, mock_graph_db): def test_storage_creates_indexes_with_workspace_collection(self, mock_graph_db):
"""Test that storage service creates compound indexes for user/collection""" """Test that storage service creates compound indexes for workspace/collection"""
taskgroup_mock = MagicMock() taskgroup_mock = MagicMock()
mock_driver = MagicMock() mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver mock_graph_db.driver.return_value = mock_driver
@ -25,25 +25,23 @@ class TestNeo4jUserCollectionIsolation:
processor = StorageProcessor(taskgroup=taskgroup_mock) processor = StorageProcessor(taskgroup=taskgroup_mock)
# Verify both legacy and new compound indexes are created
expected_indexes = [ expected_indexes = [
"CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)", "CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)",
"CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)", "CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)",
"CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)", "CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)",
"CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)", "CREATE INDEX node_workspace_collection_uri FOR (n:Node) ON (n.workspace, n.collection, n.uri)",
"CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)", "CREATE INDEX literal_workspace_collection_value FOR (n:Literal) ON (n.workspace, n.collection, n.value)",
"CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)", "CREATE INDEX rel_workspace FOR ()-[r:Rel]-() ON (r.workspace)",
"CREATE INDEX rel_collection FOR ()-[r:Rel]-() ON (r.collection)" "CREATE INDEX rel_collection FOR ()-[r:Rel]-() ON (r.collection)"
] ]
# Check that all expected indexes were created
for expected_query in expected_indexes: for expected_query in expected_indexes:
mock_session.run.assert_any_call(expected_query) mock_session.run.assert_any_call(expected_query)
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_store_triples_with_user_collection(self, mock_graph_db): async def test_store_triples_with_workspace_collection(self, mock_graph_db):
"""Test that triples are stored with user/collection properties""" """Test that triples are stored with workspace/collection properties"""
taskgroup_mock = MagicMock() taskgroup_mock = MagicMock()
mock_driver = MagicMock() mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver mock_graph_db.driver.return_value = mock_driver
@ -52,12 +50,7 @@ class TestNeo4jUserCollectionIsolation:
processor = StorageProcessor(taskgroup=taskgroup_mock) processor = StorageProcessor(taskgroup=taskgroup_mock)
# Create test message with user/collection metadata metadata = Metadata(id="test-id", collection="test_collection")
metadata = Metadata(
id="test-id",
user="test_user",
collection="test_collection"
)
triple = Triple( triple = Triple(
s=Term(type=IRI, iri="http://example.com/subject"), s=Term(type=IRI, iri="http://example.com/subject"),
@ -65,45 +58,39 @@ class TestNeo4jUserCollectionIsolation:
o=Term(type=LITERAL, value="literal_value") o=Term(type=LITERAL, value="literal_value")
) )
message = Triples( message = Triples(metadata=metadata, triples=[triple])
metadata=metadata,
triples=[triple]
)
# Mock execute_query to return summaries
mock_summary = MagicMock() mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1 mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10 mock_summary.result_available_after = 10
mock_driver.execute_query.return_value.summary = mock_summary mock_driver.execute_query.return_value.summary = mock_summary
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True): with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message) await processor.store_triples("test_workspace", message)
# Verify nodes and relationships were created with user/collection properties
expected_calls = [ expected_calls = [
call( call(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri="http://example.com/subject", uri="http://example.com/subject",
user="test_user", workspace="test_workspace",
collection="test_collection", collection="test_collection",
database_='neo4j' database_='neo4j'
), ),
call( call(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})", "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value="literal_value", value="literal_value",
user="test_user", workspace="test_workspace",
collection="test_collection", collection="test_collection",
database_='neo4j' database_='neo4j'
), ),
call( call(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src="http://example.com/subject", src="http://example.com/subject",
dest="literal_value", dest="literal_value",
uri="http://example.com/predicate", uri="http://example.com/predicate",
user="test_user", workspace="test_workspace",
collection="test_collection", collection="test_collection",
database_='neo4j' database_='neo4j'
) )
@ -114,8 +101,8 @@ class TestNeo4jUserCollectionIsolation:
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_store_triples_with_default_user_collection(self, mock_graph_db): async def test_store_triples_with_default_collection(self, mock_graph_db):
"""Test that default user/collection are used when not provided""" """Test that default collection is used when not provided"""
taskgroup_mock = MagicMock() taskgroup_mock = MagicMock()
mock_driver = MagicMock() mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver mock_graph_db.driver.return_value = mock_driver
@ -124,7 +111,6 @@ class TestNeo4jUserCollectionIsolation:
processor = StorageProcessor(taskgroup=taskgroup_mock) processor = StorageProcessor(taskgroup=taskgroup_mock)
# Create test message without user/collection
metadata = Metadata(id="test-id") metadata = Metadata(id="test-id")
triple = Triple( triple = Triple(
@ -133,49 +119,40 @@ class TestNeo4jUserCollectionIsolation:
o=Term(type=IRI, iri="http://example.com/object") o=Term(type=IRI, iri="http://example.com/object")
) )
message = Triples( message = Triples(metadata=metadata, triples=[triple])
metadata=metadata,
triples=[triple]
)
# Mock execute_query
mock_summary = MagicMock() mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1 mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10 mock_summary.result_available_after = 10
mock_driver.execute_query.return_value.summary = mock_summary mock_driver.execute_query.return_value.summary = mock_summary
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True): with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message) await processor.store_triples("default", message)
# Verify defaults were used
mock_driver.execute_query.assert_any_call( mock_driver.execute_query.assert_any_call(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri="http://example.com/subject", uri="http://example.com/subject",
user="default", workspace="default",
collection="default", collection="default",
database_='neo4j' database_='neo4j'
) )
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_triples_filters_by_user_collection(self, mock_graph_db): async def test_query_triples_filters_by_workspace_collection(self, mock_graph_db):
"""Test that query service filters results by user/collection""" """Test that query service filters results by workspace/collection"""
mock_driver = MagicMock() mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver mock_graph_db.driver.return_value = mock_driver
processor = QueryProcessor(taskgroup=MagicMock()) processor = QueryProcessor(taskgroup=MagicMock())
# Create test query
query = TriplesQueryRequest( query = TriplesQueryRequest(
user="test_user",
collection="test_collection", collection="test_collection",
s=Term(type=IRI, iri="http://example.com/subject"), s=Term(type=IRI, iri="http://example.com/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"), p=Term(type=IRI, iri="http://example.com/predicate"),
o=None o=None
) )
# Mock query results
mock_records = [ mock_records = [
MagicMock(data=lambda: {"dest": "http://example.com/object1"}), MagicMock(data=lambda: {"dest": "http://example.com/object1"}),
MagicMock(data=lambda: {"dest": "literal_value"}) MagicMock(data=lambda: {"dest": "literal_value"})
@ -183,64 +160,48 @@ class TestNeo4jUserCollectionIsolation:
mock_driver.execute_query.return_value = (mock_records, MagicMock(), MagicMock()) mock_driver.execute_query.return_value = (mock_records, MagicMock(), MagicMock())
result = await processor.query_triples(query) await processor.query_triples("test_workspace", query)
# Verify queries include user/collection filters
expected_literal_query = ( expected_literal_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) " "(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN dest.value as dest" "RETURN dest.value as dest"
) )
expected_node_query = (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) "
"RETURN dest.uri as dest"
)
# Check that queries were executed with user/collection parameters
calls = mock_driver.execute_query.call_args_list calls = mock_driver.execute_query.call_args_list
assert any( assert any(
expected_literal_query in str(call) and expected_literal_query in str(c) and
"user='test_user'" in str(call) and "workspace='test_workspace'" in str(c) and
"collection='test_collection'" in str(call) "collection='test_collection'" in str(c)
for call in calls for c in calls
) )
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_triples_with_default_user_collection(self, mock_graph_db): async def test_query_triples_with_default_collection(self, mock_graph_db):
"""Test that query service uses defaults when user/collection not provided""" """Test that query service uses default collection when not provided"""
mock_driver = MagicMock() mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver mock_graph_db.driver.return_value = mock_driver
processor = QueryProcessor(taskgroup=MagicMock()) processor = QueryProcessor(taskgroup=MagicMock())
# Create test query without user/collection query = TriplesQueryRequest(s=None, p=None, o=None)
query = TriplesQueryRequest(
s=None,
p=None,
o=None
)
# Mock empty results
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
result = await processor.query_triples(query) await processor.query_triples("default", query)
# Verify defaults were used in queries
calls = mock_driver.execute_query.call_args_list calls = mock_driver.execute_query.call_args_list
assert any( assert any(
"user='default'" in str(call) and "collection='default'" in str(call) "workspace='default'" in str(c) and "collection='default'" in str(c)
for call in calls for c in calls
) )
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_data_isolation_between_users(self, mock_graph_db): async def test_data_isolation_between_workspaces(self, mock_graph_db):
"""Test that data from different users is properly isolated""" """Test that data from different workspaces is properly isolated"""
taskgroup_mock = MagicMock() taskgroup_mock = MagicMock()
mock_driver = MagicMock() mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver mock_graph_db.driver.return_value = mock_driver
@ -249,102 +210,89 @@ class TestNeo4jUserCollectionIsolation:
processor = StorageProcessor(taskgroup=taskgroup_mock) processor = StorageProcessor(taskgroup=taskgroup_mock)
# Create messages for different users message_ws1 = Triples(
message_user1 = Triples( metadata=Metadata(collection="coll1"),
metadata=Metadata(user="user1", collection="coll1"),
triples=[ triples=[
Triple( Triple(
s=Term(type=IRI, iri="http://example.com/user1/subject"), s=Term(type=IRI, iri="http://example.com/ws1/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"), p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="user1_data") o=Term(type=LITERAL, value="ws1_data")
) )
] ]
) )
message_user2 = Triples( message_ws2 = Triples(
metadata=Metadata(user="user2", collection="coll2"), metadata=Metadata(collection="coll2"),
triples=[ triples=[
Triple( Triple(
s=Term(type=IRI, iri="http://example.com/user2/subject"), s=Term(type=IRI, iri="http://example.com/ws2/subject"),
p=Term(type=IRI, iri="http://example.com/predicate"), p=Term(type=IRI, iri="http://example.com/predicate"),
o=Term(type=LITERAL, value="user2_data") o=Term(type=LITERAL, value="ws2_data")
) )
] ]
) )
# Mock execute_query
mock_summary = MagicMock() mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1 mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10 mock_summary.result_available_after = 10
mock_driver.execute_query.return_value.summary = mock_summary mock_driver.execute_query.return_value.summary = mock_summary
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True): with patch.object(processor, 'collection_exists', return_value=True):
# Store data for both users await processor.store_triples("workspace1", message_ws1)
await processor.store_triples(message_user1) await processor.store_triples("workspace2", message_ws2)
await processor.store_triples(message_user2)
# Verify user1 data was stored with user1/coll1
mock_driver.execute_query.assert_any_call( mock_driver.execute_query.assert_any_call(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})", "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value="user1_data", value="ws1_data",
user="user1", workspace="workspace1",
collection="coll1", collection="coll1",
database_='neo4j' database_='neo4j'
) )
# Verify user2 data was stored with user2/coll2
mock_driver.execute_query.assert_any_call( mock_driver.execute_query.assert_any_call(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})", "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value="user2_data", value="ws2_data",
user="user2", workspace="workspace2",
collection="coll2", collection="coll2",
database_='neo4j' database_='neo4j'
) )
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_wildcard_query_respects_user_collection(self, mock_graph_db): async def test_wildcard_query_respects_workspace_collection(self, mock_graph_db):
"""Test that wildcard queries still filter by user/collection""" """Test that wildcard queries still filter by workspace/collection"""
mock_driver = MagicMock() mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver mock_graph_db.driver.return_value = mock_driver
processor = QueryProcessor(taskgroup=MagicMock()) processor = QueryProcessor(taskgroup=MagicMock())
# Create wildcard query (all nulls) with user/collection
query = TriplesQueryRequest( query = TriplesQueryRequest(
user="test_user",
collection="test_collection", collection="test_collection",
s=None, s=None, p=None, o=None,
p=None,
o=None
) )
# Mock results
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
result = await processor.query_triples(query) await processor.query_triples("test_workspace", query)
# Verify wildcard queries include user/collection filters
wildcard_query = ( wildcard_query = (
"MATCH (src:Node {user: $user, collection: $collection})-" "MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->" "[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) " "(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest" "RETURN src.uri as src, rel.uri as rel, dest.value as dest"
) )
calls = mock_driver.execute_query.call_args_list calls = mock_driver.execute_query.call_args_list
assert any( assert any(
wildcard_query in str(call) and wildcard_query in str(c) and
"user='test_user'" in str(call) and "workspace='test_workspace'" in str(c) and
"collection='test_collection'" in str(call) "collection='test_collection'" in str(c)
for call in calls for c in calls
) )
def test_add_args_includes_neo4j_parameters(self): def test_add_args_includes_neo4j_parameters(self):
"""Test that add_args includes Neo4j-specific parameters""" """Test that add_args includes Neo4j-specific parameters"""
from argparse import ArgumentParser from argparse import ArgumentParser
from unittest.mock import patch
parser = ArgumentParser() parser = ArgumentParser()
@ -358,61 +306,55 @@ class TestNeo4jUserCollectionIsolation:
assert hasattr(args, 'password') assert hasattr(args, 'password')
assert hasattr(args, 'database') assert hasattr(args, 'database')
# Check defaults
assert args.graph_host == 'bolt://neo4j:7687' assert args.graph_host == 'bolt://neo4j:7687'
assert args.username == 'neo4j' assert args.username == 'neo4j'
assert args.password == 'password' assert args.password == 'password'
assert args.database == 'neo4j' assert args.database == 'neo4j'
class TestNeo4jUserCollectionRegression: class TestNeo4jWorkspaceCollectionRegression:
"""Regression tests to ensure user/collection isolation prevents data leaks""" """Regression tests to ensure workspace/collection isolation prevents data leaks"""
@patch('trustgraph.query.triples.neo4j.service.GraphDatabase') @patch('trustgraph.query.triples.neo4j.service.GraphDatabase')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regression_no_cross_user_data_access(self, mock_graph_db): async def test_regression_no_cross_workspace_data_access(self, mock_graph_db):
""" """
Regression test: Ensure user1 cannot access user2's data Regression test: Ensure workspace1 cannot access workspace2's data.
This test guards against the bug where all users shared the same Guards against a bug where all data shared the same Neo4j graph
Neo4j graph space, causing data contamination between users. space, causing data contamination between workspaces.
""" """
mock_driver = MagicMock() mock_driver = MagicMock()
mock_graph_db.driver.return_value = mock_driver mock_graph_db.driver.return_value = mock_driver
processor = QueryProcessor(taskgroup=MagicMock()) processor = QueryProcessor(taskgroup=MagicMock())
# User1 queries for all triples query_ws1 = TriplesQueryRequest(
query_user1 = TriplesQueryRequest(
user="user1",
collection="collection1", collection="collection1",
s=None, p=None, o=None s=None, p=None, o=None
) )
# Mock that the database has data but none matching user1/collection1
mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock()) mock_driver.execute_query.return_value = ([], MagicMock(), MagicMock())
result = await processor.query_triples(query_user1) result = await processor.query_triples("workspace1", query_ws1)
# Verify empty results (user1 cannot see other users' data)
assert len(result) == 0 assert len(result) == 0
# Verify the query included user/collection filters
calls = mock_driver.execute_query.call_args_list calls = mock_driver.execute_query.call_args_list
for call in calls: for c in calls:
query_str = str(call) query_str = str(c)
if "MATCH" in query_str: if "MATCH" in query_str:
assert "user: $user" in query_str or "user='user1'" in query_str assert "workspace: $workspace" in query_str or "workspace='workspace1'" in query_str
assert "collection: $collection" in query_str or "collection='collection1'" in query_str assert "collection: $collection" in query_str or "collection='collection1'" in query_str
@patch('trustgraph.storage.triples.neo4j.write.GraphDatabase') @patch('trustgraph.storage.triples.neo4j.write.GraphDatabase')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regression_same_uri_different_users(self, mock_graph_db): async def test_regression_same_uri_different_workspaces(self, mock_graph_db):
""" """
Regression test: Same URI in different user contexts should create separate nodes Regression test: Same URI in different workspace contexts should create separate nodes.
This ensures that http://example.com/entity for user1 is completely separate Ensures http://example.com/entity in workspace1 is completely
from http://example.com/entity for user2. separate from the same URI in workspace2.
""" """
taskgroup_mock = MagicMock() taskgroup_mock = MagicMock()
mock_driver = MagicMock() mock_driver = MagicMock()
@ -422,57 +364,53 @@ class TestNeo4jUserCollectionRegression:
processor = StorageProcessor(taskgroup=taskgroup_mock) processor = StorageProcessor(taskgroup=taskgroup_mock)
# Same URI for different users
shared_uri = "http://example.com/shared_entity" shared_uri = "http://example.com/shared_entity"
message_user1 = Triples( message_ws1 = Triples(
metadata=Metadata(user="user1", collection="coll1"), metadata=Metadata(collection="coll1"),
triples=[ triples=[
Triple( Triple(
s=Term(type=IRI, iri=shared_uri), s=Term(type=IRI, iri=shared_uri),
p=Term(type=IRI, iri="http://example.com/p"), p=Term(type=IRI, iri="http://example.com/p"),
o=Term(type=LITERAL, value="user1_value") o=Term(type=LITERAL, value="ws1_value")
) )
] ]
) )
message_user2 = Triples( message_ws2 = Triples(
metadata=Metadata(user="user2", collection="coll2"), metadata=Metadata(collection="coll2"),
triples=[ triples=[
Triple( Triple(
s=Term(type=IRI, iri=shared_uri), s=Term(type=IRI, iri=shared_uri),
p=Term(type=IRI, iri="http://example.com/p"), p=Term(type=IRI, iri="http://example.com/p"),
o=Term(type=LITERAL, value="user2_value") o=Term(type=LITERAL, value="ws2_value")
) )
] ]
) )
# Mock execute_query
mock_summary = MagicMock() mock_summary = MagicMock()
mock_summary.counters.nodes_created = 1 mock_summary.counters.nodes_created = 1
mock_summary.result_available_after = 10 mock_summary.result_available_after = 10
mock_driver.execute_query.return_value.summary = mock_summary mock_driver.execute_query.return_value.summary = mock_summary
# Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True): with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message_user1) await processor.store_triples("workspace1", message_ws1)
await processor.store_triples(message_user2) await processor.store_triples("workspace2", message_ws2)
# Verify two separate nodes were created with same URI but different user/collection ws1_node_call = call(
user1_node_call = call( "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",
uri=shared_uri, uri=shared_uri,
user="user1", workspace="workspace1",
collection="coll1", collection="coll1",
database_='neo4j' database_='neo4j'
) )
user2_node_call = call( ws2_node_call = call(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=shared_uri, uri=shared_uri,
user="user2", workspace="workspace2",
collection="coll2", collection="coll2",
database_='neo4j' database_='neo4j'
) )
mock_driver.execute_query.assert_has_calls([user1_node_call, user2_node_call], any_order=True) mock_driver.execute_query.assert_has_calls([ws1_node_call, ws2_node_call], any_order=True)

View file

@ -1,3 +1,12 @@
def _flow_mock(workspace):
"""Build a mock flow object that is callable and exposes .workspace."""
from unittest.mock import MagicMock
f = MagicMock()
f.workspace = workspace
return f
""" """
Unit tests for trustgraph.storage.row_embeddings.qdrant.write Unit tests for trustgraph.storage.row_embeddings.qdrant.write
Tests the Stage 2 processor that stores pre-computed row embeddings in Qdrant. Tests the Stage 2 processor that stores pre-computed row embeddings in Qdrant.
@ -92,13 +101,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
collection_name = processor.get_collection_name( collection_name = processor.get_collection_name(
user="test_user", workspace="test_workspace",
collection="test_collection", collection="test_collection",
schema_name="customer_data", schema_name="customer_data",
dimension=384 dimension=384
) )
assert collection_name == "rows_test_user_test_collection_customer_data_384" assert collection_name == "rows_test_workspace_test_collection_customer_data_384"
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_ensure_collection_creates_new(self, mock_qdrant_client): async def test_ensure_collection_creates_new(self, mock_qdrant_client):
@ -185,11 +194,10 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
} }
processor = Processor(**config) processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {} processor.known_collections[('test_workspace', 'test_collection')] = {}
# Create embeddings message # Create embeddings message
metadata = MagicMock() metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection' metadata.collection = 'test_collection'
metadata.id = 'doc-123' metadata.id = 'doc-123'
@ -210,14 +218,14 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_msg = MagicMock() mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock()) await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace'))
# Verify upsert was called # Verify upsert was called
mock_qdrant_instance.upsert.assert_called_once() mock_qdrant_instance.upsert.assert_called_once()
# Verify upsert parameters # Verify upsert parameters
upsert_call_args = mock_qdrant_instance.upsert.call_args upsert_call_args = mock_qdrant_instance.upsert.call_args
assert upsert_call_args[1]['collection_name'] == 'rows_test_user_test_collection_customers_3' assert upsert_call_args[1]['collection_name'] == 'rows_test_workspace_test_collection_customers_3'
point = upsert_call_args[1]['points'][0] point = upsert_call_args[1]['points'][0]
assert point.vector == [0.1, 0.2, 0.3] assert point.vector == [0.1, 0.2, 0.3]
@ -243,10 +251,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
} }
processor = Processor(**config) processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {} processor.known_collections[('test_workspace', 'test_collection')] = {}
metadata = MagicMock() metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection' metadata.collection = 'test_collection'
metadata.id = 'doc-123' metadata.id = 'doc-123'
@ -267,7 +274,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_msg = MagicMock() mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock()) await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace'))
# Should be called once for the single embedding # Should be called once for the single embedding
assert mock_qdrant_instance.upsert.call_count == 1 assert mock_qdrant_instance.upsert.call_count == 1
@ -287,10 +294,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
} }
processor = Processor(**config) processor = Processor(**config)
processor.known_collections[('test_user', 'test_collection')] = {} processor.known_collections[('test_workspace', 'test_collection')] = {}
metadata = MagicMock() metadata = MagicMock()
metadata.user = 'test_user'
metadata.collection = 'test_collection' metadata.collection = 'test_collection'
metadata.id = 'doc-123' metadata.id = 'doc-123'
@ -311,7 +317,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_msg = MagicMock() mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock()) await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace'))
# Should not call upsert for empty vectors # Should not call upsert for empty vectors
mock_qdrant_instance.upsert.assert_not_called() mock_qdrant_instance.upsert.assert_not_called()
@ -334,7 +340,6 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
# No collections registered # No collections registered
metadata = MagicMock() metadata = MagicMock()
metadata.user = 'unknown_user'
metadata.collection = 'unknown_collection' metadata.collection = 'unknown_collection'
metadata.id = 'doc-123' metadata.id = 'doc-123'
@ -354,7 +359,7 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_msg = MagicMock() mock_msg = MagicMock()
mock_msg.value.return_value = embeddings_msg mock_msg.value.return_value = embeddings_msg
await processor.on_embeddings(mock_msg, MagicMock(), MagicMock()) await processor.on_embeddings(mock_msg, MagicMock(), _flow_mock('test_workspace'))
# Should not call upsert for unknown collection # Should not call upsert for unknown collection
mock_qdrant_instance.upsert.assert_not_called() mock_qdrant_instance.upsert.assert_not_called()
@ -368,11 +373,11 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
# Mock collections list # Mock collections list
mock_coll1 = MagicMock() mock_coll1 = MagicMock()
mock_coll1.name = 'rows_test_user_test_collection_schema1_384' mock_coll1.name = 'rows_test_workspace_test_collection_schema1_384'
mock_coll2 = MagicMock() mock_coll2 = MagicMock()
mock_coll2.name = 'rows_test_user_test_collection_schema2_384' mock_coll2.name = 'rows_test_workspace_test_collection_schema2_384'
mock_coll3 = MagicMock() mock_coll3 = MagicMock()
mock_coll3.name = 'rows_other_user_other_collection_schema_384' mock_coll3.name = 'rows_other_workspace_other_collection_schema_384'
mock_collections = MagicMock() mock_collections = MagicMock()
mock_collections.collections = [mock_coll1, mock_coll2, mock_coll3] mock_collections.collections = [mock_coll1, mock_coll2, mock_coll3]
@ -386,15 +391,15 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
} }
processor = Processor(**config) processor = Processor(**config)
processor.created_collections.add('rows_test_user_test_collection_schema1_384') processor.created_collections.add('rows_test_workspace_test_collection_schema1_384')
await processor.delete_collection('test_user', 'test_collection') await processor.delete_collection('test_workspace', 'test_collection')
# Should delete only the matching collections # Should delete only the matching collections
assert mock_qdrant_instance.delete_collection.call_count == 2 assert mock_qdrant_instance.delete_collection.call_count == 2
# Verify the cached collection was removed # Verify the cached collection was removed
assert 'rows_test_user_test_collection_schema1_384' not in processor.created_collections assert 'rows_test_workspace_test_collection_schema1_384' not in processor.created_collections
@patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient') @patch('trustgraph.storage.row_embeddings.qdrant.write.QdrantClient')
async def test_delete_collection_schema(self, mock_qdrant_client): async def test_delete_collection_schema(self, mock_qdrant_client):
@ -404,9 +409,9 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
mock_qdrant_instance = MagicMock() mock_qdrant_instance = MagicMock()
mock_coll1 = MagicMock() mock_coll1 = MagicMock()
mock_coll1.name = 'rows_test_user_test_collection_customers_384' mock_coll1.name = 'rows_test_workspace_test_collection_customers_384'
mock_coll2 = MagicMock() mock_coll2 = MagicMock()
mock_coll2.name = 'rows_test_user_test_collection_orders_384' mock_coll2.name = 'rows_test_workspace_test_collection_orders_384'
mock_collections = MagicMock() mock_collections = MagicMock()
mock_collections.collections = [mock_coll1, mock_coll2] mock_collections.collections = [mock_coll1, mock_coll2]
@ -422,13 +427,13 @@ class TestQdrantRowEmbeddingsStorage(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
await processor.delete_collection_schema( await processor.delete_collection_schema(
'test_user', 'test_collection', 'customers' 'test_workspace', 'test_collection', 'customers'
) )
# Should only delete the customers schema collection # Should only delete the customers schema collection
mock_qdrant_instance.delete_collection.assert_called_once() mock_qdrant_instance.delete_collection.assert_called_once()
call_args = mock_qdrant_instance.delete_collection.call_args[0] call_args = mock_qdrant_instance.delete_collection.call_args[0]
assert call_args[0] == 'rows_test_user_test_collection_customers_384' assert call_args[0] == 'rows_test_workspace_test_collection_customers_384'
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -187,7 +187,7 @@ class TestRowsCassandraStorageLogic:
) )
} }
} }
processor.tables_initialized = {"test_user"} processor.tables_initialized = {"default"}
processor.registered_partitions = set() processor.registered_partitions = set()
processor.session = MagicMock() processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
@ -204,7 +204,6 @@ class TestRowsCassandraStorageLogic:
test_obj = ExtractedObject( test_obj = ExtractedObject(
metadata=Metadata( metadata=Metadata(
id="test-001", id="test-001",
user="test_user",
collection="test_collection", collection="test_collection",
), ),
schema_name="test_schema", schema_name="test_schema",
@ -227,7 +226,7 @@ class TestRowsCassandraStorageLogic:
values = insert_call[0][2] values = insert_call[0][2]
# Verify using unified rows table # Verify using unified rows table
assert "INSERT INTO test_user.rows" in insert_cql assert "INSERT INTO default.rows" in insert_cql
# Values should be: (collection, schema_name, index_name, index_value, data, source) # Values should be: (collection, schema_name, index_name, index_value, data, source)
assert values[0] == "test_collection" # collection assert values[0] == "test_collection" # collection
@ -254,7 +253,7 @@ class TestRowsCassandraStorageLogic:
) )
} }
} }
processor.tables_initialized = {"test_user"} processor.tables_initialized = {"default"}
processor.registered_partitions = set() processor.registered_partitions = set()
processor.session = MagicMock() processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
@ -270,7 +269,6 @@ class TestRowsCassandraStorageLogic:
test_obj = ExtractedObject( test_obj = ExtractedObject(
metadata=Metadata( metadata=Metadata(
id="test-001", id="test-001",
user="test_user",
collection="test_collection", collection="test_collection",
), ),
schema_name="multi_index_schema", schema_name="multi_index_schema",
@ -315,7 +313,7 @@ class TestRowsCassandraStorageBatchLogic:
) )
} }
} }
processor.tables_initialized = {"test_user"} processor.tables_initialized = {"default"}
processor.registered_partitions = set() processor.registered_partitions = set()
processor.session = MagicMock() processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
@ -332,7 +330,6 @@ class TestRowsCassandraStorageBatchLogic:
batch_obj = ExtractedObject( batch_obj = ExtractedObject(
metadata=Metadata( metadata=Metadata(
id="batch-001", id="batch-001",
user="test_user",
collection="batch_collection", collection="batch_collection",
), ),
schema_name="batch_schema", schema_name="batch_schema",
@ -373,7 +370,7 @@ class TestRowsCassandraStorageBatchLogic:
) )
} }
} }
processor.tables_initialized = {"test_user"} processor.tables_initialized = {"default"}
processor.registered_partitions = set() processor.registered_partitions = set()
processor.session = MagicMock() processor.session = MagicMock()
processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor) processor.sanitize_name = Processor.sanitize_name.__get__(processor, Processor)
@ -388,7 +385,6 @@ class TestRowsCassandraStorageBatchLogic:
empty_batch_obj = ExtractedObject( empty_batch_obj = ExtractedObject(
metadata=Metadata( metadata=Metadata(
id="empty-001", id="empty-001",
user="test_user",
collection="empty_collection", collection="empty_collection",
), ),
schema_name="empty_schema", schema_name="empty_schema",
@ -446,7 +442,7 @@ class TestUnifiedTableStructure:
def test_ensure_tables_idempotent(self): def test_ensure_tables_idempotent(self):
"""Test that ensure_tables is idempotent""" """Test that ensure_tables is idempotent"""
processor = MagicMock() processor = MagicMock()
processor.tables_initialized = {"test_user"} # Already initialized processor.tables_initialized = {"default"} # Already initialized
processor.session = MagicMock() processor.session = MagicMock()
processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor) processor.ensure_tables = Processor.ensure_tables.__get__(processor, Processor)

View file

@ -102,11 +102,10 @@ class TestCassandraStorageProcessor:
# Create mock message # Create mock message
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1' mock_message.metadata.collection = 'collection1'
mock_message.triples = [] mock_message.triples = []
await processor.store_triples(mock_message) await processor.store_triples('user1', mock_message)
# Verify KnowledgeGraph was called with auth parameters # Verify KnowledgeGraph was called with auth parameters
mock_kg_class.assert_called_once_with( mock_kg_class.assert_called_once_with(
@ -129,11 +128,10 @@ class TestCassandraStorageProcessor:
# Create mock message # Create mock message
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'user2'
mock_message.metadata.collection = 'collection2' mock_message.metadata.collection = 'collection2'
mock_message.triples = [] mock_message.triples = []
await processor.store_triples(mock_message) await processor.store_triples('user2', mock_message)
# Verify KnowledgeGraph was called without auth parameters # Verify KnowledgeGraph was called without auth parameters
mock_kg_class.assert_called_once_with( mock_kg_class.assert_called_once_with(
@ -154,16 +152,15 @@ class TestCassandraStorageProcessor:
# Create mock message # Create mock message
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1' mock_message.metadata.collection = 'collection1'
mock_message.triples = [] mock_message.triples = []
# First call should create TrustGraph # First call should create TrustGraph
await processor.store_triples(mock_message) await processor.store_triples('user1', mock_message)
assert mock_kg_class.call_count == 1 assert mock_kg_class.call_count == 1
# Second call with same table should reuse TrustGraph # Second call with same table should reuse TrustGraph
await processor.store_triples(mock_message) await processor.store_triples('user1', mock_message)
assert mock_kg_class.call_count == 1 # Should not increase assert mock_kg_class.call_count == 1 # Should not increase
@pytest.mark.asyncio @pytest.mark.asyncio
@ -205,11 +202,10 @@ class TestCassandraStorageProcessor:
# Create mock message # Create mock message
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1' mock_message.metadata.collection = 'collection1'
mock_message.triples = [triple1, triple2] mock_message.triples = [triple1, triple2]
await processor.store_triples(mock_message) await processor.store_triples('user1', mock_message)
# Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters) # Verify both triples were inserted (with g=, otype=, dtype=, lang= parameters)
assert mock_tg_instance.insert.call_count == 2 assert mock_tg_instance.insert.call_count == 2
@ -234,11 +230,10 @@ class TestCassandraStorageProcessor:
# Create mock message with empty triples # Create mock message with empty triples
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1' mock_message.metadata.collection = 'collection1'
mock_message.triples = [] mock_message.triples = []
await processor.store_triples(mock_message) await processor.store_triples('user1', mock_message)
# Verify no triples were inserted # Verify no triples were inserted
mock_tg_instance.insert.assert_not_called() mock_tg_instance.insert.assert_not_called()
@ -255,12 +250,11 @@ class TestCassandraStorageProcessor:
# Create mock message # Create mock message
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1' mock_message.metadata.collection = 'collection1'
mock_message.triples = [] mock_message.triples = []
with pytest.raises(Exception, match="Connection failed"): with pytest.raises(Exception, match="Connection failed"):
await processor.store_triples(mock_message) await processor.store_triples('user1', mock_message)
# Verify sleep was called before re-raising # Verify sleep was called before re-raising
mock_sleep.assert_called_once_with(1) mock_sleep.assert_called_once_with(1)
@ -361,21 +355,19 @@ class TestCassandraStorageProcessor:
# First message with table1 # First message with table1
mock_message1 = MagicMock() mock_message1 = MagicMock()
mock_message1.metadata.user = 'user1'
mock_message1.metadata.collection = 'collection1' mock_message1.metadata.collection = 'collection1'
mock_message1.triples = [] mock_message1.triples = []
await processor.store_triples(mock_message1) await processor.store_triples('user1', mock_message1)
assert processor.table == 'user1' assert processor.table == 'user1'
assert processor.tg == mock_tg_instance1 assert processor.tg == mock_tg_instance1
# Second message with different table # Second message with different table
mock_message2 = MagicMock() mock_message2 = MagicMock()
mock_message2.metadata.user = 'user2'
mock_message2.metadata.collection = 'collection2' mock_message2.metadata.collection = 'collection2'
mock_message2.triples = [] mock_message2.triples = []
await processor.store_triples(mock_message2) await processor.store_triples('user2', mock_message2)
assert processor.table == 'user2' assert processor.table == 'user2'
assert processor.tg == mock_tg_instance2 assert processor.tg == mock_tg_instance2
@ -407,11 +399,10 @@ class TestCassandraStorageProcessor:
triple.g = None triple.g = None
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'test_user'
mock_message.metadata.collection = 'test_collection' mock_message.metadata.collection = 'test_collection'
mock_message.triples = [triple] mock_message.triples = [triple]
await processor.store_triples(mock_message) await processor.store_triples('test_workspace', mock_message)
# Verify the triple was inserted with special characters preserved # Verify the triple was inserted with special characters preserved
mock_tg_instance.insert.assert_called_once_with( mock_tg_instance.insert.assert_called_once_with(
@ -440,12 +431,11 @@ class TestCassandraStorageProcessor:
mock_kg_class.side_effect = Exception("Connection failed") mock_kg_class.side_effect = Exception("Connection failed")
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'new_user'
mock_message.metadata.collection = 'new_collection' mock_message.metadata.collection = 'new_collection'
mock_message.triples = [] mock_message.triples = []
with pytest.raises(Exception, match="Connection failed"): with pytest.raises(Exception, match="Connection failed"):
await processor.store_triples(mock_message) await processor.store_triples('new_user', mock_message)
# Table should remain unchanged since self.table = table happens after try/except # Table should remain unchanged since self.table = table happens after try/except
assert processor.table == ('old_user', 'old_collection') assert processor.table == ('old_user', 'old_collection')
@ -468,11 +458,10 @@ class TestCassandraPerformanceOptimizations:
processor = Processor(taskgroup=taskgroup_mock) processor = Processor(taskgroup=taskgroup_mock)
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1' mock_message.metadata.collection = 'collection1'
mock_message.triples = [] mock_message.triples = []
await processor.store_triples(mock_message) await processor.store_triples('user1', mock_message)
# Verify KnowledgeGraph instance uses legacy mode # Verify KnowledgeGraph instance uses legacy mode
assert mock_tg_instance is not None assert mock_tg_instance is not None
@ -489,11 +478,10 @@ class TestCassandraPerformanceOptimizations:
processor = Processor(taskgroup=taskgroup_mock) processor = Processor(taskgroup=taskgroup_mock)
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1' mock_message.metadata.collection = 'collection1'
mock_message.triples = [] mock_message.triples = []
await processor.store_triples(mock_message) await processor.store_triples('user1', mock_message)
# Verify KnowledgeGraph instance is in optimized mode # Verify KnowledgeGraph instance is in optimized mode
assert mock_tg_instance is not None assert mock_tg_instance is not None
@ -523,11 +511,10 @@ class TestCassandraPerformanceOptimizations:
triple.g = None triple.g = None
mock_message = MagicMock() mock_message = MagicMock()
mock_message.metadata.user = 'user1'
mock_message.metadata.collection = 'collection1' mock_message.metadata.collection = 'collection1'
mock_message.triples = [triple] mock_message.triples = [triple]
await processor.store_triples(mock_message) await processor.store_triples('user1', mock_message)
# Verify insert was called for the triple (implementation details tested in KnowledgeGraph) # Verify insert was called for the triple (implementation details tested in KnowledgeGraph)
mock_tg_instance.insert.assert_called_once_with( mock_tg_instance.insert.assert_called_once_with(

View file

@ -17,7 +17,6 @@ class TestFalkorDBStorageProcessor:
"""Create a mock message for testing""" """Create a mock message for testing"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
# Create a test triple # Create a test triple
@ -89,13 +88,13 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result processor.io.query.return_value = mock_result
processor.create_node(test_uri, 'test_user', 'test_collection') processor.create_node(test_uri, 'test_workspace', 'test_collection')
processor.io.query.assert_called_once_with( processor.io.query.assert_called_once_with(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
params={ params={
"uri": test_uri, "uri": test_uri,
"user": 'test_user', "workspace": 'test_workspace',
"collection": 'test_collection', "collection": 'test_collection',
}, },
) )
@ -109,13 +108,13 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result processor.io.query.return_value = mock_result
processor.create_literal(test_value, 'test_user', 'test_collection') processor.create_literal(test_value, 'test_workspace', 'test_collection')
processor.io.query.assert_called_once_with( processor.io.query.assert_called_once_with(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})", "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
params={ params={
"value": test_value, "value": test_value,
"user": 'test_user', "workspace": 'test_workspace',
"collection": 'test_collection', "collection": 'test_collection',
}, },
) )
@ -132,17 +131,17 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result processor.io.query.return_value = mock_result
processor.relate_node(src_uri, pred_uri, dest_uri, 'test_user', 'test_collection') processor.relate_node(src_uri, pred_uri, dest_uri, 'test_workspace', 'test_collection')
processor.io.query.assert_called_once_with( processor.io.query.assert_called_once_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
params={ params={
"src": src_uri, "src": src_uri,
"dest": dest_uri, "dest": dest_uri,
"uri": pred_uri, "uri": pred_uri,
"user": 'test_user', "workspace": 'test_workspace',
"collection": 'test_collection', "collection": 'test_collection',
}, },
) )
@ -159,17 +158,17 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result processor.io.query.return_value = mock_result
processor.relate_literal(src_uri, pred_uri, literal_value, 'test_user', 'test_collection') processor.relate_literal(src_uri, pred_uri, literal_value, 'test_workspace', 'test_collection')
processor.io.query.assert_called_once_with( processor.io.query.assert_called_once_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
params={ params={
"src": src_uri, "src": src_uri,
"dest": literal_value, "dest": literal_value,
"uri": pred_uri, "uri": pred_uri,
"user": 'test_user', "workspace": 'test_workspace',
"collection": 'test_collection', "collection": 'test_collection',
}, },
) )
@ -179,7 +178,6 @@ class TestFalkorDBStorageProcessor:
"""Test storing triple with URI object""" """Test storing triple with URI object"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
triple = Triple( triple = Triple(
@ -200,21 +198,21 @@ class TestFalkorDBStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True): with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message) await processor.store_triples('test_workspace', message)
# Verify queries were called in the correct order # Verify queries were called in the correct order
expected_calls = [ expected_calls = [
# Create subject node # Create subject node
(("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",), (("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",),
{"params": {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection"}}), {"params": {"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection"}}),
# Create object node # Create object node
(("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",), (("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",),
{"params": {"uri": "http://example.com/object", "user": "test_user", "collection": "test_collection"}}), {"params": {"uri": "http://example.com/object", "workspace": "test_workspace", "collection": "test_collection"}}),
# Create relationship # Create relationship
(("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " (("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",), "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",),
{"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate", "user": "test_user", "collection": "test_collection"}}), {"params": {"src": "http://example.com/subject", "dest": "http://example.com/object", "uri": "http://example.com/predicate", "workspace": "test_workspace", "collection": "test_collection"}}),
] ]
assert processor.io.query.call_count == 3 assert processor.io.query.call_count == 3
@ -242,16 +240,16 @@ class TestFalkorDBStorageProcessor:
# Verify queries were called in the correct order # Verify queries were called in the correct order
expected_calls = [ expected_calls = [
# Create subject node # Create subject node
(("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})",), (("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",),
{"params": {"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection"}}), {"params": {"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection"}}),
# Create literal object # Create literal object
(("MERGE (n:Literal {value: $value, user: $user, collection: $collection})",), (("MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",),
{"params": {"value": "literal object", "user": "test_user", "collection": "test_collection"}}), {"params": {"value": "literal object", "workspace": "test_workspace", "collection": "test_collection"}}),
# Create relationship # Create relationship
(("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " (("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)",), "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",),
{"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate", "user": "test_user", "collection": "test_collection"}}), {"params": {"src": "http://example.com/subject", "dest": "literal object", "uri": "http://example.com/predicate", "workspace": "test_workspace", "collection": "test_collection"}}),
] ]
assert processor.io.query.call_count == 3 assert processor.io.query.call_count == 3
@ -265,7 +263,6 @@ class TestFalkorDBStorageProcessor:
"""Test storing multiple triples""" """Test storing multiple triples"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
triple1 = Triple( triple1 = Triple(
@ -291,7 +288,7 @@ class TestFalkorDBStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True): with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message) await processor.store_triples('test_workspace', message)
# Verify total number of queries (3 per triple) # Verify total number of queries (3 per triple)
assert processor.io.query.call_count == 6 assert processor.io.query.call_count == 6
@ -313,7 +310,6 @@ class TestFalkorDBStorageProcessor:
"""Test storing empty triples list""" """Test storing empty triples list"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
message.triples = [] message.triples = []
@ -323,7 +319,7 @@ class TestFalkorDBStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True): with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message) await processor.store_triples('test_workspace', message)
# Verify no queries were made # Verify no queries were made
processor.io.query.assert_not_called() processor.io.query.assert_not_called()
@ -333,7 +329,6 @@ class TestFalkorDBStorageProcessor:
"""Test storing triples with mixed URI and literal objects""" """Test storing triples with mixed URI and literal objects"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
triple1 = Triple( triple1 = Triple(
@ -359,7 +354,7 @@ class TestFalkorDBStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True): with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message) await processor.store_triples('test_workspace', message)
# Verify total number of queries (3 per triple) # Verify total number of queries (3 per triple)
assert processor.io.query.call_count == 6 assert processor.io.query.call_count == 6
@ -450,13 +445,13 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result processor.io.query.return_value = mock_result
processor.create_node(test_uri, 'test_user', 'test_collection') processor.create_node(test_uri, 'test_workspace', 'test_collection')
processor.io.query.assert_called_once_with( processor.io.query.assert_called_once_with(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
params={ params={
"uri": test_uri, "uri": test_uri,
"user": 'test_user', "workspace": 'test_workspace',
"collection": 'test_collection', "collection": 'test_collection',
}, },
) )
@ -470,13 +465,13 @@ class TestFalkorDBStorageProcessor:
processor.io.query.return_value = mock_result processor.io.query.return_value = mock_result
processor.create_literal(test_value, 'test_user', 'test_collection') processor.create_literal(test_value, 'test_workspace', 'test_collection')
processor.io.query.assert_called_once_with( processor.io.query.assert_called_once_with(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})", "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
params={ params={
"value": test_value, "value": test_value,
"user": 'test_user', "workspace": 'test_workspace',
"collection": 'test_collection', "collection": 'test_collection',
}, },
) )

View file

@ -17,7 +17,6 @@ class TestMemgraphStorageProcessor:
"""Create a mock message for testing""" """Create a mock message for testing"""
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
# Create a test triple # Create a test triple
@ -43,7 +42,7 @@ class TestMemgraphStorageProcessor:
taskgroup=MagicMock(), taskgroup=MagicMock(),
id='test-memgraph-storage', id='test-memgraph-storage',
graph_host='bolt://localhost:7687', graph_host='bolt://localhost:7687',
username='test_user', username='test_workspace',
password='test_pass', password='test_pass',
database='test_db' database='test_db'
) )
@ -105,9 +104,9 @@ class TestMemgraphStorageProcessor:
"CREATE INDEX ON :Node(uri)", "CREATE INDEX ON :Node(uri)",
"CREATE INDEX ON :Literal", "CREATE INDEX ON :Literal",
"CREATE INDEX ON :Literal(value)", "CREATE INDEX ON :Literal(value)",
"CREATE INDEX ON :Node(user)", "CREATE INDEX ON :Node(workspace)",
"CREATE INDEX ON :Node(collection)", "CREATE INDEX ON :Node(collection)",
"CREATE INDEX ON :Literal(user)", "CREATE INDEX ON :Literal(workspace)",
"CREATE INDEX ON :Literal(collection)" "CREATE INDEX ON :Literal(collection)"
] ]
@ -145,12 +144,12 @@ class TestMemgraphStorageProcessor:
processor.io.execute_query.return_value = mock_result processor.io.execute_query.return_value = mock_result
processor.create_node(test_uri, "test_user", "test_collection") processor.create_node(test_uri, "test_workspace", "test_collection")
processor.io.execute_query.assert_called_once_with( processor.io.execute_query.assert_called_once_with(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=test_uri, uri=test_uri,
user="test_user", workspace="test_workspace",
collection="test_collection", collection="test_collection",
database_=processor.db database_=processor.db
) )
@ -166,12 +165,12 @@ class TestMemgraphStorageProcessor:
processor.io.execute_query.return_value = mock_result processor.io.execute_query.return_value = mock_result
processor.create_literal(test_value, "test_user", "test_collection") processor.create_literal(test_value, "test_workspace", "test_collection")
processor.io.execute_query.assert_called_once_with( processor.io.execute_query.assert_called_once_with(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})", "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value=test_value, value=test_value,
user="test_user", workspace="test_workspace",
collection="test_collection", collection="test_collection",
database_=processor.db database_=processor.db
) )
@ -190,14 +189,14 @@ class TestMemgraphStorageProcessor:
processor.io.execute_query.return_value = mock_result processor.io.execute_query.return_value = mock_result
processor.relate_node(src_uri, pred_uri, dest_uri, "test_user", "test_collection") processor.relate_node(src_uri, pred_uri, dest_uri, "test_workspace", "test_collection")
processor.io.execute_query.assert_called_once_with( processor.io.execute_query.assert_called_once_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=src_uri, dest=dest_uri, uri=pred_uri, src=src_uri, dest=dest_uri, uri=pred_uri,
user="test_user", collection="test_collection", workspace="test_workspace", collection="test_collection",
database_=processor.db database_=processor.db
) )
@ -215,14 +214,14 @@ class TestMemgraphStorageProcessor:
processor.io.execute_query.return_value = mock_result processor.io.execute_query.return_value = mock_result
processor.relate_literal(src_uri, pred_uri, literal_value, "test_user", "test_collection") processor.relate_literal(src_uri, pred_uri, literal_value, "test_workspace", "test_collection")
processor.io.execute_query.assert_called_once_with( processor.io.execute_query.assert_called_once_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=src_uri, dest=literal_value, uri=pred_uri, src=src_uri, dest=literal_value, uri=pred_uri,
user="test_user", collection="test_collection", workspace="test_workspace", collection="test_collection",
database_=processor.db database_=processor.db
) )
@ -236,22 +235,22 @@ class TestMemgraphStorageProcessor:
o=Term(type=IRI, iri='http://example.com/object') o=Term(type=IRI, iri='http://example.com/object')
) )
processor.create_triple(mock_tx, triple, "test_user", "test_collection") processor.create_triple(mock_tx, triple, "test_workspace", "test_collection")
# Verify transaction calls # Verify transaction calls
expected_calls = [ expected_calls = [
# Create subject node # Create subject node
("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", ("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
{'uri': 'http://example.com/subject', 'user': 'test_user', 'collection': 'test_collection'}), {'uri': 'http://example.com/subject', 'workspace': 'test_workspace', 'collection': 'test_collection'}),
# Create object node # Create object node
("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", ("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
{'uri': 'http://example.com/object', 'user': 'test_user', 'collection': 'test_collection'}), {'uri': 'http://example.com/object', 'workspace': 'test_workspace', 'collection': 'test_collection'}),
# Create relationship # Create relationship
("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " ("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
{'src': 'http://example.com/subject', 'dest': 'http://example.com/object', 'uri': 'http://example.com/predicate', {'src': 'http://example.com/subject', 'dest': 'http://example.com/object', 'uri': 'http://example.com/predicate',
'user': 'test_user', 'collection': 'test_collection'}) 'workspace': 'test_workspace', 'collection': 'test_collection'})
] ]
assert mock_tx.run.call_count == 3 assert mock_tx.run.call_count == 3
@ -270,22 +269,22 @@ class TestMemgraphStorageProcessor:
o=Term(type=LITERAL, value='literal object') o=Term(type=LITERAL, value='literal object')
) )
processor.create_triple(mock_tx, triple, "test_user", "test_collection") processor.create_triple(mock_tx, triple, "test_workspace", "test_collection")
# Verify transaction calls # Verify transaction calls
expected_calls = [ expected_calls = [
# Create subject node # Create subject node
("MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", ("MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
{'uri': 'http://example.com/subject', 'user': 'test_user', 'collection': 'test_collection'}), {'uri': 'http://example.com/subject', 'workspace': 'test_workspace', 'collection': 'test_collection'}),
# Create literal object # Create literal object
("MERGE (n:Literal {value: $value, user: $user, collection: $collection})", ("MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
{'value': 'literal object', 'user': 'test_user', 'collection': 'test_collection'}), {'value': 'literal object', 'workspace': 'test_workspace', 'collection': 'test_collection'}),
# Create relationship # Create relationship
("MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " ("MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
{'src': 'http://example.com/subject', 'dest': 'literal object', 'uri': 'http://example.com/predicate', {'src': 'http://example.com/subject', 'dest': 'literal object', 'uri': 'http://example.com/predicate',
'user': 'test_user', 'collection': 'test_collection'}) 'workspace': 'test_workspace', 'collection': 'test_collection'})
] ]
assert mock_tx.run.call_count == 3 assert mock_tx.run.call_count == 3
@ -323,7 +322,7 @@ class TestMemgraphStorageProcessor:
# Verify user/collection parameters were included # Verify user/collection parameters were included
for call in processor.io.execute_query.call_args_list: for call in processor.io.execute_query.call_args_list:
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1] call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
assert 'user' in call_kwargs assert 'workspace' in call_kwargs
assert 'collection' in call_kwargs assert 'collection' in call_kwargs
@pytest.mark.asyncio @pytest.mark.asyncio
@ -343,7 +342,6 @@ class TestMemgraphStorageProcessor:
# Create message with multiple triples # Create message with multiple triples
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
triple1 = Triple( triple1 = Triple(
@ -364,7 +362,7 @@ class TestMemgraphStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True): with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message) await processor.store_triples('test_workspace', message)
# Verify execute_query was called: # Verify execute_query was called:
# Triple1: create_node(s) + create_literal(o) + relate_literal = 3 calls # Triple1: create_node(s) + create_literal(o) + relate_literal = 3 calls
@ -375,7 +373,7 @@ class TestMemgraphStorageProcessor:
# Verify user/collection parameters were included in all calls # Verify user/collection parameters were included in all calls
for call in processor.io.execute_query.call_args_list: for call in processor.io.execute_query.call_args_list:
call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1] call_kwargs = call.kwargs if hasattr(call, 'kwargs') else call[1]
assert call_kwargs['user'] == 'test_user' assert call_kwargs['workspace'] == 'test_workspace'
assert call_kwargs['collection'] == 'test_collection' assert call_kwargs['collection'] == 'test_collection'
@pytest.mark.asyncio @pytest.mark.asyncio
@ -389,7 +387,6 @@ class TestMemgraphStorageProcessor:
message = MagicMock() message = MagicMock()
message.metadata = MagicMock() message.metadata = MagicMock()
message.metadata.user = 'test_user'
message.metadata.collection = 'test_collection' message.metadata.collection = 'test_collection'
message.triples = [] message.triples = []
@ -399,7 +396,7 @@ class TestMemgraphStorageProcessor:
with patch.object(processor, 'collection_exists', return_value=True): with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(message) await processor.store_triples('test_workspace', message)
# Verify no session calls were made (no triples to process) # Verify no session calls were made (no triples to process)
processor.io.session.assert_not_called() processor.io.session.assert_not_called()

View file

@ -68,9 +68,9 @@ class TestNeo4jStorageProcessor:
"CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)", "CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)",
"CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)", "CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)",
"CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)", "CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)",
"CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)", "CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.workspace, n.collection, n.uri)",
"CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)", "CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.workspace, n.collection, n.value)",
"CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)", "CREATE INDEX rel_workspace FOR ()-[r:Rel]-() ON (r.workspace)",
"CREATE INDEX rel_collection FOR ()-[r:Rel]-() ON (r.collection)" "CREATE INDEX rel_collection FOR ()-[r:Rel]-() ON (r.collection)"
] ]
@ -116,12 +116,12 @@ class TestNeo4jStorageProcessor:
processor = Processor(taskgroup=taskgroup_mock) processor = Processor(taskgroup=taskgroup_mock)
# Test create_node # Test create_node
processor.create_node("http://example.com/node", "test_user", "test_collection") processor.create_node("http://example.com/node", "test_workspace", "test_collection")
mock_driver.execute_query.assert_called_with( mock_driver.execute_query.assert_called_with(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri="http://example.com/node", uri="http://example.com/node",
user="test_user", workspace="test_workspace",
collection="test_collection", collection="test_collection",
database_="neo4j" database_="neo4j"
) )
@ -146,12 +146,12 @@ class TestNeo4jStorageProcessor:
processor = Processor(taskgroup=taskgroup_mock) processor = Processor(taskgroup=taskgroup_mock)
# Test create_literal # Test create_literal
processor.create_literal("literal value", "test_user", "test_collection") processor.create_literal("literal value", "test_workspace", "test_collection")
mock_driver.execute_query.assert_called_with( mock_driver.execute_query.assert_called_with(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})", "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value="literal value", value="literal value",
user="test_user", workspace="test_workspace",
collection="test_collection", collection="test_collection",
database_="neo4j" database_="neo4j"
) )
@ -180,18 +180,18 @@ class TestNeo4jStorageProcessor:
"http://example.com/subject", "http://example.com/subject",
"http://example.com/predicate", "http://example.com/predicate",
"http://example.com/object", "http://example.com/object",
"test_user", "test_workspace",
"test_collection" "test_collection"
) )
mock_driver.execute_query.assert_called_with( mock_driver.execute_query.assert_called_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src="http://example.com/subject", src="http://example.com/subject",
dest="http://example.com/object", dest="http://example.com/object",
uri="http://example.com/predicate", uri="http://example.com/predicate",
user="test_user", workspace="test_workspace",
collection="test_collection", collection="test_collection",
database_="neo4j" database_="neo4j"
) )
@ -220,18 +220,18 @@ class TestNeo4jStorageProcessor:
"http://example.com/subject", "http://example.com/subject",
"http://example.com/predicate", "http://example.com/predicate",
"literal value", "literal value",
"test_user", "test_workspace",
"test_collection" "test_collection"
) )
mock_driver.execute_query.assert_called_with( mock_driver.execute_query.assert_called_with(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src="http://example.com/subject", src="http://example.com/subject",
dest="literal value", dest="literal value",
uri="http://example.com/predicate", uri="http://example.com/predicate",
user="test_user", workspace="test_workspace",
collection="test_collection", collection="test_collection",
database_="neo4j" database_="neo4j"
) )
@ -268,36 +268,35 @@ class TestNeo4jStorageProcessor:
# Create mock message with metadata # Create mock message with metadata
mock_message = MagicMock() mock_message = MagicMock()
mock_message.triples = [triple] mock_message.triples = [triple]
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection" mock_message.metadata.collection = "test_collection"
# Mock collection_exists to bypass validation in unit tests # Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True): with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message) await processor.store_triples("test_workspace", mock_message)
# Verify create_node was called for subject and object # Verify create_node was called for subject and object
# Verify relate_node was called # Verify relate_node was called
expected_calls = [ expected_calls = [
# Subject node creation # Subject node creation
( (
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
{"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection", "database_": "neo4j"} {"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"}
), ),
# Object node creation # Object node creation
( (
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
{"uri": "http://example.com/object", "user": "test_user", "collection": "test_collection", "database_": "neo4j"} {"uri": "http://example.com/object", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"}
), ),
# Relationship creation # Relationship creation
( (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
{ {
"src": "http://example.com/subject", "src": "http://example.com/subject",
"dest": "http://example.com/object", "dest": "http://example.com/object",
"uri": "http://example.com/predicate", "uri": "http://example.com/predicate",
"user": "test_user", "workspace": "test_workspace",
"collection": "test_collection", "collection": "test_collection",
"database_": "neo4j" "database_": "neo4j"
} }
@ -340,12 +339,11 @@ class TestNeo4jStorageProcessor:
# Create mock message with metadata # Create mock message with metadata
mock_message = MagicMock() mock_message = MagicMock()
mock_message.triples = [triple] mock_message.triples = [triple]
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection" mock_message.metadata.collection = "test_collection"
# Mock collection_exists to bypass validation in unit tests # Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True): with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message) await processor.store_triples("test_workspace", mock_message)
# Verify create_node was called for subject # Verify create_node was called for subject
# Verify create_literal was called for object # Verify create_literal was called for object
@ -353,24 +351,24 @@ class TestNeo4jStorageProcessor:
expected_calls = [ expected_calls = [
# Subject node creation # Subject node creation
( (
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
{"uri": "http://example.com/subject", "user": "test_user", "collection": "test_collection", "database_": "neo4j"} {"uri": "http://example.com/subject", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"}
), ),
# Literal creation # Literal creation
( (
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})", "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
{"value": "literal value", "user": "test_user", "collection": "test_collection", "database_": "neo4j"} {"value": "literal value", "workspace": "test_workspace", "collection": "test_collection", "database_": "neo4j"}
), ),
# Relationship creation # Relationship creation
( (
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
{ {
"src": "http://example.com/subject", "src": "http://example.com/subject",
"dest": "literal value", "dest": "literal value",
"uri": "http://example.com/predicate", "uri": "http://example.com/predicate",
"user": "test_user", "workspace": "test_workspace",
"collection": "test_collection", "collection": "test_collection",
"database_": "neo4j" "database_": "neo4j"
} }
@ -421,12 +419,11 @@ class TestNeo4jStorageProcessor:
# Create mock message with metadata # Create mock message with metadata
mock_message = MagicMock() mock_message = MagicMock()
mock_message.triples = [triple1, triple2] mock_message.triples = [triple1, triple2]
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection" mock_message.metadata.collection = "test_collection"
# Mock collection_exists to bypass validation in unit tests # Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True): with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message) await processor.store_triples("test_workspace", mock_message)
# Should have processed both triples # Should have processed both triples
# Triple1: 2 nodes + 1 relationship = 3 calls # Triple1: 2 nodes + 1 relationship = 3 calls
@ -449,12 +446,11 @@ class TestNeo4jStorageProcessor:
# Create mock message with empty triples and metadata # Create mock message with empty triples and metadata
mock_message = MagicMock() mock_message = MagicMock()
mock_message.triples = [] mock_message.triples = []
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection" mock_message.metadata.collection = "test_collection"
# Mock collection_exists to bypass validation in unit tests # Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True): with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message) await processor.store_triples("test_workspace", mock_message)
# Should not have made any execute_query calls beyond index creation # Should not have made any execute_query calls beyond index creation
# Only index creation calls should have been made during initialization # Only index creation calls should have been made during initialization
@ -568,38 +564,37 @@ class TestNeo4jStorageProcessor:
mock_message = MagicMock() mock_message = MagicMock()
mock_message.triples = [triple] mock_message.triples = [triple]
mock_message.metadata.user = "test_user"
mock_message.metadata.collection = "test_collection" mock_message.metadata.collection = "test_collection"
# Mock collection_exists to bypass validation in unit tests # Mock collection_exists to bypass validation in unit tests
with patch.object(processor, 'collection_exists', return_value=True): with patch.object(processor, 'collection_exists', return_value=True):
await processor.store_triples(mock_message) await processor.store_triples("test_workspace", mock_message)
# Verify the triple was processed with special characters preserved # Verify the triple was processed with special characters preserved
mock_driver.execute_query.assert_any_call( mock_driver.execute_query.assert_any_call(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri="http://example.com/subject with spaces", uri="http://example.com/subject with spaces",
user="test_user", workspace="test_workspace",
collection="test_collection", collection="test_collection",
database_="neo4j" database_="neo4j"
) )
mock_driver.execute_query.assert_any_call( mock_driver.execute_query.assert_any_call(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})", "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value='literal with "quotes" and unicode: ñáéíóú', value='literal with "quotes" and unicode: ñáéíóú',
user="test_user", workspace="test_workspace",
collection="test_collection", collection="test_collection",
database_="neo4j" database_="neo4j"
) )
mock_driver.execute_query.assert_any_call( mock_driver.execute_query.assert_any_call(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src="http://example.com/subject with spaces", src="http://example.com/subject with spaces",
dest='literal with "quotes" and unicode: ñáéíóú', dest='literal with "quotes" and unicode: ñáéíóú',
uri="http://example.com/predicate:with/symbols", uri="http://example.com/predicate:with/symbols",
user="test_user", workspace="test_workspace",
collection="test_collection", collection="test_collection",
database_="neo4j" database_="neo4j"
) )

View file

@ -24,11 +24,10 @@ def _make_processor(qdrant_client=None):
return proc return proc
def _make_request(vector=None, user="test-user", collection="test-col", def _make_request(vector=None, collection="test-col",
schema_name="customers", limit=10, index_name=None): schema_name="customers", limit=10, index_name=None):
return RowEmbeddingsRequest( return RowEmbeddingsRequest(
vector=vector or [0.1, 0.2, 0.3], vector=vector or [0.1, 0.2, 0.3],
user=user,
collection=collection, collection=collection,
schema_name=schema_name, schema_name=schema_name,
limit=limit, limit=limit,
@ -36,6 +35,14 @@ def _make_request(vector=None, user="test-user", collection="test-col",
) )
def _make_flow(workspace="test-workspace", pub=None):
"""Make a mock flow object that is callable and has .workspace."""
flow = MagicMock()
flow.return_value = pub if pub is not None else AsyncMock()
flow.workspace = workspace
return flow
def _make_search_point(index_name, index_value, text, score): def _make_search_point(index_name, index_value, text, score):
point = MagicMock() point = MagicMock()
point.payload = { point.payload = {
@ -85,34 +92,33 @@ class TestFindCollection:
def test_finds_matching_collection(self): def test_finds_matching_collection(self):
proc = _make_processor() proc = _make_processor()
mock_coll = MagicMock() mock_coll = MagicMock()
mock_coll.name = "rows_test_user_test_col_customers_384" mock_coll.name = "rows_test_workspace_test_col_customers_384"
mock_collections = MagicMock() mock_collections = MagicMock()
mock_collections.collections = [mock_coll] mock_collections.collections = [mock_coll]
proc.qdrant.get_collections.return_value = mock_collections proc.qdrant.get_collections.return_value = mock_collections
result = proc.find_collection("test-user", "test-col", "customers") result = proc.find_collection("test-workspace", "test-col", "customers")
# Prefix: rows_test_user_test_col_customers_ assert result == "rows_test_workspace_test_col_customers_384"
assert result == "rows_test_user_test_col_customers_384"
def test_returns_none_when_no_match(self): def test_returns_none_when_no_match(self):
proc = _make_processor() proc = _make_processor()
mock_coll = MagicMock() mock_coll = MagicMock()
mock_coll.name = "rows_other_user_other_col_schema_768" mock_coll.name = "rows_other_workspace_other_col_schema_768"
mock_collections = MagicMock() mock_collections = MagicMock()
mock_collections.collections = [mock_coll] mock_collections.collections = [mock_coll]
proc.qdrant.get_collections.return_value = mock_collections proc.qdrant.get_collections.return_value = mock_collections
result = proc.find_collection("test-user", "test-col", "customers") result = proc.find_collection("test-workspace", "test-col", "customers")
assert result is None assert result is None
def test_returns_none_on_error(self): def test_returns_none_on_error(self):
proc = _make_processor() proc = _make_processor()
proc.qdrant.get_collections.side_effect = Exception("connection error") proc.qdrant.get_collections.side_effect = Exception("connection error")
result = proc.find_collection("user", "col", "schema") result = proc.find_collection("workspace", "col", "schema")
assert result is None assert result is None
@ -127,7 +133,7 @@ class TestQueryRowEmbeddings:
proc = _make_processor() proc = _make_processor()
request = _make_request(vector=[]) request = _make_request(vector=[])
result = await proc.query_row_embeddings(request) result = await proc.query_row_embeddings("test-workspace", request)
assert result == [] assert result == []
@pytest.mark.asyncio @pytest.mark.asyncio
@ -136,13 +142,13 @@ class TestQueryRowEmbeddings:
proc.find_collection = MagicMock(return_value=None) proc.find_collection = MagicMock(return_value=None)
request = _make_request() request = _make_request()
result = await proc.query_row_embeddings(request) result = await proc.query_row_embeddings("test-workspace", request)
assert result == [] assert result == []
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_successful_query_returns_matches(self): async def test_successful_query_returns_matches(self):
proc = _make_processor() proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384") proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
points = [ points = [
_make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95), _make_search_point("name", ["Alice Smith"], "Alice Smith", 0.95),
@ -153,7 +159,7 @@ class TestQueryRowEmbeddings:
proc.qdrant.query_points.return_value = mock_result proc.qdrant.query_points.return_value = mock_result
request = _make_request() request = _make_request()
result = await proc.query_row_embeddings(request) result = await proc.query_row_embeddings("test-workspace", request)
assert len(result) == 2 assert len(result) == 2
assert isinstance(result[0], RowIndexMatch) assert isinstance(result[0], RowIndexMatch)
@ -166,14 +172,14 @@ class TestQueryRowEmbeddings:
async def test_index_name_filter_applied(self): async def test_index_name_filter_applied(self):
"""When index_name is specified, a Qdrant filter should be used.""" """When index_name is specified, a Qdrant filter should be used."""
proc = _make_processor() proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384") proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
mock_result = MagicMock() mock_result = MagicMock()
mock_result.points = [] mock_result.points = []
proc.qdrant.query_points.return_value = mock_result proc.qdrant.query_points.return_value = mock_result
request = _make_request(index_name="address") request = _make_request(index_name="address")
await proc.query_row_embeddings(request) await proc.query_row_embeddings("test-workspace", request)
call_kwargs = proc.qdrant.query_points.call_args[1] call_kwargs = proc.qdrant.query_points.call_args[1]
assert call_kwargs["query_filter"] is not None assert call_kwargs["query_filter"] is not None
@ -182,14 +188,14 @@ class TestQueryRowEmbeddings:
async def test_no_index_name_no_filter(self): async def test_no_index_name_no_filter(self):
"""When index_name is empty, no filter should be applied.""" """When index_name is empty, no filter should be applied."""
proc = _make_processor() proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384") proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
mock_result = MagicMock() mock_result = MagicMock()
mock_result.points = [] mock_result.points = []
proc.qdrant.query_points.return_value = mock_result proc.qdrant.query_points.return_value = mock_result
request = _make_request(index_name="") request = _make_request(index_name="")
await proc.query_row_embeddings(request) await proc.query_row_embeddings("test-workspace", request)
call_kwargs = proc.qdrant.query_points.call_args[1] call_kwargs = proc.qdrant.query_points.call_args[1]
assert call_kwargs["query_filter"] is None assert call_kwargs["query_filter"] is None
@ -198,7 +204,7 @@ class TestQueryRowEmbeddings:
async def test_missing_payload_fields_default(self): async def test_missing_payload_fields_default(self):
"""Points with missing payload fields should use defaults.""" """Points with missing payload fields should use defaults."""
proc = _make_processor() proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384") proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
point = MagicMock() point = MagicMock()
point.payload = {} # Empty payload point.payload = {} # Empty payload
@ -209,7 +215,7 @@ class TestQueryRowEmbeddings:
proc.qdrant.query_points.return_value = mock_result proc.qdrant.query_points.return_value = mock_result
request = _make_request() request = _make_request()
result = await proc.query_row_embeddings(request) result = await proc.query_row_embeddings("test-workspace", request)
assert len(result) == 1 assert len(result) == 1
assert result[0].index_name == "" assert result[0].index_name == ""
@ -219,13 +225,13 @@ class TestQueryRowEmbeddings:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_qdrant_error_propagates(self): async def test_qdrant_error_propagates(self):
proc = _make_processor() proc = _make_processor()
proc.find_collection = MagicMock(return_value="rows_u_c_s_384") proc.find_collection = MagicMock(return_value="rows_w_c_s_384")
proc.qdrant.query_points.side_effect = Exception("qdrant down") proc.qdrant.query_points.side_effect = Exception("qdrant down")
request = _make_request() request = _make_request()
with pytest.raises(Exception, match="qdrant down"): with pytest.raises(Exception, match="qdrant down"):
await proc.query_row_embeddings(request) await proc.query_row_embeddings("test-workspace", request)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -243,7 +249,7 @@ class TestOnMessage:
]) ])
mock_pub = AsyncMock() mock_pub = AsyncMock()
flow = lambda name: mock_pub flow = _make_flow(pub=mock_pub)
msg = MagicMock() msg = MagicMock()
msg.value.return_value = _make_request() msg.value.return_value = _make_request()
@ -264,7 +270,7 @@ class TestOnMessage:
) )
mock_pub = AsyncMock() mock_pub = AsyncMock()
flow = lambda name: mock_pub flow = _make_flow(pub=mock_pub)
msg = MagicMock() msg = MagicMock()
msg.value.return_value = _make_request() msg.value.return_value = _make_request()
@ -284,7 +290,7 @@ class TestOnMessage:
proc.query_row_embeddings = AsyncMock(return_value=[]) proc.query_row_embeddings = AsyncMock(return_value=[])
mock_pub = AsyncMock() mock_pub = AsyncMock()
flow = lambda name: mock_pub flow = _make_flow(pub=mock_pub)
msg = MagicMock() msg = MagicMock()
msg.value.return_value = _make_request() msg.value.return_value = _make_request()

View file

@ -45,12 +45,9 @@ class TestGetGraphEmbeddings:
with `vector=` (singular) the schema field name. A previous with `vector=` (singular) the schema field name. A previous
version used `vectors=` and TypeError'd at runtime. version used `vectors=` and TypeError'd at runtime.
""" """
# Arrange — fake row matching the get_triples_stmt result shape:
# row[0..2] are unused by the method, row[3] is the entities blob
fake_row = ( fake_row = (
None, None, None, None, None, None,
[ [
# ((value, is_uri), vector)
(("http://example.org/alice", True), [0.1, 0.2, 0.3]), (("http://example.org/alice", True), [0.1, 0.2, 0.3]),
(("http://example.org/bob", True), [0.4, 0.5, 0.6]), (("http://example.org/bob", True), [0.4, 0.5, 0.6]),
(("a literal entity", False), [0.7, 0.8, 0.9]), (("a literal entity", False), [0.7, 0.8, 0.9]),
@ -67,14 +64,8 @@ class TestGetGraphEmbeddings:
async def receiver(msg): async def receiver(msg):
received.append(msg) received.append(msg)
# Act await store.get_graph_embeddings("alice", "doc-1", receiver)
await store.get_graph_embeddings(
user="alice",
document_id="doc-1",
receiver=receiver,
)
# Assert
mock_async_execute.assert_called_once_with( mock_async_execute.assert_called_once_with(
store.cassandra, store.cassandra,
store.get_graph_embeddings_stmt, store.get_graph_embeddings_stmt,
@ -86,7 +77,6 @@ class TestGetGraphEmbeddings:
assert isinstance(ge, GraphEmbeddings) assert isinstance(ge, GraphEmbeddings)
assert isinstance(ge.metadata, Metadata) assert isinstance(ge.metadata, Metadata)
assert ge.metadata.id == "doc-1" assert ge.metadata.id == "doc-1"
assert ge.metadata.user == "alice"
assert len(ge.entities) == 3 assert len(ge.entities) == 3
assert all(isinstance(e, EntityEmbeddings) for e in ge.entities) assert all(isinstance(e, EntityEmbeddings) for e in ge.entities)
@ -122,7 +112,7 @@ class TestGetGraphEmbeddings:
async def receiver(msg): async def receiver(msg):
received.append(msg) received.append(msg)
await store.get_graph_embeddings("u", "d", receiver) await store.get_graph_embeddings("w", "d", receiver)
assert len(received) == 1 assert len(received) == 1
assert received[0].entities == [] assert received[0].entities == []
@ -149,7 +139,7 @@ class TestGetGraphEmbeddings:
async def receiver(msg): async def receiver(msg):
received.append(msg) received.append(msg)
await store.get_graph_embeddings("u", "d", receiver) await store.get_graph_embeddings("w", "d", receiver)
assert len(received) == 2 assert len(received) == 2
assert received[0].entities[0].entity.iri == "http://example.org/a" assert received[0].entities[0].entity.iri == "http://example.org/a"
@ -194,7 +184,6 @@ class TestGetTriples:
assert isinstance(triples_msg, Triples) assert isinstance(triples_msg, Triples)
assert isinstance(triples_msg.metadata, Metadata) assert isinstance(triples_msg.metadata, Metadata)
assert triples_msg.metadata.id == "doc-1" assert triples_msg.metadata.id == "doc-1"
assert triples_msg.metadata.user == "alice"
assert len(triples_msg.triples) == 1 assert len(triples_msg.triples) == 1
t = triples_msg.triples[0] t = triples_msg.triples[0]

View file

@ -30,7 +30,6 @@ def sample():
metadata=Metadata( metadata=Metadata(
id="doc-1", id="doc-1",
root="", root="",
user="alice",
collection="testcoll", collection="testcoll",
), ),
chunks=[ chunks=[
@ -56,7 +55,6 @@ class TestDocumentEmbeddingsTranslator:
assert isinstance(decoded, DocumentEmbeddings) assert isinstance(decoded, DocumentEmbeddings)
assert isinstance(decoded.metadata, Metadata) assert isinstance(decoded.metadata, Metadata)
assert decoded.metadata.id == "doc-1" assert decoded.metadata.id == "doc-1"
assert decoded.metadata.user == "alice"
assert decoded.metadata.collection == "testcoll" assert decoded.metadata.collection == "testcoll"
assert len(decoded.chunks) == 2 assert len(decoded.chunks) == 2

View file

@ -41,7 +41,6 @@ def translator():
def graph_embeddings_request(): def graph_embeddings_request():
return KnowledgeRequest( return KnowledgeRequest(
operation="put-kg-core", operation="put-kg-core",
user="alice",
id="doc-1", id="doc-1",
flow="default", flow="default",
collection="testcoll", collection="testcoll",
@ -49,7 +48,6 @@ def graph_embeddings_request():
metadata=Metadata( metadata=Metadata(
id="doc-1", id="doc-1",
root="", root="",
user="alice",
collection="testcoll", collection="testcoll",
), ),
entities=[ entities=[
@ -70,7 +68,6 @@ def graph_embeddings_request():
def triples_request(): def triples_request():
return KnowledgeRequest( return KnowledgeRequest(
operation="put-kg-core", operation="put-kg-core",
user="alice",
id="doc-1", id="doc-1",
flow="default", flow="default",
collection="testcoll", collection="testcoll",
@ -78,7 +75,6 @@ def triples_request():
metadata=Metadata( metadata=Metadata(
id="doc-1", id="doc-1",
root="", root="",
user="alice",
collection="testcoll", collection="testcoll",
), ),
triples=[ triples=[
@ -123,7 +119,6 @@ class TestKnowledgeRequestTranslatorGraphEmbeddings:
assert isinstance(ge, GraphEmbeddings) assert isinstance(ge, GraphEmbeddings)
assert isinstance(ge.metadata, Metadata) assert isinstance(ge.metadata, Metadata)
assert ge.metadata.id == "doc-1" assert ge.metadata.id == "doc-1"
assert ge.metadata.user == "alice"
assert ge.metadata.collection == "testcoll" assert ge.metadata.collection == "testcoll"
assert len(ge.entities) == 2 assert len(ge.entities) == 2
@ -143,7 +138,6 @@ class TestKnowledgeRequestTranslatorTriples:
assert decoded.triples is not None assert decoded.triples is not None
assert isinstance(decoded.triples.metadata, Metadata) assert isinstance(decoded.triples.metadata, Metadata)
assert decoded.triples.metadata.id == "doc-1" assert decoded.triples.metadata.id == "doc-1"
assert decoded.triples.metadata.user == "alice"
assert decoded.triples.metadata.collection == "testcoll" assert decoded.triples.metadata.collection == "testcoll"
assert len(decoded.triples.triples) == 1 assert len(decoded.triples.triples) == 1

View file

@ -12,7 +12,6 @@ from trustgraph.api import Api
from trustgraph.api.types import hash, Uri, Literal, Triple from trustgraph.api.types import hash, Uri, Literal, Triple
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_user = 'trustgraph'
default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")

View file

@ -40,7 +40,6 @@ def load_structured_data(
sample_chars: int = 500, sample_chars: int = 500,
schema_name: str = None, schema_name: str = None,
flow: str = 'default', flow: str = 'default',
user: str = 'trustgraph',
collection: str = 'default', collection: str = 'default',
dry_run: bool = False, dry_run: bool = False,
verbose: bool = False, verbose: bool = False,
@ -64,7 +63,6 @@ def load_structured_data(
sample_chars: Maximum characters to read for sampling sample_chars: Maximum characters to read for sampling
schema_name: Target schema name for generation schema_name: Target schema name for generation
flow: TrustGraph flow name to use for prompts flow: TrustGraph flow name to use for prompts
user: User name for metadata (default: trustgraph)
collection: Collection name for metadata (default: default) collection: Collection name for metadata (default: default)
dry_run: If True, validate but don't import data dry_run: If True, validate but don't import data
verbose: Enable verbose logging verbose: Enable verbose logging
@ -112,7 +110,7 @@ def load_structured_data(
try: try:
# Use shared pipeline for preview (small sample) # Use shared pipeline for preview (small sample)
preview_objects, _ = _process_data_pipeline(input_file, temp_descriptor.name, user, collection, sample_size=5) preview_objects, _ = _process_data_pipeline(input_file, temp_descriptor.name, collection, sample_size=5)
# Show preview # Show preview
print("📊 Data Preview (first few records):") print("📊 Data Preview (first few records):")
@ -133,7 +131,7 @@ def load_structured_data(
print("🚀 Importing data to TrustGraph...") print("🚀 Importing data to TrustGraph...")
# Use shared pipeline for full processing (no sample limit) # Use shared pipeline for full processing (no sample limit)
output_objects, descriptor = _process_data_pipeline(input_file, temp_descriptor.name, user, collection) output_objects, descriptor = _process_data_pipeline(input_file, temp_descriptor.name, collection)
# Get batch size from descriptor # Get batch size from descriptor
batch_size = descriptor.get('output', {}).get('options', {}).get('batch_size', 1000) batch_size = descriptor.get('output', {}).get('options', {}).get('batch_size', 1000)
@ -244,7 +242,7 @@ def load_structured_data(
logger.info(f"Parsing {input_file} with descriptor {descriptor_file}...") logger.info(f"Parsing {input_file} with descriptor {descriptor_file}...")
# Use shared pipeline # Use shared pipeline
output_records, descriptor = _process_data_pipeline(input_file, descriptor_file, user, collection, sample_size) output_records, descriptor = _process_data_pipeline(input_file, descriptor_file, collection, sample_size)
# Output results # Output results
if output_file: if output_file:
@ -288,7 +286,7 @@ def load_structured_data(
logger.info(f"Loading {input_file} to TrustGraph using descriptor {descriptor_file}...") logger.info(f"Loading {input_file} to TrustGraph using descriptor {descriptor_file}...")
# Use shared pipeline (no sample_size limit for full load) # Use shared pipeline (no sample_size limit for full load)
output_records, descriptor = _process_data_pipeline(input_file, descriptor_file, user, collection) output_records, descriptor = _process_data_pipeline(input_file, descriptor_file, collection)
# Get batch size from descriptor or use default # Get batch size from descriptor or use default
batch_size = descriptor.get('output', {}).get('options', {}).get('batch_size', 1000) batch_size = descriptor.get('output', {}).get('options', {}).get('batch_size', 1000)
@ -529,7 +527,7 @@ def _apply_transformations(records, mappings):
return processed_records return processed_records
def _format_extracted_objects(processed_records, descriptor, user, collection): def _format_extracted_objects(processed_records, descriptor, collection):
"""Convert to TrustGraph ExtractedObject format""" """Convert to TrustGraph ExtractedObject format"""
output_records = [] output_records = []
schema_name = descriptor.get('output', {}).get('schema_name', 'default') schema_name = descriptor.get('output', {}).get('schema_name', 'default')
@ -540,7 +538,6 @@ def _format_extracted_objects(processed_records, descriptor, user, collection):
"metadata": { "metadata": {
"id": f"parsed-{len(output_records)+1}", "id": f"parsed-{len(output_records)+1}",
"metadata": [], # Empty metadata triples "metadata": [], # Empty metadata triples
"user": user,
"collection": collection "collection": collection
}, },
"schema_name": schema_name, "schema_name": schema_name,
@ -553,7 +550,7 @@ def _format_extracted_objects(processed_records, descriptor, user, collection):
return output_records return output_records
def _process_data_pipeline(input_file, descriptor_file, user, collection, sample_size=None): def _process_data_pipeline(input_file, descriptor_file, collection, sample_size=None):
"""Shared pipeline: load descriptor → read → parse → transform → format""" """Shared pipeline: load descriptor → read → parse → transform → format"""
# Load descriptor configuration # Load descriptor configuration
descriptor = _load_descriptor(descriptor_file) descriptor = _load_descriptor(descriptor_file)
@ -570,7 +567,7 @@ def _process_data_pipeline(input_file, descriptor_file, user, collection, sample
processed_records = _apply_transformations(parsed_records, mappings) processed_records = _apply_transformations(parsed_records, mappings)
# Format output for TrustGraph ExtractedObject structure # Format output for TrustGraph ExtractedObject structure
output_records = _format_extracted_objects(processed_records, descriptor, user, collection) output_records = _format_extracted_objects(processed_records, descriptor, collection)
return output_records, descriptor return output_records, descriptor
@ -1048,7 +1045,6 @@ For more information on the descriptor format, see:
sample_chars=args.sample_chars, sample_chars=args.sample_chars,
schema_name=args.schema_name, schema_name=args.schema_name,
flow=args.flow, flow=args.flow,
user=args.user,
collection=args.collection, collection=args.collection,
dry_run=args.dry_run, dry_run=args.dry_run,
verbose=args.verbose, verbose=args.verbose,

View file

@ -6,9 +6,9 @@ import re
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def make_safe_collection_name(user, collection, prefix): def make_safe_collection_name(workspace, collection, prefix):
""" """
Create a safe Milvus collection name from user/collection parameters. Create a safe Milvus collection name from workspace/collection parameters.
Milvus only allows letters, numbers, and underscores. Milvus only allows letters, numbers, and underscores.
""" """
def sanitize(s): def sanitize(s):
@ -23,10 +23,10 @@ def make_safe_collection_name(user, collection, prefix):
safe = 'default' safe = 'default'
return safe return safe
safe_user = sanitize(user) safe_workspace = sanitize(workspace)
safe_collection = sanitize(collection) safe_collection = sanitize(collection)
return f"{prefix}_{safe_user}_{safe_collection}" return f"{prefix}_{safe_workspace}_{safe_collection}"
class DocVectors: class DocVectors:
@ -49,26 +49,26 @@ class DocVectors:
self.next_reload = time.time() + self.reload_time self.next_reload = time.time() + self.reload_time
logger.debug(f"Reload at {self.next_reload}") logger.debug(f"Reload at {self.next_reload}")
def collection_exists(self, user, collection): def collection_exists(self, workspace, collection):
""" """
Check if any collection exists for this user/collection combination. Check if any collection exists for this workspace/collection combination.
Since collections are dimension-specific, this checks if ANY dimension variant exists. Since collections are dimension-specific, this checks if ANY dimension variant exists.
""" """
base_name = make_safe_collection_name(user, collection, self.prefix) base_name = make_safe_collection_name(workspace, collection, self.prefix)
prefix = f"{base_name}_" prefix = f"{base_name}_"
all_collections = self.client.list_collections() all_collections = self.client.list_collections()
return any(coll.startswith(prefix) for coll in all_collections) return any(coll.startswith(prefix) for coll in all_collections)
def create_collection(self, user, collection, dimension=384): def create_collection(self, workspace, collection, dimension=384):
""" """
No-op for explicit collection creation. No-op for explicit collection creation.
Collections are created lazily on first insert with actual dimension. Collections are created lazily on first insert with actual dimension.
""" """
logger.info(f"Collection creation requested for {user}/{collection} - will be created lazily on first insert") logger.info(f"Collection creation requested for {workspace}/{collection} - will be created lazily on first insert")
def init_collection(self, dimension, user, collection): def init_collection(self, dimension, workspace, collection):
base_name = make_safe_collection_name(user, collection, self.prefix) base_name = make_safe_collection_name(workspace, collection, self.prefix)
collection_name = f"{base_name}_{dimension}" collection_name = f"{base_name}_{dimension}"
pkey_field = FieldSchema( pkey_field = FieldSchema(
@ -116,15 +116,15 @@ class DocVectors:
index_params=index_params index_params=index_params
) )
self.collections[(dimension, user, collection)] = collection_name self.collections[(dimension, workspace, collection)] = collection_name
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}") logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
def insert(self, embeds, chunk_id, user, collection): def insert(self, embeds, chunk_id, workspace, collection):
dim = len(embeds) dim = len(embeds)
if (dim, user, collection) not in self.collections: if (dim, workspace, collection) not in self.collections:
self.init_collection(dim, user, collection) self.init_collection(dim, workspace, collection)
data = [ data = [
{ {
@ -134,25 +134,25 @@ class DocVectors:
] ]
self.client.insert( self.client.insert(
collection_name=self.collections[(dim, user, collection)], collection_name=self.collections[(dim, workspace, collection)],
data=data data=data
) )
def search(self, embeds, user, collection, fields=["chunk_id"], limit=10): def search(self, embeds, workspace, collection, fields=["chunk_id"], limit=10):
dim = len(embeds) dim = len(embeds)
# Check if collection exists - return empty if not # Check if collection exists - return empty if not
if (dim, user, collection) not in self.collections: if (dim, workspace, collection) not in self.collections:
base_name = make_safe_collection_name(user, collection, self.prefix) base_name = make_safe_collection_name(workspace, collection, self.prefix)
collection_name = f"{base_name}_{dim}" collection_name = f"{base_name}_{dim}"
if not self.client.has_collection(collection_name): if not self.client.has_collection(collection_name):
logger.info(f"Collection {collection_name} does not exist, returning empty results") logger.info(f"Collection {collection_name} does not exist, returning empty results")
return [] return []
# Collection exists but not in cache, add it # Collection exists but not in cache, add it
self.collections[(dim, user, collection)] = collection_name self.collections[(dim, workspace, collection)] = collection_name
coll = self.collections[(dim, user, collection)] coll = self.collections[(dim, workspace, collection)]
logger.debug("Loading...") logger.debug("Loading...")
self.client.load_collection( self.client.load_collection(
@ -181,12 +181,12 @@ class DocVectors:
return res return res
def delete_collection(self, user, collection): def delete_collection(self, workspace, collection):
""" """
Delete all dimension variants of the collection for the given user/collection. Delete all dimension variants of the collection for the given workspace/collection.
Since collections are created with dimension suffixes, we need to find and delete all. Since collections are created with dimension suffixes, we need to find and delete all.
""" """
base_name = make_safe_collection_name(user, collection, self.prefix) base_name = make_safe_collection_name(workspace, collection, self.prefix)
prefix = f"{base_name}_" prefix = f"{base_name}_"
# Get all collections and filter for matches # Get all collections and filter for matches
@ -199,10 +199,10 @@ class DocVectors:
for collection_name in matching_collections: for collection_name in matching_collections:
self.client.drop_collection(collection_name) self.client.drop_collection(collection_name)
logger.info(f"Deleted Milvus collection: {collection_name}") logger.info(f"Deleted Milvus collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}") logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")
# Remove from our local cache # Remove from our local cache
keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection] keys_to_remove = [key for key in self.collections.keys() if key[1] == workspace and key[2] == collection]
for key in keys_to_remove: for key in keys_to_remove:
del self.collections[key] del self.collections[key]

View file

@ -6,9 +6,9 @@ import re
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def make_safe_collection_name(user, collection, prefix): def make_safe_collection_name(workspace, collection, prefix):
""" """
Create a safe Milvus collection name from user/collection parameters. Create a safe Milvus collection name from workspace/collection parameters.
Milvus only allows letters, numbers, and underscores. Milvus only allows letters, numbers, and underscores.
""" """
def sanitize(s): def sanitize(s):
@ -23,10 +23,10 @@ def make_safe_collection_name(user, collection, prefix):
safe = 'default' safe = 'default'
return safe return safe
safe_user = sanitize(user) safe_workspace = sanitize(workspace)
safe_collection = sanitize(collection) safe_collection = sanitize(collection)
return f"{prefix}_{safe_user}_{safe_collection}" return f"{prefix}_{safe_workspace}_{safe_collection}"
class EntityVectors: class EntityVectors:
@ -49,26 +49,26 @@ class EntityVectors:
self.next_reload = time.time() + self.reload_time self.next_reload = time.time() + self.reload_time
logger.debug(f"Reload at {self.next_reload}") logger.debug(f"Reload at {self.next_reload}")
def collection_exists(self, user, collection): def collection_exists(self, workspace, collection):
""" """
Check if any collection exists for this user/collection combination. Check if any collection exists for this workspace/collection combination.
Since collections are dimension-specific, this checks if ANY dimension variant exists. Since collections are dimension-specific, this checks if ANY dimension variant exists.
""" """
base_name = make_safe_collection_name(user, collection, self.prefix) base_name = make_safe_collection_name(workspace, collection, self.prefix)
prefix = f"{base_name}_" prefix = f"{base_name}_"
all_collections = self.client.list_collections() all_collections = self.client.list_collections()
return any(coll.startswith(prefix) for coll in all_collections) return any(coll.startswith(prefix) for coll in all_collections)
def create_collection(self, user, collection, dimension=384): def create_collection(self, workspace, collection, dimension=384):
""" """
No-op for explicit collection creation. No-op for explicit collection creation.
Collections are created lazily on first insert with actual dimension. Collections are created lazily on first insert with actual dimension.
""" """
logger.info(f"Collection creation requested for {user}/{collection} - will be created lazily on first insert") logger.info(f"Collection creation requested for {workspace}/{collection} - will be created lazily on first insert")
def init_collection(self, dimension, user, collection): def init_collection(self, dimension, workspace, collection):
base_name = make_safe_collection_name(user, collection, self.prefix) base_name = make_safe_collection_name(workspace, collection, self.prefix)
collection_name = f"{base_name}_{dimension}" collection_name = f"{base_name}_{dimension}"
pkey_field = FieldSchema( pkey_field = FieldSchema(
@ -122,15 +122,15 @@ class EntityVectors:
index_params=index_params index_params=index_params
) )
self.collections[(dimension, user, collection)] = collection_name self.collections[(dimension, workspace, collection)] = collection_name
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}") logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
def insert(self, embeds, entity, user, collection, chunk_id=""): def insert(self, embeds, entity, workspace, collection, chunk_id=""):
dim = len(embeds) dim = len(embeds)
if (dim, user, collection) not in self.collections: if (dim, workspace, collection) not in self.collections:
self.init_collection(dim, user, collection) self.init_collection(dim, workspace, collection)
data = [ data = [
{ {
@ -141,25 +141,25 @@ class EntityVectors:
] ]
self.client.insert( self.client.insert(
collection_name=self.collections[(dim, user, collection)], collection_name=self.collections[(dim, workspace, collection)],
data=data data=data
) )
def search(self, embeds, user, collection, fields=["entity"], limit=10): def search(self, embeds, workspace, collection, fields=["entity"], limit=10):
dim = len(embeds) dim = len(embeds)
# Check if collection exists - return empty if not # Check if collection exists - return empty if not
if (dim, user, collection) not in self.collections: if (dim, workspace, collection) not in self.collections:
base_name = make_safe_collection_name(user, collection, self.prefix) base_name = make_safe_collection_name(workspace, collection, self.prefix)
collection_name = f"{base_name}_{dim}" collection_name = f"{base_name}_{dim}"
if not self.client.has_collection(collection_name): if not self.client.has_collection(collection_name):
logger.info(f"Collection {collection_name} does not exist, returning empty results") logger.info(f"Collection {collection_name} does not exist, returning empty results")
return [] return []
# Collection exists but not in cache, add it # Collection exists but not in cache, add it
self.collections[(dim, user, collection)] = collection_name self.collections[(dim, workspace, collection)] = collection_name
coll = self.collections[(dim, user, collection)] coll = self.collections[(dim, workspace, collection)]
logger.debug("Loading...") logger.debug("Loading...")
self.client.load_collection( self.client.load_collection(
@ -188,12 +188,12 @@ class EntityVectors:
return res return res
def delete_collection(self, user, collection): def delete_collection(self, workspace, collection):
""" """
Delete all dimension variants of the collection for the given user/collection. Delete all dimension variants of the collection for the given workspace/collection.
Since collections are created with dimension suffixes, we need to find and delete all. Since collections are created with dimension suffixes, we need to find and delete all.
""" """
base_name = make_safe_collection_name(user, collection, self.prefix) base_name = make_safe_collection_name(workspace, collection, self.prefix)
prefix = f"{base_name}_" prefix = f"{base_name}_"
# Get all collections and filter for matches # Get all collections and filter for matches
@ -206,10 +206,10 @@ class EntityVectors:
for collection_name in matching_collections: for collection_name in matching_collections:
self.client.drop_collection(collection_name) self.client.drop_collection(collection_name)
logger.info(f"Deleted Milvus collection: {collection_name}") logger.info(f"Deleted Milvus collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}") logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")
# Remove from our local cache # Remove from our local cache
keys_to_remove = [key for key in self.collections.keys() if key[1] == user and key[2] == collection] keys_to_remove = [key for key in self.collections.keys() if key[1] == workspace and key[2] == collection]
for key in keys_to_remove: for key in keys_to_remove:
del self.collections[key] del self.collections[key]

View file

@ -70,7 +70,7 @@ class GraphQLSchemaBuilder:
Build the GraphQL schema with the provided query callback. Build the GraphQL schema with the provided query callback.
The query callback will be invoked when resolving queries, with: The query callback will be invoked when resolving queries, with:
- user: str - workspace: str
- collection: str - collection: str
- schema_name: str - schema_name: str
- row_schema: RowSchema - row_schema: RowSchema
@ -228,7 +228,7 @@ class GraphQLSchemaBuilder:
limit: Optional[int] = 100 limit: Optional[int] = 100
) -> List[graphql_type]: ) -> List[graphql_type]:
# Get context values # Get context values
user = info.context["user"] workspace = info.context["workspace"]
collection = info.context["collection"] collection = info.context["collection"]
# Parse the where clause # Parse the where clause
@ -236,7 +236,7 @@ class GraphQLSchemaBuilder:
# Call the query backend # Call the query backend
results = await query_callback( results = await query_callback(
user, collection, schema_name, row_schema, workspace, collection, schema_name, row_schema,
filters, limit, order_by, direction filters, limit, order_by, direction
) )

View file

@ -167,7 +167,7 @@ class QueryExplainer:
question_components, query_results, processing_metadata question_components, query_results, processing_metadata
) )
# Generate user-friendly explanation # Generate workspace-friendly explanation
user_friendly_explanation = self._generate_user_friendly_explanation( user_friendly_explanation = self._generate_user_friendly_explanation(
question, question_components, ontology_subsets, final_answer question, question_components, ontology_subsets, final_answer
) )
@ -503,7 +503,7 @@ class QueryExplainer:
question_components: QuestionComponents, question_components: QuestionComponents,
ontology_subsets: List[QueryOntologySubset], ontology_subsets: List[QueryOntologySubset],
final_answer: str) -> str: final_answer: str) -> str:
"""Generate user-friendly explanation of the process.""" """Generate workspace-friendly explanation of the process."""
explanation_parts = [] explanation_parts = []
# Introduction # Introduction

View file

@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
@dataclass @dataclass
class QueryRequest: class QueryRequest:
"""Query request from user.""" """Query request from workspace."""
question: str question: str
context: Optional[str] = None context: Optional[str] = None
ontology_hint: Optional[str] = None ontology_hint: Optional[str] = None

View file

@ -1,6 +1,6 @@
""" """
Question analyzer for ontology-sensitive query system. Question analyzer for ontology-sensitive query system.
Decomposes user questions into semantic components. Decomposes workspace questions into semantic components.
""" """
import logging import logging

View file

@ -1,7 +1,7 @@
""" """
Row embeddings query service for Qdrant. Row embeddings query service for Qdrant.
Input is query vectors plus user/collection/schema context. Input is query vectors plus workspace/collection/schema context.
Output is matching row index information (index_name, index_value) for Output is matching row index information (index_name, index_value) for
use in subsequent Cassandra lookups. use in subsequent Cassandra lookups.
""" """
@ -70,10 +70,10 @@ class Processor(FlowProcessor):
safe_name = 'r_' + safe_name safe_name = 'r_' + safe_name
return safe_name.lower() return safe_name.lower()
def find_collection(self, user: str, collection: str, schema_name: str) -> Optional[str]: def find_collection(self, workspace: str, collection: str, schema_name: str) -> Optional[str]:
"""Find the Qdrant collection for a given user/collection/schema""" """Find the Qdrant collection for a given workspace/collection/schema"""
prefix = ( prefix = (
f"rows_{self.sanitize_name(user)}_" f"rows_{self.sanitize_name(workspace)}_"
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_" f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
) )
@ -163,7 +163,7 @@ class Processor(FlowProcessor):
logger.debug( logger.debug(
f"Handling row embeddings query for " f"Handling row embeddings query for "
f"{request.user}/{request.collection}/{request.schema_name}..." f"{flow.workspace}/{request.collection}/{request.schema_name}..."
) )
# Execute query # Execute query

View file

@ -238,7 +238,7 @@ class Processor(FlowProcessor):
async def query_cassandra( async def query_cassandra(
self, self,
user: str, workspace: str,
collection: str, collection: str,
schema_name: str, schema_name: str,
row_schema: RowSchema, row_schema: RowSchema,
@ -256,7 +256,7 @@ class Processor(FlowProcessor):
# Connect if needed # Connect if needed
self.connect_cassandra() self.connect_cassandra()
safe_keyspace = self.sanitize_name(user) safe_keyspace = self.sanitize_name(workspace)
# Try to find an index that matches the filters # Try to find an index that matches the filters
index_match = self.find_matching_index(row_schema, filters) index_match = self.find_matching_index(row_schema, filters)
@ -409,7 +409,6 @@ class Processor(FlowProcessor):
query: str, query: str,
variables: Dict[str, Any], variables: Dict[str, Any],
operation_name: Optional[str], operation_name: Optional[str],
user: str,
collection: str collection: str
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Execute a GraphQL query against the workspace's schema""" """Execute a GraphQL query against the workspace's schema"""
@ -424,7 +423,7 @@ class Processor(FlowProcessor):
# Create context for the query # Create context for the query
context = { context = {
"processor": self, "processor": self,
"user": user, "workspace": workspace,
"collection": collection "collection": collection
} }
@ -479,7 +478,6 @@ class Processor(FlowProcessor):
query=request.query, query=request.query,
variables=dict(request.variables) if request.variables else {}, variables=dict(request.variables) if request.variables else {},
operation_name=request.operation_name, operation_name=request.operation_name,
user=request.user,
collection=request.collection collection=request.collection
) )

View file

@ -30,14 +30,14 @@ class EvaluationError(Exception):
pass pass
async def evaluate(node, triples_client, user, collection, limit=10000): async def evaluate(node, triples_client, workspace, collection, limit=10000):
""" """
Evaluate a SPARQL algebra node. Evaluate a SPARQL algebra node.
Args: Args:
node: rdflib CompValue algebra node node: rdflib CompValue algebra node
triples_client: TriplesClient instance for triple pattern queries triples_client: TriplesClient instance for triple pattern queries
user: user/keyspace identifier workspace: workspace/keyspace identifier
collection: collection identifier collection: collection identifier
limit: safety limit on results limit: safety limit on results
@ -55,24 +55,24 @@ async def evaluate(node, triples_client, user, collection, limit=10000):
logger.warning(f"Unsupported algebra node: {name}") logger.warning(f"Unsupported algebra node: {name}")
return [{}] return [{}]
return await handler(node, triples_client, user, collection, limit) return await handler(node, triples_client, workspace, collection, limit)
# --- Node handlers --- # --- Node handlers ---
async def _eval_select_query(node, tc, user, collection, limit): async def _eval_select_query(node, tc, workspace, collection, limit):
"""Evaluate a SelectQuery node.""" """Evaluate a SelectQuery node."""
return await evaluate(node.p, tc, user, collection, limit) return await evaluate(node.p, tc, workspace, collection, limit)
async def _eval_project(node, tc, user, collection, limit): async def _eval_project(node, tc, workspace, collection, limit):
"""Evaluate a Project node (SELECT variable projection).""" """Evaluate a Project node (SELECT variable projection)."""
solutions = await evaluate(node.p, tc, user, collection, limit) solutions = await evaluate(node.p, tc, workspace, collection, limit)
variables = [str(v) for v in node.PV] variables = [str(v) for v in node.PV]
return project(solutions, variables) return project(solutions, variables)
async def _eval_bgp(node, tc, user, collection, limit): async def _eval_bgp(node, tc, workspace, collection, limit):
""" """
Evaluate a Basic Graph Pattern. Evaluate a Basic Graph Pattern.
@ -107,7 +107,7 @@ async def _eval_bgp(node, tc, user, collection, limit):
# Query the triples store # Query the triples store
results = await _query_pattern( results = await _query_pattern(
tc, s_val, p_val, o_val, user, collection, limit tc, s_val, p_val, o_val, workspace, collection, limit
) )
# Map results back to variable bindings, # Map results back to variable bindings,
@ -130,17 +130,17 @@ async def _eval_bgp(node, tc, user, collection, limit):
return solutions[:limit] return solutions[:limit]
async def _eval_join(node, tc, user, collection, limit): async def _eval_join(node, tc, workspace, collection, limit):
"""Evaluate a Join node.""" """Evaluate a Join node."""
left = await evaluate(node.p1, tc, user, collection, limit) left = await evaluate(node.p1, tc, workspace, collection, limit)
right = await evaluate(node.p2, tc, user, collection, limit) right = await evaluate(node.p2, tc, workspace, collection, limit)
return hash_join(left, right)[:limit] return hash_join(left, right)[:limit]
async def _eval_left_join(node, tc, user, collection, limit): async def _eval_left_join(node, tc, workspace, collection, limit):
"""Evaluate a LeftJoin node (OPTIONAL).""" """Evaluate a LeftJoin node (OPTIONAL)."""
left_sols = await evaluate(node.p1, tc, user, collection, limit) left_sols = await evaluate(node.p1, tc, workspace, collection, limit)
right_sols = await evaluate(node.p2, tc, user, collection, limit) right_sols = await evaluate(node.p2, tc, workspace, collection, limit)
filter_fn = None filter_fn = None
if hasattr(node, "expr") and node.expr is not None: if hasattr(node, "expr") and node.expr is not None:
@ -153,16 +153,16 @@ async def _eval_left_join(node, tc, user, collection, limit):
return left_join(left_sols, right_sols, filter_fn)[:limit] return left_join(left_sols, right_sols, filter_fn)[:limit]
async def _eval_union(node, tc, user, collection, limit): async def _eval_union(node, tc, workspace, collection, limit):
"""Evaluate a Union node.""" """Evaluate a Union node."""
left = await evaluate(node.p1, tc, user, collection, limit) left = await evaluate(node.p1, tc, workspace, collection, limit)
right = await evaluate(node.p2, tc, user, collection, limit) right = await evaluate(node.p2, tc, workspace, collection, limit)
return union(left, right)[:limit] return union(left, right)[:limit]
async def _eval_filter(node, tc, user, collection, limit): async def _eval_filter(node, tc, workspace, collection, limit):
"""Evaluate a Filter node.""" """Evaluate a Filter node."""
solutions = await evaluate(node.p, tc, user, collection, limit) solutions = await evaluate(node.p, tc, workspace, collection, limit)
expr = node.expr expr = node.expr
return [ return [
sol for sol in solutions sol for sol in solutions
@ -170,22 +170,22 @@ async def _eval_filter(node, tc, user, collection, limit):
] ]
async def _eval_distinct(node, tc, user, collection, limit): async def _eval_distinct(node, tc, workspace, collection, limit):
"""Evaluate a Distinct node.""" """Evaluate a Distinct node."""
solutions = await evaluate(node.p, tc, user, collection, limit) solutions = await evaluate(node.p, tc, workspace, collection, limit)
return distinct(solutions) return distinct(solutions)
async def _eval_reduced(node, tc, user, collection, limit): async def _eval_reduced(node, tc, workspace, collection, limit):
"""Evaluate a Reduced node (like Distinct but implementation-defined).""" """Evaluate a Reduced node (like Distinct but implementation-defined)."""
# Treat same as Distinct # Treat same as Distinct
solutions = await evaluate(node.p, tc, user, collection, limit) solutions = await evaluate(node.p, tc, workspace, collection, limit)
return distinct(solutions) return distinct(solutions)
async def _eval_order_by(node, tc, user, collection, limit): async def _eval_order_by(node, tc, workspace, collection, limit):
"""Evaluate an OrderBy node.""" """Evaluate an OrderBy node."""
solutions = await evaluate(node.p, tc, user, collection, limit) solutions = await evaluate(node.p, tc, workspace, collection, limit)
key_fns = [] key_fns = []
for cond in node.expr: for cond in node.expr:
@ -206,7 +206,7 @@ async def _eval_order_by(node, tc, user, collection, limit):
return order_by(solutions, key_fns) return order_by(solutions, key_fns)
async def _eval_slice(node, tc, user, collection, limit): async def _eval_slice(node, tc, workspace, collection, limit):
"""Evaluate a Slice node (LIMIT/OFFSET).""" """Evaluate a Slice node (LIMIT/OFFSET)."""
# Pass tighter limit downstream if possible # Pass tighter limit downstream if possible
inner_limit = limit inner_limit = limit
@ -214,13 +214,13 @@ async def _eval_slice(node, tc, user, collection, limit):
offset = node.start or 0 offset = node.start or 0
inner_limit = min(limit, offset + node.length) inner_limit = min(limit, offset + node.length)
solutions = await evaluate(node.p, tc, user, collection, inner_limit) solutions = await evaluate(node.p, tc, workspace, collection, inner_limit)
return slice_solutions(solutions, node.start or 0, node.length) return slice_solutions(solutions, node.start or 0, node.length)
async def _eval_extend(node, tc, user, collection, limit): async def _eval_extend(node, tc, workspace, collection, limit):
"""Evaluate an Extend node (BIND).""" """Evaluate an Extend node (BIND)."""
solutions = await evaluate(node.p, tc, user, collection, limit) solutions = await evaluate(node.p, tc, workspace, collection, limit)
var_name = str(node.var) var_name = str(node.var)
expr = node.expr expr = node.expr
@ -246,9 +246,9 @@ async def _eval_extend(node, tc, user, collection, limit):
return result return result
async def _eval_group(node, tc, user, collection, limit): async def _eval_group(node, tc, workspace, collection, limit):
"""Evaluate a Group node (GROUP BY with aggregation).""" """Evaluate a Group node (GROUP BY with aggregation)."""
solutions = await evaluate(node.p, tc, user, collection, limit) solutions = await evaluate(node.p, tc, workspace, collection, limit)
# Extract grouping expressions # Extract grouping expressions
group_exprs = [] group_exprs = []
@ -289,9 +289,9 @@ async def _eval_group(node, tc, user, collection, limit):
return result return result
async def _eval_aggregate_join(node, tc, user, collection, limit): async def _eval_aggregate_join(node, tc, workspace, collection, limit):
"""Evaluate an AggregateJoin (aggregation functions after GROUP BY).""" """Evaluate an AggregateJoin (aggregation functions after GROUP BY)."""
solutions = await evaluate(node.p, tc, user, collection, limit) solutions = await evaluate(node.p, tc, workspace, collection, limit)
result = [] result = []
for sol in solutions: for sol in solutions:
@ -310,7 +310,7 @@ async def _eval_aggregate_join(node, tc, user, collection, limit):
return result return result
async def _eval_graph(node, tc, user, collection, limit): async def _eval_graph(node, tc, workspace, collection, limit):
"""Evaluate a Graph node (GRAPH clause).""" """Evaluate a Graph node (GRAPH clause)."""
term = node.term term = node.term
@ -319,16 +319,16 @@ async def _eval_graph(node, tc, user, collection, limit):
# We'd need to pass graph to triples queries # We'd need to pass graph to triples queries
# For now, evaluate inner pattern normally # For now, evaluate inner pattern normally
logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired") logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired")
return await evaluate(node.p, tc, user, collection, limit) return await evaluate(node.p, tc, workspace, collection, limit)
elif isinstance(term, Variable): elif isinstance(term, Variable):
# GRAPH ?g { ... } — variable graph # GRAPH ?g { ... } — variable graph
logger.info(f"GRAPH ?{term} clause - variable graph not yet wired") logger.info(f"GRAPH ?{term} clause - variable graph not yet wired")
return await evaluate(node.p, tc, user, collection, limit) return await evaluate(node.p, tc, workspace, collection, limit)
else: else:
return await evaluate(node.p, tc, user, collection, limit) return await evaluate(node.p, tc, workspace, collection, limit)
async def _eval_values(node, tc, user, collection, limit): async def _eval_values(node, tc, workspace, collection, limit):
"""Evaluate a VALUES clause (inline data).""" """Evaluate a VALUES clause (inline data)."""
variables = [str(v) for v in node.var] variables = [str(v) for v in node.var]
solutions = [] solutions = []
@ -343,9 +343,9 @@ async def _eval_values(node, tc, user, collection, limit):
return solutions return solutions
async def _eval_to_multiset(node, tc, user, collection, limit): async def _eval_to_multiset(node, tc, workspace, collection, limit):
"""Evaluate a ToMultiSet node (subquery).""" """Evaluate a ToMultiSet node (subquery)."""
return await evaluate(node.p, tc, user, collection, limit) return await evaluate(node.p, tc, workspace, collection, limit)
# --- Aggregate computation --- # --- Aggregate computation ---
@ -487,7 +487,7 @@ def _resolve_term(tmpl, solution):
return rdflib_term_to_term(tmpl) return rdflib_term_to_term(tmpl)
async def _query_pattern(tc, s, p, o, user, collection, limit): async def _query_pattern(tc, s, p, o, workspace, collection, limit):
""" """
Issue a streaming triple pattern query via TriplesClient. Issue a streaming triple pattern query via TriplesClient.
@ -496,7 +496,7 @@ async def _query_pattern(tc, s, p, o, user, collection, limit):
results = await tc.query( results = await tc.query(
s=s, p=p, o=o, s=s, p=p, o=o,
limit=limit, limit=limit,
user=user, workspace=workspace,
collection=collection, collection=collection,
) )
return results return results

View file

@ -141,7 +141,7 @@ class Processor(FlowProcessor):
solutions = await evaluate( solutions = await evaluate(
parsed.algebra, parsed.algebra,
triples_client, triples_client,
user=flow.workspace, workspace=flow.workspace,
collection=request.collection or "default", collection=request.collection or "default",
limit=request.limit or 10000, limit=request.limit or 10000,
) )

View file

@ -178,24 +178,24 @@ class Processor(TriplesQueryService):
self.cassandra_password = password self.cassandra_password = password
self.table = None self.table = None
def ensure_connection(self, user): def ensure_connection(self, workspace):
"""Ensure we have a connection to the correct keyspace.""" """Ensure we have a connection to the correct keyspace."""
if user != self.table: if workspace != self.table:
KGClass = EntityCentricKnowledgeGraph KGClass = EntityCentricKnowledgeGraph
if self.cassandra_username and self.cassandra_password: if self.cassandra_username and self.cassandra_password:
self.tg = KGClass( self.tg = KGClass(
hosts=self.cassandra_host, hosts=self.cassandra_host,
keyspace=user, keyspace=workspace,
username=self.cassandra_username, username=self.cassandra_username,
password=self.cassandra_password password=self.cassandra_password
) )
else: else:
self.tg = KGClass( self.tg = KGClass(
hosts=self.cassandra_host, hosts=self.cassandra_host,
keyspace=user, keyspace=workspace,
) )
self.table = user self.table = workspace
async def query_triples(self, workspace, query): async def query_triples(self, workspace, query):

View file

@ -67,7 +67,7 @@ class Processor(TriplesQueryService):
try: try:
user = workspace workspace = workspace
collection = query.collection if query.collection else "default" collection = query.collection if query.collection else "default"
triples = [] triples = []
@ -79,13 +79,13 @@ class Processor(TriplesQueryService):
# SPO # SPO
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) " "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN $src as src " "RETURN $src as src "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p), value=get_term_value(query.o), src=get_term_value(query.s), rel=get_term_value(query.p), value=get_term_value(query.o),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -93,13 +93,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o))) triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o)))
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) " "(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN $src as src " "RETURN $src as src "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p), uri=get_term_value(query.o), src=get_term_value(query.s), rel=get_term_value(query.p), uri=get_term_value(query.o),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -111,13 +111,13 @@ class Processor(TriplesQueryService):
# SP # SP
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) " "(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN dest.value as dest " "RETURN dest.value as dest "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p), src=get_term_value(query.s), rel=get_term_value(query.p),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -126,13 +126,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"])) triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"]))
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) " "(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN dest.uri as dest " "RETURN dest.uri as dest "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p), src=get_term_value(query.s), rel=get_term_value(query.p),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -147,13 +147,13 @@ class Processor(TriplesQueryService):
# SO # SO
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->" "[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) " "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel " "RETURN rel.uri as rel "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
src=get_term_value(query.s), value=get_term_value(query.o), src=get_term_value(query.s), value=get_term_value(query.o),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -162,13 +162,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o))) triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o)))
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->" "[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) " "(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel " "RETURN rel.uri as rel "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
src=get_term_value(query.s), uri=get_term_value(query.o), src=get_term_value(query.s), uri=get_term_value(query.o),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -181,13 +181,13 @@ class Processor(TriplesQueryService):
# S # S
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->" "[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) " "(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel, dest.value as dest " "RETURN rel.uri as rel, dest.value as dest "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
src=get_term_value(query.s), src=get_term_value(query.s),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -196,13 +196,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), data["rel"], data["dest"])) triples.append((get_term_value(query.s), data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->" "[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) " "(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel, dest.uri as dest " "RETURN rel.uri as rel, dest.uri as dest "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
src=get_term_value(query.s), src=get_term_value(query.s),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -220,13 +220,13 @@ class Processor(TriplesQueryService):
# PO # PO
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-" "MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) " "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src " "RETURN src.uri as src "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
uri=get_term_value(query.p), value=get_term_value(query.o), uri=get_term_value(query.p), value=get_term_value(query.o),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -235,13 +235,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], get_term_value(query.p), get_term_value(query.o))) triples.append((data["src"], get_term_value(query.p), get_term_value(query.o)))
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-" "MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $dest, user: $user, collection: $collection}) " "(dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src " "RETURN src.uri as src "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
uri=get_term_value(query.p), dest=get_term_value(query.o), uri=get_term_value(query.p), dest=get_term_value(query.o),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -254,13 +254,13 @@ class Processor(TriplesQueryService):
# P # P
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-" "MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) " "(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, dest.value as dest " "RETURN src.uri as src, dest.value as dest "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
uri=get_term_value(query.p), uri=get_term_value(query.p),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -269,13 +269,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], get_term_value(query.p), data["dest"])) triples.append((data["src"], get_term_value(query.p), data["dest"]))
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-" "MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) " "(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, dest.uri as dest " "RETURN src.uri as src, dest.uri as dest "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
uri=get_term_value(query.p), uri=get_term_value(query.p),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -290,13 +290,13 @@ class Processor(TriplesQueryService):
# O # O
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-" "MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->" "[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) " "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel " "RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
value=get_term_value(query.o), value=get_term_value(query.o),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -305,13 +305,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], data["rel"], get_term_value(query.o))) triples.append((data["src"], data["rel"], get_term_value(query.o)))
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-" "MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->" "[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) " "(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel " "RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
uri=get_term_value(query.o), uri=get_term_value(query.o),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -324,12 +324,12 @@ class Processor(TriplesQueryService):
# * # *
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-" "MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->" "[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) " "(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest " "RETURN src.uri as src, rel.uri as rel, dest.value as dest "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -338,12 +338,12 @@ class Processor(TriplesQueryService):
triples.append((data["src"], data["rel"], data["dest"])) triples.append((data["src"], data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-" "MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->" "[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) " "(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest " "RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )

View file

@ -67,7 +67,6 @@ class Processor(TriplesQueryService):
try: try:
user = workspace
collection = query.collection if query.collection else "default" collection = query.collection if query.collection else "default"
triples = [] triples = []
@ -79,13 +78,13 @@ class Processor(TriplesQueryService):
# SPO # SPO
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) " "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN $src as src " "RETURN $src as src "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p), value=get_term_value(query.o), src=get_term_value(query.s), rel=get_term_value(query.p), value=get_term_value(query.o),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -93,13 +92,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o))) triples.append((get_term_value(query.s), get_term_value(query.p), get_term_value(query.o)))
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) " "(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN $src as src " "RETURN $src as src "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p), uri=get_term_value(query.o), src=get_term_value(query.s), rel=get_term_value(query.p), uri=get_term_value(query.o),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -111,13 +110,13 @@ class Processor(TriplesQueryService):
# SP # SP
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) " "(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN dest.value as dest " "RETURN dest.value as dest "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p), src=get_term_value(query.s), rel=get_term_value(query.p),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -126,13 +125,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"])) triples.append((get_term_value(query.s), get_term_value(query.p), data["dest"]))
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $rel, user: $user, collection: $collection}]->" "[rel:Rel {uri: $rel, workspace: $workspace, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) " "(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN dest.uri as dest " "RETURN dest.uri as dest "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
src=get_term_value(query.s), rel=get_term_value(query.p), src=get_term_value(query.s), rel=get_term_value(query.p),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -147,13 +146,13 @@ class Processor(TriplesQueryService):
# SO # SO
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->" "[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) " "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel " "RETURN rel.uri as rel "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
src=get_term_value(query.s), value=get_term_value(query.o), src=get_term_value(query.s), value=get_term_value(query.o),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -162,13 +161,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o))) triples.append((get_term_value(query.s), data["rel"], get_term_value(query.o)))
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->" "[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) " "(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel " "RETURN rel.uri as rel "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
src=get_term_value(query.s), uri=get_term_value(query.o), src=get_term_value(query.s), uri=get_term_value(query.o),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -181,13 +180,13 @@ class Processor(TriplesQueryService):
# S # S
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->" "[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) " "(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel, dest.value as dest " "RETURN rel.uri as rel, dest.value as dest "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
src=get_term_value(query.s), src=get_term_value(query.s),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -196,13 +195,13 @@ class Processor(TriplesQueryService):
triples.append((get_term_value(query.s), data["rel"], data["dest"])) triples.append((get_term_value(query.s), data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection})-" "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->" "[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) " "(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN rel.uri as rel, dest.uri as dest " "RETURN rel.uri as rel, dest.uri as dest "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
src=get_term_value(query.s), src=get_term_value(query.s),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -220,13 +219,13 @@ class Processor(TriplesQueryService):
# PO # PO
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-" "MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) " "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src " "RETURN src.uri as src "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
uri=get_term_value(query.p), value=get_term_value(query.o), uri=get_term_value(query.p), value=get_term_value(query.o),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -235,13 +234,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], get_term_value(query.p), get_term_value(query.o))) triples.append((data["src"], get_term_value(query.p), get_term_value(query.o)))
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-" "MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $dest, user: $user, collection: $collection}) " "(dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src " "RETURN src.uri as src "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
uri=get_term_value(query.p), dest=get_term_value(query.o), uri=get_term_value(query.p), dest=get_term_value(query.o),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -254,13 +253,13 @@ class Processor(TriplesQueryService):
# P # P
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-" "MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) " "(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, dest.value as dest " "RETURN src.uri as src, dest.value as dest "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
uri=get_term_value(query.p), uri=get_term_value(query.p),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -269,13 +268,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], get_term_value(query.p), data["dest"])) triples.append((data["src"], get_term_value(query.p), data["dest"]))
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-" "MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {uri: $uri, user: $user, collection: $collection}]->" "[rel:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) " "(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, dest.uri as dest " "RETURN src.uri as src, dest.uri as dest "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
uri=get_term_value(query.p), uri=get_term_value(query.p),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -290,13 +289,13 @@ class Processor(TriplesQueryService):
# O # O
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-" "MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->" "[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {value: $value, user: $user, collection: $collection}) " "(dest:Literal {value: $value, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel " "RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
value=get_term_value(query.o), value=get_term_value(query.o),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -305,13 +304,13 @@ class Processor(TriplesQueryService):
triples.append((data["src"], data["rel"], get_term_value(query.o))) triples.append((data["src"], data["rel"], get_term_value(query.o)))
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-" "MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->" "[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {uri: $uri, user: $user, collection: $collection}) " "(dest:Node {uri: $uri, workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel " "RETURN src.uri as src, rel.uri as rel "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
uri=get_term_value(query.o), uri=get_term_value(query.o),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -324,12 +323,12 @@ class Processor(TriplesQueryService):
# * # *
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-" "MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->" "[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Literal {user: $user, collection: $collection}) " "(dest:Literal {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.value as dest " "RETURN src.uri as src, rel.uri as rel, dest.value as dest "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )
@ -338,12 +337,12 @@ class Processor(TriplesQueryService):
triples.append((data["src"], data["rel"], data["dest"])) triples.append((data["src"], data["rel"], data["dest"]))
records, summary, keys = self.io.execute_query( records, summary, keys = self.io.execute_query(
"MATCH (src:Node {user: $user, collection: $collection})-" "MATCH (src:Node {workspace: $workspace, collection: $collection})-"
"[rel:Rel {user: $user, collection: $collection}]->" "[rel:Rel {workspace: $workspace, collection: $collection}]->"
"(dest:Node {user: $user, collection: $collection}) " "(dest:Node {workspace: $workspace, collection: $collection}) "
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest " "RETURN src.uri as src, rel.uri as rel, dest.uri as dest "
"LIMIT " + str(query.limit), "LIMIT " + str(query.limit),
user=user, collection=collection, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
) )

View file

@ -60,27 +60,27 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
help=f'Milvus store URI (default: {default_store_uri})' help=f'Milvus store URI (default: {default_store_uri})'
) )
async def create_collection(self, user: str, collection: str, metadata: dict): async def create_collection(self, workspace: str, collection: str, metadata: dict):
""" """
Create collection via config push - collections are created lazily on first write Create collection via config push - collections are created lazily on first write
with the correct dimension determined from the actual embeddings. with the correct dimension determined from the actual embeddings.
""" """
try: try:
logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write") logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write")
self.vecstore.create_collection(user, collection) self.vecstore.create_collection(workspace, collection)
except Exception as e: except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise raise
async def delete_collection(self, user: str, collection: str): async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for document embeddings via config push""" """Delete the collection for document embeddings via config push"""
try: try:
self.vecstore.delete_collection(user, collection) self.vecstore.delete_collection(workspace, collection)
logger.info(f"Successfully deleted collection {user}/{collection}") logger.info(f"Successfully deleted collection {workspace}/{collection}")
except Exception as e: except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise raise
def run(): def run():

View file

@ -165,22 +165,22 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
help=f'Pinecone region, (default: {default_region}' help=f'Pinecone region, (default: {default_region}'
) )
async def create_collection(self, user: str, collection: str, metadata: dict): async def create_collection(self, workspace: str, collection: str, metadata: dict):
""" """
Create collection via config push - indexes are created lazily on first write Create collection via config push - indexes are created lazily on first write
with the correct dimension determined from the actual embeddings. with the correct dimension determined from the actual embeddings.
""" """
try: try:
logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write") logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write")
except Exception as e: except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise raise
async def delete_collection(self, user: str, collection: str): async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for document embeddings via config push""" """Delete the collection for document embeddings via config push"""
try: try:
prefix = f"d-{user}-{collection}-" prefix = f"d-{workspace}-{collection}-"
# Get all indexes and filter for matches # Get all indexes and filter for matches
all_indexes = self.pinecone.list_indexes() all_indexes = self.pinecone.list_indexes()
@ -195,10 +195,10 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
for index_name in matching_indexes: for index_name in matching_indexes:
self.pinecone.delete_index(index_name) self.pinecone.delete_index(index_name)
logger.info(f"Deleted Pinecone index: {index_name}") logger.info(f"Deleted Pinecone index: {index_name}")
logger.info(f"Deleted {len(matching_indexes)} index(es) for {user}/{collection}") logger.info(f"Deleted {len(matching_indexes)} index(es) for {workspace}/{collection}")
except Exception as e: except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise raise
def run(): def run():

View file

@ -107,22 +107,22 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
help=f'Qdrant API key (default: None)' help=f'Qdrant API key (default: None)'
) )
async def create_collection(self, user: str, collection: str, metadata: dict): async def create_collection(self, workspace: str, collection: str, metadata: dict):
""" """
Create collection via config push - collections are created lazily on first write Create collection via config push - collections are created lazily on first write
with the correct dimension determined from the actual embeddings. with the correct dimension determined from the actual embeddings.
""" """
try: try:
logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write") logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write")
except Exception as e: except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise raise
async def delete_collection(self, user: str, collection: str): async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for document embeddings via config push""" """Delete the collection for document embeddings via config push"""
try: try:
prefix = f"d_{user}_{collection}_" prefix = f"d_{workspace}_{collection}_"
# Get all collections and filter for matches # Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections all_collections = self.qdrant.get_collections().collections
@ -137,10 +137,10 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
for collection_name in matching_collections: for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name) self.qdrant.delete_collection(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}") logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}") logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")
except Exception as e: except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise raise
def run(): def run():

View file

@ -73,27 +73,27 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
help=f'Milvus store URI (default: {default_store_uri})' help=f'Milvus store URI (default: {default_store_uri})'
) )
async def create_collection(self, user: str, collection: str, metadata: dict): async def create_collection(self, workspace: str, collection: str, metadata: dict):
""" """
Create collection via config push - collections are created lazily on first write Create collection via config push - collections are created lazily on first write
with the correct dimension determined from the actual embeddings. with the correct dimension determined from the actual embeddings.
""" """
try: try:
logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write") logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write")
self.vecstore.create_collection(user, collection) self.vecstore.create_collection(workspace, collection)
except Exception as e: except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise raise
async def delete_collection(self, user: str, collection: str): async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for graph embeddings via config push""" """Delete the collection for graph embeddings via config push"""
try: try:
self.vecstore.delete_collection(user, collection) self.vecstore.delete_collection(workspace, collection)
logger.info(f"Successfully deleted collection {user}/{collection}") logger.info(f"Successfully deleted collection {workspace}/{collection}")
except Exception as e: except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise raise
def run(): def run():

View file

@ -183,22 +183,22 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
help=f'Pinecone region, (default: {default_region}' help=f'Pinecone region, (default: {default_region}'
) )
async def create_collection(self, user: str, collection: str, metadata: dict): async def create_collection(self, workspace: str, collection: str, metadata: dict):
""" """
Create collection via config push - indexes are created lazily on first write Create collection via config push - indexes are created lazily on first write
with the correct dimension determined from the actual embeddings. with the correct dimension determined from the actual embeddings.
""" """
try: try:
logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write") logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write")
except Exception as e: except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise raise
async def delete_collection(self, user: str, collection: str): async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for graph embeddings via config push""" """Delete the collection for graph embeddings via config push"""
try: try:
prefix = f"t-{user}-{collection}-" prefix = f"t-{workspace}-{collection}-"
# Get all indexes and filter for matches # Get all indexes and filter for matches
all_indexes = self.pinecone.list_indexes() all_indexes = self.pinecone.list_indexes()
@ -213,10 +213,10 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
for index_name in matching_indexes: for index_name in matching_indexes:
self.pinecone.delete_index(index_name) self.pinecone.delete_index(index_name)
logger.info(f"Deleted Pinecone index: {index_name}") logger.info(f"Deleted Pinecone index: {index_name}")
logger.info(f"Deleted {len(matching_indexes)} index(es) for {user}/{collection}") logger.info(f"Deleted {len(matching_indexes)} index(es) for {workspace}/{collection}")
except Exception as e: except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise raise
def run(): def run():

View file

@ -126,22 +126,22 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
help=f'Qdrant API key' help=f'Qdrant API key'
) )
async def create_collection(self, user: str, collection: str, metadata: dict): async def create_collection(self, workspace: str, collection: str, metadata: dict):
""" """
Create collection via config push - collections are created lazily on first write Create collection via config push - collections are created lazily on first write
with the correct dimension determined from the actual embeddings. with the correct dimension determined from the actual embeddings.
""" """
try: try:
logger.info(f"Collection create request for {user}/{collection} - will be created lazily on first write") logger.info(f"Collection create request for {workspace}/{collection} - will be created lazily on first write")
except Exception as e: except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise raise
async def delete_collection(self, user: str, collection: str): async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for graph embeddings via config push""" """Delete the collection for graph embeddings via config push"""
try: try:
prefix = f"t_{user}_{collection}_" prefix = f"t_{workspace}_{collection}_"
# Get all collections and filter for matches # Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections all_collections = self.qdrant.get_collections().collections
@ -156,10 +156,10 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
for collection_name in matching_collections: for collection_name in matching_collections:
self.qdrant.delete_collection(collection_name) self.qdrant.delete_collection(collection_name)
logger.info(f"Deleted Qdrant collection: {collection_name}") logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info(f"Deleted {len(matching_collections)} collection(s) for {user}/{collection}") logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}")
except Exception as e: except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise raise
def run(): def run():

View file

@ -2,13 +2,13 @@
Row embeddings writer for Qdrant (Stage 2). Row embeddings writer for Qdrant (Stage 2).
Consumes RowEmbeddings messages (which already contain computed vectors) Consumes RowEmbeddings messages (which already contain computed vectors)
and writes them to Qdrant. One Qdrant collection per (user, collection, schema_name) pair. and writes them to Qdrant. One Qdrant collection per (workspace, collection, schema_name) pair.
This follows the two-stage pattern used by graph-embeddings and document-embeddings: This follows the two-stage pattern used by graph-embeddings and document-embeddings:
Stage 1 (row-embeddings): Compute embeddings Stage 1 (row-embeddings): Compute embeddings
Stage 2 (this processor): Store embeddings Stage 2 (this processor): Store embeddings
Collection naming: rows_{user}_{collection}_{schema_name}_{dimension} Collection naming: rows_{workspace}_{collection}_{schema_name}_{dimension}
Payload structure: Payload structure:
- index_name: The indexed field(s) this embedding represents - index_name: The indexed field(s) this embedding represents
@ -77,10 +77,10 @@ class Processor(CollectionConfigHandler, FlowProcessor):
return safe_name.lower() return safe_name.lower()
def get_collection_name( def get_collection_name(
self, user: str, collection: str, schema_name: str, dimension: int self, workspace: str, collection: str, schema_name: str, dimension: int
) -> str: ) -> str:
"""Generate Qdrant collection name""" """Generate Qdrant collection name"""
safe_user = self.sanitize_name(user) safe_user = self.sanitize_name(workspace)
safe_collection = self.sanitize_name(collection) safe_collection = self.sanitize_name(collection)
safe_schema = self.sanitize_name(schema_name) safe_schema = self.sanitize_name(schema_name)
return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}" return f"rows_{safe_user}_{safe_collection}_{safe_schema}_{dimension}"
@ -169,17 +169,17 @@ class Processor(CollectionConfigHandler, FlowProcessor):
logger.info(f"Wrote {embeddings_written} embeddings to Qdrant") logger.info(f"Wrote {embeddings_written} embeddings to Qdrant")
async def create_collection(self, user: str, collection: str, metadata: dict): async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Collection creation via config push - collections created lazily on first write""" """Collection creation via config push - collections created lazily on first write"""
logger.info( logger.info(
f"Row embeddings collection create request for {user}/{collection} - " f"Row embeddings collection create request for {workspace}/{collection} - "
f"will be created lazily on first write" f"will be created lazily on first write"
) )
async def delete_collection(self, user: str, collection: str): async def delete_collection(self, workspace: str, collection: str):
"""Delete all Qdrant collections for a given user/collection""" """Delete all Qdrant collections for a given workspace/collection"""
try: try:
prefix = f"rows_{self.sanitize_name(user)}_{self.sanitize_name(collection)}_" prefix = f"rows_{self.sanitize_name(workspace)}_{self.sanitize_name(collection)}_"
# Get all collections and filter for matches # Get all collections and filter for matches
all_collections = self.qdrant.get_collections().collections all_collections = self.qdrant.get_collections().collections
@ -197,23 +197,23 @@ class Processor(CollectionConfigHandler, FlowProcessor):
logger.info(f"Deleted Qdrant collection: {collection_name}") logger.info(f"Deleted Qdrant collection: {collection_name}")
logger.info( logger.info(
f"Deleted {len(matching_collections)} collection(s) " f"Deleted {len(matching_collections)} collection(s) "
f"for {user}/{collection}" f"for {workspace}/{collection}"
) )
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to delete collection {user}/{collection}: {e}", f"Failed to delete collection {workspace}/{collection}: {e}",
exc_info=True exc_info=True
) )
raise raise
async def delete_collection_schema( async def delete_collection_schema(
self, user: str, collection: str, schema_name: str self, workspace: str, collection: str, schema_name: str
): ):
"""Delete Qdrant collection for a specific user/collection/schema""" """Delete Qdrant collection for a specific workspace/collection/schema"""
try: try:
prefix = ( prefix = (
f"rows_{self.sanitize_name(user)}_" f"rows_{self.sanitize_name(workspace)}_"
f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_" f"{self.sanitize_name(collection)}_{self.sanitize_name(schema_name)}_"
) )
@ -234,7 +234,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to delete collection {user}/{collection}/{schema_name}: {e}", f"Failed to delete collection {workspace}/{collection}/{schema_name}: {e}",
exc_info=True exc_info=True
) )
raise raise

View file

@ -459,25 +459,25 @@ class Processor(CollectionConfigHandler, FlowProcessor):
f"({len(index_names)} indexes per row)" f"({len(index_names)} indexes per row)"
) )
async def create_collection(self, user: str, collection: str, metadata: dict): async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create/verify collection exists in Cassandra row store""" """Create/verify collection exists in Cassandra row store"""
# Connect if not already connected (sync, push to thread) # Connect if not already connected (sync, push to thread)
await asyncio.to_thread(self.connect_cassandra) await asyncio.to_thread(self.connect_cassandra)
# Ensure tables exist (sync DDL, push to thread) # Ensure tables exist (sync DDL, push to thread)
await asyncio.to_thread(self.ensure_tables, user) await asyncio.to_thread(self.ensure_tables, workspace)
logger.info(f"Collection {collection} ready for user {user}") logger.info(f"Collection {collection} ready for workspace {workspace}")
async def delete_collection(self, user: str, collection: str): async def delete_collection(self, workspace: str, collection: str):
"""Delete all data for a specific collection using partition tracking""" """Delete all data for a specific collection using partition tracking"""
# Connect if not already connected # Connect if not already connected
await asyncio.to_thread(self.connect_cassandra) await asyncio.to_thread(self.connect_cassandra)
safe_keyspace = self.sanitize_name(user) safe_keyspace = self.sanitize_name(workspace)
# Check if keyspace exists # Check if keyspace exists
if user not in self.known_keyspaces: if workspace not in self.known_keyspaces:
check_keyspace_cql = """ check_keyspace_cql = """
SELECT keyspace_name FROM system_schema.keyspaces SELECT keyspace_name FROM system_schema.keyspaces
WHERE keyspace_name = %s WHERE keyspace_name = %s
@ -488,7 +488,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
if not result: if not result:
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete") logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
return return
self.known_keyspaces.add(user) self.known_keyspaces.add(workspace)
# Discover all partitions for this collection # Discover all partitions for this collection
select_partitions_cql = f""" select_partitions_cql = f"""
@ -551,12 +551,12 @@ class Processor(CollectionConfigHandler, FlowProcessor):
f"from keyspace {safe_keyspace}" f"from keyspace {safe_keyspace}"
) )
async def delete_collection_schema(self, user: str, collection: str, schema_name: str): async def delete_collection_schema(self, workspace: str, collection: str, schema_name: str):
"""Delete all data for a specific collection + schema combination""" """Delete all data for a specific collection + schema combination"""
# Connect if not already connected # Connect if not already connected
await asyncio.to_thread(self.connect_cassandra) await asyncio.to_thread(self.connect_cassandra)
safe_keyspace = self.sanitize_name(user) safe_keyspace = self.sanitize_name(workspace)
# Discover partitions for this collection + schema # Discover partitions for this collection + schema
select_partitions_cql = f""" select_partitions_cql = f"""

View file

@ -210,12 +210,12 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
await asyncio.to_thread(_do_store) await asyncio.to_thread(_do_store)
async def create_collection(self, user: str, collection: str, metadata: dict): async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create a collection in Cassandra triple store via config push""" """Create a collection in Cassandra triple store via config push"""
def _do_create(): def _do_create():
# Create or reuse connection for this user's keyspace # Create or reuse connection for this workspace's keyspace
if self.table is None or self.table != user: if self.table is None or self.table != workspace:
self.tg = None self.tg = None
# Use factory function to select implementation # Use factory function to select implementation
@ -225,23 +225,23 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
if self.cassandra_username and self.cassandra_password: if self.cassandra_username and self.cassandra_password:
self.tg = KGClass( self.tg = KGClass(
hosts=self.cassandra_host, hosts=self.cassandra_host,
keyspace=user, keyspace=workspace,
username=self.cassandra_username, username=self.cassandra_username,
password=self.cassandra_password, password=self.cassandra_password,
) )
else: else:
self.tg = KGClass( self.tg = KGClass(
hosts=self.cassandra_host, hosts=self.cassandra_host,
keyspace=user, keyspace=workspace,
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to connect to Cassandra for user {user}: {e}") logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}")
raise raise
self.table = user self.table = workspace
# Create collection using the built-in method # Create collection using the built-in method
logger.info(f"Creating collection {collection} for user {user}") logger.info(f"Creating collection {collection} for workspace {workspace}")
if self.tg.collection_exists(collection): if self.tg.collection_exists(collection):
logger.info(f"Collection {collection} already exists") logger.info(f"Collection {collection} already exists")
@ -252,15 +252,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
try: try:
await asyncio.to_thread(_do_create) await asyncio.to_thread(_do_create)
except Exception as e: except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise raise
async def delete_collection(self, user: str, collection: str): async def delete_collection(self, workspace: str, collection: str):
"""Delete all data for a specific collection from the unified triples table""" """Delete all data for a specific collection from the unified triples table"""
def _do_delete(): def _do_delete():
# Create or reuse connection for this user's keyspace # Create or reuse connection for this workspace's keyspace
if self.table is None or self.table != user: if self.table is None or self.table != workspace:
self.tg = None self.tg = None
# Use factory function to select implementation # Use factory function to select implementation
@ -270,29 +270,29 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
if self.cassandra_username and self.cassandra_password: if self.cassandra_username and self.cassandra_password:
self.tg = KGClass( self.tg = KGClass(
hosts=self.cassandra_host, hosts=self.cassandra_host,
keyspace=user, keyspace=workspace,
username=self.cassandra_username, username=self.cassandra_username,
password=self.cassandra_password, password=self.cassandra_password,
) )
else: else:
self.tg = KGClass( self.tg = KGClass(
hosts=self.cassandra_host, hosts=self.cassandra_host,
keyspace=user, keyspace=workspace,
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to connect to Cassandra for user {user}: {e}") logger.error(f"Failed to connect to Cassandra for workspace {workspace}: {e}")
raise raise
self.table = user self.table = workspace
# Delete all triples for this collection using the built-in method # Delete all triples for this collection using the built-in method
self.tg.delete_collection(collection) self.tg.delete_collection(collection)
logger.info(f"Deleted all triples for collection {collection} from keyspace {user}") logger.info(f"Deleted all triples for collection {collection} from keyspace {workspace}")
try: try:
await asyncio.to_thread(_do_delete) await asyncio.to_thread(_do_delete)
except Exception as e: except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise raise
@staticmethod @staticmethod

View file

@ -59,15 +59,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
# Register for config push notifications # Register for config push notifications
self.register_config_handler(self.on_collection_config, types=["collection"]) self.register_config_handler(self.on_collection_config, types=["collection"])
def create_node(self, uri, user, collection): def create_node(self, uri, workspace, collection):
logger.debug(f"Create node {uri} for user={user}, collection={collection}") logger.debug(f"Create node {uri} for workspace={workspace}, collection={collection}")
res = self.io.query( res = self.io.query(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
params={ params={
"uri": uri, "uri": uri,
"user": user, "workspace": workspace,
"collection": collection, "collection": collection,
}, },
) )
@ -77,15 +77,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=res.run_time_ms time=res.run_time_ms
)) ))
def create_literal(self, value, user, collection): def create_literal(self, value, workspace, collection):
logger.debug(f"Create literal {value} for user={user}, collection={collection}") logger.debug(f"Create literal {value} for workspace={workspace}, collection={collection}")
res = self.io.query( res = self.io.query(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})", "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
params={ params={
"value": value, "value": value,
"user": user, "workspace": workspace,
"collection": collection, "collection": collection,
}, },
) )
@ -95,19 +95,19 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=res.run_time_ms time=res.run_time_ms
)) ))
def relate_node(self, src, uri, dest, user, collection): def relate_node(self, src, uri, dest, workspace, collection):
logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}") logger.debug(f"Create node rel {src} {uri} {dest} for workspace={workspace}, collection={collection}")
res = self.io.query( res = self.io.query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
params={ params={
"src": src, "src": src,
"dest": dest, "dest": dest,
"uri": uri, "uri": uri,
"user": user, "workspace": workspace,
"collection": collection, "collection": collection,
}, },
) )
@ -117,19 +117,19 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=res.run_time_ms time=res.run_time_ms
)) ))
def relate_literal(self, src, uri, dest, user, collection): def relate_literal(self, src, uri, dest, workspace, collection):
logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}") logger.debug(f"Create literal rel {src} {uri} {dest} for workspace={workspace}, collection={collection}")
res = self.io.query( res = self.io.query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
params={ params={
"src": src, "src": src,
"dest": dest, "dest": dest,
"uri": uri, "uri": uri,
"user": user, "workspace": workspace,
"collection": collection, "collection": collection,
}, },
) )
@ -139,28 +139,28 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=res.run_time_ms time=res.run_time_ms
)) ))
def collection_exists(self, user, collection): def collection_exists(self, workspace, collection):
"""Check if collection metadata node exists""" """Check if collection metadata node exists"""
result = self.io.query( result = self.io.query(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) " "MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"RETURN c LIMIT 1", "RETURN c LIMIT 1",
params={"user": user, "collection": collection} params={"workspace": workspace, "collection": collection}
) )
return result.result_set is not None and len(result.result_set) > 0 return result.result_set is not None and len(result.result_set) > 0
def create_collection(self, user, collection): def create_collection(self, workspace, collection):
"""Create collection metadata node""" """Create collection metadata node"""
import datetime import datetime
self.io.query( self.io.query(
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) " "MERGE (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"SET c.created_at = $created_at", "SET c.created_at = $created_at",
params={ params={
"user": user, "workspace": workspace,
"collection": collection, "collection": collection,
"created_at": datetime.datetime.now().isoformat() "created_at": datetime.datetime.now().isoformat()
} }
) )
logger.info(f"Created collection metadata node for {user}/{collection}") logger.info(f"Created collection metadata node for {workspace}/{collection}")
async def store_triples(self, workspace, message): async def store_triples(self, workspace, message):
collection = message.metadata.collection if message.metadata.collection else "default" collection = message.metadata.collection if message.metadata.collection else "default"
@ -206,58 +206,58 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
help=f'FalkorDB database (default: {default_database})' help=f'FalkorDB database (default: {default_database})'
) )
async def create_collection(self, user: str, collection: str, metadata: dict): async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create collection metadata in FalkorDB via config push""" """Create collection metadata in FalkorDB via config push"""
try: try:
# Check if collection exists # Check if collection exists
result = self.io.query( result = self.io.query(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) RETURN c LIMIT 1", "MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) RETURN c LIMIT 1",
params={"user": user, "collection": collection} params={"workspace": workspace, "collection": collection}
) )
if result.result_set: if result.result_set:
logger.info(f"Collection {user}/{collection} already exists") logger.info(f"Collection {workspace}/{collection} already exists")
else: else:
# Create collection metadata node # Create collection metadata node
import datetime import datetime
self.io.query( self.io.query(
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) " "MERGE (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"SET c.created_at = $created_at", "SET c.created_at = $created_at",
params={ params={
"user": user, "workspace": workspace,
"collection": collection, "collection": collection,
"created_at": datetime.datetime.now().isoformat() "created_at": datetime.datetime.now().isoformat()
} }
) )
logger.info(f"Created collection {user}/{collection}") logger.info(f"Created collection {workspace}/{collection}")
except Exception as e: except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise raise
async def delete_collection(self, user: str, collection: str): async def delete_collection(self, workspace: str, collection: str):
"""Delete the collection for FalkorDB triples via config push""" """Delete the collection for FalkorDB triples via config push"""
try: try:
# Delete all nodes and literals for this user/collection # Delete all nodes and literals for this workspace/collection
node_result = self.io.query( node_result = self.io.query(
"MATCH (n:Node {user: $user, collection: $collection}) DETACH DELETE n", "MATCH (n:Node {workspace: $workspace, collection: $collection}) DETACH DELETE n",
params={"user": user, "collection": collection} params={"workspace": workspace, "collection": collection}
) )
literal_result = self.io.query( literal_result = self.io.query(
"MATCH (n:Literal {user: $user, collection: $collection}) DETACH DELETE n", "MATCH (n:Literal {workspace: $workspace, collection: $collection}) DETACH DELETE n",
params={"user": user, "collection": collection} params={"workspace": workspace, "collection": collection}
) )
# Delete collection metadata node # Delete collection metadata node
metadata_result = self.io.query( metadata_result = self.io.query(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) DELETE c", "MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) DELETE c",
params={"user": user, "collection": collection} params={"workspace": workspace, "collection": collection}
) )
logger.info(f"Deleted {node_result.nodes_deleted} nodes, {literal_result.nodes_deleted} literals, and {metadata_result.nodes_deleted} metadata nodes for collection {user}/{collection}") logger.info(f"Deleted {node_result.nodes_deleted} nodes, {literal_result.nodes_deleted} literals, and {metadata_result.nodes_deleted} metadata nodes for collection {workspace}/{collection}")
except Exception as e: except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise raise
def run(): def run():

View file

@ -117,10 +117,10 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
# Maybe index already exists # Maybe index already exists
logger.warning("Index create failure ignored") logger.warning("Index create failure ignored")
# New indexes for user/collection filtering # New indexes for workspace/collection filtering
try: try:
session.run( session.run(
"CREATE INDEX ON :Node(user)" "CREATE INDEX ON :Node(workspace)"
) )
except Exception as e: except Exception as e:
logger.warning(f"User index create failure: {e}") logger.warning(f"User index create failure: {e}")
@ -136,7 +136,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
try: try:
session.run( session.run(
"CREATE INDEX ON :Literal(user)" "CREATE INDEX ON :Literal(workspace)"
) )
except Exception as e: except Exception as e:
logger.warning(f"User index create failure: {e}") logger.warning(f"User index create failure: {e}")
@ -152,13 +152,13 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
logger.info("Index creation done") logger.info("Index creation done")
def create_node(self, uri, user, collection): def create_node(self, uri, workspace, collection):
logger.debug(f"Create node {uri} for user={user}, collection={collection}") logger.debug(f"Create node {uri} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query( summary = self.io.execute_query(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=uri, user=user, collection=collection, uri=uri, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
).summary ).summary
@ -167,13 +167,13 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after time=summary.result_available_after
)) ))
def create_literal(self, value, user, collection): def create_literal(self, value, workspace, collection):
logger.debug(f"Create literal {value} for user={user}, collection={collection}") logger.debug(f"Create literal {value} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query( summary = self.io.execute_query(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})", "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value=value, user=user, collection=collection, value=value, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
).summary ).summary
@ -182,15 +182,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after time=summary.result_available_after
)) ))
def relate_node(self, src, uri, dest, user, collection): def relate_node(self, src, uri, dest, workspace, collection):
logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}") logger.debug(f"Create node rel {src} {uri} {dest} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query( summary = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, user=user, collection=collection, src=src, dest=dest, uri=uri, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
).summary ).summary
@ -199,15 +199,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after time=summary.result_available_after
)) ))
def relate_literal(self, src, uri, dest, user, collection): def relate_literal(self, src, uri, dest, workspace, collection):
logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}") logger.debug(f"Create literal rel {src} {uri} {dest} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query( summary = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, user=user, collection=collection, src=src, dest=dest, uri=uri, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
).summary ).summary
@ -216,7 +216,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after time=summary.result_available_after
)) ))
def create_triple(self, tx, t, user, collection): def create_triple(self, tx, t, workspace, collection):
s_val = get_term_value(t.s) s_val = get_term_value(t.s)
p_val = get_term_value(t.p) p_val = get_term_value(t.p)
@ -224,38 +224,38 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
# Create new s node with given uri, if not exists # Create new s node with given uri, if not exists
result = tx.run( result = tx.run(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=s_val, user=user, collection=collection uri=s_val, workspace=workspace, collection=collection
) )
if t.o.type == IRI: if t.o.type == IRI:
# Create new o node with given uri, if not exists # Create new o node with given uri, if not exists
result = tx.run( result = tx.run(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=o_val, user=user, collection=collection uri=o_val, workspace=workspace, collection=collection
) )
result = tx.run( result = tx.run(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=s_val, dest=o_val, uri=p_val, user=user, collection=collection, src=s_val, dest=o_val, uri=p_val, workspace=workspace, collection=collection,
) )
else: else:
# Create new o literal with given uri, if not exists # Create new o literal with given uri, if not exists
result = tx.run( result = tx.run(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})", "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value=o_val, user=user, collection=collection value=o_val, workspace=workspace, collection=collection
) )
result = tx.run( result = tx.run(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=s_val, dest=o_val, uri=p_val, user=user, collection=collection, src=s_val, dest=o_val, uri=p_val, workspace=workspace, collection=collection,
) )
async def store_triples(self, workspace, message): async def store_triples(self, workspace, message):
@ -288,7 +288,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
# Alternative implementation using transactions # Alternative implementation using transactions
# with self.io.session(database=self.db) as session: # with self.io.session(database=self.db) as session:
# session.execute_write(self.create_triple, t, user, collection) # session.execute_write(self.create_triple, t, workspace, collection)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
@ -319,72 +319,72 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
help=f'Memgraph database (default: {default_database})' help=f'Memgraph database (default: {default_database})'
) )
def _collection_exists_in_db(self, user, collection): def _collection_exists_in_db(self, workspace, collection):
"""Check if collection metadata node exists""" """Check if collection metadata node exists"""
with self.io.session(database=self.db) as session: with self.io.session(database=self.db) as session:
result = session.run( result = session.run(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) " "MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"RETURN c LIMIT 1", "RETURN c LIMIT 1",
user=user, collection=collection workspace=workspace, collection=collection
) )
return bool(list(result)) return bool(list(result))
def _create_collection_in_db(self, user, collection): def _create_collection_in_db(self, workspace, collection):
"""Create collection metadata node""" """Create collection metadata node"""
import datetime import datetime
with self.io.session(database=self.db) as session: with self.io.session(database=self.db) as session:
session.run( session.run(
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) " "MERGE (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"SET c.created_at = $created_at", "SET c.created_at = $created_at",
user=user, collection=collection, workspace=workspace, collection=collection,
created_at=datetime.datetime.now().isoformat() created_at=datetime.datetime.now().isoformat()
) )
logger.info(f"Created collection metadata node for {user}/{collection}") logger.info(f"Created collection metadata node for {workspace}/{collection}")
async def create_collection(self, user: str, collection: str, metadata: dict): async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create collection metadata in Memgraph via config push""" """Create collection metadata in Memgraph via config push"""
try: try:
if self._collection_exists_in_db(user, collection): if self._collection_exists_in_db(workspace, collection):
logger.info(f"Collection {user}/{collection} already exists") logger.info(f"Collection {workspace}/{collection} already exists")
else: else:
self._create_collection_in_db(user, collection) self._create_collection_in_db(workspace, collection)
logger.info(f"Created collection {user}/{collection}") logger.info(f"Created collection {workspace}/{collection}")
except Exception as e: except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise raise
async def delete_collection(self, user: str, collection: str): async def delete_collection(self, workspace: str, collection: str):
"""Delete all data for a specific collection via config push""" """Delete all data for a specific collection via config push"""
try: try:
with self.io.session(database=self.db) as session: with self.io.session(database=self.db) as session:
# Delete all nodes for this user and collection # Delete all nodes for this workspace and collection
node_result = session.run( node_result = session.run(
"MATCH (n:Node {user: $user, collection: $collection}) " "MATCH (n:Node {workspace: $workspace, collection: $collection}) "
"DETACH DELETE n", "DETACH DELETE n",
user=user, collection=collection workspace=workspace, collection=collection
) )
nodes_deleted = node_result.consume().counters.nodes_deleted nodes_deleted = node_result.consume().counters.nodes_deleted
# Delete all literals for this user and collection # Delete all literals for this workspace and collection
literal_result = session.run( literal_result = session.run(
"MATCH (n:Literal {user: $user, collection: $collection}) " "MATCH (n:Literal {workspace: $workspace, collection: $collection}) "
"DETACH DELETE n", "DETACH DELETE n",
user=user, collection=collection workspace=workspace, collection=collection
) )
literals_deleted = literal_result.consume().counters.nodes_deleted literals_deleted = literal_result.consume().counters.nodes_deleted
# Delete collection metadata node # Delete collection metadata node
metadata_result = session.run( metadata_result = session.run(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) " "MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"DELETE c", "DELETE c",
user=user, collection=collection workspace=workspace, collection=collection
) )
metadata_deleted = metadata_result.consume().counters.nodes_deleted metadata_deleted = metadata_result.consume().counters.nodes_deleted
# Note: Relationships are automatically deleted with DETACH DELETE # Note: Relationships are automatically deleted with DETACH DELETE
logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {user}/{collection}") logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {workspace}/{collection}")
except Exception as e: except Exception as e:
logger.error(f"Failed to delete collection: {e}") logger.error(f"Failed to delete collection: {e}")

View file

@ -80,14 +80,12 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
logger.info("Create indexes...") logger.info("Create indexes...")
# Legacy indexes for backwards compatibility
try: try:
session.run( session.run(
"CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)", "CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)",
) )
except Exception as e: except Exception as e:
logger.warning(f"Index create failure: {e}") logger.warning(f"Index create failure: {e}")
# Maybe index already exists
logger.warning("Index create failure ignored") logger.warning("Index create failure ignored")
try: try:
@ -96,7 +94,6 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
) )
except Exception as e: except Exception as e:
logger.warning(f"Index create failure: {e}") logger.warning(f"Index create failure: {e}")
# Maybe index already exists
logger.warning("Index create failure ignored") logger.warning("Index create failure ignored")
try: try:
@ -105,13 +102,11 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
) )
except Exception as e: except Exception as e:
logger.warning(f"Index create failure: {e}") logger.warning(f"Index create failure: {e}")
# Maybe index already exists
logger.warning("Index create failure ignored") logger.warning("Index create failure ignored")
# New compound indexes for user/collection filtering
try: try:
session.run( session.run(
"CREATE INDEX node_user_collection_uri FOR (n:Node) ON (n.user, n.collection, n.uri)", "CREATE INDEX node_workspace_collection_uri FOR (n:Node) ON (n.workspace, n.collection, n.uri)",
) )
except Exception as e: except Exception as e:
logger.warning(f"Compound index create failure: {e}") logger.warning(f"Compound index create failure: {e}")
@ -119,17 +114,16 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
try: try:
session.run( session.run(
"CREATE INDEX literal_user_collection_value FOR (n:Literal) ON (n.user, n.collection, n.value)", "CREATE INDEX literal_workspace_collection_value FOR (n:Literal) ON (n.workspace, n.collection, n.value)",
) )
except Exception as e: except Exception as e:
logger.warning(f"Compound index create failure: {e}") logger.warning(f"Compound index create failure: {e}")
logger.warning("Index create failure ignored") logger.warning("Index create failure ignored")
# Note: Neo4j doesn't support compound indexes on relationships in all versions # Neo4j doesn't support compound indexes on relationships in all versions
# Try to create individual indexes on relationship properties
try: try:
session.run( session.run(
"CREATE INDEX rel_user FOR ()-[r:Rel]-() ON (r.user)", "CREATE INDEX rel_workspace FOR ()-[r:Rel]-() ON (r.workspace)",
) )
except Exception as e: except Exception as e:
logger.warning(f"Relationship index create failure: {e}") logger.warning(f"Relationship index create failure: {e}")
@ -145,13 +139,13 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
logger.info("Index creation done") logger.info("Index creation done")
def create_node(self, uri, user, collection): def create_node(self, uri, workspace, collection):
logger.debug(f"Create node {uri} for user={user}, collection={collection}") logger.debug(f"Create node {uri} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query( summary = self.io.execute_query(
"MERGE (n:Node {uri: $uri, user: $user, collection: $collection})", "MERGE (n:Node {uri: $uri, workspace: $workspace, collection: $collection})",
uri=uri, user=user, collection=collection, uri=uri, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
).summary ).summary
@ -160,13 +154,13 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after time=summary.result_available_after
)) ))
def create_literal(self, value, user, collection): def create_literal(self, value, workspace, collection):
logger.debug(f"Create literal {value} for user={user}, collection={collection}") logger.debug(f"Create literal {value} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query( summary = self.io.execute_query(
"MERGE (n:Literal {value: $value, user: $user, collection: $collection})", "MERGE (n:Literal {value: $value, workspace: $workspace, collection: $collection})",
value=value, user=user, collection=collection, value=value, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
).summary ).summary
@ -175,15 +169,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after time=summary.result_available_after
)) ))
def relate_node(self, src, uri, dest, user, collection): def relate_node(self, src, uri, dest, workspace, collection):
logger.debug(f"Create node rel {src} {uri} {dest} for user={user}, collection={collection}") logger.debug(f"Create node rel {src} {uri} {dest} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query( summary = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Node {uri: $dest, user: $user, collection: $collection}) " "MATCH (dest:Node {uri: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, user=user, collection=collection, src=src, dest=dest, uri=uri, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
).summary ).summary
@ -192,15 +186,15 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
time=summary.result_available_after time=summary.result_available_after
)) ))
def relate_literal(self, src, uri, dest, user, collection): def relate_literal(self, src, uri, dest, workspace, collection):
logger.debug(f"Create literal rel {src} {uri} {dest} for user={user}, collection={collection}") logger.debug(f"Create literal rel {src} {uri} {dest} for workspace={workspace}, collection={collection}")
summary = self.io.execute_query( summary = self.io.execute_query(
"MATCH (src:Node {uri: $src, user: $user, collection: $collection}) " "MATCH (src:Node {uri: $src, workspace: $workspace, collection: $collection}) "
"MATCH (dest:Literal {value: $dest, user: $user, collection: $collection}) " "MATCH (dest:Literal {value: $dest, workspace: $workspace, collection: $collection}) "
"MERGE (src)-[:Rel {uri: $uri, user: $user, collection: $collection}]->(dest)", "MERGE (src)-[:Rel {uri: $uri, workspace: $workspace, collection: $collection}]->(dest)",
src=src, dest=dest, uri=uri, user=user, collection=collection, src=src, dest=dest, uri=uri, workspace=workspace, collection=collection,
database_=self.db, database_=self.db,
).summary ).summary
@ -266,75 +260,70 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
help=f'Neo4j database (default: {default_database})' help=f'Neo4j database (default: {default_database})'
) )
def _collection_exists_in_db(self, user, collection): def _collection_exists_in_db(self, workspace, collection):
"""Check if collection metadata node exists""" """Check if collection metadata node exists"""
with self.io.session(database=self.db) as session: with self.io.session(database=self.db) as session:
result = session.run( result = session.run(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) " "MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"RETURN c LIMIT 1", "RETURN c LIMIT 1",
user=user, collection=collection workspace=workspace, collection=collection
) )
return bool(list(result)) return bool(list(result))
def _create_collection_in_db(self, user, collection): def _create_collection_in_db(self, workspace, collection):
"""Create collection metadata node""" """Create collection metadata node"""
import datetime import datetime
with self.io.session(database=self.db) as session: with self.io.session(database=self.db) as session:
session.run( session.run(
"MERGE (c:CollectionMetadata {user: $user, collection: $collection}) " "MERGE (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"SET c.created_at = $created_at", "SET c.created_at = $created_at",
user=user, collection=collection, workspace=workspace, collection=collection,
created_at=datetime.datetime.now().isoformat() created_at=datetime.datetime.now().isoformat()
) )
logger.info(f"Created collection metadata node for {user}/{collection}") logger.info(f"Created collection metadata node for {workspace}/{collection}")
async def create_collection(self, user: str, collection: str, metadata: dict): async def create_collection(self, workspace: str, collection: str, metadata: dict):
"""Create collection metadata in Neo4j via config push""" """Create collection metadata in Neo4j via config push"""
try: try:
if self._collection_exists_in_db(user, collection): if self._collection_exists_in_db(workspace, collection):
logger.info(f"Collection {user}/{collection} already exists") logger.info(f"Collection {workspace}/{collection} already exists")
else: else:
self._create_collection_in_db(user, collection) self._create_collection_in_db(workspace, collection)
logger.info(f"Created collection {user}/{collection}") logger.info(f"Created collection {workspace}/{collection}")
except Exception as e: except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) logger.error(f"Failed to create collection {workspace}/{collection}: {e}", exc_info=True)
raise raise
async def delete_collection(self, user: str, collection: str): async def delete_collection(self, workspace: str, collection: str):
"""Delete all data for a specific collection via config push""" """Delete all data for a specific collection via config push"""
try: try:
with self.io.session(database=self.db) as session: with self.io.session(database=self.db) as session:
# Delete all nodes for this user and collection
node_result = session.run( node_result = session.run(
"MATCH (n:Node {user: $user, collection: $collection}) " "MATCH (n:Node {workspace: $workspace, collection: $collection}) "
"DETACH DELETE n", "DETACH DELETE n",
user=user, collection=collection workspace=workspace, collection=collection
) )
nodes_deleted = node_result.consume().counters.nodes_deleted nodes_deleted = node_result.consume().counters.nodes_deleted
# Delete all literals for this user and collection
literal_result = session.run( literal_result = session.run(
"MATCH (n:Literal {user: $user, collection: $collection}) " "MATCH (n:Literal {workspace: $workspace, collection: $collection}) "
"DETACH DELETE n", "DETACH DELETE n",
user=user, collection=collection workspace=workspace, collection=collection
) )
literals_deleted = literal_result.consume().counters.nodes_deleted literals_deleted = literal_result.consume().counters.nodes_deleted
# Note: Relationships are automatically deleted with DETACH DELETE
# Delete collection metadata node
metadata_result = session.run( metadata_result = session.run(
"MATCH (c:CollectionMetadata {user: $user, collection: $collection}) " "MATCH (c:CollectionMetadata {workspace: $workspace, collection: $collection}) "
"DELETE c", "DELETE c",
user=user, collection=collection workspace=workspace, collection=collection
) )
metadata_deleted = metadata_result.consume().counters.nodes_deleted metadata_deleted = metadata_result.consume().counters.nodes_deleted
logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {user}/{collection}") logger.info(f"Deleted {nodes_deleted} nodes, {literals_deleted} literals, and {metadata_deleted} metadata nodes for {workspace}/{collection}")
except Exception as e: except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) logger.error(f"Failed to delete collection {workspace}/{collection}: {e}", exc_info=True)
raise raise
def run(): def run():