mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-28 01:46:22 +02:00
Make dimensions work dynamically
This commit is contained in:
parent
72814c2029
commit
3741b54566
1 changed files with 41 additions and 15 deletions
|
|
@ -7,13 +7,21 @@ class TripleVectors:
|
|||
|
||||
self.client = MilvusClient(uri=uri)
|
||||
|
||||
self.collection = "edges"
|
||||
self.dimension = 384
|
||||
# Strategy is to create collections per dimension. Probably only
|
||||
# going to be using 1 anyway, but that means we don't need to
|
||||
# hard-code the dimension anywhere, and no big deal if more than
|
||||
# one are created.
|
||||
self.collections = {}
|
||||
|
||||
if not self.client.has_collection(collection_name=self.collection):
|
||||
self.init_collection()
|
||||
# self.collection = "edges"
|
||||
# self.dimension = 384
|
||||
|
||||
def init_collection(self):
|
||||
# if not self.client.has_collection(collection_name=self.collection):
|
||||
# self.init_collection()
|
||||
|
||||
def init_collection(self, dimension):
|
||||
|
||||
collection_name = "triples_" + str(dimension)
|
||||
|
||||
pkey_field = FieldSchema(
|
||||
name="id",
|
||||
|
|
@ -25,7 +33,7 @@ class TripleVectors:
|
|||
vec_field = FieldSchema(
|
||||
name="vector",
|
||||
dtype=DataType.FLOAT_VECTOR,
|
||||
dim=self.dimension,
|
||||
dim=dimension,
|
||||
)
|
||||
|
||||
entity_field = FieldSchema(
|
||||
|
|
@ -40,9 +48,9 @@ class TripleVectors:
|
|||
)
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=self.collection,
|
||||
collection_name=collection_name,
|
||||
schema=schema,
|
||||
metric_type="IP",
|
||||
metric_type="COSINE",
|
||||
)
|
||||
|
||||
index_params = MilvusClient.prepare_index_params()
|
||||
|
|
@ -50,17 +58,24 @@ class TripleVectors:
|
|||
index_params.add_index(
|
||||
field_name="vector",
|
||||
metric_type="COSINE",
|
||||
index_type="FLAT", # IVF_FLAT?!
|
||||
index_type="IVF_SQ8",
|
||||
index_name="vector_index",
|
||||
params={ "nlist": 128 }
|
||||
)
|
||||
|
||||
self.client.create_index(
|
||||
collection_name=self.collection,
|
||||
collection_name=collection_name,
|
||||
index_params=index_params
|
||||
)
|
||||
|
||||
self.collections[dimension] = collection_name
|
||||
|
||||
def insert(self, embeds, entity):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim)
|
||||
|
||||
data = [
|
||||
{
|
||||
|
|
@ -69,10 +84,20 @@ class TripleVectors:
|
|||
}
|
||||
]
|
||||
|
||||
self.client.insert(collection_name=self.collection, data=data)
|
||||
self.client.insert(
|
||||
collection_name=self.collections[dim],
|
||||
data=data
|
||||
)
|
||||
|
||||
def search(self, embeds, fields=["entity"], limit=10):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim)
|
||||
|
||||
coll = self.collections[dim]
|
||||
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
"params": {
|
||||
|
|
@ -82,20 +107,21 @@ class TripleVectors:
|
|||
}
|
||||
|
||||
self.client.load_collection(
|
||||
collection_name=self.collection,
|
||||
# replica_number=1
|
||||
collection_name=coll,
|
||||
)
|
||||
|
||||
res = self.client.search(
|
||||
collection_name=self.collection,
|
||||
collection_name=coll,
|
||||
data=[embeds],
|
||||
limit=limit,
|
||||
output_fields=fields,
|
||||
search_params=search_params,
|
||||
)[0]
|
||||
|
||||
# FIXME: a lot of loading/unloading going on. How about using a
|
||||
# time window?
|
||||
self.client.release_collection(
|
||||
collection_name=self.collection,
|
||||
collection_name=coll,
|
||||
)
|
||||
|
||||
return res
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue