mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-26 17:06:22 +02:00
Embeddings API scores (#671)
- Put scores in all responses - Remove unused 'middle' vector layer. Vector of texts -> vector of (vector embedding)
This commit is contained in:
parent
4fa7cc7d7c
commit
f2ae0e8623
65 changed files with 1339 additions and 1292 deletions
|
|
@ -612,12 +612,12 @@ class AsyncFlowInstance:
|
|||
print(f"{entity['name']}: {entity['score']}")
|
||||
```
|
||||
"""
|
||||
# First convert text to embeddings vectors
|
||||
# First convert text to embedding vector
|
||||
emb_result = await self.embeddings(texts=[text])
|
||||
vectors = emb_result.get("vectors", [[]])[0]
|
||||
vector = emb_result.get("vectors", [[]])[0]
|
||||
|
||||
request_data = {
|
||||
"vectors": vectors,
|
||||
"vector": vector,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
|
|
@ -810,12 +810,12 @@ class AsyncFlowInstance:
|
|||
print(f"{match['index_name']}: {match['index_value']} (score: {match['score']})")
|
||||
```
|
||||
"""
|
||||
# First convert text to embeddings vectors
|
||||
# First convert text to embedding vector
|
||||
emb_result = await self.embeddings(texts=[text])
|
||||
vectors = emb_result.get("vectors", [[]])[0]
|
||||
vector = emb_result.get("vectors", [[]])[0]
|
||||
|
||||
request_data = {
|
||||
"vectors": vectors,
|
||||
"vector": vector,
|
||||
"schema_name": schema_name,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
|
|
|
|||
|
|
@ -282,12 +282,12 @@ class AsyncSocketFlowInstance:
|
|||
|
||||
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs):
|
||||
"""Query graph embeddings for semantic search"""
|
||||
# First convert text to embeddings vectors
|
||||
# First convert text to embedding vector
|
||||
emb_result = await self.embeddings(texts=[text])
|
||||
vectors = emb_result.get("vectors", [[]])[0]
|
||||
vector = emb_result.get("vectors", [[]])[0]
|
||||
|
||||
request = {
|
||||
"vectors": vectors,
|
||||
"vector": vector,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
|
|
@ -352,12 +352,12 @@ class AsyncSocketFlowInstance:
|
|||
limit: int = 10, **kwargs
|
||||
):
|
||||
"""Query row embeddings for semantic search on structured data"""
|
||||
# First convert text to embeddings vectors
|
||||
# First convert text to embedding vector
|
||||
emb_result = await self.embeddings(texts=[text])
|
||||
vectors = emb_result.get("vectors", [[]])[0]
|
||||
vector = emb_result.get("vectors", [[]])[0]
|
||||
|
||||
request = {
|
||||
"vectors": vectors,
|
||||
"vector": vector,
|
||||
"schema_name": schema_name,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
|
|
|
|||
|
|
@ -602,13 +602,13 @@ class FlowInstance:
|
|||
```
|
||||
"""
|
||||
|
||||
# First convert text to embeddings vectors
|
||||
# First convert text to embedding vector
|
||||
emb_result = self.embeddings(texts=[text])
|
||||
vectors = emb_result.get("vectors", [[]])[0]
|
||||
vector = emb_result.get("vectors", [[]])[0]
|
||||
|
||||
# Query graph embeddings for semantic search
|
||||
input = {
|
||||
"vectors": vectors,
|
||||
"vector": vector,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
|
|
@ -648,13 +648,13 @@ class FlowInstance:
|
|||
```
|
||||
"""
|
||||
|
||||
# First convert text to embeddings vectors
|
||||
# First convert text to embedding vector
|
||||
emb_result = self.embeddings(texts=[text])
|
||||
vectors = emb_result.get("vectors", [[]])[0]
|
||||
vector = emb_result.get("vectors", [[]])[0]
|
||||
|
||||
# Query document embeddings for semantic search
|
||||
input = {
|
||||
"vectors": vectors,
|
||||
"vector": vector,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
|
|
@ -1362,13 +1362,13 @@ class FlowInstance:
|
|||
```
|
||||
"""
|
||||
|
||||
# First convert text to embeddings vectors
|
||||
# First convert text to embedding vector
|
||||
emb_result = self.embeddings(texts=[text])
|
||||
vectors = emb_result.get("vectors", [[]])[0]
|
||||
vector = emb_result.get("vectors", [[]])[0]
|
||||
|
||||
# Query row embeddings for semantic search
|
||||
input = {
|
||||
"vectors": vectors,
|
||||
"vector": vector,
|
||||
"schema_name": schema_name,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
|
|
|
|||
|
|
@ -649,12 +649,12 @@ class SocketFlowInstance:
|
|||
)
|
||||
```
|
||||
"""
|
||||
# First convert text to embeddings vectors
|
||||
# First convert text to embedding vector
|
||||
emb_result = self.embeddings(texts=[text])
|
||||
vectors = emb_result.get("vectors", [[]])[0]
|
||||
vector = emb_result.get("vectors", [[]])[0]
|
||||
|
||||
request = {
|
||||
"vectors": vectors,
|
||||
"vector": vector,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
|
|
@ -698,12 +698,12 @@ class SocketFlowInstance:
|
|||
# results contains {"chunk_ids": ["doc1/p0/c0", ...]}
|
||||
```
|
||||
"""
|
||||
# First convert text to embeddings vectors
|
||||
# First convert text to embedding vector
|
||||
emb_result = self.embeddings(texts=[text])
|
||||
vectors = emb_result.get("vectors", [[]])[0]
|
||||
vector = emb_result.get("vectors", [[]])[0]
|
||||
|
||||
request = {
|
||||
"vectors": vectors,
|
||||
"vector": vector,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
|
|
@ -936,12 +936,12 @@ class SocketFlowInstance:
|
|||
)
|
||||
```
|
||||
"""
|
||||
# First convert text to embeddings vectors
|
||||
# First convert text to embedding vector
|
||||
emb_result = self.embeddings(texts=[text])
|
||||
vectors = emb_result.get("vectors", [[]])[0]
|
||||
vector = emb_result.get("vectors", [[]])[0]
|
||||
|
||||
request = {
|
||||
"vectors": vectors,
|
||||
"vector": vector,
|
||||
"schema_name": schema_name,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
|
|
|
|||
|
|
@ -9,12 +9,12 @@ from .. knowledge import Uri, Literal
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DocumentEmbeddingsClient(RequestResponse):
|
||||
async def query(self, vectors, limit=20, user="trustgraph",
|
||||
async def query(self, vector, limit=20, user="trustgraph",
|
||||
collection="default", timeout=30):
|
||||
|
||||
resp = await self.request(
|
||||
DocumentEmbeddingsRequest(
|
||||
vectors = vectors,
|
||||
vector = vector,
|
||||
limit = limit,
|
||||
user = user,
|
||||
collection = collection
|
||||
|
|
@ -27,7 +27,8 @@ class DocumentEmbeddingsClient(RequestResponse):
|
|||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
return resp.chunk_ids
|
||||
# Return ChunkMatch objects with chunk_id and score
|
||||
return resp.chunks
|
||||
|
||||
class DocumentEmbeddingsClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
|
|||
docs = await self.query_document_embeddings(request)
|
||||
|
||||
logger.debug("Sending document embeddings query response...")
|
||||
r = DocumentEmbeddingsResponse(chunk_ids=docs, error=None)
|
||||
r = DocumentEmbeddingsResponse(chunks=docs, error=None)
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
|
||||
logger.debug("Document embeddings query request completed")
|
||||
|
|
@ -73,7 +73,7 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
|
|||
type = "document-embeddings-query-error",
|
||||
message = str(e),
|
||||
),
|
||||
chunk_ids=[],
|
||||
chunks=[],
|
||||
)
|
||||
|
||||
await flow("response").send(r, properties={"id": id})
|
||||
|
|
|
|||
|
|
@ -19,12 +19,12 @@ def to_value(x):
|
|||
return Literal(x.value or x.iri)
|
||||
|
||||
class GraphEmbeddingsClient(RequestResponse):
|
||||
async def query(self, vectors, limit=20, user="trustgraph",
|
||||
async def query(self, vector, limit=20, user="trustgraph",
|
||||
collection="default", timeout=30):
|
||||
|
||||
resp = await self.request(
|
||||
GraphEmbeddingsRequest(
|
||||
vectors = vectors,
|
||||
vector = vector,
|
||||
limit = limit,
|
||||
user = user,
|
||||
collection = collection
|
||||
|
|
@ -37,10 +37,8 @@ class GraphEmbeddingsClient(RequestResponse):
|
|||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
return [
|
||||
to_value(v)
|
||||
for v in resp.entities
|
||||
]
|
||||
# Return EntityMatch objects with entity and score
|
||||
return resp.entities
|
||||
|
||||
class GraphEmbeddingsClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse
|
|||
|
||||
class RowEmbeddingsQueryClient(RequestResponse):
|
||||
async def row_embeddings_query(
|
||||
self, vectors, schema_name, user="trustgraph", collection="default",
|
||||
self, vector, schema_name, user="trustgraph", collection="default",
|
||||
index_name=None, limit=10, timeout=600
|
||||
):
|
||||
request = RowEmbeddingsRequest(
|
||||
vectors=vectors,
|
||||
vector=vector,
|
||||
schema_name=schema_name,
|
||||
user=user,
|
||||
collection=collection,
|
||||
|
|
|
|||
|
|
@ -41,11 +41,11 @@ class DocumentEmbeddingsClient(BaseClient):
|
|||
)
|
||||
|
||||
def request(
|
||||
self, vectors, user="trustgraph", collection="default",
|
||||
self, vector, user="trustgraph", collection="default",
|
||||
limit=10, timeout=300
|
||||
):
|
||||
return self.call(
|
||||
user=user, collection=collection,
|
||||
vectors=vectors, limit=limit, timeout=timeout
|
||||
vector=vector, limit=limit, timeout=timeout
|
||||
).chunks
|
||||
|
||||
|
|
|
|||
|
|
@ -41,11 +41,11 @@ class GraphEmbeddingsClient(BaseClient):
|
|||
)
|
||||
|
||||
def request(
|
||||
self, vectors, user="trustgraph", collection="default",
|
||||
self, vector, user="trustgraph", collection="default",
|
||||
limit=10, timeout=300
|
||||
):
|
||||
return self.call(
|
||||
user=user, collection=collection,
|
||||
vectors=vectors, limit=limit, timeout=timeout
|
||||
vector=vector, limit=limit, timeout=timeout
|
||||
).entities
|
||||
|
||||
|
|
|
|||
|
|
@ -41,12 +41,12 @@ class RowEmbeddingsClient(BaseClient):
|
|||
)
|
||||
|
||||
def request(
|
||||
self, vectors, schema_name, user="trustgraph", collection="default",
|
||||
self, vector, schema_name, user="trustgraph", collection="default",
|
||||
index_name=None, limit=10, timeout=300
|
||||
):
|
||||
kwargs = dict(
|
||||
user=user, collection=collection,
|
||||
vectors=vectors, schema_name=schema_name,
|
||||
vector=vector, schema_name=schema_name,
|
||||
limit=limit, timeout=timeout
|
||||
)
|
||||
if index_name:
|
||||
|
|
|
|||
|
|
@ -10,18 +10,18 @@ from .primitives import ValueTranslator
|
|||
|
||||
class DocumentEmbeddingsRequestTranslator(MessageTranslator):
|
||||
"""Translator for DocumentEmbeddingsRequest schema objects"""
|
||||
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsRequest:
|
||||
return DocumentEmbeddingsRequest(
|
||||
vectors=data["vectors"],
|
||||
vector=data["vector"],
|
||||
limit=int(data.get("limit", 10)),
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default")
|
||||
)
|
||||
|
||||
|
||||
def from_pulsar(self, obj: DocumentEmbeddingsRequest) -> Dict[str, Any]:
|
||||
return {
|
||||
"vectors": obj.vectors,
|
||||
"vector": obj.vector,
|
||||
"limit": obj.limit,
|
||||
"user": obj.user,
|
||||
"collection": obj.collection
|
||||
|
|
@ -30,18 +30,24 @@ class DocumentEmbeddingsRequestTranslator(MessageTranslator):
|
|||
|
||||
class DocumentEmbeddingsResponseTranslator(MessageTranslator):
|
||||
"""Translator for DocumentEmbeddingsResponse schema objects"""
|
||||
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
|
||||
def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.chunk_ids is not None:
|
||||
result["chunk_ids"] = list(obj.chunk_ids)
|
||||
if obj.chunks is not None:
|
||||
result["chunks"] = [
|
||||
{
|
||||
"chunk_id": chunk.chunk_id,
|
||||
"score": chunk.score
|
||||
}
|
||||
for chunk in obj.chunks
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
|
@ -49,18 +55,18 @@ class DocumentEmbeddingsResponseTranslator(MessageTranslator):
|
|||
|
||||
class GraphEmbeddingsRequestTranslator(MessageTranslator):
|
||||
"""Translator for GraphEmbeddingsRequest schema objects"""
|
||||
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsRequest:
|
||||
return GraphEmbeddingsRequest(
|
||||
vectors=data["vectors"],
|
||||
vector=data["vector"],
|
||||
limit=int(data.get("limit", 10)),
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default")
|
||||
)
|
||||
|
||||
|
||||
def from_pulsar(self, obj: GraphEmbeddingsRequest) -> Dict[str, Any]:
|
||||
return {
|
||||
"vectors": obj.vectors,
|
||||
"vector": obj.vector,
|
||||
"limit": obj.limit,
|
||||
"user": obj.user,
|
||||
"collection": obj.collection
|
||||
|
|
@ -69,24 +75,27 @@ class GraphEmbeddingsRequestTranslator(MessageTranslator):
|
|||
|
||||
class GraphEmbeddingsResponseTranslator(MessageTranslator):
|
||||
"""Translator for GraphEmbeddingsResponse schema objects"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.value_translator = ValueTranslator()
|
||||
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
|
||||
def from_pulsar(self, obj: GraphEmbeddingsResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
|
||||
if obj.entities is not None:
|
||||
result["entities"] = [
|
||||
self.value_translator.from_pulsar(entity)
|
||||
for entity in obj.entities
|
||||
{
|
||||
"entity": self.value_translator.from_pulsar(match.entity),
|
||||
"score": match.score
|
||||
}
|
||||
for match in obj.entities
|
||||
]
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
|
@ -97,7 +106,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator):
|
|||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest:
|
||||
return RowEmbeddingsRequest(
|
||||
vectors=data["vectors"],
|
||||
vector=data["vector"],
|
||||
limit=int(data.get("limit", 10)),
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default"),
|
||||
|
|
@ -107,7 +116,7 @@ class RowEmbeddingsRequestTranslator(MessageTranslator):
|
|||
|
||||
def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]:
|
||||
result = {
|
||||
"vectors": obj.vectors,
|
||||
"vector": obj.vector,
|
||||
"limit": obj.limit,
|
||||
"user": obj.user,
|
||||
"collection": obj.collection,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from ..core.topic import topic
|
|||
@dataclass
|
||||
class EntityEmbeddings:
|
||||
entity: Term | None = None
|
||||
vectors: list[list[float]] = field(default_factory=list)
|
||||
vector: list[float] = field(default_factory=list)
|
||||
# Provenance: which chunk this embedding was derived from
|
||||
chunk_id: str = ""
|
||||
|
||||
|
|
@ -28,7 +28,7 @@ class GraphEmbeddings:
|
|||
@dataclass
|
||||
class ChunkEmbeddings:
|
||||
chunk_id: str = ""
|
||||
vectors: list[list[float]] = field(default_factory=list)
|
||||
vector: list[float] = field(default_factory=list)
|
||||
|
||||
# This is a 'batching' mechanism for the above data
|
||||
@dataclass
|
||||
|
|
@ -44,7 +44,7 @@ class DocumentEmbeddings:
|
|||
@dataclass
|
||||
class ObjectEmbeddings:
|
||||
metadata: Metadata | None = None
|
||||
vectors: list[list[float]] = field(default_factory=list)
|
||||
vector: list[float] = field(default_factory=list)
|
||||
name: str = ""
|
||||
key_name: str = ""
|
||||
id: str = ""
|
||||
|
|
@ -56,7 +56,7 @@ class ObjectEmbeddings:
|
|||
@dataclass
|
||||
class StructuredObjectEmbedding:
|
||||
metadata: Metadata | None = None
|
||||
vectors: list[list[float]] = field(default_factory=list)
|
||||
vector: list[float] = field(default_factory=list)
|
||||
schema_name: str = ""
|
||||
object_id: str = "" # Primary key value
|
||||
field_embeddings: dict[str, list[float]] = field(default_factory=dict) # Per-field embeddings
|
||||
|
|
@ -72,7 +72,7 @@ class RowIndexEmbedding:
|
|||
index_name: str = "" # The indexed field name(s)
|
||||
index_value: list[str] = field(default_factory=list) # The field value(s)
|
||||
text: str = "" # Text that was embedded
|
||||
vectors: list[list[float]] = field(default_factory=list)
|
||||
vector: list[float] = field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class RowEmbeddings:
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class EmbeddingsRequest:
|
|||
@dataclass
|
||||
class EmbeddingsResponse:
|
||||
error: Error | None = None
|
||||
vectors: list[list[list[float]]] = field(default_factory=list)
|
||||
vectors: list[list[float]] = field(default_factory=list)
|
||||
|
||||
############################################################################
|
||||
|
||||
|
|
|
|||
|
|
@ -9,15 +9,21 @@ from ..core.topic import topic
|
|||
|
||||
@dataclass
|
||||
class GraphEmbeddingsRequest:
|
||||
vectors: list[list[float]] = field(default_factory=list)
|
||||
vector: list[float] = field(default_factory=list)
|
||||
limit: int = 0
|
||||
user: str = ""
|
||||
collection: str = ""
|
||||
|
||||
@dataclass
|
||||
class EntityMatch:
|
||||
"""A matching entity from a semantic search with similarity score"""
|
||||
entity: Term | None = None
|
||||
score: float = 0.0
|
||||
|
||||
@dataclass
|
||||
class GraphEmbeddingsResponse:
|
||||
error: Error | None = None
|
||||
entities: list[Term] = field(default_factory=list)
|
||||
entities: list[EntityMatch] = field(default_factory=list)
|
||||
|
||||
############################################################################
|
||||
|
||||
|
|
@ -44,15 +50,21 @@ class TriplesQueryResponse:
|
|||
|
||||
@dataclass
|
||||
class DocumentEmbeddingsRequest:
|
||||
vectors: list[list[float]] = field(default_factory=list)
|
||||
vector: list[float] = field(default_factory=list)
|
||||
limit: int = 0
|
||||
user: str = ""
|
||||
collection: str = ""
|
||||
|
||||
@dataclass
|
||||
class ChunkMatch:
|
||||
"""A matching chunk from a semantic search with similarity score"""
|
||||
chunk_id: str = ""
|
||||
score: float = 0.0
|
||||
|
||||
@dataclass
|
||||
class DocumentEmbeddingsResponse:
|
||||
error: Error | None = None
|
||||
chunk_ids: list[str] = field(default_factory=list)
|
||||
chunks: list[ChunkMatch] = field(default_factory=list)
|
||||
|
||||
document_embeddings_request_queue = topic(
|
||||
"document-embeddings-request", qos='q0', tenant='trustgraph', namespace='flow'
|
||||
|
|
@ -76,7 +88,7 @@ class RowIndexMatch:
|
|||
@dataclass
|
||||
class RowEmbeddingsRequest:
|
||||
"""Request for row embeddings semantic search"""
|
||||
vectors: list[list[float]] = field(default_factory=list) # Query vectors
|
||||
vector: list[float] = field(default_factory=list) # Query vector
|
||||
limit: int = 10 # Max results to return
|
||||
user: str = "" # User/keyspace
|
||||
collection: str = "" # Collection name
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue