trustgraph/tests/unit/test_retrieval/test_document_rag.py
cybermaggedon a115ec06ab
Enhance retrieval pipelines: 4-stage GraphRAG, DocRAG grounding (#697)
Enhance retrieval pipelines: 4-stage GraphRAG, DocRAG grounding,
consistent PROV-O

GraphRAG:
- Split retrieval into 4 prompt stages: extract-concepts,
  kg-edge-scoring,
  kg-edge-reasoning, kg-synthesis (was single-stage)
- Add concept extraction (grounding) for per-concept embedding
- Filter main query to default graph, ignoring
  provenance/explainability edges
- Add source document edges to knowledge graph

DocumentRAG:
- Add grounding step with concept extraction, matching GraphRAG's
  pattern:
  Question → Grounding → Exploration → Synthesis
- Per-concept embedding and chunk retrieval with deduplication

Cross-pipeline:
- Make PROV-O derivation links consistent: wasGeneratedBy for first
  entity from Activity, wasDerivedFrom for entity-to-entity chains
- Update CLIs (tg-invoke-agent, tg-invoke-graph-rag,
  tg-invoke-document-rag) for new explainability structure
- Fix all affected unit and integration tests
2026-03-16 12:12:13 +00:00

626 lines
23 KiB
Python

"""
Tests for DocumentRAG retrieval implementation
"""
import pytest
from unittest.mock import MagicMock, AsyncMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag, Query
# Sample chunk content mapping for tests
CHUNK_CONTENT = {
"doc/c1": "Document 1 content",
"doc/c2": "Document 2 content",
"doc/c3": "Relevant document content",
"doc/c4": "Another document",
"doc/c5": "Default doc",
"doc/c6": "Verbose test doc",
"doc/c7": "Verbose doc content",
"doc/ml1": "Machine learning is a subset of artificial intelligence...",
"doc/ml2": "ML algorithms learn patterns from data to make predictions...",
"doc/ml3": "Common ML techniques include supervised and unsupervised learning...",
}
@pytest.fixture
def mock_fetch_chunk():
"""Create a mock fetch_chunk function"""
async def fetch(chunk_id, user):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
return fetch
class TestDocumentRag:
"""Test cases for DocumentRag class"""
def test_document_rag_initialization_with_defaults(self, mock_fetch_chunk):
"""Test DocumentRag initialization with default verbose setting"""
# Create mock clients
mock_prompt_client = MagicMock()
mock_embeddings_client = MagicMock()
mock_doc_embeddings_client = MagicMock()
# Initialize DocumentRag
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
fetch_chunk=mock_fetch_chunk
)
# Verify initialization
assert document_rag.prompt_client == mock_prompt_client
assert document_rag.embeddings_client == mock_embeddings_client
assert document_rag.doc_embeddings_client == mock_doc_embeddings_client
assert document_rag.fetch_chunk == mock_fetch_chunk
assert document_rag.verbose is False # Default value
def test_document_rag_initialization_with_verbose(self, mock_fetch_chunk):
"""Test DocumentRag initialization with verbose enabled"""
# Create mock clients
mock_prompt_client = MagicMock()
mock_embeddings_client = MagicMock()
mock_doc_embeddings_client = MagicMock()
# Initialize DocumentRag with verbose=True
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
fetch_chunk=mock_fetch_chunk,
verbose=True
)
# Verify initialization
assert document_rag.prompt_client == mock_prompt_client
assert document_rag.embeddings_client == mock_embeddings_client
assert document_rag.doc_embeddings_client == mock_doc_embeddings_client
assert document_rag.fetch_chunk == mock_fetch_chunk
assert document_rag.verbose is True
class TestQuery:
"""Test cases for Query class"""
def test_query_initialization_with_defaults(self):
"""Test Query initialization with default parameters"""
# Create mock DocumentRag
mock_rag = MagicMock()
# Initialize Query with defaults
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
# Verify initialization
assert query.rag == mock_rag
assert query.user == "test_user"
assert query.collection == "test_collection"
assert query.verbose is False
assert query.doc_limit == 20 # Default value
def test_query_initialization_with_custom_doc_limit(self):
"""Test Query initialization with custom doc_limit"""
# Create mock DocumentRag
mock_rag = MagicMock()
# Initialize Query with custom doc_limit
query = Query(
rag=mock_rag,
user="custom_user",
collection="custom_collection",
verbose=True,
doc_limit=50
)
# Verify initialization
assert query.rag == mock_rag
assert query.user == "custom_user"
assert query.collection == "custom_collection"
assert query.verbose is True
assert query.doc_limit == 50
@pytest.mark.asyncio
async def test_extract_concepts(self):
"""Test Query.extract_concepts extracts concepts from query"""
mock_rag = MagicMock()
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
# Mock the prompt response with concept lines
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence\ndata patterns"
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
result = await query.extract_concepts("What is machine learning?")
mock_prompt_client.prompt.assert_called_once_with(
"extract-concepts",
variables={"query": "What is machine learning?"}
)
assert result == ["machine learning", "artificial intelligence", "data patterns"]
@pytest.mark.asyncio
async def test_extract_concepts_fallback_to_raw_query(self):
"""Test Query.extract_concepts falls back to raw query when no concepts extracted"""
mock_rag = MagicMock()
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
# Mock empty response
mock_prompt_client.prompt.return_value = ""
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
result = await query.extract_concepts("What is ML?")
assert result == ["What is ML?"]
@pytest.mark.asyncio
async def test_get_vectors_method(self):
"""Test Query.get_vectors method calls embeddings client correctly"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
# Mock the embed method - returns vectors for each concept
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
mock_embeddings_client.embed.return_value = expected_vectors
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
concepts = ["machine learning", "data patterns"]
result = await query.get_vectors(concepts)
mock_embeddings_client.embed.assert_called_once_with(concepts)
assert result == expected_vectors
@pytest.mark.asyncio
async def test_get_docs_method(self):
"""Test Query.get_docs method retrieves documents correctly"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
# Mock fetch_chunk function
async def mock_fetch(chunk_id, user):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
mock_rag.fetch_chunk = mock_fetch
# Mock embeddings - one vector per concept
mock_embeddings_client.embed.return_value = [[0.1, 0.2, 0.3]]
# Mock document embeddings returns ChunkMatch objects
mock_match1 = MagicMock()
mock_match1.chunk_id = "doc/c1"
mock_match1.score = 0.95
mock_match2 = MagicMock()
mock_match2.chunk_id = "doc/c2"
mock_match2.score = 0.85
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False,
doc_limit=15
)
# Call get_docs with concepts list
concepts = ["test concept"]
result = await query.get_docs(concepts)
# Verify embeddings client was called with concepts
mock_embeddings_client.embed.assert_called_once_with(concepts)
# Verify doc embeddings client was called
mock_doc_embeddings_client.query.assert_called_once_with(
vector=[0.1, 0.2, 0.3],
limit=15,
user="test_user",
collection="test_collection"
)
# Verify result is tuple of (docs, chunk_ids)
docs, chunk_ids = result
assert "Document 1 content" in docs
assert "Document 2 content" in docs
assert "doc/c1" in chunk_ids
assert "doc/c2" in chunk_ids
@pytest.mark.asyncio
async def test_document_rag_query_method(self, mock_fetch_chunk):
"""Test DocumentRag.query method orchestrates full document RAG pipeline"""
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction
mock_prompt_client.prompt.return_value = "test concept"
# Mock embeddings - one vector per concept
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = test_vectors
mock_match1 = MagicMock()
mock_match1.chunk_id = "doc/c3"
mock_match1.score = 0.9
mock_match2 = MagicMock()
mock_match2.chunk_id = "doc/c4"
mock_match2.score = 0.8
expected_response = "This is the document RAG response"
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
mock_prompt_client.document_prompt.return_value = expected_response
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
result = await document_rag.query(
query="test query",
user="test_user",
collection="test_collection",
doc_limit=10
)
# Verify concept extraction was called
mock_prompt_client.prompt.assert_called_once_with(
"extract-concepts",
variables={"query": "test query"}
)
# Verify embeddings called with extracted concepts
mock_embeddings_client.embed.assert_called_once_with(["test concept"])
# Verify doc embeddings client was called
mock_doc_embeddings_client.query.assert_called_once_with(
vector=[0.1, 0.2, 0.3],
limit=10,
user="test_user",
collection="test_collection"
)
# Verify prompt client was called with fetched documents and query
mock_prompt_client.document_prompt.assert_called_once()
call_args = mock_prompt_client.document_prompt.call_args
assert call_args.kwargs["query"] == "test query"
docs = call_args.kwargs["documents"]
assert "Relevant document content" in docs
assert "Another document" in docs
assert result == expected_response
@pytest.mark.asyncio
async def test_document_rag_query_with_defaults(self, mock_fetch_chunk):
"""Test DocumentRag.query method with default parameters"""
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction fallback (empty → raw query)
mock_prompt_client.prompt.return_value = ""
# Mock responses
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
mock_match = MagicMock()
mock_match.chunk_id = "doc/c5"
mock_match.score = 0.9
mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Default response"
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
fetch_chunk=mock_fetch_chunk
)
result = await document_rag.query("simple query")
# Verify default parameters were used
mock_doc_embeddings_client.query.assert_called_once_with(
vector=[[0.1, 0.2]],
limit=20, # Default doc_limit
user="trustgraph", # Default user
collection="default" # Default collection
)
assert result == "Default response"
@pytest.mark.asyncio
async def test_get_docs_with_verbose_output(self):
"""Test Query.get_docs method with verbose logging"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
# Mock fetch_chunk
async def mock_fetch(chunk_id, user):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
mock_rag.fetch_chunk = mock_fetch
# Mock responses - one vector per concept
mock_embeddings_client.embed.return_value = [[[0.7, 0.8]]]
mock_match = MagicMock()
mock_match.chunk_id = "doc/c6"
mock_match.score = 0.88
mock_doc_embeddings_client.query.return_value = [mock_match]
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=True,
doc_limit=5
)
# Call get_docs with concepts
result = await query.get_docs(["verbose test"])
mock_embeddings_client.embed.assert_called_once_with(["verbose test"])
mock_doc_embeddings_client.query.assert_called_once()
docs, chunk_ids = result
assert "Verbose test doc" in docs
assert "doc/c6" in chunk_ids
@pytest.mark.asyncio
async def test_document_rag_query_with_verbose(self, mock_fetch_chunk):
"""Test DocumentRag.query method with verbose logging enabled"""
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction
mock_prompt_client.prompt.return_value = "verbose query test"
# Mock responses
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
mock_match = MagicMock()
mock_match.chunk_id = "doc/c7"
mock_match.score = 0.92
mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
fetch_chunk=mock_fetch_chunk,
verbose=True
)
result = await document_rag.query("verbose query test")
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_called_once()
call_args = mock_prompt_client.document_prompt.call_args
assert call_args.kwargs["query"] == "verbose query test"
assert "Verbose doc content" in call_args.kwargs["documents"]
assert result == "Verbose RAG response"
@pytest.mark.asyncio
async def test_get_docs_with_empty_results(self):
"""Test Query.get_docs method when no documents are found"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
async def mock_fetch(chunk_id, user):
return f"Content for {chunk_id}"
mock_rag.fetch_chunk = mock_fetch
# Mock responses - empty results
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
mock_doc_embeddings_client.query.return_value = []
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
result = await query.get_docs(["query with no results"])
mock_embeddings_client.embed.assert_called_once_with(["query with no results"])
mock_doc_embeddings_client.query.assert_called_once()
assert result == ([], [])
@pytest.mark.asyncio
async def test_document_rag_query_with_empty_documents(self, mock_fetch_chunk):
"""Test DocumentRag.query method when no documents are retrieved"""
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction
mock_prompt_client.prompt.return_value = "query with no matching docs"
mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]]
mock_doc_embeddings_client.query.return_value = []
mock_prompt_client.document_prompt.return_value = "No documents found response"
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
result = await document_rag.query("query with no matching docs")
mock_prompt_client.document_prompt.assert_called_once_with(
query="query with no matching docs",
documents=[]
)
assert result == "No documents found response"
@pytest.mark.asyncio
async def test_get_vectors_with_verbose(self):
"""Test Query.get_vectors method with verbose logging"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
expected_vectors = [[0.9, 1.0, 1.1]]
mock_embeddings_client.embed.return_value = expected_vectors
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=True
)
result = await query.get_vectors(["verbose vector test"])
mock_embeddings_client.embed.assert_called_once_with(["verbose vector test"])
assert result == expected_vectors
@pytest.mark.asyncio
async def test_document_rag_integration_flow(self, mock_fetch_chunk):
"""Test complete DocumentRag integration with realistic data flow"""
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
query_text = "What is machine learning?"
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
# Mock concept extraction
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence"
# Mock embeddings - one vector per concept
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]]
mock_embeddings_client.embed.return_value = query_vectors
# Each concept query returns some matches
mock_matches_1 = [
MagicMock(chunk_id="doc/ml1", score=0.9),
MagicMock(chunk_id="doc/ml2", score=0.85),
]
mock_matches_2 = [
MagicMock(chunk_id="doc/ml2", score=0.88), # duplicate
MagicMock(chunk_id="doc/ml3", score=0.82),
]
mock_doc_embeddings_client.query.side_effect = [mock_matches_1, mock_matches_2]
mock_prompt_client.document_prompt.return_value = final_response
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
doc_embeddings_client=mock_doc_embeddings_client,
fetch_chunk=mock_fetch_chunk,
verbose=False
)
result = await document_rag.query(
query=query_text,
user="research_user",
collection="ml_knowledge",
doc_limit=25
)
# Verify concept extraction
mock_prompt_client.prompt.assert_called_once_with(
"extract-concepts",
variables={"query": query_text}
)
# Verify embeddings called with concepts
mock_embeddings_client.embed.assert_called_once_with(
["machine learning", "artificial intelligence"]
)
# Verify two per-concept queries were made (25 // 2 = 12 per concept)
assert mock_doc_embeddings_client.query.call_count == 2
# Verify prompt client was called with fetched document content
mock_prompt_client.document_prompt.assert_called_once()
call_args = mock_prompt_client.document_prompt.call_args
assert call_args.kwargs["query"] == query_text
# Verify documents were fetched and deduplicated
docs = call_args.kwargs["documents"]
assert "Machine learning is a subset of artificial intelligence..." in docs
assert "ML algorithms learn patterns from data to make predictions..." in docs
assert "Common ML techniques include supervised and unsupervised learning..." in docs
assert len(docs) == 3 # doc/ml2 deduplicated
assert result == final_response
@pytest.mark.asyncio
async def test_get_docs_deduplicates_across_concepts(self):
"""Test that get_docs deduplicates chunks across multiple concepts"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
async def mock_fetch(chunk_id, user):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
mock_rag.fetch_chunk = mock_fetch
# Two concepts → two vectors
mock_embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
# Both queries return overlapping chunks
match_a = MagicMock(chunk_id="doc/c1", score=0.9)
match_b = MagicMock(chunk_id="doc/c2", score=0.8)
match_c = MagicMock(chunk_id="doc/c1", score=0.85) # duplicate
mock_doc_embeddings_client.query.side_effect = [
[match_a, match_b],
[match_c],
]
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False,
doc_limit=10
)
docs, chunk_ids = await query.get_docs(["concept A", "concept B"])
assert len(chunk_ids) == 2 # doc/c1 only counted once
assert "doc/c1" in chunk_ids
assert "doc/c2" in chunk_ids