Embeddings API scores (#671)

- Put scores in all responses
- Remove unused 'middle' vector layer. Vector of texts -> vector of (vector embedding)
This commit is contained in:
cybermaggedon 2026-03-09 10:53:44 +00:00 committed by GitHub
parent 4fa7cc7d7c
commit f2ae0e8623
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
65 changed files with 1339 additions and 1292 deletions

View file

@ -612,12 +612,12 @@ class AsyncFlowInstance:
print(f"{entity['name']}: {entity['score']}")
```
"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = await self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
request_data = {
"vectors": vectors,
"vector": vector,
"user": user,
"collection": collection,
"limit": limit
@ -810,12 +810,12 @@ class AsyncFlowInstance:
print(f"{match['index_name']}: {match['index_value']} (score: {match['score']})")
```
"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = await self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
request_data = {
"vectors": vectors,
"vector": vector,
"schema_name": schema_name,
"user": user,
"collection": collection,

View file

@ -282,12 +282,12 @@ class AsyncSocketFlowInstance:
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs):
"""Query graph embeddings for semantic search"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = await self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
request = {
"vectors": vectors,
"vector": vector,
"user": user,
"collection": collection,
"limit": limit
@ -352,12 +352,12 @@ class AsyncSocketFlowInstance:
limit: int = 10, **kwargs
):
"""Query row embeddings for semantic search on structured data"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = await self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
request = {
"vectors": vectors,
"vector": vector,
"schema_name": schema_name,
"user": user,
"collection": collection,

View file

@ -602,13 +602,13 @@ class FlowInstance:
```
"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
# Query graph embeddings for semantic search
input = {
"vectors": vectors,
"vector": vector,
"user": user,
"collection": collection,
"limit": limit
@ -648,13 +648,13 @@ class FlowInstance:
```
"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
# Query document embeddings for semantic search
input = {
"vectors": vectors,
"vector": vector,
"user": user,
"collection": collection,
"limit": limit
@ -1362,13 +1362,13 @@ class FlowInstance:
```
"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
# Query row embeddings for semantic search
input = {
"vectors": vectors,
"vector": vector,
"schema_name": schema_name,
"user": user,
"collection": collection,

View file

@ -649,12 +649,12 @@ class SocketFlowInstance:
)
```
"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
request = {
"vectors": vectors,
"vector": vector,
"user": user,
"collection": collection,
"limit": limit
@ -698,12 +698,12 @@ class SocketFlowInstance:
# results contains {"chunk_ids": ["doc1/p0/c0", ...]}
```
"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
request = {
"vectors": vectors,
"vector": vector,
"user": user,
"collection": collection,
"limit": limit
@ -936,12 +936,12 @@ class SocketFlowInstance:
)
```
"""
# First convert text to embeddings vectors
# First convert text to embedding vector
emb_result = self.embeddings(texts=[text])
vectors = emb_result.get("vectors", [[]])[0]
vector = emb_result.get("vectors", [[]])[0]
request = {
"vectors": vectors,
"vector": vector,
"schema_name": schema_name,
"user": user,
"collection": collection,

View file

@ -9,12 +9,12 @@ from .. knowledge import Uri, Literal
logger = logging.getLogger(__name__)
class DocumentEmbeddingsClient(RequestResponse):
async def query(self, vectors, limit=20, user="trustgraph",
async def query(self, vector, limit=20, user="trustgraph",
collection="default", timeout=30):
resp = await self.request(
DocumentEmbeddingsRequest(
vectors = vectors,
vector = vector,
limit = limit,
user = user,
collection = collection
@ -27,7 +27,8 @@ class DocumentEmbeddingsClient(RequestResponse):
if resp.error:
raise RuntimeError(resp.error.message)
return resp.chunk_ids
# Return ChunkMatch objects with chunk_id and score
return resp.chunks
class DocumentEmbeddingsClientSpec(RequestResponseSpec):
def __init__(

View file

@ -57,7 +57,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
docs = await self.query_document_embeddings(request)
logger.debug("Sending document embeddings query response...")
r = DocumentEmbeddingsResponse(chunk_ids=docs, error=None)
r = DocumentEmbeddingsResponse(chunks=docs, error=None)
await flow("response").send(r, properties={"id": id})
logger.debug("Document embeddings query request completed")
@ -73,7 +73,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
type = "document-embeddings-query-error",
message = str(e),
),
chunk_ids=[],
chunks=[],
)
await flow("response").send(r, properties={"id": id})

View file

@ -19,12 +19,12 @@ def to_value(x):
return Literal(x.value or x.iri)
class GraphEmbeddingsClient(RequestResponse):
async def query(self, vectors, limit=20, user="trustgraph",
async def query(self, vector, limit=20, user="trustgraph",
collection="default", timeout=30):
resp = await self.request(
GraphEmbeddingsRequest(
vectors = vectors,
vector = vector,
limit = limit,
user = user,
collection = collection
@ -37,10 +37,8 @@ class GraphEmbeddingsClient(RequestResponse):
if resp.error:
raise RuntimeError(resp.error.message)
return [
to_value(v)
for v in resp.entities
]
# Return EntityMatch objects with entity and score
return resp.entities
class GraphEmbeddingsClientSpec(RequestResponseSpec):
def __init__(

View file

@ -3,11 +3,11 @@ from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse
class RowEmbeddingsQueryClient(RequestResponse):
async def row_embeddings_query(
self, vectors, schema_name, user="trustgraph", collection="default",
self, vector, schema_name, user="trustgraph", collection="default",
index_name=None, limit=10, timeout=600
):
request = RowEmbeddingsRequest(
vectors=vectors,
vector=vector,
schema_name=schema_name,
user=user,
collection=collection,

View file

@ -41,11 +41,11 @@ class DocumentEmbeddingsClient(BaseClient):
)
def request(
self, vectors, user="trustgraph", collection="default",
self, vector, user="trustgraph", collection="default",
limit=10, timeout=300
):
return self.call(
user=user, collection=collection,
vectors=vectors, limit=limit, timeout=timeout
vector=vector, limit=limit, timeout=timeout
).chunks

View file

@ -41,11 +41,11 @@ class GraphEmbeddingsClient(BaseClient):
)
def request(
self, vectors, user="trustgraph", collection="default",
self, vector, user="trustgraph", collection="default",
limit=10, timeout=300
):
return self.call(
user=user, collection=collection,
vectors=vectors, limit=limit, timeout=timeout
vector=vector, limit=limit, timeout=timeout
).entities

View file

@ -41,12 +41,12 @@ class RowEmbeddingsClient(BaseClient):
)
def request(
self, vectors, schema_name, user="trustgraph", collection="default",
self, vector, schema_name, user="trustgraph", collection="default",
index_name=None, limit=10, timeout=300
):
kwargs = dict(
user=user, collection=collection,
vectors=vectors, schema_name=schema_name,
vector=vector, schema_name=schema_name,
limit=limit, timeout=timeout
)
if index_name:

View file

@ -10,18 +10,18 @@ from .primitives import ValueTranslator
class DocumentEmbeddingsRequestTranslator(MessageTranslator):
"""Translator for DocumentEmbeddingsRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsRequest:
return DocumentEmbeddingsRequest(
vectors=data["vectors"],
vector=data["vector"],
limit=int(data.get("limit", 10)),
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default")
)
def from_pulsar(self, obj: DocumentEmbeddingsRequest) -> Dict[str, Any]:
return {
"vectors": obj.vectors,
"vector": obj.vector,
"limit": obj.limit,
"user": obj.user,
"collection": obj.collection
@ -30,18 +30,24 @@ class DocumentEmbeddingsRequestTranslator(MessageTranslator):
class DocumentEmbeddingsResponseTranslator(MessageTranslator):
"""Translator for DocumentEmbeddingsResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]:
result = {}
if obj.chunk_ids is not None:
result["chunk_ids"] = list(obj.chunk_ids)
if obj.chunks is not None:
result["chunks"] = [
{
"chunk_id": chunk.chunk_id,
"score": chunk.score
}
for chunk in obj.chunks
]
return result
def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True
@ -49,18 +55,18 @@ class DocumentEmbeddingsResponseTranslator(MessageTranslator):
class GraphEmbeddingsRequestTranslator(MessageTranslator):
"""Translator for GraphEmbeddingsRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsRequest:
return GraphEmbeddingsRequest(
vectors=data["vectors"],
vector=data["vector"],
limit=int(data.get("limit", 10)),
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default")
)
def from_pulsar(self, obj: GraphEmbeddingsRequest) -> Dict[str, Any]:
return {
"vectors": obj.vectors,
"vector": obj.vector,
"limit": obj.limit,
"user": obj.user,
"collection": obj.collection
@ -69,24 +75,27 @@ class GraphEmbeddingsRequestTranslator(MessageTranslator):
class GraphEmbeddingsResponseTranslator(MessageTranslator):
"""Translator for GraphEmbeddingsResponse schema objects"""
def __init__(self):
self.value_translator = ValueTranslator()
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: GraphEmbeddingsResponse) -> Dict[str, Any]:
result = {}
if obj.entities is not None:
result["entities"] = [
self.value_translator.from_pulsar(entity)
for entity in obj.entities
{
"entity": self.value_translator.from_pulsar(match.entity),
"score": match.score
}
for match in obj.entities
]
return result
def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True
@ -97,7 +106,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator):
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest:
return RowEmbeddingsRequest(
vectors=data["vectors"],
vector=data["vector"],
limit=int(data.get("limit", 10)),
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
@ -107,7 +116,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator):
def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]:
result = {
"vectors": obj.vectors,
"vector": obj.vector,
"limit": obj.limit,
"user": obj.user,
"collection": obj.collection,

View file

@ -11,7 +11,7 @@ from ..core.topic import topic
@dataclass
class EntityEmbeddings:
entity: Term | None = None
vectors: list[list[float]] = field(default_factory=list)
vector: list[float] = field(default_factory=list)
# Provenance: which chunk this embedding was derived from
chunk_id: str = ""
@ -28,7 +28,7 @@ class GraphEmbeddings:
@dataclass
class ChunkEmbeddings:
chunk_id: str = ""
vectors: list[list[float]] = field(default_factory=list)
vector: list[float] = field(default_factory=list)
# This is a 'batching' mechanism for the above data
@dataclass
@ -44,7 +44,7 @@ class DocumentEmbeddings:
@dataclass
class ObjectEmbeddings:
metadata: Metadata | None = None
vectors: list[list[float]] = field(default_factory=list)
vector: list[float] = field(default_factory=list)
name: str = ""
key_name: str = ""
id: str = ""
@ -56,7 +56,7 @@ class ObjectEmbeddings:
@dataclass
class StructuredObjectEmbedding:
metadata: Metadata | None = None
vectors: list[list[float]] = field(default_factory=list)
vector: list[float] = field(default_factory=list)
schema_name: str = ""
object_id: str = "" # Primary key value
field_embeddings: dict[str, list[float]] = field(default_factory=dict) # Per-field embeddings
@ -72,7 +72,7 @@ class RowIndexEmbedding:
index_name: str = "" # The indexed field name(s)
index_value: list[str] = field(default_factory=list) # The field value(s)
text: str = "" # Text that was embedded
vectors: list[list[float]] = field(default_factory=list)
vector: list[float] = field(default_factory=list)
@dataclass
class RowEmbeddings:

View file

@ -34,7 +34,7 @@ class EmbeddingsRequest:
@dataclass
class EmbeddingsResponse:
error: Error | None = None
vectors: list[list[list[float]]] = field(default_factory=list)
vectors: list[list[float]] = field(default_factory=list)
############################################################################

View file

@ -9,15 +9,21 @@ from ..core.topic import topic
@dataclass
class GraphEmbeddingsRequest:
vectors: list[list[float]] = field(default_factory=list)
vector: list[float] = field(default_factory=list)
limit: int = 0
user: str = ""
collection: str = ""
@dataclass
class EntityMatch:
"""A matching entity from a semantic search with similarity score"""
entity: Term | None = None
score: float = 0.0
@dataclass
class GraphEmbeddingsResponse:
error: Error | None = None
entities: list[Term] = field(default_factory=list)
entities: list[EntityMatch] = field(default_factory=list)
############################################################################
@ -44,15 +50,21 @@ class TriplesQueryResponse:
@dataclass
class DocumentEmbeddingsRequest:
vectors: list[list[float]] = field(default_factory=list)
vector: list[float] = field(default_factory=list)
limit: int = 0
user: str = ""
collection: str = ""
@dataclass
class ChunkMatch:
"""A matching chunk from a semantic search with similarity score"""
chunk_id: str = ""
score: float = 0.0
@dataclass
class DocumentEmbeddingsResponse:
error: Error | None = None
chunk_ids: list[str] = field(default_factory=list)
chunks: list[ChunkMatch] = field(default_factory=list)
document_embeddings_request_queue = topic(
"document-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow'
@ -76,7 +88,7 @@ class RowIndexMatch:
@dataclass
class RowEmbeddingsRequest:
"""Request for row embeddings semantic search"""
vectors: list[list[float]] = field(default_factory=list) # Query vectors
vector: list[float] = field(default_factory=list) # Query vector
limit: int = 10 # Max results to return
user: str = "" # User/keyspace
collection: str = "" # Collection name