diff --git a/trustgraph-base/trustgraph/schema/services/nlp_query.py b/trustgraph-base/trustgraph/schema/services/nlp_query.py index 4e7c20fe..a3e709a1 100644 --- a/trustgraph-base/trustgraph/schema/services/nlp_query.py +++ b/trustgraph-base/trustgraph/schema/services/nlp_query.py @@ -7,16 +7,15 @@ from ..core.topic import topic # NLP to Structured Query Service - converts natural language to GraphQL -class NLPToStructuredQueryRequest(Record): - natural_language_query = String() +class QuestionToStructuredQueryRequest(Record): + question = String() max_results = Integer() - context_hints = Map(String()) # Optional context for query generation -class NLPToStructuredQueryResponse(Record): +class QuestionToStructuredQueryResponse(Record): error = Error() graphql_query = String() # Generated GraphQL query variables = Map(String()) # GraphQL variables if any detected_schemas = Array(String()) # Which schemas the query targets confidence = Double() -############################################################################ \ No newline at end of file +############################################################################ diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index 3d0e7554..75428ff2 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -87,6 +87,7 @@ kg-store = "trustgraph.storage.knowledge:run" librarian = "trustgraph.librarian:run" mcp-tool = "trustgraph.agent.mcp_tool:run" metering = "trustgraph.metering:run" +nlp-query = "trustgraph.retrieval.nlp_query:run" objects-write-cassandra = "trustgraph.storage.objects.cassandra:run" objects-query-cassandra = "trustgraph.query.objects.cassandra:run" oe-write-milvus = "trustgraph.storage.object_embeddings.milvus:run" diff --git a/trustgraph-flow/trustgraph/retrieval/nlp_query/__init__.py b/trustgraph-flow/trustgraph/retrieval/nlp_query/__init__.py new file mode 100644 index 00000000..974260f2 --- /dev/null +++ b/trustgraph-flow/trustgraph/retrieval/nlp_query/__init__.py @@ -0,0 +1 @@ +from . service import * \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/retrieval/nlp_query/__main__.py b/trustgraph-flow/trustgraph/retrieval/nlp_query/__main__.py new file mode 100644 index 00000000..0bec8f9d --- /dev/null +++ b/trustgraph-flow/trustgraph/retrieval/nlp_query/__main__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +from . service import run + +run() \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/retrieval/nlp_query/pass1.txt b/trustgraph-flow/trustgraph/retrieval/nlp_query/pass1.txt new file mode 100644 index 00000000..39b180e5 --- /dev/null +++ b/trustgraph-flow/trustgraph/retrieval/nlp_query/pass1.txt @@ -0,0 +1,25 @@ +You are a database schema selection expert. Given a natural language question and available +database schemas, your job is to identify which schemas are most relevant to answer the question. + +## Available Schemas: +{% for schema in schemas %} +**{{ schema.name }}**: {{ schema.description }} +Fields: +{% for field in schema.fields %} +- {{ field.name }} ({{ field.type }}): {{ field.description }} +{% endfor %} + +{% endfor %} + +## Question: +{{ question }} + +## Instructions: +1. Analyze the question to understand what data is being requested +2. Examine each schema to understand what data it contains +3. Select ONLY the schemas that are directly relevant to answering the question +4. Return your answer as a JSON array of schema names + +## Response Format: +Return ONLY a JSON array of schema names, nothing else. +Example: ["customers", "orders", "products"] diff --git a/trustgraph-flow/trustgraph/retrieval/nlp_query/pass2.txt b/trustgraph-flow/trustgraph/retrieval/nlp_query/pass2.txt new file mode 100644 index 00000000..4aa4f93a --- /dev/null +++ b/trustgraph-flow/trustgraph/retrieval/nlp_query/pass2.txt @@ -0,0 +1,101 @@ +You are a GraphQL query generation expert. Given a natural language question and relevant database + schemas, generate a precise GraphQL query to answer the question. + +## Question: +{{ question }} + +## Relevant Schemas: +{% for schema in schemas %} +**{{ schema.name }}**: {{ schema.description }} +Fields: +{% for field in schema.fields %} +- {{ field.name }} ({{ field.type }}){% if field.description %}: {{ field.description }}{% endif +%}{% if field.primary_key %} [PRIMARY KEY]{% endif %}{% if field.required %} [REQUIRED]{% endif +%}{% if field.indexed %} [INDEXED]{% endif %}{% if field.enum_values %} [OPTIONS: {{ +field.enum_values|join(', ') }}]{% endif %} +{% endfor %} + +{% endfor %} + +## GraphQL Query Rules: +1. Use the schema names as GraphQL query fields (e.g., `customers`, `orders`) +2. Apply filters using the `where` parameter with nested filter objects +3. Available filter operators per field type: + - String fields: `eq`, `contains`, `startsWith`, `endsWith`, `in`, `not`, `not_in` + - Integer/Float fields: `eq`, `gt`, `gte`, `lt`, `lte`, `in`, `not`, `not_in` +4. Use `order_by` for sorting (field name as string) +5. Use `direction` for sort direction: `ASC` or `DESC` +6. Use `limit` to restrict number of results +7. Select specific fields in the query body + +## Example GraphQL Queries: + +**Question**: "Show me customers from California" +```graphql +query { + customers(where: {state: {eq: "California"}}, limit: 100) { + customer_id + name + email + state + } +} + +Question: "Top 10 products by price" +query { + products(order_by: "price", direction: DESC, limit: 10) { + product_id + name + price + category + } +} + +Question: "Recent orders over $100" +query { + orders( + where: { + total_amount: {gt: 100} + order_date: {gte: "2024-01-01"} + } + order_by: "order_date" + direction: DESC + limit: 50 + ) { + order_id + customer_id + total_amount + order_date + status + } +} + +Instructions: + +1. Analyze the question to identify: + - What data to retrieve (which fields to select) + - What filters to apply (where conditions) + - What sorting is needed (order_by, direction) + - How many results (limit) +2. Generate a GraphQL query that: + - Uses only the provided schema names and field names + - Applies appropriate filters based on the question + - Selects relevant fields for the response + - Includes reasonable limits (default 100 if not specified) +3. If variables are needed, include them in the response + +Response Format: + +Return a JSON object with: +- "query": the GraphQL query string +- "variables": object with any GraphQL variables (empty object if none) +- "confidence": float between 0.0-1.0 indicating confidence in the query + +Example: +{ + "query": "query { customers(where: {state: {eq: \"California\"}}, limit: 100) { customer_id name + email state } }", + "variables": {}, + "confidence": 0.95 +} + diff --git a/trustgraph-flow/trustgraph/retrieval/nlp_query/service.py b/trustgraph-flow/trustgraph/retrieval/nlp_query/service.py new file mode 100644 index 00000000..1e962e0a --- /dev/null +++ b/trustgraph-flow/trustgraph/retrieval/nlp_query/service.py @@ -0,0 +1,303 @@ +""" +NLP to Structured Query Service - converts natural language questions to GraphQL queries. +Two-phase approach: 1) Select relevant schemas, 2) Generate GraphQL query. +""" + +import json +import logging +from typing import Dict, Any, Optional, List + +from ...schema import QuestionToStructuredQueryRequest, QuestionToStructuredQueryResponse +from ...schema import PromptRequest, PromptResponse +from ...schema import Error, RowSchema, Field as SchemaField + +from ...base import FlowProcessor, ConsumerSpec, ProducerSpec, ClientSpec + +# Module logger +logger = logging.getLogger(__name__) + +default_ident = "nlp-query" +default_schema_selection_template = "schema-selection" +default_graphql_generation_template = "graphql-generation" + +class Processor(FlowProcessor): + + def __init__(self, **params): + + id = params.get("id", default_ident) + + # Config key for schemas + self.config_key = params.get("config_type", "schema") + + # Configurable prompt template names + self.schema_selection_template = params.get("schema_selection_template", default_schema_selection_template) + self.graphql_generation_template = params.get("graphql_generation_template", default_graphql_generation_template) + + super(Processor, self).__init__( + **params | { + "id": id, + "config-type": self.config_key, + } + ) + + self.register_specification( + ConsumerSpec( + name = "request", + schema = QuestionToStructuredQueryRequest, + handler = self.on_message + ) + ) + + self.register_specification( + ProducerSpec( + name = "response", + schema = QuestionToStructuredQueryResponse, + ) + ) + + # Client spec for calling prompt service + self.register_specification( + ClientSpec( + request_name = "prompt-request", + response_name = "prompt-response", + request_schema = PromptRequest, + response_schema = PromptResponse + ) + ) + + # Register config handler for schema updates + self.register_config_handler(self.on_schema_config) + + # Schema storage: name -> RowSchema + self.schemas: Dict[str, RowSchema] = {} + + logger.info("NLP Query service initialized") + + async def on_schema_config(self, config, version): + """Handle schema configuration updates""" + logger.info(f"Loading schema configuration version {version}") + + # Clear existing schemas + self.schemas = {} + + # Check if our config type exists + if self.config_key not in config: + logger.warning(f"No '{self.config_key}' type in configuration") + return + + # Get the schemas dictionary for our type + schemas_config = config[self.config_key] + + # Process each schema in the schemas config + for schema_name, schema_json in schemas_config.items(): + try: + # Parse the JSON schema definition + schema_def = json.loads(schema_json) + + # Create Field objects + fields = [] + for field_def in schema_def.get("fields", []): + field = SchemaField( + name=field_def["name"], + type=field_def["type"], + size=field_def.get("size", 0), + primary=field_def.get("primary_key", False), + description=field_def.get("description", ""), + required=field_def.get("required", False), + enum_values=field_def.get("enum", []), + indexed=field_def.get("indexed", False) + ) + fields.append(field) + + # Create RowSchema + row_schema = RowSchema( + name=schema_def.get("name", schema_name), + description=schema_def.get("description", ""), + fields=fields + ) + + self.schemas[schema_name] = row_schema + logger.info(f"Loaded schema: {schema_name} with {len(fields)} fields") + + except Exception as e: + logger.error(f"Failed to parse schema {schema_name}: {e}", exc_info=True) + + logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") + + async def phase1_select_schemas(self, question: str) -> List[str]: + """Phase 1: Use prompt service to select relevant schemas for the question""" + logger.info("Starting Phase 1: Schema selection") + + # Prepare schema information for the prompt + schema_info = [] + for name, schema in self.schemas.items(): + schema_desc = { + "name": name, + "description": schema.description, + "fields": [{"name": f.name, "type": f.type, "description": f.description} + for f in schema.fields] + } + schema_info.append(schema_desc) + + # Create prompt variables + variables = { + "question": question, + "schemas": schema_info # Pass structured data directly + } + + # Call prompt service for schema selection + prompt_request = PromptRequest( + template=self.schema_selection_template, + variables=variables + ) + + try: + response = await self.client("prompt-request").request(prompt_request) + + if response.error is not None: + raise Exception(f"Prompt service error: {response.error}") + + # Parse the response to get selected schema names + # Expecting response.text to contain JSON array of schema names + selected_schemas = json.loads(response.text) + + logger.info(f"Phase 1 selected schemas: {selected_schemas}") + return selected_schemas + + except Exception as e: + logger.error(f"Phase 1 schema selection failed: {e}") + raise + + async def phase2_generate_graphql(self, question: str, selected_schemas: List[str]) -> Dict[str, Any]: + """Phase 2: Generate GraphQL query using selected schemas""" + logger.info(f"Starting Phase 2: GraphQL generation for schemas: {selected_schemas}") + + # Get detailed schema information for selected schemas only + selected_schema_info = [] + for schema_name in selected_schemas: + if schema_name in self.schemas: + schema = self.schemas[schema_name] + schema_desc = { + "name": schema_name, + "description": schema.description, + "fields": [ + { + "name": f.name, + "type": f.type, + "description": f.description, + "required": f.required, + "primary_key": f.primary, + "indexed": f.indexed, + "enum_values": f.enum_values if f.enum_values else [] + } + for f in schema.fields + ] + } + selected_schema_info.append(schema_desc) + + # Create prompt variables for GraphQL generation + variables = { + "question": question, + "schemas": selected_schema_info # Pass structured data directly + } + + # Call prompt service for GraphQL generation + prompt_request = PromptRequest( + template=self.graphql_generation_template, + variables=variables + ) + + try: + response = await self.client("prompt-request").request(prompt_request) + + if response.error is not None: + raise Exception(f"Prompt service error: {response.error}") + + # Parse the response to get GraphQL query and variables + # Expecting response.text to contain JSON with "query" and "variables" fields + result = json.loads(response.text) + + logger.info(f"Phase 2 generated GraphQL: {result.get('query', '')[:100]}...") + return result + + except Exception as e: + logger.error(f"Phase 2 GraphQL generation failed: {e}") + raise + + async def on_message(self, msg, consumer, flow): + """Handle incoming question to structured query request""" + + try: + request = msg.value() + + # Sender-produced ID + id = msg.properties()["id"] + + logger.info(f"Handling NLP query request {id}: {request.question[:100]}...") + + # Phase 1: Select relevant schemas + selected_schemas = await self.phase1_select_schemas(request.question) + + # Phase 2: Generate GraphQL query + graphql_result = await self.phase2_generate_graphql(request.question, selected_schemas) + + # Create response + response = QuestionToStructuredQueryResponse( + error=None, + graphql_query=graphql_result.get("query", ""), + variables=graphql_result.get("variables", {}), + detected_schemas=selected_schemas, + confidence=graphql_result.get("confidence", 0.8) # Default confidence + ) + + logger.info("Sending NLP query response...") + await flow("response").send(response, properties={"id": id}) + + logger.info("NLP query request completed") + + except Exception as e: + + logger.error(f"Exception in NLP query service: {e}", exc_info=True) + + logger.info("Sending error response...") + + response = QuestionToStructuredQueryResponse( + error = Error( + type = "nlp-query-error", + message = str(e), + ), + graphql_query = "", + variables = {}, + detected_schemas = [], + confidence = 0.0 + ) + + await flow("response").send(response, properties={"id": id}) + + @staticmethod + def add_args(parser): + """Add command-line arguments""" + + FlowProcessor.add_args(parser) + + parser.add_argument( + '--config-type', + default='schema', + help='Configuration type prefix for schemas (default: schema)' + ) + + parser.add_argument( + '--schema-selection-template', + default=default_schema_selection_template, + help=f'Prompt template name for schema selection (default: {default_schema_selection_template})' + ) + + parser.add_argument( + '--graphql-generation-template', + default=default_graphql_generation_template, + help=f'Prompt template name for GraphQL generation (default: {default_graphql_generation_template})' + ) + +def run(): + """Entry point for nlp-query command""" + Processor.launch(default_ident, __doc__) \ No newline at end of file