mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 16:36:21 +02:00
* Fix observer, thought streaming * Fix end of message indicators * Remove double-delivery of answer
395 lines
13 KiB
Python
395 lines
13 KiB
Python
"""
|
|
Integration tests for Agent Manager Streaming Functionality
|
|
|
|
These tests verify the streaming behavior of the Agent service, testing
|
|
chunk-by-chunk delivery of thoughts, actions, observations, and final answers.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
from trustgraph.agent.react.agent_manager import AgentManager
|
|
from trustgraph.agent.react.tools import KnowledgeQueryImpl
|
|
from trustgraph.agent.react.types import Tool, Argument
|
|
from tests.utils.streaming_assertions import (
|
|
assert_agent_streaming_chunks,
|
|
assert_streaming_chunks_valid,
|
|
assert_callback_invoked,
|
|
assert_chunk_types_valid,
|
|
)
|
|
|
|
|
|
@pytest.mark.integration
|
|
class TestAgentStreaming:
|
|
"""Integration tests for Agent streaming functionality"""
|
|
|
|
@pytest.fixture
|
|
def mock_prompt_client_streaming(self):
|
|
"""Mock prompt client with streaming support"""
|
|
client = AsyncMock()
|
|
|
|
async def agent_react_streaming(variables, timeout=600, streaming=False, chunk_callback=None):
|
|
# Both modes return the same text for equivalence
|
|
full_text = """Thought: I need to search for information about machine learning.
|
|
Action: knowledge_query
|
|
Args: {
|
|
"question": "What is machine learning?"
|
|
}"""
|
|
|
|
if streaming and chunk_callback:
|
|
# Send realistic line-by-line chunks
|
|
# This tests that the parser properly handles "Args:" starting a new chunk
|
|
# (which previously caused a bug where action_buffer was overwritten)
|
|
chunks = [
|
|
"Thought: I need to search for information about machine learning.\n",
|
|
"Action: knowledge_query\n",
|
|
"Args: {\n", # This used to trigger bug - Args: at start of chunk
|
|
' "question": "What is machine learning?"\n',
|
|
"}"
|
|
]
|
|
|
|
for chunk in chunks:
|
|
await chunk_callback(chunk)
|
|
|
|
return full_text
|
|
else:
|
|
# Non-streaming response - same text
|
|
return full_text
|
|
|
|
client.agent_react.side_effect = agent_react_streaming
|
|
return client
|
|
|
|
@pytest.fixture
|
|
def mock_flow_context(self, mock_prompt_client_streaming):
|
|
"""Mock flow context with streaming prompt client"""
|
|
context = MagicMock()
|
|
|
|
# Mock graph RAG client
|
|
graph_rag_client = AsyncMock()
|
|
graph_rag_client.rag.return_value = "Machine learning is a subset of AI."
|
|
|
|
def context_router(service_name):
|
|
if service_name == "prompt-request":
|
|
return mock_prompt_client_streaming
|
|
elif service_name == "graph-rag-request":
|
|
return graph_rag_client
|
|
else:
|
|
return AsyncMock()
|
|
|
|
context.side_effect = context_router
|
|
return context
|
|
|
|
@pytest.fixture
|
|
def sample_tools(self):
|
|
"""Sample tool configuration"""
|
|
return {
|
|
"knowledge_query": Tool(
|
|
name="knowledge_query",
|
|
description="Query the knowledge graph",
|
|
arguments=[
|
|
Argument(
|
|
name="question",
|
|
type="string",
|
|
description="The question to ask"
|
|
)
|
|
],
|
|
implementation=KnowledgeQueryImpl,
|
|
config={}
|
|
)
|
|
}
|
|
|
|
@pytest.fixture
|
|
def agent_manager(self, sample_tools):
|
|
"""Create AgentManager instance with streaming support"""
|
|
return AgentManager(
|
|
tools=sample_tools,
|
|
additional_context="You are a helpful AI assistant."
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_streaming_thought_chunks(self, agent_manager, mock_flow_context):
|
|
"""Test that thought chunks are streamed correctly"""
|
|
# Arrange
|
|
thought_chunks = []
|
|
|
|
async def think(chunk, is_final=False):
|
|
thought_chunks.append(chunk)
|
|
|
|
# Act
|
|
await agent_manager.react(
|
|
question="What is machine learning?",
|
|
history=[],
|
|
think=think,
|
|
observe=AsyncMock(),
|
|
context=mock_flow_context,
|
|
streaming=True
|
|
)
|
|
|
|
# Assert
|
|
assert len(thought_chunks) > 0
|
|
assert_streaming_chunks_valid(thought_chunks, min_chunks=1)
|
|
|
|
# Verify thought content makes sense
|
|
full_thought = "".join(thought_chunks)
|
|
assert "search" in full_thought.lower() or "information" in full_thought.lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_streaming_observation_chunks(self, agent_manager, mock_flow_context):
|
|
"""Test that observation chunks are streamed correctly"""
|
|
# Arrange
|
|
observation_chunks = []
|
|
|
|
async def observe(chunk, is_final=False):
|
|
observation_chunks.append(chunk)
|
|
|
|
# Act
|
|
await agent_manager.react(
|
|
question="What is machine learning?",
|
|
history=[],
|
|
think=AsyncMock(),
|
|
observe=observe,
|
|
context=mock_flow_context,
|
|
streaming=True
|
|
)
|
|
|
|
# Assert
|
|
# Note: Observations come from tool execution, which may or may not be streamed
|
|
# depending on the tool implementation
|
|
# For now, verify callback was set up
|
|
assert observe is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_streaming_vs_non_streaming(self, agent_manager, mock_flow_context):
|
|
"""Test that streaming and non-streaming produce equivalent results"""
|
|
# Arrange
|
|
question = "What is machine learning?"
|
|
history = []
|
|
|
|
# Act - Non-streaming
|
|
non_streaming_result = await agent_manager.react(
|
|
question=question,
|
|
history=history,
|
|
think=AsyncMock(),
|
|
observe=AsyncMock(),
|
|
context=mock_flow_context,
|
|
streaming=False
|
|
)
|
|
|
|
# Act - Streaming
|
|
thought_chunks = []
|
|
observation_chunks = []
|
|
|
|
async def think(chunk, is_final=False):
|
|
thought_chunks.append(chunk)
|
|
|
|
async def observe(chunk, is_final=False):
|
|
observation_chunks.append(chunk)
|
|
|
|
streaming_result = await agent_manager.react(
|
|
question=question,
|
|
history=history,
|
|
think=think,
|
|
observe=observe,
|
|
context=mock_flow_context,
|
|
streaming=True
|
|
)
|
|
|
|
# Assert - Results should be equivalent (or both valid)
|
|
assert non_streaming_result is not None
|
|
assert streaming_result is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_streaming_callback_invocation(self, agent_manager, mock_flow_context):
|
|
"""Test that callbacks are invoked with correct parameters"""
|
|
# Arrange
|
|
think = AsyncMock()
|
|
observe = AsyncMock()
|
|
|
|
# Act
|
|
await agent_manager.react(
|
|
question="What is machine learning?",
|
|
history=[],
|
|
think=think,
|
|
observe=observe,
|
|
context=mock_flow_context,
|
|
streaming=True
|
|
)
|
|
|
|
# Assert - Think callback should be invoked
|
|
assert think.call_count > 0
|
|
|
|
# Verify all callback invocations had string arguments
|
|
for call in think.call_args_list:
|
|
assert len(call.args) > 0
|
|
assert isinstance(call.args[0], str)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_streaming_without_callbacks(self, agent_manager, mock_flow_context):
|
|
"""Test streaming parameter without callbacks (should work gracefully)"""
|
|
# Arrange & Act
|
|
result = await agent_manager.react(
|
|
question="What is machine learning?",
|
|
history=[],
|
|
think=AsyncMock(),
|
|
observe=AsyncMock(),
|
|
context=mock_flow_context,
|
|
streaming=True # Streaming enabled with mock callbacks
|
|
)
|
|
|
|
# Assert - Should complete without error
|
|
assert result is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_streaming_with_conversation_history(self, agent_manager, mock_flow_context):
|
|
"""Test streaming with existing conversation history"""
|
|
# Arrange
|
|
# History should be a list of Action objects
|
|
from trustgraph.agent.react.types import Action
|
|
history = [
|
|
Action(
|
|
thought="I need to search for information about machine learning",
|
|
name="knowledge_query",
|
|
arguments={"question": "What is machine learning?"},
|
|
observation="Machine learning is a subset of AI that enables computers to learn from data."
|
|
)
|
|
]
|
|
think = AsyncMock()
|
|
|
|
# Act
|
|
result = await agent_manager.react(
|
|
question="Tell me more about neural networks",
|
|
history=history,
|
|
think=think,
|
|
observe=AsyncMock(),
|
|
context=mock_flow_context,
|
|
streaming=True
|
|
)
|
|
|
|
# Assert
|
|
assert result is not None
|
|
assert think.call_count > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_streaming_error_propagation(self, agent_manager, mock_flow_context):
|
|
"""Test that errors during streaming are properly propagated"""
|
|
# Arrange
|
|
mock_prompt_client = mock_flow_context("prompt-request")
|
|
mock_prompt_client.agent_react.side_effect = Exception("Prompt service error")
|
|
|
|
think = AsyncMock()
|
|
observe = AsyncMock()
|
|
|
|
# Act & Assert
|
|
with pytest.raises(Exception) as exc_info:
|
|
await agent_manager.react(
|
|
question="test question",
|
|
history=[],
|
|
think=think,
|
|
observe=observe,
|
|
context=mock_flow_context,
|
|
streaming=True
|
|
)
|
|
|
|
assert "Prompt service error" in str(exc_info.value)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_streaming_multi_step_reasoning(self, agent_manager, mock_flow_context,
|
|
mock_prompt_client_streaming):
|
|
"""Test streaming through multi-step reasoning process"""
|
|
# Arrange - Mock a multi-step response
|
|
step_responses = [
|
|
"""Thought: I need to search for basic information.
|
|
Action: knowledge_query
|
|
Args: {"question": "What is AI?"}""",
|
|
"""Thought: Now I can answer the question.
|
|
Final Answer: AI is the simulation of human intelligence in machines."""
|
|
]
|
|
|
|
call_count = 0
|
|
|
|
async def multi_step_agent_react(variables, timeout=600, streaming=False, chunk_callback=None):
|
|
nonlocal call_count
|
|
response = step_responses[min(call_count, len(step_responses) - 1)]
|
|
call_count += 1
|
|
|
|
if streaming and chunk_callback:
|
|
for chunk in response.split():
|
|
await chunk_callback(chunk + " ")
|
|
return response
|
|
return response
|
|
|
|
mock_prompt_client_streaming.agent_react.side_effect = multi_step_agent_react
|
|
|
|
think = AsyncMock()
|
|
observe = AsyncMock()
|
|
|
|
# Act
|
|
result = await agent_manager.react(
|
|
question="What is artificial intelligence?",
|
|
history=[],
|
|
think=think,
|
|
observe=observe,
|
|
context=mock_flow_context,
|
|
streaming=True
|
|
)
|
|
|
|
# Assert
|
|
assert result is not None
|
|
assert think.call_count > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_streaming_preserves_tool_config(self, agent_manager, mock_flow_context):
|
|
"""Test that streaming preserves tool configuration and context"""
|
|
# Arrange
|
|
think = AsyncMock()
|
|
observe = AsyncMock()
|
|
|
|
# Act
|
|
await agent_manager.react(
|
|
question="What is machine learning?",
|
|
history=[],
|
|
think=think,
|
|
observe=observe,
|
|
context=mock_flow_context,
|
|
streaming=True
|
|
)
|
|
|
|
# Assert - Verify prompt client was called with streaming
|
|
mock_prompt_client = mock_flow_context("prompt-request")
|
|
call_args = mock_prompt_client.agent_react.call_args
|
|
assert call_args.kwargs['streaming'] is True
|
|
assert call_args.kwargs['chunk_callback'] is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_streaming_end_of_message_flags(self, agent_manager, mock_flow_context):
|
|
"""Test that end_of_message flags are correctly set for thought chunks"""
|
|
# Arrange
|
|
thought_calls = []
|
|
|
|
async def think(chunk, is_final=False):
|
|
thought_calls.append({
|
|
'chunk': chunk,
|
|
'is_final': is_final
|
|
})
|
|
|
|
# Act
|
|
await agent_manager.react(
|
|
question="What is machine learning?",
|
|
history=[],
|
|
think=think,
|
|
observe=AsyncMock(),
|
|
context=mock_flow_context,
|
|
streaming=True
|
|
)
|
|
|
|
# Assert
|
|
assert len(thought_calls) > 0, "Expected thought chunks to be sent"
|
|
|
|
# All chunks except the last should have is_final=False
|
|
for i, call in enumerate(thought_calls[:-1]):
|
|
assert call['is_final'] is False, \
|
|
f"Thought chunk {i} should have is_final=False, got {call['is_final']}"
|
|
|
|
# Last chunk should have is_final=True
|
|
last_call = thought_calls[-1]
|
|
assert last_call['is_final'] is True, \
|
|
f"Last thought chunk should have is_final=True, got {last_call['is_final']}"
|