mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-01 09:29:38 +02:00
feat: replace LLM edge scoring with cross-encoder reranker in GraphRAG (#1005)
Replace the three-prompt LLM scoring pipeline (kg-edge-scoring, kg-edge-reasoning, kg-edge-selection) with a cross-encoder reranker service backed by FlashRank. The new hop_and_filter() method performs iterative graph traversal with semantic scoring at each hop, replacing the previous follow_edges/get_subgraph approach. - Add reranker service (trustgraph-base client/service, FlashRank processor) - Add gateway dispatch for reranker via API and WebSocket - Rewrite GraphRAG pipeline: hop_and_filter() with per-hop cross-encoder scoring - Remove kg_prompt() and edge_score_limit from prompt client - Update provenance: add tg:EdgeSelection type, tg:concept, tg:score predicates - Update CLIs (tg-invoke-graph-rag, tg-show-explain-trace) for new metadata - Add tg-invoke-reranker CLI tool - Add tech spec and UX developer guidance - Update all unit and integration tests
This commit is contained in:
parent
1aa9549912
commit
01cc8dbc64
43 changed files with 1613 additions and 792 deletions
|
|
@ -646,6 +646,16 @@ class AsyncFlowInstance:
|
|||
|
||||
return await self.request("embeddings", request_data)
|
||||
|
||||
async def rerank(self, queries: list, documents: list, limit: int = 10, **kwargs: Any):
|
||||
request_data = {
|
||||
"queries": queries,
|
||||
"documents": documents,
|
||||
"limit": limit,
|
||||
}
|
||||
request_data.update(kwargs)
|
||||
|
||||
return await self.request("reranker", request_data)
|
||||
|
||||
async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs: Any):
|
||||
"""
|
||||
Query RDF triples using pattern matching.
|
||||
|
|
|
|||
|
|
@ -443,6 +443,19 @@ class AsyncSocketFlowInstance:
|
|||
|
||||
return await self.client._send_request("embeddings", self.flow_id, request)
|
||||
|
||||
async def rerank(self, queries: list, documents: list, limit: int = 10,
|
||||
**kwargs):
|
||||
request = {
|
||||
"queries": queries,
|
||||
"documents": documents,
|
||||
"limit": limit,
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
return await self.client._send_request(
|
||||
"reranker", self.flow_id, request,
|
||||
)
|
||||
|
||||
async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs):
|
||||
"""Triple pattern query"""
|
||||
request = {"limit": limit}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ TG_EDGE_COUNT = TG + "edgeCount"
|
|||
TG_SELECTED_EDGE = TG + "selectedEdge"
|
||||
TG_EDGE = TG + "edge"
|
||||
TG_REASONING = TG + "reasoning"
|
||||
TG_SCORE = TG + "score"
|
||||
TG_DOCUMENT = TG + "document"
|
||||
TG_CONCEPT = TG + "concept"
|
||||
TG_ENTITY = TG + "entity"
|
||||
|
|
@ -66,10 +67,12 @@ RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
|
|||
|
||||
@dataclass
|
||||
class EdgeSelection:
|
||||
"""A selected edge with reasoning from GraphRAG Focus step."""
|
||||
"""A selected edge with cross-encoder metadata from GraphRAG Focus step."""
|
||||
uri: str
|
||||
edge: Optional[Dict[str, str]] = None # {"s": ..., "p": ..., "o": ...}
|
||||
reasoning: str = ""
|
||||
concept: str = ""
|
||||
score: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -209,7 +212,7 @@ class Exploration(ExplainEntity):
|
|||
|
||||
@dataclass
|
||||
class Focus(ExplainEntity):
|
||||
"""Focus entity - selected edges with LLM reasoning (GraphRAG only)."""
|
||||
"""Focus entity - selected edges with cross-encoder scoring (GraphRAG only)."""
|
||||
selected_edge_uris: List[str] = field(default_factory=list)
|
||||
edge_selections: List[EdgeSelection] = field(default_factory=list)
|
||||
|
||||
|
|
@ -418,14 +421,26 @@ def parse_edge_selection_triples(triples: List[Tuple[str, str, Any]]) -> EdgeSel
|
|||
uri = triples[0][0] if triples else ""
|
||||
edge = None
|
||||
reasoning = ""
|
||||
concept = ""
|
||||
score = None
|
||||
|
||||
for s, p, o in triples:
|
||||
if p == TG_EDGE and isinstance(o, dict):
|
||||
edge = o
|
||||
elif p == TG_REASONING:
|
||||
reasoning = o
|
||||
elif p == TG_CONCEPT:
|
||||
concept = o
|
||||
elif p == TG_SCORE:
|
||||
try:
|
||||
score = float(o)
|
||||
except (ValueError, TypeError):
|
||||
score = None
|
||||
|
||||
return EdgeSelection(uri=uri, edge=edge, reasoning=reasoning)
|
||||
return EdgeSelection(
|
||||
uri=uri, edge=edge, reasoning=reasoning,
|
||||
concept=concept, score=score,
|
||||
)
|
||||
|
||||
|
||||
def extract_term_value(term: Dict[str, Any]) -> Any:
|
||||
|
|
|
|||
|
|
@ -491,6 +491,19 @@ class FlowInstance:
|
|||
input
|
||||
)["vectors"]
|
||||
|
||||
def rerank(self, queries, documents, limit=10):
|
||||
|
||||
input = {
|
||||
"queries": queries,
|
||||
"documents": documents,
|
||||
"limit": limit,
|
||||
}
|
||||
|
||||
return self.request(
|
||||
"service/reranker",
|
||||
input
|
||||
)
|
||||
|
||||
def graph_embeddings_query(self, text, collection, limit=10):
|
||||
"""
|
||||
Query knowledge graph entities using semantic similarity.
|
||||
|
|
|
|||
|
|
@ -885,6 +885,19 @@ class SocketFlowInstance:
|
|||
|
||||
return self.client._send_request_sync("embeddings", self.flow_id, request, False)
|
||||
|
||||
def rerank(self, queries: list, documents: list, limit: int = 10,
|
||||
**kwargs: Any) -> Dict[str, Any]:
|
||||
request = {
|
||||
"queries": queries,
|
||||
"documents": documents,
|
||||
"limit": limit,
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
return self.client._send_request_sync(
|
||||
"reranker", self.flow_id, request, False,
|
||||
)
|
||||
|
||||
def triples_query(
|
||||
self,
|
||||
s: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
|
|
|
|||
|
|
@ -42,6 +42,8 @@ from . dynamic_tool_service import DynamicToolService
|
|||
from . tool_service_client import ToolServiceClientSpec
|
||||
from . agent_client import AgentClientSpec
|
||||
from . structured_query_client import StructuredQueryClientSpec
|
||||
from . reranker_client import RerankerClientSpec
|
||||
from . reranker_service import RerankerService
|
||||
from . row_embeddings_query_client import RowEmbeddingsQueryClientSpec
|
||||
from . collection_config_handler import CollectionConfigHandler
|
||||
|
||||
|
|
|
|||
|
|
@ -157,21 +157,6 @@ class PromptClient(RequestResponse):
|
|||
timeout = timeout,
|
||||
)
|
||||
|
||||
async def kg_prompt(self, query, kg, timeout=600, streaming=False, chunk_callback=None):
|
||||
return await self.prompt(
|
||||
id = "kg-prompt",
|
||||
variables = {
|
||||
"query": query,
|
||||
"knowledge": [
|
||||
{ "s": v[0], "p": v[1], "o": v[2] }
|
||||
for v in kg
|
||||
]
|
||||
},
|
||||
timeout = timeout,
|
||||
streaming = streaming,
|
||||
chunk_callback = chunk_callback,
|
||||
)
|
||||
|
||||
async def document_prompt(self, query, documents, timeout=600, streaming=False, chunk_callback=None):
|
||||
return await self.prompt(
|
||||
id = "document-prompt",
|
||||
|
|
|
|||
43
trustgraph-base/trustgraph/base/reranker_client.py
Normal file
43
trustgraph-base/trustgraph/base/reranker_client.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
|
||||
from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||
from .. schema import (
|
||||
RerankerRequest, RerankerResponse,
|
||||
RerankerQuery, RerankerDocument,
|
||||
)
|
||||
|
||||
class RerankerClient(RequestResponse):
|
||||
async def rerank(self, queries, documents, limit=10, timeout=300):
|
||||
|
||||
resp = await self.request(
|
||||
RerankerRequest(
|
||||
queries=[
|
||||
RerankerQuery(query_id=q["id"], query_text=q["text"])
|
||||
for q in queries
|
||||
],
|
||||
documents=[
|
||||
RerankerDocument(
|
||||
document_id=d["id"], document_text=d["text"]
|
||||
)
|
||||
for d in documents
|
||||
],
|
||||
limit=limit,
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
return resp.results
|
||||
|
||||
class RerankerClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
self, request_name, response_name,
|
||||
):
|
||||
super(RerankerClientSpec, self).__init__(
|
||||
request_name = request_name,
|
||||
request_schema = RerankerRequest,
|
||||
response_name = response_name,
|
||||
response_schema = RerankerResponse,
|
||||
impl = RerankerClient,
|
||||
)
|
||||
109
trustgraph-base/trustgraph/base/reranker_service.py
Normal file
109
trustgraph-base/trustgraph/base/reranker_service.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import logging
|
||||
|
||||
from .. schema import (
|
||||
RerankerRequest, RerankerResponse, RerankerResult, Error,
|
||||
)
|
||||
from .. exceptions import TooManyRequests
|
||||
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec, ParameterSpec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "reranker"
|
||||
default_concurrency = 1
|
||||
|
||||
class RerankerService(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
concurrency = params.get("concurrency", 1)
|
||||
|
||||
super(RerankerService, self).__init__(**params | {
|
||||
"id": id,
|
||||
"concurrency": concurrency,
|
||||
})
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "request",
|
||||
schema = RerankerRequest,
|
||||
handler = self.on_request,
|
||||
concurrency = concurrency,
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "response",
|
||||
schema = RerankerResponse
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ParameterSpec(
|
||||
name = "model",
|
||||
)
|
||||
)
|
||||
|
||||
async def on_request(self, msg, consumer, flow):
|
||||
|
||||
try:
|
||||
|
||||
request = msg.value()
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
logger.debug(f"Handling reranker request {id}...")
|
||||
|
||||
model = flow("model")
|
||||
results = await self.on_rerank(
|
||||
request.queries, request.documents,
|
||||
request.limit, model=model,
|
||||
)
|
||||
|
||||
await flow("response").send(
|
||||
RerankerResponse(
|
||||
error = None,
|
||||
results = results,
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
logger.debug("Reranker request handled successfully")
|
||||
|
||||
except TooManyRequests as e:
|
||||
raise e
|
||||
|
||||
except Exception as e:
|
||||
|
||||
logger.error(f"Exception in reranker service: {e}", exc_info=True)
|
||||
|
||||
logger.info("Sending error response...")
|
||||
|
||||
await flow.producer["response"].send(
|
||||
RerankerResponse(
|
||||
error=Error(
|
||||
type = "reranker-error",
|
||||
message = str(e),
|
||||
),
|
||||
results=[],
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser: ArgumentParser) -> None:
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--concurrency',
|
||||
type=int,
|
||||
default=default_concurrency,
|
||||
help=f'Concurrent processing threads (default: {default_concurrency})'
|
||||
)
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
|
@ -140,20 +140,6 @@ class PromptClient(BaseClient):
|
|||
timeout=timeout
|
||||
)
|
||||
|
||||
def request_kg_prompt(self, query, kg, timeout=300):
|
||||
|
||||
return self.request(
|
||||
id="kg-prompt",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": [
|
||||
{ "s": v[0], "p": v[1], "o": v[2] }
|
||||
for v in kg
|
||||
]
|
||||
},
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
def request_document_prompt(self, query, documents, timeout=300):
|
||||
|
||||
return self.request(
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from .translators.rows_query import RowsQueryRequestTranslator, RowsQueryRespons
|
|||
from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator
|
||||
from .translators.structured_query import StructuredQueryRequestTranslator, StructuredQueryResponseTranslator
|
||||
from .translators.diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
|
||||
from .translators.reranker import RerankerRequestTranslator, RerankerResponseTranslator
|
||||
from .translators.collection import CollectionManagementRequestTranslator, CollectionManagementResponseTranslator
|
||||
from .translators.sparql_query import SparqlQueryRequestTranslator, SparqlQueryResponseTranslator
|
||||
|
||||
|
|
@ -163,6 +164,12 @@ TranslatorRegistry.register_service(
|
|||
SparqlQueryResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"reranker",
|
||||
RerankerRequestTranslator(),
|
||||
RerankerResponseTranslator()
|
||||
)
|
||||
|
||||
# Register single-direction translators for document loading
|
||||
TranslatorRegistry.register_request("document", DocumentTranslator())
|
||||
TranslatorRegistry.register_request("text-document", TextDocumentTranslator())
|
||||
|
|
|
|||
|
|
@ -20,3 +20,4 @@ from .embeddings_query import (
|
|||
)
|
||||
from .rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
|
||||
from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
|
||||
from .reranker import RerankerRequestTranslator, RerankerResponseTranslator
|
||||
|
|
|
|||
73
trustgraph-base/trustgraph/messaging/translators/reranker.py
Normal file
73
trustgraph-base/trustgraph/messaging/translators/reranker.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
from typing import Dict, Any, Tuple
|
||||
from ...schema import (
|
||||
RerankerRequest, RerankerResponse,
|
||||
RerankerQuery, RerankerDocument, RerankerResult,
|
||||
)
|
||||
from .base import MessageTranslator
|
||||
|
||||
|
||||
class RerankerRequestTranslator(MessageTranslator):
|
||||
|
||||
def decode(self, data: Dict[str, Any]) -> RerankerRequest:
|
||||
return RerankerRequest(
|
||||
queries=[
|
||||
RerankerQuery(
|
||||
query_id=q["query_id"],
|
||||
query_text=q["query_text"],
|
||||
)
|
||||
for q in data.get("queries", [])
|
||||
],
|
||||
documents=[
|
||||
RerankerDocument(
|
||||
document_id=d["document_id"],
|
||||
document_text=d["document_text"],
|
||||
)
|
||||
for d in data.get("documents", [])
|
||||
],
|
||||
limit=data.get("limit", 10),
|
||||
)
|
||||
|
||||
def encode(self, obj: RerankerRequest) -> Dict[str, Any]:
|
||||
return {
|
||||
"queries": [
|
||||
{"query_id": q.query_id, "query_text": q.query_text}
|
||||
for q in obj.queries
|
||||
],
|
||||
"documents": [
|
||||
{"document_id": d.document_id, "document_text": d.document_text}
|
||||
for d in obj.documents
|
||||
],
|
||||
"limit": obj.limit,
|
||||
}
|
||||
|
||||
|
||||
class RerankerResponseTranslator(MessageTranslator):
|
||||
|
||||
def decode(self, data: Dict[str, Any]) -> RerankerResponse:
|
||||
return RerankerResponse(
|
||||
results=[
|
||||
RerankerResult(
|
||||
document_id=r["document_id"],
|
||||
query_id=r["query_id"],
|
||||
score=r["score"],
|
||||
)
|
||||
for r in data.get("results", [])
|
||||
],
|
||||
)
|
||||
|
||||
def encode(self, obj: RerankerResponse) -> Dict[str, Any]:
|
||||
return {
|
||||
"results": [
|
||||
{
|
||||
"document_id": r.document_id,
|
||||
"query_id": r.query_id,
|
||||
"score": r.score,
|
||||
}
|
||||
for r in obj.results
|
||||
],
|
||||
}
|
||||
|
||||
def encode_with_completion(
|
||||
self, obj: RerankerResponse
|
||||
) -> Tuple[Dict[str, Any], bool]:
|
||||
return self.encode(obj), True
|
||||
|
|
@ -89,7 +89,9 @@ from . namespaces import (
|
|||
TG_IMAGE_TYPE, TG_SUBGRAPH_TYPE,
|
||||
# Query-time provenance predicates (GraphRAG)
|
||||
TG_QUERY, TG_CONCEPT, TG_ENTITY,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING, TG_SCORE,
|
||||
# Edge selection entity type
|
||||
TG_EDGE_SELECTION,
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
|
||||
# Explainability entity types
|
||||
|
|
@ -212,7 +214,9 @@ __all__ = [
|
|||
"TG_CHUNK_TYPE", "TG_IMAGE_TYPE", "TG_SUBGRAPH_TYPE",
|
||||
# Query-time provenance predicates (GraphRAG)
|
||||
"TG_QUERY", "TG_CONCEPT", "TG_ENTITY",
|
||||
"TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING",
|
||||
"TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING", "TG_SCORE",
|
||||
# Edge selection entity type
|
||||
"TG_EDGE_SELECTION",
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
"TG_CHUNK_COUNT", "TG_SELECTED_CHUNK",
|
||||
# Explainability entity types
|
||||
|
|
|
|||
|
|
@ -66,8 +66,12 @@ TG_EDGE_COUNT = TG + "edgeCount"
|
|||
TG_SELECTED_EDGE = TG + "selectedEdge"
|
||||
TG_EDGE = TG + "edge"
|
||||
TG_REASONING = TG + "reasoning"
|
||||
TG_SCORE = TG + "score"
|
||||
TG_DOCUMENT = TG + "document" # Reference to document in librarian
|
||||
|
||||
# Edge selection entity type (cross-encoder scored edge in Focus)
|
||||
TG_EDGE_SELECTION = TG + "EdgeSelection"
|
||||
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
TG_CHUNK_COUNT = TG + "chunkCount"
|
||||
TG_SELECTED_CHUNK = TG + "selectedChunk"
|
||||
|
|
|
|||
|
|
@ -24,8 +24,10 @@ from . namespaces import (
|
|||
TG_ELEMENT_TYPES, TG_TABLE_COUNT, TG_IMAGE_COUNT,
|
||||
# Query-time provenance predicates (GraphRAG)
|
||||
TG_QUERY, TG_CONCEPT, TG_ENTITY,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, TG_SCORE,
|
||||
TG_DOCUMENT,
|
||||
# Edge selection entity type
|
||||
TG_EDGE_SELECTION,
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
|
||||
# Explainability entity types
|
||||
|
|
@ -536,10 +538,9 @@ def focus_triples(
|
|||
_triple(focus_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)),
|
||||
]
|
||||
|
||||
# Add each selected edge with its reasoning via intermediate entity
|
||||
# Add each selected edge with metadata via intermediate entity
|
||||
for idx, edge_info in enumerate(selected_edges_with_reasoning):
|
||||
edge = edge_info.get("edge")
|
||||
reasoning = edge_info.get("reasoning", "")
|
||||
|
||||
if edge:
|
||||
s, p, o = edge
|
||||
|
|
@ -552,13 +553,32 @@ def focus_triples(
|
|||
_triple(focus_uri, TG_SELECTED_EDGE, _iri(edge_sel_uri))
|
||||
)
|
||||
|
||||
# Type the edge selection entity
|
||||
triples.append(
|
||||
_triple(edge_sel_uri, RDF_TYPE, _iri(TG_EDGE_SELECTION))
|
||||
)
|
||||
|
||||
# Attach quoted triple to edge selection entity
|
||||
quoted = _quoted_triple(s, p, o)
|
||||
triples.append(
|
||||
Triple(s=_iri(edge_sel_uri), p=_iri(TG_EDGE), o=quoted)
|
||||
)
|
||||
|
||||
# Attach reasoning to edge selection entity
|
||||
# Structured cross-encoder metadata
|
||||
concept = edge_info.get("concept")
|
||||
if concept:
|
||||
triples.append(
|
||||
_triple(edge_sel_uri, TG_CONCEPT, _literal(concept))
|
||||
)
|
||||
|
||||
score = edge_info.get("score")
|
||||
if score is not None:
|
||||
triples.append(
|
||||
_triple(edge_sel_uri, TG_SCORE, _literal(str(score)))
|
||||
)
|
||||
|
||||
# Legacy reasoning text (for non-cross-encoder callers)
|
||||
reasoning = edge_info.get("reasoning", "")
|
||||
if reasoning:
|
||||
triples.append(
|
||||
_triple(edge_sel_uri, TG_REASONING, _literal(reasoning))
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from . namespaces import (
|
|||
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
|
||||
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
|
||||
TG_SUBAGENT_GOAL, TG_PLAN_STEP,
|
||||
TG_EDGE_SELECTION, TG_SCORE,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -93,6 +94,7 @@ TG_CLASS_LABELS = [
|
|||
_label_triple(TG_FINDING, "Finding"),
|
||||
_label_triple(TG_PLAN_TYPE, "Plan"),
|
||||
_label_triple(TG_STEP_RESULT, "Step Result"),
|
||||
_label_triple(TG_EDGE_SELECTION, "Edge Selection"),
|
||||
]
|
||||
|
||||
# TrustGraph predicate labels
|
||||
|
|
@ -117,6 +119,7 @@ TG_PREDICATE_LABELS = [
|
|||
_label_triple(TG_ENTITY, "entity"),
|
||||
_label_triple(TG_SUBAGENT_GOAL, "subagent goal"),
|
||||
_label_triple(TG_PLAN_STEP, "plan step"),
|
||||
_label_triple(TG_SCORE, "score"),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,4 +15,5 @@ from .diagnosis import *
|
|||
from .collection import *
|
||||
from .storage import *
|
||||
from .tool_service import *
|
||||
from .sparql_query import *
|
||||
from .sparql_query import *
|
||||
from .reranker import *
|
||||
|
|
@ -6,17 +6,6 @@ from ..core.primitives import Error
|
|||
|
||||
# Prompt services, abstract the prompt generation
|
||||
|
||||
# extract-definitions:
|
||||
# chunk -> definitions
|
||||
# extract-relationships:
|
||||
# chunk -> relationships
|
||||
# kg-prompt:
|
||||
# query, triples -> answer
|
||||
# document-prompt:
|
||||
# query, documents -> answer
|
||||
# extract-rows
|
||||
# schema, chunk -> rows
|
||||
|
||||
@dataclass
|
||||
class PromptRequest:
|
||||
id: str = ""
|
||||
|
|
@ -46,4 +35,4 @@ class PromptResponse:
|
|||
out_token: int | None = None
|
||||
model: str | None = None
|
||||
|
||||
############################################################################
|
||||
############################################################################
|
||||
|
|
|
|||
35
trustgraph-base/trustgraph/schema/services/reranker.py
Normal file
35
trustgraph-base/trustgraph/schema/services/reranker.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..core.primitives import Error
|
||||
|
||||
############################################################################
|
||||
|
||||
# Cross-encoder reranker
|
||||
|
||||
@dataclass
|
||||
class RerankerQuery:
|
||||
query_id: str = ""
|
||||
query_text: str = ""
|
||||
|
||||
@dataclass
|
||||
class RerankerDocument:
|
||||
document_id: str = ""
|
||||
document_text: str = ""
|
||||
|
||||
@dataclass
|
||||
class RerankerRequest:
|
||||
queries: list[RerankerQuery] = field(default_factory=list)
|
||||
documents: list[RerankerDocument] = field(default_factory=list)
|
||||
limit: int = 10
|
||||
|
||||
@dataclass
|
||||
class RerankerResult:
|
||||
document_id: str = ""
|
||||
query_id: str = ""
|
||||
score: float = 0.0
|
||||
|
||||
@dataclass
|
||||
class RerankerResponse:
|
||||
error: Error | None = None
|
||||
results: list[RerankerResult] = field(default_factory=list)
|
||||
Loading…
Add table
Add a link
Reference in a new issue