mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-16 19:05:14 +02:00
Embeddings API scores (#671)
- Put scores in all responses - Remove unused 'middle' vector layer. Vector of texts -> vector of (vector embedding)
This commit is contained in:
parent
4fa7cc7d7c
commit
f2ae0e8623
65 changed files with 1339 additions and 1292 deletions
|
|
@ -7,7 +7,7 @@ entities
|
|||
import logging
|
||||
|
||||
from .... direct.milvus_graph_embeddings import EntityVectors
|
||||
from .... schema import GraphEmbeddingsResponse
|
||||
from .... schema import GraphEmbeddingsResponse, EntityMatch
|
||||
from .... schema import Error, Term, IRI, LITERAL
|
||||
from .... base import GraphEmbeddingsQueryService
|
||||
|
||||
|
|
@ -41,42 +41,41 @@ class Processor(GraphEmbeddingsQueryService):
|
|||
|
||||
try:
|
||||
|
||||
entity_set = set()
|
||||
entities = []
|
||||
vec = msg.vector
|
||||
if not vec:
|
||||
return []
|
||||
|
||||
# Handle zero limit case
|
||||
if msg.limit <= 0:
|
||||
return []
|
||||
|
||||
for vec in msg.vectors:
|
||||
resp = self.vecstore.search(
|
||||
vec,
|
||||
msg.user,
|
||||
msg.collection,
|
||||
limit=msg.limit * 2
|
||||
)
|
||||
|
||||
resp = self.vecstore.search(
|
||||
vec,
|
||||
msg.user,
|
||||
msg.collection,
|
||||
limit=msg.limit * 2
|
||||
)
|
||||
entity_set = set()
|
||||
entities = []
|
||||
|
||||
for r in resp:
|
||||
ent = r["entity"]["entity"]
|
||||
|
||||
# De-dupe entities
|
||||
if ent not in entity_set:
|
||||
entity_set.add(ent)
|
||||
entities.append(ent)
|
||||
for r in resp:
|
||||
ent = r["entity"]["entity"]
|
||||
# Milvus returns distance, convert to similarity score
|
||||
distance = r.get("distance", 0.0)
|
||||
score = 1.0 - distance if distance else 0.0
|
||||
|
||||
# Keep adding entities until limit
|
||||
if len(entity_set) >= msg.limit: break
|
||||
# De-dupe entities, keep highest score
|
||||
if ent not in entity_set:
|
||||
entity_set.add(ent)
|
||||
entities.append(EntityMatch(
|
||||
entity=self.create_value(ent),
|
||||
score=score,
|
||||
))
|
||||
|
||||
# Keep adding entities until limit
|
||||
if len(entity_set) >= msg.limit: break
|
||||
|
||||
ents2 = []
|
||||
|
||||
for ent in entities:
|
||||
ents2.append(self.create_value(ent))
|
||||
|
||||
entities = ents2
|
||||
if len(entities) >= msg.limit:
|
||||
break
|
||||
|
||||
logger.debug("Send response...")
|
||||
return entities
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue