trustgraph/trustgraph-base/trustgraph/clients/prompt_client.py
cybermaggedon 01cc8dbc64
feat: replace LLM edge scoring with cross-encoder reranker in GraphRAG (#1005)
Replace the three-prompt LLM scoring pipeline (kg-edge-scoring,
kg-edge-reasoning, kg-edge-selection) with a cross-encoder reranker
service backed by FlashRank. The new hop_and_filter() method performs
iterative graph traversal with semantic scoring at each hop, replacing
the previous follow_edges/get_subgraph approach.

- Add reranker service (trustgraph-base client/service, FlashRank processor)
- Add gateway dispatch for reranker via API and WebSocket
- Rewrite GraphRAG pipeline: hop_and_filter() with per-hop cross-encoder scoring
- Remove kg_prompt() and edge_score_limit from prompt client
- Update provenance: add tg:EdgeSelection type, tg:concept, tg:score predicates
- Update CLIs (tg-invoke-graph-rag, tg-show-explain-trace) for new metadata
- Add tg-invoke-reranker CLI tool
- Add tech spec and UX developer guidance
- Update all unit and integration tests
2026-06-30 14:36:37 +01:00

154 lines
3.6 KiB
Python

import json
import dataclasses
from .. schema import PromptRequest, PromptResponse
from .. schema import prompt_request_queue
from .. schema import prompt_response_queue
from . base import BaseClient
# Ugly
@dataclasses.dataclass
class Definition:
name: str
definition: str
@dataclasses.dataclass
class Relationship:
s: str
p: str
o: str
o_entity: str
@dataclasses.dataclass
class Topic:
name: str
definition: str
class PromptClient(BaseClient):
def __init__(
self,
subscriber=None,
input_queue=None,
output_queue=None,
pulsar_host="pulsar://pulsar:6650",
pulsar_api_key=None,
):
if input_queue == None:
input_queue = prompt_request_queue
if output_queue == None:
output_queue = prompt_response_queue
super(PromptClient, self).__init__(
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
pulsar_api_key=pulsar_api_key,
input_schema=PromptRequest,
output_schema=PromptResponse,
)
def request(self, id, variables, timeout=300):
resp = self.call(
id=id,
terms={
k: json.dumps(v)
for k, v in variables.items()
},
timeout=timeout
)
if resp.text: return resp.text
return json.loads(resp.object)
def request_definitions(self, chunk, timeout=300):
defs = self.request(
id="extract-definitions",
variables={
"text": chunk
},
timeout=timeout
)
return [
Definition(name=d["entity"], definition=d["definition"])
for d in defs
]
def request_relationships(self, chunk, timeout=300):
rels = self.request(
id="extract-relationships",
variables={
"text": chunk
},
timeout=timeout
)
return [
Relationship(
s=d["subject"],
p=d["predicate"],
o=d["object"],
o_entity=d["object-entity"]
)
for d in rels
]
def request_topics(self, chunk, timeout=300):
topics = self.request(
id="extract-topics",
variables={
"text": chunk
},
timeout=timeout
)
return [
Topic(name=d["topic"], definition=d["definition"])
for d in topics
]
def request_rows(self, schema, chunk, timeout=300):
return self.request(
id="extract-rows",
variables={
"chunk": chunk,
"row-schema": {
"name": schema.name,
"description": schema.description,
"fields": [
{
"name": f.name, "type": str(f.type),
"size": f.size, "primary": f.primary,
"description": f.description,
}
for f in schema.fields
]
}
},
timeout=timeout
)
def request_document_prompt(self, query, documents, timeout=300):
return self.request(
id="document-prompt",
variables={
"query": query,
"documents": documents,
},
timeout=timeout
)