mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-02 03:42:36 +02:00
Feature/fix milvus (#507)
- Remove object embeddings, were currently broken and not used - Fixed Milvus collection names * Updating tests * Remove unused entrypoint
This commit is contained in:
parent
6ac8a7c2d9
commit
314ce76b81
15 changed files with 256 additions and 303 deletions
|
|
@ -2,9 +2,32 @@
|
|||
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
||||
import time
|
||||
import logging
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def make_safe_collection_name(user, collection, dimension, prefix):
|
||||
"""
|
||||
Create a safe Milvus collection name from user/collection parameters.
|
||||
Milvus only allows letters, numbers, and underscores.
|
||||
"""
|
||||
def sanitize(s):
|
||||
# Replace non-alphanumeric characters (except underscore) with underscore
|
||||
# Then collapse multiple underscores into single underscore
|
||||
safe = re.sub(r'[^a-zA-Z0-9_]', '_', s)
|
||||
safe = re.sub(r'_+', '_', safe)
|
||||
# Remove leading/trailing underscores
|
||||
safe = safe.strip('_')
|
||||
# Ensure it's not empty
|
||||
if not safe:
|
||||
safe = 'default'
|
||||
return safe
|
||||
|
||||
safe_user = sanitize(user)
|
||||
safe_collection = sanitize(collection)
|
||||
|
||||
return f"{prefix}_{safe_user}_{safe_collection}_{dimension}"
|
||||
|
||||
class DocVectors:
|
||||
|
||||
def __init__(self, uri="http://localhost:19530", prefix='doc'):
|
||||
|
|
@ -26,9 +49,9 @@ class DocVectors:
|
|||
self.next_reload = time.time() + self.reload_time
|
||||
logger.debug(f"Reload at {self.next_reload}")
|
||||
|
||||
def init_collection(self, dimension):
|
||||
def init_collection(self, dimension, user, collection):
|
||||
|
||||
collection_name = self.prefix + "_" + str(dimension)
|
||||
collection_name = make_safe_collection_name(user, collection, dimension, self.prefix)
|
||||
|
||||
pkey_field = FieldSchema(
|
||||
name="id",
|
||||
|
|
@ -75,14 +98,14 @@ class DocVectors:
|
|||
index_params=index_params
|
||||
)
|
||||
|
||||
self.collections[dimension] = collection_name
|
||||
self.collections[(dimension, user, collection)] = collection_name
|
||||
|
||||
def insert(self, embeds, doc):
|
||||
def insert(self, embeds, doc, user, collection):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim)
|
||||
if (dim, user, collection) not in self.collections:
|
||||
self.init_collection(dim, user, collection)
|
||||
|
||||
data = [
|
||||
{
|
||||
|
|
@ -92,18 +115,18 @@ class DocVectors:
|
|||
]
|
||||
|
||||
self.client.insert(
|
||||
collection_name=self.collections[dim],
|
||||
collection_name=self.collections[(dim, user, collection)],
|
||||
data=data
|
||||
)
|
||||
|
||||
def search(self, embeds, fields=["doc"], limit=10):
|
||||
def search(self, embeds, user, collection, fields=["doc"], limit=10):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim)
|
||||
if (dim, user, collection) not in self.collections:
|
||||
self.init_collection(dim, user, collection)
|
||||
|
||||
coll = self.collections[dim]
|
||||
coll = self.collections[(dim, user, collection)]
|
||||
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
|
|
|
|||
|
|
@ -2,9 +2,32 @@
|
|||
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
||||
import time
|
||||
import logging
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def make_safe_collection_name(user, collection, dimension, prefix):
|
||||
"""
|
||||
Create a safe Milvus collection name from user/collection parameters.
|
||||
Milvus only allows letters, numbers, and underscores.
|
||||
"""
|
||||
def sanitize(s):
|
||||
# Replace non-alphanumeric characters (except underscore) with underscore
|
||||
# Then collapse multiple underscores into single underscore
|
||||
safe = re.sub(r'[^a-zA-Z0-9_]', '_', s)
|
||||
safe = re.sub(r'_+', '_', safe)
|
||||
# Remove leading/trailing underscores
|
||||
safe = safe.strip('_')
|
||||
# Ensure it's not empty
|
||||
if not safe:
|
||||
safe = 'default'
|
||||
return safe
|
||||
|
||||
safe_user = sanitize(user)
|
||||
safe_collection = sanitize(collection)
|
||||
|
||||
return f"{prefix}_{safe_user}_{safe_collection}_{dimension}"
|
||||
|
||||
class EntityVectors:
|
||||
|
||||
def __init__(self, uri="http://localhost:19530", prefix='entity'):
|
||||
|
|
@ -26,9 +49,9 @@ class EntityVectors:
|
|||
self.next_reload = time.time() + self.reload_time
|
||||
logger.debug(f"Reload at {self.next_reload}")
|
||||
|
||||
def init_collection(self, dimension):
|
||||
def init_collection(self, dimension, user, collection):
|
||||
|
||||
collection_name = self.prefix + "_" + str(dimension)
|
||||
collection_name = make_safe_collection_name(user, collection, dimension, self.prefix)
|
||||
|
||||
pkey_field = FieldSchema(
|
||||
name="id",
|
||||
|
|
@ -75,14 +98,14 @@ class EntityVectors:
|
|||
index_params=index_params
|
||||
)
|
||||
|
||||
self.collections[dimension] = collection_name
|
||||
self.collections[(dimension, user, collection)] = collection_name
|
||||
|
||||
def insert(self, embeds, entity):
|
||||
def insert(self, embeds, entity, user, collection):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim)
|
||||
if (dim, user, collection) not in self.collections:
|
||||
self.init_collection(dim, user, collection)
|
||||
|
||||
data = [
|
||||
{
|
||||
|
|
@ -92,18 +115,18 @@ class EntityVectors:
|
|||
]
|
||||
|
||||
self.client.insert(
|
||||
collection_name=self.collections[dim],
|
||||
collection_name=self.collections[(dim, user, collection)],
|
||||
data=data
|
||||
)
|
||||
|
||||
def search(self, embeds, fields=["entity"], limit=10):
|
||||
def search(self, embeds, user, collection, fields=["entity"], limit=10):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim)
|
||||
if (dim, user, collection) not in self.collections:
|
||||
self.init_collection(dim, user, collection)
|
||||
|
||||
coll = self.collections[dim]
|
||||
coll = self.collections[(dim, user, collection)]
|
||||
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
|
|
|
|||
|
|
@ -1,157 +0,0 @@
|
|||
|
||||
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
||||
import time
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ObjectVectors:
|
||||
|
||||
def __init__(self, uri="http://localhost:19530", prefix='obj'):
|
||||
|
||||
self.client = MilvusClient(uri=uri)
|
||||
|
||||
# Strategy is to create collections per dimension. Probably only
|
||||
# going to be using 1 anyway, but that means we don't need to
|
||||
# hard-code the dimension anywhere, and no big deal if more than
|
||||
# one are created.
|
||||
self.collections = {}
|
||||
|
||||
self.prefix = prefix
|
||||
|
||||
# Time between reloads
|
||||
self.reload_time = 90
|
||||
|
||||
# Next time to reload - this forces a reload at next window
|
||||
self.next_reload = time.time() + self.reload_time
|
||||
logger.debug(f"Reload at {self.next_reload}")
|
||||
|
||||
def init_collection(self, dimension, name):
|
||||
|
||||
collection_name = self.prefix + "_" + name + "_" + str(dimension)
|
||||
|
||||
pkey_field = FieldSchema(
|
||||
name="id",
|
||||
dtype=DataType.INT64,
|
||||
is_primary=True,
|
||||
auto_id=True,
|
||||
)
|
||||
|
||||
vec_field = FieldSchema(
|
||||
name="vector",
|
||||
dtype=DataType.FLOAT_VECTOR,
|
||||
dim=dimension,
|
||||
)
|
||||
|
||||
name_field = FieldSchema(
|
||||
name="name",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
key_name_field = FieldSchema(
|
||||
name="key_name",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
key_field = FieldSchema(
|
||||
name="key",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
schema = CollectionSchema(
|
||||
fields = [
|
||||
pkey_field, vec_field, name_field, key_name_field, key_field
|
||||
],
|
||||
description = "Object embedding schema",
|
||||
)
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name,
|
||||
schema=schema,
|
||||
metric_type="COSINE",
|
||||
)
|
||||
|
||||
index_params = MilvusClient.prepare_index_params()
|
||||
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
metric_type="COSINE",
|
||||
index_type="IVF_SQ8",
|
||||
index_name="vector_index",
|
||||
params={ "nlist": 128 }
|
||||
)
|
||||
|
||||
self.client.create_index(
|
||||
collection_name=collection_name,
|
||||
index_params=index_params
|
||||
)
|
||||
|
||||
self.collections[(dimension, name)] = collection_name
|
||||
|
||||
def insert(self, embeds, name, key_name, key):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if (dim, name) not in self.collections:
|
||||
self.init_collection(dim, name)
|
||||
|
||||
data = [
|
||||
{
|
||||
"vector": embeds,
|
||||
"name": name,
|
||||
"key_name": key_name,
|
||||
"key": key,
|
||||
}
|
||||
]
|
||||
|
||||
self.client.insert(
|
||||
collection_name=self.collections[(dim, name)],
|
||||
data=data
|
||||
)
|
||||
|
||||
def search(self, embeds, name, fields=["key_name", "name"], limit=10):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim, name)
|
||||
|
||||
coll = self.collections[(dim, name)]
|
||||
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
"params": {
|
||||
"radius": 0.1,
|
||||
"range_filter": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug("Loading...")
|
||||
self.client.load_collection(
|
||||
collection_name=coll,
|
||||
)
|
||||
|
||||
logger.debug("Searching...")
|
||||
|
||||
res = self.client.search(
|
||||
collection_name=coll,
|
||||
data=[embeds],
|
||||
limit=limit,
|
||||
output_fields=fields,
|
||||
search_params=search_params,
|
||||
)[0]
|
||||
|
||||
|
||||
# If reload time has passed, unload collection
|
||||
if time.time() > self.next_reload:
|
||||
logger.debug(f"Unloading, reload at {self.next_reload}")
|
||||
self.client.release_collection(
|
||||
collection_name=coll,
|
||||
)
|
||||
self.next_reload = time.time() + self.reload_time
|
||||
|
||||
return res
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue