mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-27 01:16:22 +02:00
Fix/sys integration issues (#494)
* Fix integration issues * Fix query defaults * Fix tests
This commit is contained in:
parent
ed0e02791d
commit
50c37407c5
6 changed files with 341 additions and 133 deletions
|
|
@ -87,14 +87,18 @@ class TestNLPQueryProcessor:
|
|||
error=None
|
||||
)
|
||||
|
||||
processor.client.return_value.request = AsyncMock(return_value=mock_response)
|
||||
# Mock flow context
|
||||
flow = MagicMock()
|
||||
mock_prompt_service = AsyncMock()
|
||||
mock_prompt_service.request = AsyncMock(return_value=mock_response)
|
||||
flow.side_effect = lambda service_name: mock_prompt_service if service_name == "prompt-request" else AsyncMock()
|
||||
|
||||
# Act
|
||||
result = await processor.phase1_select_schemas(question)
|
||||
result = await processor.phase1_select_schemas(question, flow)
|
||||
|
||||
# Assert
|
||||
assert result == expected_schemas
|
||||
processor.client.assert_called_once_with("prompt-request")
|
||||
mock_prompt_service.request.assert_called_once()
|
||||
|
||||
async def test_phase1_select_schemas_prompt_error(self, processor):
|
||||
"""Test schema selection with prompt service error"""
|
||||
|
|
@ -103,11 +107,15 @@ class TestNLPQueryProcessor:
|
|||
error = Error(type="prompt-error", message="Template not found")
|
||||
mock_response = PromptResponse(text="", error=error)
|
||||
|
||||
processor.client.return_value.request = AsyncMock(return_value=mock_response)
|
||||
# Mock flow context
|
||||
flow = MagicMock()
|
||||
mock_prompt_service = AsyncMock()
|
||||
mock_prompt_service.request = AsyncMock(return_value=mock_response)
|
||||
flow.side_effect = lambda service_name: mock_prompt_service if service_name == "prompt-request" else AsyncMock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Prompt service error"):
|
||||
await processor.phase1_select_schemas(question)
|
||||
await processor.phase1_select_schemas(question, flow)
|
||||
|
||||
async def test_phase2_generate_graphql_success(self, processor):
|
||||
"""Test successful GraphQL generation (Phase 2)"""
|
||||
|
|
@ -125,14 +133,18 @@ class TestNLPQueryProcessor:
|
|||
error=None
|
||||
)
|
||||
|
||||
processor.client.return_value.request = AsyncMock(return_value=mock_response)
|
||||
# Mock flow context
|
||||
flow = MagicMock()
|
||||
mock_prompt_service = AsyncMock()
|
||||
mock_prompt_service.request = AsyncMock(return_value=mock_response)
|
||||
flow.side_effect = lambda service_name: mock_prompt_service if service_name == "prompt-request" else AsyncMock()
|
||||
|
||||
# Act
|
||||
result = await processor.phase2_generate_graphql(question, selected_schemas)
|
||||
result = await processor.phase2_generate_graphql(question, selected_schemas, flow)
|
||||
|
||||
# Assert
|
||||
assert result == expected_result
|
||||
processor.client.assert_called_once_with("prompt-request")
|
||||
mock_prompt_service.request.assert_called_once()
|
||||
|
||||
async def test_phase2_generate_graphql_prompt_error(self, processor):
|
||||
"""Test GraphQL generation with prompt service error"""
|
||||
|
|
@ -142,11 +154,15 @@ class TestNLPQueryProcessor:
|
|||
error = Error(type="prompt-error", message="Generation failed")
|
||||
mock_response = PromptResponse(text="", error=error)
|
||||
|
||||
processor.client.return_value.request = AsyncMock(return_value=mock_response)
|
||||
# Mock flow context
|
||||
flow = MagicMock()
|
||||
mock_prompt_service = AsyncMock()
|
||||
mock_prompt_service.request = AsyncMock(return_value=mock_response)
|
||||
flow.side_effect = lambda service_name: mock_prompt_service if service_name == "prompt-request" else AsyncMock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Prompt service error"):
|
||||
await processor.phase2_generate_graphql(question, selected_schemas)
|
||||
await processor.phase2_generate_graphql(question, selected_schemas, flow)
|
||||
|
||||
async def test_on_message_full_flow_success(self, processor):
|
||||
"""Test complete message processing flow"""
|
||||
|
|
@ -181,16 +197,18 @@ class TestNLPQueryProcessor:
|
|||
error=None
|
||||
)
|
||||
|
||||
# Set up mock to return different responses for each call
|
||||
processor.client.return_value.request = AsyncMock(
|
||||
# Mock flow context to return prompt service responses
|
||||
mock_prompt_service = AsyncMock()
|
||||
mock_prompt_service.request = AsyncMock(
|
||||
side_effect=[phase1_response, phase2_response]
|
||||
)
|
||||
flow.side_effect = lambda service_name: mock_prompt_service if service_name == "prompt-request" else flow_response if service_name == "response" else AsyncMock()
|
||||
|
||||
# Act
|
||||
await processor.on_message(msg, consumer, flow)
|
||||
|
||||
# Assert
|
||||
assert processor.client.return_value.request.call_count == 2
|
||||
assert mock_prompt_service.request.call_count == 2
|
||||
flow_response.send.assert_called_once()
|
||||
|
||||
# Verify response structure
|
||||
|
|
|
|||
|
|
@ -80,9 +80,17 @@ class TestStructuredQueryProcessor:
|
|||
mock_objects_client = AsyncMock()
|
||||
mock_objects_client.request.return_value = objects_response
|
||||
|
||||
processor.client.side_effect = lambda name: (
|
||||
mock_nlp_client if name == "nlp-query-request" else mock_objects_client
|
||||
)
|
||||
# Mock flow context to route to appropriate services
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -101,7 +109,7 @@ class TestStructuredQueryProcessor:
|
|||
assert isinstance(objects_call_args, ObjectsQueryRequest)
|
||||
assert objects_call_args.query == 'query { customers(where: {state: {eq: "NY"}}) { id name email } }'
|
||||
assert objects_call_args.variables == {"state": "NY"}
|
||||
assert objects_call_args.user == "default"
|
||||
assert objects_call_args.user == "trustgraph"
|
||||
assert objects_call_args.collection == "default"
|
||||
|
||||
# Verify response
|
||||
|
|
@ -142,7 +150,15 @@ class TestStructuredQueryProcessor:
|
|||
mock_nlp_client = AsyncMock()
|
||||
mock_nlp_client.request.return_value = nlp_response
|
||||
|
||||
processor.client.return_value = mock_nlp_client
|
||||
# Mock flow context to route to nlp service
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -185,7 +201,15 @@ class TestStructuredQueryProcessor:
|
|||
mock_nlp_client = AsyncMock()
|
||||
mock_nlp_client.request.return_value = nlp_response
|
||||
|
||||
processor.client.return_value = mock_nlp_client
|
||||
# Mock flow context to route to nlp service
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -237,9 +261,17 @@ class TestStructuredQueryProcessor:
|
|||
mock_objects_client = AsyncMock()
|
||||
mock_objects_client.request.return_value = objects_response
|
||||
|
||||
processor.client.side_effect = lambda name: (
|
||||
mock_nlp_client if name == "nlp-query-request" else mock_objects_client
|
||||
)
|
||||
# Mock flow context to route to appropriate services
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -300,9 +332,17 @@ class TestStructuredQueryProcessor:
|
|||
mock_objects_client = AsyncMock()
|
||||
mock_objects_client.request.return_value = objects_response
|
||||
|
||||
processor.client.side_effect = lambda name: (
|
||||
mock_nlp_client if name == "nlp-query-request" else mock_objects_client
|
||||
)
|
||||
# Mock flow context to route to appropriate services
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -370,9 +410,17 @@ class TestStructuredQueryProcessor:
|
|||
mock_objects_client = AsyncMock()
|
||||
mock_objects_client.request.return_value = objects_response
|
||||
|
||||
processor.client.side_effect = lambda name: (
|
||||
mock_nlp_client if name == "nlp-query-request" else mock_objects_client
|
||||
)
|
||||
# Mock flow context to route to appropriate services
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -427,9 +475,17 @@ class TestStructuredQueryProcessor:
|
|||
mock_objects_client = AsyncMock()
|
||||
mock_objects_client.request.return_value = objects_response
|
||||
|
||||
processor.client.side_effect = lambda name: (
|
||||
mock_nlp_client if name == "nlp-query-request" else mock_objects_client
|
||||
)
|
||||
# Mock flow context to route to appropriate services
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_nlp_client
|
||||
elif service_name == "objects-query-request":
|
||||
return mock_objects_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await processor.on_message(msg, consumer, flow)
|
||||
|
|
@ -457,10 +513,18 @@ class TestStructuredQueryProcessor:
|
|||
flow_response = AsyncMock()
|
||||
flow.return_value = flow_response
|
||||
|
||||
# Mock client to raise exception
|
||||
# Mock flow context to raise exception
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request.side_effect = Exception("Network timeout")
|
||||
processor.client.return_value = mock_client
|
||||
|
||||
def flow_router(service_name):
|
||||
if service_name == "nlp-query-request":
|
||||
return mock_client
|
||||
elif service_name == "response":
|
||||
return flow_response
|
||||
else:
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Act
|
||||
await processor.on_message(msg, consumer, flow)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue