mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-01 17:39:39 +02:00
feat: replace LLM edge scoring with cross-encoder reranker in GraphRAG
Replace the three-prompt LLM scoring pipeline (kg-edge-scoring, kg-edge-reasoning, kg-edge-selection) with a cross-encoder reranker service backed by FlashRank. The new hop_and_filter() method performs iterative graph traversal with semantic scoring at each hop, replacing the previous follow_edges/get_subgraph approach. - Add reranker service (trustgraph-base client/service, FlashRank processor) - Add gateway dispatch for reranker via API and WebSocket - Rewrite GraphRAG pipeline: hop_and_filter() with per-hop cross-encoder scoring - Remove kg_prompt() and edge_score_limit from prompt client - Update provenance: add tg:EdgeSelection type, tg:concept, tg:score predicates - Update CLIs (tg-invoke-graph-rag, tg-show-explain-trace) for new metadata - Add tg-invoke-reranker CLI tool - Add tech spec and UX developer guidance - Update all unit and integration tests
This commit is contained in:
parent
1aa9549912
commit
1346cbebb4
43 changed files with 1613 additions and 792 deletions
|
|
@ -95,10 +95,6 @@ class TestGraphRagIntegration:
|
|||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-scoring":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-reasoning":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
|
|
@ -113,14 +109,22 @@ class TestGraphRagIntegration:
|
|||
client.prompt.side_effect = mock_prompt
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_reranker_client(self):
|
||||
"""Mock reranker client for cross-encoder edge filtering"""
|
||||
client = AsyncMock()
|
||||
client.rerank.return_value = []
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client,
|
||||
mock_triples_client, mock_prompt_client):
|
||||
mock_triples_client, mock_reranker_client, mock_prompt_client):
|
||||
"""Create GraphRag instance with mocked dependencies"""
|
||||
return GraphRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
reranker_client=mock_reranker_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
verbose=True
|
||||
)
|
||||
|
|
@ -167,8 +171,8 @@ class TestGraphRagIntegration:
|
|||
# 3. Should query triples to build knowledge subgraph
|
||||
assert mock_triples_client.query_stream.call_count > 0
|
||||
|
||||
# 4. Should call prompt four times (extract-concepts + edge-scoring + edge-reasoning + synthesis)
|
||||
assert mock_prompt_client.prompt.call_count == 4
|
||||
# 4. Should call prompt twice (extract-concepts + synthesis)
|
||||
assert mock_prompt_client.prompt.call_count == 2
|
||||
|
||||
# Verify final response
|
||||
response, usage = response
|
||||
|
|
|
|||
|
|
@ -63,11 +63,6 @@ class TestGraphRagStreaming:
|
|||
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
|
||||
if prompt_id == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_id == "kg-edge-scoring":
|
||||
# Edge scoring returns JSONL with IDs and scores
|
||||
return PromptResult(response_type="text", text='{"id": "abc12345", "score": 0.9}\n')
|
||||
elif prompt_id == "kg-edge-reasoning":
|
||||
return PromptResult(response_type="text", text='{"id": "abc12345", "reasoning": "Relevant to query"}\n')
|
||||
elif prompt_id == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
# Simulate streaming chunks with end_of_stream flags
|
||||
|
|
@ -88,14 +83,23 @@ class TestGraphRagStreaming:
|
|||
client.prompt.side_effect = prompt_side_effect
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_reranker_client(self):
|
||||
"""Mock reranker client for cross-encoder edge filtering"""
|
||||
client = AsyncMock()
|
||||
client.rerank.return_value = []
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def graph_rag_streaming(self, mock_embeddings_client, mock_graph_embeddings_client,
|
||||
mock_triples_client, mock_streaming_prompt_client):
|
||||
mock_triples_client, mock_reranker_client,
|
||||
mock_streaming_prompt_client):
|
||||
"""Create GraphRag instance with streaming support"""
|
||||
return GraphRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
reranker_client=mock_reranker_client,
|
||||
prompt_client=mock_streaming_prompt_client,
|
||||
verbose=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ class TestGraphRagStreamingProtocol:
|
|||
client = AsyncMock()
|
||||
|
||||
async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
if prompt_name == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
|
|
@ -63,14 +63,23 @@ class TestGraphRagStreamingProtocol:
|
|||
client.prompt.side_effect = prompt_side_effect
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_reranker_client(self):
|
||||
"""Mock reranker client for cross-encoder edge filtering"""
|
||||
client = AsyncMock()
|
||||
client.rerank.return_value = []
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client,
|
||||
mock_triples_client, mock_streaming_prompt_client):
|
||||
mock_triples_client, mock_reranker_client,
|
||||
mock_streaming_prompt_client):
|
||||
"""Create GraphRag instance with mocked dependencies"""
|
||||
return GraphRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
reranker_client=mock_reranker_client,
|
||||
prompt_client=mock_streaming_prompt_client,
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -327,7 +336,7 @@ class TestStreamingProtocolEdgeCases:
|
|||
client = AsyncMock()
|
||||
|
||||
async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
if prompt_name == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
|
|
@ -342,10 +351,14 @@ class TestStreamingProtocolEdgeCases:
|
|||
|
||||
client.prompt.side_effect = prompt_with_empties
|
||||
|
||||
mock_reranker = AsyncMock()
|
||||
mock_reranker.rerank.return_value = []
|
||||
|
||||
rag = GraphRag(
|
||||
embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[[0.1]]])),
|
||||
graph_embeddings_client=AsyncMock(query=AsyncMock(return_value=[])),
|
||||
triples_client=AsyncMock(query=AsyncMock(return_value=[])),
|
||||
reranker_client=mock_reranker,
|
||||
prompt_client=client,
|
||||
verbose=False
|
||||
)
|
||||
|
|
|
|||
|
|
@ -195,38 +195,6 @@ class TestPromptClientStreamingCallback:
|
|||
assert callback.call_args_list[0] == call("test", False)
|
||||
assert callback.call_args_list[1] == call("", True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kg_prompt_passes_parameters_to_callback(self, prompt_client):
|
||||
"""Test that kg_prompt correctly passes streaming parameters"""
|
||||
# Arrange
|
||||
async def mock_request(request, recipient=None, timeout=600):
|
||||
if recipient:
|
||||
responses = [
|
||||
PromptResponse(text="Answer", object=None, error=None, end_of_stream=False),
|
||||
PromptResponse(text="", object=None, error=None, end_of_stream=True),
|
||||
]
|
||||
for resp in responses:
|
||||
should_stop = await recipient(resp)
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
prompt_client.request = mock_request
|
||||
|
||||
callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
await prompt_client.kg_prompt(
|
||||
query="What is machine learning?",
|
||||
kg=[("subject", "predicate", "object")],
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert callback.call_count == 2
|
||||
assert callback.call_args_list[0] == call("Answer", False)
|
||||
assert callback.call_args_list[1] == call("", True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_prompt_passes_parameters_to_callback(self, prompt_client):
|
||||
"""Test that document_prompt correctly passes streaming parameters"""
|
||||
|
|
|
|||
|
|
@ -107,6 +107,7 @@ class TestGraphRagDagStructure:
|
|||
embeddings_client = AsyncMock()
|
||||
graph_embeddings_client = AsyncMock()
|
||||
triples_client = AsyncMock()
|
||||
reranker_client = AsyncMock()
|
||||
|
||||
embeddings_client.embed.return_value = [[0.1, 0.2]]
|
||||
graph_embeddings_client.query.return_value = [
|
||||
|
|
@ -121,27 +122,22 @@ class TestGraphRagDagStructure:
|
|||
]
|
||||
triples_client.query.return_value = []
|
||||
|
||||
result = MagicMock()
|
||||
result.document_id = "0"
|
||||
result.query_id = "0"
|
||||
result.score = 0.95
|
||||
reranker_client.rerank.return_value = [result]
|
||||
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="concept")
|
||||
elif template_id == "kg-edge-scoring":
|
||||
edges = variables.get("knowledge", [])
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[{"id": e["id"], "score": 10} for e in edges],
|
||||
)
|
||||
elif template_id == "kg-edge-reasoning":
|
||||
edges = variables.get("knowledge", [])
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[{"id": e["id"], "reasoning": "relevant"} for e in edges],
|
||||
)
|
||||
elif template_id == "kg-synthesis":
|
||||
return PromptResult(response_type="text", text="Answer.")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
|
||||
return (prompt_client, embeddings_client, graph_embeddings_client,
|
||||
triples_client, reranker_client)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_chain(self, mock_clients):
|
||||
|
|
@ -152,7 +148,7 @@ class TestGraphRagDagStructure:
|
|||
events.append({"explain_id": explain_id, "triples": triples})
|
||||
|
||||
await rag.query(
|
||||
query="test", explain_callback=explain_cb, edge_score_limit=0,
|
||||
query="test", explain_callback=explain_cb,
|
||||
)
|
||||
|
||||
dag = _collect_events(events)
|
||||
|
|
|
|||
|
|
@ -15,54 +15,52 @@ class TestGraphRag:
|
|||
|
||||
def test_graph_rag_initialization_with_defaults(self):
|
||||
"""Test GraphRag initialization with default verbose setting"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = MagicMock()
|
||||
mock_embeddings_client = MagicMock()
|
||||
mock_graph_embeddings_client = MagicMock()
|
||||
mock_triples_client = MagicMock()
|
||||
mock_reranker_client = MagicMock()
|
||||
|
||||
# Initialize GraphRag
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert graph_rag.prompt_client == mock_prompt_client
|
||||
assert graph_rag.embeddings_client == mock_embeddings_client
|
||||
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
||||
assert graph_rag.triples_client == mock_triples_client
|
||||
assert graph_rag.verbose is False # Default value
|
||||
# Verify label_cache is an LRUCacheWithTTL instance
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
|
||||
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
|
||||
|
||||
def test_graph_rag_initialization_with_verbose(self):
|
||||
"""Test GraphRag initialization with verbose enabled"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = MagicMock()
|
||||
mock_embeddings_client = MagicMock()
|
||||
mock_graph_embeddings_client = MagicMock()
|
||||
mock_triples_client = MagicMock()
|
||||
|
||||
# Initialize GraphRag with verbose=True
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
verbose=True
|
||||
reranker_client=mock_reranker_client,
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert graph_rag.prompt_client == mock_prompt_client
|
||||
assert graph_rag.embeddings_client == mock_embeddings_client
|
||||
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
||||
assert graph_rag.triples_client == mock_triples_client
|
||||
assert graph_rag.reranker_client == mock_reranker_client
|
||||
assert graph_rag.verbose is False
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
|
||||
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
|
||||
|
||||
def test_graph_rag_initialization_with_verbose(self):
|
||||
"""Test GraphRag initialization with verbose enabled"""
|
||||
mock_prompt_client = MagicMock()
|
||||
mock_embeddings_client = MagicMock()
|
||||
mock_graph_embeddings_client = MagicMock()
|
||||
mock_triples_client = MagicMock()
|
||||
mock_reranker_client = MagicMock()
|
||||
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
reranker_client=mock_reranker_client,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
assert graph_rag.prompt_client == mock_prompt_client
|
||||
assert graph_rag.embeddings_client == mock_embeddings_client
|
||||
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
||||
assert graph_rag.triples_client == mock_triples_client
|
||||
assert graph_rag.reranker_client == mock_reranker_client
|
||||
assert graph_rag.verbose is True
|
||||
# Verify label_cache is an LRUCacheWithTTL instance
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
|
||||
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
|
||||
|
||||
|
|
@ -365,244 +363,162 @@ class TestQuery:
|
|||
assert "workspace" not in c.kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_never_passes_workspace(self):
|
||||
"""Verify follow_edges never passes workspace to query_stream."""
|
||||
async def test_hop_and_filter_never_passes_workspace(self):
|
||||
"""Verify hop_and_filter never passes workspace to query_stream."""
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_reranker_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
mock_rag.reranker_client = mock_reranker_client
|
||||
mock_rag.label_cache = MagicMock()
|
||||
mock_rag.label_cache.get.return_value = None
|
||||
|
||||
mock_triple = MagicMock()
|
||||
mock_triple.s, mock_triple.p, mock_triple.o = "e1", "p1", "o1"
|
||||
mock_triple.s = "e1"
|
||||
mock_triple.p = "p1"
|
||||
mock_triple.o = "o1"
|
||||
mock_triples_client.query_stream.return_value = [mock_triple]
|
||||
mock_triples_client.query.return_value = []
|
||||
|
||||
result = MagicMock()
|
||||
result.document_id = "0"
|
||||
result.query_id = "0"
|
||||
result.score = 0.9
|
||||
mock_reranker_client.rerank.return_value = [result]
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
triple_limit=10
|
||||
triple_limit=10,
|
||||
)
|
||||
|
||||
subgraph = set()
|
||||
await query.follow_edges("e1", subgraph, path_length=1)
|
||||
await query.hop_and_filter(["e1"], ["concept"])
|
||||
|
||||
for c in mock_triples_client.query_stream.call_args_list:
|
||||
assert "workspace" not in c.kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_basic_functionality(self):
|
||||
"""Test Query.follow_edges method basic triple discovery"""
|
||||
async def test_hop_and_filter_basic_functionality(self):
|
||||
"""Test hop_and_filter retrieves edges and scores them with reranker."""
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_reranker_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
mock_rag.reranker_client = mock_reranker_client
|
||||
mock_rag.label_cache = MagicMock()
|
||||
mock_rag.label_cache.get.return_value = None
|
||||
|
||||
mock_triple1 = MagicMock()
|
||||
mock_triple1.s, mock_triple1.p, mock_triple1.o = "entity1", "predicate1", "object1"
|
||||
mock_triple = MagicMock()
|
||||
mock_triple.s = "entity1"
|
||||
mock_triple.p = "predicate1"
|
||||
mock_triple.o = "object1"
|
||||
mock_triples_client.query_stream.return_value = [mock_triple]
|
||||
mock_triples_client.query.return_value = []
|
||||
|
||||
mock_triple2 = MagicMock()
|
||||
mock_triple2.s, mock_triple2.p, mock_triple2.o = "subject2", "entity1", "object2"
|
||||
|
||||
mock_triple3 = MagicMock()
|
||||
mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1"
|
||||
|
||||
mock_triples_client.query_stream.side_effect = [
|
||||
[mock_triple1], # s=ent
|
||||
[mock_triple2], # p=ent
|
||||
[mock_triple3], # o=ent
|
||||
]
|
||||
result = MagicMock()
|
||||
result.document_id = "0"
|
||||
result.query_id = "0"
|
||||
result.score = 0.95
|
||||
mock_reranker_client.rerank.return_value = [result]
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
triple_limit=10
|
||||
triple_limit=10,
|
||||
edge_limit=25,
|
||||
)
|
||||
|
||||
subgraph = set()
|
||||
await query.follow_edges("entity1", subgraph, path_length=1)
|
||||
|
||||
assert mock_triples_client.query_stream.call_count == 3
|
||||
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s="entity1", p=None, o=None, limit=10,
|
||||
collection="test_collection", batch_size=20, g=""
|
||||
)
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s=None, p="entity1", o=None, limit=10,
|
||||
collection="test_collection", batch_size=20, g=""
|
||||
)
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s=None, p=None, o="entity1", limit=10,
|
||||
collection="test_collection", batch_size=20, g=""
|
||||
selected, uri_map, edge_meta = await query.hop_and_filter(
|
||||
["entity1"], ["test concept"],
|
||||
)
|
||||
|
||||
expected_subgraph = {
|
||||
("entity1", "predicate1", "object1"),
|
||||
("subject2", "entity1", "object2"),
|
||||
("subject3", "predicate3", "entity1")
|
||||
}
|
||||
assert subgraph == expected_subgraph
|
||||
assert len(selected) == 1
|
||||
assert len(uri_map) == 1
|
||||
assert len(edge_meta) == 1
|
||||
|
||||
mock_reranker_client.rerank.assert_called_once()
|
||||
call_kwargs = mock_reranker_client.rerank.call_args
|
||||
assert call_kwargs.kwargs["limit"] == 25
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_with_path_length_zero(self):
|
||||
"""Test Query.follow_edges method with path_length=0"""
|
||||
async def test_hop_and_filter_with_empty_frontier(self):
|
||||
"""Test hop_and_filter with no seed entities returns empty."""
|
||||
mock_rag = MagicMock()
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
selected, uri_map, edge_meta = await query.hop_and_filter([], ["concept"])
|
||||
|
||||
assert selected == []
|
||||
assert uri_map == {}
|
||||
assert edge_meta == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hop_and_filter_filters_label_triples(self):
|
||||
"""Test hop_and_filter skips rdfs:label edges."""
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_reranker_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
mock_rag.reranker_client = mock_reranker_client
|
||||
mock_rag.label_cache = MagicMock()
|
||||
mock_rag.label_cache.get.return_value = None
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
label_triple = MagicMock()
|
||||
label_triple.s = "entity1"
|
||||
label_triple.p = "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
label_triple.o = "Entity One"
|
||||
|
||||
subgraph = set()
|
||||
await query.follow_edges("entity1", subgraph, path_length=0)
|
||||
|
||||
mock_triples_client.query_stream.assert_not_called()
|
||||
assert subgraph == set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_with_max_subgraph_size_limit(self):
|
||||
"""Test Query.follow_edges method respects max_subgraph_size"""
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
mock_triples_client.query_stream.return_value = [label_triple]
|
||||
mock_triples_client.query.return_value = []
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_subgraph_size=2
|
||||
triple_limit=10,
|
||||
)
|
||||
|
||||
subgraph = {("s1", "p1", "o1"), ("s2", "p2", "o2"), ("s3", "p3", "o3")}
|
||||
|
||||
await query.follow_edges("entity1", subgraph, path_length=1)
|
||||
|
||||
mock_triples_client.query_stream.assert_not_called()
|
||||
assert len(subgraph) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_subgraph_method(self):
|
||||
"""Test Query.get_subgraph returns (subgraph, entities, concepts) tuple"""
|
||||
mock_rag = MagicMock()
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_path_length=1
|
||||
selected, uri_map, edge_meta = await query.hop_and_filter(
|
||||
["entity1"], ["concept"],
|
||||
)
|
||||
|
||||
# Mock get_entities to return (entities, concepts) tuple
|
||||
query.get_entities = AsyncMock(
|
||||
return_value=(["entity1", "entity2"], ["concept1"])
|
||||
)
|
||||
|
||||
query.follow_edges_batch = AsyncMock(return_value=(
|
||||
{
|
||||
("entity1", "predicate1", "object1"),
|
||||
("entity2", "predicate2", "object2")
|
||||
},
|
||||
{}
|
||||
))
|
||||
|
||||
subgraph, term_map, entities, concepts = await query.get_subgraph("test query")
|
||||
|
||||
query.get_entities.assert_called_once_with("test query")
|
||||
query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1)
|
||||
|
||||
assert isinstance(subgraph, list)
|
||||
assert len(subgraph) == 2
|
||||
assert ("entity1", "predicate1", "object1") in subgraph
|
||||
assert ("entity2", "predicate2", "object2") in subgraph
|
||||
assert entities == ["entity1", "entity2"]
|
||||
assert concepts == ["concept1"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_labelgraph_method(self):
|
||||
"""Test Query.get_labelgraph returns (labeled_edges, uri_map, entities, concepts)"""
|
||||
mock_rag = MagicMock()
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_subgraph_size=100
|
||||
)
|
||||
|
||||
test_subgraph = [
|
||||
("entity1", "predicate1", "object1"),
|
||||
("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"),
|
||||
("entity3", "predicate3", "object3")
|
||||
]
|
||||
test_entities = ["entity1", "entity3"]
|
||||
test_concepts = ["concept1"]
|
||||
query.get_subgraph = AsyncMock(
|
||||
return_value=(test_subgraph, {}, test_entities, test_concepts)
|
||||
)
|
||||
|
||||
async def mock_maybe_label(entity):
|
||||
label_map = {
|
||||
"entity1": "Human Entity One",
|
||||
"predicate1": "Human Predicate One",
|
||||
"object1": "Human Object One",
|
||||
"entity3": "Human Entity Three",
|
||||
"predicate3": "Human Predicate Three",
|
||||
"object3": "Human Object Three"
|
||||
}
|
||||
return label_map.get(entity, entity)
|
||||
|
||||
query.maybe_label = AsyncMock(side_effect=mock_maybe_label)
|
||||
|
||||
labeled_edges, uri_map, entities, concepts = await query.get_labelgraph("test query")
|
||||
|
||||
query.get_subgraph.assert_called_once_with("test query")
|
||||
|
||||
# Label triples filtered out
|
||||
assert len(labeled_edges) == 2
|
||||
|
||||
# maybe_label called for non-label triples
|
||||
assert query.maybe_label.call_count == 6
|
||||
|
||||
expected_edges = [
|
||||
("Human Entity One", "Human Predicate One", "Human Object One"),
|
||||
("Human Entity Three", "Human Predicate Three", "Human Object Three")
|
||||
]
|
||||
assert labeled_edges == expected_edges
|
||||
|
||||
assert len(uri_map) == 2
|
||||
assert entities == test_entities
|
||||
assert concepts == test_concepts
|
||||
assert selected == []
|
||||
mock_reranker_client.rerank.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_query_method(self):
|
||||
"""Test GraphRag.query method orchestrates full RAG pipeline with provenance"""
|
||||
import json
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import edge_id
|
||||
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_graph_embeddings_client = AsyncMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_reranker_client = AsyncMock()
|
||||
|
||||
expected_response = "This is the RAG response"
|
||||
test_labelgraph = [("Subject", "Predicate", "Object")]
|
||||
test_edge_id = edge_id("Subject", "Predicate", "Object")
|
||||
test_selected_edges = [("Subject", "Predicate", "Object")]
|
||||
test_eid = edge_id("Subject", "Predicate", "Object")
|
||||
test_uri_map = {
|
||||
test_edge_id: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
|
||||
test_eid: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
|
||||
}
|
||||
test_edge_metadata = {
|
||||
test_eid: {"concept": "test concept", "score": 0.95}
|
||||
}
|
||||
test_entities = ["http://example.org/subject"]
|
||||
test_concepts = ["test concept"]
|
||||
|
||||
# Mock prompt responses for the multi-step process
|
||||
mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
|
||||
mock_graph_embeddings_client.query.return_value = []
|
||||
|
||||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-scoring":
|
||||
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "score": 0.9}])
|
||||
elif prompt_name == "kg-edge-reasoning":
|
||||
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "reasoning": "relevant"}])
|
||||
return PromptResult(response_type="text", text="test concept")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return PromptResult(response_type="text", text=expected_response)
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
|
@ -614,16 +530,16 @@ class TestQuery:
|
|||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
verbose=False
|
||||
reranker_client=mock_reranker_client,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# Patch Query.get_labelgraph to return test data
|
||||
original_get_labelgraph = Query.get_labelgraph
|
||||
original_hop_and_filter = Query.hop_and_filter
|
||||
|
||||
async def mock_get_labelgraph(self, query_text):
|
||||
return test_labelgraph, test_uri_map, test_entities, test_concepts
|
||||
async def mock_hop_and_filter(self, seed_entities, concepts):
|
||||
return test_selected_edges, test_uri_map, test_edge_metadata
|
||||
|
||||
Query.get_labelgraph = mock_get_labelgraph
|
||||
Query.hop_and_filter = mock_hop_and_filter
|
||||
|
||||
provenance_events = []
|
||||
|
||||
|
|
@ -636,7 +552,7 @@ class TestQuery:
|
|||
collection="test_collection",
|
||||
entity_limit=25,
|
||||
triple_limit=15,
|
||||
explain_callback=collect_provenance
|
||||
explain_callback=collect_provenance,
|
||||
)
|
||||
|
||||
response_text, usage = response
|
||||
|
|
@ -650,7 +566,6 @@ class TestQuery:
|
|||
assert len(triples) > 0
|
||||
assert prov_id.startswith("urn:trustgraph:")
|
||||
|
||||
# Verify order
|
||||
assert "question" in provenance_events[0][1]
|
||||
assert "grounding" in provenance_events[1][1]
|
||||
assert "exploration" in provenance_events[2][1]
|
||||
|
|
@ -658,4 +573,4 @@ class TestQuery:
|
|||
assert "synthesis" in provenance_events[4][1]
|
||||
|
||||
finally:
|
||||
Query.get_labelgraph = original_get_labelgraph
|
||||
Query.hop_and_filter = original_hop_and_filter
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from trustgraph.provenance.namespaces import (
|
|||
TG_GRAPH_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
|
||||
TG_FOCUS, TG_SYNTHESIS, TG_ANSWER_TYPE,
|
||||
TG_QUERY, TG_CONCEPT, TG_ENTITY, TG_EDGE_COUNT,
|
||||
TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
TG_SELECTED_EDGE, TG_EDGE, TG_SCORE, TG_EDGE_SELECTION,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -91,17 +91,17 @@ def build_mock_clients():
|
|||
1. prompt_client.prompt("extract-concepts", ...) -> concepts
|
||||
2. embeddings_client.embed(concepts) -> vectors
|
||||
3. graph_embeddings_client.query(vector, ...) -> entity matches
|
||||
4. triples_client.query_stream(s/p/o, ...) -> edges (follow_edges_batch)
|
||||
4. triples_client.query_stream(s/p/o, ...) -> edges (hop_and_filter)
|
||||
5. triples_client.query(s, LABEL, ...) -> labels (maybe_label)
|
||||
6. prompt_client.prompt("kg-edge-scoring", ...) -> scored edges
|
||||
7. prompt_client.prompt("kg-edge-reasoning", ...) -> reasoning
|
||||
8. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns [])
|
||||
9. prompt_client.prompt("kg-synthesis", ...) -> final answer
|
||||
6. reranker_client.rerank(queries, documents, limit) -> scored edges
|
||||
7. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns [])
|
||||
8. prompt_client.prompt("kg-synthesis", ...) -> final answer
|
||||
"""
|
||||
prompt_client = AsyncMock()
|
||||
embeddings_client = AsyncMock()
|
||||
graph_embeddings_client = AsyncMock()
|
||||
triples_client = AsyncMock()
|
||||
reranker_client = AsyncMock()
|
||||
|
||||
# 1. Concept extraction
|
||||
prompt_responses = {}
|
||||
|
|
@ -116,7 +116,7 @@ def build_mock_clients():
|
|||
EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_B)),
|
||||
]
|
||||
|
||||
# 4. Triple queries (follow_edges_batch) - return our edges
|
||||
# 4. Triple queries (hop_and_filter) - return our edges
|
||||
kg_triples = [
|
||||
make_schema_triple(*EDGE_1),
|
||||
make_schema_triple(*EDGE_2),
|
||||
|
|
@ -130,9 +130,18 @@ def build_mock_clients():
|
|||
return [] # No labels found, will fall back to URI
|
||||
triples_client.query.side_effect = mock_label_query
|
||||
|
||||
# 6+7. Edge scoring and reasoning: dynamically score/reason about
|
||||
# whatever edges the query method sends us, since edge IDs are computed
|
||||
# from str(Term) representations which include the full dataclass repr.
|
||||
# 6. Reranker: select all documents with high scores
|
||||
async def mock_rerank(queries, documents, limit):
|
||||
results = []
|
||||
for i, doc in enumerate(documents):
|
||||
result = MagicMock()
|
||||
result.document_id = doc["id"]
|
||||
result.query_id = queries[0]["id"] if queries else "0"
|
||||
result.score = 0.9 - (i * 0.1)
|
||||
results.append(result)
|
||||
return results[:limit]
|
||||
reranker_client.rerank.side_effect = mock_rerank
|
||||
|
||||
synthesis_answer = "Quantum computing applies physics principles to computation."
|
||||
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
|
|
@ -141,26 +150,6 @@ def build_mock_clients():
|
|||
response_type="text",
|
||||
text=prompt_responses["extract-concepts"],
|
||||
)
|
||||
elif template_id == "kg-edge-scoring":
|
||||
# Score all edges highly, using the IDs that GraphRag computed
|
||||
edges = variables.get("knowledge", [])
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"id": e["id"], "score": 10 - i}
|
||||
for i, e in enumerate(edges)
|
||||
],
|
||||
)
|
||||
elif template_id == "kg-edge-reasoning":
|
||||
# Provide reasoning for each edge
|
||||
edges = variables.get("knowledge", [])
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
|
||||
for i, e in enumerate(edges)
|
||||
],
|
||||
)
|
||||
elif template_id == "kg-synthesis":
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
|
|
@ -170,7 +159,8 @@ def build_mock_clients():
|
|||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
|
||||
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
|
||||
return (prompt_client, embeddings_client, graph_embeddings_client,
|
||||
triples_client, reranker_client)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -197,7 +187,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0, # skip semantic pre-filter for simplicity
|
||||
|
||||
)
|
||||
|
||||
assert len(events) == 5, (
|
||||
|
|
@ -222,7 +212,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
expected_types = [
|
||||
|
|
@ -260,7 +250,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
uris = [e["explain_id"] for e in events]
|
||||
|
|
@ -297,7 +287,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
q_uri = events[0]["explain_id"]
|
||||
|
|
@ -320,7 +310,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
gnd_uri = events[1]["explain_id"]
|
||||
|
|
@ -344,7 +334,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
exp_uri = events[2]["explain_id"]
|
||||
|
|
@ -355,10 +345,10 @@ class TestGraphRagQueryProvenance:
|
|||
assert int(t.o.value) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_focus_has_selected_edges_with_reasoning(self):
|
||||
async def test_focus_has_selected_edges_with_concept_and_score(self):
|
||||
"""
|
||||
The focus event should carry selected edges as quoted triples
|
||||
with reasoning text.
|
||||
with cross-encoder concept and score metadata.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
|
@ -371,7 +361,6 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
foc_uri = events[3]["explain_id"]
|
||||
|
|
@ -387,11 +376,19 @@ class TestGraphRagQueryProvenance:
|
|||
for t in edge_t:
|
||||
assert t.o.triple is not None, "tg:edge object must be a quoted triple"
|
||||
|
||||
# Should have reasoning
|
||||
reasoning = find_triples(foc_triples, TG_REASONING)
|
||||
assert len(reasoning) > 0, "Focus should have reasoning for selected edges"
|
||||
reasoning_texts = {t.o.value for t in reasoning}
|
||||
assert any(r for r in reasoning_texts), "Reasoning should not be empty"
|
||||
# Edge selections should be typed as EdgeSelection
|
||||
edge_sel_uris = [t.o.iri for t in selected]
|
||||
for uri in edge_sel_uris:
|
||||
assert has_type(foc_triples, uri, TG_EDGE_SELECTION)
|
||||
|
||||
# Should have concept and score
|
||||
concepts = find_triples(foc_triples, TG_CONCEPT)
|
||||
assert len(concepts) > 0, "Focus should have tg:concept for selected edges"
|
||||
|
||||
scores = find_triples(foc_triples, TG_SCORE)
|
||||
assert len(scores) > 0, "Focus should have tg:score for selected edges"
|
||||
for t in scores:
|
||||
float(t.o.value) # Should be parseable as float
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesis_is_answer_type(self):
|
||||
|
|
@ -407,7 +404,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
syn_uri = events[4]["explain_id"]
|
||||
|
|
@ -429,7 +426,7 @@ class TestGraphRagQueryProvenance:
|
|||
result_text, usage = await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
assert result_text == "Quantum computing applies physics principles to computation."
|
||||
|
|
@ -449,7 +446,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
parent_uri=parent,
|
||||
)
|
||||
|
||||
|
|
@ -465,7 +462,7 @@ class TestGraphRagQueryProvenance:
|
|||
|
||||
result_text, usage = await rag.query(
|
||||
query="What is quantum computing?",
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
assert result_text == "Quantum computing applies physics principles to computation."
|
||||
|
|
@ -484,7 +481,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
for event in events:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue