mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-28 09:56:22 +02:00
Streaming rag responses (#568)
* Tech spec for streaming RAG * Support for streaming Graph/Doc RAG
This commit is contained in:
parent
b1cc724f7d
commit
1948edaa50
20 changed files with 3087 additions and 94 deletions
288
docs/tech-specs/rag-streaming-support.md
Normal file
288
docs/tech-specs/rag-streaming-support.md
Normal file
|
|
@ -0,0 +1,288 @@
|
||||||
|
# RAG Streaming Support Technical Specification
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This specification describes adding streaming support to GraphRAG and DocumentRAG services, enabling real-time token-by-token responses for knowledge graph and document retrieval queries. This extends the existing streaming architecture already implemented for LLM text-completion, prompt, and agent services.
|
||||||
|
|
||||||
|
## Goals
|
||||||
|
|
||||||
|
- **Consistent streaming UX**: Provide the same streaming experience across all TrustGraph services
|
||||||
|
- **Minimal API changes**: Add streaming support with a single `streaming` flag, following established patterns
|
||||||
|
- **Backward compatibility**: Maintain existing non-streaming behavior as default
|
||||||
|
- **Reuse existing infrastructure**: Leverage PromptClient streaming already implemented
|
||||||
|
- **Gateway support**: Enable streaming through websocket gateway for client applications
|
||||||
|
|
||||||
|
## Background
|
||||||
|
|
||||||
|
Currently implemented streaming services:
|
||||||
|
- **LLM text-completion service**: Phase 1 - streaming from LLM providers
|
||||||
|
- **Prompt service**: Phase 2 - streaming through prompt templates
|
||||||
|
- **Agent service**: Phase 3-4 - streaming ReAct responses with incremental thought/observation/answer chunks
|
||||||
|
|
||||||
|
Current limitations for RAG services:
|
||||||
|
- GraphRAG and DocumentRAG only support blocking responses
|
||||||
|
- Users must wait for complete LLM response before seeing any output
|
||||||
|
- Poor UX for long responses from knowledge graph or document queries
|
||||||
|
- Inconsistent experience compared to other TrustGraph services
|
||||||
|
|
||||||
|
This specification addresses these gaps by adding streaming support to GraphRAG and DocumentRAG. By enabling token-by-token responses, TrustGraph can:
|
||||||
|
- Provide consistent streaming UX across all query types
|
||||||
|
- Reduce perceived latency for RAG queries
|
||||||
|
- Enable better progress feedback for long-running queries
|
||||||
|
- Support real-time display in client applications
|
||||||
|
|
||||||
|
## Technical Design
|
||||||
|
|
||||||
|
### Architecture
|
||||||
|
|
||||||
|
The RAG streaming implementation leverages existing infrastructure:
|
||||||
|
|
||||||
|
1. **PromptClient Streaming** (Already implemented)
|
||||||
|
- `kg_prompt()` and `document_prompt()` already accept `streaming` and `chunk_callback` parameters
|
||||||
|
- These call `prompt()` internally with streaming support
|
||||||
|
- No changes needed to PromptClient
|
||||||
|
|
||||||
|
Module: `trustgraph-base/trustgraph/base/prompt_client.py`
|
||||||
|
|
||||||
|
2. **GraphRAG Service** (Needs streaming parameter pass-through)
|
||||||
|
- Add `streaming` parameter to `query()` method
|
||||||
|
- Pass streaming flag and callbacks to `prompt_client.kg_prompt()`
|
||||||
|
- GraphRagRequest schema needs `streaming` field
|
||||||
|
|
||||||
|
Modules:
|
||||||
|
- `trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py`
|
||||||
|
- `trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py` (Processor)
|
||||||
|
- `trustgraph-base/trustgraph/schema/graph_rag.py` (Request schema)
|
||||||
|
- `trustgraph-flow/trustgraph/gateway/dispatch/graph_rag.py` (Gateway)
|
||||||
|
|
||||||
|
3. **DocumentRAG Service** (Needs streaming parameter pass-through)
|
||||||
|
- Add `streaming` parameter to `query()` method
|
||||||
|
- Pass streaming flag and callbacks to `prompt_client.document_prompt()`
|
||||||
|
- DocumentRagRequest schema needs `streaming` field
|
||||||
|
|
||||||
|
Modules:
|
||||||
|
- `trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py`
|
||||||
|
- `trustgraph-flow/trustgraph/retrieval/document_rag/rag.py` (Processor)
|
||||||
|
- `trustgraph-base/trustgraph/schema/document_rag.py` (Request schema)
|
||||||
|
- `trustgraph-flow/trustgraph/gateway/dispatch/document_rag.py` (Gateway)
|
||||||
|
|
||||||
|
### Data Flow
|
||||||
|
|
||||||
|
**Non-streaming (current)**:
|
||||||
|
```
|
||||||
|
Client → Gateway → RAG Service → PromptClient.kg_prompt(streaming=False)
|
||||||
|
↓
|
||||||
|
Prompt Service → LLM
|
||||||
|
↓
|
||||||
|
Complete response
|
||||||
|
↓
|
||||||
|
Client ← Gateway ← RAG Service ← Response
|
||||||
|
```
|
||||||
|
|
||||||
|
**Streaming (proposed)**:
|
||||||
|
```
|
||||||
|
Client → Gateway → RAG Service → PromptClient.kg_prompt(streaming=True, chunk_callback=cb)
|
||||||
|
↓
|
||||||
|
Prompt Service → LLM (streaming)
|
||||||
|
↓
|
||||||
|
Chunk → callback → RAG Response (chunk)
|
||||||
|
↓ ↓
|
||||||
|
Client ← Gateway ← ────────────────────────────────── Response stream
|
||||||
|
```
|
||||||
|
|
||||||
|
### APIs
|
||||||
|
|
||||||
|
**GraphRAG Changes**:
|
||||||
|
|
||||||
|
1. **GraphRag.query()** - Add streaming parameters
|
||||||
|
```python
|
||||||
|
async def query(
|
||||||
|
self, query, user, collection,
|
||||||
|
verbose=False, streaming=False, chunk_callback=None # NEW
|
||||||
|
):
|
||||||
|
# ... existing entity/triple retrieval ...
|
||||||
|
|
||||||
|
if streaming and chunk_callback:
|
||||||
|
resp = await self.prompt_client.kg_prompt(
|
||||||
|
query, kg,
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=chunk_callback
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
resp = await self.prompt_client.kg_prompt(query, kg)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **GraphRagRequest schema** - Add streaming field
|
||||||
|
```python
|
||||||
|
class GraphRagRequest(Record):
|
||||||
|
query = String()
|
||||||
|
user = String()
|
||||||
|
collection = String()
|
||||||
|
streaming = Boolean() # NEW
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **GraphRagResponse schema** - Add streaming fields (follow Agent pattern)
|
||||||
|
```python
|
||||||
|
class GraphRagResponse(Record):
|
||||||
|
response = String() # Legacy: complete response
|
||||||
|
chunk = String() # NEW: streaming chunk
|
||||||
|
end_of_stream = Boolean() # NEW: indicates last chunk
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Processor** - Pass streaming through
|
||||||
|
```python
|
||||||
|
async def handle(self, msg):
|
||||||
|
# ... existing code ...
|
||||||
|
|
||||||
|
async def send_chunk(chunk):
|
||||||
|
await self.respond(GraphRagResponse(
|
||||||
|
chunk=chunk,
|
||||||
|
end_of_stream=False,
|
||||||
|
response=None
|
||||||
|
))
|
||||||
|
|
||||||
|
if request.streaming:
|
||||||
|
full_response = await self.rag.query(
|
||||||
|
query=request.query,
|
||||||
|
user=request.user,
|
||||||
|
collection=request.collection,
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=send_chunk
|
||||||
|
)
|
||||||
|
# Send final message
|
||||||
|
await self.respond(GraphRagResponse(
|
||||||
|
chunk=None,
|
||||||
|
end_of_stream=True,
|
||||||
|
response=full_response
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
# Existing non-streaming path
|
||||||
|
response = await self.rag.query(...)
|
||||||
|
await self.respond(GraphRagResponse(response=response))
|
||||||
|
```
|
||||||
|
|
||||||
|
**DocumentRAG Changes**:
|
||||||
|
|
||||||
|
Identical pattern to GraphRAG:
|
||||||
|
1. Add `streaming` and `chunk_callback` parameters to `DocumentRag.query()`
|
||||||
|
2. Add `streaming` field to `DocumentRagRequest`
|
||||||
|
3. Add `chunk` and `end_of_stream` fields to `DocumentRagResponse`
|
||||||
|
4. Update Processor to handle streaming with callbacks
|
||||||
|
|
||||||
|
**Gateway Changes**:
|
||||||
|
|
||||||
|
Both `graph_rag.py` and `document_rag.py` in gateway/dispatch need updates to forward streaming chunks to websocket:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def handle(self, message, session, websocket):
|
||||||
|
# ... existing code ...
|
||||||
|
|
||||||
|
if request.streaming:
|
||||||
|
async def recipient(resp):
|
||||||
|
if resp.chunk:
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"id": message["id"],
|
||||||
|
"response": {"chunk": resp.chunk},
|
||||||
|
"complete": resp.end_of_stream
|
||||||
|
}))
|
||||||
|
return resp.end_of_stream
|
||||||
|
|
||||||
|
await self.rag_client.request(request, recipient=recipient)
|
||||||
|
else:
|
||||||
|
# Existing non-streaming path
|
||||||
|
resp = await self.rag_client.request(request)
|
||||||
|
await websocket.send(...)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Implementation Details
|
||||||
|
|
||||||
|
**Implementation order**:
|
||||||
|
1. Add schema fields (Request + Response for both RAG services)
|
||||||
|
2. Update GraphRag.query() and DocumentRag.query() methods
|
||||||
|
3. Update Processors to handle streaming
|
||||||
|
4. Update Gateway dispatch handlers
|
||||||
|
5. Add `--no-streaming` flags to `tg-invoke-graph-rag` and `tg-invoke-document-rag` (streaming enabled by default, following agent CLI pattern)
|
||||||
|
|
||||||
|
**Callback pattern**:
|
||||||
|
Follow the same async callback pattern established in Agent streaming:
|
||||||
|
- Processor defines `async def send_chunk(chunk)` callback
|
||||||
|
- Passes callback to RAG service
|
||||||
|
- RAG service passes callback to PromptClient
|
||||||
|
- PromptClient invokes callback for each LLM chunk
|
||||||
|
- Processor sends streaming response message for each chunk
|
||||||
|
|
||||||
|
**Error handling**:
|
||||||
|
- Errors during streaming should send error response with `end_of_stream=True`
|
||||||
|
- Follow existing error propagation patterns from Agent streaming
|
||||||
|
|
||||||
|
## Security Considerations
|
||||||
|
|
||||||
|
No new security considerations beyond existing RAG services:
|
||||||
|
- Streaming responses use same user/collection isolation
|
||||||
|
- No changes to authentication or authorization
|
||||||
|
- Chunk boundaries don't expose sensitive data
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
**Benefits**:
|
||||||
|
- Reduced perceived latency (first tokens arrive faster)
|
||||||
|
- Better UX for long responses
|
||||||
|
- Lower memory usage (no need to buffer complete response)
|
||||||
|
|
||||||
|
**Potential concerns**:
|
||||||
|
- More Pulsar messages for streaming responses
|
||||||
|
- Slightly higher CPU for chunking/callback overhead
|
||||||
|
- Mitigated by: streaming is opt-in, default remains non-streaming
|
||||||
|
|
||||||
|
**Testing considerations**:
|
||||||
|
- Test with large knowledge graphs (many triples)
|
||||||
|
- Test with many retrieved documents
|
||||||
|
- Measure overhead of streaming vs non-streaming
|
||||||
|
|
||||||
|
## Testing Strategy
|
||||||
|
|
||||||
|
**Unit tests**:
|
||||||
|
- Test GraphRag.query() with streaming=True/False
|
||||||
|
- Test DocumentRag.query() with streaming=True/False
|
||||||
|
- Mock PromptClient to verify callback invocations
|
||||||
|
|
||||||
|
**Integration tests**:
|
||||||
|
- Test full GraphRAG streaming flow (similar to existing agent streaming tests)
|
||||||
|
- Test full DocumentRAG streaming flow
|
||||||
|
- Test Gateway streaming forwarding
|
||||||
|
- Test CLI streaming output
|
||||||
|
|
||||||
|
**Manual testing**:
|
||||||
|
- `tg-invoke-graph-rag -q "What is machine learning?"` (streaming by default)
|
||||||
|
- `tg-invoke-document-rag -q "Summarize the documents about AI"` (streaming by default)
|
||||||
|
- `tg-invoke-graph-rag --no-streaming -q "..."` (test non-streaming mode)
|
||||||
|
- Verify incremental output appears in streaming mode
|
||||||
|
|
||||||
|
## Migration Plan
|
||||||
|
|
||||||
|
No migration needed:
|
||||||
|
- Streaming is opt-in via `streaming` parameter (defaults to False)
|
||||||
|
- Existing clients continue to work unchanged
|
||||||
|
- New clients can opt into streaming
|
||||||
|
|
||||||
|
## Timeline
|
||||||
|
|
||||||
|
Estimated implementation: 4-6 hours
|
||||||
|
- Phase 1 (2 hours): GraphRAG streaming support
|
||||||
|
- Phase 2 (2 hours): DocumentRAG streaming support
|
||||||
|
- Phase 3 (1-2 hours): Gateway updates and CLI flags
|
||||||
|
- Testing: Built into each phase
|
||||||
|
|
||||||
|
## Open Questions
|
||||||
|
|
||||||
|
- Should we add streaming support to NLP Query service as well?
|
||||||
|
- Do we want to stream intermediate steps (e.g., "Retrieving entities...", "Querying graph...") or just LLM output?
|
||||||
|
- Should GraphRAG/DocumentRAG responses include chunk metadata (e.g., chunk number, total expected)?
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- Existing implementation: `docs/tech-specs/streaming-llm-responses.md`
|
||||||
|
- Agent streaming: `trustgraph-flow/trustgraph/agent/react/agent_manager.py`
|
||||||
|
- PromptClient streaming: `trustgraph-base/trustgraph/base/prompt_client.py`
|
||||||
|
|
@ -382,6 +382,206 @@ def sample_kg_triples():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Streaming test fixtures
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_streaming_llm_response():
|
||||||
|
"""Mock streaming LLM response with realistic chunks"""
|
||||||
|
async def _generate_chunks():
|
||||||
|
"""Generate realistic streaming chunks"""
|
||||||
|
chunks = [
|
||||||
|
"Machine",
|
||||||
|
" learning",
|
||||||
|
" is",
|
||||||
|
" a",
|
||||||
|
" subset",
|
||||||
|
" of",
|
||||||
|
" artificial",
|
||||||
|
" intelligence",
|
||||||
|
" that",
|
||||||
|
" focuses",
|
||||||
|
" on",
|
||||||
|
" algorithms",
|
||||||
|
" that",
|
||||||
|
" learn",
|
||||||
|
" from",
|
||||||
|
" data",
|
||||||
|
"."
|
||||||
|
]
|
||||||
|
for chunk in chunks:
|
||||||
|
yield chunk
|
||||||
|
return _generate_chunks
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_streaming_agent_response():
|
||||||
|
"""Sample streaming agent response chunks"""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"chunk_type": "thought",
|
||||||
|
"content": "I need to search",
|
||||||
|
"end_of_message": False,
|
||||||
|
"end_of_dialog": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_type": "thought",
|
||||||
|
"content": " for information",
|
||||||
|
"end_of_message": False,
|
||||||
|
"end_of_dialog": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_type": "thought",
|
||||||
|
"content": " about machine learning.",
|
||||||
|
"end_of_message": True,
|
||||||
|
"end_of_dialog": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_type": "action",
|
||||||
|
"content": "knowledge_query",
|
||||||
|
"end_of_message": True,
|
||||||
|
"end_of_dialog": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_type": "observation",
|
||||||
|
"content": "Machine learning is",
|
||||||
|
"end_of_message": False,
|
||||||
|
"end_of_dialog": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_type": "observation",
|
||||||
|
"content": " a subset of AI.",
|
||||||
|
"end_of_message": True,
|
||||||
|
"end_of_dialog": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_type": "final-answer",
|
||||||
|
"content": "Machine learning",
|
||||||
|
"end_of_message": False,
|
||||||
|
"end_of_dialog": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_type": "final-answer",
|
||||||
|
"content": " is a subset",
|
||||||
|
"end_of_message": False,
|
||||||
|
"end_of_dialog": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_type": "final-answer",
|
||||||
|
"content": " of artificial intelligence.",
|
||||||
|
"end_of_message": True,
|
||||||
|
"end_of_dialog": True
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def streaming_chunk_collector():
|
||||||
|
"""Helper to collect streaming chunks for assertions"""
|
||||||
|
class ChunkCollector:
|
||||||
|
def __init__(self):
|
||||||
|
self.chunks = []
|
||||||
|
self.complete = False
|
||||||
|
|
||||||
|
async def collect(self, chunk):
|
||||||
|
"""Async callback to collect chunks"""
|
||||||
|
self.chunks.append(chunk)
|
||||||
|
|
||||||
|
def get_full_text(self):
|
||||||
|
"""Concatenate all chunk content"""
|
||||||
|
return "".join(self.chunks)
|
||||||
|
|
||||||
|
def get_chunk_types(self):
|
||||||
|
"""Get list of chunk types if chunks are dicts"""
|
||||||
|
if self.chunks and isinstance(self.chunks[0], dict):
|
||||||
|
return [c.get("chunk_type") for c in self.chunks]
|
||||||
|
return []
|
||||||
|
|
||||||
|
return ChunkCollector
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_streaming_prompt_response():
|
||||||
|
"""Mock streaming prompt service response"""
|
||||||
|
async def _generate_prompt_chunks():
|
||||||
|
"""Generate streaming chunks for prompt responses"""
|
||||||
|
chunks = [
|
||||||
|
"Based on the",
|
||||||
|
" provided context,",
|
||||||
|
" here is",
|
||||||
|
" the answer:",
|
||||||
|
" Machine learning",
|
||||||
|
" enables computers",
|
||||||
|
" to learn",
|
||||||
|
" from data."
|
||||||
|
]
|
||||||
|
for chunk in chunks:
|
||||||
|
yield chunk
|
||||||
|
return _generate_prompt_chunks
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_rag_streaming_chunks():
|
||||||
|
"""Sample RAG streaming response chunks"""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"chunk": "Based on",
|
||||||
|
"end_of_stream": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk": " the knowledge",
|
||||||
|
"end_of_stream": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk": " graph,",
|
||||||
|
"end_of_stream": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk": " machine learning",
|
||||||
|
"end_of_stream": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk": " is a subset",
|
||||||
|
"end_of_stream": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk": " of AI.",
|
||||||
|
"end_of_stream": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk": None,
|
||||||
|
"end_of_stream": True,
|
||||||
|
"response": "Based on the knowledge graph, machine learning is a subset of AI."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def streaming_error_scenarios():
|
||||||
|
"""Common error scenarios for streaming tests"""
|
||||||
|
return {
|
||||||
|
"connection_drop": {
|
||||||
|
"exception": ConnectionError,
|
||||||
|
"message": "Connection lost during streaming",
|
||||||
|
"chunks_before_error": 5
|
||||||
|
},
|
||||||
|
"timeout": {
|
||||||
|
"exception": TimeoutError,
|
||||||
|
"message": "Streaming timeout exceeded",
|
||||||
|
"chunks_before_error": 10
|
||||||
|
},
|
||||||
|
"rate_limit": {
|
||||||
|
"exception": Exception,
|
||||||
|
"message": "Rate limit exceeded",
|
||||||
|
"chunks_before_error": 3
|
||||||
|
},
|
||||||
|
"invalid_chunk": {
|
||||||
|
"exception": ValueError,
|
||||||
|
"message": "Invalid chunk format",
|
||||||
|
"chunks_before_error": 7
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# Test markers for integration tests
|
# Test markers for integration tests
|
||||||
pytestmark = pytest.mark.integration
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
|
||||||
360
tests/integration/test_agent_streaming_integration.py
Normal file
360
tests/integration/test_agent_streaming_integration.py
Normal file
|
|
@ -0,0 +1,360 @@
|
||||||
|
"""
|
||||||
|
Integration tests for Agent Manager Streaming Functionality
|
||||||
|
|
||||||
|
These tests verify the streaming behavior of the Agent service, testing
|
||||||
|
chunk-by-chunk delivery of thoughts, actions, observations, and final answers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from trustgraph.agent.react.agent_manager import AgentManager
|
||||||
|
from trustgraph.agent.react.tools import KnowledgeQueryImpl
|
||||||
|
from trustgraph.agent.react.types import Tool, Argument
|
||||||
|
from tests.utils.streaming_assertions import (
|
||||||
|
assert_agent_streaming_chunks,
|
||||||
|
assert_streaming_chunks_valid,
|
||||||
|
assert_callback_invoked,
|
||||||
|
assert_chunk_types_valid,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestAgentStreaming:
|
||||||
|
"""Integration tests for Agent streaming functionality"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_prompt_client_streaming(self):
|
||||||
|
"""Mock prompt client with streaming support"""
|
||||||
|
client = AsyncMock()
|
||||||
|
|
||||||
|
async def agent_react_streaming(variables, timeout=600, streaming=False, chunk_callback=None):
|
||||||
|
# Both modes return the same text for equivalence
|
||||||
|
full_text = """Thought: I need to search for information about machine learning.
|
||||||
|
Action: knowledge_query
|
||||||
|
Args: {
|
||||||
|
"question": "What is machine learning?"
|
||||||
|
}"""
|
||||||
|
|
||||||
|
if streaming and chunk_callback:
|
||||||
|
# Send realistic line-by-line chunks
|
||||||
|
# This tests that the parser properly handles "Args:" starting a new chunk
|
||||||
|
# (which previously caused a bug where action_buffer was overwritten)
|
||||||
|
chunks = [
|
||||||
|
"Thought: I need to search for information about machine learning.\n",
|
||||||
|
"Action: knowledge_query\n",
|
||||||
|
"Args: {\n", # This used to trigger bug - Args: at start of chunk
|
||||||
|
' "question": "What is machine learning?"\n',
|
||||||
|
"}"
|
||||||
|
]
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
await chunk_callback(chunk)
|
||||||
|
|
||||||
|
return full_text
|
||||||
|
else:
|
||||||
|
# Non-streaming response - same text
|
||||||
|
return full_text
|
||||||
|
|
||||||
|
client.agent_react.side_effect = agent_react_streaming
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_flow_context(self, mock_prompt_client_streaming):
|
||||||
|
"""Mock flow context with streaming prompt client"""
|
||||||
|
context = MagicMock()
|
||||||
|
|
||||||
|
# Mock graph RAG client
|
||||||
|
graph_rag_client = AsyncMock()
|
||||||
|
graph_rag_client.rag.return_value = "Machine learning is a subset of AI."
|
||||||
|
|
||||||
|
def context_router(service_name):
|
||||||
|
if service_name == "prompt-request":
|
||||||
|
return mock_prompt_client_streaming
|
||||||
|
elif service_name == "graph-rag-request":
|
||||||
|
return graph_rag_client
|
||||||
|
else:
|
||||||
|
return AsyncMock()
|
||||||
|
|
||||||
|
context.side_effect = context_router
|
||||||
|
return context
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_tools(self):
|
||||||
|
"""Sample tool configuration"""
|
||||||
|
return {
|
||||||
|
"knowledge_query": Tool(
|
||||||
|
name="knowledge_query",
|
||||||
|
description="Query the knowledge graph",
|
||||||
|
arguments=[
|
||||||
|
Argument(
|
||||||
|
name="question",
|
||||||
|
type="string",
|
||||||
|
description="The question to ask"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
implementation=KnowledgeQueryImpl,
|
||||||
|
config={}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def agent_manager(self, sample_tools):
|
||||||
|
"""Create AgentManager instance with streaming support"""
|
||||||
|
return AgentManager(
|
||||||
|
tools=sample_tools,
|
||||||
|
additional_context="You are a helpful AI assistant."
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_streaming_thought_chunks(self, agent_manager, mock_flow_context):
|
||||||
|
"""Test that thought chunks are streamed correctly"""
|
||||||
|
# Arrange
|
||||||
|
thought_chunks = []
|
||||||
|
|
||||||
|
async def think(chunk):
|
||||||
|
thought_chunks.append(chunk)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await agent_manager.react(
|
||||||
|
question="What is machine learning?",
|
||||||
|
history=[],
|
||||||
|
think=think,
|
||||||
|
observe=AsyncMock(),
|
||||||
|
context=mock_flow_context,
|
||||||
|
streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(thought_chunks) > 0
|
||||||
|
assert_streaming_chunks_valid(thought_chunks, min_chunks=1)
|
||||||
|
|
||||||
|
# Verify thought content makes sense
|
||||||
|
full_thought = "".join(thought_chunks)
|
||||||
|
assert "search" in full_thought.lower() or "information" in full_thought.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_streaming_observation_chunks(self, agent_manager, mock_flow_context):
|
||||||
|
"""Test that observation chunks are streamed correctly"""
|
||||||
|
# Arrange
|
||||||
|
observation_chunks = []
|
||||||
|
|
||||||
|
async def observe(chunk):
|
||||||
|
observation_chunks.append(chunk)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await agent_manager.react(
|
||||||
|
question="What is machine learning?",
|
||||||
|
history=[],
|
||||||
|
think=AsyncMock(),
|
||||||
|
observe=observe,
|
||||||
|
context=mock_flow_context,
|
||||||
|
streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# Note: Observations come from tool execution, which may or may not be streamed
|
||||||
|
# depending on the tool implementation
|
||||||
|
# For now, verify callback was set up
|
||||||
|
assert observe is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_streaming_vs_non_streaming(self, agent_manager, mock_flow_context):
|
||||||
|
"""Test that streaming and non-streaming produce equivalent results"""
|
||||||
|
# Arrange
|
||||||
|
question = "What is machine learning?"
|
||||||
|
history = []
|
||||||
|
|
||||||
|
# Act - Non-streaming
|
||||||
|
non_streaming_result = await agent_manager.react(
|
||||||
|
question=question,
|
||||||
|
history=history,
|
||||||
|
think=AsyncMock(),
|
||||||
|
observe=AsyncMock(),
|
||||||
|
context=mock_flow_context,
|
||||||
|
streaming=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act - Streaming
|
||||||
|
thought_chunks = []
|
||||||
|
observation_chunks = []
|
||||||
|
|
||||||
|
async def think(chunk):
|
||||||
|
thought_chunks.append(chunk)
|
||||||
|
|
||||||
|
async def observe(chunk):
|
||||||
|
observation_chunks.append(chunk)
|
||||||
|
|
||||||
|
streaming_result = await agent_manager.react(
|
||||||
|
question=question,
|
||||||
|
history=history,
|
||||||
|
think=think,
|
||||||
|
observe=observe,
|
||||||
|
context=mock_flow_context,
|
||||||
|
streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - Results should be equivalent (or both valid)
|
||||||
|
assert non_streaming_result is not None
|
||||||
|
assert streaming_result is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_streaming_callback_invocation(self, agent_manager, mock_flow_context):
|
||||||
|
"""Test that callbacks are invoked with correct parameters"""
|
||||||
|
# Arrange
|
||||||
|
think = AsyncMock()
|
||||||
|
observe = AsyncMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await agent_manager.react(
|
||||||
|
question="What is machine learning?",
|
||||||
|
history=[],
|
||||||
|
think=think,
|
||||||
|
observe=observe,
|
||||||
|
context=mock_flow_context,
|
||||||
|
streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - Think callback should be invoked
|
||||||
|
assert think.call_count > 0
|
||||||
|
|
||||||
|
# Verify all callback invocations had string arguments
|
||||||
|
for call in think.call_args_list:
|
||||||
|
assert len(call.args) > 0
|
||||||
|
assert isinstance(call.args[0], str)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_streaming_without_callbacks(self, agent_manager, mock_flow_context):
|
||||||
|
"""Test streaming parameter without callbacks (should work gracefully)"""
|
||||||
|
# Arrange & Act
|
||||||
|
result = await agent_manager.react(
|
||||||
|
question="What is machine learning?",
|
||||||
|
history=[],
|
||||||
|
think=AsyncMock(),
|
||||||
|
observe=AsyncMock(),
|
||||||
|
context=mock_flow_context,
|
||||||
|
streaming=True # Streaming enabled with mock callbacks
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - Should complete without error
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_streaming_with_conversation_history(self, agent_manager, mock_flow_context):
|
||||||
|
"""Test streaming with existing conversation history"""
|
||||||
|
# Arrange
|
||||||
|
# History should be a list of Action objects
|
||||||
|
from trustgraph.agent.react.types import Action
|
||||||
|
history = [
|
||||||
|
Action(
|
||||||
|
thought="I need to search for information about machine learning",
|
||||||
|
name="knowledge_query",
|
||||||
|
arguments={"question": "What is machine learning?"},
|
||||||
|
observation="Machine learning is a subset of AI that enables computers to learn from data."
|
||||||
|
)
|
||||||
|
]
|
||||||
|
think = AsyncMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await agent_manager.react(
|
||||||
|
question="Tell me more about neural networks",
|
||||||
|
history=history,
|
||||||
|
think=think,
|
||||||
|
observe=AsyncMock(),
|
||||||
|
context=mock_flow_context,
|
||||||
|
streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is not None
|
||||||
|
assert think.call_count > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_streaming_error_propagation(self, agent_manager, mock_flow_context):
|
||||||
|
"""Test that errors during streaming are properly propagated"""
|
||||||
|
# Arrange
|
||||||
|
mock_prompt_client = mock_flow_context("prompt-request")
|
||||||
|
mock_prompt_client.agent_react.side_effect = Exception("Prompt service error")
|
||||||
|
|
||||||
|
think = AsyncMock()
|
||||||
|
observe = AsyncMock()
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await agent_manager.react(
|
||||||
|
question="test question",
|
||||||
|
history=[],
|
||||||
|
think=think,
|
||||||
|
observe=observe,
|
||||||
|
context=mock_flow_context,
|
||||||
|
streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Prompt service error" in str(exc_info.value)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_streaming_multi_step_reasoning(self, agent_manager, mock_flow_context,
|
||||||
|
mock_prompt_client_streaming):
|
||||||
|
"""Test streaming through multi-step reasoning process"""
|
||||||
|
# Arrange - Mock a multi-step response
|
||||||
|
step_responses = [
|
||||||
|
"""Thought: I need to search for basic information.
|
||||||
|
Action: knowledge_query
|
||||||
|
Args: {"question": "What is AI?"}""",
|
||||||
|
"""Thought: Now I can answer the question.
|
||||||
|
Final Answer: AI is the simulation of human intelligence in machines."""
|
||||||
|
]
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def multi_step_agent_react(variables, timeout=600, streaming=False, chunk_callback=None):
|
||||||
|
nonlocal call_count
|
||||||
|
response = step_responses[min(call_count, len(step_responses) - 1)]
|
||||||
|
call_count += 1
|
||||||
|
|
||||||
|
if streaming and chunk_callback:
|
||||||
|
for chunk in response.split():
|
||||||
|
await chunk_callback(chunk + " ")
|
||||||
|
return response
|
||||||
|
return response
|
||||||
|
|
||||||
|
mock_prompt_client_streaming.agent_react.side_effect = multi_step_agent_react
|
||||||
|
|
||||||
|
think = AsyncMock()
|
||||||
|
observe = AsyncMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await agent_manager.react(
|
||||||
|
question="What is artificial intelligence?",
|
||||||
|
history=[],
|
||||||
|
think=think,
|
||||||
|
observe=observe,
|
||||||
|
context=mock_flow_context,
|
||||||
|
streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is not None
|
||||||
|
assert think.call_count > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_streaming_preserves_tool_config(self, agent_manager, mock_flow_context):
|
||||||
|
"""Test that streaming preserves tool configuration and context"""
|
||||||
|
# Arrange
|
||||||
|
think = AsyncMock()
|
||||||
|
observe = AsyncMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await agent_manager.react(
|
||||||
|
question="What is machine learning?",
|
||||||
|
history=[],
|
||||||
|
think=think,
|
||||||
|
observe=observe,
|
||||||
|
context=mock_flow_context,
|
||||||
|
streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - Verify prompt client was called with streaming
|
||||||
|
mock_prompt_client = mock_flow_context("prompt-request")
|
||||||
|
call_args = mock_prompt_client.agent_react.call_args
|
||||||
|
assert call_args.kwargs['streaming'] is True
|
||||||
|
assert call_args.kwargs['chunk_callback'] is not None
|
||||||
274
tests/integration/test_document_rag_streaming_integration.py
Normal file
274
tests/integration/test_document_rag_streaming_integration.py
Normal file
|
|
@ -0,0 +1,274 @@
|
||||||
|
"""
|
||||||
|
Integration tests for DocumentRAG streaming functionality
|
||||||
|
|
||||||
|
These tests verify the streaming behavior of DocumentRAG, testing token-by-token
|
||||||
|
response delivery through the complete pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||||
|
from tests.utils.streaming_assertions import (
|
||||||
|
assert_streaming_chunks_valid,
|
||||||
|
assert_callback_invoked,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestDocumentRagStreaming:
|
||||||
|
"""Integration tests for DocumentRAG streaming"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embeddings_client(self):
|
||||||
|
"""Mock embeddings client"""
|
||||||
|
client = AsyncMock()
|
||||||
|
client.embed.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_doc_embeddings_client(self):
|
||||||
|
"""Mock document embeddings client"""
|
||||||
|
client = AsyncMock()
|
||||||
|
client.query.return_value = [
|
||||||
|
"Machine learning is a subset of AI.",
|
||||||
|
"Deep learning uses neural networks.",
|
||||||
|
"Supervised learning needs labeled data."
|
||||||
|
]
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_streaming_prompt_client(self, mock_streaming_llm_response):
|
||||||
|
"""Mock prompt client with streaming support"""
|
||||||
|
client = AsyncMock()
|
||||||
|
|
||||||
|
async def document_prompt_side_effect(query, documents, timeout=600, streaming=False, chunk_callback=None):
|
||||||
|
# Both modes return the same text
|
||||||
|
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
|
||||||
|
|
||||||
|
if streaming and chunk_callback:
|
||||||
|
# Simulate streaming chunks
|
||||||
|
async for chunk in mock_streaming_llm_response():
|
||||||
|
await chunk_callback(chunk)
|
||||||
|
return full_text
|
||||||
|
else:
|
||||||
|
# Non-streaming response - same text
|
||||||
|
return full_text
|
||||||
|
|
||||||
|
client.document_prompt.side_effect = document_prompt_side_effect
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def document_rag_streaming(self, mock_embeddings_client, mock_doc_embeddings_client,
|
||||||
|
mock_streaming_prompt_client):
|
||||||
|
"""Create DocumentRag instance with streaming support"""
|
||||||
|
return DocumentRag(
|
||||||
|
embeddings_client=mock_embeddings_client,
|
||||||
|
doc_embeddings_client=mock_doc_embeddings_client,
|
||||||
|
prompt_client=mock_streaming_prompt_client,
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_document_rag_streaming_basic(self, document_rag_streaming, streaming_chunk_collector):
|
||||||
|
"""Test basic DocumentRAG streaming functionality"""
|
||||||
|
# Arrange
|
||||||
|
query = "What is machine learning?"
|
||||||
|
collector = streaming_chunk_collector()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await document_rag_streaming.query(
|
||||||
|
query=query,
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
doc_limit=10,
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=collector.collect
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
|
||||||
|
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
|
||||||
|
|
||||||
|
# Verify full response matches concatenated chunks
|
||||||
|
full_from_chunks = collector.get_full_text()
|
||||||
|
assert result == full_from_chunks
|
||||||
|
|
||||||
|
# Verify content is reasonable
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_document_rag_streaming_vs_non_streaming(self, document_rag_streaming):
|
||||||
|
"""Test that streaming and non-streaming produce equivalent results"""
|
||||||
|
# Arrange
|
||||||
|
query = "What is machine learning?"
|
||||||
|
user = "test_user"
|
||||||
|
collection = "test_collection"
|
||||||
|
doc_limit = 10
|
||||||
|
|
||||||
|
# Act - Non-streaming
|
||||||
|
non_streaming_result = await document_rag_streaming.query(
|
||||||
|
query=query,
|
||||||
|
user=user,
|
||||||
|
collection=collection,
|
||||||
|
doc_limit=doc_limit,
|
||||||
|
streaming=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act - Streaming
|
||||||
|
streaming_chunks = []
|
||||||
|
|
||||||
|
async def collect(chunk):
|
||||||
|
streaming_chunks.append(chunk)
|
||||||
|
|
||||||
|
streaming_result = await document_rag_streaming.query(
|
||||||
|
query=query,
|
||||||
|
user=user,
|
||||||
|
collection=collection,
|
||||||
|
doc_limit=doc_limit,
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=collect
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - Results should be equivalent
|
||||||
|
assert streaming_result == non_streaming_result
|
||||||
|
assert len(streaming_chunks) > 0
|
||||||
|
assert "".join(streaming_chunks) == streaming_result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_document_rag_streaming_callback_invocation(self, document_rag_streaming):
|
||||||
|
"""Test that chunk callback is invoked correctly"""
|
||||||
|
# Arrange
|
||||||
|
callback = AsyncMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await document_rag_streaming.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
doc_limit=5,
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert callback.call_count > 0
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
# Verify all callback invocations had string arguments
|
||||||
|
for call in callback.call_args_list:
|
||||||
|
assert isinstance(call.args[0], str)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_document_rag_streaming_without_callback(self, document_rag_streaming):
|
||||||
|
"""Test streaming parameter without callback (should fall back to non-streaming)"""
|
||||||
|
# Arrange & Act
|
||||||
|
result = await document_rag_streaming.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
doc_limit=5,
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=None # No callback provided
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - Should complete without error
|
||||||
|
assert result is not None
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_document_rag_streaming_with_no_documents(self, document_rag_streaming,
|
||||||
|
mock_doc_embeddings_client):
|
||||||
|
"""Test streaming with no documents found"""
|
||||||
|
# Arrange
|
||||||
|
mock_doc_embeddings_client.query.return_value = [] # No documents
|
||||||
|
callback = AsyncMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await document_rag_streaming.query(
|
||||||
|
query="unknown topic",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
doc_limit=10,
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - Should still produce streamed response
|
||||||
|
assert result is not None
|
||||||
|
assert callback.call_count > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_document_rag_streaming_error_propagation(self, document_rag_streaming,
|
||||||
|
mock_embeddings_client):
|
||||||
|
"""Test that errors during streaming are properly propagated"""
|
||||||
|
# Arrange
|
||||||
|
mock_embeddings_client.embed.side_effect = Exception("Embeddings error")
|
||||||
|
callback = AsyncMock()
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await document_rag_streaming.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
doc_limit=5,
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Embeddings error" in str(exc_info.value)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_document_rag_streaming_with_different_doc_limits(self, document_rag_streaming,
|
||||||
|
mock_doc_embeddings_client):
|
||||||
|
"""Test streaming with various document limits"""
|
||||||
|
# Arrange
|
||||||
|
callback = AsyncMock()
|
||||||
|
doc_limits = [1, 5, 10, 20]
|
||||||
|
|
||||||
|
for limit in doc_limits:
|
||||||
|
# Reset mocks
|
||||||
|
mock_doc_embeddings_client.reset_mock()
|
||||||
|
callback.reset_mock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await document_rag_streaming.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
doc_limit=limit,
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result is not None
|
||||||
|
assert callback.call_count > 0
|
||||||
|
|
||||||
|
# Verify doc_limit was passed correctly
|
||||||
|
call_args = mock_doc_embeddings_client.query.call_args
|
||||||
|
assert call_args.kwargs['limit'] == limit
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_document_rag_streaming_preserves_user_collection(self, document_rag_streaming,
|
||||||
|
mock_doc_embeddings_client):
|
||||||
|
"""Test that streaming preserves user/collection isolation"""
|
||||||
|
# Arrange
|
||||||
|
callback = AsyncMock()
|
||||||
|
user = "test_user_123"
|
||||||
|
collection = "test_collection_456"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await document_rag_streaming.query(
|
||||||
|
query="test query",
|
||||||
|
user=user,
|
||||||
|
collection=collection,
|
||||||
|
doc_limit=10,
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - Verify user/collection were passed to document embeddings client
|
||||||
|
call_args = mock_doc_embeddings_client.query.call_args
|
||||||
|
assert call_args.kwargs['user'] == user
|
||||||
|
assert call_args.kwargs['collection'] == collection
|
||||||
269
tests/integration/test_graph_rag_integration.py
Normal file
269
tests/integration/test_graph_rag_integration.py
Normal file
|
|
@ -0,0 +1,269 @@
|
||||||
|
"""
|
||||||
|
Integration tests for GraphRAG retrieval system
|
||||||
|
|
||||||
|
These tests verify the end-to-end functionality of the GraphRAG system,
|
||||||
|
testing the coordination between embeddings, graph retrieval, triple querying, and prompt services.
|
||||||
|
Following the TEST_STRATEGY.md approach for integration testing.
|
||||||
|
|
||||||
|
NOTE: This is the first integration test file for GraphRAG (previously had only unit tests).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestGraphRagIntegration:
|
||||||
|
"""Integration tests for GraphRAG system coordination"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embeddings_client(self):
|
||||||
|
"""Mock embeddings client that returns realistic vector embeddings"""
|
||||||
|
client = AsyncMock()
|
||||||
|
client.embed.return_value = [
|
||||||
|
[0.1, 0.2, 0.3, 0.4, 0.5], # Realistic 5-dimensional embedding
|
||||||
|
]
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_graph_embeddings_client(self):
|
||||||
|
"""Mock graph embeddings client that returns realistic entities"""
|
||||||
|
client = AsyncMock()
|
||||||
|
client.query.return_value = [
|
||||||
|
"http://trustgraph.ai/e/machine-learning",
|
||||||
|
"http://trustgraph.ai/e/artificial-intelligence",
|
||||||
|
"http://trustgraph.ai/e/neural-networks"
|
||||||
|
]
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_triples_client(self):
|
||||||
|
"""Mock triples client that returns realistic knowledge graph triples"""
|
||||||
|
client = AsyncMock()
|
||||||
|
|
||||||
|
# Mock different queries return different triples
|
||||||
|
async def query_side_effect(s=None, p=None, o=None, limit=None, user=None, collection=None):
|
||||||
|
# Mock label queries
|
||||||
|
if p == "http://www.w3.org/2000/01/rdf-schema#label":
|
||||||
|
if s == "http://trustgraph.ai/e/machine-learning":
|
||||||
|
return [MagicMock(s=s, p=p, o="Machine Learning")]
|
||||||
|
elif s == "http://trustgraph.ai/e/artificial-intelligence":
|
||||||
|
return [MagicMock(s=s, p=p, o="Artificial Intelligence")]
|
||||||
|
elif s == "http://trustgraph.ai/e/neural-networks":
|
||||||
|
return [MagicMock(s=s, p=p, o="Neural Networks")]
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Mock relationship queries
|
||||||
|
if s == "http://trustgraph.ai/e/machine-learning":
|
||||||
|
return [
|
||||||
|
MagicMock(
|
||||||
|
s="http://trustgraph.ai/e/machine-learning",
|
||||||
|
p="http://trustgraph.ai/is_subset_of",
|
||||||
|
o="http://trustgraph.ai/e/artificial-intelligence"
|
||||||
|
),
|
||||||
|
MagicMock(
|
||||||
|
s="http://trustgraph.ai/e/machine-learning",
|
||||||
|
p="http://www.w3.org/2000/01/rdf-schema#label",
|
||||||
|
o="Machine Learning"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
client.query.side_effect = query_side_effect
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_prompt_client(self):
|
||||||
|
"""Mock prompt client that generates realistic responses"""
|
||||||
|
client = AsyncMock()
|
||||||
|
client.kg_prompt.return_value = (
|
||||||
|
"Machine learning is a subset of artificial intelligence that enables computers "
|
||||||
|
"to learn from data without being explicitly programmed. It uses algorithms "
|
||||||
|
"and statistical models to find patterns in data."
|
||||||
|
)
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client,
|
||||||
|
mock_triples_client, mock_prompt_client):
|
||||||
|
"""Create GraphRag instance with mocked dependencies"""
|
||||||
|
return GraphRag(
|
||||||
|
embeddings_client=mock_embeddings_client,
|
||||||
|
graph_embeddings_client=mock_graph_embeddings_client,
|
||||||
|
triples_client=mock_triples_client,
|
||||||
|
prompt_client=mock_prompt_client,
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_rag_end_to_end_flow(self, graph_rag, mock_embeddings_client,
|
||||||
|
mock_graph_embeddings_client, mock_triples_client,
|
||||||
|
mock_prompt_client):
|
||||||
|
"""Test complete GraphRAG pipeline from query to response"""
|
||||||
|
# Arrange
|
||||||
|
query = "What is machine learning?"
|
||||||
|
user = "test_user"
|
||||||
|
collection = "ml_knowledge"
|
||||||
|
entity_limit = 50
|
||||||
|
triple_limit = 30
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await graph_rag.query(
|
||||||
|
query=query,
|
||||||
|
user=user,
|
||||||
|
collection=collection,
|
||||||
|
entity_limit=entity_limit,
|
||||||
|
triple_limit=triple_limit
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - Verify service coordination
|
||||||
|
|
||||||
|
# 1. Should compute embeddings for query
|
||||||
|
mock_embeddings_client.embed.assert_called_once_with(query)
|
||||||
|
|
||||||
|
# 2. Should query graph embeddings to find relevant entities
|
||||||
|
mock_graph_embeddings_client.query.assert_called_once()
|
||||||
|
call_args = mock_graph_embeddings_client.query.call_args
|
||||||
|
assert call_args.kwargs['vectors'] == [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||||
|
assert call_args.kwargs['limit'] == entity_limit
|
||||||
|
assert call_args.kwargs['user'] == user
|
||||||
|
assert call_args.kwargs['collection'] == collection
|
||||||
|
|
||||||
|
# 3. Should query triples to build knowledge subgraph
|
||||||
|
assert mock_triples_client.query.call_count > 0
|
||||||
|
|
||||||
|
# 4. Should call prompt with knowledge graph
|
||||||
|
mock_prompt_client.kg_prompt.assert_called_once()
|
||||||
|
call_args = mock_prompt_client.kg_prompt.call_args
|
||||||
|
assert call_args.args[0] == query # First arg is query
|
||||||
|
assert isinstance(call_args.args[1], list) # Second arg is kg (list of triples)
|
||||||
|
|
||||||
|
# Verify final response
|
||||||
|
assert result is not None
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "machine learning" in result.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_rag_with_different_limits(self, graph_rag, mock_embeddings_client,
|
||||||
|
mock_graph_embeddings_client):
|
||||||
|
"""Test GraphRAG with various entity and triple limits"""
|
||||||
|
# Arrange
|
||||||
|
query = "Explain neural networks"
|
||||||
|
test_configs = [
|
||||||
|
{"entity_limit": 10, "triple_limit": 10},
|
||||||
|
{"entity_limit": 50, "triple_limit": 30},
|
||||||
|
{"entity_limit": 100, "triple_limit": 100},
|
||||||
|
]
|
||||||
|
|
||||||
|
for config in test_configs:
|
||||||
|
# Reset mocks
|
||||||
|
mock_embeddings_client.reset_mock()
|
||||||
|
mock_graph_embeddings_client.reset_mock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await graph_rag.query(
|
||||||
|
query=query,
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
entity_limit=config["entity_limit"],
|
||||||
|
triple_limit=config["triple_limit"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
call_args = mock_graph_embeddings_client.query.call_args
|
||||||
|
assert call_args.kwargs['limit'] == config["entity_limit"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_rag_error_propagation(self, graph_rag, mock_embeddings_client):
|
||||||
|
"""Test that errors from underlying services are properly propagated"""
|
||||||
|
# Arrange
|
||||||
|
mock_embeddings_client.embed.side_effect = Exception("Embeddings service error")
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await graph_rag.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Embeddings service error" in str(exc_info.value)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_rag_with_empty_knowledge_graph(self, graph_rag, mock_graph_embeddings_client,
|
||||||
|
mock_triples_client, mock_prompt_client):
|
||||||
|
"""Test GraphRAG handles empty knowledge graph gracefully"""
|
||||||
|
# Arrange
|
||||||
|
mock_graph_embeddings_client.query.return_value = [] # No entities found
|
||||||
|
mock_triples_client.query.return_value = [] # No triples found
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await graph_rag.query(
|
||||||
|
query="unknown topic",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# Should still call prompt client with empty knowledge graph
|
||||||
|
mock_prompt_client.kg_prompt.assert_called_once()
|
||||||
|
call_args = mock_prompt_client.kg_prompt.call_args
|
||||||
|
assert isinstance(call_args.args[1], list) # kg should be a list
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_rag_label_caching(self, graph_rag, mock_triples_client):
|
||||||
|
"""Test that label lookups are cached to reduce redundant queries"""
|
||||||
|
# Arrange
|
||||||
|
query = "What is machine learning?"
|
||||||
|
|
||||||
|
# First query
|
||||||
|
await graph_rag.query(
|
||||||
|
query=query,
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection"
|
||||||
|
)
|
||||||
|
|
||||||
|
first_call_count = mock_triples_client.query.call_count
|
||||||
|
mock_triples_client.reset_mock()
|
||||||
|
|
||||||
|
# Second identical query
|
||||||
|
await graph_rag.query(
|
||||||
|
query=query,
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection"
|
||||||
|
)
|
||||||
|
|
||||||
|
second_call_count = mock_triples_client.query.call_count
|
||||||
|
|
||||||
|
# Assert - Second query should make fewer triple queries due to caching
|
||||||
|
# Note: This is a weak assertion because caching behavior depends on
|
||||||
|
# implementation details, but it verifies the concept
|
||||||
|
assert second_call_count >= 0 # Should complete without errors
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_rag_multi_user_isolation(self, graph_rag, mock_graph_embeddings_client):
|
||||||
|
"""Test that different users/collections are properly isolated"""
|
||||||
|
# Arrange
|
||||||
|
query = "test query"
|
||||||
|
user1, collection1 = "user1", "collection1"
|
||||||
|
user2, collection2 = "user2", "collection2"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await graph_rag.query(query=query, user=user1, collection=collection1)
|
||||||
|
await graph_rag.query(query=query, user=user2, collection=collection2)
|
||||||
|
|
||||||
|
# Assert - Both users should have separate queries
|
||||||
|
assert mock_graph_embeddings_client.query.call_count == 2
|
||||||
|
|
||||||
|
# Verify first call
|
||||||
|
first_call = mock_graph_embeddings_client.query.call_args_list[0]
|
||||||
|
assert first_call.kwargs['user'] == user1
|
||||||
|
assert first_call.kwargs['collection'] == collection1
|
||||||
|
|
||||||
|
# Verify second call
|
||||||
|
second_call = mock_graph_embeddings_client.query.call_args_list[1]
|
||||||
|
assert second_call.kwargs['user'] == user2
|
||||||
|
assert second_call.kwargs['collection'] == collection2
|
||||||
249
tests/integration/test_graph_rag_streaming_integration.py
Normal file
249
tests/integration/test_graph_rag_streaming_integration.py
Normal file
|
|
@ -0,0 +1,249 @@
|
||||||
|
"""
|
||||||
|
Integration tests for GraphRAG streaming functionality
|
||||||
|
|
||||||
|
These tests verify the streaming behavior of GraphRAG, testing token-by-token
|
||||||
|
response delivery through the complete pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||||
|
from tests.utils.streaming_assertions import (
|
||||||
|
assert_streaming_chunks_valid,
|
||||||
|
assert_rag_streaming_chunks,
|
||||||
|
assert_streaming_content_matches,
|
||||||
|
assert_callback_invoked,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestGraphRagStreaming:
|
||||||
|
"""Integration tests for GraphRAG streaming"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embeddings_client(self):
|
||||||
|
"""Mock embeddings client"""
|
||||||
|
client = AsyncMock()
|
||||||
|
client.embed.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5]]
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_graph_embeddings_client(self):
|
||||||
|
"""Mock graph embeddings client"""
|
||||||
|
client = AsyncMock()
|
||||||
|
client.query.return_value = [
|
||||||
|
"http://trustgraph.ai/e/machine-learning",
|
||||||
|
]
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_triples_client(self):
|
||||||
|
"""Mock triples client with minimal responses"""
|
||||||
|
client = AsyncMock()
|
||||||
|
|
||||||
|
async def query_side_effect(s=None, p=None, o=None, limit=None, user=None, collection=None):
|
||||||
|
if p == "http://www.w3.org/2000/01/rdf-schema#label":
|
||||||
|
return [MagicMock(s=s, p=p, o="Machine Learning")]
|
||||||
|
return []
|
||||||
|
|
||||||
|
client.query.side_effect = query_side_effect
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_streaming_prompt_client(self, mock_streaming_llm_response):
|
||||||
|
"""Mock prompt client with streaming support"""
|
||||||
|
client = AsyncMock()
|
||||||
|
|
||||||
|
async def kg_prompt_side_effect(query, kg, timeout=600, streaming=False, chunk_callback=None):
|
||||||
|
# Both modes return the same text
|
||||||
|
full_text = "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data."
|
||||||
|
|
||||||
|
if streaming and chunk_callback:
|
||||||
|
# Simulate streaming chunks
|
||||||
|
async for chunk in mock_streaming_llm_response():
|
||||||
|
await chunk_callback(chunk)
|
||||||
|
return full_text
|
||||||
|
else:
|
||||||
|
# Non-streaming response - same text
|
||||||
|
return full_text
|
||||||
|
|
||||||
|
client.kg_prompt.side_effect = kg_prompt_side_effect
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def graph_rag_streaming(self, mock_embeddings_client, mock_graph_embeddings_client,
|
||||||
|
mock_triples_client, mock_streaming_prompt_client):
|
||||||
|
"""Create GraphRag instance with streaming support"""
|
||||||
|
return GraphRag(
|
||||||
|
embeddings_client=mock_embeddings_client,
|
||||||
|
graph_embeddings_client=mock_graph_embeddings_client,
|
||||||
|
triples_client=mock_triples_client,
|
||||||
|
prompt_client=mock_streaming_prompt_client,
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_rag_streaming_basic(self, graph_rag_streaming, streaming_chunk_collector):
|
||||||
|
"""Test basic GraphRAG streaming functionality"""
|
||||||
|
# Arrange
|
||||||
|
query = "What is machine learning?"
|
||||||
|
collector = streaming_chunk_collector()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await graph_rag_streaming.query(
|
||||||
|
query=query,
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=collector.collect
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
|
||||||
|
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
|
||||||
|
|
||||||
|
# Verify full response matches concatenated chunks
|
||||||
|
full_from_chunks = collector.get_full_text()
|
||||||
|
assert result == full_from_chunks
|
||||||
|
|
||||||
|
# Verify content is reasonable
|
||||||
|
assert "machine" in result.lower() or "learning" in result.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_rag_streaming_vs_non_streaming(self, graph_rag_streaming):
|
||||||
|
"""Test that streaming and non-streaming produce equivalent results"""
|
||||||
|
# Arrange
|
||||||
|
query = "What is machine learning?"
|
||||||
|
user = "test_user"
|
||||||
|
collection = "test_collection"
|
||||||
|
|
||||||
|
# Act - Non-streaming
|
||||||
|
non_streaming_result = await graph_rag_streaming.query(
|
||||||
|
query=query,
|
||||||
|
user=user,
|
||||||
|
collection=collection,
|
||||||
|
streaming=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act - Streaming
|
||||||
|
streaming_chunks = []
|
||||||
|
|
||||||
|
async def collect(chunk):
|
||||||
|
streaming_chunks.append(chunk)
|
||||||
|
|
||||||
|
streaming_result = await graph_rag_streaming.query(
|
||||||
|
query=query,
|
||||||
|
user=user,
|
||||||
|
collection=collection,
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=collect
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - Results should be equivalent
|
||||||
|
assert streaming_result == non_streaming_result
|
||||||
|
assert len(streaming_chunks) > 0
|
||||||
|
assert "".join(streaming_chunks) == streaming_result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_rag_streaming_callback_invocation(self, graph_rag_streaming):
|
||||||
|
"""Test that chunk callback is invoked correctly"""
|
||||||
|
# Arrange
|
||||||
|
callback = AsyncMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await graph_rag_streaming.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert callback.call_count > 0
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
# Verify all callback invocations had string arguments
|
||||||
|
for call in callback.call_args_list:
|
||||||
|
assert isinstance(call.args[0], str)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_rag_streaming_without_callback(self, graph_rag_streaming):
|
||||||
|
"""Test streaming parameter without callback (should fall back to non-streaming)"""
|
||||||
|
# Arrange & Act
|
||||||
|
result = await graph_rag_streaming.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=None # No callback provided
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - Should complete without error
|
||||||
|
assert result is not None
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_rag_streaming_with_empty_kg(self, graph_rag_streaming,
|
||||||
|
mock_graph_embeddings_client):
|
||||||
|
"""Test streaming with empty knowledge graph"""
|
||||||
|
# Arrange
|
||||||
|
mock_graph_embeddings_client.query.return_value = [] # No entities
|
||||||
|
callback = AsyncMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = await graph_rag_streaming.query(
|
||||||
|
query="unknown topic",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - Should still produce streamed response
|
||||||
|
assert result is not None
|
||||||
|
assert callback.call_count > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_rag_streaming_error_propagation(self, graph_rag_streaming,
|
||||||
|
mock_embeddings_client):
|
||||||
|
"""Test that errors during streaming are properly propagated"""
|
||||||
|
# Arrange
|
||||||
|
mock_embeddings_client.embed.side_effect = Exception("Embeddings error")
|
||||||
|
callback = AsyncMock()
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await graph_rag_streaming.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Embeddings error" in str(exc_info.value)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_rag_streaming_preserves_parameters(self, graph_rag_streaming,
|
||||||
|
mock_graph_embeddings_client):
|
||||||
|
"""Test that streaming preserves all query parameters"""
|
||||||
|
# Arrange
|
||||||
|
callback = AsyncMock()
|
||||||
|
entity_limit = 25
|
||||||
|
triple_limit = 15
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await graph_rag_streaming.query(
|
||||||
|
query="test query",
|
||||||
|
user="test_user",
|
||||||
|
collection="test_collection",
|
||||||
|
entity_limit=entity_limit,
|
||||||
|
triple_limit=triple_limit,
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert - Verify parameters were passed to underlying services
|
||||||
|
call_args = mock_graph_embeddings_client.query.call_args
|
||||||
|
assert call_args.kwargs['limit'] == entity_limit
|
||||||
404
tests/integration/test_prompt_streaming_integration.py
Normal file
404
tests/integration/test_prompt_streaming_integration.py
Normal file
|
|
@ -0,0 +1,404 @@
|
||||||
|
"""
|
||||||
|
Integration tests for Prompt Service Streaming Functionality
|
||||||
|
|
||||||
|
These tests verify the streaming behavior of the Prompt service,
|
||||||
|
testing how it coordinates between templates and text completion streaming.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from trustgraph.prompt.template.service import Processor
|
||||||
|
from trustgraph.schema import PromptRequest, PromptResponse, TextCompletionResponse
|
||||||
|
from tests.utils.streaming_assertions import (
|
||||||
|
assert_streaming_chunks_valid,
|
||||||
|
assert_callback_invoked,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestPromptStreaming:
|
||||||
|
"""Integration tests for Prompt service streaming"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_flow_context_streaming(self):
|
||||||
|
"""Mock flow context with streaming text completion support"""
|
||||||
|
context = MagicMock()
|
||||||
|
|
||||||
|
# Mock text completion client with streaming
|
||||||
|
text_completion_client = AsyncMock()
|
||||||
|
|
||||||
|
async def streaming_request(request, recipient=None, timeout=600):
|
||||||
|
"""Simulate streaming text completion"""
|
||||||
|
if request.streaming and recipient:
|
||||||
|
# Simulate streaming chunks
|
||||||
|
chunks = [
|
||||||
|
"Machine", " learning", " is", " a", " field",
|
||||||
|
" of", " artificial", " intelligence", "."
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, chunk_text in enumerate(chunks):
|
||||||
|
is_final = (i == len(chunks) - 1)
|
||||||
|
response = TextCompletionResponse(
|
||||||
|
response=chunk_text,
|
||||||
|
error=None,
|
||||||
|
end_of_stream=is_final
|
||||||
|
)
|
||||||
|
final = await recipient(response)
|
||||||
|
if final:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Final empty chunk
|
||||||
|
await recipient(TextCompletionResponse(
|
||||||
|
response="",
|
||||||
|
error=None,
|
||||||
|
end_of_stream=True
|
||||||
|
))
|
||||||
|
|
||||||
|
text_completion_client.request = streaming_request
|
||||||
|
|
||||||
|
# Mock response producer
|
||||||
|
response_producer = AsyncMock()
|
||||||
|
|
||||||
|
def context_router(service_name):
|
||||||
|
if service_name == "text-completion-request":
|
||||||
|
return text_completion_client
|
||||||
|
elif service_name == "response":
|
||||||
|
return response_producer
|
||||||
|
else:
|
||||||
|
return AsyncMock()
|
||||||
|
|
||||||
|
context.side_effect = context_router
|
||||||
|
return context
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_prompt_manager(self):
|
||||||
|
"""Mock PromptManager with simple template"""
|
||||||
|
manager = MagicMock()
|
||||||
|
|
||||||
|
async def invoke_template(kind, input_vars, llm_function):
|
||||||
|
"""Simulate template invocation"""
|
||||||
|
# Call the LLM function with simple prompts
|
||||||
|
system = "You are a helpful assistant."
|
||||||
|
prompt = f"Question: {input_vars.get('question', 'test')}"
|
||||||
|
result = await llm_function(system, prompt)
|
||||||
|
return result
|
||||||
|
|
||||||
|
manager.invoke = invoke_template
|
||||||
|
return manager
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def prompt_processor_streaming(self, mock_prompt_manager):
|
||||||
|
"""Create Prompt processor with streaming support"""
|
||||||
|
processor = MagicMock()
|
||||||
|
processor.manager = mock_prompt_manager
|
||||||
|
processor.config_key = "prompt"
|
||||||
|
|
||||||
|
# Bind the actual on_request method
|
||||||
|
processor.on_request = Processor.on_request.__get__(processor, Processor)
|
||||||
|
|
||||||
|
return processor
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_streaming_basic(self, prompt_processor_streaming, mock_flow_context_streaming):
|
||||||
|
"""Test basic prompt streaming functionality"""
|
||||||
|
# Arrange
|
||||||
|
request = PromptRequest(
|
||||||
|
id="kg_prompt",
|
||||||
|
terms={"question": '"What is machine learning?"'},
|
||||||
|
streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
message = MagicMock()
|
||||||
|
message.value.return_value = request
|
||||||
|
message.properties.return_value = {"id": "test-123"}
|
||||||
|
|
||||||
|
consumer = MagicMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await prompt_processor_streaming.on_request(
|
||||||
|
message, consumer, mock_flow_context_streaming
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# Verify response producer was called multiple times (for streaming chunks)
|
||||||
|
response_producer = mock_flow_context_streaming("response")
|
||||||
|
assert response_producer.send.call_count > 0
|
||||||
|
|
||||||
|
# Verify streaming chunks were sent
|
||||||
|
calls = response_producer.send.call_args_list
|
||||||
|
assert len(calls) > 1 # Should have multiple chunks
|
||||||
|
|
||||||
|
# Check that responses have end_of_stream flag
|
||||||
|
for call in calls:
|
||||||
|
response = call.args[0]
|
||||||
|
assert isinstance(response, PromptResponse)
|
||||||
|
assert hasattr(response, 'end_of_stream')
|
||||||
|
|
||||||
|
# Last response should have end_of_stream=True
|
||||||
|
last_call = calls[-1]
|
||||||
|
last_response = last_call.args[0]
|
||||||
|
assert last_response.end_of_stream is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_streaming_non_streaming_mode(self, prompt_processor_streaming,
|
||||||
|
mock_flow_context_streaming):
|
||||||
|
"""Test prompt service in non-streaming mode"""
|
||||||
|
# Arrange
|
||||||
|
request = PromptRequest(
|
||||||
|
id="kg_prompt",
|
||||||
|
terms={"question": '"What is AI?"'},
|
||||||
|
streaming=False # Non-streaming
|
||||||
|
)
|
||||||
|
|
||||||
|
message = MagicMock()
|
||||||
|
message.value.return_value = request
|
||||||
|
message.properties.return_value = {"id": "test-456"}
|
||||||
|
|
||||||
|
consumer = MagicMock()
|
||||||
|
|
||||||
|
# Mock non-streaming text completion
|
||||||
|
text_completion_client = mock_flow_context_streaming("text-completion-request")
|
||||||
|
|
||||||
|
async def non_streaming_text_completion(system, prompt, streaming=False):
|
||||||
|
return "AI is the simulation of human intelligence in machines."
|
||||||
|
|
||||||
|
text_completion_client.text_completion = non_streaming_text_completion
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await prompt_processor_streaming.on_request(
|
||||||
|
message, consumer, mock_flow_context_streaming
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# Verify response producer was called once (non-streaming)
|
||||||
|
response_producer = mock_flow_context_streaming("response")
|
||||||
|
# Note: In non-streaming mode, the service sends a single response
|
||||||
|
assert response_producer.send.call_count >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_streaming_chunk_forwarding(self, prompt_processor_streaming,
|
||||||
|
mock_flow_context_streaming):
|
||||||
|
"""Test that prompt service forwards chunks immediately"""
|
||||||
|
# Arrange
|
||||||
|
request = PromptRequest(
|
||||||
|
id="test_prompt",
|
||||||
|
terms={"question": '"Test query"'},
|
||||||
|
streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
message = MagicMock()
|
||||||
|
message.value.return_value = request
|
||||||
|
message.properties.return_value = {"id": "test-789"}
|
||||||
|
|
||||||
|
consumer = MagicMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await prompt_processor_streaming.on_request(
|
||||||
|
message, consumer, mock_flow_context_streaming
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# Verify chunks were forwarded with proper structure
|
||||||
|
response_producer = mock_flow_context_streaming("response")
|
||||||
|
calls = response_producer.send.call_args_list
|
||||||
|
|
||||||
|
for call in calls:
|
||||||
|
response = call.args[0]
|
||||||
|
# Each response should have text and end_of_stream fields
|
||||||
|
assert hasattr(response, 'text')
|
||||||
|
assert hasattr(response, 'end_of_stream')
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_streaming_error_handling(self, prompt_processor_streaming):
|
||||||
|
"""Test error handling during streaming"""
|
||||||
|
# Arrange
|
||||||
|
from trustgraph.schema import Error
|
||||||
|
context = MagicMock()
|
||||||
|
|
||||||
|
# Mock text completion client that raises an error
|
||||||
|
text_completion_client = AsyncMock()
|
||||||
|
|
||||||
|
async def failing_request(request, recipient=None, timeout=600):
|
||||||
|
if recipient:
|
||||||
|
# Send error response with proper Error schema
|
||||||
|
error_response = TextCompletionResponse(
|
||||||
|
response="",
|
||||||
|
error=Error(message="Text completion error", type="processing_error"),
|
||||||
|
end_of_stream=True
|
||||||
|
)
|
||||||
|
await recipient(error_response)
|
||||||
|
|
||||||
|
text_completion_client.request = failing_request
|
||||||
|
|
||||||
|
# Mock response producer to capture error response
|
||||||
|
response_producer = AsyncMock()
|
||||||
|
|
||||||
|
def context_router(service_name):
|
||||||
|
if service_name == "text-completion-request":
|
||||||
|
return text_completion_client
|
||||||
|
elif service_name == "response":
|
||||||
|
return response_producer
|
||||||
|
else:
|
||||||
|
return AsyncMock()
|
||||||
|
|
||||||
|
context.side_effect = context_router
|
||||||
|
|
||||||
|
request = PromptRequest(
|
||||||
|
id="test_prompt",
|
||||||
|
terms={"question": '"Test"'},
|
||||||
|
streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
message = MagicMock()
|
||||||
|
message.value.return_value = request
|
||||||
|
message.properties.return_value = {"id": "test-error"}
|
||||||
|
|
||||||
|
consumer = MagicMock()
|
||||||
|
|
||||||
|
# Act - The service catches errors and sends error responses, doesn't raise
|
||||||
|
await prompt_processor_streaming.on_request(message, consumer, context)
|
||||||
|
|
||||||
|
# Assert - Verify error response was sent
|
||||||
|
assert response_producer.send.call_count > 0
|
||||||
|
|
||||||
|
# Check that at least one response contains an error
|
||||||
|
error_sent = False
|
||||||
|
for call in response_producer.send.call_args_list:
|
||||||
|
response = call.args[0]
|
||||||
|
if hasattr(response, 'error') and response.error:
|
||||||
|
error_sent = True
|
||||||
|
assert "Text completion error" in response.error.message
|
||||||
|
break
|
||||||
|
|
||||||
|
assert error_sent, "Expected error response to be sent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_streaming_preserves_message_id(self, prompt_processor_streaming,
|
||||||
|
mock_flow_context_streaming):
|
||||||
|
"""Test that message IDs are preserved through streaming"""
|
||||||
|
# Arrange
|
||||||
|
message_id = "unique-test-id-12345"
|
||||||
|
|
||||||
|
request = PromptRequest(
|
||||||
|
id="test_prompt",
|
||||||
|
terms={"question": '"Test"'},
|
||||||
|
streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
message = MagicMock()
|
||||||
|
message.value.return_value = request
|
||||||
|
message.properties.return_value = {"id": message_id}
|
||||||
|
|
||||||
|
consumer = MagicMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await prompt_processor_streaming.on_request(
|
||||||
|
message, consumer, mock_flow_context_streaming
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# Verify all responses were sent with the correct message ID
|
||||||
|
response_producer = mock_flow_context_streaming("response")
|
||||||
|
calls = response_producer.send.call_args_list
|
||||||
|
|
||||||
|
for call in calls:
|
||||||
|
properties = call.kwargs.get('properties')
|
||||||
|
assert properties is not None
|
||||||
|
assert properties['id'] == message_id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_streaming_empty_response_handling(self, prompt_processor_streaming):
|
||||||
|
"""Test handling of empty responses during streaming"""
|
||||||
|
# Arrange
|
||||||
|
context = MagicMock()
|
||||||
|
|
||||||
|
# Mock text completion that sends empty chunks
|
||||||
|
text_completion_client = AsyncMock()
|
||||||
|
|
||||||
|
async def empty_streaming_request(request, recipient=None, timeout=600):
|
||||||
|
if request.streaming and recipient:
|
||||||
|
# Send empty chunk followed by final marker
|
||||||
|
await recipient(TextCompletionResponse(
|
||||||
|
response="",
|
||||||
|
error=None,
|
||||||
|
end_of_stream=False
|
||||||
|
))
|
||||||
|
await recipient(TextCompletionResponse(
|
||||||
|
response="",
|
||||||
|
error=None,
|
||||||
|
end_of_stream=True
|
||||||
|
))
|
||||||
|
|
||||||
|
text_completion_client.request = empty_streaming_request
|
||||||
|
response_producer = AsyncMock()
|
||||||
|
|
||||||
|
def context_router(service_name):
|
||||||
|
if service_name == "text-completion-request":
|
||||||
|
return text_completion_client
|
||||||
|
elif service_name == "response":
|
||||||
|
return response_producer
|
||||||
|
else:
|
||||||
|
return AsyncMock()
|
||||||
|
|
||||||
|
context.side_effect = context_router
|
||||||
|
|
||||||
|
request = PromptRequest(
|
||||||
|
id="test_prompt",
|
||||||
|
terms={"question": '"Test"'},
|
||||||
|
streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
message = MagicMock()
|
||||||
|
message.value.return_value = request
|
||||||
|
message.properties.return_value = {"id": "test-empty"}
|
||||||
|
|
||||||
|
consumer = MagicMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await prompt_processor_streaming.on_request(message, consumer, context)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# Should still send responses even if empty (including final marker)
|
||||||
|
assert response_producer.send.call_count > 0
|
||||||
|
|
||||||
|
# Last response should have end_of_stream=True
|
||||||
|
last_call = response_producer.send.call_args_list[-1]
|
||||||
|
last_response = last_call.args[0]
|
||||||
|
assert last_response.end_of_stream is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_streaming_concatenation_matches_complete(self, prompt_processor_streaming,
|
||||||
|
mock_flow_context_streaming):
|
||||||
|
"""Test that streaming chunks concatenate to form complete response"""
|
||||||
|
# Arrange
|
||||||
|
request = PromptRequest(
|
||||||
|
id="test_prompt",
|
||||||
|
terms={"question": '"What is ML?"'},
|
||||||
|
streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
message = MagicMock()
|
||||||
|
message.value.return_value = request
|
||||||
|
message.properties.return_value = {"id": "test-concat"}
|
||||||
|
|
||||||
|
consumer = MagicMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await prompt_processor_streaming.on_request(
|
||||||
|
message, consumer, mock_flow_context_streaming
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# Collect all response texts
|
||||||
|
response_producer = mock_flow_context_streaming("response")
|
||||||
|
calls = response_producer.send.call_args_list
|
||||||
|
|
||||||
|
chunk_texts = []
|
||||||
|
for call in calls:
|
||||||
|
response = call.args[0]
|
||||||
|
if response.text and not response.end_of_stream:
|
||||||
|
chunk_texts.append(response.text)
|
||||||
|
|
||||||
|
# Verify chunks concatenate to expected result
|
||||||
|
full_text = "".join(chunk_texts)
|
||||||
|
assert full_text == "Machine learning is a field of artificial intelligence"
|
||||||
366
tests/integration/test_text_completion_streaming_integration.py
Normal file
366
tests/integration/test_text_completion_streaming_integration.py
Normal file
|
|
@ -0,0 +1,366 @@
|
||||||
|
"""
|
||||||
|
Integration tests for Text Completion Streaming Functionality
|
||||||
|
|
||||||
|
These tests verify the streaming behavior of the Text Completion service,
|
||||||
|
testing token-by-token response delivery through the complete pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from openai.types.chat import ChatCompletionChunk
|
||||||
|
from openai.types.chat.chat_completion_chunk import Choice as StreamChoice, ChoiceDelta
|
||||||
|
|
||||||
|
from trustgraph.model.text_completion.openai.llm import Processor
|
||||||
|
from trustgraph.base import LlmChunk
|
||||||
|
from tests.utils.streaming_assertions import (
|
||||||
|
assert_streaming_chunks_valid,
|
||||||
|
assert_callback_invoked,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestTextCompletionStreaming:
|
||||||
|
"""Integration tests for Text Completion streaming"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_streaming_openai_client(self, mock_streaming_llm_response):
|
||||||
|
"""Mock OpenAI client with streaming support"""
|
||||||
|
client = MagicMock()
|
||||||
|
|
||||||
|
def create_streaming_completion(**kwargs):
|
||||||
|
"""Generator that yields streaming chunks"""
|
||||||
|
# Check if streaming is enabled
|
||||||
|
if not kwargs.get('stream', False):
|
||||||
|
raise ValueError("Expected streaming mode")
|
||||||
|
|
||||||
|
# Simulate OpenAI streaming response
|
||||||
|
chunks_text = [
|
||||||
|
"Machine", " learning", " is", " a", " subset",
|
||||||
|
" of", " AI", " that", " enables", " computers",
|
||||||
|
" to", " learn", " from", " data", "."
|
||||||
|
]
|
||||||
|
|
||||||
|
for text in chunks_text:
|
||||||
|
delta = ChoiceDelta(content=text, role=None)
|
||||||
|
choice = StreamChoice(index=0, delta=delta, finish_reason=None)
|
||||||
|
chunk = ChatCompletionChunk(
|
||||||
|
id="chatcmpl-streaming",
|
||||||
|
choices=[choice],
|
||||||
|
created=1234567890,
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
object="chat.completion.chunk"
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# Return a new generator each time create is called
|
||||||
|
client.chat.completions.create.side_effect = lambda **kwargs: create_streaming_completion(**kwargs)
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def text_completion_processor_streaming(self, mock_streaming_openai_client):
|
||||||
|
"""Create text completion processor with streaming support"""
|
||||||
|
processor = MagicMock()
|
||||||
|
processor.default_model = "gpt-3.5-turbo"
|
||||||
|
processor.temperature = 0.7
|
||||||
|
processor.max_output = 1024
|
||||||
|
processor.openai = mock_streaming_openai_client
|
||||||
|
|
||||||
|
# Bind the actual streaming method
|
||||||
|
processor.generate_content_stream = Processor.generate_content_stream.__get__(
|
||||||
|
processor, Processor
|
||||||
|
)
|
||||||
|
|
||||||
|
return processor
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_completion_streaming_basic(self, text_completion_processor_streaming,
|
||||||
|
streaming_chunk_collector):
|
||||||
|
"""Test basic text completion streaming functionality"""
|
||||||
|
# Arrange
|
||||||
|
system_prompt = "You are a helpful assistant."
|
||||||
|
user_prompt = "What is machine learning?"
|
||||||
|
collector = streaming_chunk_collector()
|
||||||
|
|
||||||
|
# Act - Collect all chunks
|
||||||
|
chunks = []
|
||||||
|
async for chunk in text_completion_processor_streaming.generate_content_stream(
|
||||||
|
system_prompt, user_prompt
|
||||||
|
):
|
||||||
|
chunks.append(chunk)
|
||||||
|
if chunk.text:
|
||||||
|
await collector.collect(chunk.text)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(chunks) > 1 # Should have multiple chunks
|
||||||
|
|
||||||
|
# Verify all chunks are LlmChunk objects
|
||||||
|
for chunk in chunks:
|
||||||
|
assert isinstance(chunk, LlmChunk)
|
||||||
|
assert chunk.model == "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
# Verify last chunk has is_final=True
|
||||||
|
assert chunks[-1].is_final is True
|
||||||
|
|
||||||
|
# Verify we got meaningful content
|
||||||
|
full_text = collector.get_full_text()
|
||||||
|
assert "machine" in full_text.lower() or "learning" in full_text.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_completion_streaming_chunk_structure(self, text_completion_processor_streaming):
|
||||||
|
"""Test that streaming chunks have correct structure"""
|
||||||
|
# Arrange
|
||||||
|
system_prompt = "You are a helpful assistant."
|
||||||
|
user_prompt = "Explain AI."
|
||||||
|
|
||||||
|
# Act
|
||||||
|
chunks = []
|
||||||
|
async for chunk in text_completion_processor_streaming.generate_content_stream(
|
||||||
|
system_prompt, user_prompt
|
||||||
|
):
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# Assert - Verify chunk structure
|
||||||
|
for i, chunk in enumerate(chunks[:-1]): # All except last
|
||||||
|
assert isinstance(chunk, LlmChunk)
|
||||||
|
assert chunk.text is not None
|
||||||
|
assert chunk.model == "gpt-3.5-turbo"
|
||||||
|
assert chunk.is_final is False
|
||||||
|
|
||||||
|
# Last chunk should be final marker
|
||||||
|
final_chunk = chunks[-1]
|
||||||
|
assert final_chunk.is_final is True
|
||||||
|
assert final_chunk.model == "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_completion_streaming_concatenation(self, text_completion_processor_streaming):
|
||||||
|
"""Test that chunks concatenate to form complete response"""
|
||||||
|
# Arrange
|
||||||
|
system_prompt = "You are a helpful assistant."
|
||||||
|
user_prompt = "What is AI?"
|
||||||
|
|
||||||
|
# Act - Collect all chunk texts
|
||||||
|
chunk_texts = []
|
||||||
|
async for chunk in text_completion_processor_streaming.generate_content_stream(
|
||||||
|
system_prompt, user_prompt
|
||||||
|
):
|
||||||
|
if chunk.text and not chunk.is_final:
|
||||||
|
chunk_texts.append(chunk.text)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
full_text = "".join(chunk_texts)
|
||||||
|
assert len(full_text) > 0
|
||||||
|
assert len(chunk_texts) > 1 # Should have multiple chunks
|
||||||
|
|
||||||
|
# Verify completeness - should be a coherent sentence
|
||||||
|
assert full_text == "Machine learning is a subset of AI that enables computers to learn from data."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_completion_streaming_final_marker(self, text_completion_processor_streaming):
|
||||||
|
"""Test that final chunk properly marks end of stream"""
|
||||||
|
# Arrange
|
||||||
|
system_prompt = "You are a helpful assistant."
|
||||||
|
user_prompt = "Test query"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
chunks = []
|
||||||
|
async for chunk in text_completion_processor_streaming.generate_content_stream(
|
||||||
|
system_prompt, user_prompt
|
||||||
|
):
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# Should have at least content chunks + final marker
|
||||||
|
assert len(chunks) >= 2
|
||||||
|
|
||||||
|
# Only the last chunk should have is_final=True
|
||||||
|
for chunk in chunks[:-1]:
|
||||||
|
assert chunk.is_final is False
|
||||||
|
|
||||||
|
assert chunks[-1].is_final is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_completion_streaming_model_parameter(self, mock_streaming_openai_client):
|
||||||
|
"""Test that model parameter is preserved in streaming"""
|
||||||
|
# Arrange
|
||||||
|
processor = MagicMock()
|
||||||
|
processor.default_model = "gpt-4"
|
||||||
|
processor.temperature = 0.5
|
||||||
|
processor.max_output = 2048
|
||||||
|
processor.openai = mock_streaming_openai_client
|
||||||
|
processor.generate_content_stream = Processor.generate_content_stream.__get__(
|
||||||
|
processor, Processor
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
chunks = []
|
||||||
|
async for chunk in processor.generate_content_stream("System", "Prompt"):
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# Verify OpenAI was called with correct model
|
||||||
|
call_args = mock_streaming_openai_client.chat.completions.create.call_args
|
||||||
|
assert call_args.kwargs['model'] == "gpt-4"
|
||||||
|
assert call_args.kwargs['temperature'] == 0.5
|
||||||
|
assert call_args.kwargs['max_tokens'] == 2048
|
||||||
|
assert call_args.kwargs['stream'] is True
|
||||||
|
|
||||||
|
# Verify chunks have correct model
|
||||||
|
for chunk in chunks:
|
||||||
|
assert chunk.model == "gpt-4"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_completion_streaming_temperature_parameter(self, mock_streaming_openai_client):
|
||||||
|
"""Test that temperature parameter is applied in streaming"""
|
||||||
|
# Arrange
|
||||||
|
temperatures = [0.0, 0.5, 1.0, 1.5]
|
||||||
|
|
||||||
|
for temp in temperatures:
|
||||||
|
processor = MagicMock()
|
||||||
|
processor.default_model = "gpt-3.5-turbo"
|
||||||
|
processor.temperature = temp
|
||||||
|
processor.max_output = 1024
|
||||||
|
processor.openai = mock_streaming_openai_client
|
||||||
|
processor.generate_content_stream = Processor.generate_content_stream.__get__(
|
||||||
|
processor, Processor
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
chunks = []
|
||||||
|
async for chunk in processor.generate_content_stream("System", "Prompt"):
|
||||||
|
chunks.append(chunk)
|
||||||
|
if chunk.is_final:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
call_args = mock_streaming_openai_client.chat.completions.create.call_args
|
||||||
|
assert call_args.kwargs['temperature'] == temp
|
||||||
|
|
||||||
|
# Reset mock for next iteration
|
||||||
|
mock_streaming_openai_client.reset_mock()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_completion_streaming_error_propagation(self):
|
||||||
|
"""Test that errors during streaming are properly propagated"""
|
||||||
|
# Arrange
|
||||||
|
mock_client = MagicMock()
|
||||||
|
|
||||||
|
def failing_stream(**kwargs):
|
||||||
|
yield from []
|
||||||
|
raise Exception("Streaming error")
|
||||||
|
|
||||||
|
mock_client.chat.completions.create.return_value = failing_stream()
|
||||||
|
|
||||||
|
processor = MagicMock()
|
||||||
|
processor.default_model = "gpt-3.5-turbo"
|
||||||
|
processor.temperature = 0.7
|
||||||
|
processor.max_output = 1024
|
||||||
|
processor.openai = mock_client
|
||||||
|
processor.generate_content_stream = Processor.generate_content_stream.__get__(
|
||||||
|
processor, Processor
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
async for chunk in processor.generate_content_stream("System", "Prompt"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert "Streaming error" in str(exc_info.value)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_completion_streaming_empty_chunks_filtered(self, mock_streaming_openai_client):
|
||||||
|
"""Test that empty chunks are handled correctly"""
|
||||||
|
# Arrange - Mock that returns some empty chunks
|
||||||
|
def create_streaming_with_empties(**kwargs):
|
||||||
|
chunks_text = ["Hello", "", " world", "", "!"]
|
||||||
|
|
||||||
|
for text in chunks_text:
|
||||||
|
delta = ChoiceDelta(content=text if text else None, role=None)
|
||||||
|
choice = StreamChoice(index=0, delta=delta, finish_reason=None)
|
||||||
|
chunk = ChatCompletionChunk(
|
||||||
|
id="chatcmpl-streaming",
|
||||||
|
choices=[choice],
|
||||||
|
created=1234567890,
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
object="chat.completion.chunk"
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
mock_streaming_openai_client.chat.completions.create.side_effect = lambda **kwargs: create_streaming_with_empties(**kwargs)
|
||||||
|
|
||||||
|
processor = MagicMock()
|
||||||
|
processor.default_model = "gpt-3.5-turbo"
|
||||||
|
processor.temperature = 0.7
|
||||||
|
processor.max_output = 1024
|
||||||
|
processor.openai = mock_streaming_openai_client
|
||||||
|
processor.generate_content_stream = Processor.generate_content_stream.__get__(
|
||||||
|
processor, Processor
|
||||||
|
)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
chunks = []
|
||||||
|
async for chunk in processor.generate_content_stream("System", "Prompt"):
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# Assert - Only non-empty chunks should be yielded (plus final marker)
|
||||||
|
text_chunks = [c for c in chunks if not c.is_final]
|
||||||
|
assert len(text_chunks) == 3 # "Hello", " world", "!"
|
||||||
|
assert "".join(c.text for c in text_chunks) == "Hello world!"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_completion_streaming_prompt_construction(self, mock_streaming_openai_client):
|
||||||
|
"""Test that system and user prompts are correctly combined for streaming"""
|
||||||
|
# Arrange
|
||||||
|
processor = MagicMock()
|
||||||
|
processor.default_model = "gpt-3.5-turbo"
|
||||||
|
processor.temperature = 0.7
|
||||||
|
processor.max_output = 1024
|
||||||
|
processor.openai = mock_streaming_openai_client
|
||||||
|
processor.generate_content_stream = Processor.generate_content_stream.__get__(
|
||||||
|
processor, Processor
|
||||||
|
)
|
||||||
|
|
||||||
|
system_prompt = "You are an expert."
|
||||||
|
user_prompt = "Explain quantum physics."
|
||||||
|
|
||||||
|
# Act
|
||||||
|
chunks = []
|
||||||
|
async for chunk in processor.generate_content_stream(system_prompt, user_prompt):
|
||||||
|
chunks.append(chunk)
|
||||||
|
if chunk.is_final:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Assert - Verify prompts were combined correctly
|
||||||
|
call_args = mock_streaming_openai_client.chat.completions.create.call_args
|
||||||
|
messages = call_args.kwargs['messages']
|
||||||
|
assert len(messages) == 1
|
||||||
|
|
||||||
|
message_content = messages[0]['content'][0]['text']
|
||||||
|
assert system_prompt in message_content
|
||||||
|
assert user_prompt in message_content
|
||||||
|
assert message_content.startswith(system_prompt)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_completion_streaming_chunk_count(self, text_completion_processor_streaming):
|
||||||
|
"""Test that streaming produces expected number of chunks"""
|
||||||
|
# Arrange
|
||||||
|
system_prompt = "You are a helpful assistant."
|
||||||
|
user_prompt = "Test"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
chunks = []
|
||||||
|
async for chunk in text_completion_processor_streaming.generate_content_stream(
|
||||||
|
system_prompt, user_prompt
|
||||||
|
):
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# Should have 15 content chunks + 1 final marker = 16 total
|
||||||
|
assert len(chunks) == 16
|
||||||
|
|
||||||
|
# 15 content chunks
|
||||||
|
content_chunks = [c for c in chunks if not c.is_final]
|
||||||
|
assert len(content_chunks) == 15
|
||||||
|
|
||||||
|
# 1 final marker
|
||||||
|
final_chunks = [c for c in chunks if c.is_final]
|
||||||
|
assert len(final_chunks) == 1
|
||||||
29
tests/utils/__init__.py
Normal file
29
tests/utils/__init__.py
Normal file
|
|
@ -0,0 +1,29 @@
|
||||||
|
"""Test utilities for TrustGraph tests"""
|
||||||
|
|
||||||
|
from .streaming_assertions import (
|
||||||
|
assert_streaming_chunks_valid,
|
||||||
|
assert_streaming_sequence,
|
||||||
|
assert_agent_streaming_chunks,
|
||||||
|
assert_rag_streaming_chunks,
|
||||||
|
assert_streaming_completion,
|
||||||
|
assert_streaming_content_matches,
|
||||||
|
assert_no_empty_chunks,
|
||||||
|
assert_streaming_error_handled,
|
||||||
|
assert_chunk_types_valid,
|
||||||
|
assert_streaming_latency_acceptable,
|
||||||
|
assert_callback_invoked,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"assert_streaming_chunks_valid",
|
||||||
|
"assert_streaming_sequence",
|
||||||
|
"assert_agent_streaming_chunks",
|
||||||
|
"assert_rag_streaming_chunks",
|
||||||
|
"assert_streaming_completion",
|
||||||
|
"assert_streaming_content_matches",
|
||||||
|
"assert_no_empty_chunks",
|
||||||
|
"assert_streaming_error_handled",
|
||||||
|
"assert_chunk_types_valid",
|
||||||
|
"assert_streaming_latency_acceptable",
|
||||||
|
"assert_callback_invoked",
|
||||||
|
]
|
||||||
218
tests/utils/streaming_assertions.py
Normal file
218
tests/utils/streaming_assertions.py
Normal file
|
|
@ -0,0 +1,218 @@
|
||||||
|
"""
|
||||||
|
Streaming test assertion helpers
|
||||||
|
|
||||||
|
Provides reusable assertion functions for validating streaming behavior
|
||||||
|
across different TrustGraph services.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def assert_streaming_chunks_valid(chunks: List[Any], min_chunks: int = 1):
|
||||||
|
"""
|
||||||
|
Assert that streaming chunks are valid and non-empty.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunks: List of streaming chunks
|
||||||
|
min_chunks: Minimum number of expected chunks
|
||||||
|
"""
|
||||||
|
assert len(chunks) >= min_chunks, f"Expected at least {min_chunks} chunks, got {len(chunks)}"
|
||||||
|
assert all(chunk is not None for chunk in chunks), "All chunks should be non-None"
|
||||||
|
|
||||||
|
|
||||||
|
def assert_streaming_sequence(chunks: List[Dict[str, Any]], expected_sequence: List[str], key: str = "chunk_type"):
|
||||||
|
"""
|
||||||
|
Assert that streaming chunks follow an expected sequence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunks: List of chunk dictionaries
|
||||||
|
expected_sequence: Expected sequence of chunk types/values
|
||||||
|
key: Dictionary key to check (default: "chunk_type")
|
||||||
|
"""
|
||||||
|
actual_sequence = [chunk.get(key) for chunk in chunks if key in chunk]
|
||||||
|
assert actual_sequence == expected_sequence, \
|
||||||
|
f"Expected sequence {expected_sequence}, got {actual_sequence}"
|
||||||
|
|
||||||
|
|
||||||
|
def assert_agent_streaming_chunks(chunks: List[Dict[str, Any]]):
|
||||||
|
"""
|
||||||
|
Assert that agent streaming chunks have valid structure.
|
||||||
|
|
||||||
|
Validates:
|
||||||
|
- All chunks have chunk_type field
|
||||||
|
- All chunks have content field
|
||||||
|
- All chunks have end_of_message field
|
||||||
|
- All chunks have end_of_dialog field
|
||||||
|
- Last chunk has end_of_dialog=True
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunks: List of agent streaming chunk dictionaries
|
||||||
|
"""
|
||||||
|
assert len(chunks) > 0, "Expected at least one chunk"
|
||||||
|
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
assert "chunk_type" in chunk, f"Chunk {i} missing chunk_type"
|
||||||
|
assert "content" in chunk, f"Chunk {i} missing content"
|
||||||
|
assert "end_of_message" in chunk, f"Chunk {i} missing end_of_message"
|
||||||
|
assert "end_of_dialog" in chunk, f"Chunk {i} missing end_of_dialog"
|
||||||
|
|
||||||
|
# Validate chunk_type values
|
||||||
|
valid_types = ["thought", "action", "observation", "final-answer"]
|
||||||
|
assert chunk["chunk_type"] in valid_types, \
|
||||||
|
f"Invalid chunk_type '{chunk['chunk_type']}' at index {i}"
|
||||||
|
|
||||||
|
# Last chunk should signal end of dialog
|
||||||
|
assert chunks[-1]["end_of_dialog"] is True, \
|
||||||
|
"Last chunk should have end_of_dialog=True"
|
||||||
|
|
||||||
|
|
||||||
|
def assert_rag_streaming_chunks(chunks: List[Dict[str, Any]]):
|
||||||
|
"""
|
||||||
|
Assert that RAG streaming chunks have valid structure.
|
||||||
|
|
||||||
|
Validates:
|
||||||
|
- All chunks except last have chunk field
|
||||||
|
- All chunks have end_of_stream field
|
||||||
|
- Last chunk has end_of_stream=True
|
||||||
|
- Last chunk may have response field with complete text
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunks: List of RAG streaming chunk dictionaries
|
||||||
|
"""
|
||||||
|
assert len(chunks) > 0, "Expected at least one chunk"
|
||||||
|
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
assert "end_of_stream" in chunk, f"Chunk {i} missing end_of_stream"
|
||||||
|
|
||||||
|
if i < len(chunks) - 1:
|
||||||
|
# Non-final chunks should have chunk content and end_of_stream=False
|
||||||
|
assert "chunk" in chunk, f"Chunk {i} missing chunk field"
|
||||||
|
assert chunk["end_of_stream"] is False, \
|
||||||
|
f"Non-final chunk {i} should have end_of_stream=False"
|
||||||
|
else:
|
||||||
|
# Final chunk should have end_of_stream=True
|
||||||
|
assert chunk["end_of_stream"] is True, \
|
||||||
|
"Last chunk should have end_of_stream=True"
|
||||||
|
|
||||||
|
|
||||||
|
def assert_streaming_completion(chunks: List[Dict[str, Any]], expected_complete_flag: str = "end_of_stream"):
|
||||||
|
"""
|
||||||
|
Assert that streaming completed properly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunks: List of streaming chunk dictionaries
|
||||||
|
expected_complete_flag: Name of the completion flag field
|
||||||
|
"""
|
||||||
|
assert len(chunks) > 0, "Expected at least one chunk"
|
||||||
|
|
||||||
|
# Check that all but last chunk have completion flag = False
|
||||||
|
for i, chunk in enumerate(chunks[:-1]):
|
||||||
|
assert chunk.get(expected_complete_flag) is False, \
|
||||||
|
f"Non-final chunk {i} should have {expected_complete_flag}=False"
|
||||||
|
|
||||||
|
# Check that last chunk has completion flag = True
|
||||||
|
assert chunks[-1].get(expected_complete_flag) is True, \
|
||||||
|
f"Final chunk should have {expected_complete_flag}=True"
|
||||||
|
|
||||||
|
|
||||||
|
def assert_streaming_content_matches(chunks: List, expected_content: str, content_key: str = "chunk"):
|
||||||
|
"""
|
||||||
|
Assert that concatenated streaming chunks match expected content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunks: List of streaming chunks (strings or dicts)
|
||||||
|
expected_content: Expected complete content after concatenation
|
||||||
|
content_key: Dictionary key for content (used if chunks are dicts)
|
||||||
|
"""
|
||||||
|
if isinstance(chunks[0], dict):
|
||||||
|
# Extract content from chunk dictionaries
|
||||||
|
content_chunks = [
|
||||||
|
chunk.get(content_key, "")
|
||||||
|
for chunk in chunks
|
||||||
|
if chunk.get(content_key) is not None
|
||||||
|
]
|
||||||
|
actual_content = "".join(content_chunks)
|
||||||
|
else:
|
||||||
|
# Chunks are already strings
|
||||||
|
actual_content = "".join(chunks)
|
||||||
|
|
||||||
|
assert actual_content == expected_content, \
|
||||||
|
f"Expected content '{expected_content}', got '{actual_content}'"
|
||||||
|
|
||||||
|
|
||||||
|
def assert_no_empty_chunks(chunks: List[Dict[str, Any]], content_key: str = "content"):
|
||||||
|
"""
|
||||||
|
Assert that no chunks have empty content (except final chunk if it's completion marker).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunks: List of streaming chunk dictionaries
|
||||||
|
content_key: Dictionary key for content
|
||||||
|
"""
|
||||||
|
for i, chunk in enumerate(chunks[:-1]):
|
||||||
|
content = chunk.get(content_key)
|
||||||
|
assert content is not None and len(content) > 0, \
|
||||||
|
f"Chunk {i} has empty content"
|
||||||
|
|
||||||
|
|
||||||
|
def assert_streaming_error_handled(chunks: List[Dict[str, Any]], error_flag: str = "error"):
|
||||||
|
"""
|
||||||
|
Assert that streaming error was properly signaled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunks: List of streaming chunk dictionaries
|
||||||
|
error_flag: Name of the error flag field
|
||||||
|
"""
|
||||||
|
# Check that at least one chunk has error flag
|
||||||
|
has_error = any(chunk.get(error_flag) is not None for chunk in chunks)
|
||||||
|
assert has_error, "Expected error flag in at least one chunk"
|
||||||
|
|
||||||
|
# If last chunk has error, should also have completion flag
|
||||||
|
if chunks[-1].get(error_flag):
|
||||||
|
# Check for completion flags (either end_of_stream or end_of_dialog)
|
||||||
|
completion_flags = ["end_of_stream", "end_of_dialog"]
|
||||||
|
has_completion = any(chunks[-1].get(flag) is True for flag in completion_flags)
|
||||||
|
assert has_completion, \
|
||||||
|
"Error chunk should have completion flag set to True"
|
||||||
|
|
||||||
|
|
||||||
|
def assert_chunk_types_valid(chunks: List[Dict[str, Any]], valid_types: List[str], type_key: str = "chunk_type"):
|
||||||
|
"""
|
||||||
|
Assert that all chunk types are from a valid set.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunks: List of streaming chunk dictionaries
|
||||||
|
valid_types: List of valid chunk type values
|
||||||
|
type_key: Dictionary key for chunk type
|
||||||
|
"""
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
chunk_type = chunk.get(type_key)
|
||||||
|
assert chunk_type in valid_types, \
|
||||||
|
f"Chunk {i} has invalid type '{chunk_type}', expected one of {valid_types}"
|
||||||
|
|
||||||
|
|
||||||
|
def assert_streaming_latency_acceptable(chunk_timestamps: List[float], max_gap_seconds: float = 5.0):
|
||||||
|
"""
|
||||||
|
Assert that streaming latency between chunks is acceptable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_timestamps: List of timestamps when chunks were received
|
||||||
|
max_gap_seconds: Maximum acceptable gap between chunks in seconds
|
||||||
|
"""
|
||||||
|
assert len(chunk_timestamps) > 1, "Need at least 2 timestamps to check latency"
|
||||||
|
|
||||||
|
for i in range(1, len(chunk_timestamps)):
|
||||||
|
gap = chunk_timestamps[i] - chunk_timestamps[i-1]
|
||||||
|
assert gap <= max_gap_seconds, \
|
||||||
|
f"Gap between chunks {i-1} and {i} is {gap:.2f}s, exceeds max {max_gap_seconds}s"
|
||||||
|
|
||||||
|
|
||||||
|
def assert_callback_invoked(mock_callback, min_calls: int = 1):
|
||||||
|
"""
|
||||||
|
Assert that a streaming callback was invoked minimum number of times.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mock_callback: AsyncMock callback object
|
||||||
|
min_calls: Minimum number of expected calls
|
||||||
|
"""
|
||||||
|
assert mock_callback.call_count >= min_calls, \
|
||||||
|
f"Expected callback to be called at least {min_calls} times, was called {mock_callback.call_count} times"
|
||||||
|
|
@ -112,7 +112,7 @@ class PromptClient(RequestResponse):
|
||||||
timeout = timeout,
|
timeout = timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def kg_prompt(self, query, kg, timeout=600):
|
async def kg_prompt(self, query, kg, timeout=600, streaming=False, chunk_callback=None):
|
||||||
return await self.prompt(
|
return await self.prompt(
|
||||||
id = "kg-prompt",
|
id = "kg-prompt",
|
||||||
variables = {
|
variables = {
|
||||||
|
|
@ -123,9 +123,11 @@ class PromptClient(RequestResponse):
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
timeout = timeout,
|
timeout = timeout,
|
||||||
|
streaming = streaming,
|
||||||
|
chunk_callback = chunk_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def document_prompt(self, query, documents, timeout=600):
|
async def document_prompt(self, query, documents, timeout=600, streaming=False, chunk_callback=None):
|
||||||
return await self.prompt(
|
return await self.prompt(
|
||||||
id = "document-prompt",
|
id = "document-prompt",
|
||||||
variables = {
|
variables = {
|
||||||
|
|
@ -133,6 +135,8 @@ class PromptClient(RequestResponse):
|
||||||
"documents": documents,
|
"documents": documents,
|
||||||
},
|
},
|
||||||
timeout = timeout,
|
timeout = timeout,
|
||||||
|
streaming = streaming,
|
||||||
|
chunk_callback = chunk_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def agent_react(self, variables, timeout=600, streaming=False, chunk_callback=None):
|
async def agent_react(self, variables, timeout=600, streaming=False, chunk_callback=None):
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,8 @@ class DocumentRagRequestTranslator(MessageTranslator):
|
||||||
query=data["query"],
|
query=data["query"],
|
||||||
user=data.get("user", "trustgraph"),
|
user=data.get("user", "trustgraph"),
|
||||||
collection=data.get("collection", "default"),
|
collection=data.get("collection", "default"),
|
||||||
doc_limit=int(data.get("doc-limit", 20))
|
doc_limit=int(data.get("doc-limit", 20)),
|
||||||
|
streaming=data.get("streaming", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
def from_pulsar(self, obj: DocumentRagQuery) -> Dict[str, Any]:
|
def from_pulsar(self, obj: DocumentRagQuery) -> Dict[str, Any]:
|
||||||
|
|
@ -19,7 +20,8 @@ class DocumentRagRequestTranslator(MessageTranslator):
|
||||||
"query": obj.query,
|
"query": obj.query,
|
||||||
"user": obj.user,
|
"user": obj.user,
|
||||||
"collection": obj.collection,
|
"collection": obj.collection,
|
||||||
"doc-limit": obj.doc_limit
|
"doc-limit": obj.doc_limit,
|
||||||
|
"streaming": getattr(obj, "streaming", False)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -30,13 +32,33 @@ class DocumentRagResponseTranslator(MessageTranslator):
|
||||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||||
|
|
||||||
def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]:
|
def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]:
|
||||||
return {
|
result = {}
|
||||||
"response": obj.response
|
|
||||||
}
|
# Check if this is a streaming response (has chunk)
|
||||||
|
if hasattr(obj, 'chunk') and obj.chunk:
|
||||||
|
result["chunk"] = obj.chunk
|
||||||
|
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
||||||
|
else:
|
||||||
|
# Non-streaming response
|
||||||
|
if obj.response:
|
||||||
|
result["response"] = obj.response
|
||||||
|
|
||||||
|
# Always include error if present
|
||||||
|
if hasattr(obj, 'error') and obj.error and obj.error.message:
|
||||||
|
result["error"] = {"message": obj.error.message, "type": obj.error.type}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def from_response_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]:
|
def from_response_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
"""Returns (response_dict, is_final)"""
|
"""Returns (response_dict, is_final)"""
|
||||||
return self.from_pulsar(obj), True
|
# For streaming responses, check end_of_stream
|
||||||
|
if hasattr(obj, 'chunk') and obj.chunk:
|
||||||
|
is_final = getattr(obj, 'end_of_stream', False)
|
||||||
|
else:
|
||||||
|
# For non-streaming responses, it's always final
|
||||||
|
is_final = True
|
||||||
|
|
||||||
|
return self.from_pulsar(obj), is_final
|
||||||
|
|
||||||
|
|
||||||
class GraphRagRequestTranslator(MessageTranslator):
|
class GraphRagRequestTranslator(MessageTranslator):
|
||||||
|
|
@ -50,7 +72,8 @@ class GraphRagRequestTranslator(MessageTranslator):
|
||||||
entity_limit=int(data.get("entity-limit", 50)),
|
entity_limit=int(data.get("entity-limit", 50)),
|
||||||
triple_limit=int(data.get("triple-limit", 30)),
|
triple_limit=int(data.get("triple-limit", 30)),
|
||||||
max_subgraph_size=int(data.get("max-subgraph-size", 1000)),
|
max_subgraph_size=int(data.get("max-subgraph-size", 1000)),
|
||||||
max_path_length=int(data.get("max-path-length", 2))
|
max_path_length=int(data.get("max-path-length", 2)),
|
||||||
|
streaming=data.get("streaming", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
def from_pulsar(self, obj: GraphRagQuery) -> Dict[str, Any]:
|
def from_pulsar(self, obj: GraphRagQuery) -> Dict[str, Any]:
|
||||||
|
|
@ -61,7 +84,8 @@ class GraphRagRequestTranslator(MessageTranslator):
|
||||||
"entity-limit": obj.entity_limit,
|
"entity-limit": obj.entity_limit,
|
||||||
"triple-limit": obj.triple_limit,
|
"triple-limit": obj.triple_limit,
|
||||||
"max-subgraph-size": obj.max_subgraph_size,
|
"max-subgraph-size": obj.max_subgraph_size,
|
||||||
"max-path-length": obj.max_path_length
|
"max-path-length": obj.max_path_length,
|
||||||
|
"streaming": getattr(obj, "streaming", False)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -72,10 +96,30 @@ class GraphRagResponseTranslator(MessageTranslator):
|
||||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||||
|
|
||||||
def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]:
|
def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]:
|
||||||
return {
|
result = {}
|
||||||
"response": obj.response
|
|
||||||
}
|
# Check if this is a streaming response (has chunk)
|
||||||
|
if hasattr(obj, 'chunk') and obj.chunk:
|
||||||
|
result["chunk"] = obj.chunk
|
||||||
|
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
||||||
|
else:
|
||||||
|
# Non-streaming response
|
||||||
|
if obj.response:
|
||||||
|
result["response"] = obj.response
|
||||||
|
|
||||||
|
# Always include error if present
|
||||||
|
if hasattr(obj, 'error') and obj.error and obj.error.message:
|
||||||
|
result["error"] = {"message": obj.error.message, "type": obj.error.type}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def from_response_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]:
|
def from_response_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
"""Returns (response_dict, is_final)"""
|
"""Returns (response_dict, is_final)"""
|
||||||
return self.from_pulsar(obj), True
|
# For streaming responses, check end_of_stream
|
||||||
|
if hasattr(obj, 'chunk') and obj.chunk:
|
||||||
|
is_final = getattr(obj, 'end_of_stream', False)
|
||||||
|
else:
|
||||||
|
# For non-streaming responses, it's always final
|
||||||
|
is_final = True
|
||||||
|
|
||||||
|
return self.from_pulsar(obj), is_final
|
||||||
|
|
@ -15,10 +15,13 @@ class GraphRagQuery(Record):
|
||||||
triple_limit = Integer()
|
triple_limit = Integer()
|
||||||
max_subgraph_size = Integer()
|
max_subgraph_size = Integer()
|
||||||
max_path_length = Integer()
|
max_path_length = Integer()
|
||||||
|
streaming = Boolean()
|
||||||
|
|
||||||
class GraphRagResponse(Record):
|
class GraphRagResponse(Record):
|
||||||
error = Error()
|
error = Error()
|
||||||
response = String()
|
response = String()
|
||||||
|
chunk = String()
|
||||||
|
end_of_stream = Boolean()
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
|
|
@ -29,8 +32,11 @@ class DocumentRagQuery(Record):
|
||||||
user = String()
|
user = String()
|
||||||
collection = String()
|
collection = String()
|
||||||
doc_limit = Integer()
|
doc_limit = Integer()
|
||||||
|
streaming = Boolean()
|
||||||
|
|
||||||
class DocumentRagResponse(Record):
|
class DocumentRagResponse(Record):
|
||||||
error = Error()
|
error = Error()
|
||||||
response = String()
|
response = String()
|
||||||
|
chunk = String()
|
||||||
|
end_of_stream = Boolean()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,10 @@ Uses the DocumentRAG service to answer a question
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from websockets.asyncio.client import connect
|
||||||
from trustgraph.api import Api
|
from trustgraph.api import Api
|
||||||
|
|
||||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||||
|
|
@ -11,7 +15,69 @@ default_user = 'trustgraph'
|
||||||
default_collection = 'default'
|
default_collection = 'default'
|
||||||
default_doc_limit = 10
|
default_doc_limit = 10
|
||||||
|
|
||||||
def question(url, flow_id, question, user, collection, doc_limit):
|
async def question_streaming(url, flow_id, question, user, collection, doc_limit):
|
||||||
|
"""Streaming version using websockets"""
|
||||||
|
|
||||||
|
# Convert http:// to ws://
|
||||||
|
if url.startswith('http://'):
|
||||||
|
url = 'ws://' + url[7:]
|
||||||
|
elif url.startswith('https://'):
|
||||||
|
url = 'wss://' + url[8:]
|
||||||
|
|
||||||
|
if not url.endswith("/"):
|
||||||
|
url += "/"
|
||||||
|
|
||||||
|
url = url + "api/v1/socket"
|
||||||
|
|
||||||
|
mid = str(uuid.uuid4())
|
||||||
|
|
||||||
|
async with connect(url) as ws:
|
||||||
|
req = {
|
||||||
|
"id": mid,
|
||||||
|
"service": "document-rag",
|
||||||
|
"flow": flow_id,
|
||||||
|
"request": {
|
||||||
|
"query": question,
|
||||||
|
"user": user,
|
||||||
|
"collection": collection,
|
||||||
|
"doc-limit": doc_limit,
|
||||||
|
"streaming": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
req = json.dumps(req)
|
||||||
|
await ws.send(req)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
msg = await ws.recv()
|
||||||
|
obj = json.loads(msg)
|
||||||
|
|
||||||
|
if "error" in obj:
|
||||||
|
raise RuntimeError(obj["error"])
|
||||||
|
|
||||||
|
if obj["id"] != mid:
|
||||||
|
print("Ignore message")
|
||||||
|
continue
|
||||||
|
|
||||||
|
response = obj["response"]
|
||||||
|
|
||||||
|
# Handle streaming format (chunk)
|
||||||
|
if "chunk" in response:
|
||||||
|
chunk = response["chunk"]
|
||||||
|
print(chunk, end="", flush=True)
|
||||||
|
elif "response" in response:
|
||||||
|
# Final response with complete text
|
||||||
|
# Already printed via chunks, just add newline
|
||||||
|
pass
|
||||||
|
|
||||||
|
if obj["complete"]:
|
||||||
|
print() # Final newline
|
||||||
|
break
|
||||||
|
|
||||||
|
await ws.close()
|
||||||
|
|
||||||
|
def question_non_streaming(url, flow_id, question, user, collection, doc_limit):
|
||||||
|
"""Non-streaming version using HTTP API"""
|
||||||
|
|
||||||
api = Api(url).flow().id(flow_id)
|
api = Api(url).flow().id(flow_id)
|
||||||
|
|
||||||
|
|
@ -65,18 +131,36 @@ def main():
|
||||||
help=f'Document limit (default: {default_doc_limit})'
|
help=f'Document limit (default: {default_doc_limit})'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--no-streaming',
|
||||||
|
action='store_true',
|
||||||
|
help='Disable streaming (use non-streaming mode)'
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
question(
|
if not args.no_streaming:
|
||||||
url=args.url,
|
asyncio.run(
|
||||||
flow_id = args.flow_id,
|
question_streaming(
|
||||||
question=args.question,
|
url=args.url,
|
||||||
user=args.user,
|
flow_id=args.flow_id,
|
||||||
collection=args.collection,
|
question=args.question,
|
||||||
doc_limit=args.doc_limit,
|
user=args.user,
|
||||||
)
|
collection=args.collection,
|
||||||
|
doc_limit=args.doc_limit,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
question_non_streaming(
|
||||||
|
url=args.url,
|
||||||
|
flow_id=args.flow_id,
|
||||||
|
question=args.question,
|
||||||
|
user=args.user,
|
||||||
|
collection=args.collection,
|
||||||
|
doc_limit=args.doc_limit,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,10 @@ Uses the GraphRAG service to answer a question
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from websockets.asyncio.client import connect
|
||||||
from trustgraph.api import Api
|
from trustgraph.api import Api
|
||||||
|
|
||||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||||
|
|
@ -14,10 +18,78 @@ default_triple_limit = 30
|
||||||
default_max_subgraph_size = 150
|
default_max_subgraph_size = 150
|
||||||
default_max_path_length = 2
|
default_max_path_length = 2
|
||||||
|
|
||||||
def question(
|
async def question_streaming(
|
||||||
url, flow_id, question, user, collection, entity_limit, triple_limit,
|
url, flow_id, question, user, collection, entity_limit, triple_limit,
|
||||||
max_subgraph_size, max_path_length
|
max_subgraph_size, max_path_length
|
||||||
):
|
):
|
||||||
|
"""Streaming version using websockets"""
|
||||||
|
|
||||||
|
# Convert http:// to ws://
|
||||||
|
if url.startswith('http://'):
|
||||||
|
url = 'ws://' + url[7:]
|
||||||
|
elif url.startswith('https://'):
|
||||||
|
url = 'wss://' + url[8:]
|
||||||
|
|
||||||
|
if not url.endswith("/"):
|
||||||
|
url += "/"
|
||||||
|
|
||||||
|
url = url + "api/v1/socket"
|
||||||
|
|
||||||
|
mid = str(uuid.uuid4())
|
||||||
|
|
||||||
|
async with connect(url) as ws:
|
||||||
|
req = {
|
||||||
|
"id": mid,
|
||||||
|
"service": "graph-rag",
|
||||||
|
"flow": flow_id,
|
||||||
|
"request": {
|
||||||
|
"query": question,
|
||||||
|
"user": user,
|
||||||
|
"collection": collection,
|
||||||
|
"entity-limit": entity_limit,
|
||||||
|
"triple-limit": triple_limit,
|
||||||
|
"max-subgraph-size": max_subgraph_size,
|
||||||
|
"max-path-length": max_path_length,
|
||||||
|
"streaming": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
req = json.dumps(req)
|
||||||
|
await ws.send(req)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
msg = await ws.recv()
|
||||||
|
obj = json.loads(msg)
|
||||||
|
|
||||||
|
if "error" in obj:
|
||||||
|
raise RuntimeError(obj["error"])
|
||||||
|
|
||||||
|
if obj["id"] != mid:
|
||||||
|
print("Ignore message")
|
||||||
|
continue
|
||||||
|
|
||||||
|
response = obj["response"]
|
||||||
|
|
||||||
|
# Handle streaming format (chunk)
|
||||||
|
if "chunk" in response:
|
||||||
|
chunk = response["chunk"]
|
||||||
|
print(chunk, end="", flush=True)
|
||||||
|
elif "response" in response:
|
||||||
|
# Final response with complete text
|
||||||
|
# Already printed via chunks, just add newline
|
||||||
|
pass
|
||||||
|
|
||||||
|
if obj["complete"]:
|
||||||
|
print() # Final newline
|
||||||
|
break
|
||||||
|
|
||||||
|
await ws.close()
|
||||||
|
|
||||||
|
def question_non_streaming(
|
||||||
|
url, flow_id, question, user, collection, entity_limit, triple_limit,
|
||||||
|
max_subgraph_size, max_path_length
|
||||||
|
):
|
||||||
|
"""Non-streaming version using HTTP API"""
|
||||||
|
|
||||||
api = Api(url).flow().id(flow_id)
|
api = Api(url).flow().id(flow_id)
|
||||||
|
|
||||||
|
|
@ -91,21 +163,42 @@ def main():
|
||||||
help=f'Max path length (default: {default_max_path_length})'
|
help=f'Max path length (default: {default_max_path_length})'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--no-streaming',
|
||||||
|
action='store_true',
|
||||||
|
help='Disable streaming (use non-streaming mode)'
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
question(
|
if not args.no_streaming:
|
||||||
url=args.url,
|
asyncio.run(
|
||||||
flow_id = args.flow_id,
|
question_streaming(
|
||||||
question=args.question,
|
url=args.url,
|
||||||
user=args.user,
|
flow_id=args.flow_id,
|
||||||
collection=args.collection,
|
question=args.question,
|
||||||
entity_limit=args.entity_limit,
|
user=args.user,
|
||||||
triple_limit=args.triple_limit,
|
collection=args.collection,
|
||||||
max_subgraph_size=args.max_subgraph_size,
|
entity_limit=args.entity_limit,
|
||||||
max_path_length=args.max_path_length,
|
triple_limit=args.triple_limit,
|
||||||
)
|
max_subgraph_size=args.max_subgraph_size,
|
||||||
|
max_path_length=args.max_path_length,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
question_non_streaming(
|
||||||
|
url=args.url,
|
||||||
|
flow_id=args.flow_id,
|
||||||
|
question=args.question,
|
||||||
|
user=args.user,
|
||||||
|
collection=args.collection,
|
||||||
|
entity_limit=args.entity_limit,
|
||||||
|
triple_limit=args.triple_limit,
|
||||||
|
max_subgraph_size=args.max_subgraph_size,
|
||||||
|
max_path_length=args.max_path_length,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -202,12 +202,16 @@ class StreamingReActParser:
|
||||||
# Find which comes first
|
# Find which comes first
|
||||||
if args_idx >= 0 and (newline_idx < 0 or args_idx < newline_idx):
|
if args_idx >= 0 and (newline_idx < 0 or args_idx < newline_idx):
|
||||||
# Args delimiter found first
|
# Args delimiter found first
|
||||||
self.action_buffer = self.line_buffer[:args_idx].strip().strip('"')
|
# Only set action_buffer if not already set (to avoid overwriting with empty string)
|
||||||
|
if not self.action_buffer:
|
||||||
|
self.action_buffer = self.line_buffer[:args_idx].strip().strip('"')
|
||||||
self.line_buffer = self.line_buffer[args_idx + len(self.ARGS_DELIMITER):].lstrip()
|
self.line_buffer = self.line_buffer[args_idx + len(self.ARGS_DELIMITER):].lstrip()
|
||||||
self.state = ParserState.ARGS
|
self.state = ParserState.ARGS
|
||||||
elif newline_idx >= 0:
|
elif newline_idx >= 0:
|
||||||
# Newline found, action name complete
|
# Newline found, action name complete
|
||||||
self.action_buffer = self.line_buffer[:newline_idx].strip().strip('"')
|
# Only set action_buffer if not already set
|
||||||
|
if not self.action_buffer:
|
||||||
|
self.action_buffer = self.line_buffer[:newline_idx].strip().strip('"')
|
||||||
self.line_buffer = self.line_buffer[newline_idx + 1:]
|
self.line_buffer = self.line_buffer[newline_idx + 1:]
|
||||||
# Stay in ACTION state or move to ARGS if we find delimiter
|
# Stay in ACTION state or move to ARGS if we find delimiter
|
||||||
# Actually, check if next line has Args:
|
# Actually, check if next line has Args:
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ class DocumentRag:
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
self, query, user="trustgraph", collection="default",
|
self, query, user="trustgraph", collection="default",
|
||||||
doc_limit=20,
|
doc_limit=20, streaming=False, chunk_callback=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
|
@ -86,10 +86,18 @@ class DocumentRag:
|
||||||
logger.debug(f"Documents: {docs}")
|
logger.debug(f"Documents: {docs}")
|
||||||
logger.debug(f"Query: {query}")
|
logger.debug(f"Query: {query}")
|
||||||
|
|
||||||
resp = await self.prompt_client.document_prompt(
|
if streaming and chunk_callback:
|
||||||
query = query,
|
resp = await self.prompt_client.document_prompt(
|
||||||
documents = docs
|
query=query,
|
||||||
)
|
documents=docs,
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=chunk_callback
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
resp = await self.prompt_client.document_prompt(
|
||||||
|
query=query,
|
||||||
|
documents=docs
|
||||||
|
)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.debug("Query processing complete")
|
logger.debug("Query processing complete")
|
||||||
|
|
|
||||||
|
|
@ -92,20 +92,56 @@ class Processor(FlowProcessor):
|
||||||
else:
|
else:
|
||||||
doc_limit = self.doc_limit
|
doc_limit = self.doc_limit
|
||||||
|
|
||||||
response = await self.rag.query(
|
# Check if streaming is requested
|
||||||
v.query,
|
if v.streaming:
|
||||||
user=v.user,
|
# Define async callback for streaming chunks
|
||||||
collection=v.collection,
|
async def send_chunk(chunk):
|
||||||
doc_limit=doc_limit
|
await flow("response").send(
|
||||||
)
|
DocumentRagResponse(
|
||||||
|
chunk=chunk,
|
||||||
|
end_of_stream=False,
|
||||||
|
response=None,
|
||||||
|
error=None
|
||||||
|
),
|
||||||
|
properties={"id": id}
|
||||||
|
)
|
||||||
|
|
||||||
await flow("response").send(
|
# Query with streaming enabled
|
||||||
DocumentRagResponse(
|
full_response = await self.rag.query(
|
||||||
response = response,
|
v.query,
|
||||||
error = None
|
user=v.user,
|
||||||
),
|
collection=v.collection,
|
||||||
properties = {"id": id}
|
doc_limit=doc_limit,
|
||||||
)
|
streaming=True,
|
||||||
|
chunk_callback=send_chunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send final message with complete response
|
||||||
|
await flow("response").send(
|
||||||
|
DocumentRagResponse(
|
||||||
|
chunk=None,
|
||||||
|
end_of_stream=True,
|
||||||
|
response=full_response,
|
||||||
|
error=None
|
||||||
|
),
|
||||||
|
properties={"id": id}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Non-streaming path (existing behavior)
|
||||||
|
response = await self.rag.query(
|
||||||
|
v.query,
|
||||||
|
user=v.user,
|
||||||
|
collection=v.collection,
|
||||||
|
doc_limit=doc_limit
|
||||||
|
)
|
||||||
|
|
||||||
|
await flow("response").send(
|
||||||
|
DocumentRagResponse(
|
||||||
|
response = response,
|
||||||
|
error = None
|
||||||
|
),
|
||||||
|
properties = {"id": id}
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Request processing complete")
|
logger.info("Request processing complete")
|
||||||
|
|
||||||
|
|
@ -115,14 +151,21 @@ class Processor(FlowProcessor):
|
||||||
|
|
||||||
logger.debug("Sending error response...")
|
logger.debug("Sending error response...")
|
||||||
|
|
||||||
await flow("response").send(
|
# Send error response with end_of_stream flag if streaming was requested
|
||||||
DocumentRagResponse(
|
error_response = DocumentRagResponse(
|
||||||
response = None,
|
response = None,
|
||||||
error = Error(
|
error = Error(
|
||||||
type = "document-rag-error",
|
type = "document-rag-error",
|
||||||
message = str(e),
|
message = str(e),
|
||||||
),
|
|
||||||
),
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# If streaming was requested, indicate stream end
|
||||||
|
if v.streaming:
|
||||||
|
error_response.end_of_stream = True
|
||||||
|
|
||||||
|
await flow("response").send(
|
||||||
|
error_response,
|
||||||
properties = {"id": id}
|
properties = {"id": id}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -316,7 +316,7 @@ class GraphRag:
|
||||||
async def query(
|
async def query(
|
||||||
self, query, user = "trustgraph", collection = "default",
|
self, query, user = "trustgraph", collection = "default",
|
||||||
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
|
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
|
||||||
max_path_length = 2,
|
max_path_length = 2, streaming = False, chunk_callback = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
|
@ -337,7 +337,14 @@ class GraphRag:
|
||||||
logger.debug(f"Knowledge graph: {kg}")
|
logger.debug(f"Knowledge graph: {kg}")
|
||||||
logger.debug(f"Query: {query}")
|
logger.debug(f"Query: {query}")
|
||||||
|
|
||||||
resp = await self.prompt_client.kg_prompt(query, kg)
|
if streaming and chunk_callback:
|
||||||
|
resp = await self.prompt_client.kg_prompt(
|
||||||
|
query, kg,
|
||||||
|
streaming=True,
|
||||||
|
chunk_callback=chunk_callback
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
resp = await self.prompt_client.kg_prompt(query, kg)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.debug("Query processing complete")
|
logger.debug("Query processing complete")
|
||||||
|
|
|
||||||
|
|
@ -135,20 +135,56 @@ class Processor(FlowProcessor):
|
||||||
else:
|
else:
|
||||||
max_path_length = self.default_max_path_length
|
max_path_length = self.default_max_path_length
|
||||||
|
|
||||||
response = await rag.query(
|
# Check if streaming is requested
|
||||||
query = v.query, user = v.user, collection = v.collection,
|
if v.streaming:
|
||||||
entity_limit = entity_limit, triple_limit = triple_limit,
|
# Define async callback for streaming chunks
|
||||||
max_subgraph_size = max_subgraph_size,
|
async def send_chunk(chunk):
|
||||||
max_path_length = max_path_length,
|
await flow("response").send(
|
||||||
)
|
GraphRagResponse(
|
||||||
|
chunk=chunk,
|
||||||
|
end_of_stream=False,
|
||||||
|
response=None,
|
||||||
|
error=None
|
||||||
|
),
|
||||||
|
properties={"id": id}
|
||||||
|
)
|
||||||
|
|
||||||
await flow("response").send(
|
# Query with streaming enabled
|
||||||
GraphRagResponse(
|
full_response = await rag.query(
|
||||||
response = response,
|
query = v.query, user = v.user, collection = v.collection,
|
||||||
error = None
|
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||||
),
|
max_subgraph_size = max_subgraph_size,
|
||||||
properties = {"id": id}
|
max_path_length = max_path_length,
|
||||||
)
|
streaming = True,
|
||||||
|
chunk_callback = send_chunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send final message with complete response
|
||||||
|
await flow("response").send(
|
||||||
|
GraphRagResponse(
|
||||||
|
chunk=None,
|
||||||
|
end_of_stream=True,
|
||||||
|
response=full_response,
|
||||||
|
error=None
|
||||||
|
),
|
||||||
|
properties={"id": id}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Non-streaming path (existing behavior)
|
||||||
|
response = await rag.query(
|
||||||
|
query = v.query, user = v.user, collection = v.collection,
|
||||||
|
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||||
|
max_subgraph_size = max_subgraph_size,
|
||||||
|
max_path_length = max_path_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
await flow("response").send(
|
||||||
|
GraphRagResponse(
|
||||||
|
response = response,
|
||||||
|
error = None
|
||||||
|
),
|
||||||
|
properties = {"id": id}
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Request processing complete")
|
logger.info("Request processing complete")
|
||||||
|
|
||||||
|
|
@ -158,14 +194,21 @@ class Processor(FlowProcessor):
|
||||||
|
|
||||||
logger.debug("Sending error response...")
|
logger.debug("Sending error response...")
|
||||||
|
|
||||||
await flow("response").send(
|
# Send error response with end_of_stream flag if streaming was requested
|
||||||
GraphRagResponse(
|
error_response = GraphRagResponse(
|
||||||
response = None,
|
response = None,
|
||||||
error = Error(
|
error = Error(
|
||||||
type = "graph-rag-error",
|
type = "graph-rag-error",
|
||||||
message = str(e),
|
message = str(e),
|
||||||
),
|
|
||||||
),
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# If streaming was requested, indicate stream end
|
||||||
|
if v.streaming:
|
||||||
|
error_response.end_of_stream = True
|
||||||
|
|
||||||
|
await flow("response").send(
|
||||||
|
error_response,
|
||||||
properties = {"id": id}
|
properties = {"id": id}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue