mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Fix streaming agent interactions (#570)
* Fix observer, thought streaming * Fix end of message indicators * Remove double-delivery of answer
This commit is contained in:
parent
1948edaa50
commit
e24de6081f
3 changed files with 54 additions and 15 deletions
|
|
@ -111,7 +111,7 @@ Args: {
|
|||
# Arrange
|
||||
thought_chunks = []
|
||||
|
||||
async def think(chunk):
|
||||
async def think(chunk, is_final=False):
|
||||
thought_chunks.append(chunk)
|
||||
|
||||
# Act
|
||||
|
|
@ -138,7 +138,7 @@ Args: {
|
|||
# Arrange
|
||||
observation_chunks = []
|
||||
|
||||
async def observe(chunk):
|
||||
async def observe(chunk, is_final=False):
|
||||
observation_chunks.append(chunk)
|
||||
|
||||
# Act
|
||||
|
|
@ -178,10 +178,10 @@ Args: {
|
|||
thought_chunks = []
|
||||
observation_chunks = []
|
||||
|
||||
async def think(chunk):
|
||||
async def think(chunk, is_final=False):
|
||||
thought_chunks.append(chunk)
|
||||
|
||||
async def observe(chunk):
|
||||
async def observe(chunk, is_final=False):
|
||||
observation_chunks.append(chunk)
|
||||
|
||||
streaming_result = await agent_manager.react(
|
||||
|
|
@ -358,3 +358,38 @@ Final Answer: AI is the simulation of human intelligence in machines."""
|
|||
call_args = mock_prompt_client.agent_react.call_args
|
||||
assert call_args.kwargs['streaming'] is True
|
||||
assert call_args.kwargs['chunk_callback'] is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_streaming_end_of_message_flags(self, agent_manager, mock_flow_context):
|
||||
"""Test that end_of_message flags are correctly set for thought chunks"""
|
||||
# Arrange
|
||||
thought_calls = []
|
||||
|
||||
async def think(chunk, is_final=False):
|
||||
thought_calls.append({
|
||||
'chunk': chunk,
|
||||
'is_final': is_final
|
||||
})
|
||||
|
||||
# Act
|
||||
await agent_manager.react(
|
||||
question="What is machine learning?",
|
||||
history=[],
|
||||
think=think,
|
||||
observe=AsyncMock(),
|
||||
context=mock_flow_context,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(thought_calls) > 0, "Expected thought chunks to be sent"
|
||||
|
||||
# All chunks except the last should have is_final=False
|
||||
for i, call in enumerate(thought_calls[:-1]):
|
||||
assert call['is_final'] is False, \
|
||||
f"Thought chunk {i} should have is_final=False, got {call['is_final']}"
|
||||
|
||||
# Last chunk should have is_final=True
|
||||
last_call = thought_calls[-1]
|
||||
assert last_call['is_final'] is True, \
|
||||
f"Last thought chunk should have is_final=True, got {last_call['is_final']}"
|
||||
|
|
|
|||
|
|
@ -256,7 +256,11 @@ class AgentManager:
|
|||
# Send any new thought chunks
|
||||
for i in range(prev_thought_count, len(thought_chunks)):
|
||||
logger.info(f"DEBUG: Sending thought chunk {i}")
|
||||
await think(thought_chunks[i])
|
||||
# Mark last chunk as final if parser has moved out of THOUGHT state
|
||||
is_last = (i == len(thought_chunks) - 1)
|
||||
is_thought_complete = parser.state.value != "thought"
|
||||
is_final = is_last and is_thought_complete
|
||||
await think(thought_chunks[i], is_final=is_final)
|
||||
|
||||
# Send any new answer chunks
|
||||
for i in range(prev_answer_count, len(answer_chunks)):
|
||||
|
|
@ -376,7 +380,7 @@ class AgentManager:
|
|||
|
||||
logger.info(f"resp: {resp}")
|
||||
|
||||
await observe(resp)
|
||||
await observe(resp, is_final=True)
|
||||
|
||||
act.observation = resp
|
||||
|
||||
|
|
|
|||
|
|
@ -214,16 +214,16 @@ class Processor(AgentService):
|
|||
|
||||
logger.debug(f"History: {history}")
|
||||
|
||||
async def think(x):
|
||||
async def think(x, is_final=False):
|
||||
|
||||
logger.debug(f"Think: {x}")
|
||||
logger.debug(f"Think: {x} (is_final={is_final})")
|
||||
|
||||
if streaming:
|
||||
# Streaming format
|
||||
r = AgentResponse(
|
||||
chunk_type="thought",
|
||||
content=x,
|
||||
end_of_message=True,
|
||||
end_of_message=is_final,
|
||||
end_of_dialog=False,
|
||||
# Legacy fields for backward compatibility
|
||||
answer=None,
|
||||
|
|
@ -242,16 +242,16 @@ class Processor(AgentService):
|
|||
|
||||
await respond(r)
|
||||
|
||||
async def observe(x):
|
||||
async def observe(x, is_final=False):
|
||||
|
||||
logger.debug(f"Observe: {x}")
|
||||
logger.debug(f"Observe: {x} (is_final={is_final})")
|
||||
|
||||
if streaming:
|
||||
# Streaming format
|
||||
r = AgentResponse(
|
||||
chunk_type="observation",
|
||||
content=x,
|
||||
end_of_message=True,
|
||||
end_of_message=is_final,
|
||||
end_of_dialog=False,
|
||||
# Legacy fields for backward compatibility
|
||||
answer=None,
|
||||
|
|
@ -352,14 +352,14 @@ class Processor(AgentService):
|
|||
|
||||
if streaming:
|
||||
# Streaming format - send end-of-dialog marker
|
||||
# Answer chunks were already sent via think() callback during parsing
|
||||
# Answer chunks were already sent via answer() callback during parsing
|
||||
r = AgentResponse(
|
||||
chunk_type="answer",
|
||||
content="", # Empty content, just marking end of dialog
|
||||
end_of_message=True,
|
||||
end_of_dialog=True,
|
||||
# Legacy fields for backward compatibility
|
||||
answer=act.final,
|
||||
# Legacy fields set to None - answer already sent via streaming chunks
|
||||
answer=None,
|
||||
error=None,
|
||||
thought=None,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue