From 3741b545660ba29ea4e1583baacd5a47fcf810f7 Mon Sep 17 00:00:00 2001 From: Cyber MacGeddon Date: Tue, 16 Jul 2024 17:43:29 +0100 Subject: [PATCH] Make dimensions work dynamically --- trustgraph/triple_vectors.py | 56 ++++++++++++++++++++++++++---------- 1 file changed, 41 insertions(+), 15 deletions(-) diff --git a/trustgraph/triple_vectors.py b/trustgraph/triple_vectors.py index f631c506..1d0988cd 100644 --- a/trustgraph/triple_vectors.py +++ b/trustgraph/triple_vectors.py @@ -7,13 +7,21 @@ class TripleVectors: self.client = MilvusClient(uri=uri) - self.collection = "edges" - self.dimension = 384 + # Strategy is to create collections per dimension. Probably only + # going to be using 1 anyway, but that means we don't need to + # hard-code the dimension anywhere, and no big deal if more than + # one are created. + self.collections = {} - if not self.client.has_collection(collection_name=self.collection): - self.init_collection() +# self.collection = "edges" +# self.dimension = 384 - def init_collection(self): +# if not self.client.has_collection(collection_name=self.collection): +# self.init_collection() + + def init_collection(self, dimension): + + collection_name = "triples_" + str(dimension) pkey_field = FieldSchema( name="id", @@ -25,7 +33,7 @@ class TripleVectors: vec_field = FieldSchema( name="vector", dtype=DataType.FLOAT_VECTOR, - dim=self.dimension, + dim=dimension, ) entity_field = FieldSchema( @@ -40,9 +48,9 @@ class TripleVectors: ) self.client.create_collection( - collection_name=self.collection, + collection_name=collection_name, schema=schema, - metric_type="IP", + metric_type="COSINE", ) index_params = MilvusClient.prepare_index_params() @@ -50,17 +58,24 @@ class TripleVectors: index_params.add_index( field_name="vector", metric_type="COSINE", - index_type="FLAT", # IVF_FLAT?! + index_type="IVF_SQ8", index_name="vector_index", params={ "nlist": 128 } ) self.client.create_index( - collection_name=self.collection, + collection_name=collection_name, index_params=index_params ) + self.collections[dimension] = collection_name + def insert(self, embeds, entity): + + dim = len(embeds) + + if dim not in self.collections: + self.init_collection(dim) data = [ { @@ -69,10 +84,20 @@ class TripleVectors: } ] - self.client.insert(collection_name=self.collection, data=data) + self.client.insert( + collection_name=self.collections[dim], + data=data + ) def search(self, embeds, fields=["entity"], limit=10): + dim = len(embeds) + + if dim not in self.collections: + self.init_collection(dim) + + coll = self.collections[dim] + search_params = { "metric_type": "COSINE", "params": { @@ -82,20 +107,21 @@ class TripleVectors: } self.client.load_collection( - collection_name=self.collection, -# replica_number=1 + collection_name=coll, ) res = self.client.search( - collection_name=self.collection, + collection_name=coll, data=[embeds], limit=limit, output_fields=fields, search_params=search_params, )[0] + # FIXME: a lot of loading/unloading going on. How about using a + # time window? self.client.release_collection( - collection_name=self.collection, + collection_name=coll, ) return res