mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 16:36:21 +02:00
218 lines
8.1 KiB
Python
218 lines
8.1 KiB
Python
"""
|
|
Streaming test assertion helpers
|
|
|
|
Provides reusable assertion functions for validating streaming behavior
|
|
across different TrustGraph services.
|
|
"""
|
|
|
|
from typing import List, Dict, Any, Optional
|
|
|
|
|
|
def assert_streaming_chunks_valid(chunks: List[Any], min_chunks: int = 1):
|
|
"""
|
|
Assert that streaming chunks are valid and non-empty.
|
|
|
|
Args:
|
|
chunks: List of streaming chunks
|
|
min_chunks: Minimum number of expected chunks
|
|
"""
|
|
assert len(chunks) >= min_chunks, f"Expected at least {min_chunks} chunks, got {len(chunks)}"
|
|
assert all(chunk is not None for chunk in chunks), "All chunks should be non-None"
|
|
|
|
|
|
def assert_streaming_sequence(chunks: List[Dict[str, Any]], expected_sequence: List[str], key: str = "chunk_type"):
|
|
"""
|
|
Assert that streaming chunks follow an expected sequence.
|
|
|
|
Args:
|
|
chunks: List of chunk dictionaries
|
|
expected_sequence: Expected sequence of chunk types/values
|
|
key: Dictionary key to check (default: "chunk_type")
|
|
"""
|
|
actual_sequence = [chunk.get(key) for chunk in chunks if key in chunk]
|
|
assert actual_sequence == expected_sequence, \
|
|
f"Expected sequence {expected_sequence}, got {actual_sequence}"
|
|
|
|
|
|
def assert_agent_streaming_chunks(chunks: List[Dict[str, Any]]):
|
|
"""
|
|
Assert that agent streaming chunks have valid structure.
|
|
|
|
Validates:
|
|
- All chunks have chunk_type field
|
|
- All chunks have content field
|
|
- All chunks have end_of_message field
|
|
- All chunks have end_of_dialog field
|
|
- Last chunk has end_of_dialog=True
|
|
|
|
Args:
|
|
chunks: List of agent streaming chunk dictionaries
|
|
"""
|
|
assert len(chunks) > 0, "Expected at least one chunk"
|
|
|
|
for i, chunk in enumerate(chunks):
|
|
assert "chunk_type" in chunk, f"Chunk {i} missing chunk_type"
|
|
assert "content" in chunk, f"Chunk {i} missing content"
|
|
assert "end_of_message" in chunk, f"Chunk {i} missing end_of_message"
|
|
assert "end_of_dialog" in chunk, f"Chunk {i} missing end_of_dialog"
|
|
|
|
# Validate chunk_type values
|
|
valid_types = ["thought", "action", "observation", "final-answer"]
|
|
assert chunk["chunk_type"] in valid_types, \
|
|
f"Invalid chunk_type '{chunk['chunk_type']}' at index {i}"
|
|
|
|
# Last chunk should signal end of dialog
|
|
assert chunks[-1]["end_of_dialog"] is True, \
|
|
"Last chunk should have end_of_dialog=True"
|
|
|
|
|
|
def assert_rag_streaming_chunks(chunks: List[Dict[str, Any]]):
|
|
"""
|
|
Assert that RAG streaming chunks have valid structure.
|
|
|
|
Validates:
|
|
- All chunks except last have chunk field
|
|
- All chunks have end_of_stream field
|
|
- Last chunk has end_of_stream=True
|
|
- Last chunk may have response field with complete text
|
|
|
|
Args:
|
|
chunks: List of RAG streaming chunk dictionaries
|
|
"""
|
|
assert len(chunks) > 0, "Expected at least one chunk"
|
|
|
|
for i, chunk in enumerate(chunks):
|
|
assert "end_of_stream" in chunk, f"Chunk {i} missing end_of_stream"
|
|
|
|
if i < len(chunks) - 1:
|
|
# Non-final chunks should have chunk content and end_of_stream=False
|
|
assert "chunk" in chunk, f"Chunk {i} missing chunk field"
|
|
assert chunk["end_of_stream"] is False, \
|
|
f"Non-final chunk {i} should have end_of_stream=False"
|
|
else:
|
|
# Final chunk should have end_of_stream=True
|
|
assert chunk["end_of_stream"] is True, \
|
|
"Last chunk should have end_of_stream=True"
|
|
|
|
|
|
def assert_streaming_completion(chunks: List[Dict[str, Any]], expected_complete_flag: str = "end_of_stream"):
|
|
"""
|
|
Assert that streaming completed properly.
|
|
|
|
Args:
|
|
chunks: List of streaming chunk dictionaries
|
|
expected_complete_flag: Name of the completion flag field
|
|
"""
|
|
assert len(chunks) > 0, "Expected at least one chunk"
|
|
|
|
# Check that all but last chunk have completion flag = False
|
|
for i, chunk in enumerate(chunks[:-1]):
|
|
assert chunk.get(expected_complete_flag) is False, \
|
|
f"Non-final chunk {i} should have {expected_complete_flag}=False"
|
|
|
|
# Check that last chunk has completion flag = True
|
|
assert chunks[-1].get(expected_complete_flag) is True, \
|
|
f"Final chunk should have {expected_complete_flag}=True"
|
|
|
|
|
|
def assert_streaming_content_matches(chunks: List, expected_content: str, content_key: str = "chunk"):
|
|
"""
|
|
Assert that concatenated streaming chunks match expected content.
|
|
|
|
Args:
|
|
chunks: List of streaming chunks (strings or dicts)
|
|
expected_content: Expected complete content after concatenation
|
|
content_key: Dictionary key for content (used if chunks are dicts)
|
|
"""
|
|
if isinstance(chunks[0], dict):
|
|
# Extract content from chunk dictionaries
|
|
content_chunks = [
|
|
chunk.get(content_key, "")
|
|
for chunk in chunks
|
|
if chunk.get(content_key) is not None
|
|
]
|
|
actual_content = "".join(content_chunks)
|
|
else:
|
|
# Chunks are already strings
|
|
actual_content = "".join(chunks)
|
|
|
|
assert actual_content == expected_content, \
|
|
f"Expected content '{expected_content}', got '{actual_content}'"
|
|
|
|
|
|
def assert_no_empty_chunks(chunks: List[Dict[str, Any]], content_key: str = "content"):
|
|
"""
|
|
Assert that no chunks have empty content (except final chunk if it's completion marker).
|
|
|
|
Args:
|
|
chunks: List of streaming chunk dictionaries
|
|
content_key: Dictionary key for content
|
|
"""
|
|
for i, chunk in enumerate(chunks[:-1]):
|
|
content = chunk.get(content_key)
|
|
assert content is not None and len(content) > 0, \
|
|
f"Chunk {i} has empty content"
|
|
|
|
|
|
def assert_streaming_error_handled(chunks: List[Dict[str, Any]], error_flag: str = "error"):
|
|
"""
|
|
Assert that streaming error was properly signaled.
|
|
|
|
Args:
|
|
chunks: List of streaming chunk dictionaries
|
|
error_flag: Name of the error flag field
|
|
"""
|
|
# Check that at least one chunk has error flag
|
|
has_error = any(chunk.get(error_flag) is not None for chunk in chunks)
|
|
assert has_error, "Expected error flag in at least one chunk"
|
|
|
|
# If last chunk has error, should also have completion flag
|
|
if chunks[-1].get(error_flag):
|
|
# Check for completion flags (either end_of_stream or end_of_dialog)
|
|
completion_flags = ["end_of_stream", "end_of_dialog"]
|
|
has_completion = any(chunks[-1].get(flag) is True for flag in completion_flags)
|
|
assert has_completion, \
|
|
"Error chunk should have completion flag set to True"
|
|
|
|
|
|
def assert_chunk_types_valid(chunks: List[Dict[str, Any]], valid_types: List[str], type_key: str = "chunk_type"):
|
|
"""
|
|
Assert that all chunk types are from a valid set.
|
|
|
|
Args:
|
|
chunks: List of streaming chunk dictionaries
|
|
valid_types: List of valid chunk type values
|
|
type_key: Dictionary key for chunk type
|
|
"""
|
|
for i, chunk in enumerate(chunks):
|
|
chunk_type = chunk.get(type_key)
|
|
assert chunk_type in valid_types, \
|
|
f"Chunk {i} has invalid type '{chunk_type}', expected one of {valid_types}"
|
|
|
|
|
|
def assert_streaming_latency_acceptable(chunk_timestamps: List[float], max_gap_seconds: float = 5.0):
|
|
"""
|
|
Assert that streaming latency between chunks is acceptable.
|
|
|
|
Args:
|
|
chunk_timestamps: List of timestamps when chunks were received
|
|
max_gap_seconds: Maximum acceptable gap between chunks in seconds
|
|
"""
|
|
assert len(chunk_timestamps) > 1, "Need at least 2 timestamps to check latency"
|
|
|
|
for i in range(1, len(chunk_timestamps)):
|
|
gap = chunk_timestamps[i] - chunk_timestamps[i-1]
|
|
assert gap <= max_gap_seconds, \
|
|
f"Gap between chunks {i-1} and {i} is {gap:.2f}s, exceeds max {max_gap_seconds}s"
|
|
|
|
|
|
def assert_callback_invoked(mock_callback, min_calls: int = 1):
|
|
"""
|
|
Assert that a streaming callback was invoked minimum number of times.
|
|
|
|
Args:
|
|
mock_callback: AsyncMock callback object
|
|
min_calls: Minimum number of expected calls
|
|
"""
|
|
assert mock_callback.call_count >= min_calls, \
|
|
f"Expected callback to be called at least {min_calls} times, was called {mock_callback.call_count} times"
|