Streaming rag responses (#568)

* Tech spec for streaming RAG

* Support for streaming Graph/Doc RAG
This commit is contained in:
cybermaggedon 2025-11-26 19:47:39 +00:00 committed by GitHub
parent b1cc724f7d
commit 1948edaa50
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 3087 additions and 94 deletions

29
tests/utils/__init__.py Normal file
View file

@ -0,0 +1,29 @@
"""Test utilities for TrustGraph tests"""
from .streaming_assertions import (
assert_streaming_chunks_valid,
assert_streaming_sequence,
assert_agent_streaming_chunks,
assert_rag_streaming_chunks,
assert_streaming_completion,
assert_streaming_content_matches,
assert_no_empty_chunks,
assert_streaming_error_handled,
assert_chunk_types_valid,
assert_streaming_latency_acceptable,
assert_callback_invoked,
)
__all__ = [
"assert_streaming_chunks_valid",
"assert_streaming_sequence",
"assert_agent_streaming_chunks",
"assert_rag_streaming_chunks",
"assert_streaming_completion",
"assert_streaming_content_matches",
"assert_no_empty_chunks",
"assert_streaming_error_handled",
"assert_chunk_types_valid",
"assert_streaming_latency_acceptable",
"assert_callback_invoked",
]

View file

@ -0,0 +1,218 @@
"""
Streaming test assertion helpers
Provides reusable assertion functions for validating streaming behavior
across different TrustGraph services.
"""
from typing import List, Dict, Any, Optional
def assert_streaming_chunks_valid(chunks: List[Any], min_chunks: int = 1):
"""
Assert that streaming chunks are valid and non-empty.
Args:
chunks: List of streaming chunks
min_chunks: Minimum number of expected chunks
"""
assert len(chunks) >= min_chunks, f"Expected at least {min_chunks} chunks, got {len(chunks)}"
assert all(chunk is not None for chunk in chunks), "All chunks should be non-None"
def assert_streaming_sequence(chunks: List[Dict[str, Any]], expected_sequence: List[str], key: str = "chunk_type"):
"""
Assert that streaming chunks follow an expected sequence.
Args:
chunks: List of chunk dictionaries
expected_sequence: Expected sequence of chunk types/values
key: Dictionary key to check (default: "chunk_type")
"""
actual_sequence = [chunk.get(key) for chunk in chunks if key in chunk]
assert actual_sequence == expected_sequence, \
f"Expected sequence {expected_sequence}, got {actual_sequence}"
def assert_agent_streaming_chunks(chunks: List[Dict[str, Any]]):
"""
Assert that agent streaming chunks have valid structure.
Validates:
- All chunks have chunk_type field
- All chunks have content field
- All chunks have end_of_message field
- All chunks have end_of_dialog field
- Last chunk has end_of_dialog=True
Args:
chunks: List of agent streaming chunk dictionaries
"""
assert len(chunks) > 0, "Expected at least one chunk"
for i, chunk in enumerate(chunks):
assert "chunk_type" in chunk, f"Chunk {i} missing chunk_type"
assert "content" in chunk, f"Chunk {i} missing content"
assert "end_of_message" in chunk, f"Chunk {i} missing end_of_message"
assert "end_of_dialog" in chunk, f"Chunk {i} missing end_of_dialog"
# Validate chunk_type values
valid_types = ["thought", "action", "observation", "final-answer"]
assert chunk["chunk_type"] in valid_types, \
f"Invalid chunk_type '{chunk['chunk_type']}' at index {i}"
# Last chunk should signal end of dialog
assert chunks[-1]["end_of_dialog"] is True, \
"Last chunk should have end_of_dialog=True"
def assert_rag_streaming_chunks(chunks: List[Dict[str, Any]]):
"""
Assert that RAG streaming chunks have valid structure.
Validates:
- All chunks except last have chunk field
- All chunks have end_of_stream field
- Last chunk has end_of_stream=True
- Last chunk may have response field with complete text
Args:
chunks: List of RAG streaming chunk dictionaries
"""
assert len(chunks) > 0, "Expected at least one chunk"
for i, chunk in enumerate(chunks):
assert "end_of_stream" in chunk, f"Chunk {i} missing end_of_stream"
if i < len(chunks) - 1:
# Non-final chunks should have chunk content and end_of_stream=False
assert "chunk" in chunk, f"Chunk {i} missing chunk field"
assert chunk["end_of_stream"] is False, \
f"Non-final chunk {i} should have end_of_stream=False"
else:
# Final chunk should have end_of_stream=True
assert chunk["end_of_stream"] is True, \
"Last chunk should have end_of_stream=True"
def assert_streaming_completion(chunks: List[Dict[str, Any]], expected_complete_flag: str = "end_of_stream"):
"""
Assert that streaming completed properly.
Args:
chunks: List of streaming chunk dictionaries
expected_complete_flag: Name of the completion flag field
"""
assert len(chunks) > 0, "Expected at least one chunk"
# Check that all but last chunk have completion flag = False
for i, chunk in enumerate(chunks[:-1]):
assert chunk.get(expected_complete_flag) is False, \
f"Non-final chunk {i} should have {expected_complete_flag}=False"
# Check that last chunk has completion flag = True
assert chunks[-1].get(expected_complete_flag) is True, \
f"Final chunk should have {expected_complete_flag}=True"
def assert_streaming_content_matches(chunks: List, expected_content: str, content_key: str = "chunk"):
"""
Assert that concatenated streaming chunks match expected content.
Args:
chunks: List of streaming chunks (strings or dicts)
expected_content: Expected complete content after concatenation
content_key: Dictionary key for content (used if chunks are dicts)
"""
if isinstance(chunks[0], dict):
# Extract content from chunk dictionaries
content_chunks = [
chunk.get(content_key, "")
for chunk in chunks
if chunk.get(content_key) is not None
]
actual_content = "".join(content_chunks)
else:
# Chunks are already strings
actual_content = "".join(chunks)
assert actual_content == expected_content, \
f"Expected content '{expected_content}', got '{actual_content}'"
def assert_no_empty_chunks(chunks: List[Dict[str, Any]], content_key: str = "content"):
"""
Assert that no chunks have empty content (except final chunk if it's completion marker).
Args:
chunks: List of streaming chunk dictionaries
content_key: Dictionary key for content
"""
for i, chunk in enumerate(chunks[:-1]):
content = chunk.get(content_key)
assert content is not None and len(content) > 0, \
f"Chunk {i} has empty content"
def assert_streaming_error_handled(chunks: List[Dict[str, Any]], error_flag: str = "error"):
"""
Assert that streaming error was properly signaled.
Args:
chunks: List of streaming chunk dictionaries
error_flag: Name of the error flag field
"""
# Check that at least one chunk has error flag
has_error = any(chunk.get(error_flag) is not None for chunk in chunks)
assert has_error, "Expected error flag in at least one chunk"
# If last chunk has error, should also have completion flag
if chunks[-1].get(error_flag):
# Check for completion flags (either end_of_stream or end_of_dialog)
completion_flags = ["end_of_stream", "end_of_dialog"]
has_completion = any(chunks[-1].get(flag) is True for flag in completion_flags)
assert has_completion, \
"Error chunk should have completion flag set to True"
def assert_chunk_types_valid(chunks: List[Dict[str, Any]], valid_types: List[str], type_key: str = "chunk_type"):
"""
Assert that all chunk types are from a valid set.
Args:
chunks: List of streaming chunk dictionaries
valid_types: List of valid chunk type values
type_key: Dictionary key for chunk type
"""
for i, chunk in enumerate(chunks):
chunk_type = chunk.get(type_key)
assert chunk_type in valid_types, \
f"Chunk {i} has invalid type '{chunk_type}', expected one of {valid_types}"
def assert_streaming_latency_acceptable(chunk_timestamps: List[float], max_gap_seconds: float = 5.0):
"""
Assert that streaming latency between chunks is acceptable.
Args:
chunk_timestamps: List of timestamps when chunks were received
max_gap_seconds: Maximum acceptable gap between chunks in seconds
"""
assert len(chunk_timestamps) > 1, "Need at least 2 timestamps to check latency"
for i in range(1, len(chunk_timestamps)):
gap = chunk_timestamps[i] - chunk_timestamps[i-1]
assert gap <= max_gap_seconds, \
f"Gap between chunks {i-1} and {i} is {gap:.2f}s, exceeds max {max_gap_seconds}s"
def assert_callback_invoked(mock_callback, min_calls: int = 1):
"""
Assert that a streaming callback was invoked minimum number of times.
Args:
mock_callback: AsyncMock callback object
min_calls: Minimum number of expected calls
"""
assert mock_callback.call_count >= min_calls, \
f"Expected callback to be called at least {min_calls} times, was called {mock_callback.call_count} times"