feat: replace LLM edge scoring with cross-encoder reranker in GraphRAG

Replace the three-prompt LLM scoring pipeline (kg-edge-scoring,
kg-edge-reasoning, kg-edge-selection) with a cross-encoder reranker
service backed by FlashRank. The new hop_and_filter() method performs
iterative graph traversal with semantic scoring at each hop, replacing
the previous follow_edges/get_subgraph approach.

- Add reranker service (trustgraph-base client/service, FlashRank processor)
- Add gateway dispatch for reranker via API and WebSocket
- Rewrite GraphRAG pipeline: hop_and_filter() with per-hop cross-encoder scoring
- Remove kg_prompt() and edge_score_limit from prompt client
- Update provenance: add tg:EdgeSelection type, tg:concept, tg:score predicates
- Update CLIs (tg-invoke-graph-rag, tg-show-explain-trace) for new metadata
- Add tg-invoke-reranker CLI tool
- Add tech spec and UX developer guidance
- Update all unit and integration tests
This commit is contained in:
Cyber MacGeddon 2026-06-30 09:39:35 +01:00
parent 1aa9549912
commit 1346cbebb4
43 changed files with 1613 additions and 792 deletions

View file

@ -19,6 +19,7 @@ dependencies = [
"faiss-cpu",
"falkordb",
"fastembed",
"flashrank",
"ibis",
"jsonschema",
"langchain",
@ -83,6 +84,7 @@ graph-embeddings-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone:
graph-embeddings-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run"
graph-embeddings = "trustgraph.embeddings.graph_embeddings:run"
graph-rag = "trustgraph.retrieval.graph_rag:run"
reranker-flashrank = "trustgraph.reranker.flashrank:run"
kg-extract-agent = "trustgraph.extract.kg.agent:run"
kg-extract-definitions = "trustgraph.extract.kg.definitions:run"
kg-extract-rows = "trustgraph.extract.kg.rows:run"

View file

@ -37,6 +37,7 @@ from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
from . document_embeddings_query import DocumentEmbeddingsQueryRequestor
from . row_embeddings_query import RowEmbeddingsQueryRequestor
from . mcp_tool import McpToolRequestor
from . reranker import RerankerRequestor
from . text_load import TextLoad
from . document_load import DocumentLoad
@ -74,6 +75,7 @@ request_response_dispatchers = {
"structured-diag": StructuredDiagRequestor,
"row-embeddings": RowEmbeddingsQueryRequestor,
"sparql": SparqlQueryRequestor,
"reranker": RerankerRequestor,
}
system_dispatchers = {

View file

@ -0,0 +1,31 @@
from ... schema import RerankerRequest, RerankerResponse
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
class RerankerRequestor(ServiceRequestor):
def __init__(
self, backend, request_queue, response_queue, timeout,
consumer, subscriber,
):
super(RerankerRequestor, self).__init__(
backend=backend,
request_queue=request_queue,
response_queue=response_queue,
request_schema=RerankerRequest,
response_schema=RerankerResponse,
subscription = subscriber,
consumer_name = consumer,
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("reranker")
self.response_translator = TranslatorRegistry.get_response_translator("reranker")
def to_request(self, body):
return self.request_translator.decode(body)
def from_response(self, message):
return self.response_translator.encode_with_completion(message)

View file

@ -518,6 +518,7 @@ _FLOW_SERVICES = {
"structured-diag": "structured-query:read",
"row-embeddings": "row-embeddings:read",
"sparql": "sparql:read",
"reranker": "reranker",
}
for _kind, _cap in _FLOW_SERVICES.items():
_register_flow_kind("flow-service", _kind, _cap)

View file

@ -72,6 +72,7 @@ _READER_CAPS = {
"row-embeddings:read",
"llm",
"embeddings",
"reranker",
"mcp",
"config:read",
"flows:read",

View file

@ -0,0 +1 @@

View file

@ -0,0 +1,2 @@
from . processor import *

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from . processor import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,109 @@
"""
Reranker service using flashrank.
Scores query-document pairs and returns the top results ranked by
relevance.
"""
import asyncio
import logging
from ... base import RerankerService
from ... schema import RerankerResult
from flashrank import Ranker, RerankRequest
logger = logging.getLogger(__name__)
default_ident = "reranker"
default_model = "ms-marco-MiniLM-L-12-v2"
class Processor(RerankerService):
def __init__(self, **params):
model = params.get("model", default_model)
super(Processor, self).__init__(
**params | { "model": model }
)
self.default_model = model
self.cached_model_name = None
self.ranker = None
self._load_model(model)
def _load_model(self, model_name):
if self.cached_model_name != model_name:
logger.info(f"Loading flashrank model: {model_name}")
self.ranker = Ranker(model_name=model_name)
self.cached_model_name = model_name
logger.info(f"flashrank model {model_name} loaded successfully")
else:
logger.debug(f"Using cached model: {model_name}")
def _run_rerank(self, query, passages):
request = RerankRequest(query=query, passages=passages)
return self.ranker.rerank(request)
async def on_rerank(self, queries, documents, limit, model=None):
if not queries or not documents:
return []
use_model = model or self.default_model
if self.cached_model_name != use_model:
await asyncio.to_thread(self._load_model, use_model)
passages = [
{"id": d.document_id, "text": d.document_text}
for d in documents
]
best_scores = {}
for q in queries:
ranked = await asyncio.to_thread(
self._run_rerank, q.query_text, passages,
)
for r in ranked:
doc_id = r["id"]
score = r["score"]
score = float(score)
if doc_id not in best_scores or score > best_scores[doc_id][1]:
best_scores[doc_id] = (q.query_id, score)
results = sorted(
best_scores.items(),
key=lambda x: x[1][1],
reverse=True,
)[:limit]
return [
RerankerResult(
document_id=doc_id,
query_id=query_id,
score=score,
)
for doc_id, (query_id, score) in results
]
@staticmethod
def add_args(parser):
RerankerService.add_args(parser)
parser.add_argument(
'-m', '--model',
default=default_model,
help=f'Reranker model (default: {default_model})'
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -120,7 +120,7 @@ class Query:
def __init__(
self, rag, collection, verbose,
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
max_path_length=2, track_usage=None,
max_path_length=2, edge_limit=25, track_usage=None,
):
self.rag = rag
self.collection = collection
@ -129,6 +129,7 @@ class Query:
self.triple_limit = triple_limit
self.max_subgraph_size = max_subgraph_size
self.max_path_length = max_path_length
self.edge_limit = edge_limit
self.track_usage = track_usage
async def extract_concepts(self, query):
@ -217,12 +218,9 @@ class Query:
logger.debug(f" {ent}")
return entities, concepts
async def maybe_label(self, e):
# The label cache lives on a per-request GraphRag instance — no
# cross-request isolation concern. The collection prefix keeps
# entries from different collections distinct within one request.
cache_key = f"{self.collection}:{e}"
cached_label = self.rag.label_cache.get(cache_key)
@ -244,11 +242,10 @@ class Query:
return label
async def execute_batch_triple_queries(self, entities, limit_per_entity):
"""Execute triple queries for multiple entities concurrently using streaming"""
"""Execute triple queries for multiple entities concurrently."""
tasks = []
for entity in entities:
# Create concurrent streaming tasks for all 3 query types per entity
tasks.extend([
self.rag.triples_client.query_stream(
s=entity, p=None, o=None,
@ -270,10 +267,8 @@ class Query:
)
])
# Execute all queries concurrently
results = await asyncio.gather(*tasks, return_exceptions=True)
# Combine all results
all_triples = []
for result in results:
if not isinstance(result, Exception) and result is not None:
@ -281,168 +276,151 @@ class Query:
return all_triples
async def follow_edges_batch(self, entities, max_depth):
"""Optimized iterative graph traversal with batching.
Returns:
tuple: (subgraph, term_map) where subgraph is a set of
(str, str, str) tuples and term_map maps each string tuple
to its original (Term, Term, Term) for type-preserving
provenance.
"""
visited = set()
current_level = set(entities)
subgraph = set()
term_map = {} # (str, str, str) -> (Term, Term, Term)
for depth in range(max_depth):
if not current_level or len(subgraph) >= self.max_subgraph_size:
break
# Filter out already visited entities
unvisited_entities = [e for e in current_level if e not in visited]
if not unvisited_entities:
break
# Batch query all unvisited entities at current level
triples = await self.execute_batch_triple_queries(
unvisited_entities, self.triple_limit
)
# Process results and collect next level entities
next_level = set()
for triple in triples:
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
subgraph.add(triple_tuple)
term_map[triple_tuple] = (to_term(triple.s), to_term(triple.p), to_term(triple.o))
# Collect entities for next level (only from s and o positions)
if depth < max_depth - 1: # Don't collect for final depth
s, p, o = triple_tuple
if s not in visited:
next_level.add(s)
if o not in visited:
next_level.add(o)
# Stop if subgraph size limit reached
if len(subgraph) >= self.max_subgraph_size:
return subgraph, term_map
# Update for next iteration
visited.update(current_level)
current_level = next_level
return subgraph, term_map
async def follow_edges(self, ent, subgraph, path_length):
"""Legacy method - replaced by follow_edges_batch"""
# Maintain backward compatibility with early termination checks
if path_length <= 0:
return
if len(subgraph) >= self.max_subgraph_size:
return
# For backward compatibility, convert to new approach
batch_result, _ = await self.follow_edges_batch([ent], path_length)
subgraph.update(batch_result)
async def get_subgraph(self, query):
"""
Get subgraph by extracting concepts, finding entities, and traversing.
Returns:
tuple: (subgraph, term_map, entities, concepts) where subgraph is
a list of (s, p, o) string tuples, term_map maps each string
tuple to its original (Term, Term, Term), entities is the seed
entity list, and concepts is the extracted concept list.
"""
entities, concepts = await self.get_entities(query)
if self.verbose:
logger.debug("Getting subgraph...")
# Use optimized batch traversal instead of sequential processing
subgraph, term_map = await self.follow_edges_batch(entities, self.max_path_length)
return list(subgraph), term_map, entities, concepts
async def resolve_labels_batch(self, entities):
"""Resolve labels for multiple entities in parallel"""
tasks = []
for entity in entities:
tasks.append(self.maybe_label(entity))
"""Resolve labels for multiple entities in parallel."""
tasks = [self.maybe_label(entity) for entity in entities]
return await asyncio.gather(*tasks, return_exceptions=True)
async def get_labelgraph(self, query):
"""
Get subgraph with labels resolved for display.
async def hop_and_filter(self, seed_entities, concepts):
"""Iterative hop-and-filter graph traversal with cross-encoder.
At each hop:
1. Retrieve all edges one hop from the frontier.
2. Resolve labels and represent each edge as "{p} {o}".
3. Score edges against concepts using the cross-encoder.
4. Select the top edges; their target nodes become the next
frontier.
Returns:
tuple: (labeled_edges, uri_map, entities, concepts) where:
- labeled_edges: list of (label_s, label_p, label_o) tuples
- uri_map: dict mapping edge_id(label_s, label_p, label_o) -> (uri_s, uri_p, uri_o)
- entities: list of seed entity URI strings
- concepts: list of concept strings extracted from query
tuple: (selected_edges, uri_map, edge_metadata) where:
- selected_edges: list of (label_s, label_p, label_o)
- uri_map: dict mapping edge_id -> (Term, Term, Term)
- edge_metadata: dict mapping edge_id -> {concept, score}
"""
subgraph, term_map, entities, concepts = await self.get_subgraph(query)
all_selected_edges = []
uri_map = {}
edge_metadata = {}
frontier = set(seed_entities)
visited_entities = set()
seen_edges = set()
# Filter out label triples
filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL]
for hop in range(self.max_path_length):
if not frontier:
break
# Collect all unique entities that need label resolution
entities_to_resolve = set()
for s, p, o in filtered_subgraph:
entities_to_resolve.update([s, p, o])
unvisited = [e for e in frontier if e not in visited_entities]
if not unvisited:
break
# Batch resolve labels for all entities in parallel
entity_list = list(entities_to_resolve)
resolved_labels = await self.resolve_labels_batch(entity_list)
if self.verbose:
logger.debug(
f"Hop {hop + 1}: {len(unvisited)} frontier entities"
)
# Create entity-to-label mapping
label_map = {}
for entity, label in zip(entity_list, resolved_labels):
if not isinstance(label, Exception):
label_map[entity] = label
else:
label_map[entity] = entity # Fallback to entity itself
# Apply labels to subgraph and build URI mapping
labeled_edges = []
uri_map = {} # Maps edge_id of labeled edge -> original Term triple
for s, p, o in filtered_subgraph:
labeled_triple = (
label_map.get(s, s),
label_map.get(p, p),
label_map.get(o, o)
# Retrieve edges one hop from frontier
triples = await self.execute_batch_triple_queries(
unvisited, self.triple_limit,
)
labeled_edges.append(labeled_triple)
# Map from labeled edge ID to original Terms (preserving types)
labeled_eid = edge_id(labeled_triple[0], labeled_triple[1], labeled_triple[2])
uri_map[labeled_eid] = term_map.get((s, p, o), (s, p, o))
# Deduplicate and filter already-seen edges
hop_triples = []
hop_term_map = {}
for triple in triples:
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
if triple_tuple[1] == LABEL:
continue
if triple_tuple in seen_edges:
continue
seen_edges.add(triple_tuple)
hop_triples.append(triple_tuple)
hop_term_map[triple_tuple] = (
to_term(triple.s), to_term(triple.p), to_term(triple.o),
)
labeled_edges = labeled_edges[0:self.max_subgraph_size]
if not hop_triples:
visited_entities.update(frontier)
break
if self.verbose:
logger.debug("Subgraph:")
for edge in labeled_edges:
logger.debug(f" {str(edge)}")
if self.verbose:
logger.debug(
f"Hop {hop + 1}: {len(hop_triples)} candidate edges"
)
if self.verbose:
logger.debug("Done.")
# Resolve labels for all entities in hop edges
entities_to_resolve = set()
for s, p, o in hop_triples:
entities_to_resolve.update([s, p, o])
return labeled_edges, uri_map, entities, concepts
entity_list = list(entities_to_resolve)
resolved = await self.resolve_labels_batch(entity_list)
label_map = {}
for entity, label in zip(entity_list, resolved):
if not isinstance(label, Exception):
label_map[entity] = label
else:
label_map[entity] = entity
# Build labeled edges and documents for cross-encoder
labeled_hop = []
for s, p, o in hop_triples:
ls = label_map.get(s, s)
lp = label_map.get(p, p)
lo = label_map.get(o, o)
labeled_hop.append((ls, lp, lo))
documents = [
{"id": str(i), "text": f"{lp} {lo}"}
for i, (ls, lp, lo) in enumerate(labeled_hop)
]
queries = [
{"id": str(i), "text": c}
for i, c in enumerate(concepts)
]
# Score with cross-encoder
results = await self.rag.reranker_client.rerank(
queries=queries,
documents=documents,
limit=self.edge_limit,
)
# Collect selected edges and metadata
next_frontier = set()
for r in results:
idx = int(r.document_id)
ls, lp, lo = labeled_hop[idx]
s, p, o = hop_triples[idx]
eid = edge_id(ls, lp, lo)
all_selected_edges.append((ls, lp, lo))
uri_map[eid] = hop_term_map[(s, p, o)]
edge_metadata[eid] = {
"concept": concepts[int(r.query_id)],
"score": r.score,
}
# Target nodes become next frontier
next_frontier.add(s)
next_frontier.add(o)
if self.verbose:
logger.debug(
f"Hop {hop + 1}: selected {len(results)} edges"
)
visited_entities.update(frontier)
frontier = next_frontier - visited_entities
return all_selected_edges, uri_map, edge_metadata
async def trace_source_documents(self, edge_uris):
"""
Trace selected edges back to their source documents via provenance.
Follows the chain: edge subgraph (via tg:contains) chunk
page document (via prov:wasDerivedFrom), all in urn:graph:source.
Follows the chain: edge -> subgraph (via tg:contains) -> chunk ->
page -> document (via prov:wasDerivedFrom), all in urn:graph:source.
Args:
edge_uris: List of (s, p, o) URI string tuples
@ -453,7 +431,6 @@ class Query:
# Step 1: Find subgraphs containing these edges via tg:contains
subgraph_tasks = []
for s, p, o in edge_uris:
# s, p, o may be Term objects (preserving types) or strings
s_term = s if isinstance(s, Term) else Term(type=IRI, iri=s)
p_term = p if isinstance(p, Term) else Term(type=IRI, iri=p)
o_term = o if isinstance(o, Term) else Term(type=IRI, iri=o)
@ -487,12 +464,10 @@ class Query:
return []
# Step 2: Walk prov:wasDerivedFrom chain to find documents
# Each level: query ?entity prov:wasDerivedFrom ?parent
# Stop when we find entities typed tg:Document
current_uris = subgraph_uris
doc_uris = set()
for depth in range(4): # Max depth: subgraph → chunk → page → doc
for depth in range(4):
if not current_uris:
break
@ -509,7 +484,6 @@ class Query:
*derivation_tasks, return_exceptions=True
)
# URIs with no parent are root documents
next_uris = set()
for uri, result in zip(current_uris, derivation_results):
if isinstance(result, Exception) or not result:
@ -524,7 +498,6 @@ class Query:
return []
# Step 3: Get all document metadata properties
# Skip structural predicates that aren't useful context
SKIP_PREDICATES = {
PROV_WAS_DERIVED_FROM,
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
@ -565,7 +538,7 @@ class GraphRag:
def __init__(
self, prompt_client, embeddings_client, graph_embeddings_client,
triples_client, verbose=False,
triples_client, reranker_client, verbose=False,
):
self.verbose = verbose
@ -574,9 +547,8 @@ class GraphRag:
self.embeddings_client = embeddings_client
self.graph_embeddings_client = graph_embeddings_client
self.triples_client = triples_client
self.reranker_client = reranker_client
# Replace simple dict with LRU cache with TTL
# CRITICAL: This cache only lives for one request due to per-request instantiation
self.label_cache = LRUCacheWithTTL(max_size=5000, ttl=300)
if self.verbose:
@ -585,33 +557,12 @@ class GraphRag:
async def query(
self, query, collection = "default",
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
max_path_length = 2, edge_score_limit = 30, edge_limit = 25,
max_path_length = 2, edge_limit = 25,
streaming = False,
chunk_callback = None,
explain_callback = None, save_answer_callback = None,
parent_uri = "",
):
"""
Execute a GraphRAG query with real-time explainability tracking.
Args:
query: The query string
collection: Collection identifier
entity_limit: Max entities to retrieve
triple_limit: Max triples per entity
max_subgraph_size: Max edges in subgraph
max_path_length: Max hops from seed entities
edge_score_limit: Max edges to pass to LLM scoring (semantic pre-filter)
edge_limit: Max edges after LLM scoring
streaming: Enable streaming LLM response
chunk_callback: async def callback(chunk, end_of_stream) for streaming
explain_callback: async def callback(triples, explain_id) for real-time explainability
save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian
Returns:
tuple: (answer_text, usage) where usage is a dict with
in_token, out_token, model
"""
# Accumulate token usage across all prompt calls
total_in = 0
total_out = 0
@ -638,7 +589,9 @@ class GraphRag:
foc_uri = make_focus_uri(session_id)
syn_uri = make_synthesis_uri(session_id)
timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
timestamp = datetime.now(timezone.utc).isoformat().replace(
"+00:00", "Z",
)
# Emit question explainability immediately
if explain_callback:
@ -657,10 +610,12 @@ class GraphRag:
triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
edge_limit = edge_limit,
track_usage = track_usage,
)
kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query)
# Step 1: Extract concepts and find seed entities
seed_entities, concepts = await q.get_entities(query)
# Emit grounding explain after concept extraction
if explain_callback:
@ -676,11 +631,16 @@ class GraphRag:
)
await explain_callback(gnd_triples, gnd_uri)
# Emit exploration explain after graph retrieval completes
# Step 2: Iterative hop-and-filter with cross-encoder
selected_edges, uri_map, edge_metadata = await q.hop_and_filter(
seed_entities, concepts,
)
# Emit exploration explain
if explain_callback:
exp_triples = set_graph(
exploration_triples(
exp_uri, gnd_uri, len(kg),
exp_uri, gnd_uri, len(selected_edges),
entities=seed_entities,
),
GRAPH_RETRIEVAL
@ -688,235 +648,63 @@ class GraphRag:
await explain_callback(exp_triples, exp_uri)
if self.verbose:
logger.debug("Invoking LLM...")
logger.debug(f"Knowledge graph: {kg}")
logger.debug(f"Query: {query}")
# Semantic pre-filter: reduce edges before expensive LLM scoring
if edge_score_limit > 0 and len(kg) > edge_score_limit:
if self.verbose:
logger.debug(f"Selected {len(selected_edges)} edges")
for s, p, o in selected_edges:
eid = edge_id(s, p, o)
meta = edge_metadata.get(eid, {})
logger.debug(
f"Semantic pre-filter: {len(kg)} edges > "
f"limit {edge_score_limit}, filtering..."
f" {meta.get('score', 0):.4f} "
f"[{meta.get('concept', '')}] "
f"{s} | {p} | {o}"
)
# Embed edge descriptions: "subject, predicate, object"
edge_descriptions = [
f"{s}, {p}, {o}" for s, p, o in kg
]
# Embed concepts and edge descriptions concurrently
concept_embed_task = self.embeddings_client.embed(concepts)
edge_embed_task = self.embeddings_client.embed(edge_descriptions)
concept_vectors, edge_vectors = await asyncio.gather(
concept_embed_task, edge_embed_task
)
# Score each edge by max cosine similarity to any concept
def cosine_similarity(a, b):
dot = sum(x * y for x, y in zip(a, b))
norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(x * x for x in b))
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
edge_scores = []
for i, edge_vec in enumerate(edge_vectors):
max_sim = max(
cosine_similarity(edge_vec, cv)
for cv in concept_vectors
)
edge_scores.append((max_sim, i))
# Sort by similarity descending and keep top edge_score_limit
edge_scores.sort(reverse=True)
keep_indices = set(
idx for _, idx in edge_scores[:edge_score_limit]
)
# Filter kg and rebuild uri_map
filtered_kg = []
filtered_uri_map = {}
for i, (s, p, o) in enumerate(kg):
if i in keep_indices:
filtered_kg.append((s, p, o))
eid = edge_id(s, p, o)
if eid in uri_map:
filtered_uri_map[eid] = uri_map[eid]
if self.verbose:
logger.debug(
f"Semantic pre-filter kept {len(filtered_kg)} "
f"of {len(kg)} edges"
)
kg = filtered_kg
uri_map = filtered_uri_map
# Build edge map: {hash_id: (labeled_s, labeled_p, labeled_o)}
# uri_map already maps edge_id -> (uri_s, uri_p, uri_o)
edge_map = {}
edges_with_ids = []
for s, p, o in kg:
eid = edge_id(s, p, o)
edge_map[eid] = (s, p, o)
edges_with_ids.append({
"id": eid,
"s": s,
"p": p,
"o": o
})
if self.verbose:
logger.debug(f"Built edge map with {len(edge_map)} edges")
# Step 1a: Edge Scoring - LLM scores edges for relevance
scoring_result = await self.prompt_client.prompt(
"kg-edge-scoring",
variables={
"query": query,
"knowledge": edges_with_ids
}
)
track_usage(scoring_result)
if self.verbose:
logger.debug(f"Edge scoring result: {scoring_result}")
# Parse scoring response (jsonl) to get edge IDs with scores
scored_edges = []
for obj in scoring_result.objects or []:
if isinstance(obj, dict) and "id" in obj and "score" in obj:
try:
score = int(obj["score"])
except (ValueError, TypeError):
score = 0
scored_edges.append({"id": obj["id"], "score": score})
# Select top N edges by score
scored_edges.sort(key=lambda x: x["score"], reverse=True)
top_edges = scored_edges[:edge_limit]
selected_ids = {e["id"] for e in top_edges}
if self.verbose:
logger.debug(
f"Scored {len(scored_edges)} edges, "
f"selected top {len(selected_ids)}"
)
# Filter to selected edges
selected_edges = []
for eid in selected_ids:
if eid in edge_map:
selected_edges.append(edge_map[eid])
# Step 1b: Edge Reasoning + Document Tracing (concurrent)
selected_edges_with_ids = [
{"id": eid, "s": s, "p": p, "o": o}
for eid in selected_ids
if eid in edge_map
for s, p, o in [edge_map[eid]]
]
# Collect selected edge URIs for document tracing
# Step 3: Document tracing
selected_edge_uris = [
uri_map[eid]
for eid in selected_ids
if eid in uri_map
uri_map[edge_id(s, p, o)]
for s, p, o in selected_edges
if edge_id(s, p, o) in uri_map
]
# Run reasoning and document tracing concurrently
async def _get_reasoning():
result = await self.prompt_client.prompt(
"kg-edge-reasoning",
variables={
"query": query,
"knowledge": selected_edges_with_ids
}
)
track_usage(result)
return result
reasoning_task = _get_reasoning()
doc_trace_task = q.trace_source_documents(selected_edge_uris)
reasoning_result, source_documents = await asyncio.gather(
reasoning_task, doc_trace_task, return_exceptions=True
source_documents = await q.trace_source_documents(
selected_edge_uris,
)
# Handle exceptions from gather
if isinstance(reasoning_result, Exception):
logger.warning(
f"Edge reasoning failed: {reasoning_result}"
)
reasoning_result = None
if isinstance(source_documents, Exception):
logger.warning(
f"Document tracing failed: {source_documents}"
)
source_documents = []
if self.verbose:
logger.debug(f"Edge reasoning result: {reasoning_result}")
# Parse reasoning response (jsonl) and build explainability data
reasoning_map = {}
if reasoning_result is not None:
for obj in reasoning_result.objects or []:
if isinstance(obj, dict) and "id" in obj:
reasoning_map[obj["id"]] = obj.get("reasoning", "")
# Build focus explainability data with cross-encoder metadata
selected_edges_with_reasoning = []
for eid in selected_ids:
for s, p, o in selected_edges:
eid = edge_id(s, p, o)
if eid in uri_map:
uri_s, uri_p, uri_o = uri_map[eid]
meta = edge_metadata.get(eid, {})
selected_edges_with_reasoning.append({
"edge": (uri_s, uri_p, uri_o),
"reasoning": reasoning_map.get(eid, ""),
"concept": meta.get("concept", ""),
"score": meta.get("score", 0),
})
if self.verbose:
logger.debug(f"Filtered to {len(selected_edges)} edges")
# Emit focus explain after edge selection completes
# Emit focus explain
if explain_callback:
# Sum scoring + reasoning token usage for focus event
focus_in = 0
focus_out = 0
focus_model = None
for r in [scoring_result, reasoning_result]:
if r is not None:
if r.in_token is not None:
focus_in += r.in_token
if r.out_token is not None:
focus_out += r.out_token
if r.model is not None:
focus_model = r.model
foc_triples = set_graph(
focus_triples(
foc_uri, exp_uri, selected_edges_with_reasoning, session_id,
in_token=focus_in or None,
out_token=focus_out or None,
model=focus_model,
foc_uri, exp_uri,
selected_edges_with_reasoning, session_id,
),
GRAPH_RETRIEVAL
)
await explain_callback(foc_triples, foc_uri)
# Step 2: Synthesis - LLM generates answer from selected edges only
# Step 4: Synthesis
selected_edge_dicts = [
{"s": s, "p": p, "o": o}
for s, p, o in selected_edges
]
# Add source document metadata as knowledge edges
for s, p, o in source_documents:
selected_edge_dicts.append({
"s": s, "p": p, "o": o,
@ -928,7 +716,6 @@ class GraphRag:
}
if streaming and chunk_callback:
# Accumulate chunks for answer storage while forwarding to callback
accumulated_chunks = []
async def accumulating_callback(chunk, end_of_stream):
@ -942,7 +729,6 @@ class GraphRag:
chunk_callback=accumulating_callback
)
track_usage(synthesis_result)
# Combine all chunks into full response
resp = "".join(accumulated_chunks)
else:
synthesis_result = await self.prompt_client.prompt(
@ -955,29 +741,42 @@ class GraphRag:
if self.verbose:
logger.debug("Query processing complete")
# Emit synthesis explain after synthesis completes
# Emit synthesis explain
if explain_callback:
synthesis_doc_id = None
answer_text = resp if resp else ""
# Save answer to librarian
if save_answer_callback and answer_text:
synthesis_doc_id = f"urn:trustgraph:synthesis:{session_id}"
try:
await save_answer_callback(synthesis_doc_id, answer_text)
if self.verbose:
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
logger.debug(
f"Saved answer to librarian: "
f"{synthesis_doc_id}"
)
except Exception as e:
logger.warning(f"Failed to save answer to librarian: {e}")
logger.warning(
f"Failed to save answer to librarian: {e}"
)
synthesis_doc_id = None
syn_triples = set_graph(
synthesis_triples(
syn_uri, foc_uri,
document_id=synthesis_doc_id,
in_token=synthesis_result.in_token if synthesis_result else None,
out_token=synthesis_result.out_token if synthesis_result else None,
model=synthesis_result.model if synthesis_result else None,
in_token=(
synthesis_result.in_token
if synthesis_result else None
),
out_token=(
synthesis_result.out_token
if synthesis_result else None
),
model=(
synthesis_result.model
if synthesis_result else None
),
),
GRAPH_RETRIEVAL
)
@ -993,4 +792,3 @@ class GraphRag:
}
return resp, usage

View file

@ -13,6 +13,7 @@ from . graph_rag import GraphRag
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import PromptClientSpec, EmbeddingsClientSpec
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
from ... base import RerankerClientSpec
from ... base import LibrarianSpec
# Module logger
@ -32,7 +33,6 @@ class Processor(FlowProcessor):
triple_limit = params.get("triple_limit", 30)
max_subgraph_size = params.get("max_subgraph_size", 150)
max_path_length = params.get("max_path_length", 2)
edge_score_limit = params.get("edge_score_limit", 30)
edge_limit = params.get("edge_limit", 25)
super(Processor, self).__init__(
@ -43,7 +43,6 @@ class Processor(FlowProcessor):
"triple_limit": triple_limit,
"max_subgraph_size": max_subgraph_size,
"max_path_length": max_path_length,
"edge_score_limit": edge_score_limit,
"edge_limit": edge_limit,
}
)
@ -52,7 +51,6 @@ class Processor(FlowProcessor):
self.default_triple_limit = triple_limit
self.default_max_subgraph_size = max_subgraph_size
self.default_max_path_length = max_path_length
self.default_edge_score_limit = edge_score_limit
self.default_edge_limit = edge_limit
# Workspace isolation is enforced by the flow layer (flow.workspace).
@ -96,6 +94,13 @@ class Processor(FlowProcessor):
)
)
self.register_specification(
RerankerClientSpec(
request_name = "reranker-request",
response_name = "reranker-response",
)
)
self.register_specification(
ProducerSpec(
name = "response",
@ -163,6 +168,7 @@ class Processor(FlowProcessor):
graph_embeddings_client=flow("graph-embeddings-request"),
triples_client=flow("triples-request"),
prompt_client=flow("prompt-request"),
reranker_client=flow("reranker-request"),
verbose=True,
)
@ -186,11 +192,6 @@ class Processor(FlowProcessor):
else:
max_path_length = self.default_max_path_length
if v.edge_score_limit:
edge_score_limit = v.edge_score_limit
else:
edge_score_limit = self.default_edge_score_limit
if v.edge_limit:
edge_limit = v.edge_limit
else:
@ -225,7 +226,7 @@ class Processor(FlowProcessor):
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
edge_score_limit = edge_score_limit,
edge_limit = edge_limit,
streaming = True,
chunk_callback = send_chunk,
@ -241,7 +242,7 @@ class Processor(FlowProcessor):
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
edge_score_limit = edge_score_limit,
edge_limit = edge_limit,
explain_callback = send_explainability,
save_answer_callback = save_answer,
@ -338,18 +339,11 @@ class Processor(FlowProcessor):
help=f'Default max path length (default: 2)'
)
parser.add_argument(
'--edge-score-limit',
type=int,
default=30,
help=f'Semantic pre-filter limit before LLM scoring (default: 30)'
)
parser.add_argument(
'--edge-limit',
type=int,
default=25,
help=f'Max edges after LLM scoring (default: 25)'
help=f'Max edges selected per hop by cross-encoder (default: 25)'
)
# Note: Explainability triples are now stored in the request's collection