Enhance retrieval pipelines: 4-stage GraphRAG, DocRAG grounding (#697)

Enhance retrieval pipelines: 4-stage GraphRAG, DocRAG grounding,
consistent PROV-O

GraphRAG:
- Split retrieval into 4 prompt stages: extract-concepts,
  kg-edge-scoring,
  kg-edge-reasoning, kg-synthesis (was single-stage)
- Add concept extraction (grounding) for per-concept embedding
- Filter main query to default graph, ignoring
  provenance/explainability edges
- Add source document edges to knowledge graph

DocumentRAG:
- Add grounding step with concept extraction, matching GraphRAG's
  pattern:
  Question → Grounding → Exploration → Synthesis
- Per-concept embedding and chunk retrieval with deduplication

Cross-pipeline:
- Make PROV-O derivation links consistent: wasGeneratedBy for first
  entity from Activity, wasDerivedFrom for entity-to-entity chains
- Update CLIs (tg-invoke-agent, tg-invoke-graph-rag,
  tg-invoke-document-rag) for new explainability structure
- Fix all affected unit and integration tests
This commit is contained in:
cybermaggedon 2026-03-16 12:12:13 +00:00 committed by GitHub
parent 29b4300808
commit a115ec06ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 1537 additions and 1008 deletions

View file

@ -8,20 +8,23 @@ import uuid
from collections import OrderedDict
from datetime import datetime
from ... schema import IRI, LITERAL
from ... schema import Term, Triple as SchemaTriple, IRI, LITERAL, TRIPLE
# Provenance imports
from trustgraph.provenance import (
question_uri,
grounding_uri as make_grounding_uri,
exploration_uri as make_exploration_uri,
focus_uri as make_focus_uri,
synthesis_uri as make_synthesis_uri,
question_triples,
grounding_triples,
exploration_triples,
focus_triples,
synthesis_triples,
set_graph,
GRAPH_RETRIEVAL,
GRAPH_RETRIEVAL, GRAPH_SOURCE,
TG_CONTAINS, PROV_WAS_DERIVED_FROM,
)
# Module logger
@ -47,6 +50,8 @@ def edge_id(s, p, o):
edge_str = f"{s}|{p}|{o}"
return hashlib.sha256(edge_str.encode()).hexdigest()[:8]
class LRUCacheWithTTL:
"""LRU cache with TTL for label caching
@ -105,42 +110,88 @@ class Query:
self.max_subgraph_size = max_subgraph_size
self.max_path_length = max_path_length
async def get_vector(self, query):
async def extract_concepts(self, query):
"""Extract key concepts from query for independent embedding."""
response = await self.rag.prompt_client.prompt(
"extract-concepts",
variables={"query": query}
)
concepts = []
if isinstance(response, str):
for line in response.strip().split('\n'):
line = line.strip()
if line:
concepts.append(line)
if self.verbose:
logger.debug(f"Extracted concepts: {concepts}")
# Fall back to raw query if extraction returns nothing
return concepts if concepts else [query]
async def get_vectors(self, concepts):
"""Embed multiple concepts concurrently."""
if self.verbose:
logger.debug("Computing embeddings...")
qembeds = await self.rag.embeddings_client.embed([query])
qembeds = await self.rag.embeddings_client.embed(concepts)
if self.verbose:
logger.debug("Done.")
# Return the vector set for the first (only) text
return qembeds[0] if qembeds else []
return qembeds
async def get_entities(self, query):
"""
Extract concepts from query, embed them, and retrieve matching entities.
vectors = await self.get_vector(query)
Returns:
tuple: (entities, concepts) where entities is a list of entity URI
strings and concepts is the list of concept strings extracted
from the query.
"""
concepts = await self.extract_concepts(query)
vectors = await self.get_vectors(concepts)
if self.verbose:
logger.debug("Getting entities...")
entity_matches = await self.rag.graph_embeddings_client.query(
vector=vectors, limit=self.entity_limit,
user=self.user, collection=self.collection,
# Query entity matches for each concept concurrently
per_concept_limit = max(
1, self.entity_limit // len(vectors)
)
entities = [
term_to_string(e.entity)
for e in entity_matches
entity_tasks = [
self.rag.graph_embeddings_client.query(
vector=v, limit=per_concept_limit,
user=self.user, collection=self.collection,
)
for v in vectors
]
results = await asyncio.gather(*entity_tasks, return_exceptions=True)
# Deduplicate while preserving order
seen = set()
entities = []
for result in results:
if isinstance(result, Exception) or not result:
continue
for e in result:
entity = term_to_string(e.entity)
if entity not in seen:
seen.add(entity)
entities.append(entity)
if self.verbose:
logger.debug("Entities:")
for ent in entities:
logger.debug(f" {ent}")
return entities
return entities, concepts
async def maybe_label(self, e):
@ -156,6 +207,7 @@ class Query:
res = await self.rag.triples_client.query(
s=e, p=LABEL, o=None, limit=1,
user=self.user, collection=self.collection,
g="",
)
if len(res) == 0:
@ -177,19 +229,19 @@ class Query:
s=entity, p=None, o=None,
limit=limit_per_entity,
user=self.user, collection=self.collection,
batch_size=20,
batch_size=20, g="",
),
self.rag.triples_client.query_stream(
s=None, p=entity, o=None,
limit=limit_per_entity,
user=self.user, collection=self.collection,
batch_size=20,
batch_size=20, g="",
),
self.rag.triples_client.query_stream(
s=None, p=None, o=entity,
limit=limit_per_entity,
user=self.user, collection=self.collection,
batch_size=20,
batch_size=20, g="",
)
])
@ -262,8 +314,16 @@ class Query:
subgraph.update(batch_result)
async def get_subgraph(self, query):
"""
Get subgraph by extracting concepts, finding entities, and traversing.
entities = await self.get_entities(query)
Returns:
tuple: (subgraph, entities, concepts) where subgraph is a list of
(s, p, o) tuples, entities is the seed entity list, and concepts
is the extracted concept list.
"""
entities, concepts = await self.get_entities(query)
if self.verbose:
logger.debug("Getting subgraph...")
@ -271,7 +331,7 @@ class Query:
# Use optimized batch traversal instead of sequential processing
subgraph = await self.follow_edges_batch(entities, self.max_path_length)
return list(subgraph)
return list(subgraph), entities, concepts
async def resolve_labels_batch(self, entities):
"""Resolve labels for multiple entities in parallel"""
@ -286,11 +346,13 @@ class Query:
Get subgraph with labels resolved for display.
Returns:
tuple: (labeled_edges, uri_map) where:
tuple: (labeled_edges, uri_map, entities, concepts) where:
- labeled_edges: list of (label_s, label_p, label_o) tuples
- uri_map: dict mapping edge_id(label_s, label_p, label_o) -> (uri_s, uri_p, uri_o)
- entities: list of seed entity URI strings
- concepts: list of concept strings extracted from query
"""
subgraph = await self.get_subgraph(query)
subgraph, entities, concepts = await self.get_subgraph(query)
# Filter out label triples
filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL]
@ -338,8 +400,125 @@ class Query:
if self.verbose:
logger.debug("Done.")
return labeled_edges, uri_map
return labeled_edges, uri_map, entities, concepts
async def trace_source_documents(self, edge_uris):
"""
Trace selected edges back to their source documents via provenance.
Follows the chain: edge subgraph (via tg:contains) chunk
page document (via prov:wasDerivedFrom), all in urn:graph:source.
Args:
edge_uris: List of (s, p, o) URI string tuples
Returns:
List of unique document titles
"""
# Step 1: Find subgraphs containing these edges via tg:contains
subgraph_tasks = []
for s, p, o in edge_uris:
quoted = Term(
type=TRIPLE,
triple=SchemaTriple(
s=Term(type=IRI, iri=s),
p=Term(type=IRI, iri=p),
o=Term(type=IRI, iri=o),
)
)
subgraph_tasks.append(
self.rag.triples_client.query(
s=None, p=TG_CONTAINS, o=quoted, limit=1,
user=self.user, collection=self.collection,
g=GRAPH_SOURCE,
)
)
subgraph_results = await asyncio.gather(
*subgraph_tasks, return_exceptions=True
)
# Collect unique subgraph URIs
subgraph_uris = set()
for result in subgraph_results:
if isinstance(result, Exception) or not result:
continue
for triple in result:
subgraph_uris.add(str(triple.s))
if not subgraph_uris:
return []
# Step 2: Walk prov:wasDerivedFrom chain to find documents
# Each level: query ?entity prov:wasDerivedFrom ?parent
# Stop when we find entities typed tg:Document
current_uris = subgraph_uris
doc_uris = set()
for depth in range(4): # Max depth: subgraph → chunk → page → doc
if not current_uris:
break
derivation_tasks = [
self.rag.triples_client.query(
s=uri, p=PROV_WAS_DERIVED_FROM, o=None, limit=5,
user=self.user, collection=self.collection,
g=GRAPH_SOURCE,
)
for uri in current_uris
]
derivation_results = await asyncio.gather(
*derivation_tasks, return_exceptions=True
)
# URIs with no parent are root documents
next_uris = set()
for uri, result in zip(current_uris, derivation_results):
if isinstance(result, Exception) or not result:
doc_uris.add(uri)
continue
for triple in result:
next_uris.add(str(triple.o))
current_uris = next_uris - doc_uris
if not doc_uris:
return []
# Step 3: Get all document metadata properties
# Skip structural predicates that aren't useful context
SKIP_PREDICATES = {
PROV_WAS_DERIVED_FROM,
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
}
metadata_tasks = [
self.rag.triples_client.query(
s=uri, p=None, o=None, limit=50,
user=self.user, collection=self.collection,
)
for uri in doc_uris
]
metadata_results = await asyncio.gather(
*metadata_tasks, return_exceptions=True
)
doc_edges = []
for result in metadata_results:
if isinstance(result, Exception) or not result:
continue
for triple in result:
p = str(triple.p)
if p in SKIP_PREDICATES:
continue
doc_edges.append((
str(triple.s), p, str(triple.o)
))
return doc_edges
class GraphRag:
"""
CRITICAL SECURITY:
@ -371,7 +550,8 @@ class GraphRag:
async def query(
self, query, user = "trustgraph", collection = "default",
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
max_path_length = 2, streaming = False, chunk_callback = None,
max_path_length = 2, edge_limit = 25, streaming = False,
chunk_callback = None,
explain_callback = None, save_answer_callback = None,
):
"""
@ -399,6 +579,7 @@ class GraphRag:
# Generate explainability URIs upfront
session_id = str(uuid.uuid4())
q_uri = question_uri(session_id)
gnd_uri = make_grounding_uri(session_id)
exp_uri = make_exploration_uri(session_id)
foc_uri = make_focus_uri(session_id)
syn_uri = make_synthesis_uri(session_id)
@ -421,12 +602,23 @@ class GraphRag:
max_path_length = max_path_length,
)
kg, uri_map = await q.get_labelgraph(query)
kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query)
# Emit grounding explain after concept extraction
if explain_callback:
gnd_triples = set_graph(
grounding_triples(gnd_uri, q_uri, concepts),
GRAPH_RETRIEVAL
)
await explain_callback(gnd_triples, gnd_uri)
# Emit exploration explain after graph retrieval completes
if explain_callback:
exp_triples = set_graph(
exploration_triples(exp_uri, q_uri, len(kg)),
exploration_triples(
exp_uri, gnd_uri, len(kg),
entities=seed_entities,
),
GRAPH_RETRIEVAL
)
await explain_callback(exp_triples, exp_uri)
@ -453,9 +645,9 @@ class GraphRag:
if self.verbose:
logger.debug(f"Built edge map with {len(edge_map)} edges")
# Step 1: Edge Selection - LLM selects relevant edges with reasoning
selection_response = await self.prompt_client.prompt(
"kg-edge-selection",
# Step 1a: Edge Scoring - LLM scores edges for relevance
scoring_response = await self.prompt_client.prompt(
"kg-edge-scoring",
variables={
"query": query,
"knowledge": edges_with_ids
@ -463,52 +655,44 @@ class GraphRag:
)
if self.verbose:
logger.debug(f"Edge selection response: {selection_response}")
logger.debug(f"Edge scoring response: {scoring_response}")
# Parse response to get selected edge IDs and reasoning
# Response can be a string (JSONL) or a list (JSON array)
selected_ids = set()
selected_edges_with_reasoning = [] # For explain
# Parse scoring response to get edge IDs with scores
scored_edges = []
if isinstance(selection_response, list):
# JSON array response
for obj in selection_response:
if isinstance(obj, dict) and "id" in obj:
selected_ids.add(obj["id"])
# Capture original URI edge (not labels) and reasoning for explain
eid = obj["id"]
if eid in uri_map:
# Use original URIs for provenance tracing
uri_s, uri_p, uri_o = uri_map[eid]
selected_edges_with_reasoning.append({
"edge": (uri_s, uri_p, uri_o),
"reasoning": obj.get("reasoning", ""),
})
elif isinstance(selection_response, str):
# JSONL string response
for line in selection_response.strip().split('\n'):
def parse_scored_edge(obj):
if isinstance(obj, dict) and "id" in obj and "score" in obj:
try:
score = int(obj["score"])
except (ValueError, TypeError):
score = 0
scored_edges.append({"id": obj["id"], "score": score})
if isinstance(scoring_response, list):
for obj in scoring_response:
parse_scored_edge(obj)
elif isinstance(scoring_response, str):
for line in scoring_response.strip().split('\n'):
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
if "id" in obj:
selected_ids.add(obj["id"])
# Capture original URI edge (not labels) and reasoning for explain
eid = obj["id"]
if eid in uri_map:
# Use original URIs for provenance tracing
uri_s, uri_p, uri_o = uri_map[eid]
selected_edges_with_reasoning.append({
"edge": (uri_s, uri_p, uri_o),
"reasoning": obj.get("reasoning", ""),
})
parse_scored_edge(json.loads(line))
except json.JSONDecodeError:
logger.warning(f"Failed to parse edge selection line: {line}")
continue
logger.warning(
f"Failed to parse edge scoring line: {line}"
)
# Select top N edges by score
scored_edges.sort(key=lambda x: x["score"], reverse=True)
top_edges = scored_edges[:edge_limit]
selected_ids = {e["id"] for e in top_edges}
if self.verbose:
logger.debug(f"Selected {len(selected_ids)} edges: {selected_ids}")
logger.debug(
f"Scored {len(scored_edges)} edges, "
f"selected top {len(selected_ids)}"
)
# Filter to selected edges
selected_edges = []
@ -516,6 +700,82 @@ class GraphRag:
if eid in edge_map:
selected_edges.append(edge_map[eid])
# Step 1b: Edge Reasoning + Document Tracing (concurrent)
selected_edges_with_ids = [
{"id": eid, "s": s, "p": p, "o": o}
for eid in selected_ids
if eid in edge_map
for s, p, o in [edge_map[eid]]
]
# Collect selected edge URIs for document tracing
selected_edge_uris = [
uri_map[eid]
for eid in selected_ids
if eid in uri_map
]
# Run reasoning and document tracing concurrently
reasoning_task = self.prompt_client.prompt(
"kg-edge-reasoning",
variables={
"query": query,
"knowledge": selected_edges_with_ids
}
)
doc_trace_task = q.trace_source_documents(selected_edge_uris)
reasoning_response, source_documents = await asyncio.gather(
reasoning_task, doc_trace_task, return_exceptions=True
)
# Handle exceptions from gather
if isinstance(reasoning_response, Exception):
logger.warning(
f"Edge reasoning failed: {reasoning_response}"
)
reasoning_response = ""
if isinstance(source_documents, Exception):
logger.warning(
f"Document tracing failed: {source_documents}"
)
source_documents = []
if self.verbose:
logger.debug(f"Edge reasoning response: {reasoning_response}")
# Parse reasoning response and build explainability data
reasoning_map = {}
def parse_reasoning(obj):
if isinstance(obj, dict) and "id" in obj:
reasoning_map[obj["id"]] = obj.get("reasoning", "")
if isinstance(reasoning_response, list):
for obj in reasoning_response:
parse_reasoning(obj)
elif isinstance(reasoning_response, str):
for line in reasoning_response.strip().split('\n'):
line = line.strip()
if not line:
continue
try:
parse_reasoning(json.loads(line))
except json.JSONDecodeError:
logger.warning(
f"Failed to parse edge reasoning line: {line}"
)
selected_edges_with_reasoning = []
for eid in selected_ids:
if eid in uri_map:
uri_s, uri_p, uri_o = uri_map[eid]
selected_edges_with_reasoning.append({
"edge": (uri_s, uri_p, uri_o),
"reasoning": reasoning_map.get(eid, ""),
})
if self.verbose:
logger.debug(f"Filtered to {len(selected_edges)} edges")
@ -534,6 +794,18 @@ class GraphRag:
{"s": s, "p": p, "o": o}
for s, p, o in selected_edges
]
# Add source document metadata as knowledge edges
for s, p, o in source_documents:
selected_edge_dicts.append({
"s": s, "p": p, "o": o,
})
synthesis_variables = {
"query": query,
"knowledge": selected_edge_dicts,
}
if streaming and chunk_callback:
# Accumulate chunks for answer storage while forwarding to callback
accumulated_chunks = []
@ -544,10 +816,7 @@ class GraphRag:
await self.prompt_client.prompt(
"kg-synthesis",
variables={
"query": query,
"knowledge": selected_edge_dicts
},
variables=synthesis_variables,
streaming=True,
chunk_callback=accumulating_callback
)
@ -556,10 +825,7 @@ class GraphRag:
else:
resp = await self.prompt_client.prompt(
"kg-synthesis",
variables={
"query": query,
"knowledge": selected_edge_dicts
}
variables=synthesis_variables,
)
if self.verbose:
@ -570,9 +836,8 @@ class GraphRag:
synthesis_doc_id = None
answer_text = resp if resp else ""
# Save answer to librarian if callback provided
# Save answer to librarian
if save_answer_callback and answer_text:
# Generate document ID as URN matching query-time provenance format
synthesis_doc_id = f"urn:trustgraph:synthesis:{session_id}"
try:
await save_answer_callback(synthesis_doc_id, answer_text)
@ -580,13 +845,11 @@ class GraphRag:
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
except Exception as e:
logger.warning(f"Failed to save answer to librarian: {e}")
synthesis_doc_id = None # Fall back to inline content
synthesis_doc_id = None
# Generate triples with document reference or inline content
syn_triples = set_graph(
synthesis_triples(
syn_uri, foc_uri,
answer_text="" if synthesis_doc_id else answer_text,
document_id=synthesis_doc_id,
),
GRAPH_RETRIEVAL

View file

@ -39,6 +39,7 @@ class Processor(FlowProcessor):
triple_limit = params.get("triple_limit", 30)
max_subgraph_size = params.get("max_subgraph_size", 150)
max_path_length = params.get("max_path_length", 2)
edge_limit = params.get("edge_limit", 25)
super(Processor, self).__init__(
**params | {
@ -48,6 +49,7 @@ class Processor(FlowProcessor):
"triple_limit": triple_limit,
"max_subgraph_size": max_subgraph_size,
"max_path_length": max_path_length,
"edge_limit": edge_limit,
}
)
@ -55,6 +57,7 @@ class Processor(FlowProcessor):
self.default_triple_limit = triple_limit
self.default_max_subgraph_size = max_subgraph_size
self.default_max_path_length = max_path_length
self.default_edge_limit = edge_limit
# CRITICAL SECURITY: NEVER share data between users or collections
# Each user/collection combination MUST have isolated data access
@ -292,6 +295,11 @@ class Processor(FlowProcessor):
else:
max_path_length = self.default_max_path_length
if v.edge_limit:
edge_limit = v.edge_limit
else:
edge_limit = self.default_edge_limit
# Callback to save answer content to librarian
async def save_answer(doc_id, answer_text):
await self.save_answer_content(
@ -322,6 +330,7 @@ class Processor(FlowProcessor):
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
edge_limit = edge_limit,
streaming = True,
chunk_callback = send_chunk,
explain_callback = send_explainability,
@ -335,6 +344,7 @@ class Processor(FlowProcessor):
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
edge_limit = edge_limit,
explain_callback = send_explainability,
save_answer_callback = save_answer,
)