trustgraph/tests/integration/test_document_rag_integration.py

312 lines
13 KiB
Python
Raw Normal View History

Release/v1.2 (#457) * Bump setup.py versions for 1.1 * PoC MCP server (#419) * Very initial MCP server PoC for TrustGraph * Put service on port 8000 * Add MCP container and packages to buildout * Update docs for API/CLI changes in 1.0 (#421) * Update some API basics for the 0.23/1.0 API change * Add MCP container push (#425) * Add command args to the MCP server (#426) * Host and port parameters * Added websocket arg * More docs * MCP client support (#427) - MCP client service - Tool request/response schema - API gateway support for mcp-tool - Message translation for tool request & response - Make mcp-tool using configuration service for information about where the MCP services are. * Feature/react call mcp (#428) Key Features - MCP Tool Integration: Added core MCP tool support with ToolClientSpec and ToolClient classes - API Enhancement: New mcp_tool method for flow-specific tool invocation - CLI Tooling: New tg-invoke-mcp-tool command for testing MCP integration - React Agent Enhancement: Fixed and improved multi-tool invocation capabilities - Tool Management: Enhanced CLI for tool configuration and management Changes - Added MCP tool invocation to API with flow-specific integration - Implemented ToolClientSpec and ToolClient for tool call handling - Updated agent-manager-react to invoke MCP tools with configurable types - Enhanced CLI with new commands and improved help text - Added comprehensive documentation for new CLI commands - Improved tool configuration management Testing - Added tg-invoke-mcp-tool CLI command for isolated MCP integration testing - Enhanced agent capability to invoke multiple tools simultaneously * Test suite executed from CI pipeline (#433) * Test strategy & test cases * Unit tests * Integration tests * Extending test coverage (#434) * Contract tests * Testing embeedings * Agent unit tests * Knowledge pipeline tests * Turn on contract tests * Increase storage test coverage (#435) * Fixing storage and adding tests * PR pipeline only runs quick tests * Empty configuration is returned as empty list, previously was not in response (#436) * Update config util to take files as well as command-line text (#437) * Updated CLI invocation and config model for tools and mcp (#438) * Updated CLI invocation and config model for tools and mcp * CLI anomalies * Tweaked the MCP tool implementation for new model * Update agent implementation to match the new model * Fix agent tools, now all tested * Fixed integration tests * Fix MCP delete tool params * Update Python deps to 1.2 * Update to enable knowledge extraction using the agent framework (#439) * Implement KG extraction agent (kg-extract-agent) * Using ReAct framework (agent-manager-react) * ReAct manager had an issue when emitting JSON, which conflicts which ReAct manager's own JSON messages, so refactored ReAct manager to use traditional ReAct messages, non-JSON structure. * Minor refactor to take the prompt template client out of prompt-template so it can be more readily used by other modules. kg-extract-agent uses this framework. * Migrate from setup.py to pyproject.toml (#440) * Converted setup.py to pyproject.toml * Modern package infrastructure as recommended by py docs * Install missing build deps (#441) * Install missing build deps (#442) * Implement logging strategy (#444) * Logging strategy and convert all prints() to logging invocations * Fix/startup failure (#445) * Fix loggin startup problems * Fix logging startup problems (#446) * Fix logging startup problems (#447) * Fixed Mistral OCR to use current API (#448) * Fixed Mistral OCR to use current API * Added PDF decoder tests * Fix Mistral OCR ident to be standard pdf-decoder (#450) * Fix Mistral OCR ident to be standard pdf-decoder * Correct test * Schema structure refactor (#451) * Write schema refactor spec * Implemented schema refactor spec * Structure data mvp (#452) * Structured data tech spec * Architecture principles * New schemas * Updated schemas and specs * Object extractor * Add .coveragerc * New tests * Cassandra object storage * Trying to object extraction working, issues exist * Validate librarian collection (#453) * Fix token chunker, broken API invocation (#454) * Fix token chunker, broken API invocation (#455) * Knowledge load utility CLI (#456) * Knowledge loader * More tests
2025-08-18 20:56:09 +01:00
"""
Integration tests for DocumentRAG retrieval system
These tests verify the end-to-end functionality of the DocumentRAG system,
testing the coordination between embeddings, document retrieval, and prompt services.
Following the TEST_STRATEGY.md approach for integration testing.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
@pytest.mark.integration
class TestDocumentRagIntegration:
"""Integration tests for DocumentRAG system coordination"""
@pytest.fixture
def mock_embeddings_client(self):
"""Mock embeddings client that returns realistic vector embeddings"""
client = AsyncMock()
client.embed.return_value = [
[0.1, 0.2, 0.3, 0.4, 0.5], # Realistic 5-dimensional embedding
[0.6, 0.7, 0.8, 0.9, 1.0] # Second embedding for testing
]
return client
@pytest.fixture
def mock_doc_embeddings_client(self):
"""Mock document embeddings client that returns realistic document chunks"""
client = AsyncMock()
client.query.return_value = [
"Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
"Deep learning uses neural networks with multiple layers to model complex patterns in data.",
"Supervised learning algorithms learn from labeled training data to make predictions on new data."
]
return client
@pytest.fixture
def mock_prompt_client(self):
"""Mock prompt client that generates realistic responses"""
client = AsyncMock()
client.document_prompt.return_value = (
"Machine learning is a field of artificial intelligence that enables computers to learn "
"and improve from experience without being explicitly programmed. It uses algorithms "
"to find patterns in data and make predictions or decisions."
)
return client
@pytest.fixture
def document_rag(self, mock_embeddings_client, mock_doc_embeddings_client, mock_prompt_client):
"""Create DocumentRag instance with mocked dependencies"""
return DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
verbose=True
)
@pytest.mark.asyncio
async def test_document_rag_end_to_end_flow(self, document_rag, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client):
"""Test complete DocumentRAG pipeline from query to response"""
# Arrange
query = "What is machine learning?"
user = "test_user"
collection = "ml_knowledge"
doc_limit = 10
# Act
result = await document_rag.query(
query=query,
user=user,
collection=collection,
doc_limit=doc_limit
)
# Assert - Verify service coordination
mock_embeddings_client.embed.assert_called_once_with(query)
mock_doc_embeddings_client.query.assert_called_once_with(
[[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]],
limit=doc_limit,
user=user,
collection=collection
)
mock_prompt_client.document_prompt.assert_called_once_with(
query=query,
documents=[
"Machine learning is a subset of artificial intelligence that focuses on algorithms that learn from data.",
"Deep learning uses neural networks with multiple layers to model complex patterns in data.",
"Supervised learning algorithms learn from labeled training data to make predictions on new data."
]
)
# Verify final response
assert result is not None
assert isinstance(result, str)
assert "machine learning" in result.lower()
assert "artificial intelligence" in result.lower()
@pytest.mark.asyncio
async def test_document_rag_with_no_documents_found(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client):
"""Test DocumentRAG behavior when no documents are retrieved"""
# Arrange
mock_doc_embeddings_client.query.return_value = [] # No documents found
mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query."
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
verbose=False
)
# Act
result = await document_rag.query("very obscure query")
# Assert
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_called_once()
mock_prompt_client.document_prompt.assert_called_once_with(
query="very obscure query",
documents=[]
)
assert result == "I couldn't find any relevant documents for your query."
@pytest.mark.asyncio
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client):
"""Test DocumentRAG error handling when embeddings service fails"""
# Arrange
mock_embeddings_client.embed.side_effect = Exception("Embeddings service unavailable")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
verbose=False
)
# Act & Assert
with pytest.raises(Exception) as exc_info:
await document_rag.query("test query")
assert "Embeddings service unavailable" in str(exc_info.value)
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_not_called()
mock_prompt_client.document_prompt.assert_not_called()
@pytest.mark.asyncio
async def test_document_rag_document_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client):
"""Test DocumentRAG error handling when document service fails"""
# Arrange
mock_doc_embeddings_client.query.side_effect = Exception("Document service connection failed")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
verbose=False
)
# Act & Assert
with pytest.raises(Exception) as exc_info:
await document_rag.query("test query")
assert "Document service connection failed" in str(exc_info.value)
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_called_once()
mock_prompt_client.document_prompt.assert_not_called()
@pytest.mark.asyncio
async def test_document_rag_prompt_service_failure(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client):
"""Test DocumentRAG error handling when prompt service fails"""
# Arrange
mock_prompt_client.document_prompt.side_effect = Exception("LLM service rate limited")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
verbose=False
)
# Act & Assert
with pytest.raises(Exception) as exc_info:
await document_rag.query("test query")
assert "LLM service rate limited" in str(exc_info.value)
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_called_once()
mock_prompt_client.document_prompt.assert_called_once()
@pytest.mark.asyncio
async def test_document_rag_with_different_document_limits(self, document_rag,
mock_doc_embeddings_client):
"""Test DocumentRAG with various document limit configurations"""
# Test different document limits
test_cases = [1, 5, 10, 25, 50]
for limit in test_cases:
# Reset mock call history
mock_doc_embeddings_client.reset_mock()
# Act
await document_rag.query(f"query with limit {limit}", doc_limit=limit)
# Assert
mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['limit'] == limit
@pytest.mark.asyncio
async def test_document_rag_multi_user_isolation(self, document_rag, mock_doc_embeddings_client):
"""Test DocumentRAG properly isolates queries by user and collection"""
# Arrange
test_scenarios = [
("user1", "collection1"),
("user2", "collection2"),
("user1", "collection2"), # Same user, different collection
("user2", "collection1"), # Different user, same collection
]
for user, collection in test_scenarios:
# Reset mock call history
mock_doc_embeddings_client.reset_mock()
# Act
await document_rag.query(
f"query from {user} in {collection}",
user=user,
collection=collection
)
# Assert
mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['user'] == user
assert call_args.kwargs['collection'] == collection
@pytest.mark.asyncio
async def test_document_rag_verbose_logging(self, mock_embeddings_client,
mock_doc_embeddings_client, mock_prompt_client,
caplog):
"""Test DocumentRAG verbose logging functionality"""
import logging
# Arrange - Configure logging to capture debug messages
caplog.set_level(logging.DEBUG)
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
prompt_client=mock_prompt_client,
verbose=True
)
# Act
await document_rag.query("test query for verbose logging")
# Assert - Check for new logging messages
log_messages = caplog.text
assert "DocumentRag initialized" in log_messages
assert "Constructing prompt..." in log_messages
assert "Computing embeddings..." in log_messages
assert "Getting documents..." in log_messages
assert "Invoking LLM..." in log_messages
assert "Query processing complete" in log_messages
@pytest.mark.asyncio
@pytest.mark.slow
async def test_document_rag_performance_with_large_document_set(self, document_rag,
mock_doc_embeddings_client):
"""Test DocumentRAG performance with large document retrieval"""
# Arrange - Mock large document set (100 documents)
large_doc_set = [f"Document {i} content about machine learning and AI" for i in range(100)]
mock_doc_embeddings_client.query.return_value = large_doc_set
# Act
import time
start_time = time.time()
result = await document_rag.query("performance test query", doc_limit=100)
end_time = time.time()
execution_time = end_time - start_time
# Assert
assert result is not None
assert execution_time < 5.0 # Should complete within 5 seconds
mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['limit'] == 100
@pytest.mark.asyncio
async def test_document_rag_default_parameters(self, document_rag, mock_doc_embeddings_client):
"""Test DocumentRAG uses correct default parameters"""
# Act
await document_rag.query("test query with defaults")
# Assert
mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_doc_embeddings_client.query.call_args
assert call_args.kwargs['user'] == "trustgraph"
assert call_args.kwargs['collection'] == "default"
assert call_args.kwargs['limit'] == 20