mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-01 09:29:38 +02:00
Replace the three-prompt LLM scoring pipeline (kg-edge-scoring, kg-edge-reasoning, kg-edge-selection) with a cross-encoder reranker service backed by FlashRank. The new hop_and_filter() method performs iterative graph traversal with semantic scoring at each hop, replacing the previous follow_edges/get_subgraph approach. - Add reranker service (trustgraph-base client/service, FlashRank processor) - Add gateway dispatch for reranker via API and WebSocket - Rewrite GraphRAG pipeline: hop_and_filter() with per-hop cross-encoder scoring - Remove kg_prompt() and edge_score_limit from prompt client - Update provenance: add tg:EdgeSelection type, tg:concept, tg:score predicates - Update CLIs (tg-invoke-graph-rag, tg-show-explain-trace) for new metadata - Add tg-invoke-reranker CLI tool - Add tech spec and UX developer guidance - Update all unit and integration tests
109 lines
2.8 KiB
Python
109 lines
2.8 KiB
Python
|
|
"""
|
|
Reranker service using flashrank.
|
|
Scores query-document pairs and returns the top results ranked by
|
|
relevance.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
|
|
from ... base import RerankerService
|
|
from ... schema import RerankerResult
|
|
|
|
from flashrank import Ranker, RerankRequest
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
default_ident = "reranker"
|
|
|
|
default_model = "ms-marco-MiniLM-L-12-v2"
|
|
|
|
class Processor(RerankerService):
|
|
|
|
def __init__(self, **params):
|
|
|
|
model = params.get("model", default_model)
|
|
|
|
super(Processor, self).__init__(
|
|
**params | { "model": model }
|
|
)
|
|
|
|
self.default_model = model
|
|
|
|
self.cached_model_name = None
|
|
self.ranker = None
|
|
|
|
self._load_model(model)
|
|
|
|
def _load_model(self, model_name):
|
|
if self.cached_model_name != model_name:
|
|
logger.info(f"Loading flashrank model: {model_name}")
|
|
self.ranker = Ranker(model_name=model_name)
|
|
self.cached_model_name = model_name
|
|
logger.info(f"flashrank model {model_name} loaded successfully")
|
|
else:
|
|
logger.debug(f"Using cached model: {model_name}")
|
|
|
|
def _run_rerank(self, query, passages):
|
|
request = RerankRequest(query=query, passages=passages)
|
|
return self.ranker.rerank(request)
|
|
|
|
async def on_rerank(self, queries, documents, limit, model=None):
|
|
|
|
if not queries or not documents:
|
|
return []
|
|
|
|
use_model = model or self.default_model
|
|
|
|
if self.cached_model_name != use_model:
|
|
await asyncio.to_thread(self._load_model, use_model)
|
|
|
|
passages = [
|
|
{"id": d.document_id, "text": d.document_text}
|
|
for d in documents
|
|
]
|
|
|
|
best_scores = {}
|
|
|
|
for q in queries:
|
|
ranked = await asyncio.to_thread(
|
|
self._run_rerank, q.query_text, passages,
|
|
)
|
|
|
|
for r in ranked:
|
|
doc_id = r["id"]
|
|
score = r["score"]
|
|
score = float(score)
|
|
if doc_id not in best_scores or score > best_scores[doc_id][1]:
|
|
best_scores[doc_id] = (q.query_id, score)
|
|
|
|
results = sorted(
|
|
best_scores.items(),
|
|
key=lambda x: x[1][1],
|
|
reverse=True,
|
|
)[:limit]
|
|
|
|
return [
|
|
RerankerResult(
|
|
document_id=doc_id,
|
|
query_id=query_id,
|
|
score=score,
|
|
)
|
|
for doc_id, (query_id, score) in results
|
|
]
|
|
|
|
@staticmethod
|
|
def add_args(parser):
|
|
|
|
RerankerService.add_args(parser)
|
|
|
|
parser.add_argument(
|
|
'-m', '--model',
|
|
default=default_model,
|
|
help=f'Reranker model (default: {default_model})'
|
|
)
|
|
|
|
def run():
|
|
|
|
Processor.launch(default_ident, __doc__)
|