mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-07 05:45:13 +02:00
Fix/sys integration issues (#494)
* Fix integration issues * Fix query defaults * Fix tests
This commit is contained in:
parent
ed0e02791d
commit
50c37407c5
6 changed files with 341 additions and 133 deletions
|
|
@ -131,15 +131,18 @@ class TestNLPQueryServiceIntegration:
|
|||
)
|
||||
|
||||
# Set up mock to return different responses for each call
|
||||
integration_processor.client.return_value.request = AsyncMock(
|
||||
# Mock the flow context to return prompt service responses
|
||||
prompt_service = AsyncMock()
|
||||
prompt_service.request = AsyncMock(
|
||||
side_effect=[phase1_response, phase2_response]
|
||||
)
|
||||
flow.side_effect = lambda service_name: prompt_service if service_name == "prompt-request" else flow_response if service_name == "response" else AsyncMock()
|
||||
|
||||
# Act - Process the message
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
||||
# Assert - Verify the complete pipeline
|
||||
assert integration_processor.client.return_value.request.call_count == 2
|
||||
assert prompt_service.request.call_count == 2
|
||||
flow_response.send.assert_called_once()
|
||||
|
||||
# Verify response structure and content
|
||||
|
|
@ -188,9 +191,12 @@ class TestNLPQueryServiceIntegration:
|
|||
error=None
|
||||
)
|
||||
|
||||
integration_processor.client.return_value.request = AsyncMock(
|
||||
# Mock the flow context to return prompt service responses
|
||||
prompt_service = AsyncMock()
|
||||
prompt_service.request = AsyncMock(
|
||||
side_effect=[phase1_response, phase2_response]
|
||||
)
|
||||
flow.side_effect = lambda service_name: prompt_service if service_name == "prompt-request" else flow_response if service_name == "response" else AsyncMock()
|
||||
|
||||
# Act
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -255,9 +261,12 @@ class TestNLPQueryServiceIntegration:
|
|||
error=None
|
||||
)
|
||||
|
||||
integration_processor.client.return_value.request = AsyncMock(
|
||||
# Mock the flow context to return prompt service responses
|
||||
prompt_service = AsyncMock()
|
||||
prompt_service.request = AsyncMock(
|
||||
side_effect=[phase1_response, phase2_response]
|
||||
)
|
||||
flow.side_effect = lambda service_name: prompt_service if service_name == "prompt-request" else flow_response if service_name == "response" else AsyncMock()
|
||||
|
||||
# Act
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -293,9 +302,12 @@ class TestNLPQueryServiceIntegration:
|
|||
error=Error(type="template-not-found", message="Schema selection template not available")
|
||||
)
|
||||
|
||||
integration_processor.client.return_value.request = AsyncMock(
|
||||
# Mock the flow context to return prompt service error response
|
||||
prompt_service = AsyncMock()
|
||||
prompt_service.request = AsyncMock(
|
||||
return_value=phase1_error_response
|
||||
)
|
||||
flow.side_effect = lambda service_name: prompt_service if service_name == "prompt-request" else flow_response if service_name == "response" else AsyncMock()
|
||||
|
||||
# Act
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -350,9 +362,12 @@ class TestNLPQueryServiceIntegration:
|
|||
error=None
|
||||
)
|
||||
|
||||
custom_processor.client.return_value.request = AsyncMock(
|
||||
# Mock flow context to return prompt service responses
|
||||
mock_prompt_service = AsyncMock()
|
||||
mock_prompt_service.request = AsyncMock(
|
||||
side_effect=[phase1_response, phase2_response]
|
||||
)
|
||||
flow.side_effect = lambda service_name: mock_prompt_service if service_name == "prompt-request" else flow_response if service_name == "response" else AsyncMock()
|
||||
|
||||
# Act
|
||||
await custom_processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -362,7 +377,7 @@ class TestNLPQueryServiceIntegration:
|
|||
assert custom_processor.graphql_generation_template == "custom-graphql-generator"
|
||||
|
||||
# Verify the calls were made
|
||||
assert custom_processor.client.return_value.request.call_count == 2
|
||||
assert mock_prompt_service.request.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_schema_set_integration(self, integration_processor):
|
||||
|
|
@ -410,9 +425,12 @@ class TestNLPQueryServiceIntegration:
|
|||
error=None
|
||||
)
|
||||
|
||||
integration_processor.client.return_value.request = AsyncMock(
|
||||
# Mock the flow context to return prompt service responses
|
||||
prompt_service = AsyncMock()
|
||||
prompt_service.request = AsyncMock(
|
||||
side_effect=[phase1_response, phase2_response]
|
||||
)
|
||||
flow.side_effect = lambda service_name: prompt_service if service_name == "prompt-request" else flow_response if service_name == "response" else AsyncMock()
|
||||
|
||||
# Act
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -451,27 +469,36 @@ class TestNLPQueryServiceIntegration:
|
|||
messages.append(msg)
|
||||
flows.append(flow)
|
||||
|
||||
# Mock responses for all requests
|
||||
mock_responses = []
|
||||
for i in range(10): # 2 calls per request (phase1 + phase2)
|
||||
if i % 2 == 0: # Phase 1 responses
|
||||
mock_responses.append(PromptResponse(
|
||||
text=json.dumps(["customers"]),
|
||||
error=None
|
||||
))
|
||||
else: # Phase 2 responses
|
||||
mock_responses.append(PromptResponse(
|
||||
text=json.dumps({
|
||||
"query": f"query {{ customers {{ id name }} }}",
|
||||
"variables": {},
|
||||
"confidence": 0.9
|
||||
}),
|
||||
error=None
|
||||
))
|
||||
|
||||
integration_processor.client.return_value.request = AsyncMock(
|
||||
side_effect=mock_responses
|
||||
)
|
||||
# Mock responses for all requests - create individual prompt services for each flow
|
||||
prompt_services = []
|
||||
for i in range(5): # 5 concurrent requests
|
||||
phase1_response = PromptResponse(
|
||||
text=json.dumps(["customers"]),
|
||||
error=None
|
||||
)
|
||||
phase2_response = PromptResponse(
|
||||
text=json.dumps({
|
||||
"query": f"query {{ customers {{ id name }} }}",
|
||||
"variables": {},
|
||||
"confidence": 0.9
|
||||
}),
|
||||
error=None
|
||||
)
|
||||
|
||||
# Create a prompt service for this request
|
||||
prompt_service = AsyncMock()
|
||||
prompt_service.request = AsyncMock(
|
||||
side_effect=[phase1_response, phase2_response]
|
||||
)
|
||||
prompt_services.append(prompt_service)
|
||||
|
||||
# Set up the flow for this request
|
||||
flow_response = flows[i].return_value
|
||||
flows[i].side_effect = lambda service_name, ps=prompt_service, fr=flow_response: (
|
||||
ps if service_name == "prompt-request" else
|
||||
fr if service_name == "response" else
|
||||
AsyncMock()
|
||||
)
|
||||
|
||||
# Act - Process all messages concurrently
|
||||
import asyncio
|
||||
|
|
@ -485,7 +512,8 @@ class TestNLPQueryServiceIntegration:
|
|||
await asyncio.gather(*tasks)
|
||||
|
||||
# Assert - All requests should be processed
|
||||
assert integration_processor.client.return_value.request.call_count == 10
|
||||
total_calls = sum(ps.request.call_count for ps in prompt_services)
|
||||
assert total_calls == 10 # 2 calls per request (phase1 + phase2)
|
||||
for flow in flows:
|
||||
flow.return_value.send.assert_called_once()
|
||||
|
||||
|
|
@ -518,9 +546,12 @@ class TestNLPQueryServiceIntegration:
|
|||
error=None
|
||||
)
|
||||
|
||||
integration_processor.client.return_value.request = AsyncMock(
|
||||
# Mock the flow context to return prompt service responses
|
||||
prompt_service = AsyncMock()
|
||||
prompt_service.request = AsyncMock(
|
||||
side_effect=[phase1_response, phase2_response]
|
||||
)
|
||||
flow.side_effect = lambda service_name: prompt_service if service_name == "prompt-request" else flow_response if service_name == "response" else AsyncMock()
|
||||
|
||||
# Act
|
||||
import time
|
||||
|
|
|
|||
|
|
@ -93,9 +93,17 @@ class TestStructuredQueryServiceIntegration:
|
|||
mock_objects_client = AsyncMock()
|
||||
mock_objects_client.request.return_value = objects_response
|
||||
|
||||
integration_processor.client.side_effect = lambda name: (
|
||||
mock_nlp_client if name == "nlp-query-request" else mock_objects_client
|
||||
)
|
||||
# Mock flow context to route to appropriate services
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act - Process the message
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -116,7 +124,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
assert "orders" in objects_call_args.query
|
||||
assert objects_call_args.variables["minAmount"] == "500.0" # Converted to string
|
||||
assert objects_call_args.variables["state"] == "California"
|
||||
assert objects_call_args.user == "default"
|
||||
assert objects_call_args.user == "trustgraph"
|
||||
assert objects_call_args.collection == "default"
|
||||
|
||||
# Verify response
|
||||
|
|
@ -159,7 +167,15 @@ class TestStructuredQueryServiceIntegration:
|
|||
mock_nlp_client = AsyncMock()
|
||||
mock_nlp_client.request.return_value = nlp_error_response
|
||||
|
||||
integration_processor.client.return_value = mock_nlp_client
|
||||
# Mock flow context to route to nlp service
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -215,9 +231,17 @@ class TestStructuredQueryServiceIntegration:
|
|||
mock_objects_client = AsyncMock()
|
||||
mock_objects_client.request.return_value = objects_error_response
|
||||
|
||||
integration_processor.client.side_effect = lambda name: (
|
||||
mock_nlp_client if name == "nlp-query-request" else mock_objects_client
|
||||
)
|
||||
# Mock flow context to route to appropriate services
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -285,9 +309,17 @@ class TestStructuredQueryServiceIntegration:
|
|||
mock_objects_client = AsyncMock()
|
||||
mock_objects_client.request.return_value = objects_response
|
||||
|
||||
integration_processor.client.side_effect = lambda name: (
|
||||
mock_nlp_client if name == "nlp-query-request" else mock_objects_client
|
||||
)
|
||||
# Mock flow context to route to appropriate services
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -405,9 +437,17 @@ class TestStructuredQueryServiceIntegration:
|
|||
mock_objects_client = AsyncMock()
|
||||
mock_objects_client.request.return_value = objects_response
|
||||
|
||||
integration_processor.client.side_effect = lambda name: (
|
||||
mock_nlp_client if name == "nlp-query-request" else mock_objects_client
|
||||
)
|
||||
# Mock flow context to route to appropriate services
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -474,9 +514,17 @@ class TestStructuredQueryServiceIntegration:
|
|||
mock_objects_client = AsyncMock()
|
||||
mock_objects_client.request.return_value = objects_response
|
||||
|
||||
integration_processor.client.side_effect = lambda name: (
|
||||
mock_nlp_client if name == "nlp-query-request" else mock_objects_client
|
||||
)
|
||||
# Mock flow context to route to appropriate services
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -514,34 +562,51 @@ class TestStructuredQueryServiceIntegration:
|
|||
messages.append(msg)
|
||||
flows.append(flow)
|
||||
|
||||
# Mock responses for all requests (6 total: 3 NLP + 3 Objects)
|
||||
mock_responses = []
|
||||
for i in range(6):
|
||||
if i % 2 == 0: # NLP responses
|
||||
mock_responses.append(QuestionToStructuredQueryResponse(
|
||||
error=None,
|
||||
graphql_query=f'query {{ test_{i//2} {{ id }} }}',
|
||||
variables={},
|
||||
detected_schemas=[f"test_{i//2}"],
|
||||
confidence=0.9
|
||||
))
|
||||
else: # Objects responses
|
||||
mock_responses.append(ObjectsQueryResponse(
|
||||
error=None,
|
||||
data=f'{{"test_{i//2}": [{{"id": "{i//2}"}}]}}',
|
||||
errors=None,
|
||||
extensions={}
|
||||
))
|
||||
# Set up individual flow routing for each concurrent request
|
||||
service_call_count = 0
|
||||
|
||||
call_count = 0
|
||||
def mock_client_side_effect(name):
|
||||
nonlocal call_count
|
||||
client = AsyncMock()
|
||||
client.request.return_value = mock_responses[call_count]
|
||||
call_count += 1
|
||||
return client
|
||||
|
||||
integration_processor.client.side_effect = mock_client_side_effect
|
||||
for i in range(3): # 3 concurrent requests
|
||||
# Create NLP and Objects responses for this request
|
||||
nlp_response = QuestionToStructuredQueryResponse(
|
||||
error=None,
|
||||
graphql_query=f'query {{ test_{i} {{ id }} }}',
|
||||
variables={},
|
||||
detected_schemas=[f"test_{i}"],
|
||||
confidence=0.9
|
||||
)
|
||||
|
||||
objects_response = ObjectsQueryResponse(
|
||||
error=None,
|
||||
data=f'{{"test_{i}": [{{"id": "{i}"}}]}}',
|
||||
errors=None,
|
||||
extensions={}
|
||||
)
|
||||
|
||||
# Create mock services for this request
|
||||
mock_nlp_client = AsyncMock()
|
||||
mock_nlp_client.request.return_value = nlp_response
|
||||
|
||||
mock_objects_client = AsyncMock()
|
||||
mock_objects_client.request.return_value = objects_response
|
||||
|
||||
# Set up flow routing for this specific request
|
||||
flow_response = flows[i].return_value
|
||||
def create_flow_router(nlp_client, objects_client, response_producer):
|
||||
def flow_router(service_name):
|
||||
nonlocal service_call_count
|
||||
if service_name == "nlp-query-request":
|
||||
service_call_count += 1
|
||||
return nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
service_call_count += 1
|
||||
return objects_client
|
||||
elif service_name == "response":
|
||||
return response_producer
|
||||
else:
|
||||
return AsyncMock()
|
||||
return flow_router
|
||||
|
||||
flows[i].side_effect = create_flow_router(mock_nlp_client, mock_objects_client, flow_response)
|
||||
|
||||
# Act - Process all messages concurrently
|
||||
import asyncio
|
||||
|
|
@ -555,7 +620,7 @@ class TestStructuredQueryServiceIntegration:
|
|||
await asyncio.gather(*tasks)
|
||||
|
||||
# Assert - All requests should be processed
|
||||
assert call_count == 6 # 2 calls per request (NLP + Objects)
|
||||
assert service_call_count == 6 # 2 calls per request (NLP + Objects)
|
||||
for flow in flows:
|
||||
flow.return_value.send.assert_called_once()
|
||||
|
||||
|
|
@ -580,7 +645,15 @@ class TestStructuredQueryServiceIntegration:
|
|||
mock_nlp_client = AsyncMock()
|
||||
mock_nlp_client.request.side_effect = Exception("Service timeout: Request took longer than 30s")
|
||||
|
||||
integration_processor.client.return_value = mock_nlp_client
|
||||
# Mock flow context to route to nlp service
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -638,9 +711,17 @@ class TestStructuredQueryServiceIntegration:
|
|||
mock_objects_client = AsyncMock()
|
||||
mock_objects_client.request.return_value = objects_response
|
||||
|
||||
integration_processor.client.side_effect = lambda name: (
|
||||
mock_nlp_client if name == "nlp-query-request" else mock_objects_client
|
||||
)
|
||||
# Mock flow context to route to appropriate services
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await integration_processor.on_message(msg, consumer, flow)
|
||||
|
|
|
|||
|
|
@ -87,14 +87,18 @@ class TestNLPQueryProcessor:
|
|||
error=None
|
||||
)
|
||||
|
||||
processor.client.return_value.request = AsyncMock(return_value=mock_response)
|
||||
# Mock flow context
|
||||
flow = MagicMock()
|
||||
mock_prompt_service = AsyncMock()
|
||||
mock_prompt_service.request = AsyncMock(return_value=mock_response)
|
||||
flow.side_effect = lambda service_name: mock_prompt_service if service_name == "prompt-request" else AsyncMock()
|
||||
|
||||
# Act
|
||||
result = await processor.phase1_select_schemas(question)
|
||||
result = await processor.phase1_select_schemas(question, flow)
|
||||
|
||||
# Assert
|
||||
assert result == expected_schemas
|
||||
processor.client.assert_called_once_with("prompt-request")
|
||||
mock_prompt_service.request.assert_called_once()
|
||||
|
||||
async def test_phase1_select_schemas_prompt_error(self, processor):
|
||||
"""Test schema selection with prompt service error"""
|
||||
|
|
@ -103,11 +107,15 @@ class TestNLPQueryProcessor:
|
|||
error = Error(type="prompt-error", message="Template not found")
|
||||
mock_response = PromptResponse(text="", error=error)
|
||||
|
||||
processor.client.return_value.request = AsyncMock(return_value=mock_response)
|
||||
# Mock flow context
|
||||
flow = MagicMock()
|
||||
mock_prompt_service = AsyncMock()
|
||||
mock_prompt_service.request = AsyncMock(return_value=mock_response)
|
||||
flow.side_effect = lambda service_name: mock_prompt_service if service_name == "prompt-request" else AsyncMock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Prompt service error"):
|
||||
await processor.phase1_select_schemas(question)
|
||||
await processor.phase1_select_schemas(question, flow)
|
||||
|
||||
async def test_phase2_generate_graphql_success(self, processor):
|
||||
"""Test successful GraphQL generation (Phase 2)"""
|
||||
|
|
@ -125,14 +133,18 @@ class TestNLPQueryProcessor:
|
|||
error=None
|
||||
)
|
||||
|
||||
processor.client.return_value.request = AsyncMock(return_value=mock_response)
|
||||
# Mock flow context
|
||||
flow = MagicMock()
|
||||
mock_prompt_service = AsyncMock()
|
||||
mock_prompt_service.request = AsyncMock(return_value=mock_response)
|
||||
flow.side_effect = lambda service_name: mock_prompt_service if service_name == "prompt-request" else AsyncMock()
|
||||
|
||||
# Act
|
||||
result = await processor.phase2_generate_graphql(question, selected_schemas)
|
||||
result = await processor.phase2_generate_graphql(question, selected_schemas, flow)
|
||||
|
||||
# Assert
|
||||
assert result == expected_result
|
||||
processor.client.assert_called_once_with("prompt-request")
|
||||
mock_prompt_service.request.assert_called_once()
|
||||
|
||||
async def test_phase2_generate_graphql_prompt_error(self, processor):
|
||||
"""Test GraphQL generation with prompt service error"""
|
||||
|
|
@ -142,11 +154,15 @@ class TestNLPQueryProcessor:
|
|||
error = Error(type="prompt-error", message="Generation failed")
|
||||
mock_response = PromptResponse(text="", error=error)
|
||||
|
||||
processor.client.return_value.request = AsyncMock(return_value=mock_response)
|
||||
# Mock flow context
|
||||
flow = MagicMock()
|
||||
mock_prompt_service = AsyncMock()
|
||||
mock_prompt_service.request = AsyncMock(return_value=mock_response)
|
||||
flow.side_effect = lambda service_name: mock_prompt_service if service_name == "prompt-request" else AsyncMock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Prompt service error"):
|
||||
await processor.phase2_generate_graphql(question, selected_schemas)
|
||||
await processor.phase2_generate_graphql(question, selected_schemas, flow)
|
||||
|
||||
async def test_on_message_full_flow_success(self, processor):
|
||||
"""Test complete message processing flow"""
|
||||
|
|
@ -181,16 +197,18 @@ class TestNLPQueryProcessor:
|
|||
error=None
|
||||
)
|
||||
|
||||
# Set up mock to return different responses for each call
|
||||
processor.client.return_value.request = AsyncMock(
|
||||
# Mock flow context to return prompt service responses
|
||||
mock_prompt_service = AsyncMock()
|
||||
mock_prompt_service.request = AsyncMock(
|
||||
side_effect=[phase1_response, phase2_response]
|
||||
)
|
||||
flow.side_effect = lambda service_name: mock_prompt_service if service_name == "prompt-request" else flow_response if service_name == "response" else AsyncMock()
|
||||
|
||||
# Act
|
||||
await processor.on_message(msg, consumer, flow)
|
||||
|
||||
# Assert
|
||||
assert processor.client.return_value.request.call_count == 2
|
||||
assert mock_prompt_service.request.call_count == 2
|
||||
flow_response.send.assert_called_once()
|
||||
|
||||
# Verify response structure
|
||||
|
|
|
|||
|
|
@ -80,9 +80,17 @@ class TestStructuredQueryProcessor:
|
|||
mock_objects_client = AsyncMock()
|
||||
mock_objects_client.request.return_value = objects_response
|
||||
|
||||
processor.client.side_effect = lambda name: (
|
||||
mock_nlp_client if name == "nlp-query-request" else mock_objects_client
|
||||
)
|
||||
# Mock flow context to route to appropriate services
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -101,7 +109,7 @@ class TestStructuredQueryProcessor:
|
|||
assert isinstance(objects_call_args, ObjectsQueryRequest)
|
||||
assert objects_call_args.query == 'query { customers(where: {state: {eq: "NY"}}) { id name email } }'
|
||||
assert objects_call_args.variables == {"state": "NY"}
|
||||
assert objects_call_args.user == "default"
|
||||
assert objects_call_args.user == "trustgraph"
|
||||
assert objects_call_args.collection == "default"
|
||||
|
||||
# Verify response
|
||||
|
|
@ -142,7 +150,15 @@ class TestStructuredQueryProcessor:
|
|||
mock_nlp_client = AsyncMock()
|
||||
mock_nlp_client.request.return_value = nlp_response
|
||||
|
||||
processor.client.return_value = mock_nlp_client
|
||||
# Mock flow context to route to nlp service
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -185,7 +201,15 @@ class TestStructuredQueryProcessor:
|
|||
mock_nlp_client = AsyncMock()
|
||||
mock_nlp_client.request.return_value = nlp_response
|
||||
|
||||
processor.client.return_value = mock_nlp_client
|
||||
# Mock flow context to route to nlp service
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -237,9 +261,17 @@ class TestStructuredQueryProcessor:
|
|||
mock_objects_client = AsyncMock()
|
||||
mock_objects_client.request.return_value = objects_response
|
||||
|
||||
processor.client.side_effect = lambda name: (
|
||||
mock_nlp_client if name == "nlp-query-request" else mock_objects_client
|
||||
)
|
||||
# Mock flow context to route to appropriate services
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -300,9 +332,17 @@ class TestStructuredQueryProcessor:
|
|||
mock_objects_client = AsyncMock()
|
||||
mock_objects_client.request.return_value = objects_response
|
||||
|
||||
processor.client.side_effect = lambda name: (
|
||||
mock_nlp_client if name == "nlp-query-request" else mock_objects_client
|
||||
)
|
||||
# Mock flow context to route to appropriate services
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -370,9 +410,17 @@ class TestStructuredQueryProcessor:
|
|||
mock_objects_client = AsyncMock()
|
||||
mock_objects_client.request.return_value = objects_response
|
||||
|
||||
processor.client.side_effect = lambda name: (
|
||||
mock_nlp_client if name == "nlp-query-request" else mock_objects_client
|
||||
)
|
||||
# Mock flow context to route to appropriate services
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -427,9 +475,17 @@ class TestStructuredQueryProcessor:
|
|||
mock_objects_client = AsyncMock()
|
||||
mock_objects_client.request.return_value = objects_response
|
||||
|
||||
processor.client.side_effect = lambda name: (
|
||||
mock_nlp_client if name == "nlp-query-request" else mock_objects_client
|
||||
)
|
||||
# Mock flow context to route to appropriate services
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -457,10 +513,18 @@ class TestStructuredQueryProcessor:
|
|||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
# Mock client to raise exception
|
||||
# Mock flow context to raise exception
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.side_effect = Exception("Network timeout")
|
||||
processor.client.return_value = mock_client
|
||||
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await processor.on_message(msg, consumer, flow)
|
||||
|
|
|
|||
|
|
@ -122,7 +122,7 @@ class Processor(FlowProcessor):
|
|||
|
||||
logger.info(f"Schema configuration loaded: {len(self.schemas)} schemas")
|
||||
|
||||
async def phase1_select_schemas(self, question: str) -> List[str]:
|
||||
async def phase1_select_schemas(self, question: str, flow) -> List[str]:
|
||||
"""Phase 1: Use prompt service to select relevant schemas for the question"""
|
||||
logger.info("Starting Phase 1: Schema selection")
|
||||
|
||||
|
|
@ -144,20 +144,27 @@ class Processor(FlowProcessor):
|
|||
}
|
||||
|
||||
# Call prompt service for schema selection
|
||||
# Convert variables to JSON-encoded terms
|
||||
terms = {k: json.dumps(v) for k, v in variables.items()}
|
||||
prompt_request = PromptRequest(
|
||||
template=self.schema_selection_template,
|
||||
variables=variables
|
||||
id=self.schema_selection_template,
|
||||
terms=terms
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client("prompt-request").request(prompt_request)
|
||||
response = await flow("prompt-request").request(prompt_request)
|
||||
|
||||
if response.error is not None:
|
||||
raise Exception(f"Prompt service error: {response.error}")
|
||||
|
||||
# Parse the response to get selected schema names
|
||||
# Expecting response.text to contain JSON array of schema names
|
||||
selected_schemas = json.loads(response.text)
|
||||
# Response could be in either text or object field
|
||||
response_data = response.text if response.text else response.object
|
||||
if response_data is None:
|
||||
raise Exception("Prompt service returned empty response")
|
||||
|
||||
# Parse JSON array of schema names
|
||||
selected_schemas = json.loads(response_data)
|
||||
|
||||
logger.info(f"Phase 1 selected schemas: {selected_schemas}")
|
||||
return selected_schemas
|
||||
|
|
@ -166,7 +173,7 @@ class Processor(FlowProcessor):
|
|||
logger.error(f"Phase 1 schema selection failed: {e}")
|
||||
raise
|
||||
|
||||
async def phase2_generate_graphql(self, question: str, selected_schemas: List[str]) -> Dict[str, Any]:
|
||||
async def phase2_generate_graphql(self, question: str, selected_schemas: List[str], flow) -> Dict[str, Any]:
|
||||
"""Phase 2: Generate GraphQL query using selected schemas"""
|
||||
logger.info(f"Starting Phase 2: GraphQL generation for schemas: {selected_schemas}")
|
||||
|
||||
|
|
@ -200,20 +207,27 @@ class Processor(FlowProcessor):
|
|||
}
|
||||
|
||||
# Call prompt service for GraphQL generation
|
||||
# Convert variables to JSON-encoded terms
|
||||
terms = {k: json.dumps(v) for k, v in variables.items()}
|
||||
prompt_request = PromptRequest(
|
||||
template=self.graphql_generation_template,
|
||||
variables=variables
|
||||
id=self.graphql_generation_template,
|
||||
terms=terms
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client("prompt-request").request(prompt_request)
|
||||
response = await flow("prompt-request").request(prompt_request)
|
||||
|
||||
if response.error is not None:
|
||||
raise Exception(f"Prompt service error: {response.error}")
|
||||
|
||||
# Parse the response to get GraphQL query and variables
|
||||
# Expecting response.text to contain JSON with "query" and "variables" fields
|
||||
result = json.loads(response.text)
|
||||
# Response could be in either text or object field
|
||||
response_data = response.text if response.text else response.object
|
||||
if response_data is None:
|
||||
raise Exception("Prompt service returned empty response")
|
||||
|
||||
# Parse JSON with "query" and "variables" fields
|
||||
result = json.loads(response_data)
|
||||
|
||||
logger.info(f"Phase 2 generated GraphQL: {result.get('query', '')[:100]}...")
|
||||
return result
|
||||
|
|
@ -234,10 +248,10 @@ class Processor(FlowProcessor):
|
|||
logger.info(f"Handling NLP query request {id}: {request.question[:100]}...")
|
||||
|
||||
# Phase 1: Select relevant schemas
|
||||
selected_schemas = await self.phase1_select_schemas(request.question)
|
||||
selected_schemas = await self.phase1_select_schemas(request.question, flow)
|
||||
|
||||
# Phase 2: Generate GraphQL query
|
||||
graphql_result = await self.phase2_generate_graphql(request.question, selected_schemas)
|
||||
graphql_result = await self.phase2_generate_graphql(request.question, selected_schemas, flow)
|
||||
|
||||
# Create response
|
||||
response = QuestionToStructuredQueryResponse(
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ class Processor(FlowProcessor):
|
|||
max_results=100 # Default limit
|
||||
)
|
||||
|
||||
nlp_response = await self.client("nlp-query-request").request(nlp_request)
|
||||
nlp_response = await flow("nlp-query-request").request(nlp_request)
|
||||
|
||||
if nlp_response.error is not None:
|
||||
raise Exception(f"NLP query service error: {nlp_response.error.message}")
|
||||
|
|
@ -111,17 +111,17 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
variables_as_strings[key] = str(value)
|
||||
|
||||
# For now, we'll use default user/collection values
|
||||
# In a real implementation, these would come from authentication/context
|
||||
# Use standard TrustGraph user/collection values
|
||||
# These should eventually come from authentication/context
|
||||
objects_request = ObjectsQueryRequest(
|
||||
user="default", # TODO: Get from authentication context
|
||||
collection="default", # TODO: Get from request context
|
||||
user="trustgraph", # Standard TrustGraph user
|
||||
collection="default", # Standard default collection
|
||||
query=nlp_response.graphql_query,
|
||||
variables=variables_as_strings,
|
||||
operation_name=None
|
||||
)
|
||||
|
||||
objects_response = await self.client("objects-query-request").request(objects_request)
|
||||
objects_response = await flow("objects-query-request").request(objects_request)
|
||||
|
||||
if objects_response.error is not None:
|
||||
raise Exception(f"Objects query service error: {objects_response.error.message}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue