mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-28 01:46:22 +02:00
Make dimensions work dynamically
This commit is contained in:
parent
72814c2029
commit
3741b54566
1 changed files with 41 additions and 15 deletions
|
|
@ -7,13 +7,21 @@ class TripleVectors:
|
||||||
|
|
||||||
self.client = MilvusClient(uri=uri)
|
self.client = MilvusClient(uri=uri)
|
||||||
|
|
||||||
self.collection = "edges"
|
# Strategy is to create collections per dimension. Probably only
|
||||||
self.dimension = 384
|
# going to be using 1 anyway, but that means we don't need to
|
||||||
|
# hard-code the dimension anywhere, and no big deal if more than
|
||||||
|
# one are created.
|
||||||
|
self.collections = {}
|
||||||
|
|
||||||
if not self.client.has_collection(collection_name=self.collection):
|
# self.collection = "edges"
|
||||||
self.init_collection()
|
# self.dimension = 384
|
||||||
|
|
||||||
def init_collection(self):
|
# if not self.client.has_collection(collection_name=self.collection):
|
||||||
|
# self.init_collection()
|
||||||
|
|
||||||
|
def init_collection(self, dimension):
|
||||||
|
|
||||||
|
collection_name = "triples_" + str(dimension)
|
||||||
|
|
||||||
pkey_field = FieldSchema(
|
pkey_field = FieldSchema(
|
||||||
name="id",
|
name="id",
|
||||||
|
|
@ -25,7 +33,7 @@ class TripleVectors:
|
||||||
vec_field = FieldSchema(
|
vec_field = FieldSchema(
|
||||||
name="vector",
|
name="vector",
|
||||||
dtype=DataType.FLOAT_VECTOR,
|
dtype=DataType.FLOAT_VECTOR,
|
||||||
dim=self.dimension,
|
dim=dimension,
|
||||||
)
|
)
|
||||||
|
|
||||||
entity_field = FieldSchema(
|
entity_field = FieldSchema(
|
||||||
|
|
@ -40,9 +48,9 @@ class TripleVectors:
|
||||||
)
|
)
|
||||||
|
|
||||||
self.client.create_collection(
|
self.client.create_collection(
|
||||||
collection_name=self.collection,
|
collection_name=collection_name,
|
||||||
schema=schema,
|
schema=schema,
|
||||||
metric_type="IP",
|
metric_type="COSINE",
|
||||||
)
|
)
|
||||||
|
|
||||||
index_params = MilvusClient.prepare_index_params()
|
index_params = MilvusClient.prepare_index_params()
|
||||||
|
|
@ -50,17 +58,24 @@ class TripleVectors:
|
||||||
index_params.add_index(
|
index_params.add_index(
|
||||||
field_name="vector",
|
field_name="vector",
|
||||||
metric_type="COSINE",
|
metric_type="COSINE",
|
||||||
index_type="FLAT", # IVF_FLAT?!
|
index_type="IVF_SQ8",
|
||||||
index_name="vector_index",
|
index_name="vector_index",
|
||||||
params={ "nlist": 128 }
|
params={ "nlist": 128 }
|
||||||
)
|
)
|
||||||
|
|
||||||
self.client.create_index(
|
self.client.create_index(
|
||||||
collection_name=self.collection,
|
collection_name=collection_name,
|
||||||
index_params=index_params
|
index_params=index_params
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.collections[dimension] = collection_name
|
||||||
|
|
||||||
def insert(self, embeds, entity):
|
def insert(self, embeds, entity):
|
||||||
|
|
||||||
|
dim = len(embeds)
|
||||||
|
|
||||||
|
if dim not in self.collections:
|
||||||
|
self.init_collection(dim)
|
||||||
|
|
||||||
data = [
|
data = [
|
||||||
{
|
{
|
||||||
|
|
@ -69,10 +84,20 @@ class TripleVectors:
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
self.client.insert(collection_name=self.collection, data=data)
|
self.client.insert(
|
||||||
|
collection_name=self.collections[dim],
|
||||||
|
data=data
|
||||||
|
)
|
||||||
|
|
||||||
def search(self, embeds, fields=["entity"], limit=10):
|
def search(self, embeds, fields=["entity"], limit=10):
|
||||||
|
|
||||||
|
dim = len(embeds)
|
||||||
|
|
||||||
|
if dim not in self.collections:
|
||||||
|
self.init_collection(dim)
|
||||||
|
|
||||||
|
coll = self.collections[dim]
|
||||||
|
|
||||||
search_params = {
|
search_params = {
|
||||||
"metric_type": "COSINE",
|
"metric_type": "COSINE",
|
||||||
"params": {
|
"params": {
|
||||||
|
|
@ -82,20 +107,21 @@ class TripleVectors:
|
||||||
}
|
}
|
||||||
|
|
||||||
self.client.load_collection(
|
self.client.load_collection(
|
||||||
collection_name=self.collection,
|
collection_name=coll,
|
||||||
# replica_number=1
|
|
||||||
)
|
)
|
||||||
|
|
||||||
res = self.client.search(
|
res = self.client.search(
|
||||||
collection_name=self.collection,
|
collection_name=coll,
|
||||||
data=[embeds],
|
data=[embeds],
|
||||||
limit=limit,
|
limit=limit,
|
||||||
output_fields=fields,
|
output_fields=fields,
|
||||||
search_params=search_params,
|
search_params=search_params,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
|
# FIXME: a lot of loading/unloading going on. How about using a
|
||||||
|
# time window?
|
||||||
self.client.release_collection(
|
self.client.release_collection(
|
||||||
collection_name=self.collection,
|
collection_name=coll,
|
||||||
)
|
)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue