Graph rag optimisations (#527)

* Tech spec for GraphRAG optimisation

* Implement GraphRAG optimisation and update tests
This commit is contained in:
cybermaggedon 2025-09-23 21:05:51 +01:00 committed by GitHub
parent fcd15d1833
commit 45a14b5958
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 881 additions and 104 deletions

View file

@ -1,12 +1,56 @@
import asyncio
import logging
import time
from collections import OrderedDict
# Module logger
logger = logging.getLogger(__name__)
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
class LRUCacheWithTTL:
"""LRU cache with TTL for label caching
CRITICAL SECURITY WARNING:
This cache is shared within a GraphRag instance but GraphRag instances
are created per-request. Cache keys MUST include user:collection prefix
to ensure data isolation between different security contexts.
"""
def __init__(self, max_size=5000, ttl=300):
self.cache = OrderedDict()
self.access_times = {}
self.max_size = max_size
self.ttl = ttl
def get(self, key):
if key not in self.cache:
return None
# Check TTL expiration
if time.time() - self.access_times[key] > self.ttl:
del self.cache[key]
del self.access_times[key]
return None
# Move to end (most recently used)
self.cache.move_to_end(key)
return self.cache[key]
def put(self, key, value):
if key in self.cache:
self.cache.move_to_end(key)
else:
if len(self.cache) >= self.max_size:
# Remove least recently used
oldest_key = next(iter(self.cache))
del self.cache[oldest_key]
del self.access_times[oldest_key]
self.cache[key] = value
self.access_times[key] = time.time()
class Query:
def __init__(
@ -61,8 +105,14 @@ class Query:
async def maybe_label(self, e):
if e in self.rag.label_cache:
return self.rag.label_cache[e]
# CRITICAL SECURITY: Cache key MUST include user and collection
# to prevent data leakage between different contexts
cache_key = f"{self.user}:{self.collection}:{e}"
# Check LRU cache first with isolated key
cached_label = self.rag.label_cache.get(cache_key)
if cached_label is not None:
return cached_label
res = await self.rag.triples_client.query(
s=e, p=LABEL, o=None, limit=1,
@ -70,60 +120,104 @@ class Query:
)
if len(res) == 0:
self.rag.label_cache[e] = e
self.rag.label_cache.put(cache_key, e)
return e
self.rag.label_cache[e] = str(res[0].o)
return self.rag.label_cache[e]
label = str(res[0].o)
self.rag.label_cache.put(cache_key, label)
return label
async def execute_batch_triple_queries(self, entities, limit_per_entity):
"""Execute triple queries for multiple entities concurrently"""
tasks = []
for entity in entities:
# Create concurrent tasks for all 3 query types per entity
tasks.extend([
self.rag.triples_client.query(
s=entity, p=None, o=None,
limit=limit_per_entity,
user=self.user, collection=self.collection
),
self.rag.triples_client.query(
s=None, p=entity, o=None,
limit=limit_per_entity,
user=self.user, collection=self.collection
),
self.rag.triples_client.query(
s=None, p=None, o=entity,
limit=limit_per_entity,
user=self.user, collection=self.collection
)
])
# Execute all queries concurrently
results = await asyncio.gather(*tasks, return_exceptions=True)
# Combine all results
all_triples = []
for result in results:
if not isinstance(result, Exception):
all_triples.extend(result)
return all_triples
async def follow_edges_batch(self, entities, max_depth):
"""Optimized iterative graph traversal with batching"""
visited = set()
current_level = set(entities)
subgraph = set()
for depth in range(max_depth):
if not current_level or len(subgraph) >= self.max_subgraph_size:
break
# Filter out already visited entities
unvisited_entities = [e for e in current_level if e not in visited]
if not unvisited_entities:
break
# Batch query all unvisited entities at current level
triples = await self.execute_batch_triple_queries(
unvisited_entities, self.triple_limit
)
# Process results and collect next level entities
next_level = set()
for triple in triples:
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
subgraph.add(triple_tuple)
# Collect entities for next level (only from s and o positions)
if depth < max_depth - 1: # Don't collect for final depth
s, p, o = triple_tuple
if s not in visited:
next_level.add(s)
if o not in visited:
next_level.add(o)
# Stop if subgraph size limit reached
if len(subgraph) >= self.max_subgraph_size:
return subgraph
# Update for next iteration
visited.update(current_level)
current_level = next_level
return subgraph
async def follow_edges(self, ent, subgraph, path_length):
# Not needed?
"""Legacy method - replaced by follow_edges_batch"""
# Maintain backward compatibility with early termination checks
if path_length <= 0:
return
# Stop spanning around if the subgraph is already maxed out
if len(subgraph) >= self.max_subgraph_size:
return
res = await self.rag.triples_client.query(
s=ent, p=None, o=None,
limit=self.triple_limit,
user=self.user, collection=self.collection,
)
for triple in res:
subgraph.add(
(str(triple.s), str(triple.p), str(triple.o))
)
if path_length > 1:
await self.follow_edges(str(triple.o), subgraph, path_length-1)
res = await self.rag.triples_client.query(
s=None, p=ent, o=None,
limit=self.triple_limit,
user=self.user, collection=self.collection,
)
for triple in res:
subgraph.add(
(str(triple.s), str(triple.p), str(triple.o))
)
res = await self.rag.triples_client.query(
s=None, p=None, o=ent,
limit=self.triple_limit,
user=self.user, collection=self.collection,
)
for triple in res:
subgraph.add(
(str(triple.s), str(triple.p), str(triple.o))
)
if path_length > 1:
await self.follow_edges(
str(triple.s), subgraph, path_length-1
)
# For backward compatibility, convert to new approach
batch_result = await self.follow_edges_batch([ent], path_length)
subgraph.update(batch_result)
async def get_subgraph(self, query):
@ -132,31 +226,52 @@ class Query:
if self.verbose:
logger.debug("Getting subgraph...")
subgraph = set()
# Use optimized batch traversal instead of sequential processing
subgraph = await self.follow_edges_batch(entities, self.max_path_length)
for ent in entities:
await self.follow_edges(ent, subgraph, self.max_path_length)
return list(subgraph)
subgraph = list(subgraph)
async def resolve_labels_batch(self, entities):
"""Resolve labels for multiple entities in parallel"""
tasks = []
for entity in entities:
tasks.append(self.maybe_label(entity))
return subgraph
return await asyncio.gather(*tasks, return_exceptions=True)
async def get_labelgraph(self, query):
subgraph = await self.get_subgraph(query)
# Filter out label triples
filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL]
# Collect all unique entities that need label resolution
entities_to_resolve = set()
for s, p, o in filtered_subgraph:
entities_to_resolve.update([s, p, o])
# Batch resolve labels for all entities in parallel
entity_list = list(entities_to_resolve)
resolved_labels = await self.resolve_labels_batch(entity_list)
# Create entity-to-label mapping
label_map = {}
for entity, label in zip(entity_list, resolved_labels):
if not isinstance(label, Exception):
label_map[entity] = label
else:
label_map[entity] = entity # Fallback to entity itself
# Apply labels to subgraph
sg2 = []
for edge in subgraph:
if edge[1] == LABEL:
continue
s = await self.maybe_label(edge[0])
p = await self.maybe_label(edge[1])
o = await self.maybe_label(edge[2])
sg2.append((s, p, o))
for s, p, o in filtered_subgraph:
labeled_triple = (
label_map.get(s, s),
label_map.get(p, p),
label_map.get(o, o)
)
sg2.append(labeled_triple)
sg2 = sg2[0:self.max_subgraph_size]
@ -171,6 +286,13 @@ class Query:
return sg2
class GraphRag:
"""
CRITICAL SECURITY:
This class MUST be instantiated per-request to ensure proper isolation
between users and collections. The cache within this instance will only
live for the duration of a single request, preventing cross-contamination
of data between different security contexts.
"""
def __init__(
self, prompt_client, embeddings_client, graph_embeddings_client,
@ -184,7 +306,9 @@ class GraphRag:
self.graph_embeddings_client = graph_embeddings_client
self.triples_client = triples_client
self.label_cache = {}
# Replace simple dict with LRU cache with TTL
# CRITICAL: This cache only lives for one request due to per-request instantiation
self.label_cache = LRUCacheWithTTL(max_size=5000, ttl=300)
if self.verbose:
logger.debug("GraphRag initialized")

View file

@ -45,6 +45,10 @@ class Processor(FlowProcessor):
self.default_max_subgraph_size = max_subgraph_size
self.default_max_path_length = max_path_length
# CRITICAL SECURITY: NEVER share data between users or collections
# Each user/collection combination MUST have isolated data access
# Caching must NEVER allow information leakage across these boundaries
self.register_specification(
ConsumerSpec(
name = "request",
@ -93,11 +97,14 @@ class Processor(FlowProcessor):
try:
self.rag = GraphRag(
embeddings_client = flow("embeddings-request"),
graph_embeddings_client = flow("graph-embeddings-request"),
triples_client = flow("triples-request"),
prompt_client = flow("prompt-request"),
# CRITICAL SECURITY: Create new GraphRag instance per request
# This ensures proper isolation between users and collections
# Flow clients are request-scoped and must not be shared
rag = GraphRag(
embeddings_client=flow("embeddings-request"),
graph_embeddings_client=flow("graph-embeddings-request"),
triples_client=flow("triples-request"),
prompt_client=flow("prompt-request"),
verbose=True,
)
@ -128,7 +135,7 @@ class Processor(FlowProcessor):
else:
max_path_length = self.default_max_path_length
response = await self.rag.query(
response = await rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,