mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-03 06:51:00 +02:00
feat: direction-aware reranker text in GraphRAG hop-and-filter (#1016)
The reranker document text now reflects the traversal direction,
showing only the new information relative to the frontier entity:
- From S (subject is frontier): text = "{predicate} {object}"
- From O (object is frontier): text = "{subject} {predicate}"
- From P (predicate is frontier): text = "{subject} {object}"
This eliminates duplicate reranker texts when traversing inward
from shared object nodes (e.g. 18 CPUs all producing identical
"hasSubcategory Processors" text when the subject was dropped).
execute_batch_triple_queries now returns (triple, direction)
tuples so hop_and_filter can select the appropriate text format.
Updates tech spec to document the direction-aware approach.
Adds unit tests for direction tracking and reranker text
construction.
This commit is contained in:
parent
9cf7dcb578
commit
db7fdbc652
4 changed files with 502 additions and 19 deletions
|
|
@ -224,12 +224,27 @@ The current embedding pre-filter represents each edge as
|
||||||
- **Drop commas.** Commas add tokenisation noise without semantic
|
- **Drop commas.** Commas add tokenisation noise without semantic
|
||||||
value.
|
value.
|
||||||
|
|
||||||
- **Drop the subject.** The subject identifies which entity the
|
- **Direction-aware text.** The reranker text should highlight
|
||||||
edge belongs to, but it does not contribute to whether the
|
the *new* information relative to the traversal direction.
|
||||||
edge's content is relevant to the query. The predicate and
|
The frontier entity is already known context — repeating it
|
||||||
object carry the semantic meaning — what relationship exists
|
adds noise and, when traversing from an object node, causes
|
||||||
and what it connects to. Representing edges as `"{p} {o}"`
|
many edges to produce identical reranker text (e.g. 18
|
||||||
produces cleaner cross-encoder matches.
|
products sharing the same `hasSubcategory Processors` triple
|
||||||
|
all collapse to the same string when the subject is dropped).
|
||||||
|
|
||||||
|
The text is constructed based on which position the frontier
|
||||||
|
entity occupied in the triple:
|
||||||
|
|
||||||
|
- **From subject** (s=entity): `"{predicate} {object}"` —
|
||||||
|
the subject is known, predicate and object are new.
|
||||||
|
- **From object** (o=entity): `"{subject} {predicate}"` —
|
||||||
|
the object is known, subject and predicate are new.
|
||||||
|
- **From predicate** (p=entity): `"{subject} {object}"` —
|
||||||
|
the predicate is known, subject and object are new.
|
||||||
|
|
||||||
|
This eliminates the duplicate-text problem that arises when
|
||||||
|
traversing inward from a shared object node, and gives the
|
||||||
|
cross-encoder a more informative signal at every hop.
|
||||||
|
|
||||||
#### Remove the embedding pre-filter (step 3)
|
#### Remove the embedding pre-filter (step 3)
|
||||||
|
|
||||||
|
|
@ -389,7 +404,10 @@ no LLM call. These fields are dropped from the Focus entity.
|
||||||
|
|
||||||
a. Retrieve all edges one hop from the current frontier nodes.
|
a. Retrieve all edges one hop from the current frontier nodes.
|
||||||
|
|
||||||
b. Represent each edge as `"{predicate} {object}"`.
|
b. Represent each edge using direction-aware text: from a
|
||||||
|
subject node use `"{predicate} {object}"`, from an object
|
||||||
|
node use `"{subject} {predicate}"`, from a predicate node
|
||||||
|
use `"{subject} {object}"`.
|
||||||
|
|
||||||
c. Score edges against the extracted concepts using the
|
c. Score edges against the extracted concepts using the
|
||||||
cross-encoder service.
|
cross-encoder service.
|
||||||
|
|
|
||||||
|
|
@ -129,6 +129,9 @@ class TestBatchTripleQueries:
|
||||||
|
|
||||||
# 3 queries, alternating results
|
# 3 queries, alternating results
|
||||||
assert len(result) == 3
|
assert len(result) == 3
|
||||||
|
# Each result is a (triple, direction) tuple
|
||||||
|
for triple, direction in result:
|
||||||
|
assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_exception_in_one_query_does_not_block_others(self):
|
async def test_exception_in_one_query_does_not_block_others(self):
|
||||||
|
|
@ -153,6 +156,8 @@ class TestBatchTripleQueries:
|
||||||
|
|
||||||
# 3 queries: 2 succeed, 1 fails → 2 triples
|
# 3 queries: 2 succeed, 1 fails → 2 triples
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
|
for triple, direction in result:
|
||||||
|
assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_none_results_filtered(self):
|
async def test_none_results_filtered(self):
|
||||||
|
|
@ -176,6 +181,8 @@ class TestBatchTripleQueries:
|
||||||
|
|
||||||
# 3 queries: 1 returns None, 2 return triples
|
# 3 queries: 1 returns None, 2 return triples
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
|
for triple, direction in result:
|
||||||
|
assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_empty_entities_no_queries(self):
|
async def test_empty_entities_no_queries(self):
|
||||||
|
|
@ -220,6 +227,80 @@ class TestBatchTripleQueries:
|
||||||
assert calls[2].kwargs["p"] is None
|
assert calls[2].kwargs["p"] is None
|
||||||
assert calls[2].kwargs["o"] == "ent-1"
|
assert calls[2].kwargs["o"] == "ent-1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_directions_assigned_correctly(self):
|
||||||
|
"""Each query position should produce the correct direction tag."""
|
||||||
|
triple = _make_triple("s", "p", "o")
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def one_triple_each(**kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return [triple]
|
||||||
|
|
||||||
|
client = AsyncMock()
|
||||||
|
client.query_stream = one_triple_each
|
||||||
|
query = _make_query(triples_client=client)
|
||||||
|
|
||||||
|
result = await query.execute_batch_triple_queries(
|
||||||
|
["e1"], limit_per_entity=10
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 3
|
||||||
|
# Order matches query order: s-position, p-position, o-position
|
||||||
|
assert result[0][1] == Query.FROM_S
|
||||||
|
assert result[1][1] == Query.FROM_P
|
||||||
|
assert result[2][1] == Query.FROM_O
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_directions_correct_for_multiple_entities(self):
|
||||||
|
"""Direction tags cycle correctly across multiple entities."""
|
||||||
|
triple = _make_triple("s", "p", "o")
|
||||||
|
client = AsyncMock()
|
||||||
|
client.query_stream = AsyncMock(return_value=[triple])
|
||||||
|
query = _make_query(triples_client=client)
|
||||||
|
|
||||||
|
result = await query.execute_batch_triple_queries(
|
||||||
|
["e1", "e2"], limit_per_entity=10
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 6
|
||||||
|
expected_directions = [
|
||||||
|
Query.FROM_S, Query.FROM_P, Query.FROM_O,
|
||||||
|
Query.FROM_S, Query.FROM_P, Query.FROM_O,
|
||||||
|
]
|
||||||
|
for (_, direction), expected in zip(result, expected_directions):
|
||||||
|
assert direction == expected
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_direction_preserved_with_multiple_triples(self):
|
||||||
|
"""All triples from one query share the same direction."""
|
||||||
|
t1 = _make_triple("a", "p1", "b")
|
||||||
|
t2 = _make_triple("a", "p2", "c")
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def multi_results(**kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
return [t1, t2]
|
||||||
|
return []
|
||||||
|
|
||||||
|
client = AsyncMock()
|
||||||
|
client.query_stream = multi_results
|
||||||
|
query = _make_query(triples_client=client)
|
||||||
|
|
||||||
|
result = await query.execute_batch_triple_queries(
|
||||||
|
["e1"], limit_per_entity=10
|
||||||
|
)
|
||||||
|
|
||||||
|
# First query (FROM_S) returns 2 triples, both should be FROM_S
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0] == (t1, Query.FROM_S)
|
||||||
|
assert result[1] == (t2, Query.FROM_S)
|
||||||
|
|
||||||
|
|
||||||
class TestLRUCacheWithTTL:
|
class TestLRUCacheWithTTL:
|
||||||
|
|
||||||
|
|
|
||||||
353
tests/unit/test_retrieval/test_graph_rag_direction_aware_text.py
Normal file
353
tests/unit/test_retrieval/test_graph_rag_direction_aware_text.py
Normal file
|
|
@ -0,0 +1,353 @@
|
||||||
|
"""
|
||||||
|
Tests for direction-aware reranker text in GraphRAG hop-and-filter.
|
||||||
|
|
||||||
|
The reranker document text varies by traversal direction:
|
||||||
|
- From S (subject is the frontier entity): text = "{p} {o}"
|
||||||
|
- From O (object is the frontier entity): text = "{s} {p}"
|
||||||
|
- From P (predicate is the frontier entity): text = "{s} {o}"
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, AsyncMock
|
||||||
|
|
||||||
|
from trustgraph.retrieval.graph_rag.graph_rag import Query, LRUCacheWithTTL
|
||||||
|
from trustgraph.schema import Term, IRI, LITERAL
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_rag(reranker_results=None):
|
||||||
|
"""Create a mock GraphRag with all clients stubbed."""
|
||||||
|
rag = MagicMock()
|
||||||
|
rag.label_cache = LRUCacheWithTTL()
|
||||||
|
rag.triples_client = AsyncMock()
|
||||||
|
rag.reranker_client = AsyncMock()
|
||||||
|
|
||||||
|
# Label lookups return empty (fall back to URI)
|
||||||
|
rag.triples_client.query.return_value = []
|
||||||
|
|
||||||
|
if reranker_results is not None:
|
||||||
|
rag.reranker_client.rerank.return_value = reranker_results
|
||||||
|
else:
|
||||||
|
rag.reranker_client.rerank.return_value = []
|
||||||
|
|
||||||
|
return rag
|
||||||
|
|
||||||
|
|
||||||
|
def _make_query(rag, max_path_length=1, edge_limit=25):
|
||||||
|
return Query(
|
||||||
|
rag=rag,
|
||||||
|
collection="test",
|
||||||
|
verbose=False,
|
||||||
|
entity_limit=50,
|
||||||
|
triple_limit=30,
|
||||||
|
max_subgraph_size=1000,
|
||||||
|
max_path_length=max_path_length,
|
||||||
|
edge_limit=edge_limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_schema_triple(s, p, o):
|
||||||
|
"""Create a mock triple matching the schema interface."""
|
||||||
|
t = MagicMock()
|
||||||
|
t.s = s
|
||||||
|
t.p = p
|
||||||
|
t.o = o
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def _reranker_result(document_id, query_id="0", score=0.9):
|
||||||
|
r = MagicMock()
|
||||||
|
r.document_id = str(document_id)
|
||||||
|
r.query_id = str(query_id)
|
||||||
|
r.score = score
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: execute_batch_triple_queries direction tracking
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestDirectionTracking:
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_from_s_direction(self):
|
||||||
|
"""Triples from s=entity queries are tagged FROM_S."""
|
||||||
|
triple = _make_schema_triple("ent1", "pred", "obj")
|
||||||
|
rag = _make_rag()
|
||||||
|
|
||||||
|
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||||
|
if s is not None:
|
||||||
|
return [triple]
|
||||||
|
return []
|
||||||
|
|
||||||
|
rag.triples_client.query_stream.side_effect = query_stream
|
||||||
|
q = _make_query(rag)
|
||||||
|
|
||||||
|
result = await q.execute_batch_triple_queries(["ent1"], 10)
|
||||||
|
|
||||||
|
from_s = [(t, d) for t, d in result if d == Query.FROM_S]
|
||||||
|
assert len(from_s) == 1
|
||||||
|
assert from_s[0][0] is triple
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_from_o_direction(self):
|
||||||
|
"""Triples from o=entity queries are tagged FROM_O."""
|
||||||
|
triple = _make_schema_triple("subj", "pred", "ent1")
|
||||||
|
rag = _make_rag()
|
||||||
|
|
||||||
|
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||||
|
if o is not None:
|
||||||
|
return [triple]
|
||||||
|
return []
|
||||||
|
|
||||||
|
rag.triples_client.query_stream.side_effect = query_stream
|
||||||
|
q = _make_query(rag)
|
||||||
|
|
||||||
|
result = await q.execute_batch_triple_queries(["ent1"], 10)
|
||||||
|
|
||||||
|
from_o = [(t, d) for t, d in result if d == Query.FROM_O]
|
||||||
|
assert len(from_o) == 1
|
||||||
|
assert from_o[0][0] is triple
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_from_p_direction(self):
|
||||||
|
"""Triples from p=entity queries are tagged FROM_P."""
|
||||||
|
triple = _make_schema_triple("subj", "ent1", "obj")
|
||||||
|
rag = _make_rag()
|
||||||
|
|
||||||
|
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||||
|
if p is not None:
|
||||||
|
return [triple]
|
||||||
|
return []
|
||||||
|
|
||||||
|
rag.triples_client.query_stream.side_effect = query_stream
|
||||||
|
q = _make_query(rag)
|
||||||
|
|
||||||
|
result = await q.execute_batch_triple_queries(["ent1"], 10)
|
||||||
|
|
||||||
|
from_p = [(t, d) for t, d in result if d == Query.FROM_P]
|
||||||
|
assert len(from_p) == 1
|
||||||
|
assert from_p[0][0] is triple
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: hop_and_filter reranker document text
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestDirectionAwareRerankerText:
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_from_s_uses_predicate_object(self):
|
||||||
|
"""From-S traversal: reranker text should be '{p} {o}'."""
|
||||||
|
triple = _make_schema_triple(
|
||||||
|
"http://ex/entity-A",
|
||||||
|
"http://ex/likes",
|
||||||
|
"http://ex/entity-B",
|
||||||
|
)
|
||||||
|
reranker_result = _reranker_result(0)
|
||||||
|
rag = _make_rag(reranker_results=[reranker_result])
|
||||||
|
|
||||||
|
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||||
|
if s is not None:
|
||||||
|
return [triple]
|
||||||
|
return []
|
||||||
|
|
||||||
|
rag.triples_client.query_stream.side_effect = query_stream
|
||||||
|
|
||||||
|
q = _make_query(rag, max_path_length=1, edge_limit=10)
|
||||||
|
|
||||||
|
await q.hop_and_filter(
|
||||||
|
seed_entities=["http://ex/entity-A"],
|
||||||
|
concepts=["likes"],
|
||||||
|
)
|
||||||
|
|
||||||
|
call_args = rag.reranker_client.rerank.call_args
|
||||||
|
documents = call_args.kwargs["documents"]
|
||||||
|
# Text should be "{p} {o}" — the URIs since no labels found
|
||||||
|
assert len(documents) == 1
|
||||||
|
assert documents[0]["text"] == "http://ex/likes http://ex/entity-B"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_from_o_uses_subject_predicate(self):
|
||||||
|
"""From-O traversal: reranker text should be '{s} {p}'."""
|
||||||
|
triple = _make_schema_triple(
|
||||||
|
"http://ex/entity-A",
|
||||||
|
"http://ex/likes",
|
||||||
|
"http://ex/entity-B",
|
||||||
|
)
|
||||||
|
reranker_result = _reranker_result(0)
|
||||||
|
rag = _make_rag(reranker_results=[reranker_result])
|
||||||
|
|
||||||
|
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||||
|
if o is not None:
|
||||||
|
return [triple]
|
||||||
|
return []
|
||||||
|
|
||||||
|
rag.triples_client.query_stream.side_effect = query_stream
|
||||||
|
|
||||||
|
q = _make_query(rag, max_path_length=1, edge_limit=10)
|
||||||
|
|
||||||
|
await q.hop_and_filter(
|
||||||
|
seed_entities=["http://ex/entity-B"],
|
||||||
|
concepts=["likes"],
|
||||||
|
)
|
||||||
|
|
||||||
|
call_args = rag.reranker_client.rerank.call_args
|
||||||
|
documents = call_args.kwargs["documents"]
|
||||||
|
assert len(documents) == 1
|
||||||
|
assert documents[0]["text"] == "http://ex/entity-A http://ex/likes"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_from_p_uses_subject_object(self):
|
||||||
|
"""From-P traversal: reranker text should be '{s} {o}'."""
|
||||||
|
triple = _make_schema_triple(
|
||||||
|
"http://ex/entity-A",
|
||||||
|
"http://ex/likes",
|
||||||
|
"http://ex/entity-B",
|
||||||
|
)
|
||||||
|
reranker_result = _reranker_result(0)
|
||||||
|
rag = _make_rag(reranker_results=[reranker_result])
|
||||||
|
|
||||||
|
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||||
|
if p is not None:
|
||||||
|
return [triple]
|
||||||
|
return []
|
||||||
|
|
||||||
|
rag.triples_client.query_stream.side_effect = query_stream
|
||||||
|
|
||||||
|
q = _make_query(rag, max_path_length=1, edge_limit=10)
|
||||||
|
|
||||||
|
await q.hop_and_filter(
|
||||||
|
seed_entities=["http://ex/likes"],
|
||||||
|
concepts=["entity"],
|
||||||
|
)
|
||||||
|
|
||||||
|
call_args = rag.reranker_client.rerank.call_args
|
||||||
|
documents = call_args.kwargs["documents"]
|
||||||
|
assert len(documents) == 1
|
||||||
|
assert documents[0]["text"] == "http://ex/entity-A http://ex/entity-B"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mixed_directions_produce_different_text(self):
|
||||||
|
"""Edges from different directions use different text formats."""
|
||||||
|
triple_from_s = _make_schema_triple(
|
||||||
|
"http://ex/seed", "http://ex/rel", "http://ex/target",
|
||||||
|
)
|
||||||
|
triple_from_o = _make_schema_triple(
|
||||||
|
"http://ex/other", "http://ex/ref", "http://ex/seed",
|
||||||
|
)
|
||||||
|
|
||||||
|
rag = _make_rag(reranker_results=[
|
||||||
|
_reranker_result(0), _reranker_result(1),
|
||||||
|
])
|
||||||
|
|
||||||
|
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||||
|
if s == "http://ex/seed":
|
||||||
|
return [triple_from_s]
|
||||||
|
if o == "http://ex/seed":
|
||||||
|
return [triple_from_o]
|
||||||
|
return []
|
||||||
|
|
||||||
|
rag.triples_client.query_stream.side_effect = query_stream
|
||||||
|
|
||||||
|
q = _make_query(rag, max_path_length=1, edge_limit=10)
|
||||||
|
|
||||||
|
await q.hop_and_filter(
|
||||||
|
seed_entities=["http://ex/seed"],
|
||||||
|
concepts=["test"],
|
||||||
|
)
|
||||||
|
|
||||||
|
call_args = rag.reranker_client.rerank.call_args
|
||||||
|
documents = call_args.kwargs["documents"]
|
||||||
|
texts = {d["text"] for d in documents}
|
||||||
|
|
||||||
|
# From S: "{p} {o}" = "http://ex/rel http://ex/target"
|
||||||
|
assert "http://ex/rel http://ex/target" in texts
|
||||||
|
# From O: "{s} {p}" = "http://ex/other http://ex/ref"
|
||||||
|
assert "http://ex/other http://ex/ref" in texts
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_labels_applied_to_direction_text(self):
|
||||||
|
"""Labels should be resolved and used in the direction-aware text."""
|
||||||
|
triple = _make_schema_triple(
|
||||||
|
"http://ex/entity-A",
|
||||||
|
"http://ex/likes",
|
||||||
|
"http://ex/entity-B",
|
||||||
|
)
|
||||||
|
reranker_result = _reranker_result(0)
|
||||||
|
rag = _make_rag(reranker_results=[reranker_result])
|
||||||
|
|
||||||
|
LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
|
||||||
|
|
||||||
|
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||||
|
if s is not None and p is None:
|
||||||
|
return [triple]
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def label_query(s=None, p=None, o=None, limit=1, **kwargs):
|
||||||
|
if p == LABEL:
|
||||||
|
labels = {
|
||||||
|
"http://ex/entity-A": "Alice",
|
||||||
|
"http://ex/likes": "likes",
|
||||||
|
"http://ex/entity-B": "Bob",
|
||||||
|
}
|
||||||
|
if s in labels:
|
||||||
|
return [MagicMock(o=labels[s])]
|
||||||
|
return []
|
||||||
|
|
||||||
|
rag.triples_client.query_stream.side_effect = query_stream
|
||||||
|
rag.triples_client.query.side_effect = label_query
|
||||||
|
|
||||||
|
q = _make_query(rag, max_path_length=1, edge_limit=10)
|
||||||
|
|
||||||
|
await q.hop_and_filter(
|
||||||
|
seed_entities=["http://ex/entity-A"],
|
||||||
|
concepts=["friendship"],
|
||||||
|
)
|
||||||
|
|
||||||
|
call_args = rag.reranker_client.rerank.call_args
|
||||||
|
documents = call_args.kwargs["documents"]
|
||||||
|
assert len(documents) == 1
|
||||||
|
# From S with labels: "{p_label} {o_label}"
|
||||||
|
assert documents[0]["text"] == "likes Bob"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_duplicate_text_from_shared_object(self):
|
||||||
|
"""Multiple edges sharing an object should produce distinct texts."""
|
||||||
|
triple_a = _make_schema_triple(
|
||||||
|
"http://ex/cpu-A", "http://ex/hasCategory", "http://ex/Processors",
|
||||||
|
)
|
||||||
|
triple_b = _make_schema_triple(
|
||||||
|
"http://ex/cpu-B", "http://ex/hasCategory", "http://ex/Processors",
|
||||||
|
)
|
||||||
|
|
||||||
|
rag = _make_rag(reranker_results=[
|
||||||
|
_reranker_result(0), _reranker_result(1),
|
||||||
|
])
|
||||||
|
|
||||||
|
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||||
|
if o == "http://ex/Processors":
|
||||||
|
return [triple_a, triple_b]
|
||||||
|
return []
|
||||||
|
|
||||||
|
rag.triples_client.query_stream.side_effect = query_stream
|
||||||
|
|
||||||
|
q = _make_query(rag, max_path_length=1, edge_limit=10)
|
||||||
|
|
||||||
|
await q.hop_and_filter(
|
||||||
|
seed_entities=["http://ex/Processors"],
|
||||||
|
concepts=["CPUs"],
|
||||||
|
)
|
||||||
|
|
||||||
|
call_args = rag.reranker_client.rerank.call_args
|
||||||
|
documents = call_args.kwargs["documents"]
|
||||||
|
texts = [d["text"] for d in documents]
|
||||||
|
|
||||||
|
assert len(texts) == 2
|
||||||
|
# From O: "{s} {p}" — subjects differ, so texts differ
|
||||||
|
assert texts[0] != texts[1]
|
||||||
|
assert "http://ex/cpu-A" in texts[0]
|
||||||
|
assert "http://ex/cpu-B" in texts[1]
|
||||||
|
|
@ -241,38 +241,56 @@ class Query:
|
||||||
self.rag.label_cache.put(cache_key, label)
|
self.rag.label_cache.put(cache_key, label)
|
||||||
return label
|
return label
|
||||||
|
|
||||||
|
FROM_S = "from_s"
|
||||||
|
FROM_P = "from_p"
|
||||||
|
FROM_O = "from_o"
|
||||||
|
|
||||||
async def execute_batch_triple_queries(self, entities, limit_per_entity):
|
async def execute_batch_triple_queries(self, entities, limit_per_entity):
|
||||||
"""Execute triple queries for multiple entities concurrently."""
|
"""Execute triple queries for multiple entities concurrently.
|
||||||
|
|
||||||
|
Returns a list of (triple, direction) tuples where direction
|
||||||
|
indicates which position the frontier entity occupied.
|
||||||
|
"""
|
||||||
tasks = []
|
tasks = []
|
||||||
|
directions = []
|
||||||
|
|
||||||
for entity in entities:
|
for entity in entities:
|
||||||
tasks.extend([
|
tasks.append(
|
||||||
self.rag.triples_client.query_stream(
|
self.rag.triples_client.query_stream(
|
||||||
s=entity, p=None, o=None,
|
s=entity, p=None, o=None,
|
||||||
limit=limit_per_entity,
|
limit=limit_per_entity,
|
||||||
collection=self.collection,
|
collection=self.collection,
|
||||||
batch_size=20, g="",
|
batch_size=20, g="",
|
||||||
),
|
),
|
||||||
|
)
|
||||||
|
directions.append(self.FROM_S)
|
||||||
|
|
||||||
|
tasks.append(
|
||||||
self.rag.triples_client.query_stream(
|
self.rag.triples_client.query_stream(
|
||||||
s=None, p=entity, o=None,
|
s=None, p=entity, o=None,
|
||||||
limit=limit_per_entity,
|
limit=limit_per_entity,
|
||||||
collection=self.collection,
|
collection=self.collection,
|
||||||
batch_size=20, g="",
|
batch_size=20, g="",
|
||||||
),
|
),
|
||||||
|
)
|
||||||
|
directions.append(self.FROM_P)
|
||||||
|
|
||||||
|
tasks.append(
|
||||||
self.rag.triples_client.query_stream(
|
self.rag.triples_client.query_stream(
|
||||||
s=None, p=None, o=entity,
|
s=None, p=None, o=entity,
|
||||||
limit=limit_per_entity,
|
limit=limit_per_entity,
|
||||||
collection=self.collection,
|
collection=self.collection,
|
||||||
batch_size=20, g="",
|
batch_size=20, g="",
|
||||||
)
|
),
|
||||||
])
|
)
|
||||||
|
directions.append(self.FROM_O)
|
||||||
|
|
||||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
all_triples = []
|
all_triples = []
|
||||||
for result in results:
|
for direction, result in zip(directions, results):
|
||||||
if not isinstance(result, Exception) and result is not None:
|
if not isinstance(result, Exception) and result is not None:
|
||||||
all_triples.extend(result)
|
all_triples.extend((triple, direction) for triple in result)
|
||||||
|
|
||||||
return all_triples
|
return all_triples
|
||||||
|
|
||||||
|
|
@ -325,7 +343,8 @@ class Query:
|
||||||
# Deduplicate and filter already-seen edges
|
# Deduplicate and filter already-seen edges
|
||||||
hop_triples = []
|
hop_triples = []
|
||||||
hop_term_map = {}
|
hop_term_map = {}
|
||||||
for triple in triples:
|
hop_directions = {}
|
||||||
|
for triple, direction in triples:
|
||||||
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
|
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
|
||||||
if triple_tuple[1] == LABEL:
|
if triple_tuple[1] == LABEL:
|
||||||
continue
|
continue
|
||||||
|
|
@ -336,6 +355,7 @@ class Query:
|
||||||
hop_term_map[triple_tuple] = (
|
hop_term_map[triple_tuple] = (
|
||||||
to_term(triple.s), to_term(triple.p), to_term(triple.o),
|
to_term(triple.s), to_term(triple.p), to_term(triple.o),
|
||||||
)
|
)
|
||||||
|
hop_directions[triple_tuple] = direction
|
||||||
|
|
||||||
if not hop_triples:
|
if not hop_triples:
|
||||||
visited_entities.update(frontier)
|
visited_entities.update(frontier)
|
||||||
|
|
@ -361,7 +381,10 @@ class Query:
|
||||||
else:
|
else:
|
||||||
label_map[entity] = entity
|
label_map[entity] = entity
|
||||||
|
|
||||||
# Build labeled edges and documents for cross-encoder
|
# Build labeled edges and documents for cross-encoder.
|
||||||
|
# The reranker text highlights the NEW information relative
|
||||||
|
# to the traversal direction: arriving from S means p,o are
|
||||||
|
# new; from O means s,p are new; from P means s,o are new.
|
||||||
labeled_hop = []
|
labeled_hop = []
|
||||||
for s, p, o in hop_triples:
|
for s, p, o in hop_triples:
|
||||||
ls = label_map.get(s, s)
|
ls = label_map.get(s, s)
|
||||||
|
|
@ -369,10 +392,18 @@ class Query:
|
||||||
lo = label_map.get(o, o)
|
lo = label_map.get(o, o)
|
||||||
labeled_hop.append((ls, lp, lo))
|
labeled_hop.append((ls, lp, lo))
|
||||||
|
|
||||||
documents = [
|
documents = []
|
||||||
{"id": str(i), "text": f"{lp} {lo}"}
|
for i, (triple_tuple, (ls, lp, lo)) in enumerate(
|
||||||
for i, (ls, lp, lo) in enumerate(labeled_hop)
|
zip(hop_triples, labeled_hop)
|
||||||
]
|
):
|
||||||
|
direction = hop_directions[triple_tuple]
|
||||||
|
if direction == self.FROM_S:
|
||||||
|
text = f"{lp} {lo}"
|
||||||
|
elif direction == self.FROM_O:
|
||||||
|
text = f"{ls} {lp}"
|
||||||
|
else:
|
||||||
|
text = f"{ls} {lo}"
|
||||||
|
documents.append({"id": str(i), "text": text})
|
||||||
|
|
||||||
queries = [
|
queries = [
|
||||||
{"id": str(i), "text": c}
|
{"id": str(i), "text": c}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue