mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-01 17:39:39 +02:00
Embeddings API scores (#671)
- Put scores in all responses - Remove unused 'middle' vector layer. Vector of texts -> vector of (vector embedding)
This commit is contained in:
parent
4fa7cc7d7c
commit
f2ae0e8623
65 changed files with 1339 additions and 1292 deletions
|
|
@ -10,18 +10,18 @@ from .primitives import ValueTranslator
|
|||
|
||||
class DocumentEmbeddingsRequestTranslator(MessageTranslator):
|
||||
"""Translator for DocumentEmbeddingsRequest schema objects"""
|
||||
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsRequest:
|
||||
return DocumentEmbeddingsRequest(
|
||||
vectors=data["vectors"],
|
||||
vector=data["vector"],
|
||||
limit=int(data.get("limit", 10)),
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default")
|
||||
)
|
||||
|
||||
|
||||
def from_pulsar(self, obj: DocumentEmbeddingsRequest) -> Dict[str, Any]:
|
||||
return {
|
||||
"vectors": obj.vectors,
|
||||
"vector": obj.vector,
|
||||
"limit": obj.limit,
|
||||
"user": obj.user,
|
||||
"collection": obj.collection
|
||||
|
|
@ -30,18 +30,24 @@ class DocumentEmbeddingsRequestTranslator(MessageTranslator):
|
|||
|
||||
class DocumentEmbeddingsResponseTranslator(MessageTranslator):
|
||||
"""Translator for DocumentEmbeddingsResponse schema objects"""
|
||||
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
|
||||
def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.chunk_ids is not None:
|
||||
result["chunk_ids"] = list(obj.chunk_ids)
|
||||
if obj.chunks is not None:
|
||||
result["chunks"] = [
|
||||
{
|
||||
"chunk_id": chunk.chunk_id,
|
||||
"score": chunk.score
|
||||
}
|
||||
for chunk in obj.chunks
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
|
@ -49,18 +55,18 @@ class DocumentEmbeddingsResponseTranslator(MessageTranslator):
|
|||
|
||||
class GraphEmbeddingsRequestTranslator(MessageTranslator):
|
||||
"""Translator for GraphEmbeddingsRequest schema objects"""
|
||||
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsRequest:
|
||||
return GraphEmbeddingsRequest(
|
||||
vectors=data["vectors"],
|
||||
vector=data["vector"],
|
||||
limit=int(data.get("limit", 10)),
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default")
|
||||
)
|
||||
|
||||
|
||||
def from_pulsar(self, obj: GraphEmbeddingsRequest) -> Dict[str, Any]:
|
||||
return {
|
||||
"vectors": obj.vectors,
|
||||
"vector": obj.vector,
|
||||
"limit": obj.limit,
|
||||
"user": obj.user,
|
||||
"collection": obj.collection
|
||||
|
|
@ -69,24 +75,27 @@ class GraphEmbeddingsRequestTranslator(MessageTranslator):
|
|||
|
||||
class GraphEmbeddingsResponseTranslator(MessageTranslator):
|
||||
"""Translator for GraphEmbeddingsResponse schema objects"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.value_translator = ValueTranslator()
|
||||
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
|
||||
def from_pulsar(self, obj: GraphEmbeddingsResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
|
||||
if obj.entities is not None:
|
||||
result["entities"] = [
|
||||
self.value_translator.from_pulsar(entity)
|
||||
for entity in obj.entities
|
||||
{
|
||||
"entity": self.value_translator.from_pulsar(match.entity),
|
||||
"score": match.score
|
||||
}
|
||||
for match in obj.entities
|
||||
]
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
|
@ -97,7 +106,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator):
|
|||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest:
|
||||
return RowEmbeddingsRequest(
|
||||
vectors=data["vectors"],
|
||||
vector=data["vector"],
|
||||
limit=int(data.get("limit", 10)),
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default"),
|
||||
|
|
@ -107,7 +116,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator):
|
|||
|
||||
def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]:
|
||||
result = {
|
||||
"vectors": obj.vectors,
|
||||
"vector": obj.vector,
|
||||
"limit": obj.limit,
|
||||
"user": obj.user,
|
||||
"collection": obj.collection,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue