Add semantic pre-filter for GraphRAG edge scoring (#702)

Embed edge descriptions and compute cosine similarity against grounding
concepts to reduce the number of edges sent to expensive LLM scoring.
Controlled by edge_score_limit parameter (default 30), skipped when edge
count is already below the limit.

Also plumbs edge_score_limit and edge_limit parameters end-to-end:
- CLI args (--edge-score-limit, --edge-limit) in both invoke and service
- Socket client: fix parameter mapping to use hyphenated wire-format keys
- Flow API, message translator, gateway all pass through correctly
- Explainable code path (_question_explainable_api) now forwards all params
- Default edge_score_limit changed from 50 to 30 based on typical subgraph
  sizes
This commit is contained in:
cybermaggedon 2026-03-21 20:06:29 +00:00 committed by GitHub
parent bc68738c37
commit 1a7b654bd3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 166 additions and 20 deletions

View file

@ -449,7 +449,7 @@ class FlowInstance:
def graph_rag( def graph_rag(
self, query, user="trustgraph", collection="default", self, query, user="trustgraph", collection="default",
entity_limit=50, triple_limit=30, max_subgraph_size=150, entity_limit=50, triple_limit=30, max_subgraph_size=150,
max_path_length=2, max_path_length=2, edge_score_limit=30, edge_limit=25,
): ):
""" """
Execute graph-based Retrieval-Augmented Generation (RAG) query. Execute graph-based Retrieval-Augmented Generation (RAG) query.
@ -465,6 +465,8 @@ class FlowInstance:
triple_limit: Maximum triples per entity (default: 30) triple_limit: Maximum triples per entity (default: 30)
max_subgraph_size: Maximum total triples in subgraph (default: 150) max_subgraph_size: Maximum total triples in subgraph (default: 150)
max_path_length: Maximum traversal depth (default: 2) max_path_length: Maximum traversal depth (default: 2)
edge_score_limit: Max edges for semantic pre-filter (default: 50)
edge_limit: Max edges after LLM scoring (default: 25)
Returns: Returns:
str: Generated response incorporating graph context str: Generated response incorporating graph context
@ -492,6 +494,8 @@ class FlowInstance:
"triple-limit": triple_limit, "triple-limit": triple_limit,
"max-subgraph-size": max_subgraph_size, "max-subgraph-size": max_subgraph_size,
"max-path-length": max_path_length, "max-path-length": max_path_length,
"edge-score-limit": edge_score_limit,
"edge-limit": edge_limit,
} }
return self.request( return self.request(

View file

@ -699,9 +699,12 @@ class SocketFlowInstance:
query: str, query: str,
user: str, user: str,
collection: str, collection: str,
entity_limit: int = 50,
triple_limit: int = 30,
max_subgraph_size: int = 1000, max_subgraph_size: int = 1000,
max_subgraph_count: int = 5, max_path_length: int = 2,
max_entity_distance: int = 3, edge_score_limit: int = 30,
edge_limit: int = 25,
streaming: bool = False, streaming: bool = False,
**kwargs: Any **kwargs: Any
) -> Union[str, Iterator[str]]: ) -> Union[str, Iterator[str]]:
@ -715,9 +718,12 @@ class SocketFlowInstance:
query: Natural language query query: Natural language query
user: User/keyspace identifier user: User/keyspace identifier
collection: Collection identifier collection: Collection identifier
entity_limit: Maximum entities to retrieve (default: 50)
triple_limit: Maximum triples per entity (default: 30)
max_subgraph_size: Maximum total triples in subgraph (default: 1000) max_subgraph_size: Maximum total triples in subgraph (default: 1000)
max_subgraph_count: Maximum number of subgraphs (default: 5) max_path_length: Maximum traversal depth (default: 2)
max_entity_distance: Maximum traversal depth (default: 3) edge_score_limit: Max edges for semantic pre-filter (default: 50)
edge_limit: Max edges after LLM scoring (default: 25)
streaming: Enable streaming mode (default: False) streaming: Enable streaming mode (default: False)
**kwargs: Additional parameters passed to the service **kwargs: Additional parameters passed to the service
@ -743,9 +749,12 @@ class SocketFlowInstance:
"query": query, "query": query,
"user": user, "user": user,
"collection": collection, "collection": collection,
"entity-limit": entity_limit,
"triple-limit": triple_limit,
"max-subgraph-size": max_subgraph_size, "max-subgraph-size": max_subgraph_size,
"max-subgraph-count": max_subgraph_count, "max-path-length": max_path_length,
"max-entity-distance": max_entity_distance, "edge-score-limit": edge_score_limit,
"edge-limit": edge_limit,
"streaming": streaming "streaming": streaming
} }
request.update(kwargs) request.update(kwargs)
@ -762,9 +771,12 @@ class SocketFlowInstance:
query: str, query: str,
user: str, user: str,
collection: str, collection: str,
entity_limit: int = 50,
triple_limit: int = 30,
max_subgraph_size: int = 1000, max_subgraph_size: int = 1000,
max_subgraph_count: int = 5, max_path_length: int = 2,
max_entity_distance: int = 3, edge_score_limit: int = 30,
edge_limit: int = 25,
**kwargs: Any **kwargs: Any
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]: ) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
""" """
@ -778,9 +790,12 @@ class SocketFlowInstance:
query: Natural language query query: Natural language query
user: User/keyspace identifier user: User/keyspace identifier
collection: Collection identifier collection: Collection identifier
entity_limit: Maximum entities to retrieve (default: 50)
triple_limit: Maximum triples per entity (default: 30)
max_subgraph_size: Maximum total triples in subgraph (default: 1000) max_subgraph_size: Maximum total triples in subgraph (default: 1000)
max_subgraph_count: Maximum number of subgraphs (default: 5) max_path_length: Maximum traversal depth (default: 2)
max_entity_distance: Maximum traversal depth (default: 3) edge_score_limit: Max edges for semantic pre-filter (default: 50)
edge_limit: Max edges after LLM scoring (default: 25)
**kwargs: Additional parameters passed to the service **kwargs: Additional parameters passed to the service
Yields: Yields:
@ -823,11 +838,14 @@ class SocketFlowInstance:
"query": query, "query": query,
"user": user, "user": user,
"collection": collection, "collection": collection,
"entity-limit": entity_limit,
"triple-limit": triple_limit,
"max-subgraph-size": max_subgraph_size, "max-subgraph-size": max_subgraph_size,
"max-subgraph-count": max_subgraph_count, "max-path-length": max_path_length,
"max-entity-distance": max_entity_distance, "edge-score-limit": edge_score_limit,
"edge-limit": edge_limit,
"streaming": True, "streaming": True,
"explainable": True, # Enable explainability mode "explainable": True,
} }
request.update(kwargs) request.update(kwargs)

View file

@ -84,6 +84,7 @@ class GraphRagRequestTranslator(MessageTranslator):
triple_limit=int(data.get("triple-limit", 30)), triple_limit=int(data.get("triple-limit", 30)),
max_subgraph_size=int(data.get("max-subgraph-size", 1000)), max_subgraph_size=int(data.get("max-subgraph-size", 1000)),
max_path_length=int(data.get("max-path-length", 2)), max_path_length=int(data.get("max-path-length", 2)),
edge_score_limit=int(data.get("edge-score-limit", 30)),
edge_limit=int(data.get("edge-limit", 25)), edge_limit=int(data.get("edge-limit", 25)),
streaming=data.get("streaming", False) streaming=data.get("streaming", False)
) )
@ -97,6 +98,7 @@ class GraphRagRequestTranslator(MessageTranslator):
"triple-limit": obj.triple_limit, "triple-limit": obj.triple_limit,
"max-subgraph-size": obj.max_subgraph_size, "max-subgraph-size": obj.max_subgraph_size,
"max-path-length": obj.max_path_length, "max-path-length": obj.max_path_length,
"edge-score-limit": obj.edge_score_limit,
"edge-limit": obj.edge_limit, "edge-limit": obj.edge_limit,
"streaming": getattr(obj, "streaming", False) "streaming": getattr(obj, "streaming", False)
} }

View file

@ -15,6 +15,7 @@ class GraphRagQuery:
triple_limit: int = 0 triple_limit: int = 0
max_subgraph_size: int = 0 max_subgraph_size: int = 0
max_path_length: int = 0 max_path_length: int = 0
edge_score_limit: int = 0
edge_limit: int = 0 edge_limit: int = 0
streaming: bool = False streaming: bool = False

View file

@ -28,6 +28,8 @@ default_entity_limit = 50
default_triple_limit = 30 default_triple_limit = 30
default_max_subgraph_size = 150 default_max_subgraph_size = 150
default_max_path_length = 2 default_max_path_length = 2
default_edge_score_limit = 30
default_edge_limit = 25
# Provenance predicates # Provenance predicates
TG = "https://trustgraph.ai/ns/" TG = "https://trustgraph.ai/ns/"
@ -638,7 +640,8 @@ async def _question_explainable(
def _question_explainable_api( def _question_explainable_api(
url, flow_id, question_text, user, collection, entity_limit, triple_limit, url, flow_id, question_text, user, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, token=None, debug=False max_subgraph_size, max_path_length, edge_score_limit=30,
edge_limit=25, token=None, debug=False
): ):
"""Execute graph RAG with explainability using the new API classes.""" """Execute graph RAG with explainability using the new API classes."""
api = Api(url=url, token=token) api = Api(url=url, token=token)
@ -652,9 +655,12 @@ def _question_explainable_api(
query=question_text, query=question_text,
user=user, user=user,
collection=collection, collection=collection,
entity_limit=entity_limit,
triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size, max_subgraph_size=max_subgraph_size,
max_subgraph_count=5, max_path_length=max_path_length,
max_entity_distance=max_path_length, edge_score_limit=edge_score_limit,
edge_limit=edge_limit,
): ):
if isinstance(item, RAGChunk): if isinstance(item, RAGChunk):
# Print response content # Print response content
@ -743,7 +749,8 @@ def _question_explainable_api(
def question( def question(
url, flow_id, question, user, collection, entity_limit, triple_limit, url, flow_id, question, user, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, streaming=True, token=None, max_subgraph_size, max_path_length, edge_score_limit=50,
edge_limit=25, streaming=True, token=None,
explainable=False, debug=False explainable=False, debug=False
): ):
@ -759,6 +766,8 @@ def question(
triple_limit=triple_limit, triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size, max_subgraph_size=max_subgraph_size,
max_path_length=max_path_length, max_path_length=max_path_length,
edge_score_limit=edge_score_limit,
edge_limit=edge_limit,
token=token, token=token,
debug=debug debug=debug
) )
@ -781,6 +790,8 @@ def question(
triple_limit=triple_limit, triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size, max_subgraph_size=max_subgraph_size,
max_path_length=max_path_length, max_path_length=max_path_length,
edge_score_limit=edge_score_limit,
edge_limit=edge_limit,
streaming=True streaming=True
) )
@ -801,7 +812,9 @@ def question(
entity_limit=entity_limit, entity_limit=entity_limit,
triple_limit=triple_limit, triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size, max_subgraph_size=max_subgraph_size,
max_path_length=max_path_length max_path_length=max_path_length,
edge_score_limit=edge_score_limit,
edge_limit=edge_limit,
) )
print(resp) print(resp)
@ -876,6 +889,20 @@ def main():
help=f'Max path length (default: {default_max_path_length})' help=f'Max path length (default: {default_max_path_length})'
) )
parser.add_argument(
'--edge-score-limit',
type=int,
default=default_edge_score_limit,
help=f'Semantic pre-filter limit before LLM scoring (default: {default_edge_score_limit})'
)
parser.add_argument(
'--edge-limit',
type=int,
default=default_edge_limit,
help=f'Max edges after LLM scoring (default: {default_edge_limit})'
)
parser.add_argument( parser.add_argument(
'--no-streaming', '--no-streaming',
action='store_true', action='store_true',
@ -908,6 +935,8 @@ def main():
triple_limit=args.triple_limit, triple_limit=args.triple_limit,
max_subgraph_size=args.max_subgraph_size, max_subgraph_size=args.max_subgraph_size,
max_path_length=args.max_path_length, max_path_length=args.max_path_length,
edge_score_limit=args.edge_score_limit,
edge_limit=args.edge_limit,
streaming=not args.no_streaming, streaming=not args.no_streaming,
token=args.token, token=args.token,
explainable=args.explainable, explainable=args.explainable,

View file

@ -3,6 +3,7 @@ import asyncio
import hashlib import hashlib
import json import json
import logging import logging
import math
import time import time
import uuid import uuid
from collections import OrderedDict from collections import OrderedDict
@ -550,7 +551,8 @@ class GraphRag:
async def query( async def query(
self, query, user = "trustgraph", collection = "default", self, query, user = "trustgraph", collection = "default",
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000, entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
max_path_length = 2, edge_limit = 25, streaming = False, max_path_length = 2, edge_score_limit = 30, edge_limit = 25,
streaming = False,
chunk_callback = None, chunk_callback = None,
explain_callback = None, save_answer_callback = None, explain_callback = None, save_answer_callback = None,
): ):
@ -565,6 +567,8 @@ class GraphRag:
triple_limit: Max triples per entity triple_limit: Max triples per entity
max_subgraph_size: Max edges in subgraph max_subgraph_size: Max edges in subgraph
max_path_length: Max hops from seed entities max_path_length: Max hops from seed entities
edge_score_limit: Max edges to pass to LLM scoring (semantic pre-filter)
edge_limit: Max edges after LLM scoring
streaming: Enable streaming LLM response streaming: Enable streaming LLM response
chunk_callback: async def callback(chunk, end_of_stream) for streaming chunk_callback: async def callback(chunk, end_of_stream) for streaming
explain_callback: async def callback(triples, explain_id) for real-time explainability explain_callback: async def callback(triples, explain_id) for real-time explainability
@ -628,6 +632,70 @@ class GraphRag:
logger.debug(f"Knowledge graph: {kg}") logger.debug(f"Knowledge graph: {kg}")
logger.debug(f"Query: {query}") logger.debug(f"Query: {query}")
# Semantic pre-filter: reduce edges before expensive LLM scoring
if edge_score_limit > 0 and len(kg) > edge_score_limit:
if self.verbose:
logger.debug(
f"Semantic pre-filter: {len(kg)} edges > "
f"limit {edge_score_limit}, filtering..."
)
# Embed edge descriptions: "subject, predicate, object"
edge_descriptions = [
f"{s}, {p}, {o}" for s, p, o in kg
]
# Embed concepts and edge descriptions concurrently
concept_embed_task = self.embeddings_client.embed(concepts)
edge_embed_task = self.embeddings_client.embed(edge_descriptions)
concept_vectors, edge_vectors = await asyncio.gather(
concept_embed_task, edge_embed_task
)
# Score each edge by max cosine similarity to any concept
def cosine_similarity(a, b):
dot = sum(x * y for x, y in zip(a, b))
norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(x * x for x in b))
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
edge_scores = []
for i, edge_vec in enumerate(edge_vectors):
max_sim = max(
cosine_similarity(edge_vec, cv)
for cv in concept_vectors
)
edge_scores.append((max_sim, i))
# Sort by similarity descending and keep top edge_score_limit
edge_scores.sort(reverse=True)
keep_indices = set(
idx for _, idx in edge_scores[:edge_score_limit]
)
# Filter kg and rebuild uri_map
filtered_kg = []
filtered_uri_map = {}
for i, (s, p, o) in enumerate(kg):
if i in keep_indices:
filtered_kg.append((s, p, o))
eid = edge_id(s, p, o)
if eid in uri_map:
filtered_uri_map[eid] = uri_map[eid]
if self.verbose:
logger.debug(
f"Semantic pre-filter kept {len(filtered_kg)} "
f"of {len(kg)} edges"
)
kg = filtered_kg
uri_map = filtered_uri_map
# Build edge map: {hash_id: (labeled_s, labeled_p, labeled_o)} # Build edge map: {hash_id: (labeled_s, labeled_p, labeled_o)}
# uri_map already maps edge_id -> (uri_s, uri_p, uri_o) # uri_map already maps edge_id -> (uri_s, uri_p, uri_o)
edge_map = {} edge_map = {}

View file

@ -39,6 +39,7 @@ class Processor(FlowProcessor):
triple_limit = params.get("triple_limit", 30) triple_limit = params.get("triple_limit", 30)
max_subgraph_size = params.get("max_subgraph_size", 150) max_subgraph_size = params.get("max_subgraph_size", 150)
max_path_length = params.get("max_path_length", 2) max_path_length = params.get("max_path_length", 2)
edge_score_limit = params.get("edge_score_limit", 30)
edge_limit = params.get("edge_limit", 25) edge_limit = params.get("edge_limit", 25)
super(Processor, self).__init__( super(Processor, self).__init__(
@ -49,6 +50,7 @@ class Processor(FlowProcessor):
"triple_limit": triple_limit, "triple_limit": triple_limit,
"max_subgraph_size": max_subgraph_size, "max_subgraph_size": max_subgraph_size,
"max_path_length": max_path_length, "max_path_length": max_path_length,
"edge_score_limit": edge_score_limit,
"edge_limit": edge_limit, "edge_limit": edge_limit,
} }
) )
@ -57,6 +59,7 @@ class Processor(FlowProcessor):
self.default_triple_limit = triple_limit self.default_triple_limit = triple_limit
self.default_max_subgraph_size = max_subgraph_size self.default_max_subgraph_size = max_subgraph_size
self.default_max_path_length = max_path_length self.default_max_path_length = max_path_length
self.default_edge_score_limit = edge_score_limit
self.default_edge_limit = edge_limit self.default_edge_limit = edge_limit
# CRITICAL SECURITY: NEVER share data between users or collections # CRITICAL SECURITY: NEVER share data between users or collections
@ -295,6 +298,11 @@ class Processor(FlowProcessor):
else: else:
max_path_length = self.default_max_path_length max_path_length = self.default_max_path_length
if v.edge_score_limit:
edge_score_limit = v.edge_score_limit
else:
edge_score_limit = self.default_edge_score_limit
if v.edge_limit: if v.edge_limit:
edge_limit = v.edge_limit edge_limit = v.edge_limit
else: else:
@ -330,6 +338,7 @@ class Processor(FlowProcessor):
entity_limit = entity_limit, triple_limit = triple_limit, entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size, max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length, max_path_length = max_path_length,
edge_score_limit = edge_score_limit,
edge_limit = edge_limit, edge_limit = edge_limit,
streaming = True, streaming = True,
chunk_callback = send_chunk, chunk_callback = send_chunk,
@ -344,6 +353,7 @@ class Processor(FlowProcessor):
entity_limit = entity_limit, triple_limit = triple_limit, entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size, max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length, max_path_length = max_path_length,
edge_score_limit = edge_score_limit,
edge_limit = edge_limit, edge_limit = edge_limit,
explain_callback = send_explainability, explain_callback = send_explainability,
save_answer_callback = save_answer, save_answer_callback = save_answer,
@ -432,6 +442,20 @@ class Processor(FlowProcessor):
help=f'Default max path length (default: 2)' help=f'Default max path length (default: 2)'
) )
parser.add_argument(
'--edge-score-limit',
type=int,
default=30,
help=f'Semantic pre-filter limit before LLM scoring (default: 30)'
)
parser.add_argument(
'--edge-limit',
type=int,
default=25,
help=f'Max edges after LLM scoring (default: 25)'
)
# Note: Explainability triples are now stored in the user's collection # Note: Explainability triples are now stored in the user's collection
# with the named graph urn:graph:retrieval (no separate collection needed) # with the named graph urn:graph:retrieval (no separate collection needed)