Streaming rag responses (#568)

* Tech spec for streaming RAG

* Support for streaming Graph/Doc RAG
This commit is contained in:
cybermaggedon 2025-11-26 19:47:39 +00:00 committed by GitHub
parent b1cc724f7d
commit 1948edaa50
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 3087 additions and 94 deletions

View file

@ -0,0 +1,288 @@
# RAG Streaming Support Technical Specification
## Overview
This specification describes adding streaming support to GraphRAG and DocumentRAG services, enabling real-time token-by-token responses for knowledge graph and document retrieval queries. This extends the existing streaming architecture already implemented for LLM text-completion, prompt, and agent services.
## Goals
- **Consistent streaming UX**: Provide the same streaming experience across all TrustGraph services
- **Minimal API changes**: Add streaming support with a single `streaming` flag, following established patterns
- **Backward compatibility**: Maintain existing non-streaming behavior as default
- **Reuse existing infrastructure**: Leverage PromptClient streaming already implemented
- **Gateway support**: Enable streaming through websocket gateway for client applications
## Background
Currently implemented streaming services:
- **LLM text-completion service**: Phase 1 - streaming from LLM providers
- **Prompt service**: Phase 2 - streaming through prompt templates
- **Agent service**: Phase 3-4 - streaming ReAct responses with incremental thought/observation/answer chunks
Current limitations for RAG services:
- GraphRAG and DocumentRAG only support blocking responses
- Users must wait for complete LLM response before seeing any output
- Poor UX for long responses from knowledge graph or document queries
- Inconsistent experience compared to other TrustGraph services
This specification addresses these gaps by adding streaming support to GraphRAG and DocumentRAG. By enabling token-by-token responses, TrustGraph can:
- Provide consistent streaming UX across all query types
- Reduce perceived latency for RAG queries
- Enable better progress feedback for long-running queries
- Support real-time display in client applications
## Technical Design
### Architecture
The RAG streaming implementation leverages existing infrastructure:
1. **PromptClient Streaming** (Already implemented)
- `kg_prompt()` and `document_prompt()` already accept `streaming` and `chunk_callback` parameters
- These call `prompt()` internally with streaming support
- No changes needed to PromptClient
Module: `trustgraph-base/trustgraph/base/prompt_client.py`
2. **GraphRAG Service** (Needs streaming parameter pass-through)
- Add `streaming` parameter to `query()` method
- Pass streaming flag and callbacks to `prompt_client.kg_prompt()`
- GraphRagRequest schema needs `streaming` field
Modules:
- `trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py`
- `trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py` (Processor)
- `trustgraph-base/trustgraph/schema/graph_rag.py` (Request schema)
- `trustgraph-flow/trustgraph/gateway/dispatch/graph_rag.py` (Gateway)
3. **DocumentRAG Service** (Needs streaming parameter pass-through)
- Add `streaming` parameter to `query()` method
- Pass streaming flag and callbacks to `prompt_client.document_prompt()`
- DocumentRagRequest schema needs `streaming` field
Modules:
- `trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py`
- `trustgraph-flow/trustgraph/retrieval/document_rag/rag.py` (Processor)
- `trustgraph-base/trustgraph/schema/document_rag.py` (Request schema)
- `trustgraph-flow/trustgraph/gateway/dispatch/document_rag.py` (Gateway)
### Data Flow
**Non-streaming (current)**:
```
Client → Gateway → RAG Service → PromptClient.kg_prompt(streaming=False)
Prompt Service → LLM
Complete response
Client ← Gateway ← RAG Service ← Response
```
**Streaming (proposed)**:
```
Client → Gateway → RAG Service → PromptClient.kg_prompt(streaming=True, chunk_callback=cb)
Prompt Service → LLM (streaming)
Chunk → callback → RAG Response (chunk)
↓ ↓
Client ← Gateway ← ────────────────────────────────── Response stream
```
### APIs
**GraphRAG Changes**:
1. **GraphRag.query()** - Add streaming parameters
```python
async def query(
self, query, user, collection,
verbose=False, streaming=False, chunk_callback=None # NEW
):
# ... existing entity/triple retrieval ...
if streaming and chunk_callback:
resp = await self.prompt_client.kg_prompt(
query, kg,
streaming=True,
chunk_callback=chunk_callback
)
else:
resp = await self.prompt_client.kg_prompt(query, kg)
return resp
```
2. **GraphRagRequest schema** - Add streaming field
```python
class GraphRagRequest(Record):
query = String()
user = String()
collection = String()
streaming = Boolean() # NEW
```
3. **GraphRagResponse schema** - Add streaming fields (follow Agent pattern)
```python
class GraphRagResponse(Record):
response = String() # Legacy: complete response
chunk = String() # NEW: streaming chunk
end_of_stream = Boolean() # NEW: indicates last chunk
```
4. **Processor** - Pass streaming through
```python
async def handle(self, msg):
# ... existing code ...
async def send_chunk(chunk):
await self.respond(GraphRagResponse(
chunk=chunk,
end_of_stream=False,
response=None
))
if request.streaming:
full_response = await self.rag.query(
query=request.query,
user=request.user,
collection=request.collection,
streaming=True,
chunk_callback=send_chunk
)
# Send final message
await self.respond(GraphRagResponse(
chunk=None,
end_of_stream=True,
response=full_response
))
else:
# Existing non-streaming path
response = await self.rag.query(...)
await self.respond(GraphRagResponse(response=response))
```
**DocumentRAG Changes**:
Identical pattern to GraphRAG:
1. Add `streaming` and `chunk_callback` parameters to `DocumentRag.query()`
2. Add `streaming` field to `DocumentRagRequest`
3. Add `chunk` and `end_of_stream` fields to `DocumentRagResponse`
4. Update Processor to handle streaming with callbacks
**Gateway Changes**:
Both `graph_rag.py` and `document_rag.py` in gateway/dispatch need updates to forward streaming chunks to websocket:
```python
async def handle(self, message, session, websocket):
# ... existing code ...
if request.streaming:
async def recipient(resp):
if resp.chunk:
await websocket.send(json.dumps({
"id": message["id"],
"response": {"chunk": resp.chunk},
"complete": resp.end_of_stream
}))
return resp.end_of_stream
await self.rag_client.request(request, recipient=recipient)
else:
# Existing non-streaming path
resp = await self.rag_client.request(request)
await websocket.send(...)
```
### Implementation Details
**Implementation order**:
1. Add schema fields (Request + Response for both RAG services)
2. Update GraphRag.query() and DocumentRag.query() methods
3. Update Processors to handle streaming
4. Update Gateway dispatch handlers
5. Add `--no-streaming` flags to `tg-invoke-graph-rag` and `tg-invoke-document-rag` (streaming enabled by default, following agent CLI pattern)
**Callback pattern**:
Follow the same async callback pattern established in Agent streaming:
- Processor defines `async def send_chunk(chunk)` callback
- Passes callback to RAG service
- RAG service passes callback to PromptClient
- PromptClient invokes callback for each LLM chunk
- Processor sends streaming response message for each chunk
**Error handling**:
- Errors during streaming should send error response with `end_of_stream=True`
- Follow existing error propagation patterns from Agent streaming
## Security Considerations
No new security considerations beyond existing RAG services:
- Streaming responses use same user/collection isolation
- No changes to authentication or authorization
- Chunk boundaries don't expose sensitive data
## Performance Considerations
**Benefits**:
- Reduced perceived latency (first tokens arrive faster)
- Better UX for long responses
- Lower memory usage (no need to buffer complete response)
**Potential concerns**:
- More Pulsar messages for streaming responses
- Slightly higher CPU for chunking/callback overhead
- Mitigated by: streaming is opt-in, default remains non-streaming
**Testing considerations**:
- Test with large knowledge graphs (many triples)
- Test with many retrieved documents
- Measure overhead of streaming vs non-streaming
## Testing Strategy
**Unit tests**:
- Test GraphRag.query() with streaming=True/False
- Test DocumentRag.query() with streaming=True/False
- Mock PromptClient to verify callback invocations
**Integration tests**:
- Test full GraphRAG streaming flow (similar to existing agent streaming tests)
- Test full DocumentRAG streaming flow
- Test Gateway streaming forwarding
- Test CLI streaming output
**Manual testing**:
- `tg-invoke-graph-rag -q "What is machine learning?"` (streaming by default)
- `tg-invoke-document-rag -q "Summarize the documents about AI"` (streaming by default)
- `tg-invoke-graph-rag --no-streaming -q "..."` (test non-streaming mode)
- Verify incremental output appears in streaming mode
## Migration Plan
No migration needed:
- Streaming is opt-in via `streaming` parameter (defaults to False)
- Existing clients continue to work unchanged
- New clients can opt into streaming
## Timeline
Estimated implementation: 4-6 hours
- Phase 1 (2 hours): GraphRAG streaming support
- Phase 2 (2 hours): DocumentRAG streaming support
- Phase 3 (1-2 hours): Gateway updates and CLI flags
- Testing: Built into each phase
## Open Questions
- Should we add streaming support to NLP Query service as well?
- Do we want to stream intermediate steps (e.g., "Retrieving entities...", "Querying graph...") or just LLM output?
- Should GraphRAG/DocumentRAG responses include chunk metadata (e.g., chunk number, total expected)?
## References
- Existing implementation: `docs/tech-specs/streaming-llm-responses.md`
- Agent streaming: `trustgraph-flow/trustgraph/agent/react/agent_manager.py`
- PromptClient streaming: `trustgraph-base/trustgraph/base/prompt_client.py`

View file

@ -382,6 +382,206 @@ def sample_kg_triples():
]
# Streaming test fixtures
@pytest.fixture
def mock_streaming_llm_response():
"""Mock streaming LLM response with realistic chunks"""
async def _generate_chunks():
"""Generate realistic streaming chunks"""
chunks = [
"Machine",
" learning",
" is",
" a",
" subset",
" of",
" artificial",
" intelligence",
" that",
" focuses",
" on",
" algorithms",
" that",
" learn",
" from",
" data",
"."
]
for chunk in chunks:
yield chunk
return _generate_chunks
@pytest.fixture
def sample_streaming_agent_response():
"""Sample streaming agent response chunks"""
return [
{
"chunk_type": "thought",
"content": "I need to search",
"end_of_message": False,
"end_of_dialog": False
},
{
"chunk_type": "thought",
"content": " for information",
"end_of_message": False,
"end_of_dialog": False
},
{
"chunk_type": "thought",
"content": " about machine learning.",
"end_of_message": True,
"end_of_dialog": False
},
{
"chunk_type": "action",
"content": "knowledge_query",
"end_of_message": True,
"end_of_dialog": False
},
{
"chunk_type": "observation",
"content": "Machine learning is",
"end_of_message": False,
"end_of_dialog": False
},
{
"chunk_type": "observation",
"content": " a subset of AI.",
"end_of_message": True,
"end_of_dialog": False
},
{
"chunk_type": "final-answer",
"content": "Machine learning",
"end_of_message": False,
"end_of_dialog": False
},
{
"chunk_type": "final-answer",
"content": " is a subset",
"end_of_message": False,
"end_of_dialog": False
},
{
"chunk_type": "final-answer",
"content": " of artificial intelligence.",
"end_of_message": True,
"end_of_dialog": True
}
]
@pytest.fixture
def streaming_chunk_collector():
"""Helper to collect streaming chunks for assertions"""
class ChunkCollector:
def __init__(self):
self.chunks = []
self.complete = False
async def collect(self, chunk):
"""Async callback to collect chunks"""
self.chunks.append(chunk)
def get_full_text(self):
"""Concatenate all chunk content"""
return "".join(self.chunks)
def get_chunk_types(self):
"""Get list of chunk types if chunks are dicts"""
if self.chunks and isinstance(self.chunks[0], dict):
return [c.get("chunk_type") for c in self.chunks]
return []
return ChunkCollector
@pytest.fixture
def mock_streaming_prompt_response():
"""Mock streaming prompt service response"""
async def _generate_prompt_chunks():
"""Generate streaming chunks for prompt responses"""
chunks = [
"Based on the",
" provided context,",
" here is",
" the answer:",
" Machine learning",
" enables computers",
" to learn",
" from data."
]
for chunk in chunks:
yield chunk
return _generate_prompt_chunks
@pytest.fixture
def sample_rag_streaming_chunks():
"""Sample RAG streaming response chunks"""
return [
{
"chunk": "Based on",
"end_of_stream": False
},
{
"chunk": " the knowledge",
"end_of_stream": False
},
{
"chunk": " graph,",
"end_of_stream": False
},
{
"chunk": " machine learning",
"end_of_stream": False
},
{
"chunk": " is a subset",
"end_of_stream": False
},
{
"chunk": " of AI.",
"end_of_stream": False
},
{
"chunk": None,
"end_of_stream": True,
"response": "Based on the knowledge graph, machine learning is a subset of AI."
}
]
@pytest.fixture
def streaming_error_scenarios():
"""Common error scenarios for streaming tests"""
return {
"connection_drop": {
"exception": ConnectionError,
"message": "Connection lost during streaming",
"chunks_before_error": 5
},
"timeout": {
"exception": TimeoutError,
"message": "Streaming timeout exceeded",
"chunks_before_error": 10
},
"rate_limit": {
"exception": Exception,
"message": "Rate limit exceeded",
"chunks_before_error": 3
},
"invalid_chunk": {
"exception": ValueError,
"message": "Invalid chunk format",
"chunks_before_error": 7
}
}
# Test markers for integration tests
pytestmark = pytest.mark.integration

View file

@ -0,0 +1,360 @@
"""
Integration tests for Agent Manager Streaming Functionality
These tests verify the streaming behavior of the Agent service, testing
chunk-by-chunk delivery of thoughts, actions, observations, and final answers.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.agent.react.agent_manager import AgentManager
from trustgraph.agent.react.tools import KnowledgeQueryImpl
from trustgraph.agent.react.types import Tool, Argument
from tests.utils.streaming_assertions import (
assert_agent_streaming_chunks,
assert_streaming_chunks_valid,
assert_callback_invoked,
assert_chunk_types_valid,
)
@pytest.mark.integration
class TestAgentStreaming:
"""Integration tests for Agent streaming functionality"""
@pytest.fixture
def mock_prompt_client_streaming(self):
"""Mock prompt client with streaming support"""
client = AsyncMock()
async def agent_react_streaming(variables, timeout=600, streaming=False, chunk_callback=None):
# Both modes return the same text for equivalence
full_text = """Thought: I need to search for information about machine learning.
Action: knowledge_query
Args: {
"question": "What is machine learning?"
}"""
if streaming and chunk_callback:
# Send realistic line-by-line chunks
# This tests that the parser properly handles "Args:" starting a new chunk
# (which previously caused a bug where action_buffer was overwritten)
chunks = [
"Thought: I need to search for information about machine learning.\n",
"Action: knowledge_query\n",
"Args: {\n", # This used to trigger bug - Args: at start of chunk
' "question": "What is machine learning?"\n',
"}"
]
for chunk in chunks:
await chunk_callback(chunk)
return full_text
else:
# Non-streaming response - same text
return full_text
client.agent_react.side_effect = agent_react_streaming
return client
@pytest.fixture
def mock_flow_context(self, mock_prompt_client_streaming):
"""Mock flow context with streaming prompt client"""
context = MagicMock()
# Mock graph RAG client
graph_rag_client = AsyncMock()
graph_rag_client.rag.return_value = "Machine learning is a subset of AI."
def context_router(service_name):
if service_name == "prompt-request":
return mock_prompt_client_streaming
elif service_name == "graph-rag-request":
return graph_rag_client
else:
return AsyncMock()
context.side_effect = context_router
return context
@pytest.fixture
def sample_tools(self):
"""Sample tool configuration"""
return {
"knowledge_query": Tool(
name="knowledge_query",
description="Query the knowledge graph",
arguments=[
Argument(
name="question",
type="string",
description="The question to ask"
)
],
implementation=KnowledgeQueryImpl,
config={}
)
}
@pytest.fixture
def agent_manager(self, sample_tools):
"""Create AgentManager instance with streaming support"""
return AgentManager(
tools=sample_tools,
additional_context="You are a helpful AI assistant."
)
@pytest.mark.asyncio
async def test_agent_streaming_thought_chunks(self, agent_manager, mock_flow_context):
"""Test that thought chunks are streamed correctly"""
# Arrange
thought_chunks = []
async def think(chunk):
thought_chunks.append(chunk)
# Act
await agent_manager.react(
question="What is machine learning?",
history=[],
think=think,
observe=AsyncMock(),
context=mock_flow_context,
streaming=True
)
# Assert
assert len(thought_chunks) > 0
assert_streaming_chunks_valid(thought_chunks, min_chunks=1)
# Verify thought content makes sense
full_thought = "".join(thought_chunks)
assert "search" in full_thought.lower() or "information" in full_thought.lower()
@pytest.mark.asyncio
async def test_agent_streaming_observation_chunks(self, agent_manager, mock_flow_context):
"""Test that observation chunks are streamed correctly"""
# Arrange
observation_chunks = []
async def observe(chunk):
observation_chunks.append(chunk)
# Act
await agent_manager.react(
question="What is machine learning?",
history=[],
think=AsyncMock(),
observe=observe,
context=mock_flow_context,
streaming=True
)
# Assert
# Note: Observations come from tool execution, which may or may not be streamed
# depending on the tool implementation
# For now, verify callback was set up
assert observe is not None
@pytest.mark.asyncio
async def test_agent_streaming_vs_non_streaming(self, agent_manager, mock_flow_context):
"""Test that streaming and non-streaming produce equivalent results"""
# Arrange
question = "What is machine learning?"
history = []
# Act - Non-streaming
non_streaming_result = await agent_manager.react(
question=question,
history=history,
think=AsyncMock(),
observe=AsyncMock(),
context=mock_flow_context,
streaming=False
)
# Act - Streaming
thought_chunks = []
observation_chunks = []
async def think(chunk):
thought_chunks.append(chunk)
async def observe(chunk):
observation_chunks.append(chunk)
streaming_result = await agent_manager.react(
question=question,
history=history,
think=think,
observe=observe,
context=mock_flow_context,
streaming=True
)
# Assert - Results should be equivalent (or both valid)
assert non_streaming_result is not None
assert streaming_result is not None
@pytest.mark.asyncio
async def test_agent_streaming_callback_invocation(self, agent_manager, mock_flow_context):
"""Test that callbacks are invoked with correct parameters"""
# Arrange
think = AsyncMock()
observe = AsyncMock()
# Act
await agent_manager.react(
question="What is machine learning?",
history=[],
think=think,
observe=observe,
context=mock_flow_context,
streaming=True
)
# Assert - Think callback should be invoked
assert think.call_count > 0
# Verify all callback invocations had string arguments
for call in think.call_args_list:
assert len(call.args) > 0
assert isinstance(call.args[0], str)
@pytest.mark.asyncio
async def test_agent_streaming_without_callbacks(self, agent_manager, mock_flow_context):
"""Test streaming parameter without callbacks (should work gracefully)"""
# Arrange & Act
result = await agent_manager.react(
question="What is machine learning?",
history=[],
think=AsyncMock(),
observe=AsyncMock(),
context=mock_flow_context,
streaming=True # Streaming enabled with mock callbacks
)
# Assert - Should complete without error
assert result is not None
@pytest.mark.asyncio
async def test_agent_streaming_with_conversation_history(self, agent_manager, mock_flow_context):
"""Test streaming with existing conversation history"""
# Arrange
# History should be a list of Action objects
from trustgraph.agent.react.types import Action
history = [
Action(
thought="I need to search for information about machine learning",
name="knowledge_query",
arguments={"question": "What is machine learning?"},
observation="Machine learning is a subset of AI that enables computers to learn from data."
)
]
think = AsyncMock()
# Act
result = await agent_manager.react(
question="Tell me more about neural networks",
history=history,
think=think,
observe=AsyncMock(),
context=mock_flow_context,
streaming=True
)
# Assert
assert result is not None
assert think.call_count > 0
@pytest.mark.asyncio
async def test_agent_streaming_error_propagation(self, agent_manager, mock_flow_context):
"""Test that errors during streaming are properly propagated"""
# Arrange
mock_prompt_client = mock_flow_context("prompt-request")
mock_prompt_client.agent_react.side_effect = Exception("Prompt service error")
think = AsyncMock()
observe = AsyncMock()
# Act & Assert
with pytest.raises(Exception) as exc_info:
await agent_manager.react(
question="test question",
history=[],
think=think,
observe=observe,
context=mock_flow_context,
streaming=True
)
assert "Prompt service error" in str(exc_info.value)
@pytest.mark.asyncio
async def test_agent_streaming_multi_step_reasoning(self, agent_manager, mock_flow_context,
mock_prompt_client_streaming):
"""Test streaming through multi-step reasoning process"""
# Arrange - Mock a multi-step response
step_responses = [
"""Thought: I need to search for basic information.
Action: knowledge_query
Args: {"question": "What is AI?"}""",
"""Thought: Now I can answer the question.
Final Answer: AI is the simulation of human intelligence in machines."""
]
call_count = 0
async def multi_step_agent_react(variables, timeout=600, streaming=False, chunk_callback=None):
nonlocal call_count
response = step_responses[min(call_count, len(step_responses) - 1)]
call_count += 1
if streaming and chunk_callback:
for chunk in response.split():
await chunk_callback(chunk + " ")
return response
return response
mock_prompt_client_streaming.agent_react.side_effect = multi_step_agent_react
think = AsyncMock()
observe = AsyncMock()
# Act
result = await agent_manager.react(
question="What is artificial intelligence?",
history=[],
think=think,
observe=observe,
context=mock_flow_context,
streaming=True
)
# Assert
assert result is not None
assert think.call_count > 0
@pytest.mark.asyncio
async def test_agent_streaming_preserves_tool_config(self, agent_manager, mock_flow_context):
"""Test that streaming preserves tool configuration and context"""
# Arrange
think = AsyncMock()
observe = AsyncMock()
# Act
await agent_manager.react(
question="What is machine learning?",
history=[],
think=think,
observe=observe,
context=mock_flow_context,
streaming=True
)
# Assert - Verify prompt client was called with streaming
mock_prompt_client = mock_flow_context("prompt-request")
call_args = mock_prompt_client.agent_react.call_args
assert call_args.kwargs['streaming'] is True
assert call_args.kwargs['chunk_callback'] is not None

View file

@ -0,0 +1,274 @@
"""
Integration tests for DocumentRAG streaming functionality
These tests verify the streaming behavior of DocumentRAG, testing token-by-token
response delivery through the complete pipeline.
"""
import pytest
from unittest.mock import AsyncMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
assert_callback_invoked,
)
@pytest.mark.integration
class TestDocumentRagStreaming:
"""Integration tests for DocumentRAG streaming"""
@pytest.fixture
def mock_embeddings_client(self):
"""Mock embeddings client"""
client = AsyncMock()
client.embed.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
return client
@pytest.fixture
def mock_doc_embeddings_client(self):
"""Mock document embeddings client"""
client = AsyncMock()
client.query.return_value = [
"Machine learning is a subset of AI.",
"Deep learning uses neural networks.",
"Supervised learning needs labeled data."
]
return client
@pytest.fixture
def mock_streaming_prompt_client(self, mock_streaming_llm_response):
"""Mock prompt client with streaming support"""
client = AsyncMock()
async def document_prompt_side_effect(query, documents, timeout=600, streaming=False, chunk_callback=None):
# Both modes return the same text
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
if streaming and chunk_callback:
# Simulate streaming chunks
async for chunk in mock_streaming_llm_response():
await chunk_callback(chunk)
return full_text
else:
# Non-streaming response - same text
return full_text
client.document_prompt.side_effect = document_prompt_side_effect
return client
@pytest.fixture
def document_rag_streaming(self, mock_embeddings_client, mock_doc_embeddings_client,
mock_streaming_prompt_client):
"""Create DocumentRag instance with streaming support"""
return DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_streaming_prompt_client,
verbose=True
)
@pytest.mark.asyncio
async def test_document_rag_streaming_basic(self, document_rag_streaming, streaming_chunk_collector):
"""Test basic DocumentRAG streaming functionality"""
# Arrange
query = "What is machine learning?"
collector = streaming_chunk_collector()
# Act
result = await document_rag_streaming.query(
query=query,
user="test_user",
collection="test_collection",
doc_limit=10,
streaming=True,
chunk_callback=collector.collect
)
# Assert
assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
# Verify full response matches concatenated chunks
full_from_chunks = collector.get_full_text()
assert result == full_from_chunks
# Verify content is reasonable
assert len(result) > 0
@pytest.mark.asyncio
async def test_document_rag_streaming_vs_non_streaming(self, document_rag_streaming):
"""Test that streaming and non-streaming produce equivalent results"""
# Arrange
query = "What is machine learning?"
user = "test_user"
collection = "test_collection"
doc_limit = 10
# Act - Non-streaming
non_streaming_result = await document_rag_streaming.query(
query=query,
user=user,
collection=collection,
doc_limit=doc_limit,
streaming=False
)
# Act - Streaming
streaming_chunks = []
async def collect(chunk):
streaming_chunks.append(chunk)
streaming_result = await document_rag_streaming.query(
query=query,
user=user,
collection=collection,
doc_limit=doc_limit,
streaming=True,
chunk_callback=collect
)
# Assert - Results should be equivalent
assert streaming_result == non_streaming_result
assert len(streaming_chunks) > 0
assert "".join(streaming_chunks) == streaming_result
@pytest.mark.asyncio
async def test_document_rag_streaming_callback_invocation(self, document_rag_streaming):
"""Test that chunk callback is invoked correctly"""
# Arrange
callback = AsyncMock()
# Act
result = await document_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
doc_limit=5,
streaming=True,
chunk_callback=callback
)
# Assert
assert callback.call_count > 0
assert result is not None
# Verify all callback invocations had string arguments
for call in callback.call_args_list:
assert isinstance(call.args[0], str)
@pytest.mark.asyncio
async def test_document_rag_streaming_without_callback(self, document_rag_streaming):
"""Test streaming parameter without callback (should fall back to non-streaming)"""
# Arrange & Act
result = await document_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
doc_limit=5,
streaming=True,
chunk_callback=None # No callback provided
)
# Assert - Should complete without error
assert result is not None
assert isinstance(result, str)
@pytest.mark.asyncio
async def test_document_rag_streaming_with_no_documents(self, document_rag_streaming,
mock_doc_embeddings_client):
"""Test streaming with no documents found"""
# Arrange
mock_doc_embeddings_client.query.return_value = [] # No documents
callback = AsyncMock()
# Act
result = await document_rag_streaming.query(
query="unknown topic",
user="test_user",
collection="test_collection",
doc_limit=10,
streaming=True,
chunk_callback=callback
)
# Assert - Should still produce streamed response
assert result is not None
assert callback.call_count > 0
@pytest.mark.asyncio
async def test_document_rag_streaming_error_propagation(self, document_rag_streaming,
mock_embeddings_client):
"""Test that errors during streaming are properly propagated"""
# Arrange
mock_embeddings_client.embed.side_effect = Exception("Embeddings error")
callback = AsyncMock()
# Act & Assert
with pytest.raises(Exception) as exc_info:
await document_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
doc_limit=5,
streaming=True,
chunk_callback=callback
)
assert "Embeddings error" in str(exc_info.value)
@pytest.mark.asyncio
async def test_document_rag_streaming_with_different_doc_limits(self, document_rag_streaming,
mock_doc_embeddings_client):
"""Test streaming with various document limits"""
# Arrange
callback = AsyncMock()
doc_limits = [1, 5, 10, 20]
for limit in doc_limits:
# Reset mocks
mock_doc_embeddings_client.reset_mock()
callback.reset_mock()
# Act
result = await document_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
doc_limit=limit,
streaming=True,
chunk_callback=callback
)
# Assert
assert result is not None
assert callback.call_count > 0
# Verify doc_limit was passed correctly
call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['limit'] == limit
@pytest.mark.asyncio
async def test_document_rag_streaming_preserves_user_collection(self, document_rag_streaming,
mock_doc_embeddings_client):
"""Test that streaming preserves user/collection isolation"""
# Arrange
callback = AsyncMock()
user = "test_user_123"
collection = "test_collection_456"
# Act
await document_rag_streaming.query(
query="test query",
user=user,
collection=collection,
doc_limit=10,
streaming=True,
chunk_callback=callback
)
# Assert - Verify user/collection were passed to document embeddings client
call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['user'] == user
assert call_args.kwargs['collection'] == collection

View file

@ -0,0 +1,269 @@
"""
Integration tests for GraphRAG retrieval system
These tests verify the end-to-end functionality of the GraphRAG system,
testing the coordination between embeddings, graph retrieval, triple querying, and prompt services.
Following the TEST_STRATEGY.md approach for integration testing.
NOTE: This is the first integration test file for GraphRAG (previously had only unit tests).
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
@pytest.mark.integration
class TestGraphRagIntegration:
"""Integration tests for GraphRAG system coordination"""
@pytest.fixture
def mock_embeddings_client(self):
"""Mock embeddings client that returns realistic vector embeddings"""
client = AsyncMock()
client.embed.return_value = [
[0.1, 0.2, 0.3, 0.4, 0.5], # Realistic 5-dimensional embedding
]
return client
@pytest.fixture
def mock_graph_embeddings_client(self):
"""Mock graph embeddings client that returns realistic entities"""
client = AsyncMock()
client.query.return_value = [
"http://trustgraph.ai/e/machine-learning",
"http://trustgraph.ai/e/artificial-intelligence",
"http://trustgraph.ai/e/neural-networks"
]
return client
@pytest.fixture
def mock_triples_client(self):
"""Mock triples client that returns realistic knowledge graph triples"""
client = AsyncMock()
# Mock different queries return different triples
async def query_side_effect(s=None, p=None, o=None, limit=None, user=None, collection=None):
# Mock label queries
if p == "http://www.w3.org/2000/01/rdf-schema#label":
if s == "http://trustgraph.ai/e/machine-learning":
return [MagicMock(s=s, p=p, o="Machine Learning")]
elif s == "http://trustgraph.ai/e/artificial-intelligence":
return [MagicMock(s=s, p=p, o="Artificial Intelligence")]
elif s == "http://trustgraph.ai/e/neural-networks":
return [MagicMock(s=s, p=p, o="Neural Networks")]
return []
# Mock relationship queries
if s == "http://trustgraph.ai/e/machine-learning":
return [
MagicMock(
s="http://trustgraph.ai/e/machine-learning",
p="http://trustgraph.ai/is_subset_of",
o="http://trustgraph.ai/e/artificial-intelligence"
),
MagicMock(
s="http://trustgraph.ai/e/machine-learning",
p="http://www.w3.org/2000/01/rdf-schema#label",
o="Machine Learning"
)
]
return []
client.query.side_effect = query_side_effect
return client
@pytest.fixture
def mock_prompt_client(self):
"""Mock prompt client that generates realistic responses"""
client = AsyncMock()
client.kg_prompt.return_value = (
"Machine learning is a subset of artificial intelligence that enables computers "
"to learn from data without being explicitly programmed. It uses algorithms "
"and statistical models to find patterns in data."
)
return client
@pytest.fixture
def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client,
mock_triples_client, mock_prompt_client):
"""Create GraphRag instance with mocked dependencies"""
return GraphRag(
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
prompt_client=mock_prompt_client,
verbose=True
)
@pytest.mark.asyncio
async def test_graph_rag_end_to_end_flow(self, graph_rag, mock_embeddings_client,
mock_graph_embeddings_client, mock_triples_client,
mock_prompt_client):
"""Test complete GraphRAG pipeline from query to response"""
# Arrange
query = "What is machine learning?"
user = "test_user"
collection = "ml_knowledge"
entity_limit = 50
triple_limit = 30
# Act
result = await graph_rag.query(
query=query,
user=user,
collection=collection,
entity_limit=entity_limit,
triple_limit=triple_limit
)
# Assert - Verify service coordination
# 1. Should compute embeddings for query
mock_embeddings_client.embed.assert_called_once_with(query)
# 2. Should query graph embeddings to find relevant entities
mock_graph_embeddings_client.query.assert_called_once()
call_args = mock_graph_embeddings_client.query.call_args
assert call_args.kwargs['vectors'] == [[0.1, 0.2, 0.3, 0.4, 0.5]]
assert call_args.kwargs['limit'] == entity_limit
assert call_args.kwargs['user'] == user
assert call_args.kwargs['collection'] == collection
# 3. Should query triples to build knowledge subgraph
assert mock_triples_client.query.call_count > 0
# 4. Should call prompt with knowledge graph
mock_prompt_client.kg_prompt.assert_called_once()
call_args = mock_prompt_client.kg_prompt.call_args
assert call_args.args[0] == query # First arg is query
assert isinstance(call_args.args[1], list) # Second arg is kg (list of triples)
# Verify final response
assert result is not None
assert isinstance(result, str)
assert "machine learning" in result.lower()
@pytest.mark.asyncio
async def test_graph_rag_with_different_limits(self, graph_rag, mock_embeddings_client,
mock_graph_embeddings_client):
"""Test GraphRAG with various entity and triple limits"""
# Arrange
query = "Explain neural networks"
test_configs = [
{"entity_limit": 10, "triple_limit": 10},
{"entity_limit": 50, "triple_limit": 30},
{"entity_limit": 100, "triple_limit": 100},
]
for config in test_configs:
# Reset mocks
mock_embeddings_client.reset_mock()
mock_graph_embeddings_client.reset_mock()
# Act
await graph_rag.query(
query=query,
user="test_user",
collection="test_collection",
entity_limit=config["entity_limit"],
triple_limit=config["triple_limit"]
)
# Assert
call_args = mock_graph_embeddings_client.query.call_args
assert call_args.kwargs['limit'] == config["entity_limit"]
@pytest.mark.asyncio
async def test_graph_rag_error_propagation(self, graph_rag, mock_embeddings_client):
"""Test that errors from underlying services are properly propagated"""
# Arrange
mock_embeddings_client.embed.side_effect = Exception("Embeddings service error")
# Act & Assert
with pytest.raises(Exception) as exc_info:
await graph_rag.query(
query="test query",
user="test_user",
collection="test_collection"
)
assert "Embeddings service error" in str(exc_info.value)
@pytest.mark.asyncio
async def test_graph_rag_with_empty_knowledge_graph(self, graph_rag, mock_graph_embeddings_client,
mock_triples_client, mock_prompt_client):
"""Test GraphRAG handles empty knowledge graph gracefully"""
# Arrange
mock_graph_embeddings_client.query.return_value = [] # No entities found
mock_triples_client.query.return_value = [] # No triples found
# Act
result = await graph_rag.query(
query="unknown topic",
user="test_user",
collection="test_collection"
)
# Assert
# Should still call prompt client with empty knowledge graph
mock_prompt_client.kg_prompt.assert_called_once()
call_args = mock_prompt_client.kg_prompt.call_args
assert isinstance(call_args.args[1], list) # kg should be a list
assert result is not None
@pytest.mark.asyncio
async def test_graph_rag_label_caching(self, graph_rag, mock_triples_client):
"""Test that label lookups are cached to reduce redundant queries"""
# Arrange
query = "What is machine learning?"
# First query
await graph_rag.query(
query=query,
user="test_user",
collection="test_collection"
)
first_call_count = mock_triples_client.query.call_count
mock_triples_client.reset_mock()
# Second identical query
await graph_rag.query(
query=query,
user="test_user",
collection="test_collection"
)
second_call_count = mock_triples_client.query.call_count
# Assert - Second query should make fewer triple queries due to caching
# Note: This is a weak assertion because caching behavior depends on
# implementation details, but it verifies the concept
assert second_call_count >= 0 # Should complete without errors
@pytest.mark.asyncio
async def test_graph_rag_multi_user_isolation(self, graph_rag, mock_graph_embeddings_client):
"""Test that different users/collections are properly isolated"""
# Arrange
query = "test query"
user1, collection1 = "user1", "collection1"
user2, collection2 = "user2", "collection2"
# Act
await graph_rag.query(query=query, user=user1, collection=collection1)
await graph_rag.query(query=query, user=user2, collection=collection2)
# Assert - Both users should have separate queries
assert mock_graph_embeddings_client.query.call_count == 2
# Verify first call
first_call = mock_graph_embeddings_client.query.call_args_list[0]
assert first_call.kwargs['user'] == user1
assert first_call.kwargs['collection'] == collection1
# Verify second call
second_call = mock_graph_embeddings_client.query.call_args_list[1]
assert second_call.kwargs['user'] == user2
assert second_call.kwargs['collection'] == collection2

View file

@ -0,0 +1,249 @@
"""
Integration tests for GraphRAG streaming functionality
These tests verify the streaming behavior of GraphRAG, testing token-by-token
response delivery through the complete pipeline.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
assert_rag_streaming_chunks,
assert_streaming_content_matches,
assert_callback_invoked,
)
@pytest.mark.integration
class TestGraphRagStreaming:
"""Integration tests for GraphRAG streaming"""
@pytest.fixture
def mock_embeddings_client(self):
"""Mock embeddings client"""
client = AsyncMock()
client.embed.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
return client
@pytest.fixture
def mock_graph_embeddings_client(self):
"""Mock graph embeddings client"""
client = AsyncMock()
client.query.return_value = [
"http://trustgraph.ai/e/machine-learning",
]
return client
@pytest.fixture
def mock_triples_client(self):
"""Mock triples client with minimal responses"""
client = AsyncMock()
async def query_side_effect(s=None, p=None, o=None, limit=None, user=None, collection=None):
if p == "http://www.w3.org/2000/01/rdf-schema#label":
return [MagicMock(s=s, p=p, o="Machine Learning")]
return []
client.query.side_effect = query_side_effect
return client
@pytest.fixture
def mock_streaming_prompt_client(self, mock_streaming_llm_response):
"""Mock prompt client with streaming support"""
client = AsyncMock()
async def kg_prompt_side_effect(query, kg, timeout=600, streaming=False, chunk_callback=None):
# Both modes return the same text
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
if streaming and chunk_callback:
# Simulate streaming chunks
async for chunk in mock_streaming_llm_response():
await chunk_callback(chunk)
return full_text
else:
# Non-streaming response - same text
return full_text
client.kg_prompt.side_effect = kg_prompt_side_effect
return client
@pytest.fixture
def graph_rag_streaming(self, mock_embeddings_client, mock_graph_embeddings_client,
mock_triples_client, mock_streaming_prompt_client):
"""Create GraphRag instance with streaming support"""
return GraphRag(
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
prompt_client=mock_streaming_prompt_client,
verbose=True
)
@pytest.mark.asyncio
async def test_graph_rag_streaming_basic(self, graph_rag_streaming, streaming_chunk_collector):
"""Test basic GraphRAG streaming functionality"""
# Arrange
query = "What is machine learning?"
collector = streaming_chunk_collector()
# Act
result = await graph_rag_streaming.query(
query=query,
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=collector.collect
)
# Assert
assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
# Verify full response matches concatenated chunks
full_from_chunks = collector.get_full_text()
assert result == full_from_chunks
# Verify content is reasonable
assert "machine" in result.lower() or "learning" in result.lower()
@pytest.mark.asyncio
async def test_graph_rag_streaming_vs_non_streaming(self, graph_rag_streaming):
"""Test that streaming and non-streaming produce equivalent results"""
# Arrange
query = "What is machine learning?"
user = "test_user"
collection = "test_collection"
# Act - Non-streaming
non_streaming_result = await graph_rag_streaming.query(
query=query,
user=user,
collection=collection,
streaming=False
)
# Act - Streaming
streaming_chunks = []
async def collect(chunk):
streaming_chunks.append(chunk)
streaming_result = await graph_rag_streaming.query(
query=query,
user=user,
collection=collection,
streaming=True,
chunk_callback=collect
)
# Assert - Results should be equivalent
assert streaming_result == non_streaming_result
assert len(streaming_chunks) > 0
assert "".join(streaming_chunks) == streaming_result
@pytest.mark.asyncio
async def test_graph_rag_streaming_callback_invocation(self, graph_rag_streaming):
"""Test that chunk callback is invoked correctly"""
# Arrange
callback = AsyncMock()
# Act
result = await graph_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=callback
)
# Assert
assert callback.call_count > 0
assert result is not None
# Verify all callback invocations had string arguments
for call in callback.call_args_list:
assert isinstance(call.args[0], str)
@pytest.mark.asyncio
async def test_graph_rag_streaming_without_callback(self, graph_rag_streaming):
"""Test streaming parameter without callback (should fall back to non-streaming)"""
# Arrange & Act
result = await graph_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=None # No callback provided
)
# Assert - Should complete without error
assert result is not None
assert isinstance(result, str)
@pytest.mark.asyncio
async def test_graph_rag_streaming_with_empty_kg(self, graph_rag_streaming,
mock_graph_embeddings_client):
"""Test streaming with empty knowledge graph"""
# Arrange
mock_graph_embeddings_client.query.return_value = [] # No entities
callback = AsyncMock()
# Act
result = await graph_rag_streaming.query(
query="unknown topic",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=callback
)
# Assert - Should still produce streamed response
assert result is not None
assert callback.call_count > 0
@pytest.mark.asyncio
async def test_graph_rag_streaming_error_propagation(self, graph_rag_streaming,
mock_embeddings_client):
"""Test that errors during streaming are properly propagated"""
# Arrange
mock_embeddings_client.embed.side_effect = Exception("Embeddings error")
callback = AsyncMock()
# Act & Assert
with pytest.raises(Exception) as exc_info:
await graph_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
streaming=True,
chunk_callback=callback
)
assert "Embeddings error" in str(exc_info.value)
@pytest.mark.asyncio
async def test_graph_rag_streaming_preserves_parameters(self, graph_rag_streaming,
mock_graph_embeddings_client):
"""Test that streaming preserves all query parameters"""
# Arrange
callback = AsyncMock()
entity_limit = 25
triple_limit = 15
# Act
await graph_rag_streaming.query(
query="test query",
user="test_user",
collection="test_collection",
entity_limit=entity_limit,
triple_limit=triple_limit,
streaming=True,
chunk_callback=callback
)
# Assert - Verify parameters were passed to underlying services
call_args = mock_graph_embeddings_client.query.call_args
assert call_args.kwargs['limit'] == entity_limit

View file

@ -0,0 +1,404 @@
"""
Integration tests for Prompt Service Streaming Functionality
These tests verify the streaming behavior of the Prompt service,
testing how it coordinates between templates and text completion streaming.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.prompt.template.service import Processor
from trustgraph.schema import PromptRequest, PromptResponse, TextCompletionResponse
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
assert_callback_invoked,
)
@pytest.mark.integration
class TestPromptStreaming:
"""Integration tests for Prompt service streaming"""
@pytest.fixture
def mock_flow_context_streaming(self):
"""Mock flow context with streaming text completion support"""
context = MagicMock()
# Mock text completion client with streaming
text_completion_client = AsyncMock()
async def streaming_request(request, recipient=None, timeout=600):
"""Simulate streaming text completion"""
if request.streaming and recipient:
# Simulate streaming chunks
chunks = [
"Machine", " learning", " is", " a", " field",
" of", " artificial", " intelligence", "."
]
for i, chunk_text in enumerate(chunks):
is_final = (i == len(chunks) - 1)
response = TextCompletionResponse(
response=chunk_text,
error=None,
end_of_stream=is_final
)
final = await recipient(response)
if final:
break
# Final empty chunk
await recipient(TextCompletionResponse(
response="",
error=None,
end_of_stream=True
))
text_completion_client.request = streaming_request
# Mock response producer
response_producer = AsyncMock()
def context_router(service_name):
if service_name == "text-completion-request":
return text_completion_client
elif service_name == "response":
return response_producer
else:
return AsyncMock()
context.side_effect = context_router
return context
@pytest.fixture
def mock_prompt_manager(self):
"""Mock PromptManager with simple template"""
manager = MagicMock()
async def invoke_template(kind, input_vars, llm_function):
"""Simulate template invocation"""
# Call the LLM function with simple prompts
system = "You are a helpful assistant."
prompt = f"Question: {input_vars.get('question', 'test')}"
result = await llm_function(system, prompt)
return result
manager.invoke = invoke_template
return manager
@pytest.fixture
def prompt_processor_streaming(self, mock_prompt_manager):
"""Create Prompt processor with streaming support"""
processor = MagicMock()
processor.manager = mock_prompt_manager
processor.config_key = "prompt"
# Bind the actual on_request method
processor.on_request = Processor.on_request.__get__(processor, Processor)
return processor
@pytest.mark.asyncio
async def test_prompt_streaming_basic(self, prompt_processor_streaming, mock_flow_context_streaming):
"""Test basic prompt streaming functionality"""
# Arrange
request = PromptRequest(
id="kg_prompt",
terms={"question": '"What is machine learning?"'},
streaming=True
)
message = MagicMock()
message.value.return_value = request
message.properties.return_value = {"id": "test-123"}
consumer = MagicMock()
# Act
await prompt_processor_streaming.on_request(
message, consumer, mock_flow_context_streaming
)
# Assert
# Verify response producer was called multiple times (for streaming chunks)
response_producer = mock_flow_context_streaming("response")
assert response_producer.send.call_count > 0
# Verify streaming chunks were sent
calls = response_producer.send.call_args_list
assert len(calls) > 1 # Should have multiple chunks
# Check that responses have end_of_stream flag
for call in calls:
response = call.args[0]
assert isinstance(response, PromptResponse)
assert hasattr(response, 'end_of_stream')
# Last response should have end_of_stream=True
last_call = calls[-1]
last_response = last_call.args[0]
assert last_response.end_of_stream is True
@pytest.mark.asyncio
async def test_prompt_streaming_non_streaming_mode(self, prompt_processor_streaming,
mock_flow_context_streaming):
"""Test prompt service in non-streaming mode"""
# Arrange
request = PromptRequest(
id="kg_prompt",
terms={"question": '"What is AI?"'},
streaming=False # Non-streaming
)
message = MagicMock()
message.value.return_value = request
message.properties.return_value = {"id": "test-456"}
consumer = MagicMock()
# Mock non-streaming text completion
text_completion_client = mock_flow_context_streaming("text-completion-request")
async def non_streaming_text_completion(system, prompt, streaming=False):
return "AI is the simulation of human intelligence in machines."
text_completion_client.text_completion = non_streaming_text_completion
# Act
await prompt_processor_streaming.on_request(
message, consumer, mock_flow_context_streaming
)
# Assert
# Verify response producer was called once (non-streaming)
response_producer = mock_flow_context_streaming("response")
# Note: In non-streaming mode, the service sends a single response
assert response_producer.send.call_count >= 1
@pytest.mark.asyncio
async def test_prompt_streaming_chunk_forwarding(self, prompt_processor_streaming,
mock_flow_context_streaming):
"""Test that prompt service forwards chunks immediately"""
# Arrange
request = PromptRequest(
id="test_prompt",
terms={"question": '"Test query"'},
streaming=True
)
message = MagicMock()
message.value.return_value = request
message.properties.return_value = {"id": "test-789"}
consumer = MagicMock()
# Act
await prompt_processor_streaming.on_request(
message, consumer, mock_flow_context_streaming
)
# Assert
# Verify chunks were forwarded with proper structure
response_producer = mock_flow_context_streaming("response")
calls = response_producer.send.call_args_list
for call in calls:
response = call.args[0]
# Each response should have text and end_of_stream fields
assert hasattr(response, 'text')
assert hasattr(response, 'end_of_stream')
@pytest.mark.asyncio
async def test_prompt_streaming_error_handling(self, prompt_processor_streaming):
"""Test error handling during streaming"""
# Arrange
from trustgraph.schema import Error
context = MagicMock()
# Mock text completion client that raises an error
text_completion_client = AsyncMock()
async def failing_request(request, recipient=None, timeout=600):
if recipient:
# Send error response with proper Error schema
error_response = TextCompletionResponse(
response="",
error=Error(message="Text completion error", type="processing_error"),
end_of_stream=True
)
await recipient(error_response)
text_completion_client.request = failing_request
# Mock response producer to capture error response
response_producer = AsyncMock()
def context_router(service_name):
if service_name == "text-completion-request":
return text_completion_client
elif service_name == "response":
return response_producer
else:
return AsyncMock()
context.side_effect = context_router
request = PromptRequest(
id="test_prompt",
terms={"question": '"Test"'},
streaming=True
)
message = MagicMock()
message.value.return_value = request
message.properties.return_value = {"id": "test-error"}
consumer = MagicMock()
# Act - The service catches errors and sends error responses, doesn't raise
await prompt_processor_streaming.on_request(message, consumer, context)
# Assert - Verify error response was sent
assert response_producer.send.call_count > 0
# Check that at least one response contains an error
error_sent = False
for call in response_producer.send.call_args_list:
response = call.args[0]
if hasattr(response, 'error') and response.error:
error_sent = True
assert "Text completion error" in response.error.message
break
assert error_sent, "Expected error response to be sent"
@pytest.mark.asyncio
async def test_prompt_streaming_preserves_message_id(self, prompt_processor_streaming,
mock_flow_context_streaming):
"""Test that message IDs are preserved through streaming"""
# Arrange
message_id = "unique-test-id-12345"
request = PromptRequest(
id="test_prompt",
terms={"question": '"Test"'},
streaming=True
)
message = MagicMock()
message.value.return_value = request
message.properties.return_value = {"id": message_id}
consumer = MagicMock()
# Act
await prompt_processor_streaming.on_request(
message, consumer, mock_flow_context_streaming
)
# Assert
# Verify all responses were sent with the correct message ID
response_producer = mock_flow_context_streaming("response")
calls = response_producer.send.call_args_list
for call in calls:
properties = call.kwargs.get('properties')
assert properties is not None
assert properties['id'] == message_id
@pytest.mark.asyncio
async def test_prompt_streaming_empty_response_handling(self, prompt_processor_streaming):
"""Test handling of empty responses during streaming"""
# Arrange
context = MagicMock()
# Mock text completion that sends empty chunks
text_completion_client = AsyncMock()
async def empty_streaming_request(request, recipient=None, timeout=600):
if request.streaming and recipient:
# Send empty chunk followed by final marker
await recipient(TextCompletionResponse(
response="",
error=None,
end_of_stream=False
))
await recipient(TextCompletionResponse(
response="",
error=None,
end_of_stream=True
))
text_completion_client.request = empty_streaming_request
response_producer = AsyncMock()
def context_router(service_name):
if service_name == "text-completion-request":
return text_completion_client
elif service_name == "response":
return response_producer
else:
return AsyncMock()
context.side_effect = context_router
request = PromptRequest(
id="test_prompt",
terms={"question": '"Test"'},
streaming=True
)
message = MagicMock()
message.value.return_value = request
message.properties.return_value = {"id": "test-empty"}
consumer = MagicMock()
# Act
await prompt_processor_streaming.on_request(message, consumer, context)
# Assert
# Should still send responses even if empty (including final marker)
assert response_producer.send.call_count > 0
# Last response should have end_of_stream=True
last_call = response_producer.send.call_args_list[-1]
last_response = last_call.args[0]
assert last_response.end_of_stream is True
@pytest.mark.asyncio
async def test_prompt_streaming_concatenation_matches_complete(self, prompt_processor_streaming,
mock_flow_context_streaming):
"""Test that streaming chunks concatenate to form complete response"""
# Arrange
request = PromptRequest(
id="test_prompt",
terms={"question": '"What is ML?"'},
streaming=True
)
message = MagicMock()
message.value.return_value = request
message.properties.return_value = {"id": "test-concat"}
consumer = MagicMock()
# Act
await prompt_processor_streaming.on_request(
message, consumer, mock_flow_context_streaming
)
# Assert
# Collect all response texts
response_producer = mock_flow_context_streaming("response")
calls = response_producer.send.call_args_list
chunk_texts = []
for call in calls:
response = call.args[0]
if response.text and not response.end_of_stream:
chunk_texts.append(response.text)
# Verify chunks concatenate to expected result
full_text = "".join(chunk_texts)
assert full_text == "Machine learning is a field of artificial intelligence"

View file

@ -0,0 +1,366 @@
"""
Integration tests for Text Completion Streaming Functionality
These tests verify the streaming behavior of the Text Completion service,
testing token-by-token response delivery through the complete pipeline.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from openai.types.chat import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as StreamChoice, ChoiceDelta
from trustgraph.model.text_completion.openai.llm import Processor
from trustgraph.base import LlmChunk
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
assert_callback_invoked,
)
@pytest.mark.integration
class TestTextCompletionStreaming:
"""Integration tests for Text Completion streaming"""
@pytest.fixture
def mock_streaming_openai_client(self, mock_streaming_llm_response):
"""Mock OpenAI client with streaming support"""
client = MagicMock()
def create_streaming_completion(**kwargs):
"""Generator that yields streaming chunks"""
# Check if streaming is enabled
if not kwargs.get('stream', False):
raise ValueError("Expected streaming mode")
# Simulate OpenAI streaming response
chunks_text = [
"Machine", " learning", " is", " a", " subset",
" of", " AI", " that", " enables", " computers",
" to", " learn", " from", " data", "."
]
for text in chunks_text:
delta = ChoiceDelta(content=text, role=None)
choice = StreamChoice(index=0, delta=delta, finish_reason=None)
chunk = ChatCompletionChunk(
id="chatcmpl-streaming",
choices=[choice],
created=1234567890,
model="gpt-3.5-turbo",
object="chat.completion.chunk"
)
yield chunk
# Return a new generator each time create is called
client.chat.completions.create.side_effect = lambda **kwargs: create_streaming_completion(**kwargs)
return client
@pytest.fixture
def text_completion_processor_streaming(self, mock_streaming_openai_client):
"""Create text completion processor with streaming support"""
processor = MagicMock()
processor.default_model = "gpt-3.5-turbo"
processor.temperature = 0.7
processor.max_output = 1024
processor.openai = mock_streaming_openai_client
# Bind the actual streaming method
processor.generate_content_stream = Processor.generate_content_stream.__get__(
processor, Processor
)
return processor
@pytest.mark.asyncio
async def test_text_completion_streaming_basic(self, text_completion_processor_streaming,
streaming_chunk_collector):
"""Test basic text completion streaming functionality"""
# Arrange
system_prompt = "You are a helpful assistant."
user_prompt = "What is machine learning?"
collector = streaming_chunk_collector()
# Act - Collect all chunks
chunks = []
async for chunk in text_completion_processor_streaming.generate_content_stream(
system_prompt, user_prompt
):
chunks.append(chunk)
if chunk.text:
await collector.collect(chunk.text)
# Assert
assert len(chunks) > 1 # Should have multiple chunks
# Verify all chunks are LlmChunk objects
for chunk in chunks:
assert isinstance(chunk, LlmChunk)
assert chunk.model == "gpt-3.5-turbo"
# Verify last chunk has is_final=True
assert chunks[-1].is_final is True
# Verify we got meaningful content
full_text = collector.get_full_text()
assert "machine" in full_text.lower() or "learning" in full_text.lower()
@pytest.mark.asyncio
async def test_text_completion_streaming_chunk_structure(self, text_completion_processor_streaming):
"""Test that streaming chunks have correct structure"""
# Arrange
system_prompt = "You are a helpful assistant."
user_prompt = "Explain AI."
# Act
chunks = []
async for chunk in text_completion_processor_streaming.generate_content_stream(
system_prompt, user_prompt
):
chunks.append(chunk)
# Assert - Verify chunk structure
for i, chunk in enumerate(chunks[:-1]): # All except last
assert isinstance(chunk, LlmChunk)
assert chunk.text is not None
assert chunk.model == "gpt-3.5-turbo"
assert chunk.is_final is False
# Last chunk should be final marker
final_chunk = chunks[-1]
assert final_chunk.is_final is True
assert final_chunk.model == "gpt-3.5-turbo"
@pytest.mark.asyncio
async def test_text_completion_streaming_concatenation(self, text_completion_processor_streaming):
"""Test that chunks concatenate to form complete response"""
# Arrange
system_prompt = "You are a helpful assistant."
user_prompt = "What is AI?"
# Act - Collect all chunk texts
chunk_texts = []
async for chunk in text_completion_processor_streaming.generate_content_stream(
system_prompt, user_prompt
):
if chunk.text and not chunk.is_final:
chunk_texts.append(chunk.text)
# Assert
full_text = "".join(chunk_texts)
assert len(full_text) > 0
assert len(chunk_texts) > 1 # Should have multiple chunks
# Verify completeness - should be a coherent sentence
assert full_text == "Machine learning is a subset of AI that enables computers to learn from data."
@pytest.mark.asyncio
async def test_text_completion_streaming_final_marker(self, text_completion_processor_streaming):
"""Test that final chunk properly marks end of stream"""
# Arrange
system_prompt = "You are a helpful assistant."
user_prompt = "Test query"
# Act
chunks = []
async for chunk in text_completion_processor_streaming.generate_content_stream(
system_prompt, user_prompt
):
chunks.append(chunk)
# Assert
# Should have at least content chunks + final marker
assert len(chunks) >= 2
# Only the last chunk should have is_final=True
for chunk in chunks[:-1]:
assert chunk.is_final is False
assert chunks[-1].is_final is True
@pytest.mark.asyncio
async def test_text_completion_streaming_model_parameter(self, mock_streaming_openai_client):
"""Test that model parameter is preserved in streaming"""
# Arrange
processor = MagicMock()
processor.default_model = "gpt-4"
processor.temperature = 0.5
processor.max_output = 2048
processor.openai = mock_streaming_openai_client
processor.generate_content_stream = Processor.generate_content_stream.__get__(
processor, Processor
)
# Act
chunks = []
async for chunk in processor.generate_content_stream("System", "Prompt"):
chunks.append(chunk)
# Assert
# Verify OpenAI was called with correct model
call_args = mock_streaming_openai_client.chat.completions.create.call_args
assert call_args.kwargs['model'] == "gpt-4"
assert call_args.kwargs['temperature'] == 0.5
assert call_args.kwargs['max_tokens'] == 2048
assert call_args.kwargs['stream'] is True
# Verify chunks have correct model
for chunk in chunks:
assert chunk.model == "gpt-4"
@pytest.mark.asyncio
async def test_text_completion_streaming_temperature_parameter(self, mock_streaming_openai_client):
"""Test that temperature parameter is applied in streaming"""
# Arrange
temperatures = [0.0, 0.5, 1.0, 1.5]
for temp in temperatures:
processor = MagicMock()
processor.default_model = "gpt-3.5-turbo"
processor.temperature = temp
processor.max_output = 1024
processor.openai = mock_streaming_openai_client
processor.generate_content_stream = Processor.generate_content_stream.__get__(
processor, Processor
)
# Act
chunks = []
async for chunk in processor.generate_content_stream("System", "Prompt"):
chunks.append(chunk)
if chunk.is_final:
break
# Assert
call_args = mock_streaming_openai_client.chat.completions.create.call_args
assert call_args.kwargs['temperature'] == temp
# Reset mock for next iteration
mock_streaming_openai_client.reset_mock()
@pytest.mark.asyncio
async def test_text_completion_streaming_error_propagation(self):
"""Test that errors during streaming are properly propagated"""
# Arrange
mock_client = MagicMock()
def failing_stream(**kwargs):
yield from []
raise Exception("Streaming error")
mock_client.chat.completions.create.return_value = failing_stream()
processor = MagicMock()
processor.default_model = "gpt-3.5-turbo"
processor.temperature = 0.7
processor.max_output = 1024
processor.openai = mock_client
processor.generate_content_stream = Processor.generate_content_stream.__get__(
processor, Processor
)
# Act & Assert
with pytest.raises(Exception) as exc_info:
async for chunk in processor.generate_content_stream("System", "Prompt"):
pass
assert "Streaming error" in str(exc_info.value)
@pytest.mark.asyncio
async def test_text_completion_streaming_empty_chunks_filtered(self, mock_streaming_openai_client):
"""Test that empty chunks are handled correctly"""
# Arrange - Mock that returns some empty chunks
def create_streaming_with_empties(**kwargs):
chunks_text = ["Hello", "", " world", "", "!"]
for text in chunks_text:
delta = ChoiceDelta(content=text if text else None, role=None)
choice = StreamChoice(index=0, delta=delta, finish_reason=None)
chunk = ChatCompletionChunk(
id="chatcmpl-streaming",
choices=[choice],
created=1234567890,
model="gpt-3.5-turbo",
object="chat.completion.chunk"
)
yield chunk
mock_streaming_openai_client.chat.completions.create.side_effect = lambda **kwargs: create_streaming_with_empties(**kwargs)
processor = MagicMock()
processor.default_model = "gpt-3.5-turbo"
processor.temperature = 0.7
processor.max_output = 1024
processor.openai = mock_streaming_openai_client
processor.generate_content_stream = Processor.generate_content_stream.__get__(
processor, Processor
)
# Act
chunks = []
async for chunk in processor.generate_content_stream("System", "Prompt"):
chunks.append(chunk)
# Assert - Only non-empty chunks should be yielded (plus final marker)
text_chunks = [c for c in chunks if not c.is_final]
assert len(text_chunks) == 3 # "Hello", " world", "!"
assert "".join(c.text for c in text_chunks) == "Hello world!"
@pytest.mark.asyncio
async def test_text_completion_streaming_prompt_construction(self, mock_streaming_openai_client):
"""Test that system and user prompts are correctly combined for streaming"""
# Arrange
processor = MagicMock()
processor.default_model = "gpt-3.5-turbo"
processor.temperature = 0.7
processor.max_output = 1024
processor.openai = mock_streaming_openai_client
processor.generate_content_stream = Processor.generate_content_stream.__get__(
processor, Processor
)
system_prompt = "You are an expert."
user_prompt = "Explain quantum physics."
# Act
chunks = []
async for chunk in processor.generate_content_stream(system_prompt, user_prompt):
chunks.append(chunk)
if chunk.is_final:
break
# Assert - Verify prompts were combined correctly
call_args = mock_streaming_openai_client.chat.completions.create.call_args
messages = call_args.kwargs['messages']
assert len(messages) == 1
message_content = messages[0]['content'][0]['text']
assert system_prompt in message_content
assert user_prompt in message_content
assert message_content.startswith(system_prompt)
@pytest.mark.asyncio
async def test_text_completion_streaming_chunk_count(self, text_completion_processor_streaming):
"""Test that streaming produces expected number of chunks"""
# Arrange
system_prompt = "You are a helpful assistant."
user_prompt = "Test"
# Act
chunks = []
async for chunk in text_completion_processor_streaming.generate_content_stream(
system_prompt, user_prompt
):
chunks.append(chunk)
# Assert
# Should have 15 content chunks + 1 final marker = 16 total
assert len(chunks) == 16
# 15 content chunks
content_chunks = [c for c in chunks if not c.is_final]
assert len(content_chunks) == 15
# 1 final marker
final_chunks = [c for c in chunks if c.is_final]
assert len(final_chunks) == 1

29
tests/utils/__init__.py Normal file
View file

@ -0,0 +1,29 @@
"""Test utilities for TrustGraph tests"""
from .streaming_assertions import (
assert_streaming_chunks_valid,
assert_streaming_sequence,
assert_agent_streaming_chunks,
assert_rag_streaming_chunks,
assert_streaming_completion,
assert_streaming_content_matches,
assert_no_empty_chunks,
assert_streaming_error_handled,
assert_chunk_types_valid,
assert_streaming_latency_acceptable,
assert_callback_invoked,
)
__all__ = [
"assert_streaming_chunks_valid",
"assert_streaming_sequence",
"assert_agent_streaming_chunks",
"assert_rag_streaming_chunks",
"assert_streaming_completion",
"assert_streaming_content_matches",
"assert_no_empty_chunks",
"assert_streaming_error_handled",
"assert_chunk_types_valid",
"assert_streaming_latency_acceptable",
"assert_callback_invoked",
]

View file

@ -0,0 +1,218 @@
"""
Streaming test assertion helpers
Provides reusable assertion functions for validating streaming behavior
across different TrustGraph services.
"""
from typing import List, Dict, Any, Optional
def assert_streaming_chunks_valid(chunks: List[Any], min_chunks: int = 1):
"""
Assert that streaming chunks are valid and non-empty.
Args:
chunks: List of streaming chunks
min_chunks: Minimum number of expected chunks
"""
assert len(chunks) >= min_chunks, f"Expected at least {min_chunks} chunks, got {len(chunks)}"
assert all(chunk is not None for chunk in chunks), "All chunks should be non-None"
def assert_streaming_sequence(chunks: List[Dict[str, Any]], expected_sequence: List[str], key: str = "chunk_type"):
"""
Assert that streaming chunks follow an expected sequence.
Args:
chunks: List of chunk dictionaries
expected_sequence: Expected sequence of chunk types/values
key: Dictionary key to check (default: "chunk_type")
"""
actual_sequence = [chunk.get(key) for chunk in chunks if key in chunk]
assert actual_sequence == expected_sequence, \
f"Expected sequence {expected_sequence}, got {actual_sequence}"
def assert_agent_streaming_chunks(chunks: List[Dict[str, Any]]):
"""
Assert that agent streaming chunks have valid structure.
Validates:
- All chunks have chunk_type field
- All chunks have content field
- All chunks have end_of_message field
- All chunks have end_of_dialog field
- Last chunk has end_of_dialog=True
Args:
chunks: List of agent streaming chunk dictionaries
"""
assert len(chunks) > 0, "Expected at least one chunk"
for i, chunk in enumerate(chunks):
assert "chunk_type" in chunk, f"Chunk {i} missing chunk_type"
assert "content" in chunk, f"Chunk {i} missing content"
assert "end_of_message" in chunk, f"Chunk {i} missing end_of_message"
assert "end_of_dialog" in chunk, f"Chunk {i} missing end_of_dialog"
# Validate chunk_type values
valid_types = ["thought", "action", "observation", "final-answer"]
assert chunk["chunk_type"] in valid_types, \
f"Invalid chunk_type '{chunk['chunk_type']}' at index {i}"
# Last chunk should signal end of dialog
assert chunks[-1]["end_of_dialog"] is True, \
"Last chunk should have end_of_dialog=True"
def assert_rag_streaming_chunks(chunks: List[Dict[str, Any]]):
"""
Assert that RAG streaming chunks have valid structure.
Validates:
- All chunks except last have chunk field
- All chunks have end_of_stream field
- Last chunk has end_of_stream=True
- Last chunk may have response field with complete text
Args:
chunks: List of RAG streaming chunk dictionaries
"""
assert len(chunks) > 0, "Expected at least one chunk"
for i, chunk in enumerate(chunks):
assert "end_of_stream" in chunk, f"Chunk {i} missing end_of_stream"
if i < len(chunks) - 1:
# Non-final chunks should have chunk content and end_of_stream=False
assert "chunk" in chunk, f"Chunk {i} missing chunk field"
assert chunk["end_of_stream"] is False, \
f"Non-final chunk {i} should have end_of_stream=False"
else:
# Final chunk should have end_of_stream=True
assert chunk["end_of_stream"] is True, \
"Last chunk should have end_of_stream=True"
def assert_streaming_completion(chunks: List[Dict[str, Any]], expected_complete_flag: str = "end_of_stream"):
"""
Assert that streaming completed properly.
Args:
chunks: List of streaming chunk dictionaries
expected_complete_flag: Name of the completion flag field
"""
assert len(chunks) > 0, "Expected at least one chunk"
# Check that all but last chunk have completion flag = False
for i, chunk in enumerate(chunks[:-1]):
assert chunk.get(expected_complete_flag) is False, \
f"Non-final chunk {i} should have {expected_complete_flag}=False"
# Check that last chunk has completion flag = True
assert chunks[-1].get(expected_complete_flag) is True, \
f"Final chunk should have {expected_complete_flag}=True"
def assert_streaming_content_matches(chunks: List, expected_content: str, content_key: str = "chunk"):
"""
Assert that concatenated streaming chunks match expected content.
Args:
chunks: List of streaming chunks (strings or dicts)
expected_content: Expected complete content after concatenation
content_key: Dictionary key for content (used if chunks are dicts)
"""
if isinstance(chunks[0], dict):
# Extract content from chunk dictionaries
content_chunks = [
chunk.get(content_key, "")
for chunk in chunks
if chunk.get(content_key) is not None
]
actual_content = "".join(content_chunks)
else:
# Chunks are already strings
actual_content = "".join(chunks)
assert actual_content == expected_content, \
f"Expected content '{expected_content}', got '{actual_content}'"
def assert_no_empty_chunks(chunks: List[Dict[str, Any]], content_key: str = "content"):
"""
Assert that no chunks have empty content (except final chunk if it's completion marker).
Args:
chunks: List of streaming chunk dictionaries
content_key: Dictionary key for content
"""
for i, chunk in enumerate(chunks[:-1]):
content = chunk.get(content_key)
assert content is not None and len(content) > 0, \
f"Chunk {i} has empty content"
def assert_streaming_error_handled(chunks: List[Dict[str, Any]], error_flag: str = "error"):
"""
Assert that streaming error was properly signaled.
Args:
chunks: List of streaming chunk dictionaries
error_flag: Name of the error flag field
"""
# Check that at least one chunk has error flag
has_error = any(chunk.get(error_flag) is not None for chunk in chunks)
assert has_error, "Expected error flag in at least one chunk"
# If last chunk has error, should also have completion flag
if chunks[-1].get(error_flag):
# Check for completion flags (either end_of_stream or end_of_dialog)
completion_flags = ["end_of_stream", "end_of_dialog"]
has_completion = any(chunks[-1].get(flag) is True for flag in completion_flags)
assert has_completion, \
"Error chunk should have completion flag set to True"
def assert_chunk_types_valid(chunks: List[Dict[str, Any]], valid_types: List[str], type_key: str = "chunk_type"):
"""
Assert that all chunk types are from a valid set.
Args:
chunks: List of streaming chunk dictionaries
valid_types: List of valid chunk type values
type_key: Dictionary key for chunk type
"""
for i, chunk in enumerate(chunks):
chunk_type = chunk.get(type_key)
assert chunk_type in valid_types, \
f"Chunk {i} has invalid type '{chunk_type}', expected one of {valid_types}"
def assert_streaming_latency_acceptable(chunk_timestamps: List[float], max_gap_seconds: float = 5.0):
"""
Assert that streaming latency between chunks is acceptable.
Args:
chunk_timestamps: List of timestamps when chunks were received
max_gap_seconds: Maximum acceptable gap between chunks in seconds
"""
assert len(chunk_timestamps) > 1, "Need at least 2 timestamps to check latency"
for i in range(1, len(chunk_timestamps)):
gap = chunk_timestamps[i] - chunk_timestamps[i-1]
assert gap <= max_gap_seconds, \
f"Gap between chunks {i-1} and {i} is {gap:.2f}s, exceeds max {max_gap_seconds}s"
def assert_callback_invoked(mock_callback, min_calls: int = 1):
"""
Assert that a streaming callback was invoked minimum number of times.
Args:
mock_callback: AsyncMock callback object
min_calls: Minimum number of expected calls
"""
assert mock_callback.call_count >= min_calls, \
f"Expected callback to be called at least {min_calls} times, was called {mock_callback.call_count} times"

View file

@ -112,7 +112,7 @@ class PromptClient(RequestResponse):
timeout = timeout,
)
async def kg_prompt(self, query, kg, timeout=600):
async def kg_prompt(self, query, kg, timeout=600, streaming=False, chunk_callback=None):
return await self.prompt(
id = "kg-prompt",
variables = {
@ -123,9 +123,11 @@ class PromptClient(RequestResponse):
]
},
timeout = timeout,
streaming = streaming,
chunk_callback = chunk_callback,
)
async def document_prompt(self, query, documents, timeout=600):
async def document_prompt(self, query, documents, timeout=600, streaming=False, chunk_callback=None):
return await self.prompt(
id = "document-prompt",
variables = {
@ -133,6 +135,8 @@ class PromptClient(RequestResponse):
"documents": documents,
},
timeout = timeout,
streaming = streaming,
chunk_callback = chunk_callback,
)
async def agent_react(self, variables, timeout=600, streaming=False, chunk_callback=None):

View file

@ -5,43 +5,65 @@ from .base import MessageTranslator
class DocumentRagRequestTranslator(MessageTranslator):
"""Translator for DocumentRagQuery schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> DocumentRagQuery:
return DocumentRagQuery(
query=data["query"],
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
doc_limit=int(data.get("doc-limit", 20))
doc_limit=int(data.get("doc-limit", 20)),
streaming=data.get("streaming", False)
)
def from_pulsar(self, obj: DocumentRagQuery) -> Dict[str, Any]:
return {
"query": obj.query,
"user": obj.user,
"collection": obj.collection,
"doc-limit": obj.doc_limit
"doc-limit": obj.doc_limit,
"streaming": getattr(obj, "streaming", False)
}
class DocumentRagResponseTranslator(MessageTranslator):
"""Translator for DocumentRagResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> DocumentRagResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]:
return {
"response": obj.response
}
result = {}
# Check if this is a streaming response (has chunk)
if hasattr(obj, 'chunk') and obj.chunk:
result["chunk"] = obj.chunk
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
else:
# Non-streaming response
if obj.response:
result["response"] = obj.response
# Always include error if present
if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "type": obj.error.type}
return result
def from_response_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True
# For streaming responses, check end_of_stream
if hasattr(obj, 'chunk') and obj.chunk:
is_final = getattr(obj, 'end_of_stream', False)
else:
# For non-streaming responses, it's always final
is_final = True
return self.from_pulsar(obj), is_final
class GraphRagRequestTranslator(MessageTranslator):
"""Translator for GraphRagQuery schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> GraphRagQuery:
return GraphRagQuery(
query=data["query"],
@ -50,9 +72,10 @@ class GraphRagRequestTranslator(MessageTranslator):
entity_limit=int(data.get("entity-limit", 50)),
triple_limit=int(data.get("triple-limit", 30)),
max_subgraph_size=int(data.get("max-subgraph-size", 1000)),
max_path_length=int(data.get("max-path-length", 2))
max_path_length=int(data.get("max-path-length", 2)),
streaming=data.get("streaming", False)
)
def from_pulsar(self, obj: GraphRagQuery) -> Dict[str, Any]:
return {
"query": obj.query,
@ -61,21 +84,42 @@ class GraphRagRequestTranslator(MessageTranslator):
"entity-limit": obj.entity_limit,
"triple-limit": obj.triple_limit,
"max-subgraph-size": obj.max_subgraph_size,
"max-path-length": obj.max_path_length
"max-path-length": obj.max_path_length,
"streaming": getattr(obj, "streaming", False)
}
class GraphRagResponseTranslator(MessageTranslator):
"""Translator for GraphRagResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> GraphRagResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]:
return {
"response": obj.response
}
result = {}
# Check if this is a streaming response (has chunk)
if hasattr(obj, 'chunk') and obj.chunk:
result["chunk"] = obj.chunk
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
else:
# Non-streaming response
if obj.response:
result["response"] = obj.response
# Always include error if present
if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "type": obj.error.type}
return result
def from_response_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True
# For streaming responses, check end_of_stream
if hasattr(obj, 'chunk') and obj.chunk:
is_final = getattr(obj, 'end_of_stream', False)
else:
# For non-streaming responses, it's always final
is_final = True
return self.from_pulsar(obj), is_final

View file

@ -15,10 +15,13 @@ class GraphRagQuery(Record):
triple_limit = Integer()
max_subgraph_size = Integer()
max_path_length = Integer()
streaming = Boolean()
class GraphRagResponse(Record):
error = Error()
response = String()
chunk = String()
end_of_stream = Boolean()
############################################################################
@ -29,8 +32,11 @@ class DocumentRagQuery(Record):
user = String()
collection = String()
doc_limit = Integer()
streaming = Boolean()
class DocumentRagResponse(Record):
error = Error()
response = String()
chunk = String()
end_of_stream = Boolean()

View file

@ -4,6 +4,10 @@ Uses the DocumentRAG service to answer a question
import argparse
import os
import asyncio
import json
import uuid
from websockets.asyncio.client import connect
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
@ -11,7 +15,69 @@ default_user = 'trustgraph'
default_collection = 'default'
default_doc_limit = 10
def question(url, flow_id, question, user, collection, doc_limit):
async def question_streaming(url, flow_id, question, user, collection, doc_limit):
"""Streaming version using websockets"""
# Convert http:// to ws://
if url.startswith('http://'):
url = 'ws://' + url[7:]
elif url.startswith('https://'):
url = 'wss://' + url[8:]
if not url.endswith("/"):
url += "/"
url = url + "api/v1/socket"
mid = str(uuid.uuid4())
async with connect(url) as ws:
req = {
"id": mid,
"service": "document-rag",
"flow": flow_id,
"request": {
"query": question,
"user": user,
"collection": collection,
"doc-limit": doc_limit,
"streaming": True
}
}
req = json.dumps(req)
await ws.send(req)
while True:
msg = await ws.recv()
obj = json.loads(msg)
if "error" in obj:
raise RuntimeError(obj["error"])
if obj["id"] != mid:
print("Ignore message")
continue
response = obj["response"]
# Handle streaming format (chunk)
if "chunk" in response:
chunk = response["chunk"]
print(chunk, end="", flush=True)
elif "response" in response:
# Final response with complete text
# Already printed via chunks, just add newline
pass
if obj["complete"]:
print() # Final newline
break
await ws.close()
def question_non_streaming(url, flow_id, question, user, collection, doc_limit):
"""Non-streaming version using HTTP API"""
api = Api(url).flow().id(flow_id)
@ -65,18 +131,36 @@ def main():
help=f'Document limit (default: {default_doc_limit})'
)
parser.add_argument(
'--no-streaming',
action='store_true',
help='Disable streaming (use non-streaming mode)'
)
args = parser.parse_args()
try:
question(
url=args.url,
flow_id = args.flow_id,
question=args.question,
user=args.user,
collection=args.collection,
doc_limit=args.doc_limit,
)
if not args.no_streaming:
asyncio.run(
question_streaming(
url=args.url,
flow_id=args.flow_id,
question=args.question,
user=args.user,
collection=args.collection,
doc_limit=args.doc_limit,
)
)
else:
question_non_streaming(
url=args.url,
flow_id=args.flow_id,
question=args.question,
user=args.user,
collection=args.collection,
doc_limit=args.doc_limit,
)
except Exception as e:

View file

@ -4,6 +4,10 @@ Uses the GraphRAG service to answer a question
import argparse
import os
import asyncio
import json
import uuid
from websockets.asyncio.client import connect
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
@ -14,10 +18,78 @@ default_triple_limit = 30
default_max_subgraph_size = 150
default_max_path_length = 2
def question(
async def question_streaming(
url, flow_id, question, user, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length
):
"""Streaming version using websockets"""
# Convert http:// to ws://
if url.startswith('http://'):
url = 'ws://' + url[7:]
elif url.startswith('https://'):
url = 'wss://' + url[8:]
if not url.endswith("/"):
url += "/"
url = url + "api/v1/socket"
mid = str(uuid.uuid4())
async with connect(url) as ws:
req = {
"id": mid,
"service": "graph-rag",
"flow": flow_id,
"request": {
"query": question,
"user": user,
"collection": collection,
"entity-limit": entity_limit,
"triple-limit": triple_limit,
"max-subgraph-size": max_subgraph_size,
"max-path-length": max_path_length,
"streaming": True
}
}
req = json.dumps(req)
await ws.send(req)
while True:
msg = await ws.recv()
obj = json.loads(msg)
if "error" in obj:
raise RuntimeError(obj["error"])
if obj["id"] != mid:
print("Ignore message")
continue
response = obj["response"]
# Handle streaming format (chunk)
if "chunk" in response:
chunk = response["chunk"]
print(chunk, end="", flush=True)
elif "response" in response:
# Final response with complete text
# Already printed via chunks, just add newline
pass
if obj["complete"]:
print() # Final newline
break
await ws.close()
def question_non_streaming(
url, flow_id, question, user, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length
):
"""Non-streaming version using HTTP API"""
api = Api(url).flow().id(flow_id)
@ -91,21 +163,42 @@ def main():
help=f'Max path length (default: {default_max_path_length})'
)
parser.add_argument(
'--no-streaming',
action='store_true',
help='Disable streaming (use non-streaming mode)'
)
args = parser.parse_args()
try:
question(
url=args.url,
flow_id = args.flow_id,
question=args.question,
user=args.user,
collection=args.collection,
entity_limit=args.entity_limit,
triple_limit=args.triple_limit,
max_subgraph_size=args.max_subgraph_size,
max_path_length=args.max_path_length,
)
if not args.no_streaming:
asyncio.run(
question_streaming(
url=args.url,
flow_id=args.flow_id,
question=args.question,
user=args.user,
collection=args.collection,
entity_limit=args.entity_limit,
triple_limit=args.triple_limit,
max_subgraph_size=args.max_subgraph_size,
max_path_length=args.max_path_length,
)
)
else:
question_non_streaming(
url=args.url,
flow_id=args.flow_id,
question=args.question,
user=args.user,
collection=args.collection,
entity_limit=args.entity_limit,
triple_limit=args.triple_limit,
max_subgraph_size=args.max_subgraph_size,
max_path_length=args.max_path_length,
)
except Exception as e:

View file

@ -202,12 +202,16 @@ class StreamingReActParser:
# Find which comes first
if args_idx >= 0 and (newline_idx < 0 or args_idx < newline_idx):
# Args delimiter found first
self.action_buffer = self.line_buffer[:args_idx].strip().strip('"')
# Only set action_buffer if not already set (to avoid overwriting with empty string)
if not self.action_buffer:
self.action_buffer = self.line_buffer[:args_idx].strip().strip('"')
self.line_buffer = self.line_buffer[args_idx + len(self.ARGS_DELIMITER):].lstrip()
self.state = ParserState.ARGS
elif newline_idx >= 0:
# Newline found, action name complete
self.action_buffer = self.line_buffer[:newline_idx].strip().strip('"')
# Only set action_buffer if not already set
if not self.action_buffer:
self.action_buffer = self.line_buffer[:newline_idx].strip().strip('"')
self.line_buffer = self.line_buffer[newline_idx + 1:]
# Stay in ACTION state or move to ARGS if we find delimiter
# Actually, check if next line has Args:

View file

@ -68,7 +68,7 @@ class DocumentRag:
async def query(
self, query, user="trustgraph", collection="default",
doc_limit=20,
doc_limit=20, streaming=False, chunk_callback=None,
):
if self.verbose:
@ -86,10 +86,18 @@ class DocumentRag:
logger.debug(f"Documents: {docs}")
logger.debug(f"Query: {query}")
resp = await self.prompt_client.document_prompt(
query = query,
documents = docs
)
if streaming and chunk_callback:
resp = await self.prompt_client.document_prompt(
query=query,
documents=docs,
streaming=True,
chunk_callback=chunk_callback
)
else:
resp = await self.prompt_client.document_prompt(
query=query,
documents=docs
)
if self.verbose:
logger.debug("Query processing complete")

View file

@ -92,20 +92,56 @@ class Processor(FlowProcessor):
else:
doc_limit = self.doc_limit
response = await self.rag.query(
v.query,
user=v.user,
collection=v.collection,
doc_limit=doc_limit
)
# Check if streaming is requested
if v.streaming:
# Define async callback for streaming chunks
async def send_chunk(chunk):
await flow("response").send(
DocumentRagResponse(
chunk=chunk,
end_of_stream=False,
response=None,
error=None
),
properties={"id": id}
)
await flow("response").send(
DocumentRagResponse(
response = response,
error = None
),
properties = {"id": id}
)
# Query with streaming enabled
full_response = await self.rag.query(
v.query,
user=v.user,
collection=v.collection,
doc_limit=doc_limit,
streaming=True,
chunk_callback=send_chunk,
)
# Send final message with complete response
await flow("response").send(
DocumentRagResponse(
chunk=None,
end_of_stream=True,
response=full_response,
error=None
),
properties={"id": id}
)
else:
# Non-streaming path (existing behavior)
response = await self.rag.query(
v.query,
user=v.user,
collection=v.collection,
doc_limit=doc_limit
)
await flow("response").send(
DocumentRagResponse(
response = response,
error = None
),
properties = {"id": id}
)
logger.info("Request processing complete")
@ -115,14 +151,21 @@ class Processor(FlowProcessor):
logger.debug("Sending error response...")
await flow("response").send(
DocumentRagResponse(
response = None,
error = Error(
type = "document-rag-error",
message = str(e),
),
# Send error response with end_of_stream flag if streaming was requested
error_response = DocumentRagResponse(
response = None,
error = Error(
type = "document-rag-error",
message = str(e),
),
)
# If streaming was requested, indicate stream end
if v.streaming:
error_response.end_of_stream = True
await flow("response").send(
error_response,
properties = {"id": id}
)

View file

@ -316,7 +316,7 @@ class GraphRag:
async def query(
self, query, user = "trustgraph", collection = "default",
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
max_path_length = 2,
max_path_length = 2, streaming = False, chunk_callback = None,
):
if self.verbose:
@ -337,7 +337,14 @@ class GraphRag:
logger.debug(f"Knowledge graph: {kg}")
logger.debug(f"Query: {query}")
resp = await self.prompt_client.kg_prompt(query, kg)
if streaming and chunk_callback:
resp = await self.prompt_client.kg_prompt(
query, kg,
streaming=True,
chunk_callback=chunk_callback
)
else:
resp = await self.prompt_client.kg_prompt(query, kg)
if self.verbose:
logger.debug("Query processing complete")

View file

@ -135,20 +135,56 @@ class Processor(FlowProcessor):
else:
max_path_length = self.default_max_path_length
response = await rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
)
# Check if streaming is requested
if v.streaming:
# Define async callback for streaming chunks
async def send_chunk(chunk):
await flow("response").send(
GraphRagResponse(
chunk=chunk,
end_of_stream=False,
response=None,
error=None
),
properties={"id": id}
)
await flow("response").send(
GraphRagResponse(
response = response,
error = None
),
properties = {"id": id}
)
# Query with streaming enabled
full_response = await rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
streaming = True,
chunk_callback = send_chunk,
)
# Send final message with complete response
await flow("response").send(
GraphRagResponse(
chunk=None,
end_of_stream=True,
response=full_response,
error=None
),
properties={"id": id}
)
else:
# Non-streaming path (existing behavior)
response = await rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
)
await flow("response").send(
GraphRagResponse(
response = response,
error = None
),
properties = {"id": id}
)
logger.info("Request processing complete")
@ -158,14 +194,21 @@ class Processor(FlowProcessor):
logger.debug("Sending error response...")
await flow("response").send(
GraphRagResponse(
response = None,
error = Error(
type = "graph-rag-error",
message = str(e),
),
# Send error response with end_of_stream flag if streaming was requested
error_response = GraphRagResponse(
response = None,
error = Error(
type = "graph-rag-error",
message = str(e),
),
)
# If streaming was requested, indicate stream end
if v.streaming:
error_response.end_of_stream = True
await flow("response").send(
error_response,
properties = {"id": id}
)