mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-27 09:26:22 +02:00
Enhance retrieval pipelines: 4-stage GraphRAG, DocRAG grounding (#697)
Enhance retrieval pipelines: 4-stage GraphRAG, DocRAG grounding, consistent PROV-O GraphRAG: - Split retrieval into 4 prompt stages: extract-concepts, kg-edge-scoring, kg-edge-reasoning, kg-synthesis (was single-stage) - Add concept extraction (grounding) for per-concept embedding - Filter main query to default graph, ignoring provenance/explainability edges - Add source document edges to knowledge graph DocumentRAG: - Add grounding step with concept extraction, matching GraphRAG's pattern: Question → Grounding → Exploration → Synthesis - Per-concept embedding and chunk retrieval with deduplication Cross-pipeline: - Make PROV-O derivation links consistent: wasGeneratedBy for first entity from Activity, wasDerivedFrom for entity-to-entity chains - Update CLIs (tg-invoke-agent, tg-invoke-graph-rag, tg-invoke-document-rag) for new explainability structure - Fix all affected unit and integration tests
This commit is contained in:
parent
29b4300808
commit
a115ec06ab
25 changed files with 1537 additions and 1008 deletions
|
|
@ -8,20 +8,23 @@ import uuid
|
|||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
|
||||
from ... schema import IRI, LITERAL
|
||||
from ... schema import Term, Triple as SchemaTriple, IRI, LITERAL, TRIPLE
|
||||
|
||||
# Provenance imports
|
||||
from trustgraph.provenance import (
|
||||
question_uri,
|
||||
grounding_uri as make_grounding_uri,
|
||||
exploration_uri as make_exploration_uri,
|
||||
focus_uri as make_focus_uri,
|
||||
synthesis_uri as make_synthesis_uri,
|
||||
question_triples,
|
||||
grounding_triples,
|
||||
exploration_triples,
|
||||
focus_triples,
|
||||
synthesis_triples,
|
||||
set_graph,
|
||||
GRAPH_RETRIEVAL,
|
||||
GRAPH_RETRIEVAL, GRAPH_SOURCE,
|
||||
TG_CONTAINS, PROV_WAS_DERIVED_FROM,
|
||||
)
|
||||
|
||||
# Module logger
|
||||
|
|
@ -47,6 +50,8 @@ def edge_id(s, p, o):
|
|||
edge_str = f"{s}|{p}|{o}"
|
||||
return hashlib.sha256(edge_str.encode()).hexdigest()[:8]
|
||||
|
||||
|
||||
|
||||
class LRUCacheWithTTL:
|
||||
"""LRU cache with TTL for label caching
|
||||
|
||||
|
|
@ -105,42 +110,88 @@ class Query:
|
|||
self.max_subgraph_size = max_subgraph_size
|
||||
self.max_path_length = max_path_length
|
||||
|
||||
async def get_vector(self, query):
|
||||
async def extract_concepts(self, query):
|
||||
"""Extract key concepts from query for independent embedding."""
|
||||
response = await self.rag.prompt_client.prompt(
|
||||
"extract-concepts",
|
||||
variables={"query": query}
|
||||
)
|
||||
|
||||
concepts = []
|
||||
if isinstance(response, str):
|
||||
for line in response.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if line:
|
||||
concepts.append(line)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Extracted concepts: {concepts}")
|
||||
|
||||
# Fall back to raw query if extraction returns nothing
|
||||
return concepts if concepts else [query]
|
||||
|
||||
async def get_vectors(self, concepts):
|
||||
"""Embed multiple concepts concurrently."""
|
||||
if self.verbose:
|
||||
logger.debug("Computing embeddings...")
|
||||
|
||||
qembeds = await self.rag.embeddings_client.embed([query])
|
||||
qembeds = await self.rag.embeddings_client.embed(concepts)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Done.")
|
||||
|
||||
# Return the vector set for the first (only) text
|
||||
return qembeds[0] if qembeds else []
|
||||
return qembeds
|
||||
|
||||
async def get_entities(self, query):
|
||||
"""
|
||||
Extract concepts from query, embed them, and retrieve matching entities.
|
||||
|
||||
vectors = await self.get_vector(query)
|
||||
Returns:
|
||||
tuple: (entities, concepts) where entities is a list of entity URI
|
||||
strings and concepts is the list of concept strings extracted
|
||||
from the query.
|
||||
"""
|
||||
|
||||
concepts = await self.extract_concepts(query)
|
||||
|
||||
vectors = await self.get_vectors(concepts)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Getting entities...")
|
||||
|
||||
entity_matches = await self.rag.graph_embeddings_client.query(
|
||||
vector=vectors, limit=self.entity_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
# Query entity matches for each concept concurrently
|
||||
per_concept_limit = max(
|
||||
1, self.entity_limit // len(vectors)
|
||||
)
|
||||
|
||||
entities = [
|
||||
term_to_string(e.entity)
|
||||
for e in entity_matches
|
||||
entity_tasks = [
|
||||
self.rag.graph_embeddings_client.query(
|
||||
vector=v, limit=per_concept_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
for v in vectors
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*entity_tasks, return_exceptions=True)
|
||||
|
||||
# Deduplicate while preserving order
|
||||
seen = set()
|
||||
entities = []
|
||||
for result in results:
|
||||
if isinstance(result, Exception) or not result:
|
||||
continue
|
||||
for e in result:
|
||||
entity = term_to_string(e.entity)
|
||||
if entity not in seen:
|
||||
seen.add(entity)
|
||||
entities.append(entity)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Entities:")
|
||||
for ent in entities:
|
||||
logger.debug(f" {ent}")
|
||||
|
||||
return entities
|
||||
return entities, concepts
|
||||
|
||||
async def maybe_label(self, e):
|
||||
|
||||
|
|
@ -156,6 +207,7 @@ class Query:
|
|||
res = await self.rag.triples_client.query(
|
||||
s=e, p=LABEL, o=None, limit=1,
|
||||
user=self.user, collection=self.collection,
|
||||
g="",
|
||||
)
|
||||
|
||||
if len(res) == 0:
|
||||
|
|
@ -177,19 +229,19 @@ class Query:
|
|||
s=entity, p=None, o=None,
|
||||
limit=limit_per_entity,
|
||||
user=self.user, collection=self.collection,
|
||||
batch_size=20,
|
||||
batch_size=20, g="",
|
||||
),
|
||||
self.rag.triples_client.query_stream(
|
||||
s=None, p=entity, o=None,
|
||||
limit=limit_per_entity,
|
||||
user=self.user, collection=self.collection,
|
||||
batch_size=20,
|
||||
batch_size=20, g="",
|
||||
),
|
||||
self.rag.triples_client.query_stream(
|
||||
s=None, p=None, o=entity,
|
||||
limit=limit_per_entity,
|
||||
user=self.user, collection=self.collection,
|
||||
batch_size=20,
|
||||
batch_size=20, g="",
|
||||
)
|
||||
])
|
||||
|
||||
|
|
@ -262,8 +314,16 @@ class Query:
|
|||
subgraph.update(batch_result)
|
||||
|
||||
async def get_subgraph(self, query):
|
||||
"""
|
||||
Get subgraph by extracting concepts, finding entities, and traversing.
|
||||
|
||||
entities = await self.get_entities(query)
|
||||
Returns:
|
||||
tuple: (subgraph, entities, concepts) where subgraph is a list of
|
||||
(s, p, o) tuples, entities is the seed entity list, and concepts
|
||||
is the extracted concept list.
|
||||
"""
|
||||
|
||||
entities, concepts = await self.get_entities(query)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Getting subgraph...")
|
||||
|
|
@ -271,7 +331,7 @@ class Query:
|
|||
# Use optimized batch traversal instead of sequential processing
|
||||
subgraph = await self.follow_edges_batch(entities, self.max_path_length)
|
||||
|
||||
return list(subgraph)
|
||||
return list(subgraph), entities, concepts
|
||||
|
||||
async def resolve_labels_batch(self, entities):
|
||||
"""Resolve labels for multiple entities in parallel"""
|
||||
|
|
@ -286,11 +346,13 @@ class Query:
|
|||
Get subgraph with labels resolved for display.
|
||||
|
||||
Returns:
|
||||
tuple: (labeled_edges, uri_map) where:
|
||||
tuple: (labeled_edges, uri_map, entities, concepts) where:
|
||||
- labeled_edges: list of (label_s, label_p, label_o) tuples
|
||||
- uri_map: dict mapping edge_id(label_s, label_p, label_o) -> (uri_s, uri_p, uri_o)
|
||||
- entities: list of seed entity URI strings
|
||||
- concepts: list of concept strings extracted from query
|
||||
"""
|
||||
subgraph = await self.get_subgraph(query)
|
||||
subgraph, entities, concepts = await self.get_subgraph(query)
|
||||
|
||||
# Filter out label triples
|
||||
filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL]
|
||||
|
|
@ -338,8 +400,125 @@ class Query:
|
|||
if self.verbose:
|
||||
logger.debug("Done.")
|
||||
|
||||
return labeled_edges, uri_map
|
||||
|
||||
return labeled_edges, uri_map, entities, concepts
|
||||
|
||||
async def trace_source_documents(self, edge_uris):
|
||||
"""
|
||||
Trace selected edges back to their source documents via provenance.
|
||||
|
||||
Follows the chain: edge → subgraph (via tg:contains) → chunk →
|
||||
page → document (via prov:wasDerivedFrom), all in urn:graph:source.
|
||||
|
||||
Args:
|
||||
edge_uris: List of (s, p, o) URI string tuples
|
||||
|
||||
Returns:
|
||||
List of unique document titles
|
||||
"""
|
||||
# Step 1: Find subgraphs containing these edges via tg:contains
|
||||
subgraph_tasks = []
|
||||
for s, p, o in edge_uris:
|
||||
quoted = Term(
|
||||
type=TRIPLE,
|
||||
triple=SchemaTriple(
|
||||
s=Term(type=IRI, iri=s),
|
||||
p=Term(type=IRI, iri=p),
|
||||
o=Term(type=IRI, iri=o),
|
||||
)
|
||||
)
|
||||
subgraph_tasks.append(
|
||||
self.rag.triples_client.query(
|
||||
s=None, p=TG_CONTAINS, o=quoted, limit=1,
|
||||
user=self.user, collection=self.collection,
|
||||
g=GRAPH_SOURCE,
|
||||
)
|
||||
)
|
||||
|
||||
subgraph_results = await asyncio.gather(
|
||||
*subgraph_tasks, return_exceptions=True
|
||||
)
|
||||
|
||||
# Collect unique subgraph URIs
|
||||
subgraph_uris = set()
|
||||
for result in subgraph_results:
|
||||
if isinstance(result, Exception) or not result:
|
||||
continue
|
||||
for triple in result:
|
||||
subgraph_uris.add(str(triple.s))
|
||||
|
||||
if not subgraph_uris:
|
||||
return []
|
||||
|
||||
# Step 2: Walk prov:wasDerivedFrom chain to find documents
|
||||
# Each level: query ?entity prov:wasDerivedFrom ?parent
|
||||
# Stop when we find entities typed tg:Document
|
||||
current_uris = subgraph_uris
|
||||
doc_uris = set()
|
||||
|
||||
for depth in range(4): # Max depth: subgraph → chunk → page → doc
|
||||
if not current_uris:
|
||||
break
|
||||
|
||||
derivation_tasks = [
|
||||
self.rag.triples_client.query(
|
||||
s=uri, p=PROV_WAS_DERIVED_FROM, o=None, limit=5,
|
||||
user=self.user, collection=self.collection,
|
||||
g=GRAPH_SOURCE,
|
||||
)
|
||||
for uri in current_uris
|
||||
]
|
||||
|
||||
derivation_results = await asyncio.gather(
|
||||
*derivation_tasks, return_exceptions=True
|
||||
)
|
||||
|
||||
# URIs with no parent are root documents
|
||||
next_uris = set()
|
||||
for uri, result in zip(current_uris, derivation_results):
|
||||
if isinstance(result, Exception) or not result:
|
||||
doc_uris.add(uri)
|
||||
continue
|
||||
for triple in result:
|
||||
next_uris.add(str(triple.o))
|
||||
|
||||
current_uris = next_uris - doc_uris
|
||||
|
||||
if not doc_uris:
|
||||
return []
|
||||
|
||||
# Step 3: Get all document metadata properties
|
||||
# Skip structural predicates that aren't useful context
|
||||
SKIP_PREDICATES = {
|
||||
PROV_WAS_DERIVED_FROM,
|
||||
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
|
||||
}
|
||||
|
||||
metadata_tasks = [
|
||||
self.rag.triples_client.query(
|
||||
s=uri, p=None, o=None, limit=50,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
for uri in doc_uris
|
||||
]
|
||||
|
||||
metadata_results = await asyncio.gather(
|
||||
*metadata_tasks, return_exceptions=True
|
||||
)
|
||||
|
||||
doc_edges = []
|
||||
for result in metadata_results:
|
||||
if isinstance(result, Exception) or not result:
|
||||
continue
|
||||
for triple in result:
|
||||
p = str(triple.p)
|
||||
if p in SKIP_PREDICATES:
|
||||
continue
|
||||
doc_edges.append((
|
||||
str(triple.s), p, str(triple.o)
|
||||
))
|
||||
|
||||
return doc_edges
|
||||
|
||||
class GraphRag:
|
||||
"""
|
||||
CRITICAL SECURITY:
|
||||
|
|
@ -371,7 +550,8 @@ class GraphRag:
|
|||
async def query(
|
||||
self, query, user = "trustgraph", collection = "default",
|
||||
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
|
||||
max_path_length = 2, streaming = False, chunk_callback = None,
|
||||
max_path_length = 2, edge_limit = 25, streaming = False,
|
||||
chunk_callback = None,
|
||||
explain_callback = None, save_answer_callback = None,
|
||||
):
|
||||
"""
|
||||
|
|
@ -399,6 +579,7 @@ class GraphRag:
|
|||
# Generate explainability URIs upfront
|
||||
session_id = str(uuid.uuid4())
|
||||
q_uri = question_uri(session_id)
|
||||
gnd_uri = make_grounding_uri(session_id)
|
||||
exp_uri = make_exploration_uri(session_id)
|
||||
foc_uri = make_focus_uri(session_id)
|
||||
syn_uri = make_synthesis_uri(session_id)
|
||||
|
|
@ -421,12 +602,23 @@ class GraphRag:
|
|||
max_path_length = max_path_length,
|
||||
)
|
||||
|
||||
kg, uri_map = await q.get_labelgraph(query)
|
||||
kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query)
|
||||
|
||||
# Emit grounding explain after concept extraction
|
||||
if explain_callback:
|
||||
gnd_triples = set_graph(
|
||||
grounding_triples(gnd_uri, q_uri, concepts),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(gnd_triples, gnd_uri)
|
||||
|
||||
# Emit exploration explain after graph retrieval completes
|
||||
if explain_callback:
|
||||
exp_triples = set_graph(
|
||||
exploration_triples(exp_uri, q_uri, len(kg)),
|
||||
exploration_triples(
|
||||
exp_uri, gnd_uri, len(kg),
|
||||
entities=seed_entities,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(exp_triples, exp_uri)
|
||||
|
|
@ -453,9 +645,9 @@ class GraphRag:
|
|||
if self.verbose:
|
||||
logger.debug(f"Built edge map with {len(edge_map)} edges")
|
||||
|
||||
# Step 1: Edge Selection - LLM selects relevant edges with reasoning
|
||||
selection_response = await self.prompt_client.prompt(
|
||||
"kg-edge-selection",
|
||||
# Step 1a: Edge Scoring - LLM scores edges for relevance
|
||||
scoring_response = await self.prompt_client.prompt(
|
||||
"kg-edge-scoring",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": edges_with_ids
|
||||
|
|
@ -463,52 +655,44 @@ class GraphRag:
|
|||
)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Edge selection response: {selection_response}")
|
||||
logger.debug(f"Edge scoring response: {scoring_response}")
|
||||
|
||||
# Parse response to get selected edge IDs and reasoning
|
||||
# Response can be a string (JSONL) or a list (JSON array)
|
||||
selected_ids = set()
|
||||
selected_edges_with_reasoning = [] # For explain
|
||||
# Parse scoring response to get edge IDs with scores
|
||||
scored_edges = []
|
||||
|
||||
if isinstance(selection_response, list):
|
||||
# JSON array response
|
||||
for obj in selection_response:
|
||||
if isinstance(obj, dict) and "id" in obj:
|
||||
selected_ids.add(obj["id"])
|
||||
# Capture original URI edge (not labels) and reasoning for explain
|
||||
eid = obj["id"]
|
||||
if eid in uri_map:
|
||||
# Use original URIs for provenance tracing
|
||||
uri_s, uri_p, uri_o = uri_map[eid]
|
||||
selected_edges_with_reasoning.append({
|
||||
"edge": (uri_s, uri_p, uri_o),
|
||||
"reasoning": obj.get("reasoning", ""),
|
||||
})
|
||||
elif isinstance(selection_response, str):
|
||||
# JSONL string response
|
||||
for line in selection_response.strip().split('\n'):
|
||||
def parse_scored_edge(obj):
|
||||
if isinstance(obj, dict) and "id" in obj and "score" in obj:
|
||||
try:
|
||||
score = int(obj["score"])
|
||||
except (ValueError, TypeError):
|
||||
score = 0
|
||||
scored_edges.append({"id": obj["id"], "score": score})
|
||||
|
||||
if isinstance(scoring_response, list):
|
||||
for obj in scoring_response:
|
||||
parse_scored_edge(obj)
|
||||
elif isinstance(scoring_response, str):
|
||||
for line in scoring_response.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
if "id" in obj:
|
||||
selected_ids.add(obj["id"])
|
||||
# Capture original URI edge (not labels) and reasoning for explain
|
||||
eid = obj["id"]
|
||||
if eid in uri_map:
|
||||
# Use original URIs for provenance tracing
|
||||
uri_s, uri_p, uri_o = uri_map[eid]
|
||||
selected_edges_with_reasoning.append({
|
||||
"edge": (uri_s, uri_p, uri_o),
|
||||
"reasoning": obj.get("reasoning", ""),
|
||||
})
|
||||
parse_scored_edge(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse edge selection line: {line}")
|
||||
continue
|
||||
logger.warning(
|
||||
f"Failed to parse edge scoring line: {line}"
|
||||
)
|
||||
|
||||
# Select top N edges by score
|
||||
scored_edges.sort(key=lambda x: x["score"], reverse=True)
|
||||
top_edges = scored_edges[:edge_limit]
|
||||
selected_ids = {e["id"] for e in top_edges}
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Selected {len(selected_ids)} edges: {selected_ids}")
|
||||
logger.debug(
|
||||
f"Scored {len(scored_edges)} edges, "
|
||||
f"selected top {len(selected_ids)}"
|
||||
)
|
||||
|
||||
# Filter to selected edges
|
||||
selected_edges = []
|
||||
|
|
@ -516,6 +700,82 @@ class GraphRag:
|
|||
if eid in edge_map:
|
||||
selected_edges.append(edge_map[eid])
|
||||
|
||||
# Step 1b: Edge Reasoning + Document Tracing (concurrent)
|
||||
selected_edges_with_ids = [
|
||||
{"id": eid, "s": s, "p": p, "o": o}
|
||||
for eid in selected_ids
|
||||
if eid in edge_map
|
||||
for s, p, o in [edge_map[eid]]
|
||||
]
|
||||
|
||||
# Collect selected edge URIs for document tracing
|
||||
selected_edge_uris = [
|
||||
uri_map[eid]
|
||||
for eid in selected_ids
|
||||
if eid in uri_map
|
||||
]
|
||||
|
||||
# Run reasoning and document tracing concurrently
|
||||
reasoning_task = self.prompt_client.prompt(
|
||||
"kg-edge-reasoning",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": selected_edges_with_ids
|
||||
}
|
||||
)
|
||||
doc_trace_task = q.trace_source_documents(selected_edge_uris)
|
||||
|
||||
reasoning_response, source_documents = await asyncio.gather(
|
||||
reasoning_task, doc_trace_task, return_exceptions=True
|
||||
)
|
||||
|
||||
# Handle exceptions from gather
|
||||
if isinstance(reasoning_response, Exception):
|
||||
logger.warning(
|
||||
f"Edge reasoning failed: {reasoning_response}"
|
||||
)
|
||||
reasoning_response = ""
|
||||
if isinstance(source_documents, Exception):
|
||||
logger.warning(
|
||||
f"Document tracing failed: {source_documents}"
|
||||
)
|
||||
source_documents = []
|
||||
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Edge reasoning response: {reasoning_response}")
|
||||
|
||||
# Parse reasoning response and build explainability data
|
||||
reasoning_map = {}
|
||||
|
||||
def parse_reasoning(obj):
|
||||
if isinstance(obj, dict) and "id" in obj:
|
||||
reasoning_map[obj["id"]] = obj.get("reasoning", "")
|
||||
|
||||
if isinstance(reasoning_response, list):
|
||||
for obj in reasoning_response:
|
||||
parse_reasoning(obj)
|
||||
elif isinstance(reasoning_response, str):
|
||||
for line in reasoning_response.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
parse_reasoning(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse edge reasoning line: {line}"
|
||||
)
|
||||
|
||||
selected_edges_with_reasoning = []
|
||||
for eid in selected_ids:
|
||||
if eid in uri_map:
|
||||
uri_s, uri_p, uri_o = uri_map[eid]
|
||||
selected_edges_with_reasoning.append({
|
||||
"edge": (uri_s, uri_p, uri_o),
|
||||
"reasoning": reasoning_map.get(eid, ""),
|
||||
})
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Filtered to {len(selected_edges)} edges")
|
||||
|
||||
|
|
@ -534,6 +794,18 @@ class GraphRag:
|
|||
{"s": s, "p": p, "o": o}
|
||||
for s, p, o in selected_edges
|
||||
]
|
||||
|
||||
# Add source document metadata as knowledge edges
|
||||
for s, p, o in source_documents:
|
||||
selected_edge_dicts.append({
|
||||
"s": s, "p": p, "o": o,
|
||||
})
|
||||
|
||||
synthesis_variables = {
|
||||
"query": query,
|
||||
"knowledge": selected_edge_dicts,
|
||||
}
|
||||
|
||||
if streaming and chunk_callback:
|
||||
# Accumulate chunks for answer storage while forwarding to callback
|
||||
accumulated_chunks = []
|
||||
|
|
@ -544,10 +816,7 @@ class GraphRag:
|
|||
|
||||
await self.prompt_client.prompt(
|
||||
"kg-synthesis",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": selected_edge_dicts
|
||||
},
|
||||
variables=synthesis_variables,
|
||||
streaming=True,
|
||||
chunk_callback=accumulating_callback
|
||||
)
|
||||
|
|
@ -556,10 +825,7 @@ class GraphRag:
|
|||
else:
|
||||
resp = await self.prompt_client.prompt(
|
||||
"kg-synthesis",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": selected_edge_dicts
|
||||
}
|
||||
variables=synthesis_variables,
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
|
|
@ -570,9 +836,8 @@ class GraphRag:
|
|||
synthesis_doc_id = None
|
||||
answer_text = resp if resp else ""
|
||||
|
||||
# Save answer to librarian if callback provided
|
||||
# Save answer to librarian
|
||||
if save_answer_callback and answer_text:
|
||||
# Generate document ID as URN matching query-time provenance format
|
||||
synthesis_doc_id = f"urn:trustgraph:synthesis:{session_id}"
|
||||
try:
|
||||
await save_answer_callback(synthesis_doc_id, answer_text)
|
||||
|
|
@ -580,13 +845,11 @@ class GraphRag:
|
|||
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save answer to librarian: {e}")
|
||||
synthesis_doc_id = None # Fall back to inline content
|
||||
synthesis_doc_id = None
|
||||
|
||||
# Generate triples with document reference or inline content
|
||||
syn_triples = set_graph(
|
||||
synthesis_triples(
|
||||
syn_uri, foc_uri,
|
||||
answer_text="" if synthesis_doc_id else answer_text,
|
||||
document_id=synthesis_doc_id,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ class Processor(FlowProcessor):
|
|||
triple_limit = params.get("triple_limit", 30)
|
||||
max_subgraph_size = params.get("max_subgraph_size", 150)
|
||||
max_path_length = params.get("max_path_length", 2)
|
||||
edge_limit = params.get("edge_limit", 25)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
|
|
@ -48,6 +49,7 @@ class Processor(FlowProcessor):
|
|||
"triple_limit": triple_limit,
|
||||
"max_subgraph_size": max_subgraph_size,
|
||||
"max_path_length": max_path_length,
|
||||
"edge_limit": edge_limit,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -55,6 +57,7 @@ class Processor(FlowProcessor):
|
|||
self.default_triple_limit = triple_limit
|
||||
self.default_max_subgraph_size = max_subgraph_size
|
||||
self.default_max_path_length = max_path_length
|
||||
self.default_edge_limit = edge_limit
|
||||
|
||||
# CRITICAL SECURITY: NEVER share data between users or collections
|
||||
# Each user/collection combination MUST have isolated data access
|
||||
|
|
@ -292,6 +295,11 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
max_path_length = self.default_max_path_length
|
||||
|
||||
if v.edge_limit:
|
||||
edge_limit = v.edge_limit
|
||||
else:
|
||||
edge_limit = self.default_edge_limit
|
||||
|
||||
# Callback to save answer content to librarian
|
||||
async def save_answer(doc_id, answer_text):
|
||||
await self.save_answer_content(
|
||||
|
|
@ -322,6 +330,7 @@ class Processor(FlowProcessor):
|
|||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
edge_limit = edge_limit,
|
||||
streaming = True,
|
||||
chunk_callback = send_chunk,
|
||||
explain_callback = send_explainability,
|
||||
|
|
@ -335,6 +344,7 @@ class Processor(FlowProcessor):
|
|||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
edge_limit = edge_limit,
|
||||
explain_callback = send_explainability,
|
||||
save_answer_callback = save_answer,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue