mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Row embeddings APIs exposed (#646)
* Added row embeddings API and CLI support * Updated protocol specs * Row embeddings agent tool * Add new agent tool to CLI
This commit is contained in:
parent
1809c1f56d
commit
4bbc6d844f
25 changed files with 1090 additions and 29 deletions
|
|
@ -368,8 +368,154 @@ This separation keeps concerns clean:
|
||||||
- Embeddings API handles semantic similarity
|
- Embeddings API handles semantic similarity
|
||||||
- User workflow: fuzzy search via embeddings to find candidates, then exact query to get full row data
|
- User workflow: fuzzy search via embeddings to find candidates, then exact query to get full row data
|
||||||
|
|
||||||
|
#### Request/Response Schema
|
||||||
|
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class RowEmbeddingsRequest:
|
||||||
|
vectors: list[list[float]] # Query vectors (pre-computed embeddings)
|
||||||
|
user: str = ""
|
||||||
|
collection: str = ""
|
||||||
|
schema_name: str = ""
|
||||||
|
index_name: str = "" # Optional: filter to specific index
|
||||||
|
limit: int = 10 # Max results per vector
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RowIndexMatch:
|
||||||
|
index_name: str = "" # The matched index field(s)
|
||||||
|
index_value: list[str] = [] # The matched value(s)
|
||||||
|
text: str = "" # Original text that was embedded
|
||||||
|
score: float = 0.0 # Similarity score
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RowEmbeddingsResponse:
|
||||||
|
error: Error | None = None
|
||||||
|
matches: list[RowIndexMatch] = []
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Query Processor
|
||||||
|
|
||||||
Module: `trustgraph-flow/trustgraph/query/row_embeddings/qdrant`
|
Module: `trustgraph-flow/trustgraph/query/row_embeddings/qdrant`
|
||||||
|
|
||||||
|
Entry point: `row-embeddings-query-qdrant`
|
||||||
|
|
||||||
|
The processor:
|
||||||
|
1. Receives `RowEmbeddingsRequest` with query vectors
|
||||||
|
2. Finds the appropriate Qdrant collection by prefix matching
|
||||||
|
3. Searches for nearest vectors with optional `index_name` filter
|
||||||
|
4. Returns `RowEmbeddingsResponse` with matching index information
|
||||||
|
|
||||||
|
#### API Gateway Integration
|
||||||
|
|
||||||
|
The gateway exposes row embeddings queries via the standard request/response pattern:
|
||||||
|
|
||||||
|
| Component | Location |
|
||||||
|
|-----------|----------|
|
||||||
|
| Dispatcher | `trustgraph-flow/trustgraph/gateway/dispatch/row_embeddings_query.py` |
|
||||||
|
| Registration | Add `"row-embeddings"` to `request_response_dispatchers` in `manager.py` |
|
||||||
|
|
||||||
|
Flow interface name: `row-embeddings`
|
||||||
|
|
||||||
|
Interface definition in flow blueprint:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"interfaces": {
|
||||||
|
"row-embeddings": {
|
||||||
|
"request": "non-persistent://tg/request/row-embeddings:{id}",
|
||||||
|
"response": "non-persistent://tg/response/row-embeddings:{id}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Python SDK Support
|
||||||
|
|
||||||
|
The SDK provides methods for row embeddings queries:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Flow-scoped query (preferred)
|
||||||
|
api = Api(url)
|
||||||
|
flow = api.flow().id("default")
|
||||||
|
|
||||||
|
# Query with text (SDK computes embeddings)
|
||||||
|
matches = flow.row_embeddings_query(
|
||||||
|
text="Chestnut Street",
|
||||||
|
collection="my_collection",
|
||||||
|
schema_name="addresses",
|
||||||
|
index_name="street_name", # Optional filter
|
||||||
|
limit=10
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query with pre-computed vectors
|
||||||
|
matches = flow.row_embeddings_query(
|
||||||
|
vectors=[[0.1, 0.2, ...]],
|
||||||
|
collection="my_collection",
|
||||||
|
schema_name="addresses"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Each match contains:
|
||||||
|
for match in matches:
|
||||||
|
print(match.index_name) # e.g., "street_name"
|
||||||
|
print(match.index_value) # e.g., ["CHESTNUT ST"]
|
||||||
|
print(match.text) # e.g., "CHESTNUT ST"
|
||||||
|
print(match.score) # e.g., 0.95
|
||||||
|
```
|
||||||
|
|
||||||
|
#### CLI Utility
|
||||||
|
|
||||||
|
Command: `tg-invoke-row-embeddings`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Query by text (computes embedding automatically)
|
||||||
|
tg-invoke-row-embeddings \
|
||||||
|
--text "Chestnut Street" \
|
||||||
|
--collection my_collection \
|
||||||
|
--schema addresses \
|
||||||
|
--index street_name \
|
||||||
|
--limit 10
|
||||||
|
|
||||||
|
# Query by vector file
|
||||||
|
tg-invoke-row-embeddings \
|
||||||
|
--vectors vectors.json \
|
||||||
|
--collection my_collection \
|
||||||
|
--schema addresses
|
||||||
|
|
||||||
|
# Output formats
|
||||||
|
tg-invoke-row-embeddings --text "..." --format json
|
||||||
|
tg-invoke-row-embeddings --text "..." --format table
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Typical Usage Pattern
|
||||||
|
|
||||||
|
The row embeddings query is typically used as part of a fuzzy-to-exact lookup flow:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Step 1: Fuzzy search via embeddings
|
||||||
|
matches = flow.row_embeddings_query(
|
||||||
|
text="chestnut street",
|
||||||
|
collection="geo",
|
||||||
|
schema_name="streets"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 2: Exact lookup via GraphQL for full row data
|
||||||
|
for match in matches:
|
||||||
|
query = f'''
|
||||||
|
query {{
|
||||||
|
streets(where: {{ {match.index_name}: {{ eq: "{match.index_value[0]}" }} }}) {{
|
||||||
|
street_name
|
||||||
|
city
|
||||||
|
zip_code
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
'''
|
||||||
|
rows = flow.rows_query(query, collection="geo")
|
||||||
|
```
|
||||||
|
|
||||||
|
This two-step pattern enables:
|
||||||
|
- Finding "CHESTNUT ST" when user searches for "Chestnut Street"
|
||||||
|
- Retrieving complete row data with all fields
|
||||||
|
- Combining semantic similarity with structured data access
|
||||||
|
|
||||||
### Row Data Ingestion
|
### Row Data Ingestion
|
||||||
|
|
||||||
Deferred to a subsequent phase. Will be designed alongside other ingestion changes.
|
Deferred to a subsequent phase. Will be designed alongside other ingestion changes.
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
type: object
|
||||||
|
description: |
|
||||||
|
Row embeddings query request - find similar rows by vector similarity on indexed fields.
|
||||||
|
Enables semantic/fuzzy matching on structured data.
|
||||||
|
required:
|
||||||
|
- vectors
|
||||||
|
- schema_name
|
||||||
|
properties:
|
||||||
|
vectors:
|
||||||
|
type: array
|
||||||
|
description: Query embedding vector
|
||||||
|
items:
|
||||||
|
type: number
|
||||||
|
example: [0.023, -0.142, 0.089, 0.234, -0.067, 0.156]
|
||||||
|
schema_name:
|
||||||
|
type: string
|
||||||
|
description: Schema name to search within
|
||||||
|
example: customers
|
||||||
|
index_name:
|
||||||
|
type: string
|
||||||
|
description: Optional index name to filter search to specific index
|
||||||
|
example: full_name
|
||||||
|
limit:
|
||||||
|
type: integer
|
||||||
|
description: Maximum number of matches to return
|
||||||
|
default: 10
|
||||||
|
minimum: 1
|
||||||
|
maximum: 1000
|
||||||
|
example: 20
|
||||||
|
user:
|
||||||
|
type: string
|
||||||
|
description: User identifier
|
||||||
|
default: trustgraph
|
||||||
|
example: alice
|
||||||
|
collection:
|
||||||
|
type: string
|
||||||
|
description: Collection to search
|
||||||
|
default: default
|
||||||
|
example: sales
|
||||||
|
|
@ -0,0 +1,53 @@
|
||||||
|
type: object
|
||||||
|
description: Row embeddings query response with matching row index information
|
||||||
|
properties:
|
||||||
|
error:
|
||||||
|
type: object
|
||||||
|
description: Error information if query failed
|
||||||
|
properties:
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
description: Error type identifier
|
||||||
|
example: row-embeddings-query-error
|
||||||
|
message:
|
||||||
|
type: string
|
||||||
|
description: Human-readable error message
|
||||||
|
example: Schema not found
|
||||||
|
matches:
|
||||||
|
type: array
|
||||||
|
description: List of matching row index entries with similarity scores
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
index_name:
|
||||||
|
type: string
|
||||||
|
description: Name of the indexed field(s)
|
||||||
|
example: full_name
|
||||||
|
index_value:
|
||||||
|
type: array
|
||||||
|
description: Values of the indexed fields for this row
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
example: ["John", "Smith"]
|
||||||
|
text:
|
||||||
|
type: string
|
||||||
|
description: The text that was embedded for this index entry
|
||||||
|
example: "John Smith"
|
||||||
|
score:
|
||||||
|
type: number
|
||||||
|
description: Similarity score (higher is more similar)
|
||||||
|
example: 0.89
|
||||||
|
example:
|
||||||
|
matches:
|
||||||
|
- index_name: full_name
|
||||||
|
index_value: ["John", "Smith"]
|
||||||
|
text: "John Smith"
|
||||||
|
score: 0.95
|
||||||
|
- index_name: full_name
|
||||||
|
index_value: ["Jon", "Smythe"]
|
||||||
|
text: "Jon Smythe"
|
||||||
|
score: 0.82
|
||||||
|
- index_name: full_name
|
||||||
|
index_value: ["Jonathan", "Schmidt"]
|
||||||
|
text: "Jonathan Schmidt"
|
||||||
|
score: 0.76
|
||||||
|
|
@ -133,6 +133,8 @@ paths:
|
||||||
$ref: './paths/flow/graph-embeddings.yaml'
|
$ref: './paths/flow/graph-embeddings.yaml'
|
||||||
/api/v1/flow/{flow}/service/document-embeddings:
|
/api/v1/flow/{flow}/service/document-embeddings:
|
||||||
$ref: './paths/flow/document-embeddings.yaml'
|
$ref: './paths/flow/document-embeddings.yaml'
|
||||||
|
/api/v1/flow/{flow}/service/row-embeddings:
|
||||||
|
$ref: './paths/flow/row-embeddings.yaml'
|
||||||
/api/v1/flow/{flow}/service/text-load:
|
/api/v1/flow/{flow}/service/text-load:
|
||||||
$ref: './paths/flow/text-load.yaml'
|
$ref: './paths/flow/text-load.yaml'
|
||||||
/api/v1/flow/{flow}/service/document-load:
|
/api/v1/flow/{flow}/service/document-load:
|
||||||
|
|
|
||||||
101
specs/api/paths/flow/row-embeddings.yaml
Normal file
101
specs/api/paths/flow/row-embeddings.yaml
Normal file
|
|
@ -0,0 +1,101 @@
|
||||||
|
post:
|
||||||
|
tags:
|
||||||
|
- Flow Services
|
||||||
|
summary: Row Embeddings Query - semantic search on structured data
|
||||||
|
description: |
|
||||||
|
Query row embeddings to find similar rows by vector similarity on indexed fields.
|
||||||
|
Enables fuzzy/semantic matching on structured data.
|
||||||
|
|
||||||
|
## Row Embeddings Query Overview
|
||||||
|
|
||||||
|
Find rows whose indexed field values are semantically similar to a query:
|
||||||
|
- **Input**: Query embedding vector, schema name, optional index filter
|
||||||
|
- **Search**: Compare against stored row index embeddings
|
||||||
|
- **Output**: Matching rows with index values and similarity scores
|
||||||
|
|
||||||
|
Core component of semantic search on structured data.
|
||||||
|
|
||||||
|
## Use Cases
|
||||||
|
|
||||||
|
- **Fuzzy name matching**: Find customers by approximate name
|
||||||
|
- **Semantic field search**: Find products by description similarity
|
||||||
|
- **Data deduplication**: Identify potential duplicate records
|
||||||
|
- **Entity resolution**: Match records across datasets
|
||||||
|
|
||||||
|
## Process
|
||||||
|
|
||||||
|
1. Obtain query embedding (via embeddings service)
|
||||||
|
2. Query stored row index embeddings for the specified schema
|
||||||
|
3. Calculate cosine similarity
|
||||||
|
4. Return top N most similar index entries
|
||||||
|
5. Use index values to retrieve full rows via GraphQL
|
||||||
|
|
||||||
|
## Response Format
|
||||||
|
|
||||||
|
Each match includes:
|
||||||
|
- `index_name`: The indexed field(s) that matched
|
||||||
|
- `index_value`: The actual values for those fields
|
||||||
|
- `text`: The text that was embedded
|
||||||
|
- `score`: Similarity score (higher = more similar)
|
||||||
|
|
||||||
|
operationId: rowEmbeddingsQueryService
|
||||||
|
security:
|
||||||
|
- bearerAuth: []
|
||||||
|
parameters:
|
||||||
|
- name: flow
|
||||||
|
in: path
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
description: Flow instance ID
|
||||||
|
example: my-flow
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '../../components/schemas/embeddings-query/RowEmbeddingsQueryRequest.yaml'
|
||||||
|
examples:
|
||||||
|
basicQuery:
|
||||||
|
summary: Find similar customer names
|
||||||
|
value:
|
||||||
|
vectors: [0.023, -0.142, 0.089, 0.234, -0.067, 0.156, 0.201, -0.178]
|
||||||
|
schema_name: customers
|
||||||
|
limit: 10
|
||||||
|
user: alice
|
||||||
|
collection: sales
|
||||||
|
filteredQuery:
|
||||||
|
summary: Search specific index
|
||||||
|
value:
|
||||||
|
vectors: [0.1, -0.2, 0.3, -0.4, 0.5]
|
||||||
|
schema_name: products
|
||||||
|
index_name: description
|
||||||
|
limit: 20
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: Successful response
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '../../components/schemas/embeddings-query/RowEmbeddingsQueryResponse.yaml'
|
||||||
|
examples:
|
||||||
|
similarRows:
|
||||||
|
summary: Similar rows found
|
||||||
|
value:
|
||||||
|
matches:
|
||||||
|
- index_name: full_name
|
||||||
|
index_value: ["John", "Smith"]
|
||||||
|
text: "John Smith"
|
||||||
|
score: 0.95
|
||||||
|
- index_name: full_name
|
||||||
|
index_value: ["Jon", "Smythe"]
|
||||||
|
text: "Jon Smythe"
|
||||||
|
score: 0.82
|
||||||
|
- index_name: full_name
|
||||||
|
index_value: ["Jonathan", "Schmidt"]
|
||||||
|
text: "Jonathan Schmidt"
|
||||||
|
score: 0.76
|
||||||
|
'401':
|
||||||
|
$ref: '../../components/responses/Unauthorized.yaml'
|
||||||
|
'500':
|
||||||
|
$ref: '../../components/responses/Error.yaml'
|
||||||
|
|
@ -31,6 +31,7 @@ payload:
|
||||||
- $ref: './requests/StructuredDiagRequest.yaml'
|
- $ref: './requests/StructuredDiagRequest.yaml'
|
||||||
- $ref: './requests/GraphEmbeddingsRequest.yaml'
|
- $ref: './requests/GraphEmbeddingsRequest.yaml'
|
||||||
- $ref: './requests/DocumentEmbeddingsRequest.yaml'
|
- $ref: './requests/DocumentEmbeddingsRequest.yaml'
|
||||||
|
- $ref: './requests/RowEmbeddingsRequest.yaml'
|
||||||
- $ref: './requests/TextLoadRequest.yaml'
|
- $ref: './requests/TextLoadRequest.yaml'
|
||||||
- $ref: './requests/DocumentLoadRequest.yaml'
|
- $ref: './requests/DocumentLoadRequest.yaml'
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,30 @@
|
||||||
|
type: object
|
||||||
|
description: WebSocket request for row-embeddings service (flow-hosted service)
|
||||||
|
required:
|
||||||
|
- id
|
||||||
|
- service
|
||||||
|
- flow
|
||||||
|
- request
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: string
|
||||||
|
description: Unique request identifier
|
||||||
|
service:
|
||||||
|
type: string
|
||||||
|
const: row-embeddings
|
||||||
|
description: Service identifier for row-embeddings service
|
||||||
|
flow:
|
||||||
|
type: string
|
||||||
|
description: Flow ID
|
||||||
|
request:
|
||||||
|
$ref: '../../../../api/components/schemas/embeddings-query/RowEmbeddingsQueryRequest.yaml'
|
||||||
|
examples:
|
||||||
|
- id: req-1
|
||||||
|
service: row-embeddings
|
||||||
|
flow: my-flow
|
||||||
|
request:
|
||||||
|
vectors: [0.023, -0.142, 0.089, 0.234]
|
||||||
|
schema_name: customers
|
||||||
|
limit: 10
|
||||||
|
user: trustgraph
|
||||||
|
collection: default
|
||||||
|
|
@ -766,3 +766,63 @@ class AsyncFlowInstance:
|
||||||
request_data.update(kwargs)
|
request_data.update(kwargs)
|
||||||
|
|
||||||
return await self.request("rows", request_data)
|
return await self.request("rows", request_data)
|
||||||
|
|
||||||
|
async def row_embeddings_query(
|
||||||
|
self, text: str, schema_name: str, user: str = "trustgraph",
|
||||||
|
collection: str = "default", index_name: Optional[str] = None,
|
||||||
|
limit: int = 10, **kwargs: Any
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Query row embeddings for semantic search on structured data.
|
||||||
|
|
||||||
|
Performs semantic search over row index embeddings to find rows whose
|
||||||
|
indexed field values are most similar to the input text. Enables
|
||||||
|
fuzzy/semantic matching on structured data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Query text for semantic search
|
||||||
|
schema_name: Schema name to search within
|
||||||
|
user: User identifier (default: "trustgraph")
|
||||||
|
collection: Collection identifier (default: "default")
|
||||||
|
index_name: Optional index name to filter search to specific index
|
||||||
|
limit: Maximum number of results to return (default: 10)
|
||||||
|
**kwargs: Additional service-specific parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Response containing matches with index_name, index_value,
|
||||||
|
text, and score
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
async_flow = await api.async_flow()
|
||||||
|
flow = async_flow.id("default")
|
||||||
|
|
||||||
|
# Search for customers by name similarity
|
||||||
|
results = await flow.row_embeddings_query(
|
||||||
|
text="John Smith",
|
||||||
|
schema_name="customers",
|
||||||
|
user="trustgraph",
|
||||||
|
collection="sales",
|
||||||
|
limit=5
|
||||||
|
)
|
||||||
|
|
||||||
|
for match in results.get("matches", []):
|
||||||
|
print(f"{match['index_name']}: {match['index_value']} (score: {match['score']})")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
# First convert text to embeddings vectors
|
||||||
|
emb_result = await self.embeddings(text=text)
|
||||||
|
vectors = emb_result.get("vectors", [])
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"vectors": vectors,
|
||||||
|
"schema_name": schema_name,
|
||||||
|
"user": user,
|
||||||
|
"collection": collection,
|
||||||
|
"limit": limit
|
||||||
|
}
|
||||||
|
if index_name:
|
||||||
|
request_data["index_name"] = index_name
|
||||||
|
request_data.update(kwargs)
|
||||||
|
|
||||||
|
return await self.request("row-embeddings", request_data)
|
||||||
|
|
|
||||||
|
|
@ -345,3 +345,26 @@ class AsyncSocketFlowInstance:
|
||||||
request.update(kwargs)
|
request.update(kwargs)
|
||||||
|
|
||||||
return await self.client._send_request("mcp-tool", self.flow_id, request)
|
return await self.client._send_request("mcp-tool", self.flow_id, request)
|
||||||
|
|
||||||
|
async def row_embeddings_query(
|
||||||
|
self, text: str, schema_name: str, user: str = "trustgraph",
|
||||||
|
collection: str = "default", index_name: Optional[str] = None,
|
||||||
|
limit: int = 10, **kwargs
|
||||||
|
):
|
||||||
|
"""Query row embeddings for semantic search on structured data"""
|
||||||
|
# First convert text to embeddings vectors
|
||||||
|
emb_result = await self.embeddings(text=text)
|
||||||
|
vectors = emb_result.get("vectors", [])
|
||||||
|
|
||||||
|
request = {
|
||||||
|
"vectors": vectors,
|
||||||
|
"schema_name": schema_name,
|
||||||
|
"user": user,
|
||||||
|
"collection": collection,
|
||||||
|
"limit": limit
|
||||||
|
}
|
||||||
|
if index_name:
|
||||||
|
request["index_name"] = index_name
|
||||||
|
request.update(kwargs)
|
||||||
|
|
||||||
|
return await self.client._send_request("row-embeddings", self.flow_id, request)
|
||||||
|
|
|
||||||
|
|
@ -1297,3 +1297,78 @@ class FlowInstance:
|
||||||
|
|
||||||
return response["schema-matches"]
|
return response["schema-matches"]
|
||||||
|
|
||||||
|
def row_embeddings_query(
|
||||||
|
self, text, schema_name, user="trustgraph", collection="default",
|
||||||
|
index_name=None, limit=10
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Query row data using semantic similarity on indexed fields.
|
||||||
|
|
||||||
|
Finds rows whose indexed field values are semantically similar to the
|
||||||
|
input text, using vector embeddings. This enables fuzzy/semantic matching
|
||||||
|
on structured data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Query text for semantic search
|
||||||
|
schema_name: Schema name to search within
|
||||||
|
user: User/keyspace identifier (default: "trustgraph")
|
||||||
|
collection: Collection identifier (default: "default")
|
||||||
|
index_name: Optional index name to filter search to specific index
|
||||||
|
limit: Maximum number of results (default: 10)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Query results with matches containing index_name, index_value,
|
||||||
|
text, and score
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
flow = api.flow().id("default")
|
||||||
|
|
||||||
|
# Search for customers by name similarity
|
||||||
|
results = flow.row_embeddings_query(
|
||||||
|
text="John Smith",
|
||||||
|
schema_name="customers",
|
||||||
|
user="trustgraph",
|
||||||
|
collection="sales",
|
||||||
|
limit=5
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter to specific index
|
||||||
|
results = flow.row_embeddings_query(
|
||||||
|
text="machine learning engineer",
|
||||||
|
schema_name="employees",
|
||||||
|
index_name="job_title",
|
||||||
|
limit=10
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
# First convert text to embeddings vectors
|
||||||
|
emb_result = self.embeddings(text=text)
|
||||||
|
vectors = emb_result.get("vectors", [])
|
||||||
|
|
||||||
|
# Query row embeddings for semantic search
|
||||||
|
input = {
|
||||||
|
"vectors": vectors,
|
||||||
|
"schema_name": schema_name,
|
||||||
|
"user": user,
|
||||||
|
"collection": collection,
|
||||||
|
"limit": limit
|
||||||
|
}
|
||||||
|
|
||||||
|
if index_name:
|
||||||
|
input["index_name"] = index_name
|
||||||
|
|
||||||
|
response = self.request(
|
||||||
|
"service/row-embeddings",
|
||||||
|
input
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for system-level error
|
||||||
|
if "error" in response and response["error"]:
|
||||||
|
error_type = response["error"].get("type", "unknown")
|
||||||
|
error_message = response["error"].get("message", "Unknown error")
|
||||||
|
raise ProtocolException(f"{error_type}: {error_message}")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -881,3 +881,73 @@ class SocketFlowInstance:
|
||||||
request.update(kwargs)
|
request.update(kwargs)
|
||||||
|
|
||||||
return self.client._send_request_sync("mcp-tool", self.flow_id, request, False)
|
return self.client._send_request_sync("mcp-tool", self.flow_id, request, False)
|
||||||
|
|
||||||
|
def row_embeddings_query(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
schema_name: str,
|
||||||
|
user: str = "trustgraph",
|
||||||
|
collection: str = "default",
|
||||||
|
index_name: Optional[str] = None,
|
||||||
|
limit: int = 10,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Query row data using semantic similarity on indexed fields.
|
||||||
|
|
||||||
|
Finds rows whose indexed field values are semantically similar to the
|
||||||
|
input text, using vector embeddings. This enables fuzzy/semantic matching
|
||||||
|
on structured data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Query text for semantic search
|
||||||
|
schema_name: Schema name to search within
|
||||||
|
user: User/keyspace identifier (default: "trustgraph")
|
||||||
|
collection: Collection identifier (default: "default")
|
||||||
|
index_name: Optional index name to filter search to specific index
|
||||||
|
limit: Maximum number of results (default: 10)
|
||||||
|
**kwargs: Additional parameters passed to the service
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Query results with matches containing index_name, index_value,
|
||||||
|
text, and score
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
socket = api.socket()
|
||||||
|
flow = socket.flow("default")
|
||||||
|
|
||||||
|
# Search for customers by name similarity
|
||||||
|
results = flow.row_embeddings_query(
|
||||||
|
text="John Smith",
|
||||||
|
schema_name="customers",
|
||||||
|
user="trustgraph",
|
||||||
|
collection="sales",
|
||||||
|
limit=5
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter to specific index
|
||||||
|
results = flow.row_embeddings_query(
|
||||||
|
text="machine learning engineer",
|
||||||
|
schema_name="employees",
|
||||||
|
index_name="job_title",
|
||||||
|
limit=10
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
# First convert text to embeddings vectors
|
||||||
|
emb_result = self.embeddings(text=text)
|
||||||
|
vectors = emb_result.get("vectors", [])
|
||||||
|
|
||||||
|
request = {
|
||||||
|
"vectors": vectors,
|
||||||
|
"schema_name": schema_name,
|
||||||
|
"user": user,
|
||||||
|
"collection": collection,
|
||||||
|
"limit": limit
|
||||||
|
}
|
||||||
|
if index_name:
|
||||||
|
request["index_name"] = index_name
|
||||||
|
request.update(kwargs)
|
||||||
|
|
||||||
|
return self.client._send_request_sync("row-embeddings", self.flow_id, request, False)
|
||||||
|
|
|
||||||
|
|
@ -34,5 +34,6 @@ from . tool_service import ToolService
|
||||||
from . tool_client import ToolClientSpec
|
from . tool_client import ToolClientSpec
|
||||||
from . agent_client import AgentClientSpec
|
from . agent_client import AgentClientSpec
|
||||||
from . structured_query_client import StructuredQueryClientSpec
|
from . structured_query_client import StructuredQueryClientSpec
|
||||||
|
from . row_embeddings_query_client import RowEmbeddingsQueryClientSpec
|
||||||
from . collection_config_handler import CollectionConfigHandler
|
from . collection_config_handler import CollectionConfigHandler
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,45 @@
|
||||||
|
from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||||
|
from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse
|
||||||
|
|
||||||
|
class RowEmbeddingsQueryClient(RequestResponse):
|
||||||
|
async def row_embeddings_query(
|
||||||
|
self, vectors, schema_name, user="trustgraph", collection="default",
|
||||||
|
index_name=None, limit=10, timeout=600
|
||||||
|
):
|
||||||
|
request = RowEmbeddingsRequest(
|
||||||
|
vectors=vectors,
|
||||||
|
schema_name=schema_name,
|
||||||
|
user=user,
|
||||||
|
collection=collection,
|
||||||
|
limit=limit
|
||||||
|
)
|
||||||
|
if index_name:
|
||||||
|
request.index_name = index_name
|
||||||
|
|
||||||
|
resp = await self.request(request, timeout=timeout)
|
||||||
|
|
||||||
|
if resp.error:
|
||||||
|
raise RuntimeError(resp.error.message)
|
||||||
|
|
||||||
|
# Return matches as list of dicts
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"index_name": match.index_name,
|
||||||
|
"index_value": match.index_value,
|
||||||
|
"text": match.text,
|
||||||
|
"score": match.score
|
||||||
|
}
|
||||||
|
for match in (resp.matches or [])
|
||||||
|
]
|
||||||
|
|
||||||
|
class RowEmbeddingsQueryClientSpec(RequestResponseSpec):
|
||||||
|
def __init__(
|
||||||
|
self, request_name, response_name,
|
||||||
|
):
|
||||||
|
super(RowEmbeddingsQueryClientSpec, self).__init__(
|
||||||
|
request_name = request_name,
|
||||||
|
request_schema = RowEmbeddingsRequest,
|
||||||
|
response_name = response_name,
|
||||||
|
response_schema = RowEmbeddingsResponse,
|
||||||
|
impl = RowEmbeddingsQueryClient,
|
||||||
|
)
|
||||||
60
trustgraph-base/trustgraph/clients/row_embeddings_client.py
Normal file
60
trustgraph-base/trustgraph/clients/row_embeddings_client.py
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
|
||||||
|
import _pulsar
|
||||||
|
|
||||||
|
from .. schema import RowEmbeddingsRequest, RowEmbeddingsResponse
|
||||||
|
from .. schema import row_embeddings_request_queue
|
||||||
|
from .. schema import row_embeddings_response_queue
|
||||||
|
from . base import BaseClient
|
||||||
|
|
||||||
|
# Ugly
|
||||||
|
ERROR=_pulsar.LoggerLevel.Error
|
||||||
|
WARN=_pulsar.LoggerLevel.Warn
|
||||||
|
INFO=_pulsar.LoggerLevel.Info
|
||||||
|
DEBUG=_pulsar.LoggerLevel.Debug
|
||||||
|
|
||||||
|
class RowEmbeddingsClient(BaseClient):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, log_level=ERROR,
|
||||||
|
subscriber=None,
|
||||||
|
input_queue=None,
|
||||||
|
output_queue=None,
|
||||||
|
pulsar_host="pulsar://pulsar:6650",
|
||||||
|
pulsar_api_key=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
if input_queue == None:
|
||||||
|
input_queue = row_embeddings_request_queue
|
||||||
|
|
||||||
|
if output_queue == None:
|
||||||
|
output_queue = row_embeddings_response_queue
|
||||||
|
|
||||||
|
super(RowEmbeddingsClient, self).__init__(
|
||||||
|
log_level=log_level,
|
||||||
|
subscriber=subscriber,
|
||||||
|
input_queue=input_queue,
|
||||||
|
output_queue=output_queue,
|
||||||
|
pulsar_host=pulsar_host,
|
||||||
|
pulsar_api_key=pulsar_api_key,
|
||||||
|
input_schema=RowEmbeddingsRequest,
|
||||||
|
output_schema=RowEmbeddingsResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
def request(
|
||||||
|
self, vectors, schema_name, user="trustgraph", collection="default",
|
||||||
|
index_name=None, limit=10, timeout=300
|
||||||
|
):
|
||||||
|
kwargs = dict(
|
||||||
|
user=user, collection=collection,
|
||||||
|
vectors=vectors, schema_name=schema_name,
|
||||||
|
limit=limit, timeout=timeout
|
||||||
|
)
|
||||||
|
if index_name:
|
||||||
|
kwargs["index_name"] = index_name
|
||||||
|
|
||||||
|
response = self.call(**kwargs)
|
||||||
|
|
||||||
|
if response.error:
|
||||||
|
raise RuntimeError(f"{response.error.type}: {response.error.message}")
|
||||||
|
|
||||||
|
return response.matches
|
||||||
|
|
@ -19,7 +19,8 @@ from .translators.prompt import PromptRequestTranslator, PromptResponseTranslato
|
||||||
from .translators.tool import ToolRequestTranslator, ToolResponseTranslator
|
from .translators.tool import ToolRequestTranslator, ToolResponseTranslator
|
||||||
from .translators.embeddings_query import (
|
from .translators.embeddings_query import (
|
||||||
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
||||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
|
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator,
|
||||||
|
RowEmbeddingsRequestTranslator, RowEmbeddingsResponseTranslator
|
||||||
)
|
)
|
||||||
from .translators.rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
|
from .translators.rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
|
||||||
from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator
|
from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator
|
||||||
|
|
@ -107,11 +108,17 @@ TranslatorRegistry.register_service(
|
||||||
)
|
)
|
||||||
|
|
||||||
TranslatorRegistry.register_service(
|
TranslatorRegistry.register_service(
|
||||||
"graph-embeddings-query",
|
"graph-embeddings-query",
|
||||||
GraphEmbeddingsRequestTranslator(),
|
GraphEmbeddingsRequestTranslator(),
|
||||||
GraphEmbeddingsResponseTranslator()
|
GraphEmbeddingsResponseTranslator()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
TranslatorRegistry.register_service(
|
||||||
|
"row-embeddings-query",
|
||||||
|
RowEmbeddingsRequestTranslator(),
|
||||||
|
RowEmbeddingsResponseTranslator()
|
||||||
|
)
|
||||||
|
|
||||||
TranslatorRegistry.register_service(
|
TranslatorRegistry.register_service(
|
||||||
"rows-query",
|
"rows-query",
|
||||||
RowsQueryRequestTranslator(),
|
RowsQueryRequestTranslator(),
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,8 @@ from .flow import FlowRequestTranslator, FlowResponseTranslator
|
||||||
from .prompt import PromptRequestTranslator, PromptResponseTranslator
|
from .prompt import PromptRequestTranslator, PromptResponseTranslator
|
||||||
from .embeddings_query import (
|
from .embeddings_query import (
|
||||||
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
||||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
|
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator,
|
||||||
|
RowEmbeddingsRequestTranslator, RowEmbeddingsResponseTranslator
|
||||||
)
|
)
|
||||||
from .rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
|
from .rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
|
||||||
from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
|
from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
from typing import Dict, Any, Tuple
|
from typing import Dict, Any, Tuple
|
||||||
from ...schema import (
|
from ...schema import (
|
||||||
DocumentEmbeddingsRequest, DocumentEmbeddingsResponse,
|
DocumentEmbeddingsRequest, DocumentEmbeddingsResponse,
|
||||||
GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
GraphEmbeddingsRequest, GraphEmbeddingsResponse,
|
||||||
|
RowEmbeddingsRequest, RowEmbeddingsResponse, RowIndexMatch
|
||||||
)
|
)
|
||||||
from .base import MessageTranslator
|
from .base import MessageTranslator
|
||||||
from .primitives import ValueTranslator
|
from .primitives import ValueTranslator
|
||||||
|
|
@ -92,3 +93,62 @@ class GraphEmbeddingsResponseTranslator(MessageTranslator):
|
||||||
def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
"""Returns (response_dict, is_final)"""
|
"""Returns (response_dict, is_final)"""
|
||||||
return self.from_pulsar(obj), True
|
return self.from_pulsar(obj), True
|
||||||
|
|
||||||
|
|
||||||
|
class RowEmbeddingsRequestTranslator(MessageTranslator):
|
||||||
|
"""Translator for RowEmbeddingsRequest schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsRequest:
|
||||||
|
return RowEmbeddingsRequest(
|
||||||
|
vectors=data["vectors"],
|
||||||
|
limit=int(data.get("limit", 10)),
|
||||||
|
user=data.get("user", "trustgraph"),
|
||||||
|
collection=data.get("collection", "default"),
|
||||||
|
schema_name=data.get("schema_name", ""),
|
||||||
|
index_name=data.get("index_name")
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: RowEmbeddingsRequest) -> Dict[str, Any]:
|
||||||
|
result = {
|
||||||
|
"vectors": obj.vectors,
|
||||||
|
"limit": obj.limit,
|
||||||
|
"user": obj.user,
|
||||||
|
"collection": obj.collection,
|
||||||
|
"schema_name": obj.schema_name,
|
||||||
|
}
|
||||||
|
if obj.index_name:
|
||||||
|
result["index_name"] = obj.index_name
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class RowEmbeddingsResponseTranslator(MessageTranslator):
|
||||||
|
"""Translator for RowEmbeddingsResponse schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> RowEmbeddingsResponse:
|
||||||
|
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: RowEmbeddingsResponse) -> Dict[str, Any]:
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
if obj.error is not None:
|
||||||
|
result["error"] = {
|
||||||
|
"type": obj.error.type,
|
||||||
|
"message": obj.error.message
|
||||||
|
}
|
||||||
|
|
||||||
|
if obj.matches is not None:
|
||||||
|
result["matches"] = [
|
||||||
|
{
|
||||||
|
"index_name": match.index_name,
|
||||||
|
"index_value": match.index_value,
|
||||||
|
"text": match.text,
|
||||||
|
"score": match.score
|
||||||
|
}
|
||||||
|
for match in obj.matches
|
||||||
|
]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def from_response_with_completion(self, obj: RowEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
|
"""Returns (response_dict, is_final)"""
|
||||||
|
return self.from_pulsar(obj), True
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,7 @@ tg-invoke-document-embeddings = "trustgraph.cli.invoke_document_embeddings:main"
|
||||||
tg-invoke-mcp-tool = "trustgraph.cli.invoke_mcp_tool:main"
|
tg-invoke-mcp-tool = "trustgraph.cli.invoke_mcp_tool:main"
|
||||||
tg-invoke-nlp-query = "trustgraph.cli.invoke_nlp_query:main"
|
tg-invoke-nlp-query = "trustgraph.cli.invoke_nlp_query:main"
|
||||||
tg-invoke-rows-query = "trustgraph.cli.invoke_rows_query:main"
|
tg-invoke-rows-query = "trustgraph.cli.invoke_rows_query:main"
|
||||||
|
tg-invoke-row-embeddings = "trustgraph.cli.invoke_row_embeddings:main"
|
||||||
tg-invoke-prompt = "trustgraph.cli.invoke_prompt:main"
|
tg-invoke-prompt = "trustgraph.cli.invoke_prompt:main"
|
||||||
tg-invoke-structured-query = "trustgraph.cli.invoke_structured_query:main"
|
tg-invoke-structured-query = "trustgraph.cli.invoke_structured_query:main"
|
||||||
tg-load-doc-embeds = "trustgraph.cli.load_doc_embeds:main"
|
tg-load-doc-embeds = "trustgraph.cli.load_doc_embeds:main"
|
||||||
|
|
|
||||||
126
trustgraph-cli/trustgraph/cli/invoke_row_embeddings.py
Normal file
126
trustgraph-cli/trustgraph/cli/invoke_row_embeddings.py
Normal file
|
|
@ -0,0 +1,126 @@
|
||||||
|
"""
|
||||||
|
Queries row data by text similarity using vector embeddings on indexed fields.
|
||||||
|
Returns matching rows with their index values and similarity scores.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from trustgraph.api import Api
|
||||||
|
|
||||||
|
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||||
|
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||||
|
|
||||||
|
def query(url, flow_id, query_text, schema_name, user, collection, index_name, limit, token=None):
|
||||||
|
|
||||||
|
# Create API client
|
||||||
|
api = Api(url=url, token=token)
|
||||||
|
socket = api.socket()
|
||||||
|
flow = socket.flow(flow_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Call row embeddings query service
|
||||||
|
result = flow.row_embeddings_query(
|
||||||
|
text=query_text,
|
||||||
|
schema_name=schema_name,
|
||||||
|
user=user,
|
||||||
|
collection=collection,
|
||||||
|
index_name=index_name,
|
||||||
|
limit=limit
|
||||||
|
)
|
||||||
|
|
||||||
|
matches = result.get("matches", [])
|
||||||
|
for match in matches:
|
||||||
|
print(f"Index: {match['index_name']}")
|
||||||
|
print(f" Values: {match['index_value']}")
|
||||||
|
print(f" Text: {match['text']}")
|
||||||
|
print(f" Score: {match['score']:.4f}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up socket connection
|
||||||
|
socket.close()
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog='tg-invoke-row-embeddings',
|
||||||
|
description=__doc__,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-u', '--url',
|
||||||
|
default=default_url,
|
||||||
|
help=f'API URL (default: {default_url})',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-t', '--token',
|
||||||
|
default=default_token,
|
||||||
|
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-f', '--flow-id',
|
||||||
|
default="default",
|
||||||
|
help=f'Flow ID (default: default)'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-U', '--user',
|
||||||
|
default="trustgraph",
|
||||||
|
help='User/keyspace (default: trustgraph)',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-c', '--collection',
|
||||||
|
default="default",
|
||||||
|
help='Collection (default: default)',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-s', '--schema-name',
|
||||||
|
required=True,
|
||||||
|
help='Schema name to search within (required)',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-i', '--index-name',
|
||||||
|
default=None,
|
||||||
|
help='Index name to filter search (optional)',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-l', '--limit',
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help='Maximum number of results (default: 10)',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'query',
|
||||||
|
nargs=1,
|
||||||
|
help='Query text to search for similar row index values',
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
query(
|
||||||
|
url=args.url,
|
||||||
|
flow_id=args.flow_id,
|
||||||
|
query_text=args.query[0],
|
||||||
|
schema_name=args.schema_name,
|
||||||
|
user=args.user,
|
||||||
|
collection=args.collection,
|
||||||
|
index_name=args.index_name,
|
||||||
|
limit=args.limit,
|
||||||
|
token=args.token,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
|
||||||
|
print("Exception:", e, flush=True)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -2,8 +2,9 @@
|
||||||
Configures and registers tools in the TrustGraph system.
|
Configures and registers tools in the TrustGraph system.
|
||||||
|
|
||||||
This script allows you to define agent tools with various types including:
|
This script allows you to define agent tools with various types including:
|
||||||
- knowledge-query: Query knowledge bases
|
- knowledge-query: Query knowledge bases
|
||||||
- structured-query: Query structured data using natural language
|
- structured-query: Query structured data using natural language
|
||||||
|
- row-embeddings-query: Semantic search on structured data indexes
|
||||||
- text-completion: Text generation
|
- text-completion: Text generation
|
||||||
- mcp-tool: Reference to MCP (Model Context Protocol) tools
|
- mcp-tool: Reference to MCP (Model Context Protocol) tools
|
||||||
- prompt: Prompt template execution
|
- prompt: Prompt template execution
|
||||||
|
|
@ -64,6 +65,9 @@ def set_tool(
|
||||||
mcp_tool : str,
|
mcp_tool : str,
|
||||||
collection : str,
|
collection : str,
|
||||||
template : str,
|
template : str,
|
||||||
|
schema_name : str,
|
||||||
|
index_name : str,
|
||||||
|
limit : int,
|
||||||
arguments : List[Argument],
|
arguments : List[Argument],
|
||||||
group : List[str],
|
group : List[str],
|
||||||
state : str,
|
state : str,
|
||||||
|
|
@ -89,6 +93,12 @@ def set_tool(
|
||||||
|
|
||||||
if template: object["template"] = template
|
if template: object["template"] = template
|
||||||
|
|
||||||
|
if schema_name: object["schema-name"] = schema_name
|
||||||
|
|
||||||
|
if index_name: object["index-name"] = index_name
|
||||||
|
|
||||||
|
if limit: object["limit"] = limit
|
||||||
|
|
||||||
if arguments:
|
if arguments:
|
||||||
object["arguments"] = [
|
object["arguments"] = [
|
||||||
{
|
{
|
||||||
|
|
@ -120,30 +130,37 @@ def main():
|
||||||
description=__doc__,
|
description=__doc__,
|
||||||
epilog=textwrap.dedent('''
|
epilog=textwrap.dedent('''
|
||||||
Valid tool types:
|
Valid tool types:
|
||||||
knowledge-query - Query knowledge bases (fixed args)
|
knowledge-query - Query knowledge bases (fixed args)
|
||||||
structured-query - Query structured data using natural language (fixed args)
|
structured-query - Query structured data using natural language (fixed args)
|
||||||
text-completion - Text completion/generation (fixed args)
|
row-embeddings-query - Semantic search on structured data indexes (fixed args)
|
||||||
mcp-tool - Model Control Protocol tool (configurable args)
|
text-completion - Text completion/generation (fixed args)
|
||||||
prompt - Prompt template query (configurable args)
|
mcp-tool - Model Control Protocol tool (configurable args)
|
||||||
|
prompt - Prompt template query (configurable args)
|
||||||
Note: Tools marked "(fixed args)" have predefined arguments and don't need
|
|
||||||
|
Note: Tools marked "(fixed args)" have predefined arguments and don't need
|
||||||
--argument specified. Tools marked "(configurable args)" require --argument.
|
--argument specified. Tools marked "(configurable args)" require --argument.
|
||||||
|
|
||||||
Valid argument types:
|
Valid argument types:
|
||||||
string - String/text parameter
|
string - String/text parameter
|
||||||
number - Numeric parameter
|
number - Numeric parameter
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
%(prog)s --id weather_tool --name get_weather \\
|
%(prog)s --id weather_tool --name get_weather \\
|
||||||
--type knowledge-query \\
|
--type knowledge-query \\
|
||||||
--description "Get weather information for a location" \\
|
--description "Get weather information for a location" \\
|
||||||
--collection weather_data
|
--collection weather_data
|
||||||
|
|
||||||
%(prog)s --id data_query_tool --name query_data \\
|
%(prog)s --id data_query_tool --name query_data \\
|
||||||
--type structured-query \\
|
--type structured-query \\
|
||||||
--description "Query structured data using natural language" \\
|
--description "Query structured data using natural language" \\
|
||||||
--collection sales_data
|
--collection sales_data
|
||||||
|
|
||||||
|
%(prog)s --id customer_search --name find_customer \\
|
||||||
|
--type row-embeddings-query \\
|
||||||
|
--description "Find customers by name using semantic search" \\
|
||||||
|
--schema-name customers --collection sales \\
|
||||||
|
--index-name full_name --limit 20
|
||||||
|
|
||||||
%(prog)s --id calc_tool --name calculate --type mcp-tool \\
|
%(prog)s --id calc_tool --name calculate --type mcp-tool \\
|
||||||
--description "Perform mathematical calculations" \\
|
--description "Perform mathematical calculations" \\
|
||||||
--mcp-tool calculator \\
|
--mcp-tool calculator \\
|
||||||
|
|
@ -181,7 +198,7 @@ def main():
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--type',
|
'--type',
|
||||||
help=f'Tool type, one of: knowledge-query, structured-query, text-completion, mcp-tool, prompt',
|
help=f'Tool type, one of: knowledge-query, structured-query, row-embeddings-query, text-completion, mcp-tool, prompt',
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
@ -191,7 +208,23 @@ def main():
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--collection',
|
'--collection',
|
||||||
help=f'For knowledge-query and structured-query types: collection to query',
|
help=f'For knowledge-query, structured-query, and row-embeddings-query types: collection to query',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--schema-name',
|
||||||
|
help=f'For row-embeddings-query type: schema name to search within (required)',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--index-name',
|
||||||
|
help=f'For row-embeddings-query type: specific index to filter search (optional)',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--limit',
|
||||||
|
type=int,
|
||||||
|
help=f'For row-embeddings-query type: maximum results to return (default: 10)',
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
@ -227,7 +260,8 @@ def main():
|
||||||
try:
|
try:
|
||||||
|
|
||||||
valid_types = [
|
valid_types = [
|
||||||
"knowledge-query", "structured-query", "text-completion", "mcp-tool", "prompt"
|
"knowledge-query", "structured-query", "row-embeddings-query",
|
||||||
|
"text-completion", "mcp-tool", "prompt"
|
||||||
]
|
]
|
||||||
|
|
||||||
if args.id is None:
|
if args.id is None:
|
||||||
|
|
@ -261,6 +295,9 @@ def main():
|
||||||
mcp_tool=mcp_tool,
|
mcp_tool=mcp_tool,
|
||||||
collection=args.collection,
|
collection=args.collection,
|
||||||
template=args.template,
|
template=args.template,
|
||||||
|
schema_name=args.schema_name,
|
||||||
|
index_name=args.index_name,
|
||||||
|
limit=args.limit,
|
||||||
arguments=arguments,
|
arguments=arguments,
|
||||||
group=args.group,
|
group=args.group,
|
||||||
state=args.state,
|
state=args.state,
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,9 @@ Displays the current agent tool configurations
|
||||||
Shows all configured tools including their types:
|
Shows all configured tools including their types:
|
||||||
- knowledge-query: Tools that query knowledge bases
|
- knowledge-query: Tools that query knowledge bases
|
||||||
- structured-query: Tools that query structured data using natural language
|
- structured-query: Tools that query structured data using natural language
|
||||||
|
- row-embeddings-query: Tools for semantic search on structured data indexes
|
||||||
- text-completion: Tools for text generation
|
- text-completion: Tools for text generation
|
||||||
- mcp-tool: References to MCP (Model Context Protocol) tools
|
- mcp-tool: References to MCP (Model Context Protocol) tools
|
||||||
- prompt: Tools that execute prompt templates
|
- prompt: Tools that execute prompt templates
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -41,11 +42,19 @@ def show_config(url, token=None):
|
||||||
|
|
||||||
if tp == "mcp-tool":
|
if tp == "mcp-tool":
|
||||||
table.append(("mcp-tool", data["mcp-tool"]))
|
table.append(("mcp-tool", data["mcp-tool"]))
|
||||||
|
|
||||||
if tp == "knowledge-query" or tp == "structured-query":
|
if tp in ("knowledge-query", "structured-query", "row-embeddings-query"):
|
||||||
if "collection" in data:
|
if "collection" in data:
|
||||||
table.append(("collection", data["collection"]))
|
table.append(("collection", data["collection"]))
|
||||||
|
|
||||||
|
if tp == "row-embeddings-query":
|
||||||
|
if "schema-name" in data:
|
||||||
|
table.append(("schema-name", data["schema-name"]))
|
||||||
|
if "index-name" in data:
|
||||||
|
table.append(("index-name", data["index-name"]))
|
||||||
|
if "limit" in data:
|
||||||
|
table.append(("limit", data["limit"]))
|
||||||
|
|
||||||
if tp == "prompt":
|
if tp == "prompt":
|
||||||
table.append(("template", data["template"]))
|
table.append(("template", data["template"]))
|
||||||
for n, arg in enumerate(data["arguments"]):
|
for n, arg in enumerate(data["arguments"]):
|
||||||
|
|
|
||||||
|
|
@ -13,10 +13,11 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
|
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
|
||||||
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
|
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
|
||||||
|
from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec
|
||||||
|
|
||||||
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
|
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
|
||||||
|
|
||||||
from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl, StructuredQueryImpl
|
from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl, StructuredQueryImpl, RowEmbeddingsQueryImpl
|
||||||
from . agent_manager import AgentManager
|
from . agent_manager import AgentManager
|
||||||
from ..tool_filter import validate_tool_config, filter_tools_by_group_and_state, get_next_state
|
from ..tool_filter import validate_tool_config, filter_tools_by_group_and_state, get_next_state
|
||||||
|
|
||||||
|
|
@ -87,6 +88,20 @@ class Processor(AgentService):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.register_specification(
|
||||||
|
EmbeddingsClientSpec(
|
||||||
|
request_name = "embeddings-request",
|
||||||
|
response_name = "embeddings-response",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_specification(
|
||||||
|
RowEmbeddingsQueryClientSpec(
|
||||||
|
request_name = "row-embeddings-query-request",
|
||||||
|
response_name = "row-embeddings-query-response",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def on_tools_config(self, config, version):
|
async def on_tools_config(self, config, version):
|
||||||
|
|
||||||
logger.info(f"Loading configuration version {version}")
|
logger.info(f"Loading configuration version {version}")
|
||||||
|
|
@ -147,11 +162,21 @@ class Processor(AgentService):
|
||||||
)
|
)
|
||||||
elif impl_id == "structured-query":
|
elif impl_id == "structured-query":
|
||||||
impl = functools.partial(
|
impl = functools.partial(
|
||||||
StructuredQueryImpl,
|
StructuredQueryImpl,
|
||||||
collection=data.get("collection"),
|
collection=data.get("collection"),
|
||||||
user=None # User will be provided dynamically via context
|
user=None # User will be provided dynamically via context
|
||||||
)
|
)
|
||||||
arguments = StructuredQueryImpl.get_arguments()
|
arguments = StructuredQueryImpl.get_arguments()
|
||||||
|
elif impl_id == "row-embeddings-query":
|
||||||
|
impl = functools.partial(
|
||||||
|
RowEmbeddingsQueryImpl,
|
||||||
|
schema_name=data.get("schema-name"),
|
||||||
|
collection=data.get("collection"),
|
||||||
|
user=None, # User will be provided dynamically via context
|
||||||
|
index_name=data.get("index-name"), # Optional filter
|
||||||
|
limit=int(data.get("limit", 10)) # Max results
|
||||||
|
)
|
||||||
|
arguments = RowEmbeddingsQueryImpl.get_arguments()
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Tool type {impl_id} not known"
|
f"Tool type {impl_id} not known"
|
||||||
|
|
@ -327,11 +352,11 @@ class Processor(AgentService):
|
||||||
def __init__(self, flow, user):
|
def __init__(self, flow, user):
|
||||||
self._flow = flow
|
self._flow = flow
|
||||||
self._user = user
|
self._user = user
|
||||||
|
|
||||||
def __call__(self, service_name):
|
def __call__(self, service_name):
|
||||||
client = self._flow(service_name)
|
client = self._flow(service_name)
|
||||||
# For structured query clients, store user context
|
# For query clients that need user context, store it
|
||||||
if service_name == "structured-query-request":
|
if service_name in ("structured-query-request", "row-embeddings-query-request"):
|
||||||
client._current_user = self._user
|
client._current_user = self._user
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,62 @@ class StructuredQueryImpl:
|
||||||
return str(result)
|
return str(result)
|
||||||
|
|
||||||
|
|
||||||
|
# This tool implementation knows how to query row embeddings for semantic search
|
||||||
|
class RowEmbeddingsQueryImpl:
|
||||||
|
def __init__(self, context, schema_name, collection=None, user=None, index_name=None, limit=10):
|
||||||
|
self.context = context
|
||||||
|
self.schema_name = schema_name
|
||||||
|
self.collection = collection
|
||||||
|
self.user = user
|
||||||
|
self.index_name = index_name # Optional: filter to specific index
|
||||||
|
self.limit = limit # Max results to return
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_arguments():
|
||||||
|
return [
|
||||||
|
Argument(
|
||||||
|
name="query",
|
||||||
|
type="string",
|
||||||
|
description="Text to search for semantically similar values in the structured data index"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def invoke(self, **arguments):
|
||||||
|
# First get embeddings for the query text
|
||||||
|
embeddings_client = self.context("embeddings-request")
|
||||||
|
logger.debug("Getting embeddings for row query...")
|
||||||
|
|
||||||
|
query_text = arguments.get("query")
|
||||||
|
vectors = await embeddings_client.embed(query_text)
|
||||||
|
|
||||||
|
# Now query row embeddings
|
||||||
|
client = self.context("row-embeddings-query-request")
|
||||||
|
logger.debug("Row embeddings query...")
|
||||||
|
|
||||||
|
# Get user from client context if available
|
||||||
|
user = getattr(client, '_current_user', self.user or "trustgraph")
|
||||||
|
|
||||||
|
matches = await client.row_embeddings_query(
|
||||||
|
vectors=vectors,
|
||||||
|
schema_name=self.schema_name,
|
||||||
|
user=user,
|
||||||
|
collection=self.collection or "default",
|
||||||
|
index_name=self.index_name,
|
||||||
|
limit=self.limit
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format results for agent consumption
|
||||||
|
if not matches:
|
||||||
|
return "No matching records found"
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for match in matches:
|
||||||
|
result = f"- {match['index_name']}: {', '.join(match['index_value'])} (score: {match['score']:.3f})"
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return "Matching records:\n" + "\n".join(results)
|
||||||
|
|
||||||
|
|
||||||
# This tool implementation knows how to execute prompt templates
|
# This tool implementation knows how to execute prompt templates
|
||||||
class PromptImpl:
|
class PromptImpl:
|
||||||
def __init__(self, context, template_id, arguments=None):
|
def __init__(self, context, template_id, arguments=None):
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ from . structured_diag import StructuredDiagRequestor
|
||||||
from . embeddings import EmbeddingsRequestor
|
from . embeddings import EmbeddingsRequestor
|
||||||
from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
|
from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
|
||||||
from . document_embeddings_query import DocumentEmbeddingsQueryRequestor
|
from . document_embeddings_query import DocumentEmbeddingsQueryRequestor
|
||||||
|
from . row_embeddings_query import RowEmbeddingsQueryRequestor
|
||||||
from . mcp_tool import McpToolRequestor
|
from . mcp_tool import McpToolRequestor
|
||||||
from . text_load import TextLoad
|
from . text_load import TextLoad
|
||||||
from . document_load import DocumentLoad
|
from . document_load import DocumentLoad
|
||||||
|
|
@ -62,6 +63,7 @@ request_response_dispatchers = {
|
||||||
"nlp-query": NLPQueryRequestor,
|
"nlp-query": NLPQueryRequestor,
|
||||||
"structured-query": StructuredQueryRequestor,
|
"structured-query": StructuredQueryRequestor,
|
||||||
"structured-diag": StructuredDiagRequestor,
|
"structured-diag": StructuredDiagRequestor,
|
||||||
|
"row-embeddings": RowEmbeddingsQueryRequestor,
|
||||||
}
|
}
|
||||||
|
|
||||||
global_dispatchers = {
|
global_dispatchers = {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
|
||||||
|
from ... schema import RowEmbeddingsRequest, RowEmbeddingsResponse
|
||||||
|
from ... messaging import TranslatorRegistry
|
||||||
|
|
||||||
|
from . requestor import ServiceRequestor
|
||||||
|
|
||||||
|
class RowEmbeddingsQueryRequestor(ServiceRequestor):
|
||||||
|
def __init__(
|
||||||
|
self, backend, request_queue, response_queue, timeout,
|
||||||
|
consumer, subscriber,
|
||||||
|
):
|
||||||
|
|
||||||
|
super(RowEmbeddingsQueryRequestor, self).__init__(
|
||||||
|
backend=backend,
|
||||||
|
request_queue=request_queue,
|
||||||
|
response_queue=response_queue,
|
||||||
|
request_schema=RowEmbeddingsRequest,
|
||||||
|
response_schema=RowEmbeddingsResponse,
|
||||||
|
subscription = subscriber,
|
||||||
|
consumer_name = consumer,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.request_translator = TranslatorRegistry.get_request_translator("row-embeddings-query")
|
||||||
|
self.response_translator = TranslatorRegistry.get_response_translator("row-embeddings-query")
|
||||||
|
|
||||||
|
def to_request(self, body):
|
||||||
|
return self.request_translator.to_pulsar(body)
|
||||||
|
|
||||||
|
def from_response(self, message):
|
||||||
|
return self.response_translator.from_response_with_completion(message)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue