mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-30 17:09:38 +02:00
feat: replace LLM edge scoring with cross-encoder reranker in GraphRAG (#1005)
Replace the three-prompt LLM scoring pipeline (kg-edge-scoring, kg-edge-reasoning, kg-edge-selection) with a cross-encoder reranker service backed by FlashRank. The new hop_and_filter() method performs iterative graph traversal with semantic scoring at each hop, replacing the previous follow_edges/get_subgraph approach. - Add reranker service (trustgraph-base client/service, FlashRank processor) - Add gateway dispatch for reranker via API and WebSocket - Rewrite GraphRAG pipeline: hop_and_filter() with per-hop cross-encoder scoring - Remove kg_prompt() and edge_score_limit from prompt client - Update provenance: add tg:EdgeSelection type, tg:concept, tg:score predicates - Update CLIs (tg-invoke-graph-rag, tg-show-explain-trace) for new metadata - Add tg-invoke-reranker CLI tool - Add tech spec and UX developer guidance - Update all unit and integration tests
This commit is contained in:
parent
1aa9549912
commit
01cc8dbc64
43 changed files with 1613 additions and 792 deletions
523
docs/tech-specs/graph-rag-semantic-filter.md
Normal file
523
docs/tech-specs/graph-rag-semantic-filter.md
Normal file
|
|
@ -0,0 +1,523 @@
|
|||
# GraphRAG Semantic Filter Improvement
|
||||
|
||||
## Problem Statement
|
||||
|
||||
The GraphRAG semantic filter is observed to be ineffective with certain
|
||||
LLM models. Smaller models in particular produce poor-quality edge
|
||||
relevance scores, and there is a suspicion that models trained or
|
||||
evaluated heavily on non-Roman-script datasets offer lower performance
|
||||
on the semantic ranking operation.
|
||||
|
||||
The root cause is that the current implementation delegates edge
|
||||
relevance scoring to the LLM via a prompt that asks the model to
|
||||
assign a 1–10 relevance score to each knowledge-graph edge. This
|
||||
task — ranking structured triples for relevance to a natural-language
|
||||
query — is not well covered in standard LLM evaluation suites, so
|
||||
model benchmark scores are not predictive of performance on this
|
||||
operation. The result is that GraphRAG quality varies unpredictably
|
||||
across model choices, undermining confidence in the pipeline.
|
||||
|
||||
Beyond model variability, the LLM scoring step has further problems:
|
||||
|
||||
- **Cost and latency.** The LLM call consumes tokens and adds
|
||||
latency to every query, yet its output is unreliable. Even when
|
||||
the model performs well, the cost is disproportionate for what is
|
||||
fundamentally a ranking operation.
|
||||
|
||||
- **Subjective scoring scale.** The 1–10 relevance scale gives the
|
||||
model no objective criteria for what constitutes a 5 versus a 7.
|
||||
Different models interpret the scale differently, and even the same
|
||||
model can produce inconsistent scores across runs.
|
||||
|
||||
- **Redundancy with the embedding pre-filter.** The pipeline already
|
||||
contains a cosine-similarity stage that ranks edges by semantic
|
||||
relevance using embeddings. The LLM scoring step is a second
|
||||
filter applied on top of this, and it is not clear that it adds
|
||||
enough value to justify the additional cost and risk of
|
||||
degradation.
|
||||
|
||||
### Industry context
|
||||
|
||||
Semantic ranking is rigorously evaluated on dedicated benchmarks such
|
||||
as MTEB (Massive Text Embedding Benchmark) and BEIR (Benchmarking
|
||||
Information Retrieval), which test retrieval and reranking across
|
||||
diverse domains. The current TrustGraph approach — prompting a
|
||||
general-purpose LLM to score and rank documents (the "listwise"
|
||||
approach) — is known to be poorly optimized for this task. It
|
||||
suffers from positional bias, formatting failures, and
|
||||
inconsistency at scale.
|
||||
|
||||
The industry standard for semantic ranking has moved to
|
||||
cross-encoder models: lightweight, purpose-built models that take a
|
||||
query–document pair as input and produce a single relevance score.
|
||||
These models are fine-tuned on millions of relevance-labelled pairs
|
||||
and dominate retrieval benchmarks. They are fast, deterministic,
|
||||
and do not require an LLM inference call.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Cross-encoder service
|
||||
|
||||
A new request/response service that exposes a generic semantic
|
||||
ranking API. The service is not specific to GraphRAG — it is a
|
||||
reusable building block for any component that needs to rank text
|
||||
by relevance.
|
||||
|
||||
The service interface is pluggable. Alternative implementations
|
||||
can be swapped in behind the same API.
|
||||
|
||||
**Packaging options considered:**
|
||||
|
||||
- *`sentence-transformers`.* Full-featured, widely used.
|
||||
However, it pulls in PyTorch (~2 GB), making containers
|
||||
very large. Tested at ~1.8 seconds for 2200 edges.
|
||||
|
||||
- *`optimum.onnxruntime`.* ONNX-based inference. Still
|
||||
depends on PyTorch at import time despite using ONNX for
|
||||
inference. Tested at ~4.2 seconds for 2200 edges.
|
||||
|
||||
- *`flashrank`.* Lightweight wrapper around ONNX Runtime
|
||||
with a clean API (`Ranker`, `RerankRequest`). No PyTorch
|
||||
dependency. Tested at ~4.4 seconds for 2200 edges.
|
||||
|
||||
- *Pure `onnxruntime` + `tokenizers`.* Leanest option
|
||||
(~200 MB total). Requires manual tokenisation, padding,
|
||||
and numpy array management — more boilerplate to maintain.
|
||||
|
||||
- *External API (e.g. Cohere Rerank).* No local model at
|
||||
all. Adds network latency and an external dependency.
|
||||
|
||||
**Decision:** `flashrank` for the initial implementation.
|
||||
No PyTorch dependency, clean API, comparable performance.
|
||||
The pluggable interface allows swapping to another backend
|
||||
later.
|
||||
|
||||
**Request:**
|
||||
|
||||
- `queries` — list of `{id, text}` objects. In the GraphRAG use
|
||||
case these are the concepts extracted from the user's question.
|
||||
- `documents` — list of `{id, text}` objects. In the GraphRAG
|
||||
use case these are the candidate knowledge-graph edges
|
||||
represented as text.
|
||||
- `limit` — integer. Maximum number of results to return.
|
||||
|
||||
**Scoring:**
|
||||
|
||||
The service produces the cartesian product of all query–document
|
||||
pairs and scores each pair through the cross-encoder model. For
|
||||
each document, the maximum score across all queries is taken as the
|
||||
document's relevance score. Documents are then ranked by this
|
||||
score and the top `limit` results are returned.
|
||||
|
||||
**Response:**
|
||||
|
||||
A list of the top `limit` results, each containing:
|
||||
|
||||
- `document_id` — the ID of the matched document.
|
||||
- `query_id` — the ID of the query (concept) that produced the
|
||||
highest score for this document.
|
||||
- `score` — the relevance score.
|
||||
|
||||
Including `query_id` in the response supports the explainability
|
||||
interface: it records that an edge was selected because it is
|
||||
related to a specific concept.
|
||||
|
||||
### Integration
|
||||
|
||||
The cross-encoder service follows the standard TrustGraph service
|
||||
integration pattern:
|
||||
|
||||
- **Base package (trustgraph-base).** Schema definitions for the
|
||||
cross-encoder request/response messages. A client class that
|
||||
other components (e.g. GraphRAG) can use to call the
|
||||
cross-encoder service. Message translator registration so the
|
||||
pub/sub layer can serialise/deserialise the messages.
|
||||
|
||||
- **Flow package (trustgraph-flow).** The cross-encoder service
|
||||
implementation itself — loads the model, listens for requests,
|
||||
scores pairs, returns results. Flow definition support so the
|
||||
cross-encoder can be introduced into a processing flow via the
|
||||
standard flow configuration. `flashrank` is added as a
|
||||
dependency of `trustgraph-flow`. The service runs in its own
|
||||
container.
|
||||
|
||||
- **API gateway.** A gateway endpoint that routes cross-encoder
|
||||
requests from the HTTP API to the service over pub/sub and
|
||||
returns the response.
|
||||
|
||||
- **CLI tool.** A command-line utility
|
||||
(e.g. `tg-invoke-cross-encoder`) that calls the gateway
|
||||
endpoint for manual testing and debugging.
|
||||
|
||||
### Current GraphRAG pipeline
|
||||
|
||||
The current pipeline follows these steps:
|
||||
|
||||
1. **Concept extraction.** An LLM prompt extracts key concepts
|
||||
from the user's query.
|
||||
|
||||
2. **Graph exploration.** Seed entities are found via embedding
|
||||
similarity. A subgraph is built by multi-hop traversal from
|
||||
the seed entities (up to `max_path_length` hops, capped at
|
||||
`max_subgraph_size` edges).
|
||||
|
||||
3. **Embedding pre-filter.** Each edge is embedded as
|
||||
`"subject, predicate, object"` and scored by cosine similarity
|
||||
against the concept embeddings. The top `edge_score_limit`
|
||||
(default 30) edges are kept.
|
||||
|
||||
4. **LLM edge scoring.** The `kg-edge-scoring` prompt asks the
|
||||
LLM to assign a 1–10 relevance score to each remaining edge.
|
||||
The top `edge_limit` (default 25) edges are kept.
|
||||
|
||||
5. **LLM edge reasoning.** The `kg-edge-reasoning` prompt asks
|
||||
the LLM to explain why each selected edge is relevant to the
|
||||
query. Used for the explainability interface.
|
||||
|
||||
6. **Document tracing.** Selected edges are traced back to their
|
||||
source documents in the librarian. Runs concurrently with
|
||||
step 5.
|
||||
|
||||
7. **Synthesis.** The `kg-synthesis` prompt generates the final
|
||||
answer from the selected edges and source document metadata.
|
||||
|
||||
### Potential improvements
|
||||
|
||||
#### Replace LLM edge scoring with cross-encoder (step 4)
|
||||
|
||||
The LLM edge scoring step is replaced by a call to the
|
||||
cross-encoder service. The candidate edges are the documents and
|
||||
`edge_limit` is the limit. This is a direct substitution: faster,
|
||||
cheaper, deterministic, and more reliable across model choices.
|
||||
The LLM `kg-edge-scoring` prompt is retired.
|
||||
|
||||
**Cross-encoder query input: concepts vs. raw query.** There are
|
||||
two options for what to use as the cross-encoder queries:
|
||||
|
||||
- *Option A: Raw user query.* Pass the original question as a
|
||||
single query string. Simpler, no dependency on concept
|
||||
extraction. However, raw queries contain noise words and
|
||||
conversational phrasing that do not match well against the
|
||||
structured vocabulary of knowledge-graph edges. A single query
|
||||
also means every edge competes against the full question — a
|
||||
partial match on one aspect is diluted.
|
||||
|
||||
- *Option B: Extracted concepts.* Pass the concepts from step 1
|
||||
as separate queries. The concepts are distilled, focused terms
|
||||
that are closer to the language of the edges. With multiple
|
||||
concepts as independent queries, the cross-encoder scores each
|
||||
edge against each concept separately, giving better coverage —
|
||||
an edge only needs to match one concept well to be selected.
|
||||
The trade-off is a dependency on the LLM concept extraction
|
||||
step, but this is already in the pipeline and is a lightweight,
|
||||
reliable LLM call.
|
||||
|
||||
**Decision:** Option B — use extracted concepts. The concept
|
||||
extraction is fast, and the resulting terms produce better
|
||||
cross-encoder matches against structured triples.
|
||||
|
||||
#### Edge text representation
|
||||
|
||||
The current embedding pre-filter represents each edge as
|
||||
`"subject, predicate, object"`. Two changes:
|
||||
|
||||
- **Drop commas.** Commas add tokenisation noise without semantic
|
||||
value.
|
||||
|
||||
- **Drop the subject.** The subject identifies which entity the
|
||||
edge belongs to, but it does not contribute to whether the
|
||||
edge's content is relevant to the query. The predicate and
|
||||
object carry the semantic meaning — what relationship exists
|
||||
and what it connects to. Representing edges as `"{p} {o}"`
|
||||
produces cleaner cross-encoder matches.
|
||||
|
||||
#### Remove the embedding pre-filter (step 3)
|
||||
|
||||
The embedding pre-filter was introduced to reduce the number of
|
||||
edges before the expensive LLM scoring call. With the
|
||||
cross-encoder replacing the LLM call, this cost equation changes.
|
||||
|
||||
**Arguments for removal:**
|
||||
|
||||
- The cross-encoder is fast enough to score the full subgraph
|
||||
directly. In testing, 2200 edges scored in ~1.8 seconds; at
|
||||
the default `max_subgraph_size` of 150 edges, scoring takes
|
||||
a fraction of a second.
|
||||
|
||||
- The pre-filter is a weaker version of what the cross-encoder
|
||||
does. Bi-encoder cosine similarity embeds the query and
|
||||
document independently and compares vectors; the cross-encoder
|
||||
processes both texts together through the full transformer,
|
||||
giving it much better relevance judgement. Running a weaker
|
||||
filter before a stronger one adds latency without improving
|
||||
quality.
|
||||
|
||||
- Removing it eliminates an embedding service call (two batches:
|
||||
concepts + edges) and the associated latency.
|
||||
|
||||
**Arguments for keeping it:**
|
||||
|
||||
- If the subgraph is very large (thousands of edges), the
|
||||
cross-encoder's linear scaling could become a bottleneck.
|
||||
The pre-filter would act as a safety valve.
|
||||
|
||||
- The embedding call is cheap compared to an LLM call, so the
|
||||
overhead is modest.
|
||||
|
||||
**Decision:** Remove the pre-filter. The `max_subgraph_size`
|
||||
parameter (default 150) already caps the number of edges entering
|
||||
this stage, so the cross-encoder will not face an unbounded
|
||||
workload. If very large subgraphs become a concern in future,
|
||||
the pre-filter can be reintroduced or `max_subgraph_size` can be
|
||||
tuned.
|
||||
|
||||
#### Iterative graph traversal with cross-encoder filtering
|
||||
|
||||
The current pipeline performs graph exploration and edge filtering
|
||||
as separate phases: first build the full subgraph (up to
|
||||
`max_path_length` hops), then score and filter edges. An
|
||||
alternative is to interleave traversal and filtering — at each
|
||||
hop, use the cross-encoder to select relevant edges before
|
||||
expanding further.
|
||||
|
||||
**Option A: Big-bang traversal then filter.** Traverse the full
|
||||
subgraph up to `max_path_length` hops from the seed entities,
|
||||
collecting all edges up to `max_subgraph_size`. Then
|
||||
cross-encode the entire result to select the top edges.
|
||||
|
||||
- Simple to implement — the current traversal logic is largely
|
||||
unchanged.
|
||||
- Produces large, unfocused subgraphs. Irrelevant branches are
|
||||
explored and scored even though they will be discarded.
|
||||
- Poorly suited to multi-hop reasoning. For a query about
|
||||
Voyager 1, the subgraph includes Voyager 2's edges because
|
||||
they are within hop distance, and the filter must then
|
||||
separate them.
|
||||
|
||||
**Option B: Iterative hop-and-filter.** At each hop:
|
||||
|
||||
1. Retrieve all edges one hop from the current frontier nodes.
|
||||
2. Cross-encode these edges against the query concepts.
|
||||
3. Select the top relevant edges.
|
||||
4. The target nodes of the selected edges become the frontier
|
||||
for the next hop.
|
||||
5. Repeat up to `max_path_length` hops.
|
||||
|
||||
The final set of selected edges across all hops is the input to
|
||||
synthesis.
|
||||
|
||||
- **Guided exploration.** Each hop focuses the search by
|
||||
pruning irrelevant branches before expanding further. The
|
||||
working set stays small and relevant at every step.
|
||||
- **Multi-hop reasoning works naturally.** Following
|
||||
"Voyager 1 → has-event → crossed the heliopause" succeeds
|
||||
because each hop is individually relevant and leads to the
|
||||
next.
|
||||
- **Smaller total workload.** Fewer edges are scored overall
|
||||
because irrelevant branches are never expanded.
|
||||
- **Trade-off: greedy pruning.** An edge discarded at hop 1
|
||||
cannot lead to relevant edges at hop 2. This is inherent in
|
||||
any bounded traversal, and the cross-encoder is better
|
||||
equipped to make this relevance judgement than a blind hop
|
||||
limit.
|
||||
- **Trade-off: sequential latency.** Hops cannot be
|
||||
parallelised since each depends on the previous. However,
|
||||
each cross-encoder call on a small edge set is very fast
|
||||
(sub-second for typical working sets).
|
||||
|
||||
**Decision:** Option B — iterative hop-and-filter. The guided
|
||||
traversal produces more focused subgraphs and supports multi-hop
|
||||
reasoning, which is a significant quality improvement over the
|
||||
current approach.
|
||||
|
||||
#### Replace LLM edge reasoning with cross-encoder metadata (step 5)
|
||||
|
||||
The current `kg-edge-reasoning` prompt asks the LLM to explain why
|
||||
each edge is relevant. With the cross-encoder now making the
|
||||
selection, this explanation would be a post-hoc fabrication — the
|
||||
LLM was not involved in the decision.
|
||||
|
||||
- *Option A: Keep LLM reasoning.* Generates natural-language
|
||||
explanations but they are not grounded in the actual selection
|
||||
process. Adds an LLM call per query.
|
||||
|
||||
- *Option B: Record cross-encoder metadata.* The cross-encoder
|
||||
already returns the matched concept and score for each selected
|
||||
edge. Use this directly as the explanation.
|
||||
|
||||
**Decision:** Option B. The cross-encoder metadata is the true
|
||||
reason the edge was selected. The `kg-edge-reasoning` prompt is
|
||||
retired.
|
||||
|
||||
#### Explainability interface update
|
||||
|
||||
The explainability interface uses a `Focus` entity containing
|
||||
`EdgeSelection` sub-entities. Each `EdgeSelection` currently
|
||||
carries an `edge` (the quoted triple) and a `reasoning` field
|
||||
(free-text LLM prose), stored as `tg:reasoning` in the
|
||||
provenance graph.
|
||||
|
||||
With the cross-encoder replacing LLM reasoning, the
|
||||
`EdgeSelection` type gains two new predicates and drops one:
|
||||
|
||||
- **Remove** `tg:reasoning` — no longer produced.
|
||||
- **Add** `tg:concept` — the concept text that produced the
|
||||
highest cross-encoder score for this edge.
|
||||
- **Add** `tg:score` — the cross-encoder relevance score.
|
||||
|
||||
This is an evolution of the existing `EdgeSelection` type, not a
|
||||
new entity type. The edge selection sub-entities currently have
|
||||
no `rdf:type` declared; a new `tg:EdgeSelection` type should be
|
||||
added so that consumers can identify them in the provenance
|
||||
graph. The `Focus` entity and its relationship to `Exploration`
|
||||
are unchanged.
|
||||
|
||||
The `Focus` entity's token-usage metadata (`tg:inToken`,
|
||||
`tg:outToken`, `tg:llmModel`) no longer applies since there is
|
||||
no LLM call. These fields are dropped from the Focus entity.
|
||||
|
||||
### Proposed pipeline
|
||||
|
||||
1. **Concept extraction.** Unchanged — LLM extracts key concepts
|
||||
from the user's query.
|
||||
|
||||
2. **Seed entity lookup.** Find seed entities via embedding
|
||||
similarity against the extracted concepts.
|
||||
|
||||
3. **Iterative hop-and-filter.** For each hop up to
|
||||
`max_path_length`:
|
||||
|
||||
a. Retrieve all edges one hop from the current frontier nodes.
|
||||
|
||||
b. Represent each edge as `"{predicate} {object}"`.
|
||||
|
||||
c. Score edges against the extracted concepts using the
|
||||
cross-encoder service.
|
||||
|
||||
d. Select the top relevant edges. The target nodes of the
|
||||
selected edges become the frontier for the next hop.
|
||||
|
||||
4. **Document tracing.** Selected edges are traced back to source
|
||||
documents.
|
||||
|
||||
5. **Synthesis.** The `kg-synthesis` prompt generates the final
|
||||
answer from the selected edges and source document metadata.
|
||||
|
||||
### Implementation order
|
||||
|
||||
1. Cross-encoder service with full integration (base schema,
|
||||
flow service, gateway endpoint, CLI tool).
|
||||
2. GraphRAG pipeline changes (iterative hop-and-filter,
|
||||
edge representation, remove pre-filter).
|
||||
3. Explainability update (`tg:EdgeSelection` type, concept
|
||||
and score predicates, retire `tg:reasoning`).
|
||||
4. Retire `kg-edge-scoring` and `kg-edge-reasoning` prompts.
|
||||
5. Update `tg-invoke-graph-rag` and `tg-show-explain-trace`
|
||||
to display the new metadata. Use these as the main
|
||||
end-to-end test.
|
||||
6. Fix any failing unit tests, then add new tests as needed.
|
||||
7. Write guidance for UX devs to update the UI for the new
|
||||
explainability predicates.
|
||||
|
||||
## UX developer guidance
|
||||
|
||||
This section describes the changes to the explainability interface
|
||||
that affect frontend rendering of GraphRAG Focus events.
|
||||
|
||||
### What changed
|
||||
|
||||
Edge selection in GraphRAG previously used LLM-based scoring and
|
||||
reasoning. Each selected edge carried a `tg:reasoning` predicate
|
||||
with free-text explanation from the LLM. This has been replaced
|
||||
by a cross-encoder reranker that scores edges against query
|
||||
concepts. The explainability data now carries structured metadata
|
||||
instead of free text.
|
||||
|
||||
### Removed
|
||||
|
||||
- **`tg:reasoning`** is no longer emitted on edge selection
|
||||
entities in GraphRAG Focus events. UX code that reads
|
||||
`edge_sel.reasoning` will get an empty string. Remove any
|
||||
rendering that displays a "Reasoning" or "Reason" field for
|
||||
Focus edges.
|
||||
|
||||
- The **`kg-edge-scoring`**, **`kg-edge-reasoning`**, and
|
||||
**`kg-edge-selection`** prompts are retired. Any UX that
|
||||
references these prompt names should be cleaned up.
|
||||
|
||||
### Added
|
||||
|
||||
Each edge selection entity within a Focus event now has three
|
||||
new properties:
|
||||
|
||||
| RDF predicate | API field | Type | Description |
|
||||
|---|---|---|---|
|
||||
| `rdf:type tg:EdgeSelection` | (type check) | — | Each edge selection entity is now explicitly typed |
|
||||
| `tg:concept` | `edge_sel.concept` | `str` | The query concept that matched this edge |
|
||||
| `tg:score` | `edge_sel.score` | `float` or `None` | Cross-encoder relevance score (0.0–1.0) |
|
||||
|
||||
The `tg:edge` predicate (RDF-star quoted triple) is unchanged.
|
||||
|
||||
### How to render
|
||||
|
||||
The recommended rendering for each selected edge in a Focus event:
|
||||
|
||||
```
|
||||
Edge: (subject_label, predicate_label, object_label)
|
||||
Concept: <concept> Score: <score formatted to 4 decimal places>
|
||||
```
|
||||
|
||||
Scores near 1.0 indicate high relevance; scores near 0.0 indicate
|
||||
low relevance. UX could use the score to drive visual indicators
|
||||
such as colour intensity or a relevance bar.
|
||||
|
||||
Edges are not returned in score order — they arrive in traversal
|
||||
order across hops. If the UX wants to display edges ranked by
|
||||
relevance, sort by `edge_sel.score` descending.
|
||||
|
||||
### API classes (Python)
|
||||
|
||||
The `EdgeSelection` dataclass in `trustgraph.api.explainability`
|
||||
has these fields:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class EdgeSelection:
|
||||
uri: str
|
||||
edge: Optional[Dict[str, str]] # {"s": ..., "p": ..., "o": ...}
|
||||
reasoning: str = "" # Legacy, always empty for new traces
|
||||
concept: str = "" # Query concept that matched
|
||||
score: Optional[float] = None # Cross-encoder relevance score
|
||||
```
|
||||
|
||||
These are populated when calling
|
||||
`ExplainabilityClient.fetch_focus_with_edges()` or when parsing
|
||||
inline provenance triples from the streaming response.
|
||||
|
||||
### WebSocket response format
|
||||
|
||||
For inline explainability via the streaming WebSocket, Focus events
|
||||
arrive as `message_type: "explain"` responses. The `explain_triples`
|
||||
array contains the edge selection triples. The relevant predicates
|
||||
in wire format are:
|
||||
|
||||
```json
|
||||
{"s": {"t": "i", "i": "<edge_sel_uri>"},
|
||||
"p": {"t": "i", "i": "https://trustgraph.ai/ns/concept"},
|
||||
"o": {"t": "l", "v": "flyby event"}}
|
||||
|
||||
{"s": {"t": "i", "i": "<edge_sel_uri>"},
|
||||
"p": {"t": "i", "i": "https://trustgraph.ai/ns/score"},
|
||||
"o": {"t": "l", "v": "0.9962"}}
|
||||
```
|
||||
|
||||
Note that `tg:score` is transmitted as a string literal and must
|
||||
be parsed to a float on the client side.
|
||||
|
||||
### Exploration event
|
||||
|
||||
The Exploration event's `edge_count` field now reports the number
|
||||
of edges selected by the cross-encoder across all hops (previously
|
||||
it reported the total number of edges retrieved before filtering).
|
||||
The `entities` list continues to report the seed entities found
|
||||
by vector search.
|
||||
|
|
@ -95,10 +95,6 @@ class TestGraphRagIntegration:
|
|||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-scoring":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-reasoning":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
|
|
@ -113,14 +109,22 @@ class TestGraphRagIntegration:
|
|||
client.prompt.side_effect = mock_prompt
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_reranker_client(self):
|
||||
"""Mock reranker client for cross-encoder edge filtering"""
|
||||
client = AsyncMock()
|
||||
client.rerank.return_value = []
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client,
|
||||
mock_triples_client, mock_prompt_client):
|
||||
mock_triples_client, mock_reranker_client, mock_prompt_client):
|
||||
"""Create GraphRag instance with mocked dependencies"""
|
||||
return GraphRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
reranker_client=mock_reranker_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
verbose=True
|
||||
)
|
||||
|
|
@ -167,8 +171,8 @@ class TestGraphRagIntegration:
|
|||
# 3. Should query triples to build knowledge subgraph
|
||||
assert mock_triples_client.query_stream.call_count > 0
|
||||
|
||||
# 4. Should call prompt four times (extract-concepts + edge-scoring + edge-reasoning + synthesis)
|
||||
assert mock_prompt_client.prompt.call_count == 4
|
||||
# 4. Should call prompt twice (extract-concepts + synthesis)
|
||||
assert mock_prompt_client.prompt.call_count == 2
|
||||
|
||||
# Verify final response
|
||||
response, usage = response
|
||||
|
|
|
|||
|
|
@ -63,11 +63,6 @@ class TestGraphRagStreaming:
|
|||
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
|
||||
if prompt_id == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_id == "kg-edge-scoring":
|
||||
# Edge scoring returns JSONL with IDs and scores
|
||||
return PromptResult(response_type="text", text='{"id": "abc12345", "score": 0.9}\n')
|
||||
elif prompt_id == "kg-edge-reasoning":
|
||||
return PromptResult(response_type="text", text='{"id": "abc12345", "reasoning": "Relevant to query"}\n')
|
||||
elif prompt_id == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
# Simulate streaming chunks with end_of_stream flags
|
||||
|
|
@ -88,14 +83,23 @@ class TestGraphRagStreaming:
|
|||
client.prompt.side_effect = prompt_side_effect
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_reranker_client(self):
|
||||
"""Mock reranker client for cross-encoder edge filtering"""
|
||||
client = AsyncMock()
|
||||
client.rerank.return_value = []
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def graph_rag_streaming(self, mock_embeddings_client, mock_graph_embeddings_client,
|
||||
mock_triples_client, mock_streaming_prompt_client):
|
||||
mock_triples_client, mock_reranker_client,
|
||||
mock_streaming_prompt_client):
|
||||
"""Create GraphRag instance with streaming support"""
|
||||
return GraphRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
reranker_client=mock_reranker_client,
|
||||
prompt_client=mock_streaming_prompt_client,
|
||||
verbose=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ class TestGraphRagStreamingProtocol:
|
|||
client = AsyncMock()
|
||||
|
||||
async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
if prompt_name == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
|
|
@ -63,14 +63,23 @@ class TestGraphRagStreamingProtocol:
|
|||
client.prompt.side_effect = prompt_side_effect
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_reranker_client(self):
|
||||
"""Mock reranker client for cross-encoder edge filtering"""
|
||||
client = AsyncMock()
|
||||
client.rerank.return_value = []
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client,
|
||||
mock_triples_client, mock_streaming_prompt_client):
|
||||
mock_triples_client, mock_reranker_client,
|
||||
mock_streaming_prompt_client):
|
||||
"""Create GraphRag instance with mocked dependencies"""
|
||||
return GraphRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
reranker_client=mock_reranker_client,
|
||||
prompt_client=mock_streaming_prompt_client,
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -327,7 +336,7 @@ class TestStreamingProtocolEdgeCases:
|
|||
client = AsyncMock()
|
||||
|
||||
async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
if prompt_name == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
|
|
@ -342,10 +351,14 @@ class TestStreamingProtocolEdgeCases:
|
|||
|
||||
client.prompt.side_effect = prompt_with_empties
|
||||
|
||||
mock_reranker = AsyncMock()
|
||||
mock_reranker.rerank.return_value = []
|
||||
|
||||
rag = GraphRag(
|
||||
embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[[0.1]]])),
|
||||
graph_embeddings_client=AsyncMock(query=AsyncMock(return_value=[])),
|
||||
triples_client=AsyncMock(query=AsyncMock(return_value=[])),
|
||||
reranker_client=mock_reranker,
|
||||
prompt_client=client,
|
||||
verbose=False
|
||||
)
|
||||
|
|
|
|||
|
|
@ -195,38 +195,6 @@ class TestPromptClientStreamingCallback:
|
|||
assert callback.call_args_list[0] == call("test", False)
|
||||
assert callback.call_args_list[1] == call("", True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kg_prompt_passes_parameters_to_callback(self, prompt_client):
|
||||
"""Test that kg_prompt correctly passes streaming parameters"""
|
||||
# Arrange
|
||||
async def mock_request(request, recipient=None, timeout=600):
|
||||
if recipient:
|
||||
responses = [
|
||||
PromptResponse(text="Answer", object=None, error=None, end_of_stream=False),
|
||||
PromptResponse(text="", object=None, error=None, end_of_stream=True),
|
||||
]
|
||||
for resp in responses:
|
||||
should_stop = await recipient(resp)
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
prompt_client.request = mock_request
|
||||
|
||||
callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
await prompt_client.kg_prompt(
|
||||
query="What is machine learning?",
|
||||
kg=[("subject", "predicate", "object")],
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert callback.call_count == 2
|
||||
assert callback.call_args_list[0] == call("Answer", False)
|
||||
assert callback.call_args_list[1] == call("", True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_prompt_passes_parameters_to_callback(self, prompt_client):
|
||||
"""Test that document_prompt correctly passes streaming parameters"""
|
||||
|
|
|
|||
|
|
@ -107,6 +107,7 @@ class TestGraphRagDagStructure:
|
|||
embeddings_client = AsyncMock()
|
||||
graph_embeddings_client = AsyncMock()
|
||||
triples_client = AsyncMock()
|
||||
reranker_client = AsyncMock()
|
||||
|
||||
embeddings_client.embed.return_value = [[0.1, 0.2]]
|
||||
graph_embeddings_client.query.return_value = [
|
||||
|
|
@ -121,27 +122,22 @@ class TestGraphRagDagStructure:
|
|||
]
|
||||
triples_client.query.return_value = []
|
||||
|
||||
result = MagicMock()
|
||||
result.document_id = "0"
|
||||
result.query_id = "0"
|
||||
result.score = 0.95
|
||||
reranker_client.rerank.return_value = [result]
|
||||
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="concept")
|
||||
elif template_id == "kg-edge-scoring":
|
||||
edges = variables.get("knowledge", [])
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[{"id": e["id"], "score": 10} for e in edges],
|
||||
)
|
||||
elif template_id == "kg-edge-reasoning":
|
||||
edges = variables.get("knowledge", [])
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[{"id": e["id"], "reasoning": "relevant"} for e in edges],
|
||||
)
|
||||
elif template_id == "kg-synthesis":
|
||||
return PromptResult(response_type="text", text="Answer.")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
|
||||
return (prompt_client, embeddings_client, graph_embeddings_client,
|
||||
triples_client, reranker_client)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_chain(self, mock_clients):
|
||||
|
|
@ -152,7 +148,7 @@ class TestGraphRagDagStructure:
|
|||
events.append({"explain_id": explain_id, "triples": triples})
|
||||
|
||||
await rag.query(
|
||||
query="test", explain_callback=explain_cb, edge_score_limit=0,
|
||||
query="test", explain_callback=explain_cb,
|
||||
)
|
||||
|
||||
dag = _collect_events(events)
|
||||
|
|
|
|||
|
|
@ -15,54 +15,52 @@ class TestGraphRag:
|
|||
|
||||
def test_graph_rag_initialization_with_defaults(self):
|
||||
"""Test GraphRag initialization with default verbose setting"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = MagicMock()
|
||||
mock_embeddings_client = MagicMock()
|
||||
mock_graph_embeddings_client = MagicMock()
|
||||
mock_triples_client = MagicMock()
|
||||
mock_reranker_client = MagicMock()
|
||||
|
||||
# Initialize GraphRag
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert graph_rag.prompt_client == mock_prompt_client
|
||||
assert graph_rag.embeddings_client == mock_embeddings_client
|
||||
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
||||
assert graph_rag.triples_client == mock_triples_client
|
||||
assert graph_rag.verbose is False # Default value
|
||||
# Verify label_cache is an LRUCacheWithTTL instance
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
|
||||
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
|
||||
|
||||
def test_graph_rag_initialization_with_verbose(self):
|
||||
"""Test GraphRag initialization with verbose enabled"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = MagicMock()
|
||||
mock_embeddings_client = MagicMock()
|
||||
mock_graph_embeddings_client = MagicMock()
|
||||
mock_triples_client = MagicMock()
|
||||
|
||||
# Initialize GraphRag with verbose=True
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
verbose=True
|
||||
reranker_client=mock_reranker_client,
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert graph_rag.prompt_client == mock_prompt_client
|
||||
assert graph_rag.embeddings_client == mock_embeddings_client
|
||||
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
||||
assert graph_rag.triples_client == mock_triples_client
|
||||
assert graph_rag.reranker_client == mock_reranker_client
|
||||
assert graph_rag.verbose is False
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
|
||||
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
|
||||
|
||||
def test_graph_rag_initialization_with_verbose(self):
|
||||
"""Test GraphRag initialization with verbose enabled"""
|
||||
mock_prompt_client = MagicMock()
|
||||
mock_embeddings_client = MagicMock()
|
||||
mock_graph_embeddings_client = MagicMock()
|
||||
mock_triples_client = MagicMock()
|
||||
mock_reranker_client = MagicMock()
|
||||
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
reranker_client=mock_reranker_client,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
assert graph_rag.prompt_client == mock_prompt_client
|
||||
assert graph_rag.embeddings_client == mock_embeddings_client
|
||||
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
||||
assert graph_rag.triples_client == mock_triples_client
|
||||
assert graph_rag.reranker_client == mock_reranker_client
|
||||
assert graph_rag.verbose is True
|
||||
# Verify label_cache is an LRUCacheWithTTL instance
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
|
||||
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
|
||||
|
||||
|
|
@ -365,244 +363,162 @@ class TestQuery:
|
|||
assert "workspace" not in c.kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_never_passes_workspace(self):
|
||||
"""Verify follow_edges never passes workspace to query_stream."""
|
||||
async def test_hop_and_filter_never_passes_workspace(self):
|
||||
"""Verify hop_and_filter never passes workspace to query_stream."""
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_reranker_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
mock_rag.reranker_client = mock_reranker_client
|
||||
mock_rag.label_cache = MagicMock()
|
||||
mock_rag.label_cache.get.return_value = None
|
||||
|
||||
mock_triple = MagicMock()
|
||||
mock_triple.s, mock_triple.p, mock_triple.o = "e1", "p1", "o1"
|
||||
mock_triple.s = "e1"
|
||||
mock_triple.p = "p1"
|
||||
mock_triple.o = "o1"
|
||||
mock_triples_client.query_stream.return_value = [mock_triple]
|
||||
mock_triples_client.query.return_value = []
|
||||
|
||||
result = MagicMock()
|
||||
result.document_id = "0"
|
||||
result.query_id = "0"
|
||||
result.score = 0.9
|
||||
mock_reranker_client.rerank.return_value = [result]
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
triple_limit=10
|
||||
triple_limit=10,
|
||||
)
|
||||
|
||||
subgraph = set()
|
||||
await query.follow_edges("e1", subgraph, path_length=1)
|
||||
await query.hop_and_filter(["e1"], ["concept"])
|
||||
|
||||
for c in mock_triples_client.query_stream.call_args_list:
|
||||
assert "workspace" not in c.kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_basic_functionality(self):
|
||||
"""Test Query.follow_edges method basic triple discovery"""
|
||||
async def test_hop_and_filter_basic_functionality(self):
|
||||
"""Test hop_and_filter retrieves edges and scores them with reranker."""
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_reranker_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
mock_rag.reranker_client = mock_reranker_client
|
||||
mock_rag.label_cache = MagicMock()
|
||||
mock_rag.label_cache.get.return_value = None
|
||||
|
||||
mock_triple1 = MagicMock()
|
||||
mock_triple1.s, mock_triple1.p, mock_triple1.o = "entity1", "predicate1", "object1"
|
||||
mock_triple = MagicMock()
|
||||
mock_triple.s = "entity1"
|
||||
mock_triple.p = "predicate1"
|
||||
mock_triple.o = "object1"
|
||||
mock_triples_client.query_stream.return_value = [mock_triple]
|
||||
mock_triples_client.query.return_value = []
|
||||
|
||||
mock_triple2 = MagicMock()
|
||||
mock_triple2.s, mock_triple2.p, mock_triple2.o = "subject2", "entity1", "object2"
|
||||
|
||||
mock_triple3 = MagicMock()
|
||||
mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1"
|
||||
|
||||
mock_triples_client.query_stream.side_effect = [
|
||||
[mock_triple1], # s=ent
|
||||
[mock_triple2], # p=ent
|
||||
[mock_triple3], # o=ent
|
||||
]
|
||||
result = MagicMock()
|
||||
result.document_id = "0"
|
||||
result.query_id = "0"
|
||||
result.score = 0.95
|
||||
mock_reranker_client.rerank.return_value = [result]
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
triple_limit=10
|
||||
triple_limit=10,
|
||||
edge_limit=25,
|
||||
)
|
||||
|
||||
subgraph = set()
|
||||
await query.follow_edges("entity1", subgraph, path_length=1)
|
||||
|
||||
assert mock_triples_client.query_stream.call_count == 3
|
||||
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s="entity1", p=None, o=None, limit=10,
|
||||
collection="test_collection", batch_size=20, g=""
|
||||
)
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s=None, p="entity1", o=None, limit=10,
|
||||
collection="test_collection", batch_size=20, g=""
|
||||
)
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s=None, p=None, o="entity1", limit=10,
|
||||
collection="test_collection", batch_size=20, g=""
|
||||
selected, uri_map, edge_meta = await query.hop_and_filter(
|
||||
["entity1"], ["test concept"],
|
||||
)
|
||||
|
||||
expected_subgraph = {
|
||||
("entity1", "predicate1", "object1"),
|
||||
("subject2", "entity1", "object2"),
|
||||
("subject3", "predicate3", "entity1")
|
||||
}
|
||||
assert subgraph == expected_subgraph
|
||||
assert len(selected) == 1
|
||||
assert len(uri_map) == 1
|
||||
assert len(edge_meta) == 1
|
||||
|
||||
mock_reranker_client.rerank.assert_called_once()
|
||||
call_kwargs = mock_reranker_client.rerank.call_args
|
||||
assert call_kwargs.kwargs["limit"] == 25
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_with_path_length_zero(self):
|
||||
"""Test Query.follow_edges method with path_length=0"""
|
||||
async def test_hop_and_filter_with_empty_frontier(self):
|
||||
"""Test hop_and_filter with no seed entities returns empty."""
|
||||
mock_rag = MagicMock()
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
selected, uri_map, edge_meta = await query.hop_and_filter([], ["concept"])
|
||||
|
||||
assert selected == []
|
||||
assert uri_map == {}
|
||||
assert edge_meta == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hop_and_filter_filters_label_triples(self):
|
||||
"""Test hop_and_filter skips rdfs:label edges."""
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_reranker_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
mock_rag.reranker_client = mock_reranker_client
|
||||
mock_rag.label_cache = MagicMock()
|
||||
mock_rag.label_cache.get.return_value = None
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
label_triple = MagicMock()
|
||||
label_triple.s = "entity1"
|
||||
label_triple.p = "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
label_triple.o = "Entity One"
|
||||
|
||||
subgraph = set()
|
||||
await query.follow_edges("entity1", subgraph, path_length=0)
|
||||
|
||||
mock_triples_client.query_stream.assert_not_called()
|
||||
assert subgraph == set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_with_max_subgraph_size_limit(self):
|
||||
"""Test Query.follow_edges method respects max_subgraph_size"""
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
mock_triples_client.query_stream.return_value = [label_triple]
|
||||
mock_triples_client.query.return_value = []
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_subgraph_size=2
|
||||
triple_limit=10,
|
||||
)
|
||||
|
||||
subgraph = {("s1", "p1", "o1"), ("s2", "p2", "o2"), ("s3", "p3", "o3")}
|
||||
|
||||
await query.follow_edges("entity1", subgraph, path_length=1)
|
||||
|
||||
mock_triples_client.query_stream.assert_not_called()
|
||||
assert len(subgraph) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_subgraph_method(self):
|
||||
"""Test Query.get_subgraph returns (subgraph, entities, concepts) tuple"""
|
||||
mock_rag = MagicMock()
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_path_length=1
|
||||
selected, uri_map, edge_meta = await query.hop_and_filter(
|
||||
["entity1"], ["concept"],
|
||||
)
|
||||
|
||||
# Mock get_entities to return (entities, concepts) tuple
|
||||
query.get_entities = AsyncMock(
|
||||
return_value=(["entity1", "entity2"], ["concept1"])
|
||||
)
|
||||
|
||||
query.follow_edges_batch = AsyncMock(return_value=(
|
||||
{
|
||||
("entity1", "predicate1", "object1"),
|
||||
("entity2", "predicate2", "object2")
|
||||
},
|
||||
{}
|
||||
))
|
||||
|
||||
subgraph, term_map, entities, concepts = await query.get_subgraph("test query")
|
||||
|
||||
query.get_entities.assert_called_once_with("test query")
|
||||
query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1)
|
||||
|
||||
assert isinstance(subgraph, list)
|
||||
assert len(subgraph) == 2
|
||||
assert ("entity1", "predicate1", "object1") in subgraph
|
||||
assert ("entity2", "predicate2", "object2") in subgraph
|
||||
assert entities == ["entity1", "entity2"]
|
||||
assert concepts == ["concept1"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_labelgraph_method(self):
|
||||
"""Test Query.get_labelgraph returns (labeled_edges, uri_map, entities, concepts)"""
|
||||
mock_rag = MagicMock()
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_subgraph_size=100
|
||||
)
|
||||
|
||||
test_subgraph = [
|
||||
("entity1", "predicate1", "object1"),
|
||||
("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"),
|
||||
("entity3", "predicate3", "object3")
|
||||
]
|
||||
test_entities = ["entity1", "entity3"]
|
||||
test_concepts = ["concept1"]
|
||||
query.get_subgraph = AsyncMock(
|
||||
return_value=(test_subgraph, {}, test_entities, test_concepts)
|
||||
)
|
||||
|
||||
async def mock_maybe_label(entity):
|
||||
label_map = {
|
||||
"entity1": "Human Entity One",
|
||||
"predicate1": "Human Predicate One",
|
||||
"object1": "Human Object One",
|
||||
"entity3": "Human Entity Three",
|
||||
"predicate3": "Human Predicate Three",
|
||||
"object3": "Human Object Three"
|
||||
}
|
||||
return label_map.get(entity, entity)
|
||||
|
||||
query.maybe_label = AsyncMock(side_effect=mock_maybe_label)
|
||||
|
||||
labeled_edges, uri_map, entities, concepts = await query.get_labelgraph("test query")
|
||||
|
||||
query.get_subgraph.assert_called_once_with("test query")
|
||||
|
||||
# Label triples filtered out
|
||||
assert len(labeled_edges) == 2
|
||||
|
||||
# maybe_label called for non-label triples
|
||||
assert query.maybe_label.call_count == 6
|
||||
|
||||
expected_edges = [
|
||||
("Human Entity One", "Human Predicate One", "Human Object One"),
|
||||
("Human Entity Three", "Human Predicate Three", "Human Object Three")
|
||||
]
|
||||
assert labeled_edges == expected_edges
|
||||
|
||||
assert len(uri_map) == 2
|
||||
assert entities == test_entities
|
||||
assert concepts == test_concepts
|
||||
assert selected == []
|
||||
mock_reranker_client.rerank.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_query_method(self):
|
||||
"""Test GraphRag.query method orchestrates full RAG pipeline with provenance"""
|
||||
import json
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import edge_id
|
||||
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_graph_embeddings_client = AsyncMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_reranker_client = AsyncMock()
|
||||
|
||||
expected_response = "This is the RAG response"
|
||||
test_labelgraph = [("Subject", "Predicate", "Object")]
|
||||
test_edge_id = edge_id("Subject", "Predicate", "Object")
|
||||
test_selected_edges = [("Subject", "Predicate", "Object")]
|
||||
test_eid = edge_id("Subject", "Predicate", "Object")
|
||||
test_uri_map = {
|
||||
test_edge_id: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
|
||||
test_eid: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
|
||||
}
|
||||
test_edge_metadata = {
|
||||
test_eid: {"concept": "test concept", "score": 0.95}
|
||||
}
|
||||
test_entities = ["http://example.org/subject"]
|
||||
test_concepts = ["test concept"]
|
||||
|
||||
# Mock prompt responses for the multi-step process
|
||||
mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
|
||||
mock_graph_embeddings_client.query.return_value = []
|
||||
|
||||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-scoring":
|
||||
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "score": 0.9}])
|
||||
elif prompt_name == "kg-edge-reasoning":
|
||||
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "reasoning": "relevant"}])
|
||||
return PromptResult(response_type="text", text="test concept")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return PromptResult(response_type="text", text=expected_response)
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
|
@ -614,16 +530,16 @@ class TestQuery:
|
|||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
verbose=False
|
||||
reranker_client=mock_reranker_client,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# Patch Query.get_labelgraph to return test data
|
||||
original_get_labelgraph = Query.get_labelgraph
|
||||
original_hop_and_filter = Query.hop_and_filter
|
||||
|
||||
async def mock_get_labelgraph(self, query_text):
|
||||
return test_labelgraph, test_uri_map, test_entities, test_concepts
|
||||
async def mock_hop_and_filter(self, seed_entities, concepts):
|
||||
return test_selected_edges, test_uri_map, test_edge_metadata
|
||||
|
||||
Query.get_labelgraph = mock_get_labelgraph
|
||||
Query.hop_and_filter = mock_hop_and_filter
|
||||
|
||||
provenance_events = []
|
||||
|
||||
|
|
@ -636,7 +552,7 @@ class TestQuery:
|
|||
collection="test_collection",
|
||||
entity_limit=25,
|
||||
triple_limit=15,
|
||||
explain_callback=collect_provenance
|
||||
explain_callback=collect_provenance,
|
||||
)
|
||||
|
||||
response_text, usage = response
|
||||
|
|
@ -650,7 +566,6 @@ class TestQuery:
|
|||
assert len(triples) > 0
|
||||
assert prov_id.startswith("urn:trustgraph:")
|
||||
|
||||
# Verify order
|
||||
assert "question" in provenance_events[0][1]
|
||||
assert "grounding" in provenance_events[1][1]
|
||||
assert "exploration" in provenance_events[2][1]
|
||||
|
|
@ -658,4 +573,4 @@ class TestQuery:
|
|||
assert "synthesis" in provenance_events[4][1]
|
||||
|
||||
finally:
|
||||
Query.get_labelgraph = original_get_labelgraph
|
||||
Query.hop_and_filter = original_hop_and_filter
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from trustgraph.provenance.namespaces import (
|
|||
TG_GRAPH_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
|
||||
TG_FOCUS, TG_SYNTHESIS, TG_ANSWER_TYPE,
|
||||
TG_QUERY, TG_CONCEPT, TG_ENTITY, TG_EDGE_COUNT,
|
||||
TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
TG_SELECTED_EDGE, TG_EDGE, TG_SCORE, TG_EDGE_SELECTION,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -91,17 +91,17 @@ def build_mock_clients():
|
|||
1. prompt_client.prompt("extract-concepts", ...) -> concepts
|
||||
2. embeddings_client.embed(concepts) -> vectors
|
||||
3. graph_embeddings_client.query(vector, ...) -> entity matches
|
||||
4. triples_client.query_stream(s/p/o, ...) -> edges (follow_edges_batch)
|
||||
4. triples_client.query_stream(s/p/o, ...) -> edges (hop_and_filter)
|
||||
5. triples_client.query(s, LABEL, ...) -> labels (maybe_label)
|
||||
6. prompt_client.prompt("kg-edge-scoring", ...) -> scored edges
|
||||
7. prompt_client.prompt("kg-edge-reasoning", ...) -> reasoning
|
||||
8. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns [])
|
||||
9. prompt_client.prompt("kg-synthesis", ...) -> final answer
|
||||
6. reranker_client.rerank(queries, documents, limit) -> scored edges
|
||||
7. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns [])
|
||||
8. prompt_client.prompt("kg-synthesis", ...) -> final answer
|
||||
"""
|
||||
prompt_client = AsyncMock()
|
||||
embeddings_client = AsyncMock()
|
||||
graph_embeddings_client = AsyncMock()
|
||||
triples_client = AsyncMock()
|
||||
reranker_client = AsyncMock()
|
||||
|
||||
# 1. Concept extraction
|
||||
prompt_responses = {}
|
||||
|
|
@ -116,7 +116,7 @@ def build_mock_clients():
|
|||
EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_B)),
|
||||
]
|
||||
|
||||
# 4. Triple queries (follow_edges_batch) - return our edges
|
||||
# 4. Triple queries (hop_and_filter) - return our edges
|
||||
kg_triples = [
|
||||
make_schema_triple(*EDGE_1),
|
||||
make_schema_triple(*EDGE_2),
|
||||
|
|
@ -130,9 +130,18 @@ def build_mock_clients():
|
|||
return [] # No labels found, will fall back to URI
|
||||
triples_client.query.side_effect = mock_label_query
|
||||
|
||||
# 6+7. Edge scoring and reasoning: dynamically score/reason about
|
||||
# whatever edges the query method sends us, since edge IDs are computed
|
||||
# from str(Term) representations which include the full dataclass repr.
|
||||
# 6. Reranker: select all documents with high scores
|
||||
async def mock_rerank(queries, documents, limit):
|
||||
results = []
|
||||
for i, doc in enumerate(documents):
|
||||
result = MagicMock()
|
||||
result.document_id = doc["id"]
|
||||
result.query_id = queries[0]["id"] if queries else "0"
|
||||
result.score = 0.9 - (i * 0.1)
|
||||
results.append(result)
|
||||
return results[:limit]
|
||||
reranker_client.rerank.side_effect = mock_rerank
|
||||
|
||||
synthesis_answer = "Quantum computing applies physics principles to computation."
|
||||
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
|
|
@ -141,26 +150,6 @@ def build_mock_clients():
|
|||
response_type="text",
|
||||
text=prompt_responses["extract-concepts"],
|
||||
)
|
||||
elif template_id == "kg-edge-scoring":
|
||||
# Score all edges highly, using the IDs that GraphRag computed
|
||||
edges = variables.get("knowledge", [])
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"id": e["id"], "score": 10 - i}
|
||||
for i, e in enumerate(edges)
|
||||
],
|
||||
)
|
||||
elif template_id == "kg-edge-reasoning":
|
||||
# Provide reasoning for each edge
|
||||
edges = variables.get("knowledge", [])
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
|
||||
for i, e in enumerate(edges)
|
||||
],
|
||||
)
|
||||
elif template_id == "kg-synthesis":
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
|
|
@ -170,7 +159,8 @@ def build_mock_clients():
|
|||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
|
||||
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
|
||||
return (prompt_client, embeddings_client, graph_embeddings_client,
|
||||
triples_client, reranker_client)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -197,7 +187,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0, # skip semantic pre-filter for simplicity
|
||||
|
||||
)
|
||||
|
||||
assert len(events) == 5, (
|
||||
|
|
@ -222,7 +212,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
expected_types = [
|
||||
|
|
@ -260,7 +250,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
uris = [e["explain_id"] for e in events]
|
||||
|
|
@ -297,7 +287,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
q_uri = events[0]["explain_id"]
|
||||
|
|
@ -320,7 +310,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
gnd_uri = events[1]["explain_id"]
|
||||
|
|
@ -344,7 +334,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
exp_uri = events[2]["explain_id"]
|
||||
|
|
@ -355,10 +345,10 @@ class TestGraphRagQueryProvenance:
|
|||
assert int(t.o.value) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_focus_has_selected_edges_with_reasoning(self):
|
||||
async def test_focus_has_selected_edges_with_concept_and_score(self):
|
||||
"""
|
||||
The focus event should carry selected edges as quoted triples
|
||||
with reasoning text.
|
||||
with cross-encoder concept and score metadata.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
|
@ -371,7 +361,6 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
foc_uri = events[3]["explain_id"]
|
||||
|
|
@ -387,11 +376,19 @@ class TestGraphRagQueryProvenance:
|
|||
for t in edge_t:
|
||||
assert t.o.triple is not None, "tg:edge object must be a quoted triple"
|
||||
|
||||
# Should have reasoning
|
||||
reasoning = find_triples(foc_triples, TG_REASONING)
|
||||
assert len(reasoning) > 0, "Focus should have reasoning for selected edges"
|
||||
reasoning_texts = {t.o.value for t in reasoning}
|
||||
assert any(r for r in reasoning_texts), "Reasoning should not be empty"
|
||||
# Edge selections should be typed as EdgeSelection
|
||||
edge_sel_uris = [t.o.iri for t in selected]
|
||||
for uri in edge_sel_uris:
|
||||
assert has_type(foc_triples, uri, TG_EDGE_SELECTION)
|
||||
|
||||
# Should have concept and score
|
||||
concepts = find_triples(foc_triples, TG_CONCEPT)
|
||||
assert len(concepts) > 0, "Focus should have tg:concept for selected edges"
|
||||
|
||||
scores = find_triples(foc_triples, TG_SCORE)
|
||||
assert len(scores) > 0, "Focus should have tg:score for selected edges"
|
||||
for t in scores:
|
||||
float(t.o.value) # Should be parseable as float
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesis_is_answer_type(self):
|
||||
|
|
@ -407,7 +404,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
syn_uri = events[4]["explain_id"]
|
||||
|
|
@ -429,7 +426,7 @@ class TestGraphRagQueryProvenance:
|
|||
result_text, usage = await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
assert result_text == "Quantum computing applies physics principles to computation."
|
||||
|
|
@ -449,7 +446,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
parent_uri=parent,
|
||||
)
|
||||
|
||||
|
|
@ -465,7 +462,7 @@ class TestGraphRagQueryProvenance:
|
|||
|
||||
result_text, usage = await rag.query(
|
||||
query="What is quantum computing?",
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
assert result_text == "Quantum computing applies physics principles to computation."
|
||||
|
|
@ -484,7 +481,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
for event in events:
|
||||
|
|
|
|||
|
|
@ -646,6 +646,16 @@ class AsyncFlowInstance:
|
|||
|
||||
return await self.request("embeddings", request_data)
|
||||
|
||||
async def rerank(self, queries: list, documents: list, limit: int = 10, **kwargs: Any):
|
||||
request_data = {
|
||||
"queries": queries,
|
||||
"documents": documents,
|
||||
"limit": limit,
|
||||
}
|
||||
request_data.update(kwargs)
|
||||
|
||||
return await self.request("reranker", request_data)
|
||||
|
||||
async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs: Any):
|
||||
"""
|
||||
Query RDF triples using pattern matching.
|
||||
|
|
|
|||
|
|
@ -443,6 +443,19 @@ class AsyncSocketFlowInstance:
|
|||
|
||||
return await self.client._send_request("embeddings", self.flow_id, request)
|
||||
|
||||
async def rerank(self, queries: list, documents: list, limit: int = 10,
|
||||
**kwargs):
|
||||
request = {
|
||||
"queries": queries,
|
||||
"documents": documents,
|
||||
"limit": limit,
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
return await self.client._send_request(
|
||||
"reranker", self.flow_id, request,
|
||||
)
|
||||
|
||||
async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs):
|
||||
"""Triple pattern query"""
|
||||
request = {"limit": limit}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ TG_EDGE_COUNT = TG + "edgeCount"
|
|||
TG_SELECTED_EDGE = TG + "selectedEdge"
|
||||
TG_EDGE = TG + "edge"
|
||||
TG_REASONING = TG + "reasoning"
|
||||
TG_SCORE = TG + "score"
|
||||
TG_DOCUMENT = TG + "document"
|
||||
TG_CONCEPT = TG + "concept"
|
||||
TG_ENTITY = TG + "entity"
|
||||
|
|
@ -66,10 +67,12 @@ RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
|
|||
|
||||
@dataclass
|
||||
class EdgeSelection:
|
||||
"""A selected edge with reasoning from GraphRAG Focus step."""
|
||||
"""A selected edge with cross-encoder metadata from GraphRAG Focus step."""
|
||||
uri: str
|
||||
edge: Optional[Dict[str, str]] = None # {"s": ..., "p": ..., "o": ...}
|
||||
reasoning: str = ""
|
||||
concept: str = ""
|
||||
score: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -209,7 +212,7 @@ class Exploration(ExplainEntity):
|
|||
|
||||
@dataclass
|
||||
class Focus(ExplainEntity):
|
||||
"""Focus entity - selected edges with LLM reasoning (GraphRAG only)."""
|
||||
"""Focus entity - selected edges with cross-encoder scoring (GraphRAG only)."""
|
||||
selected_edge_uris: List[str] = field(default_factory=list)
|
||||
edge_selections: List[EdgeSelection] = field(default_factory=list)
|
||||
|
||||
|
|
@ -418,14 +421,26 @@ def parse_edge_selection_triples(triples: List[Tuple[str, str, Any]]) -> EdgeSel
|
|||
uri = triples[0][0] if triples else ""
|
||||
edge = None
|
||||
reasoning = ""
|
||||
concept = ""
|
||||
score = None
|
||||
|
||||
for s, p, o in triples:
|
||||
if p == TG_EDGE and isinstance(o, dict):
|
||||
edge = o
|
||||
elif p == TG_REASONING:
|
||||
reasoning = o
|
||||
elif p == TG_CONCEPT:
|
||||
concept = o
|
||||
elif p == TG_SCORE:
|
||||
try:
|
||||
score = float(o)
|
||||
except (ValueError, TypeError):
|
||||
score = None
|
||||
|
||||
return EdgeSelection(uri=uri, edge=edge, reasoning=reasoning)
|
||||
return EdgeSelection(
|
||||
uri=uri, edge=edge, reasoning=reasoning,
|
||||
concept=concept, score=score,
|
||||
)
|
||||
|
||||
|
||||
def extract_term_value(term: Dict[str, Any]) -> Any:
|
||||
|
|
|
|||
|
|
@ -491,6 +491,19 @@ class FlowInstance:
|
|||
input
|
||||
)["vectors"]
|
||||
|
||||
def rerank(self, queries, documents, limit=10):
|
||||
|
||||
input = {
|
||||
"queries": queries,
|
||||
"documents": documents,
|
||||
"limit": limit,
|
||||
}
|
||||
|
||||
return self.request(
|
||||
"service/reranker",
|
||||
input
|
||||
)
|
||||
|
||||
def graph_embeddings_query(self, text, collection, limit=10):
|
||||
"""
|
||||
Query knowledge graph entities using semantic similarity.
|
||||
|
|
|
|||
|
|
@ -885,6 +885,19 @@ class SocketFlowInstance:
|
|||
|
||||
return self.client._send_request_sync("embeddings", self.flow_id, request, False)
|
||||
|
||||
def rerank(self, queries: list, documents: list, limit: int = 10,
|
||||
**kwargs: Any) -> Dict[str, Any]:
|
||||
request = {
|
||||
"queries": queries,
|
||||
"documents": documents,
|
||||
"limit": limit,
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
return self.client._send_request_sync(
|
||||
"reranker", self.flow_id, request, False,
|
||||
)
|
||||
|
||||
def triples_query(
|
||||
self,
|
||||
s: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
|
|
|
|||
|
|
@ -42,6 +42,8 @@ from . dynamic_tool_service import DynamicToolService
|
|||
from . tool_service_client import ToolServiceClientSpec
|
||||
from . agent_client import AgentClientSpec
|
||||
from . structured_query_client import StructuredQueryClientSpec
|
||||
from . reranker_client import RerankerClientSpec
|
||||
from . reranker_service import RerankerService
|
||||
from . row_embeddings_query_client import RowEmbeddingsQueryClientSpec
|
||||
from . collection_config_handler import CollectionConfigHandler
|
||||
|
||||
|
|
|
|||
|
|
@ -157,21 +157,6 @@ class PromptClient(RequestResponse):
|
|||
timeout = timeout,
|
||||
)
|
||||
|
||||
async def kg_prompt(self, query, kg, timeout=600, streaming=False, chunk_callback=None):
|
||||
return await self.prompt(
|
||||
id = "kg-prompt",
|
||||
variables = {
|
||||
"query": query,
|
||||
"knowledge": [
|
||||
{ "s": v[0], "p": v[1], "o": v[2] }
|
||||
for v in kg
|
||||
]
|
||||
},
|
||||
timeout = timeout,
|
||||
streaming = streaming,
|
||||
chunk_callback = chunk_callback,
|
||||
)
|
||||
|
||||
async def document_prompt(self, query, documents, timeout=600, streaming=False, chunk_callback=None):
|
||||
return await self.prompt(
|
||||
id = "document-prompt",
|
||||
|
|
|
|||
43
trustgraph-base/trustgraph/base/reranker_client.py
Normal file
43
trustgraph-base/trustgraph/base/reranker_client.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
|
||||
from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||
from .. schema import (
|
||||
RerankerRequest, RerankerResponse,
|
||||
RerankerQuery, RerankerDocument,
|
||||
)
|
||||
|
||||
class RerankerClient(RequestResponse):
|
||||
async def rerank(self, queries, documents, limit=10, timeout=300):
|
||||
|
||||
resp = await self.request(
|
||||
RerankerRequest(
|
||||
queries=[
|
||||
RerankerQuery(query_id=q["id"], query_text=q["text"])
|
||||
for q in queries
|
||||
],
|
||||
documents=[
|
||||
RerankerDocument(
|
||||
document_id=d["id"], document_text=d["text"]
|
||||
)
|
||||
for d in documents
|
||||
],
|
||||
limit=limit,
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
return resp.results
|
||||
|
||||
class RerankerClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
self, request_name, response_name,
|
||||
):
|
||||
super(RerankerClientSpec, self).__init__(
|
||||
request_name = request_name,
|
||||
request_schema = RerankerRequest,
|
||||
response_name = response_name,
|
||||
response_schema = RerankerResponse,
|
||||
impl = RerankerClient,
|
||||
)
|
||||
109
trustgraph-base/trustgraph/base/reranker_service.py
Normal file
109
trustgraph-base/trustgraph/base/reranker_service.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import logging
|
||||
|
||||
from .. schema import (
|
||||
RerankerRequest, RerankerResponse, RerankerResult, Error,
|
||||
)
|
||||
from .. exceptions import TooManyRequests
|
||||
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec, ParameterSpec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "reranker"
|
||||
default_concurrency = 1
|
||||
|
||||
class RerankerService(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
concurrency = params.get("concurrency", 1)
|
||||
|
||||
super(RerankerService, self).__init__(**params | {
|
||||
"id": id,
|
||||
"concurrency": concurrency,
|
||||
})
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "request",
|
||||
schema = RerankerRequest,
|
||||
handler = self.on_request,
|
||||
concurrency = concurrency,
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "response",
|
||||
schema = RerankerResponse
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ParameterSpec(
|
||||
name = "model",
|
||||
)
|
||||
)
|
||||
|
||||
async def on_request(self, msg, consumer, flow):
|
||||
|
||||
try:
|
||||
|
||||
request = msg.value()
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
logger.debug(f"Handling reranker request {id}...")
|
||||
|
||||
model = flow("model")
|
||||
results = await self.on_rerank(
|
||||
request.queries, request.documents,
|
||||
request.limit, model=model,
|
||||
)
|
||||
|
||||
await flow("response").send(
|
||||
RerankerResponse(
|
||||
error = None,
|
||||
results = results,
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
logger.debug("Reranker request handled successfully")
|
||||
|
||||
except TooManyRequests as e:
|
||||
raise e
|
||||
|
||||
except Exception as e:
|
||||
|
||||
logger.error(f"Exception in reranker service: {e}", exc_info=True)
|
||||
|
||||
logger.info("Sending error response...")
|
||||
|
||||
await flow.producer["response"].send(
|
||||
RerankerResponse(
|
||||
error=Error(
|
||||
type = "reranker-error",
|
||||
message = str(e),
|
||||
),
|
||||
results=[],
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser: ArgumentParser) -> None:
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--concurrency',
|
||||
type=int,
|
||||
default=default_concurrency,
|
||||
help=f'Concurrent processing threads (default: {default_concurrency})'
|
||||
)
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
|
@ -140,20 +140,6 @@ class PromptClient(BaseClient):
|
|||
timeout=timeout
|
||||
)
|
||||
|
||||
def request_kg_prompt(self, query, kg, timeout=300):
|
||||
|
||||
return self.request(
|
||||
id="kg-prompt",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": [
|
||||
{ "s": v[0], "p": v[1], "o": v[2] }
|
||||
for v in kg
|
||||
]
|
||||
},
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
def request_document_prompt(self, query, documents, timeout=300):
|
||||
|
||||
return self.request(
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from .translators.rows_query import RowsQueryRequestTranslator, RowsQueryRespons
|
|||
from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator
|
||||
from .translators.structured_query import StructuredQueryRequestTranslator, StructuredQueryResponseTranslator
|
||||
from .translators.diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
|
||||
from .translators.reranker import RerankerRequestTranslator, RerankerResponseTranslator
|
||||
from .translators.collection import CollectionManagementRequestTranslator, CollectionManagementResponseTranslator
|
||||
from .translators.sparql_query import SparqlQueryRequestTranslator, SparqlQueryResponseTranslator
|
||||
|
||||
|
|
@ -163,6 +164,12 @@ TranslatorRegistry.register_service(
|
|||
SparqlQueryResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"reranker",
|
||||
RerankerRequestTranslator(),
|
||||
RerankerResponseTranslator()
|
||||
)
|
||||
|
||||
# Register single-direction translators for document loading
|
||||
TranslatorRegistry.register_request("document", DocumentTranslator())
|
||||
TranslatorRegistry.register_request("text-document", TextDocumentTranslator())
|
||||
|
|
|
|||
|
|
@ -20,3 +20,4 @@ from .embeddings_query import (
|
|||
)
|
||||
from .rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
|
||||
from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
|
||||
from .reranker import RerankerRequestTranslator, RerankerResponseTranslator
|
||||
|
|
|
|||
73
trustgraph-base/trustgraph/messaging/translators/reranker.py
Normal file
73
trustgraph-base/trustgraph/messaging/translators/reranker.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
from typing import Dict, Any, Tuple
|
||||
from ...schema import (
|
||||
RerankerRequest, RerankerResponse,
|
||||
RerankerQuery, RerankerDocument, RerankerResult,
|
||||
)
|
||||
from .base import MessageTranslator
|
||||
|
||||
|
||||
class RerankerRequestTranslator(MessageTranslator):
|
||||
|
||||
def decode(self, data: Dict[str, Any]) -> RerankerRequest:
|
||||
return RerankerRequest(
|
||||
queries=[
|
||||
RerankerQuery(
|
||||
query_id=q["query_id"],
|
||||
query_text=q["query_text"],
|
||||
)
|
||||
for q in data.get("queries", [])
|
||||
],
|
||||
documents=[
|
||||
RerankerDocument(
|
||||
document_id=d["document_id"],
|
||||
document_text=d["document_text"],
|
||||
)
|
||||
for d in data.get("documents", [])
|
||||
],
|
||||
limit=data.get("limit", 10),
|
||||
)
|
||||
|
||||
def encode(self, obj: RerankerRequest) -> Dict[str, Any]:
|
||||
return {
|
||||
"queries": [
|
||||
{"query_id": q.query_id, "query_text": q.query_text}
|
||||
for q in obj.queries
|
||||
],
|
||||
"documents": [
|
||||
{"document_id": d.document_id, "document_text": d.document_text}
|
||||
for d in obj.documents
|
||||
],
|
||||
"limit": obj.limit,
|
||||
}
|
||||
|
||||
|
||||
class RerankerResponseTranslator(MessageTranslator):
|
||||
|
||||
def decode(self, data: Dict[str, Any]) -> RerankerResponse:
|
||||
return RerankerResponse(
|
||||
results=[
|
||||
RerankerResult(
|
||||
document_id=r["document_id"],
|
||||
query_id=r["query_id"],
|
||||
score=r["score"],
|
||||
)
|
||||
for r in data.get("results", [])
|
||||
],
|
||||
)
|
||||
|
||||
def encode(self, obj: RerankerResponse) -> Dict[str, Any]:
|
||||
return {
|
||||
"results": [
|
||||
{
|
||||
"document_id": r.document_id,
|
||||
"query_id": r.query_id,
|
||||
"score": r.score,
|
||||
}
|
||||
for r in obj.results
|
||||
],
|
||||
}
|
||||
|
||||
def encode_with_completion(
|
||||
self, obj: RerankerResponse
|
||||
) -> Tuple[Dict[str, Any], bool]:
|
||||
return self.encode(obj), True
|
||||
|
|
@ -89,7 +89,9 @@ from . namespaces import (
|
|||
TG_IMAGE_TYPE, TG_SUBGRAPH_TYPE,
|
||||
# Query-time provenance predicates (GraphRAG)
|
||||
TG_QUERY, TG_CONCEPT, TG_ENTITY,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING, TG_SCORE,
|
||||
# Edge selection entity type
|
||||
TG_EDGE_SELECTION,
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
|
||||
# Explainability entity types
|
||||
|
|
@ -212,7 +214,9 @@ __all__ = [
|
|||
"TG_CHUNK_TYPE", "TG_IMAGE_TYPE", "TG_SUBGRAPH_TYPE",
|
||||
# Query-time provenance predicates (GraphRAG)
|
||||
"TG_QUERY", "TG_CONCEPT", "TG_ENTITY",
|
||||
"TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING",
|
||||
"TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING", "TG_SCORE",
|
||||
# Edge selection entity type
|
||||
"TG_EDGE_SELECTION",
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
"TG_CHUNK_COUNT", "TG_SELECTED_CHUNK",
|
||||
# Explainability entity types
|
||||
|
|
|
|||
|
|
@ -66,8 +66,12 @@ TG_EDGE_COUNT = TG + "edgeCount"
|
|||
TG_SELECTED_EDGE = TG + "selectedEdge"
|
||||
TG_EDGE = TG + "edge"
|
||||
TG_REASONING = TG + "reasoning"
|
||||
TG_SCORE = TG + "score"
|
||||
TG_DOCUMENT = TG + "document" # Reference to document in librarian
|
||||
|
||||
# Edge selection entity type (cross-encoder scored edge in Focus)
|
||||
TG_EDGE_SELECTION = TG + "EdgeSelection"
|
||||
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
TG_CHUNK_COUNT = TG + "chunkCount"
|
||||
TG_SELECTED_CHUNK = TG + "selectedChunk"
|
||||
|
|
|
|||
|
|
@ -24,8 +24,10 @@ from . namespaces import (
|
|||
TG_ELEMENT_TYPES, TG_TABLE_COUNT, TG_IMAGE_COUNT,
|
||||
# Query-time provenance predicates (GraphRAG)
|
||||
TG_QUERY, TG_CONCEPT, TG_ENTITY,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, TG_SCORE,
|
||||
TG_DOCUMENT,
|
||||
# Edge selection entity type
|
||||
TG_EDGE_SELECTION,
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
|
||||
# Explainability entity types
|
||||
|
|
@ -536,10 +538,9 @@ def focus_triples(
|
|||
_triple(focus_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)),
|
||||
]
|
||||
|
||||
# Add each selected edge with its reasoning via intermediate entity
|
||||
# Add each selected edge with metadata via intermediate entity
|
||||
for idx, edge_info in enumerate(selected_edges_with_reasoning):
|
||||
edge = edge_info.get("edge")
|
||||
reasoning = edge_info.get("reasoning", "")
|
||||
|
||||
if edge:
|
||||
s, p, o = edge
|
||||
|
|
@ -552,13 +553,32 @@ def focus_triples(
|
|||
_triple(focus_uri, TG_SELECTED_EDGE, _iri(edge_sel_uri))
|
||||
)
|
||||
|
||||
# Type the edge selection entity
|
||||
triples.append(
|
||||
_triple(edge_sel_uri, RDF_TYPE, _iri(TG_EDGE_SELECTION))
|
||||
)
|
||||
|
||||
# Attach quoted triple to edge selection entity
|
||||
quoted = _quoted_triple(s, p, o)
|
||||
triples.append(
|
||||
Triple(s=_iri(edge_sel_uri), p=_iri(TG_EDGE), o=quoted)
|
||||
)
|
||||
|
||||
# Attach reasoning to edge selection entity
|
||||
# Structured cross-encoder metadata
|
||||
concept = edge_info.get("concept")
|
||||
if concept:
|
||||
triples.append(
|
||||
_triple(edge_sel_uri, TG_CONCEPT, _literal(concept))
|
||||
)
|
||||
|
||||
score = edge_info.get("score")
|
||||
if score is not None:
|
||||
triples.append(
|
||||
_triple(edge_sel_uri, TG_SCORE, _literal(str(score)))
|
||||
)
|
||||
|
||||
# Legacy reasoning text (for non-cross-encoder callers)
|
||||
reasoning = edge_info.get("reasoning", "")
|
||||
if reasoning:
|
||||
triples.append(
|
||||
_triple(edge_sel_uri, TG_REASONING, _literal(reasoning))
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from . namespaces import (
|
|||
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
|
||||
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
|
||||
TG_SUBAGENT_GOAL, TG_PLAN_STEP,
|
||||
TG_EDGE_SELECTION, TG_SCORE,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -93,6 +94,7 @@ TG_CLASS_LABELS = [
|
|||
_label_triple(TG_FINDING, "Finding"),
|
||||
_label_triple(TG_PLAN_TYPE, "Plan"),
|
||||
_label_triple(TG_STEP_RESULT, "Step Result"),
|
||||
_label_triple(TG_EDGE_SELECTION, "Edge Selection"),
|
||||
]
|
||||
|
||||
# TrustGraph predicate labels
|
||||
|
|
@ -117,6 +119,7 @@ TG_PREDICATE_LABELS = [
|
|||
_label_triple(TG_ENTITY, "entity"),
|
||||
_label_triple(TG_SUBAGENT_GOAL, "subagent goal"),
|
||||
_label_triple(TG_PLAN_STEP, "plan step"),
|
||||
_label_triple(TG_SCORE, "score"),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,4 +15,5 @@ from .diagnosis import *
|
|||
from .collection import *
|
||||
from .storage import *
|
||||
from .tool_service import *
|
||||
from .sparql_query import *
|
||||
from .sparql_query import *
|
||||
from .reranker import *
|
||||
|
|
@ -6,17 +6,6 @@ from ..core.primitives import Error
|
|||
|
||||
# Prompt services, abstract the prompt generation
|
||||
|
||||
# extract-definitions:
|
||||
# chunk -> definitions
|
||||
# extract-relationships:
|
||||
# chunk -> relationships
|
||||
# kg-prompt:
|
||||
# query, triples -> answer
|
||||
# document-prompt:
|
||||
# query, documents -> answer
|
||||
# extract-rows
|
||||
# schema, chunk -> rows
|
||||
|
||||
@dataclass
|
||||
class PromptRequest:
|
||||
id: str = ""
|
||||
|
|
@ -46,4 +35,4 @@ class PromptResponse:
|
|||
out_token: int | None = None
|
||||
model: str | None = None
|
||||
|
||||
############################################################################
|
||||
############################################################################
|
||||
|
|
|
|||
35
trustgraph-base/trustgraph/schema/services/reranker.py
Normal file
35
trustgraph-base/trustgraph/schema/services/reranker.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..core.primitives import Error
|
||||
|
||||
############################################################################
|
||||
|
||||
# Cross-encoder reranker
|
||||
|
||||
@dataclass
|
||||
class RerankerQuery:
|
||||
query_id: str = ""
|
||||
query_text: str = ""
|
||||
|
||||
@dataclass
|
||||
class RerankerDocument:
|
||||
document_id: str = ""
|
||||
document_text: str = ""
|
||||
|
||||
@dataclass
|
||||
class RerankerRequest:
|
||||
queries: list[RerankerQuery] = field(default_factory=list)
|
||||
documents: list[RerankerDocument] = field(default_factory=list)
|
||||
limit: int = 10
|
||||
|
||||
@dataclass
|
||||
class RerankerResult:
|
||||
document_id: str = ""
|
||||
query_id: str = ""
|
||||
score: float = 0.0
|
||||
|
||||
@dataclass
|
||||
class RerankerResponse:
|
||||
error: Error | None = None
|
||||
results: list[RerankerResult] = field(default_factory=list)
|
||||
|
|
@ -71,6 +71,7 @@ tg-invoke-rows-query = "trustgraph.cli.invoke_rows_query:main"
|
|||
tg-invoke-sparql-query = "trustgraph.cli.invoke_sparql_query:main"
|
||||
tg-invoke-row-embeddings = "trustgraph.cli.invoke_row_embeddings:main"
|
||||
tg-invoke-prompt = "trustgraph.cli.invoke_prompt:main"
|
||||
tg-invoke-reranker = "trustgraph.cli.invoke_reranker:main"
|
||||
tg-invoke-structured-query = "trustgraph.cli.invoke_structured_query:main"
|
||||
tg-load-doc-embeds = "trustgraph.cli.load_doc_embeds:main"
|
||||
tg-load-kg-core = "trustgraph.cli.load_kg_core:main"
|
||||
|
|
|
|||
|
|
@ -112,14 +112,13 @@ def _question_explainable_api(
|
|||
if focus_full and focus_full.edge_selections:
|
||||
for edge_sel in focus_full.edge_selections:
|
||||
if edge_sel.edge:
|
||||
# Resolve labels for edge components
|
||||
s_label, p_label, o_label = explain_client.resolve_edge_labels(
|
||||
edge_sel.edge, collection
|
||||
)
|
||||
print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr)
|
||||
if edge_sel.reasoning:
|
||||
r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning
|
||||
print(f" Reason: {r_short}", file=sys.stderr)
|
||||
if edge_sel.concept or edge_sel.score is not None:
|
||||
score_str = f"{edge_sel.score:.4f}" if edge_sel.score is not None else "?"
|
||||
print(f" Concept: {edge_sel.concept} Score: {score_str}", file=sys.stderr)
|
||||
|
||||
elif isinstance(entity, Synthesis):
|
||||
print(f"\n [synthesis] {prov_id}", file=sys.stderr)
|
||||
|
|
|
|||
127
trustgraph-cli/trustgraph/cli/invoke_reranker.py
Normal file
127
trustgraph-cli/trustgraph/cli/invoke_reranker.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
"""
|
||||
Invokes the reranker service to score and rank documents by relevance
|
||||
to one or more queries.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from trustgraph.api import Api
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
|
||||
|
||||
def query(url, flow_id, queries, documents, limit, token=None,
|
||||
workspace="default"):
|
||||
|
||||
api = Api(url=url, token=token, workspace=workspace)
|
||||
socket = api.socket()
|
||||
flow = socket.flow(flow_id)
|
||||
|
||||
try:
|
||||
|
||||
query_objects = [
|
||||
{"query_id": str(i), "query_text": q}
|
||||
for i, q in enumerate(queries)
|
||||
]
|
||||
|
||||
document_objects = [
|
||||
{"document_id": str(i), "document_text": d}
|
||||
for i, d in enumerate(documents)
|
||||
]
|
||||
|
||||
result = flow.rerank(
|
||||
queries=query_objects,
|
||||
documents=document_objects,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
if "error" in result and result["error"]:
|
||||
err = result["error"]
|
||||
print(f"Error: [{err.get('type', '')}] {err.get('message', '')}")
|
||||
return
|
||||
|
||||
for r in result.get("results", []):
|
||||
doc_idx = int(r["document_id"])
|
||||
query_idx = int(r["query_id"])
|
||||
print(
|
||||
f" {r['score']:.4f} | "
|
||||
f"query: {queries[query_idx]} | "
|
||||
f"doc: {documents[doc_idx]}"
|
||||
)
|
||||
|
||||
finally:
|
||||
socket.close()
|
||||
|
||||
def main():
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='tg-invoke-reranker',
|
||||
description=__doc__,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-u', '--url',
|
||||
default=default_url,
|
||||
help=f'API URL (default: {default_url})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--token',
|
||||
default=default_token,
|
||||
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-w', '--workspace',
|
||||
default=default_workspace,
|
||||
help=f'Workspace (default: {default_workspace})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-f', '--flow-id',
|
||||
default="default",
|
||||
help=f'Flow ID (default: default)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-l', '--limit',
|
||||
type=int,
|
||||
default=10,
|
||||
help='Maximum number of results (default: 10)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-q', '--query',
|
||||
action='append',
|
||||
required=True,
|
||||
help='Query text (can be specified multiple times)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'documents',
|
||||
nargs='+',
|
||||
help='Documents to rerank',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
||||
query(
|
||||
url=args.url,
|
||||
flow_id=args.flow_id,
|
||||
queries=args.query,
|
||||
documents=args.documents,
|
||||
limit=args.limit,
|
||||
token=args.token,
|
||||
workspace=args.workspace,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print("Exception:", e, flush=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -203,9 +203,9 @@ def print_graphrag_text(trace, explain_client, flow, collection, api=None, show_
|
|||
)
|
||||
print(f" {i}. ({s_label}, {p_label}, {o_label})")
|
||||
|
||||
if edge_sel.reasoning:
|
||||
r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning
|
||||
print(f" Reasoning: {r_short}")
|
||||
if edge_sel.concept or edge_sel.score is not None:
|
||||
score_str = f"{edge_sel.score:.4f}" if edge_sel.score is not None else "?"
|
||||
print(f" Concept: {edge_sel.concept} Score: {score_str}")
|
||||
|
||||
if show_provenance and edge_sel.edge:
|
||||
provenance = trace_edge_provenance(
|
||||
|
|
@ -519,7 +519,8 @@ def trace_to_dict(trace, trace_type):
|
|||
"selected_edges": [
|
||||
{
|
||||
"edge": edge_sel.edge,
|
||||
"reasoning": edge_sel.reasoning,
|
||||
"concept": edge_sel.concept,
|
||||
"score": edge_sel.score,
|
||||
}
|
||||
for edge_sel in focus.edge_selections
|
||||
],
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ dependencies = [
|
|||
"faiss-cpu",
|
||||
"falkordb",
|
||||
"fastembed",
|
||||
"flashrank",
|
||||
"ibis",
|
||||
"jsonschema",
|
||||
"langchain",
|
||||
|
|
@ -83,6 +84,7 @@ graph-embeddings-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone:
|
|||
graph-embeddings-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run"
|
||||
graph-embeddings = "trustgraph.embeddings.graph_embeddings:run"
|
||||
graph-rag = "trustgraph.retrieval.graph_rag:run"
|
||||
reranker-flashrank = "trustgraph.reranker.flashrank:run"
|
||||
kg-extract-agent = "trustgraph.extract.kg.agent:run"
|
||||
kg-extract-definitions = "trustgraph.extract.kg.definitions:run"
|
||||
kg-extract-rows = "trustgraph.extract.kg.rows:run"
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
|
|||
from . document_embeddings_query import DocumentEmbeddingsQueryRequestor
|
||||
from . row_embeddings_query import RowEmbeddingsQueryRequestor
|
||||
from . mcp_tool import McpToolRequestor
|
||||
from . reranker import RerankerRequestor
|
||||
from . text_load import TextLoad
|
||||
from . document_load import DocumentLoad
|
||||
|
||||
|
|
@ -74,6 +75,7 @@ request_response_dispatchers = {
|
|||
"structured-diag": StructuredDiagRequestor,
|
||||
"row-embeddings": RowEmbeddingsQueryRequestor,
|
||||
"sparql": SparqlQueryRequestor,
|
||||
"reranker": RerankerRequestor,
|
||||
}
|
||||
|
||||
system_dispatchers = {
|
||||
|
|
|
|||
31
trustgraph-flow/trustgraph/gateway/dispatch/reranker.py
Normal file
31
trustgraph-flow/trustgraph/gateway/dispatch/reranker.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
|
||||
from ... schema import RerankerRequest, RerankerResponse
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class RerankerRequestor(ServiceRequestor):
|
||||
def __init__(
|
||||
self, backend, request_queue, response_queue, timeout,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
super(RerankerRequestor, self).__init__(
|
||||
backend=backend,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
request_schema=RerankerRequest,
|
||||
response_schema=RerankerResponse,
|
||||
subscription = subscriber,
|
||||
consumer_name = consumer,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("reranker")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("reranker")
|
||||
|
||||
def to_request(self, body):
|
||||
return self.request_translator.decode(body)
|
||||
|
||||
def from_response(self, message):
|
||||
return self.response_translator.encode_with_completion(message)
|
||||
|
|
@ -518,6 +518,7 @@ _FLOW_SERVICES = {
|
|||
"structured-diag": "structured-query:read",
|
||||
"row-embeddings": "row-embeddings:read",
|
||||
"sparql": "sparql:read",
|
||||
"reranker": "reranker",
|
||||
}
|
||||
for _kind, _cap in _FLOW_SERVICES.items():
|
||||
_register_flow_kind("flow-service", _kind, _cap)
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ _READER_CAPS = {
|
|||
"row-embeddings:read",
|
||||
"llm",
|
||||
"embeddings",
|
||||
"reranker",
|
||||
"mcp",
|
||||
"config:read",
|
||||
"flows:read",
|
||||
|
|
|
|||
1
trustgraph-flow/trustgraph/reranker/__init__.py
Normal file
1
trustgraph-flow/trustgraph/reranker/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
|
||||
from . processor import *
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . processor import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
109
trustgraph-flow/trustgraph/reranker/flashrank/processor.py
Normal file
109
trustgraph-flow/trustgraph/reranker/flashrank/processor.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
|
||||
"""
|
||||
Reranker service using flashrank.
|
||||
Scores query-document pairs and returns the top results ranked by
|
||||
relevance.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from ... base import RerankerService
|
||||
from ... schema import RerankerResult
|
||||
|
||||
from flashrank import Ranker, RerankRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "reranker"
|
||||
|
||||
default_model = "ms-marco-MiniLM-L-12-v2"
|
||||
|
||||
class Processor(RerankerService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
model = params.get("model", default_model)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | { "model": model }
|
||||
)
|
||||
|
||||
self.default_model = model
|
||||
|
||||
self.cached_model_name = None
|
||||
self.ranker = None
|
||||
|
||||
self._load_model(model)
|
||||
|
||||
def _load_model(self, model_name):
|
||||
if self.cached_model_name != model_name:
|
||||
logger.info(f"Loading flashrank model: {model_name}")
|
||||
self.ranker = Ranker(model_name=model_name)
|
||||
self.cached_model_name = model_name
|
||||
logger.info(f"flashrank model {model_name} loaded successfully")
|
||||
else:
|
||||
logger.debug(f"Using cached model: {model_name}")
|
||||
|
||||
def _run_rerank(self, query, passages):
|
||||
request = RerankRequest(query=query, passages=passages)
|
||||
return self.ranker.rerank(request)
|
||||
|
||||
async def on_rerank(self, queries, documents, limit, model=None):
|
||||
|
||||
if not queries or not documents:
|
||||
return []
|
||||
|
||||
use_model = model or self.default_model
|
||||
|
||||
if self.cached_model_name != use_model:
|
||||
await asyncio.to_thread(self._load_model, use_model)
|
||||
|
||||
passages = [
|
||||
{"id": d.document_id, "text": d.document_text}
|
||||
for d in documents
|
||||
]
|
||||
|
||||
best_scores = {}
|
||||
|
||||
for q in queries:
|
||||
ranked = await asyncio.to_thread(
|
||||
self._run_rerank, q.query_text, passages,
|
||||
)
|
||||
|
||||
for r in ranked:
|
||||
doc_id = r["id"]
|
||||
score = r["score"]
|
||||
score = float(score)
|
||||
if doc_id not in best_scores or score > best_scores[doc_id][1]:
|
||||
best_scores[doc_id] = (q.query_id, score)
|
||||
|
||||
results = sorted(
|
||||
best_scores.items(),
|
||||
key=lambda x: x[1][1],
|
||||
reverse=True,
|
||||
)[:limit]
|
||||
|
||||
return [
|
||||
RerankerResult(
|
||||
document_id=doc_id,
|
||||
query_id=query_id,
|
||||
score=score,
|
||||
)
|
||||
for doc_id, (query_id, score) in results
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
RerankerService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default=default_model,
|
||||
help=f'Reranker model (default: {default_model})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
@ -120,7 +120,7 @@ class Query:
|
|||
def __init__(
|
||||
self, rag, collection, verbose,
|
||||
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
|
||||
max_path_length=2, track_usage=None,
|
||||
max_path_length=2, edge_limit=25, track_usage=None,
|
||||
):
|
||||
self.rag = rag
|
||||
self.collection = collection
|
||||
|
|
@ -129,6 +129,7 @@ class Query:
|
|||
self.triple_limit = triple_limit
|
||||
self.max_subgraph_size = max_subgraph_size
|
||||
self.max_path_length = max_path_length
|
||||
self.edge_limit = edge_limit
|
||||
self.track_usage = track_usage
|
||||
|
||||
async def extract_concepts(self, query):
|
||||
|
|
@ -217,12 +218,9 @@ class Query:
|
|||
logger.debug(f" {ent}")
|
||||
|
||||
return entities, concepts
|
||||
|
||||
|
||||
async def maybe_label(self, e):
|
||||
|
||||
# The label cache lives on a per-request GraphRag instance — no
|
||||
# cross-request isolation concern. The collection prefix keeps
|
||||
# entries from different collections distinct within one request.
|
||||
cache_key = f"{self.collection}:{e}"
|
||||
|
||||
cached_label = self.rag.label_cache.get(cache_key)
|
||||
|
|
@ -244,11 +242,10 @@ class Query:
|
|||
return label
|
||||
|
||||
async def execute_batch_triple_queries(self, entities, limit_per_entity):
|
||||
"""Execute triple queries for multiple entities concurrently using streaming"""
|
||||
"""Execute triple queries for multiple entities concurrently."""
|
||||
tasks = []
|
||||
|
||||
for entity in entities:
|
||||
# Create concurrent streaming tasks for all 3 query types per entity
|
||||
tasks.extend([
|
||||
self.rag.triples_client.query_stream(
|
||||
s=entity, p=None, o=None,
|
||||
|
|
@ -270,10 +267,8 @@ class Query:
|
|||
)
|
||||
])
|
||||
|
||||
# Execute all queries concurrently
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Combine all results
|
||||
all_triples = []
|
||||
for result in results:
|
||||
if not isinstance(result, Exception) and result is not None:
|
||||
|
|
@ -281,168 +276,151 @@ class Query:
|
|||
|
||||
return all_triples
|
||||
|
||||
async def follow_edges_batch(self, entities, max_depth):
|
||||
"""Optimized iterative graph traversal with batching.
|
||||
|
||||
Returns:
|
||||
tuple: (subgraph, term_map) where subgraph is a set of
|
||||
(str, str, str) tuples and term_map maps each string tuple
|
||||
to its original (Term, Term, Term) for type-preserving
|
||||
provenance.
|
||||
"""
|
||||
visited = set()
|
||||
current_level = set(entities)
|
||||
subgraph = set()
|
||||
term_map = {} # (str, str, str) -> (Term, Term, Term)
|
||||
|
||||
for depth in range(max_depth):
|
||||
if not current_level or len(subgraph) >= self.max_subgraph_size:
|
||||
break
|
||||
|
||||
# Filter out already visited entities
|
||||
unvisited_entities = [e for e in current_level if e not in visited]
|
||||
if not unvisited_entities:
|
||||
break
|
||||
|
||||
# Batch query all unvisited entities at current level
|
||||
triples = await self.execute_batch_triple_queries(
|
||||
unvisited_entities, self.triple_limit
|
||||
)
|
||||
|
||||
# Process results and collect next level entities
|
||||
next_level = set()
|
||||
for triple in triples:
|
||||
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
|
||||
subgraph.add(triple_tuple)
|
||||
term_map[triple_tuple] = (to_term(triple.s), to_term(triple.p), to_term(triple.o))
|
||||
|
||||
# Collect entities for next level (only from s and o positions)
|
||||
if depth < max_depth - 1: # Don't collect for final depth
|
||||
s, p, o = triple_tuple
|
||||
if s not in visited:
|
||||
next_level.add(s)
|
||||
if o not in visited:
|
||||
next_level.add(o)
|
||||
|
||||
# Stop if subgraph size limit reached
|
||||
if len(subgraph) >= self.max_subgraph_size:
|
||||
return subgraph, term_map
|
||||
|
||||
# Update for next iteration
|
||||
visited.update(current_level)
|
||||
current_level = next_level
|
||||
|
||||
return subgraph, term_map
|
||||
|
||||
async def follow_edges(self, ent, subgraph, path_length):
|
||||
"""Legacy method - replaced by follow_edges_batch"""
|
||||
# Maintain backward compatibility with early termination checks
|
||||
if path_length <= 0:
|
||||
return
|
||||
|
||||
if len(subgraph) >= self.max_subgraph_size:
|
||||
return
|
||||
|
||||
# For backward compatibility, convert to new approach
|
||||
batch_result, _ = await self.follow_edges_batch([ent], path_length)
|
||||
subgraph.update(batch_result)
|
||||
|
||||
async def get_subgraph(self, query):
|
||||
"""
|
||||
Get subgraph by extracting concepts, finding entities, and traversing.
|
||||
|
||||
Returns:
|
||||
tuple: (subgraph, term_map, entities, concepts) where subgraph is
|
||||
a list of (s, p, o) string tuples, term_map maps each string
|
||||
tuple to its original (Term, Term, Term), entities is the seed
|
||||
entity list, and concepts is the extracted concept list.
|
||||
"""
|
||||
|
||||
entities, concepts = await self.get_entities(query)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Getting subgraph...")
|
||||
|
||||
# Use optimized batch traversal instead of sequential processing
|
||||
subgraph, term_map = await self.follow_edges_batch(entities, self.max_path_length)
|
||||
|
||||
return list(subgraph), term_map, entities, concepts
|
||||
|
||||
async def resolve_labels_batch(self, entities):
|
||||
"""Resolve labels for multiple entities in parallel"""
|
||||
tasks = []
|
||||
for entity in entities:
|
||||
tasks.append(self.maybe_label(entity))
|
||||
|
||||
"""Resolve labels for multiple entities in parallel."""
|
||||
tasks = [self.maybe_label(entity) for entity in entities]
|
||||
return await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def get_labelgraph(self, query):
|
||||
"""
|
||||
Get subgraph with labels resolved for display.
|
||||
async def hop_and_filter(self, seed_entities, concepts):
|
||||
"""Iterative hop-and-filter graph traversal with cross-encoder.
|
||||
|
||||
At each hop:
|
||||
1. Retrieve all edges one hop from the frontier.
|
||||
2. Resolve labels and represent each edge as "{p} {o}".
|
||||
3. Score edges against concepts using the cross-encoder.
|
||||
4. Select the top edges; their target nodes become the next
|
||||
frontier.
|
||||
|
||||
Returns:
|
||||
tuple: (labeled_edges, uri_map, entities, concepts) where:
|
||||
- labeled_edges: list of (label_s, label_p, label_o) tuples
|
||||
- uri_map: dict mapping edge_id(label_s, label_p, label_o) -> (uri_s, uri_p, uri_o)
|
||||
- entities: list of seed entity URI strings
|
||||
- concepts: list of concept strings extracted from query
|
||||
tuple: (selected_edges, uri_map, edge_metadata) where:
|
||||
- selected_edges: list of (label_s, label_p, label_o)
|
||||
- uri_map: dict mapping edge_id -> (Term, Term, Term)
|
||||
- edge_metadata: dict mapping edge_id -> {concept, score}
|
||||
"""
|
||||
subgraph, term_map, entities, concepts = await self.get_subgraph(query)
|
||||
all_selected_edges = []
|
||||
uri_map = {}
|
||||
edge_metadata = {}
|
||||
frontier = set(seed_entities)
|
||||
visited_entities = set()
|
||||
seen_edges = set()
|
||||
|
||||
# Filter out label triples
|
||||
filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL]
|
||||
for hop in range(self.max_path_length):
|
||||
if not frontier:
|
||||
break
|
||||
|
||||
# Collect all unique entities that need label resolution
|
||||
entities_to_resolve = set()
|
||||
for s, p, o in filtered_subgraph:
|
||||
entities_to_resolve.update([s, p, o])
|
||||
unvisited = [e for e in frontier if e not in visited_entities]
|
||||
if not unvisited:
|
||||
break
|
||||
|
||||
# Batch resolve labels for all entities in parallel
|
||||
entity_list = list(entities_to_resolve)
|
||||
resolved_labels = await self.resolve_labels_batch(entity_list)
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Hop {hop + 1}: {len(unvisited)} frontier entities"
|
||||
)
|
||||
|
||||
# Create entity-to-label mapping
|
||||
label_map = {}
|
||||
for entity, label in zip(entity_list, resolved_labels):
|
||||
if not isinstance(label, Exception):
|
||||
label_map[entity] = label
|
||||
else:
|
||||
label_map[entity] = entity # Fallback to entity itself
|
||||
|
||||
# Apply labels to subgraph and build URI mapping
|
||||
labeled_edges = []
|
||||
uri_map = {} # Maps edge_id of labeled edge -> original Term triple
|
||||
|
||||
for s, p, o in filtered_subgraph:
|
||||
labeled_triple = (
|
||||
label_map.get(s, s),
|
||||
label_map.get(p, p),
|
||||
label_map.get(o, o)
|
||||
# Retrieve edges one hop from frontier
|
||||
triples = await self.execute_batch_triple_queries(
|
||||
unvisited, self.triple_limit,
|
||||
)
|
||||
labeled_edges.append(labeled_triple)
|
||||
|
||||
# Map from labeled edge ID to original Terms (preserving types)
|
||||
labeled_eid = edge_id(labeled_triple[0], labeled_triple[1], labeled_triple[2])
|
||||
uri_map[labeled_eid] = term_map.get((s, p, o), (s, p, o))
|
||||
# Deduplicate and filter already-seen edges
|
||||
hop_triples = []
|
||||
hop_term_map = {}
|
||||
for triple in triples:
|
||||
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
|
||||
if triple_tuple[1] == LABEL:
|
||||
continue
|
||||
if triple_tuple in seen_edges:
|
||||
continue
|
||||
seen_edges.add(triple_tuple)
|
||||
hop_triples.append(triple_tuple)
|
||||
hop_term_map[triple_tuple] = (
|
||||
to_term(triple.s), to_term(triple.p), to_term(triple.o),
|
||||
)
|
||||
|
||||
labeled_edges = labeled_edges[0:self.max_subgraph_size]
|
||||
if not hop_triples:
|
||||
visited_entities.update(frontier)
|
||||
break
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Subgraph:")
|
||||
for edge in labeled_edges:
|
||||
logger.debug(f" {str(edge)}")
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Hop {hop + 1}: {len(hop_triples)} candidate edges"
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Done.")
|
||||
# Resolve labels for all entities in hop edges
|
||||
entities_to_resolve = set()
|
||||
for s, p, o in hop_triples:
|
||||
entities_to_resolve.update([s, p, o])
|
||||
|
||||
return labeled_edges, uri_map, entities, concepts
|
||||
entity_list = list(entities_to_resolve)
|
||||
resolved = await self.resolve_labels_batch(entity_list)
|
||||
|
||||
label_map = {}
|
||||
for entity, label in zip(entity_list, resolved):
|
||||
if not isinstance(label, Exception):
|
||||
label_map[entity] = label
|
||||
else:
|
||||
label_map[entity] = entity
|
||||
|
||||
# Build labeled edges and documents for cross-encoder
|
||||
labeled_hop = []
|
||||
for s, p, o in hop_triples:
|
||||
ls = label_map.get(s, s)
|
||||
lp = label_map.get(p, p)
|
||||
lo = label_map.get(o, o)
|
||||
labeled_hop.append((ls, lp, lo))
|
||||
|
||||
documents = [
|
||||
{"id": str(i), "text": f"{lp} {lo}"}
|
||||
for i, (ls, lp, lo) in enumerate(labeled_hop)
|
||||
]
|
||||
|
||||
queries = [
|
||||
{"id": str(i), "text": c}
|
||||
for i, c in enumerate(concepts)
|
||||
]
|
||||
|
||||
# Score with cross-encoder
|
||||
results = await self.rag.reranker_client.rerank(
|
||||
queries=queries,
|
||||
documents=documents,
|
||||
limit=self.edge_limit,
|
||||
)
|
||||
|
||||
# Collect selected edges and metadata
|
||||
next_frontier = set()
|
||||
for r in results:
|
||||
idx = int(r.document_id)
|
||||
ls, lp, lo = labeled_hop[idx]
|
||||
s, p, o = hop_triples[idx]
|
||||
eid = edge_id(ls, lp, lo)
|
||||
|
||||
all_selected_edges.append((ls, lp, lo))
|
||||
uri_map[eid] = hop_term_map[(s, p, o)]
|
||||
edge_metadata[eid] = {
|
||||
"concept": concepts[int(r.query_id)],
|
||||
"score": r.score,
|
||||
}
|
||||
|
||||
# Target nodes become next frontier
|
||||
next_frontier.add(s)
|
||||
next_frontier.add(o)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Hop {hop + 1}: selected {len(results)} edges"
|
||||
)
|
||||
|
||||
visited_entities.update(frontier)
|
||||
frontier = next_frontier - visited_entities
|
||||
|
||||
return all_selected_edges, uri_map, edge_metadata
|
||||
|
||||
async def trace_source_documents(self, edge_uris):
|
||||
"""
|
||||
Trace selected edges back to their source documents via provenance.
|
||||
|
||||
Follows the chain: edge → subgraph (via tg:contains) → chunk →
|
||||
page → document (via prov:wasDerivedFrom), all in urn:graph:source.
|
||||
Follows the chain: edge -> subgraph (via tg:contains) -> chunk ->
|
||||
page -> document (via prov:wasDerivedFrom), all in urn:graph:source.
|
||||
|
||||
Args:
|
||||
edge_uris: List of (s, p, o) URI string tuples
|
||||
|
|
@ -453,7 +431,6 @@ class Query:
|
|||
# Step 1: Find subgraphs containing these edges via tg:contains
|
||||
subgraph_tasks = []
|
||||
for s, p, o in edge_uris:
|
||||
# s, p, o may be Term objects (preserving types) or strings
|
||||
s_term = s if isinstance(s, Term) else Term(type=IRI, iri=s)
|
||||
p_term = p if isinstance(p, Term) else Term(type=IRI, iri=p)
|
||||
o_term = o if isinstance(o, Term) else Term(type=IRI, iri=o)
|
||||
|
|
@ -487,12 +464,10 @@ class Query:
|
|||
return []
|
||||
|
||||
# Step 2: Walk prov:wasDerivedFrom chain to find documents
|
||||
# Each level: query ?entity prov:wasDerivedFrom ?parent
|
||||
# Stop when we find entities typed tg:Document
|
||||
current_uris = subgraph_uris
|
||||
doc_uris = set()
|
||||
|
||||
for depth in range(4): # Max depth: subgraph → chunk → page → doc
|
||||
for depth in range(4):
|
||||
if not current_uris:
|
||||
break
|
||||
|
||||
|
|
@ -509,7 +484,6 @@ class Query:
|
|||
*derivation_tasks, return_exceptions=True
|
||||
)
|
||||
|
||||
# URIs with no parent are root documents
|
||||
next_uris = set()
|
||||
for uri, result in zip(current_uris, derivation_results):
|
||||
if isinstance(result, Exception) or not result:
|
||||
|
|
@ -524,7 +498,6 @@ class Query:
|
|||
return []
|
||||
|
||||
# Step 3: Get all document metadata properties
|
||||
# Skip structural predicates that aren't useful context
|
||||
SKIP_PREDICATES = {
|
||||
PROV_WAS_DERIVED_FROM,
|
||||
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
|
||||
|
|
@ -565,7 +538,7 @@ class GraphRag:
|
|||
|
||||
def __init__(
|
||||
self, prompt_client, embeddings_client, graph_embeddings_client,
|
||||
triples_client, verbose=False,
|
||||
triples_client, reranker_client, verbose=False,
|
||||
):
|
||||
|
||||
self.verbose = verbose
|
||||
|
|
@ -574,9 +547,8 @@ class GraphRag:
|
|||
self.embeddings_client = embeddings_client
|
||||
self.graph_embeddings_client = graph_embeddings_client
|
||||
self.triples_client = triples_client
|
||||
self.reranker_client = reranker_client
|
||||
|
||||
# Replace simple dict with LRU cache with TTL
|
||||
# CRITICAL: This cache only lives for one request due to per-request instantiation
|
||||
self.label_cache = LRUCacheWithTTL(max_size=5000, ttl=300)
|
||||
|
||||
if self.verbose:
|
||||
|
|
@ -585,33 +557,12 @@ class GraphRag:
|
|||
async def query(
|
||||
self, query, collection = "default",
|
||||
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
|
||||
max_path_length = 2, edge_score_limit = 30, edge_limit = 25,
|
||||
max_path_length = 2, edge_limit = 25,
|
||||
streaming = False,
|
||||
chunk_callback = None,
|
||||
explain_callback = None, save_answer_callback = None,
|
||||
parent_uri = "",
|
||||
):
|
||||
"""
|
||||
Execute a GraphRAG query with real-time explainability tracking.
|
||||
|
||||
Args:
|
||||
query: The query string
|
||||
collection: Collection identifier
|
||||
entity_limit: Max entities to retrieve
|
||||
triple_limit: Max triples per entity
|
||||
max_subgraph_size: Max edges in subgraph
|
||||
max_path_length: Max hops from seed entities
|
||||
edge_score_limit: Max edges to pass to LLM scoring (semantic pre-filter)
|
||||
edge_limit: Max edges after LLM scoring
|
||||
streaming: Enable streaming LLM response
|
||||
chunk_callback: async def callback(chunk, end_of_stream) for streaming
|
||||
explain_callback: async def callback(triples, explain_id) for real-time explainability
|
||||
save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian
|
||||
|
||||
Returns:
|
||||
tuple: (answer_text, usage) where usage is a dict with
|
||||
in_token, out_token, model
|
||||
"""
|
||||
# Accumulate token usage across all prompt calls
|
||||
total_in = 0
|
||||
total_out = 0
|
||||
|
|
@ -638,7 +589,9 @@ class GraphRag:
|
|||
foc_uri = make_focus_uri(session_id)
|
||||
syn_uri = make_synthesis_uri(session_id)
|
||||
|
||||
timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
timestamp = datetime.now(timezone.utc).isoformat().replace(
|
||||
"+00:00", "Z",
|
||||
)
|
||||
|
||||
# Emit question explainability immediately
|
||||
if explain_callback:
|
||||
|
|
@ -657,10 +610,12 @@ class GraphRag:
|
|||
triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
edge_limit = edge_limit,
|
||||
track_usage = track_usage,
|
||||
)
|
||||
|
||||
kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query)
|
||||
# Step 1: Extract concepts and find seed entities
|
||||
seed_entities, concepts = await q.get_entities(query)
|
||||
|
||||
# Emit grounding explain after concept extraction
|
||||
if explain_callback:
|
||||
|
|
@ -676,11 +631,16 @@ class GraphRag:
|
|||
)
|
||||
await explain_callback(gnd_triples, gnd_uri)
|
||||
|
||||
# Emit exploration explain after graph retrieval completes
|
||||
# Step 2: Iterative hop-and-filter with cross-encoder
|
||||
selected_edges, uri_map, edge_metadata = await q.hop_and_filter(
|
||||
seed_entities, concepts,
|
||||
)
|
||||
|
||||
# Emit exploration explain
|
||||
if explain_callback:
|
||||
exp_triples = set_graph(
|
||||
exploration_triples(
|
||||
exp_uri, gnd_uri, len(kg),
|
||||
exp_uri, gnd_uri, len(selected_edges),
|
||||
entities=seed_entities,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
|
|
@ -688,235 +648,63 @@ class GraphRag:
|
|||
await explain_callback(exp_triples, exp_uri)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Invoking LLM...")
|
||||
logger.debug(f"Knowledge graph: {kg}")
|
||||
logger.debug(f"Query: {query}")
|
||||
|
||||
# Semantic pre-filter: reduce edges before expensive LLM scoring
|
||||
if edge_score_limit > 0 and len(kg) > edge_score_limit:
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Selected {len(selected_edges)} edges")
|
||||
for s, p, o in selected_edges:
|
||||
eid = edge_id(s, p, o)
|
||||
meta = edge_metadata.get(eid, {})
|
||||
logger.debug(
|
||||
f"Semantic pre-filter: {len(kg)} edges > "
|
||||
f"limit {edge_score_limit}, filtering..."
|
||||
f" {meta.get('score', 0):.4f} "
|
||||
f"[{meta.get('concept', '')}] "
|
||||
f"{s} | {p} | {o}"
|
||||
)
|
||||
|
||||
# Embed edge descriptions: "subject, predicate, object"
|
||||
edge_descriptions = [
|
||||
f"{s}, {p}, {o}" for s, p, o in kg
|
||||
]
|
||||
|
||||
# Embed concepts and edge descriptions concurrently
|
||||
concept_embed_task = self.embeddings_client.embed(concepts)
|
||||
edge_embed_task = self.embeddings_client.embed(edge_descriptions)
|
||||
|
||||
concept_vectors, edge_vectors = await asyncio.gather(
|
||||
concept_embed_task, edge_embed_task
|
||||
)
|
||||
|
||||
# Score each edge by max cosine similarity to any concept
|
||||
def cosine_similarity(a, b):
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = math.sqrt(sum(x * x for x in a))
|
||||
norm_b = math.sqrt(sum(x * x for x in b))
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
return dot / (norm_a * norm_b)
|
||||
|
||||
edge_scores = []
|
||||
for i, edge_vec in enumerate(edge_vectors):
|
||||
max_sim = max(
|
||||
cosine_similarity(edge_vec, cv)
|
||||
for cv in concept_vectors
|
||||
)
|
||||
edge_scores.append((max_sim, i))
|
||||
|
||||
# Sort by similarity descending and keep top edge_score_limit
|
||||
edge_scores.sort(reverse=True)
|
||||
keep_indices = set(
|
||||
idx for _, idx in edge_scores[:edge_score_limit]
|
||||
)
|
||||
|
||||
# Filter kg and rebuild uri_map
|
||||
filtered_kg = []
|
||||
filtered_uri_map = {}
|
||||
for i, (s, p, o) in enumerate(kg):
|
||||
if i in keep_indices:
|
||||
filtered_kg.append((s, p, o))
|
||||
eid = edge_id(s, p, o)
|
||||
if eid in uri_map:
|
||||
filtered_uri_map[eid] = uri_map[eid]
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Semantic pre-filter kept {len(filtered_kg)} "
|
||||
f"of {len(kg)} edges"
|
||||
)
|
||||
|
||||
kg = filtered_kg
|
||||
uri_map = filtered_uri_map
|
||||
|
||||
# Build edge map: {hash_id: (labeled_s, labeled_p, labeled_o)}
|
||||
# uri_map already maps edge_id -> (uri_s, uri_p, uri_o)
|
||||
edge_map = {}
|
||||
edges_with_ids = []
|
||||
for s, p, o in kg:
|
||||
eid = edge_id(s, p, o)
|
||||
edge_map[eid] = (s, p, o)
|
||||
edges_with_ids.append({
|
||||
"id": eid,
|
||||
"s": s,
|
||||
"p": p,
|
||||
"o": o
|
||||
})
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Built edge map with {len(edge_map)} edges")
|
||||
|
||||
# Step 1a: Edge Scoring - LLM scores edges for relevance
|
||||
scoring_result = await self.prompt_client.prompt(
|
||||
"kg-edge-scoring",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": edges_with_ids
|
||||
}
|
||||
)
|
||||
track_usage(scoring_result)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Edge scoring result: {scoring_result}")
|
||||
|
||||
# Parse scoring response (jsonl) to get edge IDs with scores
|
||||
scored_edges = []
|
||||
|
||||
for obj in scoring_result.objects or []:
|
||||
if isinstance(obj, dict) and "id" in obj and "score" in obj:
|
||||
try:
|
||||
score = int(obj["score"])
|
||||
except (ValueError, TypeError):
|
||||
score = 0
|
||||
scored_edges.append({"id": obj["id"], "score": score})
|
||||
|
||||
# Select top N edges by score
|
||||
scored_edges.sort(key=lambda x: x["score"], reverse=True)
|
||||
top_edges = scored_edges[:edge_limit]
|
||||
selected_ids = {e["id"] for e in top_edges}
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Scored {len(scored_edges)} edges, "
|
||||
f"selected top {len(selected_ids)}"
|
||||
)
|
||||
|
||||
# Filter to selected edges
|
||||
selected_edges = []
|
||||
for eid in selected_ids:
|
||||
if eid in edge_map:
|
||||
selected_edges.append(edge_map[eid])
|
||||
|
||||
# Step 1b: Edge Reasoning + Document Tracing (concurrent)
|
||||
selected_edges_with_ids = [
|
||||
{"id": eid, "s": s, "p": p, "o": o}
|
||||
for eid in selected_ids
|
||||
if eid in edge_map
|
||||
for s, p, o in [edge_map[eid]]
|
||||
]
|
||||
|
||||
# Collect selected edge URIs for document tracing
|
||||
# Step 3: Document tracing
|
||||
selected_edge_uris = [
|
||||
uri_map[eid]
|
||||
for eid in selected_ids
|
||||
if eid in uri_map
|
||||
uri_map[edge_id(s, p, o)]
|
||||
for s, p, o in selected_edges
|
||||
if edge_id(s, p, o) in uri_map
|
||||
]
|
||||
|
||||
# Run reasoning and document tracing concurrently
|
||||
async def _get_reasoning():
|
||||
result = await self.prompt_client.prompt(
|
||||
"kg-edge-reasoning",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": selected_edges_with_ids
|
||||
}
|
||||
)
|
||||
track_usage(result)
|
||||
return result
|
||||
|
||||
reasoning_task = _get_reasoning()
|
||||
doc_trace_task = q.trace_source_documents(selected_edge_uris)
|
||||
|
||||
reasoning_result, source_documents = await asyncio.gather(
|
||||
reasoning_task, doc_trace_task, return_exceptions=True
|
||||
source_documents = await q.trace_source_documents(
|
||||
selected_edge_uris,
|
||||
)
|
||||
|
||||
# Handle exceptions from gather
|
||||
if isinstance(reasoning_result, Exception):
|
||||
logger.warning(
|
||||
f"Edge reasoning failed: {reasoning_result}"
|
||||
)
|
||||
reasoning_result = None
|
||||
if isinstance(source_documents, Exception):
|
||||
logger.warning(
|
||||
f"Document tracing failed: {source_documents}"
|
||||
)
|
||||
source_documents = []
|
||||
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Edge reasoning result: {reasoning_result}")
|
||||
|
||||
# Parse reasoning response (jsonl) and build explainability data
|
||||
reasoning_map = {}
|
||||
|
||||
if reasoning_result is not None:
|
||||
for obj in reasoning_result.objects or []:
|
||||
if isinstance(obj, dict) and "id" in obj:
|
||||
reasoning_map[obj["id"]] = obj.get("reasoning", "")
|
||||
|
||||
# Build focus explainability data with cross-encoder metadata
|
||||
selected_edges_with_reasoning = []
|
||||
for eid in selected_ids:
|
||||
for s, p, o in selected_edges:
|
||||
eid = edge_id(s, p, o)
|
||||
if eid in uri_map:
|
||||
uri_s, uri_p, uri_o = uri_map[eid]
|
||||
meta = edge_metadata.get(eid, {})
|
||||
selected_edges_with_reasoning.append({
|
||||
"edge": (uri_s, uri_p, uri_o),
|
||||
"reasoning": reasoning_map.get(eid, ""),
|
||||
"concept": meta.get("concept", ""),
|
||||
"score": meta.get("score", 0),
|
||||
})
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Filtered to {len(selected_edges)} edges")
|
||||
|
||||
# Emit focus explain after edge selection completes
|
||||
# Emit focus explain
|
||||
if explain_callback:
|
||||
# Sum scoring + reasoning token usage for focus event
|
||||
focus_in = 0
|
||||
focus_out = 0
|
||||
focus_model = None
|
||||
for r in [scoring_result, reasoning_result]:
|
||||
if r is not None:
|
||||
if r.in_token is not None:
|
||||
focus_in += r.in_token
|
||||
if r.out_token is not None:
|
||||
focus_out += r.out_token
|
||||
if r.model is not None:
|
||||
focus_model = r.model
|
||||
|
||||
foc_triples = set_graph(
|
||||
focus_triples(
|
||||
foc_uri, exp_uri, selected_edges_with_reasoning, session_id,
|
||||
in_token=focus_in or None,
|
||||
out_token=focus_out or None,
|
||||
model=focus_model,
|
||||
foc_uri, exp_uri,
|
||||
selected_edges_with_reasoning, session_id,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(foc_triples, foc_uri)
|
||||
|
||||
# Step 2: Synthesis - LLM generates answer from selected edges only
|
||||
# Step 4: Synthesis
|
||||
selected_edge_dicts = [
|
||||
{"s": s, "p": p, "o": o}
|
||||
for s, p, o in selected_edges
|
||||
]
|
||||
|
||||
# Add source document metadata as knowledge edges
|
||||
for s, p, o in source_documents:
|
||||
selected_edge_dicts.append({
|
||||
"s": s, "p": p, "o": o,
|
||||
|
|
@ -928,7 +716,6 @@ class GraphRag:
|
|||
}
|
||||
|
||||
if streaming and chunk_callback:
|
||||
# Accumulate chunks for answer storage while forwarding to callback
|
||||
accumulated_chunks = []
|
||||
|
||||
async def accumulating_callback(chunk, end_of_stream):
|
||||
|
|
@ -942,7 +729,6 @@ class GraphRag:
|
|||
chunk_callback=accumulating_callback
|
||||
)
|
||||
track_usage(synthesis_result)
|
||||
# Combine all chunks into full response
|
||||
resp = "".join(accumulated_chunks)
|
||||
else:
|
||||
synthesis_result = await self.prompt_client.prompt(
|
||||
|
|
@ -955,29 +741,42 @@ class GraphRag:
|
|||
if self.verbose:
|
||||
logger.debug("Query processing complete")
|
||||
|
||||
# Emit synthesis explain after synthesis completes
|
||||
# Emit synthesis explain
|
||||
if explain_callback:
|
||||
synthesis_doc_id = None
|
||||
answer_text = resp if resp else ""
|
||||
|
||||
# Save answer to librarian
|
||||
if save_answer_callback and answer_text:
|
||||
synthesis_doc_id = f"urn:trustgraph:synthesis:{session_id}"
|
||||
try:
|
||||
await save_answer_callback(synthesis_doc_id, answer_text)
|
||||
if self.verbose:
|
||||
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
|
||||
logger.debug(
|
||||
f"Saved answer to librarian: "
|
||||
f"{synthesis_doc_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save answer to librarian: {e}")
|
||||
logger.warning(
|
||||
f"Failed to save answer to librarian: {e}"
|
||||
)
|
||||
synthesis_doc_id = None
|
||||
|
||||
syn_triples = set_graph(
|
||||
synthesis_triples(
|
||||
syn_uri, foc_uri,
|
||||
document_id=synthesis_doc_id,
|
||||
in_token=synthesis_result.in_token if synthesis_result else None,
|
||||
out_token=synthesis_result.out_token if synthesis_result else None,
|
||||
model=synthesis_result.model if synthesis_result else None,
|
||||
in_token=(
|
||||
synthesis_result.in_token
|
||||
if synthesis_result else None
|
||||
),
|
||||
out_token=(
|
||||
synthesis_result.out_token
|
||||
if synthesis_result else None
|
||||
),
|
||||
model=(
|
||||
synthesis_result.model
|
||||
if synthesis_result else None
|
||||
),
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
|
|
@ -993,4 +792,3 @@ class GraphRag:
|
|||
}
|
||||
|
||||
return resp, usage
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from . graph_rag import GraphRag
|
|||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
||||
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
|
||||
from ... base import RerankerClientSpec
|
||||
from ... base import LibrarianSpec
|
||||
|
||||
# Module logger
|
||||
|
|
@ -32,7 +33,6 @@ class Processor(FlowProcessor):
|
|||
triple_limit = params.get("triple_limit", 30)
|
||||
max_subgraph_size = params.get("max_subgraph_size", 150)
|
||||
max_path_length = params.get("max_path_length", 2)
|
||||
edge_score_limit = params.get("edge_score_limit", 30)
|
||||
edge_limit = params.get("edge_limit", 25)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
|
|
@ -43,7 +43,6 @@ class Processor(FlowProcessor):
|
|||
"triple_limit": triple_limit,
|
||||
"max_subgraph_size": max_subgraph_size,
|
||||
"max_path_length": max_path_length,
|
||||
"edge_score_limit": edge_score_limit,
|
||||
"edge_limit": edge_limit,
|
||||
}
|
||||
)
|
||||
|
|
@ -52,7 +51,6 @@ class Processor(FlowProcessor):
|
|||
self.default_triple_limit = triple_limit
|
||||
self.default_max_subgraph_size = max_subgraph_size
|
||||
self.default_max_path_length = max_path_length
|
||||
self.default_edge_score_limit = edge_score_limit
|
||||
self.default_edge_limit = edge_limit
|
||||
|
||||
# Workspace isolation is enforced by the flow layer (flow.workspace).
|
||||
|
|
@ -96,6 +94,13 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
RerankerClientSpec(
|
||||
request_name = "reranker-request",
|
||||
response_name = "reranker-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "response",
|
||||
|
|
@ -163,6 +168,7 @@ class Processor(FlowProcessor):
|
|||
graph_embeddings_client=flow("graph-embeddings-request"),
|
||||
triples_client=flow("triples-request"),
|
||||
prompt_client=flow("prompt-request"),
|
||||
reranker_client=flow("reranker-request"),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
|
@ -186,11 +192,6 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
max_path_length = self.default_max_path_length
|
||||
|
||||
if v.edge_score_limit:
|
||||
edge_score_limit = v.edge_score_limit
|
||||
else:
|
||||
edge_score_limit = self.default_edge_score_limit
|
||||
|
||||
if v.edge_limit:
|
||||
edge_limit = v.edge_limit
|
||||
else:
|
||||
|
|
@ -225,7 +226,7 @@ class Processor(FlowProcessor):
|
|||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
edge_score_limit = edge_score_limit,
|
||||
|
||||
edge_limit = edge_limit,
|
||||
streaming = True,
|
||||
chunk_callback = send_chunk,
|
||||
|
|
@ -241,7 +242,7 @@ class Processor(FlowProcessor):
|
|||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
edge_score_limit = edge_score_limit,
|
||||
|
||||
edge_limit = edge_limit,
|
||||
explain_callback = send_explainability,
|
||||
save_answer_callback = save_answer,
|
||||
|
|
@ -338,18 +339,11 @@ class Processor(FlowProcessor):
|
|||
help=f'Default max path length (default: 2)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--edge-score-limit',
|
||||
type=int,
|
||||
default=30,
|
||||
help=f'Semantic pre-filter limit before LLM scoring (default: 30)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--edge-limit',
|
||||
type=int,
|
||||
default=25,
|
||||
help=f'Max edges after LLM scoring (default: 25)'
|
||||
help=f'Max edges selected per hop by cross-encoder (default: 25)'
|
||||
)
|
||||
|
||||
# Note: Explainability triples are now stored in the request's collection
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue