mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 08:26:21 +02:00
261 lines
9.5 KiB
Python
261 lines
9.5 KiB
Python
|
|
"""
|
||
|
|
Unit tests for PromptClient streaming callback behavior.
|
||
|
|
|
||
|
|
These tests verify that the prompt client correctly passes the end_of_stream
|
||
|
|
flag to chunk callbacks, ensuring proper streaming protocol compliance.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||
|
|
from trustgraph.base.prompt_client import PromptClient
|
||
|
|
from trustgraph.schema import PromptResponse
|
||
|
|
|
||
|
|
|
||
|
|
class TestPromptClientStreamingCallback:
|
||
|
|
"""Test PromptClient streaming callback behavior"""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def prompt_client(self):
|
||
|
|
"""Create a PromptClient with mocked dependencies"""
|
||
|
|
# Mock all the required initialization parameters
|
||
|
|
with patch.object(PromptClient, '__init__', lambda self: None):
|
||
|
|
client = PromptClient()
|
||
|
|
return client
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def mock_request_response(self):
|
||
|
|
"""Create a mock request/response handler"""
|
||
|
|
async def mock_request(request, recipient=None, timeout=600):
|
||
|
|
if recipient:
|
||
|
|
# Simulate streaming responses
|
||
|
|
responses = [
|
||
|
|
PromptResponse(text="Hello", object=None, error=None, end_of_stream=False),
|
||
|
|
PromptResponse(text=" world", object=None, error=None, end_of_stream=False),
|
||
|
|
PromptResponse(text="!", object=None, error=None, end_of_stream=False),
|
||
|
|
PromptResponse(text="", object=None, error=None, end_of_stream=True),
|
||
|
|
]
|
||
|
|
for resp in responses:
|
||
|
|
should_stop = await recipient(resp)
|
||
|
|
if should_stop:
|
||
|
|
break
|
||
|
|
else:
|
||
|
|
# Non-streaming response
|
||
|
|
return PromptResponse(text="Hello world!", object=None, error=None)
|
||
|
|
|
||
|
|
return mock_request
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_callback_receives_chunk_and_end_of_stream(self, prompt_client, mock_request_response):
|
||
|
|
"""Test that callback receives both chunk text and end_of_stream flag"""
|
||
|
|
# Arrange
|
||
|
|
prompt_client.request = mock_request_response
|
||
|
|
|
||
|
|
callback = AsyncMock()
|
||
|
|
|
||
|
|
# Act
|
||
|
|
await prompt_client.prompt(
|
||
|
|
id="test-prompt",
|
||
|
|
variables={"query": "test"},
|
||
|
|
streaming=True,
|
||
|
|
chunk_callback=callback
|
||
|
|
)
|
||
|
|
|
||
|
|
# Assert - callback should be called with (chunk, end_of_stream) signature
|
||
|
|
assert callback.call_count == 4
|
||
|
|
|
||
|
|
# Verify first chunk: text + end_of_stream=False
|
||
|
|
assert callback.call_args_list[0] == call("Hello", False)
|
||
|
|
|
||
|
|
# Verify second chunk
|
||
|
|
assert callback.call_args_list[1] == call(" world", False)
|
||
|
|
|
||
|
|
# Verify third chunk
|
||
|
|
assert callback.call_args_list[2] == call("!", False)
|
||
|
|
|
||
|
|
# Verify final chunk: empty text + end_of_stream=True
|
||
|
|
assert callback.call_args_list[3] == call("", True)
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_callback_receives_empty_final_chunk(self, prompt_client, mock_request_response):
|
||
|
|
"""Test that empty final chunks are passed to callback"""
|
||
|
|
# Arrange
|
||
|
|
prompt_client.request = mock_request_response
|
||
|
|
|
||
|
|
chunks_received = []
|
||
|
|
|
||
|
|
async def collect_chunks(chunk, end_of_stream):
|
||
|
|
chunks_received.append((chunk, end_of_stream))
|
||
|
|
|
||
|
|
# Act
|
||
|
|
await prompt_client.prompt(
|
||
|
|
id="test-prompt",
|
||
|
|
variables={"query": "test"},
|
||
|
|
streaming=True,
|
||
|
|
chunk_callback=collect_chunks
|
||
|
|
)
|
||
|
|
|
||
|
|
# Assert - should receive the empty final chunk
|
||
|
|
final_chunk = chunks_received[-1]
|
||
|
|
assert final_chunk == ("", True), "Final chunk should be empty string with end_of_stream=True"
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_callback_signature_with_non_empty_final_chunk(self, prompt_client):
|
||
|
|
"""Test callback signature when LLM sends non-empty final chunk"""
|
||
|
|
# Arrange
|
||
|
|
async def mock_request_non_empty_final(request, recipient=None, timeout=600):
|
||
|
|
if recipient:
|
||
|
|
# Some LLMs send content in the final chunk
|
||
|
|
responses = [
|
||
|
|
PromptResponse(text="Hello", object=None, error=None, end_of_stream=False),
|
||
|
|
PromptResponse(text=" world!", object=None, error=None, end_of_stream=True),
|
||
|
|
]
|
||
|
|
for resp in responses:
|
||
|
|
should_stop = await recipient(resp)
|
||
|
|
if should_stop:
|
||
|
|
break
|
||
|
|
|
||
|
|
prompt_client.request = mock_request_non_empty_final
|
||
|
|
|
||
|
|
callback = AsyncMock()
|
||
|
|
|
||
|
|
# Act
|
||
|
|
await prompt_client.prompt(
|
||
|
|
id="test-prompt",
|
||
|
|
variables={"query": "test"},
|
||
|
|
streaming=True,
|
||
|
|
chunk_callback=callback
|
||
|
|
)
|
||
|
|
|
||
|
|
# Assert
|
||
|
|
assert callback.call_count == 2
|
||
|
|
assert callback.call_args_list[0] == call("Hello", False)
|
||
|
|
assert callback.call_args_list[1] == call(" world!", True)
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_callback_not_called_without_text(self, prompt_client):
|
||
|
|
"""Test that callback is not called for responses without text"""
|
||
|
|
# Arrange
|
||
|
|
async def mock_request_no_text(request, recipient=None, timeout=600):
|
||
|
|
if recipient:
|
||
|
|
# Response with only end_of_stream, no text
|
||
|
|
responses = [
|
||
|
|
PromptResponse(text="Content", object=None, error=None, end_of_stream=False),
|
||
|
|
PromptResponse(text=None, object=None, error=None, end_of_stream=True),
|
||
|
|
]
|
||
|
|
for resp in responses:
|
||
|
|
should_stop = await recipient(resp)
|
||
|
|
if should_stop:
|
||
|
|
break
|
||
|
|
|
||
|
|
prompt_client.request = mock_request_no_text
|
||
|
|
|
||
|
|
callback = AsyncMock()
|
||
|
|
|
||
|
|
# Act
|
||
|
|
await prompt_client.prompt(
|
||
|
|
id="test-prompt",
|
||
|
|
variables={"query": "test"},
|
||
|
|
streaming=True,
|
||
|
|
chunk_callback=callback
|
||
|
|
)
|
||
|
|
|
||
|
|
# Assert - callback should only be called once (for "Content")
|
||
|
|
assert callback.call_count == 1
|
||
|
|
assert callback.call_args_list[0] == call("Content", False)
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_synchronous_callback_also_receives_end_of_stream(self, prompt_client):
|
||
|
|
"""Test that synchronous callbacks also receive end_of_stream parameter"""
|
||
|
|
# Arrange
|
||
|
|
async def mock_request(request, recipient=None, timeout=600):
|
||
|
|
if recipient:
|
||
|
|
responses = [
|
||
|
|
PromptResponse(text="test", object=None, error=None, end_of_stream=False),
|
||
|
|
PromptResponse(text="", object=None, error=None, end_of_stream=True),
|
||
|
|
]
|
||
|
|
for resp in responses:
|
||
|
|
should_stop = await recipient(resp)
|
||
|
|
if should_stop:
|
||
|
|
break
|
||
|
|
|
||
|
|
prompt_client.request = mock_request
|
||
|
|
|
||
|
|
callback = MagicMock() # Synchronous mock
|
||
|
|
|
||
|
|
# Act
|
||
|
|
await prompt_client.prompt(
|
||
|
|
id="test-prompt",
|
||
|
|
variables={"query": "test"},
|
||
|
|
streaming=True,
|
||
|
|
chunk_callback=callback
|
||
|
|
)
|
||
|
|
|
||
|
|
# Assert - synchronous callback should also get both parameters
|
||
|
|
assert callback.call_count == 2
|
||
|
|
assert callback.call_args_list[0] == call("test", False)
|
||
|
|
assert callback.call_args_list[1] == call("", True)
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_kg_prompt_passes_parameters_to_callback(self, prompt_client):
|
||
|
|
"""Test that kg_prompt correctly passes streaming parameters"""
|
||
|
|
# Arrange
|
||
|
|
async def mock_request(request, recipient=None, timeout=600):
|
||
|
|
if recipient:
|
||
|
|
responses = [
|
||
|
|
PromptResponse(text="Answer", object=None, error=None, end_of_stream=False),
|
||
|
|
PromptResponse(text="", object=None, error=None, end_of_stream=True),
|
||
|
|
]
|
||
|
|
for resp in responses:
|
||
|
|
should_stop = await recipient(resp)
|
||
|
|
if should_stop:
|
||
|
|
break
|
||
|
|
|
||
|
|
prompt_client.request = mock_request
|
||
|
|
|
||
|
|
callback = AsyncMock()
|
||
|
|
|
||
|
|
# Act
|
||
|
|
await prompt_client.kg_prompt(
|
||
|
|
query="What is machine learning?",
|
||
|
|
kg=[("subject", "predicate", "object")],
|
||
|
|
streaming=True,
|
||
|
|
chunk_callback=callback
|
||
|
|
)
|
||
|
|
|
||
|
|
# Assert
|
||
|
|
assert callback.call_count == 2
|
||
|
|
assert callback.call_args_list[0] == call("Answer", False)
|
||
|
|
assert callback.call_args_list[1] == call("", True)
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_document_prompt_passes_parameters_to_callback(self, prompt_client):
|
||
|
|
"""Test that document_prompt correctly passes streaming parameters"""
|
||
|
|
# Arrange
|
||
|
|
async def mock_request(request, recipient=None, timeout=600):
|
||
|
|
if recipient:
|
||
|
|
responses = [
|
||
|
|
PromptResponse(text="Summary", object=None, error=None, end_of_stream=False),
|
||
|
|
PromptResponse(text="", object=None, error=None, end_of_stream=True),
|
||
|
|
]
|
||
|
|
for resp in responses:
|
||
|
|
should_stop = await recipient(resp)
|
||
|
|
if should_stop:
|
||
|
|
break
|
||
|
|
|
||
|
|
prompt_client.request = mock_request
|
||
|
|
|
||
|
|
callback = AsyncMock()
|
||
|
|
|
||
|
|
# Act
|
||
|
|
await prompt_client.document_prompt(
|
||
|
|
query="Summarize this",
|
||
|
|
documents=["doc1", "doc2"],
|
||
|
|
streaming=True,
|
||
|
|
chunk_callback=callback
|
||
|
|
)
|
||
|
|
|
||
|
|
# Assert
|
||
|
|
assert callback.call_count == 2
|
||
|
|
assert callback.call_args_list[0] == call("Summary", False)
|
||
|
|
assert callback.call_args_list[1] == call("", True)
|