feat: replace LLM edge scoring with cross-encoder reranker in GraphRAG (#1005)

Replace the three-prompt LLM scoring pipeline (kg-edge-scoring,
kg-edge-reasoning, kg-edge-selection) with a cross-encoder reranker
service backed by FlashRank. The new hop_and_filter() method performs
iterative graph traversal with semantic scoring at each hop, replacing
the previous follow_edges/get_subgraph approach.

- Add reranker service (trustgraph-base client/service, FlashRank processor)
- Add gateway dispatch for reranker via API and WebSocket
- Rewrite GraphRAG pipeline: hop_and_filter() with per-hop cross-encoder scoring
- Remove kg_prompt() and edge_score_limit from prompt client
- Update provenance: add tg:EdgeSelection type, tg:concept, tg:score predicates
- Update CLIs (tg-invoke-graph-rag, tg-show-explain-trace) for new metadata
- Add tg-invoke-reranker CLI tool
- Add tech spec and UX developer guidance
- Update all unit and integration tests
This commit is contained in:
cybermaggedon 2026-06-30 14:36:37 +01:00 committed by GitHub
parent 1aa9549912
commit 01cc8dbc64
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
43 changed files with 1613 additions and 792 deletions

View file

@ -0,0 +1,523 @@
# GraphRAG Semantic Filter Improvement
## Problem Statement
The GraphRAG semantic filter is observed to be ineffective with certain
LLM models. Smaller models in particular produce poor-quality edge
relevance scores, and there is a suspicion that models trained or
evaluated heavily on non-Roman-script datasets offer lower performance
on the semantic ranking operation.
The root cause is that the current implementation delegates edge
relevance scoring to the LLM via a prompt that asks the model to
assign a 110 relevance score to each knowledge-graph edge. This
task — ranking structured triples for relevance to a natural-language
query — is not well covered in standard LLM evaluation suites, so
model benchmark scores are not predictive of performance on this
operation. The result is that GraphRAG quality varies unpredictably
across model choices, undermining confidence in the pipeline.
Beyond model variability, the LLM scoring step has further problems:
- **Cost and latency.** The LLM call consumes tokens and adds
latency to every query, yet its output is unreliable. Even when
the model performs well, the cost is disproportionate for what is
fundamentally a ranking operation.
- **Subjective scoring scale.** The 110 relevance scale gives the
model no objective criteria for what constitutes a 5 versus a 7.
Different models interpret the scale differently, and even the same
model can produce inconsistent scores across runs.
- **Redundancy with the embedding pre-filter.** The pipeline already
contains a cosine-similarity stage that ranks edges by semantic
relevance using embeddings. The LLM scoring step is a second
filter applied on top of this, and it is not clear that it adds
enough value to justify the additional cost and risk of
degradation.
### Industry context
Semantic ranking is rigorously evaluated on dedicated benchmarks such
as MTEB (Massive Text Embedding Benchmark) and BEIR (Benchmarking
Information Retrieval), which test retrieval and reranking across
diverse domains. The current TrustGraph approach — prompting a
general-purpose LLM to score and rank documents (the "listwise"
approach) — is known to be poorly optimized for this task. It
suffers from positional bias, formatting failures, and
inconsistency at scale.
The industry standard for semantic ranking has moved to
cross-encoder models: lightweight, purpose-built models that take a
querydocument pair as input and produce a single relevance score.
These models are fine-tuned on millions of relevance-labelled pairs
and dominate retrieval benchmarks. They are fast, deterministic,
and do not require an LLM inference call.
## Architecture
### Cross-encoder service
A new request/response service that exposes a generic semantic
ranking API. The service is not specific to GraphRAG — it is a
reusable building block for any component that needs to rank text
by relevance.
The service interface is pluggable. Alternative implementations
can be swapped in behind the same API.
**Packaging options considered:**
- *`sentence-transformers`.* Full-featured, widely used.
However, it pulls in PyTorch (~2 GB), making containers
very large. Tested at ~1.8 seconds for 2200 edges.
- *`optimum.onnxruntime`.* ONNX-based inference. Still
depends on PyTorch at import time despite using ONNX for
inference. Tested at ~4.2 seconds for 2200 edges.
- *`flashrank`.* Lightweight wrapper around ONNX Runtime
with a clean API (`Ranker`, `RerankRequest`). No PyTorch
dependency. Tested at ~4.4 seconds for 2200 edges.
- *Pure `onnxruntime` + `tokenizers`.* Leanest option
(~200 MB total). Requires manual tokenisation, padding,
and numpy array management — more boilerplate to maintain.
- *External API (e.g. Cohere Rerank).* No local model at
all. Adds network latency and an external dependency.
**Decision:** `flashrank` for the initial implementation.
No PyTorch dependency, clean API, comparable performance.
The pluggable interface allows swapping to another backend
later.
**Request:**
- `queries` — list of `{id, text}` objects. In the GraphRAG use
case these are the concepts extracted from the user's question.
- `documents` — list of `{id, text}` objects. In the GraphRAG
use case these are the candidate knowledge-graph edges
represented as text.
- `limit` — integer. Maximum number of results to return.
**Scoring:**
The service produces the cartesian product of all querydocument
pairs and scores each pair through the cross-encoder model. For
each document, the maximum score across all queries is taken as the
document's relevance score. Documents are then ranked by this
score and the top `limit` results are returned.
**Response:**
A list of the top `limit` results, each containing:
- `document_id` — the ID of the matched document.
- `query_id` — the ID of the query (concept) that produced the
highest score for this document.
- `score` — the relevance score.
Including `query_id` in the response supports the explainability
interface: it records that an edge was selected because it is
related to a specific concept.
### Integration
The cross-encoder service follows the standard TrustGraph service
integration pattern:
- **Base package (trustgraph-base).** Schema definitions for the
cross-encoder request/response messages. A client class that
other components (e.g. GraphRAG) can use to call the
cross-encoder service. Message translator registration so the
pub/sub layer can serialise/deserialise the messages.
- **Flow package (trustgraph-flow).** The cross-encoder service
implementation itself — loads the model, listens for requests,
scores pairs, returns results. Flow definition support so the
cross-encoder can be introduced into a processing flow via the
standard flow configuration. `flashrank` is added as a
dependency of `trustgraph-flow`. The service runs in its own
container.
- **API gateway.** A gateway endpoint that routes cross-encoder
requests from the HTTP API to the service over pub/sub and
returns the response.
- **CLI tool.** A command-line utility
(e.g. `tg-invoke-cross-encoder`) that calls the gateway
endpoint for manual testing and debugging.
### Current GraphRAG pipeline
The current pipeline follows these steps:
1. **Concept extraction.** An LLM prompt extracts key concepts
from the user's query.
2. **Graph exploration.** Seed entities are found via embedding
similarity. A subgraph is built by multi-hop traversal from
the seed entities (up to `max_path_length` hops, capped at
`max_subgraph_size` edges).
3. **Embedding pre-filter.** Each edge is embedded as
`"subject, predicate, object"` and scored by cosine similarity
against the concept embeddings. The top `edge_score_limit`
(default 30) edges are kept.
4. **LLM edge scoring.** The `kg-edge-scoring` prompt asks the
LLM to assign a 110 relevance score to each remaining edge.
The top `edge_limit` (default 25) edges are kept.
5. **LLM edge reasoning.** The `kg-edge-reasoning` prompt asks
the LLM to explain why each selected edge is relevant to the
query. Used for the explainability interface.
6. **Document tracing.** Selected edges are traced back to their
source documents in the librarian. Runs concurrently with
step 5.
7. **Synthesis.** The `kg-synthesis` prompt generates the final
answer from the selected edges and source document metadata.
### Potential improvements
#### Replace LLM edge scoring with cross-encoder (step 4)
The LLM edge scoring step is replaced by a call to the
cross-encoder service. The candidate edges are the documents and
`edge_limit` is the limit. This is a direct substitution: faster,
cheaper, deterministic, and more reliable across model choices.
The LLM `kg-edge-scoring` prompt is retired.
**Cross-encoder query input: concepts vs. raw query.** There are
two options for what to use as the cross-encoder queries:
- *Option A: Raw user query.* Pass the original question as a
single query string. Simpler, no dependency on concept
extraction. However, raw queries contain noise words and
conversational phrasing that do not match well against the
structured vocabulary of knowledge-graph edges. A single query
also means every edge competes against the full question — a
partial match on one aspect is diluted.
- *Option B: Extracted concepts.* Pass the concepts from step 1
as separate queries. The concepts are distilled, focused terms
that are closer to the language of the edges. With multiple
concepts as independent queries, the cross-encoder scores each
edge against each concept separately, giving better coverage —
an edge only needs to match one concept well to be selected.
The trade-off is a dependency on the LLM concept extraction
step, but this is already in the pipeline and is a lightweight,
reliable LLM call.
**Decision:** Option B — use extracted concepts. The concept
extraction is fast, and the resulting terms produce better
cross-encoder matches against structured triples.
#### Edge text representation
The current embedding pre-filter represents each edge as
`"subject, predicate, object"`. Two changes:
- **Drop commas.** Commas add tokenisation noise without semantic
value.
- **Drop the subject.** The subject identifies which entity the
edge belongs to, but it does not contribute to whether the
edge's content is relevant to the query. The predicate and
object carry the semantic meaning — what relationship exists
and what it connects to. Representing edges as `"{p} {o}"`
produces cleaner cross-encoder matches.
#### Remove the embedding pre-filter (step 3)
The embedding pre-filter was introduced to reduce the number of
edges before the expensive LLM scoring call. With the
cross-encoder replacing the LLM call, this cost equation changes.
**Arguments for removal:**
- The cross-encoder is fast enough to score the full subgraph
directly. In testing, 2200 edges scored in ~1.8 seconds; at
the default `max_subgraph_size` of 150 edges, scoring takes
a fraction of a second.
- The pre-filter is a weaker version of what the cross-encoder
does. Bi-encoder cosine similarity embeds the query and
document independently and compares vectors; the cross-encoder
processes both texts together through the full transformer,
giving it much better relevance judgement. Running a weaker
filter before a stronger one adds latency without improving
quality.
- Removing it eliminates an embedding service call (two batches:
concepts + edges) and the associated latency.
**Arguments for keeping it:**
- If the subgraph is very large (thousands of edges), the
cross-encoder's linear scaling could become a bottleneck.
The pre-filter would act as a safety valve.
- The embedding call is cheap compared to an LLM call, so the
overhead is modest.
**Decision:** Remove the pre-filter. The `max_subgraph_size`
parameter (default 150) already caps the number of edges entering
this stage, so the cross-encoder will not face an unbounded
workload. If very large subgraphs become a concern in future,
the pre-filter can be reintroduced or `max_subgraph_size` can be
tuned.
#### Iterative graph traversal with cross-encoder filtering
The current pipeline performs graph exploration and edge filtering
as separate phases: first build the full subgraph (up to
`max_path_length` hops), then score and filter edges. An
alternative is to interleave traversal and filtering — at each
hop, use the cross-encoder to select relevant edges before
expanding further.
**Option A: Big-bang traversal then filter.** Traverse the full
subgraph up to `max_path_length` hops from the seed entities,
collecting all edges up to `max_subgraph_size`. Then
cross-encode the entire result to select the top edges.
- Simple to implement — the current traversal logic is largely
unchanged.
- Produces large, unfocused subgraphs. Irrelevant branches are
explored and scored even though they will be discarded.
- Poorly suited to multi-hop reasoning. For a query about
Voyager 1, the subgraph includes Voyager 2's edges because
they are within hop distance, and the filter must then
separate them.
**Option B: Iterative hop-and-filter.** At each hop:
1. Retrieve all edges one hop from the current frontier nodes.
2. Cross-encode these edges against the query concepts.
3. Select the top relevant edges.
4. The target nodes of the selected edges become the frontier
for the next hop.
5. Repeat up to `max_path_length` hops.
The final set of selected edges across all hops is the input to
synthesis.
- **Guided exploration.** Each hop focuses the search by
pruning irrelevant branches before expanding further. The
working set stays small and relevant at every step.
- **Multi-hop reasoning works naturally.** Following
"Voyager 1 → has-event → crossed the heliopause" succeeds
because each hop is individually relevant and leads to the
next.
- **Smaller total workload.** Fewer edges are scored overall
because irrelevant branches are never expanded.
- **Trade-off: greedy pruning.** An edge discarded at hop 1
cannot lead to relevant edges at hop 2. This is inherent in
any bounded traversal, and the cross-encoder is better
equipped to make this relevance judgement than a blind hop
limit.
- **Trade-off: sequential latency.** Hops cannot be
parallelised since each depends on the previous. However,
each cross-encoder call on a small edge set is very fast
(sub-second for typical working sets).
**Decision:** Option B — iterative hop-and-filter. The guided
traversal produces more focused subgraphs and supports multi-hop
reasoning, which is a significant quality improvement over the
current approach.
#### Replace LLM edge reasoning with cross-encoder metadata (step 5)
The current `kg-edge-reasoning` prompt asks the LLM to explain why
each edge is relevant. With the cross-encoder now making the
selection, this explanation would be a post-hoc fabrication — the
LLM was not involved in the decision.
- *Option A: Keep LLM reasoning.* Generates natural-language
explanations but they are not grounded in the actual selection
process. Adds an LLM call per query.
- *Option B: Record cross-encoder metadata.* The cross-encoder
already returns the matched concept and score for each selected
edge. Use this directly as the explanation.
**Decision:** Option B. The cross-encoder metadata is the true
reason the edge was selected. The `kg-edge-reasoning` prompt is
retired.
#### Explainability interface update
The explainability interface uses a `Focus` entity containing
`EdgeSelection` sub-entities. Each `EdgeSelection` currently
carries an `edge` (the quoted triple) and a `reasoning` field
(free-text LLM prose), stored as `tg:reasoning` in the
provenance graph.
With the cross-encoder replacing LLM reasoning, the
`EdgeSelection` type gains two new predicates and drops one:
- **Remove** `tg:reasoning` — no longer produced.
- **Add** `tg:concept` — the concept text that produced the
highest cross-encoder score for this edge.
- **Add** `tg:score` — the cross-encoder relevance score.
This is an evolution of the existing `EdgeSelection` type, not a
new entity type. The edge selection sub-entities currently have
no `rdf:type` declared; a new `tg:EdgeSelection` type should be
added so that consumers can identify them in the provenance
graph. The `Focus` entity and its relationship to `Exploration`
are unchanged.
The `Focus` entity's token-usage metadata (`tg:inToken`,
`tg:outToken`, `tg:llmModel`) no longer applies since there is
no LLM call. These fields are dropped from the Focus entity.
### Proposed pipeline
1. **Concept extraction.** Unchanged — LLM extracts key concepts
from the user's query.
2. **Seed entity lookup.** Find seed entities via embedding
similarity against the extracted concepts.
3. **Iterative hop-and-filter.** For each hop up to
`max_path_length`:
a. Retrieve all edges one hop from the current frontier nodes.
b. Represent each edge as `"{predicate} {object}"`.
c. Score edges against the extracted concepts using the
cross-encoder service.
d. Select the top relevant edges. The target nodes of the
selected edges become the frontier for the next hop.
4. **Document tracing.** Selected edges are traced back to source
documents.
5. **Synthesis.** The `kg-synthesis` prompt generates the final
answer from the selected edges and source document metadata.
### Implementation order
1. Cross-encoder service with full integration (base schema,
flow service, gateway endpoint, CLI tool).
2. GraphRAG pipeline changes (iterative hop-and-filter,
edge representation, remove pre-filter).
3. Explainability update (`tg:EdgeSelection` type, concept
and score predicates, retire `tg:reasoning`).
4. Retire `kg-edge-scoring` and `kg-edge-reasoning` prompts.
5. Update `tg-invoke-graph-rag` and `tg-show-explain-trace`
to display the new metadata. Use these as the main
end-to-end test.
6. Fix any failing unit tests, then add new tests as needed.
7. Write guidance for UX devs to update the UI for the new
explainability predicates.
## UX developer guidance
This section describes the changes to the explainability interface
that affect frontend rendering of GraphRAG Focus events.
### What changed
Edge selection in GraphRAG previously used LLM-based scoring and
reasoning. Each selected edge carried a `tg:reasoning` predicate
with free-text explanation from the LLM. This has been replaced
by a cross-encoder reranker that scores edges against query
concepts. The explainability data now carries structured metadata
instead of free text.
### Removed
- **`tg:reasoning`** is no longer emitted on edge selection
entities in GraphRAG Focus events. UX code that reads
`edge_sel.reasoning` will get an empty string. Remove any
rendering that displays a "Reasoning" or "Reason" field for
Focus edges.
- The **`kg-edge-scoring`**, **`kg-edge-reasoning`**, and
**`kg-edge-selection`** prompts are retired. Any UX that
references these prompt names should be cleaned up.
### Added
Each edge selection entity within a Focus event now has three
new properties:
| RDF predicate | API field | Type | Description |
|---|---|---|---|
| `rdf:type tg:EdgeSelection` | (type check) | — | Each edge selection entity is now explicitly typed |
| `tg:concept` | `edge_sel.concept` | `str` | The query concept that matched this edge |
| `tg:score` | `edge_sel.score` | `float` or `None` | Cross-encoder relevance score (0.01.0) |
The `tg:edge` predicate (RDF-star quoted triple) is unchanged.
### How to render
The recommended rendering for each selected edge in a Focus event:
```
Edge: (subject_label, predicate_label, object_label)
Concept: <concept> Score: <score formatted to 4 decimal places>
```
Scores near 1.0 indicate high relevance; scores near 0.0 indicate
low relevance. UX could use the score to drive visual indicators
such as colour intensity or a relevance bar.
Edges are not returned in score order — they arrive in traversal
order across hops. If the UX wants to display edges ranked by
relevance, sort by `edge_sel.score` descending.
### API classes (Python)
The `EdgeSelection` dataclass in `trustgraph.api.explainability`
has these fields:
```python
@dataclass
class EdgeSelection:
uri: str
edge: Optional[Dict[str, str]] # {"s": ..., "p": ..., "o": ...}
reasoning: str = "" # Legacy, always empty for new traces
concept: str = "" # Query concept that matched
score: Optional[float] = None # Cross-encoder relevance score
```
These are populated when calling
`ExplainabilityClient.fetch_focus_with_edges()` or when parsing
inline provenance triples from the streaming response.
### WebSocket response format
For inline explainability via the streaming WebSocket, Focus events
arrive as `message_type: "explain"` responses. The `explain_triples`
array contains the edge selection triples. The relevant predicates
in wire format are:
```json
{"s": {"t": "i", "i": "<edge_sel_uri>"},
"p": {"t": "i", "i": "https://trustgraph.ai/ns/concept"},
"o": {"t": "l", "v": "flyby event"}}
{"s": {"t": "i", "i": "<edge_sel_uri>"},
"p": {"t": "i", "i": "https://trustgraph.ai/ns/score"},
"o": {"t": "l", "v": "0.9962"}}
```
Note that `tg:score` is transmitted as a string literal and must
be parsed to a float on the client side.
### Exploration event
The Exploration event's `edge_count` field now reports the number
of edges selected by the cross-encoder across all hops (previously
it reported the total number of edges retrieved before filtering).
The `entities` list continues to report the seed entities found
by vector search.

View file

@ -95,10 +95,6 @@ class TestGraphRagIntegration:
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "extract-concepts":
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-scoring":
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-reasoning":
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
return PromptResult(
response_type="text",
@ -113,14 +109,22 @@ class TestGraphRagIntegration:
client.prompt.side_effect = mock_prompt
return client
@pytest.fixture
def mock_reranker_client(self):
"""Mock reranker client for cross-encoder edge filtering"""
client = AsyncMock()
client.rerank.return_value = []
return client
@pytest.fixture
def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client,
mock_triples_client, mock_prompt_client):
mock_triples_client, mock_reranker_client, mock_prompt_client):
"""Create GraphRag instance with mocked dependencies"""
return GraphRag(
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
reranker_client=mock_reranker_client,
prompt_client=mock_prompt_client,
verbose=True
)
@ -167,8 +171,8 @@ class TestGraphRagIntegration:
# 3. Should query triples to build knowledge subgraph
assert mock_triples_client.query_stream.call_count > 0
# 4. Should call prompt four times (extract-concepts + edge-scoring + edge-reasoning + synthesis)
assert mock_prompt_client.prompt.call_count == 4
# 4. Should call prompt twice (extract-concepts + synthesis)
assert mock_prompt_client.prompt.call_count == 2
# Verify final response
response, usage = response

View file

@ -63,11 +63,6 @@ class TestGraphRagStreaming:
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
if prompt_id == "extract-concepts":
return PromptResult(response_type="text", text="")
elif prompt_id == "kg-edge-scoring":
# Edge scoring returns JSONL with IDs and scores
return PromptResult(response_type="text", text='{"id": "abc12345", "score": 0.9}\n')
elif prompt_id == "kg-edge-reasoning":
return PromptResult(response_type="text", text='{"id": "abc12345", "reasoning": "Relevant to query"}\n')
elif prompt_id == "kg-synthesis":
if streaming and chunk_callback:
# Simulate streaming chunks with end_of_stream flags
@ -88,14 +83,23 @@ class TestGraphRagStreaming:
client.prompt.side_effect = prompt_side_effect
return client
@pytest.fixture
def mock_reranker_client(self):
"""Mock reranker client for cross-encoder edge filtering"""
client = AsyncMock()
client.rerank.return_value = []
return client
@pytest.fixture
def graph_rag_streaming(self, mock_embeddings_client, mock_graph_embeddings_client,
mock_triples_client, mock_streaming_prompt_client):
mock_triples_client, mock_reranker_client,
mock_streaming_prompt_client):
"""Create GraphRag instance with streaming support"""
return GraphRag(
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
reranker_client=mock_reranker_client,
prompt_client=mock_streaming_prompt_client,
verbose=True
)

View file

@ -46,7 +46,7 @@ class TestGraphRagStreamingProtocol:
client = AsyncMock()
async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
if prompt_name == "extract-concepts":
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
if streaming and chunk_callback:
@ -63,14 +63,23 @@ class TestGraphRagStreamingProtocol:
client.prompt.side_effect = prompt_side_effect
return client
@pytest.fixture
def mock_reranker_client(self):
"""Mock reranker client for cross-encoder edge filtering"""
client = AsyncMock()
client.rerank.return_value = []
return client
@pytest.fixture
def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client,
mock_triples_client, mock_streaming_prompt_client):
mock_triples_client, mock_reranker_client,
mock_streaming_prompt_client):
"""Create GraphRag instance with mocked dependencies"""
return GraphRag(
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
reranker_client=mock_reranker_client,
prompt_client=mock_streaming_prompt_client,
verbose=False
)
@ -327,7 +336,7 @@ class TestStreamingProtocolEdgeCases:
client = AsyncMock()
async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
if prompt_name == "extract-concepts":
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
if streaming and chunk_callback:
@ -342,10 +351,14 @@ class TestStreamingProtocolEdgeCases:
client.prompt.side_effect = prompt_with_empties
mock_reranker = AsyncMock()
mock_reranker.rerank.return_value = []
rag = GraphRag(
embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[[0.1]]])),
graph_embeddings_client=AsyncMock(query=AsyncMock(return_value=[])),
triples_client=AsyncMock(query=AsyncMock(return_value=[])),
reranker_client=mock_reranker,
prompt_client=client,
verbose=False
)

View file

@ -195,38 +195,6 @@ class TestPromptClientStreamingCallback:
assert callback.call_args_list[0] == call("test", False)
assert callback.call_args_list[1] == call("", True)
@pytest.mark.asyncio
async def test_kg_prompt_passes_parameters_to_callback(self, prompt_client):
"""Test that kg_prompt correctly passes streaming parameters"""
# Arrange
async def mock_request(request, recipient=None, timeout=600):
if recipient:
responses = [
PromptResponse(text="Answer", object=None, error=None, end_of_stream=False),
PromptResponse(text="", object=None, error=None, end_of_stream=True),
]
for resp in responses:
should_stop = await recipient(resp)
if should_stop:
break
prompt_client.request = mock_request
callback = AsyncMock()
# Act
await prompt_client.kg_prompt(
query="What is machine learning?",
kg=[("subject", "predicate", "object")],
streaming=True,
chunk_callback=callback
)
# Assert
assert callback.call_count == 2
assert callback.call_args_list[0] == call("Answer", False)
assert callback.call_args_list[1] == call("", True)
@pytest.mark.asyncio
async def test_document_prompt_passes_parameters_to_callback(self, prompt_client):
"""Test that document_prompt correctly passes streaming parameters"""

View file

@ -107,6 +107,7 @@ class TestGraphRagDagStructure:
embeddings_client = AsyncMock()
graph_embeddings_client = AsyncMock()
triples_client = AsyncMock()
reranker_client = AsyncMock()
embeddings_client.embed.return_value = [[0.1, 0.2]]
graph_embeddings_client.query.return_value = [
@ -121,27 +122,22 @@ class TestGraphRagDagStructure:
]
triples_client.query.return_value = []
result = MagicMock()
result.document_id = "0"
result.query_id = "0"
result.score = 0.95
reranker_client.rerank.return_value = [result]
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return PromptResult(response_type="text", text="concept")
elif template_id == "kg-edge-scoring":
edges = variables.get("knowledge", [])
return PromptResult(
response_type="jsonl",
objects=[{"id": e["id"], "score": 10} for e in edges],
)
elif template_id == "kg-edge-reasoning":
edges = variables.get("knowledge", [])
return PromptResult(
response_type="jsonl",
objects=[{"id": e["id"], "reasoning": "relevant"} for e in edges],
)
elif template_id == "kg-synthesis":
return PromptResult(response_type="text", text="Answer.")
return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
return (prompt_client, embeddings_client, graph_embeddings_client,
triples_client, reranker_client)
@pytest.mark.asyncio
async def test_dag_chain(self, mock_clients):
@ -152,7 +148,7 @@ class TestGraphRagDagStructure:
events.append({"explain_id": explain_id, "triples": triples})
await rag.query(
query="test", explain_callback=explain_cb, edge_score_limit=0,
query="test", explain_callback=explain_cb,
)
dag = _collect_events(events)

View file

@ -15,54 +15,52 @@ class TestGraphRag:
def test_graph_rag_initialization_with_defaults(self):
"""Test GraphRag initialization with default verbose setting"""
# Create mock clients
mock_prompt_client = MagicMock()
mock_embeddings_client = MagicMock()
mock_graph_embeddings_client = MagicMock()
mock_triples_client = MagicMock()
mock_reranker_client = MagicMock()
# Initialize GraphRag
graph_rag = GraphRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client
)
# Verify initialization
assert graph_rag.prompt_client == mock_prompt_client
assert graph_rag.embeddings_client == mock_embeddings_client
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
assert graph_rag.triples_client == mock_triples_client
assert graph_rag.verbose is False # Default value
# Verify label_cache is an LRUCacheWithTTL instance
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
def test_graph_rag_initialization_with_verbose(self):
"""Test GraphRag initialization with verbose enabled"""
# Create mock clients
mock_prompt_client = MagicMock()
mock_embeddings_client = MagicMock()
mock_graph_embeddings_client = MagicMock()
mock_triples_client = MagicMock()
# Initialize GraphRag with verbose=True
graph_rag = GraphRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
verbose=True
reranker_client=mock_reranker_client,
)
# Verify initialization
assert graph_rag.prompt_client == mock_prompt_client
assert graph_rag.embeddings_client == mock_embeddings_client
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
assert graph_rag.triples_client == mock_triples_client
assert graph_rag.reranker_client == mock_reranker_client
assert graph_rag.verbose is False
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
def test_graph_rag_initialization_with_verbose(self):
"""Test GraphRag initialization with verbose enabled"""
mock_prompt_client = MagicMock()
mock_embeddings_client = MagicMock()
mock_graph_embeddings_client = MagicMock()
mock_triples_client = MagicMock()
mock_reranker_client = MagicMock()
graph_rag = GraphRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
reranker_client=mock_reranker_client,
verbose=True,
)
assert graph_rag.prompt_client == mock_prompt_client
assert graph_rag.embeddings_client == mock_embeddings_client
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
assert graph_rag.triples_client == mock_triples_client
assert graph_rag.reranker_client == mock_reranker_client
assert graph_rag.verbose is True
# Verify label_cache is an LRUCacheWithTTL instance
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
@ -365,244 +363,162 @@ class TestQuery:
assert "workspace" not in c.kwargs
@pytest.mark.asyncio
async def test_follow_edges_never_passes_workspace(self):
"""Verify follow_edges never passes workspace to query_stream."""
async def test_hop_and_filter_never_passes_workspace(self):
"""Verify hop_and_filter never passes workspace to query_stream."""
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_reranker_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_rag.reranker_client = mock_reranker_client
mock_rag.label_cache = MagicMock()
mock_rag.label_cache.get.return_value = None
mock_triple = MagicMock()
mock_triple.s, mock_triple.p, mock_triple.o = "e1", "p1", "o1"
mock_triple.s = "e1"
mock_triple.p = "p1"
mock_triple.o = "o1"
mock_triples_client.query_stream.return_value = [mock_triple]
mock_triples_client.query.return_value = []
result = MagicMock()
result.document_id = "0"
result.query_id = "0"
result.score = 0.9
mock_reranker_client.rerank.return_value = [result]
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
triple_limit=10
triple_limit=10,
)
subgraph = set()
await query.follow_edges("e1", subgraph, path_length=1)
await query.hop_and_filter(["e1"], ["concept"])
for c in mock_triples_client.query_stream.call_args_list:
assert "workspace" not in c.kwargs
@pytest.mark.asyncio
async def test_follow_edges_basic_functionality(self):
"""Test Query.follow_edges method basic triple discovery"""
async def test_hop_and_filter_basic_functionality(self):
"""Test hop_and_filter retrieves edges and scores them with reranker."""
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_reranker_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_rag.reranker_client = mock_reranker_client
mock_rag.label_cache = MagicMock()
mock_rag.label_cache.get.return_value = None
mock_triple1 = MagicMock()
mock_triple1.s, mock_triple1.p, mock_triple1.o = "entity1", "predicate1", "object1"
mock_triple = MagicMock()
mock_triple.s = "entity1"
mock_triple.p = "predicate1"
mock_triple.o = "object1"
mock_triples_client.query_stream.return_value = [mock_triple]
mock_triples_client.query.return_value = []
mock_triple2 = MagicMock()
mock_triple2.s, mock_triple2.p, mock_triple2.o = "subject2", "entity1", "object2"
mock_triple3 = MagicMock()
mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1"
mock_triples_client.query_stream.side_effect = [
[mock_triple1], # s=ent
[mock_triple2], # p=ent
[mock_triple3], # o=ent
]
result = MagicMock()
result.document_id = "0"
result.query_id = "0"
result.score = 0.95
mock_reranker_client.rerank.return_value = [result]
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
triple_limit=10
triple_limit=10,
edge_limit=25,
)
subgraph = set()
await query.follow_edges("entity1", subgraph, path_length=1)
assert mock_triples_client.query_stream.call_count == 3
mock_triples_client.query_stream.assert_any_call(
s="entity1", p=None, o=None, limit=10,
collection="test_collection", batch_size=20, g=""
)
mock_triples_client.query_stream.assert_any_call(
s=None, p="entity1", o=None, limit=10,
collection="test_collection", batch_size=20, g=""
)
mock_triples_client.query_stream.assert_any_call(
s=None, p=None, o="entity1", limit=10,
collection="test_collection", batch_size=20, g=""
selected, uri_map, edge_meta = await query.hop_and_filter(
["entity1"], ["test concept"],
)
expected_subgraph = {
("entity1", "predicate1", "object1"),
("subject2", "entity1", "object2"),
("subject3", "predicate3", "entity1")
}
assert subgraph == expected_subgraph
assert len(selected) == 1
assert len(uri_map) == 1
assert len(edge_meta) == 1
mock_reranker_client.rerank.assert_called_once()
call_kwargs = mock_reranker_client.rerank.call_args
assert call_kwargs.kwargs["limit"] == 25
@pytest.mark.asyncio
async def test_follow_edges_with_path_length_zero(self):
"""Test Query.follow_edges method with path_length=0"""
async def test_hop_and_filter_with_empty_frontier(self):
"""Test hop_and_filter with no seed entities returns empty."""
mock_rag = MagicMock()
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
)
selected, uri_map, edge_meta = await query.hop_and_filter([], ["concept"])
assert selected == []
assert uri_map == {}
assert edge_meta == {}
@pytest.mark.asyncio
async def test_hop_and_filter_filters_label_triples(self):
"""Test hop_and_filter skips rdfs:label edges."""
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_reranker_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_rag.reranker_client = mock_reranker_client
mock_rag.label_cache = MagicMock()
mock_rag.label_cache.get.return_value = None
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False
)
label_triple = MagicMock()
label_triple.s = "entity1"
label_triple.p = "http://www.w3.org/2000/01/rdf-schema#label"
label_triple.o = "Entity One"
subgraph = set()
await query.follow_edges("entity1", subgraph, path_length=0)
mock_triples_client.query_stream.assert_not_called()
assert subgraph == set()
@pytest.mark.asyncio
async def test_follow_edges_with_max_subgraph_size_limit(self):
"""Test Query.follow_edges method respects max_subgraph_size"""
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_triples_client.query_stream.return_value = [label_triple]
mock_triples_client.query.return_value = []
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
max_subgraph_size=2
triple_limit=10,
)
subgraph = {("s1", "p1", "o1"), ("s2", "p2", "o2"), ("s3", "p3", "o3")}
await query.follow_edges("entity1", subgraph, path_length=1)
mock_triples_client.query_stream.assert_not_called()
assert len(subgraph) == 3
@pytest.mark.asyncio
async def test_get_subgraph_method(self):
"""Test Query.get_subgraph returns (subgraph, entities, concepts) tuple"""
mock_rag = MagicMock()
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
max_path_length=1
selected, uri_map, edge_meta = await query.hop_and_filter(
["entity1"], ["concept"],
)
# Mock get_entities to return (entities, concepts) tuple
query.get_entities = AsyncMock(
return_value=(["entity1", "entity2"], ["concept1"])
)
query.follow_edges_batch = AsyncMock(return_value=(
{
("entity1", "predicate1", "object1"),
("entity2", "predicate2", "object2")
},
{}
))
subgraph, term_map, entities, concepts = await query.get_subgraph("test query")
query.get_entities.assert_called_once_with("test query")
query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1)
assert isinstance(subgraph, list)
assert len(subgraph) == 2
assert ("entity1", "predicate1", "object1") in subgraph
assert ("entity2", "predicate2", "object2") in subgraph
assert entities == ["entity1", "entity2"]
assert concepts == ["concept1"]
@pytest.mark.asyncio
async def test_get_labelgraph_method(self):
"""Test Query.get_labelgraph returns (labeled_edges, uri_map, entities, concepts)"""
mock_rag = MagicMock()
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
max_subgraph_size=100
)
test_subgraph = [
("entity1", "predicate1", "object1"),
("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"),
("entity3", "predicate3", "object3")
]
test_entities = ["entity1", "entity3"]
test_concepts = ["concept1"]
query.get_subgraph = AsyncMock(
return_value=(test_subgraph, {}, test_entities, test_concepts)
)
async def mock_maybe_label(entity):
label_map = {
"entity1": "Human Entity One",
"predicate1": "Human Predicate One",
"object1": "Human Object One",
"entity3": "Human Entity Three",
"predicate3": "Human Predicate Three",
"object3": "Human Object Three"
}
return label_map.get(entity, entity)
query.maybe_label = AsyncMock(side_effect=mock_maybe_label)
labeled_edges, uri_map, entities, concepts = await query.get_labelgraph("test query")
query.get_subgraph.assert_called_once_with("test query")
# Label triples filtered out
assert len(labeled_edges) == 2
# maybe_label called for non-label triples
assert query.maybe_label.call_count == 6
expected_edges = [
("Human Entity One", "Human Predicate One", "Human Object One"),
("Human Entity Three", "Human Predicate Three", "Human Object Three")
]
assert labeled_edges == expected_edges
assert len(uri_map) == 2
assert entities == test_entities
assert concepts == test_concepts
assert selected == []
mock_reranker_client.rerank.assert_not_called()
@pytest.mark.asyncio
async def test_graph_rag_query_method(self):
"""Test GraphRag.query method orchestrates full RAG pipeline with provenance"""
import json
from trustgraph.retrieval.graph_rag.graph_rag import edge_id
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_graph_embeddings_client = AsyncMock()
mock_triples_client = AsyncMock()
mock_reranker_client = AsyncMock()
expected_response = "This is the RAG response"
test_labelgraph = [("Subject", "Predicate", "Object")]
test_edge_id = edge_id("Subject", "Predicate", "Object")
test_selected_edges = [("Subject", "Predicate", "Object")]
test_eid = edge_id("Subject", "Predicate", "Object")
test_uri_map = {
test_edge_id: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
test_eid: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
}
test_edge_metadata = {
test_eid: {"concept": "test concept", "score": 0.95}
}
test_entities = ["http://example.org/subject"]
test_concepts = ["test concept"]
# Mock prompt responses for the multi-step process
mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
mock_graph_embeddings_client.query.return_value = []
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "extract-concepts":
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-scoring":
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "score": 0.9}])
elif prompt_name == "kg-edge-reasoning":
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "reasoning": "relevant"}])
return PromptResult(response_type="text", text="test concept")
elif prompt_name == "kg-synthesis":
return PromptResult(response_type="text", text=expected_response)
return PromptResult(response_type="text", text="")
@ -614,16 +530,16 @@ class TestQuery:
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
verbose=False
reranker_client=mock_reranker_client,
verbose=False,
)
# Patch Query.get_labelgraph to return test data
original_get_labelgraph = Query.get_labelgraph
original_hop_and_filter = Query.hop_and_filter
async def mock_get_labelgraph(self, query_text):
return test_labelgraph, test_uri_map, test_entities, test_concepts
async def mock_hop_and_filter(self, seed_entities, concepts):
return test_selected_edges, test_uri_map, test_edge_metadata
Query.get_labelgraph = mock_get_labelgraph
Query.hop_and_filter = mock_hop_and_filter
provenance_events = []
@ -636,7 +552,7 @@ class TestQuery:
collection="test_collection",
entity_limit=25,
triple_limit=15,
explain_callback=collect_provenance
explain_callback=collect_provenance,
)
response_text, usage = response
@ -650,7 +566,6 @@ class TestQuery:
assert len(triples) > 0
assert prov_id.startswith("urn:trustgraph:")
# Verify order
assert "question" in provenance_events[0][1]
assert "grounding" in provenance_events[1][1]
assert "exploration" in provenance_events[2][1]
@ -658,4 +573,4 @@ class TestQuery:
assert "synthesis" in provenance_events[4][1]
finally:
Query.get_labelgraph = original_get_labelgraph
Query.hop_and_filter = original_hop_and_filter

View file

@ -20,7 +20,7 @@ from trustgraph.provenance.namespaces import (
TG_GRAPH_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
TG_FOCUS, TG_SYNTHESIS, TG_ANSWER_TYPE,
TG_QUERY, TG_CONCEPT, TG_ENTITY, TG_EDGE_COUNT,
TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_SELECTED_EDGE, TG_EDGE, TG_SCORE, TG_EDGE_SELECTION,
)
@ -91,17 +91,17 @@ def build_mock_clients():
1. prompt_client.prompt("extract-concepts", ...) -> concepts
2. embeddings_client.embed(concepts) -> vectors
3. graph_embeddings_client.query(vector, ...) -> entity matches
4. triples_client.query_stream(s/p/o, ...) -> edges (follow_edges_batch)
4. triples_client.query_stream(s/p/o, ...) -> edges (hop_and_filter)
5. triples_client.query(s, LABEL, ...) -> labels (maybe_label)
6. prompt_client.prompt("kg-edge-scoring", ...) -> scored edges
7. prompt_client.prompt("kg-edge-reasoning", ...) -> reasoning
8. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns [])
9. prompt_client.prompt("kg-synthesis", ...) -> final answer
6. reranker_client.rerank(queries, documents, limit) -> scored edges
7. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns [])
8. prompt_client.prompt("kg-synthesis", ...) -> final answer
"""
prompt_client = AsyncMock()
embeddings_client = AsyncMock()
graph_embeddings_client = AsyncMock()
triples_client = AsyncMock()
reranker_client = AsyncMock()
# 1. Concept extraction
prompt_responses = {}
@ -116,7 +116,7 @@ def build_mock_clients():
EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_B)),
]
# 4. Triple queries (follow_edges_batch) - return our edges
# 4. Triple queries (hop_and_filter) - return our edges
kg_triples = [
make_schema_triple(*EDGE_1),
make_schema_triple(*EDGE_2),
@ -130,9 +130,18 @@ def build_mock_clients():
return [] # No labels found, will fall back to URI
triples_client.query.side_effect = mock_label_query
# 6+7. Edge scoring and reasoning: dynamically score/reason about
# whatever edges the query method sends us, since edge IDs are computed
# from str(Term) representations which include the full dataclass repr.
# 6. Reranker: select all documents with high scores
async def mock_rerank(queries, documents, limit):
results = []
for i, doc in enumerate(documents):
result = MagicMock()
result.document_id = doc["id"]
result.query_id = queries[0]["id"] if queries else "0"
result.score = 0.9 - (i * 0.1)
results.append(result)
return results[:limit]
reranker_client.rerank.side_effect = mock_rerank
synthesis_answer = "Quantum computing applies physics principles to computation."
async def mock_prompt(template_id, variables=None, **kwargs):
@ -141,26 +150,6 @@ def build_mock_clients():
response_type="text",
text=prompt_responses["extract-concepts"],
)
elif template_id == "kg-edge-scoring":
# Score all edges highly, using the IDs that GraphRag computed
edges = variables.get("knowledge", [])
return PromptResult(
response_type="jsonl",
objects=[
{"id": e["id"], "score": 10 - i}
for i, e in enumerate(edges)
],
)
elif template_id == "kg-edge-reasoning":
# Provide reasoning for each edge
edges = variables.get("knowledge", [])
return PromptResult(
response_type="jsonl",
objects=[
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
for i, e in enumerate(edges)
],
)
elif template_id == "kg-synthesis":
return PromptResult(
response_type="text",
@ -170,7 +159,8 @@ def build_mock_clients():
prompt_client.prompt.side_effect = mock_prompt
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
return (prompt_client, embeddings_client, graph_embeddings_client,
triples_client, reranker_client)
# ---------------------------------------------------------------------------
@ -197,7 +187,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0, # skip semantic pre-filter for simplicity
)
assert len(events) == 5, (
@ -222,7 +212,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
expected_types = [
@ -260,7 +250,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
uris = [e["explain_id"] for e in events]
@ -297,7 +287,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
q_uri = events[0]["explain_id"]
@ -320,7 +310,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
gnd_uri = events[1]["explain_id"]
@ -344,7 +334,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
exp_uri = events[2]["explain_id"]
@ -355,10 +345,10 @@ class TestGraphRagQueryProvenance:
assert int(t.o.value) > 0
@pytest.mark.asyncio
async def test_focus_has_selected_edges_with_reasoning(self):
async def test_focus_has_selected_edges_with_concept_and_score(self):
"""
The focus event should carry selected edges as quoted triples
with reasoning text.
with cross-encoder concept and score metadata.
"""
clients = build_mock_clients()
rag = GraphRag(*clients)
@ -371,7 +361,6 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
foc_uri = events[3]["explain_id"]
@ -387,11 +376,19 @@ class TestGraphRagQueryProvenance:
for t in edge_t:
assert t.o.triple is not None, "tg:edge object must be a quoted triple"
# Should have reasoning
reasoning = find_triples(foc_triples, TG_REASONING)
assert len(reasoning) > 0, "Focus should have reasoning for selected edges"
reasoning_texts = {t.o.value for t in reasoning}
assert any(r for r in reasoning_texts), "Reasoning should not be empty"
# Edge selections should be typed as EdgeSelection
edge_sel_uris = [t.o.iri for t in selected]
for uri in edge_sel_uris:
assert has_type(foc_triples, uri, TG_EDGE_SELECTION)
# Should have concept and score
concepts = find_triples(foc_triples, TG_CONCEPT)
assert len(concepts) > 0, "Focus should have tg:concept for selected edges"
scores = find_triples(foc_triples, TG_SCORE)
assert len(scores) > 0, "Focus should have tg:score for selected edges"
for t in scores:
float(t.o.value) # Should be parseable as float
@pytest.mark.asyncio
async def test_synthesis_is_answer_type(self):
@ -407,7 +404,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
syn_uri = events[4]["explain_id"]
@ -429,7 +426,7 @@ class TestGraphRagQueryProvenance:
result_text, usage = await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
assert result_text == "Quantum computing applies physics principles to computation."
@ -449,7 +446,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
parent_uri=parent,
)
@ -465,7 +462,7 @@ class TestGraphRagQueryProvenance:
result_text, usage = await rag.query(
query="What is quantum computing?",
edge_score_limit=0,
)
assert result_text == "Quantum computing applies physics principles to computation."
@ -484,7 +481,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
for event in events:

View file

@ -646,6 +646,16 @@ class AsyncFlowInstance:
return await self.request("embeddings", request_data)
async def rerank(self, queries: list, documents: list, limit: int = 10, **kwargs: Any):
request_data = {
"queries": queries,
"documents": documents,
"limit": limit,
}
request_data.update(kwargs)
return await self.request("reranker", request_data)
async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs: Any):
"""
Query RDF triples using pattern matching.

View file

@ -443,6 +443,19 @@ class AsyncSocketFlowInstance:
return await self.client._send_request("embeddings", self.flow_id, request)
async def rerank(self, queries: list, documents: list, limit: int = 10,
**kwargs):
request = {
"queries": queries,
"documents": documents,
"limit": limit,
}
request.update(kwargs)
return await self.client._send_request(
"reranker", self.flow_id, request,
)
async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs):
"""Triple pattern query"""
request = {"limit": limit}

View file

@ -18,6 +18,7 @@ TG_EDGE_COUNT = TG + "edgeCount"
TG_SELECTED_EDGE = TG + "selectedEdge"
TG_EDGE = TG + "edge"
TG_REASONING = TG + "reasoning"
TG_SCORE = TG + "score"
TG_DOCUMENT = TG + "document"
TG_CONCEPT = TG + "concept"
TG_ENTITY = TG + "entity"
@ -66,10 +67,12 @@ RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
@dataclass
class EdgeSelection:
"""A selected edge with reasoning from GraphRAG Focus step."""
"""A selected edge with cross-encoder metadata from GraphRAG Focus step."""
uri: str
edge: Optional[Dict[str, str]] = None # {"s": ..., "p": ..., "o": ...}
reasoning: str = ""
concept: str = ""
score: Optional[float] = None
@dataclass
@ -209,7 +212,7 @@ class Exploration(ExplainEntity):
@dataclass
class Focus(ExplainEntity):
"""Focus entity - selected edges with LLM reasoning (GraphRAG only)."""
"""Focus entity - selected edges with cross-encoder scoring (GraphRAG only)."""
selected_edge_uris: List[str] = field(default_factory=list)
edge_selections: List[EdgeSelection] = field(default_factory=list)
@ -418,14 +421,26 @@ def parse_edge_selection_triples(triples: List[Tuple[str, str, Any]]) -> EdgeSel
uri = triples[0][0] if triples else ""
edge = None
reasoning = ""
concept = ""
score = None
for s, p, o in triples:
if p == TG_EDGE and isinstance(o, dict):
edge = o
elif p == TG_REASONING:
reasoning = o
elif p == TG_CONCEPT:
concept = o
elif p == TG_SCORE:
try:
score = float(o)
except (ValueError, TypeError):
score = None
return EdgeSelection(uri=uri, edge=edge, reasoning=reasoning)
return EdgeSelection(
uri=uri, edge=edge, reasoning=reasoning,
concept=concept, score=score,
)
def extract_term_value(term: Dict[str, Any]) -> Any:

View file

@ -491,6 +491,19 @@ class FlowInstance:
input
)["vectors"]
def rerank(self, queries, documents, limit=10):
input = {
"queries": queries,
"documents": documents,
"limit": limit,
}
return self.request(
"service/reranker",
input
)
def graph_embeddings_query(self, text, collection, limit=10):
"""
Query knowledge graph entities using semantic similarity.

View file

@ -885,6 +885,19 @@ class SocketFlowInstance:
return self.client._send_request_sync("embeddings", self.flow_id, request, False)
def rerank(self, queries: list, documents: list, limit: int = 10,
**kwargs: Any) -> Dict[str, Any]:
request = {
"queries": queries,
"documents": documents,
"limit": limit,
}
request.update(kwargs)
return self.client._send_request_sync(
"reranker", self.flow_id, request, False,
)
def triples_query(
self,
s: Optional[Union[str, Dict[str, Any]]] = None,

View file

@ -42,6 +42,8 @@ from . dynamic_tool_service import DynamicToolService
from . tool_service_client import ToolServiceClientSpec
from . agent_client import AgentClientSpec
from . structured_query_client import StructuredQueryClientSpec
from . reranker_client import RerankerClientSpec
from . reranker_service import RerankerService
from . row_embeddings_query_client import RowEmbeddingsQueryClientSpec
from . collection_config_handler import CollectionConfigHandler

View file

@ -157,21 +157,6 @@ class PromptClient(RequestResponse):
timeout = timeout,
)
async def kg_prompt(self, query, kg, timeout=600, streaming=False, chunk_callback=None):
return await self.prompt(
id = "kg-prompt",
variables = {
"query": query,
"knowledge": [
{ "s": v[0], "p": v[1], "o": v[2] }
for v in kg
]
},
timeout = timeout,
streaming = streaming,
chunk_callback = chunk_callback,
)
async def document_prompt(self, query, documents, timeout=600, streaming=False, chunk_callback=None):
return await self.prompt(
id = "document-prompt",

View file

@ -0,0 +1,43 @@
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import (
RerankerRequest, RerankerResponse,
RerankerQuery, RerankerDocument,
)
class RerankerClient(RequestResponse):
async def rerank(self, queries, documents, limit=10, timeout=300):
resp = await self.request(
RerankerRequest(
queries=[
RerankerQuery(query_id=q["id"], query_text=q["text"])
for q in queries
],
documents=[
RerankerDocument(
document_id=d["id"], document_text=d["text"]
)
for d in documents
],
limit=limit,
),
timeout=timeout
)
if resp.error:
raise RuntimeError(resp.error.message)
return resp.results
class RerankerClientSpec(RequestResponseSpec):
def __init__(
self, request_name, response_name,
):
super(RerankerClientSpec, self).__init__(
request_name = request_name,
request_schema = RerankerRequest,
response_name = response_name,
response_schema = RerankerResponse,
impl = RerankerClient,
)

View file

@ -0,0 +1,109 @@
from __future__ import annotations
from argparse import ArgumentParser
import logging
from .. schema import (
RerankerRequest, RerankerResponse, RerankerResult, Error,
)
from .. exceptions import TooManyRequests
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec, ParameterSpec
logger = logging.getLogger(__name__)
default_ident = "reranker"
default_concurrency = 1
class RerankerService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
concurrency = params.get("concurrency", 1)
super(RerankerService, self).__init__(**params | {
"id": id,
"concurrency": concurrency,
})
self.register_specification(
ConsumerSpec(
name = "request",
schema = RerankerRequest,
handler = self.on_request,
concurrency = concurrency,
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = RerankerResponse
)
)
self.register_specification(
ParameterSpec(
name = "model",
)
)
async def on_request(self, msg, consumer, flow):
try:
request = msg.value()
id = msg.properties()["id"]
logger.debug(f"Handling reranker request {id}...")
model = flow("model")
results = await self.on_rerank(
request.queries, request.documents,
request.limit, model=model,
)
await flow("response").send(
RerankerResponse(
error = None,
results = results,
),
properties={"id": id}
)
logger.debug("Reranker request handled successfully")
except TooManyRequests as e:
raise e
except Exception as e:
logger.error(f"Exception in reranker service: {e}", exc_info=True)
logger.info("Sending error response...")
await flow.producer["response"].send(
RerankerResponse(
error=Error(
type = "reranker-error",
message = str(e),
),
results=[],
),
properties={"id": id}
)
@staticmethod
def add_args(parser: ArgumentParser) -> None:
parser.add_argument(
'-c', '--concurrency',
type=int,
default=default_concurrency,
help=f'Concurrent processing threads (default: {default_concurrency})'
)
FlowProcessor.add_args(parser)

View file

@ -140,20 +140,6 @@ class PromptClient(BaseClient):
timeout=timeout
)
def request_kg_prompt(self, query, kg, timeout=300):
return self.request(
id="kg-prompt",
variables={
"query": query,
"knowledge": [
{ "s": v[0], "p": v[1], "o": v[2] }
for v in kg
]
},
timeout=timeout
)
def request_document_prompt(self, query, documents, timeout=300):
return self.request(

View file

@ -27,6 +27,7 @@ from .translators.rows_query import RowsQueryRequestTranslator, RowsQueryRespons
from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator
from .translators.structured_query import StructuredQueryRequestTranslator, StructuredQueryResponseTranslator
from .translators.diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
from .translators.reranker import RerankerRequestTranslator, RerankerResponseTranslator
from .translators.collection import CollectionManagementRequestTranslator, CollectionManagementResponseTranslator
from .translators.sparql_query import SparqlQueryRequestTranslator, SparqlQueryResponseTranslator
@ -163,6 +164,12 @@ TranslatorRegistry.register_service(
SparqlQueryResponseTranslator()
)
TranslatorRegistry.register_service(
"reranker",
RerankerRequestTranslator(),
RerankerResponseTranslator()
)
# Register single-direction translators for document loading
TranslatorRegistry.register_request("document", DocumentTranslator())
TranslatorRegistry.register_request("text-document", TextDocumentTranslator())

View file

@ -20,3 +20,4 @@ from .embeddings_query import (
)
from .rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
from .reranker import RerankerRequestTranslator, RerankerResponseTranslator

View file

@ -0,0 +1,73 @@
from typing import Dict, Any, Tuple
from ...schema import (
RerankerRequest, RerankerResponse,
RerankerQuery, RerankerDocument, RerankerResult,
)
from .base import MessageTranslator
class RerankerRequestTranslator(MessageTranslator):
def decode(self, data: Dict[str, Any]) -> RerankerRequest:
return RerankerRequest(
queries=[
RerankerQuery(
query_id=q["query_id"],
query_text=q["query_text"],
)
for q in data.get("queries", [])
],
documents=[
RerankerDocument(
document_id=d["document_id"],
document_text=d["document_text"],
)
for d in data.get("documents", [])
],
limit=data.get("limit", 10),
)
def encode(self, obj: RerankerRequest) -> Dict[str, Any]:
return {
"queries": [
{"query_id": q.query_id, "query_text": q.query_text}
for q in obj.queries
],
"documents": [
{"document_id": d.document_id, "document_text": d.document_text}
for d in obj.documents
],
"limit": obj.limit,
}
class RerankerResponseTranslator(MessageTranslator):
def decode(self, data: Dict[str, Any]) -> RerankerResponse:
return RerankerResponse(
results=[
RerankerResult(
document_id=r["document_id"],
query_id=r["query_id"],
score=r["score"],
)
for r in data.get("results", [])
],
)
def encode(self, obj: RerankerResponse) -> Dict[str, Any]:
return {
"results": [
{
"document_id": r.document_id,
"query_id": r.query_id,
"score": r.score,
}
for r in obj.results
],
}
def encode_with_completion(
self, obj: RerankerResponse
) -> Tuple[Dict[str, Any], bool]:
return self.encode(obj), True

View file

@ -89,7 +89,9 @@ from . namespaces import (
TG_IMAGE_TYPE, TG_SUBGRAPH_TYPE,
# Query-time provenance predicates (GraphRAG)
TG_QUERY, TG_CONCEPT, TG_ENTITY,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING, TG_SCORE,
# Edge selection entity type
TG_EDGE_SELECTION,
# Query-time provenance predicates (DocumentRAG)
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
# Explainability entity types
@ -212,7 +214,9 @@ __all__ = [
"TG_CHUNK_TYPE", "TG_IMAGE_TYPE", "TG_SUBGRAPH_TYPE",
# Query-time provenance predicates (GraphRAG)
"TG_QUERY", "TG_CONCEPT", "TG_ENTITY",
"TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING",
"TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING", "TG_SCORE",
# Edge selection entity type
"TG_EDGE_SELECTION",
# Query-time provenance predicates (DocumentRAG)
"TG_CHUNK_COUNT", "TG_SELECTED_CHUNK",
# Explainability entity types

View file

@ -66,8 +66,12 @@ TG_EDGE_COUNT = TG + "edgeCount"
TG_SELECTED_EDGE = TG + "selectedEdge"
TG_EDGE = TG + "edge"
TG_REASONING = TG + "reasoning"
TG_SCORE = TG + "score"
TG_DOCUMENT = TG + "document" # Reference to document in librarian
# Edge selection entity type (cross-encoder scored edge in Focus)
TG_EDGE_SELECTION = TG + "EdgeSelection"
# Query-time provenance predicates (DocumentRAG)
TG_CHUNK_COUNT = TG + "chunkCount"
TG_SELECTED_CHUNK = TG + "selectedChunk"

View file

@ -24,8 +24,10 @@ from . namespaces import (
TG_ELEMENT_TYPES, TG_TABLE_COUNT, TG_IMAGE_COUNT,
# Query-time provenance predicates (GraphRAG)
TG_QUERY, TG_CONCEPT, TG_ENTITY,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, TG_SCORE,
TG_DOCUMENT,
# Edge selection entity type
TG_EDGE_SELECTION,
# Query-time provenance predicates (DocumentRAG)
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
# Explainability entity types
@ -536,10 +538,9 @@ def focus_triples(
_triple(focus_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)),
]
# Add each selected edge with its reasoning via intermediate entity
# Add each selected edge with metadata via intermediate entity
for idx, edge_info in enumerate(selected_edges_with_reasoning):
edge = edge_info.get("edge")
reasoning = edge_info.get("reasoning", "")
if edge:
s, p, o = edge
@ -552,13 +553,32 @@ def focus_triples(
_triple(focus_uri, TG_SELECTED_EDGE, _iri(edge_sel_uri))
)
# Type the edge selection entity
triples.append(
_triple(edge_sel_uri, RDF_TYPE, _iri(TG_EDGE_SELECTION))
)
# Attach quoted triple to edge selection entity
quoted = _quoted_triple(s, p, o)
triples.append(
Triple(s=_iri(edge_sel_uri), p=_iri(TG_EDGE), o=quoted)
)
# Attach reasoning to edge selection entity
# Structured cross-encoder metadata
concept = edge_info.get("concept")
if concept:
triples.append(
_triple(edge_sel_uri, TG_CONCEPT, _literal(concept))
)
score = edge_info.get("score")
if score is not None:
triples.append(
_triple(edge_sel_uri, TG_SCORE, _literal(str(score)))
)
# Legacy reasoning text (for non-cross-encoder callers)
reasoning = edge_info.get("reasoning", "")
if reasoning:
triples.append(
_triple(edge_sel_uri, TG_REASONING, _literal(reasoning))

View file

@ -29,6 +29,7 @@ from . namespaces import (
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
TG_SUBAGENT_GOAL, TG_PLAN_STEP,
TG_EDGE_SELECTION, TG_SCORE,
)
@ -93,6 +94,7 @@ TG_CLASS_LABELS = [
_label_triple(TG_FINDING, "Finding"),
_label_triple(TG_PLAN_TYPE, "Plan"),
_label_triple(TG_STEP_RESULT, "Step Result"),
_label_triple(TG_EDGE_SELECTION, "Edge Selection"),
]
# TrustGraph predicate labels
@ -117,6 +119,7 @@ TG_PREDICATE_LABELS = [
_label_triple(TG_ENTITY, "entity"),
_label_triple(TG_SUBAGENT_GOAL, "subagent goal"),
_label_triple(TG_PLAN_STEP, "plan step"),
_label_triple(TG_SCORE, "score"),
]

View file

@ -15,4 +15,5 @@ from .diagnosis import *
from .collection import *
from .storage import *
from .tool_service import *
from .sparql_query import *
from .sparql_query import *
from .reranker import *

View file

@ -6,17 +6,6 @@ from ..core.primitives import Error
# Prompt services, abstract the prompt generation
# extract-definitions:
# chunk -> definitions
# extract-relationships:
# chunk -> relationships
# kg-prompt:
# query, triples -> answer
# document-prompt:
# query, documents -> answer
# extract-rows
# schema, chunk -> rows
@dataclass
class PromptRequest:
id: str = ""
@ -46,4 +35,4 @@ class PromptResponse:
out_token: int | None = None
model: str | None = None
############################################################################
############################################################################

View file

@ -0,0 +1,35 @@
from dataclasses import dataclass, field
from ..core.primitives import Error
############################################################################
# Cross-encoder reranker
@dataclass
class RerankerQuery:
query_id: str = ""
query_text: str = ""
@dataclass
class RerankerDocument:
document_id: str = ""
document_text: str = ""
@dataclass
class RerankerRequest:
queries: list[RerankerQuery] = field(default_factory=list)
documents: list[RerankerDocument] = field(default_factory=list)
limit: int = 10
@dataclass
class RerankerResult:
document_id: str = ""
query_id: str = ""
score: float = 0.0
@dataclass
class RerankerResponse:
error: Error | None = None
results: list[RerankerResult] = field(default_factory=list)

View file

@ -71,6 +71,7 @@ tg-invoke-rows-query = "trustgraph.cli.invoke_rows_query:main"
tg-invoke-sparql-query = "trustgraph.cli.invoke_sparql_query:main"
tg-invoke-row-embeddings = "trustgraph.cli.invoke_row_embeddings:main"
tg-invoke-prompt = "trustgraph.cli.invoke_prompt:main"
tg-invoke-reranker = "trustgraph.cli.invoke_reranker:main"
tg-invoke-structured-query = "trustgraph.cli.invoke_structured_query:main"
tg-load-doc-embeds = "trustgraph.cli.load_doc_embeds:main"
tg-load-kg-core = "trustgraph.cli.load_kg_core:main"

View file

@ -112,14 +112,13 @@ def _question_explainable_api(
if focus_full and focus_full.edge_selections:
for edge_sel in focus_full.edge_selections:
if edge_sel.edge:
# Resolve labels for edge components
s_label, p_label, o_label = explain_client.resolve_edge_labels(
edge_sel.edge, collection
)
print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr)
if edge_sel.reasoning:
r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning
print(f" Reason: {r_short}", file=sys.stderr)
if edge_sel.concept or edge_sel.score is not None:
score_str = f"{edge_sel.score:.4f}" if edge_sel.score is not None else "?"
print(f" Concept: {edge_sel.concept} Score: {score_str}", file=sys.stderr)
elif isinstance(entity, Synthesis):
print(f"\n [synthesis] {prov_id}", file=sys.stderr)

View file

@ -0,0 +1,127 @@
"""
Invokes the reranker service to score and rank documents by relevance
to one or more queries.
"""
import argparse
import json
import os
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def query(url, flow_id, queries, documents, limit, token=None,
workspace="default"):
api = Api(url=url, token=token, workspace=workspace)
socket = api.socket()
flow = socket.flow(flow_id)
try:
query_objects = [
{"query_id": str(i), "query_text": q}
for i, q in enumerate(queries)
]
document_objects = [
{"document_id": str(i), "document_text": d}
for i, d in enumerate(documents)
]
result = flow.rerank(
queries=query_objects,
documents=document_objects,
limit=limit,
)
if "error" in result and result["error"]:
err = result["error"]
print(f"Error: [{err.get('type', '')}] {err.get('message', '')}")
return
for r in result.get("results", []):
doc_idx = int(r["document_id"])
query_idx = int(r["query_id"])
print(
f" {r['score']:.4f} | "
f"query: {queries[query_idx]} | "
f"doc: {documents[doc_idx]}"
)
finally:
socket.close()
def main():
parser = argparse.ArgumentParser(
prog='tg-invoke-reranker',
description=__doc__,
)
parser.add_argument(
'-u', '--url',
default=default_url,
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-t', '--token',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'-f', '--flow-id',
default="default",
help=f'Flow ID (default: default)'
)
parser.add_argument(
'-l', '--limit',
type=int,
default=10,
help='Maximum number of results (default: 10)',
)
parser.add_argument(
'-q', '--query',
action='append',
required=True,
help='Query text (can be specified multiple times)',
)
parser.add_argument(
'documents',
nargs='+',
help='Documents to rerank',
)
args = parser.parse_args()
try:
query(
url=args.url,
flow_id=args.flow_id,
queries=args.query,
documents=args.documents,
limit=args.limit,
token=args.token,
workspace=args.workspace,
)
except Exception as e:
print("Exception:", e, flush=True)
if __name__ == "__main__":
main()

View file

@ -203,9 +203,9 @@ def print_graphrag_text(trace, explain_client, flow, collection, api=None, show_
)
print(f" {i}. ({s_label}, {p_label}, {o_label})")
if edge_sel.reasoning:
r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning
print(f" Reasoning: {r_short}")
if edge_sel.concept or edge_sel.score is not None:
score_str = f"{edge_sel.score:.4f}" if edge_sel.score is not None else "?"
print(f" Concept: {edge_sel.concept} Score: {score_str}")
if show_provenance and edge_sel.edge:
provenance = trace_edge_provenance(
@ -519,7 +519,8 @@ def trace_to_dict(trace, trace_type):
"selected_edges": [
{
"edge": edge_sel.edge,
"reasoning": edge_sel.reasoning,
"concept": edge_sel.concept,
"score": edge_sel.score,
}
for edge_sel in focus.edge_selections
],

View file

@ -19,6 +19,7 @@ dependencies = [
"faiss-cpu",
"falkordb",
"fastembed",
"flashrank",
"ibis",
"jsonschema",
"langchain",
@ -83,6 +84,7 @@ graph-embeddings-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone:
graph-embeddings-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run"
graph-embeddings = "trustgraph.embeddings.graph_embeddings:run"
graph-rag = "trustgraph.retrieval.graph_rag:run"
reranker-flashrank = "trustgraph.reranker.flashrank:run"
kg-extract-agent = "trustgraph.extract.kg.agent:run"
kg-extract-definitions = "trustgraph.extract.kg.definitions:run"
kg-extract-rows = "trustgraph.extract.kg.rows:run"

View file

@ -37,6 +37,7 @@ from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
from . document_embeddings_query import DocumentEmbeddingsQueryRequestor
from . row_embeddings_query import RowEmbeddingsQueryRequestor
from . mcp_tool import McpToolRequestor
from . reranker import RerankerRequestor
from . text_load import TextLoad
from . document_load import DocumentLoad
@ -74,6 +75,7 @@ request_response_dispatchers = {
"structured-diag": StructuredDiagRequestor,
"row-embeddings": RowEmbeddingsQueryRequestor,
"sparql": SparqlQueryRequestor,
"reranker": RerankerRequestor,
}
system_dispatchers = {

View file

@ -0,0 +1,31 @@
from ... schema import RerankerRequest, RerankerResponse
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
class RerankerRequestor(ServiceRequestor):
def __init__(
self, backend, request_queue, response_queue, timeout,
consumer, subscriber,
):
super(RerankerRequestor, self).__init__(
backend=backend,
request_queue=request_queue,
response_queue=response_queue,
request_schema=RerankerRequest,
response_schema=RerankerResponse,
subscription = subscriber,
consumer_name = consumer,
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("reranker")
self.response_translator = TranslatorRegistry.get_response_translator("reranker")
def to_request(self, body):
return self.request_translator.decode(body)
def from_response(self, message):
return self.response_translator.encode_with_completion(message)

View file

@ -518,6 +518,7 @@ _FLOW_SERVICES = {
"structured-diag": "structured-query:read",
"row-embeddings": "row-embeddings:read",
"sparql": "sparql:read",
"reranker": "reranker",
}
for _kind, _cap in _FLOW_SERVICES.items():
_register_flow_kind("flow-service", _kind, _cap)

View file

@ -72,6 +72,7 @@ _READER_CAPS = {
"row-embeddings:read",
"llm",
"embeddings",
"reranker",
"mcp",
"config:read",
"flows:read",

View file

@ -0,0 +1 @@

View file

@ -0,0 +1,2 @@
from . processor import *

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from . processor import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,109 @@
"""
Reranker service using flashrank.
Scores query-document pairs and returns the top results ranked by
relevance.
"""
import asyncio
import logging
from ... base import RerankerService
from ... schema import RerankerResult
from flashrank import Ranker, RerankRequest
logger = logging.getLogger(__name__)
default_ident = "reranker"
default_model = "ms-marco-MiniLM-L-12-v2"
class Processor(RerankerService):
def __init__(self, **params):
model = params.get("model", default_model)
super(Processor, self).__init__(
**params | { "model": model }
)
self.default_model = model
self.cached_model_name = None
self.ranker = None
self._load_model(model)
def _load_model(self, model_name):
if self.cached_model_name != model_name:
logger.info(f"Loading flashrank model: {model_name}")
self.ranker = Ranker(model_name=model_name)
self.cached_model_name = model_name
logger.info(f"flashrank model {model_name} loaded successfully")
else:
logger.debug(f"Using cached model: {model_name}")
def _run_rerank(self, query, passages):
request = RerankRequest(query=query, passages=passages)
return self.ranker.rerank(request)
async def on_rerank(self, queries, documents, limit, model=None):
if not queries or not documents:
return []
use_model = model or self.default_model
if self.cached_model_name != use_model:
await asyncio.to_thread(self._load_model, use_model)
passages = [
{"id": d.document_id, "text": d.document_text}
for d in documents
]
best_scores = {}
for q in queries:
ranked = await asyncio.to_thread(
self._run_rerank, q.query_text, passages,
)
for r in ranked:
doc_id = r["id"]
score = r["score"]
score = float(score)
if doc_id not in best_scores or score > best_scores[doc_id][1]:
best_scores[doc_id] = (q.query_id, score)
results = sorted(
best_scores.items(),
key=lambda x: x[1][1],
reverse=True,
)[:limit]
return [
RerankerResult(
document_id=doc_id,
query_id=query_id,
score=score,
)
for doc_id, (query_id, score) in results
]
@staticmethod
def add_args(parser):
RerankerService.add_args(parser)
parser.add_argument(
'-m', '--model',
default=default_model,
help=f'Reranker model (default: {default_model})'
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -120,7 +120,7 @@ class Query:
def __init__(
self, rag, collection, verbose,
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
max_path_length=2, track_usage=None,
max_path_length=2, edge_limit=25, track_usage=None,
):
self.rag = rag
self.collection = collection
@ -129,6 +129,7 @@ class Query:
self.triple_limit = triple_limit
self.max_subgraph_size = max_subgraph_size
self.max_path_length = max_path_length
self.edge_limit = edge_limit
self.track_usage = track_usage
async def extract_concepts(self, query):
@ -217,12 +218,9 @@ class Query:
logger.debug(f" {ent}")
return entities, concepts
async def maybe_label(self, e):
# The label cache lives on a per-request GraphRag instance — no
# cross-request isolation concern. The collection prefix keeps
# entries from different collections distinct within one request.
cache_key = f"{self.collection}:{e}"
cached_label = self.rag.label_cache.get(cache_key)
@ -244,11 +242,10 @@ class Query:
return label
async def execute_batch_triple_queries(self, entities, limit_per_entity):
"""Execute triple queries for multiple entities concurrently using streaming"""
"""Execute triple queries for multiple entities concurrently."""
tasks = []
for entity in entities:
# Create concurrent streaming tasks for all 3 query types per entity
tasks.extend([
self.rag.triples_client.query_stream(
s=entity, p=None, o=None,
@ -270,10 +267,8 @@ class Query:
)
])
# Execute all queries concurrently
results = await asyncio.gather(*tasks, return_exceptions=True)
# Combine all results
all_triples = []
for result in results:
if not isinstance(result, Exception) and result is not None:
@ -281,168 +276,151 @@ class Query:
return all_triples
async def follow_edges_batch(self, entities, max_depth):
"""Optimized iterative graph traversal with batching.
Returns:
tuple: (subgraph, term_map) where subgraph is a set of
(str, str, str) tuples and term_map maps each string tuple
to its original (Term, Term, Term) for type-preserving
provenance.
"""
visited = set()
current_level = set(entities)
subgraph = set()
term_map = {} # (str, str, str) -> (Term, Term, Term)
for depth in range(max_depth):
if not current_level or len(subgraph) >= self.max_subgraph_size:
break
# Filter out already visited entities
unvisited_entities = [e for e in current_level if e not in visited]
if not unvisited_entities:
break
# Batch query all unvisited entities at current level
triples = await self.execute_batch_triple_queries(
unvisited_entities, self.triple_limit
)
# Process results and collect next level entities
next_level = set()
for triple in triples:
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
subgraph.add(triple_tuple)
term_map[triple_tuple] = (to_term(triple.s), to_term(triple.p), to_term(triple.o))
# Collect entities for next level (only from s and o positions)
if depth < max_depth - 1: # Don't collect for final depth
s, p, o = triple_tuple
if s not in visited:
next_level.add(s)
if o not in visited:
next_level.add(o)
# Stop if subgraph size limit reached
if len(subgraph) >= self.max_subgraph_size:
return subgraph, term_map
# Update for next iteration
visited.update(current_level)
current_level = next_level
return subgraph, term_map
async def follow_edges(self, ent, subgraph, path_length):
"""Legacy method - replaced by follow_edges_batch"""
# Maintain backward compatibility with early termination checks
if path_length <= 0:
return
if len(subgraph) >= self.max_subgraph_size:
return
# For backward compatibility, convert to new approach
batch_result, _ = await self.follow_edges_batch([ent], path_length)
subgraph.update(batch_result)
async def get_subgraph(self, query):
"""
Get subgraph by extracting concepts, finding entities, and traversing.
Returns:
tuple: (subgraph, term_map, entities, concepts) where subgraph is
a list of (s, p, o) string tuples, term_map maps each string
tuple to its original (Term, Term, Term), entities is the seed
entity list, and concepts is the extracted concept list.
"""
entities, concepts = await self.get_entities(query)
if self.verbose:
logger.debug("Getting subgraph...")
# Use optimized batch traversal instead of sequential processing
subgraph, term_map = await self.follow_edges_batch(entities, self.max_path_length)
return list(subgraph), term_map, entities, concepts
async def resolve_labels_batch(self, entities):
"""Resolve labels for multiple entities in parallel"""
tasks = []
for entity in entities:
tasks.append(self.maybe_label(entity))
"""Resolve labels for multiple entities in parallel."""
tasks = [self.maybe_label(entity) for entity in entities]
return await asyncio.gather(*tasks, return_exceptions=True)
async def get_labelgraph(self, query):
"""
Get subgraph with labels resolved for display.
async def hop_and_filter(self, seed_entities, concepts):
"""Iterative hop-and-filter graph traversal with cross-encoder.
At each hop:
1. Retrieve all edges one hop from the frontier.
2. Resolve labels and represent each edge as "{p} {o}".
3. Score edges against concepts using the cross-encoder.
4. Select the top edges; their target nodes become the next
frontier.
Returns:
tuple: (labeled_edges, uri_map, entities, concepts) where:
- labeled_edges: list of (label_s, label_p, label_o) tuples
- uri_map: dict mapping edge_id(label_s, label_p, label_o) -> (uri_s, uri_p, uri_o)
- entities: list of seed entity URI strings
- concepts: list of concept strings extracted from query
tuple: (selected_edges, uri_map, edge_metadata) where:
- selected_edges: list of (label_s, label_p, label_o)
- uri_map: dict mapping edge_id -> (Term, Term, Term)
- edge_metadata: dict mapping edge_id -> {concept, score}
"""
subgraph, term_map, entities, concepts = await self.get_subgraph(query)
all_selected_edges = []
uri_map = {}
edge_metadata = {}
frontier = set(seed_entities)
visited_entities = set()
seen_edges = set()
# Filter out label triples
filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL]
for hop in range(self.max_path_length):
if not frontier:
break
# Collect all unique entities that need label resolution
entities_to_resolve = set()
for s, p, o in filtered_subgraph:
entities_to_resolve.update([s, p, o])
unvisited = [e for e in frontier if e not in visited_entities]
if not unvisited:
break
# Batch resolve labels for all entities in parallel
entity_list = list(entities_to_resolve)
resolved_labels = await self.resolve_labels_batch(entity_list)
if self.verbose:
logger.debug(
f"Hop {hop + 1}: {len(unvisited)} frontier entities"
)
# Create entity-to-label mapping
label_map = {}
for entity, label in zip(entity_list, resolved_labels):
if not isinstance(label, Exception):
label_map[entity] = label
else:
label_map[entity] = entity # Fallback to entity itself
# Apply labels to subgraph and build URI mapping
labeled_edges = []
uri_map = {} # Maps edge_id of labeled edge -> original Term triple
for s, p, o in filtered_subgraph:
labeled_triple = (
label_map.get(s, s),
label_map.get(p, p),
label_map.get(o, o)
# Retrieve edges one hop from frontier
triples = await self.execute_batch_triple_queries(
unvisited, self.triple_limit,
)
labeled_edges.append(labeled_triple)
# Map from labeled edge ID to original Terms (preserving types)
labeled_eid = edge_id(labeled_triple[0], labeled_triple[1], labeled_triple[2])
uri_map[labeled_eid] = term_map.get((s, p, o), (s, p, o))
# Deduplicate and filter already-seen edges
hop_triples = []
hop_term_map = {}
for triple in triples:
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
if triple_tuple[1] == LABEL:
continue
if triple_tuple in seen_edges:
continue
seen_edges.add(triple_tuple)
hop_triples.append(triple_tuple)
hop_term_map[triple_tuple] = (
to_term(triple.s), to_term(triple.p), to_term(triple.o),
)
labeled_edges = labeled_edges[0:self.max_subgraph_size]
if not hop_triples:
visited_entities.update(frontier)
break
if self.verbose:
logger.debug("Subgraph:")
for edge in labeled_edges:
logger.debug(f" {str(edge)}")
if self.verbose:
logger.debug(
f"Hop {hop + 1}: {len(hop_triples)} candidate edges"
)
if self.verbose:
logger.debug("Done.")
# Resolve labels for all entities in hop edges
entities_to_resolve = set()
for s, p, o in hop_triples:
entities_to_resolve.update([s, p, o])
return labeled_edges, uri_map, entities, concepts
entity_list = list(entities_to_resolve)
resolved = await self.resolve_labels_batch(entity_list)
label_map = {}
for entity, label in zip(entity_list, resolved):
if not isinstance(label, Exception):
label_map[entity] = label
else:
label_map[entity] = entity
# Build labeled edges and documents for cross-encoder
labeled_hop = []
for s, p, o in hop_triples:
ls = label_map.get(s, s)
lp = label_map.get(p, p)
lo = label_map.get(o, o)
labeled_hop.append((ls, lp, lo))
documents = [
{"id": str(i), "text": f"{lp} {lo}"}
for i, (ls, lp, lo) in enumerate(labeled_hop)
]
queries = [
{"id": str(i), "text": c}
for i, c in enumerate(concepts)
]
# Score with cross-encoder
results = await self.rag.reranker_client.rerank(
queries=queries,
documents=documents,
limit=self.edge_limit,
)
# Collect selected edges and metadata
next_frontier = set()
for r in results:
idx = int(r.document_id)
ls, lp, lo = labeled_hop[idx]
s, p, o = hop_triples[idx]
eid = edge_id(ls, lp, lo)
all_selected_edges.append((ls, lp, lo))
uri_map[eid] = hop_term_map[(s, p, o)]
edge_metadata[eid] = {
"concept": concepts[int(r.query_id)],
"score": r.score,
}
# Target nodes become next frontier
next_frontier.add(s)
next_frontier.add(o)
if self.verbose:
logger.debug(
f"Hop {hop + 1}: selected {len(results)} edges"
)
visited_entities.update(frontier)
frontier = next_frontier - visited_entities
return all_selected_edges, uri_map, edge_metadata
async def trace_source_documents(self, edge_uris):
"""
Trace selected edges back to their source documents via provenance.
Follows the chain: edge subgraph (via tg:contains) chunk
page document (via prov:wasDerivedFrom), all in urn:graph:source.
Follows the chain: edge -> subgraph (via tg:contains) -> chunk ->
page -> document (via prov:wasDerivedFrom), all in urn:graph:source.
Args:
edge_uris: List of (s, p, o) URI string tuples
@ -453,7 +431,6 @@ class Query:
# Step 1: Find subgraphs containing these edges via tg:contains
subgraph_tasks = []
for s, p, o in edge_uris:
# s, p, o may be Term objects (preserving types) or strings
s_term = s if isinstance(s, Term) else Term(type=IRI, iri=s)
p_term = p if isinstance(p, Term) else Term(type=IRI, iri=p)
o_term = o if isinstance(o, Term) else Term(type=IRI, iri=o)
@ -487,12 +464,10 @@ class Query:
return []
# Step 2: Walk prov:wasDerivedFrom chain to find documents
# Each level: query ?entity prov:wasDerivedFrom ?parent
# Stop when we find entities typed tg:Document
current_uris = subgraph_uris
doc_uris = set()
for depth in range(4): # Max depth: subgraph → chunk → page → doc
for depth in range(4):
if not current_uris:
break
@ -509,7 +484,6 @@ class Query:
*derivation_tasks, return_exceptions=True
)
# URIs with no parent are root documents
next_uris = set()
for uri, result in zip(current_uris, derivation_results):
if isinstance(result, Exception) or not result:
@ -524,7 +498,6 @@ class Query:
return []
# Step 3: Get all document metadata properties
# Skip structural predicates that aren't useful context
SKIP_PREDICATES = {
PROV_WAS_DERIVED_FROM,
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
@ -565,7 +538,7 @@ class GraphRag:
def __init__(
self, prompt_client, embeddings_client, graph_embeddings_client,
triples_client, verbose=False,
triples_client, reranker_client, verbose=False,
):
self.verbose = verbose
@ -574,9 +547,8 @@ class GraphRag:
self.embeddings_client = embeddings_client
self.graph_embeddings_client = graph_embeddings_client
self.triples_client = triples_client
self.reranker_client = reranker_client
# Replace simple dict with LRU cache with TTL
# CRITICAL: This cache only lives for one request due to per-request instantiation
self.label_cache = LRUCacheWithTTL(max_size=5000, ttl=300)
if self.verbose:
@ -585,33 +557,12 @@ class GraphRag:
async def query(
self, query, collection = "default",
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
max_path_length = 2, edge_score_limit = 30, edge_limit = 25,
max_path_length = 2, edge_limit = 25,
streaming = False,
chunk_callback = None,
explain_callback = None, save_answer_callback = None,
parent_uri = "",
):
"""
Execute a GraphRAG query with real-time explainability tracking.
Args:
query: The query string
collection: Collection identifier
entity_limit: Max entities to retrieve
triple_limit: Max triples per entity
max_subgraph_size: Max edges in subgraph
max_path_length: Max hops from seed entities
edge_score_limit: Max edges to pass to LLM scoring (semantic pre-filter)
edge_limit: Max edges after LLM scoring
streaming: Enable streaming LLM response
chunk_callback: async def callback(chunk, end_of_stream) for streaming
explain_callback: async def callback(triples, explain_id) for real-time explainability
save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian
Returns:
tuple: (answer_text, usage) where usage is a dict with
in_token, out_token, model
"""
# Accumulate token usage across all prompt calls
total_in = 0
total_out = 0
@ -638,7 +589,9 @@ class GraphRag:
foc_uri = make_focus_uri(session_id)
syn_uri = make_synthesis_uri(session_id)
timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
timestamp = datetime.now(timezone.utc).isoformat().replace(
"+00:00", "Z",
)
# Emit question explainability immediately
if explain_callback:
@ -657,10 +610,12 @@ class GraphRag:
triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
edge_limit = edge_limit,
track_usage = track_usage,
)
kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query)
# Step 1: Extract concepts and find seed entities
seed_entities, concepts = await q.get_entities(query)
# Emit grounding explain after concept extraction
if explain_callback:
@ -676,11 +631,16 @@ class GraphRag:
)
await explain_callback(gnd_triples, gnd_uri)
# Emit exploration explain after graph retrieval completes
# Step 2: Iterative hop-and-filter with cross-encoder
selected_edges, uri_map, edge_metadata = await q.hop_and_filter(
seed_entities, concepts,
)
# Emit exploration explain
if explain_callback:
exp_triples = set_graph(
exploration_triples(
exp_uri, gnd_uri, len(kg),
exp_uri, gnd_uri, len(selected_edges),
entities=seed_entities,
),
GRAPH_RETRIEVAL
@ -688,235 +648,63 @@ class GraphRag:
await explain_callback(exp_triples, exp_uri)
if self.verbose:
logger.debug("Invoking LLM...")
logger.debug(f"Knowledge graph: {kg}")
logger.debug(f"Query: {query}")
# Semantic pre-filter: reduce edges before expensive LLM scoring
if edge_score_limit > 0 and len(kg) > edge_score_limit:
if self.verbose:
logger.debug(f"Selected {len(selected_edges)} edges")
for s, p, o in selected_edges:
eid = edge_id(s, p, o)
meta = edge_metadata.get(eid, {})
logger.debug(
f"Semantic pre-filter: {len(kg)} edges > "
f"limit {edge_score_limit}, filtering..."
f" {meta.get('score', 0):.4f} "
f"[{meta.get('concept', '')}] "
f"{s} | {p} | {o}"
)
# Embed edge descriptions: "subject, predicate, object"
edge_descriptions = [
f"{s}, {p}, {o}" for s, p, o in kg
]
# Embed concepts and edge descriptions concurrently
concept_embed_task = self.embeddings_client.embed(concepts)
edge_embed_task = self.embeddings_client.embed(edge_descriptions)
concept_vectors, edge_vectors = await asyncio.gather(
concept_embed_task, edge_embed_task
)
# Score each edge by max cosine similarity to any concept
def cosine_similarity(a, b):
dot = sum(x * y for x, y in zip(a, b))
norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(x * x for x in b))
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
edge_scores = []
for i, edge_vec in enumerate(edge_vectors):
max_sim = max(
cosine_similarity(edge_vec, cv)
for cv in concept_vectors
)
edge_scores.append((max_sim, i))
# Sort by similarity descending and keep top edge_score_limit
edge_scores.sort(reverse=True)
keep_indices = set(
idx for _, idx in edge_scores[:edge_score_limit]
)
# Filter kg and rebuild uri_map
filtered_kg = []
filtered_uri_map = {}
for i, (s, p, o) in enumerate(kg):
if i in keep_indices:
filtered_kg.append((s, p, o))
eid = edge_id(s, p, o)
if eid in uri_map:
filtered_uri_map[eid] = uri_map[eid]
if self.verbose:
logger.debug(
f"Semantic pre-filter kept {len(filtered_kg)} "
f"of {len(kg)} edges"
)
kg = filtered_kg
uri_map = filtered_uri_map
# Build edge map: {hash_id: (labeled_s, labeled_p, labeled_o)}
# uri_map already maps edge_id -> (uri_s, uri_p, uri_o)
edge_map = {}
edges_with_ids = []
for s, p, o in kg:
eid = edge_id(s, p, o)
edge_map[eid] = (s, p, o)
edges_with_ids.append({
"id": eid,
"s": s,
"p": p,
"o": o
})
if self.verbose:
logger.debug(f"Built edge map with {len(edge_map)} edges")
# Step 1a: Edge Scoring - LLM scores edges for relevance
scoring_result = await self.prompt_client.prompt(
"kg-edge-scoring",
variables={
"query": query,
"knowledge": edges_with_ids
}
)
track_usage(scoring_result)
if self.verbose:
logger.debug(f"Edge scoring result: {scoring_result}")
# Parse scoring response (jsonl) to get edge IDs with scores
scored_edges = []
for obj in scoring_result.objects or []:
if isinstance(obj, dict) and "id" in obj and "score" in obj:
try:
score = int(obj["score"])
except (ValueError, TypeError):
score = 0
scored_edges.append({"id": obj["id"], "score": score})
# Select top N edges by score
scored_edges.sort(key=lambda x: x["score"], reverse=True)
top_edges = scored_edges[:edge_limit]
selected_ids = {e["id"] for e in top_edges}
if self.verbose:
logger.debug(
f"Scored {len(scored_edges)} edges, "
f"selected top {len(selected_ids)}"
)
# Filter to selected edges
selected_edges = []
for eid in selected_ids:
if eid in edge_map:
selected_edges.append(edge_map[eid])
# Step 1b: Edge Reasoning + Document Tracing (concurrent)
selected_edges_with_ids = [
{"id": eid, "s": s, "p": p, "o": o}
for eid in selected_ids
if eid in edge_map
for s, p, o in [edge_map[eid]]
]
# Collect selected edge URIs for document tracing
# Step 3: Document tracing
selected_edge_uris = [
uri_map[eid]
for eid in selected_ids
if eid in uri_map
uri_map[edge_id(s, p, o)]
for s, p, o in selected_edges
if edge_id(s, p, o) in uri_map
]
# Run reasoning and document tracing concurrently
async def _get_reasoning():
result = await self.prompt_client.prompt(
"kg-edge-reasoning",
variables={
"query": query,
"knowledge": selected_edges_with_ids
}
)
track_usage(result)
return result
reasoning_task = _get_reasoning()
doc_trace_task = q.trace_source_documents(selected_edge_uris)
reasoning_result, source_documents = await asyncio.gather(
reasoning_task, doc_trace_task, return_exceptions=True
source_documents = await q.trace_source_documents(
selected_edge_uris,
)
# Handle exceptions from gather
if isinstance(reasoning_result, Exception):
logger.warning(
f"Edge reasoning failed: {reasoning_result}"
)
reasoning_result = None
if isinstance(source_documents, Exception):
logger.warning(
f"Document tracing failed: {source_documents}"
)
source_documents = []
if self.verbose:
logger.debug(f"Edge reasoning result: {reasoning_result}")
# Parse reasoning response (jsonl) and build explainability data
reasoning_map = {}
if reasoning_result is not None:
for obj in reasoning_result.objects or []:
if isinstance(obj, dict) and "id" in obj:
reasoning_map[obj["id"]] = obj.get("reasoning", "")
# Build focus explainability data with cross-encoder metadata
selected_edges_with_reasoning = []
for eid in selected_ids:
for s, p, o in selected_edges:
eid = edge_id(s, p, o)
if eid in uri_map:
uri_s, uri_p, uri_o = uri_map[eid]
meta = edge_metadata.get(eid, {})
selected_edges_with_reasoning.append({
"edge": (uri_s, uri_p, uri_o),
"reasoning": reasoning_map.get(eid, ""),
"concept": meta.get("concept", ""),
"score": meta.get("score", 0),
})
if self.verbose:
logger.debug(f"Filtered to {len(selected_edges)} edges")
# Emit focus explain after edge selection completes
# Emit focus explain
if explain_callback:
# Sum scoring + reasoning token usage for focus event
focus_in = 0
focus_out = 0
focus_model = None
for r in [scoring_result, reasoning_result]:
if r is not None:
if r.in_token is not None:
focus_in += r.in_token
if r.out_token is not None:
focus_out += r.out_token
if r.model is not None:
focus_model = r.model
foc_triples = set_graph(
focus_triples(
foc_uri, exp_uri, selected_edges_with_reasoning, session_id,
in_token=focus_in or None,
out_token=focus_out or None,
model=focus_model,
foc_uri, exp_uri,
selected_edges_with_reasoning, session_id,
),
GRAPH_RETRIEVAL
)
await explain_callback(foc_triples, foc_uri)
# Step 2: Synthesis - LLM generates answer from selected edges only
# Step 4: Synthesis
selected_edge_dicts = [
{"s": s, "p": p, "o": o}
for s, p, o in selected_edges
]
# Add source document metadata as knowledge edges
for s, p, o in source_documents:
selected_edge_dicts.append({
"s": s, "p": p, "o": o,
@ -928,7 +716,6 @@ class GraphRag:
}
if streaming and chunk_callback:
# Accumulate chunks for answer storage while forwarding to callback
accumulated_chunks = []
async def accumulating_callback(chunk, end_of_stream):
@ -942,7 +729,6 @@ class GraphRag:
chunk_callback=accumulating_callback
)
track_usage(synthesis_result)
# Combine all chunks into full response
resp = "".join(accumulated_chunks)
else:
synthesis_result = await self.prompt_client.prompt(
@ -955,29 +741,42 @@ class GraphRag:
if self.verbose:
logger.debug("Query processing complete")
# Emit synthesis explain after synthesis completes
# Emit synthesis explain
if explain_callback:
synthesis_doc_id = None
answer_text = resp if resp else ""
# Save answer to librarian
if save_answer_callback and answer_text:
synthesis_doc_id = f"urn:trustgraph:synthesis:{session_id}"
try:
await save_answer_callback(synthesis_doc_id, answer_text)
if self.verbose:
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
logger.debug(
f"Saved answer to librarian: "
f"{synthesis_doc_id}"
)
except Exception as e:
logger.warning(f"Failed to save answer to librarian: {e}")
logger.warning(
f"Failed to save answer to librarian: {e}"
)
synthesis_doc_id = None
syn_triples = set_graph(
synthesis_triples(
syn_uri, foc_uri,
document_id=synthesis_doc_id,
in_token=synthesis_result.in_token if synthesis_result else None,
out_token=synthesis_result.out_token if synthesis_result else None,
model=synthesis_result.model if synthesis_result else None,
in_token=(
synthesis_result.in_token
if synthesis_result else None
),
out_token=(
synthesis_result.out_token
if synthesis_result else None
),
model=(
synthesis_result.model
if synthesis_result else None
),
),
GRAPH_RETRIEVAL
)
@ -993,4 +792,3 @@ class GraphRag:
}
return resp, usage

View file

@ -13,6 +13,7 @@ from . graph_rag import GraphRag
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import PromptClientSpec, EmbeddingsClientSpec
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
from ... base import RerankerClientSpec
from ... base import LibrarianSpec
# Module logger
@ -32,7 +33,6 @@ class Processor(FlowProcessor):
triple_limit = params.get("triple_limit", 30)
max_subgraph_size = params.get("max_subgraph_size", 150)
max_path_length = params.get("max_path_length", 2)
edge_score_limit = params.get("edge_score_limit", 30)
edge_limit = params.get("edge_limit", 25)
super(Processor, self).__init__(
@ -43,7 +43,6 @@ class Processor(FlowProcessor):
"triple_limit": triple_limit,
"max_subgraph_size": max_subgraph_size,
"max_path_length": max_path_length,
"edge_score_limit": edge_score_limit,
"edge_limit": edge_limit,
}
)
@ -52,7 +51,6 @@ class Processor(FlowProcessor):
self.default_triple_limit = triple_limit
self.default_max_subgraph_size = max_subgraph_size
self.default_max_path_length = max_path_length
self.default_edge_score_limit = edge_score_limit
self.default_edge_limit = edge_limit
# Workspace isolation is enforced by the flow layer (flow.workspace).
@ -96,6 +94,13 @@ class Processor(FlowProcessor):
)
)
self.register_specification(
RerankerClientSpec(
request_name = "reranker-request",
response_name = "reranker-response",
)
)
self.register_specification(
ProducerSpec(
name = "response",
@ -163,6 +168,7 @@ class Processor(FlowProcessor):
graph_embeddings_client=flow("graph-embeddings-request"),
triples_client=flow("triples-request"),
prompt_client=flow("prompt-request"),
reranker_client=flow("reranker-request"),
verbose=True,
)
@ -186,11 +192,6 @@ class Processor(FlowProcessor):
else:
max_path_length = self.default_max_path_length
if v.edge_score_limit:
edge_score_limit = v.edge_score_limit
else:
edge_score_limit = self.default_edge_score_limit
if v.edge_limit:
edge_limit = v.edge_limit
else:
@ -225,7 +226,7 @@ class Processor(FlowProcessor):
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
edge_score_limit = edge_score_limit,
edge_limit = edge_limit,
streaming = True,
chunk_callback = send_chunk,
@ -241,7 +242,7 @@ class Processor(FlowProcessor):
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
edge_score_limit = edge_score_limit,
edge_limit = edge_limit,
explain_callback = send_explainability,
save_answer_callback = save_answer,
@ -338,18 +339,11 @@ class Processor(FlowProcessor):
help=f'Default max path length (default: 2)'
)
parser.add_argument(
'--edge-score-limit',
type=int,
default=30,
help=f'Semantic pre-filter limit before LLM scoring (default: 30)'
)
parser.add_argument(
'--edge-limit',
type=int,
default=25,
help=f'Max edges after LLM scoring (default: 25)'
help=f'Max edges selected per hop by cross-encoder (default: 25)'
)
# Note: Explainability triples are now stored in the request's collection