mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Add semantic pre-filter for GraphRAG edge scoring (#702)
Embed edge descriptions and compute cosine similarity against grounding concepts to reduce the number of edges sent to expensive LLM scoring. Controlled by edge_score_limit parameter (default 30), skipped when edge count is already below the limit. Also plumbs edge_score_limit and edge_limit parameters end-to-end: - CLI args (--edge-score-limit, --edge-limit) in both invoke and service - Socket client: fix parameter mapping to use hyphenated wire-format keys - Flow API, message translator, gateway all pass through correctly - Explainable code path (_question_explainable_api) now forwards all params - Default edge_score_limit changed from 50 to 30 based on typical subgraph sizes
This commit is contained in:
parent
bc68738c37
commit
1a7b654bd3
7 changed files with 166 additions and 20 deletions
|
|
@ -449,7 +449,7 @@ class FlowInstance:
|
|||
def graph_rag(
|
||||
self, query, user="trustgraph", collection="default",
|
||||
entity_limit=50, triple_limit=30, max_subgraph_size=150,
|
||||
max_path_length=2,
|
||||
max_path_length=2, edge_score_limit=30, edge_limit=25,
|
||||
):
|
||||
"""
|
||||
Execute graph-based Retrieval-Augmented Generation (RAG) query.
|
||||
|
|
@ -465,6 +465,8 @@ class FlowInstance:
|
|||
triple_limit: Maximum triples per entity (default: 30)
|
||||
max_subgraph_size: Maximum total triples in subgraph (default: 150)
|
||||
max_path_length: Maximum traversal depth (default: 2)
|
||||
edge_score_limit: Max edges for semantic pre-filter (default: 50)
|
||||
edge_limit: Max edges after LLM scoring (default: 25)
|
||||
|
||||
Returns:
|
||||
str: Generated response incorporating graph context
|
||||
|
|
@ -492,6 +494,8 @@ class FlowInstance:
|
|||
"triple-limit": triple_limit,
|
||||
"max-subgraph-size": max_subgraph_size,
|
||||
"max-path-length": max_path_length,
|
||||
"edge-score-limit": edge_score_limit,
|
||||
"edge-limit": edge_limit,
|
||||
}
|
||||
|
||||
return self.request(
|
||||
|
|
|
|||
|
|
@ -699,9 +699,12 @@ class SocketFlowInstance:
|
|||
query: str,
|
||||
user: str,
|
||||
collection: str,
|
||||
entity_limit: int = 50,
|
||||
triple_limit: int = 30,
|
||||
max_subgraph_size: int = 1000,
|
||||
max_subgraph_count: int = 5,
|
||||
max_entity_distance: int = 3,
|
||||
max_path_length: int = 2,
|
||||
edge_score_limit: int = 30,
|
||||
edge_limit: int = 25,
|
||||
streaming: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Union[str, Iterator[str]]:
|
||||
|
|
@ -715,9 +718,12 @@ class SocketFlowInstance:
|
|||
query: Natural language query
|
||||
user: User/keyspace identifier
|
||||
collection: Collection identifier
|
||||
entity_limit: Maximum entities to retrieve (default: 50)
|
||||
triple_limit: Maximum triples per entity (default: 30)
|
||||
max_subgraph_size: Maximum total triples in subgraph (default: 1000)
|
||||
max_subgraph_count: Maximum number of subgraphs (default: 5)
|
||||
max_entity_distance: Maximum traversal depth (default: 3)
|
||||
max_path_length: Maximum traversal depth (default: 2)
|
||||
edge_score_limit: Max edges for semantic pre-filter (default: 50)
|
||||
edge_limit: Max edges after LLM scoring (default: 25)
|
||||
streaming: Enable streaming mode (default: False)
|
||||
**kwargs: Additional parameters passed to the service
|
||||
|
||||
|
|
@ -743,9 +749,12 @@ class SocketFlowInstance:
|
|||
"query": query,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"entity-limit": entity_limit,
|
||||
"triple-limit": triple_limit,
|
||||
"max-subgraph-size": max_subgraph_size,
|
||||
"max-subgraph-count": max_subgraph_count,
|
||||
"max-entity-distance": max_entity_distance,
|
||||
"max-path-length": max_path_length,
|
||||
"edge-score-limit": edge_score_limit,
|
||||
"edge-limit": edge_limit,
|
||||
"streaming": streaming
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
|
@ -762,9 +771,12 @@ class SocketFlowInstance:
|
|||
query: str,
|
||||
user: str,
|
||||
collection: str,
|
||||
entity_limit: int = 50,
|
||||
triple_limit: int = 30,
|
||||
max_subgraph_size: int = 1000,
|
||||
max_subgraph_count: int = 5,
|
||||
max_entity_distance: int = 3,
|
||||
max_path_length: int = 2,
|
||||
edge_score_limit: int = 30,
|
||||
edge_limit: int = 25,
|
||||
**kwargs: Any
|
||||
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
|
||||
"""
|
||||
|
|
@ -778,9 +790,12 @@ class SocketFlowInstance:
|
|||
query: Natural language query
|
||||
user: User/keyspace identifier
|
||||
collection: Collection identifier
|
||||
entity_limit: Maximum entities to retrieve (default: 50)
|
||||
triple_limit: Maximum triples per entity (default: 30)
|
||||
max_subgraph_size: Maximum total triples in subgraph (default: 1000)
|
||||
max_subgraph_count: Maximum number of subgraphs (default: 5)
|
||||
max_entity_distance: Maximum traversal depth (default: 3)
|
||||
max_path_length: Maximum traversal depth (default: 2)
|
||||
edge_score_limit: Max edges for semantic pre-filter (default: 50)
|
||||
edge_limit: Max edges after LLM scoring (default: 25)
|
||||
**kwargs: Additional parameters passed to the service
|
||||
|
||||
Yields:
|
||||
|
|
@ -823,11 +838,14 @@ class SocketFlowInstance:
|
|||
"query": query,
|
||||
"user": user,
|
||||
"collection": collection,
|
||||
"entity-limit": entity_limit,
|
||||
"triple-limit": triple_limit,
|
||||
"max-subgraph-size": max_subgraph_size,
|
||||
"max-subgraph-count": max_subgraph_count,
|
||||
"max-entity-distance": max_entity_distance,
|
||||
"max-path-length": max_path_length,
|
||||
"edge-score-limit": edge_score_limit,
|
||||
"edge-limit": edge_limit,
|
||||
"streaming": True,
|
||||
"explainable": True, # Enable explainability mode
|
||||
"explainable": True,
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ class GraphRagRequestTranslator(MessageTranslator):
|
|||
triple_limit=int(data.get("triple-limit", 30)),
|
||||
max_subgraph_size=int(data.get("max-subgraph-size", 1000)),
|
||||
max_path_length=int(data.get("max-path-length", 2)),
|
||||
edge_score_limit=int(data.get("edge-score-limit", 30)),
|
||||
edge_limit=int(data.get("edge-limit", 25)),
|
||||
streaming=data.get("streaming", False)
|
||||
)
|
||||
|
|
@ -97,6 +98,7 @@ class GraphRagRequestTranslator(MessageTranslator):
|
|||
"triple-limit": obj.triple_limit,
|
||||
"max-subgraph-size": obj.max_subgraph_size,
|
||||
"max-path-length": obj.max_path_length,
|
||||
"edge-score-limit": obj.edge_score_limit,
|
||||
"edge-limit": obj.edge_limit,
|
||||
"streaming": getattr(obj, "streaming", False)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ class GraphRagQuery:
|
|||
triple_limit: int = 0
|
||||
max_subgraph_size: int = 0
|
||||
max_path_length: int = 0
|
||||
edge_score_limit: int = 0
|
||||
edge_limit: int = 0
|
||||
streaming: bool = False
|
||||
|
||||
|
|
|
|||
|
|
@ -28,6 +28,8 @@ default_entity_limit = 50
|
|||
default_triple_limit = 30
|
||||
default_max_subgraph_size = 150
|
||||
default_max_path_length = 2
|
||||
default_edge_score_limit = 30
|
||||
default_edge_limit = 25
|
||||
|
||||
# Provenance predicates
|
||||
TG = "https://trustgraph.ai/ns/"
|
||||
|
|
@ -638,7 +640,8 @@ async def _question_explainable(
|
|||
|
||||
def _question_explainable_api(
|
||||
url, flow_id, question_text, user, collection, entity_limit, triple_limit,
|
||||
max_subgraph_size, max_path_length, token=None, debug=False
|
||||
max_subgraph_size, max_path_length, edge_score_limit=30,
|
||||
edge_limit=25, token=None, debug=False
|
||||
):
|
||||
"""Execute graph RAG with explainability using the new API classes."""
|
||||
api = Api(url=url, token=token)
|
||||
|
|
@ -652,9 +655,12 @@ def _question_explainable_api(
|
|||
query=question_text,
|
||||
user=user,
|
||||
collection=collection,
|
||||
entity_limit=entity_limit,
|
||||
triple_limit=triple_limit,
|
||||
max_subgraph_size=max_subgraph_size,
|
||||
max_subgraph_count=5,
|
||||
max_entity_distance=max_path_length,
|
||||
max_path_length=max_path_length,
|
||||
edge_score_limit=edge_score_limit,
|
||||
edge_limit=edge_limit,
|
||||
):
|
||||
if isinstance(item, RAGChunk):
|
||||
# Print response content
|
||||
|
|
@ -743,7 +749,8 @@ def _question_explainable_api(
|
|||
|
||||
def question(
|
||||
url, flow_id, question, user, collection, entity_limit, triple_limit,
|
||||
max_subgraph_size, max_path_length, streaming=True, token=None,
|
||||
max_subgraph_size, max_path_length, edge_score_limit=50,
|
||||
edge_limit=25, streaming=True, token=None,
|
||||
explainable=False, debug=False
|
||||
):
|
||||
|
||||
|
|
@ -759,6 +766,8 @@ def question(
|
|||
triple_limit=triple_limit,
|
||||
max_subgraph_size=max_subgraph_size,
|
||||
max_path_length=max_path_length,
|
||||
edge_score_limit=edge_score_limit,
|
||||
edge_limit=edge_limit,
|
||||
token=token,
|
||||
debug=debug
|
||||
)
|
||||
|
|
@ -781,6 +790,8 @@ def question(
|
|||
triple_limit=triple_limit,
|
||||
max_subgraph_size=max_subgraph_size,
|
||||
max_path_length=max_path_length,
|
||||
edge_score_limit=edge_score_limit,
|
||||
edge_limit=edge_limit,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
|
|
@ -801,7 +812,9 @@ def question(
|
|||
entity_limit=entity_limit,
|
||||
triple_limit=triple_limit,
|
||||
max_subgraph_size=max_subgraph_size,
|
||||
max_path_length=max_path_length
|
||||
max_path_length=max_path_length,
|
||||
edge_score_limit=edge_score_limit,
|
||||
edge_limit=edge_limit,
|
||||
)
|
||||
print(resp)
|
||||
|
||||
|
|
@ -876,6 +889,20 @@ def main():
|
|||
help=f'Max path length (default: {default_max_path_length})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--edge-score-limit',
|
||||
type=int,
|
||||
default=default_edge_score_limit,
|
||||
help=f'Semantic pre-filter limit before LLM scoring (default: {default_edge_score_limit})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--edge-limit',
|
||||
type=int,
|
||||
default=default_edge_limit,
|
||||
help=f'Max edges after LLM scoring (default: {default_edge_limit})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--no-streaming',
|
||||
action='store_true',
|
||||
|
|
@ -908,6 +935,8 @@ def main():
|
|||
triple_limit=args.triple_limit,
|
||||
max_subgraph_size=args.max_subgraph_size,
|
||||
max_path_length=args.max_path_length,
|
||||
edge_score_limit=args.edge_score_limit,
|
||||
edge_limit=args.edge_limit,
|
||||
streaming=not args.no_streaming,
|
||||
token=args.token,
|
||||
explainable=args.explainable,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import asyncio
|
|||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
|
|
@ -550,7 +551,8 @@ class GraphRag:
|
|||
async def query(
|
||||
self, query, user = "trustgraph", collection = "default",
|
||||
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
|
||||
max_path_length = 2, edge_limit = 25, streaming = False,
|
||||
max_path_length = 2, edge_score_limit = 30, edge_limit = 25,
|
||||
streaming = False,
|
||||
chunk_callback = None,
|
||||
explain_callback = None, save_answer_callback = None,
|
||||
):
|
||||
|
|
@ -565,6 +567,8 @@ class GraphRag:
|
|||
triple_limit: Max triples per entity
|
||||
max_subgraph_size: Max edges in subgraph
|
||||
max_path_length: Max hops from seed entities
|
||||
edge_score_limit: Max edges to pass to LLM scoring (semantic pre-filter)
|
||||
edge_limit: Max edges after LLM scoring
|
||||
streaming: Enable streaming LLM response
|
||||
chunk_callback: async def callback(chunk, end_of_stream) for streaming
|
||||
explain_callback: async def callback(triples, explain_id) for real-time explainability
|
||||
|
|
@ -628,6 +632,70 @@ class GraphRag:
|
|||
logger.debug(f"Knowledge graph: {kg}")
|
||||
logger.debug(f"Query: {query}")
|
||||
|
||||
# Semantic pre-filter: reduce edges before expensive LLM scoring
|
||||
if edge_score_limit > 0 and len(kg) > edge_score_limit:
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Semantic pre-filter: {len(kg)} edges > "
|
||||
f"limit {edge_score_limit}, filtering..."
|
||||
)
|
||||
|
||||
# Embed edge descriptions: "subject, predicate, object"
|
||||
edge_descriptions = [
|
||||
f"{s}, {p}, {o}" for s, p, o in kg
|
||||
]
|
||||
|
||||
# Embed concepts and edge descriptions concurrently
|
||||
concept_embed_task = self.embeddings_client.embed(concepts)
|
||||
edge_embed_task = self.embeddings_client.embed(edge_descriptions)
|
||||
|
||||
concept_vectors, edge_vectors = await asyncio.gather(
|
||||
concept_embed_task, edge_embed_task
|
||||
)
|
||||
|
||||
# Score each edge by max cosine similarity to any concept
|
||||
def cosine_similarity(a, b):
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = math.sqrt(sum(x * x for x in a))
|
||||
norm_b = math.sqrt(sum(x * x for x in b))
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
return dot / (norm_a * norm_b)
|
||||
|
||||
edge_scores = []
|
||||
for i, edge_vec in enumerate(edge_vectors):
|
||||
max_sim = max(
|
||||
cosine_similarity(edge_vec, cv)
|
||||
for cv in concept_vectors
|
||||
)
|
||||
edge_scores.append((max_sim, i))
|
||||
|
||||
# Sort by similarity descending and keep top edge_score_limit
|
||||
edge_scores.sort(reverse=True)
|
||||
keep_indices = set(
|
||||
idx for _, idx in edge_scores[:edge_score_limit]
|
||||
)
|
||||
|
||||
# Filter kg and rebuild uri_map
|
||||
filtered_kg = []
|
||||
filtered_uri_map = {}
|
||||
for i, (s, p, o) in enumerate(kg):
|
||||
if i in keep_indices:
|
||||
filtered_kg.append((s, p, o))
|
||||
eid = edge_id(s, p, o)
|
||||
if eid in uri_map:
|
||||
filtered_uri_map[eid] = uri_map[eid]
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Semantic pre-filter kept {len(filtered_kg)} "
|
||||
f"of {len(kg)} edges"
|
||||
)
|
||||
|
||||
kg = filtered_kg
|
||||
uri_map = filtered_uri_map
|
||||
|
||||
# Build edge map: {hash_id: (labeled_s, labeled_p, labeled_o)}
|
||||
# uri_map already maps edge_id -> (uri_s, uri_p, uri_o)
|
||||
edge_map = {}
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ class Processor(FlowProcessor):
|
|||
triple_limit = params.get("triple_limit", 30)
|
||||
max_subgraph_size = params.get("max_subgraph_size", 150)
|
||||
max_path_length = params.get("max_path_length", 2)
|
||||
edge_score_limit = params.get("edge_score_limit", 30)
|
||||
edge_limit = params.get("edge_limit", 25)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
|
|
@ -49,6 +50,7 @@ class Processor(FlowProcessor):
|
|||
"triple_limit": triple_limit,
|
||||
"max_subgraph_size": max_subgraph_size,
|
||||
"max_path_length": max_path_length,
|
||||
"edge_score_limit": edge_score_limit,
|
||||
"edge_limit": edge_limit,
|
||||
}
|
||||
)
|
||||
|
|
@ -57,6 +59,7 @@ class Processor(FlowProcessor):
|
|||
self.default_triple_limit = triple_limit
|
||||
self.default_max_subgraph_size = max_subgraph_size
|
||||
self.default_max_path_length = max_path_length
|
||||
self.default_edge_score_limit = edge_score_limit
|
||||
self.default_edge_limit = edge_limit
|
||||
|
||||
# CRITICAL SECURITY: NEVER share data between users or collections
|
||||
|
|
@ -295,6 +298,11 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
max_path_length = self.default_max_path_length
|
||||
|
||||
if v.edge_score_limit:
|
||||
edge_score_limit = v.edge_score_limit
|
||||
else:
|
||||
edge_score_limit = self.default_edge_score_limit
|
||||
|
||||
if v.edge_limit:
|
||||
edge_limit = v.edge_limit
|
||||
else:
|
||||
|
|
@ -330,6 +338,7 @@ class Processor(FlowProcessor):
|
|||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
edge_score_limit = edge_score_limit,
|
||||
edge_limit = edge_limit,
|
||||
streaming = True,
|
||||
chunk_callback = send_chunk,
|
||||
|
|
@ -344,6 +353,7 @@ class Processor(FlowProcessor):
|
|||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
edge_score_limit = edge_score_limit,
|
||||
edge_limit = edge_limit,
|
||||
explain_callback = send_explainability,
|
||||
save_answer_callback = save_answer,
|
||||
|
|
@ -432,6 +442,20 @@ class Processor(FlowProcessor):
|
|||
help=f'Default max path length (default: 2)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--edge-score-limit',
|
||||
type=int,
|
||||
default=30,
|
||||
help=f'Semantic pre-filter limit before LLM scoring (default: 30)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--edge-limit',
|
||||
type=int,
|
||||
default=25,
|
||||
help=f'Max edges after LLM scoring (default: 25)'
|
||||
)
|
||||
|
||||
# Note: Explainability triples are now stored in the user's collection
|
||||
# with the named graph urn:graph:retrieval (no separate collection needed)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue