diff --git a/tests/integration/test_agent_streaming_integration.py b/tests/integration/test_agent_streaming_integration.py index 2b619098..0971d30c 100644 --- a/tests/integration/test_agent_streaming_integration.py +++ b/tests/integration/test_agent_streaming_integration.py @@ -111,7 +111,7 @@ Args: { # Arrange thought_chunks = [] - async def think(chunk): + async def think(chunk, is_final=False): thought_chunks.append(chunk) # Act @@ -138,7 +138,7 @@ Args: { # Arrange observation_chunks = [] - async def observe(chunk): + async def observe(chunk, is_final=False): observation_chunks.append(chunk) # Act @@ -178,10 +178,10 @@ Args: { thought_chunks = [] observation_chunks = [] - async def think(chunk): + async def think(chunk, is_final=False): thought_chunks.append(chunk) - async def observe(chunk): + async def observe(chunk, is_final=False): observation_chunks.append(chunk) streaming_result = await agent_manager.react( @@ -358,3 +358,38 @@ Final Answer: AI is the simulation of human intelligence in machines.""" call_args = mock_prompt_client.agent_react.call_args assert call_args.kwargs['streaming'] is True assert call_args.kwargs['chunk_callback'] is not None + + @pytest.mark.asyncio + async def test_agent_streaming_end_of_message_flags(self, agent_manager, mock_flow_context): + """Test that end_of_message flags are correctly set for thought chunks""" + # Arrange + thought_calls = [] + + async def think(chunk, is_final=False): + thought_calls.append({ + 'chunk': chunk, + 'is_final': is_final + }) + + # Act + await agent_manager.react( + question="What is machine learning?", + history=[], + think=think, + observe=AsyncMock(), + context=mock_flow_context, + streaming=True + ) + + # Assert + assert len(thought_calls) > 0, "Expected thought chunks to be sent" + + # All chunks except the last should have is_final=False + for i, call in enumerate(thought_calls[:-1]): + assert call['is_final'] is False, \ + f"Thought chunk {i} should have is_final=False, got {call['is_final']}" + + # Last chunk should have is_final=True + last_call = thought_calls[-1] + assert last_call['is_final'] is True, \ + f"Last thought chunk should have is_final=True, got {last_call['is_final']}" diff --git a/trustgraph-flow/trustgraph/agent/react/agent_manager.py b/trustgraph-flow/trustgraph/agent/react/agent_manager.py index 9311a57b..bf1edbfa 100644 --- a/trustgraph-flow/trustgraph/agent/react/agent_manager.py +++ b/trustgraph-flow/trustgraph/agent/react/agent_manager.py @@ -256,7 +256,11 @@ class AgentManager: # Send any new thought chunks for i in range(prev_thought_count, len(thought_chunks)): logger.info(f"DEBUG: Sending thought chunk {i}") - await think(thought_chunks[i]) + # Mark last chunk as final if parser has moved out of THOUGHT state + is_last = (i == len(thought_chunks) - 1) + is_thought_complete = parser.state.value != "thought" + is_final = is_last and is_thought_complete + await think(thought_chunks[i], is_final=is_final) # Send any new answer chunks for i in range(prev_answer_count, len(answer_chunks)): @@ -376,7 +380,7 @@ class AgentManager: logger.info(f"resp: {resp}") - await observe(resp) + await observe(resp, is_final=True) act.observation = resp diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py index 2fb5b9c9..a4238e36 100755 --- a/trustgraph-flow/trustgraph/agent/react/service.py +++ b/trustgraph-flow/trustgraph/agent/react/service.py @@ -214,16 +214,16 @@ class Processor(AgentService): logger.debug(f"History: {history}") - async def think(x): + async def think(x, is_final=False): - logger.debug(f"Think: {x}") + logger.debug(f"Think: {x} (is_final={is_final})") if streaming: # Streaming format r = AgentResponse( chunk_type="thought", content=x, - end_of_message=True, + end_of_message=is_final, end_of_dialog=False, # Legacy fields for backward compatibility answer=None, @@ -242,16 +242,16 @@ class Processor(AgentService): await respond(r) - async def observe(x): + async def observe(x, is_final=False): - logger.debug(f"Observe: {x}") + logger.debug(f"Observe: {x} (is_final={is_final})") if streaming: # Streaming format r = AgentResponse( chunk_type="observation", content=x, - end_of_message=True, + end_of_message=is_final, end_of_dialog=False, # Legacy fields for backward compatibility answer=None, @@ -352,14 +352,14 @@ class Processor(AgentService): if streaming: # Streaming format - send end-of-dialog marker - # Answer chunks were already sent via think() callback during parsing + # Answer chunks were already sent via answer() callback during parsing r = AgentResponse( chunk_type="answer", content="", # Empty content, just marking end of dialog end_of_message=True, end_of_dialog=True, - # Legacy fields for backward compatibility - answer=act.final, + # Legacy fields set to None - answer already sent via streaming chunks + answer=None, error=None, thought=None, )