mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-01 17:39:39 +02:00
Row embeddings APIs exposed (#646)
* Added row embeddings API and CLI support * Updated protocol specs * Row embeddings agent tool * Add new agent tool to CLI
This commit is contained in:
parent
1809c1f56d
commit
4bbc6d844f
25 changed files with 1090 additions and 29 deletions
|
|
@ -19,7 +19,8 @@ from .translators.prompt import PromptRequestTranslator, PromptResponseTranslato
|
|||
from .translators.tool import ToolRequestTranslator, ToolResponseTranslator
|
||||
from .translators.embeddings_query import (
|
||||
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
|
||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator,
|
||||
RowEmbeddingsRequestTranslator, RowEmbeddingsResponseTranslator
|
||||
)
|
||||
from .translators.rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
|
||||
from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator
|
||||
|
|
@ -107,11 +108,17 @@ TranslatorRegistry.register_service(
|
|||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"graph-embeddings-query",
|
||||
GraphEmbeddingsRequestTranslator(),
|
||||
"graph-embeddings-query",
|
||||
GraphEmbeddingsRequestTranslator(),
|
||||
GraphEmbeddingsResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"row-embeddings-query",
|
||||
RowEmbeddingsRequestTranslator(),
|
||||
RowEmbeddingsResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"rows-query",
|
||||
RowsQueryRequestTranslator(),
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ from .flow import FlowRequestTranslator, FlowResponseTranslator
|
|||
from .prompt import PromptRequestTranslator, PromptResponseTranslator
|
||||
from .embeddings_query import (
|
||||
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
|
||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator,
|
||||
RowEmbeddingsRequestTranslator, RowEmbeddingsResponseTranslator
|
||||
)
|
||||
from .rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
|
||||
from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from typing import Dict, Any, Tuple
|
||||
from ...schema import (
|
||||
DocumentEmbeddingsRequest, DocumentEmbeddingsResponse,
|
||||
GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||
GraphEmbeddingsRequest, GraphEmbeddingsResponse,
|
||||
RowEmbeddingsRequest, RowEmbeddingsResponse, RowIndexMatch
|
||||
)
|
||||
from .base import MessageTranslator
|
||||
from .primitives import ValueTranslator
|
||||
|
|
@ -92,3 +93,62 @@ class GraphEmbeddingsResponseTranslator(MessageTranslator):
|
|||
def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
||||
|
||||
class RowEmbeddingsRequestTranslator(MessageTranslator):
|
||||
"""Translator for RowEmbeddingsRequest schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest:
|
||||
return RowEmbeddingsRequest(
|
||||
vectors=data["vectors"],
|
||||
limit=int(data.get("limit", 10)),
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default"),
|
||||
schema_name=data.get("schema_name", ""),
|
||||
index_name=data.get("index_name")
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]:
|
||||
result = {
|
||||
"vectors": obj.vectors,
|
||||
"limit": obj.limit,
|
||||
"user": obj.user,
|
||||
"collection": obj.collection,
|
||||
"schema_name": obj.schema_name,
|
||||
}
|
||||
if obj.index_name:
|
||||
result["index_name"] = obj.index_name
|
||||
return result
|
||||
|
||||
|
||||
class RowEmbeddingsResponseTranslator(MessageTranslator):
|
||||
"""Translator for RowEmbeddingsResponse schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
def from_pulsar(self, obj: RowEmbeddingsResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.error is not None:
|
||||
result["error"] = {
|
||||
"type": obj.error.type,
|
||||
"message": obj.error.message
|
||||
}
|
||||
|
||||
if obj.matches is not None:
|
||||
result["matches"] = [
|
||||
{
|
||||
"index_name": match.index_name,
|
||||
"index_value": match.index_value,
|
||||
"text": match.text,
|
||||
"score": match.score
|
||||
}
|
||||
for match in obj.matches
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
def from_response_with_completion(self, obj: RowEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue