trustgraph/tests/unit/test_retrieval/test_graph_rag.py
cybermaggedon 01cc8dbc64
feat: replace LLM edge scoring with cross-encoder reranker in GraphRAG (#1005)
Replace the three-prompt LLM scoring pipeline (kg-edge-scoring,
kg-edge-reasoning, kg-edge-selection) with a cross-encoder reranker
service backed by FlashRank. The new hop_and_filter() method performs
iterative graph traversal with semantic scoring at each hop, replacing
the previous follow_edges/get_subgraph approach.

- Add reranker service (trustgraph-base client/service, FlashRank processor)
- Add gateway dispatch for reranker via API and WebSocket
- Rewrite GraphRAG pipeline: hop_and_filter() with per-hop cross-encoder scoring
- Remove kg_prompt() and edge_score_limit from prompt client
- Update provenance: add tg:EdgeSelection type, tg:concept, tg:score predicates
- Update CLIs (tg-invoke-graph-rag, tg-show-explain-trace) for new metadata
- Add tg-invoke-reranker CLI tool
- Add tech spec and UX developer guidance
- Update all unit and integration tests
2026-06-30 14:36:37 +01:00

576 lines
20 KiB
Python

"""
Tests for GraphRAG retrieval implementation
"""
import pytest
import unittest.mock
from unittest.mock import MagicMock, AsyncMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, Query
from trustgraph.base import PromptResult
class TestGraphRag:
"""Test cases for GraphRag class"""
def test_graph_rag_initialization_with_defaults(self):
"""Test GraphRag initialization with default verbose setting"""
mock_prompt_client = MagicMock()
mock_embeddings_client = MagicMock()
mock_graph_embeddings_client = MagicMock()
mock_triples_client = MagicMock()
mock_reranker_client = MagicMock()
graph_rag = GraphRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
reranker_client=mock_reranker_client,
)
assert graph_rag.prompt_client == mock_prompt_client
assert graph_rag.embeddings_client == mock_embeddings_client
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
assert graph_rag.triples_client == mock_triples_client
assert graph_rag.reranker_client == mock_reranker_client
assert graph_rag.verbose is False
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
def test_graph_rag_initialization_with_verbose(self):
"""Test GraphRag initialization with verbose enabled"""
mock_prompt_client = MagicMock()
mock_embeddings_client = MagicMock()
mock_graph_embeddings_client = MagicMock()
mock_triples_client = MagicMock()
mock_reranker_client = MagicMock()
graph_rag = GraphRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
reranker_client=mock_reranker_client,
verbose=True,
)
assert graph_rag.prompt_client == mock_prompt_client
assert graph_rag.embeddings_client == mock_embeddings_client
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
assert graph_rag.triples_client == mock_triples_client
assert graph_rag.reranker_client == mock_reranker_client
assert graph_rag.verbose is True
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
class TestQuery:
"""Test cases for Query class"""
def test_query_initialization_with_defaults(self):
"""Test Query initialization with default parameters"""
# Create mock GraphRag
mock_rag = MagicMock()
# Initialize Query with defaults
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False
)
# Verify initialization
assert query.rag == mock_rag
assert query.collection == "test_collection"
assert query.verbose is False
assert query.entity_limit == 50 # Default value
assert query.triple_limit == 30 # Default value
assert query.max_subgraph_size == 1000 # Default value
assert query.max_path_length == 2 # Default value
def test_query_initialization_with_custom_params(self):
"""Test Query initialization with custom parameters"""
# Create mock GraphRag
mock_rag = MagicMock()
# Initialize Query with custom parameters
query = Query(
rag=mock_rag,
collection="custom_collection",
verbose=True,
entity_limit=100,
triple_limit=60,
max_subgraph_size=2000,
max_path_length=3
)
# Verify initialization
assert query.rag == mock_rag
assert query.collection == "custom_collection"
assert query.verbose is True
assert query.entity_limit == 100
assert query.triple_limit == 60
assert query.max_subgraph_size == 2000
assert query.max_path_length == 3
@pytest.mark.asyncio
async def test_get_vectors_method(self):
"""Test Query.get_vectors method calls embeddings client correctly"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
# Mock embed to return vectors for a list of concepts
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
mock_embeddings_client.embed.return_value = expected_vectors
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False
)
concepts = ["machine learning", "neural networks"]
result = await query.get_vectors(concepts)
mock_embeddings_client.embed.assert_called_once_with(concepts)
assert result == expected_vectors
@pytest.mark.asyncio
async def test_get_vectors_method_with_verbose(self):
"""Test Query.get_vectors method with verbose output"""
mock_rag = MagicMock()
mock_embeddings_client = AsyncMock()
mock_rag.embeddings_client = mock_embeddings_client
expected_vectors = [[0.7, 0.8, 0.9]]
mock_embeddings_client.embed.return_value = expected_vectors
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=True
)
result = await query.get_vectors(["test concept"])
mock_embeddings_client.embed.assert_called_once_with(["test concept"])
assert result == expected_vectors
@pytest.mark.asyncio
async def test_extract_concepts(self):
"""Test Query.extract_concepts parses LLM response into concept list"""
mock_rag = MagicMock()
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nneural networks\n")
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False
)
result = await query.extract_concepts("What is machine learning?")
mock_prompt_client.prompt.assert_called_once_with(
"extract-concepts",
variables={"query": "What is machine learning?"}
)
assert result == ["machine learning", "neural networks"]
@pytest.mark.asyncio
async def test_extract_concepts_fallback_to_raw_query(self):
"""Test extract_concepts falls back to raw query when LLM returns empty"""
mock_rag = MagicMock()
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False
)
result = await query.extract_concepts("test query")
assert result == ["test query"]
@pytest.mark.asyncio
async def test_get_entities_method(self):
"""Test Query.get_entities extracts concepts, embeds, and retrieves entities"""
mock_rag = MagicMock()
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_graph_embeddings_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
mock_rag.embeddings_client = mock_embeddings_client
mock_rag.graph_embeddings_client = mock_graph_embeddings_client
# extract_concepts returns empty -> falls back to [query]
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
# embed returns one vector set for the single concept
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = test_vectors
# Mock entity matches
mock_entity1 = MagicMock()
mock_entity1.type = "i"
mock_entity1.iri = "entity1"
mock_match1 = MagicMock()
mock_match1.entity = mock_entity1
mock_entity2 = MagicMock()
mock_entity2.type = "i"
mock_entity2.iri = "entity2"
mock_match2 = MagicMock()
mock_match2.entity = mock_entity2
mock_graph_embeddings_client.query.return_value = [mock_match1, mock_match2]
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
entity_limit=25
)
entities, concepts = await query.get_entities("Find related entities")
# Verify embeddings client was called with the fallback concept
mock_embeddings_client.embed.assert_called_once_with(["Find related entities"])
# Verify result
assert entities == ["entity1", "entity2"]
assert concepts == ["Find related entities"]
@pytest.mark.asyncio
async def test_maybe_label_with_cached_label(self):
"""Test Query.maybe_label method with cached label"""
mock_rag = MagicMock()
mock_cache = MagicMock()
mock_cache.get.return_value = "Entity One Label"
mock_rag.label_cache = mock_cache
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False
)
result = await query.maybe_label("entity1")
assert result == "Entity One Label"
mock_cache.get.assert_called_once_with("test_collection:entity1")
@pytest.mark.asyncio
async def test_maybe_label_with_label_lookup(self):
"""Test Query.maybe_label method with database label lookup"""
mock_rag = MagicMock()
mock_cache = MagicMock()
mock_cache.get.return_value = None
mock_rag.label_cache = mock_cache
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_triple = MagicMock()
mock_triple.o = "Human Readable Label"
mock_triples_client.query.return_value = [mock_triple]
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False
)
result = await query.maybe_label("http://example.com/entity")
mock_triples_client.query.assert_called_once_with(
s="http://example.com/entity",
p="http://www.w3.org/2000/01/rdf-schema#label",
o=None,
limit=1,
collection="test_collection",
g=""
)
assert result == "Human Readable Label"
cache_key = "test_collection:http://example.com/entity"
mock_cache.put.assert_called_once_with(cache_key, "Human Readable Label")
@pytest.mark.asyncio
async def test_maybe_label_with_no_label_found(self):
"""Test Query.maybe_label method when no label is found"""
mock_rag = MagicMock()
mock_cache = MagicMock()
mock_cache.get.return_value = None
mock_rag.label_cache = mock_cache
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_triples_client.query.return_value = []
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False
)
result = await query.maybe_label("unlabeled_entity")
mock_triples_client.query.assert_called_once_with(
s="unlabeled_entity",
p="http://www.w3.org/2000/01/rdf-schema#label",
o=None,
limit=1,
collection="test_collection",
g=""
)
assert result == "unlabeled_entity"
cache_key = "test_collection:unlabeled_entity"
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
@pytest.mark.asyncio
async def test_triples_query_never_passes_workspace(self):
"""Workspace isolation is handled by pub/sub topic routing, not
by passing workspace to TriplesClient.query(). Verify that
GraphRAG never passes workspace as a keyword argument."""
mock_rag = MagicMock()
mock_cache = MagicMock()
mock_cache.get.return_value = None
mock_rag.label_cache = mock_cache
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_triple = MagicMock()
mock_triple.o = "Label"
mock_triples_client.query.return_value = [mock_triple]
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False
)
await query.maybe_label("http://example.com/entity")
for c in mock_triples_client.query.call_args_list:
assert "workspace" not in c.kwargs
@pytest.mark.asyncio
async def test_hop_and_filter_never_passes_workspace(self):
"""Verify hop_and_filter never passes workspace to query_stream."""
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_reranker_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_rag.reranker_client = mock_reranker_client
mock_rag.label_cache = MagicMock()
mock_rag.label_cache.get.return_value = None
mock_triple = MagicMock()
mock_triple.s = "e1"
mock_triple.p = "p1"
mock_triple.o = "o1"
mock_triples_client.query_stream.return_value = [mock_triple]
mock_triples_client.query.return_value = []
result = MagicMock()
result.document_id = "0"
result.query_id = "0"
result.score = 0.9
mock_reranker_client.rerank.return_value = [result]
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
triple_limit=10,
)
await query.hop_and_filter(["e1"], ["concept"])
for c in mock_triples_client.query_stream.call_args_list:
assert "workspace" not in c.kwargs
@pytest.mark.asyncio
async def test_hop_and_filter_basic_functionality(self):
"""Test hop_and_filter retrieves edges and scores them with reranker."""
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_reranker_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_rag.reranker_client = mock_reranker_client
mock_rag.label_cache = MagicMock()
mock_rag.label_cache.get.return_value = None
mock_triple = MagicMock()
mock_triple.s = "entity1"
mock_triple.p = "predicate1"
mock_triple.o = "object1"
mock_triples_client.query_stream.return_value = [mock_triple]
mock_triples_client.query.return_value = []
result = MagicMock()
result.document_id = "0"
result.query_id = "0"
result.score = 0.95
mock_reranker_client.rerank.return_value = [result]
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
triple_limit=10,
edge_limit=25,
)
selected, uri_map, edge_meta = await query.hop_and_filter(
["entity1"], ["test concept"],
)
assert len(selected) == 1
assert len(uri_map) == 1
assert len(edge_meta) == 1
mock_reranker_client.rerank.assert_called_once()
call_kwargs = mock_reranker_client.rerank.call_args
assert call_kwargs.kwargs["limit"] == 25
@pytest.mark.asyncio
async def test_hop_and_filter_with_empty_frontier(self):
"""Test hop_and_filter with no seed entities returns empty."""
mock_rag = MagicMock()
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
)
selected, uri_map, edge_meta = await query.hop_and_filter([], ["concept"])
assert selected == []
assert uri_map == {}
assert edge_meta == {}
@pytest.mark.asyncio
async def test_hop_and_filter_filters_label_triples(self):
"""Test hop_and_filter skips rdfs:label edges."""
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_reranker_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_rag.reranker_client = mock_reranker_client
mock_rag.label_cache = MagicMock()
mock_rag.label_cache.get.return_value = None
label_triple = MagicMock()
label_triple.s = "entity1"
label_triple.p = "http://www.w3.org/2000/01/rdf-schema#label"
label_triple.o = "Entity One"
mock_triples_client.query_stream.return_value = [label_triple]
mock_triples_client.query.return_value = []
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
triple_limit=10,
)
selected, uri_map, edge_meta = await query.hop_and_filter(
["entity1"], ["concept"],
)
assert selected == []
mock_reranker_client.rerank.assert_not_called()
@pytest.mark.asyncio
async def test_graph_rag_query_method(self):
"""Test GraphRag.query method orchestrates full RAG pipeline with provenance"""
from trustgraph.retrieval.graph_rag.graph_rag import edge_id
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_graph_embeddings_client = AsyncMock()
mock_triples_client = AsyncMock()
mock_reranker_client = AsyncMock()
expected_response = "This is the RAG response"
test_selected_edges = [("Subject", "Predicate", "Object")]
test_eid = edge_id("Subject", "Predicate", "Object")
test_uri_map = {
test_eid: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
}
test_edge_metadata = {
test_eid: {"concept": "test concept", "score": 0.95}
}
mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
mock_graph_embeddings_client.query.return_value = []
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "extract-concepts":
return PromptResult(response_type="text", text="test concept")
elif prompt_name == "kg-synthesis":
return PromptResult(response_type="text", text=expected_response)
return PromptResult(response_type="text", text="")
mock_prompt_client.prompt = mock_prompt
graph_rag = GraphRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
reranker_client=mock_reranker_client,
verbose=False,
)
original_hop_and_filter = Query.hop_and_filter
async def mock_hop_and_filter(self, seed_entities, concepts):
return test_selected_edges, test_uri_map, test_edge_metadata
Query.hop_and_filter = mock_hop_and_filter
provenance_events = []
async def collect_provenance(triples, prov_id):
provenance_events.append((triples, prov_id))
try:
response = await graph_rag.query(
query="test query",
collection="test_collection",
entity_limit=25,
triple_limit=15,
explain_callback=collect_provenance,
)
response_text, usage = response
assert response_text == expected_response
# 5 events: question, grounding, exploration, focus, synthesis
assert len(provenance_events) == 5
for triples, prov_id in provenance_events:
assert isinstance(triples, list)
assert len(triples) > 0
assert prov_id.startswith("urn:trustgraph:")
assert "question" in provenance_events[0][1]
assert "grounding" in provenance_events[1][1]
assert "exploration" in provenance_events[2][1]
assert "focus" in provenance_events[3][1]
assert "synthesis" in provenance_events[4][1]
finally:
Query.hop_and_filter = original_hop_and_filter