From vector DB, often get dupes, which means when end up returning (#210)

less then top_k elements.  So, fetch top_k=(2 * limit) and limit to
just (limit)
This commit is contained in:
cybermaggedon 2024-12-10 22:37:54 +00:00 committed by GitHub
parent cd8d0c8cbc
commit 07f9b1f244
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 32 additions and 11 deletions

View file

@ -73,7 +73,8 @@ class Processor(ConsumerProducer):
print(f"Handling input {id}...", flush=True)
entities = set()
entity_set = set()
entities = []
for vec in v.vectors:
@ -85,20 +86,30 @@ class Processor(ConsumerProducer):
index = self.pinecone.Index(index_name)
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) entities
results = index.query(
namespace=v.collection,
vector=vec,
top_k=v.limit,
top_k=v.limit * 2,
include_values=False,
include_metadata=True
)
for r in results.matches:
ent = r.metadata["entity"]
entities.add(ent)
# Convert set to list
entities = list(entities)
ent = r.metadata["entity"]
# De-dupe entities
if ent not in entity_set:
entity_set.add(ent)
entities.append(ent)
# Keep adding entities until limit
if len(entity_set) >= v.limit: break
# Keep adding entities until limit
if len(entity_set) >= v.limit: break
ents2 = []

View file

@ -61,7 +61,8 @@ class Processor(ConsumerProducer):
print(f"Handling input {id}...", flush=True)
entities = set()
entity_set = set()
entities = []
for vec in v.vectors:
@ -71,19 +72,28 @@ class Processor(ConsumerProducer):
str(dim)
)
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) entities
search_result = self.client.query_points(
collection_name=collection,
query=vec,
limit=v.limit,
limit=v.limit * 2,
with_payload=True,
).points
for r in search_result:
ent = r.payload["entity"]
entities.add(ent)
# Convert set to list
entities = list(entities)
# De-dupe entities
if ent not in entity_set:
entity_set.add(ent)
entities.append(ent)
# Keep adding entities until limit
if len(entity_set) >= v.limit: break
# Keep adding entities until limit
if len(entity_set) >= v.limit: break
ents2 = []