mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-03 12:22:37 +02:00
Row embeddings APIs exposed (#646)
* Added row embeddings API and CLI support * Updated protocol specs * Row embeddings agent tool * Add new agent tool to CLI
This commit is contained in:
parent
1809c1f56d
commit
4bbc6d844f
25 changed files with 1090 additions and 29 deletions
|
|
@ -13,10 +13,11 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
|
||||
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
|
||||
from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec
|
||||
|
||||
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
|
||||
|
||||
from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl, StructuredQueryImpl
|
||||
from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl, StructuredQueryImpl, RowEmbeddingsQueryImpl
|
||||
from . agent_manager import AgentManager
|
||||
from ..tool_filter import validate_tool_config, filter_tools_by_group_and_state, get_next_state
|
||||
|
||||
|
|
@ -87,6 +88,20 @@ class Processor(AgentService):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
EmbeddingsClientSpec(
|
||||
request_name = "embeddings-request",
|
||||
response_name = "embeddings-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
RowEmbeddingsQueryClientSpec(
|
||||
request_name = "row-embeddings-query-request",
|
||||
response_name = "row-embeddings-query-response",
|
||||
)
|
||||
)
|
||||
|
||||
async def on_tools_config(self, config, version):
|
||||
|
||||
logger.info(f"Loading configuration version {version}")
|
||||
|
|
@ -147,11 +162,21 @@ class Processor(AgentService):
|
|||
)
|
||||
elif impl_id == "structured-query":
|
||||
impl = functools.partial(
|
||||
StructuredQueryImpl,
|
||||
StructuredQueryImpl,
|
||||
collection=data.get("collection"),
|
||||
user=None # User will be provided dynamically via context
|
||||
)
|
||||
arguments = StructuredQueryImpl.get_arguments()
|
||||
elif impl_id == "row-embeddings-query":
|
||||
impl = functools.partial(
|
||||
RowEmbeddingsQueryImpl,
|
||||
schema_name=data.get("schema-name"),
|
||||
collection=data.get("collection"),
|
||||
user=None, # User will be provided dynamically via context
|
||||
index_name=data.get("index-name"), # Optional filter
|
||||
limit=int(data.get("limit", 10)) # Max results
|
||||
)
|
||||
arguments = RowEmbeddingsQueryImpl.get_arguments()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Tool type {impl_id} not known"
|
||||
|
|
@ -327,11 +352,11 @@ class Processor(AgentService):
|
|||
def __init__(self, flow, user):
|
||||
self._flow = flow
|
||||
self._user = user
|
||||
|
||||
|
||||
def __call__(self, service_name):
|
||||
client = self._flow(service_name)
|
||||
# For structured query clients, store user context
|
||||
if service_name == "structured-query-request":
|
||||
# For query clients that need user context, store it
|
||||
if service_name in ("structured-query-request", "row-embeddings-query-request"):
|
||||
client._current_user = self._user
|
||||
return client
|
||||
|
||||
|
|
|
|||
|
|
@ -128,6 +128,62 @@ class StructuredQueryImpl:
|
|||
return str(result)
|
||||
|
||||
|
||||
# This tool implementation knows how to query row embeddings for semantic search
|
||||
class RowEmbeddingsQueryImpl:
|
||||
def __init__(self, context, schema_name, collection=None, user=None, index_name=None, limit=10):
|
||||
self.context = context
|
||||
self.schema_name = schema_name
|
||||
self.collection = collection
|
||||
self.user = user
|
||||
self.index_name = index_name # Optional: filter to specific index
|
||||
self.limit = limit # Max results to return
|
||||
|
||||
@staticmethod
|
||||
def get_arguments():
|
||||
return [
|
||||
Argument(
|
||||
name="query",
|
||||
type="string",
|
||||
description="Text to search for semantically similar values in the structured data index"
|
||||
)
|
||||
]
|
||||
|
||||
async def invoke(self, **arguments):
|
||||
# First get embeddings for the query text
|
||||
embeddings_client = self.context("embeddings-request")
|
||||
logger.debug("Getting embeddings for row query...")
|
||||
|
||||
query_text = arguments.get("query")
|
||||
vectors = await embeddings_client.embed(query_text)
|
||||
|
||||
# Now query row embeddings
|
||||
client = self.context("row-embeddings-query-request")
|
||||
logger.debug("Row embeddings query...")
|
||||
|
||||
# Get user from client context if available
|
||||
user = getattr(client, '_current_user', self.user or "trustgraph")
|
||||
|
||||
matches = await client.row_embeddings_query(
|
||||
vectors=vectors,
|
||||
schema_name=self.schema_name,
|
||||
user=user,
|
||||
collection=self.collection or "default",
|
||||
index_name=self.index_name,
|
||||
limit=self.limit
|
||||
)
|
||||
|
||||
# Format results for agent consumption
|
||||
if not matches:
|
||||
return "No matching records found"
|
||||
|
||||
results = []
|
||||
for match in matches:
|
||||
result = f"- {match['index_name']}: {', '.join(match['index_value'])} (score: {match['score']:.3f})"
|
||||
results.append(result)
|
||||
|
||||
return "Matching records:\n" + "\n".join(results)
|
||||
|
||||
|
||||
# This tool implementation knows how to execute prompt templates
|
||||
class PromptImpl:
|
||||
def __init__(self, context, template_id, arguments=None):
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from . structured_diag import StructuredDiagRequestor
|
|||
from . embeddings import EmbeddingsRequestor
|
||||
from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
|
||||
from . document_embeddings_query import DocumentEmbeddingsQueryRequestor
|
||||
from . row_embeddings_query import RowEmbeddingsQueryRequestor
|
||||
from . mcp_tool import McpToolRequestor
|
||||
from . text_load import TextLoad
|
||||
from . document_load import DocumentLoad
|
||||
|
|
@ -62,6 +63,7 @@ request_response_dispatchers = {
|
|||
"nlp-query": NLPQueryRequestor,
|
||||
"structured-query": StructuredQueryRequestor,
|
||||
"structured-diag": StructuredDiagRequestor,
|
||||
"row-embeddings": RowEmbeddingsQueryRequestor,
|
||||
}
|
||||
|
||||
global_dispatchers = {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,31 @@
|
|||
|
||||
from ... schema import RowEmbeddingsRequest, RowEmbeddingsResponse
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class RowEmbeddingsQueryRequestor(ServiceRequestor):
|
||||
def __init__(
|
||||
self, backend, request_queue, response_queue, timeout,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
super(RowEmbeddingsQueryRequestor, self).__init__(
|
||||
backend=backend,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
request_schema=RowEmbeddingsRequest,
|
||||
response_schema=RowEmbeddingsResponse,
|
||||
subscription = subscriber,
|
||||
consumer_name = consumer,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("row-embeddings-query")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("row-embeddings-query")
|
||||
|
||||
def to_request(self, body):
|
||||
return self.request_translator.to_pulsar(body)
|
||||
|
||||
def from_response(self, message):
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
Loading…
Add table
Add a link
Reference in a new issue