mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Fix non streaming RAG problems (#607)
* Fix non-streaming failure in RAG services * Fix non-streaming failure in API * Fix agent non-streaming messaging * Agent messaging unit & contract tests
This commit is contained in:
parent
30ca1d2e8b
commit
807f6cc4e2
10 changed files with 677 additions and 21 deletions
242
tests/contract/test_translator_completion_flags.py
Normal file
242
tests/contract/test_translator_completion_flags.py
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
"""
|
||||
Contract tests for message translator completion flag behavior.
|
||||
|
||||
These tests verify that translators correctly compute the is_final flag
|
||||
based on message fields like end_of_stream and end_of_dialog.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.schema import (
|
||||
GraphRagResponse, DocumentRagResponse, AgentResponse, Error
|
||||
)
|
||||
from trustgraph.messaging import TranslatorRegistry
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestRAGTranslatorCompletionFlags:
|
||||
"""Contract tests for RAG response translator completion flags"""
|
||||
|
||||
def test_graph_rag_translator_is_final_with_end_of_stream_true(self):
|
||||
"""
|
||||
Test that GraphRagResponseTranslator returns is_final=True
|
||||
when end_of_stream=True.
|
||||
"""
|
||||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("graph-rag")
|
||||
response = GraphRagResponse(
|
||||
response="A small domesticated mammal.",
|
||||
end_of_stream=True,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True, "is_final must be True when end_of_stream=True"
|
||||
assert response_dict["response"] == "A small domesticated mammal."
|
||||
assert response_dict["end_of_stream"] is True
|
||||
|
||||
def test_graph_rag_translator_is_final_with_end_of_stream_false(self):
|
||||
"""
|
||||
Test that GraphRagResponseTranslator returns is_final=False
|
||||
when end_of_stream=False.
|
||||
"""
|
||||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("graph-rag")
|
||||
response = GraphRagResponse(
|
||||
response="Chunk 1",
|
||||
end_of_stream=False,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is False, "is_final must be False when end_of_stream=False"
|
||||
assert response_dict["response"] == "Chunk 1"
|
||||
assert response_dict["end_of_stream"] is False
|
||||
|
||||
def test_document_rag_translator_is_final_with_end_of_stream_true(self):
|
||||
"""
|
||||
Test that DocumentRagResponseTranslator returns is_final=True
|
||||
when end_of_stream=True.
|
||||
"""
|
||||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("document-rag")
|
||||
response = DocumentRagResponse(
|
||||
response="A document about cats.",
|
||||
end_of_stream=True,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True, "is_final must be True when end_of_stream=True"
|
||||
assert response_dict["response"] == "A document about cats."
|
||||
assert response_dict["end_of_stream"] is True
|
||||
|
||||
def test_document_rag_translator_is_final_with_end_of_stream_false(self):
|
||||
"""
|
||||
Test that DocumentRagResponseTranslator returns is_final=False
|
||||
when end_of_stream=False.
|
||||
"""
|
||||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("document-rag")
|
||||
response = DocumentRagResponse(
|
||||
response="Chunk 1",
|
||||
end_of_stream=False,
|
||||
error=None
|
||||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is False, "is_final must be False when end_of_stream=False"
|
||||
assert response_dict["response"] == "Chunk 1"
|
||||
assert response_dict["end_of_stream"] is False
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestAgentTranslatorCompletionFlags:
|
||||
"""Contract tests for Agent response translator completion flags"""
|
||||
|
||||
def test_agent_translator_is_final_with_end_of_dialog_true(self):
|
||||
"""
|
||||
Test that AgentResponseTranslator returns is_final=True
|
||||
when end_of_dialog=True.
|
||||
"""
|
||||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("agent")
|
||||
response = AgentResponse(
|
||||
answer="4",
|
||||
error=None,
|
||||
thought=None,
|
||||
observation=None,
|
||||
end_of_message=True,
|
||||
end_of_dialog=True
|
||||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True, "is_final must be True when end_of_dialog=True"
|
||||
assert response_dict["answer"] == "4"
|
||||
assert response_dict["end_of_dialog"] is True
|
||||
|
||||
def test_agent_translator_is_final_with_end_of_dialog_false(self):
|
||||
"""
|
||||
Test that AgentResponseTranslator returns is_final=False
|
||||
when end_of_dialog=False.
|
||||
"""
|
||||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("agent")
|
||||
response = AgentResponse(
|
||||
answer=None,
|
||||
error=None,
|
||||
thought="I need to solve this.",
|
||||
observation=None,
|
||||
end_of_message=True,
|
||||
end_of_dialog=False
|
||||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is False, "is_final must be False when end_of_dialog=False"
|
||||
assert response_dict["thought"] == "I need to solve this."
|
||||
assert response_dict["end_of_dialog"] is False
|
||||
|
||||
def test_agent_translator_is_final_fallback_with_answer(self):
|
||||
"""
|
||||
Test that AgentResponseTranslator returns is_final=True
|
||||
when answer is present (fallback for legacy responses).
|
||||
"""
|
||||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("agent")
|
||||
# Legacy response without end_of_dialog flag
|
||||
response = AgentResponse(
|
||||
answer="4",
|
||||
error=None,
|
||||
thought=None,
|
||||
observation=None
|
||||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True, "is_final must be True when answer is present (legacy fallback)"
|
||||
assert response_dict["answer"] == "4"
|
||||
|
||||
def test_agent_translator_intermediate_message_is_not_final(self):
|
||||
"""
|
||||
Test that intermediate messages (thought/observation) return is_final=False.
|
||||
"""
|
||||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("agent")
|
||||
|
||||
# Test thought message
|
||||
thought_response = AgentResponse(
|
||||
answer=None,
|
||||
error=None,
|
||||
thought="Processing...",
|
||||
observation=None,
|
||||
end_of_message=True,
|
||||
end_of_dialog=False
|
||||
)
|
||||
|
||||
# Act
|
||||
thought_dict, thought_is_final = translator.from_response_with_completion(thought_response)
|
||||
|
||||
# Assert
|
||||
assert thought_is_final is False, "Thought message must not be final"
|
||||
|
||||
# Test observation message
|
||||
observation_response = AgentResponse(
|
||||
answer=None,
|
||||
error=None,
|
||||
thought=None,
|
||||
observation="Result found",
|
||||
end_of_message=True,
|
||||
end_of_dialog=False
|
||||
)
|
||||
|
||||
# Act
|
||||
obs_dict, obs_is_final = translator.from_response_with_completion(observation_response)
|
||||
|
||||
# Assert
|
||||
assert obs_is_final is False, "Observation message must not be final"
|
||||
|
||||
def test_agent_translator_streaming_format_with_end_of_dialog(self):
|
||||
"""
|
||||
Test that streaming format messages use end_of_dialog for is_final.
|
||||
"""
|
||||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("agent")
|
||||
|
||||
# Streaming format with end_of_dialog=True
|
||||
response = AgentResponse(
|
||||
chunk_type="answer",
|
||||
content="",
|
||||
end_of_message=True,
|
||||
end_of_dialog=True,
|
||||
answer=None,
|
||||
error=None,
|
||||
thought=None,
|
||||
observation=None
|
||||
)
|
||||
|
||||
# Act
|
||||
response_dict, is_final = translator.from_response_with_completion(response)
|
||||
|
||||
# Assert
|
||||
assert is_final is True, "Streaming format must use end_of_dialog for is_final"
|
||||
assert response_dict["end_of_dialog"] is True
|
||||
206
tests/unit/test_agent/test_agent_service_non_streaming.py
Normal file
206
tests/unit/test_agent/test_agent_service_non_streaming.py
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
"""
|
||||
Unit tests for Agent service non-streaming mode.
|
||||
Tests that end_of_message and end_of_dialog flags are correctly set.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from trustgraph.agent.react.service import Processor
|
||||
from trustgraph.schema import AgentRequest, AgentResponse
|
||||
from trustgraph.agent.react.types import Final
|
||||
|
||||
|
||||
class TestAgentServiceNonStreaming:
|
||||
"""Test Agent service non-streaming behavior"""
|
||||
|
||||
@patch('trustgraph.agent.react.service.AgentManager')
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_intermediate_messages_have_correct_flags(self, mock_agent_manager_class):
|
||||
"""
|
||||
Test that intermediate messages (thought/observation) in non-streaming mode
|
||||
have end_of_message=True and end_of_dialog=False.
|
||||
"""
|
||||
# Setup processor
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id="test-agent",
|
||||
max_iterations=10
|
||||
)
|
||||
|
||||
# Track all responses sent
|
||||
sent_responses = []
|
||||
|
||||
# Setup mock agent manager
|
||||
mock_agent_instance = AsyncMock()
|
||||
mock_agent_manager_class.return_value = mock_agent_instance
|
||||
|
||||
# Mock react to call think and observe callbacks
|
||||
async def mock_react(question, history, think, observe, answer, context, streaming):
|
||||
await think("I need to solve this.", is_final=True)
|
||||
await observe("The answer is 4.", is_final=True)
|
||||
return Final(thought="Final answer", final="4")
|
||||
|
||||
mock_agent_instance.react = mock_react
|
||||
|
||||
# Setup message with non-streaming request
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = AgentRequest(
|
||||
question="What is 2 + 2?",
|
||||
user="trustgraph",
|
||||
streaming=False # Non-streaming mode
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
||||
# Setup flow mock
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
|
||||
mock_producer = AsyncMock()
|
||||
|
||||
async def capture_response(response, properties):
|
||||
sent_responses.append(response)
|
||||
|
||||
mock_producer.send = AsyncMock(side_effect=capture_response)
|
||||
|
||||
def flow_router(service_name):
|
||||
if service_name == "response":
|
||||
return mock_producer
|
||||
return AsyncMock()
|
||||
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: should have 3 responses (thought, observation, answer)
|
||||
assert len(sent_responses) == 3, f"Expected 3 responses, got {len(sent_responses)}"
|
||||
|
||||
# Check thought message
|
||||
thought_response = sent_responses[0]
|
||||
assert isinstance(thought_response, AgentResponse)
|
||||
assert thought_response.thought == "I need to solve this."
|
||||
assert thought_response.answer is None
|
||||
assert thought_response.end_of_message is True, "Thought message must have end_of_message=True"
|
||||
assert thought_response.end_of_dialog is False, "Thought message must have end_of_dialog=False"
|
||||
|
||||
# Check observation message
|
||||
observation_response = sent_responses[1]
|
||||
assert isinstance(observation_response, AgentResponse)
|
||||
assert observation_response.observation == "The answer is 4."
|
||||
assert observation_response.answer is None
|
||||
assert observation_response.end_of_message is True, "Observation message must have end_of_message=True"
|
||||
assert observation_response.end_of_dialog is False, "Observation message must have end_of_dialog=False"
|
||||
|
||||
@patch('trustgraph.agent.react.service.AgentManager')
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_final_answer_has_correct_flags(self, mock_agent_manager_class):
|
||||
"""
|
||||
Test that final answer in non-streaming mode has
|
||||
end_of_message=True and end_of_dialog=True.
|
||||
"""
|
||||
# Setup processor
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id="test-agent",
|
||||
max_iterations=10
|
||||
)
|
||||
|
||||
# Track all responses sent
|
||||
sent_responses = []
|
||||
|
||||
# Setup mock agent manager
|
||||
mock_agent_instance = AsyncMock()
|
||||
mock_agent_manager_class.return_value = mock_agent_instance
|
||||
|
||||
# Mock react to return Final directly
|
||||
async def mock_react(question, history, think, observe, answer, context, streaming):
|
||||
return Final(thought="Final answer", final="4")
|
||||
|
||||
mock_agent_instance.react = mock_react
|
||||
|
||||
# Setup message with non-streaming request
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = AgentRequest(
|
||||
question="What is 2 + 2?",
|
||||
user="trustgraph",
|
||||
streaming=False # Non-streaming mode
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
||||
# Setup flow mock
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
|
||||
mock_producer = AsyncMock()
|
||||
|
||||
async def capture_response(response, properties):
|
||||
sent_responses.append(response)
|
||||
|
||||
mock_producer.send = AsyncMock(side_effect=capture_response)
|
||||
|
||||
def flow_router(service_name):
|
||||
if service_name == "response":
|
||||
return mock_producer
|
||||
return AsyncMock()
|
||||
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: should have 1 response (final answer)
|
||||
assert len(sent_responses) == 1, f"Expected 1 response, got {len(sent_responses)}"
|
||||
|
||||
# Check final answer message
|
||||
answer_response = sent_responses[0]
|
||||
assert isinstance(answer_response, AgentResponse)
|
||||
assert answer_response.answer == "4"
|
||||
assert answer_response.thought is None
|
||||
assert answer_response.observation is None
|
||||
assert answer_response.end_of_message is True, "Final answer must have end_of_message=True"
|
||||
assert answer_response.end_of_dialog is True, "Final answer must have end_of_dialog=True"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_response_has_correct_flags(self):
|
||||
"""
|
||||
Test that error responses have end_of_message=True and end_of_dialog=True.
|
||||
"""
|
||||
# Setup processor that will error
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id="test-agent",
|
||||
max_iterations=10
|
||||
)
|
||||
|
||||
# Track all responses sent
|
||||
sent_responses = []
|
||||
|
||||
# Setup message
|
||||
msg = MagicMock()
|
||||
msg.value.side_effect = Exception("Test error")
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
||||
# Setup flow mock
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
flow.producer = {"response": AsyncMock()}
|
||||
|
||||
async def capture_response(response, properties):
|
||||
sent_responses.append(response)
|
||||
|
||||
flow.producer["response"].send = AsyncMock(side_effect=capture_response)
|
||||
|
||||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: should have 1 error response
|
||||
assert len(sent_responses) == 1, f"Expected 1 error response, got {len(sent_responses)}"
|
||||
|
||||
# Check error response
|
||||
error_response = sent_responses[0]
|
||||
assert isinstance(error_response, AgentResponse)
|
||||
assert error_response.error is not None
|
||||
assert "Test error" in error_response.error.message
|
||||
assert error_response.end_of_message is True, "Error response must have end_of_message=True"
|
||||
assert error_response.end_of_dialog is True, "Error response must have end_of_dialog=True"
|
||||
|
|
@ -74,4 +74,58 @@ class TestDocumentRagService:
|
|||
sent_response = mock_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_response, DocumentRagResponse)
|
||||
assert sent_response.response == "test response"
|
||||
assert sent_response.error is None
|
||||
|
||||
@patch('trustgraph.retrieval.document_rag.rag.DocumentRag')
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_mode_sets_end_of_stream_true(self, mock_document_rag_class):
|
||||
"""
|
||||
Test that non-streaming mode sets end_of_stream=True in response.
|
||||
|
||||
This is a regression test for the bug where non-streaming responses
|
||||
didn't set end_of_stream, causing clients to hang waiting for more data.
|
||||
"""
|
||||
# Setup processor
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id="test-processor",
|
||||
doc_limit=10
|
||||
)
|
||||
|
||||
# Setup mock DocumentRag instance
|
||||
mock_rag_instance = AsyncMock()
|
||||
mock_document_rag_class.return_value = mock_rag_instance
|
||||
mock_rag_instance.query.return_value = "A document about cats."
|
||||
|
||||
# Setup message with non-streaming request
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = DocumentRagQuery(
|
||||
query="What is a cat?",
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
doc_limit=10,
|
||||
streaming=False # Non-streaming mode
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
||||
# Setup flow mock
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
|
||||
mock_producer = AsyncMock()
|
||||
def flow_router(service_name):
|
||||
if service_name == "response":
|
||||
return mock_producer
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: response was sent with end_of_stream=True
|
||||
mock_producer.send.assert_called_once()
|
||||
sent_response = mock_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_response, DocumentRagResponse)
|
||||
assert sent_response.response == "A document about cats."
|
||||
assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True"
|
||||
assert sent_response.error is None
|
||||
134
tests/unit/test_retrieval/test_graph_rag_service.py
Normal file
134
tests/unit/test_retrieval/test_graph_rag_service.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
"""
|
||||
Unit tests for GraphRAG service non-streaming mode.
|
||||
Tests that end_of_stream flag is correctly set in non-streaming responses.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from trustgraph.retrieval.graph_rag.rag import Processor
|
||||
from trustgraph.schema import GraphRagQuery, GraphRagResponse
|
||||
|
||||
|
||||
class TestGraphRagService:
|
||||
"""Test GraphRAG service non-streaming behavior"""
|
||||
|
||||
@patch('trustgraph.retrieval.graph_rag.rag.GraphRag')
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_mode_sets_end_of_stream_true(self, mock_graph_rag_class):
|
||||
"""
|
||||
Test that non-streaming mode sets end_of_stream=True in response.
|
||||
|
||||
This is a regression test for the bug where non-streaming responses
|
||||
didn't set end_of_stream, causing clients to hang waiting for more data.
|
||||
"""
|
||||
# Setup processor
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id="test-processor",
|
||||
entity_limit=50,
|
||||
triple_limit=30,
|
||||
max_subgraph_size=150,
|
||||
max_path_length=2
|
||||
)
|
||||
|
||||
# Setup mock GraphRag instance
|
||||
mock_rag_instance = AsyncMock()
|
||||
mock_graph_rag_class.return_value = mock_rag_instance
|
||||
mock_rag_instance.query.return_value = "A small domesticated mammal."
|
||||
|
||||
# Setup message with non-streaming request
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = GraphRagQuery(
|
||||
query="What is a cat?",
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
entity_limit=50,
|
||||
triple_limit=30,
|
||||
max_subgraph_size=150,
|
||||
max_path_length=2,
|
||||
streaming=False # Non-streaming mode
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
||||
# Setup flow mock
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
|
||||
# Mock flow to return AsyncMock for clients and response producer
|
||||
mock_producer = AsyncMock()
|
||||
def flow_router(service_name):
|
||||
if service_name == "response":
|
||||
return mock_producer
|
||||
return AsyncMock() # embeddings, graph-embeddings, triples, prompt clients
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: response was sent with end_of_stream=True
|
||||
mock_producer.send.assert_called_once()
|
||||
sent_response = mock_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_response, GraphRagResponse)
|
||||
assert sent_response.response == "A small domesticated mammal."
|
||||
assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True"
|
||||
assert sent_response.error is None
|
||||
|
||||
@patch('trustgraph.retrieval.graph_rag.rag.GraphRag')
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_response_in_non_streaming_mode(self, mock_graph_rag_class):
|
||||
"""
|
||||
Test that error responses in non-streaming mode set end_of_stream=True.
|
||||
"""
|
||||
# Setup processor
|
||||
processor = Processor(
|
||||
taskgroup=MagicMock(),
|
||||
id="test-processor",
|
||||
entity_limit=50,
|
||||
triple_limit=30,
|
||||
max_subgraph_size=150,
|
||||
max_path_length=2
|
||||
)
|
||||
|
||||
# Setup mock GraphRag instance that raises an exception
|
||||
mock_rag_instance = AsyncMock()
|
||||
mock_graph_rag_class.return_value = mock_rag_instance
|
||||
mock_rag_instance.query.side_effect = Exception("Test error")
|
||||
|
||||
# Setup message with non-streaming request
|
||||
msg = MagicMock()
|
||||
msg.value.return_value = GraphRagQuery(
|
||||
query="What is a cat?",
|
||||
user="trustgraph",
|
||||
collection="default",
|
||||
entity_limit=50,
|
||||
triple_limit=30,
|
||||
max_subgraph_size=150,
|
||||
max_path_length=2,
|
||||
streaming=False # Non-streaming mode
|
||||
)
|
||||
msg.properties.return_value = {"id": "test-id"}
|
||||
|
||||
# Setup flow mock
|
||||
consumer = MagicMock()
|
||||
flow = MagicMock()
|
||||
|
||||
mock_producer = AsyncMock()
|
||||
def flow_router(service_name):
|
||||
if service_name == "response":
|
||||
return mock_producer
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_router
|
||||
|
||||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: error response was sent without end_of_stream (not streaming mode)
|
||||
mock_producer.send.assert_called_once()
|
||||
sent_response = mock_producer.send.call_args[0][0]
|
||||
assert isinstance(sent_response, GraphRagResponse)
|
||||
assert sent_response.response is None
|
||||
assert sent_response.error is not None
|
||||
assert sent_response.error.message == "Test error"
|
||||
# Note: error responses in non-streaming mode don't set end_of_stream
|
||||
# because streaming was never started
|
||||
|
|
@ -275,13 +275,17 @@ class SocketFlowInstance:
|
|||
result = self.client._send_request_sync("text-completion", self.flow_id, request, streaming)
|
||||
|
||||
if streaming:
|
||||
# For text completion, yield just the content
|
||||
for chunk in result:
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
# For text completion, return generator that yields content
|
||||
return self._text_completion_generator(result)
|
||||
else:
|
||||
return result.get("response", "")
|
||||
|
||||
def _text_completion_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]:
|
||||
"""Generator for text completion streaming"""
|
||||
for chunk in result:
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
|
||||
def graph_rag(
|
||||
self,
|
||||
query: str,
|
||||
|
|
@ -308,9 +312,7 @@ class SocketFlowInstance:
|
|||
result = self.client._send_request_sync("graph-rag", self.flow_id, request, streaming)
|
||||
|
||||
if streaming:
|
||||
for chunk in result:
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
return self._rag_generator(result)
|
||||
else:
|
||||
return result.get("response", "")
|
||||
|
||||
|
|
@ -336,12 +338,16 @@ class SocketFlowInstance:
|
|||
result = self.client._send_request_sync("document-rag", self.flow_id, request, streaming)
|
||||
|
||||
if streaming:
|
||||
for chunk in result:
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
return self._rag_generator(result)
|
||||
else:
|
||||
return result.get("response", "")
|
||||
|
||||
def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]:
|
||||
"""Generator for RAG streaming (graph-rag and document-rag)"""
|
||||
for chunk in result:
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
|
||||
def prompt(
|
||||
self,
|
||||
id: str,
|
||||
|
|
@ -360,9 +366,7 @@ class SocketFlowInstance:
|
|||
result = self.client._send_request_sync("prompt", self.flow_id, request, streaming)
|
||||
|
||||
if streaming:
|
||||
for chunk in result:
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
return self._rag_generator(result)
|
||||
else:
|
||||
return result.get("response", "")
|
||||
|
||||
|
|
|
|||
|
|
@ -48,13 +48,13 @@ class AgentService(FlowProcessor):
|
|||
|
||||
async def on_request(self, msg, consumer, flow):
|
||||
|
||||
# Get ID early so error handler can use it
|
||||
id = msg.properties().get("id", "unknown")
|
||||
|
||||
try:
|
||||
|
||||
request = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
async def respond(resp):
|
||||
|
||||
await flow("response").send(
|
||||
|
|
@ -93,6 +93,8 @@ class AgentService(FlowProcessor):
|
|||
thought = None,
|
||||
observation = None,
|
||||
answer = None,
|
||||
end_of_message = True,
|
||||
end_of_dialog = True,
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -44,13 +44,16 @@ class AgentResponseTranslator(MessageTranslator):
|
|||
result["end_of_message"] = getattr(obj, "end_of_message", False)
|
||||
result["end_of_dialog"] = getattr(obj, "end_of_dialog", False)
|
||||
else:
|
||||
# Legacy format
|
||||
# Legacy format (non-streaming)
|
||||
if obj.answer:
|
||||
result["answer"] = obj.answer
|
||||
if obj.thought:
|
||||
result["thought"] = obj.thought
|
||||
if obj.observation:
|
||||
result["observation"] = obj.observation
|
||||
# Include completion flags for legacy format too
|
||||
result["end_of_message"] = getattr(obj, "end_of_message", False)
|
||||
result["end_of_dialog"] = getattr(obj, "end_of_dialog", False)
|
||||
|
||||
# Always include error if present
|
||||
if hasattr(obj, 'error') and obj.error and obj.error.message:
|
||||
|
|
|
|||
|
|
@ -232,12 +232,14 @@ class Processor(AgentService):
|
|||
observation=None,
|
||||
)
|
||||
else:
|
||||
# Legacy format
|
||||
# Non-streaming format
|
||||
r = AgentResponse(
|
||||
answer=None,
|
||||
error=None,
|
||||
thought=x,
|
||||
observation=None,
|
||||
end_of_message=True,
|
||||
end_of_dialog=False,
|
||||
)
|
||||
|
||||
await respond(r)
|
||||
|
|
@ -260,12 +262,14 @@ class Processor(AgentService):
|
|||
observation=x,
|
||||
)
|
||||
else:
|
||||
# Legacy format
|
||||
# Non-streaming format
|
||||
r = AgentResponse(
|
||||
answer=None,
|
||||
error=None,
|
||||
thought=None,
|
||||
observation=x,
|
||||
end_of_message=True,
|
||||
end_of_dialog=False,
|
||||
)
|
||||
|
||||
await respond(r)
|
||||
|
|
@ -288,12 +292,14 @@ class Processor(AgentService):
|
|||
observation=None,
|
||||
)
|
||||
else:
|
||||
# Legacy format - shouldn't be called in non-streaming mode
|
||||
# Non-streaming format - shouldn't normally be called
|
||||
r = AgentResponse(
|
||||
answer=x,
|
||||
error=None,
|
||||
thought=None,
|
||||
observation=None,
|
||||
end_of_message=True,
|
||||
end_of_dialog=False,
|
||||
)
|
||||
|
||||
await respond(r)
|
||||
|
|
@ -364,11 +370,14 @@ class Processor(AgentService):
|
|||
thought=None,
|
||||
)
|
||||
else:
|
||||
# Legacy format - send complete answer
|
||||
# Non-streaming format - send complete answer
|
||||
r = AgentResponse(
|
||||
answer=act.final,
|
||||
error=None,
|
||||
thought=None,
|
||||
observation=None,
|
||||
end_of_message=True,
|
||||
end_of_dialog=True,
|
||||
)
|
||||
|
||||
await respond(r)
|
||||
|
|
|
|||
|
|
@ -128,6 +128,7 @@ class Processor(FlowProcessor):
|
|||
await flow("response").send(
|
||||
DocumentRagResponse(
|
||||
response = response,
|
||||
end_of_stream = True,
|
||||
error = None
|
||||
),
|
||||
properties = {"id": id}
|
||||
|
|
|
|||
|
|
@ -171,6 +171,7 @@ class Processor(FlowProcessor):
|
|||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
response = response,
|
||||
end_of_stream = True,
|
||||
error = None
|
||||
),
|
||||
properties = {"id": id}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue