from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType import time import logging import re logger = logging.getLogger(__name__) def make_safe_collection_name(workspace, collection, prefix): """ Create a safe Milvus collection name from workspace/collection parameters. Milvus only allows letters, numbers, and underscores. """ def sanitize(s): # Replace non-alphanumeric characters (except underscore) with underscore # Then collapse multiple underscores into single underscore safe = re.sub(r'[^a-zA-Z0-9_]', '_', s) safe = re.sub(r'_+', '_', safe) # Remove leading/trailing underscores safe = safe.strip('_') # Ensure it's not empty if not safe: safe = 'default' return safe safe_workspace = sanitize(workspace) safe_collection = sanitize(collection) return f"{prefix}_{safe_workspace}_{safe_collection}" class DocVectors: def __init__(self, uri="http://localhost:19530", prefix='doc'): self.client = MilvusClient(uri=uri) # Strategy is to create collections per dimension. Probably only # going to be using 1 anyway, but that means we don't need to # hard-code the dimension anywhere, and no big deal if more than # one are created. self.collections = {} self.prefix = prefix # Time between reloads self.reload_time = 90 # Next time to reload - this forces a reload at next window self.next_reload = time.time() + self.reload_time logger.debug(f"Reload at {self.next_reload}") def collection_exists(self, workspace, collection): """ Check if any collection exists for this workspace/collection combination. Since collections are dimension-specific, this checks if ANY dimension variant exists. """ base_name = make_safe_collection_name(workspace, collection, self.prefix) prefix = f"{base_name}_" all_collections = self.client.list_collections() return any(coll.startswith(prefix) for coll in all_collections) def create_collection(self, workspace, collection, dimension=384): """ No-op for explicit collection creation. Collections are created lazily on first insert with actual dimension. """ logger.info(f"Collection creation requested for {workspace}/{collection} - will be created lazily on first insert") def init_collection(self, dimension, workspace, collection): base_name = make_safe_collection_name(workspace, collection, self.prefix) collection_name = f"{base_name}_{dimension}" pkey_field = FieldSchema( name="id", dtype=DataType.INT64, is_primary=True, auto_id=True, ) vec_field = FieldSchema( name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension, ) chunk_id_field = FieldSchema( name="chunk_id", dtype=DataType.VARCHAR, max_length=65535, ) schema = CollectionSchema( fields = [pkey_field, vec_field, chunk_id_field], description = "Document embedding schema", ) self.client.create_collection( collection_name=collection_name, schema=schema, metric_type="COSINE", ) index_params = MilvusClient.prepare_index_params() index_params.add_index( field_name="vector", metric_type="COSINE", index_type="IVF_SQ8", index_name="vector_index", params={ "nlist": 128 } ) self.client.create_index( collection_name=collection_name, index_params=index_params ) self.collections[(dimension, workspace, collection)] = collection_name logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}") def insert(self, embeds, chunk_id, workspace, collection): dim = len(embeds) if (dim, workspace, collection) not in self.collections: self.init_collection(dim, workspace, collection) data = [ { "vector": embeds, "chunk_id": chunk_id, } ] self.client.insert( collection_name=self.collections[(dim, workspace, collection)], data=data ) def search(self, embeds, workspace, collection, fields=["chunk_id"], limit=10): dim = len(embeds) # Check if collection exists - return empty if not if (dim, workspace, collection) not in self.collections: base_name = make_safe_collection_name(workspace, collection, self.prefix) collection_name = f"{base_name}_{dim}" if not self.client.has_collection(collection_name): logger.info(f"Collection {collection_name} does not exist, returning empty results") return [] # Collection exists but not in cache, add it self.collections[(dim, workspace, collection)] = collection_name coll = self.collections[(dim, workspace, collection)] logger.debug("Loading...") self.client.load_collection( collection_name=coll, ) logger.debug("Searching...") res = self.client.search( collection_name=coll, anns_field="vector", data=[embeds], limit=limit, output_fields=fields, search_params={ "metric_type": "COSINE" }, )[0] # If reload time has passed, unload collection if time.time() > self.next_reload: logger.debug(f"Unloading, reload at {self.next_reload}") self.client.release_collection( collection_name=coll, ) self.next_reload = time.time() + self.reload_time return res def delete_collection(self, workspace, collection): """ Delete all dimension variants of the collection for the given workspace/collection. Since collections are created with dimension suffixes, we need to find and delete all. """ base_name = make_safe_collection_name(workspace, collection, self.prefix) prefix = f"{base_name}_" # Get all collections and filter for matches all_collections = self.client.list_collections() matching_collections = [coll for coll in all_collections if coll.startswith(prefix)] if not matching_collections: logger.info(f"No collections found matching prefix {prefix}") else: for collection_name in matching_collections: self.client.drop_collection(collection_name) logger.info(f"Deleted Milvus collection: {collection_name}") logger.info(f"Deleted {len(matching_collections)} collection(s) for {workspace}/{collection}") # Remove from our local cache keys_to_remove = [key for key in self.collections.keys() if key[1] == workspace and key[2] == collection] for key in keys_to_remove: del self.collections[key]