mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-27 17:36:23 +02:00
Graph rag optimisations (#527)
* Tech spec for GraphRAG optimisation * Implement GraphRAG optimisation and update tests
This commit is contained in:
parent
fcd15d1833
commit
45a14b5958
4 changed files with 881 additions and 104 deletions
|
|
@ -1,12 +1,56 @@
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
||||
class LRUCacheWithTTL:
|
||||
"""LRU cache with TTL for label caching
|
||||
|
||||
CRITICAL SECURITY WARNING:
|
||||
This cache is shared within a GraphRag instance but GraphRag instances
|
||||
are created per-request. Cache keys MUST include user:collection prefix
|
||||
to ensure data isolation between different security contexts.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size=5000, ttl=300):
|
||||
self.cache = OrderedDict()
|
||||
self.access_times = {}
|
||||
self.max_size = max_size
|
||||
self.ttl = ttl
|
||||
|
||||
def get(self, key):
|
||||
if key not in self.cache:
|
||||
return None
|
||||
|
||||
# Check TTL expiration
|
||||
if time.time() - self.access_times[key] > self.ttl:
|
||||
del self.cache[key]
|
||||
del self.access_times[key]
|
||||
return None
|
||||
|
||||
# Move to end (most recently used)
|
||||
self.cache.move_to_end(key)
|
||||
return self.cache[key]
|
||||
|
||||
def put(self, key, value):
|
||||
if key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
else:
|
||||
if len(self.cache) >= self.max_size:
|
||||
# Remove least recently used
|
||||
oldest_key = next(iter(self.cache))
|
||||
del self.cache[oldest_key]
|
||||
del self.access_times[oldest_key]
|
||||
|
||||
self.cache[key] = value
|
||||
self.access_times[key] = time.time()
|
||||
|
||||
class Query:
|
||||
|
||||
def __init__(
|
||||
|
|
@ -61,8 +105,14 @@ class Query:
|
|||
|
||||
async def maybe_label(self, e):
|
||||
|
||||
if e in self.rag.label_cache:
|
||||
return self.rag.label_cache[e]
|
||||
# CRITICAL SECURITY: Cache key MUST include user and collection
|
||||
# to prevent data leakage between different contexts
|
||||
cache_key = f"{self.user}:{self.collection}:{e}"
|
||||
|
||||
# Check LRU cache first with isolated key
|
||||
cached_label = self.rag.label_cache.get(cache_key)
|
||||
if cached_label is not None:
|
||||
return cached_label
|
||||
|
||||
res = await self.rag.triples_client.query(
|
||||
s=e, p=LABEL, o=None, limit=1,
|
||||
|
|
@ -70,60 +120,104 @@ class Query:
|
|||
)
|
||||
|
||||
if len(res) == 0:
|
||||
self.rag.label_cache[e] = e
|
||||
self.rag.label_cache.put(cache_key, e)
|
||||
return e
|
||||
|
||||
self.rag.label_cache[e] = str(res[0].o)
|
||||
return self.rag.label_cache[e]
|
||||
label = str(res[0].o)
|
||||
self.rag.label_cache.put(cache_key, label)
|
||||
return label
|
||||
|
||||
async def execute_batch_triple_queries(self, entities, limit_per_entity):
|
||||
"""Execute triple queries for multiple entities concurrently"""
|
||||
tasks = []
|
||||
|
||||
for entity in entities:
|
||||
# Create concurrent tasks for all 3 query types per entity
|
||||
tasks.extend([
|
||||
self.rag.triples_client.query(
|
||||
s=entity, p=None, o=None,
|
||||
limit=limit_per_entity,
|
||||
user=self.user, collection=self.collection
|
||||
),
|
||||
self.rag.triples_client.query(
|
||||
s=None, p=entity, o=None,
|
||||
limit=limit_per_entity,
|
||||
user=self.user, collection=self.collection
|
||||
),
|
||||
self.rag.triples_client.query(
|
||||
s=None, p=None, o=entity,
|
||||
limit=limit_per_entity,
|
||||
user=self.user, collection=self.collection
|
||||
)
|
||||
])
|
||||
|
||||
# Execute all queries concurrently
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Combine all results
|
||||
all_triples = []
|
||||
for result in results:
|
||||
if not isinstance(result, Exception):
|
||||
all_triples.extend(result)
|
||||
|
||||
return all_triples
|
||||
|
||||
async def follow_edges_batch(self, entities, max_depth):
|
||||
"""Optimized iterative graph traversal with batching"""
|
||||
visited = set()
|
||||
current_level = set(entities)
|
||||
subgraph = set()
|
||||
|
||||
for depth in range(max_depth):
|
||||
if not current_level or len(subgraph) >= self.max_subgraph_size:
|
||||
break
|
||||
|
||||
# Filter out already visited entities
|
||||
unvisited_entities = [e for e in current_level if e not in visited]
|
||||
if not unvisited_entities:
|
||||
break
|
||||
|
||||
# Batch query all unvisited entities at current level
|
||||
triples = await self.execute_batch_triple_queries(
|
||||
unvisited_entities, self.triple_limit
|
||||
)
|
||||
|
||||
# Process results and collect next level entities
|
||||
next_level = set()
|
||||
for triple in triples:
|
||||
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
|
||||
subgraph.add(triple_tuple)
|
||||
|
||||
# Collect entities for next level (only from s and o positions)
|
||||
if depth < max_depth - 1: # Don't collect for final depth
|
||||
s, p, o = triple_tuple
|
||||
if s not in visited:
|
||||
next_level.add(s)
|
||||
if o not in visited:
|
||||
next_level.add(o)
|
||||
|
||||
# Stop if subgraph size limit reached
|
||||
if len(subgraph) >= self.max_subgraph_size:
|
||||
return subgraph
|
||||
|
||||
# Update for next iteration
|
||||
visited.update(current_level)
|
||||
current_level = next_level
|
||||
|
||||
return subgraph
|
||||
|
||||
async def follow_edges(self, ent, subgraph, path_length):
|
||||
|
||||
# Not needed?
|
||||
"""Legacy method - replaced by follow_edges_batch"""
|
||||
# Maintain backward compatibility with early termination checks
|
||||
if path_length <= 0:
|
||||
return
|
||||
|
||||
# Stop spanning around if the subgraph is already maxed out
|
||||
if len(subgraph) >= self.max_subgraph_size:
|
||||
return
|
||||
|
||||
res = await self.rag.triples_client.query(
|
||||
s=ent, p=None, o=None,
|
||||
limit=self.triple_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(str(triple.s), str(triple.p), str(triple.o))
|
||||
)
|
||||
if path_length > 1:
|
||||
await self.follow_edges(str(triple.o), subgraph, path_length-1)
|
||||
|
||||
res = await self.rag.triples_client.query(
|
||||
s=None, p=ent, o=None,
|
||||
limit=self.triple_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(str(triple.s), str(triple.p), str(triple.o))
|
||||
)
|
||||
|
||||
res = await self.rag.triples_client.query(
|
||||
s=None, p=None, o=ent,
|
||||
limit=self.triple_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(str(triple.s), str(triple.p), str(triple.o))
|
||||
)
|
||||
if path_length > 1:
|
||||
await self.follow_edges(
|
||||
str(triple.s), subgraph, path_length-1
|
||||
)
|
||||
# For backward compatibility, convert to new approach
|
||||
batch_result = await self.follow_edges_batch([ent], path_length)
|
||||
subgraph.update(batch_result)
|
||||
|
||||
async def get_subgraph(self, query):
|
||||
|
||||
|
|
@ -132,31 +226,52 @@ class Query:
|
|||
if self.verbose:
|
||||
logger.debug("Getting subgraph...")
|
||||
|
||||
subgraph = set()
|
||||
# Use optimized batch traversal instead of sequential processing
|
||||
subgraph = await self.follow_edges_batch(entities, self.max_path_length)
|
||||
|
||||
for ent in entities:
|
||||
await self.follow_edges(ent, subgraph, self.max_path_length)
|
||||
return list(subgraph)
|
||||
|
||||
subgraph = list(subgraph)
|
||||
async def resolve_labels_batch(self, entities):
|
||||
"""Resolve labels for multiple entities in parallel"""
|
||||
tasks = []
|
||||
for entity in entities:
|
||||
tasks.append(self.maybe_label(entity))
|
||||
|
||||
return subgraph
|
||||
return await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def get_labelgraph(self, query):
|
||||
|
||||
subgraph = await self.get_subgraph(query)
|
||||
|
||||
# Filter out label triples
|
||||
filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL]
|
||||
|
||||
# Collect all unique entities that need label resolution
|
||||
entities_to_resolve = set()
|
||||
for s, p, o in filtered_subgraph:
|
||||
entities_to_resolve.update([s, p, o])
|
||||
|
||||
# Batch resolve labels for all entities in parallel
|
||||
entity_list = list(entities_to_resolve)
|
||||
resolved_labels = await self.resolve_labels_batch(entity_list)
|
||||
|
||||
# Create entity-to-label mapping
|
||||
label_map = {}
|
||||
for entity, label in zip(entity_list, resolved_labels):
|
||||
if not isinstance(label, Exception):
|
||||
label_map[entity] = label
|
||||
else:
|
||||
label_map[entity] = entity # Fallback to entity itself
|
||||
|
||||
# Apply labels to subgraph
|
||||
sg2 = []
|
||||
|
||||
for edge in subgraph:
|
||||
|
||||
if edge[1] == LABEL:
|
||||
continue
|
||||
|
||||
s = await self.maybe_label(edge[0])
|
||||
p = await self.maybe_label(edge[1])
|
||||
o = await self.maybe_label(edge[2])
|
||||
|
||||
sg2.append((s, p, o))
|
||||
for s, p, o in filtered_subgraph:
|
||||
labeled_triple = (
|
||||
label_map.get(s, s),
|
||||
label_map.get(p, p),
|
||||
label_map.get(o, o)
|
||||
)
|
||||
sg2.append(labeled_triple)
|
||||
|
||||
sg2 = sg2[0:self.max_subgraph_size]
|
||||
|
||||
|
|
@ -171,6 +286,13 @@ class Query:
|
|||
return sg2
|
||||
|
||||
class GraphRag:
|
||||
"""
|
||||
CRITICAL SECURITY:
|
||||
This class MUST be instantiated per-request to ensure proper isolation
|
||||
between users and collections. The cache within this instance will only
|
||||
live for the duration of a single request, preventing cross-contamination
|
||||
of data between different security contexts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, prompt_client, embeddings_client, graph_embeddings_client,
|
||||
|
|
@ -184,7 +306,9 @@ class GraphRag:
|
|||
self.graph_embeddings_client = graph_embeddings_client
|
||||
self.triples_client = triples_client
|
||||
|
||||
self.label_cache = {}
|
||||
# Replace simple dict with LRU cache with TTL
|
||||
# CRITICAL: This cache only lives for one request due to per-request instantiation
|
||||
self.label_cache = LRUCacheWithTTL(max_size=5000, ttl=300)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("GraphRag initialized")
|
||||
|
|
|
|||
|
|
@ -45,6 +45,10 @@ class Processor(FlowProcessor):
|
|||
self.default_max_subgraph_size = max_subgraph_size
|
||||
self.default_max_path_length = max_path_length
|
||||
|
||||
# CRITICAL SECURITY: NEVER share data between users or collections
|
||||
# Each user/collection combination MUST have isolated data access
|
||||
# Caching must NEVER allow information leakage across these boundaries
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "request",
|
||||
|
|
@ -93,11 +97,14 @@ class Processor(FlowProcessor):
|
|||
|
||||
try:
|
||||
|
||||
self.rag = GraphRag(
|
||||
embeddings_client = flow("embeddings-request"),
|
||||
graph_embeddings_client = flow("graph-embeddings-request"),
|
||||
triples_client = flow("triples-request"),
|
||||
prompt_client = flow("prompt-request"),
|
||||
# CRITICAL SECURITY: Create new GraphRag instance per request
|
||||
# This ensures proper isolation between users and collections
|
||||
# Flow clients are request-scoped and must not be shared
|
||||
rag = GraphRag(
|
||||
embeddings_client=flow("embeddings-request"),
|
||||
graph_embeddings_client=flow("graph-embeddings-request"),
|
||||
triples_client=flow("triples-request"),
|
||||
prompt_client=flow("prompt-request"),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
|
@ -128,7 +135,7 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
max_path_length = self.default_max_path_length
|
||||
|
||||
response = await self.rag.query(
|
||||
response = await rag.query(
|
||||
query = v.query, user = v.user, collection = v.collection,
|
||||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue