diff --git a/tests/integration/test_nlp_query_integration.py b/tests/integration/test_nlp_query_integration.py index 83e56ac4..16c4543e 100644 --- a/tests/integration/test_nlp_query_integration.py +++ b/tests/integration/test_nlp_query_integration.py @@ -131,15 +131,18 @@ class TestNLPQueryServiceIntegration: ) # Set up mock to return different responses for each call - integration_processor.client.return_value.request = AsyncMock( + # Mock the flow context to return prompt service responses + prompt_service = AsyncMock() + prompt_service.request = AsyncMock( side_effect=[phase1_response, phase2_response] ) + flow.side_effect = lambda service_name: prompt_service if service_name == "prompt-request" else flow_response if service_name == "response" else AsyncMock() # Act - Process the message await integration_processor.on_message(msg, consumer, flow) # Assert - Verify the complete pipeline - assert integration_processor.client.return_value.request.call_count == 2 + assert prompt_service.request.call_count == 2 flow_response.send.assert_called_once() # Verify response structure and content @@ -188,9 +191,12 @@ class TestNLPQueryServiceIntegration: error=None ) - integration_processor.client.return_value.request = AsyncMock( + # Mock the flow context to return prompt service responses + prompt_service = AsyncMock() + prompt_service.request = AsyncMock( side_effect=[phase1_response, phase2_response] ) + flow.side_effect = lambda service_name: prompt_service if service_name == "prompt-request" else flow_response if service_name == "response" else AsyncMock() # Act await integration_processor.on_message(msg, consumer, flow) @@ -255,9 +261,12 @@ class TestNLPQueryServiceIntegration: error=None ) - integration_processor.client.return_value.request = AsyncMock( + # Mock the flow context to return prompt service responses + prompt_service = AsyncMock() + prompt_service.request = AsyncMock( side_effect=[phase1_response, phase2_response] ) + flow.side_effect = lambda service_name: prompt_service if service_name == "prompt-request" else flow_response if service_name == "response" else AsyncMock() # Act await integration_processor.on_message(msg, consumer, flow) @@ -293,9 +302,12 @@ class TestNLPQueryServiceIntegration: error=Error(type="template-not-found", message="Schema selection template not available") ) - integration_processor.client.return_value.request = AsyncMock( + # Mock the flow context to return prompt service error response + prompt_service = AsyncMock() + prompt_service.request = AsyncMock( return_value=phase1_error_response ) + flow.side_effect = lambda service_name: prompt_service if service_name == "prompt-request" else flow_response if service_name == "response" else AsyncMock() # Act await integration_processor.on_message(msg, consumer, flow) @@ -350,9 +362,12 @@ class TestNLPQueryServiceIntegration: error=None ) - custom_processor.client.return_value.request = AsyncMock( + # Mock flow context to return prompt service responses + mock_prompt_service = AsyncMock() + mock_prompt_service.request = AsyncMock( side_effect=[phase1_response, phase2_response] ) + flow.side_effect = lambda service_name: mock_prompt_service if service_name == "prompt-request" else flow_response if service_name == "response" else AsyncMock() # Act await custom_processor.on_message(msg, consumer, flow) @@ -362,7 +377,7 @@ class TestNLPQueryServiceIntegration: assert custom_processor.graphql_generation_template == "custom-graphql-generator" # Verify the calls were made - assert custom_processor.client.return_value.request.call_count == 2 + assert mock_prompt_service.request.call_count == 2 @pytest.mark.asyncio async def test_large_schema_set_integration(self, integration_processor): @@ -410,9 +425,12 @@ class TestNLPQueryServiceIntegration: error=None ) - integration_processor.client.return_value.request = AsyncMock( + # Mock the flow context to return prompt service responses + prompt_service = AsyncMock() + prompt_service.request = AsyncMock( side_effect=[phase1_response, phase2_response] ) + flow.side_effect = lambda service_name: prompt_service if service_name == "prompt-request" else flow_response if service_name == "response" else AsyncMock() # Act await integration_processor.on_message(msg, consumer, flow) @@ -451,27 +469,36 @@ class TestNLPQueryServiceIntegration: messages.append(msg) flows.append(flow) - # Mock responses for all requests - mock_responses = [] - for i in range(10): # 2 calls per request (phase1 + phase2) - if i % 2 == 0: # Phase 1 responses - mock_responses.append(PromptResponse( - text=json.dumps(["customers"]), - error=None - )) - else: # Phase 2 responses - mock_responses.append(PromptResponse( - text=json.dumps({ - "query": f"query {{ customers {{ id name }} }}", - "variables": {}, - "confidence": 0.9 - }), - error=None - )) - - integration_processor.client.return_value.request = AsyncMock( - side_effect=mock_responses - ) + # Mock responses for all requests - create individual prompt services for each flow + prompt_services = [] + for i in range(5): # 5 concurrent requests + phase1_response = PromptResponse( + text=json.dumps(["customers"]), + error=None + ) + phase2_response = PromptResponse( + text=json.dumps({ + "query": f"query {{ customers {{ id name }} }}", + "variables": {}, + "confidence": 0.9 + }), + error=None + ) + + # Create a prompt service for this request + prompt_service = AsyncMock() + prompt_service.request = AsyncMock( + side_effect=[phase1_response, phase2_response] + ) + prompt_services.append(prompt_service) + + # Set up the flow for this request + flow_response = flows[i].return_value + flows[i].side_effect = lambda service_name, ps=prompt_service, fr=flow_response: ( + ps if service_name == "prompt-request" else + fr if service_name == "response" else + AsyncMock() + ) # Act - Process all messages concurrently import asyncio @@ -485,7 +512,8 @@ class TestNLPQueryServiceIntegration: await asyncio.gather(*tasks) # Assert - All requests should be processed - assert integration_processor.client.return_value.request.call_count == 10 + total_calls = sum(ps.request.call_count for ps in prompt_services) + assert total_calls == 10 # 2 calls per request (phase1 + phase2) for flow in flows: flow.return_value.send.assert_called_once() @@ -518,9 +546,12 @@ class TestNLPQueryServiceIntegration: error=None ) - integration_processor.client.return_value.request = AsyncMock( + # Mock the flow context to return prompt service responses + prompt_service = AsyncMock() + prompt_service.request = AsyncMock( side_effect=[phase1_response, phase2_response] ) + flow.side_effect = lambda service_name: prompt_service if service_name == "prompt-request" else flow_response if service_name == "response" else AsyncMock() # Act import time diff --git a/tests/integration/test_structured_query_integration.py b/tests/integration/test_structured_query_integration.py index 72ae1e77..2e836b33 100644 --- a/tests/integration/test_structured_query_integration.py +++ b/tests/integration/test_structured_query_integration.py @@ -93,9 +93,17 @@ class TestStructuredQueryServiceIntegration: mock_objects_client = AsyncMock() mock_objects_client.request.return_value = objects_response - integration_processor.client.side_effect = lambda name: ( - mock_nlp_client if name == "nlp-query-request" else mock_objects_client - ) + # Mock flow context to route to appropriate services + def flow_router(service_name): + if service_name == "nlp-query-request": + return mock_nlp_client + elif service_name == "objects-query-request": + return mock_objects_client + elif service_name == "response": + return flow_response + else: + return AsyncMock() + flow.side_effect = flow_router # Act - Process the message await integration_processor.on_message(msg, consumer, flow) @@ -116,7 +124,7 @@ class TestStructuredQueryServiceIntegration: assert "orders" in objects_call_args.query assert objects_call_args.variables["minAmount"] == "500.0" # Converted to string assert objects_call_args.variables["state"] == "California" - assert objects_call_args.user == "default" + assert objects_call_args.user == "trustgraph" assert objects_call_args.collection == "default" # Verify response @@ -159,7 +167,15 @@ class TestStructuredQueryServiceIntegration: mock_nlp_client = AsyncMock() mock_nlp_client.request.return_value = nlp_error_response - integration_processor.client.return_value = mock_nlp_client + # Mock flow context to route to nlp service + def flow_router(service_name): + if service_name == "nlp-query-request": + return mock_nlp_client + elif service_name == "response": + return flow_response + else: + return AsyncMock() + flow.side_effect = flow_router # Act await integration_processor.on_message(msg, consumer, flow) @@ -215,9 +231,17 @@ class TestStructuredQueryServiceIntegration: mock_objects_client = AsyncMock() mock_objects_client.request.return_value = objects_error_response - integration_processor.client.side_effect = lambda name: ( - mock_nlp_client if name == "nlp-query-request" else mock_objects_client - ) + # Mock flow context to route to appropriate services + def flow_router(service_name): + if service_name == "nlp-query-request": + return mock_nlp_client + elif service_name == "objects-query-request": + return mock_objects_client + elif service_name == "response": + return flow_response + else: + return AsyncMock() + flow.side_effect = flow_router # Act await integration_processor.on_message(msg, consumer, flow) @@ -285,9 +309,17 @@ class TestStructuredQueryServiceIntegration: mock_objects_client = AsyncMock() mock_objects_client.request.return_value = objects_response - integration_processor.client.side_effect = lambda name: ( - mock_nlp_client if name == "nlp-query-request" else mock_objects_client - ) + # Mock flow context to route to appropriate services + def flow_router(service_name): + if service_name == "nlp-query-request": + return mock_nlp_client + elif service_name == "objects-query-request": + return mock_objects_client + elif service_name == "response": + return flow_response + else: + return AsyncMock() + flow.side_effect = flow_router # Act await integration_processor.on_message(msg, consumer, flow) @@ -405,9 +437,17 @@ class TestStructuredQueryServiceIntegration: mock_objects_client = AsyncMock() mock_objects_client.request.return_value = objects_response - integration_processor.client.side_effect = lambda name: ( - mock_nlp_client if name == "nlp-query-request" else mock_objects_client - ) + # Mock flow context to route to appropriate services + def flow_router(service_name): + if service_name == "nlp-query-request": + return mock_nlp_client + elif service_name == "objects-query-request": + return mock_objects_client + elif service_name == "response": + return flow_response + else: + return AsyncMock() + flow.side_effect = flow_router # Act await integration_processor.on_message(msg, consumer, flow) @@ -474,9 +514,17 @@ class TestStructuredQueryServiceIntegration: mock_objects_client = AsyncMock() mock_objects_client.request.return_value = objects_response - integration_processor.client.side_effect = lambda name: ( - mock_nlp_client if name == "nlp-query-request" else mock_objects_client - ) + # Mock flow context to route to appropriate services + def flow_router(service_name): + if service_name == "nlp-query-request": + return mock_nlp_client + elif service_name == "objects-query-request": + return mock_objects_client + elif service_name == "response": + return flow_response + else: + return AsyncMock() + flow.side_effect = flow_router # Act await integration_processor.on_message(msg, consumer, flow) @@ -514,34 +562,51 @@ class TestStructuredQueryServiceIntegration: messages.append(msg) flows.append(flow) - # Mock responses for all requests (6 total: 3 NLP + 3 Objects) - mock_responses = [] - for i in range(6): - if i % 2 == 0: # NLP responses - mock_responses.append(QuestionToStructuredQueryResponse( - error=None, - graphql_query=f'query {{ test_{i//2} {{ id }} }}', - variables={}, - detected_schemas=[f"test_{i//2}"], - confidence=0.9 - )) - else: # Objects responses - mock_responses.append(ObjectsQueryResponse( - error=None, - data=f'{{"test_{i//2}": [{{"id": "{i//2}"}}]}}', - errors=None, - extensions={} - )) + # Set up individual flow routing for each concurrent request + service_call_count = 0 - call_count = 0 - def mock_client_side_effect(name): - nonlocal call_count - client = AsyncMock() - client.request.return_value = mock_responses[call_count] - call_count += 1 - return client - - integration_processor.client.side_effect = mock_client_side_effect + for i in range(3): # 3 concurrent requests + # Create NLP and Objects responses for this request + nlp_response = QuestionToStructuredQueryResponse( + error=None, + graphql_query=f'query {{ test_{i} {{ id }} }}', + variables={}, + detected_schemas=[f"test_{i}"], + confidence=0.9 + ) + + objects_response = ObjectsQueryResponse( + error=None, + data=f'{{"test_{i}": [{{"id": "{i}"}}]}}', + errors=None, + extensions={} + ) + + # Create mock services for this request + mock_nlp_client = AsyncMock() + mock_nlp_client.request.return_value = nlp_response + + mock_objects_client = AsyncMock() + mock_objects_client.request.return_value = objects_response + + # Set up flow routing for this specific request + flow_response = flows[i].return_value + def create_flow_router(nlp_client, objects_client, response_producer): + def flow_router(service_name): + nonlocal service_call_count + if service_name == "nlp-query-request": + service_call_count += 1 + return nlp_client + elif service_name == "objects-query-request": + service_call_count += 1 + return objects_client + elif service_name == "response": + return response_producer + else: + return AsyncMock() + return flow_router + + flows[i].side_effect = create_flow_router(mock_nlp_client, mock_objects_client, flow_response) # Act - Process all messages concurrently import asyncio @@ -555,7 +620,7 @@ class TestStructuredQueryServiceIntegration: await asyncio.gather(*tasks) # Assert - All requests should be processed - assert call_count == 6 # 2 calls per request (NLP + Objects) + assert service_call_count == 6 # 2 calls per request (NLP + Objects) for flow in flows: flow.return_value.send.assert_called_once() @@ -580,7 +645,15 @@ class TestStructuredQueryServiceIntegration: mock_nlp_client = AsyncMock() mock_nlp_client.request.side_effect = Exception("Service timeout: Request took longer than 30s") - integration_processor.client.return_value = mock_nlp_client + # Mock flow context to route to nlp service + def flow_router(service_name): + if service_name == "nlp-query-request": + return mock_nlp_client + elif service_name == "response": + return flow_response + else: + return AsyncMock() + flow.side_effect = flow_router # Act await integration_processor.on_message(msg, consumer, flow) @@ -638,9 +711,17 @@ class TestStructuredQueryServiceIntegration: mock_objects_client = AsyncMock() mock_objects_client.request.return_value = objects_response - integration_processor.client.side_effect = lambda name: ( - mock_nlp_client if name == "nlp-query-request" else mock_objects_client - ) + # Mock flow context to route to appropriate services + def flow_router(service_name): + if service_name == "nlp-query-request": + return mock_nlp_client + elif service_name == "objects-query-request": + return mock_objects_client + elif service_name == "response": + return flow_response + else: + return AsyncMock() + flow.side_effect = flow_router # Act await integration_processor.on_message(msg, consumer, flow) diff --git a/tests/unit/test_retrieval/test_nlp_query.py b/tests/unit/test_retrieval/test_nlp_query.py index c783c8f0..5141f2b2 100644 --- a/tests/unit/test_retrieval/test_nlp_query.py +++ b/tests/unit/test_retrieval/test_nlp_query.py @@ -87,14 +87,18 @@ class TestNLPQueryProcessor: error=None ) - processor.client.return_value.request = AsyncMock(return_value=mock_response) + # Mock flow context + flow = MagicMock() + mock_prompt_service = AsyncMock() + mock_prompt_service.request = AsyncMock(return_value=mock_response) + flow.side_effect = lambda service_name: mock_prompt_service if service_name == "prompt-request" else AsyncMock() # Act - result = await processor.phase1_select_schemas(question) + result = await processor.phase1_select_schemas(question, flow) # Assert assert result == expected_schemas - processor.client.assert_called_once_with("prompt-request") + mock_prompt_service.request.assert_called_once() async def test_phase1_select_schemas_prompt_error(self, processor): """Test schema selection with prompt service error""" @@ -103,11 +107,15 @@ class TestNLPQueryProcessor: error = Error(type="prompt-error", message="Template not found") mock_response = PromptResponse(text="", error=error) - processor.client.return_value.request = AsyncMock(return_value=mock_response) + # Mock flow context + flow = MagicMock() + mock_prompt_service = AsyncMock() + mock_prompt_service.request = AsyncMock(return_value=mock_response) + flow.side_effect = lambda service_name: mock_prompt_service if service_name == "prompt-request" else AsyncMock() # Act & Assert with pytest.raises(Exception, match="Prompt service error"): - await processor.phase1_select_schemas(question) + await processor.phase1_select_schemas(question, flow) async def test_phase2_generate_graphql_success(self, processor): """Test successful GraphQL generation (Phase 2)""" @@ -125,14 +133,18 @@ class TestNLPQueryProcessor: error=None ) - processor.client.return_value.request = AsyncMock(return_value=mock_response) + # Mock flow context + flow = MagicMock() + mock_prompt_service = AsyncMock() + mock_prompt_service.request = AsyncMock(return_value=mock_response) + flow.side_effect = lambda service_name: mock_prompt_service if service_name == "prompt-request" else AsyncMock() # Act - result = await processor.phase2_generate_graphql(question, selected_schemas) + result = await processor.phase2_generate_graphql(question, selected_schemas, flow) # Assert assert result == expected_result - processor.client.assert_called_once_with("prompt-request") + mock_prompt_service.request.assert_called_once() async def test_phase2_generate_graphql_prompt_error(self, processor): """Test GraphQL generation with prompt service error""" @@ -142,11 +154,15 @@ class TestNLPQueryProcessor: error = Error(type="prompt-error", message="Generation failed") mock_response = PromptResponse(text="", error=error) - processor.client.return_value.request = AsyncMock(return_value=mock_response) + # Mock flow context + flow = MagicMock() + mock_prompt_service = AsyncMock() + mock_prompt_service.request = AsyncMock(return_value=mock_response) + flow.side_effect = lambda service_name: mock_prompt_service if service_name == "prompt-request" else AsyncMock() # Act & Assert with pytest.raises(Exception, match="Prompt service error"): - await processor.phase2_generate_graphql(question, selected_schemas) + await processor.phase2_generate_graphql(question, selected_schemas, flow) async def test_on_message_full_flow_success(self, processor): """Test complete message processing flow""" @@ -181,16 +197,18 @@ class TestNLPQueryProcessor: error=None ) - # Set up mock to return different responses for each call - processor.client.return_value.request = AsyncMock( + # Mock flow context to return prompt service responses + mock_prompt_service = AsyncMock() + mock_prompt_service.request = AsyncMock( side_effect=[phase1_response, phase2_response] ) + flow.side_effect = lambda service_name: mock_prompt_service if service_name == "prompt-request" else flow_response if service_name == "response" else AsyncMock() # Act await processor.on_message(msg, consumer, flow) # Assert - assert processor.client.return_value.request.call_count == 2 + assert mock_prompt_service.request.call_count == 2 flow_response.send.assert_called_once() # Verify response structure diff --git a/tests/unit/test_retrieval/test_structured_query.py b/tests/unit/test_retrieval/test_structured_query.py index f8b157eb..1e78fa97 100644 --- a/tests/unit/test_retrieval/test_structured_query.py +++ b/tests/unit/test_retrieval/test_structured_query.py @@ -80,9 +80,17 @@ class TestStructuredQueryProcessor: mock_objects_client = AsyncMock() mock_objects_client.request.return_value = objects_response - processor.client.side_effect = lambda name: ( - mock_nlp_client if name == "nlp-query-request" else mock_objects_client - ) + # Mock flow context to route to appropriate services + def flow_router(service_name): + if service_name == "nlp-query-request": + return mock_nlp_client + elif service_name == "objects-query-request": + return mock_objects_client + elif service_name == "response": + return flow_response + else: + return AsyncMock() + flow.side_effect = flow_router # Act await processor.on_message(msg, consumer, flow) @@ -101,7 +109,7 @@ class TestStructuredQueryProcessor: assert isinstance(objects_call_args, ObjectsQueryRequest) assert objects_call_args.query == 'query { customers(where: {state: {eq: "NY"}}) { id name email } }' assert objects_call_args.variables == {"state": "NY"} - assert objects_call_args.user == "default" + assert objects_call_args.user == "trustgraph" assert objects_call_args.collection == "default" # Verify response @@ -142,7 +150,15 @@ class TestStructuredQueryProcessor: mock_nlp_client = AsyncMock() mock_nlp_client.request.return_value = nlp_response - processor.client.return_value = mock_nlp_client + # Mock flow context to route to nlp service + def flow_router(service_name): + if service_name == "nlp-query-request": + return mock_nlp_client + elif service_name == "response": + return flow_response + else: + return AsyncMock() + flow.side_effect = flow_router # Act await processor.on_message(msg, consumer, flow) @@ -185,7 +201,15 @@ class TestStructuredQueryProcessor: mock_nlp_client = AsyncMock() mock_nlp_client.request.return_value = nlp_response - processor.client.return_value = mock_nlp_client + # Mock flow context to route to nlp service + def flow_router(service_name): + if service_name == "nlp-query-request": + return mock_nlp_client + elif service_name == "response": + return flow_response + else: + return AsyncMock() + flow.side_effect = flow_router # Act await processor.on_message(msg, consumer, flow) @@ -237,9 +261,17 @@ class TestStructuredQueryProcessor: mock_objects_client = AsyncMock() mock_objects_client.request.return_value = objects_response - processor.client.side_effect = lambda name: ( - mock_nlp_client if name == "nlp-query-request" else mock_objects_client - ) + # Mock flow context to route to appropriate services + def flow_router(service_name): + if service_name == "nlp-query-request": + return mock_nlp_client + elif service_name == "objects-query-request": + return mock_objects_client + elif service_name == "response": + return flow_response + else: + return AsyncMock() + flow.side_effect = flow_router # Act await processor.on_message(msg, consumer, flow) @@ -300,9 +332,17 @@ class TestStructuredQueryProcessor: mock_objects_client = AsyncMock() mock_objects_client.request.return_value = objects_response - processor.client.side_effect = lambda name: ( - mock_nlp_client if name == "nlp-query-request" else mock_objects_client - ) + # Mock flow context to route to appropriate services + def flow_router(service_name): + if service_name == "nlp-query-request": + return mock_nlp_client + elif service_name == "objects-query-request": + return mock_objects_client + elif service_name == "response": + return flow_response + else: + return AsyncMock() + flow.side_effect = flow_router # Act await processor.on_message(msg, consumer, flow) @@ -370,9 +410,17 @@ class TestStructuredQueryProcessor: mock_objects_client = AsyncMock() mock_objects_client.request.return_value = objects_response - processor.client.side_effect = lambda name: ( - mock_nlp_client if name == "nlp-query-request" else mock_objects_client - ) + # Mock flow context to route to appropriate services + def flow_router(service_name): + if service_name == "nlp-query-request": + return mock_nlp_client + elif service_name == "objects-query-request": + return mock_objects_client + elif service_name == "response": + return flow_response + else: + return AsyncMock() + flow.side_effect = flow_router # Act await processor.on_message(msg, consumer, flow) @@ -427,9 +475,17 @@ class TestStructuredQueryProcessor: mock_objects_client = AsyncMock() mock_objects_client.request.return_value = objects_response - processor.client.side_effect = lambda name: ( - mock_nlp_client if name == "nlp-query-request" else mock_objects_client - ) + # Mock flow context to route to appropriate services + def flow_router(service_name): + if service_name == "nlp-query-request": + return mock_nlp_client + elif service_name == "objects-query-request": + return mock_objects_client + elif service_name == "response": + return flow_response + else: + return AsyncMock() + flow.side_effect = flow_router # Act await processor.on_message(msg, consumer, flow) @@ -457,10 +513,18 @@ class TestStructuredQueryProcessor: flow_response = AsyncMock() flow.return_value = flow_response - # Mock client to raise exception + # Mock flow context to raise exception mock_client = AsyncMock() mock_client.request.side_effect = Exception("Network timeout") - processor.client.return_value = mock_client + + def flow_router(service_name): + if service_name == "nlp-query-request": + return mock_client + elif service_name == "response": + return flow_response + else: + return AsyncMock() + flow.side_effect = flow_router # Act await processor.on_message(msg, consumer, flow) diff --git a/trustgraph-flow/trustgraph/retrieval/nlp_query/service.py b/trustgraph-flow/trustgraph/retrieval/nlp_query/service.py index f5100be3..67eaeaec 100644 --- a/trustgraph-flow/trustgraph/retrieval/nlp_query/service.py +++ b/trustgraph-flow/trustgraph/retrieval/nlp_query/service.py @@ -122,7 +122,7 @@ class Processor(FlowProcessor): logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas") - async def phase1_select_schemas(self, question: str) -> List[str]: + async def phase1_select_schemas(self, question: str, flow) -> List[str]: """Phase 1: Use prompt service to select relevant schemas for the question""" logger.info("Starting Phase 1: Schema selection") @@ -144,20 +144,27 @@ class Processor(FlowProcessor): } # Call prompt service for schema selection + # Convert variables to JSON-encoded terms + terms = {k: json.dumps(v) for k, v in variables.items()} prompt_request = PromptRequest( - template=self.schema_selection_template, - variables=variables + id=self.schema_selection_template, + terms=terms ) try: - response = await self.client("prompt-request").request(prompt_request) + response = await flow("prompt-request").request(prompt_request) if response.error is not None: raise Exception(f"Prompt service error: {response.error}") # Parse the response to get selected schema names - # Expecting response.text to contain JSON array of schema names - selected_schemas = json.loads(response.text) + # Response could be in either text or object field + response_data = response.text if response.text else response.object + if response_data is None: + raise Exception("Prompt service returned empty response") + + # Parse JSON array of schema names + selected_schemas = json.loads(response_data) logger.info(f"Phase 1 selected schemas: {selected_schemas}") return selected_schemas @@ -166,7 +173,7 @@ class Processor(FlowProcessor): logger.error(f"Phase 1 schema selection failed: {e}") raise - async def phase2_generate_graphql(self, question: str, selected_schemas: List[str]) -> Dict[str, Any]: + async def phase2_generate_graphql(self, question: str, selected_schemas: List[str], flow) -> Dict[str, Any]: """Phase 2: Generate GraphQL query using selected schemas""" logger.info(f"Starting Phase 2: GraphQL generation for schemas: {selected_schemas}") @@ -200,20 +207,27 @@ class Processor(FlowProcessor): } # Call prompt service for GraphQL generation + # Convert variables to JSON-encoded terms + terms = {k: json.dumps(v) for k, v in variables.items()} prompt_request = PromptRequest( - template=self.graphql_generation_template, - variables=variables + id=self.graphql_generation_template, + terms=terms ) try: - response = await self.client("prompt-request").request(prompt_request) + response = await flow("prompt-request").request(prompt_request) if response.error is not None: raise Exception(f"Prompt service error: {response.error}") # Parse the response to get GraphQL query and variables - # Expecting response.text to contain JSON with "query" and "variables" fields - result = json.loads(response.text) + # Response could be in either text or object field + response_data = response.text if response.text else response.object + if response_data is None: + raise Exception("Prompt service returned empty response") + + # Parse JSON with "query" and "variables" fields + result = json.loads(response_data) logger.info(f"Phase 2 generated GraphQL: {result.get('query', '')[:100]}...") return result @@ -234,10 +248,10 @@ class Processor(FlowProcessor): logger.info(f"Handling NLP query request {id}: {request.question[:100]}...") # Phase 1: Select relevant schemas - selected_schemas = await self.phase1_select_schemas(request.question) + selected_schemas = await self.phase1_select_schemas(request.question, flow) # Phase 2: Generate GraphQL query - graphql_result = await self.phase2_generate_graphql(request.question, selected_schemas) + graphql_result = await self.phase2_generate_graphql(request.question, selected_schemas, flow) # Create response response = QuestionToStructuredQueryResponse( diff --git a/trustgraph-flow/trustgraph/retrieval/structured_query/service.py b/trustgraph-flow/trustgraph/retrieval/structured_query/service.py index 42817d91..28327f82 100644 --- a/trustgraph-flow/trustgraph/retrieval/structured_query/service.py +++ b/trustgraph-flow/trustgraph/retrieval/structured_query/service.py @@ -87,7 +87,7 @@ class Processor(FlowProcessor): max_results=100 # Default limit ) - nlp_response = await self.client("nlp-query-request").request(nlp_request) + nlp_response = await flow("nlp-query-request").request(nlp_request) if nlp_response.error is not None: raise Exception(f"NLP query service error: {nlp_response.error.message}") @@ -111,17 +111,17 @@ class Processor(FlowProcessor): else: variables_as_strings[key] = str(value) - # For now, we'll use default user/collection values - # In a real implementation, these would come from authentication/context + # Use standard TrustGraph user/collection values + # These should eventually come from authentication/context objects_request = ObjectsQueryRequest( - user="default", # TODO: Get from authentication context - collection="default", # TODO: Get from request context + user="trustgraph", # Standard TrustGraph user + collection="default", # Standard default collection query=nlp_response.graphql_query, variables=variables_as_strings, operation_name=None ) - objects_response = await self.client("objects-query-request").request(objects_request) + objects_response = await flow("objects-query-request").request(objects_request) if objects_response.error is not None: raise Exception(f"Objects query service error: {objects_response.error.message}")