mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Update to add streaming tests (#600)
This commit is contained in:
parent
f0c95a4c5e
commit
f79d0603f7
9 changed files with 1062 additions and 57 deletions
|
|
@ -480,11 +480,15 @@ def streaming_chunk_collector():
|
||||||
class ChunkCollector:
|
class ChunkCollector:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.chunks = []
|
self.chunks = []
|
||||||
|
self.end_of_stream_flags = []
|
||||||
self.complete = False
|
self.complete = False
|
||||||
|
|
||||||
async def collect(self, chunk):
|
async def collect(self, chunk, end_of_stream=False):
|
||||||
"""Async callback to collect chunks"""
|
"""Async callback to collect chunks with end_of_stream flag"""
|
||||||
self.chunks.append(chunk)
|
self.chunks.append(chunk)
|
||||||
|
self.end_of_stream_flags.append(end_of_stream)
|
||||||
|
if end_of_stream:
|
||||||
|
self.complete = True
|
||||||
|
|
||||||
def get_full_text(self):
|
def get_full_text(self):
|
||||||
"""Concatenate all chunk content"""
|
"""Concatenate all chunk content"""
|
||||||
|
|
@ -496,6 +500,14 @@ def streaming_chunk_collector():
|
||||||
return [c.get("chunk_type") for c in self.chunks]
|
return [c.get("chunk_type") for c in self.chunks]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def verify_streaming_protocol(self):
|
||||||
|
"""Verify that streaming protocol is correct"""
|
||||||
|
assert len(self.chunks) > 0, "Should have received at least one chunk"
|
||||||
|
assert len(self.chunks) == len(self.end_of_stream_flags), "Each chunk should have an end_of_stream flag"
|
||||||
|
assert self.end_of_stream_flags.count(True) == 1, "Exactly one chunk should have end_of_stream=True"
|
||||||
|
assert self.end_of_stream_flags[-1] is True, "Last chunk should have end_of_stream=True"
|
||||||
|
assert self.complete is True, "Should be marked complete after final chunk"
|
||||||
|
|
||||||
return ChunkCollector
|
return ChunkCollector
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -46,9 +46,16 @@ class TestDocumentRagStreaming:
|
||||||
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
|
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
|
||||||
|
|
||||||
if streaming and chunk_callback:
|
if streaming and chunk_callback:
|
||||||
# Simulate streaming chunks
|
# Simulate streaming chunks with end_of_stream flags
|
||||||
|
chunks = []
|
||||||
async for chunk in mock_streaming_llm_response():
|
async for chunk in mock_streaming_llm_response():
|
||||||
await chunk_callback(chunk)
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# Send all chunks with end_of_stream=False except the last
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
is_final = (i == len(chunks) - 1)
|
||||||
|
await chunk_callback(chunk, is_final)
|
||||||
|
|
||||||
return full_text
|
return full_text
|
||||||
else:
|
else:
|
||||||
# Non-streaming response - same text
|
# Non-streaming response - same text
|
||||||
|
|
@ -89,6 +96,9 @@ class TestDocumentRagStreaming:
|
||||||
assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
|
assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
|
||||||
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
|
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
|
||||||
|
|
||||||
|
# Verify streaming protocol compliance
|
||||||
|
collector.verify_streaming_protocol()
|
||||||
|
|
||||||
# Verify full response matches concatenated chunks
|
# Verify full response matches concatenated chunks
|
||||||
full_from_chunks = collector.get_full_text()
|
full_from_chunks = collector.get_full_text()
|
||||||
assert result == full_from_chunks
|
assert result == full_from_chunks
|
||||||
|
|
@ -117,7 +127,7 @@ class TestDocumentRagStreaming:
|
||||||
# Act - Streaming
|
# Act - Streaming
|
||||||
streaming_chunks = []
|
streaming_chunks = []
|
||||||
|
|
||||||
async def collect(chunk):
|
async def collect(chunk, end_of_stream):
|
||||||
streaming_chunks.append(chunk)
|
streaming_chunks.append(chunk)
|
||||||
|
|
||||||
streaming_result = await document_rag_streaming.query(
|
streaming_result = await document_rag_streaming.query(
|
||||||
|
|
|
||||||
|
|
@ -59,9 +59,16 @@ class TestGraphRagStreaming:
|
||||||
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
|
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
|
||||||
|
|
||||||
if streaming and chunk_callback:
|
if streaming and chunk_callback:
|
||||||
# Simulate streaming chunks
|
# Simulate streaming chunks with end_of_stream flags
|
||||||
|
chunks = []
|
||||||
async for chunk in mock_streaming_llm_response():
|
async for chunk in mock_streaming_llm_response():
|
||||||
await chunk_callback(chunk)
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# Send all chunks with end_of_stream=False except the last
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
is_final = (i == len(chunks) - 1)
|
||||||
|
await chunk_callback(chunk, is_final)
|
||||||
|
|
||||||
return full_text
|
return full_text
|
||||||
else:
|
else:
|
||||||
# Non-streaming response - same text
|
# Non-streaming response - same text
|
||||||
|
|
@ -102,6 +109,9 @@ class TestGraphRagStreaming:
|
||||||
assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
|
assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
|
||||||
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
|
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
|
||||||
|
|
||||||
|
# Verify streaming protocol compliance
|
||||||
|
collector.verify_streaming_protocol()
|
||||||
|
|
||||||
# Verify full response matches concatenated chunks
|
# Verify full response matches concatenated chunks
|
||||||
full_from_chunks = collector.get_full_text()
|
full_from_chunks = collector.get_full_text()
|
||||||
assert result == full_from_chunks
|
assert result == full_from_chunks
|
||||||
|
|
@ -128,7 +138,7 @@ class TestGraphRagStreaming:
|
||||||
# Act - Streaming
|
# Act - Streaming
|
||||||
streaming_chunks = []
|
streaming_chunks = []
|
||||||
|
|
||||||
async def collect(chunk):
|
async def collect(chunk, end_of_stream):
|
||||||
streaming_chunks.append(chunk)
|
streaming_chunks.append(chunk)
|
||||||
|
|
||||||
streaming_result = await graph_rag_streaming.query(
|
streaming_result = await graph_rag_streaming.query(
|
||||||
|
|
|
||||||
351
tests/integration/test_rag_streaming_protocol.py
Normal file
351
tests/integration/test_rag_streaming_protocol.py
Normal file
|
|
@ -0,0 +1,351 @@
|
||||||
|
"""
|
||||||
|
Integration tests for RAG service streaming protocol compliance.
|
||||||
|
|
||||||
|
These tests verify that RAG services correctly forward end_of_stream flags
|
||||||
|
and don't duplicate final chunks, ensuring proper streaming semantics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, call
|
||||||
|
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||||
|
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||||
|
|
||||||
|
|
||||||
|
class TestGraphRagStreamingProtocol:
|
||||||
|
"""Integration tests for GraphRAG streaming protocol"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embeddings_client(self):
|
||||||
|
"""Mock embeddings client"""
|
||||||
|
client = AsyncMock()
|
||||||
|
client.embed.return_value = [[0.1, 0.2, 0.3]]
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_graph_embeddings_client(self):
|
||||||
|
"""Mock graph embeddings client"""
|
||||||
|
client = AsyncMock()
|
||||||
|
client.query.return_value = ["entity1", "entity2"]
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_triples_client(self):
|
||||||
|
"""Mock triples client"""
|
||||||
|
client = AsyncMock()
|
||||||
|
client.query.return_value = []
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_streaming_prompt_client(self):
|
||||||
|
"""Mock prompt client that simulates realistic streaming with end_of_stream flags"""
|
||||||
|
client = AsyncMock()
|
||||||
|
|
||||||
|
async def kg_prompt_side_effect(query, kg, timeout=600, streaming=False, chunk_callback=None):
|
||||||
|
if streaming and chunk_callback:
|
||||||
|
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
|
||||||
|
await chunk_callback("The", False)
|
||||||
|
await chunk_callback(" answer", False)
|
||||||
|
await chunk_callback(" is here.", False)
|
||||||
|
await chunk_callback("", True) # Empty final chunk with end_of_stream=True
|
||||||
|
return "" # Return value not used since callback handles everything
|
||||||
|
else:
|
||||||
|
return "The answer is here."
|
||||||
|
|
||||||
|
client.kg_prompt.side_effect = kg_prompt_side_effect
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client,
|
||||||
|
mock_triples_client, mock_streaming_prompt_client):
|
||||||
|
"""Create GraphRag instance with mocked dependencies"""
|
||||||
|
return GraphRag(
|
||||||
|
embeddings_client=mock_embeddings_client,
|
||||||
|
graph_embeddings_client=mock_graph_embeddings_client,
|
||||||
|
triples_client=mock_triples_client,
|
||||||
|
prompt_client=mock_streaming_prompt_client,
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_receives_end_of_stream_parameter(self, graph_rag):
|
||||||
|
"""Test that callback receives end_of_stream parameter"""
|
||||||
|
# Arrange
|
||||||
|
callback = AsyncMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await graph_rag.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - callback should receive (chunk, end_of_stream) signature
|
||||||
|
assert callback.call_count == 4
|
||||||
|
# All calls should have 2 arguments
|
||||||
|
for call_args in callback.call_args_list:
|
||||||
|
assert len(call_args.args) == 2, "Callback should receive (chunk, end_of_stream)"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_end_of_stream_flag_forwarded_correctly(self, graph_rag):
|
||||||
|
"""Test that end_of_stream flags are forwarded correctly"""
|
||||||
|
# Arrange
|
||||||
|
chunks_with_flags = []
|
||||||
|
|
||||||
|
async def collect(chunk, end_of_stream):
|
||||||
|
chunks_with_flags.append((chunk, end_of_stream))
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await graph_rag.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=collect
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(chunks_with_flags) == 4
|
||||||
|
|
||||||
|
# First three chunks should have end_of_stream=False
|
||||||
|
assert chunks_with_flags[0] == ("The", False)
|
||||||
|
assert chunks_with_flags[1] == (" answer", False)
|
||||||
|
assert chunks_with_flags[2] == (" is here.", False)
|
||||||
|
|
||||||
|
# Final chunk should have end_of_stream=True
|
||||||
|
assert chunks_with_flags[3] == ("", True)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_duplicate_final_chunk(self, graph_rag):
|
||||||
|
"""Test that final chunk is not duplicated"""
|
||||||
|
# Arrange
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
async def collect(chunk, end_of_stream):
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await graph_rag.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=collect
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - should have exactly 4 chunks, no duplicates
|
||||||
|
assert len(chunks) == 4
|
||||||
|
assert chunks == ["The", " answer", " is here.", ""]
|
||||||
|
|
||||||
|
# The last chunk appears exactly once
|
||||||
|
assert chunks.count("") == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exactly_one_end_of_stream_true(self, graph_rag):
|
||||||
|
"""Test that exactly one message has end_of_stream=True"""
|
||||||
|
# Arrange
|
||||||
|
end_of_stream_flags = []
|
||||||
|
|
||||||
|
async def collect(chunk, end_of_stream):
|
||||||
|
end_of_stream_flags.append(end_of_stream)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await graph_rag.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=collect
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - exactly one True
|
||||||
|
assert end_of_stream_flags.count(True) == 1
|
||||||
|
assert end_of_stream_flags.count(False) == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_final_chunk_preserved(self, graph_rag):
|
||||||
|
"""Test that empty final chunks are preserved and forwarded"""
|
||||||
|
# Arrange
|
||||||
|
final_chunk = None
|
||||||
|
final_flag = None
|
||||||
|
|
||||||
|
async def collect(chunk, end_of_stream):
|
||||||
|
nonlocal final_chunk, final_flag
|
||||||
|
if end_of_stream:
|
||||||
|
final_chunk = chunk
|
||||||
|
final_flag = end_of_stream
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await graph_rag.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=collect
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert final_flag is True
|
||||||
|
assert final_chunk == "", "Empty final chunk should be preserved"
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentRagStreamingProtocol:
|
||||||
|
"""Integration tests for DocumentRAG streaming protocol"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embeddings_client(self):
|
||||||
|
"""Mock embeddings client"""
|
||||||
|
client = AsyncMock()
|
||||||
|
client.embed.return_value = [[0.1, 0.2, 0.3]]
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_doc_embeddings_client(self):
|
||||||
|
"""Mock document embeddings client"""
|
||||||
|
client = AsyncMock()
|
||||||
|
client.query.return_value = ["doc1", "doc2"]
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_streaming_prompt_client(self):
|
||||||
|
"""Mock prompt client with streaming support"""
|
||||||
|
client = AsyncMock()
|
||||||
|
|
||||||
|
async def document_prompt_side_effect(query, documents, timeout=600, streaming=False, chunk_callback=None):
|
||||||
|
if streaming and chunk_callback:
|
||||||
|
# Simulate streaming with non-empty final chunk (some LLMs do this)
|
||||||
|
await chunk_callback("Document", False)
|
||||||
|
await chunk_callback(" summary", False)
|
||||||
|
await chunk_callback(".", True) # Non-empty final chunk
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
return "Document summary."
|
||||||
|
|
||||||
|
client.document_prompt.side_effect = document_prompt_side_effect
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client,
|
||||||
|
mock_streaming_prompt_client):
|
||||||
|
"""Create DocumentRag instance with mocked dependencies"""
|
||||||
|
return DocumentRag(
|
||||||
|
embeddings_client=mock_embeddings_client,
|
||||||
|
doc_embeddings_client=mock_doc_embeddings_client,
|
||||||
|
prompt_client=mock_streaming_prompt_client,
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_receives_end_of_stream_parameter(self, document_rag):
|
||||||
|
"""Test that callback receives end_of_stream parameter"""
|
||||||
|
# Arrange
|
||||||
|
callback = AsyncMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await document_rag.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert callback.call_count == 3
|
||||||
|
for call_args in callback.call_args_list:
|
||||||
|
assert len(call_args.args) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_non_empty_final_chunk_preserved(self, document_rag):
|
||||||
|
"""Test that non-empty final chunks are preserved with correct flag"""
|
||||||
|
# Arrange
|
||||||
|
chunks_with_flags = []
|
||||||
|
|
||||||
|
async def collect(chunk, end_of_stream):
|
||||||
|
chunks_with_flags.append((chunk, end_of_stream))
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await document_rag.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=collect
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(chunks_with_flags) == 3
|
||||||
|
assert chunks_with_flags[0] == ("Document", False)
|
||||||
|
assert chunks_with_flags[1] == (" summary", False)
|
||||||
|
assert chunks_with_flags[2] == (".", True) # Non-empty final chunk
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_duplicate_final_chunk(self, document_rag):
|
||||||
|
"""Test that final chunk is not duplicated"""
|
||||||
|
# Arrange
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
async def collect(chunk, end_of_stream):
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await document_rag.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=collect
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - final "." appears exactly once
|
||||||
|
assert chunks.count(".") == 1
|
||||||
|
assert chunks == ["Document", " summary", "."]
|
||||||
|
|
||||||
|
|
||||||
|
class TestStreamingProtocolEdgeCases:
|
||||||
|
"""Test edge cases in streaming protocol"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_empty_chunks_before_final(self):
|
||||||
|
"""Test handling of multiple empty chunks (edge case)"""
|
||||||
|
# Arrange
|
||||||
|
client = AsyncMock()
|
||||||
|
|
||||||
|
async def kg_prompt_with_empties(query, kg, timeout=600, streaming=False, chunk_callback=None):
|
||||||
|
if streaming and chunk_callback:
|
||||||
|
await chunk_callback("text", False)
|
||||||
|
await chunk_callback("", False) # Empty but not final
|
||||||
|
await chunk_callback("more", False)
|
||||||
|
await chunk_callback("", True) # Empty and final
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
return "textmore"
|
||||||
|
|
||||||
|
client.kg_prompt.side_effect = kg_prompt_with_empties
|
||||||
|
|
||||||
|
rag = GraphRag(
|
||||||
|
embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[0.1]])),
|
||||||
|
graph_embeddings_client=AsyncMock(query=AsyncMock(return_value=[])),
|
||||||
|
triples_client=AsyncMock(query=AsyncMock(return_value=[])),
|
||||||
|
prompt_client=client,
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks_with_flags = []
|
||||||
|
|
||||||
|
async def collect(chunk, end_of_stream):
|
||||||
|
chunks_with_flags.append((chunk, end_of_stream))
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await rag.query(
|
||||||
|
query="test",
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=collect
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(chunks_with_flags) == 4
|
||||||
|
assert chunks_with_flags[-1] == ("", True) # Final empty chunk
|
||||||
|
end_of_stream_flags = [f for c, f in chunks_with_flags]
|
||||||
|
assert end_of_stream_flags.count(True) == 1
|
||||||
260
tests/unit/test_base/test_prompt_client_streaming.py
Normal file
260
tests/unit/test_base/test_prompt_client_streaming.py
Normal file
|
|
@ -0,0 +1,260 @@
|
||||||
|
"""
|
||||||
|
Unit tests for PromptClient streaming callback behavior.
|
||||||
|
|
||||||
|
These tests verify that the prompt client correctly passes the end_of_stream
|
||||||
|
flag to chunk callbacks, ensuring proper streaming protocol compliance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||||
|
from trustgraph.base.prompt_client import PromptClient
|
||||||
|
from trustgraph.schema import PromptResponse
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptClientStreamingCallback:
|
||||||
|
"""Test PromptClient streaming callback behavior"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def prompt_client(self):
|
||||||
|
"""Create a PromptClient with mocked dependencies"""
|
||||||
|
# Mock all the required initialization parameters
|
||||||
|
with patch.object(PromptClient, '__init__', lambda self: None):
|
||||||
|
client = PromptClient()
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_request_response(self):
|
||||||
|
"""Create a mock request/response handler"""
|
||||||
|
async def mock_request(request, recipient=None, timeout=600):
|
||||||
|
if recipient:
|
||||||
|
# Simulate streaming responses
|
||||||
|
responses = [
|
||||||
|
PromptResponse(text="Hello", object=None, error=None, end_of_stream=False),
|
||||||
|
PromptResponse(text=" world", object=None, error=None, end_of_stream=False),
|
||||||
|
PromptResponse(text="!", object=None, error=None, end_of_stream=False),
|
||||||
|
PromptResponse(text="", object=None, error=None, end_of_stream=True),
|
||||||
|
]
|
||||||
|
for resp in responses:
|
||||||
|
should_stop = await recipient(resp)
|
||||||
|
if should_stop:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Non-streaming response
|
||||||
|
return PromptResponse(text="Hello world!", object=None, error=None)
|
||||||
|
|
||||||
|
return mock_request
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_receives_chunk_and_end_of_stream(self, prompt_client, mock_request_response):
|
||||||
|
"""Test that callback receives both chunk text and end_of_stream flag"""
|
||||||
|
# Arrange
|
||||||
|
prompt_client.request = mock_request_response
|
||||||
|
|
||||||
|
callback = AsyncMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await prompt_client.prompt(
|
||||||
|
id="test-prompt",
|
||||||
|
variables={"query": "test"},
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - callback should be called with (chunk, end_of_stream) signature
|
||||||
|
assert callback.call_count == 4
|
||||||
|
|
||||||
|
# Verify first chunk: text + end_of_stream=False
|
||||||
|
assert callback.call_args_list[0] == call("Hello", False)
|
||||||
|
|
||||||
|
# Verify second chunk
|
||||||
|
assert callback.call_args_list[1] == call(" world", False)
|
||||||
|
|
||||||
|
# Verify third chunk
|
||||||
|
assert callback.call_args_list[2] == call("!", False)
|
||||||
|
|
||||||
|
# Verify final chunk: empty text + end_of_stream=True
|
||||||
|
assert callback.call_args_list[3] == call("", True)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_receives_empty_final_chunk(self, prompt_client, mock_request_response):
|
||||||
|
"""Test that empty final chunks are passed to callback"""
|
||||||
|
# Arrange
|
||||||
|
prompt_client.request = mock_request_response
|
||||||
|
|
||||||
|
chunks_received = []
|
||||||
|
|
||||||
|
async def collect_chunks(chunk, end_of_stream):
|
||||||
|
chunks_received.append((chunk, end_of_stream))
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await prompt_client.prompt(
|
||||||
|
id="test-prompt",
|
||||||
|
variables={"query": "test"},
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=collect_chunks
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - should receive the empty final chunk
|
||||||
|
final_chunk = chunks_received[-1]
|
||||||
|
assert final_chunk == ("", True), "Final chunk should be empty string with end_of_stream=True"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_signature_with_non_empty_final_chunk(self, prompt_client):
|
||||||
|
"""Test callback signature when LLM sends non-empty final chunk"""
|
||||||
|
# Arrange
|
||||||
|
async def mock_request_non_empty_final(request, recipient=None, timeout=600):
|
||||||
|
if recipient:
|
||||||
|
# Some LLMs send content in the final chunk
|
||||||
|
responses = [
|
||||||
|
PromptResponse(text="Hello", object=None, error=None, end_of_stream=False),
|
||||||
|
PromptResponse(text=" world!", object=None, error=None, end_of_stream=True),
|
||||||
|
]
|
||||||
|
for resp in responses:
|
||||||
|
should_stop = await recipient(resp)
|
||||||
|
if should_stop:
|
||||||
|
break
|
||||||
|
|
||||||
|
prompt_client.request = mock_request_non_empty_final
|
||||||
|
|
||||||
|
callback = AsyncMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await prompt_client.prompt(
|
||||||
|
id="test-prompt",
|
||||||
|
variables={"query": "test"},
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert callback.call_count == 2
|
||||||
|
assert callback.call_args_list[0] == call("Hello", False)
|
||||||
|
assert callback.call_args_list[1] == call(" world!", True)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_not_called_without_text(self, prompt_client):
|
||||||
|
"""Test that callback is not called for responses without text"""
|
||||||
|
# Arrange
|
||||||
|
async def mock_request_no_text(request, recipient=None, timeout=600):
|
||||||
|
if recipient:
|
||||||
|
# Response with only end_of_stream, no text
|
||||||
|
responses = [
|
||||||
|
PromptResponse(text="Content", object=None, error=None, end_of_stream=False),
|
||||||
|
PromptResponse(text=None, object=None, error=None, end_of_stream=True),
|
||||||
|
]
|
||||||
|
for resp in responses:
|
||||||
|
should_stop = await recipient(resp)
|
||||||
|
if should_stop:
|
||||||
|
break
|
||||||
|
|
||||||
|
prompt_client.request = mock_request_no_text
|
||||||
|
|
||||||
|
callback = AsyncMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await prompt_client.prompt(
|
||||||
|
id="test-prompt",
|
||||||
|
variables={"query": "test"},
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - callback should only be called once (for "Content")
|
||||||
|
assert callback.call_count == 1
|
||||||
|
assert callback.call_args_list[0] == call("Content", False)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_synchronous_callback_also_receives_end_of_stream(self, prompt_client):
|
||||||
|
"""Test that synchronous callbacks also receive end_of_stream parameter"""
|
||||||
|
# Arrange
|
||||||
|
async def mock_request(request, recipient=None, timeout=600):
|
||||||
|
if recipient:
|
||||||
|
responses = [
|
||||||
|
PromptResponse(text="test", object=None, error=None, end_of_stream=False),
|
||||||
|
PromptResponse(text="", object=None, error=None, end_of_stream=True),
|
||||||
|
]
|
||||||
|
for resp in responses:
|
||||||
|
should_stop = await recipient(resp)
|
||||||
|
if should_stop:
|
||||||
|
break
|
||||||
|
|
||||||
|
prompt_client.request = mock_request
|
||||||
|
|
||||||
|
callback = MagicMock() # Synchronous mock
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await prompt_client.prompt(
|
||||||
|
id="test-prompt",
|
||||||
|
variables={"query": "test"},
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - synchronous callback should also get both parameters
|
||||||
|
assert callback.call_count == 2
|
||||||
|
assert callback.call_args_list[0] == call("test", False)
|
||||||
|
assert callback.call_args_list[1] == call("", True)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_kg_prompt_passes_parameters_to_callback(self, prompt_client):
|
||||||
|
"""Test that kg_prompt correctly passes streaming parameters"""
|
||||||
|
# Arrange
|
||||||
|
async def mock_request(request, recipient=None, timeout=600):
|
||||||
|
if recipient:
|
||||||
|
responses = [
|
||||||
|
PromptResponse(text="Answer", object=None, error=None, end_of_stream=False),
|
||||||
|
PromptResponse(text="", object=None, error=None, end_of_stream=True),
|
||||||
|
]
|
||||||
|
for resp in responses:
|
||||||
|
should_stop = await recipient(resp)
|
||||||
|
if should_stop:
|
||||||
|
break
|
||||||
|
|
||||||
|
prompt_client.request = mock_request
|
||||||
|
|
||||||
|
callback = AsyncMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await prompt_client.kg_prompt(
|
||||||
|
query="What is machine learning?",
|
||||||
|
kg=[("subject", "predicate", "object")],
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert callback.call_count == 2
|
||||||
|
assert callback.call_args_list[0] == call("Answer", False)
|
||||||
|
assert callback.call_args_list[1] == call("", True)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_document_prompt_passes_parameters_to_callback(self, prompt_client):
|
||||||
|
"""Test that document_prompt correctly passes streaming parameters"""
|
||||||
|
# Arrange
|
||||||
|
async def mock_request(request, recipient=None, timeout=600):
|
||||||
|
if recipient:
|
||||||
|
responses = [
|
||||||
|
PromptResponse(text="Summary", object=None, error=None, end_of_stream=False),
|
||||||
|
PromptResponse(text="", object=None, error=None, end_of_stream=True),
|
||||||
|
]
|
||||||
|
for resp in responses:
|
||||||
|
should_stop = await recipient(resp)
|
||||||
|
if should_stop:
|
||||||
|
break
|
||||||
|
|
||||||
|
prompt_client.request = mock_request
|
||||||
|
|
||||||
|
callback = AsyncMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await prompt_client.document_prompt(
|
||||||
|
query="Summarize this",
|
||||||
|
documents=["doc1", "doc2"],
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert callback.call_count == 2
|
||||||
|
assert callback.call_args_list[0] == call("Summary", False)
|
||||||
|
assert callback.call_args_list[1] == call("", True)
|
||||||
|
|
@ -102,7 +102,7 @@ async def test_handle_normal_flow():
|
||||||
"""Test normal websocket handling flow."""
|
"""Test normal websocket handling flow."""
|
||||||
mock_auth = MagicMock()
|
mock_auth = MagicMock()
|
||||||
mock_auth.permitted.return_value = True
|
mock_auth.permitted.return_value = True
|
||||||
|
|
||||||
dispatcher_created = False
|
dispatcher_created = False
|
||||||
async def mock_dispatcher_factory(ws, running, match_info):
|
async def mock_dispatcher_factory(ws, running, match_info):
|
||||||
nonlocal dispatcher_created
|
nonlocal dispatcher_created
|
||||||
|
|
@ -110,33 +110,41 @@ async def test_handle_normal_flow():
|
||||||
dispatcher = AsyncMock()
|
dispatcher = AsyncMock()
|
||||||
dispatcher.destroy = AsyncMock()
|
dispatcher.destroy = AsyncMock()
|
||||||
return dispatcher
|
return dispatcher
|
||||||
|
|
||||||
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||||
|
|
||||||
request = MagicMock()
|
request = MagicMock()
|
||||||
request.query = {"token": "valid-token"}
|
request.query = {"token": "valid-token"}
|
||||||
request.match_info = {}
|
request.match_info = {}
|
||||||
|
|
||||||
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
||||||
mock_ws = AsyncMock()
|
mock_ws = AsyncMock()
|
||||||
mock_ws.prepare = AsyncMock()
|
mock_ws.prepare = AsyncMock()
|
||||||
mock_ws.close = AsyncMock()
|
mock_ws.close = AsyncMock()
|
||||||
mock_ws.closed = False
|
mock_ws.closed = False
|
||||||
mock_ws_class.return_value = mock_ws
|
mock_ws_class.return_value = mock_ws
|
||||||
|
|
||||||
with patch('asyncio.TaskGroup') as mock_task_group:
|
with patch('asyncio.TaskGroup') as mock_task_group:
|
||||||
# Mock task group context manager
|
# Mock task group context manager
|
||||||
mock_tg = AsyncMock()
|
mock_tg = AsyncMock()
|
||||||
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
||||||
mock_tg.__aexit__ = AsyncMock(return_value=None)
|
mock_tg.__aexit__ = AsyncMock(return_value=None)
|
||||||
mock_tg.create_task = MagicMock(return_value=AsyncMock())
|
|
||||||
|
# Create proper mock tasks that look like asyncio.Task objects
|
||||||
|
def create_task_mock(coro):
|
||||||
|
task = AsyncMock()
|
||||||
|
task.done = MagicMock(return_value=True)
|
||||||
|
task.cancelled = MagicMock(return_value=False)
|
||||||
|
return task
|
||||||
|
|
||||||
|
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
|
||||||
mock_task_group.return_value = mock_tg
|
mock_task_group.return_value = mock_tg
|
||||||
|
|
||||||
result = await socket_endpoint.handle(request)
|
result = await socket_endpoint.handle(request)
|
||||||
|
|
||||||
# Should have created dispatcher
|
# Should have created dispatcher
|
||||||
assert dispatcher_created is True
|
assert dispatcher_created is True
|
||||||
|
|
||||||
# Should return websocket
|
# Should return websocket
|
||||||
assert result == mock_ws
|
assert result == mock_ws
|
||||||
|
|
||||||
|
|
@ -146,50 +154,58 @@ async def test_handle_exception_group_cleanup():
|
||||||
"""Test exception group triggers dispatcher cleanup."""
|
"""Test exception group triggers dispatcher cleanup."""
|
||||||
mock_auth = MagicMock()
|
mock_auth = MagicMock()
|
||||||
mock_auth.permitted.return_value = True
|
mock_auth.permitted.return_value = True
|
||||||
|
|
||||||
mock_dispatcher = AsyncMock()
|
mock_dispatcher = AsyncMock()
|
||||||
mock_dispatcher.destroy = AsyncMock()
|
mock_dispatcher.destroy = AsyncMock()
|
||||||
|
|
||||||
async def mock_dispatcher_factory(ws, running, match_info):
|
async def mock_dispatcher_factory(ws, running, match_info):
|
||||||
return mock_dispatcher
|
return mock_dispatcher
|
||||||
|
|
||||||
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||||
|
|
||||||
request = MagicMock()
|
request = MagicMock()
|
||||||
request.query = {"token": "valid-token"}
|
request.query = {"token": "valid-token"}
|
||||||
request.match_info = {}
|
request.match_info = {}
|
||||||
|
|
||||||
# Mock TaskGroup to raise ExceptionGroup
|
# Mock TaskGroup to raise ExceptionGroup
|
||||||
class TestException(Exception):
|
class TestException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
exception_group = ExceptionGroup("Test exceptions", [TestException("test")])
|
exception_group = ExceptionGroup("Test exceptions", [TestException("test")])
|
||||||
|
|
||||||
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
||||||
mock_ws = AsyncMock()
|
mock_ws = AsyncMock()
|
||||||
mock_ws.prepare = AsyncMock()
|
mock_ws.prepare = AsyncMock()
|
||||||
mock_ws.close = AsyncMock()
|
mock_ws.close = AsyncMock()
|
||||||
mock_ws.closed = False
|
mock_ws.closed = False
|
||||||
mock_ws_class.return_value = mock_ws
|
mock_ws_class.return_value = mock_ws
|
||||||
|
|
||||||
with patch('asyncio.TaskGroup') as mock_task_group:
|
with patch('asyncio.TaskGroup') as mock_task_group:
|
||||||
mock_tg = AsyncMock()
|
mock_tg = AsyncMock()
|
||||||
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
||||||
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
|
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
|
||||||
mock_tg.create_task = MagicMock(side_effect=TestException("test"))
|
|
||||||
|
# Create proper mock tasks that look like asyncio.Task objects
|
||||||
|
def create_task_mock(coro):
|
||||||
|
task = AsyncMock()
|
||||||
|
task.done = MagicMock(return_value=True)
|
||||||
|
task.cancelled = MagicMock(return_value=False)
|
||||||
|
return task
|
||||||
|
|
||||||
|
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
|
||||||
mock_task_group.return_value = mock_tg
|
mock_task_group.return_value = mock_tg
|
||||||
|
|
||||||
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for') as mock_wait_for:
|
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for') as mock_wait_for:
|
||||||
mock_wait_for.return_value = None
|
mock_wait_for.return_value = None
|
||||||
|
|
||||||
result = await socket_endpoint.handle(request)
|
result = await socket_endpoint.handle(request)
|
||||||
|
|
||||||
# Should have attempted graceful cleanup
|
# Should have attempted graceful cleanup
|
||||||
mock_wait_for.assert_called_once()
|
mock_wait_for.assert_called_once()
|
||||||
|
|
||||||
# Should have called destroy in finally block
|
# Should have called destroy in finally block
|
||||||
assert mock_dispatcher.destroy.call_count >= 1
|
assert mock_dispatcher.destroy.call_count >= 1
|
||||||
|
|
||||||
# Should have closed websocket
|
# Should have closed websocket
|
||||||
mock_ws.close.assert_called()
|
mock_ws.close.assert_called()
|
||||||
|
|
||||||
|
|
@ -199,48 +215,56 @@ async def test_handle_dispatcher_cleanup_timeout():
|
||||||
"""Test dispatcher cleanup with timeout."""
|
"""Test dispatcher cleanup with timeout."""
|
||||||
mock_auth = MagicMock()
|
mock_auth = MagicMock()
|
||||||
mock_auth.permitted.return_value = True
|
mock_auth.permitted.return_value = True
|
||||||
|
|
||||||
# Mock dispatcher that takes long to destroy
|
# Mock dispatcher that takes long to destroy
|
||||||
mock_dispatcher = AsyncMock()
|
mock_dispatcher = AsyncMock()
|
||||||
mock_dispatcher.destroy = AsyncMock()
|
mock_dispatcher.destroy = AsyncMock()
|
||||||
|
|
||||||
async def mock_dispatcher_factory(ws, running, match_info):
|
async def mock_dispatcher_factory(ws, running, match_info):
|
||||||
return mock_dispatcher
|
return mock_dispatcher
|
||||||
|
|
||||||
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||||
|
|
||||||
request = MagicMock()
|
request = MagicMock()
|
||||||
request.query = {"token": "valid-token"}
|
request.query = {"token": "valid-token"}
|
||||||
request.match_info = {}
|
request.match_info = {}
|
||||||
|
|
||||||
# Mock TaskGroup to raise exception
|
# Mock TaskGroup to raise exception
|
||||||
exception_group = ExceptionGroup("Test", [Exception("test")])
|
exception_group = ExceptionGroup("Test", [Exception("test")])
|
||||||
|
|
||||||
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
||||||
mock_ws = AsyncMock()
|
mock_ws = AsyncMock()
|
||||||
mock_ws.prepare = AsyncMock()
|
mock_ws.prepare = AsyncMock()
|
||||||
mock_ws.close = AsyncMock()
|
mock_ws.close = AsyncMock()
|
||||||
mock_ws.closed = False
|
mock_ws.closed = False
|
||||||
mock_ws_class.return_value = mock_ws
|
mock_ws_class.return_value = mock_ws
|
||||||
|
|
||||||
with patch('asyncio.TaskGroup') as mock_task_group:
|
with patch('asyncio.TaskGroup') as mock_task_group:
|
||||||
mock_tg = AsyncMock()
|
mock_tg = AsyncMock()
|
||||||
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
||||||
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
|
mock_tg.__aexit__ = AsyncMock(side_effect=exception_group)
|
||||||
mock_tg.create_task = MagicMock(side_effect=Exception("test"))
|
|
||||||
|
# Create proper mock tasks that look like asyncio.Task objects
|
||||||
|
def create_task_mock(coro):
|
||||||
|
task = AsyncMock()
|
||||||
|
task.done = MagicMock(return_value=True)
|
||||||
|
task.cancelled = MagicMock(return_value=False)
|
||||||
|
return task
|
||||||
|
|
||||||
|
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
|
||||||
mock_task_group.return_value = mock_tg
|
mock_task_group.return_value = mock_tg
|
||||||
|
|
||||||
# Mock asyncio.wait_for to raise TimeoutError
|
# Mock asyncio.wait_for to raise TimeoutError
|
||||||
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for') as mock_wait_for:
|
with patch('trustgraph.gateway.endpoint.socket.asyncio.wait_for') as mock_wait_for:
|
||||||
mock_wait_for.side_effect = asyncio.TimeoutError("Cleanup timeout")
|
mock_wait_for.side_effect = asyncio.TimeoutError("Cleanup timeout")
|
||||||
|
|
||||||
result = await socket_endpoint.handle(request)
|
result = await socket_endpoint.handle(request)
|
||||||
|
|
||||||
# Should have attempted cleanup with timeout
|
# Should have attempted cleanup with timeout
|
||||||
mock_wait_for.assert_called_once()
|
mock_wait_for.assert_called_once()
|
||||||
# Check that timeout was passed correctly
|
# Check that timeout was passed correctly
|
||||||
assert mock_wait_for.call_args[1]['timeout'] == 5.0
|
assert mock_wait_for.call_args[1]['timeout'] == 5.0
|
||||||
|
|
||||||
# Should still call destroy in finally block
|
# Should still call destroy in finally block
|
||||||
assert mock_dispatcher.destroy.call_count >= 1
|
assert mock_dispatcher.destroy.call_count >= 1
|
||||||
|
|
||||||
|
|
@ -290,37 +314,45 @@ async def test_handle_websocket_already_closed():
|
||||||
"""Test handling when websocket is already closed."""
|
"""Test handling when websocket is already closed."""
|
||||||
mock_auth = MagicMock()
|
mock_auth = MagicMock()
|
||||||
mock_auth.permitted.return_value = True
|
mock_auth.permitted.return_value = True
|
||||||
|
|
||||||
mock_dispatcher = AsyncMock()
|
mock_dispatcher = AsyncMock()
|
||||||
mock_dispatcher.destroy = AsyncMock()
|
mock_dispatcher.destroy = AsyncMock()
|
||||||
|
|
||||||
async def mock_dispatcher_factory(ws, running, match_info):
|
async def mock_dispatcher_factory(ws, running, match_info):
|
||||||
return mock_dispatcher
|
return mock_dispatcher
|
||||||
|
|
||||||
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory)
|
||||||
|
|
||||||
request = MagicMock()
|
request = MagicMock()
|
||||||
request.query = {"token": "valid-token"}
|
request.query = {"token": "valid-token"}
|
||||||
request.match_info = {}
|
request.match_info = {}
|
||||||
|
|
||||||
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
with patch('aiohttp.web.WebSocketResponse') as mock_ws_class:
|
||||||
mock_ws = AsyncMock()
|
mock_ws = AsyncMock()
|
||||||
mock_ws.prepare = AsyncMock()
|
mock_ws.prepare = AsyncMock()
|
||||||
mock_ws.close = AsyncMock()
|
mock_ws.close = AsyncMock()
|
||||||
mock_ws.closed = True # Already closed
|
mock_ws.closed = True # Already closed
|
||||||
mock_ws_class.return_value = mock_ws
|
mock_ws_class.return_value = mock_ws
|
||||||
|
|
||||||
with patch('asyncio.TaskGroup') as mock_task_group:
|
with patch('asyncio.TaskGroup') as mock_task_group:
|
||||||
mock_tg = AsyncMock()
|
mock_tg = AsyncMock()
|
||||||
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
mock_tg.__aenter__ = AsyncMock(return_value=mock_tg)
|
||||||
mock_tg.__aexit__ = AsyncMock(return_value=None)
|
mock_tg.__aexit__ = AsyncMock(return_value=None)
|
||||||
mock_tg.create_task = MagicMock(return_value=AsyncMock())
|
|
||||||
|
# Create proper mock tasks that look like asyncio.Task objects
|
||||||
|
def create_task_mock(coro):
|
||||||
|
task = AsyncMock()
|
||||||
|
task.done = MagicMock(return_value=True)
|
||||||
|
task.cancelled = MagicMock(return_value=False)
|
||||||
|
return task
|
||||||
|
|
||||||
|
mock_tg.create_task = MagicMock(side_effect=create_task_mock)
|
||||||
mock_task_group.return_value = mock_tg
|
mock_task_group.return_value = mock_tg
|
||||||
|
|
||||||
result = await socket_endpoint.handle(request)
|
result = await socket_endpoint.handle(request)
|
||||||
|
|
||||||
# Should still have called destroy
|
# Should still have called destroy
|
||||||
mock_dispatcher.destroy.assert_called()
|
mock_dispatcher.destroy.assert_called()
|
||||||
|
|
||||||
# Should not attempt to close already closed websocket
|
# Should not attempt to close already closed websocket
|
||||||
mock_ws.close.assert_not_called() # Not called in finally since ws.closed = True
|
mock_ws.close.assert_not_called() # Not called in finally since ws.closed = True
|
||||||
326
tests/unit/test_gateway/test_streaming_translators.py
Normal file
326
tests/unit/test_gateway/test_streaming_translators.py
Normal file
|
|
@ -0,0 +1,326 @@
|
||||||
|
"""
|
||||||
|
Unit tests for streaming behavior in message translators.
|
||||||
|
|
||||||
|
These tests verify that translators correctly handle empty strings and
|
||||||
|
end_of_stream flags in streaming responses, preventing bugs where empty
|
||||||
|
final chunks could be dropped due to falsy value checks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from trustgraph.messaging.translators.retrieval import (
|
||||||
|
GraphRagResponseTranslator,
|
||||||
|
DocumentRagResponseTranslator,
|
||||||
|
)
|
||||||
|
from trustgraph.messaging.translators.prompt import PromptResponseTranslator
|
||||||
|
from trustgraph.messaging.translators.text_completion import TextCompletionResponseTranslator
|
||||||
|
from trustgraph.schema import (
|
||||||
|
GraphRagResponse,
|
||||||
|
DocumentRagResponse,
|
||||||
|
PromptResponse,
|
||||||
|
TextCompletionResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGraphRagResponseTranslator:
|
||||||
|
"""Test GraphRagResponseTranslator streaming behavior"""
|
||||||
|
|
||||||
|
def test_from_pulsar_with_empty_response(self):
|
||||||
|
"""Test that empty response strings are preserved"""
|
||||||
|
# Arrange
|
||||||
|
translator = GraphRagResponseTranslator()
|
||||||
|
response = GraphRagResponse(
|
||||||
|
response="",
|
||||||
|
end_of_stream=True,
|
||||||
|
error=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Assert - Empty string should be included in result
|
||||||
|
assert "response" in result
|
||||||
|
assert result["response"] == ""
|
||||||
|
assert result["end_of_stream"] is True
|
||||||
|
|
||||||
|
def test_from_pulsar_with_non_empty_response(self):
|
||||||
|
"""Test that non-empty responses work correctly"""
|
||||||
|
# Arrange
|
||||||
|
translator = GraphRagResponseTranslator()
|
||||||
|
response = GraphRagResponse(
|
||||||
|
response="Some text",
|
||||||
|
end_of_stream=False,
|
||||||
|
error=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result["response"] == "Some text"
|
||||||
|
assert result["end_of_stream"] is False
|
||||||
|
|
||||||
|
def test_from_pulsar_with_none_response(self):
|
||||||
|
"""Test that None response is handled correctly"""
|
||||||
|
# Arrange
|
||||||
|
translator = GraphRagResponseTranslator()
|
||||||
|
response = GraphRagResponse(
|
||||||
|
response=None,
|
||||||
|
end_of_stream=True,
|
||||||
|
error=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Assert - None should not be included
|
||||||
|
assert "response" not in result
|
||||||
|
assert result["end_of_stream"] is True
|
||||||
|
|
||||||
|
def test_from_response_with_completion_returns_correct_flag(self):
|
||||||
|
"""Test that from_response_with_completion returns correct is_final flag"""
|
||||||
|
# Arrange
|
||||||
|
translator = GraphRagResponseTranslator()
|
||||||
|
|
||||||
|
# Test non-final chunk
|
||||||
|
response_chunk = GraphRagResponse(
|
||||||
|
response="chunk",
|
||||||
|
end_of_stream=False,
|
||||||
|
error=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result, is_final = translator.from_response_with_completion(response_chunk)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert is_final is False
|
||||||
|
assert result["end_of_stream"] is False
|
||||||
|
|
||||||
|
# Test final chunk with empty content
|
||||||
|
final_response = GraphRagResponse(
|
||||||
|
response="",
|
||||||
|
end_of_stream=True,
|
||||||
|
error=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result, is_final = translator.from_response_with_completion(final_response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert is_final is True
|
||||||
|
assert result["response"] == ""
|
||||||
|
assert result["end_of_stream"] is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentRagResponseTranslator:
|
||||||
|
"""Test DocumentRagResponseTranslator streaming behavior"""
|
||||||
|
|
||||||
|
def test_from_pulsar_with_empty_response(self):
|
||||||
|
"""Test that empty response strings are preserved"""
|
||||||
|
# Arrange
|
||||||
|
translator = DocumentRagResponseTranslator()
|
||||||
|
response = DocumentRagResponse(
|
||||||
|
response="",
|
||||||
|
end_of_stream=True,
|
||||||
|
error=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert "response" in result
|
||||||
|
assert result["response"] == ""
|
||||||
|
assert result["end_of_stream"] is True
|
||||||
|
|
||||||
|
def test_from_pulsar_with_non_empty_response(self):
|
||||||
|
"""Test that non-empty responses work correctly"""
|
||||||
|
# Arrange
|
||||||
|
translator = DocumentRagResponseTranslator()
|
||||||
|
response = DocumentRagResponse(
|
||||||
|
response="Document content",
|
||||||
|
end_of_stream=False,
|
||||||
|
error=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result["response"] == "Document content"
|
||||||
|
assert result["end_of_stream"] is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptResponseTranslator:
|
||||||
|
"""Test PromptResponseTranslator streaming behavior"""
|
||||||
|
|
||||||
|
def test_from_pulsar_with_empty_text(self):
|
||||||
|
"""Test that empty text strings are preserved"""
|
||||||
|
# Arrange
|
||||||
|
translator = PromptResponseTranslator()
|
||||||
|
response = PromptResponse(
|
||||||
|
text="",
|
||||||
|
object=None,
|
||||||
|
end_of_stream=True,
|
||||||
|
error=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert "text" in result
|
||||||
|
assert result["text"] == ""
|
||||||
|
assert result["end_of_stream"] is True
|
||||||
|
|
||||||
|
def test_from_pulsar_with_non_empty_text(self):
|
||||||
|
"""Test that non-empty text works correctly"""
|
||||||
|
# Arrange
|
||||||
|
translator = PromptResponseTranslator()
|
||||||
|
response = PromptResponse(
|
||||||
|
text="Some prompt response",
|
||||||
|
object=None,
|
||||||
|
end_of_stream=False,
|
||||||
|
error=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result["text"] == "Some prompt response"
|
||||||
|
assert result["end_of_stream"] is False
|
||||||
|
|
||||||
|
def test_from_pulsar_with_none_text(self):
|
||||||
|
"""Test that None text is handled correctly"""
|
||||||
|
# Arrange
|
||||||
|
translator = PromptResponseTranslator()
|
||||||
|
response = PromptResponse(
|
||||||
|
text=None,
|
||||||
|
object='{"result": "data"}',
|
||||||
|
end_of_stream=True,
|
||||||
|
error=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert "text" not in result
|
||||||
|
assert "object" in result
|
||||||
|
assert result["end_of_stream"] is True
|
||||||
|
|
||||||
|
def test_from_pulsar_includes_end_of_stream(self):
|
||||||
|
"""Test that end_of_stream flag is always included"""
|
||||||
|
# Arrange
|
||||||
|
translator = PromptResponseTranslator()
|
||||||
|
|
||||||
|
# Test with end_of_stream=False
|
||||||
|
response = PromptResponse(
|
||||||
|
text="chunk",
|
||||||
|
object=None,
|
||||||
|
end_of_stream=False,
|
||||||
|
error=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert "end_of_stream" in result
|
||||||
|
assert result["end_of_stream"] is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestTextCompletionResponseTranslator:
|
||||||
|
"""Test TextCompletionResponseTranslator streaming behavior"""
|
||||||
|
|
||||||
|
def test_from_pulsar_always_includes_response(self):
|
||||||
|
"""Test that response field is always included, even if empty"""
|
||||||
|
# Arrange
|
||||||
|
translator = TextCompletionResponseTranslator()
|
||||||
|
response = TextCompletionResponse(
|
||||||
|
response="",
|
||||||
|
end_of_stream=True,
|
||||||
|
error=None,
|
||||||
|
in_token=100,
|
||||||
|
out_token=5,
|
||||||
|
model="test-model"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Assert - Response should always be present
|
||||||
|
assert "response" in result
|
||||||
|
assert result["response"] == ""
|
||||||
|
|
||||||
|
def test_from_response_with_completion_with_empty_final(self):
|
||||||
|
"""Test that empty final response is handled correctly"""
|
||||||
|
# Arrange
|
||||||
|
translator = TextCompletionResponseTranslator()
|
||||||
|
response = TextCompletionResponse(
|
||||||
|
response="",
|
||||||
|
end_of_stream=True,
|
||||||
|
error=None,
|
||||||
|
in_token=100,
|
||||||
|
out_token=5,
|
||||||
|
model="test-model"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result, is_final = translator.from_response_with_completion(response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert is_final is True
|
||||||
|
assert result["response"] == ""
|
||||||
|
|
||||||
|
|
||||||
|
class TestStreamingProtocolCompliance:
|
||||||
|
"""Test that all translators follow streaming protocol conventions"""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("translator_class,response_class,field_name", [
|
||||||
|
(GraphRagResponseTranslator, GraphRagResponse, "response"),
|
||||||
|
(DocumentRagResponseTranslator, DocumentRagResponse, "response"),
|
||||||
|
(PromptResponseTranslator, PromptResponse, "text"),
|
||||||
|
(TextCompletionResponseTranslator, TextCompletionResponse, "response"),
|
||||||
|
])
|
||||||
|
def test_empty_final_chunk_preserved(self, translator_class, response_class, field_name):
|
||||||
|
"""Test that all translators preserve empty final chunks"""
|
||||||
|
# Arrange
|
||||||
|
translator = translator_class()
|
||||||
|
kwargs = {
|
||||||
|
field_name: "",
|
||||||
|
"end_of_stream": True,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
response = response_class(**kwargs)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert field_name in result, f"{translator_class.__name__} should include '{field_name}' field even when empty"
|
||||||
|
assert result[field_name] == "", f"{translator_class.__name__} should preserve empty string"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("translator_class,response_class,field_name", [
|
||||||
|
(GraphRagResponseTranslator, GraphRagResponse, "response"),
|
||||||
|
(DocumentRagResponseTranslator, DocumentRagResponse, "response"),
|
||||||
|
(TextCompletionResponseTranslator, TextCompletionResponse, "response"),
|
||||||
|
])
|
||||||
|
def test_end_of_stream_flag_included(self, translator_class, response_class, field_name):
|
||||||
|
"""Test that end_of_stream flag is included in all response translators"""
|
||||||
|
# Arrange
|
||||||
|
translator = translator_class()
|
||||||
|
kwargs = {
|
||||||
|
field_name: "test content",
|
||||||
|
"end_of_stream": True,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
response = response_class(**kwargs)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = translator.from_pulsar(response)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert "end_of_stream" in result, f"{translator_class.__name__} should include 'end_of_stream' flag"
|
||||||
|
assert result["end_of_stream"] is True
|
||||||
|
|
@ -51,7 +51,8 @@ class PromptClient(RequestResponse):
|
||||||
|
|
||||||
end_stream = getattr(resp, 'end_of_stream', False)
|
end_stream = getattr(resp, 'end_of_stream', False)
|
||||||
|
|
||||||
if resp.text:
|
# Always call callback if there's text OR if it's the final message
|
||||||
|
if resp.text is not None:
|
||||||
last_text = resp.text
|
last_text = resp.text
|
||||||
# Call chunk callback if provided with both chunk and end_of_stream flag
|
# Call chunk callback if provided with both chunk and end_of_stream flag
|
||||||
if chunk_callback:
|
if chunk_callback:
|
||||||
|
|
|
||||||
|
|
@ -28,14 +28,17 @@ class TextCompletionResponseTranslator(MessageTranslator):
|
||||||
|
|
||||||
def from_pulsar(self, obj: TextCompletionResponse) -> Dict[str, Any]:
|
def from_pulsar(self, obj: TextCompletionResponse) -> Dict[str, Any]:
|
||||||
result = {"response": obj.response}
|
result = {"response": obj.response}
|
||||||
|
|
||||||
if obj.in_token:
|
if obj.in_token:
|
||||||
result["in_token"] = obj.in_token
|
result["in_token"] = obj.in_token
|
||||||
if obj.out_token:
|
if obj.out_token:
|
||||||
result["out_token"] = obj.out_token
|
result["out_token"] = obj.out_token
|
||||||
if obj.model:
|
if obj.model:
|
||||||
result["model"] = obj.model
|
result["model"] = obj.model
|
||||||
|
|
||||||
|
# Always include end_of_stream flag for streaming support
|
||||||
|
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def from_response_with_completion(self, obj: TextCompletionResponse) -> Tuple[Dict[str, Any], bool]:
|
def from_response_with_completion(self, obj: TextCompletionResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue