mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-03 23:11:00 +02:00
feat: filter and cap GraphRAG reranker input across full stack (#1021)
- Filter out RDF/RDFS/OWL schema predicates (rdfs:domain, owl:inverseOf, etc.) from hop traversal, keeping rdf:type for data signal - Skip edges where reranker-visible components are unlabeled IRIs, since the cross-encoder cannot meaningfully score raw URIs - Add max-reranker-input safety cap (default 350) to prevent overloading the reranker, applied after filtering for maximum useful candidates - Expose max-reranker-input as per-request parameter through schema, translator, REST API, socket client, CLI, and OpenAPI spec - Update tests - Update tech spec
This commit is contained in:
parent
76c4763b9b
commit
68e816e65c
10 changed files with 198 additions and 43 deletions
|
|
@ -404,10 +404,33 @@ no LLM call. These fields are dropped from the Focus entity.
|
|||
|
||||
a. Retrieve all edges one hop from the current frontier nodes.
|
||||
|
||||
b. Represent each edge using direction-aware text: from a
|
||||
subject node use `"{predicate} {object}"`, from an object
|
||||
node use `"{subject} {predicate}"`, from a predicate node
|
||||
use `"{subject} {object}"`.
|
||||
b. Filter and represent edges for scoring:
|
||||
|
||||
- **Schema predicate filter.** Edges with RDF/RDFS/OWL
|
||||
schema predicates (`rdfs:domain`, `owl:inverseOf`, etc.)
|
||||
are removed. `rdf:type` is kept as it carries useful
|
||||
data signal.
|
||||
|
||||
- **IRI filter.** Edges where the reranker-visible text
|
||||
components (after label resolution) are still raw IRIs
|
||||
are removed — the cross-encoder cannot meaningfully score
|
||||
unresolved URIs. Only the components that would appear
|
||||
in the reranker text are checked, based on traversal
|
||||
direction.
|
||||
|
||||
- **Direction-aware text.** Each surviving edge is
|
||||
represented using direction-aware text: from a subject
|
||||
node use `"{predicate} {object}"`, from an object node
|
||||
use `"{subject} {predicate}"`, from a predicate node
|
||||
use `"{subject} {object}"`.
|
||||
|
||||
- **Reranker input cap.** The candidate set is truncated
|
||||
to `max_reranker_input` (default 350) edges. This is a
|
||||
safety measure, not an accuracy optimisation — there is
|
||||
no point in producing a perfectly ranked edge set if the
|
||||
reranker crashes or times out because it was handed
|
||||
thousands of candidates. The cap is applied after
|
||||
filtering so that the most useful edges fill the budget.
|
||||
|
||||
c. Score edges against the extracted concepts using the
|
||||
cross-encoder service.
|
||||
|
|
|
|||
|
|
@ -42,6 +42,13 @@ properties:
|
|||
minimum: 1
|
||||
maximum: 5
|
||||
example: 3
|
||||
max-reranker-input:
|
||||
type: integer
|
||||
description: Maximum candidate edges sent to the reranker per hop
|
||||
default: 350
|
||||
minimum: 1
|
||||
maximum: 1000
|
||||
example: 350
|
||||
streaming:
|
||||
type: boolean
|
||||
description: Enable streaming response delivery
|
||||
|
|
|
|||
|
|
@ -18,15 +18,30 @@ from trustgraph.schema import Term, IRI, LITERAL
|
|||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_rag(reranker_results=None):
|
||||
"""Create a mock GraphRag with all clients stubbed."""
|
||||
LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
||||
|
||||
def _make_rag(reranker_results=None, labels=None):
|
||||
"""Create a mock GraphRag with all clients stubbed.
|
||||
|
||||
labels is an optional dict mapping URI -> label string. When provided,
|
||||
the mock triples_client.query will return matching label triples so
|
||||
that hop_and_filter resolves labels instead of falling back to raw URIs
|
||||
(which are now filtered out by the IRI filter).
|
||||
"""
|
||||
rag = MagicMock()
|
||||
rag.label_cache = LRUCacheWithTTL()
|
||||
rag.triples_client = AsyncMock()
|
||||
rag.reranker_client = AsyncMock()
|
||||
|
||||
# Label lookups return empty (fall back to URI)
|
||||
rag.triples_client.query.return_value = []
|
||||
if labels:
|
||||
async def label_query(s=None, p=None, o=None, limit=1, **kwargs):
|
||||
if p == LABEL and s in labels:
|
||||
return [MagicMock(o=labels[s])]
|
||||
return []
|
||||
rag.triples_client.query.side_effect = label_query
|
||||
else:
|
||||
rag.triples_client.query.return_value = []
|
||||
|
||||
if reranker_results is not None:
|
||||
rag.reranker_client.rerank.return_value = reranker_results
|
||||
|
|
@ -147,8 +162,13 @@ class TestDirectionAwareRerankerText:
|
|||
"http://ex/likes",
|
||||
"http://ex/entity-B",
|
||||
)
|
||||
labels = {
|
||||
"http://ex/entity-A": "Alice",
|
||||
"http://ex/likes": "likes",
|
||||
"http://ex/entity-B": "Bob",
|
||||
}
|
||||
reranker_result = _reranker_result(0)
|
||||
rag = _make_rag(reranker_results=[reranker_result])
|
||||
rag = _make_rag(reranker_results=[reranker_result], labels=labels)
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if s is not None:
|
||||
|
|
@ -166,9 +186,8 @@ class TestDirectionAwareRerankerText:
|
|||
|
||||
call_args = rag.reranker_client.rerank.call_args
|
||||
documents = call_args.kwargs["documents"]
|
||||
# Text should be "{p} {o}" — the URIs since no labels found
|
||||
assert len(documents) == 1
|
||||
assert documents[0]["text"] == "http://ex/likes http://ex/entity-B"
|
||||
assert documents[0]["text"] == "likes Bob"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_from_o_uses_subject_predicate(self):
|
||||
|
|
@ -178,8 +197,13 @@ class TestDirectionAwareRerankerText:
|
|||
"http://ex/likes",
|
||||
"http://ex/entity-B",
|
||||
)
|
||||
labels = {
|
||||
"http://ex/entity-A": "Alice",
|
||||
"http://ex/likes": "likes",
|
||||
"http://ex/entity-B": "Bob",
|
||||
}
|
||||
reranker_result = _reranker_result(0)
|
||||
rag = _make_rag(reranker_results=[reranker_result])
|
||||
rag = _make_rag(reranker_results=[reranker_result], labels=labels)
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if o is not None:
|
||||
|
|
@ -198,7 +222,7 @@ class TestDirectionAwareRerankerText:
|
|||
call_args = rag.reranker_client.rerank.call_args
|
||||
documents = call_args.kwargs["documents"]
|
||||
assert len(documents) == 1
|
||||
assert documents[0]["text"] == "http://ex/entity-A http://ex/likes"
|
||||
assert documents[0]["text"] == "Alice likes"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_from_p_uses_subject_object(self):
|
||||
|
|
@ -208,8 +232,13 @@ class TestDirectionAwareRerankerText:
|
|||
"http://ex/likes",
|
||||
"http://ex/entity-B",
|
||||
)
|
||||
labels = {
|
||||
"http://ex/entity-A": "Alice",
|
||||
"http://ex/likes": "likes",
|
||||
"http://ex/entity-B": "Bob",
|
||||
}
|
||||
reranker_result = _reranker_result(0)
|
||||
rag = _make_rag(reranker_results=[reranker_result])
|
||||
rag = _make_rag(reranker_results=[reranker_result], labels=labels)
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if p is not None:
|
||||
|
|
@ -228,7 +257,7 @@ class TestDirectionAwareRerankerText:
|
|||
call_args = rag.reranker_client.rerank.call_args
|
||||
documents = call_args.kwargs["documents"]
|
||||
assert len(documents) == 1
|
||||
assert documents[0]["text"] == "http://ex/entity-A http://ex/entity-B"
|
||||
assert documents[0]["text"] == "Alice Bob"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_directions_produce_different_text(self):
|
||||
|
|
@ -239,10 +268,18 @@ class TestDirectionAwareRerankerText:
|
|||
triple_from_o = _make_schema_triple(
|
||||
"http://ex/other", "http://ex/ref", "http://ex/seed",
|
||||
)
|
||||
labels = {
|
||||
"http://ex/seed": "Seed",
|
||||
"http://ex/rel": "relates to",
|
||||
"http://ex/target": "Target",
|
||||
"http://ex/other": "Other",
|
||||
"http://ex/ref": "references",
|
||||
}
|
||||
|
||||
rag = _make_rag(reranker_results=[
|
||||
_reranker_result(0), _reranker_result(1),
|
||||
])
|
||||
rag = _make_rag(
|
||||
reranker_results=[_reranker_result(0), _reranker_result(1)],
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if s == "http://ex/seed":
|
||||
|
|
@ -264,10 +301,10 @@ class TestDirectionAwareRerankerText:
|
|||
documents = call_args.kwargs["documents"]
|
||||
texts = {d["text"] for d in documents}
|
||||
|
||||
# From S: "{p} {o}" = "http://ex/rel http://ex/target"
|
||||
assert "http://ex/rel http://ex/target" in texts
|
||||
# From O: "{s} {p}" = "http://ex/other http://ex/ref"
|
||||
assert "http://ex/other http://ex/ref" in texts
|
||||
# From S: "{p} {o}" = "relates to Target"
|
||||
assert "relates to Target" in texts
|
||||
# From O: "{s} {p}" = "Other references"
|
||||
assert "Other references" in texts
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_labels_applied_to_direction_text(self):
|
||||
|
|
@ -280,8 +317,6 @@ class TestDirectionAwareRerankerText:
|
|||
reranker_result = _reranker_result(0)
|
||||
rag = _make_rag(reranker_results=[reranker_result])
|
||||
|
||||
LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if s is not None and p is None:
|
||||
return [triple]
|
||||
|
|
@ -323,10 +358,17 @@ class TestDirectionAwareRerankerText:
|
|||
triple_b = _make_schema_triple(
|
||||
"http://ex/cpu-B", "http://ex/hasCategory", "http://ex/Processors",
|
||||
)
|
||||
labels = {
|
||||
"http://ex/cpu-A": "CPU Alpha",
|
||||
"http://ex/cpu-B": "CPU Beta",
|
||||
"http://ex/hasCategory": "has category",
|
||||
"http://ex/Processors": "Processors",
|
||||
}
|
||||
|
||||
rag = _make_rag(reranker_results=[
|
||||
_reranker_result(0), _reranker_result(1),
|
||||
])
|
||||
rag = _make_rag(
|
||||
reranker_results=[_reranker_result(0), _reranker_result(1)],
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if o == "http://ex/Processors":
|
||||
|
|
@ -349,5 +391,5 @@ class TestDirectionAwareRerankerText:
|
|||
assert len(texts) == 2
|
||||
# From O: "{s} {p}" — subjects differ, so texts differ
|
||||
assert texts[0] != texts[1]
|
||||
assert "http://ex/cpu-A" in texts[0]
|
||||
assert "http://ex/cpu-B" in texts[1]
|
||||
assert "CPU Alpha" in texts[0]
|
||||
assert "CPU Beta" in texts[1]
|
||||
|
|
|
|||
|
|
@ -357,6 +357,7 @@ class FlowInstance:
|
|||
self, query,collection="default",
|
||||
entity_limit=50, triple_limit=30, max_subgraph_size=150,
|
||||
max_path_length=2, edge_score_limit=30, edge_limit=25,
|
||||
max_reranker_input=350,
|
||||
):
|
||||
"""
|
||||
Execute graph-based Retrieval-Augmented Generation (RAG) query.
|
||||
|
|
@ -373,6 +374,7 @@ class FlowInstance:
|
|||
max_path_length: Maximum traversal depth (default: 2)
|
||||
edge_score_limit: Max edges for semantic pre-filter (default: 50)
|
||||
edge_limit: Max edges after LLM scoring (default: 25)
|
||||
max_reranker_input: Max candidate edges sent to reranker per hop (default: 350)
|
||||
|
||||
Returns:
|
||||
str: Generated response incorporating graph context
|
||||
|
|
@ -399,6 +401,7 @@ class FlowInstance:
|
|||
"max-path-length": max_path_length,
|
||||
"edge-score-limit": edge_score_limit,
|
||||
"edge-limit": edge_limit,
|
||||
"max-reranker-input": max_reranker_input,
|
||||
}
|
||||
|
||||
result = self.request(
|
||||
|
|
|
|||
|
|
@ -682,6 +682,7 @@ class SocketFlowInstance:
|
|||
max_path_length: int = 2,
|
||||
edge_score_limit: int = 30,
|
||||
edge_limit: int = 25,
|
||||
max_reranker_input: int = 350,
|
||||
streaming: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
|
||||
|
|
@ -699,6 +700,7 @@ class SocketFlowInstance:
|
|||
"max-path-length": max_path_length,
|
||||
"edge-score-limit": edge_score_limit,
|
||||
"edge-limit": edge_limit,
|
||||
"max-reranker-input": max_reranker_input,
|
||||
"streaming": streaming
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
|
@ -725,6 +727,7 @@ class SocketFlowInstance:
|
|||
max_path_length: int = 2,
|
||||
edge_score_limit: int = 30,
|
||||
edge_limit: int = 25,
|
||||
max_reranker_input: int = 350,
|
||||
**kwargs: Any
|
||||
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
|
||||
"""Execute graph-based RAG query with explainability support."""
|
||||
|
|
@ -737,6 +740,7 @@ class SocketFlowInstance:
|
|||
"max-path-length": max_path_length,
|
||||
"edge-score-limit": edge_score_limit,
|
||||
"edge-limit": edge_limit,
|
||||
"max-reranker-input": max_reranker_input,
|
||||
"streaming": True,
|
||||
"explainable": True,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -103,6 +103,7 @@ class GraphRagRequestTranslator(MessageTranslator):
|
|||
max_path_length=int(data.get("max-path-length", 2)),
|
||||
edge_score_limit=int(data.get("edge-score-limit", 30)),
|
||||
edge_limit=int(data.get("edge-limit", 25)),
|
||||
max_reranker_input=int(data.get("max-reranker-input", 350)),
|
||||
streaming=data.get("streaming", False)
|
||||
)
|
||||
|
||||
|
|
@ -116,6 +117,7 @@ class GraphRagRequestTranslator(MessageTranslator):
|
|||
"max-path-length": obj.max_path_length,
|
||||
"edge-score-limit": obj.edge_score_limit,
|
||||
"edge-limit": obj.edge_limit,
|
||||
"max-reranker-input": obj.max_reranker_input,
|
||||
"streaming": getattr(obj, "streaming", False)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ class GraphRagQuery:
|
|||
max_path_length: int = 0
|
||||
edge_score_limit: int = 0
|
||||
edge_limit: int = 0
|
||||
max_reranker_input: int = 0
|
||||
streaming: bool = False
|
||||
parent_uri: str = ""
|
||||
|
||||
|
|
|
|||
|
|
@ -27,11 +27,13 @@ default_max_subgraph_size = 150
|
|||
default_max_path_length = 2
|
||||
default_edge_score_limit = 30
|
||||
default_edge_limit = 25
|
||||
default_max_reranker_input = 350
|
||||
|
||||
def _question_explainable_api(
|
||||
url, flow_id, question_text, collection, entity_limit, triple_limit,
|
||||
max_subgraph_size, max_path_length, edge_score_limit=30,
|
||||
edge_limit=25, token=None, debug=False, workspace="default",
|
||||
edge_limit=25, max_reranker_input=350, token=None, debug=False,
|
||||
workspace="default",
|
||||
):
|
||||
"""Execute graph RAG with explainability using the new API classes."""
|
||||
api = Api(url=url, token=token, workspace=workspace)
|
||||
|
|
@ -50,6 +52,7 @@ def _question_explainable_api(
|
|||
max_path_length=max_path_length,
|
||||
edge_score_limit=edge_score_limit,
|
||||
edge_limit=edge_limit,
|
||||
max_reranker_input=max_reranker_input,
|
||||
):
|
||||
if isinstance(item, RAGChunk):
|
||||
# Print response content
|
||||
|
|
@ -138,7 +141,7 @@ def _question_explainable_api(
|
|||
def question(
|
||||
url, flow_id, question, collection, entity_limit, triple_limit,
|
||||
max_subgraph_size, max_path_length, edge_score_limit=50,
|
||||
edge_limit=25, streaming=True, token=None,
|
||||
edge_limit=25, max_reranker_input=350, streaming=True, token=None,
|
||||
explainable=False, debug=False, show_usage=False,
|
||||
workspace="default",
|
||||
):
|
||||
|
|
@ -156,6 +159,7 @@ def question(
|
|||
max_path_length=max_path_length,
|
||||
edge_score_limit=edge_score_limit,
|
||||
edge_limit=edge_limit,
|
||||
max_reranker_input=max_reranker_input,
|
||||
token=token,
|
||||
debug=debug,
|
||||
workspace=workspace,
|
||||
|
|
@ -180,6 +184,7 @@ def question(
|
|||
max_path_length=max_path_length,
|
||||
edge_score_limit=edge_score_limit,
|
||||
edge_limit=edge_limit,
|
||||
max_reranker_input=max_reranker_input,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
|
|
@ -212,6 +217,7 @@ def question(
|
|||
max_path_length=max_path_length,
|
||||
edge_score_limit=edge_score_limit,
|
||||
edge_limit=edge_limit,
|
||||
max_reranker_input=max_reranker_input,
|
||||
)
|
||||
print(result.text)
|
||||
|
||||
|
|
@ -308,6 +314,13 @@ def main():
|
|||
help=f'Max edges after LLM scoring (default: {default_edge_limit})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--max-reranker-input',
|
||||
type=int,
|
||||
default=default_max_reranker_input,
|
||||
help=f'Max candidate edges sent to reranker per hop (default: {default_max_reranker_input})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--no-streaming',
|
||||
action='store_true',
|
||||
|
|
@ -347,6 +360,7 @@ def main():
|
|||
max_path_length=args.max_path_length,
|
||||
edge_score_limit=args.edge_score_limit,
|
||||
edge_limit=args.edge_limit,
|
||||
max_reranker_input=args.max_reranker_input,
|
||||
streaming=not args.no_streaming,
|
||||
token=args.token,
|
||||
explainable=args.explainable,
|
||||
|
|
|
|||
|
|
@ -34,6 +34,22 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
||||
RDF_NS = "http://www.w3.org/1999/02/22-rdf-syntax-ns#"
|
||||
RDFS_NS = "http://www.w3.org/2000/01/rdf-schema#"
|
||||
OWL_NS = "http://www.w3.org/2002/07/owl#"
|
||||
RDF_TYPE = RDF_NS + "type"
|
||||
SCHEMA_NAMESPACES = (RDF_NS, RDFS_NS, OWL_NS)
|
||||
|
||||
|
||||
def is_schema_predicate(predicate):
|
||||
"""Return True if the predicate is an RDF/RDFS/OWL schema predicate.
|
||||
|
||||
rdf:type is excluded from filtering as it carries useful data signal.
|
||||
"""
|
||||
if predicate == RDF_TYPE:
|
||||
return False
|
||||
return predicate.startswith(SCHEMA_NAMESPACES)
|
||||
|
||||
|
||||
def term_to_string(term):
|
||||
"""Extract string value from a Term object."""
|
||||
|
|
@ -120,7 +136,8 @@ class Query:
|
|||
def __init__(
|
||||
self, rag, collection, verbose,
|
||||
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
|
||||
max_path_length=2, edge_limit=25, track_usage=None,
|
||||
max_path_length=2, edge_limit=25, max_reranker_input=350,
|
||||
track_usage=None,
|
||||
):
|
||||
self.rag = rag
|
||||
self.collection = collection
|
||||
|
|
@ -130,6 +147,7 @@ class Query:
|
|||
self.max_subgraph_size = max_subgraph_size
|
||||
self.max_path_length = max_path_length
|
||||
self.edge_limit = edge_limit
|
||||
self.max_reranker_input = max_reranker_input
|
||||
self.track_usage = track_usage
|
||||
|
||||
async def extract_concepts(self, query):
|
||||
|
|
@ -346,7 +364,7 @@ class Query:
|
|||
hop_directions = {}
|
||||
for triple, direction in triples:
|
||||
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
|
||||
if triple_tuple[1] == LABEL:
|
||||
if is_schema_predicate(triple_tuple[1]):
|
||||
continue
|
||||
if triple_tuple in seen_edges:
|
||||
continue
|
||||
|
|
@ -385,25 +403,50 @@ class Query:
|
|||
# The reranker text highlights the NEW information relative
|
||||
# to the traversal direction: arriving from S means p,o are
|
||||
# new; from O means s,p are new; from P means s,o are new.
|
||||
# Edges where the reranker-visible components are unlabeled
|
||||
# IRIs are skipped — the cross-encoder can't score them.
|
||||
def is_iri(val):
|
||||
return val.startswith(("http://", "https://", "urn:"))
|
||||
|
||||
filtered_triples = []
|
||||
labeled_hop = []
|
||||
documents = []
|
||||
for s, p, o in hop_triples:
|
||||
ls = label_map.get(s, s)
|
||||
lp = label_map.get(p, p)
|
||||
lo = label_map.get(o, o)
|
||||
labeled_hop.append((ls, lp, lo))
|
||||
|
||||
documents = []
|
||||
for i, (triple_tuple, (ls, lp, lo)) in enumerate(
|
||||
zip(hop_triples, labeled_hop)
|
||||
):
|
||||
direction = hop_directions[triple_tuple]
|
||||
direction = hop_directions[(s, p, o)]
|
||||
if direction == self.FROM_S:
|
||||
if is_iri(lp) or is_iri(lo):
|
||||
continue
|
||||
text = f"{lp} {lo}"
|
||||
elif direction == self.FROM_O:
|
||||
if is_iri(ls) or is_iri(lp):
|
||||
continue
|
||||
text = f"{ls} {lp}"
|
||||
else:
|
||||
if is_iri(ls) or is_iri(lo):
|
||||
continue
|
||||
text = f"{ls} {lo}"
|
||||
documents.append({"id": str(i), "text": text})
|
||||
|
||||
idx = len(filtered_triples)
|
||||
filtered_triples.append((s, p, o))
|
||||
labeled_hop.append((ls, lp, lo))
|
||||
documents.append({"id": str(idx), "text": text})
|
||||
|
||||
hop_triples = filtered_triples
|
||||
|
||||
# Cap the number of candidates sent to the reranker
|
||||
if len(hop_triples) > self.max_reranker_input:
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Hop {hop + 1}: truncating {len(hop_triples)} "
|
||||
f"candidates to {self.max_reranker_input}"
|
||||
)
|
||||
hop_triples = hop_triples[:self.max_reranker_input]
|
||||
labeled_hop = labeled_hop[:self.max_reranker_input]
|
||||
documents = documents[:self.max_reranker_input]
|
||||
|
||||
queries = [
|
||||
{"id": str(i), "text": c}
|
||||
|
|
@ -588,7 +631,7 @@ class GraphRag:
|
|||
async def query(
|
||||
self, query, collection = "default",
|
||||
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
|
||||
max_path_length = 2, edge_limit = 25,
|
||||
max_path_length = 2, edge_limit = 25, max_reranker_input = 350,
|
||||
streaming = False,
|
||||
chunk_callback = None,
|
||||
explain_callback = None, save_answer_callback = None,
|
||||
|
|
@ -642,6 +685,7 @@ class GraphRag:
|
|||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
edge_limit = edge_limit,
|
||||
max_reranker_input = max_reranker_input,
|
||||
track_usage = track_usage,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ class Processor(FlowProcessor):
|
|||
max_subgraph_size = params.get("max_subgraph_size", 150)
|
||||
max_path_length = params.get("max_path_length", 2)
|
||||
edge_limit = params.get("edge_limit", 25)
|
||||
max_reranker_input = params.get("max_reranker_input", 350)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
|
|
@ -44,6 +45,7 @@ class Processor(FlowProcessor):
|
|||
"max_subgraph_size": max_subgraph_size,
|
||||
"max_path_length": max_path_length,
|
||||
"edge_limit": edge_limit,
|
||||
"max_reranker_input": max_reranker_input,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -52,6 +54,7 @@ class Processor(FlowProcessor):
|
|||
self.default_max_subgraph_size = max_subgraph_size
|
||||
self.default_max_path_length = max_path_length
|
||||
self.default_edge_limit = edge_limit
|
||||
self.default_max_reranker_input = max_reranker_input
|
||||
|
||||
# Workspace isolation is enforced by the flow layer (flow.workspace).
|
||||
# Per-request caching (see GraphRag) keeps within-request state
|
||||
|
|
@ -197,6 +200,11 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
edge_limit = self.default_edge_limit
|
||||
|
||||
if v.max_reranker_input:
|
||||
max_reranker_input = v.max_reranker_input
|
||||
else:
|
||||
max_reranker_input = self.default_max_reranker_input
|
||||
|
||||
async def save_answer(doc_id, answer_text):
|
||||
await flow.librarian.save_document(
|
||||
doc_id=doc_id,
|
||||
|
|
@ -226,8 +234,8 @@ class Processor(FlowProcessor):
|
|||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
|
||||
edge_limit = edge_limit,
|
||||
max_reranker_input = max_reranker_input,
|
||||
streaming = True,
|
||||
chunk_callback = send_chunk,
|
||||
explain_callback = send_explainability,
|
||||
|
|
@ -242,8 +250,8 @@ class Processor(FlowProcessor):
|
|||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
|
||||
edge_limit = edge_limit,
|
||||
max_reranker_input = max_reranker_input,
|
||||
explain_callback = send_explainability,
|
||||
save_answer_callback = save_answer,
|
||||
parent_uri = v.parent_uri,
|
||||
|
|
@ -346,6 +354,13 @@ class Processor(FlowProcessor):
|
|||
help=f'Max edges selected per hop by cross-encoder (default: 25)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--max-reranker-input',
|
||||
type=int,
|
||||
default=350,
|
||||
help=f'Max candidate edges sent to the reranker per hop (default: 350)'
|
||||
)
|
||||
|
||||
# Note: Explainability triples are now stored in the request's collection
|
||||
# with the named graph urn:graph:retrieval (no separate collection needed)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue