mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-01 01:19:38 +02:00
Replace the three-prompt LLM scoring pipeline (kg-edge-scoring, kg-edge-reasoning, kg-edge-selection) with a cross-encoder reranker service backed by FlashRank. The new hop_and_filter() method performs iterative graph traversal with semantic scoring at each hop, replacing the previous follow_edges/get_subgraph approach. - Add reranker service (trustgraph-base client/service, FlashRank processor) - Add gateway dispatch for reranker via API and WebSocket - Rewrite GraphRAG pipeline: hop_and_filter() with per-hop cross-encoder scoring - Remove kg_prompt() and edge_score_limit from prompt client - Update provenance: add tg:EdgeSelection type, tg:concept, tg:score predicates - Update CLIs (tg-invoke-graph-rag, tg-show-explain-trace) for new metadata - Add tg-invoke-reranker CLI tool - Add tech spec and UX developer guidance - Update all unit and integration tests
576 lines
20 KiB
Python
576 lines
20 KiB
Python
"""
|
|
Tests for GraphRAG retrieval implementation
|
|
"""
|
|
|
|
import pytest
|
|
import unittest.mock
|
|
from unittest.mock import MagicMock, AsyncMock
|
|
|
|
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, Query
|
|
from trustgraph.base import PromptResult
|
|
|
|
|
|
class TestGraphRag:
|
|
"""Test cases for GraphRag class"""
|
|
|
|
def test_graph_rag_initialization_with_defaults(self):
|
|
"""Test GraphRag initialization with default verbose setting"""
|
|
mock_prompt_client = MagicMock()
|
|
mock_embeddings_client = MagicMock()
|
|
mock_graph_embeddings_client = MagicMock()
|
|
mock_triples_client = MagicMock()
|
|
mock_reranker_client = MagicMock()
|
|
|
|
graph_rag = GraphRag(
|
|
prompt_client=mock_prompt_client,
|
|
embeddings_client=mock_embeddings_client,
|
|
graph_embeddings_client=mock_graph_embeddings_client,
|
|
triples_client=mock_triples_client,
|
|
reranker_client=mock_reranker_client,
|
|
)
|
|
|
|
assert graph_rag.prompt_client == mock_prompt_client
|
|
assert graph_rag.embeddings_client == mock_embeddings_client
|
|
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
|
assert graph_rag.triples_client == mock_triples_client
|
|
assert graph_rag.reranker_client == mock_reranker_client
|
|
assert graph_rag.verbose is False
|
|
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
|
|
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
|
|
|
|
def test_graph_rag_initialization_with_verbose(self):
|
|
"""Test GraphRag initialization with verbose enabled"""
|
|
mock_prompt_client = MagicMock()
|
|
mock_embeddings_client = MagicMock()
|
|
mock_graph_embeddings_client = MagicMock()
|
|
mock_triples_client = MagicMock()
|
|
mock_reranker_client = MagicMock()
|
|
|
|
graph_rag = GraphRag(
|
|
prompt_client=mock_prompt_client,
|
|
embeddings_client=mock_embeddings_client,
|
|
graph_embeddings_client=mock_graph_embeddings_client,
|
|
triples_client=mock_triples_client,
|
|
reranker_client=mock_reranker_client,
|
|
verbose=True,
|
|
)
|
|
|
|
assert graph_rag.prompt_client == mock_prompt_client
|
|
assert graph_rag.embeddings_client == mock_embeddings_client
|
|
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
|
assert graph_rag.triples_client == mock_triples_client
|
|
assert graph_rag.reranker_client == mock_reranker_client
|
|
assert graph_rag.verbose is True
|
|
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
|
|
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
|
|
|
|
|
|
class TestQuery:
|
|
"""Test cases for Query class"""
|
|
|
|
def test_query_initialization_with_defaults(self):
|
|
"""Test Query initialization with default parameters"""
|
|
# Create mock GraphRag
|
|
mock_rag = MagicMock()
|
|
|
|
# Initialize Query with defaults
|
|
query = Query(
|
|
rag=mock_rag,
|
|
collection="test_collection",
|
|
verbose=False
|
|
)
|
|
|
|
# Verify initialization
|
|
assert query.rag == mock_rag
|
|
assert query.collection == "test_collection"
|
|
assert query.verbose is False
|
|
assert query.entity_limit == 50 # Default value
|
|
assert query.triple_limit == 30 # Default value
|
|
assert query.max_subgraph_size == 1000 # Default value
|
|
assert query.max_path_length == 2 # Default value
|
|
|
|
def test_query_initialization_with_custom_params(self):
|
|
"""Test Query initialization with custom parameters"""
|
|
# Create mock GraphRag
|
|
mock_rag = MagicMock()
|
|
|
|
# Initialize Query with custom parameters
|
|
query = Query(
|
|
rag=mock_rag,
|
|
collection="custom_collection",
|
|
verbose=True,
|
|
entity_limit=100,
|
|
triple_limit=60,
|
|
max_subgraph_size=2000,
|
|
max_path_length=3
|
|
)
|
|
|
|
# Verify initialization
|
|
assert query.rag == mock_rag
|
|
assert query.collection == "custom_collection"
|
|
assert query.verbose is True
|
|
assert query.entity_limit == 100
|
|
assert query.triple_limit == 60
|
|
assert query.max_subgraph_size == 2000
|
|
assert query.max_path_length == 3
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_vectors_method(self):
|
|
"""Test Query.get_vectors method calls embeddings client correctly"""
|
|
mock_rag = MagicMock()
|
|
mock_embeddings_client = AsyncMock()
|
|
mock_rag.embeddings_client = mock_embeddings_client
|
|
|
|
# Mock embed to return vectors for a list of concepts
|
|
expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
|
mock_embeddings_client.embed.return_value = expected_vectors
|
|
|
|
query = Query(
|
|
rag=mock_rag,
|
|
collection="test_collection",
|
|
verbose=False
|
|
)
|
|
|
|
concepts = ["machine learning", "neural networks"]
|
|
result = await query.get_vectors(concepts)
|
|
|
|
mock_embeddings_client.embed.assert_called_once_with(concepts)
|
|
assert result == expected_vectors
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_vectors_method_with_verbose(self):
|
|
"""Test Query.get_vectors method with verbose output"""
|
|
mock_rag = MagicMock()
|
|
mock_embeddings_client = AsyncMock()
|
|
mock_rag.embeddings_client = mock_embeddings_client
|
|
|
|
expected_vectors = [[0.7, 0.8, 0.9]]
|
|
mock_embeddings_client.embed.return_value = expected_vectors
|
|
|
|
query = Query(
|
|
rag=mock_rag,
|
|
collection="test_collection",
|
|
verbose=True
|
|
)
|
|
|
|
result = await query.get_vectors(["test concept"])
|
|
|
|
mock_embeddings_client.embed.assert_called_once_with(["test concept"])
|
|
assert result == expected_vectors
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_concepts(self):
|
|
"""Test Query.extract_concepts parses LLM response into concept list"""
|
|
mock_rag = MagicMock()
|
|
mock_prompt_client = AsyncMock()
|
|
mock_rag.prompt_client = mock_prompt_client
|
|
|
|
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nneural networks\n")
|
|
|
|
query = Query(
|
|
rag=mock_rag,
|
|
collection="test_collection",
|
|
verbose=False
|
|
)
|
|
|
|
result = await query.extract_concepts("What is machine learning?")
|
|
|
|
mock_prompt_client.prompt.assert_called_once_with(
|
|
"extract-concepts",
|
|
variables={"query": "What is machine learning?"}
|
|
)
|
|
assert result == ["machine learning", "neural networks"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_concepts_fallback_to_raw_query(self):
|
|
"""Test extract_concepts falls back to raw query when LLM returns empty"""
|
|
mock_rag = MagicMock()
|
|
mock_prompt_client = AsyncMock()
|
|
mock_rag.prompt_client = mock_prompt_client
|
|
|
|
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
|
|
|
|
query = Query(
|
|
rag=mock_rag,
|
|
collection="test_collection",
|
|
verbose=False
|
|
)
|
|
|
|
result = await query.extract_concepts("test query")
|
|
assert result == ["test query"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_entities_method(self):
|
|
"""Test Query.get_entities extracts concepts, embeds, and retrieves entities"""
|
|
mock_rag = MagicMock()
|
|
mock_prompt_client = AsyncMock()
|
|
mock_embeddings_client = AsyncMock()
|
|
mock_graph_embeddings_client = AsyncMock()
|
|
mock_rag.prompt_client = mock_prompt_client
|
|
mock_rag.embeddings_client = mock_embeddings_client
|
|
mock_rag.graph_embeddings_client = mock_graph_embeddings_client
|
|
|
|
# extract_concepts returns empty -> falls back to [query]
|
|
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
|
|
|
|
# embed returns one vector set for the single concept
|
|
test_vectors = [[0.1, 0.2, 0.3]]
|
|
mock_embeddings_client.embed.return_value = test_vectors
|
|
|
|
# Mock entity matches
|
|
mock_entity1 = MagicMock()
|
|
mock_entity1.type = "i"
|
|
mock_entity1.iri = "entity1"
|
|
mock_match1 = MagicMock()
|
|
mock_match1.entity = mock_entity1
|
|
|
|
mock_entity2 = MagicMock()
|
|
mock_entity2.type = "i"
|
|
mock_entity2.iri = "entity2"
|
|
mock_match2 = MagicMock()
|
|
mock_match2.entity = mock_entity2
|
|
|
|
mock_graph_embeddings_client.query.return_value = [mock_match1, mock_match2]
|
|
|
|
query = Query(
|
|
rag=mock_rag,
|
|
collection="test_collection",
|
|
verbose=False,
|
|
entity_limit=25
|
|
)
|
|
|
|
entities, concepts = await query.get_entities("Find related entities")
|
|
|
|
# Verify embeddings client was called with the fallback concept
|
|
mock_embeddings_client.embed.assert_called_once_with(["Find related entities"])
|
|
|
|
# Verify result
|
|
assert entities == ["entity1", "entity2"]
|
|
assert concepts == ["Find related entities"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_maybe_label_with_cached_label(self):
|
|
"""Test Query.maybe_label method with cached label"""
|
|
mock_rag = MagicMock()
|
|
mock_cache = MagicMock()
|
|
mock_cache.get.return_value = "Entity One Label"
|
|
mock_rag.label_cache = mock_cache
|
|
|
|
query = Query(
|
|
rag=mock_rag,
|
|
collection="test_collection",
|
|
verbose=False
|
|
)
|
|
|
|
result = await query.maybe_label("entity1")
|
|
|
|
assert result == "Entity One Label"
|
|
mock_cache.get.assert_called_once_with("test_collection:entity1")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_maybe_label_with_label_lookup(self):
|
|
"""Test Query.maybe_label method with database label lookup"""
|
|
mock_rag = MagicMock()
|
|
mock_cache = MagicMock()
|
|
mock_cache.get.return_value = None
|
|
mock_rag.label_cache = mock_cache
|
|
mock_triples_client = AsyncMock()
|
|
mock_rag.triples_client = mock_triples_client
|
|
|
|
mock_triple = MagicMock()
|
|
mock_triple.o = "Human Readable Label"
|
|
mock_triples_client.query.return_value = [mock_triple]
|
|
|
|
query = Query(
|
|
rag=mock_rag,
|
|
collection="test_collection",
|
|
verbose=False
|
|
)
|
|
|
|
result = await query.maybe_label("http://example.com/entity")
|
|
|
|
mock_triples_client.query.assert_called_once_with(
|
|
s="http://example.com/entity",
|
|
p="http://www.w3.org/2000/01/rdf-schema#label",
|
|
o=None,
|
|
limit=1,
|
|
collection="test_collection",
|
|
g=""
|
|
)
|
|
|
|
assert result == "Human Readable Label"
|
|
cache_key = "test_collection:http://example.com/entity"
|
|
mock_cache.put.assert_called_once_with(cache_key, "Human Readable Label")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_maybe_label_with_no_label_found(self):
|
|
"""Test Query.maybe_label method when no label is found"""
|
|
mock_rag = MagicMock()
|
|
mock_cache = MagicMock()
|
|
mock_cache.get.return_value = None
|
|
mock_rag.label_cache = mock_cache
|
|
mock_triples_client = AsyncMock()
|
|
mock_rag.triples_client = mock_triples_client
|
|
|
|
mock_triples_client.query.return_value = []
|
|
|
|
query = Query(
|
|
rag=mock_rag,
|
|
collection="test_collection",
|
|
verbose=False
|
|
)
|
|
|
|
result = await query.maybe_label("unlabeled_entity")
|
|
|
|
mock_triples_client.query.assert_called_once_with(
|
|
s="unlabeled_entity",
|
|
p="http://www.w3.org/2000/01/rdf-schema#label",
|
|
o=None,
|
|
limit=1,
|
|
collection="test_collection",
|
|
g=""
|
|
)
|
|
|
|
assert result == "unlabeled_entity"
|
|
cache_key = "test_collection:unlabeled_entity"
|
|
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_triples_query_never_passes_workspace(self):
|
|
"""Workspace isolation is handled by pub/sub topic routing, not
|
|
by passing workspace to TriplesClient.query(). Verify that
|
|
GraphRAG never passes workspace as a keyword argument."""
|
|
mock_rag = MagicMock()
|
|
mock_cache = MagicMock()
|
|
mock_cache.get.return_value = None
|
|
mock_rag.label_cache = mock_cache
|
|
mock_triples_client = AsyncMock()
|
|
mock_rag.triples_client = mock_triples_client
|
|
|
|
mock_triple = MagicMock()
|
|
mock_triple.o = "Label"
|
|
mock_triples_client.query.return_value = [mock_triple]
|
|
|
|
query = Query(
|
|
rag=mock_rag,
|
|
collection="test_collection",
|
|
verbose=False
|
|
)
|
|
|
|
await query.maybe_label("http://example.com/entity")
|
|
|
|
for c in mock_triples_client.query.call_args_list:
|
|
assert "workspace" not in c.kwargs
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_hop_and_filter_never_passes_workspace(self):
|
|
"""Verify hop_and_filter never passes workspace to query_stream."""
|
|
mock_rag = MagicMock()
|
|
mock_triples_client = AsyncMock()
|
|
mock_reranker_client = AsyncMock()
|
|
mock_rag.triples_client = mock_triples_client
|
|
mock_rag.reranker_client = mock_reranker_client
|
|
mock_rag.label_cache = MagicMock()
|
|
mock_rag.label_cache.get.return_value = None
|
|
|
|
mock_triple = MagicMock()
|
|
mock_triple.s = "e1"
|
|
mock_triple.p = "p1"
|
|
mock_triple.o = "o1"
|
|
mock_triples_client.query_stream.return_value = [mock_triple]
|
|
mock_triples_client.query.return_value = []
|
|
|
|
result = MagicMock()
|
|
result.document_id = "0"
|
|
result.query_id = "0"
|
|
result.score = 0.9
|
|
mock_reranker_client.rerank.return_value = [result]
|
|
|
|
query = Query(
|
|
rag=mock_rag,
|
|
collection="test_collection",
|
|
verbose=False,
|
|
triple_limit=10,
|
|
)
|
|
|
|
await query.hop_and_filter(["e1"], ["concept"])
|
|
|
|
for c in mock_triples_client.query_stream.call_args_list:
|
|
assert "workspace" not in c.kwargs
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_hop_and_filter_basic_functionality(self):
|
|
"""Test hop_and_filter retrieves edges and scores them with reranker."""
|
|
mock_rag = MagicMock()
|
|
mock_triples_client = AsyncMock()
|
|
mock_reranker_client = AsyncMock()
|
|
mock_rag.triples_client = mock_triples_client
|
|
mock_rag.reranker_client = mock_reranker_client
|
|
mock_rag.label_cache = MagicMock()
|
|
mock_rag.label_cache.get.return_value = None
|
|
|
|
mock_triple = MagicMock()
|
|
mock_triple.s = "entity1"
|
|
mock_triple.p = "predicate1"
|
|
mock_triple.o = "object1"
|
|
mock_triples_client.query_stream.return_value = [mock_triple]
|
|
mock_triples_client.query.return_value = []
|
|
|
|
result = MagicMock()
|
|
result.document_id = "0"
|
|
result.query_id = "0"
|
|
result.score = 0.95
|
|
mock_reranker_client.rerank.return_value = [result]
|
|
|
|
query = Query(
|
|
rag=mock_rag,
|
|
collection="test_collection",
|
|
verbose=False,
|
|
triple_limit=10,
|
|
edge_limit=25,
|
|
)
|
|
|
|
selected, uri_map, edge_meta = await query.hop_and_filter(
|
|
["entity1"], ["test concept"],
|
|
)
|
|
|
|
assert len(selected) == 1
|
|
assert len(uri_map) == 1
|
|
assert len(edge_meta) == 1
|
|
|
|
mock_reranker_client.rerank.assert_called_once()
|
|
call_kwargs = mock_reranker_client.rerank.call_args
|
|
assert call_kwargs.kwargs["limit"] == 25
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_hop_and_filter_with_empty_frontier(self):
|
|
"""Test hop_and_filter with no seed entities returns empty."""
|
|
mock_rag = MagicMock()
|
|
|
|
query = Query(
|
|
rag=mock_rag,
|
|
collection="test_collection",
|
|
verbose=False,
|
|
)
|
|
|
|
selected, uri_map, edge_meta = await query.hop_and_filter([], ["concept"])
|
|
|
|
assert selected == []
|
|
assert uri_map == {}
|
|
assert edge_meta == {}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_hop_and_filter_filters_label_triples(self):
|
|
"""Test hop_and_filter skips rdfs:label edges."""
|
|
mock_rag = MagicMock()
|
|
mock_triples_client = AsyncMock()
|
|
mock_reranker_client = AsyncMock()
|
|
mock_rag.triples_client = mock_triples_client
|
|
mock_rag.reranker_client = mock_reranker_client
|
|
mock_rag.label_cache = MagicMock()
|
|
mock_rag.label_cache.get.return_value = None
|
|
|
|
label_triple = MagicMock()
|
|
label_triple.s = "entity1"
|
|
label_triple.p = "http://www.w3.org/2000/01/rdf-schema#label"
|
|
label_triple.o = "Entity One"
|
|
|
|
mock_triples_client.query_stream.return_value = [label_triple]
|
|
mock_triples_client.query.return_value = []
|
|
|
|
query = Query(
|
|
rag=mock_rag,
|
|
collection="test_collection",
|
|
verbose=False,
|
|
triple_limit=10,
|
|
)
|
|
|
|
selected, uri_map, edge_meta = await query.hop_and_filter(
|
|
["entity1"], ["concept"],
|
|
)
|
|
|
|
assert selected == []
|
|
mock_reranker_client.rerank.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_graph_rag_query_method(self):
|
|
"""Test GraphRag.query method orchestrates full RAG pipeline with provenance"""
|
|
from trustgraph.retrieval.graph_rag.graph_rag import edge_id
|
|
|
|
mock_prompt_client = AsyncMock()
|
|
mock_embeddings_client = AsyncMock()
|
|
mock_graph_embeddings_client = AsyncMock()
|
|
mock_triples_client = AsyncMock()
|
|
mock_reranker_client = AsyncMock()
|
|
|
|
expected_response = "This is the RAG response"
|
|
test_selected_edges = [("Subject", "Predicate", "Object")]
|
|
test_eid = edge_id("Subject", "Predicate", "Object")
|
|
test_uri_map = {
|
|
test_eid: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
|
|
}
|
|
test_edge_metadata = {
|
|
test_eid: {"concept": "test concept", "score": 0.95}
|
|
}
|
|
|
|
mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
|
|
mock_graph_embeddings_client.query.return_value = []
|
|
|
|
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
|
if prompt_name == "extract-concepts":
|
|
return PromptResult(response_type="text", text="test concept")
|
|
elif prompt_name == "kg-synthesis":
|
|
return PromptResult(response_type="text", text=expected_response)
|
|
return PromptResult(response_type="text", text="")
|
|
|
|
mock_prompt_client.prompt = mock_prompt
|
|
|
|
graph_rag = GraphRag(
|
|
prompt_client=mock_prompt_client,
|
|
embeddings_client=mock_embeddings_client,
|
|
graph_embeddings_client=mock_graph_embeddings_client,
|
|
triples_client=mock_triples_client,
|
|
reranker_client=mock_reranker_client,
|
|
verbose=False,
|
|
)
|
|
|
|
original_hop_and_filter = Query.hop_and_filter
|
|
|
|
async def mock_hop_and_filter(self, seed_entities, concepts):
|
|
return test_selected_edges, test_uri_map, test_edge_metadata
|
|
|
|
Query.hop_and_filter = mock_hop_and_filter
|
|
|
|
provenance_events = []
|
|
|
|
async def collect_provenance(triples, prov_id):
|
|
provenance_events.append((triples, prov_id))
|
|
|
|
try:
|
|
response = await graph_rag.query(
|
|
query="test query",
|
|
collection="test_collection",
|
|
entity_limit=25,
|
|
triple_limit=15,
|
|
explain_callback=collect_provenance,
|
|
)
|
|
|
|
response_text, usage = response
|
|
assert response_text == expected_response
|
|
|
|
# 5 events: question, grounding, exploration, focus, synthesis
|
|
assert len(provenance_events) == 5
|
|
|
|
for triples, prov_id in provenance_events:
|
|
assert isinstance(triples, list)
|
|
assert len(triples) > 0
|
|
assert prov_id.startswith("urn:trustgraph:")
|
|
|
|
assert "question" in provenance_events[0][1]
|
|
assert "grounding" in provenance_events[1][1]
|
|
assert "exploration" in provenance_events[2][1]
|
|
assert "focus" in provenance_events[3][1]
|
|
assert "synthesis" in provenance_events[4][1]
|
|
|
|
finally:
|
|
Query.hop_and_filter = original_hop_and_filter
|