mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-28 18:06:21 +02:00
- Remove object embeddings, were currently broken and not used
- Fixed Milvus collection names
This commit is contained in:
parent
6ac8a7c2d9
commit
7da94add4b
10 changed files with 90 additions and 254 deletions
|
|
@ -2,9 +2,32 @@
|
||||||
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def make_safe_collection_name(user, collection, dimension, prefix):
|
||||||
|
"""
|
||||||
|
Create a safe Milvus collection name from user/collection parameters.
|
||||||
|
Milvus only allows letters, numbers, and underscores.
|
||||||
|
"""
|
||||||
|
def sanitize(s):
|
||||||
|
# Replace non-alphanumeric characters (except underscore) with underscore
|
||||||
|
# Then collapse multiple underscores into single underscore
|
||||||
|
safe = re.sub(r'[^a-zA-Z0-9_]', '_', s)
|
||||||
|
safe = re.sub(r'_+', '_', safe)
|
||||||
|
# Remove leading/trailing underscores
|
||||||
|
safe = safe.strip('_')
|
||||||
|
# Ensure it's not empty
|
||||||
|
if not safe:
|
||||||
|
safe = 'default'
|
||||||
|
return safe
|
||||||
|
|
||||||
|
safe_user = sanitize(user)
|
||||||
|
safe_collection = sanitize(collection)
|
||||||
|
|
||||||
|
return f"{prefix}_{safe_user}_{safe_collection}_{dimension}"
|
||||||
|
|
||||||
class DocVectors:
|
class DocVectors:
|
||||||
|
|
||||||
def __init__(self, uri="http://localhost:19530", prefix='doc'):
|
def __init__(self, uri="http://localhost:19530", prefix='doc'):
|
||||||
|
|
@ -26,9 +49,9 @@ class DocVectors:
|
||||||
self.next_reload = time.time() + self.reload_time
|
self.next_reload = time.time() + self.reload_time
|
||||||
logger.debug(f"Reload at {self.next_reload}")
|
logger.debug(f"Reload at {self.next_reload}")
|
||||||
|
|
||||||
def init_collection(self, dimension):
|
def init_collection(self, dimension, user, collection):
|
||||||
|
|
||||||
collection_name = self.prefix + "_" + str(dimension)
|
collection_name = make_safe_collection_name(user, collection, dimension, self.prefix)
|
||||||
|
|
||||||
pkey_field = FieldSchema(
|
pkey_field = FieldSchema(
|
||||||
name="id",
|
name="id",
|
||||||
|
|
@ -75,14 +98,14 @@ class DocVectors:
|
||||||
index_params=index_params
|
index_params=index_params
|
||||||
)
|
)
|
||||||
|
|
||||||
self.collections[dimension] = collection_name
|
self.collections[(dimension, user, collection)] = collection_name
|
||||||
|
|
||||||
def insert(self, embeds, doc):
|
def insert(self, embeds, doc, user, collection):
|
||||||
|
|
||||||
dim = len(embeds)
|
dim = len(embeds)
|
||||||
|
|
||||||
if dim not in self.collections:
|
if (dim, user, collection) not in self.collections:
|
||||||
self.init_collection(dim)
|
self.init_collection(dim, user, collection)
|
||||||
|
|
||||||
data = [
|
data = [
|
||||||
{
|
{
|
||||||
|
|
@ -92,18 +115,18 @@ class DocVectors:
|
||||||
]
|
]
|
||||||
|
|
||||||
self.client.insert(
|
self.client.insert(
|
||||||
collection_name=self.collections[dim],
|
collection_name=self.collections[(dim, user, collection)],
|
||||||
data=data
|
data=data
|
||||||
)
|
)
|
||||||
|
|
||||||
def search(self, embeds, fields=["doc"], limit=10):
|
def search(self, embeds, user, collection, fields=["doc"], limit=10):
|
||||||
|
|
||||||
dim = len(embeds)
|
dim = len(embeds)
|
||||||
|
|
||||||
if dim not in self.collections:
|
if (dim, user, collection) not in self.collections:
|
||||||
self.init_collection(dim)
|
self.init_collection(dim, user, collection)
|
||||||
|
|
||||||
coll = self.collections[dim]
|
coll = self.collections[(dim, user, collection)]
|
||||||
|
|
||||||
search_params = {
|
search_params = {
|
||||||
"metric_type": "COSINE",
|
"metric_type": "COSINE",
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,32 @@
|
||||||
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def make_safe_collection_name(user, collection, dimension, prefix):
|
||||||
|
"""
|
||||||
|
Create a safe Milvus collection name from user/collection parameters.
|
||||||
|
Milvus only allows letters, numbers, and underscores.
|
||||||
|
"""
|
||||||
|
def sanitize(s):
|
||||||
|
# Replace non-alphanumeric characters (except underscore) with underscore
|
||||||
|
# Then collapse multiple underscores into single underscore
|
||||||
|
safe = re.sub(r'[^a-zA-Z0-9_]', '_', s)
|
||||||
|
safe = re.sub(r'_+', '_', safe)
|
||||||
|
# Remove leading/trailing underscores
|
||||||
|
safe = safe.strip('_')
|
||||||
|
# Ensure it's not empty
|
||||||
|
if not safe:
|
||||||
|
safe = 'default'
|
||||||
|
return safe
|
||||||
|
|
||||||
|
safe_user = sanitize(user)
|
||||||
|
safe_collection = sanitize(collection)
|
||||||
|
|
||||||
|
return f"{prefix}_{safe_user}_{safe_collection}_{dimension}"
|
||||||
|
|
||||||
class EntityVectors:
|
class EntityVectors:
|
||||||
|
|
||||||
def __init__(self, uri="http://localhost:19530", prefix='entity'):
|
def __init__(self, uri="http://localhost:19530", prefix='entity'):
|
||||||
|
|
@ -26,9 +49,9 @@ class EntityVectors:
|
||||||
self.next_reload = time.time() + self.reload_time
|
self.next_reload = time.time() + self.reload_time
|
||||||
logger.debug(f"Reload at {self.next_reload}")
|
logger.debug(f"Reload at {self.next_reload}")
|
||||||
|
|
||||||
def init_collection(self, dimension):
|
def init_collection(self, dimension, user, collection):
|
||||||
|
|
||||||
collection_name = self.prefix + "_" + str(dimension)
|
collection_name = make_safe_collection_name(user, collection, dimension, self.prefix)
|
||||||
|
|
||||||
pkey_field = FieldSchema(
|
pkey_field = FieldSchema(
|
||||||
name="id",
|
name="id",
|
||||||
|
|
@ -75,14 +98,14 @@ class EntityVectors:
|
||||||
index_params=index_params
|
index_params=index_params
|
||||||
)
|
)
|
||||||
|
|
||||||
self.collections[dimension] = collection_name
|
self.collections[(dimension, user, collection)] = collection_name
|
||||||
|
|
||||||
def insert(self, embeds, entity):
|
def insert(self, embeds, entity, user, collection):
|
||||||
|
|
||||||
dim = len(embeds)
|
dim = len(embeds)
|
||||||
|
|
||||||
if dim not in self.collections:
|
if (dim, user, collection) not in self.collections:
|
||||||
self.init_collection(dim)
|
self.init_collection(dim, user, collection)
|
||||||
|
|
||||||
data = [
|
data = [
|
||||||
{
|
{
|
||||||
|
|
@ -92,18 +115,18 @@ class EntityVectors:
|
||||||
]
|
]
|
||||||
|
|
||||||
self.client.insert(
|
self.client.insert(
|
||||||
collection_name=self.collections[dim],
|
collection_name=self.collections[(dim, user, collection)],
|
||||||
data=data
|
data=data
|
||||||
)
|
)
|
||||||
|
|
||||||
def search(self, embeds, fields=["entity"], limit=10):
|
def search(self, embeds, user, collection, fields=["entity"], limit=10):
|
||||||
|
|
||||||
dim = len(embeds)
|
dim = len(embeds)
|
||||||
|
|
||||||
if dim not in self.collections:
|
if (dim, user, collection) not in self.collections:
|
||||||
self.init_collection(dim)
|
self.init_collection(dim, user, collection)
|
||||||
|
|
||||||
coll = self.collections[dim]
|
coll = self.collections[(dim, user, collection)]
|
||||||
|
|
||||||
search_params = {
|
search_params = {
|
||||||
"metric_type": "COSINE",
|
"metric_type": "COSINE",
|
||||||
|
|
|
||||||
|
|
@ -1,157 +0,0 @@
|
||||||
|
|
||||||
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class ObjectVectors:
|
|
||||||
|
|
||||||
def __init__(self, uri="http://localhost:19530", prefix='obj'):
|
|
||||||
|
|
||||||
self.client = MilvusClient(uri=uri)
|
|
||||||
|
|
||||||
# Strategy is to create collections per dimension. Probably only
|
|
||||||
# going to be using 1 anyway, but that means we don't need to
|
|
||||||
# hard-code the dimension anywhere, and no big deal if more than
|
|
||||||
# one are created.
|
|
||||||
self.collections = {}
|
|
||||||
|
|
||||||
self.prefix = prefix
|
|
||||||
|
|
||||||
# Time between reloads
|
|
||||||
self.reload_time = 90
|
|
||||||
|
|
||||||
# Next time to reload - this forces a reload at next window
|
|
||||||
self.next_reload = time.time() + self.reload_time
|
|
||||||
logger.debug(f"Reload at {self.next_reload}")
|
|
||||||
|
|
||||||
def init_collection(self, dimension, name):
|
|
||||||
|
|
||||||
collection_name = self.prefix + "_" + name + "_" + str(dimension)
|
|
||||||
|
|
||||||
pkey_field = FieldSchema(
|
|
||||||
name="id",
|
|
||||||
dtype=DataType.INT64,
|
|
||||||
is_primary=True,
|
|
||||||
auto_id=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
vec_field = FieldSchema(
|
|
||||||
name="vector",
|
|
||||||
dtype=DataType.FLOAT_VECTOR,
|
|
||||||
dim=dimension,
|
|
||||||
)
|
|
||||||
|
|
||||||
name_field = FieldSchema(
|
|
||||||
name="name",
|
|
||||||
dtype=DataType.VARCHAR,
|
|
||||||
max_length=65535,
|
|
||||||
)
|
|
||||||
|
|
||||||
key_name_field = FieldSchema(
|
|
||||||
name="key_name",
|
|
||||||
dtype=DataType.VARCHAR,
|
|
||||||
max_length=65535,
|
|
||||||
)
|
|
||||||
|
|
||||||
key_field = FieldSchema(
|
|
||||||
name="key",
|
|
||||||
dtype=DataType.VARCHAR,
|
|
||||||
max_length=65535,
|
|
||||||
)
|
|
||||||
|
|
||||||
schema = CollectionSchema(
|
|
||||||
fields = [
|
|
||||||
pkey_field, vec_field, name_field, key_name_field, key_field
|
|
||||||
],
|
|
||||||
description = "Object embedding schema",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.client.create_collection(
|
|
||||||
collection_name=collection_name,
|
|
||||||
schema=schema,
|
|
||||||
metric_type="COSINE",
|
|
||||||
)
|
|
||||||
|
|
||||||
index_params = MilvusClient.prepare_index_params()
|
|
||||||
|
|
||||||
index_params.add_index(
|
|
||||||
field_name="vector",
|
|
||||||
metric_type="COSINE",
|
|
||||||
index_type="IVF_SQ8",
|
|
||||||
index_name="vector_index",
|
|
||||||
params={ "nlist": 128 }
|
|
||||||
)
|
|
||||||
|
|
||||||
self.client.create_index(
|
|
||||||
collection_name=collection_name,
|
|
||||||
index_params=index_params
|
|
||||||
)
|
|
||||||
|
|
||||||
self.collections[(dimension, name)] = collection_name
|
|
||||||
|
|
||||||
def insert(self, embeds, name, key_name, key):
|
|
||||||
|
|
||||||
dim = len(embeds)
|
|
||||||
|
|
||||||
if (dim, name) not in self.collections:
|
|
||||||
self.init_collection(dim, name)
|
|
||||||
|
|
||||||
data = [
|
|
||||||
{
|
|
||||||
"vector": embeds,
|
|
||||||
"name": name,
|
|
||||||
"key_name": key_name,
|
|
||||||
"key": key,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
self.client.insert(
|
|
||||||
collection_name=self.collections[(dim, name)],
|
|
||||||
data=data
|
|
||||||
)
|
|
||||||
|
|
||||||
def search(self, embeds, name, fields=["key_name", "name"], limit=10):
|
|
||||||
|
|
||||||
dim = len(embeds)
|
|
||||||
|
|
||||||
if dim not in self.collections:
|
|
||||||
self.init_collection(dim, name)
|
|
||||||
|
|
||||||
coll = self.collections[(dim, name)]
|
|
||||||
|
|
||||||
search_params = {
|
|
||||||
"metric_type": "COSINE",
|
|
||||||
"params": {
|
|
||||||
"radius": 0.1,
|
|
||||||
"range_filter": 0.8
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug("Loading...")
|
|
||||||
self.client.load_collection(
|
|
||||||
collection_name=coll,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("Searching...")
|
|
||||||
|
|
||||||
res = self.client.search(
|
|
||||||
collection_name=coll,
|
|
||||||
data=[embeds],
|
|
||||||
limit=limit,
|
|
||||||
output_fields=fields,
|
|
||||||
search_params=search_params,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
|
|
||||||
# If reload time has passed, unload collection
|
|
||||||
if time.time() > self.next_reload:
|
|
||||||
logger.debug(f"Unloading, reload at {self.next_reload}")
|
|
||||||
self.client.release_collection(
|
|
||||||
collection_name=coll,
|
|
||||||
)
|
|
||||||
self.next_reload = time.time() + self.reload_time
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
@ -43,7 +43,12 @@ class Processor(DocumentEmbeddingsQueryService):
|
||||||
|
|
||||||
for vec in msg.vectors:
|
for vec in msg.vectors:
|
||||||
|
|
||||||
resp = self.vecstore.search(vec, limit=msg.limit)
|
resp = self.vecstore.search(
|
||||||
|
vec,
|
||||||
|
msg.user,
|
||||||
|
msg.collection,
|
||||||
|
limit=msg.limit
|
||||||
|
)
|
||||||
|
|
||||||
for r in resp:
|
for r in resp:
|
||||||
chunk = r["entity"]["doc"]
|
chunk = r["entity"]["doc"]
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,12 @@ class Processor(GraphEmbeddingsQueryService):
|
||||||
|
|
||||||
for vec in msg.vectors:
|
for vec in msg.vectors:
|
||||||
|
|
||||||
resp = self.vecstore.search(vec, limit=msg.limit * 2)
|
resp = self.vecstore.search(
|
||||||
|
vec,
|
||||||
|
msg.user,
|
||||||
|
msg.collection,
|
||||||
|
limit=msg.limit * 2
|
||||||
|
)
|
||||||
|
|
||||||
for r in resp:
|
for r in resp:
|
||||||
ent = r["entity"]["entity"]
|
ent = r["entity"]["entity"]
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,11 @@ class Processor(DocumentEmbeddingsStoreService):
|
||||||
if chunk == "": continue
|
if chunk == "": continue
|
||||||
|
|
||||||
for vec in emb.vectors:
|
for vec in emb.vectors:
|
||||||
self.vecstore.insert(vec, chunk)
|
self.vecstore.insert(
|
||||||
|
vec, chunk,
|
||||||
|
message.metadata.user,
|
||||||
|
message.metadata.collection
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,11 @@ class Processor(GraphEmbeddingsStoreService):
|
||||||
|
|
||||||
if entity.entity.value != "" and entity.entity.value is not None:
|
if entity.entity.value != "" and entity.entity.value is not None:
|
||||||
for vec in entity.vectors:
|
for vec in entity.vectors:
|
||||||
self.vecstore.insert(vec, entity.entity.value)
|
self.vecstore.insert(
|
||||||
|
vec, entity.entity.value,
|
||||||
|
message.metadata.user,
|
||||||
|
message.metadata.collection
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
|
||||||
|
|
@ -1,3 +0,0 @@
|
||||||
|
|
||||||
from . write import *
|
|
||||||
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
from . write import run
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
run()
|
|
||||||
|
|
||||||
|
|
@ -1,61 +0,0 @@
|
||||||
|
|
||||||
"""
|
|
||||||
Accepts entity/vector pairs and writes them to a Milvus store.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .... schema import ObjectEmbeddings
|
|
||||||
from .... schema import object_embeddings_store_queue
|
|
||||||
from .... log_level import LogLevel
|
|
||||||
from .... direct.milvus_object_embeddings import ObjectVectors
|
|
||||||
from .... base import Consumer
|
|
||||||
|
|
||||||
module = "oe-write"
|
|
||||||
|
|
||||||
default_input_queue = object_embeddings_store_queue
|
|
||||||
default_subscriber = module
|
|
||||||
default_store_uri = 'http://localhost:19530'
|
|
||||||
|
|
||||||
class Processor(Consumer):
|
|
||||||
|
|
||||||
def __init__(self, **params):
|
|
||||||
|
|
||||||
input_queue = params.get("input_queue", default_input_queue)
|
|
||||||
subscriber = params.get("subscriber", default_subscriber)
|
|
||||||
store_uri = params.get("store_uri", default_store_uri)
|
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
|
||||||
**params | {
|
|
||||||
"input_queue": input_queue,
|
|
||||||
"subscriber": subscriber,
|
|
||||||
"input_schema": ObjectEmbeddings,
|
|
||||||
"store_uri": store_uri,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
self.vecstore = ObjectVectors(store_uri)
|
|
||||||
|
|
||||||
async def handle(self, msg):
|
|
||||||
|
|
||||||
v = msg.value()
|
|
||||||
|
|
||||||
if v.id != "" and v.id is not None:
|
|
||||||
for vec in v.vectors:
|
|
||||||
self.vecstore.insert(vec, v.name, v.key_name, v.id)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def add_args(parser):
|
|
||||||
|
|
||||||
Consumer.add_args(
|
|
||||||
parser, default_input_queue, default_subscriber,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'-t', '--store-uri',
|
|
||||||
default=default_store_uri,
|
|
||||||
help=f'Milvus store URI (default: {default_store_uri})'
|
|
||||||
)
|
|
||||||
|
|
||||||
def run():
|
|
||||||
|
|
||||||
Processor.launch(module, __doc__)
|
|
||||||
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue