mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Row embeddings APIs exposed (#646)
* Added row embeddings API and CLI support * Updated protocol specs * Row embeddings agent tool * Add new agent tool to CLI
This commit is contained in:
parent
1809c1f56d
commit
4bbc6d844f
25 changed files with 1090 additions and 29 deletions
|
|
@ -368,8 +368,154 @@ This separation keeps concerns clean:
|
|||
- Embeddings API handles semantic similarity
|
||||
- User workflow: fuzzy search via embeddings to find candidates, then exact query to get full row data
|
||||
|
||||
#### Request/Response Schema
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class RowEmbeddingsRequest:
|
||||
vectors: list[list[float]] # Query vectors (pre-computed embeddings)
|
||||
user: str = ""
|
||||
collection: str = ""
|
||||
schema_name: str = ""
|
||||
index_name: str = "" # Optional: filter to specific index
|
||||
limit: int = 10 # Max results per vector
|
||||
|
||||
@dataclass
|
||||
class RowIndexMatch:
|
||||
index_name: str = "" # The matched index field(s)
|
||||
index_value: list[str] = [] # The matched value(s)
|
||||
text: str = "" # Original text that was embedded
|
||||
score: float = 0.0 # Similarity score
|
||||
|
||||
@dataclass
|
||||
class RowEmbeddingsResponse:
|
||||
error: Error | None = None
|
||||
matches: list[RowIndexMatch] = []
|
||||
```
|
||||
|
||||
#### Query Processor
|
||||
|
||||
Module: `trustgraph-flow/trustgraph/query/row_embeddings/qdrant`
|
||||
|
||||
Entry point: `row-embeddings-query-qdrant`
|
||||
|
||||
The processor:
|
||||
1. Receives `RowEmbeddingsRequest` with query vectors
|
||||
2. Finds the appropriate Qdrant collection by prefix matching
|
||||
3. Searches for nearest vectors with optional `index_name` filter
|
||||
4. Returns `RowEmbeddingsResponse` with matching index information
|
||||
|
||||
#### API Gateway Integration
|
||||
|
||||
The gateway exposes row embeddings queries via the standard request/response pattern:
|
||||
|
||||
| Component | Location |
|
||||
|-----------|----------|
|
||||
| Dispatcher | `trustgraph-flow/trustgraph/gateway/dispatch/row_embeddings_query.py` |
|
||||
| Registration | Add `"row-embeddings"` to `request_response_dispatchers` in `manager.py` |
|
||||
|
||||
Flow interface name: `row-embeddings`
|
||||
|
||||
Interface definition in flow blueprint:
|
||||
```json
|
||||
{
|
||||
"interfaces": {
|
||||
"row-embeddings": {
|
||||
"request": "non-persistent://tg/request/row-embeddings:{id}",
|
||||
"response": "non-persistent://tg/response/row-embeddings:{id}"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Python SDK Support
|
||||
|
||||
The SDK provides methods for row embeddings queries:
|
||||
|
||||
```python
|
||||
# Flow-scoped query (preferred)
|
||||
api = Api(url)
|
||||
flow = api.flow().id("default")
|
||||
|
||||
# Query with text (SDK computes embeddings)
|
||||
matches = flow.row_embeddings_query(
|
||||
text="Chestnut Street",
|
||||
collection="my_collection",
|
||||
schema_name="addresses",
|
||||
index_name="street_name", # Optional filter
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Query with pre-computed vectors
|
||||
matches = flow.row_embeddings_query(
|
||||
vectors=[[0.1, 0.2, ...]],
|
||||
collection="my_collection",
|
||||
schema_name="addresses"
|
||||
)
|
||||
|
||||
# Each match contains:
|
||||
for match in matches:
|
||||
print(match.index_name) # e.g., "street_name"
|
||||
print(match.index_value) # e.g., ["CHESTNUT ST"]
|
||||
print(match.text) # e.g., "CHESTNUT ST"
|
||||
print(match.score) # e.g., 0.95
|
||||
```
|
||||
|
||||
#### CLI Utility
|
||||
|
||||
Command: `tg-invoke-row-embeddings`
|
||||
|
||||
```bash
|
||||
# Query by text (computes embedding automatically)
|
||||
tg-invoke-row-embeddings \
|
||||
--text "Chestnut Street" \
|
||||
--collection my_collection \
|
||||
--schema addresses \
|
||||
--index street_name \
|
||||
--limit 10
|
||||
|
||||
# Query by vector file
|
||||
tg-invoke-row-embeddings \
|
||||
--vectors vectors.json \
|
||||
--collection my_collection \
|
||||
--schema addresses
|
||||
|
||||
# Output formats
|
||||
tg-invoke-row-embeddings --text "..." --format json
|
||||
tg-invoke-row-embeddings --text "..." --format table
|
||||
```
|
||||
|
||||
#### Typical Usage Pattern
|
||||
|
||||
The row embeddings query is typically used as part of a fuzzy-to-exact lookup flow:
|
||||
|
||||
```python
|
||||
# Step 1: Fuzzy search via embeddings
|
||||
matches = flow.row_embeddings_query(
|
||||
text="chestnut street",
|
||||
collection="geo",
|
||||
schema_name="streets"
|
||||
)
|
||||
|
||||
# Step 2: Exact lookup via GraphQL for full row data
|
||||
for match in matches:
|
||||
query = f'''
|
||||
query {{
|
||||
streets(where: {{ {match.index_name}: {{ eq: "{match.index_value[0]}" }} }}) {{
|
||||
street_name
|
||||
city
|
||||
zip_code
|
||||
}}
|
||||
}}
|
||||
'''
|
||||
rows = flow.rows_query(query, collection="geo")
|
||||
```
|
||||
|
||||
This two-step pattern enables:
|
||||
- Finding "CHESTNUT ST" when user searches for "Chestnut Street"
|
||||
- Retrieving complete row data with all fields
|
||||
- Combining semantic similarity with structured data access
|
||||
|
||||
### Row Data Ingestion
|
||||
|
||||
Deferred to a subsequent phase. Will be designed alongside other ingestion changes.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
type: object
|
||||
description: |
|
||||
Row embeddings query request - find similar rows by vector similarity on indexed fields.
|
||||
Enables semantic/fuzzy matching on structured data.
|
||||
required:
|
||||
- vectors
|
||||
- schema_name
|
||||
properties:
|
||||
vectors:
|
||||
type: array
|
||||
description: Query embedding vector
|
||||
items:
|
||||
type: number
|
||||
example: [0.023, -0.142, 0.089, 0.234, -0.067, 0.156]
|
||||
schema_name:
|
||||
type: string
|
||||
description: Schema name to search within
|
||||
example: customers
|
||||
index_name:
|
||||
type: string
|
||||
description: Optional index name to filter search to specific index
|
||||
example: full_name
|
||||
limit:
|
||||
type: integer
|
||||
description: Maximum number of matches to return
|
||||
default: 10
|
||||
minimum: 1
|
||||
maximum: 1000
|
||||
example: 20
|
||||
user:
|
||||
type: string
|
||||
description: User identifier
|
||||
default: trustgraph
|
||||
example: alice
|
||||
collection:
|
||||
type: string
|
||||
description: Collection to search
|
||||
default: default
|
||||
example: sales
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
type: object
|
||||
description: Row embeddings query response with matching row index information
|
||||
properties:
|
||||
error:
|
||||
type: object
|
||||
description: Error information if query failed
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
description: Error type identifier
|
||||
example: row-embeddings-query-error
|
||||
message:
|
||||
type: string
|
||||
description: Human-readable error message
|
||||
example: Schema not found
|
||||
matches:
|
||||
type: array
|
||||
description: List of matching row index entries with similarity scores
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
index_name:
|
||||
type: string
|
||||
description: Name of the indexed field(s)
|
||||
example: full_name
|
||||
index_value:
|
||||
type: array
|
||||
description: Values of the indexed fields for this row
|
||||
items:
|
||||
type: string
|
||||
example: ["John", "Smith"]
|
||||
text:
|
||||
type: string
|
||||
description: The text that was embedded for this index entry
|
||||
example: "John Smith"
|
||||
score:
|
||||
type: number
|
||||
description: Similarity score (higher is more similar)
|
||||
example: 0.89
|
||||
example:
|
||||
matches:
|
||||
- index_name: full_name
|
||||
index_value: ["John", "Smith"]
|
||||
text: "John Smith"
|
||||
score: 0.95
|
||||
- index_name: full_name
|
||||
index_value: ["Jon", "Smythe"]
|
||||
text: "Jon Smythe"
|
||||
score: 0.82
|
||||
- index_name: full_name
|
||||
index_value: ["Jonathan", "Schmidt"]
|
||||
text: "Jonathan Schmidt"
|
||||
score: 0.76
|
||||
|
|
@ -133,6 +133,8 @@ paths:
|
|||
$ref: './paths/flow/graph-embeddings.yaml'
|
||||
/api/v1/flow/{flow}/service/document-embeddings:
|
||||
$ref: './paths/flow/document-embeddings.yaml'
|
||||
/api/v1/flow/{flow}/service/row-embeddings:
|
||||
$ref: './paths/flow/row-embeddings.yaml'
|
||||
/api/v1/flow/{flow}/service/text-load:
|
||||
$ref: './paths/flow/text-load.yaml'
|
||||
/api/v1/flow/{flow}/service/document-load:
|
||||
|
|
|
|||
101
specs/api/paths/flow/row-embeddings.yaml
Normal file
101
specs/api/paths/flow/row-embeddings.yaml
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
post:
|
||||
tags:
|
||||
- Flow Services
|
||||
summary: Row Embeddings Query - semantic search on structured data
|
||||
description: |
|
||||
Query row embeddings to find similar rows by vector similarity on indexed fields.
|
||||
Enables fuzzy/semantic matching on structured data.
|
||||
|
||||
## Row Embeddings Query Overview
|
||||
|
||||
Find rows whose indexed field values are semantically similar to a query:
|
||||
- **Input**: Query embedding vector, schema name, optional index filter
|
||||
- **Search**: Compare against stored row index embeddings
|
||||
- **Output**: Matching rows with index values and similarity scores
|
||||
|
||||
Core component of semantic search on structured data.
|
||||
|
||||
## Use Cases
|
||||
|
||||
- **Fuzzy name matching**: Find customers by approximate name
|
||||
- **Semantic field search**: Find products by description similarity
|
||||
- **Data deduplication**: Identify potential duplicate records
|
||||
- **Entity resolution**: Match records across datasets
|
||||
|
||||
## Process
|
||||
|
||||
1. Obtain query embedding (via embeddings service)
|
||||
2. Query stored row index embeddings for the specified schema
|
||||
3. Calculate cosine similarity
|
||||
4. Return top N most similar index entries
|
||||
5. Use index values to retrieve full rows via GraphQL
|
||||
|
||||
## Response Format
|
||||
|
||||
Each match includes:
|
||||
- `index_name`: The indexed field(s) that matched
|
||||
- `index_value`: The actual values for those fields
|
||||
- `text`: The text that was embedded
|
||||
- `score`: Similarity score (higher = more similar)
|
||||
|
||||
operationId: rowEmbeddingsQueryService
|
||||
security:
|
||||
- bearerAuth: []
|
||||
parameters:
|
||||
- name: flow
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: Flow instance ID
|
||||
example: my-flow
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '../../components/schemas/embeddings-query/RowEmbeddingsQueryRequest.yaml'
|
||||
examples:
|
||||
basicQuery:
|
||||
summary: Find similar customer names
|
||||
value:
|
||||
vectors: [0.023, -0.142, 0.089, 0.234, -0.067, 0.156, 0.201, -0.178]
|
||||
schema_name: customers
|
||||
limit: 10
|
||||
user: alice
|
||||
collection: sales
|
||||
filteredQuery:
|
||||
summary: Search specific index
|
||||
value:
|
||||
vectors: [0.1, -0.2, 0.3, -0.4, 0.5]
|
||||
schema_name: products
|
||||
index_name: description
|
||||
limit: 20
|
||||
responses:
|
||||
'200':
|
||||
description: Successful response
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '../../components/schemas/embeddings-query/RowEmbeddingsQueryResponse.yaml'
|
||||
examples:
|
||||
similarRows:
|
||||
summary: Similar rows found
|
||||
value:
|
||||
matches:
|
||||
- index_name: full_name
|
||||
index_value: ["John", "Smith"]
|
||||
text: "John Smith"
|
||||
score: 0.95
|
||||
- index_name: full_name
|
||||
index_value: ["Jon", "Smythe"]
|
||||
text: "Jon Smythe"
|
||||
score: 0.82
|
||||
- index_name: full_name
|
||||
index_value: ["Jonathan", "Schmidt"]
|
||||
text: "Jonathan Schmidt"
|
||||
score: 0.76
|
||||
'401':
|
||||
$ref: '../../components/responses/Unauthorized.yaml'
|
||||
'500':
|
||||
$ref: '../../components/responses/Error.yaml'
|
||||
|
|
@ -31,6 +31,7 @@ payload:
|
|||
- $ref: './requests/StructuredDiagRequest.yaml'
|
||||
- $ref: './requests/GraphEmbeddingsRequest.yaml'
|
||||
- $ref: './requests/DocumentEmbeddingsRequest.yaml'
|
||||
- $ref: './requests/RowEmbeddingsRequest.yaml'
|
||||
- $ref: './requests/TextLoadRequest.yaml'
|
||||
- $ref: './requests/DocumentLoadRequest.yaml'
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,30 @@
|
|||
type: object
|
||||
description: WebSocket request for row-embeddings service (flow-hosted service)
|
||||
required:
|
||||
- id
|
||||
- service
|
||||
- flow
|
||||
- request
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
description: Unique request identifier
|
||||
service:
|
||||
type: string
|
||||
const: row-embeddings
|
||||
description: Service identifier for row-embeddings service
|
||||
flow:
|
||||
type: string
|
||||
description: Flow ID
|
||||
request:
|
||||
$ref: '../../../../api/components/schemas/embeddings-query/RowEmbeddingsQueryRequest.yaml'
|
||||
examples:
|
||||
- id: req-1
|
||||
service: row-embeddings
|
||||
flow: my-flow
|
||||
request:
|
||||
vectors: [0.023, -0.142, 0.089, 0.234]
|
||||
schema_name: customers
|
||||
limit: 10
|
||||
user: trustgraph
|
||||
collection: default
|
||||
|
|
@ -766,3 +766,63 @@ class AsyncFlowInstance:
|
|||
request_data.update(kwargs)
|
||||
|
||||
return await self.request("rows", request_data)
|
||||
|
||||
async def row_embeddings_query(
|
||||
self, text: str, schema_name: str, user: str = "trustgraph",
|
||||
collection: str = "default", index_name: Optional[str] = None,
|
||||
limit: int = 10, **kwargs: Any
|
||||
):
|
||||
"""
|
||||
Query row embeddings for semantic search on structured data.
|
||||
|
||||
Performs semantic search over row index embeddings to find rows whose
|
||||
indexed field values are most similar to the input text. Enables
|
||||
fuzzy/semantic matching on structured data.
|
||||
|
||||
Args:
|
||||
text: Query text for semantic search
|
||||
schema_name: Schema name to search within
|
||||
user: User identifier (default: "trustgraph")
|
||||
collection: Collection identifier (default: "default")
|
||||
index_name: Optional index name to filter search to specific index
|
||||
limit: Maximum number of results to return (default: 10)
|
||||
**kwargs: Additional service-specific parameters
|
||||
|
||||
Returns:
|
||||
dict: Response containing matches with index_name, index_value,
|
||||
text, and score
|
||||
|
||||
Example:
|
||||
```python
|
||||
async_flow = await api.async_flow()
|
||||
flow = async_flow.id("default")
|
||||
|
||||
# Search for customers by name similarity
|
||||
results = await flow.row_embeddings_query(
|
||||
text="John Smith",
|
||||
schema_name="customers",
|
||||
user="trustgraph",
|
||||
collection="sales",
|
||||
limit=5
|
||||
)
|
||||
|
||||
for match in results.get("matches", []):
|
||||
print(f"{match['index_name']}: {match['index_value']} (score: {match['score']})")
|
||||
```
|
||||
"""
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = await self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
|
||||
request_data = {
|
||||
"vectors": vectors,
|
||||
"schema_name": schema_name,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
}
|
||||
if index_name:
|
||||
request_data["index_name"] = index_name
|
||||
request_data.update(kwargs)
|
||||
|
||||
return await self.request("row-embeddings", request_data)
|
||||
|
|
|
|||
|
|
@ -345,3 +345,26 @@ class AsyncSocketFlowInstance:
|
|||
request.update(kwargs)
|
||||
|
||||
return await self.client._send_request("mcp-tool", self.flow_id, request)
|
||||
|
||||
async def row_embeddings_query(
|
||||
self, text: str, schema_name: str, user: str = "trustgraph",
|
||||
collection: str = "default", index_name: Optional[str] = None,
|
||||
limit: int = 10, **kwargs
|
||||
):
|
||||
"""Query row embeddings for semantic search on structured data"""
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = await self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
|
||||
request = {
|
||||
"vectors": vectors,
|
||||
"schema_name": schema_name,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
}
|
||||
if index_name:
|
||||
request["index_name"] = index_name
|
||||
request.update(kwargs)
|
||||
|
||||
return await self.client._send_request("row-embeddings", self.flow_id, request)
|
||||
|
|
|
|||
|
|
@ -1297,3 +1297,78 @@ class FlowInstance:
|
|||
|
||||
return response["schema-matches"]
|
||||
|
||||
def row_embeddings_query(
|
||||
self, text, schema_name, user="trustgraph", collection="default",
|
||||
index_name=None, limit=10
|
||||
):
|
||||
"""
|
||||
Query row data using semantic similarity on indexed fields.
|
||||
|
||||
Finds rows whose indexed field values are semantically similar to the
|
||||
input text, using vector embeddings. This enables fuzzy/semantic matching
|
||||
on structured data.
|
||||
|
||||
Args:
|
||||
text: Query text for semantic search
|
||||
schema_name: Schema name to search within
|
||||
user: User/keyspace identifier (default: "trustgraph")
|
||||
collection: Collection identifier (default: "default")
|
||||
index_name: Optional index name to filter search to specific index
|
||||
limit: Maximum number of results (default: 10)
|
||||
|
||||
Returns:
|
||||
dict: Query results with matches containing index_name, index_value,
|
||||
text, and score
|
||||
|
||||
Example:
|
||||
```python
|
||||
flow = api.flow().id("default")
|
||||
|
||||
# Search for customers by name similarity
|
||||
results = flow.row_embeddings_query(
|
||||
text="John Smith",
|
||||
schema_name="customers",
|
||||
user="trustgraph",
|
||||
collection="sales",
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Filter to specific index
|
||||
results = flow.row_embeddings_query(
|
||||
text="machine learning engineer",
|
||||
schema_name="employees",
|
||||
index_name="job_title",
|
||||
limit=10
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
|
||||
# Query row embeddings for semantic search
|
||||
input = {
|
||||
"vectors": vectors,
|
||||
"schema_name": schema_name,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
if index_name:
|
||||
input["index_name"] = index_name
|
||||
|
||||
response = self.request(
|
||||
"service/row-embeddings",
|
||||
input
|
||||
)
|
||||
|
||||
# Check for system-level error
|
||||
if "error" in response and response["error"]:
|
||||
error_type = response["error"].get("type", "unknown")
|
||||
error_message = response["error"].get("message", "Unknown error")
|
||||
raise ProtocolException(f"{error_type}: {error_message}")
|
||||
|
||||
return response
|
||||
|
||||
|
|
|
|||
|
|
@ -881,3 +881,73 @@ class SocketFlowInstance:
|
|||
request.update(kwargs)
|
||||
|
||||
return self.client._send_request_sync("mcp-tool", self.flow_id, request, False)
|
||||
|
||||
def row_embeddings_query(
|
||||
self,
|
||||
text: str,
|
||||
schema_name: str,
|
||||
user: str = "trustgraph",
|
||||
collection: str = "default",
|
||||
index_name: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
**kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Query row data using semantic similarity on indexed fields.
|
||||
|
||||
Finds rows whose indexed field values are semantically similar to the
|
||||
input text, using vector embeddings. This enables fuzzy/semantic matching
|
||||
on structured data.
|
||||
|
||||
Args:
|
||||
text: Query text for semantic search
|
||||
schema_name: Schema name to search within
|
||||
user: User/keyspace identifier (default: "trustgraph")
|
||||
collection: Collection identifier (default: "default")
|
||||
index_name: Optional index name to filter search to specific index
|
||||
limit: Maximum number of results (default: 10)
|
||||
**kwargs: Additional parameters passed to the service
|
||||
|
||||
Returns:
|
||||
dict: Query results with matches containing index_name, index_value,
|
||||
text, and score
|
||||
|
||||
Example:
|
||||
```python
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
|
||||
# Search for customers by name similarity
|
||||
results = flow.row_embeddings_query(
|
||||
text="John Smith",
|
||||
schema_name="customers",
|
||||
user="trustgraph",
|
||||
collection="sales",
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Filter to specific index
|
||||
results = flow.row_embeddings_query(
|
||||
text="machine learning engineer",
|
||||
schema_name="employees",
|
||||
index_name="job_title",
|
||||
limit=10
|
||||
)
|
||||
```
|
||||
"""
|
||||
# First convert text to embeddings vectors
|
||||
emb_result = self.embeddings(text=text)
|
||||
vectors = emb_result.get("vectors", [])
|
||||
|
||||
request = {
|
||||
"vectors": vectors,
|
||||
"schema_name": schema_name,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"limit": limit
|
||||
}
|
||||
if index_name:
|
||||
request["index_name"] = index_name
|
||||
request.update(kwargs)
|
||||
|
||||
return self.client._send_request_sync("row-embeddings", self.flow_id, request, False)
|
||||
|
|
|
|||
|
|
@ -34,5 +34,6 @@ from . tool_service import ToolService
|
|||
from . tool_client import ToolClientSpec
|
||||
from . agent_client import AgentClientSpec
|
||||
from . structured_query_client import StructuredQueryClientSpec
|
||||
from . row_embeddings_query_client import RowEmbeddingsQueryClientSpec
|
||||
from . collection_config_handler import CollectionConfigHandler
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,45 @@
|
|||
from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||
from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse
|
||||
|
||||
class RowEmbeddingsQueryClient(RequestResponse):
|
||||
async def row_embeddings_query(
|
||||
self, vectors, schema_name, user="trustgraph", collection="default",
|
||||
index_name=None, limit=10, timeout=600
|
||||
):
|
||||
request = RowEmbeddingsRequest(
|
||||
vectors=vectors,
|
||||
schema_name=schema_name,
|
||||
user=user,
|
||||
collection=collection,
|
||||
limit=limit
|
||||
)
|
||||
if index_name:
|
||||
request.index_name = index_name
|
||||
|
||||
resp = await self.request(request, timeout=timeout)
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
# Return matches as list of dicts
|
||||
return [
|
||||
{
|
||||
"index_name": match.index_name,
|
||||
"index_value": match.index_value,
|
||||
"text": match.text,
|
||||
"score": match.score
|
||||
}
|
||||
for match in (resp.matches or [])
|
||||
]
|
||||
|
||||
class RowEmbeddingsQueryClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
self, request_name, response_name,
|
||||
):
|
||||
super(RowEmbeddingsQueryClientSpec, self).__init__(
|
||||
request_name = request_name,
|
||||
request_schema = RowEmbeddingsRequest,
|
||||
response_name = response_name,
|
||||
response_schema = RowEmbeddingsResponse,
|
||||
impl = RowEmbeddingsQueryClient,
|
||||
)
|
||||
60
trustgraph-base/trustgraph/clients/row_embeddings_client.py
Normal file
60
trustgraph-base/trustgraph/clients/row_embeddings_client.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
|
||||
import _pulsar
|
||||
|
||||
from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse
|
||||
from .. schema import row_embeddings_request_queue
|
||||
from .. schema import row_embeddings_response_queue
|
||||
from . base import BaseClient
|
||||
|
||||
# Ugly
|
||||
ERROR=_pulsar.LoggerLevel.Error
|
||||
WARN=_pulsar.LoggerLevel.Warn
|
||||
INFO=_pulsar.LoggerLevel.Info
|
||||
DEBUG=_pulsar.LoggerLevel.Debug
|
||||
|
||||
class RowEmbeddingsClient(BaseClient):
|
||||
|
||||
def __init__(
|
||||
self, log_level=ERROR,
|
||||
subscriber=None,
|
||||
input_queue=None,
|
||||
output_queue=None,
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
pulsar_api_key=None,
|
||||
):
|
||||
|
||||
if input_queue == None:
|
||||
input_queue = row_embeddings_request_queue
|
||||
|
||||
if output_queue == None:
|
||||
output_queue = row_embeddings_response_queue
|
||||
|
||||
super(RowEmbeddingsClient, self).__init__(
|
||||
log_level=log_level,
|
||||
subscriber=subscriber,
|
||||
input_queue=input_queue,
|
||||
output_queue=output_queue,
|
||||
pulsar_host=pulsar_host,
|
||||
pulsar_api_key=pulsar_api_key,
|
||||
input_schema=RowEmbeddingsRequest,
|
||||
output_schema=RowEmbeddingsResponse,
|
||||
)
|
||||
|
||||
def request(
|
||||
self, vectors, schema_name, user="trustgraph", collection="default",
|
||||
index_name=None, limit=10, timeout=300
|
||||
):
|
||||
kwargs = dict(
|
||||
user=user, collection=collection,
|
||||
vectors=vectors, schema_name=schema_name,
|
||||
limit=limit, timeout=timeout
|
||||
)
|
||||
if index_name:
|
||||
kwargs["index_name"] = index_name
|
||||
|
||||
response = self.call(**kwargs)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(f"{response.error.type}: {response.error.message}")
|
||||
|
||||
return response.matches
|
||||
|
|
@ -19,7 +19,8 @@ from .translators.prompt import PromptRequestTranslator, PromptResponseTranslato
|
|||
from .translators.tool import ToolRequestTranslator, ToolResponseTranslator
|
||||
from .translators.embeddings_query import (
|
||||
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
|
||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator,
|
||||
RowEmbeddingsRequestTranslator, RowEmbeddingsResponseTranslator
|
||||
)
|
||||
from .translators.rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
|
||||
from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator
|
||||
|
|
@ -107,11 +108,17 @@ TranslatorRegistry.register_service(
|
|||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"graph-embeddings-query",
|
||||
GraphEmbeddingsRequestTranslator(),
|
||||
"graph-embeddings-query",
|
||||
GraphEmbeddingsRequestTranslator(),
|
||||
GraphEmbeddingsResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"row-embeddings-query",
|
||||
RowEmbeddingsRequestTranslator(),
|
||||
RowEmbeddingsResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"rows-query",
|
||||
RowsQueryRequestTranslator(),
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ from .flow import FlowRequestTranslator, FlowResponseTranslator
|
|||
from .prompt import PromptRequestTranslator, PromptResponseTranslator
|
||||
from .embeddings_query import (
|
||||
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
|
||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator,
|
||||
RowEmbeddingsRequestTranslator, RowEmbeddingsResponseTranslator
|
||||
)
|
||||
from .rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
|
||||
from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from typing import Dict, Any, Tuple
|
||||
from ...schema import (
|
||||
DocumentEmbeddingsRequest, DocumentEmbeddingsResponse,
|
||||
GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||
GraphEmbeddingsRequest, GraphEmbeddingsResponse,
|
||||
RowEmbeddingsRequest, RowEmbeddingsResponse, RowIndexMatch
|
||||
)
|
||||
from .base import MessageTranslator
|
||||
from .primitives import ValueTranslator
|
||||
|
|
@ -92,3 +93,62 @@ class GraphEmbeddingsResponseTranslator(MessageTranslator):
|
|||
def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
||||
|
||||
class RowEmbeddingsRequestTranslator(MessageTranslator):
|
||||
"""Translator for RowEmbeddingsRequest schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest:
|
||||
return RowEmbeddingsRequest(
|
||||
vectors=data["vectors"],
|
||||
limit=int(data.get("limit", 10)),
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default"),
|
||||
schema_name=data.get("schema_name", ""),
|
||||
index_name=data.get("index_name")
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]:
|
||||
result = {
|
||||
"vectors": obj.vectors,
|
||||
"limit": obj.limit,
|
||||
"user": obj.user,
|
||||
"collection": obj.collection,
|
||||
"schema_name": obj.schema_name,
|
||||
}
|
||||
if obj.index_name:
|
||||
result["index_name"] = obj.index_name
|
||||
return result
|
||||
|
||||
|
||||
class RowEmbeddingsResponseTranslator(MessageTranslator):
|
||||
"""Translator for RowEmbeddingsResponse schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
def from_pulsar(self, obj: RowEmbeddingsResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.error is not None:
|
||||
result["error"] = {
|
||||
"type": obj.error.type,
|
||||
"message": obj.error.message
|
||||
}
|
||||
|
||||
if obj.matches is not None:
|
||||
result["matches"] = [
|
||||
{
|
||||
"index_name": match.index_name,
|
||||
"index_value": match.index_value,
|
||||
"text": match.text,
|
||||
"score": match.score
|
||||
}
|
||||
for match in obj.matches
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
def from_response_with_completion(self, obj: RowEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ tg-invoke-document-embeddings = "trustgraph.cli.invoke_document_embeddings:main"
|
|||
tg-invoke-mcp-tool = "trustgraph.cli.invoke_mcp_tool:main"
|
||||
tg-invoke-nlp-query = "trustgraph.cli.invoke_nlp_query:main"
|
||||
tg-invoke-rows-query = "trustgraph.cli.invoke_rows_query:main"
|
||||
tg-invoke-row-embeddings = "trustgraph.cli.invoke_row_embeddings:main"
|
||||
tg-invoke-prompt = "trustgraph.cli.invoke_prompt:main"
|
||||
tg-invoke-structured-query = "trustgraph.cli.invoke_structured_query:main"
|
||||
tg-load-doc-embeds = "trustgraph.cli.load_doc_embeds:main"
|
||||
|
|
|
|||
126
trustgraph-cli/trustgraph/cli/invoke_row_embeddings.py
Normal file
126
trustgraph-cli/trustgraph/cli/invoke_row_embeddings.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
"""
|
||||
Queries row data by text similarity using vector embeddings on indexed fields.
|
||||
Returns matching rows with their index values and similarity scores.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from trustgraph.api import Api
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
|
||||
def query(url, flow_id, query_text, schema_name, user, collection, index_name, limit, token=None):
|
||||
|
||||
# Create API client
|
||||
api = Api(url=url, token=token)
|
||||
socket = api.socket()
|
||||
flow = socket.flow(flow_id)
|
||||
|
||||
try:
|
||||
# Call row embeddings query service
|
||||
result = flow.row_embeddings_query(
|
||||
text=query_text,
|
||||
schema_name=schema_name,
|
||||
user=user,
|
||||
collection=collection,
|
||||
index_name=index_name,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
matches = result.get("matches", [])
|
||||
for match in matches:
|
||||
print(f"Index: {match['index_name']}")
|
||||
print(f" Values: {match['index_value']}")
|
||||
print(f" Text: {match['text']}")
|
||||
print(f" Score: {match['score']:.4f}")
|
||||
print()
|
||||
|
||||
finally:
|
||||
# Clean up socket connection
|
||||
socket.close()
|
||||
|
||||
def main():
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='tg-invoke-row-embeddings',
|
||||
description=__doc__,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-u', '--url',
|
||||
default=default_url,
|
||||
help=f'API URL (default: {default_url})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--token',
|
||||
default=default_token,
|
||||
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-f', '--flow-id',
|
||||
default="default",
|
||||
help=f'Flow ID (default: default)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-U', '--user',
|
||||
default="trustgraph",
|
||||
help='User/keyspace (default: trustgraph)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--collection',
|
||||
default="default",
|
||||
help='Collection (default: default)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-s', '--schema-name',
|
||||
required=True,
|
||||
help='Schema name to search within (required)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-i', '--index-name',
|
||||
default=None,
|
||||
help='Index name to filter search (optional)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-l', '--limit',
|
||||
type=int,
|
||||
default=10,
|
||||
help='Maximum number of results (default: 10)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'query',
|
||||
nargs=1,
|
||||
help='Query text to search for similar row index values',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
||||
query(
|
||||
url=args.url,
|
||||
flow_id=args.flow_id,
|
||||
query_text=args.query[0],
|
||||
schema_name=args.schema_name,
|
||||
user=args.user,
|
||||
collection=args.collection,
|
||||
index_name=args.index_name,
|
||||
limit=args.limit,
|
||||
token=args.token,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print("Exception:", e, flush=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -2,8 +2,9 @@
|
|||
Configures and registers tools in the TrustGraph system.
|
||||
|
||||
This script allows you to define agent tools with various types including:
|
||||
- knowledge-query: Query knowledge bases
|
||||
- knowledge-query: Query knowledge bases
|
||||
- structured-query: Query structured data using natural language
|
||||
- row-embeddings-query: Semantic search on structured data indexes
|
||||
- text-completion: Text generation
|
||||
- mcp-tool: Reference to MCP (Model Context Protocol) tools
|
||||
- prompt: Prompt template execution
|
||||
|
|
@ -64,6 +65,9 @@ def set_tool(
|
|||
mcp_tool : str,
|
||||
collection : str,
|
||||
template : str,
|
||||
schema_name : str,
|
||||
index_name : str,
|
||||
limit : int,
|
||||
arguments : List[Argument],
|
||||
group : List[str],
|
||||
state : str,
|
||||
|
|
@ -89,6 +93,12 @@ def set_tool(
|
|||
|
||||
if template: object["template"] = template
|
||||
|
||||
if schema_name: object["schema-name"] = schema_name
|
||||
|
||||
if index_name: object["index-name"] = index_name
|
||||
|
||||
if limit: object["limit"] = limit
|
||||
|
||||
if arguments:
|
||||
object["arguments"] = [
|
||||
{
|
||||
|
|
@ -120,30 +130,37 @@ def main():
|
|||
description=__doc__,
|
||||
epilog=textwrap.dedent('''
|
||||
Valid tool types:
|
||||
knowledge-query - Query knowledge bases (fixed args)
|
||||
structured-query - Query structured data using natural language (fixed args)
|
||||
text-completion - Text completion/generation (fixed args)
|
||||
mcp-tool - Model Control Protocol tool (configurable args)
|
||||
prompt - Prompt template query (configurable args)
|
||||
|
||||
Note: Tools marked "(fixed args)" have predefined arguments and don't need
|
||||
knowledge-query - Query knowledge bases (fixed args)
|
||||
structured-query - Query structured data using natural language (fixed args)
|
||||
row-embeddings-query - Semantic search on structured data indexes (fixed args)
|
||||
text-completion - Text completion/generation (fixed args)
|
||||
mcp-tool - Model Control Protocol tool (configurable args)
|
||||
prompt - Prompt template query (configurable args)
|
||||
|
||||
Note: Tools marked "(fixed args)" have predefined arguments and don't need
|
||||
--argument specified. Tools marked "(configurable args)" require --argument.
|
||||
|
||||
|
||||
Valid argument types:
|
||||
string - String/text parameter
|
||||
string - String/text parameter
|
||||
number - Numeric parameter
|
||||
|
||||
|
||||
Examples:
|
||||
%(prog)s --id weather_tool --name get_weather \\
|
||||
--type knowledge-query \\
|
||||
--description "Get weather information for a location" \\
|
||||
--collection weather_data
|
||||
|
||||
|
||||
%(prog)s --id data_query_tool --name query_data \\
|
||||
--type structured-query \\
|
||||
--description "Query structured data using natural language" \\
|
||||
--collection sales_data
|
||||
|
||||
|
||||
%(prog)s --id customer_search --name find_customer \\
|
||||
--type row-embeddings-query \\
|
||||
--description "Find customers by name using semantic search" \\
|
||||
--schema-name customers --collection sales \\
|
||||
--index-name full_name --limit 20
|
||||
|
||||
%(prog)s --id calc_tool --name calculate --type mcp-tool \\
|
||||
--description "Perform mathematical calculations" \\
|
||||
--mcp-tool calculator \\
|
||||
|
|
@ -181,7 +198,7 @@ def main():
|
|||
|
||||
parser.add_argument(
|
||||
'--type',
|
||||
help=f'Tool type, one of: knowledge-query, structured-query, text-completion, mcp-tool, prompt',
|
||||
help=f'Tool type, one of: knowledge-query, structured-query, row-embeddings-query, text-completion, mcp-tool, prompt',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
|
@ -191,7 +208,23 @@ def main():
|
|||
|
||||
parser.add_argument(
|
||||
'--collection',
|
||||
help=f'For knowledge-query and structured-query types: collection to query',
|
||||
help=f'For knowledge-query, structured-query, and row-embeddings-query types: collection to query',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--schema-name',
|
||||
help=f'For row-embeddings-query type: schema name to search within (required)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--index-name',
|
||||
help=f'For row-embeddings-query type: specific index to filter search (optional)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--limit',
|
||||
type=int,
|
||||
help=f'For row-embeddings-query type: maximum results to return (default: 10)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
|
@ -227,7 +260,8 @@ def main():
|
|||
try:
|
||||
|
||||
valid_types = [
|
||||
"knowledge-query", "structured-query", "text-completion", "mcp-tool", "prompt"
|
||||
"knowledge-query", "structured-query", "row-embeddings-query",
|
||||
"text-completion", "mcp-tool", "prompt"
|
||||
]
|
||||
|
||||
if args.id is None:
|
||||
|
|
@ -261,6 +295,9 @@ def main():
|
|||
mcp_tool=mcp_tool,
|
||||
collection=args.collection,
|
||||
template=args.template,
|
||||
schema_name=args.schema_name,
|
||||
index_name=args.index_name,
|
||||
limit=args.limit,
|
||||
arguments=arguments,
|
||||
group=args.group,
|
||||
state=args.state,
|
||||
|
|
|
|||
|
|
@ -4,8 +4,9 @@ Displays the current agent tool configurations
|
|||
Shows all configured tools including their types:
|
||||
- knowledge-query: Tools that query knowledge bases
|
||||
- structured-query: Tools that query structured data using natural language
|
||||
- row-embeddings-query: Tools for semantic search on structured data indexes
|
||||
- text-completion: Tools for text generation
|
||||
- mcp-tool: References to MCP (Model Context Protocol) tools
|
||||
- mcp-tool: References to MCP (Model Context Protocol) tools
|
||||
- prompt: Tools that execute prompt templates
|
||||
"""
|
||||
|
||||
|
|
@ -41,11 +42,19 @@ def show_config(url, token=None):
|
|||
|
||||
if tp == "mcp-tool":
|
||||
table.append(("mcp-tool", data["mcp-tool"]))
|
||||
|
||||
if tp == "knowledge-query" or tp == "structured-query":
|
||||
|
||||
if tp in ("knowledge-query", "structured-query", "row-embeddings-query"):
|
||||
if "collection" in data:
|
||||
table.append(("collection", data["collection"]))
|
||||
|
||||
if tp == "row-embeddings-query":
|
||||
if "schema-name" in data:
|
||||
table.append(("schema-name", data["schema-name"]))
|
||||
if "index-name" in data:
|
||||
table.append(("index-name", data["index-name"]))
|
||||
if "limit" in data:
|
||||
table.append(("limit", data["limit"]))
|
||||
|
||||
if tp == "prompt":
|
||||
table.append(("template", data["template"]))
|
||||
for n, arg in enumerate(data["arguments"]):
|
||||
|
|
|
|||
|
|
@ -13,10 +13,11 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
|
||||
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
|
||||
from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec
|
||||
|
||||
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
|
||||
|
||||
from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl, StructuredQueryImpl
|
||||
from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl, StructuredQueryImpl, RowEmbeddingsQueryImpl
|
||||
from . agent_manager import AgentManager
|
||||
from ..tool_filter import validate_tool_config, filter_tools_by_group_and_state, get_next_state
|
||||
|
||||
|
|
@ -87,6 +88,20 @@ class Processor(AgentService):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
EmbeddingsClientSpec(
|
||||
request_name = "embeddings-request",
|
||||
response_name = "embeddings-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
RowEmbeddingsQueryClientSpec(
|
||||
request_name = "row-embeddings-query-request",
|
||||
response_name = "row-embeddings-query-response",
|
||||
)
|
||||
)
|
||||
|
||||
async def on_tools_config(self, config, version):
|
||||
|
||||
logger.info(f"Loading configuration version {version}")
|
||||
|
|
@ -147,11 +162,21 @@ class Processor(AgentService):
|
|||
)
|
||||
elif impl_id == "structured-query":
|
||||
impl = functools.partial(
|
||||
StructuredQueryImpl,
|
||||
StructuredQueryImpl,
|
||||
collection=data.get("collection"),
|
||||
user=None # User will be provided dynamically via context
|
||||
)
|
||||
arguments = StructuredQueryImpl.get_arguments()
|
||||
elif impl_id == "row-embeddings-query":
|
||||
impl = functools.partial(
|
||||
RowEmbeddingsQueryImpl,
|
||||
schema_name=data.get("schema-name"),
|
||||
collection=data.get("collection"),
|
||||
user=None, # User will be provided dynamically via context
|
||||
index_name=data.get("index-name"), # Optional filter
|
||||
limit=int(data.get("limit", 10)) # Max results
|
||||
)
|
||||
arguments = RowEmbeddingsQueryImpl.get_arguments()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Tool type {impl_id} not known"
|
||||
|
|
@ -327,11 +352,11 @@ class Processor(AgentService):
|
|||
def __init__(self, flow, user):
|
||||
self._flow = flow
|
||||
self._user = user
|
||||
|
||||
|
||||
def __call__(self, service_name):
|
||||
client = self._flow(service_name)
|
||||
# For structured query clients, store user context
|
||||
if service_name == "structured-query-request":
|
||||
# For query clients that need user context, store it
|
||||
if service_name in ("structured-query-request", "row-embeddings-query-request"):
|
||||
client._current_user = self._user
|
||||
return client
|
||||
|
||||
|
|
|
|||
|
|
@ -128,6 +128,62 @@ class StructuredQueryImpl:
|
|||
return str(result)
|
||||
|
||||
|
||||
# This tool implementation knows how to query row embeddings for semantic search
|
||||
class RowEmbeddingsQueryImpl:
|
||||
def __init__(self, context, schema_name, collection=None, user=None, index_name=None, limit=10):
|
||||
self.context = context
|
||||
self.schema_name = schema_name
|
||||
self.collection = collection
|
||||
self.user = user
|
||||
self.index_name = index_name # Optional: filter to specific index
|
||||
self.limit = limit # Max results to return
|
||||
|
||||
@staticmethod
|
||||
def get_arguments():
|
||||
return [
|
||||
Argument(
|
||||
name="query",
|
||||
type="string",
|
||||
description="Text to search for semantically similar values in the structured data index"
|
||||
)
|
||||
]
|
||||
|
||||
async def invoke(self, **arguments):
|
||||
# First get embeddings for the query text
|
||||
embeddings_client = self.context("embeddings-request")
|
||||
logger.debug("Getting embeddings for row query...")
|
||||
|
||||
query_text = arguments.get("query")
|
||||
vectors = await embeddings_client.embed(query_text)
|
||||
|
||||
# Now query row embeddings
|
||||
client = self.context("row-embeddings-query-request")
|
||||
logger.debug("Row embeddings query...")
|
||||
|
||||
# Get user from client context if available
|
||||
user = getattr(client, '_current_user', self.user or "trustgraph")
|
||||
|
||||
matches = await client.row_embeddings_query(
|
||||
vectors=vectors,
|
||||
schema_name=self.schema_name,
|
||||
user=user,
|
||||
collection=self.collection or "default",
|
||||
index_name=self.index_name,
|
||||
limit=self.limit
|
||||
)
|
||||
|
||||
# Format results for agent consumption
|
||||
if not matches:
|
||||
return "No matching records found"
|
||||
|
||||
results = []
|
||||
for match in matches:
|
||||
result = f"- {match['index_name']}: {', '.join(match['index_value'])} (score: {match['score']:.3f})"
|
||||
results.append(result)
|
||||
|
||||
return "Matching records:\n" + "\n".join(results)
|
||||
|
||||
|
||||
# This tool implementation knows how to execute prompt templates
|
||||
class PromptImpl:
|
||||
def __init__(self, context, template_id, arguments=None):
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from . structured_diag import StructuredDiagRequestor
|
|||
from . embeddings import EmbeddingsRequestor
|
||||
from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
|
||||
from . document_embeddings_query import DocumentEmbeddingsQueryRequestor
|
||||
from . row_embeddings_query import RowEmbeddingsQueryRequestor
|
||||
from . mcp_tool import McpToolRequestor
|
||||
from . text_load import TextLoad
|
||||
from . document_load import DocumentLoad
|
||||
|
|
@ -62,6 +63,7 @@ request_response_dispatchers = {
|
|||
"nlp-query": NLPQueryRequestor,
|
||||
"structured-query": StructuredQueryRequestor,
|
||||
"structured-diag": StructuredDiagRequestor,
|
||||
"row-embeddings": RowEmbeddingsQueryRequestor,
|
||||
}
|
||||
|
||||
global_dispatchers = {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,31 @@
|
|||
|
||||
from ... schema import RowEmbeddingsRequest, RowEmbeddingsResponse
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class RowEmbeddingsQueryRequestor(ServiceRequestor):
|
||||
def __init__(
|
||||
self, backend, request_queue, response_queue, timeout,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
super(RowEmbeddingsQueryRequestor, self).__init__(
|
||||
backend=backend,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
request_schema=RowEmbeddingsRequest,
|
||||
response_schema=RowEmbeddingsResponse,
|
||||
subscription = subscriber,
|
||||
consumer_name = consumer,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("row-embeddings-query")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("row-embeddings-query")
|
||||
|
||||
def to_request(self, body):
|
||||
return self.request_translator.to_pulsar(body)
|
||||
|
||||
def from_response(self, message):
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
Loading…
Add table
Add a link
Reference in a new issue