trustgraph/tests/integration/test_document_rag_integration.py
cybermaggedon d35473f7f7
feat: workspace-based multi-tenancy, replacing user as tenancy axis (#840)
Introduces `workspace` as the isolation boundary for config, flows,
library, and knowledge data. Removes `user` as a schema-level field
throughout the code, API specs, and tests; workspace provides the
same separation more cleanly at the trusted flow.workspace layer
rather than through client-supplied message fields.

Design
------
- IAM tech spec (docs/tech-specs/iam.md) documents current state,
  proposed auth/access model, and migration direction.
- Data ownership model (docs/tech-specs/data-ownership-model.md)
  captures the workspace/collection/flow hierarchy.

Schema + messaging
------------------
- Drop `user` field from AgentRequest/Step, GraphRagQuery,
  DocumentRagQuery, Triples/Graph/Document/Row EmbeddingsRequest,
  Sparql/Rows/Structured QueryRequest, ToolServiceRequest.
- Keep collection/workspace routing via flow.workspace at the
  service layer.
- Translators updated to not serialise/deserialise user.

API specs
---------
- OpenAPI schemas and path examples cleaned of user fields.
- Websocket async-api messages updated.
- Removed the unused parameters/User.yaml.

Services + base
---------------
- Librarian, collection manager, knowledge, config: all operations
  scoped by workspace. Config client API takes workspace as first
  positional arg.
- `flow.workspace` set at flow start time by the infrastructure;
  no longer pass-through from clients.
- Tool service drops user-personalisation passthrough.

CLI + SDK
---------
- tg-init-workspace and workspace-aware import/export.
- All tg-* commands drop user args; accept --workspace.
- Python API/SDK (flow, socket_client, async_*, explainability,
  library) drop user kwargs from every method signature.

MCP server
----------
- All tool endpoints drop user parameters; socket_manager no longer
  keyed per user.

Flow service
------------
- Closure-based topic cleanup on flow stop: only delete topics
  whose blueprint template was parameterised AND no remaining
  live flow (across all workspaces) still resolves to that topic.
  Three scopes fall out naturally from template analysis:
    * {id} -> per-flow, deleted on stop
    * {blueprint} -> per-blueprint, kept while any flow of the
      same blueprint exists
    * {workspace} -> per-workspace, kept while any flow in the
      workspace exists
    * literal -> global, never deleted (e.g. tg.request.librarian)
  Fixes a bug where stopping a flow silently destroyed the global
  librarian exchange, wedging all library operations until manual
  restart.

RabbitMQ backend
----------------
- heartbeat=60, blocked_connection_timeout=300. Catches silently
  dead connections (broker restart, orphaned channels, network
  partitions) within ~2 heartbeat windows, so the consumer
  reconnects and re-binds its queue rather than sitting forever
  on a zombie connection.

Tests
-----
- Full test refresh: unit, integration, contract, provenance.
- Dropped user-field assertions and constructor kwargs across
  ~100 test files.
- Renamed user-collection isolation tests to workspace-collection.
2026-04-21 23:23:01 +01:00

353 lines
14 KiB
Python

"""
Integration tests for DocumentRAG retrieval system
These tests verify the end-to-end functionality of the DocumentRAG system,
testing the coordination between embeddings, document retrieval, and prompt services.
Following the TEST_STRATEGY.md approach for integration testing.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import ChunkMatch
from trustgraph.base import PromptResult
# Sample chunk content for testing - maps chunk_id to content
CHUNK_CONTENT = {
"doc/c1": "Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
"doc/c2": "Deep learning uses neural networks with multiple layers to model complex patterns in data.",
"doc/c3": "Supervised learning algorithms learn from labeled training data to make predictions on new data.",
}
@pytest.mark.integration
class TestDocumentRagIntegration:
"""Integration tests for DocumentRAG system coordination"""
@pytest.fixture
def mock_embeddings_client(self):
"""Mock embeddings client that returns realistic vector embeddings"""
client = AsyncMock()
# New batch format: [[[vectors_for_text1], ...]]
# One text input returns one vector set containing two vectors
client.embed.return_value = [
[
[0.1, 0.2, 0.3, 0.4, 0.5], # First vector for text
[0.6, 0.7, 0.8, 0.9, 1.0] # Second vector for text
]
]
return client
@pytest.fixture
def mock_doc_embeddings_client(self):
"""Mock document embeddings client that returns chunk matches"""
client = AsyncMock()
# Returns ChunkMatch objects with chunk_id and score
client.query.return_value = [
ChunkMatch(chunk_id="doc/c1", score=0.95),
ChunkMatch(chunk_id="doc/c2", score=0.90),
ChunkMatch(chunk_id="doc/c3", score=0.85)
]
return client
@pytest.fixture
def mock_fetch_chunk(self):
"""Mock fetch_chunk function that retrieves chunk content from librarian"""
async def fetch(chunk_id, user):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
return fetch
@pytest.fixture
def mock_prompt_client(self):
"""Mock prompt client that generates realistic responses"""
client = AsyncMock()
client.document_prompt.return_value = PromptResult(
response_type="text",
text=(
"Machine learning is a field of artificial intelligence that enables computers to learn "
"and improve from experience without being explicitly programmed. It uses algorithms "
"to find patterns in data and make predictions or decisions."
)
)
# Mock prompt() for extract-concepts call in DocumentRag
client.prompt.return_value = PromptResult(response_type="text", text="")
return client
@pytest.fixture
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client,
mock_prompt_client, mock_fetch_chunk):
"""Create DocumentRag instance with mocked dependencies"""
return DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=True
)
@pytest.mark.asyncio
async def test_document_rag_end_to_end_flow(self, document_rag, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client):
"""Test complete DocumentRAG pipeline from query to response"""
# Arrange
query = "What is machine learning?"
user = "test_user"
collection = "ml_knowledge"
doc_limit = 10
# Act
result = await document_rag.query(
query=query,
collection=collection,
doc_limit=doc_limit
)
# Assert - Verify service coordination
mock_embeddings_client.embed.assert_called_once_with([query])
mock_doc_embeddings_client.query.assert_called_once_with(
vector=[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
limit=doc_limit,
collection=collection
)
# Documents are fetched from librarian using chunk_ids
mock_prompt_client.document_prompt.assert_called_once_with(
query=query,
documents=[
"Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
"Deep learning uses neural networks with multiple layers to model complex patterns in data.",
"Supervised learning algorithms learn from labeled training data to make predictions on new data."
]
)
# Verify final response
result, usage = result
assert result is not None
assert isinstance(result, str)
assert "machine learning" in result.lower()
assert "artificial intelligence" in result.lower()
@pytest.mark.asyncio
async def test_document_rag_with_no_documents_found(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk):
"""Test DocumentRAG behavior when no documents are retrieved"""
# Arrange
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
mock_prompt_client.document_prompt.return_value = PromptResult(
response_type="text",
text="I couldn't find any relevant documents for your query."
)
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
# Act
result = await document_rag.query("very obscure query")
# Assert
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_called_once()
mock_prompt_client.document_prompt.assert_called_once_with(
query="very obscure query",
documents=[]
)
result_text, usage = result
assert result_text == "I couldn't find any relevant documents for your query."
@pytest.mark.asyncio
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk):
"""Test DocumentRAG error handling when embeddings service fails"""
# Arrange
mock_embeddings_client.embed.side_effect = Exception("Embeddings service unavailable")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
# Act & Assert
with pytest.raises(Exception) as exc_info:
await document_rag.query("test query")
assert "Embeddings service unavailable" in str(exc_info.value)
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_not_called()
mock_prompt_client.document_prompt.assert_not_called()
@pytest.mark.asyncio
async def test_document_rag_document_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk):
"""Test DocumentRAG error handling when document service fails"""
# Arrange
mock_doc_embeddings_client.query.side_effect = Exception("Document service connection failed")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
# Act & Assert
with pytest.raises(Exception) as exc_info:
await document_rag.query("test query")
assert "Document service connection failed" in str(exc_info.value)
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_called_once()
mock_prompt_client.document_prompt.assert_not_called()
@pytest.mark.asyncio
async def test_document_rag_prompt_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk):
"""Test DocumentRAG error handling when prompt service fails"""
# Arrange
mock_prompt_client.document_prompt.side_effect = Exception("LLM service rate limited")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
# Act & Assert
with pytest.raises(Exception) as exc_info:
await document_rag.query("test query")
assert "LLM service rate limited" in str(exc_info.value)
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_called_once()
mock_prompt_client.document_prompt.assert_called_once()
@pytest.mark.asyncio
async def test_document_rag_with_different_document_limits(self, document_rag,
mock_doc_embeddings_client):
"""Test DocumentRAG with various document limit configurations"""
# Test different document limits
test_cases = [1, 5, 10, 25, 50]
for limit in test_cases:
# Reset mock call history
mock_doc_embeddings_client.reset_mock()
# Act
await document_rag.query(f"query with limit {limit}", doc_limit=limit)
# Assert
mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['limit'] == limit
@pytest.mark.asyncio
async def test_document_rag_multi_user_isolation(self, document_rag, mock_doc_embeddings_client):
"""Test DocumentRAG properly isolates queries by user and collection"""
# Arrange
test_scenarios = [
("user1", "collection1"),
("user2", "collection2"),
("user1", "collection2"), # Same user, different collection
("user2", "collection1"), # Different user, same collection
]
for user, collection in test_scenarios:
# Reset mock call history
mock_doc_embeddings_client.reset_mock()
# Act
await document_rag.query(
f"query from {user} in {collection}",
collection=collection
)
# Assert
mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['collection'] == collection
@pytest.mark.asyncio
async def test_document_rag_verbose_logging(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
mock_fetch_chunk,
caplog):
"""Test DocumentRAG verbose logging functionality"""
import logging
# Arrange - Configure logging to capture debug messages
caplog.set_level(logging.DEBUG)
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
fetch_chunk=mock_fetch_chunk,
verbose=True
)
# Act
await document_rag.query("test query for verbose logging")
# Assert - Check for new logging messages
log_messages = caplog.text
assert "DocumentRag initialized" in log_messages
assert "Constructing prompt..." in log_messages
assert "Computing embeddings..." in log_messages
assert "chunks" in log_messages.lower()
assert "Invoking LLM..." in log_messages
assert "Query processing complete" in log_messages
@pytest.mark.asyncio
@pytest.mark.slow
async def test_document_rag_performance_with_large_document_set(self, document_rag,
mock_doc_embeddings_client):
"""Test DocumentRAG performance with large document retrieval"""
# Arrange - Mock large chunk match set (100 chunks)
large_chunk_matches = [ChunkMatch(chunk_id=f"doc/c{i}", score=0.9 - i*0.001) for i in range(100)]
mock_doc_embeddings_client.query.return_value = large_chunk_matches
# Act
import time
start_time = time.time()
result = await document_rag.query("performance test query", doc_limit=100)
end_time = time.time()
execution_time = end_time - start_time
# Assert
assert result is not None
assert execution_time < 5.0 # Should complete within 5 seconds
mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['limit'] == 100
@pytest.mark.asyncio
async def test_document_rag_default_parameters(self, document_rag, mock_doc_embeddings_client):
"""Test DocumentRAG uses correct default parameters"""
# Act
await document_rag.query("test query with defaults")
# Assert
mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['collection'] == "default"
assert call_args.kwargs['limit'] == 20