Embeddings API scores (#671)

- Put scores in all responses
- Remove unused 'middle' vector layer. Vector of texts -> vector of (vector embedding)
This commit is contained in:
cybermaggedon 2026-03-09 10:53:44 +00:00 committed by GitHub
parent 4fa7cc7d7c
commit f2ae0e8623
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
65 changed files with 1339 additions and 1292 deletions

View file

@ -9,12 +9,12 @@ from .. knowledge import Uri, Literal
logger = logging.getLogger(__name__)
class DocumentEmbeddingsClient(RequestResponse):
async def query(self, vectors, limit=20, user="trustgraph",
async def query(self, vector, limit=20, user="trustgraph",
collection="default", timeout=30):
resp = await self.request(
DocumentEmbeddingsRequest(
vectors = vectors,
vector = vector,
limit = limit,
user = user,
collection = collection
@ -27,7 +27,8 @@ class DocumentEmbeddingsClient(RequestResponse):
if resp.error:
raise RuntimeError(resp.error.message)
return resp.chunk_ids
# Return ChunkMatch objects with chunk_id and score
return resp.chunks
class DocumentEmbeddingsClientSpec(RequestResponseSpec):
def __init__(

View file

@ -57,7 +57,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
docs = await self.query_document_embeddings(request)
logger.debug("Sending document embeddings query response...")
r = DocumentEmbeddingsResponse(chunk_ids=docs, error=None)
r = DocumentEmbeddingsResponse(chunks=docs, error=None)
await flow("response").send(r, properties={"id": id})
logger.debug("Document embeddings query request completed")
@ -73,7 +73,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
type = "document-embeddings-query-error",
message = str(e),
),
chunk_ids=[],
chunks=[],
)
await flow("response").send(r, properties={"id": id})

View file

@ -19,12 +19,12 @@ def to_value(x):
return Literal(x.value or x.iri)
class GraphEmbeddingsClient(RequestResponse):
async def query(self, vectors, limit=20, user="trustgraph",
async def query(self, vector, limit=20, user="trustgraph",
collection="default", timeout=30):
resp = await self.request(
GraphEmbeddingsRequest(
vectors = vectors,
vector = vector,
limit = limit,
user = user,
collection = collection
@ -37,10 +37,8 @@ class GraphEmbeddingsClient(RequestResponse):
if resp.error:
raise RuntimeError(resp.error.message)
return [
to_value(v)
for v in resp.entities
]
# Return EntityMatch objects with entity and score
return resp.entities
class GraphEmbeddingsClientSpec(RequestResponseSpec):
def __init__(

View file

@ -3,11 +3,11 @@ from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse
class RowEmbeddingsQueryClient(RequestResponse):
async def row_embeddings_query(
self, vectors, schema_name, user="trustgraph", collection="default",
self, vector, schema_name, user="trustgraph", collection="default",
index_name=None, limit=10, timeout=600
):
request = RowEmbeddingsRequest(
vectors=vectors,
vector=vector,
schema_name=schema_name,
user=user,
collection=collection,