mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-05 19:32:11 +02:00
feat: filter and cap GraphRAG reranker input across full stack (#1021)
- Filter out RDF/RDFS/OWL schema predicates (rdfs:domain, owl:inverseOf, etc.) from hop traversal, keeping rdf:type for data signal - Skip edges where reranker-visible components are unlabeled IRIs, since the cross-encoder cannot meaningfully score raw URIs - Add max-reranker-input safety cap (default 350) to prevent overloading the reranker, applied after filtering for maximum useful candidates - Expose max-reranker-input as per-request parameter through schema, translator, REST API, socket client, CLI, and OpenAPI spec - Update tests - Update tech spec
This commit is contained in:
parent
76c4763b9b
commit
68e816e65c
10 changed files with 198 additions and 43 deletions
|
|
@ -18,15 +18,30 @@ from trustgraph.schema import Term, IRI, LITERAL
|
|||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_rag(reranker_results=None):
|
||||
"""Create a mock GraphRag with all clients stubbed."""
|
||||
LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
||||
|
||||
def _make_rag(reranker_results=None, labels=None):
|
||||
"""Create a mock GraphRag with all clients stubbed.
|
||||
|
||||
labels is an optional dict mapping URI -> label string. When provided,
|
||||
the mock triples_client.query will return matching label triples so
|
||||
that hop_and_filter resolves labels instead of falling back to raw URIs
|
||||
(which are now filtered out by the IRI filter).
|
||||
"""
|
||||
rag = MagicMock()
|
||||
rag.label_cache = LRUCacheWithTTL()
|
||||
rag.triples_client = AsyncMock()
|
||||
rag.reranker_client = AsyncMock()
|
||||
|
||||
# Label lookups return empty (fall back to URI)
|
||||
rag.triples_client.query.return_value = []
|
||||
if labels:
|
||||
async def label_query(s=None, p=None, o=None, limit=1, **kwargs):
|
||||
if p == LABEL and s in labels:
|
||||
return [MagicMock(o=labels[s])]
|
||||
return []
|
||||
rag.triples_client.query.side_effect = label_query
|
||||
else:
|
||||
rag.triples_client.query.return_value = []
|
||||
|
||||
if reranker_results is not None:
|
||||
rag.reranker_client.rerank.return_value = reranker_results
|
||||
|
|
@ -147,8 +162,13 @@ class TestDirectionAwareRerankerText:
|
|||
"http://ex/likes",
|
||||
"http://ex/entity-B",
|
||||
)
|
||||
labels = {
|
||||
"http://ex/entity-A": "Alice",
|
||||
"http://ex/likes": "likes",
|
||||
"http://ex/entity-B": "Bob",
|
||||
}
|
||||
reranker_result = _reranker_result(0)
|
||||
rag = _make_rag(reranker_results=[reranker_result])
|
||||
rag = _make_rag(reranker_results=[reranker_result], labels=labels)
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if s is not None:
|
||||
|
|
@ -166,9 +186,8 @@ class TestDirectionAwareRerankerText:
|
|||
|
||||
call_args = rag.reranker_client.rerank.call_args
|
||||
documents = call_args.kwargs["documents"]
|
||||
# Text should be "{p} {o}" — the URIs since no labels found
|
||||
assert len(documents) == 1
|
||||
assert documents[0]["text"] == "http://ex/likes http://ex/entity-B"
|
||||
assert documents[0]["text"] == "likes Bob"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_from_o_uses_subject_predicate(self):
|
||||
|
|
@ -178,8 +197,13 @@ class TestDirectionAwareRerankerText:
|
|||
"http://ex/likes",
|
||||
"http://ex/entity-B",
|
||||
)
|
||||
labels = {
|
||||
"http://ex/entity-A": "Alice",
|
||||
"http://ex/likes": "likes",
|
||||
"http://ex/entity-B": "Bob",
|
||||
}
|
||||
reranker_result = _reranker_result(0)
|
||||
rag = _make_rag(reranker_results=[reranker_result])
|
||||
rag = _make_rag(reranker_results=[reranker_result], labels=labels)
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if o is not None:
|
||||
|
|
@ -198,7 +222,7 @@ class TestDirectionAwareRerankerText:
|
|||
call_args = rag.reranker_client.rerank.call_args
|
||||
documents = call_args.kwargs["documents"]
|
||||
assert len(documents) == 1
|
||||
assert documents[0]["text"] == "http://ex/entity-A http://ex/likes"
|
||||
assert documents[0]["text"] == "Alice likes"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_from_p_uses_subject_object(self):
|
||||
|
|
@ -208,8 +232,13 @@ class TestDirectionAwareRerankerText:
|
|||
"http://ex/likes",
|
||||
"http://ex/entity-B",
|
||||
)
|
||||
labels = {
|
||||
"http://ex/entity-A": "Alice",
|
||||
"http://ex/likes": "likes",
|
||||
"http://ex/entity-B": "Bob",
|
||||
}
|
||||
reranker_result = _reranker_result(0)
|
||||
rag = _make_rag(reranker_results=[reranker_result])
|
||||
rag = _make_rag(reranker_results=[reranker_result], labels=labels)
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if p is not None:
|
||||
|
|
@ -228,7 +257,7 @@ class TestDirectionAwareRerankerText:
|
|||
call_args = rag.reranker_client.rerank.call_args
|
||||
documents = call_args.kwargs["documents"]
|
||||
assert len(documents) == 1
|
||||
assert documents[0]["text"] == "http://ex/entity-A http://ex/entity-B"
|
||||
assert documents[0]["text"] == "Alice Bob"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_directions_produce_different_text(self):
|
||||
|
|
@ -239,10 +268,18 @@ class TestDirectionAwareRerankerText:
|
|||
triple_from_o = _make_schema_triple(
|
||||
"http://ex/other", "http://ex/ref", "http://ex/seed",
|
||||
)
|
||||
labels = {
|
||||
"http://ex/seed": "Seed",
|
||||
"http://ex/rel": "relates to",
|
||||
"http://ex/target": "Target",
|
||||
"http://ex/other": "Other",
|
||||
"http://ex/ref": "references",
|
||||
}
|
||||
|
||||
rag = _make_rag(reranker_results=[
|
||||
_reranker_result(0), _reranker_result(1),
|
||||
])
|
||||
rag = _make_rag(
|
||||
reranker_results=[_reranker_result(0), _reranker_result(1)],
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if s == "http://ex/seed":
|
||||
|
|
@ -264,10 +301,10 @@ class TestDirectionAwareRerankerText:
|
|||
documents = call_args.kwargs["documents"]
|
||||
texts = {d["text"] for d in documents}
|
||||
|
||||
# From S: "{p} {o}" = "http://ex/rel http://ex/target"
|
||||
assert "http://ex/rel http://ex/target" in texts
|
||||
# From O: "{s} {p}" = "http://ex/other http://ex/ref"
|
||||
assert "http://ex/other http://ex/ref" in texts
|
||||
# From S: "{p} {o}" = "relates to Target"
|
||||
assert "relates to Target" in texts
|
||||
# From O: "{s} {p}" = "Other references"
|
||||
assert "Other references" in texts
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_labels_applied_to_direction_text(self):
|
||||
|
|
@ -280,8 +317,6 @@ class TestDirectionAwareRerankerText:
|
|||
reranker_result = _reranker_result(0)
|
||||
rag = _make_rag(reranker_results=[reranker_result])
|
||||
|
||||
LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if s is not None and p is None:
|
||||
return [triple]
|
||||
|
|
@ -323,10 +358,17 @@ class TestDirectionAwareRerankerText:
|
|||
triple_b = _make_schema_triple(
|
||||
"http://ex/cpu-B", "http://ex/hasCategory", "http://ex/Processors",
|
||||
)
|
||||
labels = {
|
||||
"http://ex/cpu-A": "CPU Alpha",
|
||||
"http://ex/cpu-B": "CPU Beta",
|
||||
"http://ex/hasCategory": "has category",
|
||||
"http://ex/Processors": "Processors",
|
||||
}
|
||||
|
||||
rag = _make_rag(reranker_results=[
|
||||
_reranker_result(0), _reranker_result(1),
|
||||
])
|
||||
rag = _make_rag(
|
||||
reranker_results=[_reranker_result(0), _reranker_result(1)],
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if o == "http://ex/Processors":
|
||||
|
|
@ -349,5 +391,5 @@ class TestDirectionAwareRerankerText:
|
|||
assert len(texts) == 2
|
||||
# From O: "{s} {p}" — subjects differ, so texts differ
|
||||
assert texts[0] != texts[1]
|
||||
assert "http://ex/cpu-A" in texts[0]
|
||||
assert "http://ex/cpu-B" in texts[1]
|
||||
assert "CPU Alpha" in texts[0]
|
||||
assert "CPU Beta" in texts[1]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue