feat: replace LLM edge scoring with cross-encoder reranker in GraphRAG

Replace the three-prompt LLM scoring pipeline (kg-edge-scoring,
kg-edge-reasoning, kg-edge-selection) with a cross-encoder reranker
service backed by FlashRank. The new hop_and_filter() method performs
iterative graph traversal with semantic scoring at each hop, replacing
the previous follow_edges/get_subgraph approach.

- Add reranker service (trustgraph-base client/service, FlashRank processor)
- Add gateway dispatch for reranker via API and WebSocket
- Rewrite GraphRAG pipeline: hop_and_filter() with per-hop cross-encoder scoring
- Remove kg_prompt() and edge_score_limit from prompt client
- Update provenance: add tg:EdgeSelection type, tg:concept, tg:score predicates
- Update CLIs (tg-invoke-graph-rag, tg-show-explain-trace) for new metadata
- Add tg-invoke-reranker CLI tool
- Add tech spec and UX developer guidance
- Update all unit and integration tests
This commit is contained in:
Cyber MacGeddon 2026-06-30 09:39:35 +01:00
parent 1aa9549912
commit 1346cbebb4
43 changed files with 1613 additions and 792 deletions

View file

@ -15,54 +15,52 @@ class TestGraphRag:
def test_graph_rag_initialization_with_defaults(self):
"""Test GraphRag initialization with default verbose setting"""
# Create mock clients
mock_prompt_client = MagicMock()
mock_embeddings_client = MagicMock()
mock_graph_embeddings_client = MagicMock()
mock_triples_client = MagicMock()
mock_reranker_client = MagicMock()
# Initialize GraphRag
graph_rag = GraphRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client
)
# Verify initialization
assert graph_rag.prompt_client == mock_prompt_client
assert graph_rag.embeddings_client == mock_embeddings_client
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
assert graph_rag.triples_client == mock_triples_client
assert graph_rag.verbose is False # Default value
# Verify label_cache is an LRUCacheWithTTL instance
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
def test_graph_rag_initialization_with_verbose(self):
"""Test GraphRag initialization with verbose enabled"""
# Create mock clients
mock_prompt_client = MagicMock()
mock_embeddings_client = MagicMock()
mock_graph_embeddings_client = MagicMock()
mock_triples_client = MagicMock()
# Initialize GraphRag with verbose=True
graph_rag = GraphRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
verbose=True
reranker_client=mock_reranker_client,
)
# Verify initialization
assert graph_rag.prompt_client == mock_prompt_client
assert graph_rag.embeddings_client == mock_embeddings_client
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
assert graph_rag.triples_client == mock_triples_client
assert graph_rag.reranker_client == mock_reranker_client
assert graph_rag.verbose is False
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
def test_graph_rag_initialization_with_verbose(self):
"""Test GraphRag initialization with verbose enabled"""
mock_prompt_client = MagicMock()
mock_embeddings_client = MagicMock()
mock_graph_embeddings_client = MagicMock()
mock_triples_client = MagicMock()
mock_reranker_client = MagicMock()
graph_rag = GraphRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
reranker_client=mock_reranker_client,
verbose=True,
)
assert graph_rag.prompt_client == mock_prompt_client
assert graph_rag.embeddings_client == mock_embeddings_client
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
assert graph_rag.triples_client == mock_triples_client
assert graph_rag.reranker_client == mock_reranker_client
assert graph_rag.verbose is True
# Verify label_cache is an LRUCacheWithTTL instance
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
@ -365,244 +363,162 @@ class TestQuery:
assert "workspace" not in c.kwargs
@pytest.mark.asyncio
async def test_follow_edges_never_passes_workspace(self):
"""Verify follow_edges never passes workspace to query_stream."""
async def test_hop_and_filter_never_passes_workspace(self):
"""Verify hop_and_filter never passes workspace to query_stream."""
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_reranker_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_rag.reranker_client = mock_reranker_client
mock_rag.label_cache = MagicMock()
mock_rag.label_cache.get.return_value = None
mock_triple = MagicMock()
mock_triple.s, mock_triple.p, mock_triple.o = "e1", "p1", "o1"
mock_triple.s = "e1"
mock_triple.p = "p1"
mock_triple.o = "o1"
mock_triples_client.query_stream.return_value = [mock_triple]
mock_triples_client.query.return_value = []
result = MagicMock()
result.document_id = "0"
result.query_id = "0"
result.score = 0.9
mock_reranker_client.rerank.return_value = [result]
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
triple_limit=10
triple_limit=10,
)
subgraph = set()
await query.follow_edges("e1", subgraph, path_length=1)
await query.hop_and_filter(["e1"], ["concept"])
for c in mock_triples_client.query_stream.call_args_list:
assert "workspace" not in c.kwargs
@pytest.mark.asyncio
async def test_follow_edges_basic_functionality(self):
"""Test Query.follow_edges method basic triple discovery"""
async def test_hop_and_filter_basic_functionality(self):
"""Test hop_and_filter retrieves edges and scores them with reranker."""
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_reranker_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_rag.reranker_client = mock_reranker_client
mock_rag.label_cache = MagicMock()
mock_rag.label_cache.get.return_value = None
mock_triple1 = MagicMock()
mock_triple1.s, mock_triple1.p, mock_triple1.o = "entity1", "predicate1", "object1"
mock_triple = MagicMock()
mock_triple.s = "entity1"
mock_triple.p = "predicate1"
mock_triple.o = "object1"
mock_triples_client.query_stream.return_value = [mock_triple]
mock_triples_client.query.return_value = []
mock_triple2 = MagicMock()
mock_triple2.s, mock_triple2.p, mock_triple2.o = "subject2", "entity1", "object2"
mock_triple3 = MagicMock()
mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1"
mock_triples_client.query_stream.side_effect = [
[mock_triple1], # s=ent
[mock_triple2], # p=ent
[mock_triple3], # o=ent
]
result = MagicMock()
result.document_id = "0"
result.query_id = "0"
result.score = 0.95
mock_reranker_client.rerank.return_value = [result]
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
triple_limit=10
triple_limit=10,
edge_limit=25,
)
subgraph = set()
await query.follow_edges("entity1", subgraph, path_length=1)
assert mock_triples_client.query_stream.call_count == 3
mock_triples_client.query_stream.assert_any_call(
s="entity1", p=None, o=None, limit=10,
collection="test_collection", batch_size=20, g=""
)
mock_triples_client.query_stream.assert_any_call(
s=None, p="entity1", o=None, limit=10,
collection="test_collection", batch_size=20, g=""
)
mock_triples_client.query_stream.assert_any_call(
s=None, p=None, o="entity1", limit=10,
collection="test_collection", batch_size=20, g=""
selected, uri_map, edge_meta = await query.hop_and_filter(
["entity1"], ["test concept"],
)
expected_subgraph = {
("entity1", "predicate1", "object1"),
("subject2", "entity1", "object2"),
("subject3", "predicate3", "entity1")
}
assert subgraph == expected_subgraph
assert len(selected) == 1
assert len(uri_map) == 1
assert len(edge_meta) == 1
mock_reranker_client.rerank.assert_called_once()
call_kwargs = mock_reranker_client.rerank.call_args
assert call_kwargs.kwargs["limit"] == 25
@pytest.mark.asyncio
async def test_follow_edges_with_path_length_zero(self):
"""Test Query.follow_edges method with path_length=0"""
async def test_hop_and_filter_with_empty_frontier(self):
"""Test hop_and_filter with no seed entities returns empty."""
mock_rag = MagicMock()
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
)
selected, uri_map, edge_meta = await query.hop_and_filter([], ["concept"])
assert selected == []
assert uri_map == {}
assert edge_meta == {}
@pytest.mark.asyncio
async def test_hop_and_filter_filters_label_triples(self):
"""Test hop_and_filter skips rdfs:label edges."""
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_reranker_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_rag.reranker_client = mock_reranker_client
mock_rag.label_cache = MagicMock()
mock_rag.label_cache.get.return_value = None
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False
)
label_triple = MagicMock()
label_triple.s = "entity1"
label_triple.p = "http://www.w3.org/2000/01/rdf-schema#label"
label_triple.o = "Entity One"
subgraph = set()
await query.follow_edges("entity1", subgraph, path_length=0)
mock_triples_client.query_stream.assert_not_called()
assert subgraph == set()
@pytest.mark.asyncio
async def test_follow_edges_with_max_subgraph_size_limit(self):
"""Test Query.follow_edges method respects max_subgraph_size"""
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_triples_client.query_stream.return_value = [label_triple]
mock_triples_client.query.return_value = []
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
max_subgraph_size=2
triple_limit=10,
)
subgraph = {("s1", "p1", "o1"), ("s2", "p2", "o2"), ("s3", "p3", "o3")}
await query.follow_edges("entity1", subgraph, path_length=1)
mock_triples_client.query_stream.assert_not_called()
assert len(subgraph) == 3
@pytest.mark.asyncio
async def test_get_subgraph_method(self):
"""Test Query.get_subgraph returns (subgraph, entities, concepts) tuple"""
mock_rag = MagicMock()
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
max_path_length=1
selected, uri_map, edge_meta = await query.hop_and_filter(
["entity1"], ["concept"],
)
# Mock get_entities to return (entities, concepts) tuple
query.get_entities = AsyncMock(
return_value=(["entity1", "entity2"], ["concept1"])
)
query.follow_edges_batch = AsyncMock(return_value=(
{
("entity1", "predicate1", "object1"),
("entity2", "predicate2", "object2")
},
{}
))
subgraph, term_map, entities, concepts = await query.get_subgraph("test query")
query.get_entities.assert_called_once_with("test query")
query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1)
assert isinstance(subgraph, list)
assert len(subgraph) == 2
assert ("entity1", "predicate1", "object1") in subgraph
assert ("entity2", "predicate2", "object2") in subgraph
assert entities == ["entity1", "entity2"]
assert concepts == ["concept1"]
@pytest.mark.asyncio
async def test_get_labelgraph_method(self):
"""Test Query.get_labelgraph returns (labeled_edges, uri_map, entities, concepts)"""
mock_rag = MagicMock()
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
max_subgraph_size=100
)
test_subgraph = [
("entity1", "predicate1", "object1"),
("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"),
("entity3", "predicate3", "object3")
]
test_entities = ["entity1", "entity3"]
test_concepts = ["concept1"]
query.get_subgraph = AsyncMock(
return_value=(test_subgraph, {}, test_entities, test_concepts)
)
async def mock_maybe_label(entity):
label_map = {
"entity1": "Human Entity One",
"predicate1": "Human Predicate One",
"object1": "Human Object One",
"entity3": "Human Entity Three",
"predicate3": "Human Predicate Three",
"object3": "Human Object Three"
}
return label_map.get(entity, entity)
query.maybe_label = AsyncMock(side_effect=mock_maybe_label)
labeled_edges, uri_map, entities, concepts = await query.get_labelgraph("test query")
query.get_subgraph.assert_called_once_with("test query")
# Label triples filtered out
assert len(labeled_edges) == 2
# maybe_label called for non-label triples
assert query.maybe_label.call_count == 6
expected_edges = [
("Human Entity One", "Human Predicate One", "Human Object One"),
("Human Entity Three", "Human Predicate Three", "Human Object Three")
]
assert labeled_edges == expected_edges
assert len(uri_map) == 2
assert entities == test_entities
assert concepts == test_concepts
assert selected == []
mock_reranker_client.rerank.assert_not_called()
@pytest.mark.asyncio
async def test_graph_rag_query_method(self):
"""Test GraphRag.query method orchestrates full RAG pipeline with provenance"""
import json
from trustgraph.retrieval.graph_rag.graph_rag import edge_id
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_graph_embeddings_client = AsyncMock()
mock_triples_client = AsyncMock()
mock_reranker_client = AsyncMock()
expected_response = "This is the RAG response"
test_labelgraph = [("Subject", "Predicate", "Object")]
test_edge_id = edge_id("Subject", "Predicate", "Object")
test_selected_edges = [("Subject", "Predicate", "Object")]
test_eid = edge_id("Subject", "Predicate", "Object")
test_uri_map = {
test_edge_id: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
test_eid: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
}
test_edge_metadata = {
test_eid: {"concept": "test concept", "score": 0.95}
}
test_entities = ["http://example.org/subject"]
test_concepts = ["test concept"]
# Mock prompt responses for the multi-step process
mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
mock_graph_embeddings_client.query.return_value = []
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "extract-concepts":
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-scoring":
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "score": 0.9}])
elif prompt_name == "kg-edge-reasoning":
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "reasoning": "relevant"}])
return PromptResult(response_type="text", text="test concept")
elif prompt_name == "kg-synthesis":
return PromptResult(response_type="text", text=expected_response)
return PromptResult(response_type="text", text="")
@ -614,16 +530,16 @@ class TestQuery:
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
verbose=False
reranker_client=mock_reranker_client,
verbose=False,
)
# Patch Query.get_labelgraph to return test data
original_get_labelgraph = Query.get_labelgraph
original_hop_and_filter = Query.hop_and_filter
async def mock_get_labelgraph(self, query_text):
return test_labelgraph, test_uri_map, test_entities, test_concepts
async def mock_hop_and_filter(self, seed_entities, concepts):
return test_selected_edges, test_uri_map, test_edge_metadata
Query.get_labelgraph = mock_get_labelgraph
Query.hop_and_filter = mock_hop_and_filter
provenance_events = []
@ -636,7 +552,7 @@ class TestQuery:
collection="test_collection",
entity_limit=25,
triple_limit=15,
explain_callback=collect_provenance
explain_callback=collect_provenance,
)
response_text, usage = response
@ -650,7 +566,6 @@ class TestQuery:
assert len(triples) > 0
assert prov_id.startswith("urn:trustgraph:")
# Verify order
assert "question" in provenance_events[0][1]
assert "grounding" in provenance_events[1][1]
assert "exploration" in provenance_events[2][1]
@ -658,4 +573,4 @@ class TestQuery:
assert "synthesis" in provenance_events[4][1]
finally:
Query.get_labelgraph = original_get_labelgraph
Query.hop_and_filter = original_hop_and_filter