feat: replace LLM edge scoring with cross-encoder reranker in GraphRAG (#1005)

Replace the three-prompt LLM scoring pipeline (kg-edge-scoring,
kg-edge-reasoning, kg-edge-selection) with a cross-encoder reranker
service backed by FlashRank. The new hop_and_filter() method performs
iterative graph traversal with semantic scoring at each hop, replacing
the previous follow_edges/get_subgraph approach.

- Add reranker service (trustgraph-base client/service, FlashRank processor)
- Add gateway dispatch for reranker via API and WebSocket
- Rewrite GraphRAG pipeline: hop_and_filter() with per-hop cross-encoder scoring
- Remove kg_prompt() and edge_score_limit from prompt client
- Update provenance: add tg:EdgeSelection type, tg:concept, tg:score predicates
- Update CLIs (tg-invoke-graph-rag, tg-show-explain-trace) for new metadata
- Add tg-invoke-reranker CLI tool
- Add tech spec and UX developer guidance
- Update all unit and integration tests
This commit is contained in:
cybermaggedon 2026-06-30 14:36:37 +01:00 committed by GitHub
parent 1aa9549912
commit 01cc8dbc64
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
43 changed files with 1613 additions and 792 deletions

View file

@ -646,6 +646,16 @@ class AsyncFlowInstance:
return await self.request("embeddings", request_data)
async def rerank(self, queries: list, documents: list, limit: int = 10, **kwargs: Any):
request_data = {
"queries": queries,
"documents": documents,
"limit": limit,
}
request_data.update(kwargs)
return await self.request("reranker", request_data)
async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs: Any):
"""
Query RDF triples using pattern matching.

View file

@ -443,6 +443,19 @@ class AsyncSocketFlowInstance:
return await self.client._send_request("embeddings", self.flow_id, request)
async def rerank(self, queries: list, documents: list, limit: int = 10,
**kwargs):
request = {
"queries": queries,
"documents": documents,
"limit": limit,
}
request.update(kwargs)
return await self.client._send_request(
"reranker", self.flow_id, request,
)
async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs):
"""Triple pattern query"""
request = {"limit": limit}

View file

@ -18,6 +18,7 @@ TG_EDGE_COUNT = TG + "edgeCount"
TG_SELECTED_EDGE = TG + "selectedEdge"
TG_EDGE = TG + "edge"
TG_REASONING = TG + "reasoning"
TG_SCORE = TG + "score"
TG_DOCUMENT = TG + "document"
TG_CONCEPT = TG + "concept"
TG_ENTITY = TG + "entity"
@ -66,10 +67,12 @@ RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
@dataclass
class EdgeSelection:
"""A selected edge with reasoning from GraphRAG Focus step."""
"""A selected edge with cross-encoder metadata from GraphRAG Focus step."""
uri: str
edge: Optional[Dict[str, str]] = None # {"s": ..., "p": ..., "o": ...}
reasoning: str = ""
concept: str = ""
score: Optional[float] = None
@dataclass
@ -209,7 +212,7 @@ class Exploration(ExplainEntity):
@dataclass
class Focus(ExplainEntity):
"""Focus entity - selected edges with LLM reasoning (GraphRAG only)."""
"""Focus entity - selected edges with cross-encoder scoring (GraphRAG only)."""
selected_edge_uris: List[str] = field(default_factory=list)
edge_selections: List[EdgeSelection] = field(default_factory=list)
@ -418,14 +421,26 @@ def parse_edge_selection_triples(triples: List[Tuple[str, str, Any]]) -> EdgeSel
uri = triples[0][0] if triples else ""
edge = None
reasoning = ""
concept = ""
score = None
for s, p, o in triples:
if p == TG_EDGE and isinstance(o, dict):
edge = o
elif p == TG_REASONING:
reasoning = o
elif p == TG_CONCEPT:
concept = o
elif p == TG_SCORE:
try:
score = float(o)
except (ValueError, TypeError):
score = None
return EdgeSelection(uri=uri, edge=edge, reasoning=reasoning)
return EdgeSelection(
uri=uri, edge=edge, reasoning=reasoning,
concept=concept, score=score,
)
def extract_term_value(term: Dict[str, Any]) -> Any:

View file

@ -491,6 +491,19 @@ class FlowInstance:
input
)["vectors"]
def rerank(self, queries, documents, limit=10):
input = {
"queries": queries,
"documents": documents,
"limit": limit,
}
return self.request(
"service/reranker",
input
)
def graph_embeddings_query(self, text, collection, limit=10):
"""
Query knowledge graph entities using semantic similarity.

View file

@ -885,6 +885,19 @@ class SocketFlowInstance:
return self.client._send_request_sync("embeddings", self.flow_id, request, False)
def rerank(self, queries: list, documents: list, limit: int = 10,
**kwargs: Any) -> Dict[str, Any]:
request = {
"queries": queries,
"documents": documents,
"limit": limit,
}
request.update(kwargs)
return self.client._send_request_sync(
"reranker", self.flow_id, request, False,
)
def triples_query(
self,
s: Optional[Union[str, Dict[str, Any]]] = None,

View file

@ -42,6 +42,8 @@ from . dynamic_tool_service import DynamicToolService
from . tool_service_client import ToolServiceClientSpec
from . agent_client import AgentClientSpec
from . structured_query_client import StructuredQueryClientSpec
from . reranker_client import RerankerClientSpec
from . reranker_service import RerankerService
from . row_embeddings_query_client import RowEmbeddingsQueryClientSpec
from . collection_config_handler import CollectionConfigHandler

View file

@ -157,21 +157,6 @@ class PromptClient(RequestResponse):
timeout = timeout,
)
async def kg_prompt(self, query, kg, timeout=600, streaming=False, chunk_callback=None):
return await self.prompt(
id = "kg-prompt",
variables = {
"query": query,
"knowledge": [
{ "s": v[0], "p": v[1], "o": v[2] }
for v in kg
]
},
timeout = timeout,
streaming = streaming,
chunk_callback = chunk_callback,
)
async def document_prompt(self, query, documents, timeout=600, streaming=False, chunk_callback=None):
return await self.prompt(
id = "document-prompt",

View file

@ -0,0 +1,43 @@
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import (
RerankerRequest, RerankerResponse,
RerankerQuery, RerankerDocument,
)
class RerankerClient(RequestResponse):
async def rerank(self, queries, documents, limit=10, timeout=300):
resp = await self.request(
RerankerRequest(
queries=[
RerankerQuery(query_id=q["id"], query_text=q["text"])
for q in queries
],
documents=[
RerankerDocument(
document_id=d["id"], document_text=d["text"]
)
for d in documents
],
limit=limit,
),
timeout=timeout
)
if resp.error:
raise RuntimeError(resp.error.message)
return resp.results
class RerankerClientSpec(RequestResponseSpec):
def __init__(
self, request_name, response_name,
):
super(RerankerClientSpec, self).__init__(
request_name = request_name,
request_schema = RerankerRequest,
response_name = response_name,
response_schema = RerankerResponse,
impl = RerankerClient,
)

View file

@ -0,0 +1,109 @@
from __future__ import annotations
from argparse import ArgumentParser
import logging
from .. schema import (
RerankerRequest, RerankerResponse, RerankerResult, Error,
)
from .. exceptions import TooManyRequests
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec, ParameterSpec
logger = logging.getLogger(__name__)
default_ident = "reranker"
default_concurrency = 1
class RerankerService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
concurrency = params.get("concurrency", 1)
super(RerankerService, self).__init__(**params | {
"id": id,
"concurrency": concurrency,
})
self.register_specification(
ConsumerSpec(
name = "request",
schema = RerankerRequest,
handler = self.on_request,
concurrency = concurrency,
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = RerankerResponse
)
)
self.register_specification(
ParameterSpec(
name = "model",
)
)
async def on_request(self, msg, consumer, flow):
try:
request = msg.value()
id = msg.properties()["id"]
logger.debug(f"Handling reranker request {id}...")
model = flow("model")
results = await self.on_rerank(
request.queries, request.documents,
request.limit, model=model,
)
await flow("response").send(
RerankerResponse(
error = None,
results = results,
),
properties={"id": id}
)
logger.debug("Reranker request handled successfully")
except TooManyRequests as e:
raise e
except Exception as e:
logger.error(f"Exception in reranker service: {e}", exc_info=True)
logger.info("Sending error response...")
await flow.producer["response"].send(
RerankerResponse(
error=Error(
type = "reranker-error",
message = str(e),
),
results=[],
),
properties={"id": id}
)
@staticmethod
def add_args(parser: ArgumentParser) -> None:
parser.add_argument(
'-c', '--concurrency',
type=int,
default=default_concurrency,
help=f'Concurrent processing threads (default: {default_concurrency})'
)
FlowProcessor.add_args(parser)

View file

@ -140,20 +140,6 @@ class PromptClient(BaseClient):
timeout=timeout
)
def request_kg_prompt(self, query, kg, timeout=300):
return self.request(
id="kg-prompt",
variables={
"query": query,
"knowledge": [
{ "s": v[0], "p": v[1], "o": v[2] }
for v in kg
]
},
timeout=timeout
)
def request_document_prompt(self, query, documents, timeout=300):
return self.request(

View file

@ -27,6 +27,7 @@ from .translators.rows_query import RowsQueryRequestTranslator, RowsQueryRespons
from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator
from .translators.structured_query import StructuredQueryRequestTranslator, StructuredQueryResponseTranslator
from .translators.diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
from .translators.reranker import RerankerRequestTranslator, RerankerResponseTranslator
from .translators.collection import CollectionManagementRequestTranslator, CollectionManagementResponseTranslator
from .translators.sparql_query import SparqlQueryRequestTranslator, SparqlQueryResponseTranslator
@ -163,6 +164,12 @@ TranslatorRegistry.register_service(
SparqlQueryResponseTranslator()
)
TranslatorRegistry.register_service(
"reranker",
RerankerRequestTranslator(),
RerankerResponseTranslator()
)
# Register single-direction translators for document loading
TranslatorRegistry.register_request("document", DocumentTranslator())
TranslatorRegistry.register_request("text-document", TextDocumentTranslator())

View file

@ -20,3 +20,4 @@ from .embeddings_query import (
)
from .rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
from .reranker import RerankerRequestTranslator, RerankerResponseTranslator

View file

@ -0,0 +1,73 @@
from typing import Dict, Any, Tuple
from ...schema import (
RerankerRequest, RerankerResponse,
RerankerQuery, RerankerDocument, RerankerResult,
)
from .base import MessageTranslator
class RerankerRequestTranslator(MessageTranslator):
def decode(self, data: Dict[str, Any]) -> RerankerRequest:
return RerankerRequest(
queries=[
RerankerQuery(
query_id=q["query_id"],
query_text=q["query_text"],
)
for q in data.get("queries", [])
],
documents=[
RerankerDocument(
document_id=d["document_id"],
document_text=d["document_text"],
)
for d in data.get("documents", [])
],
limit=data.get("limit", 10),
)
def encode(self, obj: RerankerRequest) -> Dict[str, Any]:
return {
"queries": [
{"query_id": q.query_id, "query_text": q.query_text}
for q in obj.queries
],
"documents": [
{"document_id": d.document_id, "document_text": d.document_text}
for d in obj.documents
],
"limit": obj.limit,
}
class RerankerResponseTranslator(MessageTranslator):
def decode(self, data: Dict[str, Any]) -> RerankerResponse:
return RerankerResponse(
results=[
RerankerResult(
document_id=r["document_id"],
query_id=r["query_id"],
score=r["score"],
)
for r in data.get("results", [])
],
)
def encode(self, obj: RerankerResponse) -> Dict[str, Any]:
return {
"results": [
{
"document_id": r.document_id,
"query_id": r.query_id,
"score": r.score,
}
for r in obj.results
],
}
def encode_with_completion(
self, obj: RerankerResponse
) -> Tuple[Dict[str, Any], bool]:
return self.encode(obj), True

View file

@ -89,7 +89,9 @@ from . namespaces import (
TG_IMAGE_TYPE, TG_SUBGRAPH_TYPE,
# Query-time provenance predicates (GraphRAG)
TG_QUERY, TG_CONCEPT, TG_ENTITY,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING, TG_SCORE,
# Edge selection entity type
TG_EDGE_SELECTION,
# Query-time provenance predicates (DocumentRAG)
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
# Explainability entity types
@ -212,7 +214,9 @@ __all__ = [
"TG_CHUNK_TYPE", "TG_IMAGE_TYPE", "TG_SUBGRAPH_TYPE",
# Query-time provenance predicates (GraphRAG)
"TG_QUERY", "TG_CONCEPT", "TG_ENTITY",
"TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING",
"TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING", "TG_SCORE",
# Edge selection entity type
"TG_EDGE_SELECTION",
# Query-time provenance predicates (DocumentRAG)
"TG_CHUNK_COUNT", "TG_SELECTED_CHUNK",
# Explainability entity types

View file

@ -66,8 +66,12 @@ TG_EDGE_COUNT = TG + "edgeCount"
TG_SELECTED_EDGE = TG + "selectedEdge"
TG_EDGE = TG + "edge"
TG_REASONING = TG + "reasoning"
TG_SCORE = TG + "score"
TG_DOCUMENT = TG + "document" # Reference to document in librarian
# Edge selection entity type (cross-encoder scored edge in Focus)
TG_EDGE_SELECTION = TG + "EdgeSelection"
# Query-time provenance predicates (DocumentRAG)
TG_CHUNK_COUNT = TG + "chunkCount"
TG_SELECTED_CHUNK = TG + "selectedChunk"

View file

@ -24,8 +24,10 @@ from . namespaces import (
TG_ELEMENT_TYPES, TG_TABLE_COUNT, TG_IMAGE_COUNT,
# Query-time provenance predicates (GraphRAG)
TG_QUERY, TG_CONCEPT, TG_ENTITY,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, TG_SCORE,
TG_DOCUMENT,
# Edge selection entity type
TG_EDGE_SELECTION,
# Query-time provenance predicates (DocumentRAG)
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
# Explainability entity types
@ -536,10 +538,9 @@ def focus_triples(
_triple(focus_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)),
]
# Add each selected edge with its reasoning via intermediate entity
# Add each selected edge with metadata via intermediate entity
for idx, edge_info in enumerate(selected_edges_with_reasoning):
edge = edge_info.get("edge")
reasoning = edge_info.get("reasoning", "")
if edge:
s, p, o = edge
@ -552,13 +553,32 @@ def focus_triples(
_triple(focus_uri, TG_SELECTED_EDGE, _iri(edge_sel_uri))
)
# Type the edge selection entity
triples.append(
_triple(edge_sel_uri, RDF_TYPE, _iri(TG_EDGE_SELECTION))
)
# Attach quoted triple to edge selection entity
quoted = _quoted_triple(s, p, o)
triples.append(
Triple(s=_iri(edge_sel_uri), p=_iri(TG_EDGE), o=quoted)
)
# Attach reasoning to edge selection entity
# Structured cross-encoder metadata
concept = edge_info.get("concept")
if concept:
triples.append(
_triple(edge_sel_uri, TG_CONCEPT, _literal(concept))
)
score = edge_info.get("score")
if score is not None:
triples.append(
_triple(edge_sel_uri, TG_SCORE, _literal(str(score)))
)
# Legacy reasoning text (for non-cross-encoder callers)
reasoning = edge_info.get("reasoning", "")
if reasoning:
triples.append(
_triple(edge_sel_uri, TG_REASONING, _literal(reasoning))

View file

@ -29,6 +29,7 @@ from . namespaces import (
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
TG_SUBAGENT_GOAL, TG_PLAN_STEP,
TG_EDGE_SELECTION, TG_SCORE,
)
@ -93,6 +94,7 @@ TG_CLASS_LABELS = [
_label_triple(TG_FINDING, "Finding"),
_label_triple(TG_PLAN_TYPE, "Plan"),
_label_triple(TG_STEP_RESULT, "Step Result"),
_label_triple(TG_EDGE_SELECTION, "Edge Selection"),
]
# TrustGraph predicate labels
@ -117,6 +119,7 @@ TG_PREDICATE_LABELS = [
_label_triple(TG_ENTITY, "entity"),
_label_triple(TG_SUBAGENT_GOAL, "subagent goal"),
_label_triple(TG_PLAN_STEP, "plan step"),
_label_triple(TG_SCORE, "score"),
]

View file

@ -15,4 +15,5 @@ from .diagnosis import *
from .collection import *
from .storage import *
from .tool_service import *
from .sparql_query import *
from .sparql_query import *
from .reranker import *

View file

@ -6,17 +6,6 @@ from ..core.primitives import Error
# Prompt services, abstract the prompt generation
# extract-definitions:
# chunk -> definitions
# extract-relationships:
# chunk -> relationships
# kg-prompt:
# query, triples -> answer
# document-prompt:
# query, documents -> answer
# extract-rows
# schema, chunk -> rows
@dataclass
class PromptRequest:
id: str = ""
@ -46,4 +35,4 @@ class PromptResponse:
out_token: int | None = None
model: str | None = None
############################################################################
############################################################################

View file

@ -0,0 +1,35 @@
from dataclasses import dataclass, field
from ..core.primitives import Error
############################################################################
# Cross-encoder reranker
@dataclass
class RerankerQuery:
query_id: str = ""
query_text: str = ""
@dataclass
class RerankerDocument:
document_id: str = ""
document_text: str = ""
@dataclass
class RerankerRequest:
queries: list[RerankerQuery] = field(default_factory=list)
documents: list[RerankerDocument] = field(default_factory=list)
limit: int = 10
@dataclass
class RerankerResult:
document_id: str = ""
query_id: str = ""
score: float = 0.0
@dataclass
class RerankerResponse:
error: Error | None = None
results: list[RerankerResult] = field(default_factory=list)