diff --git a/tests/unit/test_query/test_doc_embeddings_milvus_query.py b/tests/unit/test_query/test_doc_embeddings_milvus_query.py index 10ea54d2..622529e5 100644 --- a/tests/unit/test_query/test_doc_embeddings_milvus_query.py +++ b/tests/unit/test_query/test_doc_embeddings_milvus_query.py @@ -85,8 +85,10 @@ class TestMilvusDocEmbeddingsQueryProcessor: result = await processor.query_document_embeddings(query) - # Verify search was called with correct parameters - processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=5) + # Verify search was called with correct parameters including user/collection + processor.vecstore.search.assert_called_once_with( + [0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=5 + ) # Verify results are document chunks assert len(result) == 3 @@ -116,10 +118,10 @@ class TestMilvusDocEmbeddingsQueryProcessor: result = await processor.query_document_embeddings(query) - # Verify search was called twice with correct parameters + # Verify search was called twice with correct parameters including user/collection expected_calls = [ - (([0.1, 0.2, 0.3],), {"limit": 3}), - (([0.4, 0.5, 0.6],), {"limit": 3}), + (([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 3}), + (([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 3}), ] assert processor.vecstore.search.call_count == 2 for i, (expected_args, expected_kwargs) in enumerate(expected_calls): @@ -155,7 +157,9 @@ class TestMilvusDocEmbeddingsQueryProcessor: result = await processor.query_document_embeddings(query) # Verify search was called with the specified limit - processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=2) + processor.vecstore.search.assert_called_once_with( + [0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=2 + ) # Verify all results are returned (Milvus handles limit internally) assert len(result) == 4 @@ -194,7 +198,9 @@ class TestMilvusDocEmbeddingsQueryProcessor: result = await processor.query_document_embeddings(query) # Verify search was called - processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=5) + processor.vecstore.search.assert_called_once_with( + [0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=5 + ) # Verify empty results assert len(result) == 0 diff --git a/tests/unit/test_query/test_graph_embeddings_milvus_query.py b/tests/unit/test_query/test_graph_embeddings_milvus_query.py index 5fbb74d5..ebacfaaf 100644 --- a/tests/unit/test_query/test_graph_embeddings_milvus_query.py +++ b/tests/unit/test_query/test_graph_embeddings_milvus_query.py @@ -133,8 +133,10 @@ class TestMilvusGraphEmbeddingsQueryProcessor: result = await processor.query_graph_embeddings(query) - # Verify search was called with correct parameters - processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=10) + # Verify search was called with correct parameters including user/collection + processor.vecstore.search.assert_called_once_with( + [0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=10 + ) # Verify results are converted to Value objects assert len(result) == 3 @@ -171,10 +173,10 @@ class TestMilvusGraphEmbeddingsQueryProcessor: result = await processor.query_graph_embeddings(query) - # Verify search was called twice with correct parameters + # Verify search was called twice with correct parameters including user/collection expected_calls = [ - (([0.1, 0.2, 0.3],), {"limit": 6}), - (([0.4, 0.5, 0.6],), {"limit": 6}), + (([0.1, 0.2, 0.3], 'test_user', 'test_collection'), {"limit": 6}), + (([0.4, 0.5, 0.6], 'test_user', 'test_collection'), {"limit": 6}), ] assert processor.vecstore.search.call_count == 2 for i, (expected_args, expected_kwargs) in enumerate(expected_calls): @@ -211,7 +213,9 @@ class TestMilvusGraphEmbeddingsQueryProcessor: result = await processor.query_graph_embeddings(query) # Verify search was called with 2*limit for better deduplication - processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=4) + processor.vecstore.search.assert_called_once_with( + [0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=4 + ) # Verify results are limited to the requested limit assert len(result) == 2 @@ -269,7 +273,9 @@ class TestMilvusGraphEmbeddingsQueryProcessor: result = await processor.query_graph_embeddings(query) # Verify only first vector was searched (limit reached) - processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=4) + processor.vecstore.search.assert_called_once_with( + [0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=4 + ) # Verify results are limited assert len(result) == 2 @@ -308,7 +314,9 @@ class TestMilvusGraphEmbeddingsQueryProcessor: result = await processor.query_graph_embeddings(query) # Verify search was called - processor.vecstore.search.assert_called_once_with([0.1, 0.2, 0.3], limit=10) + processor.vecstore.search.assert_called_once_with( + [0.1, 0.2, 0.3], 'test_user', 'test_collection', limit=10 + ) # Verify empty results assert len(result) == 0 diff --git a/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py b/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py index 5e6bcfb9..d957d711 100644 --- a/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py +++ b/tests/unit/test_storage/test_doc_embeddings_milvus_storage.py @@ -91,37 +91,41 @@ class TestMilvusDocEmbeddingsStorageProcessor: await processor.store_document_embeddings(message) - # Verify insert was called for each vector + # Verify insert was called for each vector with user/collection parameters expected_calls = [ - ([0.1, 0.2, 0.3], "Test document content"), - ([0.4, 0.5, 0.6], "Test document content"), + ([0.1, 0.2, 0.3], "Test document content", 'test_user', 'test_collection'), + ([0.4, 0.5, 0.6], "Test document content", 'test_user', 'test_collection'), ] assert processor.vecstore.insert.call_count == 2 - for i, (expected_vec, expected_doc) in enumerate(expected_calls): + for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls): actual_call = processor.vecstore.insert.call_args_list[i] assert actual_call[0][0] == expected_vec assert actual_call[0][1] == expected_doc + assert actual_call[0][2] == expected_user + assert actual_call[0][3] == expected_collection @pytest.mark.asyncio async def test_store_document_embeddings_multiple_chunks(self, processor, mock_message): """Test storing document embeddings for multiple chunks""" await processor.store_document_embeddings(mock_message) - # Verify insert was called for each vector of each chunk + # Verify insert was called for each vector of each chunk with user/collection parameters expected_calls = [ # Chunk 1 vectors - ([0.1, 0.2, 0.3], "This is the first document chunk"), - ([0.4, 0.5, 0.6], "This is the first document chunk"), + ([0.1, 0.2, 0.3], "This is the first document chunk", 'test_user', 'test_collection'), + ([0.4, 0.5, 0.6], "This is the first document chunk", 'test_user', 'test_collection'), # Chunk 2 vectors - ([0.7, 0.8, 0.9], "This is the second document chunk"), + ([0.7, 0.8, 0.9], "This is the second document chunk", 'test_user', 'test_collection'), ] assert processor.vecstore.insert.call_count == 3 - for i, (expected_vec, expected_doc) in enumerate(expected_calls): + for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls): actual_call = processor.vecstore.insert.call_args_list[i] assert actual_call[0][0] == expected_vec assert actual_call[0][1] == expected_doc + assert actual_call[0][2] == expected_user + assert actual_call[0][3] == expected_collection @pytest.mark.asyncio async def test_store_document_embeddings_empty_chunk(self, processor): @@ -185,9 +189,9 @@ class TestMilvusDocEmbeddingsStorageProcessor: await processor.store_document_embeddings(message) - # Verify only valid chunk was inserted + # Verify only valid chunk was inserted with user/collection parameters processor.vecstore.insert.assert_called_once_with( - [0.1, 0.2, 0.3], "Valid document content" + [0.1, 0.2, 0.3], "Valid document content", 'test_user', 'test_collection' ) @pytest.mark.asyncio @@ -243,18 +247,20 @@ class TestMilvusDocEmbeddingsStorageProcessor: await processor.store_document_embeddings(message) - # Verify all vectors were inserted regardless of dimension + # Verify all vectors were inserted regardless of dimension with user/collection parameters expected_calls = [ - ([0.1, 0.2], "Document with mixed dimensions"), - ([0.3, 0.4, 0.5, 0.6], "Document with mixed dimensions"), - ([0.7, 0.8, 0.9], "Document with mixed dimensions"), + ([0.1, 0.2], "Document with mixed dimensions", 'test_user', 'test_collection'), + ([0.3, 0.4, 0.5, 0.6], "Document with mixed dimensions", 'test_user', 'test_collection'), + ([0.7, 0.8, 0.9], "Document with mixed dimensions", 'test_user', 'test_collection'), ] assert processor.vecstore.insert.call_count == 3 - for i, (expected_vec, expected_doc) in enumerate(expected_calls): + for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls): actual_call = processor.vecstore.insert.call_args_list[i] assert actual_call[0][0] == expected_vec assert actual_call[0][1] == expected_doc + assert actual_call[0][2] == expected_user + assert actual_call[0][3] == expected_collection @pytest.mark.asyncio async def test_store_document_embeddings_unicode_content(self, processor): @@ -272,9 +278,9 @@ class TestMilvusDocEmbeddingsStorageProcessor: await processor.store_document_embeddings(message) - # Verify Unicode content was properly decoded and inserted + # Verify Unicode content was properly decoded and inserted with user/collection parameters processor.vecstore.insert.assert_called_once_with( - [0.1, 0.2, 0.3], "Document with Unicode: éñ中文🚀" + [0.1, 0.2, 0.3], "Document with Unicode: éñ中文🚀", 'test_user', 'test_collection' ) @pytest.mark.asyncio @@ -295,9 +301,9 @@ class TestMilvusDocEmbeddingsStorageProcessor: await processor.store_document_embeddings(message) - # Verify large content was inserted + # Verify large content was inserted with user/collection parameters processor.vecstore.insert.assert_called_once_with( - [0.1, 0.2, 0.3], large_content + [0.1, 0.2, 0.3], large_content, 'test_user', 'test_collection' ) @pytest.mark.asyncio @@ -316,9 +322,103 @@ class TestMilvusDocEmbeddingsStorageProcessor: await processor.store_document_embeddings(message) - # Verify whitespace content was inserted (not filtered out) + # Verify whitespace content was inserted (not filtered out) with user/collection parameters processor.vecstore.insert.assert_called_once_with( - [0.1, 0.2, 0.3], " \n\t " + [0.1, 0.2, 0.3], " \n\t ", 'test_user', 'test_collection' + ) + + @pytest.mark.asyncio + async def test_store_document_embeddings_different_user_collection_combinations(self, processor): + """Test storing document embeddings with different user/collection combinations""" + test_cases = [ + ('user1', 'collection1'), + ('user2', 'collection2'), + ('admin', 'production'), + ('test@domain.com', 'test-collection.v1'), + ] + + for user, collection in test_cases: + processor.vecstore.reset_mock() # Reset mock for each test case + + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = user + message.metadata.collection = collection + + chunk = ChunkEmbeddings( + chunk=b"Test content", + vectors=[[0.1, 0.2, 0.3]] + ) + message.chunks = [chunk] + + await processor.store_document_embeddings(message) + + # Verify insert was called with the correct user/collection + processor.vecstore.insert.assert_called_once_with( + [0.1, 0.2, 0.3], "Test content", user, collection + ) + + @pytest.mark.asyncio + async def test_store_document_embeddings_user_collection_parameter_isolation(self, processor): + """Test that different user/collection combinations are properly isolated""" + # Store embeddings for user1/collection1 + message1 = MagicMock() + message1.metadata = MagicMock() + message1.metadata.user = 'user1' + message1.metadata.collection = 'collection1' + chunk1 = ChunkEmbeddings( + chunk=b"User1 content", + vectors=[[0.1, 0.2, 0.3]] + ) + message1.chunks = [chunk1] + + # Store embeddings for user2/collection2 + message2 = MagicMock() + message2.metadata = MagicMock() + message2.metadata.user = 'user2' + message2.metadata.collection = 'collection2' + chunk2 = ChunkEmbeddings( + chunk=b"User2 content", + vectors=[[0.4, 0.5, 0.6]] + ) + message2.chunks = [chunk2] + + await processor.store_document_embeddings(message1) + await processor.store_document_embeddings(message2) + + # Verify both calls were made with correct parameters + expected_calls = [ + ([0.1, 0.2, 0.3], "User1 content", 'user1', 'collection1'), + ([0.4, 0.5, 0.6], "User2 content", 'user2', 'collection2'), + ] + + assert processor.vecstore.insert.call_count == 2 + for i, (expected_vec, expected_doc, expected_user, expected_collection) in enumerate(expected_calls): + actual_call = processor.vecstore.insert.call_args_list[i] + assert actual_call[0][0] == expected_vec + assert actual_call[0][1] == expected_doc + assert actual_call[0][2] == expected_user + assert actual_call[0][3] == expected_collection + + @pytest.mark.asyncio + async def test_store_document_embeddings_special_character_user_collection(self, processor): + """Test storing document embeddings with special characters in user/collection names""" + message = MagicMock() + message.metadata = MagicMock() + message.metadata.user = 'user@domain.com' # Email-like user + message.metadata.collection = 'test-collection.v1' # Collection with special chars + + chunk = ChunkEmbeddings( + chunk=b"Special chars test", + vectors=[[0.1, 0.2, 0.3]] + ) + message.chunks = [chunk] + + await processor.store_document_embeddings(message) + + # Verify the exact user/collection strings are passed (sanitization happens in DocVectors) + processor.vecstore.insert.assert_called_once_with( + [0.1, 0.2, 0.3], "Special chars test", 'user@domain.com', 'test-collection.v1' ) def test_add_args_method(self): diff --git a/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py b/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py index ae300574..a22173ab 100644 --- a/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py +++ b/tests/unit/test_storage/test_graph_embeddings_milvus_storage.py @@ -91,37 +91,41 @@ class TestMilvusGraphEmbeddingsStorageProcessor: await processor.store_graph_embeddings(message) - # Verify insert was called for each vector + # Verify insert was called for each vector with user/collection parameters expected_calls = [ - ([0.1, 0.2, 0.3], 'http://example.com/entity'), - ([0.4, 0.5, 0.6], 'http://example.com/entity'), + ([0.1, 0.2, 0.3], 'http://example.com/entity', 'test_user', 'test_collection'), + ([0.4, 0.5, 0.6], 'http://example.com/entity', 'test_user', 'test_collection'), ] assert processor.vecstore.insert.call_count == 2 - for i, (expected_vec, expected_entity) in enumerate(expected_calls): + for i, (expected_vec, expected_entity, expected_user, expected_collection) in enumerate(expected_calls): actual_call = processor.vecstore.insert.call_args_list[i] assert actual_call[0][0] == expected_vec assert actual_call[0][1] == expected_entity + assert actual_call[0][2] == expected_user + assert actual_call[0][3] == expected_collection @pytest.mark.asyncio async def test_store_graph_embeddings_multiple_entities(self, processor, mock_message): """Test storing graph embeddings for multiple entities""" await processor.store_graph_embeddings(mock_message) - # Verify insert was called for each vector of each entity + # Verify insert was called for each vector of each entity with user/collection parameters expected_calls = [ # Entity 1 vectors - ([0.1, 0.2, 0.3], 'http://example.com/entity1'), - ([0.4, 0.5, 0.6], 'http://example.com/entity1'), + ([0.1, 0.2, 0.3], 'http://example.com/entity1', 'test_user', 'test_collection'), + ([0.4, 0.5, 0.6], 'http://example.com/entity1', 'test_user', 'test_collection'), # Entity 2 vectors - ([0.7, 0.8, 0.9], 'literal entity'), + ([0.7, 0.8, 0.9], 'literal entity', 'test_user', 'test_collection'), ] assert processor.vecstore.insert.call_count == 3 - for i, (expected_vec, expected_entity) in enumerate(expected_calls): + for i, (expected_vec, expected_entity, expected_user, expected_collection) in enumerate(expected_calls): actual_call = processor.vecstore.insert.call_args_list[i] assert actual_call[0][0] == expected_vec assert actual_call[0][1] == expected_entity + assert actual_call[0][2] == expected_user + assert actual_call[0][3] == expected_collection @pytest.mark.asyncio async def test_store_graph_embeddings_empty_entity_value(self, processor): @@ -185,9 +189,9 @@ class TestMilvusGraphEmbeddingsStorageProcessor: await processor.store_graph_embeddings(message) - # Verify only valid entity was inserted + # Verify only valid entity was inserted with user/collection parameters processor.vecstore.insert.assert_called_once_with( - [0.1, 0.2, 0.3], 'http://example.com/valid' + [0.1, 0.2, 0.3], 'http://example.com/valid', 'test_user', 'test_collection' ) @pytest.mark.asyncio diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index 0e9b5b0d..a0929eee 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -90,7 +90,6 @@ metering = "trustgraph.metering:run" nlp-query = "trustgraph.retrieval.nlp_query:run" objects-write-cassandra = "trustgraph.storage.objects.cassandra:run" objects-query-cassandra = "trustgraph.query.objects.cassandra:run" -oe-write-milvus = "trustgraph.storage.object_embeddings.milvus:run" pdf-decoder = "trustgraph.decoding.pdf:run" pdf-ocr-mistral = "trustgraph.decoding.mistral_ocr:run" prompt-template = "trustgraph.prompt.template:run" diff --git a/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py b/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py index 6d203858..220c8d7b 100644 --- a/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py +++ b/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py @@ -2,9 +2,32 @@ from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType import time import logging +import re logger = logging.getLogger(__name__) +def make_safe_collection_name(user, collection, dimension, prefix): + """ + Create a safe Milvus collection name from user/collection parameters. + Milvus only allows letters, numbers, and underscores. + """ + def sanitize(s): + # Replace non-alphanumeric characters (except underscore) with underscore + # Then collapse multiple underscores into single underscore + safe = re.sub(r'[^a-zA-Z0-9_]', '_', s) + safe = re.sub(r'_+', '_', safe) + # Remove leading/trailing underscores + safe = safe.strip('_') + # Ensure it's not empty + if not safe: + safe = 'default' + return safe + + safe_user = sanitize(user) + safe_collection = sanitize(collection) + + return f"{prefix}_{safe_user}_{safe_collection}_{dimension}" + class DocVectors: def __init__(self, uri="http://localhost:19530", prefix='doc'): @@ -26,9 +49,9 @@ class DocVectors: self.next_reload = time.time() + self.reload_time logger.debug(f"Reload at {self.next_reload}") - def init_collection(self, dimension): + def init_collection(self, dimension, user, collection): - collection_name = self.prefix + "_" + str(dimension) + collection_name = make_safe_collection_name(user, collection, dimension, self.prefix) pkey_field = FieldSchema( name="id", @@ -75,14 +98,14 @@ class DocVectors: index_params=index_params ) - self.collections[dimension] = collection_name + self.collections[(dimension, user, collection)] = collection_name - def insert(self, embeds, doc): + def insert(self, embeds, doc, user, collection): dim = len(embeds) - if dim not in self.collections: - self.init_collection(dim) + if (dim, user, collection) not in self.collections: + self.init_collection(dim, user, collection) data = [ { @@ -92,18 +115,18 @@ class DocVectors: ] self.client.insert( - collection_name=self.collections[dim], + collection_name=self.collections[(dim, user, collection)], data=data ) - def search(self, embeds, fields=["doc"], limit=10): + def search(self, embeds, user, collection, fields=["doc"], limit=10): dim = len(embeds) - if dim not in self.collections: - self.init_collection(dim) + if (dim, user, collection) not in self.collections: + self.init_collection(dim, user, collection) - coll = self.collections[dim] + coll = self.collections[(dim, user, collection)] search_params = { "metric_type": "COSINE", diff --git a/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py b/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py index 99cfb0b4..b179c7de 100644 --- a/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py +++ b/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py @@ -2,9 +2,32 @@ from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType import time import logging +import re logger = logging.getLogger(__name__) +def make_safe_collection_name(user, collection, dimension, prefix): + """ + Create a safe Milvus collection name from user/collection parameters. + Milvus only allows letters, numbers, and underscores. + """ + def sanitize(s): + # Replace non-alphanumeric characters (except underscore) with underscore + # Then collapse multiple underscores into single underscore + safe = re.sub(r'[^a-zA-Z0-9_]', '_', s) + safe = re.sub(r'_+', '_', safe) + # Remove leading/trailing underscores + safe = safe.strip('_') + # Ensure it's not empty + if not safe: + safe = 'default' + return safe + + safe_user = sanitize(user) + safe_collection = sanitize(collection) + + return f"{prefix}_{safe_user}_{safe_collection}_{dimension}" + class EntityVectors: def __init__(self, uri="http://localhost:19530", prefix='entity'): @@ -26,9 +49,9 @@ class EntityVectors: self.next_reload = time.time() + self.reload_time logger.debug(f"Reload at {self.next_reload}") - def init_collection(self, dimension): + def init_collection(self, dimension, user, collection): - collection_name = self.prefix + "_" + str(dimension) + collection_name = make_safe_collection_name(user, collection, dimension, self.prefix) pkey_field = FieldSchema( name="id", @@ -75,14 +98,14 @@ class EntityVectors: index_params=index_params ) - self.collections[dimension] = collection_name + self.collections[(dimension, user, collection)] = collection_name - def insert(self, embeds, entity): + def insert(self, embeds, entity, user, collection): dim = len(embeds) - if dim not in self.collections: - self.init_collection(dim) + if (dim, user, collection) not in self.collections: + self.init_collection(dim, user, collection) data = [ { @@ -92,18 +115,18 @@ class EntityVectors: ] self.client.insert( - collection_name=self.collections[dim], + collection_name=self.collections[(dim, user, collection)], data=data ) - def search(self, embeds, fields=["entity"], limit=10): + def search(self, embeds, user, collection, fields=["entity"], limit=10): dim = len(embeds) - if dim not in self.collections: - self.init_collection(dim) + if (dim, user, collection) not in self.collections: + self.init_collection(dim, user, collection) - coll = self.collections[dim] + coll = self.collections[(dim, user, collection)] search_params = { "metric_type": "COSINE", diff --git a/trustgraph-flow/trustgraph/direct/milvus_object_embeddings.py b/trustgraph-flow/trustgraph/direct/milvus_object_embeddings.py deleted file mode 100644 index 290f5155..00000000 --- a/trustgraph-flow/trustgraph/direct/milvus_object_embeddings.py +++ /dev/null @@ -1,157 +0,0 @@ - -from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType -import time -import logging - -logger = logging.getLogger(__name__) - -class ObjectVectors: - - def __init__(self, uri="http://localhost:19530", prefix='obj'): - - self.client = MilvusClient(uri=uri) - - # Strategy is to create collections per dimension. Probably only - # going to be using 1 anyway, but that means we don't need to - # hard-code the dimension anywhere, and no big deal if more than - # one are created. - self.collections = {} - - self.prefix = prefix - - # Time between reloads - self.reload_time = 90 - - # Next time to reload - this forces a reload at next window - self.next_reload = time.time() + self.reload_time - logger.debug(f"Reload at {self.next_reload}") - - def init_collection(self, dimension, name): - - collection_name = self.prefix + "_" + name + "_" + str(dimension) - - pkey_field = FieldSchema( - name="id", - dtype=DataType.INT64, - is_primary=True, - auto_id=True, - ) - - vec_field = FieldSchema( - name="vector", - dtype=DataType.FLOAT_VECTOR, - dim=dimension, - ) - - name_field = FieldSchema( - name="name", - dtype=DataType.VARCHAR, - max_length=65535, - ) - - key_name_field = FieldSchema( - name="key_name", - dtype=DataType.VARCHAR, - max_length=65535, - ) - - key_field = FieldSchema( - name="key", - dtype=DataType.VARCHAR, - max_length=65535, - ) - - schema = CollectionSchema( - fields = [ - pkey_field, vec_field, name_field, key_name_field, key_field - ], - description = "Object embedding schema", - ) - - self.client.create_collection( - collection_name=collection_name, - schema=schema, - metric_type="COSINE", - ) - - index_params = MilvusClient.prepare_index_params() - - index_params.add_index( - field_name="vector", - metric_type="COSINE", - index_type="IVF_SQ8", - index_name="vector_index", - params={ "nlist": 128 } - ) - - self.client.create_index( - collection_name=collection_name, - index_params=index_params - ) - - self.collections[(dimension, name)] = collection_name - - def insert(self, embeds, name, key_name, key): - - dim = len(embeds) - - if (dim, name) not in self.collections: - self.init_collection(dim, name) - - data = [ - { - "vector": embeds, - "name": name, - "key_name": key_name, - "key": key, - } - ] - - self.client.insert( - collection_name=self.collections[(dim, name)], - data=data - ) - - def search(self, embeds, name, fields=["key_name", "name"], limit=10): - - dim = len(embeds) - - if dim not in self.collections: - self.init_collection(dim, name) - - coll = self.collections[(dim, name)] - - search_params = { - "metric_type": "COSINE", - "params": { - "radius": 0.1, - "range_filter": 0.8 - } - } - - logger.debug("Loading...") - self.client.load_collection( - collection_name=coll, - ) - - logger.debug("Searching...") - - res = self.client.search( - collection_name=coll, - data=[embeds], - limit=limit, - output_fields=fields, - search_params=search_params, - )[0] - - - # If reload time has passed, unload collection - if time.time() > self.next_reload: - logger.debug(f"Unloading, reload at {self.next_reload}") - self.client.release_collection( - collection_name=coll, - ) - self.next_reload = time.time() + self.reload_time - - return res - diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py index dab4a892..2915184c 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py @@ -43,7 +43,12 @@ class Processor(DocumentEmbeddingsQueryService): for vec in msg.vectors: - resp = self.vecstore.search(vec, limit=msg.limit) + resp = self.vecstore.search( + vec, + msg.user, + msg.collection, + limit=msg.limit + ) for r in resp: chunk = r["entity"]["doc"] diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py index 750dd99b..cb9255c2 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py @@ -50,7 +50,12 @@ class Processor(GraphEmbeddingsQueryService): for vec in msg.vectors: - resp = self.vecstore.search(vec, limit=msg.limit * 2) + resp = self.vecstore.search( + vec, + msg.user, + msg.collection, + limit=msg.limit * 2 + ) for r in resp: ent = r["entity"]["entity"] diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py index 05027d75..b1d401aa 100755 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py @@ -33,7 +33,11 @@ class Processor(DocumentEmbeddingsStoreService): if chunk == "": continue for vec in emb.vectors: - self.vecstore.insert(vec, chunk) + self.vecstore.insert( + vec, chunk, + message.metadata.user, + message.metadata.collection + ) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py index f140ab76..68e56c0f 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py @@ -29,7 +29,11 @@ class Processor(GraphEmbeddingsStoreService): if entity.entity.value != "" and entity.entity.value is not None: for vec in entity.vectors: - self.vecstore.insert(vec, entity.entity.value) + self.vecstore.insert( + vec, entity.entity.value, + message.metadata.user, + message.metadata.collection + ) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/__init__.py b/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/__init__.py deleted file mode 100644 index d891d55f..00000000 --- a/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ - -from . write import * - diff --git a/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/__main__.py b/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/__main__.py deleted file mode 100755 index c05d8c6d..00000000 --- a/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/__main__.py +++ /dev/null @@ -1,7 +0,0 @@ -#!/usr/bin/env python3 - -from . write import run - -if __name__ == '__main__': - run() - diff --git a/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/write.py deleted file mode 100755 index d1ad139a..00000000 --- a/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/write.py +++ /dev/null @@ -1,61 +0,0 @@ - -""" -Accepts entity/vector pairs and writes them to a Milvus store. -""" - -from .... schema import ObjectEmbeddings -from .... schema import object_embeddings_store_queue -from .... log_level import LogLevel -from .... direct.milvus_object_embeddings import ObjectVectors -from .... base import Consumer - -module = "oe-write" - -default_input_queue = object_embeddings_store_queue -default_subscriber = module -default_store_uri = 'http://localhost:19530' - -class Processor(Consumer): - - def __init__(self, **params): - - input_queue = params.get("input_queue", default_input_queue) - subscriber = params.get("subscriber", default_subscriber) - store_uri = params.get("store_uri", default_store_uri) - - super(Processor, self).__init__( - **params | { - "input_queue": input_queue, - "subscriber": subscriber, - "input_schema": ObjectEmbeddings, - "store_uri": store_uri, - } - ) - - self.vecstore = ObjectVectors(store_uri) - - async def handle(self, msg): - - v = msg.value() - - if v.id != "" and v.id is not None: - for vec in v.vectors: - self.vecstore.insert(vec, v.name, v.key_name, v.id) - - @staticmethod - def add_args(parser): - - Consumer.add_args( - parser, default_input_queue, default_subscriber, - ) - - parser.add_argument( - '-t', '--store-uri', - default=default_store_uri, - help=f'Milvus store URI (default: {default_store_uri})' - ) - -def run(): - - Processor.launch(module, __doc__) -