mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-27 17:36:23 +02:00
NLP query to GraphQL service (#491)
This commit is contained in:
parent
c078ca45cd
commit
8d4aa0069c
7 changed files with 440 additions and 5 deletions
|
|
@ -0,0 +1 @@
|
|||
from . service import *
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . service import run
|
||||
|
||||
run()
|
||||
25
trustgraph-flow/trustgraph/retrieval/nlp_query/pass1.txt
Normal file
25
trustgraph-flow/trustgraph/retrieval/nlp_query/pass1.txt
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
You are a database schema selection expert. Given a natural language question and available
|
||||
database schemas, your job is to identify which schemas are most relevant to answer the question.
|
||||
|
||||
## Available Schemas:
|
||||
{% for schema in schemas %}
|
||||
**{{ schema.name }}**: {{ schema.description }}
|
||||
Fields:
|
||||
{% for field in schema.fields %}
|
||||
- {{ field.name }} ({{ field.type }}): {{ field.description }}
|
||||
{% endfor %}
|
||||
|
||||
{% endfor %}
|
||||
|
||||
## Question:
|
||||
{{ question }}
|
||||
|
||||
## Instructions:
|
||||
1. Analyze the question to understand what data is being requested
|
||||
2. Examine each schema to understand what data it contains
|
||||
3. Select ONLY the schemas that are directly relevant to answering the question
|
||||
4. Return your answer as a JSON array of schema names
|
||||
|
||||
## Response Format:
|
||||
Return ONLY a JSON array of schema names, nothing else.
|
||||
Example: ["customers", "orders", "products"]
|
||||
101
trustgraph-flow/trustgraph/retrieval/nlp_query/pass2.txt
Normal file
101
trustgraph-flow/trustgraph/retrieval/nlp_query/pass2.txt
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
You are a GraphQL query generation expert. Given a natural language question and relevant database
|
||||
schemas, generate a precise GraphQL query to answer the question.
|
||||
|
||||
## Question:
|
||||
{{ question }}
|
||||
|
||||
## Relevant Schemas:
|
||||
{% for schema in schemas %}
|
||||
**{{ schema.name }}**: {{ schema.description }}
|
||||
Fields:
|
||||
{% for field in schema.fields %}
|
||||
- {{ field.name }} ({{ field.type }}){% if field.description %}: {{ field.description }}{% endif
|
||||
%}{% if field.primary_key %} [PRIMARY KEY]{% endif %}{% if field.required %} [REQUIRED]{% endif
|
||||
%}{% if field.indexed %} [INDEXED]{% endif %}{% if field.enum_values %} [OPTIONS: {{
|
||||
field.enum_values|join(', ') }}]{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
{% endfor %}
|
||||
|
||||
## GraphQL Query Rules:
|
||||
1. Use the schema names as GraphQL query fields (e.g., `customers`, `orders`)
|
||||
2. Apply filters using the `where` parameter with nested filter objects
|
||||
3. Available filter operators per field type:
|
||||
- String fields: `eq`, `contains`, `startsWith`, `endsWith`, `in`, `not`, `not_in`
|
||||
- Integer/Float fields: `eq`, `gt`, `gte`, `lt`, `lte`, `in`, `not`, `not_in`
|
||||
4. Use `order_by` for sorting (field name as string)
|
||||
5. Use `direction` for sort direction: `ASC` or `DESC`
|
||||
6. Use `limit` to restrict number of results
|
||||
7. Select specific fields in the query body
|
||||
|
||||
## Example GraphQL Queries:
|
||||
|
||||
**Question**: "Show me customers from California"
|
||||
```graphql
|
||||
query {
|
||||
customers(where: {state: {eq: "California"}}, limit: 100) {
|
||||
customer_id
|
||||
name
|
||||
email
|
||||
state
|
||||
}
|
||||
}
|
||||
|
||||
Question: "Top 10 products by price"
|
||||
query {
|
||||
products(order_by: "price", direction: DESC, limit: 10) {
|
||||
product_id
|
||||
name
|
||||
price
|
||||
category
|
||||
}
|
||||
}
|
||||
|
||||
Question: "Recent orders over $100"
|
||||
query {
|
||||
orders(
|
||||
where: {
|
||||
total_amount: {gt: 100}
|
||||
order_date: {gte: "2024-01-01"}
|
||||
}
|
||||
order_by: "order_date"
|
||||
direction: DESC
|
||||
limit: 50
|
||||
) {
|
||||
order_id
|
||||
customer_id
|
||||
total_amount
|
||||
order_date
|
||||
status
|
||||
}
|
||||
}
|
||||
|
||||
Instructions:
|
||||
|
||||
1. Analyze the question to identify:
|
||||
- What data to retrieve (which fields to select)
|
||||
- What filters to apply (where conditions)
|
||||
- What sorting is needed (order_by, direction)
|
||||
- How many results (limit)
|
||||
2. Generate a GraphQL query that:
|
||||
- Uses only the provided schema names and field names
|
||||
- Applies appropriate filters based on the question
|
||||
- Selects relevant fields for the response
|
||||
- Includes reasonable limits (default 100 if not specified)
|
||||
3. If variables are needed, include them in the response
|
||||
|
||||
Response Format:
|
||||
|
||||
Return a JSON object with:
|
||||
- "query": the GraphQL query string
|
||||
- "variables": object with any GraphQL variables (empty object if none)
|
||||
- "confidence": float between 0.0-1.0 indicating confidence in the query
|
||||
|
||||
Example:
|
||||
{
|
||||
"query": "query { customers(where: {state: {eq: \"California\"}}, limit: 100) { customer_id name
|
||||
email state } }",
|
||||
"variables": {},
|
||||
"confidence": 0.95
|
||||
}
|
||||
|
||||
303
trustgraph-flow/trustgraph/retrieval/nlp_query/service.py
Normal file
303
trustgraph-flow/trustgraph/retrieval/nlp_query/service.py
Normal file
|
|
@ -0,0 +1,303 @@
|
|||
"""
|
||||
NLP to Structured Query Service - converts natural language questions to GraphQL queries.
|
||||
Two-phase approach: 1) Select relevant schemas, 2) Generate GraphQL query.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
from ...schema import QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse
|
||||
from ...schema import PromptRequest, PromptResponse
|
||||
from ...schema import Error, RowSchema, Field as SchemaField
|
||||
|
||||
from ...base import FlowProcessor, ConsumerSpec, ProducerSpec, ClientSpec
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "nlp-query"
|
||||
default_schema_selection_template = "schema-selection"
|
||||
default_graphql_generation_template = "graphql-generation"
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
|
||||
# Config key for schemas
|
||||
self.config_key = params.get("config_type", "schema")
|
||||
|
||||
# Configurable prompt template names
|
||||
self.schema_selection_template = params.get("schema_selection_template", default_schema_selection_template)
|
||||
self.graphql_generation_template = params.get("graphql_generation_template", default_graphql_generation_template)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"config-type": self.config_key,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "request",
|
||||
schema = QuestionToStructuredQueryRequest,
|
||||
handler = self.on_message
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "response",
|
||||
schema = QuestionToStructuredQueryResponse,
|
||||
)
|
||||
)
|
||||
|
||||
# Client spec for calling prompt service
|
||||
self.register_specification(
|
||||
ClientSpec(
|
||||
request_name = "prompt-request",
|
||||
response_name = "prompt-response",
|
||||
request_schema = PromptRequest,
|
||||
response_schema = PromptResponse
|
||||
)
|
||||
)
|
||||
|
||||
# Register config handler for schema updates
|
||||
self.register_config_handler(self.on_schema_config)
|
||||
|
||||
# Schema storage: name -> RowSchema
|
||||
self.schemas: Dict[str, RowSchema] = {}
|
||||
|
||||
logger.info("NLP Query service initialized")
|
||||
|
||||
async def on_schema_config(self, config, version):
|
||||
"""Handle schema configuration updates"""
|
||||
logger.info(f"Loading schema configuration version {version}")
|
||||
|
||||
# Clear existing schemas
|
||||
self.schemas = {}
|
||||
|
||||
# Check if our config type exists
|
||||
if self.config_key not in config:
|
||||
logger.warning(f"No '{self.config_key}' type in configuration")
|
||||
return
|
||||
|
||||
# Get the schemas dictionary for our type
|
||||
schemas_config = config[self.config_key]
|
||||
|
||||
# Process each schema in the schemas config
|
||||
for schema_name, schema_json in schemas_config.items():
|
||||
try:
|
||||
# Parse the JSON schema definition
|
||||
schema_def = json.loads(schema_json)
|
||||
|
||||
# Create Field objects
|
||||
fields = []
|
||||
for field_def in schema_def.get("fields", []):
|
||||
field = SchemaField(
|
||||
name=field_def["name"],
|
||||
type=field_def["type"],
|
||||
size=field_def.get("size", 0),
|
||||
primary=field_def.get("primary_key", False),
|
||||
description=field_def.get("description", ""),
|
||||
required=field_def.get("required", False),
|
||||
enum_values=field_def.get("enum", []),
|
||||
indexed=field_def.get("indexed", False)
|
||||
)
|
||||
fields.append(field)
|
||||
|
||||
# Create RowSchema
|
||||
row_schema = RowSchema(
|
||||
name=schema_def.get("name", schema_name),
|
||||
description=schema_def.get("description", ""),
|
||||
fields=fields
|
||||
)
|
||||
|
||||
self.schemas[schema_name] = row_schema
|
||||
logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True)
|
||||
|
||||
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
|
||||
|
||||
async def phase1_select_schemas(self, question: str) -> List[str]:
|
||||
"""Phase 1: Use prompt service to select relevant schemas for the question"""
|
||||
logger.info("Starting Phase 1: Schema selection")
|
||||
|
||||
# Prepare schema information for the prompt
|
||||
schema_info = []
|
||||
for name, schema in self.schemas.items():
|
||||
schema_desc = {
|
||||
"name": name,
|
||||
"description": schema.description,
|
||||
"fields": [{"name": f.name, "type": f.type, "description": f.description}
|
||||
for f in schema.fields]
|
||||
}
|
||||
schema_info.append(schema_desc)
|
||||
|
||||
# Create prompt variables
|
||||
variables = {
|
||||
"question": question,
|
||||
"schemas": schema_info # Pass structured data directly
|
||||
}
|
||||
|
||||
# Call prompt service for schema selection
|
||||
prompt_request = PromptRequest(
|
||||
template=self.schema_selection_template,
|
||||
variables=variables
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client("prompt-request").request(prompt_request)
|
||||
|
||||
if response.error is not None:
|
||||
raise Exception(f"Prompt service error: {response.error}")
|
||||
|
||||
# Parse the response to get selected schema names
|
||||
# Expecting response.text to contain JSON array of schema names
|
||||
selected_schemas = json.loads(response.text)
|
||||
|
||||
logger.info(f"Phase 1 selected schemas: {selected_schemas}")
|
||||
return selected_schemas
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Phase 1 schema selection failed: {e}")
|
||||
raise
|
||||
|
||||
async def phase2_generate_graphql(self, question: str, selected_schemas: List[str]) -> Dict[str, Any]:
|
||||
"""Phase 2: Generate GraphQL query using selected schemas"""
|
||||
logger.info(f"Starting Phase 2: GraphQL generation for schemas: {selected_schemas}")
|
||||
|
||||
# Get detailed schema information for selected schemas only
|
||||
selected_schema_info = []
|
||||
for schema_name in selected_schemas:
|
||||
if schema_name in self.schemas:
|
||||
schema = self.schemas[schema_name]
|
||||
schema_desc = {
|
||||
"name": schema_name,
|
||||
"description": schema.description,
|
||||
"fields": [
|
||||
{
|
||||
"name": f.name,
|
||||
"type": f.type,
|
||||
"description": f.description,
|
||||
"required": f.required,
|
||||
"primary_key": f.primary,
|
||||
"indexed": f.indexed,
|
||||
"enum_values": f.enum_values if f.enum_values else []
|
||||
}
|
||||
for f in schema.fields
|
||||
]
|
||||
}
|
||||
selected_schema_info.append(schema_desc)
|
||||
|
||||
# Create prompt variables for GraphQL generation
|
||||
variables = {
|
||||
"question": question,
|
||||
"schemas": selected_schema_info # Pass structured data directly
|
||||
}
|
||||
|
||||
# Call prompt service for GraphQL generation
|
||||
prompt_request = PromptRequest(
|
||||
template=self.graphql_generation_template,
|
||||
variables=variables
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client("prompt-request").request(prompt_request)
|
||||
|
||||
if response.error is not None:
|
||||
raise Exception(f"Prompt service error: {response.error}")
|
||||
|
||||
# Parse the response to get GraphQL query and variables
|
||||
# Expecting response.text to contain JSON with "query" and "variables" fields
|
||||
result = json.loads(response.text)
|
||||
|
||||
logger.info(f"Phase 2 generated GraphQL: {result.get('query', '')[:100]}...")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Phase 2 GraphQL generation failed: {e}")
|
||||
raise
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
"""Handle incoming question to structured query request"""
|
||||
|
||||
try:
|
||||
request = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
logger.info(f"Handling NLP query request {id}: {request.question[:100]}...")
|
||||
|
||||
# Phase 1: Select relevant schemas
|
||||
selected_schemas = await self.phase1_select_schemas(request.question)
|
||||
|
||||
# Phase 2: Generate GraphQL query
|
||||
graphql_result = await self.phase2_generate_graphql(request.question, selected_schemas)
|
||||
|
||||
# Create response
|
||||
response = QuestionToStructuredQueryResponse(
|
||||
error=None,
|
||||
graphql_query=graphql_result.get("query", ""),
|
||||
variables=graphql_result.get("variables", {}),
|
||||
detected_schemas=selected_schemas,
|
||||
confidence=graphql_result.get("confidence", 0.8) # Default confidence
|
||||
)
|
||||
|
||||
logger.info("Sending NLP query response...")
|
||||
await flow("response").send(response, properties={"id": id})
|
||||
|
||||
logger.info("NLP query request completed")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
logger.error(f"Exception in NLP query service: {e}", exc_info=True)
|
||||
|
||||
logger.info("Sending error response...")
|
||||
|
||||
response = QuestionToStructuredQueryResponse(
|
||||
error = Error(
|
||||
type = "nlp-query-error",
|
||||
message = str(e),
|
||||
),
|
||||
graphql_query = "",
|
||||
variables = {},
|
||||
detected_schemas = [],
|
||||
confidence = 0.0
|
||||
)
|
||||
|
||||
await flow("response").send(response, properties={"id": id})
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add command-line arguments"""
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--config-type',
|
||||
default='schema',
|
||||
help='Configuration type prefix for schemas (default: schema)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--schema-selection-template',
|
||||
default=default_schema_selection_template,
|
||||
help=f'Prompt template name for schema selection (default: {default_schema_selection_template})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--graphql-generation-template',
|
||||
default=default_graphql_generation_template,
|
||||
help=f'Prompt template name for GraphQL generation (default: {default_graphql_generation_template})'
|
||||
)
|
||||
|
||||
def run():
|
||||
"""Entry point for nlp-query command"""
|
||||
Processor.launch(default_ident, __doc__)
|
||||
Loading…
Add table
Add a link
Reference in a new issue