Merge branch 'release/v2.6' into master to keep sync.

This commit is contained in:
Cyber MacGeddon 2026-07-02 14:58:37 +01:00
commit 4aaa1ce915
59 changed files with 5617 additions and 864 deletions

218
README.dev-install.md Normal file
View file

@ -0,0 +1,218 @@
# TrustGraph Developer Install Guide
A guided installer that gets TrustGraph running locally in a single
command. It detects your hardware, recommends an LLM backend, installs
missing prerequisites, runs the test suite, generates a compose deployment,
starts the stack, and opens the Workbench UI.
> **macOS only.** This installer has only been tested on macOS. If you are
> on Linux or Windows, use the standard docker-compose / podman-compose
> installation instructions instead.
## Quick start
```bash
./install_trustgraph.sh
```
The installer walks you through each step interactively. When it finishes,
the Workbench UI opens at `http://localhost:8888` and the API gateway is
available at `http://localhost:8088/`.
## Prerequisites
The installer checks for these and offers to install any that are missing
(via Homebrew):
- **Python 3** with venv support
- **Node.js / npx** (drives the `@trustgraph/config` deployment generator)
- **Docker** (with Compose) or **Podman** (with podman-compose)
- **curl** and **unzip**
- **Ollama** (only if you choose local LLMs)
The installer can also launch Docker Desktop or the Ollama app for you if
they are installed but not running.
## What the installer does
1. **Detects hardware** -- OS, architecture, CPU cores, memory, and GPU.
2. **Recommends an LLM mode** -- `ollama` for machines with >= 16 GB RAM and
a GPU or >= 8 cores; `openai` otherwise.
3. **Collects configuration** -- API key, LLM provider, model choices,
install directory. Answers are saved to
`<install-dir>/trustgraph-installer.env` and reused on subsequent runs.
4. **Checks and installs prerequisites** -- Python, Node/npx, Docker or
Podman, Ollama (if selected).
5. **Downloads Ollama models** (if using Ollama) -- chat model
(`granite4:350m` by default) and embeddings model (`mxbai-embed-large`).
6. **Creates a Python venv** and installs the local TrustGraph packages into
it, along with NLTK data and tiktoken caches.
7. **Runs the full pytest suite** against the local source tree.
8. **Runs `npx @trustgraph/config`** -- the existing interactive config
wizard that produces a `deploy.zip` with a compose file.
9. **Starts the compose stack** and waits for the API gateway to respond.
10. **Bootstraps IAM** and verifies the API key authenticates.
11. **Opens the Workbench UI** in your default browser.
## Command-line options
| Option | Description |
|---|---|
| `--install-dir PATH` | Directory for deployment files (default: `./trustgraph-deploy`) |
| `--api-url URL` | API gateway URL for health checks (default: `http://localhost:8088/`) |
| `--ui-url URL` | Workbench UI URL to open (default: `http://localhost:8888`) |
| `--use-existing-compose FILE` | Skip config generation and start this compose file directly |
| `--skip-tests` | Do not run the pytest suite |
| `--no-launch` | Do not open the Workbench UI at the end |
| `--non-interactive` | Accept all defaults without prompting |
| `--yes` | Auto-accept confirmation prompts |
| `--fresh` | Remove installer-managed files before generating a new deployment |
| `--remove-all` | Uninstall: stop containers, remove compose volumes, delete installer files |
| `--dry-run` | Print detected hardware and planned defaults, then exit |
| `-h`, `--help` | Show the built-in help text |
## Environment variables
These override the interactive prompts when set:
| Variable | Purpose |
|---|---|
| `TRUSTGRAPH_TOKEN` | Admin/bootstrap API key (must start with `tg_`) |
| `TRUSTGRAPH_URL` | API gateway URL |
| `TRUSTGRAPH_UI_URL` | Workbench UI URL |
| `OPENAI_TOKEN` | OpenAI-compatible API key |
| `OPENAI_BASE_URL` | OpenAI-compatible base URL |
| `OLLAMA_HOST` / `OLLAMA_BASE_URL` | Ollama service URL |
| `OLLAMA_MODEL` | Ollama chat model (default: `granite4:350m`) |
| `OLLAMA_EMBEDDINGS_MODEL` | Ollama embeddings model (default: `mxbai-embed-large`) |
| `TG_INSTALL_DIR` | Override the install directory |
| `TG_VENV_DIR` | Override the Python venv location |
| `TG_NLTK_DATA_DIR` | Override the NLTK data directory |
| `TIKTOKEN_CACHE_DIR` | Override the tiktoken cache directory |
| `TG_HEALTH_TIMEOUT` | Seconds to wait for the API gateway (default: 240) |
## Choosing an LLM mode
### OpenAI (or any OpenAI-compatible provider)
Best when you already have an API key or are running against a remote
endpoint. The installer asks for a base URL and an API key.
```bash
OPENAI_TOKEN=sk-... ./install_trustgraph.sh
```
### Ollama (local models)
Best on machines with enough RAM to run a small model. The installer detects
locally installed Ollama models and offers to pull missing ones. It uses
`host.docker.internal` so the Docker containers can reach the host-side
Ollama service.
```bash
./install_trustgraph.sh # choose "ollama" when prompted
```
### None
Start the platform without an LLM. Agent and RAG features will not work
until you configure one later through the Workbench.
## Saved answers and re-running
The installer saves your answers to
`<install-dir>/trustgraph-installer.env`. On the next run it loads those
answers as defaults, so you can re-run with a single Enter through each
prompt.
To start completely fresh:
```bash
./install_trustgraph.sh --fresh
```
This stops any running containers (keeping Docker volumes), removes
installer-managed files, and re-runs the full flow.
## Using an existing compose file
If you already have a compose file from the config tool or another source:
```bash
./install_trustgraph.sh --use-existing-compose path/to/docker-compose.yaml
```
This skips the config wizard and `npx` prerequisite check, and goes straight
to starting the stack.
## Non-interactive / CI usage
```bash
TRUSTGRAPH_TOKEN=tg_my-token \
OPENAI_TOKEN=sk-... \
./install_trustgraph.sh --non-interactive --yes --skip-tests
```
In non-interactive mode the installer uses defaults for every prompt. Pair
with `--yes` to auto-accept confirmation prompts and `--skip-tests` if you
want a faster run.
## Dry run
Preview what the installer would do without making any changes:
```bash
./install_trustgraph.sh --dry-run
```
This prints the detected hardware, recommended LLM mode, and planned
install paths, then exits.
## Uninstalling
```bash
./install_trustgraph.sh --remove-all
```
This stops containers, removes compose-managed volumes, and deletes
installer-managed files (venv, deploy output, logs, saved answers). It does
**not** remove Docker/Podman itself, container images, Ollama, or Ollama
models.
## Troubleshooting
### Logs
All long-running operations write logs to `<install-dir>/logs/`. Key files:
- `pytest.log` -- test suite output
- `compose-up.log` -- docker compose output
- `iam-bootstrap.log` -- IAM bootstrap output
- `ollama-pull-*.log` -- Ollama model downloads
- `pip-*.log` -- Python package installs
- `brew-install-*.log` -- Homebrew installs
### API key rejected after reinstall
If the API gateway returns 401/403 with your saved key, the compose volumes
likely contain IAM data from a previous install with a different key. Run:
```bash
./install_trustgraph.sh --remove-all
./install_trustgraph.sh
```
This clears the old volumes and starts fresh.
### Ollama not reachable from containers
The Ollama base URL should use `host.docker.internal` instead of
`localhost` so that containers running in Docker Desktop can reach the
host-side Ollama service. The installer sets this automatically; if you
override `OLLAMA_HOST`, make sure the URL is reachable from inside the
container network.
### Docker daemon not running
The installer detects Docker Desktop and offers to start it. If that
doesn't work, start Docker Desktop manually and re-run the installer.

View file

@ -0,0 +1,523 @@
# GraphRAG Semantic Filter Improvement
## Problem Statement
The GraphRAG semantic filter is observed to be ineffective with certain
LLM models. Smaller models in particular produce poor-quality edge
relevance scores, and there is a suspicion that models trained or
evaluated heavily on non-Roman-script datasets offer lower performance
on the semantic ranking operation.
The root cause is that the current implementation delegates edge
relevance scoring to the LLM via a prompt that asks the model to
assign a 110 relevance score to each knowledge-graph edge. This
task — ranking structured triples for relevance to a natural-language
query — is not well covered in standard LLM evaluation suites, so
model benchmark scores are not predictive of performance on this
operation. The result is that GraphRAG quality varies unpredictably
across model choices, undermining confidence in the pipeline.
Beyond model variability, the LLM scoring step has further problems:
- **Cost and latency.** The LLM call consumes tokens and adds
latency to every query, yet its output is unreliable. Even when
the model performs well, the cost is disproportionate for what is
fundamentally a ranking operation.
- **Subjective scoring scale.** The 110 relevance scale gives the
model no objective criteria for what constitutes a 5 versus a 7.
Different models interpret the scale differently, and even the same
model can produce inconsistent scores across runs.
- **Redundancy with the embedding pre-filter.** The pipeline already
contains a cosine-similarity stage that ranks edges by semantic
relevance using embeddings. The LLM scoring step is a second
filter applied on top of this, and it is not clear that it adds
enough value to justify the additional cost and risk of
degradation.
### Industry context
Semantic ranking is rigorously evaluated on dedicated benchmarks such
as MTEB (Massive Text Embedding Benchmark) and BEIR (Benchmarking
Information Retrieval), which test retrieval and reranking across
diverse domains. The current TrustGraph approach — prompting a
general-purpose LLM to score and rank documents (the "listwise"
approach) — is known to be poorly optimized for this task. It
suffers from positional bias, formatting failures, and
inconsistency at scale.
The industry standard for semantic ranking has moved to
cross-encoder models: lightweight, purpose-built models that take a
querydocument pair as input and produce a single relevance score.
These models are fine-tuned on millions of relevance-labelled pairs
and dominate retrieval benchmarks. They are fast, deterministic,
and do not require an LLM inference call.
## Architecture
### Cross-encoder service
A new request/response service that exposes a generic semantic
ranking API. The service is not specific to GraphRAG — it is a
reusable building block for any component that needs to rank text
by relevance.
The service interface is pluggable. Alternative implementations
can be swapped in behind the same API.
**Packaging options considered:**
- *`sentence-transformers`.* Full-featured, widely used.
However, it pulls in PyTorch (~2 GB), making containers
very large. Tested at ~1.8 seconds for 2200 edges.
- *`optimum.onnxruntime`.* ONNX-based inference. Still
depends on PyTorch at import time despite using ONNX for
inference. Tested at ~4.2 seconds for 2200 edges.
- *`flashrank`.* Lightweight wrapper around ONNX Runtime
with a clean API (`Ranker`, `RerankRequest`). No PyTorch
dependency. Tested at ~4.4 seconds for 2200 edges.
- *Pure `onnxruntime` + `tokenizers`.* Leanest option
(~200 MB total). Requires manual tokenisation, padding,
and numpy array management — more boilerplate to maintain.
- *External API (e.g. Cohere Rerank).* No local model at
all. Adds network latency and an external dependency.
**Decision:** `flashrank` for the initial implementation.
No PyTorch dependency, clean API, comparable performance.
The pluggable interface allows swapping to another backend
later.
**Request:**
- `queries` — list of `{id, text}` objects. In the GraphRAG use
case these are the concepts extracted from the user's question.
- `documents` — list of `{id, text}` objects. In the GraphRAG
use case these are the candidate knowledge-graph edges
represented as text.
- `limit` — integer. Maximum number of results to return.
**Scoring:**
The service produces the cartesian product of all querydocument
pairs and scores each pair through the cross-encoder model. For
each document, the maximum score across all queries is taken as the
document's relevance score. Documents are then ranked by this
score and the top `limit` results are returned.
**Response:**
A list of the top `limit` results, each containing:
- `document_id` — the ID of the matched document.
- `query_id` — the ID of the query (concept) that produced the
highest score for this document.
- `score` — the relevance score.
Including `query_id` in the response supports the explainability
interface: it records that an edge was selected because it is
related to a specific concept.
### Integration
The cross-encoder service follows the standard TrustGraph service
integration pattern:
- **Base package (trustgraph-base).** Schema definitions for the
cross-encoder request/response messages. A client class that
other components (e.g. GraphRAG) can use to call the
cross-encoder service. Message translator registration so the
pub/sub layer can serialise/deserialise the messages.
- **Flow package (trustgraph-flow).** The cross-encoder service
implementation itself — loads the model, listens for requests,
scores pairs, returns results. Flow definition support so the
cross-encoder can be introduced into a processing flow via the
standard flow configuration. `flashrank` is added as a
dependency of `trustgraph-flow`. The service runs in its own
container.
- **API gateway.** A gateway endpoint that routes cross-encoder
requests from the HTTP API to the service over pub/sub and
returns the response.
- **CLI tool.** A command-line utility
(e.g. `tg-invoke-cross-encoder`) that calls the gateway
endpoint for manual testing and debugging.
### Current GraphRAG pipeline
The current pipeline follows these steps:
1. **Concept extraction.** An LLM prompt extracts key concepts
from the user's query.
2. **Graph exploration.** Seed entities are found via embedding
similarity. A subgraph is built by multi-hop traversal from
the seed entities (up to `max_path_length` hops, capped at
`max_subgraph_size` edges).
3. **Embedding pre-filter.** Each edge is embedded as
`"subject, predicate, object"` and scored by cosine similarity
against the concept embeddings. The top `edge_score_limit`
(default 30) edges are kept.
4. **LLM edge scoring.** The `kg-edge-scoring` prompt asks the
LLM to assign a 110 relevance score to each remaining edge.
The top `edge_limit` (default 25) edges are kept.
5. **LLM edge reasoning.** The `kg-edge-reasoning` prompt asks
the LLM to explain why each selected edge is relevant to the
query. Used for the explainability interface.
6. **Document tracing.** Selected edges are traced back to their
source documents in the librarian. Runs concurrently with
step 5.
7. **Synthesis.** The `kg-synthesis` prompt generates the final
answer from the selected edges and source document metadata.
### Potential improvements
#### Replace LLM edge scoring with cross-encoder (step 4)
The LLM edge scoring step is replaced by a call to the
cross-encoder service. The candidate edges are the documents and
`edge_limit` is the limit. This is a direct substitution: faster,
cheaper, deterministic, and more reliable across model choices.
The LLM `kg-edge-scoring` prompt is retired.
**Cross-encoder query input: concepts vs. raw query.** There are
two options for what to use as the cross-encoder queries:
- *Option A: Raw user query.* Pass the original question as a
single query string. Simpler, no dependency on concept
extraction. However, raw queries contain noise words and
conversational phrasing that do not match well against the
structured vocabulary of knowledge-graph edges. A single query
also means every edge competes against the full question — a
partial match on one aspect is diluted.
- *Option B: Extracted concepts.* Pass the concepts from step 1
as separate queries. The concepts are distilled, focused terms
that are closer to the language of the edges. With multiple
concepts as independent queries, the cross-encoder scores each
edge against each concept separately, giving better coverage —
an edge only needs to match one concept well to be selected.
The trade-off is a dependency on the LLM concept extraction
step, but this is already in the pipeline and is a lightweight,
reliable LLM call.
**Decision:** Option B — use extracted concepts. The concept
extraction is fast, and the resulting terms produce better
cross-encoder matches against structured triples.
#### Edge text representation
The current embedding pre-filter represents each edge as
`"subject, predicate, object"`. Two changes:
- **Drop commas.** Commas add tokenisation noise without semantic
value.
- **Drop the subject.** The subject identifies which entity the
edge belongs to, but it does not contribute to whether the
edge's content is relevant to the query. The predicate and
object carry the semantic meaning — what relationship exists
and what it connects to. Representing edges as `"{p} {o}"`
produces cleaner cross-encoder matches.
#### Remove the embedding pre-filter (step 3)
The embedding pre-filter was introduced to reduce the number of
edges before the expensive LLM scoring call. With the
cross-encoder replacing the LLM call, this cost equation changes.
**Arguments for removal:**
- The cross-encoder is fast enough to score the full subgraph
directly. In testing, 2200 edges scored in ~1.8 seconds; at
the default `max_subgraph_size` of 150 edges, scoring takes
a fraction of a second.
- The pre-filter is a weaker version of what the cross-encoder
does. Bi-encoder cosine similarity embeds the query and
document independently and compares vectors; the cross-encoder
processes both texts together through the full transformer,
giving it much better relevance judgement. Running a weaker
filter before a stronger one adds latency without improving
quality.
- Removing it eliminates an embedding service call (two batches:
concepts + edges) and the associated latency.
**Arguments for keeping it:**
- If the subgraph is very large (thousands of edges), the
cross-encoder's linear scaling could become a bottleneck.
The pre-filter would act as a safety valve.
- The embedding call is cheap compared to an LLM call, so the
overhead is modest.
**Decision:** Remove the pre-filter. The `max_subgraph_size`
parameter (default 150) already caps the number of edges entering
this stage, so the cross-encoder will not face an unbounded
workload. If very large subgraphs become a concern in future,
the pre-filter can be reintroduced or `max_subgraph_size` can be
tuned.
#### Iterative graph traversal with cross-encoder filtering
The current pipeline performs graph exploration and edge filtering
as separate phases: first build the full subgraph (up to
`max_path_length` hops), then score and filter edges. An
alternative is to interleave traversal and filtering — at each
hop, use the cross-encoder to select relevant edges before
expanding further.
**Option A: Big-bang traversal then filter.** Traverse the full
subgraph up to `max_path_length` hops from the seed entities,
collecting all edges up to `max_subgraph_size`. Then
cross-encode the entire result to select the top edges.
- Simple to implement — the current traversal logic is largely
unchanged.
- Produces large, unfocused subgraphs. Irrelevant branches are
explored and scored even though they will be discarded.
- Poorly suited to multi-hop reasoning. For a query about
Voyager 1, the subgraph includes Voyager 2's edges because
they are within hop distance, and the filter must then
separate them.
**Option B: Iterative hop-and-filter.** At each hop:
1. Retrieve all edges one hop from the current frontier nodes.
2. Cross-encode these edges against the query concepts.
3. Select the top relevant edges.
4. The target nodes of the selected edges become the frontier
for the next hop.
5. Repeat up to `max_path_length` hops.
The final set of selected edges across all hops is the input to
synthesis.
- **Guided exploration.** Each hop focuses the search by
pruning irrelevant branches before expanding further. The
working set stays small and relevant at every step.
- **Multi-hop reasoning works naturally.** Following
"Voyager 1 → has-event → crossed the heliopause" succeeds
because each hop is individually relevant and leads to the
next.
- **Smaller total workload.** Fewer edges are scored overall
because irrelevant branches are never expanded.
- **Trade-off: greedy pruning.** An edge discarded at hop 1
cannot lead to relevant edges at hop 2. This is inherent in
any bounded traversal, and the cross-encoder is better
equipped to make this relevance judgement than a blind hop
limit.
- **Trade-off: sequential latency.** Hops cannot be
parallelised since each depends on the previous. However,
each cross-encoder call on a small edge set is very fast
(sub-second for typical working sets).
**Decision:** Option B — iterative hop-and-filter. The guided
traversal produces more focused subgraphs and supports multi-hop
reasoning, which is a significant quality improvement over the
current approach.
#### Replace LLM edge reasoning with cross-encoder metadata (step 5)
The current `kg-edge-reasoning` prompt asks the LLM to explain why
each edge is relevant. With the cross-encoder now making the
selection, this explanation would be a post-hoc fabrication — the
LLM was not involved in the decision.
- *Option A: Keep LLM reasoning.* Generates natural-language
explanations but they are not grounded in the actual selection
process. Adds an LLM call per query.
- *Option B: Record cross-encoder metadata.* The cross-encoder
already returns the matched concept and score for each selected
edge. Use this directly as the explanation.
**Decision:** Option B. The cross-encoder metadata is the true
reason the edge was selected. The `kg-edge-reasoning` prompt is
retired.
#### Explainability interface update
The explainability interface uses a `Focus` entity containing
`EdgeSelection` sub-entities. Each `EdgeSelection` currently
carries an `edge` (the quoted triple) and a `reasoning` field
(free-text LLM prose), stored as `tg:reasoning` in the
provenance graph.
With the cross-encoder replacing LLM reasoning, the
`EdgeSelection` type gains two new predicates and drops one:
- **Remove** `tg:reasoning` — no longer produced.
- **Add** `tg:concept` — the concept text that produced the
highest cross-encoder score for this edge.
- **Add** `tg:score` — the cross-encoder relevance score.
This is an evolution of the existing `EdgeSelection` type, not a
new entity type. The edge selection sub-entities currently have
no `rdf:type` declared; a new `tg:EdgeSelection` type should be
added so that consumers can identify them in the provenance
graph. The `Focus` entity and its relationship to `Exploration`
are unchanged.
The `Focus` entity's token-usage metadata (`tg:inToken`,
`tg:outToken`, `tg:llmModel`) no longer applies since there is
no LLM call. These fields are dropped from the Focus entity.
### Proposed pipeline
1. **Concept extraction.** Unchanged — LLM extracts key concepts
from the user's query.
2. **Seed entity lookup.** Find seed entities via embedding
similarity against the extracted concepts.
3. **Iterative hop-and-filter.** For each hop up to
`max_path_length`:
a. Retrieve all edges one hop from the current frontier nodes.
b. Represent each edge as `"{predicate} {object}"`.
c. Score edges against the extracted concepts using the
cross-encoder service.
d. Select the top relevant edges. The target nodes of the
selected edges become the frontier for the next hop.
4. **Document tracing.** Selected edges are traced back to source
documents.
5. **Synthesis.** The `kg-synthesis` prompt generates the final
answer from the selected edges and source document metadata.
### Implementation order
1. Cross-encoder service with full integration (base schema,
flow service, gateway endpoint, CLI tool).
2. GraphRAG pipeline changes (iterative hop-and-filter,
edge representation, remove pre-filter).
3. Explainability update (`tg:EdgeSelection` type, concept
and score predicates, retire `tg:reasoning`).
4. Retire `kg-edge-scoring` and `kg-edge-reasoning` prompts.
5. Update `tg-invoke-graph-rag` and `tg-show-explain-trace`
to display the new metadata. Use these as the main
end-to-end test.
6. Fix any failing unit tests, then add new tests as needed.
7. Write guidance for UX devs to update the UI for the new
explainability predicates.
## UX developer guidance
This section describes the changes to the explainability interface
that affect frontend rendering of GraphRAG Focus events.
### What changed
Edge selection in GraphRAG previously used LLM-based scoring and
reasoning. Each selected edge carried a `tg:reasoning` predicate
with free-text explanation from the LLM. This has been replaced
by a cross-encoder reranker that scores edges against query
concepts. The explainability data now carries structured metadata
instead of free text.
### Removed
- **`tg:reasoning`** is no longer emitted on edge selection
entities in GraphRAG Focus events. UX code that reads
`edge_sel.reasoning` will get an empty string. Remove any
rendering that displays a "Reasoning" or "Reason" field for
Focus edges.
- The **`kg-edge-scoring`**, **`kg-edge-reasoning`**, and
**`kg-edge-selection`** prompts are retired. Any UX that
references these prompt names should be cleaned up.
### Added
Each edge selection entity within a Focus event now has three
new properties:
| RDF predicate | API field | Type | Description |
|---|---|---|---|
| `rdf:type tg:EdgeSelection` | (type check) | — | Each edge selection entity is now explicitly typed |
| `tg:concept` | `edge_sel.concept` | `str` | The query concept that matched this edge |
| `tg:score` | `edge_sel.score` | `float` or `None` | Cross-encoder relevance score (0.01.0) |
The `tg:edge` predicate (RDF-star quoted triple) is unchanged.
### How to render
The recommended rendering for each selected edge in a Focus event:
```
Edge: (subject_label, predicate_label, object_label)
Concept: <concept> Score: <score formatted to 4 decimal places>
```
Scores near 1.0 indicate high relevance; scores near 0.0 indicate
low relevance. UX could use the score to drive visual indicators
such as colour intensity or a relevance bar.
Edges are not returned in score order — they arrive in traversal
order across hops. If the UX wants to display edges ranked by
relevance, sort by `edge_sel.score` descending.
### API classes (Python)
The `EdgeSelection` dataclass in `trustgraph.api.explainability`
has these fields:
```python
@dataclass
class EdgeSelection:
uri: str
edge: Optional[Dict[str, str]] # {"s": ..., "p": ..., "o": ...}
reasoning: str = "" # Legacy, always empty for new traces
concept: str = "" # Query concept that matched
score: Optional[float] = None # Cross-encoder relevance score
```
These are populated when calling
`ExplainabilityClient.fetch_focus_with_edges()` or when parsing
inline provenance triples from the streaming response.
### WebSocket response format
For inline explainability via the streaming WebSocket, Focus events
arrive as `message_type: "explain"` responses. The `explain_triples`
array contains the edge selection triples. The relevant predicates
in wire format are:
```json
{"s": {"t": "i", "i": "<edge_sel_uri>"},
"p": {"t": "i", "i": "https://trustgraph.ai/ns/concept"},
"o": {"t": "l", "v": "flyby event"}}
{"s": {"t": "i", "i": "<edge_sel_uri>"},
"p": {"t": "i", "i": "https://trustgraph.ai/ns/score"},
"o": {"t": "l", "v": "0.9962"}}
```
Note that `tg:score` is transmitted as a string literal and must
be parsed to a float on the client side.
### Exploration event
The Exploration event's `edge_count` field now reports the number
of edges selected by the cross-encoder across all hops (previously
it reported the total number of edges retrieved before filtering).
The `entities` list continues to report the seed entities found
by vector search.

2603
install_trustgraph.sh Normal file

File diff suppressed because it is too large Load diff

View file

@ -95,10 +95,6 @@ class TestGraphRagIntegration:
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "extract-concepts":
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-scoring":
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-reasoning":
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
return PromptResult(
response_type="text",
@ -113,14 +109,22 @@ class TestGraphRagIntegration:
client.prompt.side_effect = mock_prompt
return client
@pytest.fixture
def mock_reranker_client(self):
"""Mock reranker client for cross-encoder edge filtering"""
client = AsyncMock()
client.rerank.return_value = []
return client
@pytest.fixture
def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client,
mock_triples_client, mock_prompt_client):
mock_triples_client, mock_reranker_client, mock_prompt_client):
"""Create GraphRag instance with mocked dependencies"""
return GraphRag(
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
reranker_client=mock_reranker_client,
prompt_client=mock_prompt_client,
verbose=True
)
@ -167,8 +171,8 @@ class TestGraphRagIntegration:
# 3. Should query triples to build knowledge subgraph
assert mock_triples_client.query_stream.call_count > 0
# 4. Should call prompt four times (extract-concepts + edge-scoring + edge-reasoning + synthesis)
assert mock_prompt_client.prompt.call_count == 4
# 4. Should call prompt twice (extract-concepts + synthesis)
assert mock_prompt_client.prompt.call_count == 2
# Verify final response
response, usage = response

View file

@ -63,11 +63,6 @@ class TestGraphRagStreaming:
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
if prompt_id == "extract-concepts":
return PromptResult(response_type="text", text="")
elif prompt_id == "kg-edge-scoring":
# Edge scoring returns JSONL with IDs and scores
return PromptResult(response_type="text", text='{"id": "abc12345", "score": 0.9}\n')
elif prompt_id == "kg-edge-reasoning":
return PromptResult(response_type="text", text='{"id": "abc12345", "reasoning": "Relevant to query"}\n')
elif prompt_id == "kg-synthesis":
if streaming and chunk_callback:
# Simulate streaming chunks with end_of_stream flags
@ -88,14 +83,23 @@ class TestGraphRagStreaming:
client.prompt.side_effect = prompt_side_effect
return client
@pytest.fixture
def mock_reranker_client(self):
"""Mock reranker client for cross-encoder edge filtering"""
client = AsyncMock()
client.rerank.return_value = []
return client
@pytest.fixture
def graph_rag_streaming(self, mock_embeddings_client, mock_graph_embeddings_client,
mock_triples_client, mock_streaming_prompt_client):
mock_triples_client, mock_reranker_client,
mock_streaming_prompt_client):
"""Create GraphRag instance with streaming support"""
return GraphRag(
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
reranker_client=mock_reranker_client,
prompt_client=mock_streaming_prompt_client,
verbose=True
)

View file

@ -46,7 +46,7 @@ class TestGraphRagStreamingProtocol:
client = AsyncMock()
async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
if prompt_name == "extract-concepts":
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
if streaming and chunk_callback:
@ -63,14 +63,23 @@ class TestGraphRagStreamingProtocol:
client.prompt.side_effect = prompt_side_effect
return client
@pytest.fixture
def mock_reranker_client(self):
"""Mock reranker client for cross-encoder edge filtering"""
client = AsyncMock()
client.rerank.return_value = []
return client
@pytest.fixture
def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client,
mock_triples_client, mock_streaming_prompt_client):
mock_triples_client, mock_reranker_client,
mock_streaming_prompt_client):
"""Create GraphRag instance with mocked dependencies"""
return GraphRag(
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
reranker_client=mock_reranker_client,
prompt_client=mock_streaming_prompt_client,
verbose=False
)
@ -327,7 +336,7 @@ class TestStreamingProtocolEdgeCases:
client = AsyncMock()
async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
if prompt_name == "extract-concepts":
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
if streaming and chunk_callback:
@ -342,10 +351,14 @@ class TestStreamingProtocolEdgeCases:
client.prompt.side_effect = prompt_with_empties
mock_reranker = AsyncMock()
mock_reranker.rerank.return_value = []
rag = GraphRag(
embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[[0.1]]])),
graph_embeddings_client=AsyncMock(query=AsyncMock(return_value=[])),
triples_client=AsyncMock(query=AsyncMock(return_value=[])),
reranker_client=mock_reranker,
prompt_client=client,
verbose=False
)

View file

@ -15,11 +15,20 @@ from openai.types.chat.chat_completion import Choice
from openai.types.completion_usage import CompletionUsage
from trustgraph.model.text_completion.openai.llm import Processor
from trustgraph.model.text_completion.openai.variants import get_variant
from trustgraph.exceptions import TooManyRequests
from trustgraph.base import LlmResult
from trustgraph.schema import TextCompletionRequest, TextCompletionResponse, Error
def _wire_variant(processor):
"""Attach variant methods to a MagicMock processor."""
processor.variant = get_variant("openai")
processor.thinking = "off"
processor._build_kwargs = Processor._build_kwargs.__get__(processor, Processor)
processor._extract_content = Processor._extract_content.__get__(processor, Processor)
@pytest.mark.integration
class TestTextCompletionIntegration:
"""Integration tests for OpenAI text completion service coordination"""
@ -66,6 +75,7 @@ class TestTextCompletionIntegration:
# Add the actual generate_content method from Processor class
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
_wire_variant(processor)
return processor
@ -119,6 +129,7 @@ class TestTextCompletionIntegration:
# Add the actual generate_content method
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
_wire_variant(processor)
# Act
result = await processor.generate_content("System prompt", "User prompt")
@ -129,7 +140,7 @@ class TestTextCompletionIntegration:
assert result.in_token == 50
assert result.out_token == 100
# Note: result.model comes from mock response, not processor config
# Verify configuration was applied
call_args = mock_openai_client.chat.completions.create.call_args
assert call_args.kwargs['model'] == config['model']
@ -247,6 +258,7 @@ class TestTextCompletionIntegration:
processor.max_output = processor_config["max_output"]
processor.openai = mock_openai_client
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
_wire_variant(processor)
processors.append(processor)
# Simulate multiple concurrent requests
@ -354,6 +366,7 @@ class TestTextCompletionIntegration:
processor.max_output = 2048
processor.openai = mock_openai_client
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
_wire_variant(processor)
# Act
await processor.generate_content("System prompt", "User prompt")

View file

@ -11,6 +11,7 @@ from openai.types.chat import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as StreamChoice, ChoiceDelta
from trustgraph.model.text_completion.openai.llm import Processor
from trustgraph.model.text_completion.openai.variants import get_variant
from trustgraph.base import LlmChunk
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
@ -18,6 +19,14 @@ from tests.utils.streaming_assertions import (
)
def _wire_variant(processor):
"""Attach variant methods to a MagicMock processor."""
processor.variant = get_variant("openai")
processor.thinking = "off"
processor._build_kwargs = Processor._build_kwargs.__get__(processor, Processor)
processor._extract_content = Processor._extract_content.__get__(processor, Processor)
@pytest.mark.integration
class TestTextCompletionStreaming:
"""Integration tests for Text Completion streaming"""
@ -69,6 +78,7 @@ class TestTextCompletionStreaming:
processor.generate_content_stream = Processor.generate_content_stream.__get__(
processor, Processor
)
_wire_variant(processor)
return processor
@ -190,6 +200,7 @@ class TestTextCompletionStreaming:
processor.generate_content_stream = Processor.generate_content_stream.__get__(
processor, Processor
)
_wire_variant(processor)
# Act
chunks = []
@ -223,6 +234,7 @@ class TestTextCompletionStreaming:
processor.generate_content_stream = Processor.generate_content_stream.__get__(
processor, Processor
)
_wire_variant(processor)
# Act
chunks = []
@ -258,6 +270,7 @@ class TestTextCompletionStreaming:
processor.generate_content_stream = Processor.generate_content_stream.__get__(
processor, Processor
)
_wire_variant(processor)
# Act & Assert
with pytest.raises(Exception) as exc_info:
@ -295,6 +308,7 @@ class TestTextCompletionStreaming:
processor.generate_content_stream = Processor.generate_content_stream.__get__(
processor, Processor
)
_wire_variant(processor)
# Act
chunks = []
@ -318,6 +332,7 @@ class TestTextCompletionStreaming:
processor.generate_content_stream = Processor.generate_content_stream.__get__(
processor, Processor
)
_wire_variant(processor)
system_prompt = "You are an expert."
user_prompt = "Explain quantum physics."

View file

@ -195,38 +195,6 @@ class TestPromptClientStreamingCallback:
assert callback.call_args_list[0] == call("test", False)
assert callback.call_args_list[1] == call("", True)
@pytest.mark.asyncio
async def test_kg_prompt_passes_parameters_to_callback(self, prompt_client):
"""Test that kg_prompt correctly passes streaming parameters"""
# Arrange
async def mock_request(request, recipient=None, timeout=600):
if recipient:
responses = [
PromptResponse(text="Answer", object=None, error=None, end_of_stream=False),
PromptResponse(text="", object=None, error=None, end_of_stream=True),
]
for resp in responses:
should_stop = await recipient(resp)
if should_stop:
break
prompt_client.request = mock_request
callback = AsyncMock()
# Act
await prompt_client.kg_prompt(
query="What is machine learning?",
kg=[("subject", "predicate", "object")],
streaming=True,
chunk_callback=callback
)
# Assert
assert callback.call_count == 2
assert callback.call_args_list[0] == call("Answer", False)
assert callback.call_args_list[1] == call("", True)
@pytest.mark.asyncio
async def test_document_prompt_passes_parameters_to_callback(self, prompt_client):
"""Test that document_prompt correctly passes streaming parameters"""

View file

@ -107,6 +107,7 @@ class TestGraphRagDagStructure:
embeddings_client = AsyncMock()
graph_embeddings_client = AsyncMock()
triples_client = AsyncMock()
reranker_client = AsyncMock()
embeddings_client.embed.return_value = [[0.1, 0.2]]
graph_embeddings_client.query.return_value = [
@ -121,27 +122,22 @@ class TestGraphRagDagStructure:
]
triples_client.query.return_value = []
result = MagicMock()
result.document_id = "0"
result.query_id = "0"
result.score = 0.95
reranker_client.rerank.return_value = [result]
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return PromptResult(response_type="text", text="concept")
elif template_id == "kg-edge-scoring":
edges = variables.get("knowledge", [])
return PromptResult(
response_type="jsonl",
objects=[{"id": e["id"], "score": 10} for e in edges],
)
elif template_id == "kg-edge-reasoning":
edges = variables.get("knowledge", [])
return PromptResult(
response_type="jsonl",
objects=[{"id": e["id"], "reasoning": "relevant"} for e in edges],
)
elif template_id == "kg-synthesis":
return PromptResult(response_type="text", text="Answer.")
return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
return (prompt_client, embeddings_client, graph_embeddings_client,
triples_client, reranker_client)
@pytest.mark.asyncio
async def test_dag_chain(self, mock_clients):
@ -152,7 +148,7 @@ class TestGraphRagDagStructure:
events.append({"explain_id": explain_id, "triples": triples})
await rag.query(
query="test", explain_callback=explain_cb, edge_score_limit=0,
query="test", explain_callback=explain_cb,
)
dag = _collect_events(events)

View file

@ -101,27 +101,27 @@ class TestQuery:
assert query.rag == mock_rag
assert query.collection == "test_collection"
assert query.verbose is False
assert query.doc_limit == 20 # Default value
assert query.fetch_limit == 20 # Default value
def test_query_initialization_with_custom_doc_limit(self):
"""Test Query initialization with custom doc_limit"""
def test_query_initialization_with_custom_fetch_limit(self):
"""Test Query initialization with custom fetch_limit"""
# Create mock DocumentRag
mock_rag = MagicMock()
# Initialize Query with custom doc_limit
# Initialize Query with custom fetch_limit
query = Query(
rag=mock_rag,
workspace="test_workspace",
collection="custom_collection",
verbose=True,
doc_limit=50
fetch_limit=50
)
# Verify initialization
assert query.rag == mock_rag
assert query.collection == "custom_collection"
assert query.verbose is True
assert query.doc_limit == 50
assert query.fetch_limit == 50
@pytest.mark.asyncio
async def test_extract_concepts(self):
@ -224,7 +224,7 @@ class TestQuery:
workspace="test_workspace",
collection="test_collection",
verbose=False,
doc_limit=15
fetch_limit=15
)
# Call get_docs with concepts list
@ -377,7 +377,7 @@ class TestQuery:
workspace="test_workspace",
collection="test_collection",
verbose=True,
doc_limit=5
fetch_limit=5
)
# Call get_docs with concepts
@ -615,7 +615,7 @@ class TestQuery:
workspace="test_workspace",
collection="test_collection",
verbose=False,
doc_limit=10
fetch_limit=10
)
docs, chunk_ids = await query.get_docs(["concept A", "concept B"])

View file

@ -0,0 +1,478 @@
"""
Tests for the optional cross-encoder reranking pass in DocumentRag.query().
Two behaviours are covered:
1. No-op: when no reranker_client is wired (the default), query() must feed
the LLM the exact same chunks, in the same order, that retrieval produced
- byte-identical to the pre-reranker behaviour - and must NOT emit a
chunk-selection provenance event.
2. Rerank: when a reranker_client is wired, the retrieved chunks are reordered
and truncated according to the reranker's results, the LLM receives the
reranked top-N, and a tg:ChunkSelection (focus) provenance event is emitted
carrying the per-surviving-chunk scores and chunk references.
These are pure orchestration tests - the reranker is a stub, so there is no
torch / network dependency.
"""
import pytest
from unittest.mock import AsyncMock
from dataclasses import dataclass
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.base import PromptResult
from trustgraph.schema import RerankerResult
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_WAS_DERIVED_FROM,
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
TG_FOCUS, TG_SYNTHESIS,
TG_CHUNK_SELECTION, TG_SELECTED_CHUNK, TG_SCORE, TG_DOCUMENT,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def find_triple(triples, predicate, subject=None):
for t in triples:
if t.p.iri == predicate:
if subject is None or t.s.iri == subject:
return t
return None
def find_triples(triples, predicate, subject=None):
return [
t for t in triples
if t.p.iri == predicate
and (subject is None or t.s.iri == subject)
]
def has_type(triples, subject, rdf_type):
return any(
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
for t in triples
)
def derived_from(triples, subject):
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
return t.o.iri if t else None
@dataclass
class ChunkMatch:
"""Mimics the result from doc_embeddings_client.query()."""
chunk_id: str
# ---------------------------------------------------------------------------
# Fixtures: three retrievable chunks
# ---------------------------------------------------------------------------
CHUNK_A = "urn:chunk:policy-doc-1:chunk-0"
CHUNK_B = "urn:chunk:policy-doc-1:chunk-1"
CHUNK_C = "urn:chunk:policy-doc-1:chunk-2"
CHUNK_A_CONTENT = "Customers may return items within 30 days of purchase."
CHUNK_B_CONTENT = "Our stores are open from 9am to 5pm on weekdays."
CHUNK_C_CONTENT = "Refunds are processed to the original payment method."
# Retrieval (post-dedupe) order is A, B, C.
ORDERED_CONTENT = [CHUNK_A_CONTENT, CHUNK_B_CONTENT, CHUNK_C_CONTENT]
ORDERED_CHUNK_IDS = [CHUNK_A, CHUNK_B, CHUNK_C]
def build_mock_clients():
"""
Build mock subsidiary clients for a document-rag query returning three
distinct chunks (A, B, C) in that order.
"""
prompt_client = AsyncMock()
embeddings_client = AsyncMock()
doc_embeddings_client = AsyncMock()
fetch_chunk = AsyncMock()
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return PromptResult(response_type="text", text="return policy\nrefund")
return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt
embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
# Each concept query returns the same three chunks; dedupe keeps A, B, C.
doc_embeddings_client.query.return_value = [
ChunkMatch(chunk_id=CHUNK_A),
ChunkMatch(chunk_id=CHUNK_B),
ChunkMatch(chunk_id=CHUNK_C),
]
async def mock_fetch(chunk_id):
return {
CHUNK_A: CHUNK_A_CONTENT,
CHUNK_B: CHUNK_B_CONTENT,
CHUNK_C: CHUNK_C_CONTENT,
}[chunk_id]
fetch_chunk.side_effect = mock_fetch
prompt_client.document_prompt.return_value = PromptResult(
response_type="text",
text="Items can be returned within 30 days for a full refund.",
)
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
class StubReranker:
"""
Stub reranker_client mirroring RerankerClient.rerank(): returns a fixed,
pre-sorted, truncated list of RerankerResult - exactly the contract the
flashrank service guarantees (sorted desc by score, truncated to limit).
"""
def __init__(self, results):
self._results = results
self.calls = []
async def rerank(self, queries, documents, limit=10, timeout=300):
self.calls.append(
{"queries": queries, "documents": documents, "limit": limit}
)
return self._results
# ---------------------------------------------------------------------------
# 1. No-op: reranker_client=None must not change anything
# ---------------------------------------------------------------------------
class TestRerankNoOp:
@pytest.mark.asyncio
async def test_documents_passed_to_llm_are_unchanged(self):
"""
With no reranker wired, document_prompt must receive the retrieved
chunks in the original order and length.
"""
clients = build_mock_clients()
rag = DocumentRag(*clients) # reranker_client defaults to None
await rag.query(query="What is the return policy?")
call = rag.prompt_client.document_prompt.call_args
passed_docs = call.kwargs["documents"]
assert passed_docs == ORDERED_CONTENT
@pytest.mark.asyncio
async def test_no_chunk_selection_event_emitted(self):
"""
Without a reranker, the provenance chain is the original 4 stages:
question, grounding, exploration, synthesis - no focus stage.
"""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
assert len(events) == 4
types = [
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_SYNTHESIS,
]
for i, expected in enumerate(types):
assert has_type(events[i]["triples"], events[i]["explain_id"], expected)
# No chunk-selection entity anywhere.
for e in events:
assert not any(
t.o.iri == TG_CHUNK_SELECTION
for t in e["triples"]
if t.p.iri == RDF_TYPE
)
@pytest.mark.asyncio
async def test_synthesis_derives_from_exploration_when_no_rerank(self):
"""
No-op lineage is unchanged: synthesis derives from exploration
(there is no focus stage). Guards the conditional synthesis parent.
"""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
# events: question, grounding, exploration, synthesis
exp_uri = events[2]["explain_id"]
syn_event = events[3]
assert derived_from(syn_event["triples"], syn_event["explain_id"]) == exp_uri
# ---------------------------------------------------------------------------
# 2. Rerank: reorder + truncate + provenance
# ---------------------------------------------------------------------------
class TestRerankActive:
def _reranker_keeping_C_then_A(self):
# Reranker says chunk index 2 (C) is best, then index 0 (A); B dropped.
# Pre-sorted desc by score and truncated to limit, per the contract.
return StubReranker([
RerankerResult(document_id="2", query_id="0", score=0.95),
RerankerResult(document_id="0", query_id="0", score=0.42),
])
@pytest.mark.asyncio
async def test_documents_reordered_and_truncated(self):
clients = build_mock_clients()
reranker = self._reranker_keeping_C_then_A()
rag = DocumentRag(*clients, reranker_client=reranker)
await rag.query(query="What is the return policy?")
call = rag.prompt_client.document_prompt.call_args
passed_docs = call.kwargs["documents"]
assert passed_docs == [CHUNK_C_CONTENT, CHUNK_A_CONTENT]
@pytest.mark.asyncio
async def test_reranker_called_with_single_query_and_all_docs(self):
clients = build_mock_clients()
reranker = self._reranker_keeping_C_then_A()
rag = DocumentRag(*clients, reranker_client=reranker)
await rag.query(query="What is the return policy?", doc_limit=2)
assert len(reranker.calls) == 1
c = reranker.calls[0]
assert c["queries"] == [{"id": "0", "text": "What is the return policy?"}]
assert c["documents"] == [
{"id": "0", "text": CHUNK_A_CONTENT},
{"id": "1", "text": CHUNK_B_CONTENT},
{"id": "2", "text": CHUNK_C_CONTENT},
]
# The rerank narrows down to the final doc_limit, NOT fetch_limit
# (fetch_limit is the over-fetched candidate pool size).
assert c["limit"] == 2
@pytest.mark.asyncio
async def test_explicit_fetch_limit_over_fetches_then_narrows(self):
"""
Semantic guard for the value of reranking AND the maintainer's two-limit
contract: an explicit fetch_limit makes retrieval OVER-FETCH a wider
candidate pool so the cross-encoder can surface chunks the bi-encoder
ranked outside the final doc_limit, then the rerank narrows the pool back
down to doc_limit. The fetch_limit is honoured directly (caller controls
how hard the reranker works), not overridden by any heuristic.
"""
clients = build_mock_clients()
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
reranker = self._reranker_keeping_C_then_A()
# Candidate pool (fetch_limit=60) >> final doc_limit (6).
rag = DocumentRag(*clients, reranker_client=reranker)
await rag.query(
query="What is the return policy?", doc_limit=6, fetch_limit=60,
)
# Over-fetch: the embeddings store is queried with the fetch_limit
# budget (60 // 2 concept-vectors = 30 per concept), NOT the doc_limit
# budget (6 // 2 = 3). This is the bug guard.
q_limit = doc_embeddings_client.query.call_args.kwargs["limit"]
assert q_limit == 30
# Narrow: the rerank keeps the final doc_limit (6), not fetch_limit.
assert reranker.calls[0]["limit"] == 6
@pytest.mark.asyncio
async def test_default_fetch_limit_derives_overfetch_from_doc_limit(self):
"""
With no fetch_limit passed to query(), the candidate pool falls back to
the OVERFETCH_FACTOR x doc_limit heuristic, so over-fetch scales with
doc_limit and reranking keeps its recall benefit out of the box.
"""
clients = build_mock_clients()
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
reranker = self._reranker_keeping_C_then_A()
# No fetch_limit -> heuristic default.
rag = DocumentRag(*clients, reranker_client=reranker)
await rag.query(query="What is the return policy?", doc_limit=20)
# fetch = 3 x 20 = 60 -> 60 // 2 concept-vectors = 30 per concept.
q_limit = doc_embeddings_client.query.call_args.kwargs["limit"]
assert q_limit == 30
# Rerank narrows to the final doc_limit (20).
assert reranker.calls[0]["limit"] == 20
@pytest.mark.asyncio
async def test_fetch_limit_floored_at_doc_limit(self):
"""
A fetch_limit below doc_limit is floored up to doc_limit: retrieval must
never fetch fewer candidates than the rerank is asked to keep, else the
prompt could not be filled.
"""
clients = build_mock_clients()
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
reranker = self._reranker_keeping_C_then_A()
rag = DocumentRag(*clients, reranker_client=reranker)
await rag.query(
query="What is the return policy?", doc_limit=10, fetch_limit=4,
)
# fetch = max(4, 10) = 10 -> 10 // 2 concept-vectors = 5 per concept.
q_limit = doc_embeddings_client.query.call_args.kwargs["limit"]
assert q_limit == 5
assert reranker.calls[0]["limit"] == 10
@pytest.mark.asyncio
async def test_chunk_selection_event_emitted(self):
clients = build_mock_clients()
reranker = self._reranker_keeping_C_then_A()
rag = DocumentRag(*clients, reranker_client=reranker)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
# Now 5 stages: question, grounding, exploration, focus, synthesis.
assert len(events) == 5
ordered_types = [
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
TG_FOCUS, TG_SYNTHESIS,
]
for i, expected in enumerate(ordered_types):
assert has_type(events[i]["triples"], events[i]["explain_id"], expected)
@pytest.mark.asyncio
async def test_chunk_selection_carries_scores_and_chunk_refs(self):
clients = build_mock_clients()
reranker = self._reranker_keeping_C_then_A()
rag = DocumentRag(*clients, reranker_client=reranker)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
focus_event = events[3]
foc_uri = focus_event["explain_id"]
triples = focus_event["triples"]
# focus is derived from exploration
exp_uri = events[2]["explain_id"]
assert derived_from(triples, foc_uri) == exp_uri
# Two ChunkSelection sub-entities, linked from focus.
sel_links = find_triples(triples, TG_SELECTED_CHUNK, foc_uri)
assert len(sel_links) == 2
# Each selection has a ChunkSelection type, a chunk document ref and a score.
chunk_refs = set()
scores = set()
for link in sel_links:
sel_uri = link.o.iri
assert has_type(triples, sel_uri, TG_CHUNK_SELECTION)
doc_ref = find_triple(triples, TG_DOCUMENT, sel_uri)
assert doc_ref is not None
chunk_refs.add(doc_ref.o.iri)
score_t = find_triple(triples, TG_SCORE, sel_uri)
assert score_t is not None
scores.add(score_t.o.value)
# Surviving chunks are C and A (B dropped), with the reranker scores.
assert chunk_refs == {CHUNK_C, CHUNK_A}
assert scores == {"0.95", "0.42"}
@pytest.mark.asyncio
async def test_all_focus_triples_in_retrieval_graph(self):
clients = build_mock_clients()
reranker = self._reranker_keeping_C_then_A()
rag = DocumentRag(*clients, reranker_client=reranker)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
for t in events[3]["triples"]:
assert t.g == "urn:graph:retrieval"
@pytest.mark.asyncio
async def test_synthesis_derives_from_focus_when_reranking(self):
"""
When reranking runs, synthesis must derive from the focus node (the
reranked chunks actually fed to the LLM), mirroring GraphRAG - not from
exploration, which would leave focus as a dangling branch and
misrepresent what fed the answer.
"""
clients = build_mock_clients()
reranker = self._reranker_keeping_C_then_A()
rag = DocumentRag(*clients, reranker_client=reranker)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
doc_limit=2,
explain_callback=explain_callback,
)
# events: question, grounding, exploration, focus, synthesis
foc_uri = events[3]["explain_id"]
syn_event = events[4]
assert derived_from(syn_event["triples"], syn_event["explain_id"]) == foc_uri
@pytest.mark.asyncio
async def test_empty_docs_skips_reranker(self):
"""If retrieval returns no chunks, the reranker is never called."""
clients = build_mock_clients()
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
doc_embeddings_client.query.return_value = [] # no matches
reranker = self._reranker_keeping_C_then_A()
rag = DocumentRag(*clients, reranker_client=reranker)
await rag.query(query="What is the return policy?")
assert reranker.calls == []

View file

@ -0,0 +1,89 @@
"""
Cross-layer wiring contract for the Document-RAG reranker (issue #878).
The Document-RAG processor registers a ``RerankerClientSpec`` for the
``reranker-request`` / ``reranker-response`` roles (see
``retrieval/document_rag/rag.py``). At flow construction every spec runs
``spec.add(flow, processor, definition)``, and ``RequestResponseSpec.add``
resolves its topics via ``definition["topics"][name]`` - which raises
``KeyError`` if the flow blueprint does not provide those topics.
This means the monorepo code change is only safe to deploy together with the
companion ``trustgraph-templates`` change that wires ``reranker-request`` /
``reranker-response`` into the Document-RAG flow (mirroring what templates
PR #279 did for GraphRAG via ``graph-store.jsonnet``). These tests pin that
contract from the monorepo side:
* with the reranker topics present (as the updated templates compile them),
the spec binds cleanly and registers the client;
* without them (the pre-companion blueprint), construction fails fast with a
KeyError naming the missing role - documenting exactly why the templates
change is required.
No broker/network: the pub/sub backend is mocked (topics are bound at add()
time, connections happen later at start()).
"""
import pytest
from unittest.mock import MagicMock
from trustgraph.base import RerankerClientSpec
def _flow():
f = MagicMock()
f.workspace = "ws"
f.name = "document-rag"
f.id = "proc1"
f.consumer = {}
return f
def _processor():
p = MagicMock()
p.pubsub = MagicMock()
p.id = "proc1"
p.taskgroup = MagicMock()
return p
def _spec():
return RerankerClientSpec(
request_name="reranker-request",
response_name="reranker-response",
)
# Topics dict as the UPDATED document-store.jsonnet compiles them
# (verified by compiling the template: reranker-request -> request:tg:reranker:{workspace}:{id}).
DEFINITION_WITH_RERANKER = {
"topics": {
"request": "request:tg:document-rag:ws:id",
"response": "response:tg:document-rag:ws:id",
"reranker-request": "request:tg:reranker:ws:id",
"reranker-response": "response:tg:reranker:ws:id",
}
}
# Pre-companion blueprint: no reranker topics (document-rag before the templates change).
DEFINITION_WITHOUT_RERANKER = {
"topics": {
"request": "request:tg:document-rag:ws:id",
"response": "response:tg:document-rag:ws:id",
}
}
def test_reranker_client_binds_when_flow_provides_topics():
flow = _flow()
_spec().add(flow, _processor(), DEFINITION_WITH_RERANKER)
# The client consumer is registered against the reranker role.
assert "reranker-request" in flow.consumer
def test_reranker_client_keyerrors_without_companion_template_topics():
with pytest.raises(KeyError) as exc:
_spec().add(_flow(), _processor(), DEFINITION_WITHOUT_RERANKER)
# Fails fast naming the missing role -> the trustgraph-templates companion
# change (wire reranker-request/response into the document-rag flow) is required.
assert "reranker-request" in str(exc.value)

View file

@ -66,6 +66,7 @@ class TestDocumentRagService:
workspace=ANY, # Workspace comes from flow.workspace (mock)
collection="test_coll_1", # Must be from message, not hardcoded default
doc_limit=5,
fetch_limit=0, # Unset -> core derives the candidate pool
explain_callback=ANY, # Explainability callback is always passed
save_answer_callback=ANY, # Librarian save callback is always passed
)

View file

@ -15,54 +15,52 @@ class TestGraphRag:
def test_graph_rag_initialization_with_defaults(self):
"""Test GraphRag initialization with default verbose setting"""
# Create mock clients
mock_prompt_client = MagicMock()
mock_embeddings_client = MagicMock()
mock_graph_embeddings_client = MagicMock()
mock_triples_client = MagicMock()
mock_reranker_client = MagicMock()
# Initialize GraphRag
graph_rag = GraphRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client
)
# Verify initialization
assert graph_rag.prompt_client == mock_prompt_client
assert graph_rag.embeddings_client == mock_embeddings_client
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
assert graph_rag.triples_client == mock_triples_client
assert graph_rag.verbose is False # Default value
# Verify label_cache is an LRUCacheWithTTL instance
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
def test_graph_rag_initialization_with_verbose(self):
"""Test GraphRag initialization with verbose enabled"""
# Create mock clients
mock_prompt_client = MagicMock()
mock_embeddings_client = MagicMock()
mock_graph_embeddings_client = MagicMock()
mock_triples_client = MagicMock()
# Initialize GraphRag with verbose=True
graph_rag = GraphRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
verbose=True
reranker_client=mock_reranker_client,
)
# Verify initialization
assert graph_rag.prompt_client == mock_prompt_client
assert graph_rag.embeddings_client == mock_embeddings_client
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
assert graph_rag.triples_client == mock_triples_client
assert graph_rag.reranker_client == mock_reranker_client
assert graph_rag.verbose is False
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
def test_graph_rag_initialization_with_verbose(self):
"""Test GraphRag initialization with verbose enabled"""
mock_prompt_client = MagicMock()
mock_embeddings_client = MagicMock()
mock_graph_embeddings_client = MagicMock()
mock_triples_client = MagicMock()
mock_reranker_client = MagicMock()
graph_rag = GraphRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
reranker_client=mock_reranker_client,
verbose=True,
)
assert graph_rag.prompt_client == mock_prompt_client
assert graph_rag.embeddings_client == mock_embeddings_client
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
assert graph_rag.triples_client == mock_triples_client
assert graph_rag.reranker_client == mock_reranker_client
assert graph_rag.verbose is True
# Verify label_cache is an LRUCacheWithTTL instance
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
@ -365,244 +363,162 @@ class TestQuery:
assert "workspace" not in c.kwargs
@pytest.mark.asyncio
async def test_follow_edges_never_passes_workspace(self):
"""Verify follow_edges never passes workspace to query_stream."""
async def test_hop_and_filter_never_passes_workspace(self):
"""Verify hop_and_filter never passes workspace to query_stream."""
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_reranker_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_rag.reranker_client = mock_reranker_client
mock_rag.label_cache = MagicMock()
mock_rag.label_cache.get.return_value = None
mock_triple = MagicMock()
mock_triple.s, mock_triple.p, mock_triple.o = "e1", "p1", "o1"
mock_triple.s = "e1"
mock_triple.p = "p1"
mock_triple.o = "o1"
mock_triples_client.query_stream.return_value = [mock_triple]
mock_triples_client.query.return_value = []
result = MagicMock()
result.document_id = "0"
result.query_id = "0"
result.score = 0.9
mock_reranker_client.rerank.return_value = [result]
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
triple_limit=10
triple_limit=10,
)
subgraph = set()
await query.follow_edges("e1", subgraph, path_length=1)
await query.hop_and_filter(["e1"], ["concept"])
for c in mock_triples_client.query_stream.call_args_list:
assert "workspace" not in c.kwargs
@pytest.mark.asyncio
async def test_follow_edges_basic_functionality(self):
"""Test Query.follow_edges method basic triple discovery"""
async def test_hop_and_filter_basic_functionality(self):
"""Test hop_and_filter retrieves edges and scores them with reranker."""
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_reranker_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_rag.reranker_client = mock_reranker_client
mock_rag.label_cache = MagicMock()
mock_rag.label_cache.get.return_value = None
mock_triple1 = MagicMock()
mock_triple1.s, mock_triple1.p, mock_triple1.o = "entity1", "predicate1", "object1"
mock_triple = MagicMock()
mock_triple.s = "entity1"
mock_triple.p = "predicate1"
mock_triple.o = "object1"
mock_triples_client.query_stream.return_value = [mock_triple]
mock_triples_client.query.return_value = []
mock_triple2 = MagicMock()
mock_triple2.s, mock_triple2.p, mock_triple2.o = "subject2", "entity1", "object2"
mock_triple3 = MagicMock()
mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1"
mock_triples_client.query_stream.side_effect = [
[mock_triple1], # s=ent
[mock_triple2], # p=ent
[mock_triple3], # o=ent
]
result = MagicMock()
result.document_id = "0"
result.query_id = "0"
result.score = 0.95
mock_reranker_client.rerank.return_value = [result]
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
triple_limit=10
triple_limit=10,
edge_limit=25,
)
subgraph = set()
await query.follow_edges("entity1", subgraph, path_length=1)
assert mock_triples_client.query_stream.call_count == 3
mock_triples_client.query_stream.assert_any_call(
s="entity1", p=None, o=None, limit=10,
collection="test_collection", batch_size=20, g=""
)
mock_triples_client.query_stream.assert_any_call(
s=None, p="entity1", o=None, limit=10,
collection="test_collection", batch_size=20, g=""
)
mock_triples_client.query_stream.assert_any_call(
s=None, p=None, o="entity1", limit=10,
collection="test_collection", batch_size=20, g=""
selected, uri_map, edge_meta = await query.hop_and_filter(
["entity1"], ["test concept"],
)
expected_subgraph = {
("entity1", "predicate1", "object1"),
("subject2", "entity1", "object2"),
("subject3", "predicate3", "entity1")
}
assert subgraph == expected_subgraph
assert len(selected) == 1
assert len(uri_map) == 1
assert len(edge_meta) == 1
mock_reranker_client.rerank.assert_called_once()
call_kwargs = mock_reranker_client.rerank.call_args
assert call_kwargs.kwargs["limit"] == 25
@pytest.mark.asyncio
async def test_follow_edges_with_path_length_zero(self):
"""Test Query.follow_edges method with path_length=0"""
async def test_hop_and_filter_with_empty_frontier(self):
"""Test hop_and_filter with no seed entities returns empty."""
mock_rag = MagicMock()
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
)
selected, uri_map, edge_meta = await query.hop_and_filter([], ["concept"])
assert selected == []
assert uri_map == {}
assert edge_meta == {}
@pytest.mark.asyncio
async def test_hop_and_filter_filters_label_triples(self):
"""Test hop_and_filter skips rdfs:label edges."""
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_reranker_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_rag.reranker_client = mock_reranker_client
mock_rag.label_cache = MagicMock()
mock_rag.label_cache.get.return_value = None
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False
)
label_triple = MagicMock()
label_triple.s = "entity1"
label_triple.p = "http://www.w3.org/2000/01/rdf-schema#label"
label_triple.o = "Entity One"
subgraph = set()
await query.follow_edges("entity1", subgraph, path_length=0)
mock_triples_client.query_stream.assert_not_called()
assert subgraph == set()
@pytest.mark.asyncio
async def test_follow_edges_with_max_subgraph_size_limit(self):
"""Test Query.follow_edges method respects max_subgraph_size"""
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_triples_client.query_stream.return_value = [label_triple]
mock_triples_client.query.return_value = []
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
max_subgraph_size=2
triple_limit=10,
)
subgraph = {("s1", "p1", "o1"), ("s2", "p2", "o2"), ("s3", "p3", "o3")}
await query.follow_edges("entity1", subgraph, path_length=1)
mock_triples_client.query_stream.assert_not_called()
assert len(subgraph) == 3
@pytest.mark.asyncio
async def test_get_subgraph_method(self):
"""Test Query.get_subgraph returns (subgraph, entities, concepts) tuple"""
mock_rag = MagicMock()
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
max_path_length=1
selected, uri_map, edge_meta = await query.hop_and_filter(
["entity1"], ["concept"],
)
# Mock get_entities to return (entities, concepts) tuple
query.get_entities = AsyncMock(
return_value=(["entity1", "entity2"], ["concept1"])
)
query.follow_edges_batch = AsyncMock(return_value=(
{
("entity1", "predicate1", "object1"),
("entity2", "predicate2", "object2")
},
{}
))
subgraph, term_map, entities, concepts = await query.get_subgraph("test query")
query.get_entities.assert_called_once_with("test query")
query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1)
assert isinstance(subgraph, list)
assert len(subgraph) == 2
assert ("entity1", "predicate1", "object1") in subgraph
assert ("entity2", "predicate2", "object2") in subgraph
assert entities == ["entity1", "entity2"]
assert concepts == ["concept1"]
@pytest.mark.asyncio
async def test_get_labelgraph_method(self):
"""Test Query.get_labelgraph returns (labeled_edges, uri_map, entities, concepts)"""
mock_rag = MagicMock()
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
max_subgraph_size=100
)
test_subgraph = [
("entity1", "predicate1", "object1"),
("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"),
("entity3", "predicate3", "object3")
]
test_entities = ["entity1", "entity3"]
test_concepts = ["concept1"]
query.get_subgraph = AsyncMock(
return_value=(test_subgraph, {}, test_entities, test_concepts)
)
async def mock_maybe_label(entity):
label_map = {
"entity1": "Human Entity One",
"predicate1": "Human Predicate One",
"object1": "Human Object One",
"entity3": "Human Entity Three",
"predicate3": "Human Predicate Three",
"object3": "Human Object Three"
}
return label_map.get(entity, entity)
query.maybe_label = AsyncMock(side_effect=mock_maybe_label)
labeled_edges, uri_map, entities, concepts = await query.get_labelgraph("test query")
query.get_subgraph.assert_called_once_with("test query")
# Label triples filtered out
assert len(labeled_edges) == 2
# maybe_label called for non-label triples
assert query.maybe_label.call_count == 6
expected_edges = [
("Human Entity One", "Human Predicate One", "Human Object One"),
("Human Entity Three", "Human Predicate Three", "Human Object Three")
]
assert labeled_edges == expected_edges
assert len(uri_map) == 2
assert entities == test_entities
assert concepts == test_concepts
assert selected == []
mock_reranker_client.rerank.assert_not_called()
@pytest.mark.asyncio
async def test_graph_rag_query_method(self):
"""Test GraphRag.query method orchestrates full RAG pipeline with provenance"""
import json
from trustgraph.retrieval.graph_rag.graph_rag import edge_id
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_graph_embeddings_client = AsyncMock()
mock_triples_client = AsyncMock()
mock_reranker_client = AsyncMock()
expected_response = "This is the RAG response"
test_labelgraph = [("Subject", "Predicate", "Object")]
test_edge_id = edge_id("Subject", "Predicate", "Object")
test_selected_edges = [("Subject", "Predicate", "Object")]
test_eid = edge_id("Subject", "Predicate", "Object")
test_uri_map = {
test_edge_id: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
test_eid: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
}
test_edge_metadata = {
test_eid: {"concept": "test concept", "score": 0.95}
}
test_entities = ["http://example.org/subject"]
test_concepts = ["test concept"]
# Mock prompt responses for the multi-step process
mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
mock_graph_embeddings_client.query.return_value = []
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "extract-concepts":
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-scoring":
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "score": 0.9}])
elif prompt_name == "kg-edge-reasoning":
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "reasoning": "relevant"}])
return PromptResult(response_type="text", text="test concept")
elif prompt_name == "kg-synthesis":
return PromptResult(response_type="text", text=expected_response)
return PromptResult(response_type="text", text="")
@ -614,16 +530,16 @@ class TestQuery:
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
verbose=False
reranker_client=mock_reranker_client,
verbose=False,
)
# Patch Query.get_labelgraph to return test data
original_get_labelgraph = Query.get_labelgraph
original_hop_and_filter = Query.hop_and_filter
async def mock_get_labelgraph(self, query_text):
return test_labelgraph, test_uri_map, test_entities, test_concepts
async def mock_hop_and_filter(self, seed_entities, concepts):
return test_selected_edges, test_uri_map, test_edge_metadata
Query.get_labelgraph = mock_get_labelgraph
Query.hop_and_filter = mock_hop_and_filter
provenance_events = []
@ -636,7 +552,7 @@ class TestQuery:
collection="test_collection",
entity_limit=25,
triple_limit=15,
explain_callback=collect_provenance
explain_callback=collect_provenance,
)
response_text, usage = response
@ -650,7 +566,6 @@ class TestQuery:
assert len(triples) > 0
assert prov_id.startswith("urn:trustgraph:")
# Verify order
assert "question" in provenance_events[0][1]
assert "grounding" in provenance_events[1][1]
assert "exploration" in provenance_events[2][1]
@ -658,4 +573,4 @@ class TestQuery:
assert "synthesis" in provenance_events[4][1]
finally:
Query.get_labelgraph = original_get_labelgraph
Query.hop_and_filter = original_hop_and_filter

View file

@ -20,7 +20,7 @@ from trustgraph.provenance.namespaces import (
TG_GRAPH_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
TG_FOCUS, TG_SYNTHESIS, TG_ANSWER_TYPE,
TG_QUERY, TG_CONCEPT, TG_ENTITY, TG_EDGE_COUNT,
TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_SELECTED_EDGE, TG_EDGE, TG_SCORE, TG_EDGE_SELECTION,
)
@ -91,17 +91,17 @@ def build_mock_clients():
1. prompt_client.prompt("extract-concepts", ...) -> concepts
2. embeddings_client.embed(concepts) -> vectors
3. graph_embeddings_client.query(vector, ...) -> entity matches
4. triples_client.query_stream(s/p/o, ...) -> edges (follow_edges_batch)
4. triples_client.query_stream(s/p/o, ...) -> edges (hop_and_filter)
5. triples_client.query(s, LABEL, ...) -> labels (maybe_label)
6. prompt_client.prompt("kg-edge-scoring", ...) -> scored edges
7. prompt_client.prompt("kg-edge-reasoning", ...) -> reasoning
8. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns [])
9. prompt_client.prompt("kg-synthesis", ...) -> final answer
6. reranker_client.rerank(queries, documents, limit) -> scored edges
7. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns [])
8. prompt_client.prompt("kg-synthesis", ...) -> final answer
"""
prompt_client = AsyncMock()
embeddings_client = AsyncMock()
graph_embeddings_client = AsyncMock()
triples_client = AsyncMock()
reranker_client = AsyncMock()
# 1. Concept extraction
prompt_responses = {}
@ -116,7 +116,7 @@ def build_mock_clients():
EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_B)),
]
# 4. Triple queries (follow_edges_batch) - return our edges
# 4. Triple queries (hop_and_filter) - return our edges
kg_triples = [
make_schema_triple(*EDGE_1),
make_schema_triple(*EDGE_2),
@ -130,9 +130,18 @@ def build_mock_clients():
return [] # No labels found, will fall back to URI
triples_client.query.side_effect = mock_label_query
# 6+7. Edge scoring and reasoning: dynamically score/reason about
# whatever edges the query method sends us, since edge IDs are computed
# from str(Term) representations which include the full dataclass repr.
# 6. Reranker: select all documents with high scores
async def mock_rerank(queries, documents, limit):
results = []
for i, doc in enumerate(documents):
result = MagicMock()
result.document_id = doc["id"]
result.query_id = queries[0]["id"] if queries else "0"
result.score = 0.9 - (i * 0.1)
results.append(result)
return results[:limit]
reranker_client.rerank.side_effect = mock_rerank
synthesis_answer = "Quantum computing applies physics principles to computation."
async def mock_prompt(template_id, variables=None, **kwargs):
@ -141,26 +150,6 @@ def build_mock_clients():
response_type="text",
text=prompt_responses["extract-concepts"],
)
elif template_id == "kg-edge-scoring":
# Score all edges highly, using the IDs that GraphRag computed
edges = variables.get("knowledge", [])
return PromptResult(
response_type="jsonl",
objects=[
{"id": e["id"], "score": 10 - i}
for i, e in enumerate(edges)
],
)
elif template_id == "kg-edge-reasoning":
# Provide reasoning for each edge
edges = variables.get("knowledge", [])
return PromptResult(
response_type="jsonl",
objects=[
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
for i, e in enumerate(edges)
],
)
elif template_id == "kg-synthesis":
return PromptResult(
response_type="text",
@ -170,7 +159,8 @@ def build_mock_clients():
prompt_client.prompt.side_effect = mock_prompt
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
return (prompt_client, embeddings_client, graph_embeddings_client,
triples_client, reranker_client)
# ---------------------------------------------------------------------------
@ -197,7 +187,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0, # skip semantic pre-filter for simplicity
)
assert len(events) == 5, (
@ -222,7 +212,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
expected_types = [
@ -260,7 +250,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
uris = [e["explain_id"] for e in events]
@ -297,7 +287,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
q_uri = events[0]["explain_id"]
@ -320,7 +310,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
gnd_uri = events[1]["explain_id"]
@ -344,7 +334,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
exp_uri = events[2]["explain_id"]
@ -355,10 +345,10 @@ class TestGraphRagQueryProvenance:
assert int(t.o.value) > 0
@pytest.mark.asyncio
async def test_focus_has_selected_edges_with_reasoning(self):
async def test_focus_has_selected_edges_with_concept_and_score(self):
"""
The focus event should carry selected edges as quoted triples
with reasoning text.
with cross-encoder concept and score metadata.
"""
clients = build_mock_clients()
rag = GraphRag(*clients)
@ -371,7 +361,6 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
foc_uri = events[3]["explain_id"]
@ -387,11 +376,19 @@ class TestGraphRagQueryProvenance:
for t in edge_t:
assert t.o.triple is not None, "tg:edge object must be a quoted triple"
# Should have reasoning
reasoning = find_triples(foc_triples, TG_REASONING)
assert len(reasoning) > 0, "Focus should have reasoning for selected edges"
reasoning_texts = {t.o.value for t in reasoning}
assert any(r for r in reasoning_texts), "Reasoning should not be empty"
# Edge selections should be typed as EdgeSelection
edge_sel_uris = [t.o.iri for t in selected]
for uri in edge_sel_uris:
assert has_type(foc_triples, uri, TG_EDGE_SELECTION)
# Should have concept and score
concepts = find_triples(foc_triples, TG_CONCEPT)
assert len(concepts) > 0, "Focus should have tg:concept for selected edges"
scores = find_triples(foc_triples, TG_SCORE)
assert len(scores) > 0, "Focus should have tg:score for selected edges"
for t in scores:
float(t.o.value) # Should be parseable as float
@pytest.mark.asyncio
async def test_synthesis_is_answer_type(self):
@ -407,7 +404,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
syn_uri = events[4]["explain_id"]
@ -429,7 +426,7 @@ class TestGraphRagQueryProvenance:
result_text, usage = await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
assert result_text == "Quantum computing applies physics principles to computation."
@ -449,7 +446,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
parent_uri=parent,
)
@ -465,7 +462,7 @@ class TestGraphRagQueryProvenance:
result_text, usage = await rag.query(
query="What is quantum computing?",
edge_score_limit=0,
)
assert result_text == "Quantum computing applies physics principles to computation."
@ -484,7 +481,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
for event in events:

View file

@ -527,7 +527,8 @@ class AsyncFlowInstance:
return result.get("response", "")
async def document_rag(self, query: str, collection: str,
doc_limit: int = 10, **kwargs: Any) -> str:
doc_limit: int = 10, fetch_limit: int = 0,
**kwargs: Any) -> str:
"""
Execute document-based RAG query (non-streaming).
@ -541,7 +542,9 @@ class AsyncFlowInstance:
Args:
query: User query text
collection: Collection identifier containing documents
doc_limit: Maximum number of document chunks to retrieve (default: 10)
doc_limit: Document chunks selected into the prompt (default: 10)
fetch_limit: Candidate chunks fetched from the vector store before
reranking (default: 0 = derive from doc_limit)
**kwargs: Additional service-specific parameters
Returns:
@ -564,6 +567,7 @@ class AsyncFlowInstance:
"query": query,
"collection": collection,
"doc-limit": doc_limit,
"fetch-limit": fetch_limit,
"streaming": False
}
request_data.update(kwargs)
@ -646,6 +650,16 @@ class AsyncFlowInstance:
return await self.request("embeddings", request_data)
async def rerank(self, queries: list, documents: list, limit: int = 10, **kwargs: Any):
request_data = {
"queries": queries,
"documents": documents,
"limit": limit,
}
request_data.update(kwargs)
return await self.request("reranker", request_data)
async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs: Any):
"""
Query RDF triples using pattern matching.

View file

@ -379,12 +379,14 @@ class AsyncSocketFlowInstance:
yield chunk.content
async def document_rag(self, query: str, collection: str,
doc_limit: int = 10, streaming: bool = False, **kwargs):
doc_limit: int = 10, fetch_limit: int = 0,
streaming: bool = False, **kwargs):
"""Document RAG with optional streaming"""
request = {
"query": query,
"collection": collection,
"doc-limit": doc_limit,
"fetch-limit": fetch_limit,
"streaming": streaming
}
request.update(kwargs)
@ -443,6 +445,19 @@ class AsyncSocketFlowInstance:
return await self.client._send_request("embeddings", self.flow_id, request)
async def rerank(self, queries: list, documents: list, limit: int = 10,
**kwargs):
request = {
"queries": queries,
"documents": documents,
"limit": limit,
}
request.update(kwargs)
return await self.client._send_request(
"reranker", self.flow_id, request,
)
async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs):
"""Triple pattern query"""
request = {"limit": limit}

View file

@ -18,6 +18,7 @@ TG_EDGE_COUNT = TG + "edgeCount"
TG_SELECTED_EDGE = TG + "selectedEdge"
TG_EDGE = TG + "edge"
TG_REASONING = TG + "reasoning"
TG_SCORE = TG + "score"
TG_DOCUMENT = TG + "document"
TG_CONCEPT = TG + "concept"
TG_ENTITY = TG + "entity"
@ -66,10 +67,12 @@ RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
@dataclass
class EdgeSelection:
"""A selected edge with reasoning from GraphRAG Focus step."""
"""A selected edge with cross-encoder metadata from GraphRAG Focus step."""
uri: str
edge: Optional[Dict[str, str]] = None # {"s": ..., "p": ..., "o": ...}
reasoning: str = ""
concept: str = ""
score: Optional[float] = None
@dataclass
@ -209,7 +212,7 @@ class Exploration(ExplainEntity):
@dataclass
class Focus(ExplainEntity):
"""Focus entity - selected edges with LLM reasoning (GraphRAG only)."""
"""Focus entity - selected edges with cross-encoder scoring (GraphRAG only)."""
selected_edge_uris: List[str] = field(default_factory=list)
edge_selections: List[EdgeSelection] = field(default_factory=list)
@ -418,14 +421,26 @@ def parse_edge_selection_triples(triples: List[Tuple[str, str, Any]]) -> EdgeSel
uri = triples[0][0] if triples else ""
edge = None
reasoning = ""
concept = ""
score = None
for s, p, o in triples:
if p == TG_EDGE and isinstance(o, dict):
edge = o
elif p == TG_REASONING:
reasoning = o
elif p == TG_CONCEPT:
concept = o
elif p == TG_SCORE:
try:
score = float(o)
except (ValueError, TypeError):
score = None
return EdgeSelection(uri=uri, edge=edge, reasoning=reasoning)
return EdgeSelection(
uri=uri, edge=edge, reasoning=reasoning,
concept=concept, score=score,
)
def extract_term_value(term: Dict[str, Any]) -> Any:

View file

@ -415,7 +415,7 @@ class FlowInstance:
def document_rag(
self, query,collection="default",
doc_limit=10,
doc_limit=10, fetch_limit=0,
):
"""
Execute document-based Retrieval-Augmented Generation (RAG) query.
@ -426,7 +426,9 @@ class FlowInstance:
Args:
query: Natural language query
collection: Collection identifier (default: "default")
doc_limit: Maximum document chunks to retrieve (default: 10)
doc_limit: Document chunks selected into the prompt (default: 10)
fetch_limit: Candidate chunks fetched from the vector store before
reranking (default: 0 = derive from doc_limit)
Returns:
str: Generated response incorporating document context
@ -447,6 +449,7 @@ class FlowInstance:
"query": query,
"collection": collection,
"doc-limit": doc_limit,
"fetch-limit": fetch_limit,
}
result = self.request(
@ -491,6 +494,19 @@ class FlowInstance:
input
)["vectors"]
def rerank(self, queries, documents, limit=10):
input = {
"queries": queries,
"documents": documents,
"limit": limit,
}
return self.request(
"service/reranker",
input
)
def graph_embeddings_query(self, text, collection, limit=10):
"""
Query knowledge graph entities using semantic similarity.

View file

@ -752,6 +752,7 @@ class SocketFlowInstance:
query: str,
collection: str,
doc_limit: int = 10,
fetch_limit: int = 0,
streaming: bool = False,
**kwargs: Any
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
@ -764,6 +765,7 @@ class SocketFlowInstance:
"query": query,
"collection": collection,
"doc-limit": doc_limit,
"fetch-limit": fetch_limit,
"streaming": streaming
}
request.update(kwargs)
@ -785,6 +787,7 @@ class SocketFlowInstance:
query: str,
collection: str,
doc_limit: int = 10,
fetch_limit: int = 0,
**kwargs: Any
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
"""Execute document-based RAG query with explainability support."""
@ -792,6 +795,7 @@ class SocketFlowInstance:
"query": query,
"collection": collection,
"doc-limit": doc_limit,
"fetch-limit": fetch_limit,
"streaming": True,
"explainable": True,
}
@ -885,6 +889,19 @@ class SocketFlowInstance:
return self.client._send_request_sync("embeddings", self.flow_id, request, False)
def rerank(self, queries: list, documents: list, limit: int = 10,
**kwargs: Any) -> Dict[str, Any]:
request = {
"queries": queries,
"documents": documents,
"limit": limit,
}
request.update(kwargs)
return self.client._send_request_sync(
"reranker", self.flow_id, request, False,
)
def triples_query(
self,
s: Optional[Union[str, Dict[str, Any]]] = None,

View file

@ -42,6 +42,8 @@ from . dynamic_tool_service import DynamicToolService
from . tool_service_client import ToolServiceClientSpec
from . agent_client import AgentClientSpec
from . structured_query_client import StructuredQueryClientSpec
from . reranker_client import RerankerClientSpec
from . reranker_service import RerankerService
from . row_embeddings_query_client import RowEmbeddingsQueryClientSpec
from . collection_config_handler import CollectionConfigHandler

View file

@ -157,21 +157,6 @@ class PromptClient(RequestResponse):
timeout = timeout,
)
async def kg_prompt(self, query, kg, timeout=600, streaming=False, chunk_callback=None):
return await self.prompt(
id = "kg-prompt",
variables = {
"query": query,
"knowledge": [
{ "s": v[0], "p": v[1], "o": v[2] }
for v in kg
]
},
timeout = timeout,
streaming = streaming,
chunk_callback = chunk_callback,
)
async def document_prompt(self, query, documents, timeout=600, streaming=False, chunk_callback=None):
return await self.prompt(
id = "document-prompt",

View file

@ -0,0 +1,43 @@
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import (
RerankerRequest, RerankerResponse,
RerankerQuery, RerankerDocument,
)
class RerankerClient(RequestResponse):
async def rerank(self, queries, documents, limit=10, timeout=300):
resp = await self.request(
RerankerRequest(
queries=[
RerankerQuery(query_id=q["id"], query_text=q["text"])
for q in queries
],
documents=[
RerankerDocument(
document_id=d["id"], document_text=d["text"]
)
for d in documents
],
limit=limit,
),
timeout=timeout
)
if resp.error:
raise RuntimeError(resp.error.message)
return resp.results
class RerankerClientSpec(RequestResponseSpec):
def __init__(
self, request_name, response_name,
):
super(RerankerClientSpec, self).__init__(
request_name = request_name,
request_schema = RerankerRequest,
response_name = response_name,
response_schema = RerankerResponse,
impl = RerankerClient,
)

View file

@ -0,0 +1,109 @@
from __future__ import annotations
from argparse import ArgumentParser
import logging
from .. schema import (
RerankerRequest, RerankerResponse, RerankerResult, Error,
)
from .. exceptions import TooManyRequests
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec, ParameterSpec
logger = logging.getLogger(__name__)
default_ident = "reranker"
default_concurrency = 1
class RerankerService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
concurrency = params.get("concurrency", 1)
super(RerankerService, self).__init__(**params | {
"id": id,
"concurrency": concurrency,
})
self.register_specification(
ConsumerSpec(
name = "request",
schema = RerankerRequest,
handler = self.on_request,
concurrency = concurrency,
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = RerankerResponse
)
)
self.register_specification(
ParameterSpec(
name = "model",
)
)
async def on_request(self, msg, consumer, flow):
try:
request = msg.value()
id = msg.properties()["id"]
logger.debug(f"Handling reranker request {id}...")
model = flow("model")
results = await self.on_rerank(
request.queries, request.documents,
request.limit, model=model,
)
await flow("response").send(
RerankerResponse(
error = None,
results = results,
),
properties={"id": id}
)
logger.debug("Reranker request handled successfully")
except TooManyRequests as e:
raise e
except Exception as e:
logger.error(f"Exception in reranker service: {e}", exc_info=True)
logger.info("Sending error response...")
await flow.producer["response"].send(
RerankerResponse(
error=Error(
type = "reranker-error",
message = str(e),
),
results=[],
),
properties={"id": id}
)
@staticmethod
def add_args(parser: ArgumentParser) -> None:
parser.add_argument(
'-c', '--concurrency',
type=int,
default=default_concurrency,
help=f'Concurrent processing threads (default: {default_concurrency})'
)
FlowProcessor.add_args(parser)

View file

@ -140,20 +140,6 @@ class PromptClient(BaseClient):
timeout=timeout
)
def request_kg_prompt(self, query, kg, timeout=300):
return self.request(
id="kg-prompt",
variables={
"query": query,
"knowledge": [
{ "s": v[0], "p": v[1], "o": v[2] }
for v in kg
]
},
timeout=timeout
)
def request_document_prompt(self, query, documents, timeout=300):
return self.request(

View file

@ -27,6 +27,7 @@ from .translators.rows_query import RowsQueryRequestTranslator, RowsQueryRespons
from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator
from .translators.structured_query import StructuredQueryRequestTranslator, StructuredQueryResponseTranslator
from .translators.diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
from .translators.reranker import RerankerRequestTranslator, RerankerResponseTranslator
from .translators.collection import CollectionManagementRequestTranslator, CollectionManagementResponseTranslator
from .translators.sparql_query import SparqlQueryRequestTranslator, SparqlQueryResponseTranslator
@ -163,6 +164,12 @@ TranslatorRegistry.register_service(
SparqlQueryResponseTranslator()
)
TranslatorRegistry.register_service(
"reranker",
RerankerRequestTranslator(),
RerankerResponseTranslator()
)
# Register single-direction translators for document loading
TranslatorRegistry.register_request("document", DocumentTranslator())
TranslatorRegistry.register_request("text-document", TextDocumentTranslator())

View file

@ -20,3 +20,4 @@ from .embeddings_query import (
)
from .rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
from .reranker import RerankerRequestTranslator, RerankerResponseTranslator

View file

@ -0,0 +1,73 @@
from typing import Dict, Any, Tuple
from ...schema import (
RerankerRequest, RerankerResponse,
RerankerQuery, RerankerDocument, RerankerResult,
)
from .base import MessageTranslator
class RerankerRequestTranslator(MessageTranslator):
def decode(self, data: Dict[str, Any]) -> RerankerRequest:
return RerankerRequest(
queries=[
RerankerQuery(
query_id=q["query_id"],
query_text=q["query_text"],
)
for q in data.get("queries", [])
],
documents=[
RerankerDocument(
document_id=d["document_id"],
document_text=d["document_text"],
)
for d in data.get("documents", [])
],
limit=data.get("limit", 10),
)
def encode(self, obj: RerankerRequest) -> Dict[str, Any]:
return {
"queries": [
{"query_id": q.query_id, "query_text": q.query_text}
for q in obj.queries
],
"documents": [
{"document_id": d.document_id, "document_text": d.document_text}
for d in obj.documents
],
"limit": obj.limit,
}
class RerankerResponseTranslator(MessageTranslator):
def decode(self, data: Dict[str, Any]) -> RerankerResponse:
return RerankerResponse(
results=[
RerankerResult(
document_id=r["document_id"],
query_id=r["query_id"],
score=r["score"],
)
for r in data.get("results", [])
],
)
def encode(self, obj: RerankerResponse) -> Dict[str, Any]:
return {
"results": [
{
"document_id": r.document_id,
"query_id": r.query_id,
"score": r.score,
}
for r in obj.results
],
}
def encode_with_completion(
self, obj: RerankerResponse
) -> Tuple[Dict[str, Any], bool]:
return self.encode(obj), True

View file

@ -12,6 +12,7 @@ class DocumentRagRequestTranslator(MessageTranslator):
query=data["query"],
collection=data.get("collection", "default"),
doc_limit=int(data.get("doc-limit", 20)),
fetch_limit=int(data.get("fetch-limit", 0)),
streaming=data.get("streaming", False)
)
@ -20,6 +21,7 @@ class DocumentRagRequestTranslator(MessageTranslator):
"query": obj.query,
"collection": obj.collection,
"doc-limit": obj.doc_limit,
"fetch-limit": obj.fetch_limit,
"streaming": getattr(obj, "streaming", False)
}

View file

@ -64,6 +64,8 @@ from . uris import (
docrag_question_uri,
docrag_grounding_uri,
docrag_exploration_uri,
docrag_focus_uri,
chunk_selection_uri,
docrag_synthesis_uri,
)
@ -89,9 +91,13 @@ from . namespaces import (
TG_IMAGE_TYPE, TG_SUBGRAPH_TYPE,
# Query-time provenance predicates (GraphRAG)
TG_QUERY, TG_CONCEPT, TG_ENTITY,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING, TG_SCORE,
# Edge selection entity type
TG_EDGE_SELECTION,
# Query-time provenance predicates (DocumentRAG)
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
# Chunk selection entity type
TG_CHUNK_SELECTION,
# Explainability entity types
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_ANALYSIS, TG_CONCLUSION,
@ -130,6 +136,7 @@ from . triples import (
# Query-time provenance triple builders (DocumentRAG)
docrag_question_triples,
docrag_exploration_triples,
docrag_chunk_selection_triples,
docrag_synthesis_triples,
# Utility
set_graph,
@ -194,6 +201,8 @@ __all__ = [
"docrag_question_uri",
"docrag_grounding_uri",
"docrag_exploration_uri",
"docrag_focus_uri",
"chunk_selection_uri",
"docrag_synthesis_uri",
# Namespaces
"PROV", "PROV_ENTITY", "PROV_ACTIVITY", "PROV_AGENT",
@ -212,9 +221,13 @@ __all__ = [
"TG_CHUNK_TYPE", "TG_IMAGE_TYPE", "TG_SUBGRAPH_TYPE",
# Query-time provenance predicates (GraphRAG)
"TG_QUERY", "TG_CONCEPT", "TG_ENTITY",
"TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING",
"TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING", "TG_SCORE",
# Edge selection entity type
"TG_EDGE_SELECTION",
# Query-time provenance predicates (DocumentRAG)
"TG_CHUNK_COUNT", "TG_SELECTED_CHUNK",
# Chunk selection entity type
"TG_CHUNK_SELECTION",
# Explainability entity types
"TG_QUESTION", "TG_GROUNDING", "TG_EXPLORATION", "TG_FOCUS", "TG_SYNTHESIS",
"TG_ANALYSIS", "TG_CONCLUSION",
@ -250,6 +263,7 @@ __all__ = [
# Query-time provenance triple builders (DocumentRAG)
"docrag_question_triples",
"docrag_exploration_triples",
"docrag_chunk_selection_triples",
"docrag_synthesis_triples",
# Agent provenance triple builders
"agent_session_triples",

View file

@ -66,12 +66,19 @@ TG_EDGE_COUNT = TG + "edgeCount"
TG_SELECTED_EDGE = TG + "selectedEdge"
TG_EDGE = TG + "edge"
TG_REASONING = TG + "reasoning"
TG_SCORE = TG + "score"
TG_DOCUMENT = TG + "document" # Reference to document in librarian
# Edge selection entity type (cross-encoder scored edge in Focus)
TG_EDGE_SELECTION = TG + "EdgeSelection"
# Query-time provenance predicates (DocumentRAG)
TG_CHUNK_COUNT = TG + "chunkCount"
TG_SELECTED_CHUNK = TG + "selectedChunk"
# Chunk selection entity type (cross-encoder reranked chunk in Focus)
TG_CHUNK_SELECTION = TG + "ChunkSelection"
# Extraction provenance entity types
TG_DOCUMENT_TYPE = TG + "Document"
TG_PAGE_TYPE = TG + "Page"

View file

@ -24,10 +24,14 @@ from . namespaces import (
TG_ELEMENT_TYPES, TG_TABLE_COUNT, TG_IMAGE_COUNT,
# Query-time provenance predicates (GraphRAG)
TG_QUERY, TG_CONCEPT, TG_ENTITY,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, TG_SCORE,
TG_DOCUMENT,
# Edge selection entity type
TG_EDGE_SELECTION,
# Query-time provenance predicates (DocumentRAG)
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
# Chunk selection entity type
TG_CHUNK_SELECTION,
# Explainability entity types
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
# Unifying types
@ -38,7 +42,10 @@ from . namespaces import (
TG_IN_TOKEN, TG_OUT_TOKEN,
)
from . uris import activity_uri, agent_uri, subgraph_uri, edge_selection_uri
from . uris import (
activity_uri, agent_uri, subgraph_uri, edge_selection_uri,
chunk_selection_uri,
)
def set_graph(triples: List[Triple], graph: str) -> List[Triple]:
@ -536,10 +543,9 @@ def focus_triples(
_triple(focus_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)),
]
# Add each selected edge with its reasoning via intermediate entity
# Add each selected edge with metadata via intermediate entity
for idx, edge_info in enumerate(selected_edges_with_reasoning):
edge = edge_info.get("edge")
reasoning = edge_info.get("reasoning", "")
if edge:
s, p, o = edge
@ -552,13 +558,32 @@ def focus_triples(
_triple(focus_uri, TG_SELECTED_EDGE, _iri(edge_sel_uri))
)
# Type the edge selection entity
triples.append(
_triple(edge_sel_uri, RDF_TYPE, _iri(TG_EDGE_SELECTION))
)
# Attach quoted triple to edge selection entity
quoted = _quoted_triple(s, p, o)
triples.append(
Triple(s=_iri(edge_sel_uri), p=_iri(TG_EDGE), o=quoted)
)
# Attach reasoning to edge selection entity
# Structured cross-encoder metadata
concept = edge_info.get("concept")
if concept:
triples.append(
_triple(edge_sel_uri, TG_CONCEPT, _literal(concept))
)
score = edge_info.get("score")
if score is not None:
triples.append(
_triple(edge_sel_uri, TG_SCORE, _literal(str(score)))
)
# Legacy reasoning text (for non-cross-encoder callers)
reasoning = edge_info.get("reasoning", "")
if reasoning:
triples.append(
_triple(edge_sel_uri, TG_REASONING, _literal(reasoning))
@ -698,6 +723,75 @@ def docrag_exploration_triples(
return triples
def docrag_chunk_selection_triples(
focus_uri: str,
exploration_uri: str,
selected_chunks_with_scores: List[dict],
session_id: str,
) -> List[Triple]:
"""
Build triples for a document RAG focus entity (chunks selected by the
cross-encoder reranker).
Mirrors GraphRAG's focus_triples / tg:EdgeSelection pattern: a Focus entity
derived from exploration, with one ChunkSelection sub-entity per surviving
chunk carrying the chunk reference and the reranker score.
Structure:
<focus> a tg:Focus ; prov:wasDerivedFrom <exploration> .
<focus> tg:selectedChunk <chunk_sel_0> .
<chunk_sel_0> a tg:ChunkSelection .
<chunk_sel_0> tg:document <chunk_id> .
<chunk_sel_0> tg:score "0.97" .
Args:
focus_uri: URI of the focus entity (from docrag_focus_uri)
exploration_uri: URI of the parent exploration entity
selected_chunks_with_scores: List of dicts with 'chunk_id' and 'score'
session_id: Session UUID for generating chunk selection URIs
Returns:
List of Triple objects
"""
triples = [
_triple(focus_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(focus_uri, RDF_TYPE, _iri(TG_FOCUS)),
_triple(focus_uri, RDFS_LABEL, _literal("Chunk Selection")),
_triple(focus_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)),
]
for idx, chunk_info in enumerate(selected_chunks_with_scores):
chunk_id = chunk_info.get("chunk_id")
if not chunk_id:
continue
chunk_sel_uri = chunk_selection_uri(session_id, idx)
# Link focus to chunk selection entity
triples.append(
_triple(focus_uri, TG_SELECTED_CHUNK, _iri(chunk_sel_uri))
)
# Type the chunk selection entity
triples.append(
_triple(chunk_sel_uri, RDF_TYPE, _iri(TG_CHUNK_SELECTION))
)
# Reference the actual chunk (in librarian)
triples.append(
_triple(chunk_sel_uri, TG_DOCUMENT, _iri(chunk_id))
)
# Cross-encoder score
score = chunk_info.get("score")
if score is not None:
triples.append(
_triple(chunk_sel_uri, TG_SCORE, _literal(str(score)))
)
return triples
def docrag_synthesis_triples(
synthesis_uri: str,
exploration_uri: str,

View file

@ -309,6 +309,35 @@ def docrag_exploration_uri(session_id: str) -> str:
return f"urn:trustgraph:docrag:{session_id}/exploration"
def docrag_focus_uri(session_id: str) -> str:
"""
Generate URI for a document RAG focus entity (chunks selected by the
cross-encoder reranker).
Args:
session_id: The session UUID.
Returns:
URN in format: urn:trustgraph:docrag:{uuid}/focus
"""
return f"urn:trustgraph:docrag:{session_id}/focus"
def chunk_selection_uri(session_id: str, chunk_index: int) -> str:
"""
Generate URI for a chunk selection item (links a reranked chunk to its
score). Mirrors edge_selection_uri for GraphRAG.
Args:
session_id: The session UUID.
chunk_index: Index of this chunk in the selection (0-based).
Returns:
URN in format: urn:trustgraph:prov:chunk:{uuid}:{index}
"""
return f"urn:trustgraph:prov:chunk:{session_id}:{chunk_index}"
def docrag_synthesis_uri(session_id: str) -> str:
"""
Generate URI for a document RAG synthesis entity (final answer).

View file

@ -29,6 +29,8 @@ from . namespaces import (
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
TG_SUBAGENT_GOAL, TG_PLAN_STEP,
TG_EDGE_SELECTION, TG_SCORE,
TG_CHUNK_SELECTION,
)
@ -93,6 +95,8 @@ TG_CLASS_LABELS = [
_label_triple(TG_FINDING, "Finding"),
_label_triple(TG_PLAN_TYPE, "Plan"),
_label_triple(TG_STEP_RESULT, "Step Result"),
_label_triple(TG_EDGE_SELECTION, "Edge Selection"),
_label_triple(TG_CHUNK_SELECTION, "Chunk Selection"),
]
# TrustGraph predicate labels
@ -117,6 +121,7 @@ TG_PREDICATE_LABELS = [
_label_triple(TG_ENTITY, "entity"),
_label_triple(TG_SUBAGENT_GOAL, "subagent goal"),
_label_triple(TG_PLAN_STEP, "plan step"),
_label_triple(TG_SCORE, "score"),
]

View file

@ -15,4 +15,5 @@ from .diagnosis import *
from .collection import *
from .storage import *
from .tool_service import *
from .sparql_query import *
from .sparql_query import *
from .reranker import *

View file

@ -6,17 +6,6 @@ from ..core.primitives import Error
# Prompt services, abstract the prompt generation
# extract-definitions:
# chunk -> definitions
# extract-relationships:
# chunk -> relationships
# kg-prompt:
# query, triples -> answer
# document-prompt:
# query, documents -> answer
# extract-rows
# schema, chunk -> rows
@dataclass
class PromptRequest:
id: str = ""
@ -46,4 +35,4 @@ class PromptResponse:
out_token: int | None = None
model: str | None = None
############################################################################
############################################################################

View file

@ -0,0 +1,35 @@
from dataclasses import dataclass, field
from ..core.primitives import Error
############################################################################
# Cross-encoder reranker
@dataclass
class RerankerQuery:
query_id: str = ""
query_text: str = ""
@dataclass
class RerankerDocument:
document_id: str = ""
document_text: str = ""
@dataclass
class RerankerRequest:
queries: list[RerankerQuery] = field(default_factory=list)
documents: list[RerankerDocument] = field(default_factory=list)
limit: int = 10
@dataclass
class RerankerResult:
document_id: str = ""
query_id: str = ""
score: float = 0.0
@dataclass
class RerankerResponse:
error: Error | None = None
results: list[RerankerResult] = field(default_factory=list)

View file

@ -40,7 +40,10 @@ class GraphRagResponse:
class DocumentRagQuery:
query: str = ""
collection: str = ""
doc_limit: int = 0
doc_limit: int = 0 # docs selected into the synthesis prompt
fetch_limit: int = 0 # candidate pool fetched from the vector store
# before reranking (0 = derive from doc_limit;
# values below doc_limit are raised to it)
streaming: bool = False
@dataclass

View file

@ -71,6 +71,7 @@ tg-invoke-rows-query = "trustgraph.cli.invoke_rows_query:main"
tg-invoke-sparql-query = "trustgraph.cli.invoke_sparql_query:main"
tg-invoke-row-embeddings = "trustgraph.cli.invoke_row_embeddings:main"
tg-invoke-prompt = "trustgraph.cli.invoke_prompt:main"
tg-invoke-reranker = "trustgraph.cli.invoke_reranker:main"
tg-invoke-structured-query = "trustgraph.cli.invoke_structured_query:main"
tg-load-doc-embeds = "trustgraph.cli.load_doc_embeds:main"
tg-load-kg-core = "trustgraph.cli.load_kg_core:main"

View file

@ -21,10 +21,12 @@ default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
default_collection = 'default'
default_doc_limit = 10
default_fetch_limit = 0
def question_explainable(
url, flow_id, question_text, collection, doc_limit, token=None, debug=False,
url, flow_id, question_text, collection, doc_limit, fetch_limit=0,
token=None, debug=False,
workspace="default",
):
"""Execute document RAG with explainability - shows provenance events inline."""
@ -39,6 +41,7 @@ def question_explainable(
query=question_text,
collection=collection,
doc_limit=doc_limit,
fetch_limit=fetch_limit,
):
if isinstance(item, RAGChunk):
# Print response content
@ -97,7 +100,7 @@ def question_explainable(
def question(
url, flow_id, question_text, collection, doc_limit,
url, flow_id, question_text, collection, doc_limit, fetch_limit=0,
streaming=True, token=None, explainable=False, debug=False,
show_usage=False, workspace="default",
):
@ -109,6 +112,7 @@ def question(
question_text=question_text,
collection=collection,
doc_limit=doc_limit,
fetch_limit=fetch_limit,
token=token,
debug=debug,
workspace=workspace,
@ -128,6 +132,7 @@ def question(
query=question_text,
collection=collection,
doc_limit=doc_limit,
fetch_limit=fetch_limit,
streaming=True
)
@ -155,6 +160,7 @@ def question(
query=question_text,
collection=collection,
doc_limit=doc_limit,
fetch_limit=fetch_limit,
)
print(result.text)
@ -214,7 +220,15 @@ def main():
'-d', '--doc-limit',
type=int,
default=default_doc_limit,
help=f'Document limit (default: {default_doc_limit})'
help=f'Documents selected into the prompt (default: {default_doc_limit})'
)
parser.add_argument(
'--fetch-limit',
type=int,
default=default_fetch_limit,
help='Candidate documents fetched from the vector store before '
'reranking (default: derive from doc-limit)'
)
parser.add_argument(
@ -251,6 +265,7 @@ def main():
question_text=args.question,
collection=args.collection,
doc_limit=args.doc_limit,
fetch_limit=args.fetch_limit,
streaming=not args.no_streaming,
token=args.token,
explainable=args.explainable,

View file

@ -112,14 +112,13 @@ def _question_explainable_api(
if focus_full and focus_full.edge_selections:
for edge_sel in focus_full.edge_selections:
if edge_sel.edge:
# Resolve labels for edge components
s_label, p_label, o_label = explain_client.resolve_edge_labels(
edge_sel.edge, collection
)
print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr)
if edge_sel.reasoning:
r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning
print(f" Reason: {r_short}", file=sys.stderr)
if edge_sel.concept or edge_sel.score is not None:
score_str = f"{edge_sel.score:.4f}" if edge_sel.score is not None else "?"
print(f" Concept: {edge_sel.concept} Score: {score_str}", file=sys.stderr)
elif isinstance(entity, Synthesis):
print(f"\n [synthesis] {prov_id}", file=sys.stderr)

View file

@ -0,0 +1,127 @@
"""
Invokes the reranker service to score and rank documents by relevance
to one or more queries.
"""
import argparse
import json
import os
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def query(url, flow_id, queries, documents, limit, token=None,
workspace="default"):
api = Api(url=url, token=token, workspace=workspace)
socket = api.socket()
flow = socket.flow(flow_id)
try:
query_objects = [
{"query_id": str(i), "query_text": q}
for i, q in enumerate(queries)
]
document_objects = [
{"document_id": str(i), "document_text": d}
for i, d in enumerate(documents)
]
result = flow.rerank(
queries=query_objects,
documents=document_objects,
limit=limit,
)
if "error" in result and result["error"]:
err = result["error"]
print(f"Error: [{err.get('type', '')}] {err.get('message', '')}")
return
for r in result.get("results", []):
doc_idx = int(r["document_id"])
query_idx = int(r["query_id"])
print(
f" {r['score']:.4f} | "
f"query: {queries[query_idx]} | "
f"doc: {documents[doc_idx]}"
)
finally:
socket.close()
def main():
parser = argparse.ArgumentParser(
prog='tg-invoke-reranker',
description=__doc__,
)
parser.add_argument(
'-u', '--url',
default=default_url,
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-t', '--token',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'-f', '--flow-id',
default="default",
help=f'Flow ID (default: default)'
)
parser.add_argument(
'-l', '--limit',
type=int,
default=10,
help='Maximum number of results (default: 10)',
)
parser.add_argument(
'-q', '--query',
action='append',
required=True,
help='Query text (can be specified multiple times)',
)
parser.add_argument(
'documents',
nargs='+',
help='Documents to rerank',
)
args = parser.parse_args()
try:
query(
url=args.url,
flow_id=args.flow_id,
queries=args.query,
documents=args.documents,
limit=args.limit,
token=args.token,
workspace=args.workspace,
)
except Exception as e:
print("Exception:", e, flush=True)
if __name__ == "__main__":
main()

View file

@ -203,9 +203,9 @@ def print_graphrag_text(trace, explain_client, flow, collection, api=None, show_
)
print(f" {i}. ({s_label}, {p_label}, {o_label})")
if edge_sel.reasoning:
r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning
print(f" Reasoning: {r_short}")
if edge_sel.concept or edge_sel.score is not None:
score_str = f"{edge_sel.score:.4f}" if edge_sel.score is not None else "?"
print(f" Concept: {edge_sel.concept} Score: {score_str}")
if show_provenance and edge_sel.edge:
provenance = trace_edge_provenance(
@ -519,7 +519,8 @@ def trace_to_dict(trace, trace_type):
"selected_edges": [
{
"edge": edge_sel.edge,
"reasoning": edge_sel.reasoning,
"concept": edge_sel.concept,
"score": edge_sel.score,
}
for edge_sel in focus.edge_selections
],

View file

@ -19,6 +19,7 @@ dependencies = [
"faiss-cpu",
"falkordb",
"fastembed",
"flashrank",
"ibis",
"jsonschema",
"langchain",
@ -83,6 +84,7 @@ graph-embeddings-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone:
graph-embeddings-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run"
graph-embeddings = "trustgraph.embeddings.graph_embeddings:run"
graph-rag = "trustgraph.retrieval.graph_rag:run"
reranker-flashrank = "trustgraph.reranker.flashrank:run"
kg-extract-agent = "trustgraph.extract.kg.agent:run"
kg-extract-definitions = "trustgraph.extract.kg.definitions:run"
kg-extract-rows = "trustgraph.extract.kg.rows:run"

View file

@ -37,6 +37,7 @@ from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
from . document_embeddings_query import DocumentEmbeddingsQueryRequestor
from . row_embeddings_query import RowEmbeddingsQueryRequestor
from . mcp_tool import McpToolRequestor
from . reranker import RerankerRequestor
from . text_load import TextLoad
from . document_load import DocumentLoad
@ -74,6 +75,7 @@ request_response_dispatchers = {
"structured-diag": StructuredDiagRequestor,
"row-embeddings": RowEmbeddingsQueryRequestor,
"sparql": SparqlQueryRequestor,
"reranker": RerankerRequestor,
}
system_dispatchers = {

View file

@ -0,0 +1,31 @@
from ... schema import RerankerRequest, RerankerResponse
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
class RerankerRequestor(ServiceRequestor):
def __init__(
self, backend, request_queue, response_queue, timeout,
consumer, subscriber,
):
super(RerankerRequestor, self).__init__(
backend=backend,
request_queue=request_queue,
response_queue=response_queue,
request_schema=RerankerRequest,
response_schema=RerankerResponse,
subscription = subscriber,
consumer_name = consumer,
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("reranker")
self.response_translator = TranslatorRegistry.get_response_translator("reranker")
def to_request(self, body):
return self.request_translator.decode(body)
def from_response(self, message):
return self.response_translator.encode_with_completion(message)

View file

@ -518,6 +518,7 @@ _FLOW_SERVICES = {
"structured-diag": "structured-query:read",
"row-embeddings": "row-embeddings:read",
"sparql": "sparql:read",
"reranker": "reranker",
}
for _kind, _cap in _FLOW_SERVICES.items():
_register_flow_kind("flow-service", _kind, _cap)

View file

@ -72,6 +72,7 @@ _READER_CAPS = {
"row-embeddings:read",
"llm",
"embeddings",
"reranker",
"mcp",
"config:read",
"flows:read",

View file

@ -10,6 +10,7 @@ import logging
from .... exceptions import TooManyRequests, LlmError
from .... base import LlmService, LlmResult, LlmChunk
from . variants import get_variant, DEFAULT_VARIANT, VARIANTS
# Module logger
logger = logging.getLogger(__name__)
@ -21,6 +22,7 @@ default_temperature = 0.0
default_max_output = 4096
default_api_key = os.getenv("OPENAI_TOKEN")
default_base_url = os.getenv("OPENAI_BASE_URL")
default_thinking = "off"
if default_base_url is None or default_base_url == "":
default_base_url = "https://api.openai.com/v1"
@ -28,16 +30,21 @@ if default_base_url is None or default_base_url == "":
class Processor(LlmService):
def __init__(self, **params):
model = params.get("model", default_model)
api_key = params.get("api_key", default_api_key)
base_url = params.get("url", default_base_url)
temperature = params.get("temperature", default_temperature)
max_output = params.get("max_output", default_max_output)
thinking = params.get("thinking", default_thinking)
variant_name = params.get("variant", DEFAULT_VARIANT)
if not api_key:
api_key = "not-set"
self.variant = get_variant(variant_name)
self.thinking = thinking
super(Processor, self).__init__(
**params | {
"model": model,
@ -56,13 +63,28 @@ class Processor(LlmService):
else:
self.openai = OpenAI(api_key=api_key)
logger.info("OpenAI LLM service initialized")
logger.info(
f"OpenAI LLM service initialized "
f"(variant={self.variant.name}, thinking={self.thinking})"
)
def _build_kwargs(self, model_name, temperature):
"""Build API call kwargs using the active variant."""
return self.variant.completion_kwargs(
max_output=self.max_output,
temperature=temperature,
thinking=self.thinking,
)
def _extract_content(self, message):
"""Extract visible content from a response message."""
if hasattr(self.variant, "extract_content"):
return self.variant.extract_content(message)
return message.content
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
@ -72,31 +94,38 @@ class Processor(LlmService):
try:
resp = self.openai.chat.completions.create(
model=model_name,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
}
]
}
],
temperature=effective_temperature,
max_completion_tokens=self.max_output,
api_kwargs = self._build_kwargs(model_name, effective_temperature)
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
}
]
}
]
resp = self.variant.create_completion(
self.openai, model_name, messages, **api_kwargs,
)
inputtokens = resp.usage.prompt_tokens
outputtokens = resp.usage.completion_tokens
logger.debug(f"LLM response: {resp.choices[0].message.content}")
content = self._extract_content(resp.choices[0].message)
thinking = self.variant.extract_thinking(resp.choices[0].message)
logger.debug(f"LLM response: {content}")
if thinking:
logger.debug(f"LLM thinking: {thinking[:200]}...")
logger.info(f"Input Tokens: {inputtokens}")
logger.info(f"Output Tokens: {outputtokens}")
resp = LlmResult(
text = resp.choices[0].message.content,
text = content,
in_token = inputtokens,
out_token = outputtokens,
model = model_name
@ -136,9 +165,7 @@ class Processor(LlmService):
Stream content generation from OpenAI.
Yields LlmChunk objects with is_final=True on the last chunk.
"""
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model (streaming): {model_name}")
@ -147,30 +174,26 @@ class Processor(LlmService):
prompt = system + "\n\n" + prompt
try:
response = self.openai.chat.completions.create(
model=model_name,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
}
]
}
],
temperature=effective_temperature,
max_completion_tokens=self.max_output,
stream=True,
stream_options={"include_usage": True}
)
api_kwargs = self._build_kwargs(model_name, effective_temperature)
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
}
]
}
]
total_input_tokens = 0
total_output_tokens = 0
# Stream chunks
for chunk in response:
async for chunk in self.variant.create_completion_stream(
self.openai, model_name, messages, **api_kwargs,
):
if chunk.choices and chunk.choices[0].delta.content:
yield LlmChunk(
text=chunk.choices[0].delta.content,
@ -254,6 +277,20 @@ class Processor(LlmService):
help=f'LLM max output tokens (default: {default_max_output})'
)
parser.add_argument(
'--thinking',
choices=["off", "low", "medium", "high"],
default=default_thinking,
help=f'Thinking/reasoning effort level (default: {default_thinking})'
)
parser.add_argument(
'--variant',
choices=sorted(VARIANTS.keys()),
default=DEFAULT_VARIANT,
help=f'API variant (default: {DEFAULT_VARIANT})'
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,219 @@
"""
OpenAI API variant profiles.
Different providers expose OpenAI-compatible APIs with subtle differences
in parameter names, thinking/reasoning support, and temperature handling.
Each variant encapsulates those quirks so the processor doesn't need
provider-specific conditionals.
"""
import re
import logging
logger = logging.getLogger(__name__)
class Variant:
"""Base variant — defines the interface all variants implement."""
name = None
token_param = "max_completion_tokens"
temperature_with_thinking = False
def completion_kwargs(self, max_output, temperature, thinking):
"""Build provider-specific kwargs for chat.completions.create().
Parameters
----------
max_output : int
Configured max output tokens.
temperature : float
Configured temperature.
thinking : str
Thinking effort level: "off", "low", "medium", "high".
Returns
-------
dict
Extra kwargs to spread into the API call.
"""
kwargs = {self.token_param: max_output}
if thinking != "off":
kwargs.update(self.thinking_kwargs(thinking))
if not self.temperature_with_thinking:
kwargs["temperature"] = 1.0
else:
kwargs["temperature"] = temperature
else:
kwargs["temperature"] = temperature
return kwargs
def thinking_kwargs(self, effort):
"""Return kwargs to enable thinking at the given effort level."""
return {}
def extract_thinking(self, message):
"""Extract thinking/reasoning content from a response message."""
return getattr(message, "reasoning_content", None)
def extract_thinking_stream(self, delta):
"""Extract thinking content from a streaming delta."""
return getattr(delta, "reasoning_content", None)
def create_completion(self, client, model, messages, **kwargs):
"""Call the completions API. Override for non-standard SDKs."""
return client.chat.completions.create(
model=model, messages=messages, **kwargs,
)
async def create_completion_stream(self, client, model, messages, **kwargs):
"""Call the streaming completions API. Override for non-standard SDKs."""
for chunk in client.chat.completions.create(
model=model, messages=messages, stream=True,
stream_options={"include_usage": True}, **kwargs,
):
yield chunk
class OpenAIVariant(Variant):
"""Standard OpenAI API (GPT-4o, o1, o3, etc.)."""
name = "openai"
token_param = "max_completion_tokens"
temperature_with_thinking = False
def thinking_kwargs(self, effort):
return {"reasoning_effort": effort}
class DeepSeekVariant(Variant):
"""DeepSeek API (R1, V3, etc.)."""
name = "deepseek"
token_param = "max_completion_tokens"
temperature_with_thinking = True
def completion_kwargs(self, max_output, temperature, thinking):
enabled = "enabled" if thinking != "off" else "disabled"
kwargs = {
self.token_param: max_output,
"temperature": temperature,
"extra_body": {
"thinking": {"type": enabled},
},
}
return kwargs
def thinking_kwargs(self, effort):
return {}
class DashScopeVariant(Variant):
"""Alibaba Cloud DashScope API (Qwen models)."""
name = "dashscope"
token_param = "max_completion_tokens"
temperature_with_thinking = True
def completion_kwargs(self, max_output, temperature, thinking):
enabled = thinking != "off"
return {
self.token_param: max_output,
"temperature": temperature,
"extra_body": {
"enable_thinking": enabled,
},
}
def thinking_kwargs(self, effort):
return {}
class QwenVariant(DashScopeVariant):
"""Qwen — alias for DashScope."""
name = "qwen"
class MistralVariant(Variant):
"""Mistral API (Mistral Large, etc.)."""
name = "mistral"
token_param = "max_tokens"
temperature_with_thinking = False
def thinking_kwargs(self, effort):
return {"reasoning_effort": effort}
class GlmVariant(Variant):
"""GLM / Zhipu AI API (GLM-4, GLM-4.7, etc.)."""
name = "glm"
token_param = "max_tokens"
temperature_with_thinking = True
def completion_kwargs(self, max_output, temperature, thinking):
enabled = "enabled" if thinking != "off" else "disabled"
kwargs = {
self.token_param: max_output,
"temperature": temperature,
"extra_body": {
"thinking": {"type": enabled},
},
}
return kwargs
def thinking_kwargs(self, effort):
return {}
class LlamaVariant(Variant):
"""Llama models via OpenAI-compatible servers (vLLM, Ollama, etc.).
Thinking is typically always-on or always-off depending on the model.
When present, thinking appears inline as <think>...</think> tags.
"""
name = "llama"
token_param = "max_tokens"
temperature_with_thinking = True
def thinking_kwargs(self, effort):
return {}
def extract_thinking(self, message):
content = message.content or ""
match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
return match.group(1).strip() if match else None
def extract_content(self, message):
"""Strip think tags from visible content."""
content = message.content or ""
return re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
VARIANTS = {
"openai": OpenAIVariant,
"deepseek": DeepSeekVariant,
"qwen": QwenVariant,
"mistral": MistralVariant,
"dashscope": DashScopeVariant,
"glm": GlmVariant,
"llama": LlamaVariant,
}
DEFAULT_VARIANT = "openai"
def get_variant(name):
"""Look up a variant by name, raising ValueError if unknown."""
cls = VARIANTS.get(name)
if cls is None:
raise ValueError(
f"Unknown variant {name!r}. "
f"Available: {', '.join(sorted(VARIANTS))}"
)
return cls()

View file

@ -0,0 +1 @@

View file

@ -0,0 +1,2 @@
from . processor import *

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from . processor import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,109 @@
"""
Reranker service using flashrank.
Scores query-document pairs and returns the top results ranked by
relevance.
"""
import asyncio
import logging
from ... base import RerankerService
from ... schema import RerankerResult
from flashrank import Ranker, RerankRequest
logger = logging.getLogger(__name__)
default_ident = "reranker"
default_model = "ms-marco-MiniLM-L-12-v2"
class Processor(RerankerService):
def __init__(self, **params):
model = params.get("model", default_model)
super(Processor, self).__init__(
**params | { "model": model }
)
self.default_model = model
self.cached_model_name = None
self.ranker = None
self._load_model(model)
def _load_model(self, model_name):
if self.cached_model_name != model_name:
logger.info(f"Loading flashrank model: {model_name}")
self.ranker = Ranker(model_name=model_name)
self.cached_model_name = model_name
logger.info(f"flashrank model {model_name} loaded successfully")
else:
logger.debug(f"Using cached model: {model_name}")
def _run_rerank(self, query, passages):
request = RerankRequest(query=query, passages=passages)
return self.ranker.rerank(request)
async def on_rerank(self, queries, documents, limit, model=None):
if not queries or not documents:
return []
use_model = model or self.default_model
if self.cached_model_name != use_model:
await asyncio.to_thread(self._load_model, use_model)
passages = [
{"id": d.document_id, "text": d.document_text}
for d in documents
]
best_scores = {}
for q in queries:
ranked = await asyncio.to_thread(
self._run_rerank, q.query_text, passages,
)
for r in ranked:
doc_id = r["id"]
score = r["score"]
score = float(score)
if doc_id not in best_scores or score > best_scores[doc_id][1]:
best_scores[doc_id] = (q.query_id, score)
results = sorted(
best_scores.items(),
key=lambda x: x[1][1],
reverse=True,
)[:limit]
return [
RerankerResult(
document_id=doc_id,
query_id=query_id,
score=score,
)
for doc_id, (query_id, score) in results
]
@staticmethod
def add_args(parser):
RerankerService.add_args(parser)
parser.add_argument(
'-m', '--model',
default=default_model,
help=f'Reranker model (default: {default_model})'
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -9,10 +9,12 @@ from trustgraph.provenance import (
docrag_question_uri,
docrag_grounding_uri,
docrag_exploration_uri,
docrag_focus_uri,
docrag_synthesis_uri,
docrag_question_triples,
grounding_triples,
docrag_exploration_triples,
docrag_chunk_selection_triples,
docrag_synthesis_triples,
set_graph,
GRAPH_RETRIEVAL,
@ -21,19 +23,25 @@ from trustgraph.provenance import (
# Module logger
logger = logging.getLogger(__name__)
# When the caller does not specify a fetch_limit, reranking over-fetches this
# many times the final doc_limit as the candidate pool, so the cross-encoder can
# recover relevant chunks the bi-encoder ranked just outside the top doc_limit.
# This is only the fallback default: an explicit fetch_limit overrides it.
OVERFETCH_FACTOR = 3
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
class Query:
def __init__(
self, rag, workspace, collection, verbose,
doc_limit=20, track_usage=None,
fetch_limit=20, track_usage=None,
):
self.rag = rag
self.workspace = workspace
self.collection = collection
self.verbose = verbose
self.doc_limit = doc_limit
self.fetch_limit = fetch_limit
self.track_usage = track_usage
async def extract_concepts(self, query):
@ -91,7 +99,7 @@ class Query:
# Query chunk matches for each concept concurrently
per_concept_limit = max(
1, self.doc_limit // len(vectors)
1, self.fetch_limit // len(vectors)
)
async def query_concept(vec):
@ -140,6 +148,7 @@ class DocumentRag:
def __init__(
self, prompt_client, embeddings_client, doc_embeddings_client,
fetch_chunk,
reranker_client=None,
verbose=False,
):
@ -150,12 +159,16 @@ class DocumentRag:
self.doc_embeddings_client = doc_embeddings_client
self.fetch_chunk = fetch_chunk
# Optional cross-encoder reranker. When None, the retrieval path is
# byte-identical to the pre-reranker behaviour.
self.reranker_client = reranker_client
if self.verbose:
logger.debug("DocumentRag initialized")
async def query(
self, query, workspace="default", collection="default",
doc_limit=20, streaming=False, chunk_callback=None,
doc_limit=20, fetch_limit=0, streaming=False, chunk_callback=None,
explain_callback=None, save_answer_callback=None,
):
"""
@ -165,7 +178,10 @@ class DocumentRag:
query: The query string
workspace: Workspace for isolation (also scopes chunk lookup)
collection: Collection identifier
doc_limit: Max chunks to retrieve
doc_limit: Chunks selected into the synthesis prompt (after rerank)
fetch_limit: Candidate pool fetched from the vector store before
reranking. 0 = derive (OVERFETCH_FACTOR x doc_limit when a
reranker is wired, else doc_limit).
streaming: Enable streaming LLM response
chunk_callback: async def callback(chunk, end_of_stream) for streaming
explain_callback: async def callback(triples, explain_id) for explainability
@ -197,6 +213,7 @@ class DocumentRag:
q_uri = docrag_question_uri(session_id)
gnd_uri = docrag_grounding_uri(session_id)
exp_uri = docrag_exploration_uri(session_id)
foc_uri = docrag_focus_uri(session_id)
syn_uri = docrag_synthesis_uri(session_id)
timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
@ -209,10 +226,21 @@ class DocumentRag:
)
await explain_callback(q_triples, q_uri)
# Resolve the candidate-pool size fetched from the vector store. When a
# reranker is wired, honour an explicit fetch_limit; if unset, fall back
# to the OVERFETCH_FACTOR heuristic. Never fetch fewer than doc_limit,
# else the rerank could not fill the prompt. Without a reranker, fetch
# doc_limit as before (byte-identical behaviour).
if self.reranker_client is not None:
fl = fetch_limit or (OVERFETCH_FACTOR * doc_limit)
fetch_count = max(fl, doc_limit)
else:
fetch_count = doc_limit
q = Query(
rag=self, workspace=workspace, collection=collection,
verbose=self.verbose,
doc_limit=doc_limit, track_usage=track_usage,
fetch_limit=fetch_count, track_usage=track_usage,
)
# Extract concepts from query (grounding step)
@ -235,6 +263,7 @@ class DocumentRag:
docs, chunk_ids = await q.get_docs(concepts)
# Emit exploration explainability after chunks retrieved
# (full candidate set, before any reranking)
if explain_callback:
exp_triples = set_graph(
docrag_exploration_triples(exp_uri, gnd_uri, len(chunk_ids), chunk_ids),
@ -242,6 +271,45 @@ class DocumentRag:
)
await explain_callback(exp_triples, exp_uri)
# Optional cross-encoder reranking pass between retrieval and
# synthesis. Mirrors GraphRAG's reranker usage but with a single
# query (the question). When no reranker is wired, this block is
# skipped entirely and behaviour is byte-identical to before.
reranked = False
if self.reranker_client is not None and docs:
results = await self.reranker_client.rerank(
queries=[{"id": "0", "text": query}],
documents=[
{"id": str(i), "text": d} for i, d in enumerate(docs)
],
# Narrow the over-fetched candidate pool down to the final
# doc_limit requested for synthesis.
limit=doc_limit,
)
# results are sorted desc by score and truncated to limit by the
# reranker service, so order gives the surviving top-N directly.
order = [int(r.document_id) for r in results]
docs = [docs[i] for i in order]
chunk_ids = [chunk_ids[i] for i in order]
reranked = True
# Emit chunk-selection (focus) explainability: surviving chunks
# with their cross-encoder scores, derived from exploration.
if explain_callback:
selected_chunks_with_scores = [
{"chunk_id": chunk_ids[i], "score": r.score}
for i, r in enumerate(results)
]
foc_triples = set_graph(
docrag_chunk_selection_triples(
foc_uri, exp_uri,
selected_chunks_with_scores, session_id,
),
GRAPH_RETRIEVAL
)
await explain_callback(foc_triples, foc_uri)
if self.verbose:
logger.debug("Invoking LLM...")
logger.debug(f"Documents: {docs}")
@ -291,9 +359,15 @@ class DocumentRag:
logger.warning(f"Failed to save answer to librarian: {e}")
synthesis_doc_id = None
# When reranking ran, synthesis derives from the focus (the
# reranked chunks actually fed to the LLM), as GraphRAG always does.
# When no reranker is wired, there is no focus stage, so synthesis
# derives from exploration (the unchanged no-op lineage) - a
# deliberate divergence from GraphRAG's always-on focus.
syn_parent = foc_uri if reranked else exp_uri
syn_triples = set_graph(
docrag_synthesis_triples(
syn_uri, exp_uri,
syn_uri, syn_parent,
document_id=synthesis_doc_id,
in_token=synthesis_result.in_token if synthesis_result else None,
out_token=synthesis_result.out_token if synthesis_result else None,

View file

@ -13,6 +13,7 @@ from . document_rag import DocumentRag
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import PromptClientSpec, EmbeddingsClientSpec
from ... base import DocumentEmbeddingsClientSpec
from ... base import RerankerClientSpec
from ... base import LibrarianSpec
# Module logger
@ -28,14 +29,21 @@ class Processor(FlowProcessor):
doc_limit = params.get("doc_limit", 5)
# Instance-default candidate-pool size fetched before cross-encoder
# reranking; the rerank step narrows it back down to doc_limit for the
# LLM. 0 means the core derives it (OVERFETCH_FACTOR x doc_limit).
fetch_limit = params.get("fetch_limit", 0)
super(Processor, self).__init__(
**params | {
"id": id,
"doc_limit": doc_limit,
"fetch_limit": fetch_limit,
}
)
self.doc_limit = doc_limit
self.fetch_limit = fetch_limit
self.register_specification(
ConsumerSpec(
@ -66,6 +74,13 @@ class Processor(FlowProcessor):
)
)
self.register_specification(
RerankerClientSpec(
request_name = "reranker-request",
response_name = "reranker-response",
)
)
self.register_specification(
ProducerSpec(
name = "response",
@ -105,6 +120,7 @@ class Processor(FlowProcessor):
doc_embeddings_client = flow("document-embeddings-request"),
prompt_client = flow("prompt-request"),
fetch_chunk = fetch_chunk,
reranker_client = flow("reranker-request"),
verbose=True,
)
@ -113,6 +129,13 @@ class Processor(FlowProcessor):
else:
doc_limit = self.doc_limit
# Candidate-pool size: per-request override, else the instance
# default; 0 lets the core derive it from doc_limit.
if v.fetch_limit:
fetch_limit = v.fetch_limit
else:
fetch_limit = self.fetch_limit
async def send_explainability(triples, explain_id):
await flow("explainability").send(Triples(
metadata=Metadata(
@ -163,6 +186,7 @@ class Processor(FlowProcessor):
workspace=flow.workspace,
collection=v.collection,
doc_limit=doc_limit,
fetch_limit=fetch_limit,
streaming=True,
chunk_callback=send_chunk,
explain_callback=send_explainability,
@ -188,6 +212,7 @@ class Processor(FlowProcessor):
workspace=flow.workspace,
collection=v.collection,
doc_limit=doc_limit,
fetch_limit=fetch_limit,
explain_callback=send_explainability,
save_answer_callback=save_answer,
)
@ -243,6 +268,15 @@ class Processor(FlowProcessor):
help=f'Default document fetch limit (default: 10)'
)
parser.add_argument(
'--fetch-limit',
type=int,
default=0,
help='Candidate chunks to fetch from the vector store and rerank '
'before keeping the top doc-limit for the LLM '
'(default: derive from doc-limit)'
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -120,7 +120,7 @@ class Query:
def __init__(
self, rag, collection, verbose,
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
max_path_length=2, track_usage=None,
max_path_length=2, edge_limit=25, track_usage=None,
):
self.rag = rag
self.collection = collection
@ -129,6 +129,7 @@ class Query:
self.triple_limit = triple_limit
self.max_subgraph_size = max_subgraph_size
self.max_path_length = max_path_length
self.edge_limit = edge_limit
self.track_usage = track_usage
async def extract_concepts(self, query):
@ -217,12 +218,9 @@ class Query:
logger.debug(f" {ent}")
return entities, concepts
async def maybe_label(self, e):
# The label cache lives on a per-request GraphRag instance — no
# cross-request isolation concern. The collection prefix keeps
# entries from different collections distinct within one request.
cache_key = f"{self.collection}:{e}"
cached_label = self.rag.label_cache.get(cache_key)
@ -244,11 +242,10 @@ class Query:
return label
async def execute_batch_triple_queries(self, entities, limit_per_entity):
"""Execute triple queries for multiple entities concurrently using streaming"""
"""Execute triple queries for multiple entities concurrently."""
tasks = []
for entity in entities:
# Create concurrent streaming tasks for all 3 query types per entity
tasks.extend([
self.rag.triples_client.query_stream(
s=entity, p=None, o=None,
@ -270,10 +267,8 @@ class Query:
)
])
# Execute all queries concurrently
results = await asyncio.gather(*tasks, return_exceptions=True)
# Combine all results
all_triples = []
for result in results:
if not isinstance(result, Exception) and result is not None:
@ -281,168 +276,151 @@ class Query:
return all_triples
async def follow_edges_batch(self, entities, max_depth):
"""Optimized iterative graph traversal with batching.
Returns:
tuple: (subgraph, term_map) where subgraph is a set of
(str, str, str) tuples and term_map maps each string tuple
to its original (Term, Term, Term) for type-preserving
provenance.
"""
visited = set()
current_level = set(entities)
subgraph = set()
term_map = {} # (str, str, str) -> (Term, Term, Term)
for depth in range(max_depth):
if not current_level or len(subgraph) >= self.max_subgraph_size:
break
# Filter out already visited entities
unvisited_entities = [e for e in current_level if e not in visited]
if not unvisited_entities:
break
# Batch query all unvisited entities at current level
triples = await self.execute_batch_triple_queries(
unvisited_entities, self.triple_limit
)
# Process results and collect next level entities
next_level = set()
for triple in triples:
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
subgraph.add(triple_tuple)
term_map[triple_tuple] = (to_term(triple.s), to_term(triple.p), to_term(triple.o))
# Collect entities for next level (only from s and o positions)
if depth < max_depth - 1: # Don't collect for final depth
s, p, o = triple_tuple
if s not in visited:
next_level.add(s)
if o not in visited:
next_level.add(o)
# Stop if subgraph size limit reached
if len(subgraph) >= self.max_subgraph_size:
return subgraph, term_map
# Update for next iteration
visited.update(current_level)
current_level = next_level
return subgraph, term_map
async def follow_edges(self, ent, subgraph, path_length):
"""Legacy method - replaced by follow_edges_batch"""
# Maintain backward compatibility with early termination checks
if path_length <= 0:
return
if len(subgraph) >= self.max_subgraph_size:
return
# For backward compatibility, convert to new approach
batch_result, _ = await self.follow_edges_batch([ent], path_length)
subgraph.update(batch_result)
async def get_subgraph(self, query):
"""
Get subgraph by extracting concepts, finding entities, and traversing.
Returns:
tuple: (subgraph, term_map, entities, concepts) where subgraph is
a list of (s, p, o) string tuples, term_map maps each string
tuple to its original (Term, Term, Term), entities is the seed
entity list, and concepts is the extracted concept list.
"""
entities, concepts = await self.get_entities(query)
if self.verbose:
logger.debug("Getting subgraph...")
# Use optimized batch traversal instead of sequential processing
subgraph, term_map = await self.follow_edges_batch(entities, self.max_path_length)
return list(subgraph), term_map, entities, concepts
async def resolve_labels_batch(self, entities):
"""Resolve labels for multiple entities in parallel"""
tasks = []
for entity in entities:
tasks.append(self.maybe_label(entity))
"""Resolve labels for multiple entities in parallel."""
tasks = [self.maybe_label(entity) for entity in entities]
return await asyncio.gather(*tasks, return_exceptions=True)
async def get_labelgraph(self, query):
"""
Get subgraph with labels resolved for display.
async def hop_and_filter(self, seed_entities, concepts):
"""Iterative hop-and-filter graph traversal with cross-encoder.
At each hop:
1. Retrieve all edges one hop from the frontier.
2. Resolve labels and represent each edge as "{p} {o}".
3. Score edges against concepts using the cross-encoder.
4. Select the top edges; their target nodes become the next
frontier.
Returns:
tuple: (labeled_edges, uri_map, entities, concepts) where:
- labeled_edges: list of (label_s, label_p, label_o) tuples
- uri_map: dict mapping edge_id(label_s, label_p, label_o) -> (uri_s, uri_p, uri_o)
- entities: list of seed entity URI strings
- concepts: list of concept strings extracted from query
tuple: (selected_edges, uri_map, edge_metadata) where:
- selected_edges: list of (label_s, label_p, label_o)
- uri_map: dict mapping edge_id -> (Term, Term, Term)
- edge_metadata: dict mapping edge_id -> {concept, score}
"""
subgraph, term_map, entities, concepts = await self.get_subgraph(query)
all_selected_edges = []
uri_map = {}
edge_metadata = {}
frontier = set(seed_entities)
visited_entities = set()
seen_edges = set()
# Filter out label triples
filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL]
for hop in range(self.max_path_length):
if not frontier:
break
# Collect all unique entities that need label resolution
entities_to_resolve = set()
for s, p, o in filtered_subgraph:
entities_to_resolve.update([s, p, o])
unvisited = [e for e in frontier if e not in visited_entities]
if not unvisited:
break
# Batch resolve labels for all entities in parallel
entity_list = list(entities_to_resolve)
resolved_labels = await self.resolve_labels_batch(entity_list)
if self.verbose:
logger.debug(
f"Hop {hop + 1}: {len(unvisited)} frontier entities"
)
# Create entity-to-label mapping
label_map = {}
for entity, label in zip(entity_list, resolved_labels):
if not isinstance(label, Exception):
label_map[entity] = label
else:
label_map[entity] = entity # Fallback to entity itself
# Apply labels to subgraph and build URI mapping
labeled_edges = []
uri_map = {} # Maps edge_id of labeled edge -> original Term triple
for s, p, o in filtered_subgraph:
labeled_triple = (
label_map.get(s, s),
label_map.get(p, p),
label_map.get(o, o)
# Retrieve edges one hop from frontier
triples = await self.execute_batch_triple_queries(
unvisited, self.triple_limit,
)
labeled_edges.append(labeled_triple)
# Map from labeled edge ID to original Terms (preserving types)
labeled_eid = edge_id(labeled_triple[0], labeled_triple[1], labeled_triple[2])
uri_map[labeled_eid] = term_map.get((s, p, o), (s, p, o))
# Deduplicate and filter already-seen edges
hop_triples = []
hop_term_map = {}
for triple in triples:
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
if triple_tuple[1] == LABEL:
continue
if triple_tuple in seen_edges:
continue
seen_edges.add(triple_tuple)
hop_triples.append(triple_tuple)
hop_term_map[triple_tuple] = (
to_term(triple.s), to_term(triple.p), to_term(triple.o),
)
labeled_edges = labeled_edges[0:self.max_subgraph_size]
if not hop_triples:
visited_entities.update(frontier)
break
if self.verbose:
logger.debug("Subgraph:")
for edge in labeled_edges:
logger.debug(f" {str(edge)}")
if self.verbose:
logger.debug(
f"Hop {hop + 1}: {len(hop_triples)} candidate edges"
)
if self.verbose:
logger.debug("Done.")
# Resolve labels for all entities in hop edges
entities_to_resolve = set()
for s, p, o in hop_triples:
entities_to_resolve.update([s, p, o])
return labeled_edges, uri_map, entities, concepts
entity_list = list(entities_to_resolve)
resolved = await self.resolve_labels_batch(entity_list)
label_map = {}
for entity, label in zip(entity_list, resolved):
if not isinstance(label, Exception):
label_map[entity] = label
else:
label_map[entity] = entity
# Build labeled edges and documents for cross-encoder
labeled_hop = []
for s, p, o in hop_triples:
ls = label_map.get(s, s)
lp = label_map.get(p, p)
lo = label_map.get(o, o)
labeled_hop.append((ls, lp, lo))
documents = [
{"id": str(i), "text": f"{lp} {lo}"}
for i, (ls, lp, lo) in enumerate(labeled_hop)
]
queries = [
{"id": str(i), "text": c}
for i, c in enumerate(concepts)
]
# Score with cross-encoder
results = await self.rag.reranker_client.rerank(
queries=queries,
documents=documents,
limit=self.edge_limit,
)
# Collect selected edges and metadata
next_frontier = set()
for r in results:
idx = int(r.document_id)
ls, lp, lo = labeled_hop[idx]
s, p, o = hop_triples[idx]
eid = edge_id(ls, lp, lo)
all_selected_edges.append((ls, lp, lo))
uri_map[eid] = hop_term_map[(s, p, o)]
edge_metadata[eid] = {
"concept": concepts[int(r.query_id)],
"score": r.score,
}
# Target nodes become next frontier
next_frontier.add(s)
next_frontier.add(o)
if self.verbose:
logger.debug(
f"Hop {hop + 1}: selected {len(results)} edges"
)
visited_entities.update(frontier)
frontier = next_frontier - visited_entities
return all_selected_edges, uri_map, edge_metadata
async def trace_source_documents(self, edge_uris):
"""
Trace selected edges back to their source documents via provenance.
Follows the chain: edge subgraph (via tg:contains) chunk
page document (via prov:wasDerivedFrom), all in urn:graph:source.
Follows the chain: edge -> subgraph (via tg:contains) -> chunk ->
page -> document (via prov:wasDerivedFrom), all in urn:graph:source.
Args:
edge_uris: List of (s, p, o) URI string tuples
@ -453,7 +431,6 @@ class Query:
# Step 1: Find subgraphs containing these edges via tg:contains
subgraph_tasks = []
for s, p, o in edge_uris:
# s, p, o may be Term objects (preserving types) or strings
s_term = s if isinstance(s, Term) else Term(type=IRI, iri=s)
p_term = p if isinstance(p, Term) else Term(type=IRI, iri=p)
o_term = o if isinstance(o, Term) else Term(type=IRI, iri=o)
@ -487,12 +464,10 @@ class Query:
return []
# Step 2: Walk prov:wasDerivedFrom chain to find documents
# Each level: query ?entity prov:wasDerivedFrom ?parent
# Stop when we find entities typed tg:Document
current_uris = subgraph_uris
doc_uris = set()
for depth in range(4): # Max depth: subgraph → chunk → page → doc
for depth in range(4):
if not current_uris:
break
@ -509,7 +484,6 @@ class Query:
*derivation_tasks, return_exceptions=True
)
# URIs with no parent are root documents
next_uris = set()
for uri, result in zip(current_uris, derivation_results):
if isinstance(result, Exception) or not result:
@ -524,7 +498,6 @@ class Query:
return []
# Step 3: Get all document metadata properties
# Skip structural predicates that aren't useful context
SKIP_PREDICATES = {
PROV_WAS_DERIVED_FROM,
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
@ -565,7 +538,7 @@ class GraphRag:
def __init__(
self, prompt_client, embeddings_client, graph_embeddings_client,
triples_client, verbose=False,
triples_client, reranker_client, verbose=False,
):
self.verbose = verbose
@ -574,9 +547,8 @@ class GraphRag:
self.embeddings_client = embeddings_client
self.graph_embeddings_client = graph_embeddings_client
self.triples_client = triples_client
self.reranker_client = reranker_client
# Replace simple dict with LRU cache with TTL
# CRITICAL: This cache only lives for one request due to per-request instantiation
self.label_cache = LRUCacheWithTTL(max_size=5000, ttl=300)
if self.verbose:
@ -585,33 +557,12 @@ class GraphRag:
async def query(
self, query, collection = "default",
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
max_path_length = 2, edge_score_limit = 30, edge_limit = 25,
max_path_length = 2, edge_limit = 25,
streaming = False,
chunk_callback = None,
explain_callback = None, save_answer_callback = None,
parent_uri = "",
):
"""
Execute a GraphRAG query with real-time explainability tracking.
Args:
query: The query string
collection: Collection identifier
entity_limit: Max entities to retrieve
triple_limit: Max triples per entity
max_subgraph_size: Max edges in subgraph
max_path_length: Max hops from seed entities
edge_score_limit: Max edges to pass to LLM scoring (semantic pre-filter)
edge_limit: Max edges after LLM scoring
streaming: Enable streaming LLM response
chunk_callback: async def callback(chunk, end_of_stream) for streaming
explain_callback: async def callback(triples, explain_id) for real-time explainability
save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian
Returns:
tuple: (answer_text, usage) where usage is a dict with
in_token, out_token, model
"""
# Accumulate token usage across all prompt calls
total_in = 0
total_out = 0
@ -638,7 +589,9 @@ class GraphRag:
foc_uri = make_focus_uri(session_id)
syn_uri = make_synthesis_uri(session_id)
timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
timestamp = datetime.now(timezone.utc).isoformat().replace(
"+00:00", "Z",
)
# Emit question explainability immediately
if explain_callback:
@ -657,10 +610,12 @@ class GraphRag:
triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
edge_limit = edge_limit,
track_usage = track_usage,
)
kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query)
# Step 1: Extract concepts and find seed entities
seed_entities, concepts = await q.get_entities(query)
# Emit grounding explain after concept extraction
if explain_callback:
@ -676,11 +631,16 @@ class GraphRag:
)
await explain_callback(gnd_triples, gnd_uri)
# Emit exploration explain after graph retrieval completes
# Step 2: Iterative hop-and-filter with cross-encoder
selected_edges, uri_map, edge_metadata = await q.hop_and_filter(
seed_entities, concepts,
)
# Emit exploration explain
if explain_callback:
exp_triples = set_graph(
exploration_triples(
exp_uri, gnd_uri, len(kg),
exp_uri, gnd_uri, len(selected_edges),
entities=seed_entities,
),
GRAPH_RETRIEVAL
@ -688,235 +648,63 @@ class GraphRag:
await explain_callback(exp_triples, exp_uri)
if self.verbose:
logger.debug("Invoking LLM...")
logger.debug(f"Knowledge graph: {kg}")
logger.debug(f"Query: {query}")
# Semantic pre-filter: reduce edges before expensive LLM scoring
if edge_score_limit > 0 and len(kg) > edge_score_limit:
if self.verbose:
logger.debug(f"Selected {len(selected_edges)} edges")
for s, p, o in selected_edges:
eid = edge_id(s, p, o)
meta = edge_metadata.get(eid, {})
logger.debug(
f"Semantic pre-filter: {len(kg)} edges > "
f"limit {edge_score_limit}, filtering..."
f" {meta.get('score', 0):.4f} "
f"[{meta.get('concept', '')}] "
f"{s} | {p} | {o}"
)
# Embed edge descriptions: "subject, predicate, object"
edge_descriptions = [
f"{s}, {p}, {o}" for s, p, o in kg
]
# Embed concepts and edge descriptions concurrently
concept_embed_task = self.embeddings_client.embed(concepts)
edge_embed_task = self.embeddings_client.embed(edge_descriptions)
concept_vectors, edge_vectors = await asyncio.gather(
concept_embed_task, edge_embed_task
)
# Score each edge by max cosine similarity to any concept
def cosine_similarity(a, b):
dot = sum(x * y for x, y in zip(a, b))
norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(x * x for x in b))
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
edge_scores = []
for i, edge_vec in enumerate(edge_vectors):
max_sim = max(
cosine_similarity(edge_vec, cv)
for cv in concept_vectors
)
edge_scores.append((max_sim, i))
# Sort by similarity descending and keep top edge_score_limit
edge_scores.sort(reverse=True)
keep_indices = set(
idx for _, idx in edge_scores[:edge_score_limit]
)
# Filter kg and rebuild uri_map
filtered_kg = []
filtered_uri_map = {}
for i, (s, p, o) in enumerate(kg):
if i in keep_indices:
filtered_kg.append((s, p, o))
eid = edge_id(s, p, o)
if eid in uri_map:
filtered_uri_map[eid] = uri_map[eid]
if self.verbose:
logger.debug(
f"Semantic pre-filter kept {len(filtered_kg)} "
f"of {len(kg)} edges"
)
kg = filtered_kg
uri_map = filtered_uri_map
# Build edge map: {hash_id: (labeled_s, labeled_p, labeled_o)}
# uri_map already maps edge_id -> (uri_s, uri_p, uri_o)
edge_map = {}
edges_with_ids = []
for s, p, o in kg:
eid = edge_id(s, p, o)
edge_map[eid] = (s, p, o)
edges_with_ids.append({
"id": eid,
"s": s,
"p": p,
"o": o
})
if self.verbose:
logger.debug(f"Built edge map with {len(edge_map)} edges")
# Step 1a: Edge Scoring - LLM scores edges for relevance
scoring_result = await self.prompt_client.prompt(
"kg-edge-scoring",
variables={
"query": query,
"knowledge": edges_with_ids
}
)
track_usage(scoring_result)
if self.verbose:
logger.debug(f"Edge scoring result: {scoring_result}")
# Parse scoring response (jsonl) to get edge IDs with scores
scored_edges = []
for obj in scoring_result.objects or []:
if isinstance(obj, dict) and "id" in obj and "score" in obj:
try:
score = int(obj["score"])
except (ValueError, TypeError):
score = 0
scored_edges.append({"id": obj["id"], "score": score})
# Select top N edges by score
scored_edges.sort(key=lambda x: x["score"], reverse=True)
top_edges = scored_edges[:edge_limit]
selected_ids = {e["id"] for e in top_edges}
if self.verbose:
logger.debug(
f"Scored {len(scored_edges)} edges, "
f"selected top {len(selected_ids)}"
)
# Filter to selected edges
selected_edges = []
for eid in selected_ids:
if eid in edge_map:
selected_edges.append(edge_map[eid])
# Step 1b: Edge Reasoning + Document Tracing (concurrent)
selected_edges_with_ids = [
{"id": eid, "s": s, "p": p, "o": o}
for eid in selected_ids
if eid in edge_map
for s, p, o in [edge_map[eid]]
]
# Collect selected edge URIs for document tracing
# Step 3: Document tracing
selected_edge_uris = [
uri_map[eid]
for eid in selected_ids
if eid in uri_map
uri_map[edge_id(s, p, o)]
for s, p, o in selected_edges
if edge_id(s, p, o) in uri_map
]
# Run reasoning and document tracing concurrently
async def _get_reasoning():
result = await self.prompt_client.prompt(
"kg-edge-reasoning",
variables={
"query": query,
"knowledge": selected_edges_with_ids
}
)
track_usage(result)
return result
reasoning_task = _get_reasoning()
doc_trace_task = q.trace_source_documents(selected_edge_uris)
reasoning_result, source_documents = await asyncio.gather(
reasoning_task, doc_trace_task, return_exceptions=True
source_documents = await q.trace_source_documents(
selected_edge_uris,
)
# Handle exceptions from gather
if isinstance(reasoning_result, Exception):
logger.warning(
f"Edge reasoning failed: {reasoning_result}"
)
reasoning_result = None
if isinstance(source_documents, Exception):
logger.warning(
f"Document tracing failed: {source_documents}"
)
source_documents = []
if self.verbose:
logger.debug(f"Edge reasoning result: {reasoning_result}")
# Parse reasoning response (jsonl) and build explainability data
reasoning_map = {}
if reasoning_result is not None:
for obj in reasoning_result.objects or []:
if isinstance(obj, dict) and "id" in obj:
reasoning_map[obj["id"]] = obj.get("reasoning", "")
# Build focus explainability data with cross-encoder metadata
selected_edges_with_reasoning = []
for eid in selected_ids:
for s, p, o in selected_edges:
eid = edge_id(s, p, o)
if eid in uri_map:
uri_s, uri_p, uri_o = uri_map[eid]
meta = edge_metadata.get(eid, {})
selected_edges_with_reasoning.append({
"edge": (uri_s, uri_p, uri_o),
"reasoning": reasoning_map.get(eid, ""),
"concept": meta.get("concept", ""),
"score": meta.get("score", 0),
})
if self.verbose:
logger.debug(f"Filtered to {len(selected_edges)} edges")
# Emit focus explain after edge selection completes
# Emit focus explain
if explain_callback:
# Sum scoring + reasoning token usage for focus event
focus_in = 0
focus_out = 0
focus_model = None
for r in [scoring_result, reasoning_result]:
if r is not None:
if r.in_token is not None:
focus_in += r.in_token
if r.out_token is not None:
focus_out += r.out_token
if r.model is not None:
focus_model = r.model
foc_triples = set_graph(
focus_triples(
foc_uri, exp_uri, selected_edges_with_reasoning, session_id,
in_token=focus_in or None,
out_token=focus_out or None,
model=focus_model,
foc_uri, exp_uri,
selected_edges_with_reasoning, session_id,
),
GRAPH_RETRIEVAL
)
await explain_callback(foc_triples, foc_uri)
# Step 2: Synthesis - LLM generates answer from selected edges only
# Step 4: Synthesis
selected_edge_dicts = [
{"s": s, "p": p, "o": o}
for s, p, o in selected_edges
]
# Add source document metadata as knowledge edges
for s, p, o in source_documents:
selected_edge_dicts.append({
"s": s, "p": p, "o": o,
@ -928,7 +716,6 @@ class GraphRag:
}
if streaming and chunk_callback:
# Accumulate chunks for answer storage while forwarding to callback
accumulated_chunks = []
async def accumulating_callback(chunk, end_of_stream):
@ -942,7 +729,6 @@ class GraphRag:
chunk_callback=accumulating_callback
)
track_usage(synthesis_result)
# Combine all chunks into full response
resp = "".join(accumulated_chunks)
else:
synthesis_result = await self.prompt_client.prompt(
@ -955,29 +741,42 @@ class GraphRag:
if self.verbose:
logger.debug("Query processing complete")
# Emit synthesis explain after synthesis completes
# Emit synthesis explain
if explain_callback:
synthesis_doc_id = None
answer_text = resp if resp else ""
# Save answer to librarian
if save_answer_callback and answer_text:
synthesis_doc_id = f"urn:trustgraph:synthesis:{session_id}"
try:
await save_answer_callback(synthesis_doc_id, answer_text)
if self.verbose:
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
logger.debug(
f"Saved answer to librarian: "
f"{synthesis_doc_id}"
)
except Exception as e:
logger.warning(f"Failed to save answer to librarian: {e}")
logger.warning(
f"Failed to save answer to librarian: {e}"
)
synthesis_doc_id = None
syn_triples = set_graph(
synthesis_triples(
syn_uri, foc_uri,
document_id=synthesis_doc_id,
in_token=synthesis_result.in_token if synthesis_result else None,
out_token=synthesis_result.out_token if synthesis_result else None,
model=synthesis_result.model if synthesis_result else None,
in_token=(
synthesis_result.in_token
if synthesis_result else None
),
out_token=(
synthesis_result.out_token
if synthesis_result else None
),
model=(
synthesis_result.model
if synthesis_result else None
),
),
GRAPH_RETRIEVAL
)
@ -993,4 +792,3 @@ class GraphRag:
}
return resp, usage

View file

@ -13,6 +13,7 @@ from . graph_rag import GraphRag
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import PromptClientSpec, EmbeddingsClientSpec
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
from ... base import RerankerClientSpec
from ... base import LibrarianSpec
# Module logger
@ -32,7 +33,6 @@ class Processor(FlowProcessor):
triple_limit = params.get("triple_limit", 30)
max_subgraph_size = params.get("max_subgraph_size", 150)
max_path_length = params.get("max_path_length", 2)
edge_score_limit = params.get("edge_score_limit", 30)
edge_limit = params.get("edge_limit", 25)
super(Processor, self).__init__(
@ -43,7 +43,6 @@ class Processor(FlowProcessor):
"triple_limit": triple_limit,
"max_subgraph_size": max_subgraph_size,
"max_path_length": max_path_length,
"edge_score_limit": edge_score_limit,
"edge_limit": edge_limit,
}
)
@ -52,7 +51,6 @@ class Processor(FlowProcessor):
self.default_triple_limit = triple_limit
self.default_max_subgraph_size = max_subgraph_size
self.default_max_path_length = max_path_length
self.default_edge_score_limit = edge_score_limit
self.default_edge_limit = edge_limit
# Workspace isolation is enforced by the flow layer (flow.workspace).
@ -96,6 +94,13 @@ class Processor(FlowProcessor):
)
)
self.register_specification(
RerankerClientSpec(
request_name = "reranker-request",
response_name = "reranker-response",
)
)
self.register_specification(
ProducerSpec(
name = "response",
@ -163,6 +168,7 @@ class Processor(FlowProcessor):
graph_embeddings_client=flow("graph-embeddings-request"),
triples_client=flow("triples-request"),
prompt_client=flow("prompt-request"),
reranker_client=flow("reranker-request"),
verbose=True,
)
@ -186,11 +192,6 @@ class Processor(FlowProcessor):
else:
max_path_length = self.default_max_path_length
if v.edge_score_limit:
edge_score_limit = v.edge_score_limit
else:
edge_score_limit = self.default_edge_score_limit
if v.edge_limit:
edge_limit = v.edge_limit
else:
@ -225,7 +226,7 @@ class Processor(FlowProcessor):
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
edge_score_limit = edge_score_limit,
edge_limit = edge_limit,
streaming = True,
chunk_callback = send_chunk,
@ -241,7 +242,7 @@ class Processor(FlowProcessor):
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
edge_score_limit = edge_score_limit,
edge_limit = edge_limit,
explain_callback = send_explainability,
save_answer_callback = save_answer,
@ -338,18 +339,11 @@ class Processor(FlowProcessor):
help=f'Default max path length (default: 2)'
)
parser.add_argument(
'--edge-score-limit',
type=int,
default=30,
help=f'Semantic pre-filter limit before LLM scoring (default: 30)'
)
parser.add_argument(
'--edge-limit',
type=int,
default=25,
help=f'Max edges after LLM scoring (default: 25)'
help=f'Max edges selected per hop by cross-encoder (default: 25)'
)
# Note: Explainability triples are now stored in the request's collection