feat: filter and cap GraphRAG reranker input across full stack (#1021)

- Filter out RDF/RDFS/OWL schema predicates (rdfs:domain, owl:inverseOf,
  etc.) from hop traversal, keeping rdf:type for data signal
- Skip edges where reranker-visible components are unlabeled IRIs, since
  the cross-encoder cannot meaningfully score raw URIs
- Add max-reranker-input safety cap (default 350) to prevent overloading
  the reranker, applied after filtering for maximum useful candidates
- Expose max-reranker-input as per-request parameter through schema,
  translator, REST API, socket client, CLI, and OpenAPI spec
- Update tests
- Update tech spec
This commit is contained in:
cybermaggedon 2026-07-03 15:51:04 +01:00 committed by GitHub
parent 76c4763b9b
commit 68e816e65c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 198 additions and 43 deletions

View file

@ -404,10 +404,33 @@ no LLM call. These fields are dropped from the Focus entity.
a. Retrieve all edges one hop from the current frontier nodes.
b. Represent each edge using direction-aware text: from a
subject node use `"{predicate} {object}"`, from an object
node use `"{subject} {predicate}"`, from a predicate node
use `"{subject} {object}"`.
b. Filter and represent edges for scoring:
- **Schema predicate filter.** Edges with RDF/RDFS/OWL
schema predicates (`rdfs:domain`, `owl:inverseOf`, etc.)
are removed. `rdf:type` is kept as it carries useful
data signal.
- **IRI filter.** Edges where the reranker-visible text
components (after label resolution) are still raw IRIs
are removed — the cross-encoder cannot meaningfully score
unresolved URIs. Only the components that would appear
in the reranker text are checked, based on traversal
direction.
- **Direction-aware text.** Each surviving edge is
represented using direction-aware text: from a subject
node use `"{predicate} {object}"`, from an object node
use `"{subject} {predicate}"`, from a predicate node
use `"{subject} {object}"`.
- **Reranker input cap.** The candidate set is truncated
to `max_reranker_input` (default 350) edges. This is a
safety measure, not an accuracy optimisation — there is
no point in producing a perfectly ranked edge set if the
reranker crashes or times out because it was handed
thousands of candidates. The cap is applied after
filtering so that the most useful edges fill the budget.
c. Score edges against the extracted concepts using the
cross-encoder service.

View file

@ -42,6 +42,13 @@ properties:
minimum: 1
maximum: 5
example: 3
max-reranker-input:
type: integer
description: Maximum candidate edges sent to the reranker per hop
default: 350
minimum: 1
maximum: 1000
example: 350
streaming:
type: boolean
description: Enable streaming response delivery

View file

@ -18,15 +18,30 @@ from trustgraph.schema import Term, IRI, LITERAL
# Helpers
# ---------------------------------------------------------------------------
def _make_rag(reranker_results=None):
"""Create a mock GraphRag with all clients stubbed."""
LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
def _make_rag(reranker_results=None, labels=None):
"""Create a mock GraphRag with all clients stubbed.
labels is an optional dict mapping URI -> label string. When provided,
the mock triples_client.query will return matching label triples so
that hop_and_filter resolves labels instead of falling back to raw URIs
(which are now filtered out by the IRI filter).
"""
rag = MagicMock()
rag.label_cache = LRUCacheWithTTL()
rag.triples_client = AsyncMock()
rag.reranker_client = AsyncMock()
# Label lookups return empty (fall back to URI)
rag.triples_client.query.return_value = []
if labels:
async def label_query(s=None, p=None, o=None, limit=1, **kwargs):
if p == LABEL and s in labels:
return [MagicMock(o=labels[s])]
return []
rag.triples_client.query.side_effect = label_query
else:
rag.triples_client.query.return_value = []
if reranker_results is not None:
rag.reranker_client.rerank.return_value = reranker_results
@ -147,8 +162,13 @@ class TestDirectionAwareRerankerText:
"http://ex/likes",
"http://ex/entity-B",
)
labels = {
"http://ex/entity-A": "Alice",
"http://ex/likes": "likes",
"http://ex/entity-B": "Bob",
}
reranker_result = _reranker_result(0)
rag = _make_rag(reranker_results=[reranker_result])
rag = _make_rag(reranker_results=[reranker_result], labels=labels)
async def query_stream(s=None, p=None, o=None, **kwargs):
if s is not None:
@ -166,9 +186,8 @@ class TestDirectionAwareRerankerText:
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
# Text should be "{p} {o}" — the URIs since no labels found
assert len(documents) == 1
assert documents[0]["text"] == "http://ex/likes http://ex/entity-B"
assert documents[0]["text"] == "likes Bob"
@pytest.mark.asyncio
async def test_from_o_uses_subject_predicate(self):
@ -178,8 +197,13 @@ class TestDirectionAwareRerankerText:
"http://ex/likes",
"http://ex/entity-B",
)
labels = {
"http://ex/entity-A": "Alice",
"http://ex/likes": "likes",
"http://ex/entity-B": "Bob",
}
reranker_result = _reranker_result(0)
rag = _make_rag(reranker_results=[reranker_result])
rag = _make_rag(reranker_results=[reranker_result], labels=labels)
async def query_stream(s=None, p=None, o=None, **kwargs):
if o is not None:
@ -198,7 +222,7 @@ class TestDirectionAwareRerankerText:
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
assert len(documents) == 1
assert documents[0]["text"] == "http://ex/entity-A http://ex/likes"
assert documents[0]["text"] == "Alice likes"
@pytest.mark.asyncio
async def test_from_p_uses_subject_object(self):
@ -208,8 +232,13 @@ class TestDirectionAwareRerankerText:
"http://ex/likes",
"http://ex/entity-B",
)
labels = {
"http://ex/entity-A": "Alice",
"http://ex/likes": "likes",
"http://ex/entity-B": "Bob",
}
reranker_result = _reranker_result(0)
rag = _make_rag(reranker_results=[reranker_result])
rag = _make_rag(reranker_results=[reranker_result], labels=labels)
async def query_stream(s=None, p=None, o=None, **kwargs):
if p is not None:
@ -228,7 +257,7 @@ class TestDirectionAwareRerankerText:
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
assert len(documents) == 1
assert documents[0]["text"] == "http://ex/entity-A http://ex/entity-B"
assert documents[0]["text"] == "Alice Bob"
@pytest.mark.asyncio
async def test_mixed_directions_produce_different_text(self):
@ -239,10 +268,18 @@ class TestDirectionAwareRerankerText:
triple_from_o = _make_schema_triple(
"http://ex/other", "http://ex/ref", "http://ex/seed",
)
labels = {
"http://ex/seed": "Seed",
"http://ex/rel": "relates to",
"http://ex/target": "Target",
"http://ex/other": "Other",
"http://ex/ref": "references",
}
rag = _make_rag(reranker_results=[
_reranker_result(0), _reranker_result(1),
])
rag = _make_rag(
reranker_results=[_reranker_result(0), _reranker_result(1)],
labels=labels,
)
async def query_stream(s=None, p=None, o=None, **kwargs):
if s == "http://ex/seed":
@ -264,10 +301,10 @@ class TestDirectionAwareRerankerText:
documents = call_args.kwargs["documents"]
texts = {d["text"] for d in documents}
# From S: "{p} {o}" = "http://ex/rel http://ex/target"
assert "http://ex/rel http://ex/target" in texts
# From O: "{s} {p}" = "http://ex/other http://ex/ref"
assert "http://ex/other http://ex/ref" in texts
# From S: "{p} {o}" = "relates to Target"
assert "relates to Target" in texts
# From O: "{s} {p}" = "Other references"
assert "Other references" in texts
@pytest.mark.asyncio
async def test_labels_applied_to_direction_text(self):
@ -280,8 +317,6 @@ class TestDirectionAwareRerankerText:
reranker_result = _reranker_result(0)
rag = _make_rag(reranker_results=[reranker_result])
LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
async def query_stream(s=None, p=None, o=None, **kwargs):
if s is not None and p is None:
return [triple]
@ -323,10 +358,17 @@ class TestDirectionAwareRerankerText:
triple_b = _make_schema_triple(
"http://ex/cpu-B", "http://ex/hasCategory", "http://ex/Processors",
)
labels = {
"http://ex/cpu-A": "CPU Alpha",
"http://ex/cpu-B": "CPU Beta",
"http://ex/hasCategory": "has category",
"http://ex/Processors": "Processors",
}
rag = _make_rag(reranker_results=[
_reranker_result(0), _reranker_result(1),
])
rag = _make_rag(
reranker_results=[_reranker_result(0), _reranker_result(1)],
labels=labels,
)
async def query_stream(s=None, p=None, o=None, **kwargs):
if o == "http://ex/Processors":
@ -349,5 +391,5 @@ class TestDirectionAwareRerankerText:
assert len(texts) == 2
# From O: "{s} {p}" — subjects differ, so texts differ
assert texts[0] != texts[1]
assert "http://ex/cpu-A" in texts[0]
assert "http://ex/cpu-B" in texts[1]
assert "CPU Alpha" in texts[0]
assert "CPU Beta" in texts[1]

View file

@ -357,6 +357,7 @@ class FlowInstance:
self, query,collection="default",
entity_limit=50, triple_limit=30, max_subgraph_size=150,
max_path_length=2, edge_score_limit=30, edge_limit=25,
max_reranker_input=350,
):
"""
Execute graph-based Retrieval-Augmented Generation (RAG) query.
@ -373,6 +374,7 @@ class FlowInstance:
max_path_length: Maximum traversal depth (default: 2)
edge_score_limit: Max edges for semantic pre-filter (default: 50)
edge_limit: Max edges after LLM scoring (default: 25)
max_reranker_input: Max candidate edges sent to reranker per hop (default: 350)
Returns:
str: Generated response incorporating graph context
@ -399,6 +401,7 @@ class FlowInstance:
"max-path-length": max_path_length,
"edge-score-limit": edge_score_limit,
"edge-limit": edge_limit,
"max-reranker-input": max_reranker_input,
}
result = self.request(

View file

@ -682,6 +682,7 @@ class SocketFlowInstance:
max_path_length: int = 2,
edge_score_limit: int = 30,
edge_limit: int = 25,
max_reranker_input: int = 350,
streaming: bool = False,
**kwargs: Any
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
@ -699,6 +700,7 @@ class SocketFlowInstance:
"max-path-length": max_path_length,
"edge-score-limit": edge_score_limit,
"edge-limit": edge_limit,
"max-reranker-input": max_reranker_input,
"streaming": streaming
}
request.update(kwargs)
@ -725,6 +727,7 @@ class SocketFlowInstance:
max_path_length: int = 2,
edge_score_limit: int = 30,
edge_limit: int = 25,
max_reranker_input: int = 350,
**kwargs: Any
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
"""Execute graph-based RAG query with explainability support."""
@ -737,6 +740,7 @@ class SocketFlowInstance:
"max-path-length": max_path_length,
"edge-score-limit": edge_score_limit,
"edge-limit": edge_limit,
"max-reranker-input": max_reranker_input,
"streaming": True,
"explainable": True,
}

View file

@ -103,6 +103,7 @@ class GraphRagRequestTranslator(MessageTranslator):
max_path_length=int(data.get("max-path-length", 2)),
edge_score_limit=int(data.get("edge-score-limit", 30)),
edge_limit=int(data.get("edge-limit", 25)),
max_reranker_input=int(data.get("max-reranker-input", 350)),
streaming=data.get("streaming", False)
)
@ -116,6 +117,7 @@ class GraphRagRequestTranslator(MessageTranslator):
"max-path-length": obj.max_path_length,
"edge-score-limit": obj.edge_score_limit,
"edge-limit": obj.edge_limit,
"max-reranker-input": obj.max_reranker_input,
"streaming": getattr(obj, "streaming", False)
}

View file

@ -15,6 +15,7 @@ class GraphRagQuery:
max_path_length: int = 0
edge_score_limit: int = 0
edge_limit: int = 0
max_reranker_input: int = 0
streaming: bool = False
parent_uri: str = ""

View file

@ -27,11 +27,13 @@ default_max_subgraph_size = 150
default_max_path_length = 2
default_edge_score_limit = 30
default_edge_limit = 25
default_max_reranker_input = 350
def _question_explainable_api(
url, flow_id, question_text, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, edge_score_limit=30,
edge_limit=25, token=None, debug=False, workspace="default",
edge_limit=25, max_reranker_input=350, token=None, debug=False,
workspace="default",
):
"""Execute graph RAG with explainability using the new API classes."""
api = Api(url=url, token=token, workspace=workspace)
@ -50,6 +52,7 @@ def _question_explainable_api(
max_path_length=max_path_length,
edge_score_limit=edge_score_limit,
edge_limit=edge_limit,
max_reranker_input=max_reranker_input,
):
if isinstance(item, RAGChunk):
# Print response content
@ -138,7 +141,7 @@ def _question_explainable_api(
def question(
url, flow_id, question, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, edge_score_limit=50,
edge_limit=25, streaming=True, token=None,
edge_limit=25, max_reranker_input=350, streaming=True, token=None,
explainable=False, debug=False, show_usage=False,
workspace="default",
):
@ -156,6 +159,7 @@ def question(
max_path_length=max_path_length,
edge_score_limit=edge_score_limit,
edge_limit=edge_limit,
max_reranker_input=max_reranker_input,
token=token,
debug=debug,
workspace=workspace,
@ -180,6 +184,7 @@ def question(
max_path_length=max_path_length,
edge_score_limit=edge_score_limit,
edge_limit=edge_limit,
max_reranker_input=max_reranker_input,
streaming=True
)
@ -212,6 +217,7 @@ def question(
max_path_length=max_path_length,
edge_score_limit=edge_score_limit,
edge_limit=edge_limit,
max_reranker_input=max_reranker_input,
)
print(result.text)
@ -308,6 +314,13 @@ def main():
help=f'Max edges after LLM scoring (default: {default_edge_limit})'
)
parser.add_argument(
'--max-reranker-input',
type=int,
default=default_max_reranker_input,
help=f'Max candidate edges sent to reranker per hop (default: {default_max_reranker_input})'
)
parser.add_argument(
'--no-streaming',
action='store_true',
@ -347,6 +360,7 @@ def main():
max_path_length=args.max_path_length,
edge_score_limit=args.edge_score_limit,
edge_limit=args.edge_limit,
max_reranker_input=args.max_reranker_input,
streaming=not args.no_streaming,
token=args.token,
explainable=args.explainable,

View file

@ -34,6 +34,22 @@ logger = logging.getLogger(__name__)
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
RDF_NS = "http://www.w3.org/1999/02/22-rdf-syntax-ns#"
RDFS_NS = "http://www.w3.org/2000/01/rdf-schema#"
OWL_NS = "http://www.w3.org/2002/07/owl#"
RDF_TYPE = RDF_NS + "type"
SCHEMA_NAMESPACES = (RDF_NS, RDFS_NS, OWL_NS)
def is_schema_predicate(predicate):
"""Return True if the predicate is an RDF/RDFS/OWL schema predicate.
rdf:type is excluded from filtering as it carries useful data signal.
"""
if predicate == RDF_TYPE:
return False
return predicate.startswith(SCHEMA_NAMESPACES)
def term_to_string(term):
"""Extract string value from a Term object."""
@ -120,7 +136,8 @@ class Query:
def __init__(
self, rag, collection, verbose,
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
max_path_length=2, edge_limit=25, track_usage=None,
max_path_length=2, edge_limit=25, max_reranker_input=350,
track_usage=None,
):
self.rag = rag
self.collection = collection
@ -130,6 +147,7 @@ class Query:
self.max_subgraph_size = max_subgraph_size
self.max_path_length = max_path_length
self.edge_limit = edge_limit
self.max_reranker_input = max_reranker_input
self.track_usage = track_usage
async def extract_concepts(self, query):
@ -346,7 +364,7 @@ class Query:
hop_directions = {}
for triple, direction in triples:
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
if triple_tuple[1] == LABEL:
if is_schema_predicate(triple_tuple[1]):
continue
if triple_tuple in seen_edges:
continue
@ -385,25 +403,50 @@ class Query:
# The reranker text highlights the NEW information relative
# to the traversal direction: arriving from S means p,o are
# new; from O means s,p are new; from P means s,o are new.
# Edges where the reranker-visible components are unlabeled
# IRIs are skipped — the cross-encoder can't score them.
def is_iri(val):
return val.startswith(("http://", "https://", "urn:"))
filtered_triples = []
labeled_hop = []
documents = []
for s, p, o in hop_triples:
ls = label_map.get(s, s)
lp = label_map.get(p, p)
lo = label_map.get(o, o)
labeled_hop.append((ls, lp, lo))
documents = []
for i, (triple_tuple, (ls, lp, lo)) in enumerate(
zip(hop_triples, labeled_hop)
):
direction = hop_directions[triple_tuple]
direction = hop_directions[(s, p, o)]
if direction == self.FROM_S:
if is_iri(lp) or is_iri(lo):
continue
text = f"{lp} {lo}"
elif direction == self.FROM_O:
if is_iri(ls) or is_iri(lp):
continue
text = f"{ls} {lp}"
else:
if is_iri(ls) or is_iri(lo):
continue
text = f"{ls} {lo}"
documents.append({"id": str(i), "text": text})
idx = len(filtered_triples)
filtered_triples.append((s, p, o))
labeled_hop.append((ls, lp, lo))
documents.append({"id": str(idx), "text": text})
hop_triples = filtered_triples
# Cap the number of candidates sent to the reranker
if len(hop_triples) > self.max_reranker_input:
if self.verbose:
logger.debug(
f"Hop {hop + 1}: truncating {len(hop_triples)} "
f"candidates to {self.max_reranker_input}"
)
hop_triples = hop_triples[:self.max_reranker_input]
labeled_hop = labeled_hop[:self.max_reranker_input]
documents = documents[:self.max_reranker_input]
queries = [
{"id": str(i), "text": c}
@ -588,7 +631,7 @@ class GraphRag:
async def query(
self, query, collection = "default",
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
max_path_length = 2, edge_limit = 25,
max_path_length = 2, edge_limit = 25, max_reranker_input = 350,
streaming = False,
chunk_callback = None,
explain_callback = None, save_answer_callback = None,
@ -642,6 +685,7 @@ class GraphRag:
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
edge_limit = edge_limit,
max_reranker_input = max_reranker_input,
track_usage = track_usage,
)

View file

@ -34,6 +34,7 @@ class Processor(FlowProcessor):
max_subgraph_size = params.get("max_subgraph_size", 150)
max_path_length = params.get("max_path_length", 2)
edge_limit = params.get("edge_limit", 25)
max_reranker_input = params.get("max_reranker_input", 350)
super(Processor, self).__init__(
**params | {
@ -44,6 +45,7 @@ class Processor(FlowProcessor):
"max_subgraph_size": max_subgraph_size,
"max_path_length": max_path_length,
"edge_limit": edge_limit,
"max_reranker_input": max_reranker_input,
}
)
@ -52,6 +54,7 @@ class Processor(FlowProcessor):
self.default_max_subgraph_size = max_subgraph_size
self.default_max_path_length = max_path_length
self.default_edge_limit = edge_limit
self.default_max_reranker_input = max_reranker_input
# Workspace isolation is enforced by the flow layer (flow.workspace).
# Per-request caching (see GraphRag) keeps within-request state
@ -197,6 +200,11 @@ class Processor(FlowProcessor):
else:
edge_limit = self.default_edge_limit
if v.max_reranker_input:
max_reranker_input = v.max_reranker_input
else:
max_reranker_input = self.default_max_reranker_input
async def save_answer(doc_id, answer_text):
await flow.librarian.save_document(
doc_id=doc_id,
@ -226,8 +234,8 @@ class Processor(FlowProcessor):
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
edge_limit = edge_limit,
max_reranker_input = max_reranker_input,
streaming = True,
chunk_callback = send_chunk,
explain_callback = send_explainability,
@ -242,8 +250,8 @@ class Processor(FlowProcessor):
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
edge_limit = edge_limit,
max_reranker_input = max_reranker_input,
explain_callback = send_explainability,
save_answer_callback = save_answer,
parent_uri = v.parent_uri,
@ -346,6 +354,13 @@ class Processor(FlowProcessor):
help=f'Max edges selected per hop by cross-encoder (default: 25)'
)
parser.add_argument(
'--max-reranker-input',
type=int,
default=350,
help=f'Max candidate edges sent to the reranker per hop (default: 350)'
)
# Note: Explainability triples are now stored in the request's collection
# with the named graph urn:graph:retrieval (no separate collection needed)