Add semantic pre-filter for GraphRAG edge scoring (#702)

Embed edge descriptions and compute cosine similarity against grounding
concepts to reduce the number of edges sent to expensive LLM scoring.
Controlled by edge_score_limit parameter (default 30), skipped when edge
count is already below the limit.

Also plumbs edge_score_limit and edge_limit parameters end-to-end:
- CLI args (--edge-score-limit, --edge-limit) in both invoke and service
- Socket client: fix parameter mapping to use hyphenated wire-format keys
- Flow API, message translator, gateway all pass through correctly
- Explainable code path (_question_explainable_api) now forwards all params
- Default edge_score_limit changed from 50 to 30 based on typical subgraph
  sizes
This commit is contained in:
cybermaggedon 2026-03-21 20:06:29 +00:00 committed by GitHub
parent bc68738c37
commit 1a7b654bd3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 166 additions and 20 deletions

View file

@ -449,7 +449,7 @@ class FlowInstance:
def graph_rag(
self, query, user="trustgraph", collection="default",
entity_limit=50, triple_limit=30, max_subgraph_size=150,
max_path_length=2,
max_path_length=2, edge_score_limit=30, edge_limit=25,
):
"""
Execute graph-based Retrieval-Augmented Generation (RAG) query.
@ -465,6 +465,8 @@ class FlowInstance:
triple_limit: Maximum triples per entity (default: 30)
max_subgraph_size: Maximum total triples in subgraph (default: 150)
max_path_length: Maximum traversal depth (default: 2)
edge_score_limit: Max edges for semantic pre-filter (default: 50)
edge_limit: Max edges after LLM scoring (default: 25)
Returns:
str: Generated response incorporating graph context
@ -492,6 +494,8 @@ class FlowInstance:
"triple-limit": triple_limit,
"max-subgraph-size": max_subgraph_size,
"max-path-length": max_path_length,
"edge-score-limit": edge_score_limit,
"edge-limit": edge_limit,
}
return self.request(

View file

@ -699,9 +699,12 @@ class SocketFlowInstance:
query: str,
user: str,
collection: str,
entity_limit: int = 50,
triple_limit: int = 30,
max_subgraph_size: int = 1000,
max_subgraph_count: int = 5,
max_entity_distance: int = 3,
max_path_length: int = 2,
edge_score_limit: int = 30,
edge_limit: int = 25,
streaming: bool = False,
**kwargs: Any
) -> Union[str, Iterator[str]]:
@ -715,9 +718,12 @@ class SocketFlowInstance:
query: Natural language query
user: User/keyspace identifier
collection: Collection identifier
entity_limit: Maximum entities to retrieve (default: 50)
triple_limit: Maximum triples per entity (default: 30)
max_subgraph_size: Maximum total triples in subgraph (default: 1000)
max_subgraph_count: Maximum number of subgraphs (default: 5)
max_entity_distance: Maximum traversal depth (default: 3)
max_path_length: Maximum traversal depth (default: 2)
edge_score_limit: Max edges for semantic pre-filter (default: 50)
edge_limit: Max edges after LLM scoring (default: 25)
streaming: Enable streaming mode (default: False)
**kwargs: Additional parameters passed to the service
@ -743,9 +749,12 @@ class SocketFlowInstance:
"query": query,
"user": user,
"collection": collection,
"entity-limit": entity_limit,
"triple-limit": triple_limit,
"max-subgraph-size": max_subgraph_size,
"max-subgraph-count": max_subgraph_count,
"max-entity-distance": max_entity_distance,
"max-path-length": max_path_length,
"edge-score-limit": edge_score_limit,
"edge-limit": edge_limit,
"streaming": streaming
}
request.update(kwargs)
@ -762,9 +771,12 @@ class SocketFlowInstance:
query: str,
user: str,
collection: str,
entity_limit: int = 50,
triple_limit: int = 30,
max_subgraph_size: int = 1000,
max_subgraph_count: int = 5,
max_entity_distance: int = 3,
max_path_length: int = 2,
edge_score_limit: int = 30,
edge_limit: int = 25,
**kwargs: Any
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
"""
@ -778,9 +790,12 @@ class SocketFlowInstance:
query: Natural language query
user: User/keyspace identifier
collection: Collection identifier
entity_limit: Maximum entities to retrieve (default: 50)
triple_limit: Maximum triples per entity (default: 30)
max_subgraph_size: Maximum total triples in subgraph (default: 1000)
max_subgraph_count: Maximum number of subgraphs (default: 5)
max_entity_distance: Maximum traversal depth (default: 3)
max_path_length: Maximum traversal depth (default: 2)
edge_score_limit: Max edges for semantic pre-filter (default: 50)
edge_limit: Max edges after LLM scoring (default: 25)
**kwargs: Additional parameters passed to the service
Yields:
@ -823,11 +838,14 @@ class SocketFlowInstance:
"query": query,
"user": user,
"collection": collection,
"entity-limit": entity_limit,
"triple-limit": triple_limit,
"max-subgraph-size": max_subgraph_size,
"max-subgraph-count": max_subgraph_count,
"max-entity-distance": max_entity_distance,
"max-path-length": max_path_length,
"edge-score-limit": edge_score_limit,
"edge-limit": edge_limit,
"streaming": True,
"explainable": True, # Enable explainability mode
"explainable": True,
}
request.update(kwargs)

View file

@ -84,6 +84,7 @@ class GraphRagRequestTranslator(MessageTranslator):
triple_limit=int(data.get("triple-limit", 30)),
max_subgraph_size=int(data.get("max-subgraph-size", 1000)),
max_path_length=int(data.get("max-path-length", 2)),
edge_score_limit=int(data.get("edge-score-limit", 30)),
edge_limit=int(data.get("edge-limit", 25)),
streaming=data.get("streaming", False)
)
@ -97,6 +98,7 @@ class GraphRagRequestTranslator(MessageTranslator):
"triple-limit": obj.triple_limit,
"max-subgraph-size": obj.max_subgraph_size,
"max-path-length": obj.max_path_length,
"edge-score-limit": obj.edge_score_limit,
"edge-limit": obj.edge_limit,
"streaming": getattr(obj, "streaming", False)
}

View file

@ -15,6 +15,7 @@ class GraphRagQuery:
triple_limit: int = 0
max_subgraph_size: int = 0
max_path_length: int = 0
edge_score_limit: int = 0
edge_limit: int = 0
streaming: bool = False

View file

@ -28,6 +28,8 @@ default_entity_limit = 50
default_triple_limit = 30
default_max_subgraph_size = 150
default_max_path_length = 2
default_edge_score_limit = 30
default_edge_limit = 25
# Provenance predicates
TG = "https://trustgraph.ai/ns/"
@ -638,7 +640,8 @@ async def _question_explainable(
def _question_explainable_api(
url, flow_id, question_text, user, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, token=None, debug=False
max_subgraph_size, max_path_length, edge_score_limit=30,
edge_limit=25, token=None, debug=False
):
"""Execute graph RAG with explainability using the new API classes."""
api = Api(url=url, token=token)
@ -652,9 +655,12 @@ def _question_explainable_api(
query=question_text,
user=user,
collection=collection,
entity_limit=entity_limit,
triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size,
max_subgraph_count=5,
max_entity_distance=max_path_length,
max_path_length=max_path_length,
edge_score_limit=edge_score_limit,
edge_limit=edge_limit,
):
if isinstance(item, RAGChunk):
# Print response content
@ -743,7 +749,8 @@ def _question_explainable_api(
def question(
url, flow_id, question, user, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, streaming=True, token=None,
max_subgraph_size, max_path_length, edge_score_limit=50,
edge_limit=25, streaming=True, token=None,
explainable=False, debug=False
):
@ -759,6 +766,8 @@ def question(
triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size,
max_path_length=max_path_length,
edge_score_limit=edge_score_limit,
edge_limit=edge_limit,
token=token,
debug=debug
)
@ -781,6 +790,8 @@ def question(
triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size,
max_path_length=max_path_length,
edge_score_limit=edge_score_limit,
edge_limit=edge_limit,
streaming=True
)
@ -801,7 +812,9 @@ def question(
entity_limit=entity_limit,
triple_limit=triple_limit,
max_subgraph_size=max_subgraph_size,
max_path_length=max_path_length
max_path_length=max_path_length,
edge_score_limit=edge_score_limit,
edge_limit=edge_limit,
)
print(resp)
@ -876,6 +889,20 @@ def main():
help=f'Max path length (default: {default_max_path_length})'
)
parser.add_argument(
'--edge-score-limit',
type=int,
default=default_edge_score_limit,
help=f'Semantic pre-filter limit before LLM scoring (default: {default_edge_score_limit})'
)
parser.add_argument(
'--edge-limit',
type=int,
default=default_edge_limit,
help=f'Max edges after LLM scoring (default: {default_edge_limit})'
)
parser.add_argument(
'--no-streaming',
action='store_true',
@ -908,6 +935,8 @@ def main():
triple_limit=args.triple_limit,
max_subgraph_size=args.max_subgraph_size,
max_path_length=args.max_path_length,
edge_score_limit=args.edge_score_limit,
edge_limit=args.edge_limit,
streaming=not args.no_streaming,
token=args.token,
explainable=args.explainable,

View file

@ -3,6 +3,7 @@ import asyncio
import hashlib
import json
import logging
import math
import time
import uuid
from collections import OrderedDict
@ -550,7 +551,8 @@ class GraphRag:
async def query(
self, query, user = "trustgraph", collection = "default",
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
max_path_length = 2, edge_limit = 25, streaming = False,
max_path_length = 2, edge_score_limit = 30, edge_limit = 25,
streaming = False,
chunk_callback = None,
explain_callback = None, save_answer_callback = None,
):
@ -565,6 +567,8 @@ class GraphRag:
triple_limit: Max triples per entity
max_subgraph_size: Max edges in subgraph
max_path_length: Max hops from seed entities
edge_score_limit: Max edges to pass to LLM scoring (semantic pre-filter)
edge_limit: Max edges after LLM scoring
streaming: Enable streaming LLM response
chunk_callback: async def callback(chunk, end_of_stream) for streaming
explain_callback: async def callback(triples, explain_id) for real-time explainability
@ -628,6 +632,70 @@ class GraphRag:
logger.debug(f"Knowledge graph: {kg}")
logger.debug(f"Query: {query}")
# Semantic pre-filter: reduce edges before expensive LLM scoring
if edge_score_limit > 0 and len(kg) > edge_score_limit:
if self.verbose:
logger.debug(
f"Semantic pre-filter: {len(kg)} edges > "
f"limit {edge_score_limit}, filtering..."
)
# Embed edge descriptions: "subject, predicate, object"
edge_descriptions = [
f"{s}, {p}, {o}" for s, p, o in kg
]
# Embed concepts and edge descriptions concurrently
concept_embed_task = self.embeddings_client.embed(concepts)
edge_embed_task = self.embeddings_client.embed(edge_descriptions)
concept_vectors, edge_vectors = await asyncio.gather(
concept_embed_task, edge_embed_task
)
# Score each edge by max cosine similarity to any concept
def cosine_similarity(a, b):
dot = sum(x * y for x, y in zip(a, b))
norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(x * x for x in b))
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
edge_scores = []
for i, edge_vec in enumerate(edge_vectors):
max_sim = max(
cosine_similarity(edge_vec, cv)
for cv in concept_vectors
)
edge_scores.append((max_sim, i))
# Sort by similarity descending and keep top edge_score_limit
edge_scores.sort(reverse=True)
keep_indices = set(
idx for _, idx in edge_scores[:edge_score_limit]
)
# Filter kg and rebuild uri_map
filtered_kg = []
filtered_uri_map = {}
for i, (s, p, o) in enumerate(kg):
if i in keep_indices:
filtered_kg.append((s, p, o))
eid = edge_id(s, p, o)
if eid in uri_map:
filtered_uri_map[eid] = uri_map[eid]
if self.verbose:
logger.debug(
f"Semantic pre-filter kept {len(filtered_kg)} "
f"of {len(kg)} edges"
)
kg = filtered_kg
uri_map = filtered_uri_map
# Build edge map: {hash_id: (labeled_s, labeled_p, labeled_o)}
# uri_map already maps edge_id -> (uri_s, uri_p, uri_o)
edge_map = {}

View file

@ -39,6 +39,7 @@ class Processor(FlowProcessor):
triple_limit = params.get("triple_limit", 30)
max_subgraph_size = params.get("max_subgraph_size", 150)
max_path_length = params.get("max_path_length", 2)
edge_score_limit = params.get("edge_score_limit", 30)
edge_limit = params.get("edge_limit", 25)
super(Processor, self).__init__(
@ -49,6 +50,7 @@ class Processor(FlowProcessor):
"triple_limit": triple_limit,
"max_subgraph_size": max_subgraph_size,
"max_path_length": max_path_length,
"edge_score_limit": edge_score_limit,
"edge_limit": edge_limit,
}
)
@ -57,6 +59,7 @@ class Processor(FlowProcessor):
self.default_triple_limit = triple_limit
self.default_max_subgraph_size = max_subgraph_size
self.default_max_path_length = max_path_length
self.default_edge_score_limit = edge_score_limit
self.default_edge_limit = edge_limit
# CRITICAL SECURITY: NEVER share data between users or collections
@ -295,6 +298,11 @@ class Processor(FlowProcessor):
else:
max_path_length = self.default_max_path_length
if v.edge_score_limit:
edge_score_limit = v.edge_score_limit
else:
edge_score_limit = self.default_edge_score_limit
if v.edge_limit:
edge_limit = v.edge_limit
else:
@ -330,6 +338,7 @@ class Processor(FlowProcessor):
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
edge_score_limit = edge_score_limit,
edge_limit = edge_limit,
streaming = True,
chunk_callback = send_chunk,
@ -344,6 +353,7 @@ class Processor(FlowProcessor):
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
edge_score_limit = edge_score_limit,
edge_limit = edge_limit,
explain_callback = send_explainability,
save_answer_callback = save_answer,
@ -432,6 +442,20 @@ class Processor(FlowProcessor):
help=f'Default max path length (default: 2)'
)
parser.add_argument(
'--edge-score-limit',
type=int,
default=30,
help=f'Semantic pre-filter limit before LLM scoring (default: 30)'
)
parser.add_argument(
'--edge-limit',
type=int,
default=25,
help=f'Max edges after LLM scoring (default: 25)'
)
# Note: Explainability triples are now stored in the user's collection
# with the named graph urn:graph:retrieval (no separate collection needed)