mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Row embeddings APIs exposed (#646)
* Added row embeddings API and CLI support * Updated protocol specs * Row embeddings agent tool * Add new agent tool to CLI
This commit is contained in:
parent
1809c1f56d
commit
4bbc6d844f
25 changed files with 1090 additions and 29 deletions
|
|
@ -766,3 +766,63 @@ class AsyncFlowInstance:
|
|||
request_data.update(kwargs)
|
||||
|
||||
return await self.request("rows", request_data)
|
||||
|
||||
async def row_embeddings_query(
|
||||
self, text: str, schema_name: str, user: str = "trustgraph",
|
||||
collection: str = "default", index_name: Optional[str] = None,
|
||||
limit: int = 10, **kwargs: Any
|
||||
):
|
||||
"""
|
||||
Query row embeddings for semantic search on structured data.
|
||||
|
||||
Performs semantic search over row index embeddings to find rows whose
|
||||
indexed field values are most similar to the input text. Enables
|
||||
fuzzy/semantic matching on structured data.
|
||||
|
||||
Args:
|
||||
text: Query text for semantic search
|
||||
schema_name: Schema name to search within
|
||||
user: User identifier (default: "trustgraph")
|
||||
collection: Collection identifier (default: "default")
|
||||
index_name: Optional index name to filter search to specific index
|
||||
limit: Maximum number of results to return (default: 10)
|
||||
**kwargs: Additional service-specific parameters
|
||||
|
||||
Returns:
|
||||
dict: Response containing matches with index_name, index_value,
|
||||
text, and score
|
||||
|
||||
Example:
|
||||
```python
|
||||
async_flow = await api.async_flow()
|
||||
flow = async_flow.id("default")
|
||||
|
||||
# Search for customers by name similarity
|
||||
results = await flow.row_embeddings_query(
|
||||
text="John Smith",
|
||||
schema_name="customers",
|
||||
user="trustgraph",
|
||||
collection="sales",
|
||||
limit=5
|
||||
)
|
||||
|
||||
for match in results.get("matches", []):
|
||||
print(f"{match['index_name']}: {match['index_value']} (score: {match['score']})")
|
||||
```
|
||||
"""
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = await self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
|
||||
request_data = {
|
||||
"vectors": vectors,
|
||||
"schema_name": schema_name,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
}
|
||||
if index_name:
|
||||
request_data["index_name"] = index_name
|
||||
request_data.update(kwargs)
|
||||
|
||||
return await self.request("row-embeddings", request_data)
|
||||
|
|
|
|||
|
|
@ -345,3 +345,26 @@ class AsyncSocketFlowInstance:
|
|||
request.update(kwargs)
|
||||
|
||||
return await self.client._send_request("mcp-tool", self.flow_id, request)
|
||||
|
||||
async def row_embeddings_query(
|
||||
self, text: str, schema_name: str, user: str = "trustgraph",
|
||||
collection: str = "default", index_name: Optional[str] = None,
|
||||
limit: int = 10, **kwargs
|
||||
):
|
||||
"""Query row embeddings for semantic search on structured data"""
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = await self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
|
||||
request = {
|
||||
"vectors": vectors,
|
||||
"schema_name": schema_name,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
}
|
||||
if index_name:
|
||||
request["index_name"] = index_name
|
||||
request.update(kwargs)
|
||||
|
||||
return await self.client._send_request("row-embeddings", self.flow_id, request)
|
||||
|
|
|
|||
|
|
@ -1297,3 +1297,78 @@ class FlowInstance:
|
|||
|
||||
return response["schema-matches"]
|
||||
|
||||
def row_embeddings_query(
|
||||
self, text, schema_name, user="trustgraph", collection="default",
|
||||
index_name=None, limit=10
|
||||
):
|
||||
"""
|
||||
Query row data using semantic similarity on indexed fields.
|
||||
|
||||
Finds rows whose indexed field values are semantically similar to the
|
||||
input text, using vector embeddings. This enables fuzzy/semantic matching
|
||||
on structured data.
|
||||
|
||||
Args:
|
||||
text: Query text for semantic search
|
||||
schema_name: Schema name to search within
|
||||
user: User/keyspace identifier (default: "trustgraph")
|
||||
collection: Collection identifier (default: "default")
|
||||
index_name: Optional index name to filter search to specific index
|
||||
limit: Maximum number of results (default: 10)
|
||||
|
||||
Returns:
|
||||
dict: Query results with matches containing index_name, index_value,
|
||||
text, and score
|
||||
|
||||
Example:
|
||||
```python
|
||||
flow = api.flow().id("default")
|
||||
|
||||
# Search for customers by name similarity
|
||||
results = flow.row_embeddings_query(
|
||||
text="John Smith",
|
||||
schema_name="customers",
|
||||
user="trustgraph",
|
||||
collection="sales",
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Filter to specific index
|
||||
results = flow.row_embeddings_query(
|
||||
text="machine learning engineer",
|
||||
schema_name="employees",
|
||||
index_name="job_title",
|
||||
limit=10
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
|
||||
# Query row embeddings for semantic search
|
||||
input = {
|
||||
"vectors": vectors,
|
||||
"schema_name": schema_name,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
if index_name:
|
||||
input["index_name"] = index_name
|
||||
|
||||
response = self.request(
|
||||
"service/row-embeddings",
|
||||
input
|
||||
)
|
||||
|
||||
# Check for system-level error
|
||||
if "error" in response and response["error"]:
|
||||
error_type = response["error"].get("type", "unknown")
|
||||
error_message = response["error"].get("message", "Unknown error")
|
||||
raise ProtocolException(f"{error_type}: {error_message}")
|
||||
|
||||
return response
|
||||
|
||||
|
|
|
|||
|
|
@ -881,3 +881,73 @@ class SocketFlowInstance:
|
|||
request.update(kwargs)
|
||||
|
||||
return self.client._send_request_sync("mcp-tool", self.flow_id, request, False)
|
||||
|
||||
def row_embeddings_query(
|
||||
self,
|
||||
text: str,
|
||||
schema_name: str,
|
||||
user: str = "trustgraph",
|
||||
collection: str = "default",
|
||||
index_name: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
**kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Query row data using semantic similarity on indexed fields.
|
||||
|
||||
Finds rows whose indexed field values are semantically similar to the
|
||||
input text, using vector embeddings. This enables fuzzy/semantic matching
|
||||
on structured data.
|
||||
|
||||
Args:
|
||||
text: Query text for semantic search
|
||||
schema_name: Schema name to search within
|
||||
user: User/keyspace identifier (default: "trustgraph")
|
||||
collection: Collection identifier (default: "default")
|
||||
index_name: Optional index name to filter search to specific index
|
||||
limit: Maximum number of results (default: 10)
|
||||
**kwargs: Additional parameters passed to the service
|
||||
|
||||
Returns:
|
||||
dict: Query results with matches containing index_name, index_value,
|
||||
text, and score
|
||||
|
||||
Example:
|
||||
```python
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
|
||||
# Search for customers by name similarity
|
||||
results = flow.row_embeddings_query(
|
||||
text="John Smith",
|
||||
schema_name="customers",
|
||||
user="trustgraph",
|
||||
collection="sales",
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Filter to specific index
|
||||
results = flow.row_embeddings_query(
|
||||
text="machine learning engineer",
|
||||
schema_name="employees",
|
||||
index_name="job_title",
|
||||
limit=10
|
||||
)
|
||||
```
|
||||
"""
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
|
||||
request = {
|
||||
"vectors": vectors,
|
||||
"schema_name": schema_name,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
}
|
||||
if index_name:
|
||||
request["index_name"] = index_name
|
||||
request.update(kwargs)
|
||||
|
||||
return self.client._send_request_sync("row-embeddings", self.flow_id, request, False)
|
||||
|
|
|
|||
|
|
@ -34,5 +34,6 @@ from . tool_service import ToolService
|
|||
from . tool_client import ToolClientSpec
|
||||
from . agent_client import AgentClientSpec
|
||||
from . structured_query_client import StructuredQueryClientSpec
|
||||
from . row_embeddings_query_client import RowEmbeddingsQueryClientSpec
|
||||
from . collection_config_handler import CollectionConfigHandler
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,45 @@
|
|||
from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||
from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse
|
||||
|
||||
class RowEmbeddingsQueryClient(RequestResponse):
|
||||
async def row_embeddings_query(
|
||||
self, vectors, schema_name, user="trustgraph", collection="default",
|
||||
index_name=None, limit=10, timeout=600
|
||||
):
|
||||
request = RowEmbeddingsRequest(
|
||||
vectors=vectors,
|
||||
schema_name=schema_name,
|
||||
user=user,
|
||||
collection=collection,
|
||||
limit=limit
|
||||
)
|
||||
if index_name:
|
||||
request.index_name = index_name
|
||||
|
||||
resp = await self.request(request, timeout=timeout)
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
# Return matches as list of dicts
|
||||
return [
|
||||
{
|
||||
"index_name": match.index_name,
|
||||
"index_value": match.index_value,
|
||||
"text": match.text,
|
||||
"score": match.score
|
||||
}
|
||||
for match in (resp.matches or [])
|
||||
]
|
||||
|
||||
class RowEmbeddingsQueryClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
self, request_name, response_name,
|
||||
):
|
||||
super(RowEmbeddingsQueryClientSpec, self).__init__(
|
||||
request_name = request_name,
|
||||
request_schema = RowEmbeddingsRequest,
|
||||
response_name = response_name,
|
||||
response_schema = RowEmbeddingsResponse,
|
||||
impl = RowEmbeddingsQueryClient,
|
||||
)
|
||||
60
trustgraph-base/trustgraph/clients/row_embeddings_client.py
Normal file
60
trustgraph-base/trustgraph/clients/row_embeddings_client.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
|
||||
import _pulsar
|
||||
|
||||
from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse
|
||||
from .. schema import row_embeddings_request_queue
|
||||
from .. schema import row_embeddings_response_queue
|
||||
from . base import BaseClient
|
||||
|
||||
# Ugly
|
||||
ERROR=_pulsar.LoggerLevel.Error
|
||||
WARN=_pulsar.LoggerLevel.Warn
|
||||
INFO=_pulsar.LoggerLevel.Info
|
||||
DEBUG=_pulsar.LoggerLevel.Debug
|
||||
|
||||
class RowEmbeddingsClient(BaseClient):
|
||||
|
||||
def __init__(
|
||||
self, log_level=ERROR,
|
||||
subscriber=None,
|
||||
input_queue=None,
|
||||
output_queue=None,
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
pulsar_api_key=None,
|
||||
):
|
||||
|
||||
if input_queue == None:
|
||||
input_queue = row_embeddings_request_queue
|
||||
|
||||
if output_queue == None:
|
||||
output_queue = row_embeddings_response_queue
|
||||
|
||||
super(RowEmbeddingsClient, self).__init__(
|
||||
log_level=log_level,
|
||||
subscriber=subscriber,
|
||||
input_queue=input_queue,
|
||||
output_queue=output_queue,
|
||||
pulsar_host=pulsar_host,
|
||||
pulsar_api_key=pulsar_api_key,
|
||||
input_schema=RowEmbeddingsRequest,
|
||||
output_schema=RowEmbeddingsResponse,
|
||||
)
|
||||
|
||||
def request(
|
||||
self, vectors, schema_name, user="trustgraph", collection="default",
|
||||
index_name=None, limit=10, timeout=300
|
||||
):
|
||||
kwargs = dict(
|
||||
user=user, collection=collection,
|
||||
vectors=vectors, schema_name=schema_name,
|
||||
limit=limit, timeout=timeout
|
||||
)
|
||||
if index_name:
|
||||
kwargs["index_name"] = index_name
|
||||
|
||||
response = self.call(**kwargs)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(f"{response.error.type}: {response.error.message}")
|
||||
|
||||
return response.matches
|
||||
|
|
@ -19,7 +19,8 @@ from .translators.prompt import PromptRequestTranslator, PromptResponseTranslato
|
|||
from .translators.tool import ToolRequestTranslator, ToolResponseTranslator
|
||||
from .translators.embeddings_query import (
|
||||
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
|
||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator,
|
||||
RowEmbeddingsRequestTranslator, RowEmbeddingsResponseTranslator
|
||||
)
|
||||
from .translators.rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
|
||||
from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator
|
||||
|
|
@ -107,11 +108,17 @@ TranslatorRegistry.register_service(
|
|||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"graph-embeddings-query",
|
||||
GraphEmbeddingsRequestTranslator(),
|
||||
"graph-embeddings-query",
|
||||
GraphEmbeddingsRequestTranslator(),
|
||||
GraphEmbeddingsResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"row-embeddings-query",
|
||||
RowEmbeddingsRequestTranslator(),
|
||||
RowEmbeddingsResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"rows-query",
|
||||
RowsQueryRequestTranslator(),
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ from .flow import FlowRequestTranslator, FlowResponseTranslator
|
|||
from .prompt import PromptRequestTranslator, PromptResponseTranslator
|
||||
from .embeddings_query import (
|
||||
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
|
||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator,
|
||||
RowEmbeddingsRequestTranslator, RowEmbeddingsResponseTranslator
|
||||
)
|
||||
from .rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
|
||||
from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from typing import Dict, Any, Tuple
|
||||
from ...schema import (
|
||||
DocumentEmbeddingsRequest, DocumentEmbeddingsResponse,
|
||||
GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||
GraphEmbeddingsRequest, GraphEmbeddingsResponse,
|
||||
RowEmbeddingsRequest, RowEmbeddingsResponse, RowIndexMatch
|
||||
)
|
||||
from .base import MessageTranslator
|
||||
from .primitives import ValueTranslator
|
||||
|
|
@ -92,3 +93,62 @@ class GraphEmbeddingsResponseTranslator(MessageTranslator):
|
|||
def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
||||
|
||||
class RowEmbeddingsRequestTranslator(MessageTranslator):
|
||||
"""Translator for RowEmbeddingsRequest schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest:
|
||||
return RowEmbeddingsRequest(
|
||||
vectors=data["vectors"],
|
||||
limit=int(data.get("limit", 10)),
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default"),
|
||||
schema_name=data.get("schema_name", ""),
|
||||
index_name=data.get("index_name")
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]:
|
||||
result = {
|
||||
"vectors": obj.vectors,
|
||||
"limit": obj.limit,
|
||||
"user": obj.user,
|
||||
"collection": obj.collection,
|
||||
"schema_name": obj.schema_name,
|
||||
}
|
||||
if obj.index_name:
|
||||
result["index_name"] = obj.index_name
|
||||
return result
|
||||
|
||||
|
||||
class RowEmbeddingsResponseTranslator(MessageTranslator):
|
||||
"""Translator for RowEmbeddingsResponse schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
def from_pulsar(self, obj: RowEmbeddingsResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.error is not None:
|
||||
result["error"] = {
|
||||
"type": obj.error.type,
|
||||
"message": obj.error.message
|
||||
}
|
||||
|
||||
if obj.matches is not None:
|
||||
result["matches"] = [
|
||||
{
|
||||
"index_name": match.index_name,
|
||||
"index_value": match.index_value,
|
||||
"text": match.text,
|
||||
"score": match.score
|
||||
}
|
||||
for match in obj.matches
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
def from_response_with_completion(self, obj: RowEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue