mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-04 21:02:37 +02:00
Graph rag optimisations (#527)
* Tech spec for GraphRAG optimisation * Implement GraphRAG optimisation and update tests
This commit is contained in:
parent
fcd15d1833
commit
45a14b5958
4 changed files with 881 additions and 104 deletions
629
docs/tech-specs/graphrag-performance-optimization.md
Normal file
629
docs/tech-specs/graphrag-performance-optimization.md
Normal file
|
|
@ -0,0 +1,629 @@
|
||||||
|
# GraphRAG Performance Optimisation Technical Specification
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This specification describes comprehensive performance optimisations for the GraphRAG (Graph Retrieval-Augmented Generation) algorithm in TrustGraph. The current implementation suffers from significant performance bottlenecks that limit scalability and response times. This specification addresses four primary optimisation areas:
|
||||||
|
|
||||||
|
1. **Graph Traversal Optimisation**: Eliminate inefficient recursive database queries and implement batched graph exploration
|
||||||
|
2. **Label Resolution Optimisation**: Replace sequential label fetching with parallel/batched operations
|
||||||
|
3. **Caching Strategy Enhancement**: Implement intelligent caching with LRU eviction and prefetching
|
||||||
|
4. **Query Optimisation**: Add result memoisation and embedding caching for improved response times
|
||||||
|
|
||||||
|
## Goals
|
||||||
|
|
||||||
|
- **Reduce Database Query Volume**: Achieve 50-80% reduction in total database queries through batching and caching
|
||||||
|
- **Improve Response Times**: Target 3-5x faster subgraph construction and 2-3x faster label resolution
|
||||||
|
- **Enhance Scalability**: Support larger knowledge graphs with better memory management
|
||||||
|
- **Maintain Accuracy**: Preserve existing GraphRAG functionality and result quality
|
||||||
|
- **Enable Concurrency**: Improve parallel processing capabilities for multiple concurrent requests
|
||||||
|
- **Reduce Memory Footprint**: Implement efficient data structures and memory management
|
||||||
|
- **Add Observability**: Include performance metrics and monitoring capabilities
|
||||||
|
- **Ensure Reliability**: Add proper error handling and timeout mechanisms
|
||||||
|
|
||||||
|
## Background
|
||||||
|
|
||||||
|
The current GraphRAG implementation in `trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py` exhibits several critical performance issues that severely impact system scalability:
|
||||||
|
|
||||||
|
### Current Performance Problems
|
||||||
|
|
||||||
|
**1. Inefficient Graph Traversal (`follow_edges` function, lines 79-127)**
|
||||||
|
- Makes 3 separate database queries per entity per depth level
|
||||||
|
- Query pattern: subject-based, predicate-based, and object-based queries for each entity
|
||||||
|
- No batching: Each query processes only one entity at a time
|
||||||
|
- No cycle detection: Can revisit the same nodes multiple times
|
||||||
|
- Recursive implementation without memoisation leads to exponential complexity
|
||||||
|
- Time complexity: O(entities × max_path_length × triple_limit³)
|
||||||
|
|
||||||
|
**2. Sequential Label Resolution (`get_labelgraph` function, lines 144-171)**
|
||||||
|
- Processes each triple component (subject, predicate, object) sequentially
|
||||||
|
- Each `maybe_label` call potentially triggers a database query
|
||||||
|
- No parallel execution or batching of label queries
|
||||||
|
- Results in up to 3 × subgraph_size individual database calls
|
||||||
|
|
||||||
|
**3. Primitive Caching Strategy (`maybe_label` function, lines 62-77)**
|
||||||
|
- Simple dictionary cache without size limits or TTL
|
||||||
|
- No cache eviction policy leads to unbounded memory growth
|
||||||
|
- Cache misses trigger individual database queries
|
||||||
|
- No prefetching or intelligent cache warming
|
||||||
|
|
||||||
|
**4. Suboptimal Query Patterns**
|
||||||
|
- Entity vector similarity queries not cached between similar requests
|
||||||
|
- No result memoisation for repeated query patterns
|
||||||
|
- Missing query optimisation for common access patterns
|
||||||
|
|
||||||
|
**5. Critical Object Lifetime Issues (`rag.py:96-102`)**
|
||||||
|
- **GraphRag object recreated per request**: Fresh instance created for every query, losing all cache benefits
|
||||||
|
- **Query object extremely short-lived**: Created and destroyed within single query execution (lines 201-207)
|
||||||
|
- **Label cache reset per request**: Cache warming and accumulated knowledge lost between requests
|
||||||
|
- **Client recreation overhead**: Database clients potentially re-established for each request
|
||||||
|
- **No cross-request optimisation**: Cannot benefit from query patterns or result sharing
|
||||||
|
|
||||||
|
### Performance Impact Analysis
|
||||||
|
|
||||||
|
Current worst-case scenario for a typical query:
|
||||||
|
- **Entity Retrieval**: 1 vector similarity query
|
||||||
|
- **Graph Traversal**: entities × max_path_length × 3 × triple_limit queries
|
||||||
|
- **Label Resolution**: subgraph_size × 3 individual label queries
|
||||||
|
|
||||||
|
For default parameters (50 entities, path length 2, 30 triple limit, 150 subgraph size):
|
||||||
|
- **Minimum queries**: 1 + (50 × 2 × 3 × 30) + (150 × 3) = **9,451 database queries**
|
||||||
|
- **Response time**: 15-30 seconds for moderate-sized graphs
|
||||||
|
- **Memory usage**: Unbounded cache growth over time
|
||||||
|
- **Cache effectiveness**: 0% - caches reset on every request
|
||||||
|
- **Object creation overhead**: GraphRag + Query objects created/destroyed per request
|
||||||
|
|
||||||
|
This specification addresses these gaps by implementing batched queries, intelligent caching, and parallel processing. By optimizing query patterns and data access, TrustGraph can:
|
||||||
|
- Support enterprise-scale knowledge graphs with millions of entities
|
||||||
|
- Provide sub-second response times for typical queries
|
||||||
|
- Handle hundreds of concurrent GraphRAG requests
|
||||||
|
- Scale efficiently with graph size and complexity
|
||||||
|
|
||||||
|
## Technical Design
|
||||||
|
|
||||||
|
### Architecture
|
||||||
|
|
||||||
|
The GraphRAG performance optimisation requires the following technical components:
|
||||||
|
|
||||||
|
#### 1. **Object Lifetime Architectural Refactor**
|
||||||
|
- **Make GraphRag long-lived**: Move GraphRag instance to Processor level for persistence across requests
|
||||||
|
- **Preserve caches**: Maintain label cache, embedding cache, and query result cache between requests
|
||||||
|
- **Optimize Query object**: Refactor Query as lightweight execution context, not data container
|
||||||
|
- **Connection persistence**: Maintain database client connections across requests
|
||||||
|
|
||||||
|
Module: `trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py` (modified)
|
||||||
|
|
||||||
|
#### 2. **Optimized Graph Traversal Engine**
|
||||||
|
- Replace recursive `follow_edges` with iterative breadth-first search
|
||||||
|
- Implement batched entity processing at each traversal level
|
||||||
|
- Add cycle detection using visited node tracking
|
||||||
|
- Include early termination when limits are reached
|
||||||
|
|
||||||
|
Module: `trustgraph-flow/trustgraph/retrieval/graph_rag/optimized_traversal.py`
|
||||||
|
|
||||||
|
#### 3. **Parallel Label Resolution System**
|
||||||
|
- Batch label queries for multiple entities simultaneously
|
||||||
|
- Implement async/await patterns for concurrent database access
|
||||||
|
- Add intelligent prefetching for common label patterns
|
||||||
|
- Include label cache warming strategies
|
||||||
|
|
||||||
|
Module: `trustgraph-flow/trustgraph/retrieval/graph_rag/label_resolver.py`
|
||||||
|
|
||||||
|
#### 4. **Conservative Label Caching Layer**
|
||||||
|
- LRU cache with short TTL for labels only (5min) to balance performance vs consistency
|
||||||
|
- Cache metrics and hit ratio monitoring
|
||||||
|
- **No embedding caching**: Already cached per-query, no cross-query benefit
|
||||||
|
- **No query result caching**: Due to graph mutation consistency concerns
|
||||||
|
|
||||||
|
Module: `trustgraph-flow/trustgraph/retrieval/graph_rag/cache_manager.py`
|
||||||
|
|
||||||
|
#### 5. **Query Optimisation Framework**
|
||||||
|
- Query pattern analysis and optimisation suggestions
|
||||||
|
- Batch query coordinator for database access
|
||||||
|
- Connection pooling and query timeout management
|
||||||
|
- Performance monitoring and metrics collection
|
||||||
|
|
||||||
|
Module: `trustgraph-flow/trustgraph/retrieval/graph_rag/query_optimizer.py`
|
||||||
|
|
||||||
|
### Data Models
|
||||||
|
|
||||||
|
#### Optimized Graph Traversal State
|
||||||
|
|
||||||
|
The traversal engine maintains state to avoid redundant operations:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class TraversalState:
|
||||||
|
visited_entities: Set[str]
|
||||||
|
current_level_entities: Set[str]
|
||||||
|
next_level_entities: Set[str]
|
||||||
|
subgraph: Set[Tuple[str, str, str]]
|
||||||
|
depth: int
|
||||||
|
query_batch: List[TripleQuery]
|
||||||
|
```
|
||||||
|
|
||||||
|
This approach allows:
|
||||||
|
- Efficient cycle detection through visited entity tracking
|
||||||
|
- Batched query preparation at each traversal level
|
||||||
|
- Memory-efficient state management
|
||||||
|
- Early termination when size limits are reached
|
||||||
|
|
||||||
|
#### Enhanced Cache Structure
|
||||||
|
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class CacheEntry:
|
||||||
|
value: Any
|
||||||
|
timestamp: float
|
||||||
|
access_count: int
|
||||||
|
ttl: Optional[float]
|
||||||
|
|
||||||
|
class CacheManager:
|
||||||
|
label_cache: LRUCache[str, CacheEntry]
|
||||||
|
embedding_cache: LRUCache[str, CacheEntry]
|
||||||
|
query_result_cache: LRUCache[str, CacheEntry]
|
||||||
|
cache_stats: CacheStatistics
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Batch Query Structures
|
||||||
|
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class BatchTripleQuery:
|
||||||
|
entities: List[str]
|
||||||
|
query_type: QueryType # SUBJECT, PREDICATE, OBJECT
|
||||||
|
limit_per_entity: int
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchLabelQuery:
|
||||||
|
entities: List[str]
|
||||||
|
predicate: str = LABEL
|
||||||
|
```
|
||||||
|
|
||||||
|
### APIs
|
||||||
|
|
||||||
|
#### New APIs:
|
||||||
|
|
||||||
|
**GraphTraversal API**
|
||||||
|
```python
|
||||||
|
async def optimized_follow_edges_batch(
|
||||||
|
entities: List[str],
|
||||||
|
max_depth: int,
|
||||||
|
triple_limit: int,
|
||||||
|
max_subgraph_size: int
|
||||||
|
) -> Set[Tuple[str, str, str]]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Batch Label Resolution API**
|
||||||
|
```python
|
||||||
|
async def resolve_labels_batch(
|
||||||
|
entities: List[str],
|
||||||
|
cache_manager: CacheManager
|
||||||
|
) -> Dict[str, str]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Cache Management API**
|
||||||
|
```python
|
||||||
|
class CacheManager:
|
||||||
|
async def get_or_fetch_label(self, entity: str) -> str
|
||||||
|
async def get_or_fetch_embeddings(self, query: str) -> List[float]
|
||||||
|
async def cache_query_result(self, query_hash: str, result: Any, ttl: int)
|
||||||
|
def get_cache_statistics(self) -> CacheStatistics
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Modified APIs:
|
||||||
|
|
||||||
|
**GraphRag.query()** - Enhanced with performance optimisations:
|
||||||
|
- Add cache_manager parameter for cache control
|
||||||
|
- Include performance_metrics return value
|
||||||
|
- Add query_timeout parameter for reliability
|
||||||
|
|
||||||
|
**Query class** - Refactored for batch processing:
|
||||||
|
- Replace individual entity processing with batch operations
|
||||||
|
- Add async context managers for resource cleanup
|
||||||
|
- Include progress callbacks for long-running operations
|
||||||
|
|
||||||
|
### Implementation Details
|
||||||
|
|
||||||
|
#### Phase 0: Critical Architectural Lifetime Refactor
|
||||||
|
|
||||||
|
**Current Problematic Implementation:**
|
||||||
|
```python
|
||||||
|
# INEFFICIENT: GraphRag recreated every request
|
||||||
|
class Processor(FlowProcessor):
|
||||||
|
async def on_request(self, msg, consumer, flow):
|
||||||
|
# PROBLEM: New GraphRag instance per request!
|
||||||
|
self.rag = GraphRag(
|
||||||
|
embeddings_client = flow("embeddings-request"),
|
||||||
|
graph_embeddings_client = flow("graph-embeddings-request"),
|
||||||
|
triples_client = flow("triples-request"),
|
||||||
|
prompt_client = flow("prompt-request"),
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
# Cache starts empty every time - no benefit from previous requests
|
||||||
|
response = await self.rag.query(...)
|
||||||
|
|
||||||
|
# VERY SHORT-LIVED: Query object created/destroyed per request
|
||||||
|
class GraphRag:
|
||||||
|
async def query(self, query, user="trustgraph", collection="default", ...):
|
||||||
|
q = Query(rag=self, user=user, collection=collection, ...) # Created
|
||||||
|
kg = await q.get_labelgraph(query) # Used briefly
|
||||||
|
# q automatically destroyed when function exits
|
||||||
|
```
|
||||||
|
|
||||||
|
**Optimized Long-Lived Architecture:**
|
||||||
|
```python
|
||||||
|
class Processor(FlowProcessor):
|
||||||
|
def __init__(self, **params):
|
||||||
|
super().__init__(**params)
|
||||||
|
self.rag_instance = None # Will be initialized once
|
||||||
|
self.client_connections = {}
|
||||||
|
|
||||||
|
async def initialize_rag(self, flow):
|
||||||
|
"""Initialize GraphRag once, reuse for all requests"""
|
||||||
|
if self.rag_instance is None:
|
||||||
|
self.rag_instance = LongLivedGraphRag(
|
||||||
|
embeddings_client=flow("embeddings-request"),
|
||||||
|
graph_embeddings_client=flow("graph-embeddings-request"),
|
||||||
|
triples_client=flow("triples-request"),
|
||||||
|
prompt_client=flow("prompt-request"),
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
return self.rag_instance
|
||||||
|
|
||||||
|
async def on_request(self, msg, consumer, flow):
|
||||||
|
# REUSE the same GraphRag instance - caches persist!
|
||||||
|
rag = await self.initialize_rag(flow)
|
||||||
|
|
||||||
|
# Query object becomes lightweight execution context
|
||||||
|
response = await rag.query_with_context(
|
||||||
|
query=v.query,
|
||||||
|
execution_context=QueryContext(
|
||||||
|
user=v.user,
|
||||||
|
collection=v.collection,
|
||||||
|
entity_limit=entity_limit,
|
||||||
|
# ... other params
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
class LongLivedGraphRag:
|
||||||
|
def __init__(self, ...):
|
||||||
|
# CONSERVATIVE caches - balance performance vs consistency
|
||||||
|
self.label_cache = LRUCacheWithTTL(max_size=5000, ttl=300) # 5min TTL for freshness
|
||||||
|
# Note: No embedding cache - already cached per-query, no cross-query benefit
|
||||||
|
# Note: No query result cache due to consistency concerns
|
||||||
|
self.performance_metrics = PerformanceTracker()
|
||||||
|
|
||||||
|
async def query_with_context(self, query: str, context: QueryContext):
|
||||||
|
# Use lightweight QueryExecutor instead of heavyweight Query object
|
||||||
|
executor = QueryExecutor(self, context) # Minimal object
|
||||||
|
return await executor.execute(query)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QueryContext:
|
||||||
|
"""Lightweight execution context - no heavy operations"""
|
||||||
|
user: str
|
||||||
|
collection: str
|
||||||
|
entity_limit: int
|
||||||
|
triple_limit: int
|
||||||
|
max_subgraph_size: int
|
||||||
|
max_path_length: int
|
||||||
|
|
||||||
|
class QueryExecutor:
|
||||||
|
"""Lightweight execution context - replaces old Query class"""
|
||||||
|
def __init__(self, rag: LongLivedGraphRag, context: QueryContext):
|
||||||
|
self.rag = rag
|
||||||
|
self.context = context
|
||||||
|
# No heavy initialization - just references
|
||||||
|
|
||||||
|
async def execute(self, query: str):
|
||||||
|
# All heavy lifting uses persistent rag caches
|
||||||
|
return await self.rag.execute_optimized_query(query, self.context)
|
||||||
|
```
|
||||||
|
|
||||||
|
This architectural change provides:
|
||||||
|
- **10-20% database query reduction** for graphs with common relationships (vs 0% currently)
|
||||||
|
- **Eliminated object creation overhead** for every request
|
||||||
|
- **Persistent connection pooling** and client reuse
|
||||||
|
- **Cross-request optimization** within cache TTL windows
|
||||||
|
|
||||||
|
**Important Cache Consistency Limitation:**
|
||||||
|
Long-term caching introduces staleness risk when entities/labels are deleted or modified in the underlying graph. The LRU cache with TTL provides a balance between performance gains and data freshness, but cannot detect real-time graph changes.
|
||||||
|
|
||||||
|
#### Phase 1: Graph Traversal Optimisation
|
||||||
|
|
||||||
|
**Current Implementation Problems:**
|
||||||
|
```python
|
||||||
|
# INEFFICIENT: 3 queries per entity per level
|
||||||
|
async def follow_edges(self, ent, subgraph, path_length):
|
||||||
|
# Query 1: s=ent, p=None, o=None
|
||||||
|
res = await self.rag.triples_client.query(s=ent, p=None, o=None, limit=self.triple_limit)
|
||||||
|
# Query 2: s=None, p=ent, o=None
|
||||||
|
res = await self.rag.triples_client.query(s=None, p=ent, o=None, limit=self.triple_limit)
|
||||||
|
# Query 3: s=None, p=None, o=ent
|
||||||
|
res = await self.rag.triples_client.query(s=None, p=None, o=ent, limit=self.triple_limit)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Optimized Implementation:**
|
||||||
|
```python
|
||||||
|
async def optimized_traversal(self, entities: List[str], max_depth: int) -> Set[Triple]:
|
||||||
|
visited = set()
|
||||||
|
current_level = set(entities)
|
||||||
|
subgraph = set()
|
||||||
|
|
||||||
|
for depth in range(max_depth):
|
||||||
|
if not current_level or len(subgraph) >= self.max_subgraph_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Batch all queries for current level
|
||||||
|
batch_queries = []
|
||||||
|
for entity in current_level:
|
||||||
|
if entity not in visited:
|
||||||
|
batch_queries.extend([
|
||||||
|
TripleQuery(s=entity, p=None, o=None),
|
||||||
|
TripleQuery(s=None, p=entity, o=None),
|
||||||
|
TripleQuery(s=None, p=None, o=entity)
|
||||||
|
])
|
||||||
|
|
||||||
|
# Execute all queries concurrently
|
||||||
|
results = await self.execute_batch_queries(batch_queries)
|
||||||
|
|
||||||
|
# Process results and prepare next level
|
||||||
|
next_level = set()
|
||||||
|
for result in results:
|
||||||
|
subgraph.update(result.triples)
|
||||||
|
next_level.update(result.new_entities)
|
||||||
|
|
||||||
|
visited.update(current_level)
|
||||||
|
current_level = next_level - visited
|
||||||
|
|
||||||
|
return subgraph
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Phase 2: Parallel Label Resolution
|
||||||
|
|
||||||
|
**Current Sequential Implementation:**
|
||||||
|
```python
|
||||||
|
# INEFFICIENT: Sequential processing
|
||||||
|
for edge in subgraph:
|
||||||
|
s = await self.maybe_label(edge[0]) # Individual query
|
||||||
|
p = await self.maybe_label(edge[1]) # Individual query
|
||||||
|
o = await self.maybe_label(edge[2]) # Individual query
|
||||||
|
```
|
||||||
|
|
||||||
|
**Optimized Parallel Implementation:**
|
||||||
|
```python
|
||||||
|
async def resolve_labels_parallel(self, subgraph: List[Triple]) -> List[Triple]:
|
||||||
|
# Collect all unique entities needing labels
|
||||||
|
entities_to_resolve = set()
|
||||||
|
for s, p, o in subgraph:
|
||||||
|
entities_to_resolve.update([s, p, o])
|
||||||
|
|
||||||
|
# Remove already cached entities
|
||||||
|
uncached_entities = [e for e in entities_to_resolve if e not in self.label_cache]
|
||||||
|
|
||||||
|
# Batch query for all uncached labels
|
||||||
|
if uncached_entities:
|
||||||
|
label_results = await self.batch_label_query(uncached_entities)
|
||||||
|
self.label_cache.update(label_results)
|
||||||
|
|
||||||
|
# Apply labels to subgraph
|
||||||
|
return [
|
||||||
|
(self.label_cache.get(s, s), self.label_cache.get(p, p), self.label_cache.get(o, o))
|
||||||
|
for s, p, o in subgraph
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Phase 3: Advanced Caching Strategy
|
||||||
|
|
||||||
|
**LRU Cache with TTL:**
|
||||||
|
```python
|
||||||
|
class LRUCacheWithTTL:
|
||||||
|
def __init__(self, max_size: int, default_ttl: int = 3600):
|
||||||
|
self.cache = OrderedDict()
|
||||||
|
self.max_size = max_size
|
||||||
|
self.default_ttl = default_ttl
|
||||||
|
self.access_times = {}
|
||||||
|
|
||||||
|
async def get(self, key: str) -> Optional[Any]:
|
||||||
|
if key in self.cache:
|
||||||
|
# Check TTL expiration
|
||||||
|
if time.time() - self.access_times[key] > self.default_ttl:
|
||||||
|
del self.cache[key]
|
||||||
|
del self.access_times[key]
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Move to end (most recently used)
|
||||||
|
self.cache.move_to_end(key)
|
||||||
|
return self.cache[key]
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def put(self, key: str, value: Any):
|
||||||
|
if key in self.cache:
|
||||||
|
self.cache.move_to_end(key)
|
||||||
|
else:
|
||||||
|
if len(self.cache) >= self.max_size:
|
||||||
|
# Remove least recently used
|
||||||
|
oldest_key = next(iter(self.cache))
|
||||||
|
del self.cache[oldest_key]
|
||||||
|
del self.access_times[oldest_key]
|
||||||
|
|
||||||
|
self.cache[key] = value
|
||||||
|
self.access_times[key] = time.time()
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Phase 4: Query Optimisation and Monitoring
|
||||||
|
|
||||||
|
**Performance Metrics Collection:**
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class PerformanceMetrics:
|
||||||
|
total_queries: int
|
||||||
|
cache_hits: int
|
||||||
|
cache_misses: int
|
||||||
|
avg_response_time: float
|
||||||
|
subgraph_construction_time: float
|
||||||
|
label_resolution_time: float
|
||||||
|
total_entities_processed: int
|
||||||
|
memory_usage_mb: float
|
||||||
|
```
|
||||||
|
|
||||||
|
**Query Timeout and Circuit Breaker:**
|
||||||
|
```python
|
||||||
|
async def execute_with_timeout(self, query_func, timeout: int = 30):
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(query_func(), timeout=timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"Query timeout after {timeout}s")
|
||||||
|
raise GraphRagTimeoutError(f"Query exceeded timeout of {timeout}s")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Cache Consistency Considerations
|
||||||
|
|
||||||
|
**Data Staleness Trade-offs:**
|
||||||
|
- **Label cache (5min TTL)**: Risk of serving deleted/renamed entity labels
|
||||||
|
- **No embedding caching**: Not needed - embeddings already cached per-query
|
||||||
|
- **No result caching**: Prevents stale subgraph results from deleted entities/relationships
|
||||||
|
|
||||||
|
**Mitigation Strategies:**
|
||||||
|
- **Conservative TTL values**: Balance performance gains (10-20%) with data freshness
|
||||||
|
- **Cache invalidation hooks**: Optional integration with graph mutation events
|
||||||
|
- **Monitoring dashboards**: Track cache hit rates vs staleness incidents
|
||||||
|
- **Configurable cache policies**: Allow per-deployment tuning based on mutation frequency
|
||||||
|
|
||||||
|
**Recommended Cache Configuration by Graph Mutation Rate:**
|
||||||
|
- **High mutation (>100 changes/hour)**: TTL=60s, smaller cache sizes
|
||||||
|
- **Medium mutation (10-100 changes/hour)**: TTL=300s (default)
|
||||||
|
- **Low mutation (<10 changes/hour)**: TTL=600s, larger cache sizes
|
||||||
|
|
||||||
|
## Security Considerations
|
||||||
|
|
||||||
|
**Query Injection Prevention:**
|
||||||
|
- Validate all entity identifiers and query parameters
|
||||||
|
- Use parameterized queries for all database interactions
|
||||||
|
- Implement query complexity limits to prevent DoS attacks
|
||||||
|
|
||||||
|
**Resource Protection:**
|
||||||
|
- Enforce maximum subgraph size limits
|
||||||
|
- Implement query timeouts to prevent resource exhaustion
|
||||||
|
- Add memory usage monitoring and limits
|
||||||
|
|
||||||
|
**Access Control:**
|
||||||
|
- Maintain existing user and collection isolation
|
||||||
|
- Add audit logging for performance-impacting operations
|
||||||
|
- Implement rate limiting for expensive operations
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
### Expected Performance Improvements
|
||||||
|
|
||||||
|
**Query Reduction:**
|
||||||
|
- Current: ~9,000+ queries for typical request
|
||||||
|
- Optimized: ~50-100 batched queries (98% reduction)
|
||||||
|
|
||||||
|
**Response Time Improvements:**
|
||||||
|
- Graph traversal: 15-20s → 3-5s (4-5x faster)
|
||||||
|
- Label resolution: 8-12s → 2-4s (3x faster)
|
||||||
|
- Overall query: 25-35s → 6-10s (3-4x improvement)
|
||||||
|
|
||||||
|
**Memory Efficiency:**
|
||||||
|
- Bounded cache sizes prevent memory leaks
|
||||||
|
- Efficient data structures reduce memory footprint by ~40%
|
||||||
|
- Better garbage collection through proper resource cleanup
|
||||||
|
|
||||||
|
**Realistic Performance Expectations:**
|
||||||
|
- **Label cache**: 10-20% query reduction for graphs with common relationships
|
||||||
|
- **Batching optimization**: 50-80% query reduction (primary optimization)
|
||||||
|
- **Object lifetime optimization**: Eliminate per-request creation overhead
|
||||||
|
- **Overall improvement**: 3-4x response time improvement primarily from batching
|
||||||
|
|
||||||
|
**Scalability Improvements:**
|
||||||
|
- Support for 3-5x larger knowledge graphs (limited by cache consistency needs)
|
||||||
|
- 3-5x higher concurrent request capacity
|
||||||
|
- Better resource utilization through connection reuse
|
||||||
|
|
||||||
|
### Performance Monitoring
|
||||||
|
|
||||||
|
**Real-time Metrics:**
|
||||||
|
- Query execution times by operation type
|
||||||
|
- Cache hit ratios and effectiveness
|
||||||
|
- Database connection pool utilisation
|
||||||
|
- Memory usage and garbage collection impact
|
||||||
|
|
||||||
|
**Performance Benchmarking:**
|
||||||
|
- Automated performance regression testing
|
||||||
|
- Load testing with realistic data volumes
|
||||||
|
- Comparison benchmarks against current implementation
|
||||||
|
|
||||||
|
## Testing Strategy
|
||||||
|
|
||||||
|
### Unit Testing
|
||||||
|
- Individual component testing for traversal, caching, and label resolution
|
||||||
|
- Mock database interactions for performance testing
|
||||||
|
- Cache eviction and TTL expiration testing
|
||||||
|
- Error handling and timeout scenarios
|
||||||
|
|
||||||
|
### Integration Testing
|
||||||
|
- End-to-end GraphRAG query testing with optimisations
|
||||||
|
- Database interaction testing with real data
|
||||||
|
- Concurrent request handling and resource management
|
||||||
|
- Memory leak detection and resource cleanup verification
|
||||||
|
|
||||||
|
### Performance Testing
|
||||||
|
- Benchmark testing against current implementation
|
||||||
|
- Load testing with varying graph sizes and complexities
|
||||||
|
- Stress testing for memory and connection limits
|
||||||
|
- Regression testing for performance improvements
|
||||||
|
|
||||||
|
### Compatibility Testing
|
||||||
|
- Verify existing GraphRAG API compatibility
|
||||||
|
- Test with various graph database backends
|
||||||
|
- Validate result accuracy compared to current implementation
|
||||||
|
|
||||||
|
## Implementation Plan
|
||||||
|
|
||||||
|
### Direct Implementation Approach
|
||||||
|
Since APIs are allowed to change, implement optimizations directly without migration complexity:
|
||||||
|
|
||||||
|
1. **Replace `follow_edges` method**: Rewrite with iterative batched traversal
|
||||||
|
2. **Optimize `get_labelgraph`**: Implement parallel label resolution
|
||||||
|
3. **Add long-lived GraphRag**: Modify Processor to maintain persistent instance
|
||||||
|
4. **Implement label caching**: Add LRU cache with TTL to GraphRag class
|
||||||
|
|
||||||
|
### Scope of Changes
|
||||||
|
- **Query class**: Replace ~50 lines in `follow_edges`, add ~30 lines batch handling
|
||||||
|
- **GraphRag class**: Add caching layer (~40 lines)
|
||||||
|
- **Processor class**: Modify to use persistent GraphRag instance (~20 lines)
|
||||||
|
- **Total**: ~140 lines of focused changes, mostly within existing classes
|
||||||
|
|
||||||
|
## Timeline
|
||||||
|
|
||||||
|
**Week 1: Core Implementation**
|
||||||
|
- Replace `follow_edges` with batched iterative traversal
|
||||||
|
- Implement parallel label resolution in `get_labelgraph`
|
||||||
|
- Add long-lived GraphRag instance to Processor
|
||||||
|
- Implement label caching layer
|
||||||
|
|
||||||
|
**Week 2: Testing and Integration**
|
||||||
|
- Unit tests for new traversal and caching logic
|
||||||
|
- Performance benchmarking against current implementation
|
||||||
|
- Integration testing with real graph data
|
||||||
|
- Code review and optimization
|
||||||
|
|
||||||
|
**Week 3: Deployment**
|
||||||
|
- Deploy optimized implementation
|
||||||
|
- Monitor performance improvements
|
||||||
|
- Fine-tune cache TTL and batch sizes based on real usage
|
||||||
|
|
||||||
|
## Open Questions
|
||||||
|
|
||||||
|
- **Database Connection Pooling**: Should we implement custom connection pooling or rely on existing database client pooling?
|
||||||
|
- **Cache Persistence**: Should label and embedding caches persist across service restarts?
|
||||||
|
- **Distributed Caching**: For multi-instance deployments, should we implement distributed caching with Redis/Memcached?
|
||||||
|
- **Query Result Format**: Should we optimize the internal triple representation for better memory efficiency?
|
||||||
|
- **Monitoring Integration**: Which metrics should be exposed to existing monitoring systems (Prometheus, etc.)?
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- [GraphRAG Original Implementation](trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py)
|
||||||
|
- [TrustGraph Architecture Principles](architecture-principles.md)
|
||||||
|
- [Collection Management Specification](collection-management.md)
|
||||||
|
|
@ -34,7 +34,9 @@ class TestGraphRag:
|
||||||
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
||||||
assert graph_rag.triples_client == mock_triples_client
|
assert graph_rag.triples_client == mock_triples_client
|
||||||
assert graph_rag.verbose is False # Default value
|
assert graph_rag.verbose is False # Default value
|
||||||
assert graph_rag.label_cache == {} # Empty cache initially
|
# Verify label_cache is an LRUCacheWithTTL instance
|
||||||
|
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
|
||||||
|
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
|
||||||
|
|
||||||
def test_graph_rag_initialization_with_verbose(self):
|
def test_graph_rag_initialization_with_verbose(self):
|
||||||
"""Test GraphRag initialization with verbose enabled"""
|
"""Test GraphRag initialization with verbose enabled"""
|
||||||
|
|
@ -59,7 +61,9 @@ class TestGraphRag:
|
||||||
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
||||||
assert graph_rag.triples_client == mock_triples_client
|
assert graph_rag.triples_client == mock_triples_client
|
||||||
assert graph_rag.verbose is True
|
assert graph_rag.verbose is True
|
||||||
assert graph_rag.label_cache == {} # Empty cache initially
|
# Verify label_cache is an LRUCacheWithTTL instance
|
||||||
|
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
|
||||||
|
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
|
||||||
|
|
||||||
|
|
||||||
class TestQuery:
|
class TestQuery:
|
||||||
|
|
@ -228,7 +232,10 @@ class TestQuery:
|
||||||
"""Test Query.maybe_label method with cached label"""
|
"""Test Query.maybe_label method with cached label"""
|
||||||
# Create mock GraphRag with label cache
|
# Create mock GraphRag with label cache
|
||||||
mock_rag = MagicMock()
|
mock_rag = MagicMock()
|
||||||
mock_rag.label_cache = {"entity1": "Entity One Label"}
|
# Create mock LRUCacheWithTTL
|
||||||
|
mock_cache = MagicMock()
|
||||||
|
mock_cache.get.return_value = "Entity One Label"
|
||||||
|
mock_rag.label_cache = mock_cache
|
||||||
|
|
||||||
# Initialize Query
|
# Initialize Query
|
||||||
query = Query(
|
query = Query(
|
||||||
|
|
@ -243,13 +250,18 @@ class TestQuery:
|
||||||
|
|
||||||
# Verify cached label is returned
|
# Verify cached label is returned
|
||||||
assert result == "Entity One Label"
|
assert result == "Entity One Label"
|
||||||
|
# Verify cache was checked with proper key format (user:collection:entity)
|
||||||
|
mock_cache.get.assert_called_once_with("test_user:test_collection:entity1")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_maybe_label_with_label_lookup(self):
|
async def test_maybe_label_with_label_lookup(self):
|
||||||
"""Test Query.maybe_label method with database label lookup"""
|
"""Test Query.maybe_label method with database label lookup"""
|
||||||
# Create mock GraphRag with triples client
|
# Create mock GraphRag with triples client
|
||||||
mock_rag = MagicMock()
|
mock_rag = MagicMock()
|
||||||
mock_rag.label_cache = {} # Empty cache
|
# Create mock LRUCacheWithTTL that returns None (cache miss)
|
||||||
|
mock_cache = MagicMock()
|
||||||
|
mock_cache.get.return_value = None
|
||||||
|
mock_rag.label_cache = mock_cache
|
||||||
mock_triples_client = AsyncMock()
|
mock_triples_client = AsyncMock()
|
||||||
mock_rag.triples_client = mock_triples_client
|
mock_rag.triples_client = mock_triples_client
|
||||||
|
|
||||||
|
|
@ -279,16 +291,20 @@ class TestQuery:
|
||||||
collection="test_collection"
|
collection="test_collection"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify result and cache update
|
# Verify result and cache update with proper key
|
||||||
assert result == "Human Readable Label"
|
assert result == "Human Readable Label"
|
||||||
assert mock_rag.label_cache["http://example.com/entity"] == "Human Readable Label"
|
cache_key = "test_user:test_collection:http://example.com/entity"
|
||||||
|
mock_cache.put.assert_called_once_with(cache_key, "Human Readable Label")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_maybe_label_with_no_label_found(self):
|
async def test_maybe_label_with_no_label_found(self):
|
||||||
"""Test Query.maybe_label method when no label is found"""
|
"""Test Query.maybe_label method when no label is found"""
|
||||||
# Create mock GraphRag with triples client
|
# Create mock GraphRag with triples client
|
||||||
mock_rag = MagicMock()
|
mock_rag = MagicMock()
|
||||||
mock_rag.label_cache = {} # Empty cache
|
# Create mock LRUCacheWithTTL that returns None (cache miss)
|
||||||
|
mock_cache = MagicMock()
|
||||||
|
mock_cache.get.return_value = None
|
||||||
|
mock_rag.label_cache = mock_cache
|
||||||
mock_triples_client = AsyncMock()
|
mock_triples_client = AsyncMock()
|
||||||
mock_rag.triples_client = mock_triples_client
|
mock_rag.triples_client = mock_triples_client
|
||||||
|
|
||||||
|
|
@ -318,7 +334,8 @@ class TestQuery:
|
||||||
|
|
||||||
# Verify result is entity itself and cache is updated
|
# Verify result is entity itself and cache is updated
|
||||||
assert result == "unlabeled_entity"
|
assert result == "unlabeled_entity"
|
||||||
assert mock_rag.label_cache["unlabeled_entity"] == "unlabeled_entity"
|
cache_key = "test_user:test_collection:unlabeled_entity"
|
||||||
|
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_follow_edges_basic_functionality(self):
|
async def test_follow_edges_basic_functionality(self):
|
||||||
|
|
@ -441,7 +458,7 @@ class TestQuery:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_subgraph_method(self):
|
async def test_get_subgraph_method(self):
|
||||||
"""Test Query.get_subgraph method orchestrates entity and edge discovery"""
|
"""Test Query.get_subgraph method orchestrates entity and edge discovery"""
|
||||||
# Create mock Query that patches get_entities and follow_edges
|
# Create mock Query that patches get_entities and follow_edges_batch
|
||||||
mock_rag = MagicMock()
|
mock_rag = MagicMock()
|
||||||
|
|
||||||
query = Query(
|
query = Query(
|
||||||
|
|
@ -455,11 +472,11 @@ class TestQuery:
|
||||||
# Mock get_entities to return test entities
|
# Mock get_entities to return test entities
|
||||||
query.get_entities = AsyncMock(return_value=["entity1", "entity2"])
|
query.get_entities = AsyncMock(return_value=["entity1", "entity2"])
|
||||||
|
|
||||||
# Mock follow_edges to add triples to subgraph
|
# Mock follow_edges_batch to return test triples
|
||||||
async def mock_follow_edges(ent, subgraph, path_length):
|
query.follow_edges_batch = AsyncMock(return_value={
|
||||||
subgraph.add((ent, "predicate", "object"))
|
("entity1", "predicate1", "object1"),
|
||||||
|
("entity2", "predicate2", "object2")
|
||||||
query.follow_edges = AsyncMock(side_effect=mock_follow_edges)
|
})
|
||||||
|
|
||||||
# Call get_subgraph
|
# Call get_subgraph
|
||||||
result = await query.get_subgraph("test query")
|
result = await query.get_subgraph("test query")
|
||||||
|
|
@ -467,14 +484,14 @@ class TestQuery:
|
||||||
# Verify get_entities was called
|
# Verify get_entities was called
|
||||||
query.get_entities.assert_called_once_with("test query")
|
query.get_entities.assert_called_once_with("test query")
|
||||||
|
|
||||||
# Verify follow_edges was called for each entity
|
# Verify follow_edges_batch was called with entities and max_path_length
|
||||||
assert query.follow_edges.call_count == 2
|
query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1)
|
||||||
query.follow_edges.assert_any_call("entity1", unittest.mock.ANY, 1)
|
|
||||||
query.follow_edges.assert_any_call("entity2", unittest.mock.ANY, 1)
|
|
||||||
|
|
||||||
# Verify result is list format
|
# Verify result is list format and contains expected triples
|
||||||
assert isinstance(result, list)
|
assert isinstance(result, list)
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
|
assert ("entity1", "predicate1", "object1") in result
|
||||||
|
assert ("entity2", "predicate2", "object2") in result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_labelgraph_method(self):
|
async def test_get_labelgraph_method(self):
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,56 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||||
|
|
||||||
|
class LRUCacheWithTTL:
|
||||||
|
"""LRU cache with TTL for label caching
|
||||||
|
|
||||||
|
CRITICAL SECURITY WARNING:
|
||||||
|
This cache is shared within a GraphRag instance but GraphRag instances
|
||||||
|
are created per-request. Cache keys MUST include user:collection prefix
|
||||||
|
to ensure data isolation between different security contexts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_size=5000, ttl=300):
|
||||||
|
self.cache = OrderedDict()
|
||||||
|
self.access_times = {}
|
||||||
|
self.max_size = max_size
|
||||||
|
self.ttl = ttl
|
||||||
|
|
||||||
|
def get(self, key):
|
||||||
|
if key not in self.cache:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check TTL expiration
|
||||||
|
if time.time() - self.access_times[key] > self.ttl:
|
||||||
|
del self.cache[key]
|
||||||
|
del self.access_times[key]
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Move to end (most recently used)
|
||||||
|
self.cache.move_to_end(key)
|
||||||
|
return self.cache[key]
|
||||||
|
|
||||||
|
def put(self, key, value):
|
||||||
|
if key in self.cache:
|
||||||
|
self.cache.move_to_end(key)
|
||||||
|
else:
|
||||||
|
if len(self.cache) >= self.max_size:
|
||||||
|
# Remove least recently used
|
||||||
|
oldest_key = next(iter(self.cache))
|
||||||
|
del self.cache[oldest_key]
|
||||||
|
del self.access_times[oldest_key]
|
||||||
|
|
||||||
|
self.cache[key] = value
|
||||||
|
self.access_times[key] = time.time()
|
||||||
|
|
||||||
class Query:
|
class Query:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -61,8 +105,14 @@ class Query:
|
||||||
|
|
||||||
async def maybe_label(self, e):
|
async def maybe_label(self, e):
|
||||||
|
|
||||||
if e in self.rag.label_cache:
|
# CRITICAL SECURITY: Cache key MUST include user and collection
|
||||||
return self.rag.label_cache[e]
|
# to prevent data leakage between different contexts
|
||||||
|
cache_key = f"{self.user}:{self.collection}:{e}"
|
||||||
|
|
||||||
|
# Check LRU cache first with isolated key
|
||||||
|
cached_label = self.rag.label_cache.get(cache_key)
|
||||||
|
if cached_label is not None:
|
||||||
|
return cached_label
|
||||||
|
|
||||||
res = await self.rag.triples_client.query(
|
res = await self.rag.triples_client.query(
|
||||||
s=e, p=LABEL, o=None, limit=1,
|
s=e, p=LABEL, o=None, limit=1,
|
||||||
|
|
@ -70,60 +120,104 @@ class Query:
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(res) == 0:
|
if len(res) == 0:
|
||||||
self.rag.label_cache[e] = e
|
self.rag.label_cache.put(cache_key, e)
|
||||||
return e
|
return e
|
||||||
|
|
||||||
self.rag.label_cache[e] = str(res[0].o)
|
label = str(res[0].o)
|
||||||
return self.rag.label_cache[e]
|
self.rag.label_cache.put(cache_key, label)
|
||||||
|
return label
|
||||||
|
|
||||||
|
async def execute_batch_triple_queries(self, entities, limit_per_entity):
|
||||||
|
"""Execute triple queries for multiple entities concurrently"""
|
||||||
|
tasks = []
|
||||||
|
|
||||||
|
for entity in entities:
|
||||||
|
# Create concurrent tasks for all 3 query types per entity
|
||||||
|
tasks.extend([
|
||||||
|
self.rag.triples_client.query(
|
||||||
|
s=entity, p=None, o=None,
|
||||||
|
limit=limit_per_entity,
|
||||||
|
user=self.user, collection=self.collection
|
||||||
|
),
|
||||||
|
self.rag.triples_client.query(
|
||||||
|
s=None, p=entity, o=None,
|
||||||
|
limit=limit_per_entity,
|
||||||
|
user=self.user, collection=self.collection
|
||||||
|
),
|
||||||
|
self.rag.triples_client.query(
|
||||||
|
s=None, p=None, o=entity,
|
||||||
|
limit=limit_per_entity,
|
||||||
|
user=self.user, collection=self.collection
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
# Execute all queries concurrently
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# Combine all results
|
||||||
|
all_triples = []
|
||||||
|
for result in results:
|
||||||
|
if not isinstance(result, Exception):
|
||||||
|
all_triples.extend(result)
|
||||||
|
|
||||||
|
return all_triples
|
||||||
|
|
||||||
|
async def follow_edges_batch(self, entities, max_depth):
|
||||||
|
"""Optimized iterative graph traversal with batching"""
|
||||||
|
visited = set()
|
||||||
|
current_level = set(entities)
|
||||||
|
subgraph = set()
|
||||||
|
|
||||||
|
for depth in range(max_depth):
|
||||||
|
if not current_level or len(subgraph) >= self.max_subgraph_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Filter out already visited entities
|
||||||
|
unvisited_entities = [e for e in current_level if e not in visited]
|
||||||
|
if not unvisited_entities:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Batch query all unvisited entities at current level
|
||||||
|
triples = await self.execute_batch_triple_queries(
|
||||||
|
unvisited_entities, self.triple_limit
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process results and collect next level entities
|
||||||
|
next_level = set()
|
||||||
|
for triple in triples:
|
||||||
|
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
|
||||||
|
subgraph.add(triple_tuple)
|
||||||
|
|
||||||
|
# Collect entities for next level (only from s and o positions)
|
||||||
|
if depth < max_depth - 1: # Don't collect for final depth
|
||||||
|
s, p, o = triple_tuple
|
||||||
|
if s not in visited:
|
||||||
|
next_level.add(s)
|
||||||
|
if o not in visited:
|
||||||
|
next_level.add(o)
|
||||||
|
|
||||||
|
# Stop if subgraph size limit reached
|
||||||
|
if len(subgraph) >= self.max_subgraph_size:
|
||||||
|
return subgraph
|
||||||
|
|
||||||
|
# Update for next iteration
|
||||||
|
visited.update(current_level)
|
||||||
|
current_level = next_level
|
||||||
|
|
||||||
|
return subgraph
|
||||||
|
|
||||||
async def follow_edges(self, ent, subgraph, path_length):
|
async def follow_edges(self, ent, subgraph, path_length):
|
||||||
|
"""Legacy method - replaced by follow_edges_batch"""
|
||||||
# Not needed?
|
# Maintain backward compatibility with early termination checks
|
||||||
if path_length <= 0:
|
if path_length <= 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Stop spanning around if the subgraph is already maxed out
|
|
||||||
if len(subgraph) >= self.max_subgraph_size:
|
if len(subgraph) >= self.max_subgraph_size:
|
||||||
return
|
return
|
||||||
|
|
||||||
res = await self.rag.triples_client.query(
|
# For backward compatibility, convert to new approach
|
||||||
s=ent, p=None, o=None,
|
batch_result = await self.follow_edges_batch([ent], path_length)
|
||||||
limit=self.triple_limit,
|
subgraph.update(batch_result)
|
||||||
user=self.user, collection=self.collection,
|
|
||||||
)
|
|
||||||
|
|
||||||
for triple in res:
|
|
||||||
subgraph.add(
|
|
||||||
(str(triple.s), str(triple.p), str(triple.o))
|
|
||||||
)
|
|
||||||
if path_length > 1:
|
|
||||||
await self.follow_edges(str(triple.o), subgraph, path_length-1)
|
|
||||||
|
|
||||||
res = await self.rag.triples_client.query(
|
|
||||||
s=None, p=ent, o=None,
|
|
||||||
limit=self.triple_limit,
|
|
||||||
user=self.user, collection=self.collection,
|
|
||||||
)
|
|
||||||
|
|
||||||
for triple in res:
|
|
||||||
subgraph.add(
|
|
||||||
(str(triple.s), str(triple.p), str(triple.o))
|
|
||||||
)
|
|
||||||
|
|
||||||
res = await self.rag.triples_client.query(
|
|
||||||
s=None, p=None, o=ent,
|
|
||||||
limit=self.triple_limit,
|
|
||||||
user=self.user, collection=self.collection,
|
|
||||||
)
|
|
||||||
|
|
||||||
for triple in res:
|
|
||||||
subgraph.add(
|
|
||||||
(str(triple.s), str(triple.p), str(triple.o))
|
|
||||||
)
|
|
||||||
if path_length > 1:
|
|
||||||
await self.follow_edges(
|
|
||||||
str(triple.s), subgraph, path_length-1
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_subgraph(self, query):
|
async def get_subgraph(self, query):
|
||||||
|
|
||||||
|
|
@ -132,31 +226,52 @@ class Query:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.debug("Getting subgraph...")
|
logger.debug("Getting subgraph...")
|
||||||
|
|
||||||
subgraph = set()
|
# Use optimized batch traversal instead of sequential processing
|
||||||
|
subgraph = await self.follow_edges_batch(entities, self.max_path_length)
|
||||||
|
|
||||||
for ent in entities:
|
return list(subgraph)
|
||||||
await self.follow_edges(ent, subgraph, self.max_path_length)
|
|
||||||
|
|
||||||
subgraph = list(subgraph)
|
async def resolve_labels_batch(self, entities):
|
||||||
|
"""Resolve labels for multiple entities in parallel"""
|
||||||
|
tasks = []
|
||||||
|
for entity in entities:
|
||||||
|
tasks.append(self.maybe_label(entity))
|
||||||
|
|
||||||
return subgraph
|
return await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
async def get_labelgraph(self, query):
|
async def get_labelgraph(self, query):
|
||||||
|
|
||||||
subgraph = await self.get_subgraph(query)
|
subgraph = await self.get_subgraph(query)
|
||||||
|
|
||||||
|
# Filter out label triples
|
||||||
|
filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL]
|
||||||
|
|
||||||
|
# Collect all unique entities that need label resolution
|
||||||
|
entities_to_resolve = set()
|
||||||
|
for s, p, o in filtered_subgraph:
|
||||||
|
entities_to_resolve.update([s, p, o])
|
||||||
|
|
||||||
|
# Batch resolve labels for all entities in parallel
|
||||||
|
entity_list = list(entities_to_resolve)
|
||||||
|
resolved_labels = await self.resolve_labels_batch(entity_list)
|
||||||
|
|
||||||
|
# Create entity-to-label mapping
|
||||||
|
label_map = {}
|
||||||
|
for entity, label in zip(entity_list, resolved_labels):
|
||||||
|
if not isinstance(label, Exception):
|
||||||
|
label_map[entity] = label
|
||||||
|
else:
|
||||||
|
label_map[entity] = entity # Fallback to entity itself
|
||||||
|
|
||||||
|
# Apply labels to subgraph
|
||||||
sg2 = []
|
sg2 = []
|
||||||
|
for s, p, o in filtered_subgraph:
|
||||||
for edge in subgraph:
|
labeled_triple = (
|
||||||
|
label_map.get(s, s),
|
||||||
if edge[1] == LABEL:
|
label_map.get(p, p),
|
||||||
continue
|
label_map.get(o, o)
|
||||||
|
)
|
||||||
s = await self.maybe_label(edge[0])
|
sg2.append(labeled_triple)
|
||||||
p = await self.maybe_label(edge[1])
|
|
||||||
o = await self.maybe_label(edge[2])
|
|
||||||
|
|
||||||
sg2.append((s, p, o))
|
|
||||||
|
|
||||||
sg2 = sg2[0:self.max_subgraph_size]
|
sg2 = sg2[0:self.max_subgraph_size]
|
||||||
|
|
||||||
|
|
@ -171,6 +286,13 @@ class Query:
|
||||||
return sg2
|
return sg2
|
||||||
|
|
||||||
class GraphRag:
|
class GraphRag:
|
||||||
|
"""
|
||||||
|
CRITICAL SECURITY:
|
||||||
|
This class MUST be instantiated per-request to ensure proper isolation
|
||||||
|
between users and collections. The cache within this instance will only
|
||||||
|
live for the duration of a single request, preventing cross-contamination
|
||||||
|
of data between different security contexts.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, prompt_client, embeddings_client, graph_embeddings_client,
|
self, prompt_client, embeddings_client, graph_embeddings_client,
|
||||||
|
|
@ -184,7 +306,9 @@ class GraphRag:
|
||||||
self.graph_embeddings_client = graph_embeddings_client
|
self.graph_embeddings_client = graph_embeddings_client
|
||||||
self.triples_client = triples_client
|
self.triples_client = triples_client
|
||||||
|
|
||||||
self.label_cache = {}
|
# Replace simple dict with LRU cache with TTL
|
||||||
|
# CRITICAL: This cache only lives for one request due to per-request instantiation
|
||||||
|
self.label_cache = LRUCacheWithTTL(max_size=5000, ttl=300)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.debug("GraphRag initialized")
|
logger.debug("GraphRag initialized")
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,10 @@ class Processor(FlowProcessor):
|
||||||
self.default_max_subgraph_size = max_subgraph_size
|
self.default_max_subgraph_size = max_subgraph_size
|
||||||
self.default_max_path_length = max_path_length
|
self.default_max_path_length = max_path_length
|
||||||
|
|
||||||
|
# CRITICAL SECURITY: NEVER share data between users or collections
|
||||||
|
# Each user/collection combination MUST have isolated data access
|
||||||
|
# Caching must NEVER allow information leakage across these boundaries
|
||||||
|
|
||||||
self.register_specification(
|
self.register_specification(
|
||||||
ConsumerSpec(
|
ConsumerSpec(
|
||||||
name = "request",
|
name = "request",
|
||||||
|
|
@ -93,11 +97,14 @@ class Processor(FlowProcessor):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
self.rag = GraphRag(
|
# CRITICAL SECURITY: Create new GraphRag instance per request
|
||||||
embeddings_client = flow("embeddings-request"),
|
# This ensures proper isolation between users and collections
|
||||||
graph_embeddings_client = flow("graph-embeddings-request"),
|
# Flow clients are request-scoped and must not be shared
|
||||||
triples_client = flow("triples-request"),
|
rag = GraphRag(
|
||||||
prompt_client = flow("prompt-request"),
|
embeddings_client=flow("embeddings-request"),
|
||||||
|
graph_embeddings_client=flow("graph-embeddings-request"),
|
||||||
|
triples_client=flow("triples-request"),
|
||||||
|
prompt_client=flow("prompt-request"),
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -128,7 +135,7 @@ class Processor(FlowProcessor):
|
||||||
else:
|
else:
|
||||||
max_path_length = self.default_max_path_length
|
max_path_length = self.default_max_path_length
|
||||||
|
|
||||||
response = await self.rag.query(
|
response = await rag.query(
|
||||||
query = v.query, user = v.user, collection = v.collection,
|
query = v.query, user = v.user, collection = v.collection,
|
||||||
entity_limit = entity_limit, triple_limit = triple_limit,
|
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||||
max_subgraph_size = max_subgraph_size,
|
max_subgraph_size = max_subgraph_size,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue