diff --git a/tests/contract/test_translator_completion_flags.py b/tests/contract/test_translator_completion_flags.py new file mode 100644 index 00000000..c01156ae --- /dev/null +++ b/tests/contract/test_translator_completion_flags.py @@ -0,0 +1,242 @@ +""" +Contract tests for message translator completion flag behavior. + +These tests verify that translators correctly compute the is_final flag +based on message fields like end_of_stream and end_of_dialog. +""" + +import pytest + +from trustgraph.schema import ( + GraphRagResponse, DocumentRagResponse, AgentResponse, Error +) +from trustgraph.messaging import TranslatorRegistry + + +@pytest.mark.contract +class TestRAGTranslatorCompletionFlags: + """Contract tests for RAG response translator completion flags""" + + def test_graph_rag_translator_is_final_with_end_of_stream_true(self): + """ + Test that GraphRagResponseTranslator returns is_final=True + when end_of_stream=True. + """ + # Arrange + translator = TranslatorRegistry.get_response_translator("graph-rag") + response = GraphRagResponse( + response="A small domesticated mammal.", + end_of_stream=True, + error=None + ) + + # Act + response_dict, is_final = translator.from_response_with_completion(response) + + # Assert + assert is_final is True, "is_final must be True when end_of_stream=True" + assert response_dict["response"] == "A small domesticated mammal." + assert response_dict["end_of_stream"] is True + + def test_graph_rag_translator_is_final_with_end_of_stream_false(self): + """ + Test that GraphRagResponseTranslator returns is_final=False + when end_of_stream=False. + """ + # Arrange + translator = TranslatorRegistry.get_response_translator("graph-rag") + response = GraphRagResponse( + response="Chunk 1", + end_of_stream=False, + error=None + ) + + # Act + response_dict, is_final = translator.from_response_with_completion(response) + + # Assert + assert is_final is False, "is_final must be False when end_of_stream=False" + assert response_dict["response"] == "Chunk 1" + assert response_dict["end_of_stream"] is False + + def test_document_rag_translator_is_final_with_end_of_stream_true(self): + """ + Test that DocumentRagResponseTranslator returns is_final=True + when end_of_stream=True. + """ + # Arrange + translator = TranslatorRegistry.get_response_translator("document-rag") + response = DocumentRagResponse( + response="A document about cats.", + end_of_stream=True, + error=None + ) + + # Act + response_dict, is_final = translator.from_response_with_completion(response) + + # Assert + assert is_final is True, "is_final must be True when end_of_stream=True" + assert response_dict["response"] == "A document about cats." + assert response_dict["end_of_stream"] is True + + def test_document_rag_translator_is_final_with_end_of_stream_false(self): + """ + Test that DocumentRagResponseTranslator returns is_final=False + when end_of_stream=False. + """ + # Arrange + translator = TranslatorRegistry.get_response_translator("document-rag") + response = DocumentRagResponse( + response="Chunk 1", + end_of_stream=False, + error=None + ) + + # Act + response_dict, is_final = translator.from_response_with_completion(response) + + # Assert + assert is_final is False, "is_final must be False when end_of_stream=False" + assert response_dict["response"] == "Chunk 1" + assert response_dict["end_of_stream"] is False + + +@pytest.mark.contract +class TestAgentTranslatorCompletionFlags: + """Contract tests for Agent response translator completion flags""" + + def test_agent_translator_is_final_with_end_of_dialog_true(self): + """ + Test that AgentResponseTranslator returns is_final=True + when end_of_dialog=True. + """ + # Arrange + translator = TranslatorRegistry.get_response_translator("agent") + response = AgentResponse( + answer="4", + error=None, + thought=None, + observation=None, + end_of_message=True, + end_of_dialog=True + ) + + # Act + response_dict, is_final = translator.from_response_with_completion(response) + + # Assert + assert is_final is True, "is_final must be True when end_of_dialog=True" + assert response_dict["answer"] == "4" + assert response_dict["end_of_dialog"] is True + + def test_agent_translator_is_final_with_end_of_dialog_false(self): + """ + Test that AgentResponseTranslator returns is_final=False + when end_of_dialog=False. + """ + # Arrange + translator = TranslatorRegistry.get_response_translator("agent") + response = AgentResponse( + answer=None, + error=None, + thought="I need to solve this.", + observation=None, + end_of_message=True, + end_of_dialog=False + ) + + # Act + response_dict, is_final = translator.from_response_with_completion(response) + + # Assert + assert is_final is False, "is_final must be False when end_of_dialog=False" + assert response_dict["thought"] == "I need to solve this." + assert response_dict["end_of_dialog"] is False + + def test_agent_translator_is_final_fallback_with_answer(self): + """ + Test that AgentResponseTranslator returns is_final=True + when answer is present (fallback for legacy responses). + """ + # Arrange + translator = TranslatorRegistry.get_response_translator("agent") + # Legacy response without end_of_dialog flag + response = AgentResponse( + answer="4", + error=None, + thought=None, + observation=None + ) + + # Act + response_dict, is_final = translator.from_response_with_completion(response) + + # Assert + assert is_final is True, "is_final must be True when answer is present (legacy fallback)" + assert response_dict["answer"] == "4" + + def test_agent_translator_intermediate_message_is_not_final(self): + """ + Test that intermediate messages (thought/observation) return is_final=False. + """ + # Arrange + translator = TranslatorRegistry.get_response_translator("agent") + + # Test thought message + thought_response = AgentResponse( + answer=None, + error=None, + thought="Processing...", + observation=None, + end_of_message=True, + end_of_dialog=False + ) + + # Act + thought_dict, thought_is_final = translator.from_response_with_completion(thought_response) + + # Assert + assert thought_is_final is False, "Thought message must not be final" + + # Test observation message + observation_response = AgentResponse( + answer=None, + error=None, + thought=None, + observation="Result found", + end_of_message=True, + end_of_dialog=False + ) + + # Act + obs_dict, obs_is_final = translator.from_response_with_completion(observation_response) + + # Assert + assert obs_is_final is False, "Observation message must not be final" + + def test_agent_translator_streaming_format_with_end_of_dialog(self): + """ + Test that streaming format messages use end_of_dialog for is_final. + """ + # Arrange + translator = TranslatorRegistry.get_response_translator("agent") + + # Streaming format with end_of_dialog=True + response = AgentResponse( + chunk_type="answer", + content="", + end_of_message=True, + end_of_dialog=True, + answer=None, + error=None, + thought=None, + observation=None + ) + + # Act + response_dict, is_final = translator.from_response_with_completion(response) + + # Assert + assert is_final is True, "Streaming format must use end_of_dialog for is_final" + assert response_dict["end_of_dialog"] is True diff --git a/tests/unit/test_agent/test_agent_service_non_streaming.py b/tests/unit/test_agent/test_agent_service_non_streaming.py new file mode 100644 index 00000000..0fd2060d --- /dev/null +++ b/tests/unit/test_agent/test_agent_service_non_streaming.py @@ -0,0 +1,206 @@ +""" +Unit tests for Agent service non-streaming mode. +Tests that end_of_message and end_of_dialog flags are correctly set. +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from trustgraph.agent.react.service import Processor +from trustgraph.schema import AgentRequest, AgentResponse +from trustgraph.agent.react.types import Final + + +class TestAgentServiceNonStreaming: + """Test Agent service non-streaming behavior""" + + @patch('trustgraph.agent.react.service.AgentManager') + @pytest.mark.asyncio + async def test_non_streaming_intermediate_messages_have_correct_flags(self, mock_agent_manager_class): + """ + Test that intermediate messages (thought/observation) in non-streaming mode + have end_of_message=True and end_of_dialog=False. + """ + # Setup processor + processor = Processor( + taskgroup=MagicMock(), + id="test-agent", + max_iterations=10 + ) + + # Track all responses sent + sent_responses = [] + + # Setup mock agent manager + mock_agent_instance = AsyncMock() + mock_agent_manager_class.return_value = mock_agent_instance + + # Mock react to call think and observe callbacks + async def mock_react(question, history, think, observe, answer, context, streaming): + await think("I need to solve this.", is_final=True) + await observe("The answer is 4.", is_final=True) + return Final(thought="Final answer", final="4") + + mock_agent_instance.react = mock_react + + # Setup message with non-streaming request + msg = MagicMock() + msg.value.return_value = AgentRequest( + question="What is 2 + 2?", + user="trustgraph", + streaming=False # Non-streaming mode + ) + msg.properties.return_value = {"id": "test-id"} + + # Setup flow mock + consumer = MagicMock() + flow = MagicMock() + + mock_producer = AsyncMock() + + async def capture_response(response, properties): + sent_responses.append(response) + + mock_producer.send = AsyncMock(side_effect=capture_response) + + def flow_router(service_name): + if service_name == "response": + return mock_producer + return AsyncMock() + + flow.side_effect = flow_router + + # Execute + await processor.on_request(msg, consumer, flow) + + # Verify: should have 3 responses (thought, observation, answer) + assert len(sent_responses) == 3, f"Expected 3 responses, got {len(sent_responses)}" + + # Check thought message + thought_response = sent_responses[0] + assert isinstance(thought_response, AgentResponse) + assert thought_response.thought == "I need to solve this." + assert thought_response.answer is None + assert thought_response.end_of_message is True, "Thought message must have end_of_message=True" + assert thought_response.end_of_dialog is False, "Thought message must have end_of_dialog=False" + + # Check observation message + observation_response = sent_responses[1] + assert isinstance(observation_response, AgentResponse) + assert observation_response.observation == "The answer is 4." + assert observation_response.answer is None + assert observation_response.end_of_message is True, "Observation message must have end_of_message=True" + assert observation_response.end_of_dialog is False, "Observation message must have end_of_dialog=False" + + @patch('trustgraph.agent.react.service.AgentManager') + @pytest.mark.asyncio + async def test_non_streaming_final_answer_has_correct_flags(self, mock_agent_manager_class): + """ + Test that final answer in non-streaming mode has + end_of_message=True and end_of_dialog=True. + """ + # Setup processor + processor = Processor( + taskgroup=MagicMock(), + id="test-agent", + max_iterations=10 + ) + + # Track all responses sent + sent_responses = [] + + # Setup mock agent manager + mock_agent_instance = AsyncMock() + mock_agent_manager_class.return_value = mock_agent_instance + + # Mock react to return Final directly + async def mock_react(question, history, think, observe, answer, context, streaming): + return Final(thought="Final answer", final="4") + + mock_agent_instance.react = mock_react + + # Setup message with non-streaming request + msg = MagicMock() + msg.value.return_value = AgentRequest( + question="What is 2 + 2?", + user="trustgraph", + streaming=False # Non-streaming mode + ) + msg.properties.return_value = {"id": "test-id"} + + # Setup flow mock + consumer = MagicMock() + flow = MagicMock() + + mock_producer = AsyncMock() + + async def capture_response(response, properties): + sent_responses.append(response) + + mock_producer.send = AsyncMock(side_effect=capture_response) + + def flow_router(service_name): + if service_name == "response": + return mock_producer + return AsyncMock() + + flow.side_effect = flow_router + + # Execute + await processor.on_request(msg, consumer, flow) + + # Verify: should have 1 response (final answer) + assert len(sent_responses) == 1, f"Expected 1 response, got {len(sent_responses)}" + + # Check final answer message + answer_response = sent_responses[0] + assert isinstance(answer_response, AgentResponse) + assert answer_response.answer == "4" + assert answer_response.thought is None + assert answer_response.observation is None + assert answer_response.end_of_message is True, "Final answer must have end_of_message=True" + assert answer_response.end_of_dialog is True, "Final answer must have end_of_dialog=True" + + @pytest.mark.asyncio + async def test_error_response_has_correct_flags(self): + """ + Test that error responses have end_of_message=True and end_of_dialog=True. + """ + # Setup processor that will error + processor = Processor( + taskgroup=MagicMock(), + id="test-agent", + max_iterations=10 + ) + + # Track all responses sent + sent_responses = [] + + # Setup message + msg = MagicMock() + msg.value.side_effect = Exception("Test error") + msg.properties.return_value = {"id": "test-id"} + + # Setup flow mock + consumer = MagicMock() + flow = MagicMock() + flow.producer = {"response": AsyncMock()} + + async def capture_response(response, properties): + sent_responses.append(response) + + flow.producer["response"].send = AsyncMock(side_effect=capture_response) + + # Execute + await processor.on_request(msg, consumer, flow) + + # Verify: should have 1 error response + assert len(sent_responses) == 1, f"Expected 1 error response, got {len(sent_responses)}" + + # Check error response + error_response = sent_responses[0] + assert isinstance(error_response, AgentResponse) + assert error_response.error is not None + assert "Test error" in error_response.error.message + assert error_response.end_of_message is True, "Error response must have end_of_message=True" + assert error_response.end_of_dialog is True, "Error response must have end_of_dialog=True" diff --git a/tests/unit/test_retrieval/test_document_rag_service.py b/tests/unit/test_retrieval/test_document_rag_service.py index 55b9b97f..041d29df 100644 --- a/tests/unit/test_retrieval/test_document_rag_service.py +++ b/tests/unit/test_retrieval/test_document_rag_service.py @@ -74,4 +74,58 @@ class TestDocumentRagService: sent_response = mock_producer.send.call_args[0][0] assert isinstance(sent_response, DocumentRagResponse) assert sent_response.response == "test response" + assert sent_response.error is None + + @patch('trustgraph.retrieval.document_rag.rag.DocumentRag') + @pytest.mark.asyncio + async def test_non_streaming_mode_sets_end_of_stream_true(self, mock_document_rag_class): + """ + Test that non-streaming mode sets end_of_stream=True in response. + + This is a regression test for the bug where non-streaming responses + didn't set end_of_stream, causing clients to hang waiting for more data. + """ + # Setup processor + processor = Processor( + taskgroup=MagicMock(), + id="test-processor", + doc_limit=10 + ) + + # Setup mock DocumentRag instance + mock_rag_instance = AsyncMock() + mock_document_rag_class.return_value = mock_rag_instance + mock_rag_instance.query.return_value = "A document about cats." + + # Setup message with non-streaming request + msg = MagicMock() + msg.value.return_value = DocumentRagQuery( + query="What is a cat?", + user="trustgraph", + collection="default", + doc_limit=10, + streaming=False # Non-streaming mode + ) + msg.properties.return_value = {"id": "test-id"} + + # Setup flow mock + consumer = MagicMock() + flow = MagicMock() + + mock_producer = AsyncMock() + def flow_router(service_name): + if service_name == "response": + return mock_producer + return AsyncMock() + flow.side_effect = flow_router + + # Execute + await processor.on_request(msg, consumer, flow) + + # Verify: response was sent with end_of_stream=True + mock_producer.send.assert_called_once() + sent_response = mock_producer.send.call_args[0][0] + assert isinstance(sent_response, DocumentRagResponse) + assert sent_response.response == "A document about cats." + assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True" assert sent_response.error is None \ No newline at end of file diff --git a/tests/unit/test_retrieval/test_graph_rag_service.py b/tests/unit/test_retrieval/test_graph_rag_service.py new file mode 100644 index 00000000..ddfdfa75 --- /dev/null +++ b/tests/unit/test_retrieval/test_graph_rag_service.py @@ -0,0 +1,134 @@ +""" +Unit tests for GraphRAG service non-streaming mode. +Tests that end_of_stream flag is correctly set in non-streaming responses. +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from trustgraph.retrieval.graph_rag.rag import Processor +from trustgraph.schema import GraphRagQuery, GraphRagResponse + + +class TestGraphRagService: + """Test GraphRAG service non-streaming behavior""" + + @patch('trustgraph.retrieval.graph_rag.rag.GraphRag') + @pytest.mark.asyncio + async def test_non_streaming_mode_sets_end_of_stream_true(self, mock_graph_rag_class): + """ + Test that non-streaming mode sets end_of_stream=True in response. + + This is a regression test for the bug where non-streaming responses + didn't set end_of_stream, causing clients to hang waiting for more data. + """ + # Setup processor + processor = Processor( + taskgroup=MagicMock(), + id="test-processor", + entity_limit=50, + triple_limit=30, + max_subgraph_size=150, + max_path_length=2 + ) + + # Setup mock GraphRag instance + mock_rag_instance = AsyncMock() + mock_graph_rag_class.return_value = mock_rag_instance + mock_rag_instance.query.return_value = "A small domesticated mammal." + + # Setup message with non-streaming request + msg = MagicMock() + msg.value.return_value = GraphRagQuery( + query="What is a cat?", + user="trustgraph", + collection="default", + entity_limit=50, + triple_limit=30, + max_subgraph_size=150, + max_path_length=2, + streaming=False # Non-streaming mode + ) + msg.properties.return_value = {"id": "test-id"} + + # Setup flow mock + consumer = MagicMock() + flow = MagicMock() + + # Mock flow to return AsyncMock for clients and response producer + mock_producer = AsyncMock() + def flow_router(service_name): + if service_name == "response": + return mock_producer + return AsyncMock() # embeddings, graph-embeddings, triples, prompt clients + flow.side_effect = flow_router + + # Execute + await processor.on_request(msg, consumer, flow) + + # Verify: response was sent with end_of_stream=True + mock_producer.send.assert_called_once() + sent_response = mock_producer.send.call_args[0][0] + assert isinstance(sent_response, GraphRagResponse) + assert sent_response.response == "A small domesticated mammal." + assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True" + assert sent_response.error is None + + @patch('trustgraph.retrieval.graph_rag.rag.GraphRag') + @pytest.mark.asyncio + async def test_error_response_in_non_streaming_mode(self, mock_graph_rag_class): + """ + Test that error responses in non-streaming mode set end_of_stream=True. + """ + # Setup processor + processor = Processor( + taskgroup=MagicMock(), + id="test-processor", + entity_limit=50, + triple_limit=30, + max_subgraph_size=150, + max_path_length=2 + ) + + # Setup mock GraphRag instance that raises an exception + mock_rag_instance = AsyncMock() + mock_graph_rag_class.return_value = mock_rag_instance + mock_rag_instance.query.side_effect = Exception("Test error") + + # Setup message with non-streaming request + msg = MagicMock() + msg.value.return_value = GraphRagQuery( + query="What is a cat?", + user="trustgraph", + collection="default", + entity_limit=50, + triple_limit=30, + max_subgraph_size=150, + max_path_length=2, + streaming=False # Non-streaming mode + ) + msg.properties.return_value = {"id": "test-id"} + + # Setup flow mock + consumer = MagicMock() + flow = MagicMock() + + mock_producer = AsyncMock() + def flow_router(service_name): + if service_name == "response": + return mock_producer + return AsyncMock() + flow.side_effect = flow_router + + # Execute + await processor.on_request(msg, consumer, flow) + + # Verify: error response was sent without end_of_stream (not streaming mode) + mock_producer.send.assert_called_once() + sent_response = mock_producer.send.call_args[0][0] + assert isinstance(sent_response, GraphRagResponse) + assert sent_response.response is None + assert sent_response.error is not None + assert sent_response.error.message == "Test error" + # Note: error responses in non-streaming mode don't set end_of_stream + # because streaming was never started diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index b1be0195..e1b8f705 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -275,13 +275,17 @@ class SocketFlowInstance: result = self.client._send_request_sync("text-completion", self.flow_id, request, streaming) if streaming: - # For text completion, yield just the content - for chunk in result: - if hasattr(chunk, 'content'): - yield chunk.content + # For text completion, return generator that yields content + return self._text_completion_generator(result) else: return result.get("response", "") + def _text_completion_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]: + """Generator for text completion streaming""" + for chunk in result: + if hasattr(chunk, 'content'): + yield chunk.content + def graph_rag( self, query: str, @@ -308,9 +312,7 @@ class SocketFlowInstance: result = self.client._send_request_sync("graph-rag", self.flow_id, request, streaming) if streaming: - for chunk in result: - if hasattr(chunk, 'content'): - yield chunk.content + return self._rag_generator(result) else: return result.get("response", "") @@ -336,12 +338,16 @@ class SocketFlowInstance: result = self.client._send_request_sync("document-rag", self.flow_id, request, streaming) if streaming: - for chunk in result: - if hasattr(chunk, 'content'): - yield chunk.content + return self._rag_generator(result) else: return result.get("response", "") + def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]: + """Generator for RAG streaming (graph-rag and document-rag)""" + for chunk in result: + if hasattr(chunk, 'content'): + yield chunk.content + def prompt( self, id: str, @@ -360,9 +366,7 @@ class SocketFlowInstance: result = self.client._send_request_sync("prompt", self.flow_id, request, streaming) if streaming: - for chunk in result: - if hasattr(chunk, 'content'): - yield chunk.content + return self._rag_generator(result) else: return result.get("response", "") diff --git a/trustgraph-base/trustgraph/base/agent_service.py b/trustgraph-base/trustgraph/base/agent_service.py index 0d38114b..0e5524fe 100644 --- a/trustgraph-base/trustgraph/base/agent_service.py +++ b/trustgraph-base/trustgraph/base/agent_service.py @@ -48,13 +48,13 @@ class AgentService(FlowProcessor): async def on_request(self, msg, consumer, flow): + # Get ID early so error handler can use it + id = msg.properties().get("id", "unknown") + try: request = msg.value() - # Sender-produced ID - id = msg.properties()["id"] - async def respond(resp): await flow("response").send( @@ -93,6 +93,8 @@ class AgentService(FlowProcessor): thought = None, observation = None, answer = None, + end_of_message = True, + end_of_dialog = True, ), properties={"id": id} ) diff --git a/trustgraph-base/trustgraph/messaging/translators/agent.py b/trustgraph-base/trustgraph/messaging/translators/agent.py index 4319fd16..4289df0a 100644 --- a/trustgraph-base/trustgraph/messaging/translators/agent.py +++ b/trustgraph-base/trustgraph/messaging/translators/agent.py @@ -44,13 +44,16 @@ class AgentResponseTranslator(MessageTranslator): result["end_of_message"] = getattr(obj, "end_of_message", False) result["end_of_dialog"] = getattr(obj, "end_of_dialog", False) else: - # Legacy format + # Legacy format (non-streaming) if obj.answer: result["answer"] = obj.answer if obj.thought: result["thought"] = obj.thought if obj.observation: result["observation"] = obj.observation + # Include completion flags for legacy format too + result["end_of_message"] = getattr(obj, "end_of_message", False) + result["end_of_dialog"] = getattr(obj, "end_of_dialog", False) # Always include error if present if hasattr(obj, 'error') and obj.error and obj.error.message: diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py index d4a4d72f..3af851d2 100755 --- a/trustgraph-flow/trustgraph/agent/react/service.py +++ b/trustgraph-flow/trustgraph/agent/react/service.py @@ -232,12 +232,14 @@ class Processor(AgentService): observation=None, ) else: - # Legacy format + # Non-streaming format r = AgentResponse( answer=None, error=None, thought=x, observation=None, + end_of_message=True, + end_of_dialog=False, ) await respond(r) @@ -260,12 +262,14 @@ class Processor(AgentService): observation=x, ) else: - # Legacy format + # Non-streaming format r = AgentResponse( answer=None, error=None, thought=None, observation=x, + end_of_message=True, + end_of_dialog=False, ) await respond(r) @@ -288,12 +292,14 @@ class Processor(AgentService): observation=None, ) else: - # Legacy format - shouldn't be called in non-streaming mode + # Non-streaming format - shouldn't normally be called r = AgentResponse( answer=x, error=None, thought=None, observation=None, + end_of_message=True, + end_of_dialog=False, ) await respond(r) @@ -364,11 +370,14 @@ class Processor(AgentService): thought=None, ) else: - # Legacy format - send complete answer + # Non-streaming format - send complete answer r = AgentResponse( answer=act.final, error=None, thought=None, + observation=None, + end_of_message=True, + end_of_dialog=True, ) await respond(r) diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py index 14d71d97..6490562a 100755 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py @@ -128,6 +128,7 @@ class Processor(FlowProcessor): await flow("response").send( DocumentRagResponse( response = response, + end_of_stream = True, error = None ), properties = {"id": id} diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py index d159dbae..d8bfbddb 100755 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py @@ -171,6 +171,7 @@ class Processor(FlowProcessor): await flow("response").send( GraphRagResponse( response = response, + end_of_stream = True, error = None ), properties = {"id": id}