mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-05 13:22:37 +02:00
NLP query to GraphQL service (#491)
This commit is contained in:
parent
c078ca45cd
commit
8d4aa0069c
7 changed files with 440 additions and 5 deletions
|
|
@ -7,16 +7,15 @@ from ..core.topic import topic
|
||||||
|
|
||||||
# NLP to Structured Query Service - converts natural language to GraphQL
|
# NLP to Structured Query Service - converts natural language to GraphQL
|
||||||
|
|
||||||
class NLPToStructuredQueryRequest(Record):
|
class QuestionToStructuredQueryRequest(Record):
|
||||||
natural_language_query = String()
|
question = String()
|
||||||
max_results = Integer()
|
max_results = Integer()
|
||||||
context_hints = Map(String()) # Optional context for query generation
|
|
||||||
|
|
||||||
class NLPToStructuredQueryResponse(Record):
|
class QuestionToStructuredQueryResponse(Record):
|
||||||
error = Error()
|
error = Error()
|
||||||
graphql_query = String() # Generated GraphQL query
|
graphql_query = String() # Generated GraphQL query
|
||||||
variables = Map(String()) # GraphQL variables if any
|
variables = Map(String()) # GraphQL variables if any
|
||||||
detected_schemas = Array(String()) # Which schemas the query targets
|
detected_schemas = Array(String()) # Which schemas the query targets
|
||||||
confidence = Double()
|
confidence = Double()
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
|
||||||
|
|
@ -87,6 +87,7 @@ kg-store = "trustgraph.storage.knowledge:run"
|
||||||
librarian = "trustgraph.librarian:run"
|
librarian = "trustgraph.librarian:run"
|
||||||
mcp-tool = "trustgraph.agent.mcp_tool:run"
|
mcp-tool = "trustgraph.agent.mcp_tool:run"
|
||||||
metering = "trustgraph.metering:run"
|
metering = "trustgraph.metering:run"
|
||||||
|
nlp-query = "trustgraph.retrieval.nlp_query:run"
|
||||||
objects-write-cassandra = "trustgraph.storage.objects.cassandra:run"
|
objects-write-cassandra = "trustgraph.storage.objects.cassandra:run"
|
||||||
objects-query-cassandra = "trustgraph.query.objects.cassandra:run"
|
objects-query-cassandra = "trustgraph.query.objects.cassandra:run"
|
||||||
oe-write-milvus = "trustgraph.storage.object_embeddings.milvus:run"
|
oe-write-milvus = "trustgraph.storage.object_embeddings.milvus:run"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
from . service import *
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from . service import run
|
||||||
|
|
||||||
|
run()
|
||||||
25
trustgraph-flow/trustgraph/retrieval/nlp_query/pass1.txt
Normal file
25
trustgraph-flow/trustgraph/retrieval/nlp_query/pass1.txt
Normal file
|
|
@ -0,0 +1,25 @@
|
||||||
|
You are a database schema selection expert. Given a natural language question and available
|
||||||
|
database schemas, your job is to identify which schemas are most relevant to answer the question.
|
||||||
|
|
||||||
|
## Available Schemas:
|
||||||
|
{% for schema in schemas %}
|
||||||
|
**{{ schema.name }}**: {{ schema.description }}
|
||||||
|
Fields:
|
||||||
|
{% for field in schema.fields %}
|
||||||
|
- {{ field.name }} ({{ field.type }}): {{ field.description }}
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
## Question:
|
||||||
|
{{ question }}
|
||||||
|
|
||||||
|
## Instructions:
|
||||||
|
1. Analyze the question to understand what data is being requested
|
||||||
|
2. Examine each schema to understand what data it contains
|
||||||
|
3. Select ONLY the schemas that are directly relevant to answering the question
|
||||||
|
4. Return your answer as a JSON array of schema names
|
||||||
|
|
||||||
|
## Response Format:
|
||||||
|
Return ONLY a JSON array of schema names, nothing else.
|
||||||
|
Example: ["customers", "orders", "products"]
|
||||||
101
trustgraph-flow/trustgraph/retrieval/nlp_query/pass2.txt
Normal file
101
trustgraph-flow/trustgraph/retrieval/nlp_query/pass2.txt
Normal file
|
|
@ -0,0 +1,101 @@
|
||||||
|
You are a GraphQL query generation expert. Given a natural language question and relevant database
|
||||||
|
schemas, generate a precise GraphQL query to answer the question.
|
||||||
|
|
||||||
|
## Question:
|
||||||
|
{{ question }}
|
||||||
|
|
||||||
|
## Relevant Schemas:
|
||||||
|
{% for schema in schemas %}
|
||||||
|
**{{ schema.name }}**: {{ schema.description }}
|
||||||
|
Fields:
|
||||||
|
{% for field in schema.fields %}
|
||||||
|
- {{ field.name }} ({{ field.type }}){% if field.description %}: {{ field.description }}{% endif
|
||||||
|
%}{% if field.primary_key %} [PRIMARY KEY]{% endif %}{% if field.required %} [REQUIRED]{% endif
|
||||||
|
%}{% if field.indexed %} [INDEXED]{% endif %}{% if field.enum_values %} [OPTIONS: {{
|
||||||
|
field.enum_values|join(', ') }}]{% endif %}
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
## GraphQL Query Rules:
|
||||||
|
1. Use the schema names as GraphQL query fields (e.g., `customers`, `orders`)
|
||||||
|
2. Apply filters using the `where` parameter with nested filter objects
|
||||||
|
3. Available filter operators per field type:
|
||||||
|
- String fields: `eq`, `contains`, `startsWith`, `endsWith`, `in`, `not`, `not_in`
|
||||||
|
- Integer/Float fields: `eq`, `gt`, `gte`, `lt`, `lte`, `in`, `not`, `not_in`
|
||||||
|
4. Use `order_by` for sorting (field name as string)
|
||||||
|
5. Use `direction` for sort direction: `ASC` or `DESC`
|
||||||
|
6. Use `limit` to restrict number of results
|
||||||
|
7. Select specific fields in the query body
|
||||||
|
|
||||||
|
## Example GraphQL Queries:
|
||||||
|
|
||||||
|
**Question**: "Show me customers from California"
|
||||||
|
```graphql
|
||||||
|
query {
|
||||||
|
customers(where: {state: {eq: "California"}}, limit: 100) {
|
||||||
|
customer_id
|
||||||
|
name
|
||||||
|
email
|
||||||
|
state
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Question: "Top 10 products by price"
|
||||||
|
query {
|
||||||
|
products(order_by: "price", direction: DESC, limit: 10) {
|
||||||
|
product_id
|
||||||
|
name
|
||||||
|
price
|
||||||
|
category
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Question: "Recent orders over $100"
|
||||||
|
query {
|
||||||
|
orders(
|
||||||
|
where: {
|
||||||
|
total_amount: {gt: 100}
|
||||||
|
order_date: {gte: "2024-01-01"}
|
||||||
|
}
|
||||||
|
order_by: "order_date"
|
||||||
|
direction: DESC
|
||||||
|
limit: 50
|
||||||
|
) {
|
||||||
|
order_id
|
||||||
|
customer_id
|
||||||
|
total_amount
|
||||||
|
order_date
|
||||||
|
status
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Instructions:
|
||||||
|
|
||||||
|
1. Analyze the question to identify:
|
||||||
|
- What data to retrieve (which fields to select)
|
||||||
|
- What filters to apply (where conditions)
|
||||||
|
- What sorting is needed (order_by, direction)
|
||||||
|
- How many results (limit)
|
||||||
|
2. Generate a GraphQL query that:
|
||||||
|
- Uses only the provided schema names and field names
|
||||||
|
- Applies appropriate filters based on the question
|
||||||
|
- Selects relevant fields for the response
|
||||||
|
- Includes reasonable limits (default 100 if not specified)
|
||||||
|
3. If variables are needed, include them in the response
|
||||||
|
|
||||||
|
Response Format:
|
||||||
|
|
||||||
|
Return a JSON object with:
|
||||||
|
- "query": the GraphQL query string
|
||||||
|
- "variables": object with any GraphQL variables (empty object if none)
|
||||||
|
- "confidence": float between 0.0-1.0 indicating confidence in the query
|
||||||
|
|
||||||
|
Example:
|
||||||
|
{
|
||||||
|
"query": "query { customers(where: {state: {eq: \"California\"}}, limit: 100) { customer_id name
|
||||||
|
email state } }",
|
||||||
|
"variables": {},
|
||||||
|
"confidence": 0.95
|
||||||
|
}
|
||||||
|
|
||||||
303
trustgraph-flow/trustgraph/retrieval/nlp_query/service.py
Normal file
303
trustgraph-flow/trustgraph/retrieval/nlp_query/service.py
Normal file
|
|
@ -0,0 +1,303 @@
|
||||||
|
"""
|
||||||
|
NLP to Structured Query Service - converts natural language questions to GraphQL queries.
|
||||||
|
Two-phase approach: 1) Select relevant schemas, 2) Generate GraphQL query.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
|
|
||||||
|
from ...schema import QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse
|
||||||
|
from ...schema import PromptRequest, PromptResponse
|
||||||
|
from ...schema import Error, RowSchema, Field as SchemaField
|
||||||
|
|
||||||
|
from ...base import FlowProcessor, ConsumerSpec, ProducerSpec, ClientSpec
|
||||||
|
|
||||||
|
# Module logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
default_ident = "nlp-query"
|
||||||
|
default_schema_selection_template = "schema-selection"
|
||||||
|
default_graphql_generation_template = "graphql-generation"
|
||||||
|
|
||||||
|
class Processor(FlowProcessor):
|
||||||
|
|
||||||
|
def __init__(self, **params):
|
||||||
|
|
||||||
|
id = params.get("id", default_ident)
|
||||||
|
|
||||||
|
# Config key for schemas
|
||||||
|
self.config_key = params.get("config_type", "schema")
|
||||||
|
|
||||||
|
# Configurable prompt template names
|
||||||
|
self.schema_selection_template = params.get("schema_selection_template", default_schema_selection_template)
|
||||||
|
self.graphql_generation_template = params.get("graphql_generation_template", default_graphql_generation_template)
|
||||||
|
|
||||||
|
super(Processor, self).__init__(
|
||||||
|
**params | {
|
||||||
|
"id": id,
|
||||||
|
"config-type": self.config_key,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_specification(
|
||||||
|
ConsumerSpec(
|
||||||
|
name = "request",
|
||||||
|
schema = QuestionToStructuredQueryRequest,
|
||||||
|
handler = self.on_message
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_specification(
|
||||||
|
ProducerSpec(
|
||||||
|
name = "response",
|
||||||
|
schema = QuestionToStructuredQueryResponse,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Client spec for calling prompt service
|
||||||
|
self.register_specification(
|
||||||
|
ClientSpec(
|
||||||
|
request_name = "prompt-request",
|
||||||
|
response_name = "prompt-response",
|
||||||
|
request_schema = PromptRequest,
|
||||||
|
response_schema = PromptResponse
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register config handler for schema updates
|
||||||
|
self.register_config_handler(self.on_schema_config)
|
||||||
|
|
||||||
|
# Schema storage: name -> RowSchema
|
||||||
|
self.schemas: Dict[str, RowSchema] = {}
|
||||||
|
|
||||||
|
logger.info("NLP Query service initialized")
|
||||||
|
|
||||||
|
async def on_schema_config(self, config, version):
|
||||||
|
"""Handle schema configuration updates"""
|
||||||
|
logger.info(f"Loading schema configuration version {version}")
|
||||||
|
|
||||||
|
# Clear existing schemas
|
||||||
|
self.schemas = {}
|
||||||
|
|
||||||
|
# Check if our config type exists
|
||||||
|
if self.config_key not in config:
|
||||||
|
logger.warning(f"No '{self.config_key}' type in configuration")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get the schemas dictionary for our type
|
||||||
|
schemas_config = config[self.config_key]
|
||||||
|
|
||||||
|
# Process each schema in the schemas config
|
||||||
|
for schema_name, schema_json in schemas_config.items():
|
||||||
|
try:
|
||||||
|
# Parse the JSON schema definition
|
||||||
|
schema_def = json.loads(schema_json)
|
||||||
|
|
||||||
|
# Create Field objects
|
||||||
|
fields = []
|
||||||
|
for field_def in schema_def.get("fields", []):
|
||||||
|
field = SchemaField(
|
||||||
|
name=field_def["name"],
|
||||||
|
type=field_def["type"],
|
||||||
|
size=field_def.get("size", 0),
|
||||||
|
primary=field_def.get("primary_key", False),
|
||||||
|
description=field_def.get("description", ""),
|
||||||
|
required=field_def.get("required", False),
|
||||||
|
enum_values=field_def.get("enum", []),
|
||||||
|
indexed=field_def.get("indexed", False)
|
||||||
|
)
|
||||||
|
fields.append(field)
|
||||||
|
|
||||||
|
# Create RowSchema
|
||||||
|
row_schema = RowSchema(
|
||||||
|
name=schema_def.get("name", schema_name),
|
||||||
|
description=schema_def.get("description", ""),
|
||||||
|
fields=fields
|
||||||
|
)
|
||||||
|
|
||||||
|
self.schemas[schema_name] = row_schema
|
||||||
|
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
|
||||||
|
|
||||||
|
async def phase1_select_schemas(self, question: str) -> List[str]:
|
||||||
|
"""Phase 1: Use prompt service to select relevant schemas for the question"""
|
||||||
|
logger.info("Starting Phase 1: Schema selection")
|
||||||
|
|
||||||
|
# Prepare schema information for the prompt
|
||||||
|
schema_info = []
|
||||||
|
for name, schema in self.schemas.items():
|
||||||
|
schema_desc = {
|
||||||
|
"name": name,
|
||||||
|
"description": schema.description,
|
||||||
|
"fields": [{"name": f.name, "type": f.type, "description": f.description}
|
||||||
|
for f in schema.fields]
|
||||||
|
}
|
||||||
|
schema_info.append(schema_desc)
|
||||||
|
|
||||||
|
# Create prompt variables
|
||||||
|
variables = {
|
||||||
|
"question": question,
|
||||||
|
"schemas": schema_info # Pass structured data directly
|
||||||
|
}
|
||||||
|
|
||||||
|
# Call prompt service for schema selection
|
||||||
|
prompt_request = PromptRequest(
|
||||||
|
template=self.schema_selection_template,
|
||||||
|
variables=variables
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.client("prompt-request").request(prompt_request)
|
||||||
|
|
||||||
|
if response.error is not None:
|
||||||
|
raise Exception(f"Prompt service error: {response.error}")
|
||||||
|
|
||||||
|
# Parse the response to get selected schema names
|
||||||
|
# Expecting response.text to contain JSON array of schema names
|
||||||
|
selected_schemas = json.loads(response.text)
|
||||||
|
|
||||||
|
logger.info(f"Phase 1 selected schemas: {selected_schemas}")
|
||||||
|
return selected_schemas
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Phase 1 schema selection failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def phase2_generate_graphql(self, question: str, selected_schemas: List[str]) -> Dict[str, Any]:
|
||||||
|
"""Phase 2: Generate GraphQL query using selected schemas"""
|
||||||
|
logger.info(f"Starting Phase 2: GraphQL generation for schemas: {selected_schemas}")
|
||||||
|
|
||||||
|
# Get detailed schema information for selected schemas only
|
||||||
|
selected_schema_info = []
|
||||||
|
for schema_name in selected_schemas:
|
||||||
|
if schema_name in self.schemas:
|
||||||
|
schema = self.schemas[schema_name]
|
||||||
|
schema_desc = {
|
||||||
|
"name": schema_name,
|
||||||
|
"description": schema.description,
|
||||||
|
"fields": [
|
||||||
|
{
|
||||||
|
"name": f.name,
|
||||||
|
"type": f.type,
|
||||||
|
"description": f.description,
|
||||||
|
"required": f.required,
|
||||||
|
"primary_key": f.primary,
|
||||||
|
"indexed": f.indexed,
|
||||||
|
"enum_values": f.enum_values if f.enum_values else []
|
||||||
|
}
|
||||||
|
for f in schema.fields
|
||||||
|
]
|
||||||
|
}
|
||||||
|
selected_schema_info.append(schema_desc)
|
||||||
|
|
||||||
|
# Create prompt variables for GraphQL generation
|
||||||
|
variables = {
|
||||||
|
"question": question,
|
||||||
|
"schemas": selected_schema_info # Pass structured data directly
|
||||||
|
}
|
||||||
|
|
||||||
|
# Call prompt service for GraphQL generation
|
||||||
|
prompt_request = PromptRequest(
|
||||||
|
template=self.graphql_generation_template,
|
||||||
|
variables=variables
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.client("prompt-request").request(prompt_request)
|
||||||
|
|
||||||
|
if response.error is not None:
|
||||||
|
raise Exception(f"Prompt service error: {response.error}")
|
||||||
|
|
||||||
|
# Parse the response to get GraphQL query and variables
|
||||||
|
# Expecting response.text to contain JSON with "query" and "variables" fields
|
||||||
|
result = json.loads(response.text)
|
||||||
|
|
||||||
|
logger.info(f"Phase 2 generated GraphQL: {result.get('query', '')[:100]}...")
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Phase 2 GraphQL generation failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def on_message(self, msg, consumer, flow):
|
||||||
|
"""Handle incoming question to structured query request"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
request = msg.value()
|
||||||
|
|
||||||
|
# Sender-produced ID
|
||||||
|
id = msg.properties()["id"]
|
||||||
|
|
||||||
|
logger.info(f"Handling NLP query request {id}: {request.question[:100]}...")
|
||||||
|
|
||||||
|
# Phase 1: Select relevant schemas
|
||||||
|
selected_schemas = await self.phase1_select_schemas(request.question)
|
||||||
|
|
||||||
|
# Phase 2: Generate GraphQL query
|
||||||
|
graphql_result = await self.phase2_generate_graphql(request.question, selected_schemas)
|
||||||
|
|
||||||
|
# Create response
|
||||||
|
response = QuestionToStructuredQueryResponse(
|
||||||
|
error=None,
|
||||||
|
graphql_query=graphql_result.get("query", ""),
|
||||||
|
variables=graphql_result.get("variables", {}),
|
||||||
|
detected_schemas=selected_schemas,
|
||||||
|
confidence=graphql_result.get("confidence", 0.8) # Default confidence
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Sending NLP query response...")
|
||||||
|
await flow("response").send(response, properties={"id": id})
|
||||||
|
|
||||||
|
logger.info("NLP query request completed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
|
||||||
|
logger.error(f"Exception in NLP query service: {e}", exc_info=True)
|
||||||
|
|
||||||
|
logger.info("Sending error response...")
|
||||||
|
|
||||||
|
response = QuestionToStructuredQueryResponse(
|
||||||
|
error = Error(
|
||||||
|
type = "nlp-query-error",
|
||||||
|
message = str(e),
|
||||||
|
),
|
||||||
|
graphql_query = "",
|
||||||
|
variables = {},
|
||||||
|
detected_schemas = [],
|
||||||
|
confidence = 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
await flow("response").send(response, properties={"id": id})
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_args(parser):
|
||||||
|
"""Add command-line arguments"""
|
||||||
|
|
||||||
|
FlowProcessor.add_args(parser)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--config-type',
|
||||||
|
default='schema',
|
||||||
|
help='Configuration type prefix for schemas (default: schema)'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--schema-selection-template',
|
||||||
|
default=default_schema_selection_template,
|
||||||
|
help=f'Prompt template name for schema selection (default: {default_schema_selection_template})'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--graphql-generation-template',
|
||||||
|
default=default_graphql_generation_template,
|
||||||
|
help=f'Prompt template name for GraphQL generation (default: {default_graphql_generation_template})'
|
||||||
|
)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
"""Entry point for nlp-query command"""
|
||||||
|
Processor.launch(default_ident, __doc__)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue