Row embeddings APIs exposed (#646)

* Added row embeddings API and CLI support

* Updated protocol specs

* Row embeddings agent tool

* Add new agent tool to CLI
This commit is contained in:
cybermaggedon 2026-02-23 21:52:56 +00:00 committed by GitHub
parent 1809c1f56d
commit 4bbc6d844f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 1090 additions and 29 deletions

View file

@ -13,10 +13,11 @@ logger = logging.getLogger(__name__)
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl, StructuredQueryImpl
from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl, StructuredQueryImpl, RowEmbeddingsQueryImpl
from . agent_manager import AgentManager
from ..tool_filter import validate_tool_config, filter_tools_by_group_and_state, get_next_state
@ -87,6 +88,20 @@ class Processor(AgentService):
)
)
self.register_specification(
EmbeddingsClientSpec(
request_name = "embeddings-request",
response_name = "embeddings-response",
)
)
self.register_specification(
RowEmbeddingsQueryClientSpec(
request_name = "row-embeddings-query-request",
response_name = "row-embeddings-query-response",
)
)
async def on_tools_config(self, config, version):
logger.info(f"Loading configuration version {version}")
@ -147,11 +162,21 @@ class Processor(AgentService):
)
elif impl_id == "structured-query":
impl = functools.partial(
StructuredQueryImpl,
StructuredQueryImpl,
collection=data.get("collection"),
user=None # User will be provided dynamically via context
)
arguments = StructuredQueryImpl.get_arguments()
elif impl_id == "row-embeddings-query":
impl = functools.partial(
RowEmbeddingsQueryImpl,
schema_name=data.get("schema-name"),
collection=data.get("collection"),
user=None, # User will be provided dynamically via context
index_name=data.get("index-name"), # Optional filter
limit=int(data.get("limit", 10)) # Max results
)
arguments = RowEmbeddingsQueryImpl.get_arguments()
else:
raise RuntimeError(
f"Tool type {impl_id} not known"
@ -327,11 +352,11 @@ class Processor(AgentService):
def __init__(self, flow, user):
self._flow = flow
self._user = user
def __call__(self, service_name):
client = self._flow(service_name)
# For structured query clients, store user context
if service_name == "structured-query-request":
# For query clients that need user context, store it
if service_name in ("structured-query-request", "row-embeddings-query-request"):
client._current_user = self._user
return client

View file

@ -128,6 +128,62 @@ class StructuredQueryImpl:
return str(result)
# This tool implementation knows how to query row embeddings for semantic search
class RowEmbeddingsQueryImpl:
def __init__(self, context, schema_name, collection=None, user=None, index_name=None, limit=10):
self.context = context
self.schema_name = schema_name
self.collection = collection
self.user = user
self.index_name = index_name # Optional: filter to specific index
self.limit = limit # Max results to return
@staticmethod
def get_arguments():
return [
Argument(
name="query",
type="string",
description="Text to search for semantically similar values in the structured data index"
)
]
async def invoke(self, **arguments):
# First get embeddings for the query text
embeddings_client = self.context("embeddings-request")
logger.debug("Getting embeddings for row query...")
query_text = arguments.get("query")
vectors = await embeddings_client.embed(query_text)
# Now query row embeddings
client = self.context("row-embeddings-query-request")
logger.debug("Row embeddings query...")
# Get user from client context if available
user = getattr(client, '_current_user', self.user or "trustgraph")
matches = await client.row_embeddings_query(
vectors=vectors,
schema_name=self.schema_name,
user=user,
collection=self.collection or "default",
index_name=self.index_name,
limit=self.limit
)
# Format results for agent consumption
if not matches:
return "No matching records found"
results = []
for match in matches:
result = f"- {match['index_name']}: {', '.join(match['index_value'])} (score: {match['score']:.3f})"
results.append(result)
return "Matching records:\n" + "\n".join(results)
# This tool implementation knows how to execute prompt templates
class PromptImpl:
def __init__(self, context, template_id, arguments=None):

View file

@ -27,6 +27,7 @@ from . structured_diag import StructuredDiagRequestor
from . embeddings import EmbeddingsRequestor
from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
from . document_embeddings_query import DocumentEmbeddingsQueryRequestor
from . row_embeddings_query import RowEmbeddingsQueryRequestor
from . mcp_tool import McpToolRequestor
from . text_load import TextLoad
from . document_load import DocumentLoad
@ -62,6 +63,7 @@ request_response_dispatchers = {
"nlp-query": NLPQueryRequestor,
"structured-query": StructuredQueryRequestor,
"structured-diag": StructuredDiagRequestor,
"row-embeddings": RowEmbeddingsQueryRequestor,
}
global_dispatchers = {

View file

@ -0,0 +1,31 @@
from ... schema import RowEmbeddingsRequest, RowEmbeddingsResponse
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
class RowEmbeddingsQueryRequestor(ServiceRequestor):
def __init__(
self, backend, request_queue, response_queue, timeout,
consumer, subscriber,
):
super(RowEmbeddingsQueryRequestor, self).__init__(
backend=backend,
request_queue=request_queue,
response_queue=response_queue,
request_schema=RowEmbeddingsRequest,
response_schema=RowEmbeddingsResponse,
subscription = subscriber,
consumer_name = consumer,
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("row-embeddings-query")
self.response_translator = TranslatorRegistry.get_response_translator("row-embeddings-query")
def to_request(self, body):
return self.request_translator.to_pulsar(body)
def from_response(self, message):
return self.response_translator.from_response_with_completion(message)