mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-03 06:51:00 +02:00
Merge branch 'release/v2.6' into master to keep sync.
This commit is contained in:
commit
4aaa1ce915
59 changed files with 5617 additions and 864 deletions
218
README.dev-install.md
Normal file
218
README.dev-install.md
Normal file
|
|
@ -0,0 +1,218 @@
|
|||
# TrustGraph Developer Install Guide
|
||||
|
||||
A guided installer that gets TrustGraph running locally in a single
|
||||
command. It detects your hardware, recommends an LLM backend, installs
|
||||
missing prerequisites, runs the test suite, generates a compose deployment,
|
||||
starts the stack, and opens the Workbench UI.
|
||||
|
||||
> **macOS only.** This installer has only been tested on macOS. If you are
|
||||
> on Linux or Windows, use the standard docker-compose / podman-compose
|
||||
> installation instructions instead.
|
||||
|
||||
## Quick start
|
||||
|
||||
```bash
|
||||
./install_trustgraph.sh
|
||||
```
|
||||
|
||||
The installer walks you through each step interactively. When it finishes,
|
||||
the Workbench UI opens at `http://localhost:8888` and the API gateway is
|
||||
available at `http://localhost:8088/`.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
The installer checks for these and offers to install any that are missing
|
||||
(via Homebrew):
|
||||
|
||||
- **Python 3** with venv support
|
||||
- **Node.js / npx** (drives the `@trustgraph/config` deployment generator)
|
||||
- **Docker** (with Compose) or **Podman** (with podman-compose)
|
||||
- **curl** and **unzip**
|
||||
- **Ollama** (only if you choose local LLMs)
|
||||
|
||||
The installer can also launch Docker Desktop or the Ollama app for you if
|
||||
they are installed but not running.
|
||||
|
||||
## What the installer does
|
||||
|
||||
1. **Detects hardware** -- OS, architecture, CPU cores, memory, and GPU.
|
||||
2. **Recommends an LLM mode** -- `ollama` for machines with >= 16 GB RAM and
|
||||
a GPU or >= 8 cores; `openai` otherwise.
|
||||
3. **Collects configuration** -- API key, LLM provider, model choices,
|
||||
install directory. Answers are saved to
|
||||
`<install-dir>/trustgraph-installer.env` and reused on subsequent runs.
|
||||
4. **Checks and installs prerequisites** -- Python, Node/npx, Docker or
|
||||
Podman, Ollama (if selected).
|
||||
5. **Downloads Ollama models** (if using Ollama) -- chat model
|
||||
(`granite4:350m` by default) and embeddings model (`mxbai-embed-large`).
|
||||
6. **Creates a Python venv** and installs the local TrustGraph packages into
|
||||
it, along with NLTK data and tiktoken caches.
|
||||
7. **Runs the full pytest suite** against the local source tree.
|
||||
8. **Runs `npx @trustgraph/config`** -- the existing interactive config
|
||||
wizard that produces a `deploy.zip` with a compose file.
|
||||
9. **Starts the compose stack** and waits for the API gateway to respond.
|
||||
10. **Bootstraps IAM** and verifies the API key authenticates.
|
||||
11. **Opens the Workbench UI** in your default browser.
|
||||
|
||||
## Command-line options
|
||||
|
||||
| Option | Description |
|
||||
|---|---|
|
||||
| `--install-dir PATH` | Directory for deployment files (default: `./trustgraph-deploy`) |
|
||||
| `--api-url URL` | API gateway URL for health checks (default: `http://localhost:8088/`) |
|
||||
| `--ui-url URL` | Workbench UI URL to open (default: `http://localhost:8888`) |
|
||||
| `--use-existing-compose FILE` | Skip config generation and start this compose file directly |
|
||||
| `--skip-tests` | Do not run the pytest suite |
|
||||
| `--no-launch` | Do not open the Workbench UI at the end |
|
||||
| `--non-interactive` | Accept all defaults without prompting |
|
||||
| `--yes` | Auto-accept confirmation prompts |
|
||||
| `--fresh` | Remove installer-managed files before generating a new deployment |
|
||||
| `--remove-all` | Uninstall: stop containers, remove compose volumes, delete installer files |
|
||||
| `--dry-run` | Print detected hardware and planned defaults, then exit |
|
||||
| `-h`, `--help` | Show the built-in help text |
|
||||
|
||||
## Environment variables
|
||||
|
||||
These override the interactive prompts when set:
|
||||
|
||||
| Variable | Purpose |
|
||||
|---|---|
|
||||
| `TRUSTGRAPH_TOKEN` | Admin/bootstrap API key (must start with `tg_`) |
|
||||
| `TRUSTGRAPH_URL` | API gateway URL |
|
||||
| `TRUSTGRAPH_UI_URL` | Workbench UI URL |
|
||||
| `OPENAI_TOKEN` | OpenAI-compatible API key |
|
||||
| `OPENAI_BASE_URL` | OpenAI-compatible base URL |
|
||||
| `OLLAMA_HOST` / `OLLAMA_BASE_URL` | Ollama service URL |
|
||||
| `OLLAMA_MODEL` | Ollama chat model (default: `granite4:350m`) |
|
||||
| `OLLAMA_EMBEDDINGS_MODEL` | Ollama embeddings model (default: `mxbai-embed-large`) |
|
||||
| `TG_INSTALL_DIR` | Override the install directory |
|
||||
| `TG_VENV_DIR` | Override the Python venv location |
|
||||
| `TG_NLTK_DATA_DIR` | Override the NLTK data directory |
|
||||
| `TIKTOKEN_CACHE_DIR` | Override the tiktoken cache directory |
|
||||
| `TG_HEALTH_TIMEOUT` | Seconds to wait for the API gateway (default: 240) |
|
||||
|
||||
## Choosing an LLM mode
|
||||
|
||||
### OpenAI (or any OpenAI-compatible provider)
|
||||
|
||||
Best when you already have an API key or are running against a remote
|
||||
endpoint. The installer asks for a base URL and an API key.
|
||||
|
||||
```bash
|
||||
OPENAI_TOKEN=sk-... ./install_trustgraph.sh
|
||||
```
|
||||
|
||||
### Ollama (local models)
|
||||
|
||||
Best on machines with enough RAM to run a small model. The installer detects
|
||||
locally installed Ollama models and offers to pull missing ones. It uses
|
||||
`host.docker.internal` so the Docker containers can reach the host-side
|
||||
Ollama service.
|
||||
|
||||
```bash
|
||||
./install_trustgraph.sh # choose "ollama" when prompted
|
||||
```
|
||||
|
||||
### None
|
||||
|
||||
Start the platform without an LLM. Agent and RAG features will not work
|
||||
until you configure one later through the Workbench.
|
||||
|
||||
## Saved answers and re-running
|
||||
|
||||
The installer saves your answers to
|
||||
`<install-dir>/trustgraph-installer.env`. On the next run it loads those
|
||||
answers as defaults, so you can re-run with a single Enter through each
|
||||
prompt.
|
||||
|
||||
To start completely fresh:
|
||||
|
||||
```bash
|
||||
./install_trustgraph.sh --fresh
|
||||
```
|
||||
|
||||
This stops any running containers (keeping Docker volumes), removes
|
||||
installer-managed files, and re-runs the full flow.
|
||||
|
||||
## Using an existing compose file
|
||||
|
||||
If you already have a compose file from the config tool or another source:
|
||||
|
||||
```bash
|
||||
./install_trustgraph.sh --use-existing-compose path/to/docker-compose.yaml
|
||||
```
|
||||
|
||||
This skips the config wizard and `npx` prerequisite check, and goes straight
|
||||
to starting the stack.
|
||||
|
||||
## Non-interactive / CI usage
|
||||
|
||||
```bash
|
||||
TRUSTGRAPH_TOKEN=tg_my-token \
|
||||
OPENAI_TOKEN=sk-... \
|
||||
./install_trustgraph.sh --non-interactive --yes --skip-tests
|
||||
```
|
||||
|
||||
In non-interactive mode the installer uses defaults for every prompt. Pair
|
||||
with `--yes` to auto-accept confirmation prompts and `--skip-tests` if you
|
||||
want a faster run.
|
||||
|
||||
## Dry run
|
||||
|
||||
Preview what the installer would do without making any changes:
|
||||
|
||||
```bash
|
||||
./install_trustgraph.sh --dry-run
|
||||
```
|
||||
|
||||
This prints the detected hardware, recommended LLM mode, and planned
|
||||
install paths, then exits.
|
||||
|
||||
## Uninstalling
|
||||
|
||||
```bash
|
||||
./install_trustgraph.sh --remove-all
|
||||
```
|
||||
|
||||
This stops containers, removes compose-managed volumes, and deletes
|
||||
installer-managed files (venv, deploy output, logs, saved answers). It does
|
||||
**not** remove Docker/Podman itself, container images, Ollama, or Ollama
|
||||
models.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Logs
|
||||
|
||||
All long-running operations write logs to `<install-dir>/logs/`. Key files:
|
||||
|
||||
- `pytest.log` -- test suite output
|
||||
- `compose-up.log` -- docker compose output
|
||||
- `iam-bootstrap.log` -- IAM bootstrap output
|
||||
- `ollama-pull-*.log` -- Ollama model downloads
|
||||
- `pip-*.log` -- Python package installs
|
||||
- `brew-install-*.log` -- Homebrew installs
|
||||
|
||||
### API key rejected after reinstall
|
||||
|
||||
If the API gateway returns 401/403 with your saved key, the compose volumes
|
||||
likely contain IAM data from a previous install with a different key. Run:
|
||||
|
||||
```bash
|
||||
./install_trustgraph.sh --remove-all
|
||||
./install_trustgraph.sh
|
||||
```
|
||||
|
||||
This clears the old volumes and starts fresh.
|
||||
|
||||
### Ollama not reachable from containers
|
||||
|
||||
The Ollama base URL should use `host.docker.internal` instead of
|
||||
`localhost` so that containers running in Docker Desktop can reach the
|
||||
host-side Ollama service. The installer sets this automatically; if you
|
||||
override `OLLAMA_HOST`, make sure the URL is reachable from inside the
|
||||
container network.
|
||||
|
||||
### Docker daemon not running
|
||||
|
||||
The installer detects Docker Desktop and offers to start it. If that
|
||||
doesn't work, start Docker Desktop manually and re-run the installer.
|
||||
523
docs/tech-specs/graph-rag-semantic-filter.md
Normal file
523
docs/tech-specs/graph-rag-semantic-filter.md
Normal file
|
|
@ -0,0 +1,523 @@
|
|||
# GraphRAG Semantic Filter Improvement
|
||||
|
||||
## Problem Statement
|
||||
|
||||
The GraphRAG semantic filter is observed to be ineffective with certain
|
||||
LLM models. Smaller models in particular produce poor-quality edge
|
||||
relevance scores, and there is a suspicion that models trained or
|
||||
evaluated heavily on non-Roman-script datasets offer lower performance
|
||||
on the semantic ranking operation.
|
||||
|
||||
The root cause is that the current implementation delegates edge
|
||||
relevance scoring to the LLM via a prompt that asks the model to
|
||||
assign a 1–10 relevance score to each knowledge-graph edge. This
|
||||
task — ranking structured triples for relevance to a natural-language
|
||||
query — is not well covered in standard LLM evaluation suites, so
|
||||
model benchmark scores are not predictive of performance on this
|
||||
operation. The result is that GraphRAG quality varies unpredictably
|
||||
across model choices, undermining confidence in the pipeline.
|
||||
|
||||
Beyond model variability, the LLM scoring step has further problems:
|
||||
|
||||
- **Cost and latency.** The LLM call consumes tokens and adds
|
||||
latency to every query, yet its output is unreliable. Even when
|
||||
the model performs well, the cost is disproportionate for what is
|
||||
fundamentally a ranking operation.
|
||||
|
||||
- **Subjective scoring scale.** The 1–10 relevance scale gives the
|
||||
model no objective criteria for what constitutes a 5 versus a 7.
|
||||
Different models interpret the scale differently, and even the same
|
||||
model can produce inconsistent scores across runs.
|
||||
|
||||
- **Redundancy with the embedding pre-filter.** The pipeline already
|
||||
contains a cosine-similarity stage that ranks edges by semantic
|
||||
relevance using embeddings. The LLM scoring step is a second
|
||||
filter applied on top of this, and it is not clear that it adds
|
||||
enough value to justify the additional cost and risk of
|
||||
degradation.
|
||||
|
||||
### Industry context
|
||||
|
||||
Semantic ranking is rigorously evaluated on dedicated benchmarks such
|
||||
as MTEB (Massive Text Embedding Benchmark) and BEIR (Benchmarking
|
||||
Information Retrieval), which test retrieval and reranking across
|
||||
diverse domains. The current TrustGraph approach — prompting a
|
||||
general-purpose LLM to score and rank documents (the "listwise"
|
||||
approach) — is known to be poorly optimized for this task. It
|
||||
suffers from positional bias, formatting failures, and
|
||||
inconsistency at scale.
|
||||
|
||||
The industry standard for semantic ranking has moved to
|
||||
cross-encoder models: lightweight, purpose-built models that take a
|
||||
query–document pair as input and produce a single relevance score.
|
||||
These models are fine-tuned on millions of relevance-labelled pairs
|
||||
and dominate retrieval benchmarks. They are fast, deterministic,
|
||||
and do not require an LLM inference call.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Cross-encoder service
|
||||
|
||||
A new request/response service that exposes a generic semantic
|
||||
ranking API. The service is not specific to GraphRAG — it is a
|
||||
reusable building block for any component that needs to rank text
|
||||
by relevance.
|
||||
|
||||
The service interface is pluggable. Alternative implementations
|
||||
can be swapped in behind the same API.
|
||||
|
||||
**Packaging options considered:**
|
||||
|
||||
- *`sentence-transformers`.* Full-featured, widely used.
|
||||
However, it pulls in PyTorch (~2 GB), making containers
|
||||
very large. Tested at ~1.8 seconds for 2200 edges.
|
||||
|
||||
- *`optimum.onnxruntime`.* ONNX-based inference. Still
|
||||
depends on PyTorch at import time despite using ONNX for
|
||||
inference. Tested at ~4.2 seconds for 2200 edges.
|
||||
|
||||
- *`flashrank`.* Lightweight wrapper around ONNX Runtime
|
||||
with a clean API (`Ranker`, `RerankRequest`). No PyTorch
|
||||
dependency. Tested at ~4.4 seconds for 2200 edges.
|
||||
|
||||
- *Pure `onnxruntime` + `tokenizers`.* Leanest option
|
||||
(~200 MB total). Requires manual tokenisation, padding,
|
||||
and numpy array management — more boilerplate to maintain.
|
||||
|
||||
- *External API (e.g. Cohere Rerank).* No local model at
|
||||
all. Adds network latency and an external dependency.
|
||||
|
||||
**Decision:** `flashrank` for the initial implementation.
|
||||
No PyTorch dependency, clean API, comparable performance.
|
||||
The pluggable interface allows swapping to another backend
|
||||
later.
|
||||
|
||||
**Request:**
|
||||
|
||||
- `queries` — list of `{id, text}` objects. In the GraphRAG use
|
||||
case these are the concepts extracted from the user's question.
|
||||
- `documents` — list of `{id, text}` objects. In the GraphRAG
|
||||
use case these are the candidate knowledge-graph edges
|
||||
represented as text.
|
||||
- `limit` — integer. Maximum number of results to return.
|
||||
|
||||
**Scoring:**
|
||||
|
||||
The service produces the cartesian product of all query–document
|
||||
pairs and scores each pair through the cross-encoder model. For
|
||||
each document, the maximum score across all queries is taken as the
|
||||
document's relevance score. Documents are then ranked by this
|
||||
score and the top `limit` results are returned.
|
||||
|
||||
**Response:**
|
||||
|
||||
A list of the top `limit` results, each containing:
|
||||
|
||||
- `document_id` — the ID of the matched document.
|
||||
- `query_id` — the ID of the query (concept) that produced the
|
||||
highest score for this document.
|
||||
- `score` — the relevance score.
|
||||
|
||||
Including `query_id` in the response supports the explainability
|
||||
interface: it records that an edge was selected because it is
|
||||
related to a specific concept.
|
||||
|
||||
### Integration
|
||||
|
||||
The cross-encoder service follows the standard TrustGraph service
|
||||
integration pattern:
|
||||
|
||||
- **Base package (trustgraph-base).** Schema definitions for the
|
||||
cross-encoder request/response messages. A client class that
|
||||
other components (e.g. GraphRAG) can use to call the
|
||||
cross-encoder service. Message translator registration so the
|
||||
pub/sub layer can serialise/deserialise the messages.
|
||||
|
||||
- **Flow package (trustgraph-flow).** The cross-encoder service
|
||||
implementation itself — loads the model, listens for requests,
|
||||
scores pairs, returns results. Flow definition support so the
|
||||
cross-encoder can be introduced into a processing flow via the
|
||||
standard flow configuration. `flashrank` is added as a
|
||||
dependency of `trustgraph-flow`. The service runs in its own
|
||||
container.
|
||||
|
||||
- **API gateway.** A gateway endpoint that routes cross-encoder
|
||||
requests from the HTTP API to the service over pub/sub and
|
||||
returns the response.
|
||||
|
||||
- **CLI tool.** A command-line utility
|
||||
(e.g. `tg-invoke-cross-encoder`) that calls the gateway
|
||||
endpoint for manual testing and debugging.
|
||||
|
||||
### Current GraphRAG pipeline
|
||||
|
||||
The current pipeline follows these steps:
|
||||
|
||||
1. **Concept extraction.** An LLM prompt extracts key concepts
|
||||
from the user's query.
|
||||
|
||||
2. **Graph exploration.** Seed entities are found via embedding
|
||||
similarity. A subgraph is built by multi-hop traversal from
|
||||
the seed entities (up to `max_path_length` hops, capped at
|
||||
`max_subgraph_size` edges).
|
||||
|
||||
3. **Embedding pre-filter.** Each edge is embedded as
|
||||
`"subject, predicate, object"` and scored by cosine similarity
|
||||
against the concept embeddings. The top `edge_score_limit`
|
||||
(default 30) edges are kept.
|
||||
|
||||
4. **LLM edge scoring.** The `kg-edge-scoring` prompt asks the
|
||||
LLM to assign a 1–10 relevance score to each remaining edge.
|
||||
The top `edge_limit` (default 25) edges are kept.
|
||||
|
||||
5. **LLM edge reasoning.** The `kg-edge-reasoning` prompt asks
|
||||
the LLM to explain why each selected edge is relevant to the
|
||||
query. Used for the explainability interface.
|
||||
|
||||
6. **Document tracing.** Selected edges are traced back to their
|
||||
source documents in the librarian. Runs concurrently with
|
||||
step 5.
|
||||
|
||||
7. **Synthesis.** The `kg-synthesis` prompt generates the final
|
||||
answer from the selected edges and source document metadata.
|
||||
|
||||
### Potential improvements
|
||||
|
||||
#### Replace LLM edge scoring with cross-encoder (step 4)
|
||||
|
||||
The LLM edge scoring step is replaced by a call to the
|
||||
cross-encoder service. The candidate edges are the documents and
|
||||
`edge_limit` is the limit. This is a direct substitution: faster,
|
||||
cheaper, deterministic, and more reliable across model choices.
|
||||
The LLM `kg-edge-scoring` prompt is retired.
|
||||
|
||||
**Cross-encoder query input: concepts vs. raw query.** There are
|
||||
two options for what to use as the cross-encoder queries:
|
||||
|
||||
- *Option A: Raw user query.* Pass the original question as a
|
||||
single query string. Simpler, no dependency on concept
|
||||
extraction. However, raw queries contain noise words and
|
||||
conversational phrasing that do not match well against the
|
||||
structured vocabulary of knowledge-graph edges. A single query
|
||||
also means every edge competes against the full question — a
|
||||
partial match on one aspect is diluted.
|
||||
|
||||
- *Option B: Extracted concepts.* Pass the concepts from step 1
|
||||
as separate queries. The concepts are distilled, focused terms
|
||||
that are closer to the language of the edges. With multiple
|
||||
concepts as independent queries, the cross-encoder scores each
|
||||
edge against each concept separately, giving better coverage —
|
||||
an edge only needs to match one concept well to be selected.
|
||||
The trade-off is a dependency on the LLM concept extraction
|
||||
step, but this is already in the pipeline and is a lightweight,
|
||||
reliable LLM call.
|
||||
|
||||
**Decision:** Option B — use extracted concepts. The concept
|
||||
extraction is fast, and the resulting terms produce better
|
||||
cross-encoder matches against structured triples.
|
||||
|
||||
#### Edge text representation
|
||||
|
||||
The current embedding pre-filter represents each edge as
|
||||
`"subject, predicate, object"`. Two changes:
|
||||
|
||||
- **Drop commas.** Commas add tokenisation noise without semantic
|
||||
value.
|
||||
|
||||
- **Drop the subject.** The subject identifies which entity the
|
||||
edge belongs to, but it does not contribute to whether the
|
||||
edge's content is relevant to the query. The predicate and
|
||||
object carry the semantic meaning — what relationship exists
|
||||
and what it connects to. Representing edges as `"{p} {o}"`
|
||||
produces cleaner cross-encoder matches.
|
||||
|
||||
#### Remove the embedding pre-filter (step 3)
|
||||
|
||||
The embedding pre-filter was introduced to reduce the number of
|
||||
edges before the expensive LLM scoring call. With the
|
||||
cross-encoder replacing the LLM call, this cost equation changes.
|
||||
|
||||
**Arguments for removal:**
|
||||
|
||||
- The cross-encoder is fast enough to score the full subgraph
|
||||
directly. In testing, 2200 edges scored in ~1.8 seconds; at
|
||||
the default `max_subgraph_size` of 150 edges, scoring takes
|
||||
a fraction of a second.
|
||||
|
||||
- The pre-filter is a weaker version of what the cross-encoder
|
||||
does. Bi-encoder cosine similarity embeds the query and
|
||||
document independently and compares vectors; the cross-encoder
|
||||
processes both texts together through the full transformer,
|
||||
giving it much better relevance judgement. Running a weaker
|
||||
filter before a stronger one adds latency without improving
|
||||
quality.
|
||||
|
||||
- Removing it eliminates an embedding service call (two batches:
|
||||
concepts + edges) and the associated latency.
|
||||
|
||||
**Arguments for keeping it:**
|
||||
|
||||
- If the subgraph is very large (thousands of edges), the
|
||||
cross-encoder's linear scaling could become a bottleneck.
|
||||
The pre-filter would act as a safety valve.
|
||||
|
||||
- The embedding call is cheap compared to an LLM call, so the
|
||||
overhead is modest.
|
||||
|
||||
**Decision:** Remove the pre-filter. The `max_subgraph_size`
|
||||
parameter (default 150) already caps the number of edges entering
|
||||
this stage, so the cross-encoder will not face an unbounded
|
||||
workload. If very large subgraphs become a concern in future,
|
||||
the pre-filter can be reintroduced or `max_subgraph_size` can be
|
||||
tuned.
|
||||
|
||||
#### Iterative graph traversal with cross-encoder filtering
|
||||
|
||||
The current pipeline performs graph exploration and edge filtering
|
||||
as separate phases: first build the full subgraph (up to
|
||||
`max_path_length` hops), then score and filter edges. An
|
||||
alternative is to interleave traversal and filtering — at each
|
||||
hop, use the cross-encoder to select relevant edges before
|
||||
expanding further.
|
||||
|
||||
**Option A: Big-bang traversal then filter.** Traverse the full
|
||||
subgraph up to `max_path_length` hops from the seed entities,
|
||||
collecting all edges up to `max_subgraph_size`. Then
|
||||
cross-encode the entire result to select the top edges.
|
||||
|
||||
- Simple to implement — the current traversal logic is largely
|
||||
unchanged.
|
||||
- Produces large, unfocused subgraphs. Irrelevant branches are
|
||||
explored and scored even though they will be discarded.
|
||||
- Poorly suited to multi-hop reasoning. For a query about
|
||||
Voyager 1, the subgraph includes Voyager 2's edges because
|
||||
they are within hop distance, and the filter must then
|
||||
separate them.
|
||||
|
||||
**Option B: Iterative hop-and-filter.** At each hop:
|
||||
|
||||
1. Retrieve all edges one hop from the current frontier nodes.
|
||||
2. Cross-encode these edges against the query concepts.
|
||||
3. Select the top relevant edges.
|
||||
4. The target nodes of the selected edges become the frontier
|
||||
for the next hop.
|
||||
5. Repeat up to `max_path_length` hops.
|
||||
|
||||
The final set of selected edges across all hops is the input to
|
||||
synthesis.
|
||||
|
||||
- **Guided exploration.** Each hop focuses the search by
|
||||
pruning irrelevant branches before expanding further. The
|
||||
working set stays small and relevant at every step.
|
||||
- **Multi-hop reasoning works naturally.** Following
|
||||
"Voyager 1 → has-event → crossed the heliopause" succeeds
|
||||
because each hop is individually relevant and leads to the
|
||||
next.
|
||||
- **Smaller total workload.** Fewer edges are scored overall
|
||||
because irrelevant branches are never expanded.
|
||||
- **Trade-off: greedy pruning.** An edge discarded at hop 1
|
||||
cannot lead to relevant edges at hop 2. This is inherent in
|
||||
any bounded traversal, and the cross-encoder is better
|
||||
equipped to make this relevance judgement than a blind hop
|
||||
limit.
|
||||
- **Trade-off: sequential latency.** Hops cannot be
|
||||
parallelised since each depends on the previous. However,
|
||||
each cross-encoder call on a small edge set is very fast
|
||||
(sub-second for typical working sets).
|
||||
|
||||
**Decision:** Option B — iterative hop-and-filter. The guided
|
||||
traversal produces more focused subgraphs and supports multi-hop
|
||||
reasoning, which is a significant quality improvement over the
|
||||
current approach.
|
||||
|
||||
#### Replace LLM edge reasoning with cross-encoder metadata (step 5)
|
||||
|
||||
The current `kg-edge-reasoning` prompt asks the LLM to explain why
|
||||
each edge is relevant. With the cross-encoder now making the
|
||||
selection, this explanation would be a post-hoc fabrication — the
|
||||
LLM was not involved in the decision.
|
||||
|
||||
- *Option A: Keep LLM reasoning.* Generates natural-language
|
||||
explanations but they are not grounded in the actual selection
|
||||
process. Adds an LLM call per query.
|
||||
|
||||
- *Option B: Record cross-encoder metadata.* The cross-encoder
|
||||
already returns the matched concept and score for each selected
|
||||
edge. Use this directly as the explanation.
|
||||
|
||||
**Decision:** Option B. The cross-encoder metadata is the true
|
||||
reason the edge was selected. The `kg-edge-reasoning` prompt is
|
||||
retired.
|
||||
|
||||
#### Explainability interface update
|
||||
|
||||
The explainability interface uses a `Focus` entity containing
|
||||
`EdgeSelection` sub-entities. Each `EdgeSelection` currently
|
||||
carries an `edge` (the quoted triple) and a `reasoning` field
|
||||
(free-text LLM prose), stored as `tg:reasoning` in the
|
||||
provenance graph.
|
||||
|
||||
With the cross-encoder replacing LLM reasoning, the
|
||||
`EdgeSelection` type gains two new predicates and drops one:
|
||||
|
||||
- **Remove** `tg:reasoning` — no longer produced.
|
||||
- **Add** `tg:concept` — the concept text that produced the
|
||||
highest cross-encoder score for this edge.
|
||||
- **Add** `tg:score` — the cross-encoder relevance score.
|
||||
|
||||
This is an evolution of the existing `EdgeSelection` type, not a
|
||||
new entity type. The edge selection sub-entities currently have
|
||||
no `rdf:type` declared; a new `tg:EdgeSelection` type should be
|
||||
added so that consumers can identify them in the provenance
|
||||
graph. The `Focus` entity and its relationship to `Exploration`
|
||||
are unchanged.
|
||||
|
||||
The `Focus` entity's token-usage metadata (`tg:inToken`,
|
||||
`tg:outToken`, `tg:llmModel`) no longer applies since there is
|
||||
no LLM call. These fields are dropped from the Focus entity.
|
||||
|
||||
### Proposed pipeline
|
||||
|
||||
1. **Concept extraction.** Unchanged — LLM extracts key concepts
|
||||
from the user's query.
|
||||
|
||||
2. **Seed entity lookup.** Find seed entities via embedding
|
||||
similarity against the extracted concepts.
|
||||
|
||||
3. **Iterative hop-and-filter.** For each hop up to
|
||||
`max_path_length`:
|
||||
|
||||
a. Retrieve all edges one hop from the current frontier nodes.
|
||||
|
||||
b. Represent each edge as `"{predicate} {object}"`.
|
||||
|
||||
c. Score edges against the extracted concepts using the
|
||||
cross-encoder service.
|
||||
|
||||
d. Select the top relevant edges. The target nodes of the
|
||||
selected edges become the frontier for the next hop.
|
||||
|
||||
4. **Document tracing.** Selected edges are traced back to source
|
||||
documents.
|
||||
|
||||
5. **Synthesis.** The `kg-synthesis` prompt generates the final
|
||||
answer from the selected edges and source document metadata.
|
||||
|
||||
### Implementation order
|
||||
|
||||
1. Cross-encoder service with full integration (base schema,
|
||||
flow service, gateway endpoint, CLI tool).
|
||||
2. GraphRAG pipeline changes (iterative hop-and-filter,
|
||||
edge representation, remove pre-filter).
|
||||
3. Explainability update (`tg:EdgeSelection` type, concept
|
||||
and score predicates, retire `tg:reasoning`).
|
||||
4. Retire `kg-edge-scoring` and `kg-edge-reasoning` prompts.
|
||||
5. Update `tg-invoke-graph-rag` and `tg-show-explain-trace`
|
||||
to display the new metadata. Use these as the main
|
||||
end-to-end test.
|
||||
6. Fix any failing unit tests, then add new tests as needed.
|
||||
7. Write guidance for UX devs to update the UI for the new
|
||||
explainability predicates.
|
||||
|
||||
## UX developer guidance
|
||||
|
||||
This section describes the changes to the explainability interface
|
||||
that affect frontend rendering of GraphRAG Focus events.
|
||||
|
||||
### What changed
|
||||
|
||||
Edge selection in GraphRAG previously used LLM-based scoring and
|
||||
reasoning. Each selected edge carried a `tg:reasoning` predicate
|
||||
with free-text explanation from the LLM. This has been replaced
|
||||
by a cross-encoder reranker that scores edges against query
|
||||
concepts. The explainability data now carries structured metadata
|
||||
instead of free text.
|
||||
|
||||
### Removed
|
||||
|
||||
- **`tg:reasoning`** is no longer emitted on edge selection
|
||||
entities in GraphRAG Focus events. UX code that reads
|
||||
`edge_sel.reasoning` will get an empty string. Remove any
|
||||
rendering that displays a "Reasoning" or "Reason" field for
|
||||
Focus edges.
|
||||
|
||||
- The **`kg-edge-scoring`**, **`kg-edge-reasoning`**, and
|
||||
**`kg-edge-selection`** prompts are retired. Any UX that
|
||||
references these prompt names should be cleaned up.
|
||||
|
||||
### Added
|
||||
|
||||
Each edge selection entity within a Focus event now has three
|
||||
new properties:
|
||||
|
||||
| RDF predicate | API field | Type | Description |
|
||||
|---|---|---|---|
|
||||
| `rdf:type tg:EdgeSelection` | (type check) | — | Each edge selection entity is now explicitly typed |
|
||||
| `tg:concept` | `edge_sel.concept` | `str` | The query concept that matched this edge |
|
||||
| `tg:score` | `edge_sel.score` | `float` or `None` | Cross-encoder relevance score (0.0–1.0) |
|
||||
|
||||
The `tg:edge` predicate (RDF-star quoted triple) is unchanged.
|
||||
|
||||
### How to render
|
||||
|
||||
The recommended rendering for each selected edge in a Focus event:
|
||||
|
||||
```
|
||||
Edge: (subject_label, predicate_label, object_label)
|
||||
Concept: <concept> Score: <score formatted to 4 decimal places>
|
||||
```
|
||||
|
||||
Scores near 1.0 indicate high relevance; scores near 0.0 indicate
|
||||
low relevance. UX could use the score to drive visual indicators
|
||||
such as colour intensity or a relevance bar.
|
||||
|
||||
Edges are not returned in score order — they arrive in traversal
|
||||
order across hops. If the UX wants to display edges ranked by
|
||||
relevance, sort by `edge_sel.score` descending.
|
||||
|
||||
### API classes (Python)
|
||||
|
||||
The `EdgeSelection` dataclass in `trustgraph.api.explainability`
|
||||
has these fields:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class EdgeSelection:
|
||||
uri: str
|
||||
edge: Optional[Dict[str, str]] # {"s": ..., "p": ..., "o": ...}
|
||||
reasoning: str = "" # Legacy, always empty for new traces
|
||||
concept: str = "" # Query concept that matched
|
||||
score: Optional[float] = None # Cross-encoder relevance score
|
||||
```
|
||||
|
||||
These are populated when calling
|
||||
`ExplainabilityClient.fetch_focus_with_edges()` or when parsing
|
||||
inline provenance triples from the streaming response.
|
||||
|
||||
### WebSocket response format
|
||||
|
||||
For inline explainability via the streaming WebSocket, Focus events
|
||||
arrive as `message_type: "explain"` responses. The `explain_triples`
|
||||
array contains the edge selection triples. The relevant predicates
|
||||
in wire format are:
|
||||
|
||||
```json
|
||||
{"s": {"t": "i", "i": "<edge_sel_uri>"},
|
||||
"p": {"t": "i", "i": "https://trustgraph.ai/ns/concept"},
|
||||
"o": {"t": "l", "v": "flyby event"}}
|
||||
|
||||
{"s": {"t": "i", "i": "<edge_sel_uri>"},
|
||||
"p": {"t": "i", "i": "https://trustgraph.ai/ns/score"},
|
||||
"o": {"t": "l", "v": "0.9962"}}
|
||||
```
|
||||
|
||||
Note that `tg:score` is transmitted as a string literal and must
|
||||
be parsed to a float on the client side.
|
||||
|
||||
### Exploration event
|
||||
|
||||
The Exploration event's `edge_count` field now reports the number
|
||||
of edges selected by the cross-encoder across all hops (previously
|
||||
it reported the total number of edges retrieved before filtering).
|
||||
The `entities` list continues to report the seed entities found
|
||||
by vector search.
|
||||
2603
install_trustgraph.sh
Normal file
2603
install_trustgraph.sh
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -95,10 +95,6 @@ class TestGraphRagIntegration:
|
|||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-scoring":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-reasoning":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
|
|
@ -113,14 +109,22 @@ class TestGraphRagIntegration:
|
|||
client.prompt.side_effect = mock_prompt
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_reranker_client(self):
|
||||
"""Mock reranker client for cross-encoder edge filtering"""
|
||||
client = AsyncMock()
|
||||
client.rerank.return_value = []
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client,
|
||||
mock_triples_client, mock_prompt_client):
|
||||
mock_triples_client, mock_reranker_client, mock_prompt_client):
|
||||
"""Create GraphRag instance with mocked dependencies"""
|
||||
return GraphRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
reranker_client=mock_reranker_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
verbose=True
|
||||
)
|
||||
|
|
@ -167,8 +171,8 @@ class TestGraphRagIntegration:
|
|||
# 3. Should query triples to build knowledge subgraph
|
||||
assert mock_triples_client.query_stream.call_count > 0
|
||||
|
||||
# 4. Should call prompt four times (extract-concepts + edge-scoring + edge-reasoning + synthesis)
|
||||
assert mock_prompt_client.prompt.call_count == 4
|
||||
# 4. Should call prompt twice (extract-concepts + synthesis)
|
||||
assert mock_prompt_client.prompt.call_count == 2
|
||||
|
||||
# Verify final response
|
||||
response, usage = response
|
||||
|
|
|
|||
|
|
@ -63,11 +63,6 @@ class TestGraphRagStreaming:
|
|||
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
|
||||
if prompt_id == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_id == "kg-edge-scoring":
|
||||
# Edge scoring returns JSONL with IDs and scores
|
||||
return PromptResult(response_type="text", text='{"id": "abc12345", "score": 0.9}\n')
|
||||
elif prompt_id == "kg-edge-reasoning":
|
||||
return PromptResult(response_type="text", text='{"id": "abc12345", "reasoning": "Relevant to query"}\n')
|
||||
elif prompt_id == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
# Simulate streaming chunks with end_of_stream flags
|
||||
|
|
@ -88,14 +83,23 @@ class TestGraphRagStreaming:
|
|||
client.prompt.side_effect = prompt_side_effect
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_reranker_client(self):
|
||||
"""Mock reranker client for cross-encoder edge filtering"""
|
||||
client = AsyncMock()
|
||||
client.rerank.return_value = []
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def graph_rag_streaming(self, mock_embeddings_client, mock_graph_embeddings_client,
|
||||
mock_triples_client, mock_streaming_prompt_client):
|
||||
mock_triples_client, mock_reranker_client,
|
||||
mock_streaming_prompt_client):
|
||||
"""Create GraphRag instance with streaming support"""
|
||||
return GraphRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
reranker_client=mock_reranker_client,
|
||||
prompt_client=mock_streaming_prompt_client,
|
||||
verbose=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ class TestGraphRagStreamingProtocol:
|
|||
client = AsyncMock()
|
||||
|
||||
async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
if prompt_name == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
|
|
@ -63,14 +63,23 @@ class TestGraphRagStreamingProtocol:
|
|||
client.prompt.side_effect = prompt_side_effect
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_reranker_client(self):
|
||||
"""Mock reranker client for cross-encoder edge filtering"""
|
||||
client = AsyncMock()
|
||||
client.rerank.return_value = []
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client,
|
||||
mock_triples_client, mock_streaming_prompt_client):
|
||||
mock_triples_client, mock_reranker_client,
|
||||
mock_streaming_prompt_client):
|
||||
"""Create GraphRag instance with mocked dependencies"""
|
||||
return GraphRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
reranker_client=mock_reranker_client,
|
||||
prompt_client=mock_streaming_prompt_client,
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -327,7 +336,7 @@ class TestStreamingProtocolEdgeCases:
|
|||
client = AsyncMock()
|
||||
|
||||
async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
if prompt_name == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
|
|
@ -342,10 +351,14 @@ class TestStreamingProtocolEdgeCases:
|
|||
|
||||
client.prompt.side_effect = prompt_with_empties
|
||||
|
||||
mock_reranker = AsyncMock()
|
||||
mock_reranker.rerank.return_value = []
|
||||
|
||||
rag = GraphRag(
|
||||
embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[[0.1]]])),
|
||||
graph_embeddings_client=AsyncMock(query=AsyncMock(return_value=[])),
|
||||
triples_client=AsyncMock(query=AsyncMock(return_value=[])),
|
||||
reranker_client=mock_reranker,
|
||||
prompt_client=client,
|
||||
verbose=False
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,11 +15,20 @@ from openai.types.chat.chat_completion import Choice
|
|||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from trustgraph.model.text_completion.openai.llm import Processor
|
||||
from trustgraph.model.text_completion.openai.variants import get_variant
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.schema import TextCompletionRequest, TextCompletionResponse, Error
|
||||
|
||||
|
||||
def _wire_variant(processor):
|
||||
"""Attach variant methods to a MagicMock processor."""
|
||||
processor.variant = get_variant("openai")
|
||||
processor.thinking = "off"
|
||||
processor._build_kwargs = Processor._build_kwargs.__get__(processor, Processor)
|
||||
processor._extract_content = Processor._extract_content.__get__(processor, Processor)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestTextCompletionIntegration:
|
||||
"""Integration tests for OpenAI text completion service coordination"""
|
||||
|
|
@ -66,6 +75,7 @@ class TestTextCompletionIntegration:
|
|||
|
||||
# Add the actual generate_content method from Processor class
|
||||
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
|
||||
_wire_variant(processor)
|
||||
|
||||
return processor
|
||||
|
||||
|
|
@ -119,6 +129,7 @@ class TestTextCompletionIntegration:
|
|||
|
||||
# Add the actual generate_content method
|
||||
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
|
||||
_wire_variant(processor)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
|
@ -129,7 +140,7 @@ class TestTextCompletionIntegration:
|
|||
assert result.in_token == 50
|
||||
assert result.out_token == 100
|
||||
# Note: result.model comes from mock response, not processor config
|
||||
|
||||
|
||||
# Verify configuration was applied
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
assert call_args.kwargs['model'] == config['model']
|
||||
|
|
@ -247,6 +258,7 @@ class TestTextCompletionIntegration:
|
|||
processor.max_output = processor_config["max_output"]
|
||||
processor.openai = mock_openai_client
|
||||
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
|
||||
_wire_variant(processor)
|
||||
processors.append(processor)
|
||||
|
||||
# Simulate multiple concurrent requests
|
||||
|
|
@ -354,6 +366,7 @@ class TestTextCompletionIntegration:
|
|||
processor.max_output = 2048
|
||||
processor.openai = mock_openai_client
|
||||
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
|
||||
_wire_variant(processor)
|
||||
|
||||
# Act
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from openai.types.chat import ChatCompletionChunk
|
|||
from openai.types.chat.chat_completion_chunk import Choice as StreamChoice, ChoiceDelta
|
||||
|
||||
from trustgraph.model.text_completion.openai.llm import Processor
|
||||
from trustgraph.model.text_completion.openai.variants import get_variant
|
||||
from trustgraph.base import LlmChunk
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_streaming_chunks_valid,
|
||||
|
|
@ -18,6 +19,14 @@ from tests.utils.streaming_assertions import (
|
|||
)
|
||||
|
||||
|
||||
def _wire_variant(processor):
|
||||
"""Attach variant methods to a MagicMock processor."""
|
||||
processor.variant = get_variant("openai")
|
||||
processor.thinking = "off"
|
||||
processor._build_kwargs = Processor._build_kwargs.__get__(processor, Processor)
|
||||
processor._extract_content = Processor._extract_content.__get__(processor, Processor)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestTextCompletionStreaming:
|
||||
"""Integration tests for Text Completion streaming"""
|
||||
|
|
@ -69,6 +78,7 @@ class TestTextCompletionStreaming:
|
|||
processor.generate_content_stream = Processor.generate_content_stream.__get__(
|
||||
processor, Processor
|
||||
)
|
||||
_wire_variant(processor)
|
||||
|
||||
return processor
|
||||
|
||||
|
|
@ -190,6 +200,7 @@ class TestTextCompletionStreaming:
|
|||
processor.generate_content_stream = Processor.generate_content_stream.__get__(
|
||||
processor, Processor
|
||||
)
|
||||
_wire_variant(processor)
|
||||
|
||||
# Act
|
||||
chunks = []
|
||||
|
|
@ -223,6 +234,7 @@ class TestTextCompletionStreaming:
|
|||
processor.generate_content_stream = Processor.generate_content_stream.__get__(
|
||||
processor, Processor
|
||||
)
|
||||
_wire_variant(processor)
|
||||
|
||||
# Act
|
||||
chunks = []
|
||||
|
|
@ -258,6 +270,7 @@ class TestTextCompletionStreaming:
|
|||
processor.generate_content_stream = Processor.generate_content_stream.__get__(
|
||||
processor, Processor
|
||||
)
|
||||
_wire_variant(processor)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
|
|
@ -295,6 +308,7 @@ class TestTextCompletionStreaming:
|
|||
processor.generate_content_stream = Processor.generate_content_stream.__get__(
|
||||
processor, Processor
|
||||
)
|
||||
_wire_variant(processor)
|
||||
|
||||
# Act
|
||||
chunks = []
|
||||
|
|
@ -318,6 +332,7 @@ class TestTextCompletionStreaming:
|
|||
processor.generate_content_stream = Processor.generate_content_stream.__get__(
|
||||
processor, Processor
|
||||
)
|
||||
_wire_variant(processor)
|
||||
|
||||
system_prompt = "You are an expert."
|
||||
user_prompt = "Explain quantum physics."
|
||||
|
|
|
|||
|
|
@ -195,38 +195,6 @@ class TestPromptClientStreamingCallback:
|
|||
assert callback.call_args_list[0] == call("test", False)
|
||||
assert callback.call_args_list[1] == call("", True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kg_prompt_passes_parameters_to_callback(self, prompt_client):
|
||||
"""Test that kg_prompt correctly passes streaming parameters"""
|
||||
# Arrange
|
||||
async def mock_request(request, recipient=None, timeout=600):
|
||||
if recipient:
|
||||
responses = [
|
||||
PromptResponse(text="Answer", object=None, error=None, end_of_stream=False),
|
||||
PromptResponse(text="", object=None, error=None, end_of_stream=True),
|
||||
]
|
||||
for resp in responses:
|
||||
should_stop = await recipient(resp)
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
prompt_client.request = mock_request
|
||||
|
||||
callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
await prompt_client.kg_prompt(
|
||||
query="What is machine learning?",
|
||||
kg=[("subject", "predicate", "object")],
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert callback.call_count == 2
|
||||
assert callback.call_args_list[0] == call("Answer", False)
|
||||
assert callback.call_args_list[1] == call("", True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_prompt_passes_parameters_to_callback(self, prompt_client):
|
||||
"""Test that document_prompt correctly passes streaming parameters"""
|
||||
|
|
|
|||
|
|
@ -107,6 +107,7 @@ class TestGraphRagDagStructure:
|
|||
embeddings_client = AsyncMock()
|
||||
graph_embeddings_client = AsyncMock()
|
||||
triples_client = AsyncMock()
|
||||
reranker_client = AsyncMock()
|
||||
|
||||
embeddings_client.embed.return_value = [[0.1, 0.2]]
|
||||
graph_embeddings_client.query.return_value = [
|
||||
|
|
@ -121,27 +122,22 @@ class TestGraphRagDagStructure:
|
|||
]
|
||||
triples_client.query.return_value = []
|
||||
|
||||
result = MagicMock()
|
||||
result.document_id = "0"
|
||||
result.query_id = "0"
|
||||
result.score = 0.95
|
||||
reranker_client.rerank.return_value = [result]
|
||||
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="concept")
|
||||
elif template_id == "kg-edge-scoring":
|
||||
edges = variables.get("knowledge", [])
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[{"id": e["id"], "score": 10} for e in edges],
|
||||
)
|
||||
elif template_id == "kg-edge-reasoning":
|
||||
edges = variables.get("knowledge", [])
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[{"id": e["id"], "reasoning": "relevant"} for e in edges],
|
||||
)
|
||||
elif template_id == "kg-synthesis":
|
||||
return PromptResult(response_type="text", text="Answer.")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
|
||||
return (prompt_client, embeddings_client, graph_embeddings_client,
|
||||
triples_client, reranker_client)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_chain(self, mock_clients):
|
||||
|
|
@ -152,7 +148,7 @@ class TestGraphRagDagStructure:
|
|||
events.append({"explain_id": explain_id, "triples": triples})
|
||||
|
||||
await rag.query(
|
||||
query="test", explain_callback=explain_cb, edge_score_limit=0,
|
||||
query="test", explain_callback=explain_cb,
|
||||
)
|
||||
|
||||
dag = _collect_events(events)
|
||||
|
|
|
|||
|
|
@ -101,27 +101,27 @@ class TestQuery:
|
|||
assert query.rag == mock_rag
|
||||
assert query.collection == "test_collection"
|
||||
assert query.verbose is False
|
||||
assert query.doc_limit == 20 # Default value
|
||||
assert query.fetch_limit == 20 # Default value
|
||||
|
||||
def test_query_initialization_with_custom_doc_limit(self):
|
||||
"""Test Query initialization with custom doc_limit"""
|
||||
def test_query_initialization_with_custom_fetch_limit(self):
|
||||
"""Test Query initialization with custom fetch_limit"""
|
||||
# Create mock DocumentRag
|
||||
mock_rag = MagicMock()
|
||||
|
||||
# Initialize Query with custom doc_limit
|
||||
# Initialize Query with custom fetch_limit
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
workspace="test_workspace",
|
||||
collection="custom_collection",
|
||||
verbose=True,
|
||||
doc_limit=50
|
||||
fetch_limit=50
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert query.rag == mock_rag
|
||||
assert query.collection == "custom_collection"
|
||||
assert query.verbose is True
|
||||
assert query.doc_limit == 50
|
||||
assert query.fetch_limit == 50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_concepts(self):
|
||||
|
|
@ -224,7 +224,7 @@ class TestQuery:
|
|||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
doc_limit=15
|
||||
fetch_limit=15
|
||||
)
|
||||
|
||||
# Call get_docs with concepts list
|
||||
|
|
@ -377,7 +377,7 @@ class TestQuery:
|
|||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=True,
|
||||
doc_limit=5
|
||||
fetch_limit=5
|
||||
)
|
||||
|
||||
# Call get_docs with concepts
|
||||
|
|
@ -615,7 +615,7 @@ class TestQuery:
|
|||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
doc_limit=10
|
||||
fetch_limit=10
|
||||
)
|
||||
|
||||
docs, chunk_ids = await query.get_docs(["concept A", "concept B"])
|
||||
|
|
|
|||
478
tests/unit/test_retrieval/test_document_rag_rerank.py
Normal file
478
tests/unit/test_retrieval/test_document_rag_rerank.py
Normal file
|
|
@ -0,0 +1,478 @@
|
|||
"""
|
||||
Tests for the optional cross-encoder reranking pass in DocumentRag.query().
|
||||
|
||||
Two behaviours are covered:
|
||||
|
||||
1. No-op: when no reranker_client is wired (the default), query() must feed
|
||||
the LLM the exact same chunks, in the same order, that retrieval produced
|
||||
- byte-identical to the pre-reranker behaviour - and must NOT emit a
|
||||
chunk-selection provenance event.
|
||||
|
||||
2. Rerank: when a reranker_client is wired, the retrieved chunks are reordered
|
||||
and truncated according to the reranker's results, the LLM receives the
|
||||
reranked top-N, and a tg:ChunkSelection (focus) provenance event is emitted
|
||||
carrying the per-surviving-chunk scores and chunk references.
|
||||
|
||||
These are pure orchestration tests - the reranker is a stub, so there is no
|
||||
torch / network dependency.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
from dataclasses import dataclass
|
||||
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.base import PromptResult
|
||||
from trustgraph.schema import RerankerResult
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_WAS_DERIVED_FROM,
|
||||
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
|
||||
TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_CHUNK_SELECTION, TG_SELECTED_CHUNK, TG_SCORE, TG_DOCUMENT,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def find_triple(triples, predicate, subject=None):
|
||||
for t in triples:
|
||||
if t.p.iri == predicate:
|
||||
if subject is None or t.s.iri == subject:
|
||||
return t
|
||||
return None
|
||||
|
||||
|
||||
def find_triples(triples, predicate, subject=None):
|
||||
return [
|
||||
t for t in triples
|
||||
if t.p.iri == predicate
|
||||
and (subject is None or t.s.iri == subject)
|
||||
]
|
||||
|
||||
|
||||
def has_type(triples, subject, rdf_type):
|
||||
return any(
|
||||
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
|
||||
for t in triples
|
||||
)
|
||||
|
||||
|
||||
def derived_from(triples, subject):
|
||||
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
|
||||
return t.o.iri if t else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkMatch:
|
||||
"""Mimics the result from doc_embeddings_client.query()."""
|
||||
chunk_id: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures: three retrievable chunks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CHUNK_A = "urn:chunk:policy-doc-1:chunk-0"
|
||||
CHUNK_B = "urn:chunk:policy-doc-1:chunk-1"
|
||||
CHUNK_C = "urn:chunk:policy-doc-1:chunk-2"
|
||||
|
||||
CHUNK_A_CONTENT = "Customers may return items within 30 days of purchase."
|
||||
CHUNK_B_CONTENT = "Our stores are open from 9am to 5pm on weekdays."
|
||||
CHUNK_C_CONTENT = "Refunds are processed to the original payment method."
|
||||
|
||||
# Retrieval (post-dedupe) order is A, B, C.
|
||||
ORDERED_CONTENT = [CHUNK_A_CONTENT, CHUNK_B_CONTENT, CHUNK_C_CONTENT]
|
||||
ORDERED_CHUNK_IDS = [CHUNK_A, CHUNK_B, CHUNK_C]
|
||||
|
||||
|
||||
def build_mock_clients():
|
||||
"""
|
||||
Build mock subsidiary clients for a document-rag query returning three
|
||||
distinct chunks (A, B, C) in that order.
|
||||
"""
|
||||
prompt_client = AsyncMock()
|
||||
embeddings_client = AsyncMock()
|
||||
doc_embeddings_client = AsyncMock()
|
||||
fetch_chunk = AsyncMock()
|
||||
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="return policy\nrefund")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
|
||||
embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
|
||||
|
||||
# Each concept query returns the same three chunks; dedupe keeps A, B, C.
|
||||
doc_embeddings_client.query.return_value = [
|
||||
ChunkMatch(chunk_id=CHUNK_A),
|
||||
ChunkMatch(chunk_id=CHUNK_B),
|
||||
ChunkMatch(chunk_id=CHUNK_C),
|
||||
]
|
||||
|
||||
async def mock_fetch(chunk_id):
|
||||
return {
|
||||
CHUNK_A: CHUNK_A_CONTENT,
|
||||
CHUNK_B: CHUNK_B_CONTENT,
|
||||
CHUNK_C: CHUNK_C_CONTENT,
|
||||
}[chunk_id]
|
||||
|
||||
fetch_chunk.side_effect = mock_fetch
|
||||
|
||||
prompt_client.document_prompt.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="Items can be returned within 30 days for a full refund.",
|
||||
)
|
||||
|
||||
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
|
||||
|
||||
|
||||
class StubReranker:
|
||||
"""
|
||||
Stub reranker_client mirroring RerankerClient.rerank(): returns a fixed,
|
||||
pre-sorted, truncated list of RerankerResult - exactly the contract the
|
||||
flashrank service guarantees (sorted desc by score, truncated to limit).
|
||||
"""
|
||||
|
||||
def __init__(self, results):
|
||||
self._results = results
|
||||
self.calls = []
|
||||
|
||||
async def rerank(self, queries, documents, limit=10, timeout=300):
|
||||
self.calls.append(
|
||||
{"queries": queries, "documents": documents, "limit": limit}
|
||||
)
|
||||
return self._results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. No-op: reranker_client=None must not change anything
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRerankNoOp:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_documents_passed_to_llm_are_unchanged(self):
|
||||
"""
|
||||
With no reranker wired, document_prompt must receive the retrieved
|
||||
chunks in the original order and length.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients) # reranker_client defaults to None
|
||||
|
||||
await rag.query(query="What is the return policy?")
|
||||
|
||||
call = rag.prompt_client.document_prompt.call_args
|
||||
passed_docs = call.kwargs["documents"]
|
||||
assert passed_docs == ORDERED_CONTENT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_chunk_selection_event_emitted(self):
|
||||
"""
|
||||
Without a reranker, the provenance chain is the original 4 stages:
|
||||
question, grounding, exploration, synthesis - no focus stage.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
assert len(events) == 4
|
||||
types = [
|
||||
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_SYNTHESIS,
|
||||
]
|
||||
for i, expected in enumerate(types):
|
||||
assert has_type(events[i]["triples"], events[i]["explain_id"], expected)
|
||||
|
||||
# No chunk-selection entity anywhere.
|
||||
for e in events:
|
||||
assert not any(
|
||||
t.o.iri == TG_CHUNK_SELECTION
|
||||
for t in e["triples"]
|
||||
if t.p.iri == RDF_TYPE
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesis_derives_from_exploration_when_no_rerank(self):
|
||||
"""
|
||||
No-op lineage is unchanged: synthesis derives from exploration
|
||||
(there is no focus stage). Guards the conditional synthesis parent.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
# events: question, grounding, exploration, synthesis
|
||||
exp_uri = events[2]["explain_id"]
|
||||
syn_event = events[3]
|
||||
assert derived_from(syn_event["triples"], syn_event["explain_id"]) == exp_uri
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Rerank: reorder + truncate + provenance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRerankActive:
|
||||
|
||||
def _reranker_keeping_C_then_A(self):
|
||||
# Reranker says chunk index 2 (C) is best, then index 0 (A); B dropped.
|
||||
# Pre-sorted desc by score and truncated to limit, per the contract.
|
||||
return StubReranker([
|
||||
RerankerResult(document_id="2", query_id="0", score=0.95),
|
||||
RerankerResult(document_id="0", query_id="0", score=0.42),
|
||||
])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_documents_reordered_and_truncated(self):
|
||||
clients = build_mock_clients()
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
await rag.query(query="What is the return policy?")
|
||||
|
||||
call = rag.prompt_client.document_prompt.call_args
|
||||
passed_docs = call.kwargs["documents"]
|
||||
assert passed_docs == [CHUNK_C_CONTENT, CHUNK_A_CONTENT]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reranker_called_with_single_query_and_all_docs(self):
|
||||
clients = build_mock_clients()
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
await rag.query(query="What is the return policy?", doc_limit=2)
|
||||
|
||||
assert len(reranker.calls) == 1
|
||||
c = reranker.calls[0]
|
||||
assert c["queries"] == [{"id": "0", "text": "What is the return policy?"}]
|
||||
assert c["documents"] == [
|
||||
{"id": "0", "text": CHUNK_A_CONTENT},
|
||||
{"id": "1", "text": CHUNK_B_CONTENT},
|
||||
{"id": "2", "text": CHUNK_C_CONTENT},
|
||||
]
|
||||
# The rerank narrows down to the final doc_limit, NOT fetch_limit
|
||||
# (fetch_limit is the over-fetched candidate pool size).
|
||||
assert c["limit"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explicit_fetch_limit_over_fetches_then_narrows(self):
|
||||
"""
|
||||
Semantic guard for the value of reranking AND the maintainer's two-limit
|
||||
contract: an explicit fetch_limit makes retrieval OVER-FETCH a wider
|
||||
candidate pool so the cross-encoder can surface chunks the bi-encoder
|
||||
ranked outside the final doc_limit, then the rerank narrows the pool back
|
||||
down to doc_limit. The fetch_limit is honoured directly (caller controls
|
||||
how hard the reranker works), not overridden by any heuristic.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
# Candidate pool (fetch_limit=60) >> final doc_limit (6).
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?", doc_limit=6, fetch_limit=60,
|
||||
)
|
||||
|
||||
# Over-fetch: the embeddings store is queried with the fetch_limit
|
||||
# budget (60 // 2 concept-vectors = 30 per concept), NOT the doc_limit
|
||||
# budget (6 // 2 = 3). This is the bug guard.
|
||||
q_limit = doc_embeddings_client.query.call_args.kwargs["limit"]
|
||||
assert q_limit == 30
|
||||
|
||||
# Narrow: the rerank keeps the final doc_limit (6), not fetch_limit.
|
||||
assert reranker.calls[0]["limit"] == 6
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_fetch_limit_derives_overfetch_from_doc_limit(self):
|
||||
"""
|
||||
With no fetch_limit passed to query(), the candidate pool falls back to
|
||||
the OVERFETCH_FACTOR x doc_limit heuristic, so over-fetch scales with
|
||||
doc_limit and reranking keeps its recall benefit out of the box.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
# No fetch_limit -> heuristic default.
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
await rag.query(query="What is the return policy?", doc_limit=20)
|
||||
|
||||
# fetch = 3 x 20 = 60 -> 60 // 2 concept-vectors = 30 per concept.
|
||||
q_limit = doc_embeddings_client.query.call_args.kwargs["limit"]
|
||||
assert q_limit == 30
|
||||
# Rerank narrows to the final doc_limit (20).
|
||||
assert reranker.calls[0]["limit"] == 20
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_limit_floored_at_doc_limit(self):
|
||||
"""
|
||||
A fetch_limit below doc_limit is floored up to doc_limit: retrieval must
|
||||
never fetch fewer candidates than the rerank is asked to keep, else the
|
||||
prompt could not be filled.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?", doc_limit=10, fetch_limit=4,
|
||||
)
|
||||
|
||||
# fetch = max(4, 10) = 10 -> 10 // 2 concept-vectors = 5 per concept.
|
||||
q_limit = doc_embeddings_client.query.call_args.kwargs["limit"]
|
||||
assert q_limit == 5
|
||||
assert reranker.calls[0]["limit"] == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_selection_event_emitted(self):
|
||||
clients = build_mock_clients()
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
# Now 5 stages: question, grounding, exploration, focus, synthesis.
|
||||
assert len(events) == 5
|
||||
ordered_types = [
|
||||
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
|
||||
TG_FOCUS, TG_SYNTHESIS,
|
||||
]
|
||||
for i, expected in enumerate(ordered_types):
|
||||
assert has_type(events[i]["triples"], events[i]["explain_id"], expected)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_selection_carries_scores_and_chunk_refs(self):
|
||||
clients = build_mock_clients()
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
focus_event = events[3]
|
||||
foc_uri = focus_event["explain_id"]
|
||||
triples = focus_event["triples"]
|
||||
|
||||
# focus is derived from exploration
|
||||
exp_uri = events[2]["explain_id"]
|
||||
assert derived_from(triples, foc_uri) == exp_uri
|
||||
|
||||
# Two ChunkSelection sub-entities, linked from focus.
|
||||
sel_links = find_triples(triples, TG_SELECTED_CHUNK, foc_uri)
|
||||
assert len(sel_links) == 2
|
||||
|
||||
# Each selection has a ChunkSelection type, a chunk document ref and a score.
|
||||
chunk_refs = set()
|
||||
scores = set()
|
||||
for link in sel_links:
|
||||
sel_uri = link.o.iri
|
||||
assert has_type(triples, sel_uri, TG_CHUNK_SELECTION)
|
||||
doc_ref = find_triple(triples, TG_DOCUMENT, sel_uri)
|
||||
assert doc_ref is not None
|
||||
chunk_refs.add(doc_ref.o.iri)
|
||||
score_t = find_triple(triples, TG_SCORE, sel_uri)
|
||||
assert score_t is not None
|
||||
scores.add(score_t.o.value)
|
||||
|
||||
# Surviving chunks are C and A (B dropped), with the reranker scores.
|
||||
assert chunk_refs == {CHUNK_C, CHUNK_A}
|
||||
assert scores == {"0.95", "0.42"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_focus_triples_in_retrieval_graph(self):
|
||||
clients = build_mock_clients()
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
for t in events[3]["triples"]:
|
||||
assert t.g == "urn:graph:retrieval"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesis_derives_from_focus_when_reranking(self):
|
||||
"""
|
||||
When reranking runs, synthesis must derive from the focus node (the
|
||||
reranked chunks actually fed to the LLM), mirroring GraphRAG - not from
|
||||
exploration, which would leave focus as a dangling branch and
|
||||
misrepresent what fed the answer.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
doc_limit=2,
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
# events: question, grounding, exploration, focus, synthesis
|
||||
foc_uri = events[3]["explain_id"]
|
||||
syn_event = events[4]
|
||||
assert derived_from(syn_event["triples"], syn_event["explain_id"]) == foc_uri
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_docs_skips_reranker(self):
|
||||
"""If retrieval returns no chunks, the reranker is never called."""
|
||||
clients = build_mock_clients()
|
||||
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
|
||||
doc_embeddings_client.query.return_value = [] # no matches
|
||||
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
await rag.query(query="What is the return policy?")
|
||||
|
||||
assert reranker.calls == []
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
"""
|
||||
Cross-layer wiring contract for the Document-RAG reranker (issue #878).
|
||||
|
||||
The Document-RAG processor registers a ``RerankerClientSpec`` for the
|
||||
``reranker-request`` / ``reranker-response`` roles (see
|
||||
``retrieval/document_rag/rag.py``). At flow construction every spec runs
|
||||
``spec.add(flow, processor, definition)``, and ``RequestResponseSpec.add``
|
||||
resolves its topics via ``definition["topics"][name]`` - which raises
|
||||
``KeyError`` if the flow blueprint does not provide those topics.
|
||||
|
||||
This means the monorepo code change is only safe to deploy together with the
|
||||
companion ``trustgraph-templates`` change that wires ``reranker-request`` /
|
||||
``reranker-response`` into the Document-RAG flow (mirroring what templates
|
||||
PR #279 did for GraphRAG via ``graph-store.jsonnet``). These tests pin that
|
||||
contract from the monorepo side:
|
||||
|
||||
* with the reranker topics present (as the updated templates compile them),
|
||||
the spec binds cleanly and registers the client;
|
||||
* without them (the pre-companion blueprint), construction fails fast with a
|
||||
KeyError naming the missing role - documenting exactly why the templates
|
||||
change is required.
|
||||
|
||||
No broker/network: the pub/sub backend is mocked (topics are bound at add()
|
||||
time, connections happen later at start()).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from trustgraph.base import RerankerClientSpec
|
||||
|
||||
|
||||
def _flow():
|
||||
f = MagicMock()
|
||||
f.workspace = "ws"
|
||||
f.name = "document-rag"
|
||||
f.id = "proc1"
|
||||
f.consumer = {}
|
||||
return f
|
||||
|
||||
|
||||
def _processor():
|
||||
p = MagicMock()
|
||||
p.pubsub = MagicMock()
|
||||
p.id = "proc1"
|
||||
p.taskgroup = MagicMock()
|
||||
return p
|
||||
|
||||
|
||||
def _spec():
|
||||
return RerankerClientSpec(
|
||||
request_name="reranker-request",
|
||||
response_name="reranker-response",
|
||||
)
|
||||
|
||||
|
||||
# Topics dict as the UPDATED document-store.jsonnet compiles them
|
||||
# (verified by compiling the template: reranker-request -> request:tg:reranker:{workspace}:{id}).
|
||||
DEFINITION_WITH_RERANKER = {
|
||||
"topics": {
|
||||
"request": "request:tg:document-rag:ws:id",
|
||||
"response": "response:tg:document-rag:ws:id",
|
||||
"reranker-request": "request:tg:reranker:ws:id",
|
||||
"reranker-response": "response:tg:reranker:ws:id",
|
||||
}
|
||||
}
|
||||
|
||||
# Pre-companion blueprint: no reranker topics (document-rag before the templates change).
|
||||
DEFINITION_WITHOUT_RERANKER = {
|
||||
"topics": {
|
||||
"request": "request:tg:document-rag:ws:id",
|
||||
"response": "response:tg:document-rag:ws:id",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_reranker_client_binds_when_flow_provides_topics():
|
||||
flow = _flow()
|
||||
_spec().add(flow, _processor(), DEFINITION_WITH_RERANKER)
|
||||
# The client consumer is registered against the reranker role.
|
||||
assert "reranker-request" in flow.consumer
|
||||
|
||||
|
||||
def test_reranker_client_keyerrors_without_companion_template_topics():
|
||||
with pytest.raises(KeyError) as exc:
|
||||
_spec().add(_flow(), _processor(), DEFINITION_WITHOUT_RERANKER)
|
||||
# Fails fast naming the missing role -> the trustgraph-templates companion
|
||||
# change (wire reranker-request/response into the document-rag flow) is required.
|
||||
assert "reranker-request" in str(exc.value)
|
||||
|
|
@ -66,6 +66,7 @@ class TestDocumentRagService:
|
|||
workspace=ANY, # Workspace comes from flow.workspace (mock)
|
||||
collection="test_coll_1", # Must be from message, not hardcoded default
|
||||
doc_limit=5,
|
||||
fetch_limit=0, # Unset -> core derives the candidate pool
|
||||
explain_callback=ANY, # Explainability callback is always passed
|
||||
save_answer_callback=ANY, # Librarian save callback is always passed
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,54 +15,52 @@ class TestGraphRag:
|
|||
|
||||
def test_graph_rag_initialization_with_defaults(self):
|
||||
"""Test GraphRag initialization with default verbose setting"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = MagicMock()
|
||||
mock_embeddings_client = MagicMock()
|
||||
mock_graph_embeddings_client = MagicMock()
|
||||
mock_triples_client = MagicMock()
|
||||
mock_reranker_client = MagicMock()
|
||||
|
||||
# Initialize GraphRag
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert graph_rag.prompt_client == mock_prompt_client
|
||||
assert graph_rag.embeddings_client == mock_embeddings_client
|
||||
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
||||
assert graph_rag.triples_client == mock_triples_client
|
||||
assert graph_rag.verbose is False # Default value
|
||||
# Verify label_cache is an LRUCacheWithTTL instance
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
|
||||
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
|
||||
|
||||
def test_graph_rag_initialization_with_verbose(self):
|
||||
"""Test GraphRag initialization with verbose enabled"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = MagicMock()
|
||||
mock_embeddings_client = MagicMock()
|
||||
mock_graph_embeddings_client = MagicMock()
|
||||
mock_triples_client = MagicMock()
|
||||
|
||||
# Initialize GraphRag with verbose=True
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
verbose=True
|
||||
reranker_client=mock_reranker_client,
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert graph_rag.prompt_client == mock_prompt_client
|
||||
assert graph_rag.embeddings_client == mock_embeddings_client
|
||||
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
||||
assert graph_rag.triples_client == mock_triples_client
|
||||
assert graph_rag.reranker_client == mock_reranker_client
|
||||
assert graph_rag.verbose is False
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
|
||||
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
|
||||
|
||||
def test_graph_rag_initialization_with_verbose(self):
|
||||
"""Test GraphRag initialization with verbose enabled"""
|
||||
mock_prompt_client = MagicMock()
|
||||
mock_embeddings_client = MagicMock()
|
||||
mock_graph_embeddings_client = MagicMock()
|
||||
mock_triples_client = MagicMock()
|
||||
mock_reranker_client = MagicMock()
|
||||
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
reranker_client=mock_reranker_client,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
assert graph_rag.prompt_client == mock_prompt_client
|
||||
assert graph_rag.embeddings_client == mock_embeddings_client
|
||||
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
||||
assert graph_rag.triples_client == mock_triples_client
|
||||
assert graph_rag.reranker_client == mock_reranker_client
|
||||
assert graph_rag.verbose is True
|
||||
# Verify label_cache is an LRUCacheWithTTL instance
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
|
||||
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
|
||||
|
||||
|
|
@ -365,244 +363,162 @@ class TestQuery:
|
|||
assert "workspace" not in c.kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_never_passes_workspace(self):
|
||||
"""Verify follow_edges never passes workspace to query_stream."""
|
||||
async def test_hop_and_filter_never_passes_workspace(self):
|
||||
"""Verify hop_and_filter never passes workspace to query_stream."""
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_reranker_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
mock_rag.reranker_client = mock_reranker_client
|
||||
mock_rag.label_cache = MagicMock()
|
||||
mock_rag.label_cache.get.return_value = None
|
||||
|
||||
mock_triple = MagicMock()
|
||||
mock_triple.s, mock_triple.p, mock_triple.o = "e1", "p1", "o1"
|
||||
mock_triple.s = "e1"
|
||||
mock_triple.p = "p1"
|
||||
mock_triple.o = "o1"
|
||||
mock_triples_client.query_stream.return_value = [mock_triple]
|
||||
mock_triples_client.query.return_value = []
|
||||
|
||||
result = MagicMock()
|
||||
result.document_id = "0"
|
||||
result.query_id = "0"
|
||||
result.score = 0.9
|
||||
mock_reranker_client.rerank.return_value = [result]
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
triple_limit=10
|
||||
triple_limit=10,
|
||||
)
|
||||
|
||||
subgraph = set()
|
||||
await query.follow_edges("e1", subgraph, path_length=1)
|
||||
await query.hop_and_filter(["e1"], ["concept"])
|
||||
|
||||
for c in mock_triples_client.query_stream.call_args_list:
|
||||
assert "workspace" not in c.kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_basic_functionality(self):
|
||||
"""Test Query.follow_edges method basic triple discovery"""
|
||||
async def test_hop_and_filter_basic_functionality(self):
|
||||
"""Test hop_and_filter retrieves edges and scores them with reranker."""
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_reranker_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
mock_rag.reranker_client = mock_reranker_client
|
||||
mock_rag.label_cache = MagicMock()
|
||||
mock_rag.label_cache.get.return_value = None
|
||||
|
||||
mock_triple1 = MagicMock()
|
||||
mock_triple1.s, mock_triple1.p, mock_triple1.o = "entity1", "predicate1", "object1"
|
||||
mock_triple = MagicMock()
|
||||
mock_triple.s = "entity1"
|
||||
mock_triple.p = "predicate1"
|
||||
mock_triple.o = "object1"
|
||||
mock_triples_client.query_stream.return_value = [mock_triple]
|
||||
mock_triples_client.query.return_value = []
|
||||
|
||||
mock_triple2 = MagicMock()
|
||||
mock_triple2.s, mock_triple2.p, mock_triple2.o = "subject2", "entity1", "object2"
|
||||
|
||||
mock_triple3 = MagicMock()
|
||||
mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1"
|
||||
|
||||
mock_triples_client.query_stream.side_effect = [
|
||||
[mock_triple1], # s=ent
|
||||
[mock_triple2], # p=ent
|
||||
[mock_triple3], # o=ent
|
||||
]
|
||||
result = MagicMock()
|
||||
result.document_id = "0"
|
||||
result.query_id = "0"
|
||||
result.score = 0.95
|
||||
mock_reranker_client.rerank.return_value = [result]
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
triple_limit=10
|
||||
triple_limit=10,
|
||||
edge_limit=25,
|
||||
)
|
||||
|
||||
subgraph = set()
|
||||
await query.follow_edges("entity1", subgraph, path_length=1)
|
||||
|
||||
assert mock_triples_client.query_stream.call_count == 3
|
||||
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s="entity1", p=None, o=None, limit=10,
|
||||
collection="test_collection", batch_size=20, g=""
|
||||
)
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s=None, p="entity1", o=None, limit=10,
|
||||
collection="test_collection", batch_size=20, g=""
|
||||
)
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s=None, p=None, o="entity1", limit=10,
|
||||
collection="test_collection", batch_size=20, g=""
|
||||
selected, uri_map, edge_meta = await query.hop_and_filter(
|
||||
["entity1"], ["test concept"],
|
||||
)
|
||||
|
||||
expected_subgraph = {
|
||||
("entity1", "predicate1", "object1"),
|
||||
("subject2", "entity1", "object2"),
|
||||
("subject3", "predicate3", "entity1")
|
||||
}
|
||||
assert subgraph == expected_subgraph
|
||||
assert len(selected) == 1
|
||||
assert len(uri_map) == 1
|
||||
assert len(edge_meta) == 1
|
||||
|
||||
mock_reranker_client.rerank.assert_called_once()
|
||||
call_kwargs = mock_reranker_client.rerank.call_args
|
||||
assert call_kwargs.kwargs["limit"] == 25
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_with_path_length_zero(self):
|
||||
"""Test Query.follow_edges method with path_length=0"""
|
||||
async def test_hop_and_filter_with_empty_frontier(self):
|
||||
"""Test hop_and_filter with no seed entities returns empty."""
|
||||
mock_rag = MagicMock()
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
selected, uri_map, edge_meta = await query.hop_and_filter([], ["concept"])
|
||||
|
||||
assert selected == []
|
||||
assert uri_map == {}
|
||||
assert edge_meta == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hop_and_filter_filters_label_triples(self):
|
||||
"""Test hop_and_filter skips rdfs:label edges."""
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_reranker_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
mock_rag.reranker_client = mock_reranker_client
|
||||
mock_rag.label_cache = MagicMock()
|
||||
mock_rag.label_cache.get.return_value = None
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
label_triple = MagicMock()
|
||||
label_triple.s = "entity1"
|
||||
label_triple.p = "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
label_triple.o = "Entity One"
|
||||
|
||||
subgraph = set()
|
||||
await query.follow_edges("entity1", subgraph, path_length=0)
|
||||
|
||||
mock_triples_client.query_stream.assert_not_called()
|
||||
assert subgraph == set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_with_max_subgraph_size_limit(self):
|
||||
"""Test Query.follow_edges method respects max_subgraph_size"""
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
mock_triples_client.query_stream.return_value = [label_triple]
|
||||
mock_triples_client.query.return_value = []
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_subgraph_size=2
|
||||
triple_limit=10,
|
||||
)
|
||||
|
||||
subgraph = {("s1", "p1", "o1"), ("s2", "p2", "o2"), ("s3", "p3", "o3")}
|
||||
|
||||
await query.follow_edges("entity1", subgraph, path_length=1)
|
||||
|
||||
mock_triples_client.query_stream.assert_not_called()
|
||||
assert len(subgraph) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_subgraph_method(self):
|
||||
"""Test Query.get_subgraph returns (subgraph, entities, concepts) tuple"""
|
||||
mock_rag = MagicMock()
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_path_length=1
|
||||
selected, uri_map, edge_meta = await query.hop_and_filter(
|
||||
["entity1"], ["concept"],
|
||||
)
|
||||
|
||||
# Mock get_entities to return (entities, concepts) tuple
|
||||
query.get_entities = AsyncMock(
|
||||
return_value=(["entity1", "entity2"], ["concept1"])
|
||||
)
|
||||
|
||||
query.follow_edges_batch = AsyncMock(return_value=(
|
||||
{
|
||||
("entity1", "predicate1", "object1"),
|
||||
("entity2", "predicate2", "object2")
|
||||
},
|
||||
{}
|
||||
))
|
||||
|
||||
subgraph, term_map, entities, concepts = await query.get_subgraph("test query")
|
||||
|
||||
query.get_entities.assert_called_once_with("test query")
|
||||
query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1)
|
||||
|
||||
assert isinstance(subgraph, list)
|
||||
assert len(subgraph) == 2
|
||||
assert ("entity1", "predicate1", "object1") in subgraph
|
||||
assert ("entity2", "predicate2", "object2") in subgraph
|
||||
assert entities == ["entity1", "entity2"]
|
||||
assert concepts == ["concept1"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_labelgraph_method(self):
|
||||
"""Test Query.get_labelgraph returns (labeled_edges, uri_map, entities, concepts)"""
|
||||
mock_rag = MagicMock()
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_subgraph_size=100
|
||||
)
|
||||
|
||||
test_subgraph = [
|
||||
("entity1", "predicate1", "object1"),
|
||||
("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"),
|
||||
("entity3", "predicate3", "object3")
|
||||
]
|
||||
test_entities = ["entity1", "entity3"]
|
||||
test_concepts = ["concept1"]
|
||||
query.get_subgraph = AsyncMock(
|
||||
return_value=(test_subgraph, {}, test_entities, test_concepts)
|
||||
)
|
||||
|
||||
async def mock_maybe_label(entity):
|
||||
label_map = {
|
||||
"entity1": "Human Entity One",
|
||||
"predicate1": "Human Predicate One",
|
||||
"object1": "Human Object One",
|
||||
"entity3": "Human Entity Three",
|
||||
"predicate3": "Human Predicate Three",
|
||||
"object3": "Human Object Three"
|
||||
}
|
||||
return label_map.get(entity, entity)
|
||||
|
||||
query.maybe_label = AsyncMock(side_effect=mock_maybe_label)
|
||||
|
||||
labeled_edges, uri_map, entities, concepts = await query.get_labelgraph("test query")
|
||||
|
||||
query.get_subgraph.assert_called_once_with("test query")
|
||||
|
||||
# Label triples filtered out
|
||||
assert len(labeled_edges) == 2
|
||||
|
||||
# maybe_label called for non-label triples
|
||||
assert query.maybe_label.call_count == 6
|
||||
|
||||
expected_edges = [
|
||||
("Human Entity One", "Human Predicate One", "Human Object One"),
|
||||
("Human Entity Three", "Human Predicate Three", "Human Object Three")
|
||||
]
|
||||
assert labeled_edges == expected_edges
|
||||
|
||||
assert len(uri_map) == 2
|
||||
assert entities == test_entities
|
||||
assert concepts == test_concepts
|
||||
assert selected == []
|
||||
mock_reranker_client.rerank.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_query_method(self):
|
||||
"""Test GraphRag.query method orchestrates full RAG pipeline with provenance"""
|
||||
import json
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import edge_id
|
||||
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_graph_embeddings_client = AsyncMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_reranker_client = AsyncMock()
|
||||
|
||||
expected_response = "This is the RAG response"
|
||||
test_labelgraph = [("Subject", "Predicate", "Object")]
|
||||
test_edge_id = edge_id("Subject", "Predicate", "Object")
|
||||
test_selected_edges = [("Subject", "Predicate", "Object")]
|
||||
test_eid = edge_id("Subject", "Predicate", "Object")
|
||||
test_uri_map = {
|
||||
test_edge_id: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
|
||||
test_eid: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
|
||||
}
|
||||
test_edge_metadata = {
|
||||
test_eid: {"concept": "test concept", "score": 0.95}
|
||||
}
|
||||
test_entities = ["http://example.org/subject"]
|
||||
test_concepts = ["test concept"]
|
||||
|
||||
# Mock prompt responses for the multi-step process
|
||||
mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
|
||||
mock_graph_embeddings_client.query.return_value = []
|
||||
|
||||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-scoring":
|
||||
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "score": 0.9}])
|
||||
elif prompt_name == "kg-edge-reasoning":
|
||||
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "reasoning": "relevant"}])
|
||||
return PromptResult(response_type="text", text="test concept")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return PromptResult(response_type="text", text=expected_response)
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
|
@ -614,16 +530,16 @@ class TestQuery:
|
|||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
verbose=False
|
||||
reranker_client=mock_reranker_client,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# Patch Query.get_labelgraph to return test data
|
||||
original_get_labelgraph = Query.get_labelgraph
|
||||
original_hop_and_filter = Query.hop_and_filter
|
||||
|
||||
async def mock_get_labelgraph(self, query_text):
|
||||
return test_labelgraph, test_uri_map, test_entities, test_concepts
|
||||
async def mock_hop_and_filter(self, seed_entities, concepts):
|
||||
return test_selected_edges, test_uri_map, test_edge_metadata
|
||||
|
||||
Query.get_labelgraph = mock_get_labelgraph
|
||||
Query.hop_and_filter = mock_hop_and_filter
|
||||
|
||||
provenance_events = []
|
||||
|
||||
|
|
@ -636,7 +552,7 @@ class TestQuery:
|
|||
collection="test_collection",
|
||||
entity_limit=25,
|
||||
triple_limit=15,
|
||||
explain_callback=collect_provenance
|
||||
explain_callback=collect_provenance,
|
||||
)
|
||||
|
||||
response_text, usage = response
|
||||
|
|
@ -650,7 +566,6 @@ class TestQuery:
|
|||
assert len(triples) > 0
|
||||
assert prov_id.startswith("urn:trustgraph:")
|
||||
|
||||
# Verify order
|
||||
assert "question" in provenance_events[0][1]
|
||||
assert "grounding" in provenance_events[1][1]
|
||||
assert "exploration" in provenance_events[2][1]
|
||||
|
|
@ -658,4 +573,4 @@ class TestQuery:
|
|||
assert "synthesis" in provenance_events[4][1]
|
||||
|
||||
finally:
|
||||
Query.get_labelgraph = original_get_labelgraph
|
||||
Query.hop_and_filter = original_hop_and_filter
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from trustgraph.provenance.namespaces import (
|
|||
TG_GRAPH_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
|
||||
TG_FOCUS, TG_SYNTHESIS, TG_ANSWER_TYPE,
|
||||
TG_QUERY, TG_CONCEPT, TG_ENTITY, TG_EDGE_COUNT,
|
||||
TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
TG_SELECTED_EDGE, TG_EDGE, TG_SCORE, TG_EDGE_SELECTION,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -91,17 +91,17 @@ def build_mock_clients():
|
|||
1. prompt_client.prompt("extract-concepts", ...) -> concepts
|
||||
2. embeddings_client.embed(concepts) -> vectors
|
||||
3. graph_embeddings_client.query(vector, ...) -> entity matches
|
||||
4. triples_client.query_stream(s/p/o, ...) -> edges (follow_edges_batch)
|
||||
4. triples_client.query_stream(s/p/o, ...) -> edges (hop_and_filter)
|
||||
5. triples_client.query(s, LABEL, ...) -> labels (maybe_label)
|
||||
6. prompt_client.prompt("kg-edge-scoring", ...) -> scored edges
|
||||
7. prompt_client.prompt("kg-edge-reasoning", ...) -> reasoning
|
||||
8. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns [])
|
||||
9. prompt_client.prompt("kg-synthesis", ...) -> final answer
|
||||
6. reranker_client.rerank(queries, documents, limit) -> scored edges
|
||||
7. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns [])
|
||||
8. prompt_client.prompt("kg-synthesis", ...) -> final answer
|
||||
"""
|
||||
prompt_client = AsyncMock()
|
||||
embeddings_client = AsyncMock()
|
||||
graph_embeddings_client = AsyncMock()
|
||||
triples_client = AsyncMock()
|
||||
reranker_client = AsyncMock()
|
||||
|
||||
# 1. Concept extraction
|
||||
prompt_responses = {}
|
||||
|
|
@ -116,7 +116,7 @@ def build_mock_clients():
|
|||
EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_B)),
|
||||
]
|
||||
|
||||
# 4. Triple queries (follow_edges_batch) - return our edges
|
||||
# 4. Triple queries (hop_and_filter) - return our edges
|
||||
kg_triples = [
|
||||
make_schema_triple(*EDGE_1),
|
||||
make_schema_triple(*EDGE_2),
|
||||
|
|
@ -130,9 +130,18 @@ def build_mock_clients():
|
|||
return [] # No labels found, will fall back to URI
|
||||
triples_client.query.side_effect = mock_label_query
|
||||
|
||||
# 6+7. Edge scoring and reasoning: dynamically score/reason about
|
||||
# whatever edges the query method sends us, since edge IDs are computed
|
||||
# from str(Term) representations which include the full dataclass repr.
|
||||
# 6. Reranker: select all documents with high scores
|
||||
async def mock_rerank(queries, documents, limit):
|
||||
results = []
|
||||
for i, doc in enumerate(documents):
|
||||
result = MagicMock()
|
||||
result.document_id = doc["id"]
|
||||
result.query_id = queries[0]["id"] if queries else "0"
|
||||
result.score = 0.9 - (i * 0.1)
|
||||
results.append(result)
|
||||
return results[:limit]
|
||||
reranker_client.rerank.side_effect = mock_rerank
|
||||
|
||||
synthesis_answer = "Quantum computing applies physics principles to computation."
|
||||
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
|
|
@ -141,26 +150,6 @@ def build_mock_clients():
|
|||
response_type="text",
|
||||
text=prompt_responses["extract-concepts"],
|
||||
)
|
||||
elif template_id == "kg-edge-scoring":
|
||||
# Score all edges highly, using the IDs that GraphRag computed
|
||||
edges = variables.get("knowledge", [])
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"id": e["id"], "score": 10 - i}
|
||||
for i, e in enumerate(edges)
|
||||
],
|
||||
)
|
||||
elif template_id == "kg-edge-reasoning":
|
||||
# Provide reasoning for each edge
|
||||
edges = variables.get("knowledge", [])
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
|
||||
for i, e in enumerate(edges)
|
||||
],
|
||||
)
|
||||
elif template_id == "kg-synthesis":
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
|
|
@ -170,7 +159,8 @@ def build_mock_clients():
|
|||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
|
||||
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
|
||||
return (prompt_client, embeddings_client, graph_embeddings_client,
|
||||
triples_client, reranker_client)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -197,7 +187,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0, # skip semantic pre-filter for simplicity
|
||||
|
||||
)
|
||||
|
||||
assert len(events) == 5, (
|
||||
|
|
@ -222,7 +212,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
expected_types = [
|
||||
|
|
@ -260,7 +250,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
uris = [e["explain_id"] for e in events]
|
||||
|
|
@ -297,7 +287,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
q_uri = events[0]["explain_id"]
|
||||
|
|
@ -320,7 +310,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
gnd_uri = events[1]["explain_id"]
|
||||
|
|
@ -344,7 +334,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
exp_uri = events[2]["explain_id"]
|
||||
|
|
@ -355,10 +345,10 @@ class TestGraphRagQueryProvenance:
|
|||
assert int(t.o.value) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_focus_has_selected_edges_with_reasoning(self):
|
||||
async def test_focus_has_selected_edges_with_concept_and_score(self):
|
||||
"""
|
||||
The focus event should carry selected edges as quoted triples
|
||||
with reasoning text.
|
||||
with cross-encoder concept and score metadata.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
|
@ -371,7 +361,6 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
foc_uri = events[3]["explain_id"]
|
||||
|
|
@ -387,11 +376,19 @@ class TestGraphRagQueryProvenance:
|
|||
for t in edge_t:
|
||||
assert t.o.triple is not None, "tg:edge object must be a quoted triple"
|
||||
|
||||
# Should have reasoning
|
||||
reasoning = find_triples(foc_triples, TG_REASONING)
|
||||
assert len(reasoning) > 0, "Focus should have reasoning for selected edges"
|
||||
reasoning_texts = {t.o.value for t in reasoning}
|
||||
assert any(r for r in reasoning_texts), "Reasoning should not be empty"
|
||||
# Edge selections should be typed as EdgeSelection
|
||||
edge_sel_uris = [t.o.iri for t in selected]
|
||||
for uri in edge_sel_uris:
|
||||
assert has_type(foc_triples, uri, TG_EDGE_SELECTION)
|
||||
|
||||
# Should have concept and score
|
||||
concepts = find_triples(foc_triples, TG_CONCEPT)
|
||||
assert len(concepts) > 0, "Focus should have tg:concept for selected edges"
|
||||
|
||||
scores = find_triples(foc_triples, TG_SCORE)
|
||||
assert len(scores) > 0, "Focus should have tg:score for selected edges"
|
||||
for t in scores:
|
||||
float(t.o.value) # Should be parseable as float
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesis_is_answer_type(self):
|
||||
|
|
@ -407,7 +404,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
syn_uri = events[4]["explain_id"]
|
||||
|
|
@ -429,7 +426,7 @@ class TestGraphRagQueryProvenance:
|
|||
result_text, usage = await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
assert result_text == "Quantum computing applies physics principles to computation."
|
||||
|
|
@ -449,7 +446,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
parent_uri=parent,
|
||||
)
|
||||
|
||||
|
|
@ -465,7 +462,7 @@ class TestGraphRagQueryProvenance:
|
|||
|
||||
result_text, usage = await rag.query(
|
||||
query="What is quantum computing?",
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
assert result_text == "Quantum computing applies physics principles to computation."
|
||||
|
|
@ -484,7 +481,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
for event in events:
|
||||
|
|
|
|||
|
|
@ -527,7 +527,8 @@ class AsyncFlowInstance:
|
|||
return result.get("response", "")
|
||||
|
||||
async def document_rag(self, query: str, collection: str,
|
||||
doc_limit: int = 10, **kwargs: Any) -> str:
|
||||
doc_limit: int = 10, fetch_limit: int = 0,
|
||||
**kwargs: Any) -> str:
|
||||
"""
|
||||
Execute document-based RAG query (non-streaming).
|
||||
|
||||
|
|
@ -541,7 +542,9 @@ class AsyncFlowInstance:
|
|||
Args:
|
||||
query: User query text
|
||||
collection: Collection identifier containing documents
|
||||
doc_limit: Maximum number of document chunks to retrieve (default: 10)
|
||||
doc_limit: Document chunks selected into the prompt (default: 10)
|
||||
fetch_limit: Candidate chunks fetched from the vector store before
|
||||
reranking (default: 0 = derive from doc_limit)
|
||||
**kwargs: Additional service-specific parameters
|
||||
|
||||
Returns:
|
||||
|
|
@ -564,6 +567,7 @@ class AsyncFlowInstance:
|
|||
"query": query,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
"fetch-limit": fetch_limit,
|
||||
"streaming": False
|
||||
}
|
||||
request_data.update(kwargs)
|
||||
|
|
@ -646,6 +650,16 @@ class AsyncFlowInstance:
|
|||
|
||||
return await self.request("embeddings", request_data)
|
||||
|
||||
async def rerank(self, queries: list, documents: list, limit: int = 10, **kwargs: Any):
|
||||
request_data = {
|
||||
"queries": queries,
|
||||
"documents": documents,
|
||||
"limit": limit,
|
||||
}
|
||||
request_data.update(kwargs)
|
||||
|
||||
return await self.request("reranker", request_data)
|
||||
|
||||
async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs: Any):
|
||||
"""
|
||||
Query RDF triples using pattern matching.
|
||||
|
|
|
|||
|
|
@ -379,12 +379,14 @@ class AsyncSocketFlowInstance:
|
|||
yield chunk.content
|
||||
|
||||
async def document_rag(self, query: str, collection: str,
|
||||
doc_limit: int = 10, streaming: bool = False, **kwargs):
|
||||
doc_limit: int = 10, fetch_limit: int = 0,
|
||||
streaming: bool = False, **kwargs):
|
||||
"""Document RAG with optional streaming"""
|
||||
request = {
|
||||
"query": query,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
"fetch-limit": fetch_limit,
|
||||
"streaming": streaming
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
|
@ -443,6 +445,19 @@ class AsyncSocketFlowInstance:
|
|||
|
||||
return await self.client._send_request("embeddings", self.flow_id, request)
|
||||
|
||||
async def rerank(self, queries: list, documents: list, limit: int = 10,
|
||||
**kwargs):
|
||||
request = {
|
||||
"queries": queries,
|
||||
"documents": documents,
|
||||
"limit": limit,
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
return await self.client._send_request(
|
||||
"reranker", self.flow_id, request,
|
||||
)
|
||||
|
||||
async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs):
|
||||
"""Triple pattern query"""
|
||||
request = {"limit": limit}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ TG_EDGE_COUNT = TG + "edgeCount"
|
|||
TG_SELECTED_EDGE = TG + "selectedEdge"
|
||||
TG_EDGE = TG + "edge"
|
||||
TG_REASONING = TG + "reasoning"
|
||||
TG_SCORE = TG + "score"
|
||||
TG_DOCUMENT = TG + "document"
|
||||
TG_CONCEPT = TG + "concept"
|
||||
TG_ENTITY = TG + "entity"
|
||||
|
|
@ -66,10 +67,12 @@ RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
|
|||
|
||||
@dataclass
|
||||
class EdgeSelection:
|
||||
"""A selected edge with reasoning from GraphRAG Focus step."""
|
||||
"""A selected edge with cross-encoder metadata from GraphRAG Focus step."""
|
||||
uri: str
|
||||
edge: Optional[Dict[str, str]] = None # {"s": ..., "p": ..., "o": ...}
|
||||
reasoning: str = ""
|
||||
concept: str = ""
|
||||
score: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -209,7 +212,7 @@ class Exploration(ExplainEntity):
|
|||
|
||||
@dataclass
|
||||
class Focus(ExplainEntity):
|
||||
"""Focus entity - selected edges with LLM reasoning (GraphRAG only)."""
|
||||
"""Focus entity - selected edges with cross-encoder scoring (GraphRAG only)."""
|
||||
selected_edge_uris: List[str] = field(default_factory=list)
|
||||
edge_selections: List[EdgeSelection] = field(default_factory=list)
|
||||
|
||||
|
|
@ -418,14 +421,26 @@ def parse_edge_selection_triples(triples: List[Tuple[str, str, Any]]) -> EdgeSel
|
|||
uri = triples[0][0] if triples else ""
|
||||
edge = None
|
||||
reasoning = ""
|
||||
concept = ""
|
||||
score = None
|
||||
|
||||
for s, p, o in triples:
|
||||
if p == TG_EDGE and isinstance(o, dict):
|
||||
edge = o
|
||||
elif p == TG_REASONING:
|
||||
reasoning = o
|
||||
elif p == TG_CONCEPT:
|
||||
concept = o
|
||||
elif p == TG_SCORE:
|
||||
try:
|
||||
score = float(o)
|
||||
except (ValueError, TypeError):
|
||||
score = None
|
||||
|
||||
return EdgeSelection(uri=uri, edge=edge, reasoning=reasoning)
|
||||
return EdgeSelection(
|
||||
uri=uri, edge=edge, reasoning=reasoning,
|
||||
concept=concept, score=score,
|
||||
)
|
||||
|
||||
|
||||
def extract_term_value(term: Dict[str, Any]) -> Any:
|
||||
|
|
|
|||
|
|
@ -415,7 +415,7 @@ class FlowInstance:
|
|||
|
||||
def document_rag(
|
||||
self, query,collection="default",
|
||||
doc_limit=10,
|
||||
doc_limit=10, fetch_limit=0,
|
||||
):
|
||||
"""
|
||||
Execute document-based Retrieval-Augmented Generation (RAG) query.
|
||||
|
|
@ -426,7 +426,9 @@ class FlowInstance:
|
|||
Args:
|
||||
query: Natural language query
|
||||
collection: Collection identifier (default: "default")
|
||||
doc_limit: Maximum document chunks to retrieve (default: 10)
|
||||
doc_limit: Document chunks selected into the prompt (default: 10)
|
||||
fetch_limit: Candidate chunks fetched from the vector store before
|
||||
reranking (default: 0 = derive from doc_limit)
|
||||
|
||||
Returns:
|
||||
str: Generated response incorporating document context
|
||||
|
|
@ -447,6 +449,7 @@ class FlowInstance:
|
|||
"query": query,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
"fetch-limit": fetch_limit,
|
||||
}
|
||||
|
||||
result = self.request(
|
||||
|
|
@ -491,6 +494,19 @@ class FlowInstance:
|
|||
input
|
||||
)["vectors"]
|
||||
|
||||
def rerank(self, queries, documents, limit=10):
|
||||
|
||||
input = {
|
||||
"queries": queries,
|
||||
"documents": documents,
|
||||
"limit": limit,
|
||||
}
|
||||
|
||||
return self.request(
|
||||
"service/reranker",
|
||||
input
|
||||
)
|
||||
|
||||
def graph_embeddings_query(self, text, collection, limit=10):
|
||||
"""
|
||||
Query knowledge graph entities using semantic similarity.
|
||||
|
|
|
|||
|
|
@ -752,6 +752,7 @@ class SocketFlowInstance:
|
|||
query: str,
|
||||
collection: str,
|
||||
doc_limit: int = 10,
|
||||
fetch_limit: int = 0,
|
||||
streaming: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
|
||||
|
|
@ -764,6 +765,7 @@ class SocketFlowInstance:
|
|||
"query": query,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
"fetch-limit": fetch_limit,
|
||||
"streaming": streaming
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
|
@ -785,6 +787,7 @@ class SocketFlowInstance:
|
|||
query: str,
|
||||
collection: str,
|
||||
doc_limit: int = 10,
|
||||
fetch_limit: int = 0,
|
||||
**kwargs: Any
|
||||
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
|
||||
"""Execute document-based RAG query with explainability support."""
|
||||
|
|
@ -792,6 +795,7 @@ class SocketFlowInstance:
|
|||
"query": query,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
"fetch-limit": fetch_limit,
|
||||
"streaming": True,
|
||||
"explainable": True,
|
||||
}
|
||||
|
|
@ -885,6 +889,19 @@ class SocketFlowInstance:
|
|||
|
||||
return self.client._send_request_sync("embeddings", self.flow_id, request, False)
|
||||
|
||||
def rerank(self, queries: list, documents: list, limit: int = 10,
|
||||
**kwargs: Any) -> Dict[str, Any]:
|
||||
request = {
|
||||
"queries": queries,
|
||||
"documents": documents,
|
||||
"limit": limit,
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
return self.client._send_request_sync(
|
||||
"reranker", self.flow_id, request, False,
|
||||
)
|
||||
|
||||
def triples_query(
|
||||
self,
|
||||
s: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
|
|
|
|||
|
|
@ -42,6 +42,8 @@ from . dynamic_tool_service import DynamicToolService
|
|||
from . tool_service_client import ToolServiceClientSpec
|
||||
from . agent_client import AgentClientSpec
|
||||
from . structured_query_client import StructuredQueryClientSpec
|
||||
from . reranker_client import RerankerClientSpec
|
||||
from . reranker_service import RerankerService
|
||||
from . row_embeddings_query_client import RowEmbeddingsQueryClientSpec
|
||||
from . collection_config_handler import CollectionConfigHandler
|
||||
|
||||
|
|
|
|||
|
|
@ -157,21 +157,6 @@ class PromptClient(RequestResponse):
|
|||
timeout = timeout,
|
||||
)
|
||||
|
||||
async def kg_prompt(self, query, kg, timeout=600, streaming=False, chunk_callback=None):
|
||||
return await self.prompt(
|
||||
id = "kg-prompt",
|
||||
variables = {
|
||||
"query": query,
|
||||
"knowledge": [
|
||||
{ "s": v[0], "p": v[1], "o": v[2] }
|
||||
for v in kg
|
||||
]
|
||||
},
|
||||
timeout = timeout,
|
||||
streaming = streaming,
|
||||
chunk_callback = chunk_callback,
|
||||
)
|
||||
|
||||
async def document_prompt(self, query, documents, timeout=600, streaming=False, chunk_callback=None):
|
||||
return await self.prompt(
|
||||
id = "document-prompt",
|
||||
|
|
|
|||
43
trustgraph-base/trustgraph/base/reranker_client.py
Normal file
43
trustgraph-base/trustgraph/base/reranker_client.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
|
||||
from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||
from .. schema import (
|
||||
RerankerRequest, RerankerResponse,
|
||||
RerankerQuery, RerankerDocument,
|
||||
)
|
||||
|
||||
class RerankerClient(RequestResponse):
|
||||
async def rerank(self, queries, documents, limit=10, timeout=300):
|
||||
|
||||
resp = await self.request(
|
||||
RerankerRequest(
|
||||
queries=[
|
||||
RerankerQuery(query_id=q["id"], query_text=q["text"])
|
||||
for q in queries
|
||||
],
|
||||
documents=[
|
||||
RerankerDocument(
|
||||
document_id=d["id"], document_text=d["text"]
|
||||
)
|
||||
for d in documents
|
||||
],
|
||||
limit=limit,
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
return resp.results
|
||||
|
||||
class RerankerClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
self, request_name, response_name,
|
||||
):
|
||||
super(RerankerClientSpec, self).__init__(
|
||||
request_name = request_name,
|
||||
request_schema = RerankerRequest,
|
||||
response_name = response_name,
|
||||
response_schema = RerankerResponse,
|
||||
impl = RerankerClient,
|
||||
)
|
||||
109
trustgraph-base/trustgraph/base/reranker_service.py
Normal file
109
trustgraph-base/trustgraph/base/reranker_service.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import logging
|
||||
|
||||
from .. schema import (
|
||||
RerankerRequest, RerankerResponse, RerankerResult, Error,
|
||||
)
|
||||
from .. exceptions import TooManyRequests
|
||||
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec, ParameterSpec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "reranker"
|
||||
default_concurrency = 1
|
||||
|
||||
class RerankerService(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
concurrency = params.get("concurrency", 1)
|
||||
|
||||
super(RerankerService, self).__init__(**params | {
|
||||
"id": id,
|
||||
"concurrency": concurrency,
|
||||
})
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "request",
|
||||
schema = RerankerRequest,
|
||||
handler = self.on_request,
|
||||
concurrency = concurrency,
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "response",
|
||||
schema = RerankerResponse
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ParameterSpec(
|
||||
name = "model",
|
||||
)
|
||||
)
|
||||
|
||||
async def on_request(self, msg, consumer, flow):
|
||||
|
||||
try:
|
||||
|
||||
request = msg.value()
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
logger.debug(f"Handling reranker request {id}...")
|
||||
|
||||
model = flow("model")
|
||||
results = await self.on_rerank(
|
||||
request.queries, request.documents,
|
||||
request.limit, model=model,
|
||||
)
|
||||
|
||||
await flow("response").send(
|
||||
RerankerResponse(
|
||||
error = None,
|
||||
results = results,
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
logger.debug("Reranker request handled successfully")
|
||||
|
||||
except TooManyRequests as e:
|
||||
raise e
|
||||
|
||||
except Exception as e:
|
||||
|
||||
logger.error(f"Exception in reranker service: {e}", exc_info=True)
|
||||
|
||||
logger.info("Sending error response...")
|
||||
|
||||
await flow.producer["response"].send(
|
||||
RerankerResponse(
|
||||
error=Error(
|
||||
type = "reranker-error",
|
||||
message = str(e),
|
||||
),
|
||||
results=[],
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser: ArgumentParser) -> None:
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--concurrency',
|
||||
type=int,
|
||||
default=default_concurrency,
|
||||
help=f'Concurrent processing threads (default: {default_concurrency})'
|
||||
)
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
|
@ -140,20 +140,6 @@ class PromptClient(BaseClient):
|
|||
timeout=timeout
|
||||
)
|
||||
|
||||
def request_kg_prompt(self, query, kg, timeout=300):
|
||||
|
||||
return self.request(
|
||||
id="kg-prompt",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": [
|
||||
{ "s": v[0], "p": v[1], "o": v[2] }
|
||||
for v in kg
|
||||
]
|
||||
},
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
def request_document_prompt(self, query, documents, timeout=300):
|
||||
|
||||
return self.request(
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from .translators.rows_query import RowsQueryRequestTranslator, RowsQueryRespons
|
|||
from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator
|
||||
from .translators.structured_query import StructuredQueryRequestTranslator, StructuredQueryResponseTranslator
|
||||
from .translators.diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
|
||||
from .translators.reranker import RerankerRequestTranslator, RerankerResponseTranslator
|
||||
from .translators.collection import CollectionManagementRequestTranslator, CollectionManagementResponseTranslator
|
||||
from .translators.sparql_query import SparqlQueryRequestTranslator, SparqlQueryResponseTranslator
|
||||
|
||||
|
|
@ -163,6 +164,12 @@ TranslatorRegistry.register_service(
|
|||
SparqlQueryResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"reranker",
|
||||
RerankerRequestTranslator(),
|
||||
RerankerResponseTranslator()
|
||||
)
|
||||
|
||||
# Register single-direction translators for document loading
|
||||
TranslatorRegistry.register_request("document", DocumentTranslator())
|
||||
TranslatorRegistry.register_request("text-document", TextDocumentTranslator())
|
||||
|
|
|
|||
|
|
@ -20,3 +20,4 @@ from .embeddings_query import (
|
|||
)
|
||||
from .rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
|
||||
from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
|
||||
from .reranker import RerankerRequestTranslator, RerankerResponseTranslator
|
||||
|
|
|
|||
73
trustgraph-base/trustgraph/messaging/translators/reranker.py
Normal file
73
trustgraph-base/trustgraph/messaging/translators/reranker.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
from typing import Dict, Any, Tuple
|
||||
from ...schema import (
|
||||
RerankerRequest, RerankerResponse,
|
||||
RerankerQuery, RerankerDocument, RerankerResult,
|
||||
)
|
||||
from .base import MessageTranslator
|
||||
|
||||
|
||||
class RerankerRequestTranslator(MessageTranslator):
|
||||
|
||||
def decode(self, data: Dict[str, Any]) -> RerankerRequest:
|
||||
return RerankerRequest(
|
||||
queries=[
|
||||
RerankerQuery(
|
||||
query_id=q["query_id"],
|
||||
query_text=q["query_text"],
|
||||
)
|
||||
for q in data.get("queries", [])
|
||||
],
|
||||
documents=[
|
||||
RerankerDocument(
|
||||
document_id=d["document_id"],
|
||||
document_text=d["document_text"],
|
||||
)
|
||||
for d in data.get("documents", [])
|
||||
],
|
||||
limit=data.get("limit", 10),
|
||||
)
|
||||
|
||||
def encode(self, obj: RerankerRequest) -> Dict[str, Any]:
|
||||
return {
|
||||
"queries": [
|
||||
{"query_id": q.query_id, "query_text": q.query_text}
|
||||
for q in obj.queries
|
||||
],
|
||||
"documents": [
|
||||
{"document_id": d.document_id, "document_text": d.document_text}
|
||||
for d in obj.documents
|
||||
],
|
||||
"limit": obj.limit,
|
||||
}
|
||||
|
||||
|
||||
class RerankerResponseTranslator(MessageTranslator):
|
||||
|
||||
def decode(self, data: Dict[str, Any]) -> RerankerResponse:
|
||||
return RerankerResponse(
|
||||
results=[
|
||||
RerankerResult(
|
||||
document_id=r["document_id"],
|
||||
query_id=r["query_id"],
|
||||
score=r["score"],
|
||||
)
|
||||
for r in data.get("results", [])
|
||||
],
|
||||
)
|
||||
|
||||
def encode(self, obj: RerankerResponse) -> Dict[str, Any]:
|
||||
return {
|
||||
"results": [
|
||||
{
|
||||
"document_id": r.document_id,
|
||||
"query_id": r.query_id,
|
||||
"score": r.score,
|
||||
}
|
||||
for r in obj.results
|
||||
],
|
||||
}
|
||||
|
||||
def encode_with_completion(
|
||||
self, obj: RerankerResponse
|
||||
) -> Tuple[Dict[str, Any], bool]:
|
||||
return self.encode(obj), True
|
||||
|
|
@ -12,6 +12,7 @@ class DocumentRagRequestTranslator(MessageTranslator):
|
|||
query=data["query"],
|
||||
collection=data.get("collection", "default"),
|
||||
doc_limit=int(data.get("doc-limit", 20)),
|
||||
fetch_limit=int(data.get("fetch-limit", 0)),
|
||||
streaming=data.get("streaming", False)
|
||||
)
|
||||
|
||||
|
|
@ -20,6 +21,7 @@ class DocumentRagRequestTranslator(MessageTranslator):
|
|||
"query": obj.query,
|
||||
"collection": obj.collection,
|
||||
"doc-limit": obj.doc_limit,
|
||||
"fetch-limit": obj.fetch_limit,
|
||||
"streaming": getattr(obj, "streaming", False)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -64,6 +64,8 @@ from . uris import (
|
|||
docrag_question_uri,
|
||||
docrag_grounding_uri,
|
||||
docrag_exploration_uri,
|
||||
docrag_focus_uri,
|
||||
chunk_selection_uri,
|
||||
docrag_synthesis_uri,
|
||||
)
|
||||
|
||||
|
|
@ -89,9 +91,13 @@ from . namespaces import (
|
|||
TG_IMAGE_TYPE, TG_SUBGRAPH_TYPE,
|
||||
# Query-time provenance predicates (GraphRAG)
|
||||
TG_QUERY, TG_CONCEPT, TG_ENTITY,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING, TG_SCORE,
|
||||
# Edge selection entity type
|
||||
TG_EDGE_SELECTION,
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
|
||||
# Chunk selection entity type
|
||||
TG_CHUNK_SELECTION,
|
||||
# Explainability entity types
|
||||
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_ANALYSIS, TG_CONCLUSION,
|
||||
|
|
@ -130,6 +136,7 @@ from . triples import (
|
|||
# Query-time provenance triple builders (DocumentRAG)
|
||||
docrag_question_triples,
|
||||
docrag_exploration_triples,
|
||||
docrag_chunk_selection_triples,
|
||||
docrag_synthesis_triples,
|
||||
# Utility
|
||||
set_graph,
|
||||
|
|
@ -194,6 +201,8 @@ __all__ = [
|
|||
"docrag_question_uri",
|
||||
"docrag_grounding_uri",
|
||||
"docrag_exploration_uri",
|
||||
"docrag_focus_uri",
|
||||
"chunk_selection_uri",
|
||||
"docrag_synthesis_uri",
|
||||
# Namespaces
|
||||
"PROV", "PROV_ENTITY", "PROV_ACTIVITY", "PROV_AGENT",
|
||||
|
|
@ -212,9 +221,13 @@ __all__ = [
|
|||
"TG_CHUNK_TYPE", "TG_IMAGE_TYPE", "TG_SUBGRAPH_TYPE",
|
||||
# Query-time provenance predicates (GraphRAG)
|
||||
"TG_QUERY", "TG_CONCEPT", "TG_ENTITY",
|
||||
"TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING",
|
||||
"TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING", "TG_SCORE",
|
||||
# Edge selection entity type
|
||||
"TG_EDGE_SELECTION",
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
"TG_CHUNK_COUNT", "TG_SELECTED_CHUNK",
|
||||
# Chunk selection entity type
|
||||
"TG_CHUNK_SELECTION",
|
||||
# Explainability entity types
|
||||
"TG_QUESTION", "TG_GROUNDING", "TG_EXPLORATION", "TG_FOCUS", "TG_SYNTHESIS",
|
||||
"TG_ANALYSIS", "TG_CONCLUSION",
|
||||
|
|
@ -250,6 +263,7 @@ __all__ = [
|
|||
# Query-time provenance triple builders (DocumentRAG)
|
||||
"docrag_question_triples",
|
||||
"docrag_exploration_triples",
|
||||
"docrag_chunk_selection_triples",
|
||||
"docrag_synthesis_triples",
|
||||
# Agent provenance triple builders
|
||||
"agent_session_triples",
|
||||
|
|
|
|||
|
|
@ -66,12 +66,19 @@ TG_EDGE_COUNT = TG + "edgeCount"
|
|||
TG_SELECTED_EDGE = TG + "selectedEdge"
|
||||
TG_EDGE = TG + "edge"
|
||||
TG_REASONING = TG + "reasoning"
|
||||
TG_SCORE = TG + "score"
|
||||
TG_DOCUMENT = TG + "document" # Reference to document in librarian
|
||||
|
||||
# Edge selection entity type (cross-encoder scored edge in Focus)
|
||||
TG_EDGE_SELECTION = TG + "EdgeSelection"
|
||||
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
TG_CHUNK_COUNT = TG + "chunkCount"
|
||||
TG_SELECTED_CHUNK = TG + "selectedChunk"
|
||||
|
||||
# Chunk selection entity type (cross-encoder reranked chunk in Focus)
|
||||
TG_CHUNK_SELECTION = TG + "ChunkSelection"
|
||||
|
||||
# Extraction provenance entity types
|
||||
TG_DOCUMENT_TYPE = TG + "Document"
|
||||
TG_PAGE_TYPE = TG + "Page"
|
||||
|
|
|
|||
|
|
@ -24,10 +24,14 @@ from . namespaces import (
|
|||
TG_ELEMENT_TYPES, TG_TABLE_COUNT, TG_IMAGE_COUNT,
|
||||
# Query-time provenance predicates (GraphRAG)
|
||||
TG_QUERY, TG_CONCEPT, TG_ENTITY,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, TG_SCORE,
|
||||
TG_DOCUMENT,
|
||||
# Edge selection entity type
|
||||
TG_EDGE_SELECTION,
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
|
||||
# Chunk selection entity type
|
||||
TG_CHUNK_SELECTION,
|
||||
# Explainability entity types
|
||||
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
# Unifying types
|
||||
|
|
@ -38,7 +42,10 @@ from . namespaces import (
|
|||
TG_IN_TOKEN, TG_OUT_TOKEN,
|
||||
)
|
||||
|
||||
from . uris import activity_uri, agent_uri, subgraph_uri, edge_selection_uri
|
||||
from . uris import (
|
||||
activity_uri, agent_uri, subgraph_uri, edge_selection_uri,
|
||||
chunk_selection_uri,
|
||||
)
|
||||
|
||||
|
||||
def set_graph(triples: List[Triple], graph: str) -> List[Triple]:
|
||||
|
|
@ -536,10 +543,9 @@ def focus_triples(
|
|||
_triple(focus_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)),
|
||||
]
|
||||
|
||||
# Add each selected edge with its reasoning via intermediate entity
|
||||
# Add each selected edge with metadata via intermediate entity
|
||||
for idx, edge_info in enumerate(selected_edges_with_reasoning):
|
||||
edge = edge_info.get("edge")
|
||||
reasoning = edge_info.get("reasoning", "")
|
||||
|
||||
if edge:
|
||||
s, p, o = edge
|
||||
|
|
@ -552,13 +558,32 @@ def focus_triples(
|
|||
_triple(focus_uri, TG_SELECTED_EDGE, _iri(edge_sel_uri))
|
||||
)
|
||||
|
||||
# Type the edge selection entity
|
||||
triples.append(
|
||||
_triple(edge_sel_uri, RDF_TYPE, _iri(TG_EDGE_SELECTION))
|
||||
)
|
||||
|
||||
# Attach quoted triple to edge selection entity
|
||||
quoted = _quoted_triple(s, p, o)
|
||||
triples.append(
|
||||
Triple(s=_iri(edge_sel_uri), p=_iri(TG_EDGE), o=quoted)
|
||||
)
|
||||
|
||||
# Attach reasoning to edge selection entity
|
||||
# Structured cross-encoder metadata
|
||||
concept = edge_info.get("concept")
|
||||
if concept:
|
||||
triples.append(
|
||||
_triple(edge_sel_uri, TG_CONCEPT, _literal(concept))
|
||||
)
|
||||
|
||||
score = edge_info.get("score")
|
||||
if score is not None:
|
||||
triples.append(
|
||||
_triple(edge_sel_uri, TG_SCORE, _literal(str(score)))
|
||||
)
|
||||
|
||||
# Legacy reasoning text (for non-cross-encoder callers)
|
||||
reasoning = edge_info.get("reasoning", "")
|
||||
if reasoning:
|
||||
triples.append(
|
||||
_triple(edge_sel_uri, TG_REASONING, _literal(reasoning))
|
||||
|
|
@ -698,6 +723,75 @@ def docrag_exploration_triples(
|
|||
return triples
|
||||
|
||||
|
||||
def docrag_chunk_selection_triples(
|
||||
focus_uri: str,
|
||||
exploration_uri: str,
|
||||
selected_chunks_with_scores: List[dict],
|
||||
session_id: str,
|
||||
) -> List[Triple]:
|
||||
"""
|
||||
Build triples for a document RAG focus entity (chunks selected by the
|
||||
cross-encoder reranker).
|
||||
|
||||
Mirrors GraphRAG's focus_triples / tg:EdgeSelection pattern: a Focus entity
|
||||
derived from exploration, with one ChunkSelection sub-entity per surviving
|
||||
chunk carrying the chunk reference and the reranker score.
|
||||
|
||||
Structure:
|
||||
<focus> a tg:Focus ; prov:wasDerivedFrom <exploration> .
|
||||
<focus> tg:selectedChunk <chunk_sel_0> .
|
||||
<chunk_sel_0> a tg:ChunkSelection .
|
||||
<chunk_sel_0> tg:document <chunk_id> .
|
||||
<chunk_sel_0> tg:score "0.97" .
|
||||
|
||||
Args:
|
||||
focus_uri: URI of the focus entity (from docrag_focus_uri)
|
||||
exploration_uri: URI of the parent exploration entity
|
||||
selected_chunks_with_scores: List of dicts with 'chunk_id' and 'score'
|
||||
session_id: Session UUID for generating chunk selection URIs
|
||||
|
||||
Returns:
|
||||
List of Triple objects
|
||||
"""
|
||||
triples = [
|
||||
_triple(focus_uri, RDF_TYPE, _iri(PROV_ENTITY)),
|
||||
_triple(focus_uri, RDF_TYPE, _iri(TG_FOCUS)),
|
||||
_triple(focus_uri, RDFS_LABEL, _literal("Chunk Selection")),
|
||||
_triple(focus_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)),
|
||||
]
|
||||
|
||||
for idx, chunk_info in enumerate(selected_chunks_with_scores):
|
||||
chunk_id = chunk_info.get("chunk_id")
|
||||
if not chunk_id:
|
||||
continue
|
||||
|
||||
chunk_sel_uri = chunk_selection_uri(session_id, idx)
|
||||
|
||||
# Link focus to chunk selection entity
|
||||
triples.append(
|
||||
_triple(focus_uri, TG_SELECTED_CHUNK, _iri(chunk_sel_uri))
|
||||
)
|
||||
|
||||
# Type the chunk selection entity
|
||||
triples.append(
|
||||
_triple(chunk_sel_uri, RDF_TYPE, _iri(TG_CHUNK_SELECTION))
|
||||
)
|
||||
|
||||
# Reference the actual chunk (in librarian)
|
||||
triples.append(
|
||||
_triple(chunk_sel_uri, TG_DOCUMENT, _iri(chunk_id))
|
||||
)
|
||||
|
||||
# Cross-encoder score
|
||||
score = chunk_info.get("score")
|
||||
if score is not None:
|
||||
triples.append(
|
||||
_triple(chunk_sel_uri, TG_SCORE, _literal(str(score)))
|
||||
)
|
||||
|
||||
return triples
|
||||
|
||||
|
||||
def docrag_synthesis_triples(
|
||||
synthesis_uri: str,
|
||||
exploration_uri: str,
|
||||
|
|
|
|||
|
|
@ -309,6 +309,35 @@ def docrag_exploration_uri(session_id: str) -> str:
|
|||
return f"urn:trustgraph:docrag:{session_id}/exploration"
|
||||
|
||||
|
||||
def docrag_focus_uri(session_id: str) -> str:
|
||||
"""
|
||||
Generate URI for a document RAG focus entity (chunks selected by the
|
||||
cross-encoder reranker).
|
||||
|
||||
Args:
|
||||
session_id: The session UUID.
|
||||
|
||||
Returns:
|
||||
URN in format: urn:trustgraph:docrag:{uuid}/focus
|
||||
"""
|
||||
return f"urn:trustgraph:docrag:{session_id}/focus"
|
||||
|
||||
|
||||
def chunk_selection_uri(session_id: str, chunk_index: int) -> str:
|
||||
"""
|
||||
Generate URI for a chunk selection item (links a reranked chunk to its
|
||||
score). Mirrors edge_selection_uri for GraphRAG.
|
||||
|
||||
Args:
|
||||
session_id: The session UUID.
|
||||
chunk_index: Index of this chunk in the selection (0-based).
|
||||
|
||||
Returns:
|
||||
URN in format: urn:trustgraph:prov:chunk:{uuid}:{index}
|
||||
"""
|
||||
return f"urn:trustgraph:prov:chunk:{session_id}:{chunk_index}"
|
||||
|
||||
|
||||
def docrag_synthesis_uri(session_id: str) -> str:
|
||||
"""
|
||||
Generate URI for a document RAG synthesis entity (final answer).
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@ from . namespaces import (
|
|||
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
|
||||
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
|
||||
TG_SUBAGENT_GOAL, TG_PLAN_STEP,
|
||||
TG_EDGE_SELECTION, TG_SCORE,
|
||||
TG_CHUNK_SELECTION,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -93,6 +95,8 @@ TG_CLASS_LABELS = [
|
|||
_label_triple(TG_FINDING, "Finding"),
|
||||
_label_triple(TG_PLAN_TYPE, "Plan"),
|
||||
_label_triple(TG_STEP_RESULT, "Step Result"),
|
||||
_label_triple(TG_EDGE_SELECTION, "Edge Selection"),
|
||||
_label_triple(TG_CHUNK_SELECTION, "Chunk Selection"),
|
||||
]
|
||||
|
||||
# TrustGraph predicate labels
|
||||
|
|
@ -117,6 +121,7 @@ TG_PREDICATE_LABELS = [
|
|||
_label_triple(TG_ENTITY, "entity"),
|
||||
_label_triple(TG_SUBAGENT_GOAL, "subagent goal"),
|
||||
_label_triple(TG_PLAN_STEP, "plan step"),
|
||||
_label_triple(TG_SCORE, "score"),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,4 +15,5 @@ from .diagnosis import *
|
|||
from .collection import *
|
||||
from .storage import *
|
||||
from .tool_service import *
|
||||
from .sparql_query import *
|
||||
from .sparql_query import *
|
||||
from .reranker import *
|
||||
|
|
@ -6,17 +6,6 @@ from ..core.primitives import Error
|
|||
|
||||
# Prompt services, abstract the prompt generation
|
||||
|
||||
# extract-definitions:
|
||||
# chunk -> definitions
|
||||
# extract-relationships:
|
||||
# chunk -> relationships
|
||||
# kg-prompt:
|
||||
# query, triples -> answer
|
||||
# document-prompt:
|
||||
# query, documents -> answer
|
||||
# extract-rows
|
||||
# schema, chunk -> rows
|
||||
|
||||
@dataclass
|
||||
class PromptRequest:
|
||||
id: str = ""
|
||||
|
|
@ -46,4 +35,4 @@ class PromptResponse:
|
|||
out_token: int | None = None
|
||||
model: str | None = None
|
||||
|
||||
############################################################################
|
||||
############################################################################
|
||||
|
|
|
|||
35
trustgraph-base/trustgraph/schema/services/reranker.py
Normal file
35
trustgraph-base/trustgraph/schema/services/reranker.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..core.primitives import Error
|
||||
|
||||
############################################################################
|
||||
|
||||
# Cross-encoder reranker
|
||||
|
||||
@dataclass
|
||||
class RerankerQuery:
|
||||
query_id: str = ""
|
||||
query_text: str = ""
|
||||
|
||||
@dataclass
|
||||
class RerankerDocument:
|
||||
document_id: str = ""
|
||||
document_text: str = ""
|
||||
|
||||
@dataclass
|
||||
class RerankerRequest:
|
||||
queries: list[RerankerQuery] = field(default_factory=list)
|
||||
documents: list[RerankerDocument] = field(default_factory=list)
|
||||
limit: int = 10
|
||||
|
||||
@dataclass
|
||||
class RerankerResult:
|
||||
document_id: str = ""
|
||||
query_id: str = ""
|
||||
score: float = 0.0
|
||||
|
||||
@dataclass
|
||||
class RerankerResponse:
|
||||
error: Error | None = None
|
||||
results: list[RerankerResult] = field(default_factory=list)
|
||||
|
|
@ -40,7 +40,10 @@ class GraphRagResponse:
|
|||
class DocumentRagQuery:
|
||||
query: str = ""
|
||||
collection: str = ""
|
||||
doc_limit: int = 0
|
||||
doc_limit: int = 0 # docs selected into the synthesis prompt
|
||||
fetch_limit: int = 0 # candidate pool fetched from the vector store
|
||||
# before reranking (0 = derive from doc_limit;
|
||||
# values below doc_limit are raised to it)
|
||||
streaming: bool = False
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -71,6 +71,7 @@ tg-invoke-rows-query = "trustgraph.cli.invoke_rows_query:main"
|
|||
tg-invoke-sparql-query = "trustgraph.cli.invoke_sparql_query:main"
|
||||
tg-invoke-row-embeddings = "trustgraph.cli.invoke_row_embeddings:main"
|
||||
tg-invoke-prompt = "trustgraph.cli.invoke_prompt:main"
|
||||
tg-invoke-reranker = "trustgraph.cli.invoke_reranker:main"
|
||||
tg-invoke-structured-query = "trustgraph.cli.invoke_structured_query:main"
|
||||
tg-load-doc-embeds = "trustgraph.cli.load_doc_embeds:main"
|
||||
tg-load-kg-core = "trustgraph.cli.load_kg_core:main"
|
||||
|
|
|
|||
|
|
@ -21,10 +21,12 @@ default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
|||
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
|
||||
default_collection = 'default'
|
||||
default_doc_limit = 10
|
||||
default_fetch_limit = 0
|
||||
|
||||
|
||||
def question_explainable(
|
||||
url, flow_id, question_text, collection, doc_limit, token=None, debug=False,
|
||||
url, flow_id, question_text, collection, doc_limit, fetch_limit=0,
|
||||
token=None, debug=False,
|
||||
workspace="default",
|
||||
):
|
||||
"""Execute document RAG with explainability - shows provenance events inline."""
|
||||
|
|
@ -39,6 +41,7 @@ def question_explainable(
|
|||
query=question_text,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
fetch_limit=fetch_limit,
|
||||
):
|
||||
if isinstance(item, RAGChunk):
|
||||
# Print response content
|
||||
|
|
@ -97,7 +100,7 @@ def question_explainable(
|
|||
|
||||
|
||||
def question(
|
||||
url, flow_id, question_text, collection, doc_limit,
|
||||
url, flow_id, question_text, collection, doc_limit, fetch_limit=0,
|
||||
streaming=True, token=None, explainable=False, debug=False,
|
||||
show_usage=False, workspace="default",
|
||||
):
|
||||
|
|
@ -109,6 +112,7 @@ def question(
|
|||
question_text=question_text,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
fetch_limit=fetch_limit,
|
||||
token=token,
|
||||
debug=debug,
|
||||
workspace=workspace,
|
||||
|
|
@ -128,6 +132,7 @@ def question(
|
|||
query=question_text,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
fetch_limit=fetch_limit,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
|
|
@ -155,6 +160,7 @@ def question(
|
|||
query=question_text,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
fetch_limit=fetch_limit,
|
||||
)
|
||||
print(result.text)
|
||||
|
||||
|
|
@ -214,7 +220,15 @@ def main():
|
|||
'-d', '--doc-limit',
|
||||
type=int,
|
||||
default=default_doc_limit,
|
||||
help=f'Document limit (default: {default_doc_limit})'
|
||||
help=f'Documents selected into the prompt (default: {default_doc_limit})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--fetch-limit',
|
||||
type=int,
|
||||
default=default_fetch_limit,
|
||||
help='Candidate documents fetched from the vector store before '
|
||||
'reranking (default: derive from doc-limit)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
|
@ -251,6 +265,7 @@ def main():
|
|||
question_text=args.question,
|
||||
collection=args.collection,
|
||||
doc_limit=args.doc_limit,
|
||||
fetch_limit=args.fetch_limit,
|
||||
streaming=not args.no_streaming,
|
||||
token=args.token,
|
||||
explainable=args.explainable,
|
||||
|
|
|
|||
|
|
@ -112,14 +112,13 @@ def _question_explainable_api(
|
|||
if focus_full and focus_full.edge_selections:
|
||||
for edge_sel in focus_full.edge_selections:
|
||||
if edge_sel.edge:
|
||||
# Resolve labels for edge components
|
||||
s_label, p_label, o_label = explain_client.resolve_edge_labels(
|
||||
edge_sel.edge, collection
|
||||
)
|
||||
print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr)
|
||||
if edge_sel.reasoning:
|
||||
r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning
|
||||
print(f" Reason: {r_short}", file=sys.stderr)
|
||||
if edge_sel.concept or edge_sel.score is not None:
|
||||
score_str = f"{edge_sel.score:.4f}" if edge_sel.score is not None else "?"
|
||||
print(f" Concept: {edge_sel.concept} Score: {score_str}", file=sys.stderr)
|
||||
|
||||
elif isinstance(entity, Synthesis):
|
||||
print(f"\n [synthesis] {prov_id}", file=sys.stderr)
|
||||
|
|
|
|||
127
trustgraph-cli/trustgraph/cli/invoke_reranker.py
Normal file
127
trustgraph-cli/trustgraph/cli/invoke_reranker.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
"""
|
||||
Invokes the reranker service to score and rank documents by relevance
|
||||
to one or more queries.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from trustgraph.api import Api
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
|
||||
|
||||
def query(url, flow_id, queries, documents, limit, token=None,
|
||||
workspace="default"):
|
||||
|
||||
api = Api(url=url, token=token, workspace=workspace)
|
||||
socket = api.socket()
|
||||
flow = socket.flow(flow_id)
|
||||
|
||||
try:
|
||||
|
||||
query_objects = [
|
||||
{"query_id": str(i), "query_text": q}
|
||||
for i, q in enumerate(queries)
|
||||
]
|
||||
|
||||
document_objects = [
|
||||
{"document_id": str(i), "document_text": d}
|
||||
for i, d in enumerate(documents)
|
||||
]
|
||||
|
||||
result = flow.rerank(
|
||||
queries=query_objects,
|
||||
documents=document_objects,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
if "error" in result and result["error"]:
|
||||
err = result["error"]
|
||||
print(f"Error: [{err.get('type', '')}] {err.get('message', '')}")
|
||||
return
|
||||
|
||||
for r in result.get("results", []):
|
||||
doc_idx = int(r["document_id"])
|
||||
query_idx = int(r["query_id"])
|
||||
print(
|
||||
f" {r['score']:.4f} | "
|
||||
f"query: {queries[query_idx]} | "
|
||||
f"doc: {documents[doc_idx]}"
|
||||
)
|
||||
|
||||
finally:
|
||||
socket.close()
|
||||
|
||||
def main():
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='tg-invoke-reranker',
|
||||
description=__doc__,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-u', '--url',
|
||||
default=default_url,
|
||||
help=f'API URL (default: {default_url})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--token',
|
||||
default=default_token,
|
||||
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-w', '--workspace',
|
||||
default=default_workspace,
|
||||
help=f'Workspace (default: {default_workspace})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-f', '--flow-id',
|
||||
default="default",
|
||||
help=f'Flow ID (default: default)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-l', '--limit',
|
||||
type=int,
|
||||
default=10,
|
||||
help='Maximum number of results (default: 10)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-q', '--query',
|
||||
action='append',
|
||||
required=True,
|
||||
help='Query text (can be specified multiple times)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'documents',
|
||||
nargs='+',
|
||||
help='Documents to rerank',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
||||
query(
|
||||
url=args.url,
|
||||
flow_id=args.flow_id,
|
||||
queries=args.query,
|
||||
documents=args.documents,
|
||||
limit=args.limit,
|
||||
token=args.token,
|
||||
workspace=args.workspace,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print("Exception:", e, flush=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -203,9 +203,9 @@ def print_graphrag_text(trace, explain_client, flow, collection, api=None, show_
|
|||
)
|
||||
print(f" {i}. ({s_label}, {p_label}, {o_label})")
|
||||
|
||||
if edge_sel.reasoning:
|
||||
r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning
|
||||
print(f" Reasoning: {r_short}")
|
||||
if edge_sel.concept or edge_sel.score is not None:
|
||||
score_str = f"{edge_sel.score:.4f}" if edge_sel.score is not None else "?"
|
||||
print(f" Concept: {edge_sel.concept} Score: {score_str}")
|
||||
|
||||
if show_provenance and edge_sel.edge:
|
||||
provenance = trace_edge_provenance(
|
||||
|
|
@ -519,7 +519,8 @@ def trace_to_dict(trace, trace_type):
|
|||
"selected_edges": [
|
||||
{
|
||||
"edge": edge_sel.edge,
|
||||
"reasoning": edge_sel.reasoning,
|
||||
"concept": edge_sel.concept,
|
||||
"score": edge_sel.score,
|
||||
}
|
||||
for edge_sel in focus.edge_selections
|
||||
],
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ dependencies = [
|
|||
"faiss-cpu",
|
||||
"falkordb",
|
||||
"fastembed",
|
||||
"flashrank",
|
||||
"ibis",
|
||||
"jsonschema",
|
||||
"langchain",
|
||||
|
|
@ -83,6 +84,7 @@ graph-embeddings-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone:
|
|||
graph-embeddings-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run"
|
||||
graph-embeddings = "trustgraph.embeddings.graph_embeddings:run"
|
||||
graph-rag = "trustgraph.retrieval.graph_rag:run"
|
||||
reranker-flashrank = "trustgraph.reranker.flashrank:run"
|
||||
kg-extract-agent = "trustgraph.extract.kg.agent:run"
|
||||
kg-extract-definitions = "trustgraph.extract.kg.definitions:run"
|
||||
kg-extract-rows = "trustgraph.extract.kg.rows:run"
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
|
|||
from . document_embeddings_query import DocumentEmbeddingsQueryRequestor
|
||||
from . row_embeddings_query import RowEmbeddingsQueryRequestor
|
||||
from . mcp_tool import McpToolRequestor
|
||||
from . reranker import RerankerRequestor
|
||||
from . text_load import TextLoad
|
||||
from . document_load import DocumentLoad
|
||||
|
||||
|
|
@ -74,6 +75,7 @@ request_response_dispatchers = {
|
|||
"structured-diag": StructuredDiagRequestor,
|
||||
"row-embeddings": RowEmbeddingsQueryRequestor,
|
||||
"sparql": SparqlQueryRequestor,
|
||||
"reranker": RerankerRequestor,
|
||||
}
|
||||
|
||||
system_dispatchers = {
|
||||
|
|
|
|||
31
trustgraph-flow/trustgraph/gateway/dispatch/reranker.py
Normal file
31
trustgraph-flow/trustgraph/gateway/dispatch/reranker.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
|
||||
from ... schema import RerankerRequest, RerankerResponse
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class RerankerRequestor(ServiceRequestor):
|
||||
def __init__(
|
||||
self, backend, request_queue, response_queue, timeout,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
super(RerankerRequestor, self).__init__(
|
||||
backend=backend,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
request_schema=RerankerRequest,
|
||||
response_schema=RerankerResponse,
|
||||
subscription = subscriber,
|
||||
consumer_name = consumer,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("reranker")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("reranker")
|
||||
|
||||
def to_request(self, body):
|
||||
return self.request_translator.decode(body)
|
||||
|
||||
def from_response(self, message):
|
||||
return self.response_translator.encode_with_completion(message)
|
||||
|
|
@ -518,6 +518,7 @@ _FLOW_SERVICES = {
|
|||
"structured-diag": "structured-query:read",
|
||||
"row-embeddings": "row-embeddings:read",
|
||||
"sparql": "sparql:read",
|
||||
"reranker": "reranker",
|
||||
}
|
||||
for _kind, _cap in _FLOW_SERVICES.items():
|
||||
_register_flow_kind("flow-service", _kind, _cap)
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ _READER_CAPS = {
|
|||
"row-embeddings:read",
|
||||
"llm",
|
||||
"embeddings",
|
||||
"reranker",
|
||||
"mcp",
|
||||
"config:read",
|
||||
"flows:read",
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import logging
|
|||
|
||||
from .... exceptions import TooManyRequests, LlmError
|
||||
from .... base import LlmService, LlmResult, LlmChunk
|
||||
from . variants import get_variant, DEFAULT_VARIANT, VARIANTS
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -21,6 +22,7 @@ default_temperature = 0.0
|
|||
default_max_output = 4096
|
||||
default_api_key = os.getenv("OPENAI_TOKEN")
|
||||
default_base_url = os.getenv("OPENAI_BASE_URL")
|
||||
default_thinking = "off"
|
||||
|
||||
if default_base_url is None or default_base_url == "":
|
||||
default_base_url = "https://api.openai.com/v1"
|
||||
|
|
@ -28,16 +30,21 @@ if default_base_url is None or default_base_url == "":
|
|||
class Processor(LlmService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
|
||||
model = params.get("model", default_model)
|
||||
api_key = params.get("api_key", default_api_key)
|
||||
base_url = params.get("url", default_base_url)
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
max_output = params.get("max_output", default_max_output)
|
||||
thinking = params.get("thinking", default_thinking)
|
||||
variant_name = params.get("variant", DEFAULT_VARIANT)
|
||||
|
||||
if not api_key:
|
||||
api_key = "not-set"
|
||||
|
||||
self.variant = get_variant(variant_name)
|
||||
self.thinking = thinking
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"model": model,
|
||||
|
|
@ -56,13 +63,28 @@ class Processor(LlmService):
|
|||
else:
|
||||
self.openai = OpenAI(api_key=api_key)
|
||||
|
||||
logger.info("OpenAI LLM service initialized")
|
||||
logger.info(
|
||||
f"OpenAI LLM service initialized "
|
||||
f"(variant={self.variant.name}, thinking={self.thinking})"
|
||||
)
|
||||
|
||||
def _build_kwargs(self, model_name, temperature):
|
||||
"""Build API call kwargs using the active variant."""
|
||||
return self.variant.completion_kwargs(
|
||||
max_output=self.max_output,
|
||||
temperature=temperature,
|
||||
thinking=self.thinking,
|
||||
)
|
||||
|
||||
def _extract_content(self, message):
|
||||
"""Extract visible content from a response message."""
|
||||
if hasattr(self.variant, "extract_content"):
|
||||
return self.variant.extract_content(message)
|
||||
return message.content
|
||||
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
|
|
@ -72,31 +94,38 @@ class Processor(LlmService):
|
|||
|
||||
try:
|
||||
|
||||
resp = self.openai.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
temperature=effective_temperature,
|
||||
max_completion_tokens=self.max_output,
|
||||
api_kwargs = self._build_kwargs(model_name, effective_temperature)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
resp = self.variant.create_completion(
|
||||
self.openai, model_name, messages, **api_kwargs,
|
||||
)
|
||||
|
||||
|
||||
inputtokens = resp.usage.prompt_tokens
|
||||
outputtokens = resp.usage.completion_tokens
|
||||
logger.debug(f"LLM response: {resp.choices[0].message.content}")
|
||||
|
||||
content = self._extract_content(resp.choices[0].message)
|
||||
thinking = self.variant.extract_thinking(resp.choices[0].message)
|
||||
|
||||
logger.debug(f"LLM response: {content}")
|
||||
if thinking:
|
||||
logger.debug(f"LLM thinking: {thinking[:200]}...")
|
||||
logger.info(f"Input Tokens: {inputtokens}")
|
||||
logger.info(f"Output Tokens: {outputtokens}")
|
||||
|
||||
resp = LlmResult(
|
||||
text = resp.choices[0].message.content,
|
||||
text = content,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = model_name
|
||||
|
|
@ -136,9 +165,7 @@ class Processor(LlmService):
|
|||
Stream content generation from OpenAI.
|
||||
Yields LlmChunk objects with is_final=True on the last chunk.
|
||||
"""
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model (streaming): {model_name}")
|
||||
|
|
@ -147,30 +174,26 @@ class Processor(LlmService):
|
|||
prompt = system + "\n\n" + prompt
|
||||
|
||||
try:
|
||||
response = self.openai.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
temperature=effective_temperature,
|
||||
max_completion_tokens=self.max_output,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True}
|
||||
)
|
||||
api_kwargs = self._build_kwargs(model_name, effective_temperature)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
|
||||
# Stream chunks
|
||||
for chunk in response:
|
||||
async for chunk in self.variant.create_completion_stream(
|
||||
self.openai, model_name, messages, **api_kwargs,
|
||||
):
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
yield LlmChunk(
|
||||
text=chunk.choices[0].delta.content,
|
||||
|
|
@ -254,6 +277,20 @@ class Processor(LlmService):
|
|||
help=f'LLM max output tokens (default: {default_max_output})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--thinking',
|
||||
choices=["off", "low", "medium", "high"],
|
||||
default=default_thinking,
|
||||
help=f'Thinking/reasoning effort level (default: {default_thinking})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--variant',
|
||||
choices=sorted(VARIANTS.keys()),
|
||||
default=DEFAULT_VARIANT,
|
||||
help=f'API variant (default: {DEFAULT_VARIANT})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,219 @@
|
|||
"""
|
||||
OpenAI API variant profiles.
|
||||
|
||||
Different providers expose OpenAI-compatible APIs with subtle differences
|
||||
in parameter names, thinking/reasoning support, and temperature handling.
|
||||
Each variant encapsulates those quirks so the processor doesn't need
|
||||
provider-specific conditionals.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Variant:
|
||||
"""Base variant — defines the interface all variants implement."""
|
||||
|
||||
name = None
|
||||
token_param = "max_completion_tokens"
|
||||
temperature_with_thinking = False
|
||||
|
||||
def completion_kwargs(self, max_output, temperature, thinking):
|
||||
"""Build provider-specific kwargs for chat.completions.create().
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_output : int
|
||||
Configured max output tokens.
|
||||
temperature : float
|
||||
Configured temperature.
|
||||
thinking : str
|
||||
Thinking effort level: "off", "low", "medium", "high".
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Extra kwargs to spread into the API call.
|
||||
"""
|
||||
kwargs = {self.token_param: max_output}
|
||||
|
||||
if thinking != "off":
|
||||
kwargs.update(self.thinking_kwargs(thinking))
|
||||
if not self.temperature_with_thinking:
|
||||
kwargs["temperature"] = 1.0
|
||||
else:
|
||||
kwargs["temperature"] = temperature
|
||||
else:
|
||||
kwargs["temperature"] = temperature
|
||||
|
||||
return kwargs
|
||||
|
||||
def thinking_kwargs(self, effort):
|
||||
"""Return kwargs to enable thinking at the given effort level."""
|
||||
return {}
|
||||
|
||||
def extract_thinking(self, message):
|
||||
"""Extract thinking/reasoning content from a response message."""
|
||||
return getattr(message, "reasoning_content", None)
|
||||
|
||||
def extract_thinking_stream(self, delta):
|
||||
"""Extract thinking content from a streaming delta."""
|
||||
return getattr(delta, "reasoning_content", None)
|
||||
|
||||
def create_completion(self, client, model, messages, **kwargs):
|
||||
"""Call the completions API. Override for non-standard SDKs."""
|
||||
return client.chat.completions.create(
|
||||
model=model, messages=messages, **kwargs,
|
||||
)
|
||||
|
||||
async def create_completion_stream(self, client, model, messages, **kwargs):
|
||||
"""Call the streaming completions API. Override for non-standard SDKs."""
|
||||
for chunk in client.chat.completions.create(
|
||||
model=model, messages=messages, stream=True,
|
||||
stream_options={"include_usage": True}, **kwargs,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
class OpenAIVariant(Variant):
|
||||
"""Standard OpenAI API (GPT-4o, o1, o3, etc.)."""
|
||||
|
||||
name = "openai"
|
||||
token_param = "max_completion_tokens"
|
||||
temperature_with_thinking = False
|
||||
|
||||
def thinking_kwargs(self, effort):
|
||||
return {"reasoning_effort": effort}
|
||||
|
||||
|
||||
class DeepSeekVariant(Variant):
|
||||
"""DeepSeek API (R1, V3, etc.)."""
|
||||
|
||||
name = "deepseek"
|
||||
token_param = "max_completion_tokens"
|
||||
temperature_with_thinking = True
|
||||
|
||||
def completion_kwargs(self, max_output, temperature, thinking):
|
||||
enabled = "enabled" if thinking != "off" else "disabled"
|
||||
kwargs = {
|
||||
self.token_param: max_output,
|
||||
"temperature": temperature,
|
||||
"extra_body": {
|
||||
"thinking": {"type": enabled},
|
||||
},
|
||||
}
|
||||
return kwargs
|
||||
|
||||
def thinking_kwargs(self, effort):
|
||||
return {}
|
||||
|
||||
|
||||
class DashScopeVariant(Variant):
|
||||
"""Alibaba Cloud DashScope API (Qwen models)."""
|
||||
|
||||
name = "dashscope"
|
||||
token_param = "max_completion_tokens"
|
||||
temperature_with_thinking = True
|
||||
|
||||
def completion_kwargs(self, max_output, temperature, thinking):
|
||||
enabled = thinking != "off"
|
||||
return {
|
||||
self.token_param: max_output,
|
||||
"temperature": temperature,
|
||||
"extra_body": {
|
||||
"enable_thinking": enabled,
|
||||
},
|
||||
}
|
||||
|
||||
def thinking_kwargs(self, effort):
|
||||
return {}
|
||||
|
||||
|
||||
class QwenVariant(DashScopeVariant):
|
||||
"""Qwen — alias for DashScope."""
|
||||
|
||||
name = "qwen"
|
||||
|
||||
|
||||
class MistralVariant(Variant):
|
||||
"""Mistral API (Mistral Large, etc.)."""
|
||||
|
||||
name = "mistral"
|
||||
token_param = "max_tokens"
|
||||
temperature_with_thinking = False
|
||||
|
||||
def thinking_kwargs(self, effort):
|
||||
return {"reasoning_effort": effort}
|
||||
|
||||
|
||||
class GlmVariant(Variant):
|
||||
"""GLM / Zhipu AI API (GLM-4, GLM-4.7, etc.)."""
|
||||
|
||||
name = "glm"
|
||||
token_param = "max_tokens"
|
||||
temperature_with_thinking = True
|
||||
|
||||
def completion_kwargs(self, max_output, temperature, thinking):
|
||||
enabled = "enabled" if thinking != "off" else "disabled"
|
||||
kwargs = {
|
||||
self.token_param: max_output,
|
||||
"temperature": temperature,
|
||||
"extra_body": {
|
||||
"thinking": {"type": enabled},
|
||||
},
|
||||
}
|
||||
return kwargs
|
||||
|
||||
def thinking_kwargs(self, effort):
|
||||
return {}
|
||||
|
||||
|
||||
class LlamaVariant(Variant):
|
||||
"""Llama models via OpenAI-compatible servers (vLLM, Ollama, etc.).
|
||||
|
||||
Thinking is typically always-on or always-off depending on the model.
|
||||
When present, thinking appears inline as <think>...</think> tags.
|
||||
"""
|
||||
|
||||
name = "llama"
|
||||
token_param = "max_tokens"
|
||||
temperature_with_thinking = True
|
||||
|
||||
def thinking_kwargs(self, effort):
|
||||
return {}
|
||||
|
||||
def extract_thinking(self, message):
|
||||
content = message.content or ""
|
||||
match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
|
||||
return match.group(1).strip() if match else None
|
||||
|
||||
def extract_content(self, message):
|
||||
"""Strip think tags from visible content."""
|
||||
content = message.content or ""
|
||||
return re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
|
||||
|
||||
|
||||
VARIANTS = {
|
||||
"openai": OpenAIVariant,
|
||||
"deepseek": DeepSeekVariant,
|
||||
"qwen": QwenVariant,
|
||||
"mistral": MistralVariant,
|
||||
"dashscope": DashScopeVariant,
|
||||
"glm": GlmVariant,
|
||||
"llama": LlamaVariant,
|
||||
}
|
||||
|
||||
DEFAULT_VARIANT = "openai"
|
||||
|
||||
|
||||
def get_variant(name):
|
||||
"""Look up a variant by name, raising ValueError if unknown."""
|
||||
cls = VARIANTS.get(name)
|
||||
if cls is None:
|
||||
raise ValueError(
|
||||
f"Unknown variant {name!r}. "
|
||||
f"Available: {', '.join(sorted(VARIANTS))}"
|
||||
)
|
||||
return cls()
|
||||
1
trustgraph-flow/trustgraph/reranker/__init__.py
Normal file
1
trustgraph-flow/trustgraph/reranker/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
|
||||
from . processor import *
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . processor import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
109
trustgraph-flow/trustgraph/reranker/flashrank/processor.py
Normal file
109
trustgraph-flow/trustgraph/reranker/flashrank/processor.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
|
||||
"""
|
||||
Reranker service using flashrank.
|
||||
Scores query-document pairs and returns the top results ranked by
|
||||
relevance.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from ... base import RerankerService
|
||||
from ... schema import RerankerResult
|
||||
|
||||
from flashrank import Ranker, RerankRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "reranker"
|
||||
|
||||
default_model = "ms-marco-MiniLM-L-12-v2"
|
||||
|
||||
class Processor(RerankerService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
model = params.get("model", default_model)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | { "model": model }
|
||||
)
|
||||
|
||||
self.default_model = model
|
||||
|
||||
self.cached_model_name = None
|
||||
self.ranker = None
|
||||
|
||||
self._load_model(model)
|
||||
|
||||
def _load_model(self, model_name):
|
||||
if self.cached_model_name != model_name:
|
||||
logger.info(f"Loading flashrank model: {model_name}")
|
||||
self.ranker = Ranker(model_name=model_name)
|
||||
self.cached_model_name = model_name
|
||||
logger.info(f"flashrank model {model_name} loaded successfully")
|
||||
else:
|
||||
logger.debug(f"Using cached model: {model_name}")
|
||||
|
||||
def _run_rerank(self, query, passages):
|
||||
request = RerankRequest(query=query, passages=passages)
|
||||
return self.ranker.rerank(request)
|
||||
|
||||
async def on_rerank(self, queries, documents, limit, model=None):
|
||||
|
||||
if not queries or not documents:
|
||||
return []
|
||||
|
||||
use_model = model or self.default_model
|
||||
|
||||
if self.cached_model_name != use_model:
|
||||
await asyncio.to_thread(self._load_model, use_model)
|
||||
|
||||
passages = [
|
||||
{"id": d.document_id, "text": d.document_text}
|
||||
for d in documents
|
||||
]
|
||||
|
||||
best_scores = {}
|
||||
|
||||
for q in queries:
|
||||
ranked = await asyncio.to_thread(
|
||||
self._run_rerank, q.query_text, passages,
|
||||
)
|
||||
|
||||
for r in ranked:
|
||||
doc_id = r["id"]
|
||||
score = r["score"]
|
||||
score = float(score)
|
||||
if doc_id not in best_scores or score > best_scores[doc_id][1]:
|
||||
best_scores[doc_id] = (q.query_id, score)
|
||||
|
||||
results = sorted(
|
||||
best_scores.items(),
|
||||
key=lambda x: x[1][1],
|
||||
reverse=True,
|
||||
)[:limit]
|
||||
|
||||
return [
|
||||
RerankerResult(
|
||||
document_id=doc_id,
|
||||
query_id=query_id,
|
||||
score=score,
|
||||
)
|
||||
for doc_id, (query_id, score) in results
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
RerankerService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default=default_model,
|
||||
help=f'Reranker model (default: {default_model})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
@ -9,10 +9,12 @@ from trustgraph.provenance import (
|
|||
docrag_question_uri,
|
||||
docrag_grounding_uri,
|
||||
docrag_exploration_uri,
|
||||
docrag_focus_uri,
|
||||
docrag_synthesis_uri,
|
||||
docrag_question_triples,
|
||||
grounding_triples,
|
||||
docrag_exploration_triples,
|
||||
docrag_chunk_selection_triples,
|
||||
docrag_synthesis_triples,
|
||||
set_graph,
|
||||
GRAPH_RETRIEVAL,
|
||||
|
|
@ -21,19 +23,25 @@ from trustgraph.provenance import (
|
|||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# When the caller does not specify a fetch_limit, reranking over-fetches this
|
||||
# many times the final doc_limit as the candidate pool, so the cross-encoder can
|
||||
# recover relevant chunks the bi-encoder ranked just outside the top doc_limit.
|
||||
# This is only the fallback default: an explicit fetch_limit overrides it.
|
||||
OVERFETCH_FACTOR = 3
|
||||
|
||||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
||||
class Query:
|
||||
|
||||
def __init__(
|
||||
self, rag, workspace, collection, verbose,
|
||||
doc_limit=20, track_usage=None,
|
||||
fetch_limit=20, track_usage=None,
|
||||
):
|
||||
self.rag = rag
|
||||
self.workspace = workspace
|
||||
self.collection = collection
|
||||
self.verbose = verbose
|
||||
self.doc_limit = doc_limit
|
||||
self.fetch_limit = fetch_limit
|
||||
self.track_usage = track_usage
|
||||
|
||||
async def extract_concepts(self, query):
|
||||
|
|
@ -91,7 +99,7 @@ class Query:
|
|||
|
||||
# Query chunk matches for each concept concurrently
|
||||
per_concept_limit = max(
|
||||
1, self.doc_limit // len(vectors)
|
||||
1, self.fetch_limit // len(vectors)
|
||||
)
|
||||
|
||||
async def query_concept(vec):
|
||||
|
|
@ -140,6 +148,7 @@ class DocumentRag:
|
|||
def __init__(
|
||||
self, prompt_client, embeddings_client, doc_embeddings_client,
|
||||
fetch_chunk,
|
||||
reranker_client=None,
|
||||
verbose=False,
|
||||
):
|
||||
|
||||
|
|
@ -150,12 +159,16 @@ class DocumentRag:
|
|||
self.doc_embeddings_client = doc_embeddings_client
|
||||
self.fetch_chunk = fetch_chunk
|
||||
|
||||
# Optional cross-encoder reranker. When None, the retrieval path is
|
||||
# byte-identical to the pre-reranker behaviour.
|
||||
self.reranker_client = reranker_client
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("DocumentRag initialized")
|
||||
|
||||
async def query(
|
||||
self, query, workspace="default", collection="default",
|
||||
doc_limit=20, streaming=False, chunk_callback=None,
|
||||
doc_limit=20, fetch_limit=0, streaming=False, chunk_callback=None,
|
||||
explain_callback=None, save_answer_callback=None,
|
||||
):
|
||||
"""
|
||||
|
|
@ -165,7 +178,10 @@ class DocumentRag:
|
|||
query: The query string
|
||||
workspace: Workspace for isolation (also scopes chunk lookup)
|
||||
collection: Collection identifier
|
||||
doc_limit: Max chunks to retrieve
|
||||
doc_limit: Chunks selected into the synthesis prompt (after rerank)
|
||||
fetch_limit: Candidate pool fetched from the vector store before
|
||||
reranking. 0 = derive (OVERFETCH_FACTOR x doc_limit when a
|
||||
reranker is wired, else doc_limit).
|
||||
streaming: Enable streaming LLM response
|
||||
chunk_callback: async def callback(chunk, end_of_stream) for streaming
|
||||
explain_callback: async def callback(triples, explain_id) for explainability
|
||||
|
|
@ -197,6 +213,7 @@ class DocumentRag:
|
|||
q_uri = docrag_question_uri(session_id)
|
||||
gnd_uri = docrag_grounding_uri(session_id)
|
||||
exp_uri = docrag_exploration_uri(session_id)
|
||||
foc_uri = docrag_focus_uri(session_id)
|
||||
syn_uri = docrag_synthesis_uri(session_id)
|
||||
|
||||
timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
|
|
@ -209,10 +226,21 @@ class DocumentRag:
|
|||
)
|
||||
await explain_callback(q_triples, q_uri)
|
||||
|
||||
# Resolve the candidate-pool size fetched from the vector store. When a
|
||||
# reranker is wired, honour an explicit fetch_limit; if unset, fall back
|
||||
# to the OVERFETCH_FACTOR heuristic. Never fetch fewer than doc_limit,
|
||||
# else the rerank could not fill the prompt. Without a reranker, fetch
|
||||
# doc_limit as before (byte-identical behaviour).
|
||||
if self.reranker_client is not None:
|
||||
fl = fetch_limit or (OVERFETCH_FACTOR * doc_limit)
|
||||
fetch_count = max(fl, doc_limit)
|
||||
else:
|
||||
fetch_count = doc_limit
|
||||
|
||||
q = Query(
|
||||
rag=self, workspace=workspace, collection=collection,
|
||||
verbose=self.verbose,
|
||||
doc_limit=doc_limit, track_usage=track_usage,
|
||||
fetch_limit=fetch_count, track_usage=track_usage,
|
||||
)
|
||||
|
||||
# Extract concepts from query (grounding step)
|
||||
|
|
@ -235,6 +263,7 @@ class DocumentRag:
|
|||
docs, chunk_ids = await q.get_docs(concepts)
|
||||
|
||||
# Emit exploration explainability after chunks retrieved
|
||||
# (full candidate set, before any reranking)
|
||||
if explain_callback:
|
||||
exp_triples = set_graph(
|
||||
docrag_exploration_triples(exp_uri, gnd_uri, len(chunk_ids), chunk_ids),
|
||||
|
|
@ -242,6 +271,45 @@ class DocumentRag:
|
|||
)
|
||||
await explain_callback(exp_triples, exp_uri)
|
||||
|
||||
# Optional cross-encoder reranking pass between retrieval and
|
||||
# synthesis. Mirrors GraphRAG's reranker usage but with a single
|
||||
# query (the question). When no reranker is wired, this block is
|
||||
# skipped entirely and behaviour is byte-identical to before.
|
||||
reranked = False
|
||||
if self.reranker_client is not None and docs:
|
||||
results = await self.reranker_client.rerank(
|
||||
queries=[{"id": "0", "text": query}],
|
||||
documents=[
|
||||
{"id": str(i), "text": d} for i, d in enumerate(docs)
|
||||
],
|
||||
# Narrow the over-fetched candidate pool down to the final
|
||||
# doc_limit requested for synthesis.
|
||||
limit=doc_limit,
|
||||
)
|
||||
|
||||
# results are sorted desc by score and truncated to limit by the
|
||||
# reranker service, so order gives the surviving top-N directly.
|
||||
order = [int(r.document_id) for r in results]
|
||||
docs = [docs[i] for i in order]
|
||||
chunk_ids = [chunk_ids[i] for i in order]
|
||||
reranked = True
|
||||
|
||||
# Emit chunk-selection (focus) explainability: surviving chunks
|
||||
# with their cross-encoder scores, derived from exploration.
|
||||
if explain_callback:
|
||||
selected_chunks_with_scores = [
|
||||
{"chunk_id": chunk_ids[i], "score": r.score}
|
||||
for i, r in enumerate(results)
|
||||
]
|
||||
foc_triples = set_graph(
|
||||
docrag_chunk_selection_triples(
|
||||
foc_uri, exp_uri,
|
||||
selected_chunks_with_scores, session_id,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(foc_triples, foc_uri)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Invoking LLM...")
|
||||
logger.debug(f"Documents: {docs}")
|
||||
|
|
@ -291,9 +359,15 @@ class DocumentRag:
|
|||
logger.warning(f"Failed to save answer to librarian: {e}")
|
||||
synthesis_doc_id = None
|
||||
|
||||
# When reranking ran, synthesis derives from the focus (the
|
||||
# reranked chunks actually fed to the LLM), as GraphRAG always does.
|
||||
# When no reranker is wired, there is no focus stage, so synthesis
|
||||
# derives from exploration (the unchanged no-op lineage) - a
|
||||
# deliberate divergence from GraphRAG's always-on focus.
|
||||
syn_parent = foc_uri if reranked else exp_uri
|
||||
syn_triples = set_graph(
|
||||
docrag_synthesis_triples(
|
||||
syn_uri, exp_uri,
|
||||
syn_uri, syn_parent,
|
||||
document_id=synthesis_doc_id,
|
||||
in_token=synthesis_result.in_token if synthesis_result else None,
|
||||
out_token=synthesis_result.out_token if synthesis_result else None,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from . document_rag import DocumentRag
|
|||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
||||
from ... base import DocumentEmbeddingsClientSpec
|
||||
from ... base import RerankerClientSpec
|
||||
from ... base import LibrarianSpec
|
||||
|
||||
# Module logger
|
||||
|
|
@ -28,14 +29,21 @@ class Processor(FlowProcessor):
|
|||
|
||||
doc_limit = params.get("doc_limit", 5)
|
||||
|
||||
# Instance-default candidate-pool size fetched before cross-encoder
|
||||
# reranking; the rerank step narrows it back down to doc_limit for the
|
||||
# LLM. 0 means the core derives it (OVERFETCH_FACTOR x doc_limit).
|
||||
fetch_limit = params.get("fetch_limit", 0)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"doc_limit": doc_limit,
|
||||
"fetch_limit": fetch_limit,
|
||||
}
|
||||
)
|
||||
|
||||
self.doc_limit = doc_limit
|
||||
self.fetch_limit = fetch_limit
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
|
|
@ -66,6 +74,13 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
RerankerClientSpec(
|
||||
request_name = "reranker-request",
|
||||
response_name = "reranker-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "response",
|
||||
|
|
@ -105,6 +120,7 @@ class Processor(FlowProcessor):
|
|||
doc_embeddings_client = flow("document-embeddings-request"),
|
||||
prompt_client = flow("prompt-request"),
|
||||
fetch_chunk = fetch_chunk,
|
||||
reranker_client = flow("reranker-request"),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
|
@ -113,6 +129,13 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
doc_limit = self.doc_limit
|
||||
|
||||
# Candidate-pool size: per-request override, else the instance
|
||||
# default; 0 lets the core derive it from doc_limit.
|
||||
if v.fetch_limit:
|
||||
fetch_limit = v.fetch_limit
|
||||
else:
|
||||
fetch_limit = self.fetch_limit
|
||||
|
||||
async def send_explainability(triples, explain_id):
|
||||
await flow("explainability").send(Triples(
|
||||
metadata=Metadata(
|
||||
|
|
@ -163,6 +186,7 @@ class Processor(FlowProcessor):
|
|||
workspace=flow.workspace,
|
||||
collection=v.collection,
|
||||
doc_limit=doc_limit,
|
||||
fetch_limit=fetch_limit,
|
||||
streaming=True,
|
||||
chunk_callback=send_chunk,
|
||||
explain_callback=send_explainability,
|
||||
|
|
@ -188,6 +212,7 @@ class Processor(FlowProcessor):
|
|||
workspace=flow.workspace,
|
||||
collection=v.collection,
|
||||
doc_limit=doc_limit,
|
||||
fetch_limit=fetch_limit,
|
||||
explain_callback=send_explainability,
|
||||
save_answer_callback=save_answer,
|
||||
)
|
||||
|
|
@ -243,6 +268,15 @@ class Processor(FlowProcessor):
|
|||
help=f'Default document fetch limit (default: 10)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--fetch-limit',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Candidate chunks to fetch from the vector store and rerank '
|
||||
'before keeping the top doc-limit for the LLM '
|
||||
'(default: derive from doc-limit)'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ class Query:
|
|||
def __init__(
|
||||
self, rag, collection, verbose,
|
||||
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
|
||||
max_path_length=2, track_usage=None,
|
||||
max_path_length=2, edge_limit=25, track_usage=None,
|
||||
):
|
||||
self.rag = rag
|
||||
self.collection = collection
|
||||
|
|
@ -129,6 +129,7 @@ class Query:
|
|||
self.triple_limit = triple_limit
|
||||
self.max_subgraph_size = max_subgraph_size
|
||||
self.max_path_length = max_path_length
|
||||
self.edge_limit = edge_limit
|
||||
self.track_usage = track_usage
|
||||
|
||||
async def extract_concepts(self, query):
|
||||
|
|
@ -217,12 +218,9 @@ class Query:
|
|||
logger.debug(f" {ent}")
|
||||
|
||||
return entities, concepts
|
||||
|
||||
|
||||
async def maybe_label(self, e):
|
||||
|
||||
# The label cache lives on a per-request GraphRag instance — no
|
||||
# cross-request isolation concern. The collection prefix keeps
|
||||
# entries from different collections distinct within one request.
|
||||
cache_key = f"{self.collection}:{e}"
|
||||
|
||||
cached_label = self.rag.label_cache.get(cache_key)
|
||||
|
|
@ -244,11 +242,10 @@ class Query:
|
|||
return label
|
||||
|
||||
async def execute_batch_triple_queries(self, entities, limit_per_entity):
|
||||
"""Execute triple queries for multiple entities concurrently using streaming"""
|
||||
"""Execute triple queries for multiple entities concurrently."""
|
||||
tasks = []
|
||||
|
||||
for entity in entities:
|
||||
# Create concurrent streaming tasks for all 3 query types per entity
|
||||
tasks.extend([
|
||||
self.rag.triples_client.query_stream(
|
||||
s=entity, p=None, o=None,
|
||||
|
|
@ -270,10 +267,8 @@ class Query:
|
|||
)
|
||||
])
|
||||
|
||||
# Execute all queries concurrently
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Combine all results
|
||||
all_triples = []
|
||||
for result in results:
|
||||
if not isinstance(result, Exception) and result is not None:
|
||||
|
|
@ -281,168 +276,151 @@ class Query:
|
|||
|
||||
return all_triples
|
||||
|
||||
async def follow_edges_batch(self, entities, max_depth):
|
||||
"""Optimized iterative graph traversal with batching.
|
||||
|
||||
Returns:
|
||||
tuple: (subgraph, term_map) where subgraph is a set of
|
||||
(str, str, str) tuples and term_map maps each string tuple
|
||||
to its original (Term, Term, Term) for type-preserving
|
||||
provenance.
|
||||
"""
|
||||
visited = set()
|
||||
current_level = set(entities)
|
||||
subgraph = set()
|
||||
term_map = {} # (str, str, str) -> (Term, Term, Term)
|
||||
|
||||
for depth in range(max_depth):
|
||||
if not current_level or len(subgraph) >= self.max_subgraph_size:
|
||||
break
|
||||
|
||||
# Filter out already visited entities
|
||||
unvisited_entities = [e for e in current_level if e not in visited]
|
||||
if not unvisited_entities:
|
||||
break
|
||||
|
||||
# Batch query all unvisited entities at current level
|
||||
triples = await self.execute_batch_triple_queries(
|
||||
unvisited_entities, self.triple_limit
|
||||
)
|
||||
|
||||
# Process results and collect next level entities
|
||||
next_level = set()
|
||||
for triple in triples:
|
||||
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
|
||||
subgraph.add(triple_tuple)
|
||||
term_map[triple_tuple] = (to_term(triple.s), to_term(triple.p), to_term(triple.o))
|
||||
|
||||
# Collect entities for next level (only from s and o positions)
|
||||
if depth < max_depth - 1: # Don't collect for final depth
|
||||
s, p, o = triple_tuple
|
||||
if s not in visited:
|
||||
next_level.add(s)
|
||||
if o not in visited:
|
||||
next_level.add(o)
|
||||
|
||||
# Stop if subgraph size limit reached
|
||||
if len(subgraph) >= self.max_subgraph_size:
|
||||
return subgraph, term_map
|
||||
|
||||
# Update for next iteration
|
||||
visited.update(current_level)
|
||||
current_level = next_level
|
||||
|
||||
return subgraph, term_map
|
||||
|
||||
async def follow_edges(self, ent, subgraph, path_length):
|
||||
"""Legacy method - replaced by follow_edges_batch"""
|
||||
# Maintain backward compatibility with early termination checks
|
||||
if path_length <= 0:
|
||||
return
|
||||
|
||||
if len(subgraph) >= self.max_subgraph_size:
|
||||
return
|
||||
|
||||
# For backward compatibility, convert to new approach
|
||||
batch_result, _ = await self.follow_edges_batch([ent], path_length)
|
||||
subgraph.update(batch_result)
|
||||
|
||||
async def get_subgraph(self, query):
|
||||
"""
|
||||
Get subgraph by extracting concepts, finding entities, and traversing.
|
||||
|
||||
Returns:
|
||||
tuple: (subgraph, term_map, entities, concepts) where subgraph is
|
||||
a list of (s, p, o) string tuples, term_map maps each string
|
||||
tuple to its original (Term, Term, Term), entities is the seed
|
||||
entity list, and concepts is the extracted concept list.
|
||||
"""
|
||||
|
||||
entities, concepts = await self.get_entities(query)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Getting subgraph...")
|
||||
|
||||
# Use optimized batch traversal instead of sequential processing
|
||||
subgraph, term_map = await self.follow_edges_batch(entities, self.max_path_length)
|
||||
|
||||
return list(subgraph), term_map, entities, concepts
|
||||
|
||||
async def resolve_labels_batch(self, entities):
|
||||
"""Resolve labels for multiple entities in parallel"""
|
||||
tasks = []
|
||||
for entity in entities:
|
||||
tasks.append(self.maybe_label(entity))
|
||||
|
||||
"""Resolve labels for multiple entities in parallel."""
|
||||
tasks = [self.maybe_label(entity) for entity in entities]
|
||||
return await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def get_labelgraph(self, query):
|
||||
"""
|
||||
Get subgraph with labels resolved for display.
|
||||
async def hop_and_filter(self, seed_entities, concepts):
|
||||
"""Iterative hop-and-filter graph traversal with cross-encoder.
|
||||
|
||||
At each hop:
|
||||
1. Retrieve all edges one hop from the frontier.
|
||||
2. Resolve labels and represent each edge as "{p} {o}".
|
||||
3. Score edges against concepts using the cross-encoder.
|
||||
4. Select the top edges; their target nodes become the next
|
||||
frontier.
|
||||
|
||||
Returns:
|
||||
tuple: (labeled_edges, uri_map, entities, concepts) where:
|
||||
- labeled_edges: list of (label_s, label_p, label_o) tuples
|
||||
- uri_map: dict mapping edge_id(label_s, label_p, label_o) -> (uri_s, uri_p, uri_o)
|
||||
- entities: list of seed entity URI strings
|
||||
- concepts: list of concept strings extracted from query
|
||||
tuple: (selected_edges, uri_map, edge_metadata) where:
|
||||
- selected_edges: list of (label_s, label_p, label_o)
|
||||
- uri_map: dict mapping edge_id -> (Term, Term, Term)
|
||||
- edge_metadata: dict mapping edge_id -> {concept, score}
|
||||
"""
|
||||
subgraph, term_map, entities, concepts = await self.get_subgraph(query)
|
||||
all_selected_edges = []
|
||||
uri_map = {}
|
||||
edge_metadata = {}
|
||||
frontier = set(seed_entities)
|
||||
visited_entities = set()
|
||||
seen_edges = set()
|
||||
|
||||
# Filter out label triples
|
||||
filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL]
|
||||
for hop in range(self.max_path_length):
|
||||
if not frontier:
|
||||
break
|
||||
|
||||
# Collect all unique entities that need label resolution
|
||||
entities_to_resolve = set()
|
||||
for s, p, o in filtered_subgraph:
|
||||
entities_to_resolve.update([s, p, o])
|
||||
unvisited = [e for e in frontier if e not in visited_entities]
|
||||
if not unvisited:
|
||||
break
|
||||
|
||||
# Batch resolve labels for all entities in parallel
|
||||
entity_list = list(entities_to_resolve)
|
||||
resolved_labels = await self.resolve_labels_batch(entity_list)
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Hop {hop + 1}: {len(unvisited)} frontier entities"
|
||||
)
|
||||
|
||||
# Create entity-to-label mapping
|
||||
label_map = {}
|
||||
for entity, label in zip(entity_list, resolved_labels):
|
||||
if not isinstance(label, Exception):
|
||||
label_map[entity] = label
|
||||
else:
|
||||
label_map[entity] = entity # Fallback to entity itself
|
||||
|
||||
# Apply labels to subgraph and build URI mapping
|
||||
labeled_edges = []
|
||||
uri_map = {} # Maps edge_id of labeled edge -> original Term triple
|
||||
|
||||
for s, p, o in filtered_subgraph:
|
||||
labeled_triple = (
|
||||
label_map.get(s, s),
|
||||
label_map.get(p, p),
|
||||
label_map.get(o, o)
|
||||
# Retrieve edges one hop from frontier
|
||||
triples = await self.execute_batch_triple_queries(
|
||||
unvisited, self.triple_limit,
|
||||
)
|
||||
labeled_edges.append(labeled_triple)
|
||||
|
||||
# Map from labeled edge ID to original Terms (preserving types)
|
||||
labeled_eid = edge_id(labeled_triple[0], labeled_triple[1], labeled_triple[2])
|
||||
uri_map[labeled_eid] = term_map.get((s, p, o), (s, p, o))
|
||||
# Deduplicate and filter already-seen edges
|
||||
hop_triples = []
|
||||
hop_term_map = {}
|
||||
for triple in triples:
|
||||
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
|
||||
if triple_tuple[1] == LABEL:
|
||||
continue
|
||||
if triple_tuple in seen_edges:
|
||||
continue
|
||||
seen_edges.add(triple_tuple)
|
||||
hop_triples.append(triple_tuple)
|
||||
hop_term_map[triple_tuple] = (
|
||||
to_term(triple.s), to_term(triple.p), to_term(triple.o),
|
||||
)
|
||||
|
||||
labeled_edges = labeled_edges[0:self.max_subgraph_size]
|
||||
if not hop_triples:
|
||||
visited_entities.update(frontier)
|
||||
break
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Subgraph:")
|
||||
for edge in labeled_edges:
|
||||
logger.debug(f" {str(edge)}")
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Hop {hop + 1}: {len(hop_triples)} candidate edges"
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Done.")
|
||||
# Resolve labels for all entities in hop edges
|
||||
entities_to_resolve = set()
|
||||
for s, p, o in hop_triples:
|
||||
entities_to_resolve.update([s, p, o])
|
||||
|
||||
return labeled_edges, uri_map, entities, concepts
|
||||
entity_list = list(entities_to_resolve)
|
||||
resolved = await self.resolve_labels_batch(entity_list)
|
||||
|
||||
label_map = {}
|
||||
for entity, label in zip(entity_list, resolved):
|
||||
if not isinstance(label, Exception):
|
||||
label_map[entity] = label
|
||||
else:
|
||||
label_map[entity] = entity
|
||||
|
||||
# Build labeled edges and documents for cross-encoder
|
||||
labeled_hop = []
|
||||
for s, p, o in hop_triples:
|
||||
ls = label_map.get(s, s)
|
||||
lp = label_map.get(p, p)
|
||||
lo = label_map.get(o, o)
|
||||
labeled_hop.append((ls, lp, lo))
|
||||
|
||||
documents = [
|
||||
{"id": str(i), "text": f"{lp} {lo}"}
|
||||
for i, (ls, lp, lo) in enumerate(labeled_hop)
|
||||
]
|
||||
|
||||
queries = [
|
||||
{"id": str(i), "text": c}
|
||||
for i, c in enumerate(concepts)
|
||||
]
|
||||
|
||||
# Score with cross-encoder
|
||||
results = await self.rag.reranker_client.rerank(
|
||||
queries=queries,
|
||||
documents=documents,
|
||||
limit=self.edge_limit,
|
||||
)
|
||||
|
||||
# Collect selected edges and metadata
|
||||
next_frontier = set()
|
||||
for r in results:
|
||||
idx = int(r.document_id)
|
||||
ls, lp, lo = labeled_hop[idx]
|
||||
s, p, o = hop_triples[idx]
|
||||
eid = edge_id(ls, lp, lo)
|
||||
|
||||
all_selected_edges.append((ls, lp, lo))
|
||||
uri_map[eid] = hop_term_map[(s, p, o)]
|
||||
edge_metadata[eid] = {
|
||||
"concept": concepts[int(r.query_id)],
|
||||
"score": r.score,
|
||||
}
|
||||
|
||||
# Target nodes become next frontier
|
||||
next_frontier.add(s)
|
||||
next_frontier.add(o)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Hop {hop + 1}: selected {len(results)} edges"
|
||||
)
|
||||
|
||||
visited_entities.update(frontier)
|
||||
frontier = next_frontier - visited_entities
|
||||
|
||||
return all_selected_edges, uri_map, edge_metadata
|
||||
|
||||
async def trace_source_documents(self, edge_uris):
|
||||
"""
|
||||
Trace selected edges back to their source documents via provenance.
|
||||
|
||||
Follows the chain: edge → subgraph (via tg:contains) → chunk →
|
||||
page → document (via prov:wasDerivedFrom), all in urn:graph:source.
|
||||
Follows the chain: edge -> subgraph (via tg:contains) -> chunk ->
|
||||
page -> document (via prov:wasDerivedFrom), all in urn:graph:source.
|
||||
|
||||
Args:
|
||||
edge_uris: List of (s, p, o) URI string tuples
|
||||
|
|
@ -453,7 +431,6 @@ class Query:
|
|||
# Step 1: Find subgraphs containing these edges via tg:contains
|
||||
subgraph_tasks = []
|
||||
for s, p, o in edge_uris:
|
||||
# s, p, o may be Term objects (preserving types) or strings
|
||||
s_term = s if isinstance(s, Term) else Term(type=IRI, iri=s)
|
||||
p_term = p if isinstance(p, Term) else Term(type=IRI, iri=p)
|
||||
o_term = o if isinstance(o, Term) else Term(type=IRI, iri=o)
|
||||
|
|
@ -487,12 +464,10 @@ class Query:
|
|||
return []
|
||||
|
||||
# Step 2: Walk prov:wasDerivedFrom chain to find documents
|
||||
# Each level: query ?entity prov:wasDerivedFrom ?parent
|
||||
# Stop when we find entities typed tg:Document
|
||||
current_uris = subgraph_uris
|
||||
doc_uris = set()
|
||||
|
||||
for depth in range(4): # Max depth: subgraph → chunk → page → doc
|
||||
for depth in range(4):
|
||||
if not current_uris:
|
||||
break
|
||||
|
||||
|
|
@ -509,7 +484,6 @@ class Query:
|
|||
*derivation_tasks, return_exceptions=True
|
||||
)
|
||||
|
||||
# URIs with no parent are root documents
|
||||
next_uris = set()
|
||||
for uri, result in zip(current_uris, derivation_results):
|
||||
if isinstance(result, Exception) or not result:
|
||||
|
|
@ -524,7 +498,6 @@ class Query:
|
|||
return []
|
||||
|
||||
# Step 3: Get all document metadata properties
|
||||
# Skip structural predicates that aren't useful context
|
||||
SKIP_PREDICATES = {
|
||||
PROV_WAS_DERIVED_FROM,
|
||||
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
|
||||
|
|
@ -565,7 +538,7 @@ class GraphRag:
|
|||
|
||||
def __init__(
|
||||
self, prompt_client, embeddings_client, graph_embeddings_client,
|
||||
triples_client, verbose=False,
|
||||
triples_client, reranker_client, verbose=False,
|
||||
):
|
||||
|
||||
self.verbose = verbose
|
||||
|
|
@ -574,9 +547,8 @@ class GraphRag:
|
|||
self.embeddings_client = embeddings_client
|
||||
self.graph_embeddings_client = graph_embeddings_client
|
||||
self.triples_client = triples_client
|
||||
self.reranker_client = reranker_client
|
||||
|
||||
# Replace simple dict with LRU cache with TTL
|
||||
# CRITICAL: This cache only lives for one request due to per-request instantiation
|
||||
self.label_cache = LRUCacheWithTTL(max_size=5000, ttl=300)
|
||||
|
||||
if self.verbose:
|
||||
|
|
@ -585,33 +557,12 @@ class GraphRag:
|
|||
async def query(
|
||||
self, query, collection = "default",
|
||||
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
|
||||
max_path_length = 2, edge_score_limit = 30, edge_limit = 25,
|
||||
max_path_length = 2, edge_limit = 25,
|
||||
streaming = False,
|
||||
chunk_callback = None,
|
||||
explain_callback = None, save_answer_callback = None,
|
||||
parent_uri = "",
|
||||
):
|
||||
"""
|
||||
Execute a GraphRAG query with real-time explainability tracking.
|
||||
|
||||
Args:
|
||||
query: The query string
|
||||
collection: Collection identifier
|
||||
entity_limit: Max entities to retrieve
|
||||
triple_limit: Max triples per entity
|
||||
max_subgraph_size: Max edges in subgraph
|
||||
max_path_length: Max hops from seed entities
|
||||
edge_score_limit: Max edges to pass to LLM scoring (semantic pre-filter)
|
||||
edge_limit: Max edges after LLM scoring
|
||||
streaming: Enable streaming LLM response
|
||||
chunk_callback: async def callback(chunk, end_of_stream) for streaming
|
||||
explain_callback: async def callback(triples, explain_id) for real-time explainability
|
||||
save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian
|
||||
|
||||
Returns:
|
||||
tuple: (answer_text, usage) where usage is a dict with
|
||||
in_token, out_token, model
|
||||
"""
|
||||
# Accumulate token usage across all prompt calls
|
||||
total_in = 0
|
||||
total_out = 0
|
||||
|
|
@ -638,7 +589,9 @@ class GraphRag:
|
|||
foc_uri = make_focus_uri(session_id)
|
||||
syn_uri = make_synthesis_uri(session_id)
|
||||
|
||||
timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
timestamp = datetime.now(timezone.utc).isoformat().replace(
|
||||
"+00:00", "Z",
|
||||
)
|
||||
|
||||
# Emit question explainability immediately
|
||||
if explain_callback:
|
||||
|
|
@ -657,10 +610,12 @@ class GraphRag:
|
|||
triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
edge_limit = edge_limit,
|
||||
track_usage = track_usage,
|
||||
)
|
||||
|
||||
kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query)
|
||||
# Step 1: Extract concepts and find seed entities
|
||||
seed_entities, concepts = await q.get_entities(query)
|
||||
|
||||
# Emit grounding explain after concept extraction
|
||||
if explain_callback:
|
||||
|
|
@ -676,11 +631,16 @@ class GraphRag:
|
|||
)
|
||||
await explain_callback(gnd_triples, gnd_uri)
|
||||
|
||||
# Emit exploration explain after graph retrieval completes
|
||||
# Step 2: Iterative hop-and-filter with cross-encoder
|
||||
selected_edges, uri_map, edge_metadata = await q.hop_and_filter(
|
||||
seed_entities, concepts,
|
||||
)
|
||||
|
||||
# Emit exploration explain
|
||||
if explain_callback:
|
||||
exp_triples = set_graph(
|
||||
exploration_triples(
|
||||
exp_uri, gnd_uri, len(kg),
|
||||
exp_uri, gnd_uri, len(selected_edges),
|
||||
entities=seed_entities,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
|
|
@ -688,235 +648,63 @@ class GraphRag:
|
|||
await explain_callback(exp_triples, exp_uri)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Invoking LLM...")
|
||||
logger.debug(f"Knowledge graph: {kg}")
|
||||
logger.debug(f"Query: {query}")
|
||||
|
||||
# Semantic pre-filter: reduce edges before expensive LLM scoring
|
||||
if edge_score_limit > 0 and len(kg) > edge_score_limit:
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Selected {len(selected_edges)} edges")
|
||||
for s, p, o in selected_edges:
|
||||
eid = edge_id(s, p, o)
|
||||
meta = edge_metadata.get(eid, {})
|
||||
logger.debug(
|
||||
f"Semantic pre-filter: {len(kg)} edges > "
|
||||
f"limit {edge_score_limit}, filtering..."
|
||||
f" {meta.get('score', 0):.4f} "
|
||||
f"[{meta.get('concept', '')}] "
|
||||
f"{s} | {p} | {o}"
|
||||
)
|
||||
|
||||
# Embed edge descriptions: "subject, predicate, object"
|
||||
edge_descriptions = [
|
||||
f"{s}, {p}, {o}" for s, p, o in kg
|
||||
]
|
||||
|
||||
# Embed concepts and edge descriptions concurrently
|
||||
concept_embed_task = self.embeddings_client.embed(concepts)
|
||||
edge_embed_task = self.embeddings_client.embed(edge_descriptions)
|
||||
|
||||
concept_vectors, edge_vectors = await asyncio.gather(
|
||||
concept_embed_task, edge_embed_task
|
||||
)
|
||||
|
||||
# Score each edge by max cosine similarity to any concept
|
||||
def cosine_similarity(a, b):
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = math.sqrt(sum(x * x for x in a))
|
||||
norm_b = math.sqrt(sum(x * x for x in b))
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
return dot / (norm_a * norm_b)
|
||||
|
||||
edge_scores = []
|
||||
for i, edge_vec in enumerate(edge_vectors):
|
||||
max_sim = max(
|
||||
cosine_similarity(edge_vec, cv)
|
||||
for cv in concept_vectors
|
||||
)
|
||||
edge_scores.append((max_sim, i))
|
||||
|
||||
# Sort by similarity descending and keep top edge_score_limit
|
||||
edge_scores.sort(reverse=True)
|
||||
keep_indices = set(
|
||||
idx for _, idx in edge_scores[:edge_score_limit]
|
||||
)
|
||||
|
||||
# Filter kg and rebuild uri_map
|
||||
filtered_kg = []
|
||||
filtered_uri_map = {}
|
||||
for i, (s, p, o) in enumerate(kg):
|
||||
if i in keep_indices:
|
||||
filtered_kg.append((s, p, o))
|
||||
eid = edge_id(s, p, o)
|
||||
if eid in uri_map:
|
||||
filtered_uri_map[eid] = uri_map[eid]
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Semantic pre-filter kept {len(filtered_kg)} "
|
||||
f"of {len(kg)} edges"
|
||||
)
|
||||
|
||||
kg = filtered_kg
|
||||
uri_map = filtered_uri_map
|
||||
|
||||
# Build edge map: {hash_id: (labeled_s, labeled_p, labeled_o)}
|
||||
# uri_map already maps edge_id -> (uri_s, uri_p, uri_o)
|
||||
edge_map = {}
|
||||
edges_with_ids = []
|
||||
for s, p, o in kg:
|
||||
eid = edge_id(s, p, o)
|
||||
edge_map[eid] = (s, p, o)
|
||||
edges_with_ids.append({
|
||||
"id": eid,
|
||||
"s": s,
|
||||
"p": p,
|
||||
"o": o
|
||||
})
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Built edge map with {len(edge_map)} edges")
|
||||
|
||||
# Step 1a: Edge Scoring - LLM scores edges for relevance
|
||||
scoring_result = await self.prompt_client.prompt(
|
||||
"kg-edge-scoring",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": edges_with_ids
|
||||
}
|
||||
)
|
||||
track_usage(scoring_result)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Edge scoring result: {scoring_result}")
|
||||
|
||||
# Parse scoring response (jsonl) to get edge IDs with scores
|
||||
scored_edges = []
|
||||
|
||||
for obj in scoring_result.objects or []:
|
||||
if isinstance(obj, dict) and "id" in obj and "score" in obj:
|
||||
try:
|
||||
score = int(obj["score"])
|
||||
except (ValueError, TypeError):
|
||||
score = 0
|
||||
scored_edges.append({"id": obj["id"], "score": score})
|
||||
|
||||
# Select top N edges by score
|
||||
scored_edges.sort(key=lambda x: x["score"], reverse=True)
|
||||
top_edges = scored_edges[:edge_limit]
|
||||
selected_ids = {e["id"] for e in top_edges}
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Scored {len(scored_edges)} edges, "
|
||||
f"selected top {len(selected_ids)}"
|
||||
)
|
||||
|
||||
# Filter to selected edges
|
||||
selected_edges = []
|
||||
for eid in selected_ids:
|
||||
if eid in edge_map:
|
||||
selected_edges.append(edge_map[eid])
|
||||
|
||||
# Step 1b: Edge Reasoning + Document Tracing (concurrent)
|
||||
selected_edges_with_ids = [
|
||||
{"id": eid, "s": s, "p": p, "o": o}
|
||||
for eid in selected_ids
|
||||
if eid in edge_map
|
||||
for s, p, o in [edge_map[eid]]
|
||||
]
|
||||
|
||||
# Collect selected edge URIs for document tracing
|
||||
# Step 3: Document tracing
|
||||
selected_edge_uris = [
|
||||
uri_map[eid]
|
||||
for eid in selected_ids
|
||||
if eid in uri_map
|
||||
uri_map[edge_id(s, p, o)]
|
||||
for s, p, o in selected_edges
|
||||
if edge_id(s, p, o) in uri_map
|
||||
]
|
||||
|
||||
# Run reasoning and document tracing concurrently
|
||||
async def _get_reasoning():
|
||||
result = await self.prompt_client.prompt(
|
||||
"kg-edge-reasoning",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": selected_edges_with_ids
|
||||
}
|
||||
)
|
||||
track_usage(result)
|
||||
return result
|
||||
|
||||
reasoning_task = _get_reasoning()
|
||||
doc_trace_task = q.trace_source_documents(selected_edge_uris)
|
||||
|
||||
reasoning_result, source_documents = await asyncio.gather(
|
||||
reasoning_task, doc_trace_task, return_exceptions=True
|
||||
source_documents = await q.trace_source_documents(
|
||||
selected_edge_uris,
|
||||
)
|
||||
|
||||
# Handle exceptions from gather
|
||||
if isinstance(reasoning_result, Exception):
|
||||
logger.warning(
|
||||
f"Edge reasoning failed: {reasoning_result}"
|
||||
)
|
||||
reasoning_result = None
|
||||
if isinstance(source_documents, Exception):
|
||||
logger.warning(
|
||||
f"Document tracing failed: {source_documents}"
|
||||
)
|
||||
source_documents = []
|
||||
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Edge reasoning result: {reasoning_result}")
|
||||
|
||||
# Parse reasoning response (jsonl) and build explainability data
|
||||
reasoning_map = {}
|
||||
|
||||
if reasoning_result is not None:
|
||||
for obj in reasoning_result.objects or []:
|
||||
if isinstance(obj, dict) and "id" in obj:
|
||||
reasoning_map[obj["id"]] = obj.get("reasoning", "")
|
||||
|
||||
# Build focus explainability data with cross-encoder metadata
|
||||
selected_edges_with_reasoning = []
|
||||
for eid in selected_ids:
|
||||
for s, p, o in selected_edges:
|
||||
eid = edge_id(s, p, o)
|
||||
if eid in uri_map:
|
||||
uri_s, uri_p, uri_o = uri_map[eid]
|
||||
meta = edge_metadata.get(eid, {})
|
||||
selected_edges_with_reasoning.append({
|
||||
"edge": (uri_s, uri_p, uri_o),
|
||||
"reasoning": reasoning_map.get(eid, ""),
|
||||
"concept": meta.get("concept", ""),
|
||||
"score": meta.get("score", 0),
|
||||
})
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Filtered to {len(selected_edges)} edges")
|
||||
|
||||
# Emit focus explain after edge selection completes
|
||||
# Emit focus explain
|
||||
if explain_callback:
|
||||
# Sum scoring + reasoning token usage for focus event
|
||||
focus_in = 0
|
||||
focus_out = 0
|
||||
focus_model = None
|
||||
for r in [scoring_result, reasoning_result]:
|
||||
if r is not None:
|
||||
if r.in_token is not None:
|
||||
focus_in += r.in_token
|
||||
if r.out_token is not None:
|
||||
focus_out += r.out_token
|
||||
if r.model is not None:
|
||||
focus_model = r.model
|
||||
|
||||
foc_triples = set_graph(
|
||||
focus_triples(
|
||||
foc_uri, exp_uri, selected_edges_with_reasoning, session_id,
|
||||
in_token=focus_in or None,
|
||||
out_token=focus_out or None,
|
||||
model=focus_model,
|
||||
foc_uri, exp_uri,
|
||||
selected_edges_with_reasoning, session_id,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(foc_triples, foc_uri)
|
||||
|
||||
# Step 2: Synthesis - LLM generates answer from selected edges only
|
||||
# Step 4: Synthesis
|
||||
selected_edge_dicts = [
|
||||
{"s": s, "p": p, "o": o}
|
||||
for s, p, o in selected_edges
|
||||
]
|
||||
|
||||
# Add source document metadata as knowledge edges
|
||||
for s, p, o in source_documents:
|
||||
selected_edge_dicts.append({
|
||||
"s": s, "p": p, "o": o,
|
||||
|
|
@ -928,7 +716,6 @@ class GraphRag:
|
|||
}
|
||||
|
||||
if streaming and chunk_callback:
|
||||
# Accumulate chunks for answer storage while forwarding to callback
|
||||
accumulated_chunks = []
|
||||
|
||||
async def accumulating_callback(chunk, end_of_stream):
|
||||
|
|
@ -942,7 +729,6 @@ class GraphRag:
|
|||
chunk_callback=accumulating_callback
|
||||
)
|
||||
track_usage(synthesis_result)
|
||||
# Combine all chunks into full response
|
||||
resp = "".join(accumulated_chunks)
|
||||
else:
|
||||
synthesis_result = await self.prompt_client.prompt(
|
||||
|
|
@ -955,29 +741,42 @@ class GraphRag:
|
|||
if self.verbose:
|
||||
logger.debug("Query processing complete")
|
||||
|
||||
# Emit synthesis explain after synthesis completes
|
||||
# Emit synthesis explain
|
||||
if explain_callback:
|
||||
synthesis_doc_id = None
|
||||
answer_text = resp if resp else ""
|
||||
|
||||
# Save answer to librarian
|
||||
if save_answer_callback and answer_text:
|
||||
synthesis_doc_id = f"urn:trustgraph:synthesis:{session_id}"
|
||||
try:
|
||||
await save_answer_callback(synthesis_doc_id, answer_text)
|
||||
if self.verbose:
|
||||
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
|
||||
logger.debug(
|
||||
f"Saved answer to librarian: "
|
||||
f"{synthesis_doc_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save answer to librarian: {e}")
|
||||
logger.warning(
|
||||
f"Failed to save answer to librarian: {e}"
|
||||
)
|
||||
synthesis_doc_id = None
|
||||
|
||||
syn_triples = set_graph(
|
||||
synthesis_triples(
|
||||
syn_uri, foc_uri,
|
||||
document_id=synthesis_doc_id,
|
||||
in_token=synthesis_result.in_token if synthesis_result else None,
|
||||
out_token=synthesis_result.out_token if synthesis_result else None,
|
||||
model=synthesis_result.model if synthesis_result else None,
|
||||
in_token=(
|
||||
synthesis_result.in_token
|
||||
if synthesis_result else None
|
||||
),
|
||||
out_token=(
|
||||
synthesis_result.out_token
|
||||
if synthesis_result else None
|
||||
),
|
||||
model=(
|
||||
synthesis_result.model
|
||||
if synthesis_result else None
|
||||
),
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
|
|
@ -993,4 +792,3 @@ class GraphRag:
|
|||
}
|
||||
|
||||
return resp, usage
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from . graph_rag import GraphRag
|
|||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
||||
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
|
||||
from ... base import RerankerClientSpec
|
||||
from ... base import LibrarianSpec
|
||||
|
||||
# Module logger
|
||||
|
|
@ -32,7 +33,6 @@ class Processor(FlowProcessor):
|
|||
triple_limit = params.get("triple_limit", 30)
|
||||
max_subgraph_size = params.get("max_subgraph_size", 150)
|
||||
max_path_length = params.get("max_path_length", 2)
|
||||
edge_score_limit = params.get("edge_score_limit", 30)
|
||||
edge_limit = params.get("edge_limit", 25)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
|
|
@ -43,7 +43,6 @@ class Processor(FlowProcessor):
|
|||
"triple_limit": triple_limit,
|
||||
"max_subgraph_size": max_subgraph_size,
|
||||
"max_path_length": max_path_length,
|
||||
"edge_score_limit": edge_score_limit,
|
||||
"edge_limit": edge_limit,
|
||||
}
|
||||
)
|
||||
|
|
@ -52,7 +51,6 @@ class Processor(FlowProcessor):
|
|||
self.default_triple_limit = triple_limit
|
||||
self.default_max_subgraph_size = max_subgraph_size
|
||||
self.default_max_path_length = max_path_length
|
||||
self.default_edge_score_limit = edge_score_limit
|
||||
self.default_edge_limit = edge_limit
|
||||
|
||||
# Workspace isolation is enforced by the flow layer (flow.workspace).
|
||||
|
|
@ -96,6 +94,13 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
RerankerClientSpec(
|
||||
request_name = "reranker-request",
|
||||
response_name = "reranker-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "response",
|
||||
|
|
@ -163,6 +168,7 @@ class Processor(FlowProcessor):
|
|||
graph_embeddings_client=flow("graph-embeddings-request"),
|
||||
triples_client=flow("triples-request"),
|
||||
prompt_client=flow("prompt-request"),
|
||||
reranker_client=flow("reranker-request"),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
|
@ -186,11 +192,6 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
max_path_length = self.default_max_path_length
|
||||
|
||||
if v.edge_score_limit:
|
||||
edge_score_limit = v.edge_score_limit
|
||||
else:
|
||||
edge_score_limit = self.default_edge_score_limit
|
||||
|
||||
if v.edge_limit:
|
||||
edge_limit = v.edge_limit
|
||||
else:
|
||||
|
|
@ -225,7 +226,7 @@ class Processor(FlowProcessor):
|
|||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
edge_score_limit = edge_score_limit,
|
||||
|
||||
edge_limit = edge_limit,
|
||||
streaming = True,
|
||||
chunk_callback = send_chunk,
|
||||
|
|
@ -241,7 +242,7 @@ class Processor(FlowProcessor):
|
|||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
edge_score_limit = edge_score_limit,
|
||||
|
||||
edge_limit = edge_limit,
|
||||
explain_callback = send_explainability,
|
||||
save_answer_callback = save_answer,
|
||||
|
|
@ -338,18 +339,11 @@ class Processor(FlowProcessor):
|
|||
help=f'Default max path length (default: 2)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--edge-score-limit',
|
||||
type=int,
|
||||
default=30,
|
||||
help=f'Semantic pre-filter limit before LLM scoring (default: 30)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--edge-limit',
|
||||
type=int,
|
||||
default=25,
|
||||
help=f'Max edges after LLM scoring (default: 25)'
|
||||
help=f'Max edges selected per hop by cross-encoder (default: 25)'
|
||||
)
|
||||
|
||||
# Note: Explainability triples are now stored in the request's collection
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue