From 7da94add4b8187127f2a8fb56ff11f388ea8e227 Mon Sep 17 00:00:00 2001 From: Cyber MacGeddon Date: Tue, 9 Sep 2025 21:29:51 +0100 Subject: [PATCH] - Remove object embeddings, were currently broken and not used - Fixed Milvus collection names --- .../direct/milvus_doc_embeddings.py | 45 +++-- .../direct/milvus_graph_embeddings.py | 45 +++-- .../direct/milvus_object_embeddings.py | 157 ------------------ .../query/doc_embeddings/milvus/service.py | 7 +- .../query/graph_embeddings/milvus/service.py | 7 +- .../storage/doc_embeddings/milvus/write.py | 6 +- .../storage/graph_embeddings/milvus/write.py | 6 +- .../object_embeddings/milvus/__init__.py | 3 - .../object_embeddings/milvus/__main__.py | 7 - .../storage/object_embeddings/milvus/write.py | 61 ------- 10 files changed, 90 insertions(+), 254 deletions(-) delete mode 100644 trustgraph-flow/trustgraph/direct/milvus_object_embeddings.py delete mode 100644 trustgraph-flow/trustgraph/storage/object_embeddings/milvus/__init__.py delete mode 100755 trustgraph-flow/trustgraph/storage/object_embeddings/milvus/__main__.py delete mode 100755 trustgraph-flow/trustgraph/storage/object_embeddings/milvus/write.py diff --git a/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py b/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py index 6d203858..220c8d7b 100644 --- a/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py +++ b/trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py @@ -2,9 +2,32 @@ from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType import time import logging +import re logger = logging.getLogger(__name__) +def make_safe_collection_name(user, collection, dimension, prefix): + """ + Create a safe Milvus collection name from user/collection parameters. + Milvus only allows letters, numbers, and underscores. + """ + def sanitize(s): + # Replace non-alphanumeric characters (except underscore) with underscore + # Then collapse multiple underscores into single underscore + safe = re.sub(r'[^a-zA-Z0-9_]', '_', s) + safe = re.sub(r'_+', '_', safe) + # Remove leading/trailing underscores + safe = safe.strip('_') + # Ensure it's not empty + if not safe: + safe = 'default' + return safe + + safe_user = sanitize(user) + safe_collection = sanitize(collection) + + return f"{prefix}_{safe_user}_{safe_collection}_{dimension}" + class DocVectors: def __init__(self, uri="http://localhost:19530", prefix='doc'): @@ -26,9 +49,9 @@ class DocVectors: self.next_reload = time.time() + self.reload_time logger.debug(f"Reload at {self.next_reload}") - def init_collection(self, dimension): + def init_collection(self, dimension, user, collection): - collection_name = self.prefix + "_" + str(dimension) + collection_name = make_safe_collection_name(user, collection, dimension, self.prefix) pkey_field = FieldSchema( name="id", @@ -75,14 +98,14 @@ class DocVectors: index_params=index_params ) - self.collections[dimension] = collection_name + self.collections[(dimension, user, collection)] = collection_name - def insert(self, embeds, doc): + def insert(self, embeds, doc, user, collection): dim = len(embeds) - if dim not in self.collections: - self.init_collection(dim) + if (dim, user, collection) not in self.collections: + self.init_collection(dim, user, collection) data = [ { @@ -92,18 +115,18 @@ class DocVectors: ] self.client.insert( - collection_name=self.collections[dim], + collection_name=self.collections[(dim, user, collection)], data=data ) - def search(self, embeds, fields=["doc"], limit=10): + def search(self, embeds, user, collection, fields=["doc"], limit=10): dim = len(embeds) - if dim not in self.collections: - self.init_collection(dim) + if (dim, user, collection) not in self.collections: + self.init_collection(dim, user, collection) - coll = self.collections[dim] + coll = self.collections[(dim, user, collection)] search_params = { "metric_type": "COSINE", diff --git a/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py b/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py index 99cfb0b4..b179c7de 100644 --- a/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py +++ b/trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py @@ -2,9 +2,32 @@ from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType import time import logging +import re logger = logging.getLogger(__name__) +def make_safe_collection_name(user, collection, dimension, prefix): + """ + Create a safe Milvus collection name from user/collection parameters. + Milvus only allows letters, numbers, and underscores. + """ + def sanitize(s): + # Replace non-alphanumeric characters (except underscore) with underscore + # Then collapse multiple underscores into single underscore + safe = re.sub(r'[^a-zA-Z0-9_]', '_', s) + safe = re.sub(r'_+', '_', safe) + # Remove leading/trailing underscores + safe = safe.strip('_') + # Ensure it's not empty + if not safe: + safe = 'default' + return safe + + safe_user = sanitize(user) + safe_collection = sanitize(collection) + + return f"{prefix}_{safe_user}_{safe_collection}_{dimension}" + class EntityVectors: def __init__(self, uri="http://localhost:19530", prefix='entity'): @@ -26,9 +49,9 @@ class EntityVectors: self.next_reload = time.time() + self.reload_time logger.debug(f"Reload at {self.next_reload}") - def init_collection(self, dimension): + def init_collection(self, dimension, user, collection): - collection_name = self.prefix + "_" + str(dimension) + collection_name = make_safe_collection_name(user, collection, dimension, self.prefix) pkey_field = FieldSchema( name="id", @@ -75,14 +98,14 @@ class EntityVectors: index_params=index_params ) - self.collections[dimension] = collection_name + self.collections[(dimension, user, collection)] = collection_name - def insert(self, embeds, entity): + def insert(self, embeds, entity, user, collection): dim = len(embeds) - if dim not in self.collections: - self.init_collection(dim) + if (dim, user, collection) not in self.collections: + self.init_collection(dim, user, collection) data = [ { @@ -92,18 +115,18 @@ class EntityVectors: ] self.client.insert( - collection_name=self.collections[dim], + collection_name=self.collections[(dim, user, collection)], data=data ) - def search(self, embeds, fields=["entity"], limit=10): + def search(self, embeds, user, collection, fields=["entity"], limit=10): dim = len(embeds) - if dim not in self.collections: - self.init_collection(dim) + if (dim, user, collection) not in self.collections: + self.init_collection(dim, user, collection) - coll = self.collections[dim] + coll = self.collections[(dim, user, collection)] search_params = { "metric_type": "COSINE", diff --git a/trustgraph-flow/trustgraph/direct/milvus_object_embeddings.py b/trustgraph-flow/trustgraph/direct/milvus_object_embeddings.py deleted file mode 100644 index 290f5155..00000000 --- a/trustgraph-flow/trustgraph/direct/milvus_object_embeddings.py +++ /dev/null @@ -1,157 +0,0 @@ - -from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType -import time -import logging - -logger = logging.getLogger(__name__) - -class ObjectVectors: - - def __init__(self, uri="http://localhost:19530", prefix='obj'): - - self.client = MilvusClient(uri=uri) - - # Strategy is to create collections per dimension. Probably only - # going to be using 1 anyway, but that means we don't need to - # hard-code the dimension anywhere, and no big deal if more than - # one are created. - self.collections = {} - - self.prefix = prefix - - # Time between reloads - self.reload_time = 90 - - # Next time to reload - this forces a reload at next window - self.next_reload = time.time() + self.reload_time - logger.debug(f"Reload at {self.next_reload}") - - def init_collection(self, dimension, name): - - collection_name = self.prefix + "_" + name + "_" + str(dimension) - - pkey_field = FieldSchema( - name="id", - dtype=DataType.INT64, - is_primary=True, - auto_id=True, - ) - - vec_field = FieldSchema( - name="vector", - dtype=DataType.FLOAT_VECTOR, - dim=dimension, - ) - - name_field = FieldSchema( - name="name", - dtype=DataType.VARCHAR, - max_length=65535, - ) - - key_name_field = FieldSchema( - name="key_name", - dtype=DataType.VARCHAR, - max_length=65535, - ) - - key_field = FieldSchema( - name="key", - dtype=DataType.VARCHAR, - max_length=65535, - ) - - schema = CollectionSchema( - fields = [ - pkey_field, vec_field, name_field, key_name_field, key_field - ], - description = "Object embedding schema", - ) - - self.client.create_collection( - collection_name=collection_name, - schema=schema, - metric_type="COSINE", - ) - - index_params = MilvusClient.prepare_index_params() - - index_params.add_index( - field_name="vector", - metric_type="COSINE", - index_type="IVF_SQ8", - index_name="vector_index", - params={ "nlist": 128 } - ) - - self.client.create_index( - collection_name=collection_name, - index_params=index_params - ) - - self.collections[(dimension, name)] = collection_name - - def insert(self, embeds, name, key_name, key): - - dim = len(embeds) - - if (dim, name) not in self.collections: - self.init_collection(dim, name) - - data = [ - { - "vector": embeds, - "name": name, - "key_name": key_name, - "key": key, - } - ] - - self.client.insert( - collection_name=self.collections[(dim, name)], - data=data - ) - - def search(self, embeds, name, fields=["key_name", "name"], limit=10): - - dim = len(embeds) - - if dim not in self.collections: - self.init_collection(dim, name) - - coll = self.collections[(dim, name)] - - search_params = { - "metric_type": "COSINE", - "params": { - "radius": 0.1, - "range_filter": 0.8 - } - } - - logger.debug("Loading...") - self.client.load_collection( - collection_name=coll, - ) - - logger.debug("Searching...") - - res = self.client.search( - collection_name=coll, - data=[embeds], - limit=limit, - output_fields=fields, - search_params=search_params, - )[0] - - - # If reload time has passed, unload collection - if time.time() > self.next_reload: - logger.debug(f"Unloading, reload at {self.next_reload}") - self.client.release_collection( - collection_name=coll, - ) - self.next_reload = time.time() + self.reload_time - - return res - diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py index dab4a892..2915184c 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py @@ -43,7 +43,12 @@ class Processor(DocumentEmbeddingsQueryService): for vec in msg.vectors: - resp = self.vecstore.search(vec, limit=msg.limit) + resp = self.vecstore.search( + vec, + msg.user, + msg.collection, + limit=msg.limit + ) for r in resp: chunk = r["entity"]["doc"] diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py index 750dd99b..cb9255c2 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py @@ -50,7 +50,12 @@ class Processor(GraphEmbeddingsQueryService): for vec in msg.vectors: - resp = self.vecstore.search(vec, limit=msg.limit * 2) + resp = self.vecstore.search( + vec, + msg.user, + msg.collection, + limit=msg.limit * 2 + ) for r in resp: ent = r["entity"]["entity"] diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py index 05027d75..b1d401aa 100755 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py @@ -33,7 +33,11 @@ class Processor(DocumentEmbeddingsStoreService): if chunk == "": continue for vec in emb.vectors: - self.vecstore.insert(vec, chunk) + self.vecstore.insert( + vec, chunk, + message.metadata.user, + message.metadata.collection + ) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py index f140ab76..68e56c0f 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py @@ -29,7 +29,11 @@ class Processor(GraphEmbeddingsStoreService): if entity.entity.value != "" and entity.entity.value is not None: for vec in entity.vectors: - self.vecstore.insert(vec, entity.entity.value) + self.vecstore.insert( + vec, entity.entity.value, + message.metadata.user, + message.metadata.collection + ) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/__init__.py b/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/__init__.py deleted file mode 100644 index d891d55f..00000000 --- a/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ - -from . write import * - diff --git a/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/__main__.py b/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/__main__.py deleted file mode 100755 index c05d8c6d..00000000 --- a/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/__main__.py +++ /dev/null @@ -1,7 +0,0 @@ -#!/usr/bin/env python3 - -from . write import run - -if __name__ == '__main__': - run() - diff --git a/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/write.py deleted file mode 100755 index d1ad139a..00000000 --- a/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/write.py +++ /dev/null @@ -1,61 +0,0 @@ - -""" -Accepts entity/vector pairs and writes them to a Milvus store. -""" - -from .... schema import ObjectEmbeddings -from .... schema import object_embeddings_store_queue -from .... log_level import LogLevel -from .... direct.milvus_object_embeddings import ObjectVectors -from .... base import Consumer - -module = "oe-write" - -default_input_queue = object_embeddings_store_queue -default_subscriber = module -default_store_uri = 'http://localhost:19530' - -class Processor(Consumer): - - def __init__(self, **params): - - input_queue = params.get("input_queue", default_input_queue) - subscriber = params.get("subscriber", default_subscriber) - store_uri = params.get("store_uri", default_store_uri) - - super(Processor, self).__init__( - **params | { - "input_queue": input_queue, - "subscriber": subscriber, - "input_schema": ObjectEmbeddings, - "store_uri": store_uri, - } - ) - - self.vecstore = ObjectVectors(store_uri) - - async def handle(self, msg): - - v = msg.value() - - if v.id != "" and v.id is not None: - for vec in v.vectors: - self.vecstore.insert(vec, v.name, v.key_name, v.id) - - @staticmethod - def add_args(parser): - - Consumer.add_args( - parser, default_input_queue, default_subscriber, - ) - - parser.add_argument( - '-t', '--store-uri', - default=default_store_uri, - help=f'Milvus store URI (default: {default_store_uri})' - ) - -def run(): - - Processor.launch(module, __doc__) -