mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-01 09:29:38 +02:00
feat: replace LLM edge scoring with cross-encoder reranker in GraphRAG (#1005)
Replace the three-prompt LLM scoring pipeline (kg-edge-scoring, kg-edge-reasoning, kg-edge-selection) with a cross-encoder reranker service backed by FlashRank. The new hop_and_filter() method performs iterative graph traversal with semantic scoring at each hop, replacing the previous follow_edges/get_subgraph approach. - Add reranker service (trustgraph-base client/service, FlashRank processor) - Add gateway dispatch for reranker via API and WebSocket - Rewrite GraphRAG pipeline: hop_and_filter() with per-hop cross-encoder scoring - Remove kg_prompt() and edge_score_limit from prompt client - Update provenance: add tg:EdgeSelection type, tg:concept, tg:score predicates - Update CLIs (tg-invoke-graph-rag, tg-show-explain-trace) for new metadata - Add tg-invoke-reranker CLI tool - Add tech spec and UX developer guidance - Update all unit and integration tests
This commit is contained in:
parent
1aa9549912
commit
01cc8dbc64
43 changed files with 1613 additions and 792 deletions
|
|
@ -19,6 +19,7 @@ dependencies = [
|
|||
"faiss-cpu",
|
||||
"falkordb",
|
||||
"fastembed",
|
||||
"flashrank",
|
||||
"ibis",
|
||||
"jsonschema",
|
||||
"langchain",
|
||||
|
|
@ -83,6 +84,7 @@ graph-embeddings-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone:
|
|||
graph-embeddings-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run"
|
||||
graph-embeddings = "trustgraph.embeddings.graph_embeddings:run"
|
||||
graph-rag = "trustgraph.retrieval.graph_rag:run"
|
||||
reranker-flashrank = "trustgraph.reranker.flashrank:run"
|
||||
kg-extract-agent = "trustgraph.extract.kg.agent:run"
|
||||
kg-extract-definitions = "trustgraph.extract.kg.definitions:run"
|
||||
kg-extract-rows = "trustgraph.extract.kg.rows:run"
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
|
|||
from . document_embeddings_query import DocumentEmbeddingsQueryRequestor
|
||||
from . row_embeddings_query import RowEmbeddingsQueryRequestor
|
||||
from . mcp_tool import McpToolRequestor
|
||||
from . reranker import RerankerRequestor
|
||||
from . text_load import TextLoad
|
||||
from . document_load import DocumentLoad
|
||||
|
||||
|
|
@ -74,6 +75,7 @@ request_response_dispatchers = {
|
|||
"structured-diag": StructuredDiagRequestor,
|
||||
"row-embeddings": RowEmbeddingsQueryRequestor,
|
||||
"sparql": SparqlQueryRequestor,
|
||||
"reranker": RerankerRequestor,
|
||||
}
|
||||
|
||||
system_dispatchers = {
|
||||
|
|
|
|||
31
trustgraph-flow/trustgraph/gateway/dispatch/reranker.py
Normal file
31
trustgraph-flow/trustgraph/gateway/dispatch/reranker.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
|
||||
from ... schema import RerankerRequest, RerankerResponse
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class RerankerRequestor(ServiceRequestor):
|
||||
def __init__(
|
||||
self, backend, request_queue, response_queue, timeout,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
super(RerankerRequestor, self).__init__(
|
||||
backend=backend,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
request_schema=RerankerRequest,
|
||||
response_schema=RerankerResponse,
|
||||
subscription = subscriber,
|
||||
consumer_name = consumer,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("reranker")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("reranker")
|
||||
|
||||
def to_request(self, body):
|
||||
return self.request_translator.decode(body)
|
||||
|
||||
def from_response(self, message):
|
||||
return self.response_translator.encode_with_completion(message)
|
||||
|
|
@ -518,6 +518,7 @@ _FLOW_SERVICES = {
|
|||
"structured-diag": "structured-query:read",
|
||||
"row-embeddings": "row-embeddings:read",
|
||||
"sparql": "sparql:read",
|
||||
"reranker": "reranker",
|
||||
}
|
||||
for _kind, _cap in _FLOW_SERVICES.items():
|
||||
_register_flow_kind("flow-service", _kind, _cap)
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ _READER_CAPS = {
|
|||
"row-embeddings:read",
|
||||
"llm",
|
||||
"embeddings",
|
||||
"reranker",
|
||||
"mcp",
|
||||
"config:read",
|
||||
"flows:read",
|
||||
|
|
|
|||
1
trustgraph-flow/trustgraph/reranker/__init__.py
Normal file
1
trustgraph-flow/trustgraph/reranker/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
|
||||
from . processor import *
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . processor import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
109
trustgraph-flow/trustgraph/reranker/flashrank/processor.py
Normal file
109
trustgraph-flow/trustgraph/reranker/flashrank/processor.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
|
||||
"""
|
||||
Reranker service using flashrank.
|
||||
Scores query-document pairs and returns the top results ranked by
|
||||
relevance.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from ... base import RerankerService
|
||||
from ... schema import RerankerResult
|
||||
|
||||
from flashrank import Ranker, RerankRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "reranker"
|
||||
|
||||
default_model = "ms-marco-MiniLM-L-12-v2"
|
||||
|
||||
class Processor(RerankerService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
model = params.get("model", default_model)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | { "model": model }
|
||||
)
|
||||
|
||||
self.default_model = model
|
||||
|
||||
self.cached_model_name = None
|
||||
self.ranker = None
|
||||
|
||||
self._load_model(model)
|
||||
|
||||
def _load_model(self, model_name):
|
||||
if self.cached_model_name != model_name:
|
||||
logger.info(f"Loading flashrank model: {model_name}")
|
||||
self.ranker = Ranker(model_name=model_name)
|
||||
self.cached_model_name = model_name
|
||||
logger.info(f"flashrank model {model_name} loaded successfully")
|
||||
else:
|
||||
logger.debug(f"Using cached model: {model_name}")
|
||||
|
||||
def _run_rerank(self, query, passages):
|
||||
request = RerankRequest(query=query, passages=passages)
|
||||
return self.ranker.rerank(request)
|
||||
|
||||
async def on_rerank(self, queries, documents, limit, model=None):
|
||||
|
||||
if not queries or not documents:
|
||||
return []
|
||||
|
||||
use_model = model or self.default_model
|
||||
|
||||
if self.cached_model_name != use_model:
|
||||
await asyncio.to_thread(self._load_model, use_model)
|
||||
|
||||
passages = [
|
||||
{"id": d.document_id, "text": d.document_text}
|
||||
for d in documents
|
||||
]
|
||||
|
||||
best_scores = {}
|
||||
|
||||
for q in queries:
|
||||
ranked = await asyncio.to_thread(
|
||||
self._run_rerank, q.query_text, passages,
|
||||
)
|
||||
|
||||
for r in ranked:
|
||||
doc_id = r["id"]
|
||||
score = r["score"]
|
||||
score = float(score)
|
||||
if doc_id not in best_scores or score > best_scores[doc_id][1]:
|
||||
best_scores[doc_id] = (q.query_id, score)
|
||||
|
||||
results = sorted(
|
||||
best_scores.items(),
|
||||
key=lambda x: x[1][1],
|
||||
reverse=True,
|
||||
)[:limit]
|
||||
|
||||
return [
|
||||
RerankerResult(
|
||||
document_id=doc_id,
|
||||
query_id=query_id,
|
||||
score=score,
|
||||
)
|
||||
for doc_id, (query_id, score) in results
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
RerankerService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default=default_model,
|
||||
help=f'Reranker model (default: {default_model})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
@ -120,7 +120,7 @@ class Query:
|
|||
def __init__(
|
||||
self, rag, collection, verbose,
|
||||
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
|
||||
max_path_length=2, track_usage=None,
|
||||
max_path_length=2, edge_limit=25, track_usage=None,
|
||||
):
|
||||
self.rag = rag
|
||||
self.collection = collection
|
||||
|
|
@ -129,6 +129,7 @@ class Query:
|
|||
self.triple_limit = triple_limit
|
||||
self.max_subgraph_size = max_subgraph_size
|
||||
self.max_path_length = max_path_length
|
||||
self.edge_limit = edge_limit
|
||||
self.track_usage = track_usage
|
||||
|
||||
async def extract_concepts(self, query):
|
||||
|
|
@ -217,12 +218,9 @@ class Query:
|
|||
logger.debug(f" {ent}")
|
||||
|
||||
return entities, concepts
|
||||
|
||||
|
||||
async def maybe_label(self, e):
|
||||
|
||||
# The label cache lives on a per-request GraphRag instance — no
|
||||
# cross-request isolation concern. The collection prefix keeps
|
||||
# entries from different collections distinct within one request.
|
||||
cache_key = f"{self.collection}:{e}"
|
||||
|
||||
cached_label = self.rag.label_cache.get(cache_key)
|
||||
|
|
@ -244,11 +242,10 @@ class Query:
|
|||
return label
|
||||
|
||||
async def execute_batch_triple_queries(self, entities, limit_per_entity):
|
||||
"""Execute triple queries for multiple entities concurrently using streaming"""
|
||||
"""Execute triple queries for multiple entities concurrently."""
|
||||
tasks = []
|
||||
|
||||
for entity in entities:
|
||||
# Create concurrent streaming tasks for all 3 query types per entity
|
||||
tasks.extend([
|
||||
self.rag.triples_client.query_stream(
|
||||
s=entity, p=None, o=None,
|
||||
|
|
@ -270,10 +267,8 @@ class Query:
|
|||
)
|
||||
])
|
||||
|
||||
# Execute all queries concurrently
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Combine all results
|
||||
all_triples = []
|
||||
for result in results:
|
||||
if not isinstance(result, Exception) and result is not None:
|
||||
|
|
@ -281,168 +276,151 @@ class Query:
|
|||
|
||||
return all_triples
|
||||
|
||||
async def follow_edges_batch(self, entities, max_depth):
|
||||
"""Optimized iterative graph traversal with batching.
|
||||
|
||||
Returns:
|
||||
tuple: (subgraph, term_map) where subgraph is a set of
|
||||
(str, str, str) tuples and term_map maps each string tuple
|
||||
to its original (Term, Term, Term) for type-preserving
|
||||
provenance.
|
||||
"""
|
||||
visited = set()
|
||||
current_level = set(entities)
|
||||
subgraph = set()
|
||||
term_map = {} # (str, str, str) -> (Term, Term, Term)
|
||||
|
||||
for depth in range(max_depth):
|
||||
if not current_level or len(subgraph) >= self.max_subgraph_size:
|
||||
break
|
||||
|
||||
# Filter out already visited entities
|
||||
unvisited_entities = [e for e in current_level if e not in visited]
|
||||
if not unvisited_entities:
|
||||
break
|
||||
|
||||
# Batch query all unvisited entities at current level
|
||||
triples = await self.execute_batch_triple_queries(
|
||||
unvisited_entities, self.triple_limit
|
||||
)
|
||||
|
||||
# Process results and collect next level entities
|
||||
next_level = set()
|
||||
for triple in triples:
|
||||
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
|
||||
subgraph.add(triple_tuple)
|
||||
term_map[triple_tuple] = (to_term(triple.s), to_term(triple.p), to_term(triple.o))
|
||||
|
||||
# Collect entities for next level (only from s and o positions)
|
||||
if depth < max_depth - 1: # Don't collect for final depth
|
||||
s, p, o = triple_tuple
|
||||
if s not in visited:
|
||||
next_level.add(s)
|
||||
if o not in visited:
|
||||
next_level.add(o)
|
||||
|
||||
# Stop if subgraph size limit reached
|
||||
if len(subgraph) >= self.max_subgraph_size:
|
||||
return subgraph, term_map
|
||||
|
||||
# Update for next iteration
|
||||
visited.update(current_level)
|
||||
current_level = next_level
|
||||
|
||||
return subgraph, term_map
|
||||
|
||||
async def follow_edges(self, ent, subgraph, path_length):
|
||||
"""Legacy method - replaced by follow_edges_batch"""
|
||||
# Maintain backward compatibility with early termination checks
|
||||
if path_length <= 0:
|
||||
return
|
||||
|
||||
if len(subgraph) >= self.max_subgraph_size:
|
||||
return
|
||||
|
||||
# For backward compatibility, convert to new approach
|
||||
batch_result, _ = await self.follow_edges_batch([ent], path_length)
|
||||
subgraph.update(batch_result)
|
||||
|
||||
async def get_subgraph(self, query):
|
||||
"""
|
||||
Get subgraph by extracting concepts, finding entities, and traversing.
|
||||
|
||||
Returns:
|
||||
tuple: (subgraph, term_map, entities, concepts) where subgraph is
|
||||
a list of (s, p, o) string tuples, term_map maps each string
|
||||
tuple to its original (Term, Term, Term), entities is the seed
|
||||
entity list, and concepts is the extracted concept list.
|
||||
"""
|
||||
|
||||
entities, concepts = await self.get_entities(query)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Getting subgraph...")
|
||||
|
||||
# Use optimized batch traversal instead of sequential processing
|
||||
subgraph, term_map = await self.follow_edges_batch(entities, self.max_path_length)
|
||||
|
||||
return list(subgraph), term_map, entities, concepts
|
||||
|
||||
async def resolve_labels_batch(self, entities):
|
||||
"""Resolve labels for multiple entities in parallel"""
|
||||
tasks = []
|
||||
for entity in entities:
|
||||
tasks.append(self.maybe_label(entity))
|
||||
|
||||
"""Resolve labels for multiple entities in parallel."""
|
||||
tasks = [self.maybe_label(entity) for entity in entities]
|
||||
return await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def get_labelgraph(self, query):
|
||||
"""
|
||||
Get subgraph with labels resolved for display.
|
||||
async def hop_and_filter(self, seed_entities, concepts):
|
||||
"""Iterative hop-and-filter graph traversal with cross-encoder.
|
||||
|
||||
At each hop:
|
||||
1. Retrieve all edges one hop from the frontier.
|
||||
2. Resolve labels and represent each edge as "{p} {o}".
|
||||
3. Score edges against concepts using the cross-encoder.
|
||||
4. Select the top edges; their target nodes become the next
|
||||
frontier.
|
||||
|
||||
Returns:
|
||||
tuple: (labeled_edges, uri_map, entities, concepts) where:
|
||||
- labeled_edges: list of (label_s, label_p, label_o) tuples
|
||||
- uri_map: dict mapping edge_id(label_s, label_p, label_o) -> (uri_s, uri_p, uri_o)
|
||||
- entities: list of seed entity URI strings
|
||||
- concepts: list of concept strings extracted from query
|
||||
tuple: (selected_edges, uri_map, edge_metadata) where:
|
||||
- selected_edges: list of (label_s, label_p, label_o)
|
||||
- uri_map: dict mapping edge_id -> (Term, Term, Term)
|
||||
- edge_metadata: dict mapping edge_id -> {concept, score}
|
||||
"""
|
||||
subgraph, term_map, entities, concepts = await self.get_subgraph(query)
|
||||
all_selected_edges = []
|
||||
uri_map = {}
|
||||
edge_metadata = {}
|
||||
frontier = set(seed_entities)
|
||||
visited_entities = set()
|
||||
seen_edges = set()
|
||||
|
||||
# Filter out label triples
|
||||
filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL]
|
||||
for hop in range(self.max_path_length):
|
||||
if not frontier:
|
||||
break
|
||||
|
||||
# Collect all unique entities that need label resolution
|
||||
entities_to_resolve = set()
|
||||
for s, p, o in filtered_subgraph:
|
||||
entities_to_resolve.update([s, p, o])
|
||||
unvisited = [e for e in frontier if e not in visited_entities]
|
||||
if not unvisited:
|
||||
break
|
||||
|
||||
# Batch resolve labels for all entities in parallel
|
||||
entity_list = list(entities_to_resolve)
|
||||
resolved_labels = await self.resolve_labels_batch(entity_list)
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Hop {hop + 1}: {len(unvisited)} frontier entities"
|
||||
)
|
||||
|
||||
# Create entity-to-label mapping
|
||||
label_map = {}
|
||||
for entity, label in zip(entity_list, resolved_labels):
|
||||
if not isinstance(label, Exception):
|
||||
label_map[entity] = label
|
||||
else:
|
||||
label_map[entity] = entity # Fallback to entity itself
|
||||
|
||||
# Apply labels to subgraph and build URI mapping
|
||||
labeled_edges = []
|
||||
uri_map = {} # Maps edge_id of labeled edge -> original Term triple
|
||||
|
||||
for s, p, o in filtered_subgraph:
|
||||
labeled_triple = (
|
||||
label_map.get(s, s),
|
||||
label_map.get(p, p),
|
||||
label_map.get(o, o)
|
||||
# Retrieve edges one hop from frontier
|
||||
triples = await self.execute_batch_triple_queries(
|
||||
unvisited, self.triple_limit,
|
||||
)
|
||||
labeled_edges.append(labeled_triple)
|
||||
|
||||
# Map from labeled edge ID to original Terms (preserving types)
|
||||
labeled_eid = edge_id(labeled_triple[0], labeled_triple[1], labeled_triple[2])
|
||||
uri_map[labeled_eid] = term_map.get((s, p, o), (s, p, o))
|
||||
# Deduplicate and filter already-seen edges
|
||||
hop_triples = []
|
||||
hop_term_map = {}
|
||||
for triple in triples:
|
||||
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
|
||||
if triple_tuple[1] == LABEL:
|
||||
continue
|
||||
if triple_tuple in seen_edges:
|
||||
continue
|
||||
seen_edges.add(triple_tuple)
|
||||
hop_triples.append(triple_tuple)
|
||||
hop_term_map[triple_tuple] = (
|
||||
to_term(triple.s), to_term(triple.p), to_term(triple.o),
|
||||
)
|
||||
|
||||
labeled_edges = labeled_edges[0:self.max_subgraph_size]
|
||||
if not hop_triples:
|
||||
visited_entities.update(frontier)
|
||||
break
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Subgraph:")
|
||||
for edge in labeled_edges:
|
||||
logger.debug(f" {str(edge)}")
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Hop {hop + 1}: {len(hop_triples)} candidate edges"
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Done.")
|
||||
# Resolve labels for all entities in hop edges
|
||||
entities_to_resolve = set()
|
||||
for s, p, o in hop_triples:
|
||||
entities_to_resolve.update([s, p, o])
|
||||
|
||||
return labeled_edges, uri_map, entities, concepts
|
||||
entity_list = list(entities_to_resolve)
|
||||
resolved = await self.resolve_labels_batch(entity_list)
|
||||
|
||||
label_map = {}
|
||||
for entity, label in zip(entity_list, resolved):
|
||||
if not isinstance(label, Exception):
|
||||
label_map[entity] = label
|
||||
else:
|
||||
label_map[entity] = entity
|
||||
|
||||
# Build labeled edges and documents for cross-encoder
|
||||
labeled_hop = []
|
||||
for s, p, o in hop_triples:
|
||||
ls = label_map.get(s, s)
|
||||
lp = label_map.get(p, p)
|
||||
lo = label_map.get(o, o)
|
||||
labeled_hop.append((ls, lp, lo))
|
||||
|
||||
documents = [
|
||||
{"id": str(i), "text": f"{lp} {lo}"}
|
||||
for i, (ls, lp, lo) in enumerate(labeled_hop)
|
||||
]
|
||||
|
||||
queries = [
|
||||
{"id": str(i), "text": c}
|
||||
for i, c in enumerate(concepts)
|
||||
]
|
||||
|
||||
# Score with cross-encoder
|
||||
results = await self.rag.reranker_client.rerank(
|
||||
queries=queries,
|
||||
documents=documents,
|
||||
limit=self.edge_limit,
|
||||
)
|
||||
|
||||
# Collect selected edges and metadata
|
||||
next_frontier = set()
|
||||
for r in results:
|
||||
idx = int(r.document_id)
|
||||
ls, lp, lo = labeled_hop[idx]
|
||||
s, p, o = hop_triples[idx]
|
||||
eid = edge_id(ls, lp, lo)
|
||||
|
||||
all_selected_edges.append((ls, lp, lo))
|
||||
uri_map[eid] = hop_term_map[(s, p, o)]
|
||||
edge_metadata[eid] = {
|
||||
"concept": concepts[int(r.query_id)],
|
||||
"score": r.score,
|
||||
}
|
||||
|
||||
# Target nodes become next frontier
|
||||
next_frontier.add(s)
|
||||
next_frontier.add(o)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Hop {hop + 1}: selected {len(results)} edges"
|
||||
)
|
||||
|
||||
visited_entities.update(frontier)
|
||||
frontier = next_frontier - visited_entities
|
||||
|
||||
return all_selected_edges, uri_map, edge_metadata
|
||||
|
||||
async def trace_source_documents(self, edge_uris):
|
||||
"""
|
||||
Trace selected edges back to their source documents via provenance.
|
||||
|
||||
Follows the chain: edge → subgraph (via tg:contains) → chunk →
|
||||
page → document (via prov:wasDerivedFrom), all in urn:graph:source.
|
||||
Follows the chain: edge -> subgraph (via tg:contains) -> chunk ->
|
||||
page -> document (via prov:wasDerivedFrom), all in urn:graph:source.
|
||||
|
||||
Args:
|
||||
edge_uris: List of (s, p, o) URI string tuples
|
||||
|
|
@ -453,7 +431,6 @@ class Query:
|
|||
# Step 1: Find subgraphs containing these edges via tg:contains
|
||||
subgraph_tasks = []
|
||||
for s, p, o in edge_uris:
|
||||
# s, p, o may be Term objects (preserving types) or strings
|
||||
s_term = s if isinstance(s, Term) else Term(type=IRI, iri=s)
|
||||
p_term = p if isinstance(p, Term) else Term(type=IRI, iri=p)
|
||||
o_term = o if isinstance(o, Term) else Term(type=IRI, iri=o)
|
||||
|
|
@ -487,12 +464,10 @@ class Query:
|
|||
return []
|
||||
|
||||
# Step 2: Walk prov:wasDerivedFrom chain to find documents
|
||||
# Each level: query ?entity prov:wasDerivedFrom ?parent
|
||||
# Stop when we find entities typed tg:Document
|
||||
current_uris = subgraph_uris
|
||||
doc_uris = set()
|
||||
|
||||
for depth in range(4): # Max depth: subgraph → chunk → page → doc
|
||||
for depth in range(4):
|
||||
if not current_uris:
|
||||
break
|
||||
|
||||
|
|
@ -509,7 +484,6 @@ class Query:
|
|||
*derivation_tasks, return_exceptions=True
|
||||
)
|
||||
|
||||
# URIs with no parent are root documents
|
||||
next_uris = set()
|
||||
for uri, result in zip(current_uris, derivation_results):
|
||||
if isinstance(result, Exception) or not result:
|
||||
|
|
@ -524,7 +498,6 @@ class Query:
|
|||
return []
|
||||
|
||||
# Step 3: Get all document metadata properties
|
||||
# Skip structural predicates that aren't useful context
|
||||
SKIP_PREDICATES = {
|
||||
PROV_WAS_DERIVED_FROM,
|
||||
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
|
||||
|
|
@ -565,7 +538,7 @@ class GraphRag:
|
|||
|
||||
def __init__(
|
||||
self, prompt_client, embeddings_client, graph_embeddings_client,
|
||||
triples_client, verbose=False,
|
||||
triples_client, reranker_client, verbose=False,
|
||||
):
|
||||
|
||||
self.verbose = verbose
|
||||
|
|
@ -574,9 +547,8 @@ class GraphRag:
|
|||
self.embeddings_client = embeddings_client
|
||||
self.graph_embeddings_client = graph_embeddings_client
|
||||
self.triples_client = triples_client
|
||||
self.reranker_client = reranker_client
|
||||
|
||||
# Replace simple dict with LRU cache with TTL
|
||||
# CRITICAL: This cache only lives for one request due to per-request instantiation
|
||||
self.label_cache = LRUCacheWithTTL(max_size=5000, ttl=300)
|
||||
|
||||
if self.verbose:
|
||||
|
|
@ -585,33 +557,12 @@ class GraphRag:
|
|||
async def query(
|
||||
self, query, collection = "default",
|
||||
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
|
||||
max_path_length = 2, edge_score_limit = 30, edge_limit = 25,
|
||||
max_path_length = 2, edge_limit = 25,
|
||||
streaming = False,
|
||||
chunk_callback = None,
|
||||
explain_callback = None, save_answer_callback = None,
|
||||
parent_uri = "",
|
||||
):
|
||||
"""
|
||||
Execute a GraphRAG query with real-time explainability tracking.
|
||||
|
||||
Args:
|
||||
query: The query string
|
||||
collection: Collection identifier
|
||||
entity_limit: Max entities to retrieve
|
||||
triple_limit: Max triples per entity
|
||||
max_subgraph_size: Max edges in subgraph
|
||||
max_path_length: Max hops from seed entities
|
||||
edge_score_limit: Max edges to pass to LLM scoring (semantic pre-filter)
|
||||
edge_limit: Max edges after LLM scoring
|
||||
streaming: Enable streaming LLM response
|
||||
chunk_callback: async def callback(chunk, end_of_stream) for streaming
|
||||
explain_callback: async def callback(triples, explain_id) for real-time explainability
|
||||
save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian
|
||||
|
||||
Returns:
|
||||
tuple: (answer_text, usage) where usage is a dict with
|
||||
in_token, out_token, model
|
||||
"""
|
||||
# Accumulate token usage across all prompt calls
|
||||
total_in = 0
|
||||
total_out = 0
|
||||
|
|
@ -638,7 +589,9 @@ class GraphRag:
|
|||
foc_uri = make_focus_uri(session_id)
|
||||
syn_uri = make_synthesis_uri(session_id)
|
||||
|
||||
timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
timestamp = datetime.now(timezone.utc).isoformat().replace(
|
||||
"+00:00", "Z",
|
||||
)
|
||||
|
||||
# Emit question explainability immediately
|
||||
if explain_callback:
|
||||
|
|
@ -657,10 +610,12 @@ class GraphRag:
|
|||
triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
edge_limit = edge_limit,
|
||||
track_usage = track_usage,
|
||||
)
|
||||
|
||||
kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query)
|
||||
# Step 1: Extract concepts and find seed entities
|
||||
seed_entities, concepts = await q.get_entities(query)
|
||||
|
||||
# Emit grounding explain after concept extraction
|
||||
if explain_callback:
|
||||
|
|
@ -676,11 +631,16 @@ class GraphRag:
|
|||
)
|
||||
await explain_callback(gnd_triples, gnd_uri)
|
||||
|
||||
# Emit exploration explain after graph retrieval completes
|
||||
# Step 2: Iterative hop-and-filter with cross-encoder
|
||||
selected_edges, uri_map, edge_metadata = await q.hop_and_filter(
|
||||
seed_entities, concepts,
|
||||
)
|
||||
|
||||
# Emit exploration explain
|
||||
if explain_callback:
|
||||
exp_triples = set_graph(
|
||||
exploration_triples(
|
||||
exp_uri, gnd_uri, len(kg),
|
||||
exp_uri, gnd_uri, len(selected_edges),
|
||||
entities=seed_entities,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
|
|
@ -688,235 +648,63 @@ class GraphRag:
|
|||
await explain_callback(exp_triples, exp_uri)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Invoking LLM...")
|
||||
logger.debug(f"Knowledge graph: {kg}")
|
||||
logger.debug(f"Query: {query}")
|
||||
|
||||
# Semantic pre-filter: reduce edges before expensive LLM scoring
|
||||
if edge_score_limit > 0 and len(kg) > edge_score_limit:
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Selected {len(selected_edges)} edges")
|
||||
for s, p, o in selected_edges:
|
||||
eid = edge_id(s, p, o)
|
||||
meta = edge_metadata.get(eid, {})
|
||||
logger.debug(
|
||||
f"Semantic pre-filter: {len(kg)} edges > "
|
||||
f"limit {edge_score_limit}, filtering..."
|
||||
f" {meta.get('score', 0):.4f} "
|
||||
f"[{meta.get('concept', '')}] "
|
||||
f"{s} | {p} | {o}"
|
||||
)
|
||||
|
||||
# Embed edge descriptions: "subject, predicate, object"
|
||||
edge_descriptions = [
|
||||
f"{s}, {p}, {o}" for s, p, o in kg
|
||||
]
|
||||
|
||||
# Embed concepts and edge descriptions concurrently
|
||||
concept_embed_task = self.embeddings_client.embed(concepts)
|
||||
edge_embed_task = self.embeddings_client.embed(edge_descriptions)
|
||||
|
||||
concept_vectors, edge_vectors = await asyncio.gather(
|
||||
concept_embed_task, edge_embed_task
|
||||
)
|
||||
|
||||
# Score each edge by max cosine similarity to any concept
|
||||
def cosine_similarity(a, b):
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = math.sqrt(sum(x * x for x in a))
|
||||
norm_b = math.sqrt(sum(x * x for x in b))
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
return dot / (norm_a * norm_b)
|
||||
|
||||
edge_scores = []
|
||||
for i, edge_vec in enumerate(edge_vectors):
|
||||
max_sim = max(
|
||||
cosine_similarity(edge_vec, cv)
|
||||
for cv in concept_vectors
|
||||
)
|
||||
edge_scores.append((max_sim, i))
|
||||
|
||||
# Sort by similarity descending and keep top edge_score_limit
|
||||
edge_scores.sort(reverse=True)
|
||||
keep_indices = set(
|
||||
idx for _, idx in edge_scores[:edge_score_limit]
|
||||
)
|
||||
|
||||
# Filter kg and rebuild uri_map
|
||||
filtered_kg = []
|
||||
filtered_uri_map = {}
|
||||
for i, (s, p, o) in enumerate(kg):
|
||||
if i in keep_indices:
|
||||
filtered_kg.append((s, p, o))
|
||||
eid = edge_id(s, p, o)
|
||||
if eid in uri_map:
|
||||
filtered_uri_map[eid] = uri_map[eid]
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Semantic pre-filter kept {len(filtered_kg)} "
|
||||
f"of {len(kg)} edges"
|
||||
)
|
||||
|
||||
kg = filtered_kg
|
||||
uri_map = filtered_uri_map
|
||||
|
||||
# Build edge map: {hash_id: (labeled_s, labeled_p, labeled_o)}
|
||||
# uri_map already maps edge_id -> (uri_s, uri_p, uri_o)
|
||||
edge_map = {}
|
||||
edges_with_ids = []
|
||||
for s, p, o in kg:
|
||||
eid = edge_id(s, p, o)
|
||||
edge_map[eid] = (s, p, o)
|
||||
edges_with_ids.append({
|
||||
"id": eid,
|
||||
"s": s,
|
||||
"p": p,
|
||||
"o": o
|
||||
})
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Built edge map with {len(edge_map)} edges")
|
||||
|
||||
# Step 1a: Edge Scoring - LLM scores edges for relevance
|
||||
scoring_result = await self.prompt_client.prompt(
|
||||
"kg-edge-scoring",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": edges_with_ids
|
||||
}
|
||||
)
|
||||
track_usage(scoring_result)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Edge scoring result: {scoring_result}")
|
||||
|
||||
# Parse scoring response (jsonl) to get edge IDs with scores
|
||||
scored_edges = []
|
||||
|
||||
for obj in scoring_result.objects or []:
|
||||
if isinstance(obj, dict) and "id" in obj and "score" in obj:
|
||||
try:
|
||||
score = int(obj["score"])
|
||||
except (ValueError, TypeError):
|
||||
score = 0
|
||||
scored_edges.append({"id": obj["id"], "score": score})
|
||||
|
||||
# Select top N edges by score
|
||||
scored_edges.sort(key=lambda x: x["score"], reverse=True)
|
||||
top_edges = scored_edges[:edge_limit]
|
||||
selected_ids = {e["id"] for e in top_edges}
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Scored {len(scored_edges)} edges, "
|
||||
f"selected top {len(selected_ids)}"
|
||||
)
|
||||
|
||||
# Filter to selected edges
|
||||
selected_edges = []
|
||||
for eid in selected_ids:
|
||||
if eid in edge_map:
|
||||
selected_edges.append(edge_map[eid])
|
||||
|
||||
# Step 1b: Edge Reasoning + Document Tracing (concurrent)
|
||||
selected_edges_with_ids = [
|
||||
{"id": eid, "s": s, "p": p, "o": o}
|
||||
for eid in selected_ids
|
||||
if eid in edge_map
|
||||
for s, p, o in [edge_map[eid]]
|
||||
]
|
||||
|
||||
# Collect selected edge URIs for document tracing
|
||||
# Step 3: Document tracing
|
||||
selected_edge_uris = [
|
||||
uri_map[eid]
|
||||
for eid in selected_ids
|
||||
if eid in uri_map
|
||||
uri_map[edge_id(s, p, o)]
|
||||
for s, p, o in selected_edges
|
||||
if edge_id(s, p, o) in uri_map
|
||||
]
|
||||
|
||||
# Run reasoning and document tracing concurrently
|
||||
async def _get_reasoning():
|
||||
result = await self.prompt_client.prompt(
|
||||
"kg-edge-reasoning",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": selected_edges_with_ids
|
||||
}
|
||||
)
|
||||
track_usage(result)
|
||||
return result
|
||||
|
||||
reasoning_task = _get_reasoning()
|
||||
doc_trace_task = q.trace_source_documents(selected_edge_uris)
|
||||
|
||||
reasoning_result, source_documents = await asyncio.gather(
|
||||
reasoning_task, doc_trace_task, return_exceptions=True
|
||||
source_documents = await q.trace_source_documents(
|
||||
selected_edge_uris,
|
||||
)
|
||||
|
||||
# Handle exceptions from gather
|
||||
if isinstance(reasoning_result, Exception):
|
||||
logger.warning(
|
||||
f"Edge reasoning failed: {reasoning_result}"
|
||||
)
|
||||
reasoning_result = None
|
||||
if isinstance(source_documents, Exception):
|
||||
logger.warning(
|
||||
f"Document tracing failed: {source_documents}"
|
||||
)
|
||||
source_documents = []
|
||||
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Edge reasoning result: {reasoning_result}")
|
||||
|
||||
# Parse reasoning response (jsonl) and build explainability data
|
||||
reasoning_map = {}
|
||||
|
||||
if reasoning_result is not None:
|
||||
for obj in reasoning_result.objects or []:
|
||||
if isinstance(obj, dict) and "id" in obj:
|
||||
reasoning_map[obj["id"]] = obj.get("reasoning", "")
|
||||
|
||||
# Build focus explainability data with cross-encoder metadata
|
||||
selected_edges_with_reasoning = []
|
||||
for eid in selected_ids:
|
||||
for s, p, o in selected_edges:
|
||||
eid = edge_id(s, p, o)
|
||||
if eid in uri_map:
|
||||
uri_s, uri_p, uri_o = uri_map[eid]
|
||||
meta = edge_metadata.get(eid, {})
|
||||
selected_edges_with_reasoning.append({
|
||||
"edge": (uri_s, uri_p, uri_o),
|
||||
"reasoning": reasoning_map.get(eid, ""),
|
||||
"concept": meta.get("concept", ""),
|
||||
"score": meta.get("score", 0),
|
||||
})
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Filtered to {len(selected_edges)} edges")
|
||||
|
||||
# Emit focus explain after edge selection completes
|
||||
# Emit focus explain
|
||||
if explain_callback:
|
||||
# Sum scoring + reasoning token usage for focus event
|
||||
focus_in = 0
|
||||
focus_out = 0
|
||||
focus_model = None
|
||||
for r in [scoring_result, reasoning_result]:
|
||||
if r is not None:
|
||||
if r.in_token is not None:
|
||||
focus_in += r.in_token
|
||||
if r.out_token is not None:
|
||||
focus_out += r.out_token
|
||||
if r.model is not None:
|
||||
focus_model = r.model
|
||||
|
||||
foc_triples = set_graph(
|
||||
focus_triples(
|
||||
foc_uri, exp_uri, selected_edges_with_reasoning, session_id,
|
||||
in_token=focus_in or None,
|
||||
out_token=focus_out or None,
|
||||
model=focus_model,
|
||||
foc_uri, exp_uri,
|
||||
selected_edges_with_reasoning, session_id,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(foc_triples, foc_uri)
|
||||
|
||||
# Step 2: Synthesis - LLM generates answer from selected edges only
|
||||
# Step 4: Synthesis
|
||||
selected_edge_dicts = [
|
||||
{"s": s, "p": p, "o": o}
|
||||
for s, p, o in selected_edges
|
||||
]
|
||||
|
||||
# Add source document metadata as knowledge edges
|
||||
for s, p, o in source_documents:
|
||||
selected_edge_dicts.append({
|
||||
"s": s, "p": p, "o": o,
|
||||
|
|
@ -928,7 +716,6 @@ class GraphRag:
|
|||
}
|
||||
|
||||
if streaming and chunk_callback:
|
||||
# Accumulate chunks for answer storage while forwarding to callback
|
||||
accumulated_chunks = []
|
||||
|
||||
async def accumulating_callback(chunk, end_of_stream):
|
||||
|
|
@ -942,7 +729,6 @@ class GraphRag:
|
|||
chunk_callback=accumulating_callback
|
||||
)
|
||||
track_usage(synthesis_result)
|
||||
# Combine all chunks into full response
|
||||
resp = "".join(accumulated_chunks)
|
||||
else:
|
||||
synthesis_result = await self.prompt_client.prompt(
|
||||
|
|
@ -955,29 +741,42 @@ class GraphRag:
|
|||
if self.verbose:
|
||||
logger.debug("Query processing complete")
|
||||
|
||||
# Emit synthesis explain after synthesis completes
|
||||
# Emit synthesis explain
|
||||
if explain_callback:
|
||||
synthesis_doc_id = None
|
||||
answer_text = resp if resp else ""
|
||||
|
||||
# Save answer to librarian
|
||||
if save_answer_callback and answer_text:
|
||||
synthesis_doc_id = f"urn:trustgraph:synthesis:{session_id}"
|
||||
try:
|
||||
await save_answer_callback(synthesis_doc_id, answer_text)
|
||||
if self.verbose:
|
||||
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
|
||||
logger.debug(
|
||||
f"Saved answer to librarian: "
|
||||
f"{synthesis_doc_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save answer to librarian: {e}")
|
||||
logger.warning(
|
||||
f"Failed to save answer to librarian: {e}"
|
||||
)
|
||||
synthesis_doc_id = None
|
||||
|
||||
syn_triples = set_graph(
|
||||
synthesis_triples(
|
||||
syn_uri, foc_uri,
|
||||
document_id=synthesis_doc_id,
|
||||
in_token=synthesis_result.in_token if synthesis_result else None,
|
||||
out_token=synthesis_result.out_token if synthesis_result else None,
|
||||
model=synthesis_result.model if synthesis_result else None,
|
||||
in_token=(
|
||||
synthesis_result.in_token
|
||||
if synthesis_result else None
|
||||
),
|
||||
out_token=(
|
||||
synthesis_result.out_token
|
||||
if synthesis_result else None
|
||||
),
|
||||
model=(
|
||||
synthesis_result.model
|
||||
if synthesis_result else None
|
||||
),
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
|
|
@ -993,4 +792,3 @@ class GraphRag:
|
|||
}
|
||||
|
||||
return resp, usage
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from . graph_rag import GraphRag
|
|||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
||||
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
|
||||
from ... base import RerankerClientSpec
|
||||
from ... base import LibrarianSpec
|
||||
|
||||
# Module logger
|
||||
|
|
@ -32,7 +33,6 @@ class Processor(FlowProcessor):
|
|||
triple_limit = params.get("triple_limit", 30)
|
||||
max_subgraph_size = params.get("max_subgraph_size", 150)
|
||||
max_path_length = params.get("max_path_length", 2)
|
||||
edge_score_limit = params.get("edge_score_limit", 30)
|
||||
edge_limit = params.get("edge_limit", 25)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
|
|
@ -43,7 +43,6 @@ class Processor(FlowProcessor):
|
|||
"triple_limit": triple_limit,
|
||||
"max_subgraph_size": max_subgraph_size,
|
||||
"max_path_length": max_path_length,
|
||||
"edge_score_limit": edge_score_limit,
|
||||
"edge_limit": edge_limit,
|
||||
}
|
||||
)
|
||||
|
|
@ -52,7 +51,6 @@ class Processor(FlowProcessor):
|
|||
self.default_triple_limit = triple_limit
|
||||
self.default_max_subgraph_size = max_subgraph_size
|
||||
self.default_max_path_length = max_path_length
|
||||
self.default_edge_score_limit = edge_score_limit
|
||||
self.default_edge_limit = edge_limit
|
||||
|
||||
# Workspace isolation is enforced by the flow layer (flow.workspace).
|
||||
|
|
@ -96,6 +94,13 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
RerankerClientSpec(
|
||||
request_name = "reranker-request",
|
||||
response_name = "reranker-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "response",
|
||||
|
|
@ -163,6 +168,7 @@ class Processor(FlowProcessor):
|
|||
graph_embeddings_client=flow("graph-embeddings-request"),
|
||||
triples_client=flow("triples-request"),
|
||||
prompt_client=flow("prompt-request"),
|
||||
reranker_client=flow("reranker-request"),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
|
@ -186,11 +192,6 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
max_path_length = self.default_max_path_length
|
||||
|
||||
if v.edge_score_limit:
|
||||
edge_score_limit = v.edge_score_limit
|
||||
else:
|
||||
edge_score_limit = self.default_edge_score_limit
|
||||
|
||||
if v.edge_limit:
|
||||
edge_limit = v.edge_limit
|
||||
else:
|
||||
|
|
@ -225,7 +226,7 @@ class Processor(FlowProcessor):
|
|||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
edge_score_limit = edge_score_limit,
|
||||
|
||||
edge_limit = edge_limit,
|
||||
streaming = True,
|
||||
chunk_callback = send_chunk,
|
||||
|
|
@ -241,7 +242,7 @@ class Processor(FlowProcessor):
|
|||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
edge_score_limit = edge_score_limit,
|
||||
|
||||
edge_limit = edge_limit,
|
||||
explain_callback = send_explainability,
|
||||
save_answer_callback = save_answer,
|
||||
|
|
@ -338,18 +339,11 @@ class Processor(FlowProcessor):
|
|||
help=f'Default max path length (default: 2)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--edge-score-limit',
|
||||
type=int,
|
||||
default=30,
|
||||
help=f'Semantic pre-filter limit before LLM scoring (default: 30)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--edge-limit',
|
||||
type=int,
|
||||
default=25,
|
||||
help=f'Max edges after LLM scoring (default: 25)'
|
||||
help=f'Max edges selected per hop by cross-encoder (default: 25)'
|
||||
)
|
||||
|
||||
# Note: Explainability triples are now stored in the request's collection
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue