Enhance retrieval pipelines: 4-stage GraphRAG, DocRAG grounding,

consistent PROV-O

GraphRAG:
- Split retrieval into 4 prompt stages: extract-concepts,
  kg-edge-scoring,
  kg-edge-reasoning, kg-synthesis (was single-stage)
- Add concept extraction (grounding) for per-concept embedding
- Filter main query to default graph, ignoring
  provenance/explainability edges
- Add source document edges to knowledge graph

DocumentRAG:
- Add grounding step with concept extraction, matching GraphRAG's
  pattern:
  Question → Grounding → Exploration → Synthesis
- Per-concept embedding and chunk retrieval with deduplication

Cross-pipeline:
- Make PROV-O derivation links consistent: wasGeneratedBy for first
  entity from Activity, wasDerivedFrom for entity-to-entity chains
- Update CLIs (tg-invoke-agent, tg-invoke-graph-rag,
  tg-invoke-document-rag) for new explainability structure
- Fix all affected unit and integration tests
This commit is contained in:
Cyber MacGeddon 2026-03-14 11:54:10 +00:00
parent 29b4300808
commit 20bb645b9a
25 changed files with 1537 additions and 1008 deletions

View file

@ -125,19 +125,15 @@ class TestQuery:
assert query.doc_limit == 50
@pytest.mark.asyncio
async def test_get_vector_method(self):
"""Test Query.get_vector method calls embeddings client correctly"""
# Create mock DocumentRag with embeddings client
async def test_extract_concepts(self):
"""Test Query.extract_concepts extracts concepts from query"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
# Mock the embed method to return test vectors in batch format
# New format: [[[vectors_for_text1]]] - returns first text's vector set
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
mock_embeddings_client.embed.return_value = [expected_vectors]
# Mock the prompt response with concept lines
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence\ndata patterns"
# Initialize Query
query = Query(
rag=mock_rag,
user="test_user",
@ -145,20 +141,62 @@ class TestQuery:
verbose=False
)
# Call get_vector
test_query = "What documents are relevant?"
result = await query.get_vector(test_query)
result = await query.extract_concepts("What is machine learning?")
# Verify embeddings client was called correctly (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query])
mock_prompt_client.prompt.assert_called_once_with(
"extract-concepts",
variables={"query": "What is machine learning?"}
)
assert result == ["machine learning", "artificial intelligence", "data patterns"]
# Verify result matches expected vectors (extracted from batch)
@pytest.mark.asyncio
async def test_extract_concepts_fallback_to_raw_query(self):
"""Test Query.extract_concepts falls back to raw query when no concepts extracted"""
mock_rag = MagicMock()
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
# Mock empty response
mock_prompt_client.prompt.return_value = ""
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
result = await query.extract_concepts("What is ML?")
assert result == ["What is ML?"]
@pytest.mark.asyncio
async def test_get_vectors_method(self):
"""Test Query.get_vectors method calls embeddings client correctly"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
# Mock the embed method - returns vectors for each concept
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
mock_embeddings_client.embed.return_value = expected_vectors
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False
)
concepts = ["machine learning", "data patterns"]
result = await query.get_vectors(concepts)
mock_embeddings_client.embed.assert_called_once_with(concepts)
assert result == expected_vectors
@pytest.mark.asyncio
async def test_get_docs_method(self):
"""Test Query.get_docs method retrieves documents correctly"""
# Create mock DocumentRag with clients
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
@ -170,10 +208,8 @@ class TestQuery:
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
mock_rag.fetch_chunk = mock_fetch
# Mock the embedding and document query responses
# New batch format: [[[vectors]]] - get_vector extracts [0]
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = [test_vectors]
# Mock embeddings - one vector per concept
mock_embeddings_client.embed.return_value = [[0.1, 0.2, 0.3]]
# Mock document embeddings returns ChunkMatch objects
mock_match1 = MagicMock()
@ -184,7 +220,6 @@ class TestQuery:
mock_match2.score = 0.85
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
# Initialize Query
query = Query(
rag=mock_rag,
user="test_user",
@ -193,16 +228,16 @@ class TestQuery:
doc_limit=15
)
# Call get_docs
test_query = "Find relevant documents"
result = await query.get_docs(test_query)
# Call get_docs with concepts list
concepts = ["test concept"]
result = await query.get_docs(concepts)
# Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with([test_query])
# Verify embeddings client was called with concepts
mock_embeddings_client.embed.assert_called_once_with(concepts)
# Verify doc embeddings client was called correctly (with extracted vector)
# Verify doc embeddings client was called
mock_doc_embeddings_client.query.assert_called_once_with(
vector=test_vectors,
vector=[0.1, 0.2, 0.3],
limit=15,
user="test_user",
collection="test_collection"
@ -218,14 +253,17 @@ class TestQuery:
@pytest.mark.asyncio
async def test_document_rag_query_method(self, mock_fetch_chunk):
"""Test DocumentRag.query method orchestrates full document RAG pipeline"""
# Create mock clients
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock embeddings and document embeddings responses
# New batch format: [[[vectors]]] - get_vector extracts [0]
# Mock concept extraction
mock_prompt_client.prompt.return_value = "test concept"
# Mock embeddings - one vector per concept
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = test_vectors
mock_match1 = MagicMock()
mock_match1.chunk_id = "doc/c3"
mock_match1.score = 0.9
@ -234,11 +272,9 @@ class TestQuery:
mock_match2.score = 0.8
expected_response = "This is the document RAG response"
mock_embeddings_client.embed.return_value = [test_vectors]
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
mock_prompt_client.document_prompt.return_value = expected_response
# Initialize DocumentRag
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
@ -247,7 +283,6 @@ class TestQuery:
verbose=False
)
# Call DocumentRag.query
result = await document_rag.query(
query="test query",
user="test_user",
@ -255,12 +290,18 @@ class TestQuery:
doc_limit=10
)
# Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["test query"])
# Verify concept extraction was called
mock_prompt_client.prompt.assert_called_once_with(
"extract-concepts",
variables={"query": "test query"}
)
# Verify doc embeddings client was called (with extracted vector)
# Verify embeddings called with extracted concepts
mock_embeddings_client.embed.assert_called_once_with(["test concept"])
# Verify doc embeddings client was called
mock_doc_embeddings_client.query.assert_called_once_with(
vector=test_vectors,
vector=[0.1, 0.2, 0.3],
limit=10,
user="test_user",
collection="test_collection"
@ -270,23 +311,23 @@ class TestQuery:
mock_prompt_client.document_prompt.assert_called_once()
call_args = mock_prompt_client.document_prompt.call_args
assert call_args.kwargs["query"] == "test query"
# Documents should be fetched content, not chunk_ids
docs = call_args.kwargs["documents"]
assert "Relevant document content" in docs
assert "Another document" in docs
# Verify result
assert result == expected_response
@pytest.mark.asyncio
async def test_document_rag_query_with_defaults(self, mock_fetch_chunk):
"""Test DocumentRag.query method with default parameters"""
# Create mock clients
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock responses (batch format)
# Mock concept extraction fallback (empty → raw query)
mock_prompt_client.prompt.return_value = ""
# Mock responses
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
mock_match = MagicMock()
mock_match.chunk_id = "doc/c5"
@ -294,7 +335,6 @@ class TestQuery:
mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Default response"
# Initialize DocumentRag
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
@ -302,10 +342,9 @@ class TestQuery:
fetch_chunk=mock_fetch_chunk
)
# Call DocumentRag.query with minimal parameters
result = await document_rag.query("simple query")
# Verify default parameters were used (vector extracted from batch)
# Verify default parameters were used
mock_doc_embeddings_client.query.assert_called_once_with(
vector=[[0.1, 0.2]],
limit=20, # Default doc_limit
@ -318,7 +357,6 @@ class TestQuery:
@pytest.mark.asyncio
async def test_get_docs_with_verbose_output(self):
"""Test Query.get_docs method with verbose logging"""
# Create mock DocumentRag with clients
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
@ -330,14 +368,13 @@ class TestQuery:
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
mock_rag.fetch_chunk = mock_fetch
# Mock responses (batch format)
# Mock responses - one vector per concept
mock_embeddings_client.embed.return_value = [[[0.7, 0.8]]]
mock_match = MagicMock()
mock_match.chunk_id = "doc/c6"
mock_match.score = 0.88
mock_doc_embeddings_client.query.return_value = [mock_match]
# Initialize Query with verbose=True
query = Query(
rag=mock_rag,
user="test_user",
@ -346,14 +383,12 @@ class TestQuery:
doc_limit=5
)
# Call get_docs
result = await query.get_docs("verbose test")
# Call get_docs with concepts
result = await query.get_docs(["verbose test"])
# Verify calls were made (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["verbose test"])
mock_doc_embeddings_client.query.assert_called_once()
# Verify result is tuple of (docs, chunk_ids) with fetched content
docs, chunk_ids = result
assert "Verbose test doc" in docs
assert "doc/c6" in chunk_ids
@ -361,12 +396,14 @@ class TestQuery:
@pytest.mark.asyncio
async def test_document_rag_query_with_verbose(self, mock_fetch_chunk):
"""Test DocumentRag.query method with verbose logging enabled"""
# Create mock clients
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock responses (batch format)
# Mock concept extraction
mock_prompt_client.prompt.return_value = "verbose query test"
# Mock responses
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
mock_match = MagicMock()
mock_match.chunk_id = "doc/c7"
@ -374,7 +411,6 @@ class TestQuery:
mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
# Initialize DocumentRag with verbose=True
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
@ -383,14 +419,11 @@ class TestQuery:
verbose=True
)
# Call DocumentRag.query
result = await document_rag.query("verbose query test")
# Verify all clients were called (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["verbose query test"])
mock_embeddings_client.embed.assert_called_once()
mock_doc_embeddings_client.query.assert_called_once()
# Verify prompt client was called with fetched content
call_args = mock_prompt_client.document_prompt.call_args
assert call_args.kwargs["query"] == "verbose query test"
assert "Verbose doc content" in call_args.kwargs["documents"]
@ -400,23 +433,20 @@ class TestQuery:
@pytest.mark.asyncio
async def test_get_docs_with_empty_results(self):
"""Test Query.get_docs method when no documents are found"""
# Create mock DocumentRag with clients
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
# Mock fetch_chunk (won't be called if no chunk_ids)
async def mock_fetch(chunk_id, user):
return f"Content for {chunk_id}"
mock_rag.fetch_chunk = mock_fetch
# Mock responses - empty chunk_id list (batch format)
# Mock responses - empty results
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
mock_doc_embeddings_client.query.return_value = []
# Initialize Query
query = Query(
rag=mock_rag,
user="test_user",
@ -424,30 +454,27 @@ class TestQuery:
verbose=False
)
# Call get_docs
result = await query.get_docs("query with no results")
result = await query.get_docs(["query with no results"])
# Verify calls were made (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["query with no results"])
mock_doc_embeddings_client.query.assert_called_once()
# Verify empty result is returned (tuple of empty lists)
assert result == ([], [])
@pytest.mark.asyncio
async def test_document_rag_query_with_empty_documents(self, mock_fetch_chunk):
"""Test DocumentRag.query method when no documents are retrieved"""
# Create mock clients
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock responses - no chunk_ids found (batch format)
# Mock concept extraction
mock_prompt_client.prompt.return_value = "query with no matching docs"
mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]]
mock_doc_embeddings_client.query.return_value = [] # Empty chunk_id list
mock_doc_embeddings_client.query.return_value = []
mock_prompt_client.document_prompt.return_value = "No documents found response"
# Initialize DocumentRag
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
@ -456,10 +483,8 @@ class TestQuery:
verbose=False
)
# Call DocumentRag.query
result = await document_rag.query("query with no matching docs")
# Verify prompt client was called with empty document list
mock_prompt_client.document_prompt.assert_called_once_with(
query="query with no matching docs",
documents=[]
@ -468,18 +493,15 @@ class TestQuery:
assert result == "No documents found response"
@pytest.mark.asyncio
async def test_get_vector_with_verbose(self):
"""Test Query.get_vector method with verbose logging"""
# Create mock DocumentRag with embeddings client
async def test_get_vectors_with_verbose(self):
"""Test Query.get_vectors method with verbose logging"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
# Mock the embed method (batch format)
expected_vectors = [[0.9, 1.0, 1.1]]
mock_embeddings_client.embed.return_value = [expected_vectors]
mock_embeddings_client.embed.return_value = expected_vectors
# Initialize Query with verbose=True
query = Query(
rag=mock_rag,
user="test_user",
@ -487,40 +509,40 @@ class TestQuery:
verbose=True
)
# Call get_vector
result = await query.get_vector("verbose vector test")
result = await query.get_vectors(["verbose vector test"])
# Verify embeddings client was called (now expects list)
mock_embeddings_client.embed.assert_called_once_with(["verbose vector test"])
# Verify result (extracted from batch)
assert result == expected_vectors
@pytest.mark.asyncio
async def test_document_rag_integration_flow(self, mock_fetch_chunk):
"""Test complete DocumentRag integration with realistic data flow"""
# Create mock clients
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
# Mock realistic responses (batch format)
query_text = "What is machine learning?"
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5]]
retrieved_chunk_ids = ["doc/ml1", "doc/ml2", "doc/ml3"]
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
mock_embeddings_client.embed.return_value = [query_vectors]
mock_matches = []
for chunk_id in retrieved_chunk_ids:
mock_match = MagicMock()
mock_match.chunk_id = chunk_id
mock_match.score = 0.9
mock_matches.append(mock_match)
mock_doc_embeddings_client.query.return_value = mock_matches
# Mock concept extraction
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence"
# Mock embeddings - one vector per concept
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]]
mock_embeddings_client.embed.return_value = query_vectors
# Each concept query returns some matches
mock_matches_1 = [
MagicMock(chunk_id="doc/ml1", score=0.9),
MagicMock(chunk_id="doc/ml2", score=0.85),
]
mock_matches_2 = [
MagicMock(chunk_id="doc/ml2", score=0.88), # duplicate
MagicMock(chunk_id="doc/ml3", score=0.82),
]
mock_doc_embeddings_client.query.side_effect = [mock_matches_1, mock_matches_2]
mock_prompt_client.document_prompt.return_value = final_response
# Initialize DocumentRag
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
@ -529,7 +551,6 @@ class TestQuery:
verbose=False
)
# Execute full pipeline
result = await document_rag.query(
query=query_text,
user="research_user",
@ -537,26 +558,69 @@ class TestQuery:
doc_limit=25
)
# Verify complete pipeline execution (now expects list)
mock_embeddings_client.embed.assert_called_once_with([query_text])
mock_doc_embeddings_client.query.assert_called_once_with(
vector=query_vectors,
limit=25,
user="research_user",
collection="ml_knowledge"
# Verify concept extraction
mock_prompt_client.prompt.assert_called_once_with(
"extract-concepts",
variables={"query": query_text}
)
# Verify embeddings called with concepts
mock_embeddings_client.embed.assert_called_once_with(
["machine learning", "artificial intelligence"]
)
# Verify two per-concept queries were made (25 // 2 = 12 per concept)
assert mock_doc_embeddings_client.query.call_count == 2
# Verify prompt client was called with fetched document content
mock_prompt_client.document_prompt.assert_called_once()
call_args = mock_prompt_client.document_prompt.call_args
assert call_args.kwargs["query"] == query_text
# Verify documents were fetched from chunk_ids
# Verify documents were fetched and deduplicated
docs = call_args.kwargs["documents"]
assert "Machine learning is a subset of artificial intelligence..." in docs
assert "ML algorithms learn patterns from data to make predictions..." in docs
assert "Common ML techniques include supervised and unsupervised learning..." in docs
assert len(docs) == 3 # doc/ml2 deduplicated
# Verify final result
assert result == final_response
@pytest.mark.asyncio
async def test_get_docs_deduplicates_across_concepts(self):
"""Test that get_docs deduplicates chunks across multiple concepts"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_doc_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
async def mock_fetch(chunk_id, user):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
mock_rag.fetch_chunk = mock_fetch
# Two concepts → two vectors
mock_embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
# Both queries return overlapping chunks
match_a = MagicMock(chunk_id="doc/c1", score=0.9)
match_b = MagicMock(chunk_id="doc/c2", score=0.8)
match_c = MagicMock(chunk_id="doc/c1", score=0.85) # duplicate
mock_doc_embeddings_client.query.side_effect = [
[match_a, match_b],
[match_c],
]
query = Query(
rag=mock_rag,
user="test_user",
collection="test_collection",
verbose=False,
doc_limit=10
)
docs, chunk_ids = await query.get_docs(["concept A", "concept B"])
assert len(chunk_ids) == 2 # doc/c1 only counted once
assert "doc/c1" in chunk_ids
assert "doc/c2" in chunk_ids