mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-03 23:11:00 +02:00
Merge branch 'release/v2.6'
This commit is contained in:
commit
508d0bb5c1
9 changed files with 892 additions and 31 deletions
|
|
@ -224,12 +224,27 @@ The current embedding pre-filter represents each edge as
|
|||
- **Drop commas.** Commas add tokenisation noise without semantic
|
||||
value.
|
||||
|
||||
- **Drop the subject.** The subject identifies which entity the
|
||||
edge belongs to, but it does not contribute to whether the
|
||||
edge's content is relevant to the query. The predicate and
|
||||
object carry the semantic meaning — what relationship exists
|
||||
and what it connects to. Representing edges as `"{p} {o}"`
|
||||
produces cleaner cross-encoder matches.
|
||||
- **Direction-aware text.** The reranker text should highlight
|
||||
the *new* information relative to the traversal direction.
|
||||
The frontier entity is already known context — repeating it
|
||||
adds noise and, when traversing from an object node, causes
|
||||
many edges to produce identical reranker text (e.g. 18
|
||||
products sharing the same `hasSubcategory Processors` triple
|
||||
all collapse to the same string when the subject is dropped).
|
||||
|
||||
The text is constructed based on which position the frontier
|
||||
entity occupied in the triple:
|
||||
|
||||
- **From subject** (s=entity): `"{predicate} {object}"` —
|
||||
the subject is known, predicate and object are new.
|
||||
- **From object** (o=entity): `"{subject} {predicate}"` —
|
||||
the object is known, subject and predicate are new.
|
||||
- **From predicate** (p=entity): `"{subject} {object}"` —
|
||||
the predicate is known, subject and object are new.
|
||||
|
||||
This eliminates the duplicate-text problem that arises when
|
||||
traversing inward from a shared object node, and gives the
|
||||
cross-encoder a more informative signal at every hop.
|
||||
|
||||
#### Remove the embedding pre-filter (step 3)
|
||||
|
||||
|
|
@ -389,7 +404,10 @@ no LLM call. These fields are dropped from the Focus entity.
|
|||
|
||||
a. Retrieve all edges one hop from the current frontier nodes.
|
||||
|
||||
b. Represent each edge as `"{predicate} {object}"`.
|
||||
b. Represent each edge using direction-aware text: from a
|
||||
subject node use `"{predicate} {object}"`, from an object
|
||||
node use `"{subject} {predicate}"`, from a predicate node
|
||||
use `"{subject} {object}"`.
|
||||
|
||||
c. Score edges against the extracted concepts using the
|
||||
cross-encoder service.
|
||||
|
|
|
|||
|
|
@ -129,6 +129,9 @@ class TestBatchTripleQueries:
|
|||
|
||||
# 3 queries, alternating results
|
||||
assert len(result) == 3
|
||||
# Each result is a (triple, direction) tuple
|
||||
for triple, direction in result:
|
||||
assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_in_one_query_does_not_block_others(self):
|
||||
|
|
@ -153,6 +156,8 @@ class TestBatchTripleQueries:
|
|||
|
||||
# 3 queries: 2 succeed, 1 fails → 2 triples
|
||||
assert len(result) == 2
|
||||
for triple, direction in result:
|
||||
assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_results_filtered(self):
|
||||
|
|
@ -176,6 +181,8 @@ class TestBatchTripleQueries:
|
|||
|
||||
# 3 queries: 1 returns None, 2 return triples
|
||||
assert len(result) == 2
|
||||
for triple, direction in result:
|
||||
assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_entities_no_queries(self):
|
||||
|
|
@ -220,6 +227,80 @@ class TestBatchTripleQueries:
|
|||
assert calls[2].kwargs["p"] is None
|
||||
assert calls[2].kwargs["o"] == "ent-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_directions_assigned_correctly(self):
|
||||
"""Each query position should produce the correct direction tag."""
|
||||
triple = _make_triple("s", "p", "o")
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def one_triple_each(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return [triple]
|
||||
|
||||
client = AsyncMock()
|
||||
client.query_stream = one_triple_each
|
||||
query = _make_query(triples_client=client)
|
||||
|
||||
result = await query.execute_batch_triple_queries(
|
||||
["e1"], limit_per_entity=10
|
||||
)
|
||||
|
||||
assert len(result) == 3
|
||||
# Order matches query order: s-position, p-position, o-position
|
||||
assert result[0][1] == Query.FROM_S
|
||||
assert result[1][1] == Query.FROM_P
|
||||
assert result[2][1] == Query.FROM_O
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_directions_correct_for_multiple_entities(self):
|
||||
"""Direction tags cycle correctly across multiple entities."""
|
||||
triple = _make_triple("s", "p", "o")
|
||||
client = AsyncMock()
|
||||
client.query_stream = AsyncMock(return_value=[triple])
|
||||
query = _make_query(triples_client=client)
|
||||
|
||||
result = await query.execute_batch_triple_queries(
|
||||
["e1", "e2"], limit_per_entity=10
|
||||
)
|
||||
|
||||
assert len(result) == 6
|
||||
expected_directions = [
|
||||
Query.FROM_S, Query.FROM_P, Query.FROM_O,
|
||||
Query.FROM_S, Query.FROM_P, Query.FROM_O,
|
||||
]
|
||||
for (_, direction), expected in zip(result, expected_directions):
|
||||
assert direction == expected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direction_preserved_with_multiple_triples(self):
|
||||
"""All triples from one query share the same direction."""
|
||||
t1 = _make_triple("a", "p1", "b")
|
||||
t2 = _make_triple("a", "p2", "c")
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def multi_results(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return [t1, t2]
|
||||
return []
|
||||
|
||||
client = AsyncMock()
|
||||
client.query_stream = multi_results
|
||||
query = _make_query(triples_client=client)
|
||||
|
||||
result = await query.execute_batch_triple_queries(
|
||||
["e1"], limit_per_entity=10
|
||||
)
|
||||
|
||||
# First query (FROM_S) returns 2 triples, both should be FROM_S
|
||||
assert len(result) == 2
|
||||
assert result[0] == (t1, Query.FROM_S)
|
||||
assert result[1] == (t2, Query.FROM_S)
|
||||
|
||||
|
||||
class TestLRUCacheWithTTL:
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,92 @@
|
|||
from trustgraph.retrieval.document_rag.rerank import (
|
||||
RerankCandidate, normalize_candidate_scores, mmr_select,
|
||||
_pair_diversity_penalty
|
||||
)
|
||||
|
||||
def candidate(index, chunk_id, text, score):
|
||||
return RerankCandidate(
|
||||
index=index,
|
||||
chunk_id=chunk_id,
|
||||
text=text,
|
||||
reranker_score=score,
|
||||
)
|
||||
|
||||
|
||||
def test_normalize_candidate_scores_min_max_scales_raw_scores():
|
||||
candidates = [
|
||||
candidate(0, "a", "alpha", -2.0),
|
||||
candidate(1, "b", "beta", 0.0),
|
||||
candidate(2, "c", "gamma", 4.0),
|
||||
]
|
||||
|
||||
normalized = normalize_candidate_scores(candidates)
|
||||
|
||||
assert normalized[0].normalized_score == 0.0
|
||||
assert normalized[1].normalized_score == 1.0 / 3.0
|
||||
assert normalized[2].normalized_score == 1.0
|
||||
|
||||
|
||||
def test_normalize_candidate_scores_handles_equal_scores():
|
||||
candidates = [
|
||||
candidate(0, "a", "alpha", 3.0),
|
||||
candidate(1, "b", "beta", 3.0),
|
||||
candidate(2, "c", "gamma", 3.0),
|
||||
]
|
||||
|
||||
normalized = normalize_candidate_scores(candidates)
|
||||
|
||||
assert [c.normalized_score for c in normalized] == [0.5, 0.5, 0.5]
|
||||
|
||||
|
||||
def test_mmr_select_limits_results():
|
||||
candidates = [
|
||||
candidate(0, "a", "alpha policy", 0.9),
|
||||
candidate(1, "b", "beta refund", 0.8),
|
||||
candidate(2, "c", "gamma shipping", 0.7),
|
||||
]
|
||||
|
||||
selected = mmr_select(candidates, limit=2)
|
||||
|
||||
assert len(selected) == 2
|
||||
|
||||
|
||||
def test_mmr_select_prefers_highest_reranker_score_first():
|
||||
candidates = [
|
||||
candidate(0, "a", "weakly relevant text", 0.1),
|
||||
candidate(1, "b", "strongly relevant answer", 10.0),
|
||||
candidate(2, "c", "medium relevant text", 5.0),
|
||||
]
|
||||
|
||||
selected = mmr_select(candidates, limit=1)
|
||||
|
||||
assert selected[0].chunk_id == "b"
|
||||
|
||||
|
||||
def test_mmr_select_penalizes_near_duplicate_chunks():
|
||||
candidates = [
|
||||
candidate(0, "a", "apple banana fruit return policy", 1.00),
|
||||
candidate(1, "b", "apple banana fruit return policy duplicate", 0.95),
|
||||
candidate(2, "c", "engine motor vehicle warranty", 0.90),
|
||||
]
|
||||
|
||||
selected = mmr_select(
|
||||
candidates,
|
||||
limit=2,
|
||||
lambda_mult=0.2,
|
||||
token_overlap_weight=1.0,
|
||||
)
|
||||
|
||||
assert [c.chunk_id for c in selected] == ["a", "c"]
|
||||
|
||||
|
||||
def test_pair_diversity_penalty_is_clamped():
|
||||
left = candidate(0, "a", "same same same", 1.0)
|
||||
right = candidate(1, "b", "same same same", 0.9)
|
||||
|
||||
penalty = _pair_diversity_penalty(
|
||||
left,
|
||||
right,
|
||||
token_overlap_weight=10.0,
|
||||
)
|
||||
|
||||
assert penalty == 1.0
|
||||
|
|
@ -476,3 +476,75 @@ class TestRerankActive:
|
|||
await rag.query(query="What is the return policy?")
|
||||
|
||||
assert reranker.calls == []
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Diversity selection: optional MMR after cross-encoder scoring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diversity_mode_scores_full_candidate_pool_before_selecting(self):
|
||||
"""
|
||||
With diversity selection enabled, the cross-encoder should score the full
|
||||
fetched candidate pool before MMR narrows it down to doc_limit.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
reranker = StubReranker([
|
||||
RerankerResult(document_id="0", query_id="0", score=1.00),
|
||||
RerankerResult(document_id="1", query_id="0", score=0.95),
|
||||
RerankerResult(document_id="2", query_id="0", score=0.90),
|
||||
])
|
||||
rag = DocumentRag(
|
||||
*clients,
|
||||
reranker_client=reranker,
|
||||
rerank_diversity_mode="mmr",
|
||||
)
|
||||
|
||||
await rag.query(query="What is the return policy?", doc_limit=2)
|
||||
|
||||
assert reranker.calls[0]["limit"] == len(ORDERED_CONTENT)
|
||||
|
||||
call = rag.prompt_client.document_prompt.call_args
|
||||
passed_docs = call.kwargs["documents"]
|
||||
assert len(passed_docs) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diversity_mode_selects_less_redundant_context_set(self):
|
||||
"""
|
||||
MMR should use cross-encoder scores as relevance while penalizing redundant
|
||||
chunks, so a slightly lower-scored but less redundant chunk can be selected.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
|
||||
|
||||
duplicate_a = "apple banana fruit return policy"
|
||||
duplicate_b = "apple banana fruit return policy duplicate"
|
||||
diverse_c = "engine motor vehicle warranty"
|
||||
|
||||
async def mock_fetch(chunk_id):
|
||||
return {
|
||||
CHUNK_A: duplicate_a,
|
||||
CHUNK_B: duplicate_b,
|
||||
CHUNK_C: diverse_c,
|
||||
}[chunk_id]
|
||||
|
||||
fetch_chunk.side_effect = mock_fetch
|
||||
|
||||
reranker = StubReranker([
|
||||
RerankerResult(document_id="0", query_id="0", score=1.00),
|
||||
RerankerResult(document_id="1", query_id="0", score=0.95),
|
||||
RerankerResult(document_id="2", query_id="0", score=0.90),
|
||||
])
|
||||
rag = DocumentRag(
|
||||
*clients,
|
||||
reranker_client=reranker,
|
||||
rerank_diversity_mode="mmr",
|
||||
rerank_diversity_lambda=0.2,
|
||||
)
|
||||
|
||||
await rag.query(query="What is the return policy?", doc_limit=2)
|
||||
|
||||
call = rag.prompt_client.document_prompt.call_args
|
||||
passed_docs = call.kwargs["documents"]
|
||||
|
||||
assert passed_docs == [duplicate_a, diverse_c]
|
||||
353
tests/unit/test_retrieval/test_graph_rag_direction_aware_text.py
Normal file
353
tests/unit/test_retrieval/test_graph_rag_direction_aware_text.py
Normal file
|
|
@ -0,0 +1,353 @@
|
|||
"""
|
||||
Tests for direction-aware reranker text in GraphRAG hop-and-filter.
|
||||
|
||||
The reranker document text varies by traversal direction:
|
||||
- From S (subject is the frontier entity): text = "{p} {o}"
|
||||
- From O (object is the frontier entity): text = "{s} {p}"
|
||||
- From P (predicate is the frontier entity): text = "{s} {o}"
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import Query, LRUCacheWithTTL
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_rag(reranker_results=None):
|
||||
"""Create a mock GraphRag with all clients stubbed."""
|
||||
rag = MagicMock()
|
||||
rag.label_cache = LRUCacheWithTTL()
|
||||
rag.triples_client = AsyncMock()
|
||||
rag.reranker_client = AsyncMock()
|
||||
|
||||
# Label lookups return empty (fall back to URI)
|
||||
rag.triples_client.query.return_value = []
|
||||
|
||||
if reranker_results is not None:
|
||||
rag.reranker_client.rerank.return_value = reranker_results
|
||||
else:
|
||||
rag.reranker_client.rerank.return_value = []
|
||||
|
||||
return rag
|
||||
|
||||
|
||||
def _make_query(rag, max_path_length=1, edge_limit=25):
|
||||
return Query(
|
||||
rag=rag,
|
||||
collection="test",
|
||||
verbose=False,
|
||||
entity_limit=50,
|
||||
triple_limit=30,
|
||||
max_subgraph_size=1000,
|
||||
max_path_length=max_path_length,
|
||||
edge_limit=edge_limit,
|
||||
)
|
||||
|
||||
|
||||
def _make_schema_triple(s, p, o):
|
||||
"""Create a mock triple matching the schema interface."""
|
||||
t = MagicMock()
|
||||
t.s = s
|
||||
t.p = p
|
||||
t.o = o
|
||||
return t
|
||||
|
||||
|
||||
def _reranker_result(document_id, query_id="0", score=0.9):
|
||||
r = MagicMock()
|
||||
r.document_id = str(document_id)
|
||||
r.query_id = str(query_id)
|
||||
r.score = score
|
||||
return r
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: execute_batch_triple_queries direction tracking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDirectionTracking:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_from_s_direction(self):
|
||||
"""Triples from s=entity queries are tagged FROM_S."""
|
||||
triple = _make_schema_triple("ent1", "pred", "obj")
|
||||
rag = _make_rag()
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if s is not None:
|
||||
return [triple]
|
||||
return []
|
||||
|
||||
rag.triples_client.query_stream.side_effect = query_stream
|
||||
q = _make_query(rag)
|
||||
|
||||
result = await q.execute_batch_triple_queries(["ent1"], 10)
|
||||
|
||||
from_s = [(t, d) for t, d in result if d == Query.FROM_S]
|
||||
assert len(from_s) == 1
|
||||
assert from_s[0][0] is triple
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_from_o_direction(self):
|
||||
"""Triples from o=entity queries are tagged FROM_O."""
|
||||
triple = _make_schema_triple("subj", "pred", "ent1")
|
||||
rag = _make_rag()
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if o is not None:
|
||||
return [triple]
|
||||
return []
|
||||
|
||||
rag.triples_client.query_stream.side_effect = query_stream
|
||||
q = _make_query(rag)
|
||||
|
||||
result = await q.execute_batch_triple_queries(["ent1"], 10)
|
||||
|
||||
from_o = [(t, d) for t, d in result if d == Query.FROM_O]
|
||||
assert len(from_o) == 1
|
||||
assert from_o[0][0] is triple
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_from_p_direction(self):
|
||||
"""Triples from p=entity queries are tagged FROM_P."""
|
||||
triple = _make_schema_triple("subj", "ent1", "obj")
|
||||
rag = _make_rag()
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if p is not None:
|
||||
return [triple]
|
||||
return []
|
||||
|
||||
rag.triples_client.query_stream.side_effect = query_stream
|
||||
q = _make_query(rag)
|
||||
|
||||
result = await q.execute_batch_triple_queries(["ent1"], 10)
|
||||
|
||||
from_p = [(t, d) for t, d in result if d == Query.FROM_P]
|
||||
assert len(from_p) == 1
|
||||
assert from_p[0][0] is triple
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: hop_and_filter reranker document text
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDirectionAwareRerankerText:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_from_s_uses_predicate_object(self):
|
||||
"""From-S traversal: reranker text should be '{p} {o}'."""
|
||||
triple = _make_schema_triple(
|
||||
"http://ex/entity-A",
|
||||
"http://ex/likes",
|
||||
"http://ex/entity-B",
|
||||
)
|
||||
reranker_result = _reranker_result(0)
|
||||
rag = _make_rag(reranker_results=[reranker_result])
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if s is not None:
|
||||
return [triple]
|
||||
return []
|
||||
|
||||
rag.triples_client.query_stream.side_effect = query_stream
|
||||
|
||||
q = _make_query(rag, max_path_length=1, edge_limit=10)
|
||||
|
||||
await q.hop_and_filter(
|
||||
seed_entities=["http://ex/entity-A"],
|
||||
concepts=["likes"],
|
||||
)
|
||||
|
||||
call_args = rag.reranker_client.rerank.call_args
|
||||
documents = call_args.kwargs["documents"]
|
||||
# Text should be "{p} {o}" — the URIs since no labels found
|
||||
assert len(documents) == 1
|
||||
assert documents[0]["text"] == "http://ex/likes http://ex/entity-B"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_from_o_uses_subject_predicate(self):
|
||||
"""From-O traversal: reranker text should be '{s} {p}'."""
|
||||
triple = _make_schema_triple(
|
||||
"http://ex/entity-A",
|
||||
"http://ex/likes",
|
||||
"http://ex/entity-B",
|
||||
)
|
||||
reranker_result = _reranker_result(0)
|
||||
rag = _make_rag(reranker_results=[reranker_result])
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if o is not None:
|
||||
return [triple]
|
||||
return []
|
||||
|
||||
rag.triples_client.query_stream.side_effect = query_stream
|
||||
|
||||
q = _make_query(rag, max_path_length=1, edge_limit=10)
|
||||
|
||||
await q.hop_and_filter(
|
||||
seed_entities=["http://ex/entity-B"],
|
||||
concepts=["likes"],
|
||||
)
|
||||
|
||||
call_args = rag.reranker_client.rerank.call_args
|
||||
documents = call_args.kwargs["documents"]
|
||||
assert len(documents) == 1
|
||||
assert documents[0]["text"] == "http://ex/entity-A http://ex/likes"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_from_p_uses_subject_object(self):
|
||||
"""From-P traversal: reranker text should be '{s} {o}'."""
|
||||
triple = _make_schema_triple(
|
||||
"http://ex/entity-A",
|
||||
"http://ex/likes",
|
||||
"http://ex/entity-B",
|
||||
)
|
||||
reranker_result = _reranker_result(0)
|
||||
rag = _make_rag(reranker_results=[reranker_result])
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if p is not None:
|
||||
return [triple]
|
||||
return []
|
||||
|
||||
rag.triples_client.query_stream.side_effect = query_stream
|
||||
|
||||
q = _make_query(rag, max_path_length=1, edge_limit=10)
|
||||
|
||||
await q.hop_and_filter(
|
||||
seed_entities=["http://ex/likes"],
|
||||
concepts=["entity"],
|
||||
)
|
||||
|
||||
call_args = rag.reranker_client.rerank.call_args
|
||||
documents = call_args.kwargs["documents"]
|
||||
assert len(documents) == 1
|
||||
assert documents[0]["text"] == "http://ex/entity-A http://ex/entity-B"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_directions_produce_different_text(self):
|
||||
"""Edges from different directions use different text formats."""
|
||||
triple_from_s = _make_schema_triple(
|
||||
"http://ex/seed", "http://ex/rel", "http://ex/target",
|
||||
)
|
||||
triple_from_o = _make_schema_triple(
|
||||
"http://ex/other", "http://ex/ref", "http://ex/seed",
|
||||
)
|
||||
|
||||
rag = _make_rag(reranker_results=[
|
||||
_reranker_result(0), _reranker_result(1),
|
||||
])
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if s == "http://ex/seed":
|
||||
return [triple_from_s]
|
||||
if o == "http://ex/seed":
|
||||
return [triple_from_o]
|
||||
return []
|
||||
|
||||
rag.triples_client.query_stream.side_effect = query_stream
|
||||
|
||||
q = _make_query(rag, max_path_length=1, edge_limit=10)
|
||||
|
||||
await q.hop_and_filter(
|
||||
seed_entities=["http://ex/seed"],
|
||||
concepts=["test"],
|
||||
)
|
||||
|
||||
call_args = rag.reranker_client.rerank.call_args
|
||||
documents = call_args.kwargs["documents"]
|
||||
texts = {d["text"] for d in documents}
|
||||
|
||||
# From S: "{p} {o}" = "http://ex/rel http://ex/target"
|
||||
assert "http://ex/rel http://ex/target" in texts
|
||||
# From O: "{s} {p}" = "http://ex/other http://ex/ref"
|
||||
assert "http://ex/other http://ex/ref" in texts
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_labels_applied_to_direction_text(self):
|
||||
"""Labels should be resolved and used in the direction-aware text."""
|
||||
triple = _make_schema_triple(
|
||||
"http://ex/entity-A",
|
||||
"http://ex/likes",
|
||||
"http://ex/entity-B",
|
||||
)
|
||||
reranker_result = _reranker_result(0)
|
||||
rag = _make_rag(reranker_results=[reranker_result])
|
||||
|
||||
LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if s is not None and p is None:
|
||||
return [triple]
|
||||
return []
|
||||
|
||||
async def label_query(s=None, p=None, o=None, limit=1, **kwargs):
|
||||
if p == LABEL:
|
||||
labels = {
|
||||
"http://ex/entity-A": "Alice",
|
||||
"http://ex/likes": "likes",
|
||||
"http://ex/entity-B": "Bob",
|
||||
}
|
||||
if s in labels:
|
||||
return [MagicMock(o=labels[s])]
|
||||
return []
|
||||
|
||||
rag.triples_client.query_stream.side_effect = query_stream
|
||||
rag.triples_client.query.side_effect = label_query
|
||||
|
||||
q = _make_query(rag, max_path_length=1, edge_limit=10)
|
||||
|
||||
await q.hop_and_filter(
|
||||
seed_entities=["http://ex/entity-A"],
|
||||
concepts=["friendship"],
|
||||
)
|
||||
|
||||
call_args = rag.reranker_client.rerank.call_args
|
||||
documents = call_args.kwargs["documents"]
|
||||
assert len(documents) == 1
|
||||
# From S with labels: "{p_label} {o_label}"
|
||||
assert documents[0]["text"] == "likes Bob"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_duplicate_text_from_shared_object(self):
|
||||
"""Multiple edges sharing an object should produce distinct texts."""
|
||||
triple_a = _make_schema_triple(
|
||||
"http://ex/cpu-A", "http://ex/hasCategory", "http://ex/Processors",
|
||||
)
|
||||
triple_b = _make_schema_triple(
|
||||
"http://ex/cpu-B", "http://ex/hasCategory", "http://ex/Processors",
|
||||
)
|
||||
|
||||
rag = _make_rag(reranker_results=[
|
||||
_reranker_result(0), _reranker_result(1),
|
||||
])
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if o == "http://ex/Processors":
|
||||
return [triple_a, triple_b]
|
||||
return []
|
||||
|
||||
rag.triples_client.query_stream.side_effect = query_stream
|
||||
|
||||
q = _make_query(rag, max_path_length=1, edge_limit=10)
|
||||
|
||||
await q.hop_and_filter(
|
||||
seed_entities=["http://ex/Processors"],
|
||||
concepts=["CPUs"],
|
||||
)
|
||||
|
||||
call_args = rag.reranker_client.rerank.call_args
|
||||
documents = call_args.kwargs["documents"]
|
||||
texts = [d["text"] for d in documents]
|
||||
|
||||
assert len(texts) == 2
|
||||
# From O: "{s} {p}" — subjects differ, so texts differ
|
||||
assert texts[0] != texts[1]
|
||||
assert "http://ex/cpu-A" in texts[0]
|
||||
assert "http://ex/cpu-B" in texts[1]
|
||||
|
|
@ -20,6 +20,8 @@ from trustgraph.provenance import (
|
|||
GRAPH_RETRIEVAL,
|
||||
)
|
||||
|
||||
from .rerank import RerankCandidate, mmr_select
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -150,6 +152,8 @@ class DocumentRag:
|
|||
fetch_chunk,
|
||||
reranker_client=None,
|
||||
verbose=False,
|
||||
rerank_diversity_mode="none",
|
||||
rerank_diversity_lambda=0.7,
|
||||
):
|
||||
|
||||
self.verbose = verbose
|
||||
|
|
@ -162,6 +166,8 @@ class DocumentRag:
|
|||
# Optional cross-encoder reranker. When None, the retrieval path is
|
||||
# byte-identical to the pre-reranker behaviour.
|
||||
self.reranker_client = reranker_client
|
||||
self.rerank_diversity_mode = rerank_diversity_mode
|
||||
self.rerank_diversity_lambda = rerank_diversity_lambda
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("DocumentRag initialized")
|
||||
|
|
@ -277,30 +283,74 @@ class DocumentRag:
|
|||
# skipped entirely and behaviour is byte-identical to before.
|
||||
reranked = False
|
||||
if self.reranker_client is not None and docs:
|
||||
use_diversity = self.rerank_diversity_mode == "mmr"
|
||||
|
||||
# Without diversity selection, preserve the existing #1011
|
||||
# behavior: ask the reranker for exactly doc_limit results.
|
||||
#
|
||||
# With diversity selection enabled, ask the reranker to score the
|
||||
# full fetched candidate pool first, then let MMR choose the final
|
||||
# doc_limit context set.
|
||||
rerank_limit = len(docs) if use_diversity else doc_limit
|
||||
|
||||
results = await self.reranker_client.rerank(
|
||||
queries=[{"id": "0", "text": query}],
|
||||
documents=[
|
||||
{"id": str(i), "text": d} for i, d in enumerate(docs)
|
||||
],
|
||||
# Narrow the over-fetched candidate pool down to the final
|
||||
# doc_limit requested for synthesis.
|
||||
limit=doc_limit,
|
||||
limit=rerank_limit,
|
||||
)
|
||||
|
||||
# results are sorted desc by score and truncated to limit by the
|
||||
# reranker service, so order gives the surviving top-N directly.
|
||||
order = [int(r.document_id) for r in results]
|
||||
docs = [docs[i] for i in order]
|
||||
chunk_ids = [chunk_ids[i] for i in order]
|
||||
source_docs = docs
|
||||
source_chunk_ids = chunk_ids
|
||||
|
||||
if use_diversity:
|
||||
candidates = [
|
||||
RerankCandidate(
|
||||
index=int(r.document_id),
|
||||
chunk_id=source_chunk_ids[int(r.document_id)],
|
||||
text=source_docs[int(r.document_id)],
|
||||
reranker_score=r.score,
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
|
||||
selected_candidates = mmr_select(
|
||||
candidates,
|
||||
limit=doc_limit,
|
||||
lambda_mult=self.rerank_diversity_lambda,
|
||||
)
|
||||
|
||||
docs = [candidate.text for candidate in selected_candidates]
|
||||
chunk_ids = [
|
||||
candidate.chunk_id for candidate in selected_candidates
|
||||
]
|
||||
|
||||
selected_chunks_with_scores = [
|
||||
{
|
||||
"chunk_id": candidate.chunk_id,
|
||||
"score": candidate.reranker_score,
|
||||
}
|
||||
for candidate in selected_candidates
|
||||
]
|
||||
|
||||
else:
|
||||
# results are sorted desc by score and truncated to limit by the
|
||||
# reranker service, so order gives the surviving top-N directly.
|
||||
order = [int(r.document_id) for r in results]
|
||||
docs = [source_docs[i] for i in order]
|
||||
chunk_ids = [source_chunk_ids[i] for i in order]
|
||||
|
||||
selected_chunks_with_scores = [
|
||||
{"chunk_id": chunk_ids[i], "score": r.score}
|
||||
for i, r in enumerate(results)
|
||||
]
|
||||
|
||||
reranked = True
|
||||
|
||||
# Emit chunk-selection (focus) explainability: surviving chunks
|
||||
# with their cross-encoder scores, derived from exploration.
|
||||
if explain_callback:
|
||||
selected_chunks_with_scores = [
|
||||
{"chunk_id": chunk_ids[i], "score": r.score}
|
||||
for i, r in enumerate(results)
|
||||
]
|
||||
foc_triples = set_graph(
|
||||
docrag_chunk_selection_triples(
|
||||
foc_uri, exp_uri,
|
||||
|
|
|
|||
|
|
@ -33,17 +33,23 @@ class Processor(FlowProcessor):
|
|||
# reranking; the rerank step narrows it back down to doc_limit for the
|
||||
# LLM. 0 means the core derives it (OVERFETCH_FACTOR x doc_limit).
|
||||
fetch_limit = params.get("fetch_limit", 0)
|
||||
rerank_diversity_mode = params.get("rerank_diversity_mode", "none")
|
||||
rerank_diversity_lambda = params.get("rerank_diversity_lambda", 0.7)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"doc_limit": doc_limit,
|
||||
"fetch_limit": fetch_limit,
|
||||
"rerank_diversity_mode": rerank_diversity_mode,
|
||||
"rerank_diversity_lambda": rerank_diversity_lambda,
|
||||
}
|
||||
)
|
||||
|
||||
self.doc_limit = doc_limit
|
||||
self.fetch_limit = fetch_limit
|
||||
self.rerank_diversity_mode = rerank_diversity_mode
|
||||
self.rerank_diversity_lambda = rerank_diversity_lambda
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
|
|
@ -122,6 +128,8 @@ class Processor(FlowProcessor):
|
|||
fetch_chunk = fetch_chunk,
|
||||
reranker_client = flow("reranker-request"),
|
||||
verbose=True,
|
||||
rerank_diversity_mode=self.rerank_diversity_mode,
|
||||
rerank_diversity_lambda=self.rerank_diversity_lambda,
|
||||
)
|
||||
|
||||
if v.doc_limit:
|
||||
|
|
@ -277,6 +285,20 @@ class Processor(FlowProcessor):
|
|||
'(default: derive from doc-limit)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--rerank-diversity-mode',
|
||||
choices=['none', 'mmr'],
|
||||
default='none',
|
||||
help='Optional diversity-aware selection after reranking (default: none)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--rerank-diversity-lambda',
|
||||
type=float,
|
||||
default=0.7,
|
||||
help='MMR relevance/diversity tradeoff, higher values prefer relevance'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
142
trustgraph-flow/trustgraph/retrieval/document_rag/rerank.py
Normal file
142
trustgraph-flow/trustgraph/retrieval/document_rag/rerank.py
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
import re
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import List, Sequence, Set
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RerankCandidate:
|
||||
"""
|
||||
Candidate chunk after cross-encoder reranking.
|
||||
|
||||
reranker_score is the raw score returned by the reranker backend. It may
|
||||
not be normalized, so MMR should use normalized_score instead.
|
||||
"""
|
||||
index: int
|
||||
chunk_id: str
|
||||
text: str
|
||||
reranker_score: float
|
||||
normalized_score: float = 0.0
|
||||
|
||||
|
||||
_TOKEN_RE = re.compile(r"[A-Za-z0-9_]+")
|
||||
|
||||
|
||||
def _clamp01(value: float) -> float:
|
||||
return max(0.0, min(1.0, value))
|
||||
|
||||
|
||||
def _token_set(text: str) -> Set[str]:
|
||||
return set(token.lower() for token in _TOKEN_RE.findall(text or ""))
|
||||
|
||||
|
||||
def _jaccard(a: str, b: str) -> float:
|
||||
a_tokens = _token_set(a)
|
||||
b_tokens = _token_set(b)
|
||||
|
||||
if not a_tokens or not b_tokens:
|
||||
return 0.0
|
||||
|
||||
return len(a_tokens & b_tokens) / len(a_tokens | b_tokens)
|
||||
|
||||
|
||||
def normalize_candidate_scores(
|
||||
candidates: Sequence[RerankCandidate],
|
||||
) -> List[RerankCandidate]:
|
||||
"""
|
||||
Min-max normalize reranker scores within the current candidate set.
|
||||
|
||||
Reranker backends may return different score scales: probabilities,
|
||||
logits, or prompt-defined scores. MMR needs a stable [0, 1] relevance
|
||||
signal, so normalize per candidate set instead of assuming a global range.
|
||||
"""
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
scores = [float(candidate.reranker_score) for candidate in candidates]
|
||||
min_score = min(scores)
|
||||
max_score = max(scores)
|
||||
|
||||
if max_score == min_score:
|
||||
return [
|
||||
replace(candidate, normalized_score=0.5)
|
||||
for candidate in candidates
|
||||
]
|
||||
|
||||
score_range = max_score - min_score
|
||||
|
||||
return [
|
||||
replace(
|
||||
candidate,
|
||||
normalized_score=(float(candidate.reranker_score) - min_score) / score_range,
|
||||
)
|
||||
for candidate in candidates
|
||||
]
|
||||
|
||||
|
||||
def _pair_diversity_penalty(
|
||||
candidate: RerankCandidate,
|
||||
selected: RerankCandidate,
|
||||
token_overlap_weight: float,
|
||||
) -> float:
|
||||
"""
|
||||
Pairwise diversity penalty between two candidate chunks.
|
||||
|
||||
The first revision only uses token overlap because the current Document-RAG
|
||||
reranker document_id is the candidate index, not a source document id.
|
||||
"""
|
||||
penalty = token_overlap_weight * _jaccard(candidate.text, selected.text)
|
||||
return _clamp01(penalty)
|
||||
|
||||
|
||||
def mmr_select(
|
||||
candidates: Sequence[RerankCandidate],
|
||||
limit: int,
|
||||
lambda_mult: float = 0.7,
|
||||
token_overlap_weight: float = 1.0,
|
||||
) -> List[RerankCandidate]:
|
||||
"""
|
||||
Select a diverse final context set using MMR.
|
||||
|
||||
Relevance comes from normalized cross-encoder reranker scores.
|
||||
Diversity comes from token overlap against already selected chunks.
|
||||
"""
|
||||
if limit <= 0:
|
||||
return []
|
||||
|
||||
lambda_mult = _clamp01(lambda_mult)
|
||||
token_overlap_weight = max(0.0, token_overlap_weight)
|
||||
|
||||
remaining = normalize_candidate_scores(candidates)
|
||||
selected: List[RerankCandidate] = []
|
||||
|
||||
while remaining and len(selected) < limit:
|
||||
best_idx = 0
|
||||
best_score = None
|
||||
|
||||
for idx, candidate in enumerate(remaining):
|
||||
relevance = candidate.normalized_score
|
||||
|
||||
if selected:
|
||||
diversity_penalty = max(
|
||||
_pair_diversity_penalty(
|
||||
candidate,
|
||||
chosen,
|
||||
token_overlap_weight=token_overlap_weight,
|
||||
)
|
||||
for chosen in selected
|
||||
)
|
||||
else:
|
||||
diversity_penalty = 0.0
|
||||
|
||||
mmr_score = (
|
||||
lambda_mult * relevance
|
||||
- (1.0 - lambda_mult) * diversity_penalty
|
||||
)
|
||||
|
||||
if best_score is None or mmr_score > best_score:
|
||||
best_score = mmr_score
|
||||
best_idx = idx
|
||||
|
||||
selected.append(remaining.pop(best_idx))
|
||||
|
||||
return selected
|
||||
|
|
@ -241,38 +241,56 @@ class Query:
|
|||
self.rag.label_cache.put(cache_key, label)
|
||||
return label
|
||||
|
||||
FROM_S = "from_s"
|
||||
FROM_P = "from_p"
|
||||
FROM_O = "from_o"
|
||||
|
||||
async def execute_batch_triple_queries(self, entities, limit_per_entity):
|
||||
"""Execute triple queries for multiple entities concurrently."""
|
||||
"""Execute triple queries for multiple entities concurrently.
|
||||
|
||||
Returns a list of (triple, direction) tuples where direction
|
||||
indicates which position the frontier entity occupied.
|
||||
"""
|
||||
tasks = []
|
||||
directions = []
|
||||
|
||||
for entity in entities:
|
||||
tasks.extend([
|
||||
tasks.append(
|
||||
self.rag.triples_client.query_stream(
|
||||
s=entity, p=None, o=None,
|
||||
limit=limit_per_entity,
|
||||
collection=self.collection,
|
||||
batch_size=20, g="",
|
||||
),
|
||||
)
|
||||
directions.append(self.FROM_S)
|
||||
|
||||
tasks.append(
|
||||
self.rag.triples_client.query_stream(
|
||||
s=None, p=entity, o=None,
|
||||
limit=limit_per_entity,
|
||||
collection=self.collection,
|
||||
batch_size=20, g="",
|
||||
),
|
||||
)
|
||||
directions.append(self.FROM_P)
|
||||
|
||||
tasks.append(
|
||||
self.rag.triples_client.query_stream(
|
||||
s=None, p=None, o=entity,
|
||||
limit=limit_per_entity,
|
||||
collection=self.collection,
|
||||
batch_size=20, g="",
|
||||
)
|
||||
])
|
||||
),
|
||||
)
|
||||
directions.append(self.FROM_O)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
all_triples = []
|
||||
for result in results:
|
||||
for direction, result in zip(directions, results):
|
||||
if not isinstance(result, Exception) and result is not None:
|
||||
all_triples.extend(result)
|
||||
all_triples.extend((triple, direction) for triple in result)
|
||||
|
||||
return all_triples
|
||||
|
||||
|
|
@ -325,7 +343,8 @@ class Query:
|
|||
# Deduplicate and filter already-seen edges
|
||||
hop_triples = []
|
||||
hop_term_map = {}
|
||||
for triple in triples:
|
||||
hop_directions = {}
|
||||
for triple, direction in triples:
|
||||
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
|
||||
if triple_tuple[1] == LABEL:
|
||||
continue
|
||||
|
|
@ -336,6 +355,7 @@ class Query:
|
|||
hop_term_map[triple_tuple] = (
|
||||
to_term(triple.s), to_term(triple.p), to_term(triple.o),
|
||||
)
|
||||
hop_directions[triple_tuple] = direction
|
||||
|
||||
if not hop_triples:
|
||||
visited_entities.update(frontier)
|
||||
|
|
@ -361,7 +381,10 @@ class Query:
|
|||
else:
|
||||
label_map[entity] = entity
|
||||
|
||||
# Build labeled edges and documents for cross-encoder
|
||||
# Build labeled edges and documents for cross-encoder.
|
||||
# The reranker text highlights the NEW information relative
|
||||
# to the traversal direction: arriving from S means p,o are
|
||||
# new; from O means s,p are new; from P means s,o are new.
|
||||
labeled_hop = []
|
||||
for s, p, o in hop_triples:
|
||||
ls = label_map.get(s, s)
|
||||
|
|
@ -369,10 +392,18 @@ class Query:
|
|||
lo = label_map.get(o, o)
|
||||
labeled_hop.append((ls, lp, lo))
|
||||
|
||||
documents = [
|
||||
{"id": str(i), "text": f"{lp} {lo}"}
|
||||
for i, (ls, lp, lo) in enumerate(labeled_hop)
|
||||
]
|
||||
documents = []
|
||||
for i, (triple_tuple, (ls, lp, lo)) in enumerate(
|
||||
zip(hop_triples, labeled_hop)
|
||||
):
|
||||
direction = hop_directions[triple_tuple]
|
||||
if direction == self.FROM_S:
|
||||
text = f"{lp} {lo}"
|
||||
elif direction == self.FROM_O:
|
||||
text = f"{ls} {lp}"
|
||||
else:
|
||||
text = f"{ls} {lo}"
|
||||
documents.append({"id": str(i), "text": text})
|
||||
|
||||
queries = [
|
||||
{"id": str(i), "text": c}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue