Row embeddings APIs exposed (#646)

* Added row embeddings API and CLI support

* Updated protocol specs

* Row embeddings agent tool

* Add new agent tool to CLI
This commit is contained in:
cybermaggedon 2026-02-23 21:52:56 +00:00 committed by GitHub
parent 1809c1f56d
commit 4bbc6d844f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 1090 additions and 29 deletions

View file

@ -766,3 +766,63 @@ class AsyncFlowInstance:
request_data.update(kwargs)
return await self.request("rows", request_data)
async def row_embeddings_query(
self, text: str, schema_name: str, user: str = "trustgraph",
collection: str = "default", index_name: Optional[str] = None,
limit: int = 10, **kwargs: Any
):
"""
Query row embeddings for semantic search on structured data.
Performs semantic search over row index embeddings to find rows whose
indexed field values are most similar to the input text. Enables
fuzzy/semantic matching on structured data.
Args:
text: Query text for semantic search
schema_name: Schema name to search within
user: User identifier (default: "trustgraph")
collection: Collection identifier (default: "default")
index_name: Optional index name to filter search to specific index
limit: Maximum number of results to return (default: 10)
**kwargs: Additional service-specific parameters
Returns:
dict: Response containing matches with index_name, index_value,
text, and score
Example:
```python
async_flow = await api.async_flow()
flow = async_flow.id("default")
# Search for customers by name similarity
results = await flow.row_embeddings_query(
text="John Smith",
schema_name="customers",
user="trustgraph",
collection="sales",
limit=5
)
for match in results.get("matches", []):
print(f"{match['index_name']}: {match['index_value']} (score: {match['score']})")
```
"""
# First convert text to embeddings vectors
emb_result = await self.embeddings(text=text)
vectors = emb_result.get("vectors", [])
request_data = {
"vectors": vectors,
"schema_name": schema_name,
"user": user,
"collection": collection,
"limit": limit
}
if index_name:
request_data["index_name"] = index_name
request_data.update(kwargs)
return await self.request("row-embeddings", request_data)

View file

@ -345,3 +345,26 @@ class AsyncSocketFlowInstance:
request.update(kwargs)
return await self.client._send_request("mcp-tool", self.flow_id, request)
async def row_embeddings_query(
self, text: str, schema_name: str, user: str = "trustgraph",
collection: str = "default", index_name: Optional[str] = None,
limit: int = 10, **kwargs
):
"""Query row embeddings for semantic search on structured data"""
# First convert text to embeddings vectors
emb_result = await self.embeddings(text=text)
vectors = emb_result.get("vectors", [])
request = {
"vectors": vectors,
"schema_name": schema_name,
"user": user,
"collection": collection,
"limit": limit
}
if index_name:
request["index_name"] = index_name
request.update(kwargs)
return await self.client._send_request("row-embeddings", self.flow_id, request)

View file

@ -1297,3 +1297,78 @@ class FlowInstance:
return response["schema-matches"]
def row_embeddings_query(
self, text, schema_name, user="trustgraph", collection="default",
index_name=None, limit=10
):
"""
Query row data using semantic similarity on indexed fields.
Finds rows whose indexed field values are semantically similar to the
input text, using vector embeddings. This enables fuzzy/semantic matching
on structured data.
Args:
text: Query text for semantic search
schema_name: Schema name to search within
user: User/keyspace identifier (default: "trustgraph")
collection: Collection identifier (default: "default")
index_name: Optional index name to filter search to specific index
limit: Maximum number of results (default: 10)
Returns:
dict: Query results with matches containing index_name, index_value,
text, and score
Example:
```python
flow = api.flow().id("default")
# Search for customers by name similarity
results = flow.row_embeddings_query(
text="John Smith",
schema_name="customers",
user="trustgraph",
collection="sales",
limit=5
)
# Filter to specific index
results = flow.row_embeddings_query(
text="machine learning engineer",
schema_name="employees",
index_name="job_title",
limit=10
)
```
"""
# First convert text to embeddings vectors
emb_result = self.embeddings(text=text)
vectors = emb_result.get("vectors", [])
# Query row embeddings for semantic search
input = {
"vectors": vectors,
"schema_name": schema_name,
"user": user,
"collection": collection,
"limit": limit
}
if index_name:
input["index_name"] = index_name
response = self.request(
"service/row-embeddings",
input
)
# Check for system-level error
if "error" in response and response["error"]:
error_type = response["error"].get("type", "unknown")
error_message = response["error"].get("message", "Unknown error")
raise ProtocolException(f"{error_type}: {error_message}")
return response

View file

@ -881,3 +881,73 @@ class SocketFlowInstance:
request.update(kwargs)
return self.client._send_request_sync("mcp-tool", self.flow_id, request, False)
def row_embeddings_query(
self,
text: str,
schema_name: str,
user: str = "trustgraph",
collection: str = "default",
index_name: Optional[str] = None,
limit: int = 10,
**kwargs: Any
) -> Dict[str, Any]:
"""
Query row data using semantic similarity on indexed fields.
Finds rows whose indexed field values are semantically similar to the
input text, using vector embeddings. This enables fuzzy/semantic matching
on structured data.
Args:
text: Query text for semantic search
schema_name: Schema name to search within
user: User/keyspace identifier (default: "trustgraph")
collection: Collection identifier (default: "default")
index_name: Optional index name to filter search to specific index
limit: Maximum number of results (default: 10)
**kwargs: Additional parameters passed to the service
Returns:
dict: Query results with matches containing index_name, index_value,
text, and score
Example:
```python
socket = api.socket()
flow = socket.flow("default")
# Search for customers by name similarity
results = flow.row_embeddings_query(
text="John Smith",
schema_name="customers",
user="trustgraph",
collection="sales",
limit=5
)
# Filter to specific index
results = flow.row_embeddings_query(
text="machine learning engineer",
schema_name="employees",
index_name="job_title",
limit=10
)
```
"""
# First convert text to embeddings vectors
emb_result = self.embeddings(text=text)
vectors = emb_result.get("vectors", [])
request = {
"vectors": vectors,
"schema_name": schema_name,
"user": user,
"collection": collection,
"limit": limit
}
if index_name:
request["index_name"] = index_name
request.update(kwargs)
return self.client._send_request_sync("row-embeddings", self.flow_id, request, False)

View file

@ -34,5 +34,6 @@ from . tool_service import ToolService
from . tool_client import ToolClientSpec
from . agent_client import AgentClientSpec
from . structured_query_client import StructuredQueryClientSpec
from . row_embeddings_query_client import RowEmbeddingsQueryClientSpec
from . collection_config_handler import CollectionConfigHandler

View file

@ -0,0 +1,45 @@
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse
class RowEmbeddingsQueryClient(RequestResponse):
async def row_embeddings_query(
self, vectors, schema_name, user="trustgraph", collection="default",
index_name=None, limit=10, timeout=600
):
request = RowEmbeddingsRequest(
vectors=vectors,
schema_name=schema_name,
user=user,
collection=collection,
limit=limit
)
if index_name:
request.index_name = index_name
resp = await self.request(request, timeout=timeout)
if resp.error:
raise RuntimeError(resp.error.message)
# Return matches as list of dicts
return [
{
"index_name": match.index_name,
"index_value": match.index_value,
"text": match.text,
"score": match.score
}
for match in (resp.matches or [])
]
class RowEmbeddingsQueryClientSpec(RequestResponseSpec):
def __init__(
self, request_name, response_name,
):
super(RowEmbeddingsQueryClientSpec, self).__init__(
request_name = request_name,
request_schema = RowEmbeddingsRequest,
response_name = response_name,
response_schema = RowEmbeddingsResponse,
impl = RowEmbeddingsQueryClient,
)

View file

@ -0,0 +1,60 @@
import _pulsar
from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse
from .. schema import row_embeddings_request_queue
from .. schema import row_embeddings_response_queue
from . base import BaseClient
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class RowEmbeddingsClient(BaseClient):
def __init__(
self, log_level=ERROR,
subscriber=None,
input_queue=None,
output_queue=None,
pulsar_host="pulsar://pulsar:6650",
pulsar_api_key=None,
):
if input_queue == None:
input_queue = row_embeddings_request_queue
if output_queue == None:
output_queue = row_embeddings_response_queue
super(RowEmbeddingsClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
pulsar_api_key=pulsar_api_key,
input_schema=RowEmbeddingsRequest,
output_schema=RowEmbeddingsResponse,
)
def request(
self, vectors, schema_name, user="trustgraph", collection="default",
index_name=None, limit=10, timeout=300
):
kwargs = dict(
user=user, collection=collection,
vectors=vectors, schema_name=schema_name,
limit=limit, timeout=timeout
)
if index_name:
kwargs["index_name"] = index_name
response = self.call(**kwargs)
if response.error:
raise RuntimeError(f"{response.error.type}: {response.error.message}")
return response.matches

View file

@ -19,7 +19,8 @@ from .translators.prompt import PromptRequestTranslator, PromptResponseTranslato
from .translators.tool import ToolRequestTranslator, ToolResponseTranslator
from .translators.embeddings_query import (
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator,
RowEmbeddingsRequestTranslator, RowEmbeddingsResponseTranslator
)
from .translators.rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator
@ -107,11 +108,17 @@ TranslatorRegistry.register_service(
)
TranslatorRegistry.register_service(
"graph-embeddings-query",
GraphEmbeddingsRequestTranslator(),
"graph-embeddings-query",
GraphEmbeddingsRequestTranslator(),
GraphEmbeddingsResponseTranslator()
)
TranslatorRegistry.register_service(
"row-embeddings-query",
RowEmbeddingsRequestTranslator(),
RowEmbeddingsResponseTranslator()
)
TranslatorRegistry.register_service(
"rows-query",
RowsQueryRequestTranslator(),

View file

@ -15,7 +15,8 @@ from .flow import FlowRequestTranslator, FlowResponseTranslator
from .prompt import PromptRequestTranslator, PromptResponseTranslator
from .embeddings_query import (
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator,
RowEmbeddingsRequestTranslator, RowEmbeddingsResponseTranslator
)
from .rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator

View file

@ -1,7 +1,8 @@
from typing import Dict, Any, Tuple
from ...schema import (
DocumentEmbeddingsRequest, DocumentEmbeddingsResponse,
GraphEmbeddingsRequest, GraphEmbeddingsResponse
GraphEmbeddingsRequest, GraphEmbeddingsResponse,
RowEmbeddingsRequest, RowEmbeddingsResponse, RowIndexMatch
)
from .base import MessageTranslator
from .primitives import ValueTranslator
@ -92,3 +93,62 @@ class GraphEmbeddingsResponseTranslator(MessageTranslator):
def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True
class RowEmbeddingsRequestTranslator(MessageTranslator):
"""Translator for RowEmbeddingsRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest:
return RowEmbeddingsRequest(
vectors=data["vectors"],
limit=int(data.get("limit", 10)),
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
schema_name=data.get("schema_name", ""),
index_name=data.get("index_name")
)
def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]:
result = {
"vectors": obj.vectors,
"limit": obj.limit,
"user": obj.user,
"collection": obj.collection,
"schema_name": obj.schema_name,
}
if obj.index_name:
result["index_name"] = obj.index_name
return result
class RowEmbeddingsResponseTranslator(MessageTranslator):
"""Translator for RowEmbeddingsResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: RowEmbeddingsResponse) -> Dict[str, Any]:
result = {}
if obj.error is not None:
result["error"] = {
"type": obj.error.type,
"message": obj.error.message
}
if obj.matches is not None:
result["matches"] = [
{
"index_name": match.index_name,
"index_value": match.index_value,
"text": match.text,
"score": match.score
}
for match in obj.matches
]
return result
def from_response_with_completion(self, obj: RowEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True