diff --git a/docs/tech-specs/graph-rag-semantic-filter.md b/docs/tech-specs/graph-rag-semantic-filter.md index 58497d10..cb2fd24a 100644 --- a/docs/tech-specs/graph-rag-semantic-filter.md +++ b/docs/tech-specs/graph-rag-semantic-filter.md @@ -404,10 +404,33 @@ no LLM call. These fields are dropped from the Focus entity. a. Retrieve all edges one hop from the current frontier nodes. - b. Represent each edge using direction-aware text: from a - subject node use `"{predicate} {object}"`, from an object - node use `"{subject} {predicate}"`, from a predicate node - use `"{subject} {object}"`. + b. Filter and represent edges for scoring: + + - **Schema predicate filter.** Edges with RDF/RDFS/OWL + schema predicates (`rdfs:domain`, `owl:inverseOf`, etc.) + are removed. `rdf:type` is kept as it carries useful + data signal. + + - **IRI filter.** Edges where the reranker-visible text + components (after label resolution) are still raw IRIs + are removed — the cross-encoder cannot meaningfully score + unresolved URIs. Only the components that would appear + in the reranker text are checked, based on traversal + direction. + + - **Direction-aware text.** Each surviving edge is + represented using direction-aware text: from a subject + node use `"{predicate} {object}"`, from an object node + use `"{subject} {predicate}"`, from a predicate node + use `"{subject} {object}"`. + + - **Reranker input cap.** The candidate set is truncated + to `max_reranker_input` (default 350) edges. This is a + safety measure, not an accuracy optimisation — there is + no point in producing a perfectly ranked edge set if the + reranker crashes or times out because it was handed + thousands of candidates. The cap is applied after + filtering so that the most useful edges fill the budget. c. Score edges against the extracted concepts using the cross-encoder service. diff --git a/specs/api/components/schemas/rag/GraphRagRequest.yaml b/specs/api/components/schemas/rag/GraphRagRequest.yaml index 754dcc92..f1899fc4 100644 --- a/specs/api/components/schemas/rag/GraphRagRequest.yaml +++ b/specs/api/components/schemas/rag/GraphRagRequest.yaml @@ -42,6 +42,13 @@ properties: minimum: 1 maximum: 5 example: 3 + max-reranker-input: + type: integer + description: Maximum candidate edges sent to the reranker per hop + default: 350 + minimum: 1 + maximum: 1000 + example: 350 streaming: type: boolean description: Enable streaming response delivery diff --git a/tests/unit/test_retrieval/test_graph_rag_direction_aware_text.py b/tests/unit/test_retrieval/test_graph_rag_direction_aware_text.py index cc95228a..c58ac3b6 100644 --- a/tests/unit/test_retrieval/test_graph_rag_direction_aware_text.py +++ b/tests/unit/test_retrieval/test_graph_rag_direction_aware_text.py @@ -18,15 +18,30 @@ from trustgraph.schema import Term, IRI, LITERAL # Helpers # --------------------------------------------------------------------------- -def _make_rag(reranker_results=None): - """Create a mock GraphRag with all clients stubbed.""" +LABEL = "http://www.w3.org/2000/01/rdf-schema#label" + + +def _make_rag(reranker_results=None, labels=None): + """Create a mock GraphRag with all clients stubbed. + + labels is an optional dict mapping URI -> label string. When provided, + the mock triples_client.query will return matching label triples so + that hop_and_filter resolves labels instead of falling back to raw URIs + (which are now filtered out by the IRI filter). + """ rag = MagicMock() rag.label_cache = LRUCacheWithTTL() rag.triples_client = AsyncMock() rag.reranker_client = AsyncMock() - # Label lookups return empty (fall back to URI) - rag.triples_client.query.return_value = [] + if labels: + async def label_query(s=None, p=None, o=None, limit=1, **kwargs): + if p == LABEL and s in labels: + return [MagicMock(o=labels[s])] + return [] + rag.triples_client.query.side_effect = label_query + else: + rag.triples_client.query.return_value = [] if reranker_results is not None: rag.reranker_client.rerank.return_value = reranker_results @@ -147,8 +162,13 @@ class TestDirectionAwareRerankerText: "http://ex/likes", "http://ex/entity-B", ) + labels = { + "http://ex/entity-A": "Alice", + "http://ex/likes": "likes", + "http://ex/entity-B": "Bob", + } reranker_result = _reranker_result(0) - rag = _make_rag(reranker_results=[reranker_result]) + rag = _make_rag(reranker_results=[reranker_result], labels=labels) async def query_stream(s=None, p=None, o=None, **kwargs): if s is not None: @@ -166,9 +186,8 @@ class TestDirectionAwareRerankerText: call_args = rag.reranker_client.rerank.call_args documents = call_args.kwargs["documents"] - # Text should be "{p} {o}" — the URIs since no labels found assert len(documents) == 1 - assert documents[0]["text"] == "http://ex/likes http://ex/entity-B" + assert documents[0]["text"] == "likes Bob" @pytest.mark.asyncio async def test_from_o_uses_subject_predicate(self): @@ -178,8 +197,13 @@ class TestDirectionAwareRerankerText: "http://ex/likes", "http://ex/entity-B", ) + labels = { + "http://ex/entity-A": "Alice", + "http://ex/likes": "likes", + "http://ex/entity-B": "Bob", + } reranker_result = _reranker_result(0) - rag = _make_rag(reranker_results=[reranker_result]) + rag = _make_rag(reranker_results=[reranker_result], labels=labels) async def query_stream(s=None, p=None, o=None, **kwargs): if o is not None: @@ -198,7 +222,7 @@ class TestDirectionAwareRerankerText: call_args = rag.reranker_client.rerank.call_args documents = call_args.kwargs["documents"] assert len(documents) == 1 - assert documents[0]["text"] == "http://ex/entity-A http://ex/likes" + assert documents[0]["text"] == "Alice likes" @pytest.mark.asyncio async def test_from_p_uses_subject_object(self): @@ -208,8 +232,13 @@ class TestDirectionAwareRerankerText: "http://ex/likes", "http://ex/entity-B", ) + labels = { + "http://ex/entity-A": "Alice", + "http://ex/likes": "likes", + "http://ex/entity-B": "Bob", + } reranker_result = _reranker_result(0) - rag = _make_rag(reranker_results=[reranker_result]) + rag = _make_rag(reranker_results=[reranker_result], labels=labels) async def query_stream(s=None, p=None, o=None, **kwargs): if p is not None: @@ -228,7 +257,7 @@ class TestDirectionAwareRerankerText: call_args = rag.reranker_client.rerank.call_args documents = call_args.kwargs["documents"] assert len(documents) == 1 - assert documents[0]["text"] == "http://ex/entity-A http://ex/entity-B" + assert documents[0]["text"] == "Alice Bob" @pytest.mark.asyncio async def test_mixed_directions_produce_different_text(self): @@ -239,10 +268,18 @@ class TestDirectionAwareRerankerText: triple_from_o = _make_schema_triple( "http://ex/other", "http://ex/ref", "http://ex/seed", ) + labels = { + "http://ex/seed": "Seed", + "http://ex/rel": "relates to", + "http://ex/target": "Target", + "http://ex/other": "Other", + "http://ex/ref": "references", + } - rag = _make_rag(reranker_results=[ - _reranker_result(0), _reranker_result(1), - ]) + rag = _make_rag( + reranker_results=[_reranker_result(0), _reranker_result(1)], + labels=labels, + ) async def query_stream(s=None, p=None, o=None, **kwargs): if s == "http://ex/seed": @@ -264,10 +301,10 @@ class TestDirectionAwareRerankerText: documents = call_args.kwargs["documents"] texts = {d["text"] for d in documents} - # From S: "{p} {o}" = "http://ex/rel http://ex/target" - assert "http://ex/rel http://ex/target" in texts - # From O: "{s} {p}" = "http://ex/other http://ex/ref" - assert "http://ex/other http://ex/ref" in texts + # From S: "{p} {o}" = "relates to Target" + assert "relates to Target" in texts + # From O: "{s} {p}" = "Other references" + assert "Other references" in texts @pytest.mark.asyncio async def test_labels_applied_to_direction_text(self): @@ -280,8 +317,6 @@ class TestDirectionAwareRerankerText: reranker_result = _reranker_result(0) rag = _make_rag(reranker_results=[reranker_result]) - LABEL = "http://www.w3.org/2000/01/rdf-schema#label" - async def query_stream(s=None, p=None, o=None, **kwargs): if s is not None and p is None: return [triple] @@ -323,10 +358,17 @@ class TestDirectionAwareRerankerText: triple_b = _make_schema_triple( "http://ex/cpu-B", "http://ex/hasCategory", "http://ex/Processors", ) + labels = { + "http://ex/cpu-A": "CPU Alpha", + "http://ex/cpu-B": "CPU Beta", + "http://ex/hasCategory": "has category", + "http://ex/Processors": "Processors", + } - rag = _make_rag(reranker_results=[ - _reranker_result(0), _reranker_result(1), - ]) + rag = _make_rag( + reranker_results=[_reranker_result(0), _reranker_result(1)], + labels=labels, + ) async def query_stream(s=None, p=None, o=None, **kwargs): if o == "http://ex/Processors": @@ -349,5 +391,5 @@ class TestDirectionAwareRerankerText: assert len(texts) == 2 # From O: "{s} {p}" — subjects differ, so texts differ assert texts[0] != texts[1] - assert "http://ex/cpu-A" in texts[0] - assert "http://ex/cpu-B" in texts[1] + assert "CPU Alpha" in texts[0] + assert "CPU Beta" in texts[1] diff --git a/trustgraph-base/trustgraph/api/flow.py b/trustgraph-base/trustgraph/api/flow.py index b9e9487b..95fc009b 100644 --- a/trustgraph-base/trustgraph/api/flow.py +++ b/trustgraph-base/trustgraph/api/flow.py @@ -357,6 +357,7 @@ class FlowInstance: self, query,collection="default", entity_limit=50, triple_limit=30, max_subgraph_size=150, max_path_length=2, edge_score_limit=30, edge_limit=25, + max_reranker_input=350, ): """ Execute graph-based Retrieval-Augmented Generation (RAG) query. @@ -373,6 +374,7 @@ class FlowInstance: max_path_length: Maximum traversal depth (default: 2) edge_score_limit: Max edges for semantic pre-filter (default: 50) edge_limit: Max edges after LLM scoring (default: 25) + max_reranker_input: Max candidate edges sent to reranker per hop (default: 350) Returns: str: Generated response incorporating graph context @@ -399,6 +401,7 @@ class FlowInstance: "max-path-length": max_path_length, "edge-score-limit": edge_score_limit, "edge-limit": edge_limit, + "max-reranker-input": max_reranker_input, } result = self.request( diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index efa887a1..6c51210f 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -682,6 +682,7 @@ class SocketFlowInstance: max_path_length: int = 2, edge_score_limit: int = 30, edge_limit: int = 25, + max_reranker_input: int = 350, streaming: bool = False, **kwargs: Any ) -> Union[TextCompletionResult, Iterator[RAGChunk]]: @@ -699,6 +700,7 @@ class SocketFlowInstance: "max-path-length": max_path_length, "edge-score-limit": edge_score_limit, "edge-limit": edge_limit, + "max-reranker-input": max_reranker_input, "streaming": streaming } request.update(kwargs) @@ -725,6 +727,7 @@ class SocketFlowInstance: max_path_length: int = 2, edge_score_limit: int = 30, edge_limit: int = 25, + max_reranker_input: int = 350, **kwargs: Any ) -> Iterator[Union[RAGChunk, ProvenanceEvent]]: """Execute graph-based RAG query with explainability support.""" @@ -737,6 +740,7 @@ class SocketFlowInstance: "max-path-length": max_path_length, "edge-score-limit": edge_score_limit, "edge-limit": edge_limit, + "max-reranker-input": max_reranker_input, "streaming": True, "explainable": True, } diff --git a/trustgraph-base/trustgraph/messaging/translators/retrieval.py b/trustgraph-base/trustgraph/messaging/translators/retrieval.py index f2a0b29a..556ad758 100644 --- a/trustgraph-base/trustgraph/messaging/translators/retrieval.py +++ b/trustgraph-base/trustgraph/messaging/translators/retrieval.py @@ -103,6 +103,7 @@ class GraphRagRequestTranslator(MessageTranslator): max_path_length=int(data.get("max-path-length", 2)), edge_score_limit=int(data.get("edge-score-limit", 30)), edge_limit=int(data.get("edge-limit", 25)), + max_reranker_input=int(data.get("max-reranker-input", 350)), streaming=data.get("streaming", False) ) @@ -116,6 +117,7 @@ class GraphRagRequestTranslator(MessageTranslator): "max-path-length": obj.max_path_length, "edge-score-limit": obj.edge_score_limit, "edge-limit": obj.edge_limit, + "max-reranker-input": obj.max_reranker_input, "streaming": getattr(obj, "streaming", False) } diff --git a/trustgraph-base/trustgraph/schema/services/retrieval.py b/trustgraph-base/trustgraph/schema/services/retrieval.py index 2d4e01e1..47ced73d 100644 --- a/trustgraph-base/trustgraph/schema/services/retrieval.py +++ b/trustgraph-base/trustgraph/schema/services/retrieval.py @@ -15,6 +15,7 @@ class GraphRagQuery: max_path_length: int = 0 edge_score_limit: int = 0 edge_limit: int = 0 + max_reranker_input: int = 0 streaming: bool = False parent_uri: str = "" diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py index 892d2d35..97bb8db9 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py @@ -27,11 +27,13 @@ default_max_subgraph_size = 150 default_max_path_length = 2 default_edge_score_limit = 30 default_edge_limit = 25 +default_max_reranker_input = 350 def _question_explainable_api( url, flow_id, question_text, collection, entity_limit, triple_limit, max_subgraph_size, max_path_length, edge_score_limit=30, - edge_limit=25, token=None, debug=False, workspace="default", + edge_limit=25, max_reranker_input=350, token=None, debug=False, + workspace="default", ): """Execute graph RAG with explainability using the new API classes.""" api = Api(url=url, token=token, workspace=workspace) @@ -50,6 +52,7 @@ def _question_explainable_api( max_path_length=max_path_length, edge_score_limit=edge_score_limit, edge_limit=edge_limit, + max_reranker_input=max_reranker_input, ): if isinstance(item, RAGChunk): # Print response content @@ -138,7 +141,7 @@ def _question_explainable_api( def question( url, flow_id, question, collection, entity_limit, triple_limit, max_subgraph_size, max_path_length, edge_score_limit=50, - edge_limit=25, streaming=True, token=None, + edge_limit=25, max_reranker_input=350, streaming=True, token=None, explainable=False, debug=False, show_usage=False, workspace="default", ): @@ -156,6 +159,7 @@ def question( max_path_length=max_path_length, edge_score_limit=edge_score_limit, edge_limit=edge_limit, + max_reranker_input=max_reranker_input, token=token, debug=debug, workspace=workspace, @@ -180,6 +184,7 @@ def question( max_path_length=max_path_length, edge_score_limit=edge_score_limit, edge_limit=edge_limit, + max_reranker_input=max_reranker_input, streaming=True ) @@ -212,6 +217,7 @@ def question( max_path_length=max_path_length, edge_score_limit=edge_score_limit, edge_limit=edge_limit, + max_reranker_input=max_reranker_input, ) print(result.text) @@ -308,6 +314,13 @@ def main(): help=f'Max edges after LLM scoring (default: {default_edge_limit})' ) + parser.add_argument( + '--max-reranker-input', + type=int, + default=default_max_reranker_input, + help=f'Max candidate edges sent to reranker per hop (default: {default_max_reranker_input})' + ) + parser.add_argument( '--no-streaming', action='store_true', @@ -347,6 +360,7 @@ def main(): max_path_length=args.max_path_length, edge_score_limit=args.edge_score_limit, edge_limit=args.edge_limit, + max_reranker_input=args.max_reranker_input, streaming=not args.no_streaming, token=args.token, explainable=args.explainable, diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py index 2054cb0f..c094e395 100644 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py @@ -34,6 +34,22 @@ logger = logging.getLogger(__name__) LABEL="http://www.w3.org/2000/01/rdf-schema#label" +RDF_NS = "http://www.w3.org/1999/02/22-rdf-syntax-ns#" +RDFS_NS = "http://www.w3.org/2000/01/rdf-schema#" +OWL_NS = "http://www.w3.org/2002/07/owl#" +RDF_TYPE = RDF_NS + "type" +SCHEMA_NAMESPACES = (RDF_NS, RDFS_NS, OWL_NS) + + +def is_schema_predicate(predicate): + """Return True if the predicate is an RDF/RDFS/OWL schema predicate. + + rdf:type is excluded from filtering as it carries useful data signal. + """ + if predicate == RDF_TYPE: + return False + return predicate.startswith(SCHEMA_NAMESPACES) + def term_to_string(term): """Extract string value from a Term object.""" @@ -120,7 +136,8 @@ class Query: def __init__( self, rag, collection, verbose, entity_limit=50, triple_limit=30, max_subgraph_size=1000, - max_path_length=2, edge_limit=25, track_usage=None, + max_path_length=2, edge_limit=25, max_reranker_input=350, + track_usage=None, ): self.rag = rag self.collection = collection @@ -130,6 +147,7 @@ class Query: self.max_subgraph_size = max_subgraph_size self.max_path_length = max_path_length self.edge_limit = edge_limit + self.max_reranker_input = max_reranker_input self.track_usage = track_usage async def extract_concepts(self, query): @@ -346,7 +364,7 @@ class Query: hop_directions = {} for triple, direction in triples: triple_tuple = (str(triple.s), str(triple.p), str(triple.o)) - if triple_tuple[1] == LABEL: + if is_schema_predicate(triple_tuple[1]): continue if triple_tuple in seen_edges: continue @@ -385,25 +403,50 @@ class Query: # The reranker text highlights the NEW information relative # to the traversal direction: arriving from S means p,o are # new; from O means s,p are new; from P means s,o are new. + # Edges where the reranker-visible components are unlabeled + # IRIs are skipped — the cross-encoder can't score them. + def is_iri(val): + return val.startswith(("http://", "https://", "urn:")) + + filtered_triples = [] labeled_hop = [] + documents = [] for s, p, o in hop_triples: ls = label_map.get(s, s) lp = label_map.get(p, p) lo = label_map.get(o, o) - labeled_hop.append((ls, lp, lo)) - documents = [] - for i, (triple_tuple, (ls, lp, lo)) in enumerate( - zip(hop_triples, labeled_hop) - ): - direction = hop_directions[triple_tuple] + direction = hop_directions[(s, p, o)] if direction == self.FROM_S: + if is_iri(lp) or is_iri(lo): + continue text = f"{lp} {lo}" elif direction == self.FROM_O: + if is_iri(ls) or is_iri(lp): + continue text = f"{ls} {lp}" else: + if is_iri(ls) or is_iri(lo): + continue text = f"{ls} {lo}" - documents.append({"id": str(i), "text": text}) + + idx = len(filtered_triples) + filtered_triples.append((s, p, o)) + labeled_hop.append((ls, lp, lo)) + documents.append({"id": str(idx), "text": text}) + + hop_triples = filtered_triples + + # Cap the number of candidates sent to the reranker + if len(hop_triples) > self.max_reranker_input: + if self.verbose: + logger.debug( + f"Hop {hop + 1}: truncating {len(hop_triples)} " + f"candidates to {self.max_reranker_input}" + ) + hop_triples = hop_triples[:self.max_reranker_input] + labeled_hop = labeled_hop[:self.max_reranker_input] + documents = documents[:self.max_reranker_input] queries = [ {"id": str(i), "text": c} @@ -588,7 +631,7 @@ class GraphRag: async def query( self, query, collection = "default", entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000, - max_path_length = 2, edge_limit = 25, + max_path_length = 2, edge_limit = 25, max_reranker_input = 350, streaming = False, chunk_callback = None, explain_callback = None, save_answer_callback = None, @@ -642,6 +685,7 @@ class GraphRag: max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, edge_limit = edge_limit, + max_reranker_input = max_reranker_input, track_usage = track_usage, ) diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py index 27ec4937..9ed802f4 100755 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py @@ -34,6 +34,7 @@ class Processor(FlowProcessor): max_subgraph_size = params.get("max_subgraph_size", 150) max_path_length = params.get("max_path_length", 2) edge_limit = params.get("edge_limit", 25) + max_reranker_input = params.get("max_reranker_input", 350) super(Processor, self).__init__( **params | { @@ -44,6 +45,7 @@ class Processor(FlowProcessor): "max_subgraph_size": max_subgraph_size, "max_path_length": max_path_length, "edge_limit": edge_limit, + "max_reranker_input": max_reranker_input, } ) @@ -52,6 +54,7 @@ class Processor(FlowProcessor): self.default_max_subgraph_size = max_subgraph_size self.default_max_path_length = max_path_length self.default_edge_limit = edge_limit + self.default_max_reranker_input = max_reranker_input # Workspace isolation is enforced by the flow layer (flow.workspace). # Per-request caching (see GraphRag) keeps within-request state @@ -197,6 +200,11 @@ class Processor(FlowProcessor): else: edge_limit = self.default_edge_limit + if v.max_reranker_input: + max_reranker_input = v.max_reranker_input + else: + max_reranker_input = self.default_max_reranker_input + async def save_answer(doc_id, answer_text): await flow.librarian.save_document( doc_id=doc_id, @@ -226,8 +234,8 @@ class Processor(FlowProcessor): entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, - edge_limit = edge_limit, + max_reranker_input = max_reranker_input, streaming = True, chunk_callback = send_chunk, explain_callback = send_explainability, @@ -242,8 +250,8 @@ class Processor(FlowProcessor): entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, - edge_limit = edge_limit, + max_reranker_input = max_reranker_input, explain_callback = send_explainability, save_answer_callback = save_answer, parent_uri = v.parent_uri, @@ -346,6 +354,13 @@ class Processor(FlowProcessor): help=f'Max edges selected per hop by cross-encoder (default: 25)' ) + parser.add_argument( + '--max-reranker-input', + type=int, + default=350, + help=f'Max candidate edges sent to the reranker per hop (default: 350)' + ) + # Note: Explainability triples are now stored in the request's collection # with the named graph urn:graph:retrieval (no separate collection needed)