diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py index 64ae4d32..2534d278 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py @@ -73,7 +73,8 @@ class Processor(ConsumerProducer): print(f"Handling input {id}...", flush=True) - entities = set() + entity_set = set() + entities = [] for vec in v.vectors: @@ -85,20 +86,30 @@ class Processor(ConsumerProducer): index = self.pinecone.Index(index_name) + # Heuristic hack, get (2*limit), so that we have more chance + # of getting (limit) entities results = index.query( namespace=v.collection, vector=vec, - top_k=v.limit, + top_k=v.limit * 2, include_values=False, include_metadata=True ) for r in results.matches: - ent = r.metadata["entity"] - entities.add(ent) - # Convert set to list - entities = list(entities) + ent = r.metadata["entity"] + + # De-dupe entities + if ent not in entity_set: + entity_set.add(ent) + entities.append(ent) + + # Keep adding entities until limit + if len(entity_set) >= v.limit: break + + # Keep adding entities until limit + if len(entity_set) >= v.limit: break ents2 = [] diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py index 8991f9ea..c2dcaa4c 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py @@ -61,7 +61,8 @@ class Processor(ConsumerProducer): print(f"Handling input {id}...", flush=True) - entities = set() + entity_set = set() + entities = [] for vec in v.vectors: @@ -71,19 +72,28 @@ class Processor(ConsumerProducer): str(dim) ) + # Heuristic hack, get (2*limit), so that we have more chance + # of getting (limit) entities search_result = self.client.query_points( collection_name=collection, query=vec, - limit=v.limit, + limit=v.limit * 2, with_payload=True, ).points for r in search_result: ent = r.payload["entity"] - entities.add(ent) - # Convert set to list - entities = list(entities) + # De-dupe entities + if ent not in entity_set: + entity_set.add(ent) + entities.append(ent) + + # Keep adding entities until limit + if len(entity_set) >= v.limit: break + + # Keep adding entities until limit + if len(entity_set) >= v.limit: break ents2 = []