diff --git a/trustgraph-base/trustgraph/api/flow.py b/trustgraph-base/trustgraph/api/flow.py index f20d4d56..d89e16f6 100644 --- a/trustgraph-base/trustgraph/api/flow.py +++ b/trustgraph-base/trustgraph/api/flow.py @@ -449,7 +449,7 @@ class FlowInstance: def graph_rag( self, query, user="trustgraph", collection="default", entity_limit=50, triple_limit=30, max_subgraph_size=150, - max_path_length=2, + max_path_length=2, edge_score_limit=30, edge_limit=25, ): """ Execute graph-based Retrieval-Augmented Generation (RAG) query. @@ -465,6 +465,8 @@ class FlowInstance: triple_limit: Maximum triples per entity (default: 30) max_subgraph_size: Maximum total triples in subgraph (default: 150) max_path_length: Maximum traversal depth (default: 2) + edge_score_limit: Max edges for semantic pre-filter (default: 50) + edge_limit: Max edges after LLM scoring (default: 25) Returns: str: Generated response incorporating graph context @@ -492,6 +494,8 @@ class FlowInstance: "triple-limit": triple_limit, "max-subgraph-size": max_subgraph_size, "max-path-length": max_path_length, + "edge-score-limit": edge_score_limit, + "edge-limit": edge_limit, } return self.request( diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index 4e09351a..91db8b69 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -699,9 +699,12 @@ class SocketFlowInstance: query: str, user: str, collection: str, + entity_limit: int = 50, + triple_limit: int = 30, max_subgraph_size: int = 1000, - max_subgraph_count: int = 5, - max_entity_distance: int = 3, + max_path_length: int = 2, + edge_score_limit: int = 30, + edge_limit: int = 25, streaming: bool = False, **kwargs: Any ) -> Union[str, Iterator[str]]: @@ -715,9 +718,12 @@ class SocketFlowInstance: query: Natural language query user: User/keyspace identifier collection: Collection identifier + entity_limit: Maximum entities to retrieve (default: 50) + triple_limit: Maximum triples per entity (default: 30) max_subgraph_size: Maximum total triples in subgraph (default: 1000) - max_subgraph_count: Maximum number of subgraphs (default: 5) - max_entity_distance: Maximum traversal depth (default: 3) + max_path_length: Maximum traversal depth (default: 2) + edge_score_limit: Max edges for semantic pre-filter (default: 50) + edge_limit: Max edges after LLM scoring (default: 25) streaming: Enable streaming mode (default: False) **kwargs: Additional parameters passed to the service @@ -743,9 +749,12 @@ class SocketFlowInstance: "query": query, "user": user, "collection": collection, + "entity-limit": entity_limit, + "triple-limit": triple_limit, "max-subgraph-size": max_subgraph_size, - "max-subgraph-count": max_subgraph_count, - "max-entity-distance": max_entity_distance, + "max-path-length": max_path_length, + "edge-score-limit": edge_score_limit, + "edge-limit": edge_limit, "streaming": streaming } request.update(kwargs) @@ -762,9 +771,12 @@ class SocketFlowInstance: query: str, user: str, collection: str, + entity_limit: int = 50, + triple_limit: int = 30, max_subgraph_size: int = 1000, - max_subgraph_count: int = 5, - max_entity_distance: int = 3, + max_path_length: int = 2, + edge_score_limit: int = 30, + edge_limit: int = 25, **kwargs: Any ) -> Iterator[Union[RAGChunk, ProvenanceEvent]]: """ @@ -778,9 +790,12 @@ class SocketFlowInstance: query: Natural language query user: User/keyspace identifier collection: Collection identifier + entity_limit: Maximum entities to retrieve (default: 50) + triple_limit: Maximum triples per entity (default: 30) max_subgraph_size: Maximum total triples in subgraph (default: 1000) - max_subgraph_count: Maximum number of subgraphs (default: 5) - max_entity_distance: Maximum traversal depth (default: 3) + max_path_length: Maximum traversal depth (default: 2) + edge_score_limit: Max edges for semantic pre-filter (default: 50) + edge_limit: Max edges after LLM scoring (default: 25) **kwargs: Additional parameters passed to the service Yields: @@ -823,11 +838,14 @@ class SocketFlowInstance: "query": query, "user": user, "collection": collection, + "entity-limit": entity_limit, + "triple-limit": triple_limit, "max-subgraph-size": max_subgraph_size, - "max-subgraph-count": max_subgraph_count, - "max-entity-distance": max_entity_distance, + "max-path-length": max_path_length, + "edge-score-limit": edge_score_limit, + "edge-limit": edge_limit, "streaming": True, - "explainable": True, # Enable explainability mode + "explainable": True, } request.update(kwargs) diff --git a/trustgraph-base/trustgraph/messaging/translators/retrieval.py b/trustgraph-base/trustgraph/messaging/translators/retrieval.py index b7ff818c..98473db2 100644 --- a/trustgraph-base/trustgraph/messaging/translators/retrieval.py +++ b/trustgraph-base/trustgraph/messaging/translators/retrieval.py @@ -84,6 +84,7 @@ class GraphRagRequestTranslator(MessageTranslator): triple_limit=int(data.get("triple-limit", 30)), max_subgraph_size=int(data.get("max-subgraph-size", 1000)), max_path_length=int(data.get("max-path-length", 2)), + edge_score_limit=int(data.get("edge-score-limit", 30)), edge_limit=int(data.get("edge-limit", 25)), streaming=data.get("streaming", False) ) @@ -97,6 +98,7 @@ class GraphRagRequestTranslator(MessageTranslator): "triple-limit": obj.triple_limit, "max-subgraph-size": obj.max_subgraph_size, "max-path-length": obj.max_path_length, + "edge-score-limit": obj.edge_score_limit, "edge-limit": obj.edge_limit, "streaming": getattr(obj, "streaming", False) } diff --git a/trustgraph-base/trustgraph/schema/services/retrieval.py b/trustgraph-base/trustgraph/schema/services/retrieval.py index b3a9d58d..d4f76655 100644 --- a/trustgraph-base/trustgraph/schema/services/retrieval.py +++ b/trustgraph-base/trustgraph/schema/services/retrieval.py @@ -15,6 +15,7 @@ class GraphRagQuery: triple_limit: int = 0 max_subgraph_size: int = 0 max_path_length: int = 0 + edge_score_limit: int = 0 edge_limit: int = 0 streaming: bool = False diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py index 1e530c03..76b8b158 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py @@ -28,6 +28,8 @@ default_entity_limit = 50 default_triple_limit = 30 default_max_subgraph_size = 150 default_max_path_length = 2 +default_edge_score_limit = 30 +default_edge_limit = 25 # Provenance predicates TG = "https://trustgraph.ai/ns/" @@ -638,7 +640,8 @@ async def _question_explainable( def _question_explainable_api( url, flow_id, question_text, user, collection, entity_limit, triple_limit, - max_subgraph_size, max_path_length, token=None, debug=False + max_subgraph_size, max_path_length, edge_score_limit=30, + edge_limit=25, token=None, debug=False ): """Execute graph RAG with explainability using the new API classes.""" api = Api(url=url, token=token) @@ -652,9 +655,12 @@ def _question_explainable_api( query=question_text, user=user, collection=collection, + entity_limit=entity_limit, + triple_limit=triple_limit, max_subgraph_size=max_subgraph_size, - max_subgraph_count=5, - max_entity_distance=max_path_length, + max_path_length=max_path_length, + edge_score_limit=edge_score_limit, + edge_limit=edge_limit, ): if isinstance(item, RAGChunk): # Print response content @@ -743,7 +749,8 @@ def _question_explainable_api( def question( url, flow_id, question, user, collection, entity_limit, triple_limit, - max_subgraph_size, max_path_length, streaming=True, token=None, + max_subgraph_size, max_path_length, edge_score_limit=50, + edge_limit=25, streaming=True, token=None, explainable=False, debug=False ): @@ -759,6 +766,8 @@ def question( triple_limit=triple_limit, max_subgraph_size=max_subgraph_size, max_path_length=max_path_length, + edge_score_limit=edge_score_limit, + edge_limit=edge_limit, token=token, debug=debug ) @@ -781,6 +790,8 @@ def question( triple_limit=triple_limit, max_subgraph_size=max_subgraph_size, max_path_length=max_path_length, + edge_score_limit=edge_score_limit, + edge_limit=edge_limit, streaming=True ) @@ -801,7 +812,9 @@ def question( entity_limit=entity_limit, triple_limit=triple_limit, max_subgraph_size=max_subgraph_size, - max_path_length=max_path_length + max_path_length=max_path_length, + edge_score_limit=edge_score_limit, + edge_limit=edge_limit, ) print(resp) @@ -876,6 +889,20 @@ def main(): help=f'Max path length (default: {default_max_path_length})' ) + parser.add_argument( + '--edge-score-limit', + type=int, + default=default_edge_score_limit, + help=f'Semantic pre-filter limit before LLM scoring (default: {default_edge_score_limit})' + ) + + parser.add_argument( + '--edge-limit', + type=int, + default=default_edge_limit, + help=f'Max edges after LLM scoring (default: {default_edge_limit})' + ) + parser.add_argument( '--no-streaming', action='store_true', @@ -908,6 +935,8 @@ def main(): triple_limit=args.triple_limit, max_subgraph_size=args.max_subgraph_size, max_path_length=args.max_path_length, + edge_score_limit=args.edge_score_limit, + edge_limit=args.edge_limit, streaming=not args.no_streaming, token=args.token, explainable=args.explainable, diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py index 22d4fc1b..ea9326a4 100644 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py @@ -3,6 +3,7 @@ import asyncio import hashlib import json import logging +import math import time import uuid from collections import OrderedDict @@ -550,7 +551,8 @@ class GraphRag: async def query( self, query, user = "trustgraph", collection = "default", entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000, - max_path_length = 2, edge_limit = 25, streaming = False, + max_path_length = 2, edge_score_limit = 30, edge_limit = 25, + streaming = False, chunk_callback = None, explain_callback = None, save_answer_callback = None, ): @@ -565,6 +567,8 @@ class GraphRag: triple_limit: Max triples per entity max_subgraph_size: Max edges in subgraph max_path_length: Max hops from seed entities + edge_score_limit: Max edges to pass to LLM scoring (semantic pre-filter) + edge_limit: Max edges after LLM scoring streaming: Enable streaming LLM response chunk_callback: async def callback(chunk, end_of_stream) for streaming explain_callback: async def callback(triples, explain_id) for real-time explainability @@ -628,6 +632,70 @@ class GraphRag: logger.debug(f"Knowledge graph: {kg}") logger.debug(f"Query: {query}") + # Semantic pre-filter: reduce edges before expensive LLM scoring + if edge_score_limit > 0 and len(kg) > edge_score_limit: + + if self.verbose: + logger.debug( + f"Semantic pre-filter: {len(kg)} edges > " + f"limit {edge_score_limit}, filtering..." + ) + + # Embed edge descriptions: "subject, predicate, object" + edge_descriptions = [ + f"{s}, {p}, {o}" for s, p, o in kg + ] + + # Embed concepts and edge descriptions concurrently + concept_embed_task = self.embeddings_client.embed(concepts) + edge_embed_task = self.embeddings_client.embed(edge_descriptions) + + concept_vectors, edge_vectors = await asyncio.gather( + concept_embed_task, edge_embed_task + ) + + # Score each edge by max cosine similarity to any concept + def cosine_similarity(a, b): + dot = sum(x * y for x, y in zip(a, b)) + norm_a = math.sqrt(sum(x * x for x in a)) + norm_b = math.sqrt(sum(x * x for x in b)) + if norm_a == 0 or norm_b == 0: + return 0.0 + return dot / (norm_a * norm_b) + + edge_scores = [] + for i, edge_vec in enumerate(edge_vectors): + max_sim = max( + cosine_similarity(edge_vec, cv) + for cv in concept_vectors + ) + edge_scores.append((max_sim, i)) + + # Sort by similarity descending and keep top edge_score_limit + edge_scores.sort(reverse=True) + keep_indices = set( + idx for _, idx in edge_scores[:edge_score_limit] + ) + + # Filter kg and rebuild uri_map + filtered_kg = [] + filtered_uri_map = {} + for i, (s, p, o) in enumerate(kg): + if i in keep_indices: + filtered_kg.append((s, p, o)) + eid = edge_id(s, p, o) + if eid in uri_map: + filtered_uri_map[eid] = uri_map[eid] + + if self.verbose: + logger.debug( + f"Semantic pre-filter kept {len(filtered_kg)} " + f"of {len(kg)} edges" + ) + + kg = filtered_kg + uri_map = filtered_uri_map + # Build edge map: {hash_id: (labeled_s, labeled_p, labeled_o)} # uri_map already maps edge_id -> (uri_s, uri_p, uri_o) edge_map = {} diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py index ec4a806c..efcc51ef 100755 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py @@ -39,6 +39,7 @@ class Processor(FlowProcessor): triple_limit = params.get("triple_limit", 30) max_subgraph_size = params.get("max_subgraph_size", 150) max_path_length = params.get("max_path_length", 2) + edge_score_limit = params.get("edge_score_limit", 30) edge_limit = params.get("edge_limit", 25) super(Processor, self).__init__( @@ -49,6 +50,7 @@ class Processor(FlowProcessor): "triple_limit": triple_limit, "max_subgraph_size": max_subgraph_size, "max_path_length": max_path_length, + "edge_score_limit": edge_score_limit, "edge_limit": edge_limit, } ) @@ -57,6 +59,7 @@ class Processor(FlowProcessor): self.default_triple_limit = triple_limit self.default_max_subgraph_size = max_subgraph_size self.default_max_path_length = max_path_length + self.default_edge_score_limit = edge_score_limit self.default_edge_limit = edge_limit # CRITICAL SECURITY: NEVER share data between users or collections @@ -295,6 +298,11 @@ class Processor(FlowProcessor): else: max_path_length = self.default_max_path_length + if v.edge_score_limit: + edge_score_limit = v.edge_score_limit + else: + edge_score_limit = self.default_edge_score_limit + if v.edge_limit: edge_limit = v.edge_limit else: @@ -330,6 +338,7 @@ class Processor(FlowProcessor): entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, + edge_score_limit = edge_score_limit, edge_limit = edge_limit, streaming = True, chunk_callback = send_chunk, @@ -344,6 +353,7 @@ class Processor(FlowProcessor): entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, + edge_score_limit = edge_score_limit, edge_limit = edge_limit, explain_callback = send_explainability, save_answer_callback = save_answer, @@ -432,6 +442,20 @@ class Processor(FlowProcessor): help=f'Default max path length (default: 2)' ) + parser.add_argument( + '--edge-score-limit', + type=int, + default=30, + help=f'Semantic pre-filter limit before LLM scoring (default: 30)' + ) + + parser.add_argument( + '--edge-limit', + type=int, + default=25, + help=f'Max edges after LLM scoring (default: 25)' + ) + # Note: Explainability triples are now stored in the user's collection # with the named graph urn:graph:retrieval (no separate collection needed)