mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-28 09:56:22 +02:00
From vector DB, often get dupes, which means when end up returning (#210)
less then top_k elements. So, fetch top_k=(2 * limit) and limit to just (limit)
This commit is contained in:
parent
cd8d0c8cbc
commit
07f9b1f244
2 changed files with 32 additions and 11 deletions
|
|
@ -73,7 +73,8 @@ class Processor(ConsumerProducer):
|
|||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
|
||||
entities = set()
|
||||
entity_set = set()
|
||||
entities = []
|
||||
|
||||
for vec in v.vectors:
|
||||
|
||||
|
|
@ -85,20 +86,30 @@ class Processor(ConsumerProducer):
|
|||
|
||||
index = self.pinecone.Index(index_name)
|
||||
|
||||
# Heuristic hack, get (2*limit), so that we have more chance
|
||||
# of getting (limit) entities
|
||||
results = index.query(
|
||||
namespace=v.collection,
|
||||
vector=vec,
|
||||
top_k=v.limit,
|
||||
top_k=v.limit * 2,
|
||||
include_values=False,
|
||||
include_metadata=True
|
||||
)
|
||||
|
||||
for r in results.matches:
|
||||
ent = r.metadata["entity"]
|
||||
entities.add(ent)
|
||||
|
||||
# Convert set to list
|
||||
entities = list(entities)
|
||||
ent = r.metadata["entity"]
|
||||
|
||||
# De-dupe entities
|
||||
if ent not in entity_set:
|
||||
entity_set.add(ent)
|
||||
entities.append(ent)
|
||||
|
||||
# Keep adding entities until limit
|
||||
if len(entity_set) >= v.limit: break
|
||||
|
||||
# Keep adding entities until limit
|
||||
if len(entity_set) >= v.limit: break
|
||||
|
||||
ents2 = []
|
||||
|
||||
|
|
|
|||
|
|
@ -61,7 +61,8 @@ class Processor(ConsumerProducer):
|
|||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
|
||||
entities = set()
|
||||
entity_set = set()
|
||||
entities = []
|
||||
|
||||
for vec in v.vectors:
|
||||
|
||||
|
|
@ -71,19 +72,28 @@ class Processor(ConsumerProducer):
|
|||
str(dim)
|
||||
)
|
||||
|
||||
# Heuristic hack, get (2*limit), so that we have more chance
|
||||
# of getting (limit) entities
|
||||
search_result = self.client.query_points(
|
||||
collection_name=collection,
|
||||
query=vec,
|
||||
limit=v.limit,
|
||||
limit=v.limit * 2,
|
||||
with_payload=True,
|
||||
).points
|
||||
|
||||
for r in search_result:
|
||||
ent = r.payload["entity"]
|
||||
entities.add(ent)
|
||||
|
||||
# Convert set to list
|
||||
entities = list(entities)
|
||||
# De-dupe entities
|
||||
if ent not in entity_set:
|
||||
entity_set.add(ent)
|
||||
entities.append(ent)
|
||||
|
||||
# Keep adding entities until limit
|
||||
if len(entity_set) >= v.limit: break
|
||||
|
||||
# Keep adding entities until limit
|
||||
if len(entity_set) >= v.limit: break
|
||||
|
||||
ents2 = []
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue