mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-03 15:01:00 +02:00
Compare commits
42 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
508d0bb5c1 | ||
|
|
c05296376e | ||
|
|
f04ae5331d | ||
|
|
db7fdbc652 | ||
|
|
4aaa1ce915 | ||
|
|
9cf7dcb578 | ||
|
|
6c9a545a06 | ||
|
|
f18d48dc39 | ||
|
|
6887076ce0 | ||
|
|
55e2a2a3ce | ||
|
|
11ca7c89c4 | ||
|
|
656ca430b9 | ||
|
|
f20b50cfb2 | ||
|
|
04c5921687 | ||
|
|
01cc8dbc64 | ||
|
|
1aa9549912 | ||
|
|
5cb4f83afa | ||
|
|
0a828379be | ||
|
|
16f8cfd972 | ||
|
|
a3df4f62bb | ||
|
|
09b8a1d347 | ||
|
|
fa264ded46 | ||
|
|
cae931409a | ||
|
|
6b0475e315 | ||
|
|
cb0ad1a450 | ||
|
|
fc0ecc770a | ||
|
|
345da375b1 | ||
|
|
0ba1eeeda0 | ||
|
|
eb1e38d7d0 | ||
|
|
b8770a6005 | ||
|
|
28802a644a | ||
|
|
8797d9d9ff | ||
|
|
627c669097 | ||
|
|
8b0619e5d8 | ||
|
|
e3f9f8c357 | ||
|
|
81d57826c8 | ||
|
|
28a51c244f | ||
|
|
fa5ebe2393 | ||
|
|
97453d9b83 | ||
|
|
6dfa47aac8 | ||
|
|
dcee842455 | ||
|
|
36eadbda3a |
93 changed files with 7994 additions and 2215 deletions
2
.github/workflows/pull-request.yaml
vendored
2
.github/workflows/pull-request.yaml
vendored
|
|
@ -22,7 +22,7 @@ jobs:
|
|||
uses: actions/checkout@v3
|
||||
|
||||
- name: Setup packages
|
||||
run: make update-package-versions VERSION=2.5.999
|
||||
run: make update-package-versions VERSION=2.6.999
|
||||
|
||||
- name: Setup environment
|
||||
run: python3 -m venv env
|
||||
|
|
|
|||
218
README.dev-install.md
Normal file
218
README.dev-install.md
Normal file
|
|
@ -0,0 +1,218 @@
|
|||
# TrustGraph Developer Install Guide
|
||||
|
||||
A guided installer that gets TrustGraph running locally in a single
|
||||
command. It detects your hardware, recommends an LLM backend, installs
|
||||
missing prerequisites, runs the test suite, generates a compose deployment,
|
||||
starts the stack, and opens the Workbench UI.
|
||||
|
||||
> **macOS only.** This installer has only been tested on macOS. If you are
|
||||
> on Linux or Windows, use the standard docker-compose / podman-compose
|
||||
> installation instructions instead.
|
||||
|
||||
## Quick start
|
||||
|
||||
```bash
|
||||
./install_trustgraph.sh
|
||||
```
|
||||
|
||||
The installer walks you through each step interactively. When it finishes,
|
||||
the Workbench UI opens at `http://localhost:8888` and the API gateway is
|
||||
available at `http://localhost:8088/`.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
The installer checks for these and offers to install any that are missing
|
||||
(via Homebrew):
|
||||
|
||||
- **Python 3** with venv support
|
||||
- **Node.js / npx** (drives the `@trustgraph/config` deployment generator)
|
||||
- **Docker** (with Compose) or **Podman** (with podman-compose)
|
||||
- **curl** and **unzip**
|
||||
- **Ollama** (only if you choose local LLMs)
|
||||
|
||||
The installer can also launch Docker Desktop or the Ollama app for you if
|
||||
they are installed but not running.
|
||||
|
||||
## What the installer does
|
||||
|
||||
1. **Detects hardware** -- OS, architecture, CPU cores, memory, and GPU.
|
||||
2. **Recommends an LLM mode** -- `ollama` for machines with >= 16 GB RAM and
|
||||
a GPU or >= 8 cores; `openai` otherwise.
|
||||
3. **Collects configuration** -- API key, LLM provider, model choices,
|
||||
install directory. Answers are saved to
|
||||
`<install-dir>/trustgraph-installer.env` and reused on subsequent runs.
|
||||
4. **Checks and installs prerequisites** -- Python, Node/npx, Docker or
|
||||
Podman, Ollama (if selected).
|
||||
5. **Downloads Ollama models** (if using Ollama) -- chat model
|
||||
(`granite4:350m` by default) and embeddings model (`mxbai-embed-large`).
|
||||
6. **Creates a Python venv** and installs the local TrustGraph packages into
|
||||
it, along with NLTK data and tiktoken caches.
|
||||
7. **Runs the full pytest suite** against the local source tree.
|
||||
8. **Runs `npx @trustgraph/config`** -- the existing interactive config
|
||||
wizard that produces a `deploy.zip` with a compose file.
|
||||
9. **Starts the compose stack** and waits for the API gateway to respond.
|
||||
10. **Bootstraps IAM** and verifies the API key authenticates.
|
||||
11. **Opens the Workbench UI** in your default browser.
|
||||
|
||||
## Command-line options
|
||||
|
||||
| Option | Description |
|
||||
|---|---|
|
||||
| `--install-dir PATH` | Directory for deployment files (default: `./trustgraph-deploy`) |
|
||||
| `--api-url URL` | API gateway URL for health checks (default: `http://localhost:8088/`) |
|
||||
| `--ui-url URL` | Workbench UI URL to open (default: `http://localhost:8888`) |
|
||||
| `--use-existing-compose FILE` | Skip config generation and start this compose file directly |
|
||||
| `--skip-tests` | Do not run the pytest suite |
|
||||
| `--no-launch` | Do not open the Workbench UI at the end |
|
||||
| `--non-interactive` | Accept all defaults without prompting |
|
||||
| `--yes` | Auto-accept confirmation prompts |
|
||||
| `--fresh` | Remove installer-managed files before generating a new deployment |
|
||||
| `--remove-all` | Uninstall: stop containers, remove compose volumes, delete installer files |
|
||||
| `--dry-run` | Print detected hardware and planned defaults, then exit |
|
||||
| `-h`, `--help` | Show the built-in help text |
|
||||
|
||||
## Environment variables
|
||||
|
||||
These override the interactive prompts when set:
|
||||
|
||||
| Variable | Purpose |
|
||||
|---|---|
|
||||
| `TRUSTGRAPH_TOKEN` | Admin/bootstrap API key (must start with `tg_`) |
|
||||
| `TRUSTGRAPH_URL` | API gateway URL |
|
||||
| `TRUSTGRAPH_UI_URL` | Workbench UI URL |
|
||||
| `OPENAI_TOKEN` | OpenAI-compatible API key |
|
||||
| `OPENAI_BASE_URL` | OpenAI-compatible base URL |
|
||||
| `OLLAMA_HOST` / `OLLAMA_BASE_URL` | Ollama service URL |
|
||||
| `OLLAMA_MODEL` | Ollama chat model (default: `granite4:350m`) |
|
||||
| `OLLAMA_EMBEDDINGS_MODEL` | Ollama embeddings model (default: `mxbai-embed-large`) |
|
||||
| `TG_INSTALL_DIR` | Override the install directory |
|
||||
| `TG_VENV_DIR` | Override the Python venv location |
|
||||
| `TG_NLTK_DATA_DIR` | Override the NLTK data directory |
|
||||
| `TIKTOKEN_CACHE_DIR` | Override the tiktoken cache directory |
|
||||
| `TG_HEALTH_TIMEOUT` | Seconds to wait for the API gateway (default: 240) |
|
||||
|
||||
## Choosing an LLM mode
|
||||
|
||||
### OpenAI (or any OpenAI-compatible provider)
|
||||
|
||||
Best when you already have an API key or are running against a remote
|
||||
endpoint. The installer asks for a base URL and an API key.
|
||||
|
||||
```bash
|
||||
OPENAI_TOKEN=sk-... ./install_trustgraph.sh
|
||||
```
|
||||
|
||||
### Ollama (local models)
|
||||
|
||||
Best on machines with enough RAM to run a small model. The installer detects
|
||||
locally installed Ollama models and offers to pull missing ones. It uses
|
||||
`host.docker.internal` so the Docker containers can reach the host-side
|
||||
Ollama service.
|
||||
|
||||
```bash
|
||||
./install_trustgraph.sh # choose "ollama" when prompted
|
||||
```
|
||||
|
||||
### None
|
||||
|
||||
Start the platform without an LLM. Agent and RAG features will not work
|
||||
until you configure one later through the Workbench.
|
||||
|
||||
## Saved answers and re-running
|
||||
|
||||
The installer saves your answers to
|
||||
`<install-dir>/trustgraph-installer.env`. On the next run it loads those
|
||||
answers as defaults, so you can re-run with a single Enter through each
|
||||
prompt.
|
||||
|
||||
To start completely fresh:
|
||||
|
||||
```bash
|
||||
./install_trustgraph.sh --fresh
|
||||
```
|
||||
|
||||
This stops any running containers (keeping Docker volumes), removes
|
||||
installer-managed files, and re-runs the full flow.
|
||||
|
||||
## Using an existing compose file
|
||||
|
||||
If you already have a compose file from the config tool or another source:
|
||||
|
||||
```bash
|
||||
./install_trustgraph.sh --use-existing-compose path/to/docker-compose.yaml
|
||||
```
|
||||
|
||||
This skips the config wizard and `npx` prerequisite check, and goes straight
|
||||
to starting the stack.
|
||||
|
||||
## Non-interactive / CI usage
|
||||
|
||||
```bash
|
||||
TRUSTGRAPH_TOKEN=tg_my-token \
|
||||
OPENAI_TOKEN=sk-... \
|
||||
./install_trustgraph.sh --non-interactive --yes --skip-tests
|
||||
```
|
||||
|
||||
In non-interactive mode the installer uses defaults for every prompt. Pair
|
||||
with `--yes` to auto-accept confirmation prompts and `--skip-tests` if you
|
||||
want a faster run.
|
||||
|
||||
## Dry run
|
||||
|
||||
Preview what the installer would do without making any changes:
|
||||
|
||||
```bash
|
||||
./install_trustgraph.sh --dry-run
|
||||
```
|
||||
|
||||
This prints the detected hardware, recommended LLM mode, and planned
|
||||
install paths, then exits.
|
||||
|
||||
## Uninstalling
|
||||
|
||||
```bash
|
||||
./install_trustgraph.sh --remove-all
|
||||
```
|
||||
|
||||
This stops containers, removes compose-managed volumes, and deletes
|
||||
installer-managed files (venv, deploy output, logs, saved answers). It does
|
||||
**not** remove Docker/Podman itself, container images, Ollama, or Ollama
|
||||
models.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Logs
|
||||
|
||||
All long-running operations write logs to `<install-dir>/logs/`. Key files:
|
||||
|
||||
- `pytest.log` -- test suite output
|
||||
- `compose-up.log` -- docker compose output
|
||||
- `iam-bootstrap.log` -- IAM bootstrap output
|
||||
- `ollama-pull-*.log` -- Ollama model downloads
|
||||
- `pip-*.log` -- Python package installs
|
||||
- `brew-install-*.log` -- Homebrew installs
|
||||
|
||||
### API key rejected after reinstall
|
||||
|
||||
If the API gateway returns 401/403 with your saved key, the compose volumes
|
||||
likely contain IAM data from a previous install with a different key. Run:
|
||||
|
||||
```bash
|
||||
./install_trustgraph.sh --remove-all
|
||||
./install_trustgraph.sh
|
||||
```
|
||||
|
||||
This clears the old volumes and starts fresh.
|
||||
|
||||
### Ollama not reachable from containers
|
||||
|
||||
The Ollama base URL should use `host.docker.internal` instead of
|
||||
`localhost` so that containers running in Docker Desktop can reach the
|
||||
host-side Ollama service. The installer sets this automatically; if you
|
||||
override `OLLAMA_HOST`, make sure the URL is reachable from inside the
|
||||
container network.
|
||||
|
||||
### Docker daemon not running
|
||||
|
||||
The installer detects Docker Desktop and offers to start it. If that
|
||||
doesn't work, start Docker Desktop manually and re-run the installer.
|
||||
284
README.md
284
README.md
|
|
@ -3,52 +3,97 @@
|
|||
|
||||
<img src="TG-fullname-logo.svg" width=100% />
|
||||
|
||||
[](https://pypi.org/project/trustgraph/) [](LICENSE) 
|
||||
[](https://pypi.org/project/trustgraph/)  
|
||||
[](https://discord.gg/sQMwkRz5GX) [](https://deepwiki.com/trustgraph-ai/trustgraph)
|
||||
|
||||
[**Website**](https://trustgraph.ai) | [**Docs**](https://docs.trustgraph.ai) | [**YouTube**](https://www.youtube.com/@TrustGraphAI?sub_confirmation=1) | [**Configuration Terminal**](https://config-ui.demo.trustgraph.ai/) | [**Discord**](https://discord.gg/sQMwkRz5GX) | [**Blog**](https://blog.trustgraph.ai/subscribe)
|
||||
[**Website**](https://trustgraph.ai) | [**Docs**](https://docs.trustgraph.ai) | [**YouTube**](https://www.youtube.com/@TrustGraphAI?sub_confirmation=1) | [**Configuration Terminal**](https://config-ui.demo.trustgraph.ai/) | [**Discord**](https://discord.gg/yUWRkfbD) | [**Blog**](https://blog.trustgraph.ai/subscribe)
|
||||
|
||||
<a href="https://trendshift.io/repositories/17291" target="_blank"><img src="https://trendshift.io/api/badge/repositories/17291" alt="trustgraph-ai%2Ftrustgraph | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
# The agent runtime platform
|
||||
# Write context once. Run agents anywhere.
|
||||
|
||||
</div>
|
||||
|
||||
TrustGraph is an agent runtime platform built around context graphs — structured, queryable representations of your domain knowledge that ground every agent query in verified, explainable facts in private deployments with sovereign control. The platform is the full stack for agentic systems: context graphs, memory, retrieval, orchestration, and inference for precision-critical agent workloads.
|
||||
Stop rebuilding context from scratch. TrustGraph treats context as a holon — a modular, independent whole that naturally snaps into a larger domain-wide intelligence layer. By deploying context as holonic context graphs, TrustGraph powers multi-tenant agent workflows, dramatically reduces token consumption, and aligns with semantic web standards (RDF, OWL, SKOS, SHACL). Version your context, share it across teams, and scale with full provenance.
|
||||
|
||||
The platform:
|
||||
- [x] Multi-model and multimodal database system
|
||||
- [x] Tabular/relational, key-value
|
||||
- [x] Document, graph, and vectors
|
||||
- [x] Images, video, and audio
|
||||
- [x] Context Graph engine
|
||||
- [x] Automated entity and relationship extraction
|
||||
- [x] Ontology-driven graph construction
|
||||
- [x] Graph-grounded retrieval for explainable outputs
|
||||
- [x] Automated data ingest and loading
|
||||
- [x] Quick ingest with semantic similarity retrieval
|
||||
- [x] Ontology structuring for precision retrieval
|
||||
- [x] Out-of-the-box RAG pipelines
|
||||
- [x] DocumentRAG
|
||||
- [x] GraphRAG
|
||||
- [x] OntologyRAG
|
||||
- [x] 3D GraphViz for exploring context
|
||||
- [x] Fully Agentic System
|
||||
- [x] Single or Multi Agent
|
||||
- [x] ReAct, Plan-then-Execute, and Supervisor patterns
|
||||
- [x] MCP integration
|
||||
- [x] Run anywhere
|
||||
- [x] Deploy locally with Docker
|
||||
- [x] Deploy in cloud with Kubernetes
|
||||
- [x] Support for all major LLMs
|
||||
- [x] API support for Anthropic, Cohere, Gemini, Mistral, OpenAI, and others
|
||||
- [x] Model inferencing with vLLM, Ollama, TGI, LM Studio, and Llamafiles
|
||||
- [x] Developer friendly
|
||||
- [x] REST API [Docs](https://docs.trustgraph.ai/reference/apis/rest.html)
|
||||
- [x] Websocket API [Docs](https://docs.trustgraph.ai/reference/apis/websocket.html)
|
||||
- [x] Python API [Docs](https://docs.trustgraph.ai/reference/apis/python)
|
||||
- [x] CLI [Docs](https://docs.trustgraph.ai/reference/cli/)
|
||||
## What TrustGraph Does
|
||||
|
||||
TrustGraph is a complete holonic context harness for all LLMs. It provides the full infrastructure layer underneath your agents: knowledge ingestion, structured storage, graph-grounded retrieval, agent orchestration, and a full LLM inferencing stack.
|
||||
|
||||
TrustGraph relies on absolutely no 3rd party services aside from optional API integrations to cloud-hosted LLMs. Whether you are using Anthropic's or OpenAI's API, or self-hosting Qwen3.7 via vLLM, TrustGraph handles it all with pre-built API connectors and a full LLM inferencing stack to enrich the models with a sovereign, private holonic system that grounds your agents in reality.
|
||||
|
||||
## The Problem: Why Agents Break
|
||||
|
||||
When you build an AI agent today, you spend most of your time fighting context:
|
||||
|
||||
- **RAG retrieves fragments, not meaning**. Chunks of text have no structure. Relationships between facts are invisible. Your agent guesses at the connections.
|
||||
|
||||
- **Context is disposable**. What the agent learned in one session is gone in the next. There is no persistent, structured knowledge layer underneath.
|
||||
|
||||
- **Answers aren't traceable**. You can't explain why the agent said what it said, which means you can't trust it in production.
|
||||
|
||||
- **Knowledge can't be reused**. You rebuild the same context pipelines for every new project, every new agent, every new environment.
|
||||
|
||||
These aren't retrieval problems. They are structural problems. Context needs to be organized, versioned, and composable — exactly the way software infrastructure is.
|
||||
|
||||
## The Solution: A Holonic Context System
|
||||
The philosopher Arthur Koestler coined the word [holon](https://en.wikipedia.org/wiki/Holon_(philosophy)) to describe something that is simultaneously a whole in itself and a part of something larger. A fact is whole. It is also part of a domain. A domain is whole. It is also part of an organization's knowledge.
|
||||
|
||||
AI agents break down because this holonic structure is never built. Context gets shoved into flat text windows, scattered across vector stores, or hardwired into one-off prompts. Facts lose their relationships.
|
||||
|
||||
TrustGraph solves this by organizing your domain into holonic context graphs. Entities, relationships, and evidence are treated as first-class objects. Every agent query is grounded against these holons—marrying symbolic graph structures with vector embeddings. Every answer carries provenance. Every fact is traceable.
|
||||
|
||||
## Context Cores: Knowledge as a First-Class Citizen
|
||||
|
||||
A Context Core is the deployable unit of knowledge in TrustGraph. It packages everything an agent needs to reason reliably over a domain into a single, portable artifact.
|
||||
|
||||
### What's inside a Context Core
|
||||
- **Ontology** — your domain schema and entity mappings
|
||||
- **Holon** — entities, relationships, and supporting evidence
|
||||
- **Embeddings** — vector indexes for fast semantic entry-point lookup
|
||||
- **Provenance** — where every fact came from, when, and how it was derived
|
||||
- **Retrieval policies** — traversal rules, freshness controls, authority ranking
|
||||
|
||||
Context Cores decouple what agents know from how agents are deployed. Build once. Run in Docker locally, Kubernetes in production, or on any cloud. Pin a version. Roll back. Promote across environments. This is context engineering — and it works because knowledge is finally treated like the infrastructure it is.
|
||||
|
||||
## Explainability: Trust Your Agents in Production
|
||||
LLMs are black boxes, and traditional RAG makes it worse. When an agent pulls flat text chunks from a vector store, you have no idea how it connected those fragments to form an answer. You cannot ship agents to production if you can't explain why they said what they said.
|
||||
|
||||
### How TrustGraph makes agents explainable:
|
||||
|
||||
- **Traceable Reasoning Paths**: Instead of guessing at connections between text chunks, TrustGraph traverses explicit relationship paths in the holonic context graph. You can inspect exactly which entities, relationships, and sub-graphs were pulled into the LLM's context window to generate a given response.
|
||||
- **Fact-Level Provenance**: Every node and edge in the graph carries strict provenance. When an agent makes a claim, you can trace it back to the exact source document, the time it was ingested, and the extraction method used to derive it.
|
||||
- **No Black-Box Guesses**: By grounding the LLM in a structured, symbolic graph, you eliminate the hallucinations that occur when models are forced to infer relationships from unstructured text. If a fact isn't in the graph, the agent doesn't use it.
|
||||
|
||||
TrustGraph doesn't just give you answers - it gives you the receipt. Every fact is traceable, every connection is visible, and every output is verifiable.
|
||||
|
||||
## Workspaces, Collections, and Flows
|
||||
|
||||
TrustGraph has a [three-level system](https://docs.trustgraph.ai/overview/workspaces) for organizing and isolating knowledge.
|
||||
|
||||
A `Workspace` is the outermost boundary — a fully isolated tenancy scope where all data, users, configuration, and pipelines live independently from every other workspace. Isolation is structural: enforced at the pub/sub queue, storage, and API gateway layers, not by trusting a field in a message body.
|
||||
|
||||
Within a workspace, a `Collection` groups related holons, graph structures, embeddings, and documents together — think of it as a dedicated shelf in a library, scoped to a specific domain, project, or customer.
|
||||
|
||||
A `Flow` is a running data processing pipeline that defines how raw data moves through ingestion, extraction, structuring, and storage — the assembly line that turns documents into queryable knowledge. Together, the three layers let you run multiple isolated tenants on a single deployment, separate knowledge by domain within each tenant, and process that knowledge through fully configurable pipelines — all without restarting the system or rebuilding your infrastructure.
|
||||
|
||||
## The Full Stack
|
||||
TrustGraph is not a wrapper around a graph database. It is the complete backend for production agentic systems.
|
||||
|
||||
- **Holonic context graph engine**: automated entity and relationship extraction, ontology-driven graph construction, graph-grounded retrieval for explainable outputs
|
||||
- **Multi-model database**: tabular/relational, key-value, document, graph, vectors, images, video, and audio — all managed in Cassandra and S3-compatible Garage
|
||||
- **Out-of-the-box RAG pipelines**: DocumentRAG, GraphRAG, and OntologyRAG ready to deploy
|
||||
- **Fully agentic orchestration**: single or multi-agent, ReAct, Plan-then-Execute, Supervisor patterns, and MCP integration
|
||||
- **3D Knowledge Explorer**: interactive graph visualization with BFS neighborhood extraction and edge pulse animation
|
||||
- **Automated data ingest**: quick ingest with semantic similarity or ontology-structured precision retrieval
|
||||
- **Run anywhere**: Docker/Podman locally, Kubernetes in the cloud
|
||||
|
||||
All major LLMs — Anthropic, Cohere, Gemini, Mistral, OpenAI, and more via API.
|
||||
|
||||
vLLM, Ollama, TGI, LM Studio, and Llamafiles for fully local inferencing.
|
||||
|
||||
Verified cloud deployments for Alibaba Cloud, AWS, Azure, GCP, OVHcloud, and Scaleway.
|
||||
|
||||
## No API Keys Required
|
||||
|
||||
|
|
@ -62,12 +107,12 @@ Everything else is included.
|
|||
- [x] Managed Multi-model storage in [Cassandra](https://cassandra.apache.org/_/index.html)
|
||||
- [x] Managed Vector embedding storage in [Qdrant](https://github.com/qdrant/qdrant)
|
||||
- [x] Managed File and Object storage in [Garage](https://github.com/deuxfleurs-org/garage) (S3 compatible)
|
||||
- [x] Managed High-speed Pub/Sub messaging fabric with [Pulsar](https://github.com/apache/pulsar)
|
||||
- [x] Managed High-speed Pub/Sub messaging fabric with [Pulsar](https://github.com/apache/pulsar) or [RabbitMQ](https://www.rabbitmq.com/)
|
||||
- [x] Complete LLM inferencing stack for open LLMs with [vLLM](https://github.com/vllm-project/vllm), [TGI](https://github.com/huggingface/text-generation-inference), [Ollama](https://github.com/ollama/ollama), [LM Studio](https://github.com/lmstudio-ai), and [Llamafiles](https://github.com/mozilla-ai/llamafile)
|
||||
|
||||
## Quickstart
|
||||
|
||||
There's no need to clone this repo, unless you want to build from source. TrustGraph is a fully containerized app that deploys as a set of Docker containers. To configure TrustGraph on the command line:
|
||||
No need to clone the repo unless you are building from source. TrustGraph deploys as a set of Docker containers. Configure it on the command line in one step:
|
||||
|
||||
```
|
||||
npx @trustgraph/config
|
||||
|
|
@ -78,44 +123,39 @@ The config process will generate an app config that can be run locally with Dock
|
|||
- Deployment instructions as `INSTALLATION.md`
|
||||
|
||||
<p align="center">
|
||||
<video src="https://github.com/user-attachments/assets/2978a6aa-4c9c-4d7c-ad02-8f3d01a1c602"
|
||||
<video src="https://github.com/user-attachments/assets/33434c3c-f586-4610-8bb2-d7b7b586a672"
|
||||
width="80%" controls></video>
|
||||
</p>
|
||||
|
||||
For a browser based configuration, try the [Configuration Terminal](https://config-ui.demo.trustgraph.ai/).
|
||||
|
||||
## Watch What is a Context Graph?
|
||||
## Watch What is a Holonic Context Graph?
|
||||
|
||||
[](https://www.youtube.com/watch?v=gZjlt5WcWB4)
|
||||
|
||||
## Watch Context Graphs in Action
|
||||
## Watch Holonic Context Graphs in Action
|
||||
|
||||
[](https://www.youtube.com/watch?v=sWc7mkhITIo)
|
||||
|
||||
## Getting Started with TrustGraph
|
||||
|
||||
- [**Getting Started Guides**](https://docs.trustgraph.ai/getting-started)
|
||||
- [**Using the Workbench**](#workbench)
|
||||
- [**Developer APIs and CLI**](https://docs.trustgraph.ai/reference)
|
||||
- [**Deployment Guides**](https://docs.trustgraph.ai/deployment)
|
||||
|
||||
## Workbench
|
||||
## TrustGraph UI
|
||||
|
||||
The **Workbench** provides tools for all major features of TrustGraph. The **Workbench** is on port `8888` by default.
|
||||
<img width="1389" height="961" alt="Image" src="https://github.com/user-attachments/assets/35c9250d-0f01-40cb-9294-1ee8fd9a1b56" />
|
||||
|
||||
- **Vector Search**: Search the installed knowledge bases
|
||||
- **Agentic, GraphRAG and LLM Chat**: Chat interface for agents, GraphRAG queries, or direct to LLMs
|
||||
- **Relationships**: Analyze deep relationships in the installed knowledge bases
|
||||
- **Graph Visualizer**: 3D GraphViz of the installed knowledge bases
|
||||
- **Library**: Staging area for installing knowledge bases
|
||||
- **Flow Classes**: Workflow preset configurations
|
||||
- **Flows**: Create custom workflows and adjust LLM parameters during runtime
|
||||
- **Knowledge Cores**: Manage resuable knowledge bases
|
||||
- **Prompts**: Manage and adjust prompts during runtime
|
||||
- **Schemas**: Define custom schemas for structured data knowledge bases
|
||||
- **Ontologies**: Define custom ontologies for unstructured data knowledge bases
|
||||
- **Agent Tools**: Define tools with collections, knowledge cores, MCP connections, and tool groups
|
||||
- **MCP Tools**: Connect to MCP servers
|
||||
The UI provides tools for all major features of TrustGraph. The UI deploys on port `8888` by default.
|
||||
|
||||
- **Agent Console** — Query your agents directly with streaming responses and live explainability event tracking, so you can watch reasoning unfold in real time
|
||||
- **GraphRAG View** — Interactive graph RAG queries with a visual explainability DAG and inline provenance display, making it easy to see exactly where answers came from
|
||||
- **Context Explorer** — An interactive 3D context graph explorer with dynamic graph loading, BFS neighborhood extraction, edge pulse animation, and multiple navigation views
|
||||
- **Document Ingestion** — A complete upload and submission workflow with page and chunk inspection and document structure browsing
|
||||
- **Ontology Workbench** — A full ontology editor with class and property trees, OWL/XML and Turtle import/export with round-trip fidelity, circular dependency detection, and safe-delete confirmation dialogs
|
||||
- **Schema Workbench** — Interactive schema management with list, create, edit, and delete operations including field and index management
|
||||
- **Prompt Editor** — A dedicated prompt editing workflow
|
||||
|
||||
## TypeScript Library for UIs
|
||||
|
||||
|
|
@ -125,134 +165,6 @@ There are 3 libraries for quick UI integration of TrustGraph services.
|
|||
- [@trustgraph/react-state](https://www.npmjs.com/package/@trustgraph/react-state)
|
||||
- [@trustgraph/react-provider](https://www.npmjs.com/package/@trustgraph/react-provider)
|
||||
|
||||
## Context Cores
|
||||
|
||||
Context Cores are how TrustGraph treats context like code. A Context Core is a **portable, versioned bundle of context** that you can ship between projects and environments, pin in production, and reuse across agents. It packages the “stuff agents need to know” (structured knowledge + embeddings + evidence + policies) into a single artifact, so you can treat context like code: build it, test it, version it, promote it, and roll it back. TrustGraph is built to support this kind of end-to-end context engineering and orchestration workflow.
|
||||
|
||||
### What’s inside a Context Core
|
||||
A Context Core typically includes:
|
||||
- Ontology (your domain schema) and mappings
|
||||
- Context Graph (entities, relationships, supporting evidence)
|
||||
- Embeddings / vector indexes for fast semantic entry-point lookup
|
||||
- Source manifests + provenance (where facts came from, when, and how they were derived)
|
||||
- Retrieval policies (traversal rules, freshness, authority ranking)
|
||||
|
||||
## Tech Stack
|
||||
TrustGraph provides component flexibility to optimize agent workflows.
|
||||
|
||||
<details>
|
||||
<summary>LLM APIs</summary>
|
||||
<br>
|
||||
|
||||
- Anthropic<br>
|
||||
- AWS Bedrock<br>
|
||||
- AzureAI<br>
|
||||
- AzureOpenAI<br>
|
||||
- Cohere<br>
|
||||
- Google AI Studio<br>
|
||||
- Google VertexAI<br>
|
||||
- Mistral<br>
|
||||
- OpenAI<br>
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary>LLM Orchestration</summary>
|
||||
<br>
|
||||
|
||||
- LM Studio<br>
|
||||
- Llamafiles<br>
|
||||
- Ollama<br>
|
||||
- TGI<br>
|
||||
- vLLM<br>
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary>Multi-model storage</summary>
|
||||
<br>
|
||||
|
||||
- Apache Cassandra<br>
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary>VectorDB</summary>
|
||||
<br>
|
||||
|
||||
- Qdrant<br>
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary>File and Object Storage</summary>
|
||||
<br>
|
||||
|
||||
- Garage<br>
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary>Observability</summary>
|
||||
<br>
|
||||
|
||||
- Prometheus<br>
|
||||
- Grafana<br>
|
||||
- Loki<br>
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary>Data Streaming</summary>
|
||||
<br>
|
||||
|
||||
- Apache Pulsar<br>
|
||||
- RabbitMQ<br>
|
||||
- Apache Kafka<br>
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary>Clouds</summary>
|
||||
<br>
|
||||
|
||||
- AWS<br>
|
||||
- Azure<br>
|
||||
- Google Cloud<br>
|
||||
- OVHcloud<br>
|
||||
- Scaleway<br>
|
||||
|
||||
</details>
|
||||
|
||||
## Observability & Telemetry
|
||||
|
||||
Once the platform is running, access the Grafana dashboard at:
|
||||
|
||||
```
|
||||
http://localhost:3000
|
||||
```
|
||||
|
||||
Default credentials are:
|
||||
|
||||
```
|
||||
user: admin
|
||||
password: admin
|
||||
```
|
||||
|
||||
The default Grafana dashboard tracks the following:
|
||||
|
||||
<details>
|
||||
<summary>Telemetry</summary>
|
||||
<br>
|
||||
|
||||
- LLM Latency<br>
|
||||
- Error Rate<br>
|
||||
- Service Request Rates<br>
|
||||
- Queue Backlogs<br>
|
||||
- Chunking Histogram<br>
|
||||
- Error Source by Service<br>
|
||||
- Rate Limit Events<br>
|
||||
- CPU usage by Service<br>
|
||||
- Memory usage by Service<br>
|
||||
- Models Deployed<br>
|
||||
- Token Throughput (Tokens/second)<br>
|
||||
- Cost Throughput (Cost/second)<br>
|
||||
|
||||
</details>
|
||||
|
||||
## Contributing
|
||||
|
||||
[Developer's Guide](https://docs.trustgraph.ai/guides/building/introduction.html)
|
||||
|
|
@ -261,7 +173,7 @@ The default Grafana dashboard tracks the following:
|
|||
|
||||
**TrustGraph** is licensed under [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0).
|
||||
|
||||
Copyright 2024-2025 TrustGraph
|
||||
Copyright 2024-2026 TrustGraph
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
|
|
|||
|
|
@ -100,7 +100,6 @@ multi-word subsystems.
|
|||
| `users:admin` | Assign / remove roles on users within the workspace |
|
||||
| `keys:self` | Create / revoke / list **own** API keys |
|
||||
| `keys:admin` | Create / revoke / list **any user's** API keys within the workspace |
|
||||
| `workspaces:list-own` | List workspaces the caller has access to |
|
||||
| `workspaces:admin` | Create / delete / disable workspaces (system-level) |
|
||||
| `iam:admin` | JWT signing-key rotation, IAM-level operations |
|
||||
| `metrics:read` | Prometheus metrics proxy |
|
||||
|
|
@ -111,7 +110,7 @@ The open-source edition ships three roles:
|
|||
|
||||
| Role | Capabilities |
|
||||
|---|---|
|
||||
| `reader` | `agent`, `graph:read`, `documents:read`, `rows:read`, `llm`, `embeddings`, `mcp`, `collections:read`, `knowledge:read`, `flows:read`, `config:read`, `keys:self`, `workspaces:list-own` |
|
||||
| `reader` | `agent`, `graph:read`, `documents:read`, `rows:read`, `llm`, `embeddings`, `mcp`, `collections:read`, `knowledge:read`, `flows:read`, `config:read`, `keys:self` |
|
||||
| `writer` | everything in `reader` **+** `graph:write`, `documents:write`, `rows:write`, `collections:write`, `knowledge:write` |
|
||||
| `admin` | everything in `writer` **+** `config:write`, `flows:write`, `users:read`, `users:write`, `users:admin`, `keys:admin`, `workspaces:admin`, `iam:admin`, `metrics:read` |
|
||||
|
||||
|
|
|
|||
541
docs/tech-specs/graph-rag-semantic-filter.md
Normal file
541
docs/tech-specs/graph-rag-semantic-filter.md
Normal file
|
|
@ -0,0 +1,541 @@
|
|||
# GraphRAG Semantic Filter Improvement
|
||||
|
||||
## Problem Statement
|
||||
|
||||
The GraphRAG semantic filter is observed to be ineffective with certain
|
||||
LLM models. Smaller models in particular produce poor-quality edge
|
||||
relevance scores, and there is a suspicion that models trained or
|
||||
evaluated heavily on non-Roman-script datasets offer lower performance
|
||||
on the semantic ranking operation.
|
||||
|
||||
The root cause is that the current implementation delegates edge
|
||||
relevance scoring to the LLM via a prompt that asks the model to
|
||||
assign a 1–10 relevance score to each knowledge-graph edge. This
|
||||
task — ranking structured triples for relevance to a natural-language
|
||||
query — is not well covered in standard LLM evaluation suites, so
|
||||
model benchmark scores are not predictive of performance on this
|
||||
operation. The result is that GraphRAG quality varies unpredictably
|
||||
across model choices, undermining confidence in the pipeline.
|
||||
|
||||
Beyond model variability, the LLM scoring step has further problems:
|
||||
|
||||
- **Cost and latency.** The LLM call consumes tokens and adds
|
||||
latency to every query, yet its output is unreliable. Even when
|
||||
the model performs well, the cost is disproportionate for what is
|
||||
fundamentally a ranking operation.
|
||||
|
||||
- **Subjective scoring scale.** The 1–10 relevance scale gives the
|
||||
model no objective criteria for what constitutes a 5 versus a 7.
|
||||
Different models interpret the scale differently, and even the same
|
||||
model can produce inconsistent scores across runs.
|
||||
|
||||
- **Redundancy with the embedding pre-filter.** The pipeline already
|
||||
contains a cosine-similarity stage that ranks edges by semantic
|
||||
relevance using embeddings. The LLM scoring step is a second
|
||||
filter applied on top of this, and it is not clear that it adds
|
||||
enough value to justify the additional cost and risk of
|
||||
degradation.
|
||||
|
||||
### Industry context
|
||||
|
||||
Semantic ranking is rigorously evaluated on dedicated benchmarks such
|
||||
as MTEB (Massive Text Embedding Benchmark) and BEIR (Benchmarking
|
||||
Information Retrieval), which test retrieval and reranking across
|
||||
diverse domains. The current TrustGraph approach — prompting a
|
||||
general-purpose LLM to score and rank documents (the "listwise"
|
||||
approach) — is known to be poorly optimized for this task. It
|
||||
suffers from positional bias, formatting failures, and
|
||||
inconsistency at scale.
|
||||
|
||||
The industry standard for semantic ranking has moved to
|
||||
cross-encoder models: lightweight, purpose-built models that take a
|
||||
query–document pair as input and produce a single relevance score.
|
||||
These models are fine-tuned on millions of relevance-labelled pairs
|
||||
and dominate retrieval benchmarks. They are fast, deterministic,
|
||||
and do not require an LLM inference call.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Cross-encoder service
|
||||
|
||||
A new request/response service that exposes a generic semantic
|
||||
ranking API. The service is not specific to GraphRAG — it is a
|
||||
reusable building block for any component that needs to rank text
|
||||
by relevance.
|
||||
|
||||
The service interface is pluggable. Alternative implementations
|
||||
can be swapped in behind the same API.
|
||||
|
||||
**Packaging options considered:**
|
||||
|
||||
- *`sentence-transformers`.* Full-featured, widely used.
|
||||
However, it pulls in PyTorch (~2 GB), making containers
|
||||
very large. Tested at ~1.8 seconds for 2200 edges.
|
||||
|
||||
- *`optimum.onnxruntime`.* ONNX-based inference. Still
|
||||
depends on PyTorch at import time despite using ONNX for
|
||||
inference. Tested at ~4.2 seconds for 2200 edges.
|
||||
|
||||
- *`flashrank`.* Lightweight wrapper around ONNX Runtime
|
||||
with a clean API (`Ranker`, `RerankRequest`). No PyTorch
|
||||
dependency. Tested at ~4.4 seconds for 2200 edges.
|
||||
|
||||
- *Pure `onnxruntime` + `tokenizers`.* Leanest option
|
||||
(~200 MB total). Requires manual tokenisation, padding,
|
||||
and numpy array management — more boilerplate to maintain.
|
||||
|
||||
- *External API (e.g. Cohere Rerank).* No local model at
|
||||
all. Adds network latency and an external dependency.
|
||||
|
||||
**Decision:** `flashrank` for the initial implementation.
|
||||
No PyTorch dependency, clean API, comparable performance.
|
||||
The pluggable interface allows swapping to another backend
|
||||
later.
|
||||
|
||||
**Request:**
|
||||
|
||||
- `queries` — list of `{id, text}` objects. In the GraphRAG use
|
||||
case these are the concepts extracted from the user's question.
|
||||
- `documents` — list of `{id, text}` objects. In the GraphRAG
|
||||
use case these are the candidate knowledge-graph edges
|
||||
represented as text.
|
||||
- `limit` — integer. Maximum number of results to return.
|
||||
|
||||
**Scoring:**
|
||||
|
||||
The service produces the cartesian product of all query–document
|
||||
pairs and scores each pair through the cross-encoder model. For
|
||||
each document, the maximum score across all queries is taken as the
|
||||
document's relevance score. Documents are then ranked by this
|
||||
score and the top `limit` results are returned.
|
||||
|
||||
**Response:**
|
||||
|
||||
A list of the top `limit` results, each containing:
|
||||
|
||||
- `document_id` — the ID of the matched document.
|
||||
- `query_id` — the ID of the query (concept) that produced the
|
||||
highest score for this document.
|
||||
- `score` — the relevance score.
|
||||
|
||||
Including `query_id` in the response supports the explainability
|
||||
interface: it records that an edge was selected because it is
|
||||
related to a specific concept.
|
||||
|
||||
### Integration
|
||||
|
||||
The cross-encoder service follows the standard TrustGraph service
|
||||
integration pattern:
|
||||
|
||||
- **Base package (trustgraph-base).** Schema definitions for the
|
||||
cross-encoder request/response messages. A client class that
|
||||
other components (e.g. GraphRAG) can use to call the
|
||||
cross-encoder service. Message translator registration so the
|
||||
pub/sub layer can serialise/deserialise the messages.
|
||||
|
||||
- **Flow package (trustgraph-flow).** The cross-encoder service
|
||||
implementation itself — loads the model, listens for requests,
|
||||
scores pairs, returns results. Flow definition support so the
|
||||
cross-encoder can be introduced into a processing flow via the
|
||||
standard flow configuration. `flashrank` is added as a
|
||||
dependency of `trustgraph-flow`. The service runs in its own
|
||||
container.
|
||||
|
||||
- **API gateway.** A gateway endpoint that routes cross-encoder
|
||||
requests from the HTTP API to the service over pub/sub and
|
||||
returns the response.
|
||||
|
||||
- **CLI tool.** A command-line utility
|
||||
(e.g. `tg-invoke-cross-encoder`) that calls the gateway
|
||||
endpoint for manual testing and debugging.
|
||||
|
||||
### Current GraphRAG pipeline
|
||||
|
||||
The current pipeline follows these steps:
|
||||
|
||||
1. **Concept extraction.** An LLM prompt extracts key concepts
|
||||
from the user's query.
|
||||
|
||||
2. **Graph exploration.** Seed entities are found via embedding
|
||||
similarity. A subgraph is built by multi-hop traversal from
|
||||
the seed entities (up to `max_path_length` hops, capped at
|
||||
`max_subgraph_size` edges).
|
||||
|
||||
3. **Embedding pre-filter.** Each edge is embedded as
|
||||
`"subject, predicate, object"` and scored by cosine similarity
|
||||
against the concept embeddings. The top `edge_score_limit`
|
||||
(default 30) edges are kept.
|
||||
|
||||
4. **LLM edge scoring.** The `kg-edge-scoring` prompt asks the
|
||||
LLM to assign a 1–10 relevance score to each remaining edge.
|
||||
The top `edge_limit` (default 25) edges are kept.
|
||||
|
||||
5. **LLM edge reasoning.** The `kg-edge-reasoning` prompt asks
|
||||
the LLM to explain why each selected edge is relevant to the
|
||||
query. Used for the explainability interface.
|
||||
|
||||
6. **Document tracing.** Selected edges are traced back to their
|
||||
source documents in the librarian. Runs concurrently with
|
||||
step 5.
|
||||
|
||||
7. **Synthesis.** The `kg-synthesis` prompt generates the final
|
||||
answer from the selected edges and source document metadata.
|
||||
|
||||
### Potential improvements
|
||||
|
||||
#### Replace LLM edge scoring with cross-encoder (step 4)
|
||||
|
||||
The LLM edge scoring step is replaced by a call to the
|
||||
cross-encoder service. The candidate edges are the documents and
|
||||
`edge_limit` is the limit. This is a direct substitution: faster,
|
||||
cheaper, deterministic, and more reliable across model choices.
|
||||
The LLM `kg-edge-scoring` prompt is retired.
|
||||
|
||||
**Cross-encoder query input: concepts vs. raw query.** There are
|
||||
two options for what to use as the cross-encoder queries:
|
||||
|
||||
- *Option A: Raw user query.* Pass the original question as a
|
||||
single query string. Simpler, no dependency on concept
|
||||
extraction. However, raw queries contain noise words and
|
||||
conversational phrasing that do not match well against the
|
||||
structured vocabulary of knowledge-graph edges. A single query
|
||||
also means every edge competes against the full question — a
|
||||
partial match on one aspect is diluted.
|
||||
|
||||
- *Option B: Extracted concepts.* Pass the concepts from step 1
|
||||
as separate queries. The concepts are distilled, focused terms
|
||||
that are closer to the language of the edges. With multiple
|
||||
concepts as independent queries, the cross-encoder scores each
|
||||
edge against each concept separately, giving better coverage —
|
||||
an edge only needs to match one concept well to be selected.
|
||||
The trade-off is a dependency on the LLM concept extraction
|
||||
step, but this is already in the pipeline and is a lightweight,
|
||||
reliable LLM call.
|
||||
|
||||
**Decision:** Option B — use extracted concepts. The concept
|
||||
extraction is fast, and the resulting terms produce better
|
||||
cross-encoder matches against structured triples.
|
||||
|
||||
#### Edge text representation
|
||||
|
||||
The current embedding pre-filter represents each edge as
|
||||
`"subject, predicate, object"`. Two changes:
|
||||
|
||||
- **Drop commas.** Commas add tokenisation noise without semantic
|
||||
value.
|
||||
|
||||
- **Direction-aware text.** The reranker text should highlight
|
||||
the *new* information relative to the traversal direction.
|
||||
The frontier entity is already known context — repeating it
|
||||
adds noise and, when traversing from an object node, causes
|
||||
many edges to produce identical reranker text (e.g. 18
|
||||
products sharing the same `hasSubcategory Processors` triple
|
||||
all collapse to the same string when the subject is dropped).
|
||||
|
||||
The text is constructed based on which position the frontier
|
||||
entity occupied in the triple:
|
||||
|
||||
- **From subject** (s=entity): `"{predicate} {object}"` —
|
||||
the subject is known, predicate and object are new.
|
||||
- **From object** (o=entity): `"{subject} {predicate}"` —
|
||||
the object is known, subject and predicate are new.
|
||||
- **From predicate** (p=entity): `"{subject} {object}"` —
|
||||
the predicate is known, subject and object are new.
|
||||
|
||||
This eliminates the duplicate-text problem that arises when
|
||||
traversing inward from a shared object node, and gives the
|
||||
cross-encoder a more informative signal at every hop.
|
||||
|
||||
#### Remove the embedding pre-filter (step 3)
|
||||
|
||||
The embedding pre-filter was introduced to reduce the number of
|
||||
edges before the expensive LLM scoring call. With the
|
||||
cross-encoder replacing the LLM call, this cost equation changes.
|
||||
|
||||
**Arguments for removal:**
|
||||
|
||||
- The cross-encoder is fast enough to score the full subgraph
|
||||
directly. In testing, 2200 edges scored in ~1.8 seconds; at
|
||||
the default `max_subgraph_size` of 150 edges, scoring takes
|
||||
a fraction of a second.
|
||||
|
||||
- The pre-filter is a weaker version of what the cross-encoder
|
||||
does. Bi-encoder cosine similarity embeds the query and
|
||||
document independently and compares vectors; the cross-encoder
|
||||
processes both texts together through the full transformer,
|
||||
giving it much better relevance judgement. Running a weaker
|
||||
filter before a stronger one adds latency without improving
|
||||
quality.
|
||||
|
||||
- Removing it eliminates an embedding service call (two batches:
|
||||
concepts + edges) and the associated latency.
|
||||
|
||||
**Arguments for keeping it:**
|
||||
|
||||
- If the subgraph is very large (thousands of edges), the
|
||||
cross-encoder's linear scaling could become a bottleneck.
|
||||
The pre-filter would act as a safety valve.
|
||||
|
||||
- The embedding call is cheap compared to an LLM call, so the
|
||||
overhead is modest.
|
||||
|
||||
**Decision:** Remove the pre-filter. The `max_subgraph_size`
|
||||
parameter (default 150) already caps the number of edges entering
|
||||
this stage, so the cross-encoder will not face an unbounded
|
||||
workload. If very large subgraphs become a concern in future,
|
||||
the pre-filter can be reintroduced or `max_subgraph_size` can be
|
||||
tuned.
|
||||
|
||||
#### Iterative graph traversal with cross-encoder filtering
|
||||
|
||||
The current pipeline performs graph exploration and edge filtering
|
||||
as separate phases: first build the full subgraph (up to
|
||||
`max_path_length` hops), then score and filter edges. An
|
||||
alternative is to interleave traversal and filtering — at each
|
||||
hop, use the cross-encoder to select relevant edges before
|
||||
expanding further.
|
||||
|
||||
**Option A: Big-bang traversal then filter.** Traverse the full
|
||||
subgraph up to `max_path_length` hops from the seed entities,
|
||||
collecting all edges up to `max_subgraph_size`. Then
|
||||
cross-encode the entire result to select the top edges.
|
||||
|
||||
- Simple to implement — the current traversal logic is largely
|
||||
unchanged.
|
||||
- Produces large, unfocused subgraphs. Irrelevant branches are
|
||||
explored and scored even though they will be discarded.
|
||||
- Poorly suited to multi-hop reasoning. For a query about
|
||||
Voyager 1, the subgraph includes Voyager 2's edges because
|
||||
they are within hop distance, and the filter must then
|
||||
separate them.
|
||||
|
||||
**Option B: Iterative hop-and-filter.** At each hop:
|
||||
|
||||
1. Retrieve all edges one hop from the current frontier nodes.
|
||||
2. Cross-encode these edges against the query concepts.
|
||||
3. Select the top relevant edges.
|
||||
4. The target nodes of the selected edges become the frontier
|
||||
for the next hop.
|
||||
5. Repeat up to `max_path_length` hops.
|
||||
|
||||
The final set of selected edges across all hops is the input to
|
||||
synthesis.
|
||||
|
||||
- **Guided exploration.** Each hop focuses the search by
|
||||
pruning irrelevant branches before expanding further. The
|
||||
working set stays small and relevant at every step.
|
||||
- **Multi-hop reasoning works naturally.** Following
|
||||
"Voyager 1 → has-event → crossed the heliopause" succeeds
|
||||
because each hop is individually relevant and leads to the
|
||||
next.
|
||||
- **Smaller total workload.** Fewer edges are scored overall
|
||||
because irrelevant branches are never expanded.
|
||||
- **Trade-off: greedy pruning.** An edge discarded at hop 1
|
||||
cannot lead to relevant edges at hop 2. This is inherent in
|
||||
any bounded traversal, and the cross-encoder is better
|
||||
equipped to make this relevance judgement than a blind hop
|
||||
limit.
|
||||
- **Trade-off: sequential latency.** Hops cannot be
|
||||
parallelised since each depends on the previous. However,
|
||||
each cross-encoder call on a small edge set is very fast
|
||||
(sub-second for typical working sets).
|
||||
|
||||
**Decision:** Option B — iterative hop-and-filter. The guided
|
||||
traversal produces more focused subgraphs and supports multi-hop
|
||||
reasoning, which is a significant quality improvement over the
|
||||
current approach.
|
||||
|
||||
#### Replace LLM edge reasoning with cross-encoder metadata (step 5)
|
||||
|
||||
The current `kg-edge-reasoning` prompt asks the LLM to explain why
|
||||
each edge is relevant. With the cross-encoder now making the
|
||||
selection, this explanation would be a post-hoc fabrication — the
|
||||
LLM was not involved in the decision.
|
||||
|
||||
- *Option A: Keep LLM reasoning.* Generates natural-language
|
||||
explanations but they are not grounded in the actual selection
|
||||
process. Adds an LLM call per query.
|
||||
|
||||
- *Option B: Record cross-encoder metadata.* The cross-encoder
|
||||
already returns the matched concept and score for each selected
|
||||
edge. Use this directly as the explanation.
|
||||
|
||||
**Decision:** Option B. The cross-encoder metadata is the true
|
||||
reason the edge was selected. The `kg-edge-reasoning` prompt is
|
||||
retired.
|
||||
|
||||
#### Explainability interface update
|
||||
|
||||
The explainability interface uses a `Focus` entity containing
|
||||
`EdgeSelection` sub-entities. Each `EdgeSelection` currently
|
||||
carries an `edge` (the quoted triple) and a `reasoning` field
|
||||
(free-text LLM prose), stored as `tg:reasoning` in the
|
||||
provenance graph.
|
||||
|
||||
With the cross-encoder replacing LLM reasoning, the
|
||||
`EdgeSelection` type gains two new predicates and drops one:
|
||||
|
||||
- **Remove** `tg:reasoning` — no longer produced.
|
||||
- **Add** `tg:concept` — the concept text that produced the
|
||||
highest cross-encoder score for this edge.
|
||||
- **Add** `tg:score` — the cross-encoder relevance score.
|
||||
|
||||
This is an evolution of the existing `EdgeSelection` type, not a
|
||||
new entity type. The edge selection sub-entities currently have
|
||||
no `rdf:type` declared; a new `tg:EdgeSelection` type should be
|
||||
added so that consumers can identify them in the provenance
|
||||
graph. The `Focus` entity and its relationship to `Exploration`
|
||||
are unchanged.
|
||||
|
||||
The `Focus` entity's token-usage metadata (`tg:inToken`,
|
||||
`tg:outToken`, `tg:llmModel`) no longer applies since there is
|
||||
no LLM call. These fields are dropped from the Focus entity.
|
||||
|
||||
### Proposed pipeline
|
||||
|
||||
1. **Concept extraction.** Unchanged — LLM extracts key concepts
|
||||
from the user's query.
|
||||
|
||||
2. **Seed entity lookup.** Find seed entities via embedding
|
||||
similarity against the extracted concepts.
|
||||
|
||||
3. **Iterative hop-and-filter.** For each hop up to
|
||||
`max_path_length`:
|
||||
|
||||
a. Retrieve all edges one hop from the current frontier nodes.
|
||||
|
||||
b. Represent each edge using direction-aware text: from a
|
||||
subject node use `"{predicate} {object}"`, from an object
|
||||
node use `"{subject} {predicate}"`, from a predicate node
|
||||
use `"{subject} {object}"`.
|
||||
|
||||
c. Score edges against the extracted concepts using the
|
||||
cross-encoder service.
|
||||
|
||||
d. Select the top relevant edges. The target nodes of the
|
||||
selected edges become the frontier for the next hop.
|
||||
|
||||
4. **Document tracing.** Selected edges are traced back to source
|
||||
documents.
|
||||
|
||||
5. **Synthesis.** The `kg-synthesis` prompt generates the final
|
||||
answer from the selected edges and source document metadata.
|
||||
|
||||
### Implementation order
|
||||
|
||||
1. Cross-encoder service with full integration (base schema,
|
||||
flow service, gateway endpoint, CLI tool).
|
||||
2. GraphRAG pipeline changes (iterative hop-and-filter,
|
||||
edge representation, remove pre-filter).
|
||||
3. Explainability update (`tg:EdgeSelection` type, concept
|
||||
and score predicates, retire `tg:reasoning`).
|
||||
4. Retire `kg-edge-scoring` and `kg-edge-reasoning` prompts.
|
||||
5. Update `tg-invoke-graph-rag` and `tg-show-explain-trace`
|
||||
to display the new metadata. Use these as the main
|
||||
end-to-end test.
|
||||
6. Fix any failing unit tests, then add new tests as needed.
|
||||
7. Write guidance for UX devs to update the UI for the new
|
||||
explainability predicates.
|
||||
|
||||
## UX developer guidance
|
||||
|
||||
This section describes the changes to the explainability interface
|
||||
that affect frontend rendering of GraphRAG Focus events.
|
||||
|
||||
### What changed
|
||||
|
||||
Edge selection in GraphRAG previously used LLM-based scoring and
|
||||
reasoning. Each selected edge carried a `tg:reasoning` predicate
|
||||
with free-text explanation from the LLM. This has been replaced
|
||||
by a cross-encoder reranker that scores edges against query
|
||||
concepts. The explainability data now carries structured metadata
|
||||
instead of free text.
|
||||
|
||||
### Removed
|
||||
|
||||
- **`tg:reasoning`** is no longer emitted on edge selection
|
||||
entities in GraphRAG Focus events. UX code that reads
|
||||
`edge_sel.reasoning` will get an empty string. Remove any
|
||||
rendering that displays a "Reasoning" or "Reason" field for
|
||||
Focus edges.
|
||||
|
||||
- The **`kg-edge-scoring`**, **`kg-edge-reasoning`**, and
|
||||
**`kg-edge-selection`** prompts are retired. Any UX that
|
||||
references these prompt names should be cleaned up.
|
||||
|
||||
### Added
|
||||
|
||||
Each edge selection entity within a Focus event now has three
|
||||
new properties:
|
||||
|
||||
| RDF predicate | API field | Type | Description |
|
||||
|---|---|---|---|
|
||||
| `rdf:type tg:EdgeSelection` | (type check) | — | Each edge selection entity is now explicitly typed |
|
||||
| `tg:concept` | `edge_sel.concept` | `str` | The query concept that matched this edge |
|
||||
| `tg:score` | `edge_sel.score` | `float` or `None` | Cross-encoder relevance score (0.0–1.0) |
|
||||
|
||||
The `tg:edge` predicate (RDF-star quoted triple) is unchanged.
|
||||
|
||||
### How to render
|
||||
|
||||
The recommended rendering for each selected edge in a Focus event:
|
||||
|
||||
```
|
||||
Edge: (subject_label, predicate_label, object_label)
|
||||
Concept: <concept> Score: <score formatted to 4 decimal places>
|
||||
```
|
||||
|
||||
Scores near 1.0 indicate high relevance; scores near 0.0 indicate
|
||||
low relevance. UX could use the score to drive visual indicators
|
||||
such as colour intensity or a relevance bar.
|
||||
|
||||
Edges are not returned in score order — they arrive in traversal
|
||||
order across hops. If the UX wants to display edges ranked by
|
||||
relevance, sort by `edge_sel.score` descending.
|
||||
|
||||
### API classes (Python)
|
||||
|
||||
The `EdgeSelection` dataclass in `trustgraph.api.explainability`
|
||||
has these fields:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class EdgeSelection:
|
||||
uri: str
|
||||
edge: Optional[Dict[str, str]] # {"s": ..., "p": ..., "o": ...}
|
||||
reasoning: str = "" # Legacy, always empty for new traces
|
||||
concept: str = "" # Query concept that matched
|
||||
score: Optional[float] = None # Cross-encoder relevance score
|
||||
```
|
||||
|
||||
These are populated when calling
|
||||
`ExplainabilityClient.fetch_focus_with_edges()` or when parsing
|
||||
inline provenance triples from the streaming response.
|
||||
|
||||
### WebSocket response format
|
||||
|
||||
For inline explainability via the streaming WebSocket, Focus events
|
||||
arrive as `message_type: "explain"` responses. The `explain_triples`
|
||||
array contains the edge selection triples. The relevant predicates
|
||||
in wire format are:
|
||||
|
||||
```json
|
||||
{"s": {"t": "i", "i": "<edge_sel_uri>"},
|
||||
"p": {"t": "i", "i": "https://trustgraph.ai/ns/concept"},
|
||||
"o": {"t": "l", "v": "flyby event"}}
|
||||
|
||||
{"s": {"t": "i", "i": "<edge_sel_uri>"},
|
||||
"p": {"t": "i", "i": "https://trustgraph.ai/ns/score"},
|
||||
"o": {"t": "l", "v": "0.9962"}}
|
||||
```
|
||||
|
||||
Note that `tg:score` is transmitted as a string literal and must
|
||||
be parsed to a float on the client side.
|
||||
|
||||
### Exploration event
|
||||
|
||||
The Exploration event's `edge_count` field now reports the number
|
||||
of edges selected by the cross-encoder across all hops (previously
|
||||
it reported the total number of edges retrieved before filtering).
|
||||
The `entities` list continues to report the seed entities found
|
||||
by vector search.
|
||||
2603
install_trustgraph.sh
Normal file
2603
install_trustgraph.sh
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -95,10 +95,6 @@ class TestGraphRagIntegration:
|
|||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-scoring":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-reasoning":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
|
|
@ -113,14 +109,22 @@ class TestGraphRagIntegration:
|
|||
client.prompt.side_effect = mock_prompt
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_reranker_client(self):
|
||||
"""Mock reranker client for cross-encoder edge filtering"""
|
||||
client = AsyncMock()
|
||||
client.rerank.return_value = []
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client,
|
||||
mock_triples_client, mock_prompt_client):
|
||||
mock_triples_client, mock_reranker_client, mock_prompt_client):
|
||||
"""Create GraphRag instance with mocked dependencies"""
|
||||
return GraphRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
reranker_client=mock_reranker_client,
|
||||
prompt_client=mock_prompt_client,
|
||||
verbose=True
|
||||
)
|
||||
|
|
@ -167,8 +171,8 @@ class TestGraphRagIntegration:
|
|||
# 3. Should query triples to build knowledge subgraph
|
||||
assert mock_triples_client.query_stream.call_count > 0
|
||||
|
||||
# 4. Should call prompt four times (extract-concepts + edge-scoring + edge-reasoning + synthesis)
|
||||
assert mock_prompt_client.prompt.call_count == 4
|
||||
# 4. Should call prompt twice (extract-concepts + synthesis)
|
||||
assert mock_prompt_client.prompt.call_count == 2
|
||||
|
||||
# Verify final response
|
||||
response, usage = response
|
||||
|
|
|
|||
|
|
@ -63,11 +63,6 @@ class TestGraphRagStreaming:
|
|||
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
|
||||
if prompt_id == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_id == "kg-edge-scoring":
|
||||
# Edge scoring returns JSONL with IDs and scores
|
||||
return PromptResult(response_type="text", text='{"id": "abc12345", "score": 0.9}\n')
|
||||
elif prompt_id == "kg-edge-reasoning":
|
||||
return PromptResult(response_type="text", text='{"id": "abc12345", "reasoning": "Relevant to query"}\n')
|
||||
elif prompt_id == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
# Simulate streaming chunks with end_of_stream flags
|
||||
|
|
@ -88,14 +83,23 @@ class TestGraphRagStreaming:
|
|||
client.prompt.side_effect = prompt_side_effect
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_reranker_client(self):
|
||||
"""Mock reranker client for cross-encoder edge filtering"""
|
||||
client = AsyncMock()
|
||||
client.rerank.return_value = []
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def graph_rag_streaming(self, mock_embeddings_client, mock_graph_embeddings_client,
|
||||
mock_triples_client, mock_streaming_prompt_client):
|
||||
mock_triples_client, mock_reranker_client,
|
||||
mock_streaming_prompt_client):
|
||||
"""Create GraphRag instance with streaming support"""
|
||||
return GraphRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
reranker_client=mock_reranker_client,
|
||||
prompt_client=mock_streaming_prompt_client,
|
||||
verbose=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ class TestGraphRagStreamingProtocol:
|
|||
client = AsyncMock()
|
||||
|
||||
async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
if prompt_name == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
|
|
@ -63,14 +63,23 @@ class TestGraphRagStreamingProtocol:
|
|||
client.prompt.side_effect = prompt_side_effect
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_reranker_client(self):
|
||||
"""Mock reranker client for cross-encoder edge filtering"""
|
||||
client = AsyncMock()
|
||||
client.rerank.return_value = []
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client,
|
||||
mock_triples_client, mock_streaming_prompt_client):
|
||||
mock_triples_client, mock_reranker_client,
|
||||
mock_streaming_prompt_client):
|
||||
"""Create GraphRag instance with mocked dependencies"""
|
||||
return GraphRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
reranker_client=mock_reranker_client,
|
||||
prompt_client=mock_streaming_prompt_client,
|
||||
verbose=False
|
||||
)
|
||||
|
|
@ -327,7 +336,7 @@ class TestStreamingProtocolEdgeCases:
|
|||
client = AsyncMock()
|
||||
|
||||
async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
if prompt_name == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
|
|
@ -342,10 +351,14 @@ class TestStreamingProtocolEdgeCases:
|
|||
|
||||
client.prompt.side_effect = prompt_with_empties
|
||||
|
||||
mock_reranker = AsyncMock()
|
||||
mock_reranker.rerank.return_value = []
|
||||
|
||||
rag = GraphRag(
|
||||
embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[[0.1]]])),
|
||||
graph_embeddings_client=AsyncMock(query=AsyncMock(return_value=[])),
|
||||
triples_client=AsyncMock(query=AsyncMock(return_value=[])),
|
||||
reranker_client=mock_reranker,
|
||||
prompt_client=client,
|
||||
verbose=False
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,11 +15,20 @@ from openai.types.chat.chat_completion import Choice
|
|||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from trustgraph.model.text_completion.openai.llm import Processor
|
||||
from trustgraph.model.text_completion.openai.variants import get_variant
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
from trustgraph.base import LlmResult
|
||||
from trustgraph.schema import TextCompletionRequest, TextCompletionResponse, Error
|
||||
|
||||
|
||||
def _wire_variant(processor):
|
||||
"""Attach variant methods to a MagicMock processor."""
|
||||
processor.variant = get_variant("openai")
|
||||
processor.thinking = "off"
|
||||
processor._build_kwargs = Processor._build_kwargs.__get__(processor, Processor)
|
||||
processor._extract_content = Processor._extract_content.__get__(processor, Processor)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestTextCompletionIntegration:
|
||||
"""Integration tests for OpenAI text completion service coordination"""
|
||||
|
|
@ -66,6 +75,7 @@ class TestTextCompletionIntegration:
|
|||
|
||||
# Add the actual generate_content method from Processor class
|
||||
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
|
||||
_wire_variant(processor)
|
||||
|
||||
return processor
|
||||
|
||||
|
|
@ -119,6 +129,7 @@ class TestTextCompletionIntegration:
|
|||
|
||||
# Add the actual generate_content method
|
||||
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
|
||||
_wire_variant(processor)
|
||||
|
||||
# Act
|
||||
result = await processor.generate_content("System prompt", "User prompt")
|
||||
|
|
@ -247,6 +258,7 @@ class TestTextCompletionIntegration:
|
|||
processor.max_output = processor_config["max_output"]
|
||||
processor.openai = mock_openai_client
|
||||
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
|
||||
_wire_variant(processor)
|
||||
processors.append(processor)
|
||||
|
||||
# Simulate multiple concurrent requests
|
||||
|
|
@ -354,6 +366,7 @@ class TestTextCompletionIntegration:
|
|||
processor.max_output = 2048
|
||||
processor.openai = mock_openai_client
|
||||
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
|
||||
_wire_variant(processor)
|
||||
|
||||
# Act
|
||||
await processor.generate_content("System prompt", "User prompt")
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from openai.types.chat import ChatCompletionChunk
|
|||
from openai.types.chat.chat_completion_chunk import Choice as StreamChoice, ChoiceDelta
|
||||
|
||||
from trustgraph.model.text_completion.openai.llm import Processor
|
||||
from trustgraph.model.text_completion.openai.variants import get_variant
|
||||
from trustgraph.base import LlmChunk
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_streaming_chunks_valid,
|
||||
|
|
@ -18,6 +19,14 @@ from tests.utils.streaming_assertions import (
|
|||
)
|
||||
|
||||
|
||||
def _wire_variant(processor):
|
||||
"""Attach variant methods to a MagicMock processor."""
|
||||
processor.variant = get_variant("openai")
|
||||
processor.thinking = "off"
|
||||
processor._build_kwargs = Processor._build_kwargs.__get__(processor, Processor)
|
||||
processor._extract_content = Processor._extract_content.__get__(processor, Processor)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestTextCompletionStreaming:
|
||||
"""Integration tests for Text Completion streaming"""
|
||||
|
|
@ -69,6 +78,7 @@ class TestTextCompletionStreaming:
|
|||
processor.generate_content_stream = Processor.generate_content_stream.__get__(
|
||||
processor, Processor
|
||||
)
|
||||
_wire_variant(processor)
|
||||
|
||||
return processor
|
||||
|
||||
|
|
@ -190,6 +200,7 @@ class TestTextCompletionStreaming:
|
|||
processor.generate_content_stream = Processor.generate_content_stream.__get__(
|
||||
processor, Processor
|
||||
)
|
||||
_wire_variant(processor)
|
||||
|
||||
# Act
|
||||
chunks = []
|
||||
|
|
@ -223,6 +234,7 @@ class TestTextCompletionStreaming:
|
|||
processor.generate_content_stream = Processor.generate_content_stream.__get__(
|
||||
processor, Processor
|
||||
)
|
||||
_wire_variant(processor)
|
||||
|
||||
# Act
|
||||
chunks = []
|
||||
|
|
@ -258,6 +270,7 @@ class TestTextCompletionStreaming:
|
|||
processor.generate_content_stream = Processor.generate_content_stream.__get__(
|
||||
processor, Processor
|
||||
)
|
||||
_wire_variant(processor)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
|
|
@ -295,6 +308,7 @@ class TestTextCompletionStreaming:
|
|||
processor.generate_content_stream = Processor.generate_content_stream.__get__(
|
||||
processor, Processor
|
||||
)
|
||||
_wire_variant(processor)
|
||||
|
||||
# Act
|
||||
chunks = []
|
||||
|
|
@ -318,6 +332,7 @@ class TestTextCompletionStreaming:
|
|||
processor.generate_content_stream = Processor.generate_content_stream.__get__(
|
||||
processor, Processor
|
||||
)
|
||||
_wire_variant(processor)
|
||||
|
||||
system_prompt = "You are an expert."
|
||||
user_prompt = "Explain quantum physics."
|
||||
|
|
|
|||
|
|
@ -195,38 +195,6 @@ class TestPromptClientStreamingCallback:
|
|||
assert callback.call_args_list[0] == call("test", False)
|
||||
assert callback.call_args_list[1] == call("", True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kg_prompt_passes_parameters_to_callback(self, prompt_client):
|
||||
"""Test that kg_prompt correctly passes streaming parameters"""
|
||||
# Arrange
|
||||
async def mock_request(request, recipient=None, timeout=600):
|
||||
if recipient:
|
||||
responses = [
|
||||
PromptResponse(text="Answer", object=None, error=None, end_of_stream=False),
|
||||
PromptResponse(text="", object=None, error=None, end_of_stream=True),
|
||||
]
|
||||
for resp in responses:
|
||||
should_stop = await recipient(resp)
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
prompt_client.request = mock_request
|
||||
|
||||
callback = AsyncMock()
|
||||
|
||||
# Act
|
||||
await prompt_client.kg_prompt(
|
||||
query="What is machine learning?",
|
||||
kg=[("subject", "predicate", "object")],
|
||||
streaming=True,
|
||||
chunk_callback=callback
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert callback.call_count == 2
|
||||
assert callback.call_args_list[0] == call("Answer", False)
|
||||
assert callback.call_args_list[1] == call("", True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_prompt_passes_parameters_to_callback(self, prompt_client):
|
||||
"""Test that document_prompt correctly passes streaming parameters"""
|
||||
|
|
|
|||
54
tests/unit/test_bootstrap/test_default_flow_start.py
Normal file
54
tests/unit/test_bootstrap/test_default_flow_start.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
"""
|
||||
Unit tests for trustgraph.bootstrap.initialisers.DefaultFlowStart
|
||||
|
||||
Verifies the list/start timeouts are configurable and that the
|
||||
configured values actually reach the flow-client request calls.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.bootstrap.initialisers.default_flow_start import (
|
||||
DefaultFlowStart,
|
||||
)
|
||||
|
||||
|
||||
def test_default_timeouts():
|
||||
init = DefaultFlowStart(blueprint="bp")
|
||||
assert init.list_timeout == 10
|
||||
assert init.start_timeout == 30
|
||||
|
||||
|
||||
def test_timeout_overrides_are_stored():
|
||||
init = DefaultFlowStart(blueprint="bp", list_timeout=5, start_timeout=99)
|
||||
assert init.list_timeout == 5
|
||||
assert init.start_timeout == 99
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_forwards_configured_timeouts():
|
||||
init = DefaultFlowStart(blueprint="bp", list_timeout=5, start_timeout=99)
|
||||
|
||||
# Flow client: list-flows returns no error + empty flow list,
|
||||
# start-flow returns no error.
|
||||
flow = MagicMock()
|
||||
flow.start = AsyncMock()
|
||||
flow.stop = AsyncMock()
|
||||
flow.request = AsyncMock(side_effect=[
|
||||
MagicMock(error=None, flow_ids=[]), # list-flows response
|
||||
MagicMock(error=None), # start-flow response
|
||||
])
|
||||
|
||||
# Context: workspace "default" exists, hands back our mock flow client.
|
||||
ctx = MagicMock()
|
||||
ctx.logger = MagicMock()
|
||||
ctx.config.keys = AsyncMock(return_value=["default"])
|
||||
ctx.make_flow_client = MagicMock(return_value=flow)
|
||||
|
||||
await init.run(ctx, None, "v1")
|
||||
|
||||
calls = flow.request.call_args_list
|
||||
assert len(calls) == 2
|
||||
assert calls[0].kwargs["timeout"] == 5
|
||||
assert calls[1].kwargs["timeout"] == 99
|
||||
13
tests/unit/test_bootstrap/test_workspace_init.py
Normal file
13
tests/unit/test_bootstrap/test_workspace_init.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
"""Unit tests for trustgraph.bootstrap.initialisers.WorkspaceInit."""
|
||||
|
||||
from trustgraph.bootstrap.initialisers.workspace_init import WorkspaceInit
|
||||
|
||||
|
||||
def test_default_iam_timeout():
|
||||
init = WorkspaceInit()
|
||||
assert init.iam_timeout == 10
|
||||
|
||||
|
||||
def test_iam_timeout_override_is_stored():
|
||||
init = WorkspaceInit(iam_timeout=42)
|
||||
assert init.iam_timeout == 42
|
||||
|
|
@ -129,6 +129,9 @@ class TestBatchTripleQueries:
|
|||
|
||||
# 3 queries, alternating results
|
||||
assert len(result) == 3
|
||||
# Each result is a (triple, direction) tuple
|
||||
for triple, direction in result:
|
||||
assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_in_one_query_does_not_block_others(self):
|
||||
|
|
@ -153,6 +156,8 @@ class TestBatchTripleQueries:
|
|||
|
||||
# 3 queries: 2 succeed, 1 fails → 2 triples
|
||||
assert len(result) == 2
|
||||
for triple, direction in result:
|
||||
assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_results_filtered(self):
|
||||
|
|
@ -176,6 +181,8 @@ class TestBatchTripleQueries:
|
|||
|
||||
# 3 queries: 1 returns None, 2 return triples
|
||||
assert len(result) == 2
|
||||
for triple, direction in result:
|
||||
assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_entities_no_queries(self):
|
||||
|
|
@ -220,6 +227,80 @@ class TestBatchTripleQueries:
|
|||
assert calls[2].kwargs["p"] is None
|
||||
assert calls[2].kwargs["o"] == "ent-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_directions_assigned_correctly(self):
|
||||
"""Each query position should produce the correct direction tag."""
|
||||
triple = _make_triple("s", "p", "o")
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def one_triple_each(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return [triple]
|
||||
|
||||
client = AsyncMock()
|
||||
client.query_stream = one_triple_each
|
||||
query = _make_query(triples_client=client)
|
||||
|
||||
result = await query.execute_batch_triple_queries(
|
||||
["e1"], limit_per_entity=10
|
||||
)
|
||||
|
||||
assert len(result) == 3
|
||||
# Order matches query order: s-position, p-position, o-position
|
||||
assert result[0][1] == Query.FROM_S
|
||||
assert result[1][1] == Query.FROM_P
|
||||
assert result[2][1] == Query.FROM_O
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_directions_correct_for_multiple_entities(self):
|
||||
"""Direction tags cycle correctly across multiple entities."""
|
||||
triple = _make_triple("s", "p", "o")
|
||||
client = AsyncMock()
|
||||
client.query_stream = AsyncMock(return_value=[triple])
|
||||
query = _make_query(triples_client=client)
|
||||
|
||||
result = await query.execute_batch_triple_queries(
|
||||
["e1", "e2"], limit_per_entity=10
|
||||
)
|
||||
|
||||
assert len(result) == 6
|
||||
expected_directions = [
|
||||
Query.FROM_S, Query.FROM_P, Query.FROM_O,
|
||||
Query.FROM_S, Query.FROM_P, Query.FROM_O,
|
||||
]
|
||||
for (_, direction), expected in zip(result, expected_directions):
|
||||
assert direction == expected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direction_preserved_with_multiple_triples(self):
|
||||
"""All triples from one query share the same direction."""
|
||||
t1 = _make_triple("a", "p1", "b")
|
||||
t2 = _make_triple("a", "p2", "c")
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def multi_results(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return [t1, t2]
|
||||
return []
|
||||
|
||||
client = AsyncMock()
|
||||
client.query_stream = multi_results
|
||||
query = _make_query(triples_client=client)
|
||||
|
||||
result = await query.execute_batch_triple_queries(
|
||||
["e1"], limit_per_entity=10
|
||||
)
|
||||
|
||||
# First query (FROM_S) returns 2 triples, both should be FROM_S
|
||||
assert len(result) == 2
|
||||
assert result[0] == (t1, Query.FROM_S)
|
||||
assert result[1] == (t2, Query.FROM_S)
|
||||
|
||||
|
||||
class TestLRUCacheWithTTL:
|
||||
|
||||
|
|
|
|||
|
|
@ -86,19 +86,19 @@ class TestVerifyJwtEddsa:
|
|||
def test_valid_jwt_passes(self):
|
||||
priv, pub = make_keypair()
|
||||
claims = {
|
||||
"sub": "user-1", "workspace": "default",
|
||||
"sub": "user-1", "default_workspace": "default",
|
||||
"iat": int(time.time()),
|
||||
"exp": int(time.time()) + 60,
|
||||
}
|
||||
token = sign_jwt(priv, claims)
|
||||
got = _verify_jwt_eddsa(token, pub)
|
||||
assert got["sub"] == "user-1"
|
||||
assert got["workspace"] == "default"
|
||||
assert got["default_workspace"] == "default"
|
||||
|
||||
def test_expired_jwt_rejected(self):
|
||||
priv, pub = make_keypair()
|
||||
claims = {
|
||||
"sub": "user-1", "workspace": "default",
|
||||
"sub": "user-1", "default_workspace": "default",
|
||||
"iat": int(time.time()) - 3600,
|
||||
"exp": int(time.time()) - 1,
|
||||
}
|
||||
|
|
@ -110,7 +110,7 @@ class TestVerifyJwtEddsa:
|
|||
priv_a, _ = make_keypair()
|
||||
_, pub_b = make_keypair()
|
||||
claims = {
|
||||
"sub": "user-1", "workspace": "default",
|
||||
"sub": "user-1", "default_workspace": "default",
|
||||
"iat": int(time.time()),
|
||||
"exp": int(time.time()) + 60,
|
||||
}
|
||||
|
|
@ -130,7 +130,7 @@ class TestVerifyJwtEddsa:
|
|||
# since we expect it to bail before verifying.
|
||||
header = {"alg": "HS256", "typ": "JWT", "kid": "x"}
|
||||
payload = {
|
||||
"sub": "user-1", "workspace": "default",
|
||||
"sub": "user-1", "default_workspace": "default",
|
||||
"iat": int(time.time()), "exp": int(time.time()) + 60,
|
||||
}
|
||||
h = _b64url(json.dumps(header, separators=(",", ":")).encode())
|
||||
|
|
@ -148,11 +148,11 @@ class TestIdentity:
|
|||
|
||||
def test_fields(self):
|
||||
i = Identity(
|
||||
handle="u", workspace="w",
|
||||
handle="u", default_workspace="w",
|
||||
principal_id="u", source="api-key",
|
||||
)
|
||||
assert i.handle == "u"
|
||||
assert i.workspace == "w"
|
||||
assert i.default_workspace == "w"
|
||||
assert i.principal_id == "u"
|
||||
assert i.source == "api-key"
|
||||
|
||||
|
|
@ -208,7 +208,7 @@ class TestIamAuthDispatch:
|
|||
async def test_valid_jwt_resolves_to_identity(self):
|
||||
priv, pub = make_keypair()
|
||||
claims = {
|
||||
"sub": "user-1", "workspace": "default",
|
||||
"sub": "user-1", "default_workspace": "default",
|
||||
"iat": int(time.time()),
|
||||
"exp": int(time.time()) + 60,
|
||||
}
|
||||
|
|
@ -221,7 +221,7 @@ class TestIamAuthDispatch:
|
|||
make_request(f"Bearer {token}")
|
||||
)
|
||||
assert ident.handle == "user-1"
|
||||
assert ident.workspace == "default"
|
||||
assert ident.default_workspace == "default"
|
||||
assert ident.principal_id == "user-1"
|
||||
assert ident.source == "jwt"
|
||||
|
||||
|
|
@ -231,7 +231,7 @@ class TestIamAuthDispatch:
|
|||
# must not validate — even ones that would otherwise pass.
|
||||
priv, _ = make_keypair()
|
||||
claims = {
|
||||
"sub": "user-1", "workspace": "default",
|
||||
"sub": "user-1", "default_workspace": "default",
|
||||
"iat": int(time.time()), "exp": int(time.time()) + 60,
|
||||
}
|
||||
token = sign_jwt(priv, claims)
|
||||
|
|
@ -259,7 +259,7 @@ class TestIamAuthDispatch:
|
|||
make_request("Bearer tg_testkey")
|
||||
)
|
||||
assert ident.handle == "user-xyz"
|
||||
assert ident.workspace == "default"
|
||||
assert ident.default_workspace == "default"
|
||||
assert ident.principal_id == "user-xyz"
|
||||
assert ident.source == "api-key"
|
||||
|
||||
|
|
@ -338,9 +338,9 @@ class TestAuthorise:
|
|||
decision for the regime's TTL (clamped above), and raises 403
|
||||
on deny / 401 on regime error (fail closed)."""
|
||||
|
||||
def _make_identity(self, handle="u-1", workspace="default"):
|
||||
def _make_identity(self, handle="u-1", default_workspace="default"):
|
||||
return Identity(
|
||||
handle=handle, workspace=workspace,
|
||||
handle=handle, default_workspace=default_workspace,
|
||||
principal_id=handle, source="api-key",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -25,11 +25,11 @@ from trustgraph.gateway.capabilities import (
|
|||
|
||||
class _Identity:
|
||||
"""Stand-in for auth.Identity — under the IAM contract it has
|
||||
just ``handle``, ``workspace``, ``principal_id``, ``source``."""
|
||||
just ``handle``, ``default_workspace``, ``principal_id``, ``source``."""
|
||||
|
||||
def __init__(self, handle="user-1", workspace="default"):
|
||||
def __init__(self, handle="user-1", default_workspace="default"):
|
||||
self.handle = handle
|
||||
self.workspace = workspace
|
||||
self.default_workspace = default_workspace
|
||||
self.principal_id = handle
|
||||
self.source = "api-key"
|
||||
|
||||
|
|
@ -105,14 +105,14 @@ class TestEnforceWorkspace:
|
|||
async def test_default_fills_from_identity(self):
|
||||
data = {"operation": "x"}
|
||||
auth = _allow_auth()
|
||||
await enforce_workspace(data, _Identity(workspace="default"), auth)
|
||||
await enforce_workspace(data, _Identity(default_workspace="default"), auth)
|
||||
assert data["workspace"] == "default"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caller_supplied_workspace_kept(self):
|
||||
data = {"workspace": "acme", "operation": "x"}
|
||||
auth = _allow_auth()
|
||||
await enforce_workspace(data, _Identity(workspace="default"), auth)
|
||||
await enforce_workspace(data, _Identity(default_workspace="default"), auth)
|
||||
assert data["workspace"] == "acme"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ TEST_CAP = "graph:write"
|
|||
def _valid_identity():
|
||||
return Identity(
|
||||
handle="test-user",
|
||||
workspace="default",
|
||||
default_workspace="default",
|
||||
principal_id="test-user",
|
||||
source="api-key",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class TestAuthenticateAnonymous:
|
|||
)
|
||||
assert resp.error is None
|
||||
assert resp.resolved_user_id == "anon"
|
||||
assert resp.resolved_workspace == "ws"
|
||||
assert resp.resolved_default_workspace == "ws"
|
||||
assert "admin" in list(resp.resolved_roles)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -44,7 +44,7 @@ class TestAuthenticateAnonymous:
|
|||
_make_request(operation="authenticate-anonymous")
|
||||
)
|
||||
assert resp.resolved_user_id == "dev-user"
|
||||
assert resp.resolved_workspace == "dev-ws"
|
||||
assert resp.resolved_default_workspace == "dev-ws"
|
||||
|
||||
|
||||
class TestResolveApiKey:
|
||||
|
|
@ -57,7 +57,7 @@ class TestResolveApiKey:
|
|||
)
|
||||
assert resp.error is None
|
||||
assert resp.resolved_user_id == "anonymous"
|
||||
assert resp.resolved_workspace == "default"
|
||||
assert resp.resolved_default_workspace == "default"
|
||||
|
||||
|
||||
class TestAuthorise:
|
||||
|
|
|
|||
|
|
@ -107,6 +107,7 @@ class TestGraphRagDagStructure:
|
|||
embeddings_client = AsyncMock()
|
||||
graph_embeddings_client = AsyncMock()
|
||||
triples_client = AsyncMock()
|
||||
reranker_client = AsyncMock()
|
||||
|
||||
embeddings_client.embed.return_value = [[0.1, 0.2]]
|
||||
graph_embeddings_client.query.return_value = [
|
||||
|
|
@ -121,27 +122,22 @@ class TestGraphRagDagStructure:
|
|||
]
|
||||
triples_client.query.return_value = []
|
||||
|
||||
result = MagicMock()
|
||||
result.document_id = "0"
|
||||
result.query_id = "0"
|
||||
result.score = 0.95
|
||||
reranker_client.rerank.return_value = [result]
|
||||
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="concept")
|
||||
elif template_id == "kg-edge-scoring":
|
||||
edges = variables.get("knowledge", [])
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[{"id": e["id"], "score": 10} for e in edges],
|
||||
)
|
||||
elif template_id == "kg-edge-reasoning":
|
||||
edges = variables.get("knowledge", [])
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[{"id": e["id"], "reasoning": "relevant"} for e in edges],
|
||||
)
|
||||
elif template_id == "kg-synthesis":
|
||||
return PromptResult(response_type="text", text="Answer.")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
|
||||
return (prompt_client, embeddings_client, graph_embeddings_client,
|
||||
triples_client, reranker_client)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_chain(self, mock_clients):
|
||||
|
|
@ -152,7 +148,7 @@ class TestGraphRagDagStructure:
|
|||
events.append({"explain_id": explain_id, "triples": triples})
|
||||
|
||||
await rag.query(
|
||||
query="test", explain_callback=explain_cb, edge_score_limit=0,
|
||||
query="test", explain_callback=explain_cb,
|
||||
)
|
||||
|
||||
dag = _collect_events(events)
|
||||
|
|
|
|||
|
|
@ -101,27 +101,27 @@ class TestQuery:
|
|||
assert query.rag == mock_rag
|
||||
assert query.collection == "test_collection"
|
||||
assert query.verbose is False
|
||||
assert query.doc_limit == 20 # Default value
|
||||
assert query.fetch_limit == 20 # Default value
|
||||
|
||||
def test_query_initialization_with_custom_doc_limit(self):
|
||||
"""Test Query initialization with custom doc_limit"""
|
||||
def test_query_initialization_with_custom_fetch_limit(self):
|
||||
"""Test Query initialization with custom fetch_limit"""
|
||||
# Create mock DocumentRag
|
||||
mock_rag = MagicMock()
|
||||
|
||||
# Initialize Query with custom doc_limit
|
||||
# Initialize Query with custom fetch_limit
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
workspace="test_workspace",
|
||||
collection="custom_collection",
|
||||
verbose=True,
|
||||
doc_limit=50
|
||||
fetch_limit=50
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert query.rag == mock_rag
|
||||
assert query.collection == "custom_collection"
|
||||
assert query.verbose is True
|
||||
assert query.doc_limit == 50
|
||||
assert query.fetch_limit == 50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_concepts(self):
|
||||
|
|
@ -224,7 +224,7 @@ class TestQuery:
|
|||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
doc_limit=15
|
||||
fetch_limit=15
|
||||
)
|
||||
|
||||
# Call get_docs with concepts list
|
||||
|
|
@ -377,7 +377,7 @@ class TestQuery:
|
|||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=True,
|
||||
doc_limit=5
|
||||
fetch_limit=5
|
||||
)
|
||||
|
||||
# Call get_docs with concepts
|
||||
|
|
@ -615,7 +615,7 @@ class TestQuery:
|
|||
workspace="test_workspace",
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
doc_limit=10
|
||||
fetch_limit=10
|
||||
)
|
||||
|
||||
docs, chunk_ids = await query.get_docs(["concept A", "concept B"])
|
||||
|
|
|
|||
|
|
@ -0,0 +1,92 @@
|
|||
from trustgraph.retrieval.document_rag.rerank import (
|
||||
RerankCandidate, normalize_candidate_scores, mmr_select,
|
||||
_pair_diversity_penalty
|
||||
)
|
||||
|
||||
def candidate(index, chunk_id, text, score):
|
||||
return RerankCandidate(
|
||||
index=index,
|
||||
chunk_id=chunk_id,
|
||||
text=text,
|
||||
reranker_score=score,
|
||||
)
|
||||
|
||||
|
||||
def test_normalize_candidate_scores_min_max_scales_raw_scores():
|
||||
candidates = [
|
||||
candidate(0, "a", "alpha", -2.0),
|
||||
candidate(1, "b", "beta", 0.0),
|
||||
candidate(2, "c", "gamma", 4.0),
|
||||
]
|
||||
|
||||
normalized = normalize_candidate_scores(candidates)
|
||||
|
||||
assert normalized[0].normalized_score == 0.0
|
||||
assert normalized[1].normalized_score == 1.0 / 3.0
|
||||
assert normalized[2].normalized_score == 1.0
|
||||
|
||||
|
||||
def test_normalize_candidate_scores_handles_equal_scores():
|
||||
candidates = [
|
||||
candidate(0, "a", "alpha", 3.0),
|
||||
candidate(1, "b", "beta", 3.0),
|
||||
candidate(2, "c", "gamma", 3.0),
|
||||
]
|
||||
|
||||
normalized = normalize_candidate_scores(candidates)
|
||||
|
||||
assert [c.normalized_score for c in normalized] == [0.5, 0.5, 0.5]
|
||||
|
||||
|
||||
def test_mmr_select_limits_results():
|
||||
candidates = [
|
||||
candidate(0, "a", "alpha policy", 0.9),
|
||||
candidate(1, "b", "beta refund", 0.8),
|
||||
candidate(2, "c", "gamma shipping", 0.7),
|
||||
]
|
||||
|
||||
selected = mmr_select(candidates, limit=2)
|
||||
|
||||
assert len(selected) == 2
|
||||
|
||||
|
||||
def test_mmr_select_prefers_highest_reranker_score_first():
|
||||
candidates = [
|
||||
candidate(0, "a", "weakly relevant text", 0.1),
|
||||
candidate(1, "b", "strongly relevant answer", 10.0),
|
||||
candidate(2, "c", "medium relevant text", 5.0),
|
||||
]
|
||||
|
||||
selected = mmr_select(candidates, limit=1)
|
||||
|
||||
assert selected[0].chunk_id == "b"
|
||||
|
||||
|
||||
def test_mmr_select_penalizes_near_duplicate_chunks():
|
||||
candidates = [
|
||||
candidate(0, "a", "apple banana fruit return policy", 1.00),
|
||||
candidate(1, "b", "apple banana fruit return policy duplicate", 0.95),
|
||||
candidate(2, "c", "engine motor vehicle warranty", 0.90),
|
||||
]
|
||||
|
||||
selected = mmr_select(
|
||||
candidates,
|
||||
limit=2,
|
||||
lambda_mult=0.2,
|
||||
token_overlap_weight=1.0,
|
||||
)
|
||||
|
||||
assert [c.chunk_id for c in selected] == ["a", "c"]
|
||||
|
||||
|
||||
def test_pair_diversity_penalty_is_clamped():
|
||||
left = candidate(0, "a", "same same same", 1.0)
|
||||
right = candidate(1, "b", "same same same", 0.9)
|
||||
|
||||
penalty = _pair_diversity_penalty(
|
||||
left,
|
||||
right,
|
||||
token_overlap_weight=10.0,
|
||||
)
|
||||
|
||||
assert penalty == 1.0
|
||||
550
tests/unit/test_retrieval/test_document_rag_rerank.py
Normal file
550
tests/unit/test_retrieval/test_document_rag_rerank.py
Normal file
|
|
@ -0,0 +1,550 @@
|
|||
"""
|
||||
Tests for the optional cross-encoder reranking pass in DocumentRag.query().
|
||||
|
||||
Two behaviours are covered:
|
||||
|
||||
1. No-op: when no reranker_client is wired (the default), query() must feed
|
||||
the LLM the exact same chunks, in the same order, that retrieval produced
|
||||
- byte-identical to the pre-reranker behaviour - and must NOT emit a
|
||||
chunk-selection provenance event.
|
||||
|
||||
2. Rerank: when a reranker_client is wired, the retrieved chunks are reordered
|
||||
and truncated according to the reranker's results, the LLM receives the
|
||||
reranked top-N, and a tg:ChunkSelection (focus) provenance event is emitted
|
||||
carrying the per-surviving-chunk scores and chunk references.
|
||||
|
||||
These are pure orchestration tests - the reranker is a stub, so there is no
|
||||
torch / network dependency.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
from dataclasses import dataclass
|
||||
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.base import PromptResult
|
||||
from trustgraph.schema import RerankerResult
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_WAS_DERIVED_FROM,
|
||||
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
|
||||
TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_CHUNK_SELECTION, TG_SELECTED_CHUNK, TG_SCORE, TG_DOCUMENT,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def find_triple(triples, predicate, subject=None):
|
||||
for t in triples:
|
||||
if t.p.iri == predicate:
|
||||
if subject is None or t.s.iri == subject:
|
||||
return t
|
||||
return None
|
||||
|
||||
|
||||
def find_triples(triples, predicate, subject=None):
|
||||
return [
|
||||
t for t in triples
|
||||
if t.p.iri == predicate
|
||||
and (subject is None or t.s.iri == subject)
|
||||
]
|
||||
|
||||
|
||||
def has_type(triples, subject, rdf_type):
|
||||
return any(
|
||||
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
|
||||
for t in triples
|
||||
)
|
||||
|
||||
|
||||
def derived_from(triples, subject):
|
||||
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
|
||||
return t.o.iri if t else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkMatch:
|
||||
"""Mimics the result from doc_embeddings_client.query()."""
|
||||
chunk_id: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures: three retrievable chunks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CHUNK_A = "urn:chunk:policy-doc-1:chunk-0"
|
||||
CHUNK_B = "urn:chunk:policy-doc-1:chunk-1"
|
||||
CHUNK_C = "urn:chunk:policy-doc-1:chunk-2"
|
||||
|
||||
CHUNK_A_CONTENT = "Customers may return items within 30 days of purchase."
|
||||
CHUNK_B_CONTENT = "Our stores are open from 9am to 5pm on weekdays."
|
||||
CHUNK_C_CONTENT = "Refunds are processed to the original payment method."
|
||||
|
||||
# Retrieval (post-dedupe) order is A, B, C.
|
||||
ORDERED_CONTENT = [CHUNK_A_CONTENT, CHUNK_B_CONTENT, CHUNK_C_CONTENT]
|
||||
ORDERED_CHUNK_IDS = [CHUNK_A, CHUNK_B, CHUNK_C]
|
||||
|
||||
|
||||
def build_mock_clients():
|
||||
"""
|
||||
Build mock subsidiary clients for a document-rag query returning three
|
||||
distinct chunks (A, B, C) in that order.
|
||||
"""
|
||||
prompt_client = AsyncMock()
|
||||
embeddings_client = AsyncMock()
|
||||
doc_embeddings_client = AsyncMock()
|
||||
fetch_chunk = AsyncMock()
|
||||
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="return policy\nrefund")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
|
||||
embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
|
||||
|
||||
# Each concept query returns the same three chunks; dedupe keeps A, B, C.
|
||||
doc_embeddings_client.query.return_value = [
|
||||
ChunkMatch(chunk_id=CHUNK_A),
|
||||
ChunkMatch(chunk_id=CHUNK_B),
|
||||
ChunkMatch(chunk_id=CHUNK_C),
|
||||
]
|
||||
|
||||
async def mock_fetch(chunk_id):
|
||||
return {
|
||||
CHUNK_A: CHUNK_A_CONTENT,
|
||||
CHUNK_B: CHUNK_B_CONTENT,
|
||||
CHUNK_C: CHUNK_C_CONTENT,
|
||||
}[chunk_id]
|
||||
|
||||
fetch_chunk.side_effect = mock_fetch
|
||||
|
||||
prompt_client.document_prompt.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="Items can be returned within 30 days for a full refund.",
|
||||
)
|
||||
|
||||
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
|
||||
|
||||
|
||||
class StubReranker:
|
||||
"""
|
||||
Stub reranker_client mirroring RerankerClient.rerank(): returns a fixed,
|
||||
pre-sorted, truncated list of RerankerResult - exactly the contract the
|
||||
flashrank service guarantees (sorted desc by score, truncated to limit).
|
||||
"""
|
||||
|
||||
def __init__(self, results):
|
||||
self._results = results
|
||||
self.calls = []
|
||||
|
||||
async def rerank(self, queries, documents, limit=10, timeout=300):
|
||||
self.calls.append(
|
||||
{"queries": queries, "documents": documents, "limit": limit}
|
||||
)
|
||||
return self._results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. No-op: reranker_client=None must not change anything
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRerankNoOp:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_documents_passed_to_llm_are_unchanged(self):
|
||||
"""
|
||||
With no reranker wired, document_prompt must receive the retrieved
|
||||
chunks in the original order and length.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients) # reranker_client defaults to None
|
||||
|
||||
await rag.query(query="What is the return policy?")
|
||||
|
||||
call = rag.prompt_client.document_prompt.call_args
|
||||
passed_docs = call.kwargs["documents"]
|
||||
assert passed_docs == ORDERED_CONTENT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_chunk_selection_event_emitted(self):
|
||||
"""
|
||||
Without a reranker, the provenance chain is the original 4 stages:
|
||||
question, grounding, exploration, synthesis - no focus stage.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
assert len(events) == 4
|
||||
types = [
|
||||
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_SYNTHESIS,
|
||||
]
|
||||
for i, expected in enumerate(types):
|
||||
assert has_type(events[i]["triples"], events[i]["explain_id"], expected)
|
||||
|
||||
# No chunk-selection entity anywhere.
|
||||
for e in events:
|
||||
assert not any(
|
||||
t.o.iri == TG_CHUNK_SELECTION
|
||||
for t in e["triples"]
|
||||
if t.p.iri == RDF_TYPE
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesis_derives_from_exploration_when_no_rerank(self):
|
||||
"""
|
||||
No-op lineage is unchanged: synthesis derives from exploration
|
||||
(there is no focus stage). Guards the conditional synthesis parent.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
# events: question, grounding, exploration, synthesis
|
||||
exp_uri = events[2]["explain_id"]
|
||||
syn_event = events[3]
|
||||
assert derived_from(syn_event["triples"], syn_event["explain_id"]) == exp_uri
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Rerank: reorder + truncate + provenance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRerankActive:
|
||||
|
||||
def _reranker_keeping_C_then_A(self):
|
||||
# Reranker says chunk index 2 (C) is best, then index 0 (A); B dropped.
|
||||
# Pre-sorted desc by score and truncated to limit, per the contract.
|
||||
return StubReranker([
|
||||
RerankerResult(document_id="2", query_id="0", score=0.95),
|
||||
RerankerResult(document_id="0", query_id="0", score=0.42),
|
||||
])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_documents_reordered_and_truncated(self):
|
||||
clients = build_mock_clients()
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
await rag.query(query="What is the return policy?")
|
||||
|
||||
call = rag.prompt_client.document_prompt.call_args
|
||||
passed_docs = call.kwargs["documents"]
|
||||
assert passed_docs == [CHUNK_C_CONTENT, CHUNK_A_CONTENT]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reranker_called_with_single_query_and_all_docs(self):
|
||||
clients = build_mock_clients()
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
await rag.query(query="What is the return policy?", doc_limit=2)
|
||||
|
||||
assert len(reranker.calls) == 1
|
||||
c = reranker.calls[0]
|
||||
assert c["queries"] == [{"id": "0", "text": "What is the return policy?"}]
|
||||
assert c["documents"] == [
|
||||
{"id": "0", "text": CHUNK_A_CONTENT},
|
||||
{"id": "1", "text": CHUNK_B_CONTENT},
|
||||
{"id": "2", "text": CHUNK_C_CONTENT},
|
||||
]
|
||||
# The rerank narrows down to the final doc_limit, NOT fetch_limit
|
||||
# (fetch_limit is the over-fetched candidate pool size).
|
||||
assert c["limit"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explicit_fetch_limit_over_fetches_then_narrows(self):
|
||||
"""
|
||||
Semantic guard for the value of reranking AND the maintainer's two-limit
|
||||
contract: an explicit fetch_limit makes retrieval OVER-FETCH a wider
|
||||
candidate pool so the cross-encoder can surface chunks the bi-encoder
|
||||
ranked outside the final doc_limit, then the rerank narrows the pool back
|
||||
down to doc_limit. The fetch_limit is honoured directly (caller controls
|
||||
how hard the reranker works), not overridden by any heuristic.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
# Candidate pool (fetch_limit=60) >> final doc_limit (6).
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?", doc_limit=6, fetch_limit=60,
|
||||
)
|
||||
|
||||
# Over-fetch: the embeddings store is queried with the fetch_limit
|
||||
# budget (60 // 2 concept-vectors = 30 per concept), NOT the doc_limit
|
||||
# budget (6 // 2 = 3). This is the bug guard.
|
||||
q_limit = doc_embeddings_client.query.call_args.kwargs["limit"]
|
||||
assert q_limit == 30
|
||||
|
||||
# Narrow: the rerank keeps the final doc_limit (6), not fetch_limit.
|
||||
assert reranker.calls[0]["limit"] == 6
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_fetch_limit_derives_overfetch_from_doc_limit(self):
|
||||
"""
|
||||
With no fetch_limit passed to query(), the candidate pool falls back to
|
||||
the OVERFETCH_FACTOR x doc_limit heuristic, so over-fetch scales with
|
||||
doc_limit and reranking keeps its recall benefit out of the box.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
# No fetch_limit -> heuristic default.
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
await rag.query(query="What is the return policy?", doc_limit=20)
|
||||
|
||||
# fetch = 3 x 20 = 60 -> 60 // 2 concept-vectors = 30 per concept.
|
||||
q_limit = doc_embeddings_client.query.call_args.kwargs["limit"]
|
||||
assert q_limit == 30
|
||||
# Rerank narrows to the final doc_limit (20).
|
||||
assert reranker.calls[0]["limit"] == 20
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_limit_floored_at_doc_limit(self):
|
||||
"""
|
||||
A fetch_limit below doc_limit is floored up to doc_limit: retrieval must
|
||||
never fetch fewer candidates than the rerank is asked to keep, else the
|
||||
prompt could not be filled.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?", doc_limit=10, fetch_limit=4,
|
||||
)
|
||||
|
||||
# fetch = max(4, 10) = 10 -> 10 // 2 concept-vectors = 5 per concept.
|
||||
q_limit = doc_embeddings_client.query.call_args.kwargs["limit"]
|
||||
assert q_limit == 5
|
||||
assert reranker.calls[0]["limit"] == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_selection_event_emitted(self):
|
||||
clients = build_mock_clients()
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
# Now 5 stages: question, grounding, exploration, focus, synthesis.
|
||||
assert len(events) == 5
|
||||
ordered_types = [
|
||||
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
|
||||
TG_FOCUS, TG_SYNTHESIS,
|
||||
]
|
||||
for i, expected in enumerate(ordered_types):
|
||||
assert has_type(events[i]["triples"], events[i]["explain_id"], expected)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_selection_carries_scores_and_chunk_refs(self):
|
||||
clients = build_mock_clients()
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
focus_event = events[3]
|
||||
foc_uri = focus_event["explain_id"]
|
||||
triples = focus_event["triples"]
|
||||
|
||||
# focus is derived from exploration
|
||||
exp_uri = events[2]["explain_id"]
|
||||
assert derived_from(triples, foc_uri) == exp_uri
|
||||
|
||||
# Two ChunkSelection sub-entities, linked from focus.
|
||||
sel_links = find_triples(triples, TG_SELECTED_CHUNK, foc_uri)
|
||||
assert len(sel_links) == 2
|
||||
|
||||
# Each selection has a ChunkSelection type, a chunk document ref and a score.
|
||||
chunk_refs = set()
|
||||
scores = set()
|
||||
for link in sel_links:
|
||||
sel_uri = link.o.iri
|
||||
assert has_type(triples, sel_uri, TG_CHUNK_SELECTION)
|
||||
doc_ref = find_triple(triples, TG_DOCUMENT, sel_uri)
|
||||
assert doc_ref is not None
|
||||
chunk_refs.add(doc_ref.o.iri)
|
||||
score_t = find_triple(triples, TG_SCORE, sel_uri)
|
||||
assert score_t is not None
|
||||
scores.add(score_t.o.value)
|
||||
|
||||
# Surviving chunks are C and A (B dropped), with the reranker scores.
|
||||
assert chunk_refs == {CHUNK_C, CHUNK_A}
|
||||
assert scores == {"0.95", "0.42"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_focus_triples_in_retrieval_graph(self):
|
||||
clients = build_mock_clients()
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
for t in events[3]["triples"]:
|
||||
assert t.g == "urn:graph:retrieval"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesis_derives_from_focus_when_reranking(self):
|
||||
"""
|
||||
When reranking runs, synthesis must derive from the focus node (the
|
||||
reranked chunks actually fed to the LLM), mirroring GraphRAG - not from
|
||||
exploration, which would leave focus as a dangling branch and
|
||||
misrepresent what fed the answer.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
events = []
|
||||
|
||||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
await rag.query(
|
||||
query="What is the return policy?",
|
||||
doc_limit=2,
|
||||
explain_callback=explain_callback,
|
||||
)
|
||||
|
||||
# events: question, grounding, exploration, focus, synthesis
|
||||
foc_uri = events[3]["explain_id"]
|
||||
syn_event = events[4]
|
||||
assert derived_from(syn_event["triples"], syn_event["explain_id"]) == foc_uri
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_docs_skips_reranker(self):
|
||||
"""If retrieval returns no chunks, the reranker is never called."""
|
||||
clients = build_mock_clients()
|
||||
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
|
||||
doc_embeddings_client.query.return_value = [] # no matches
|
||||
|
||||
reranker = self._reranker_keeping_C_then_A()
|
||||
rag = DocumentRag(*clients, reranker_client=reranker)
|
||||
|
||||
await rag.query(query="What is the return policy?")
|
||||
|
||||
assert reranker.calls == []
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Diversity selection: optional MMR after cross-encoder scoring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diversity_mode_scores_full_candidate_pool_before_selecting(self):
|
||||
"""
|
||||
With diversity selection enabled, the cross-encoder should score the full
|
||||
fetched candidate pool before MMR narrows it down to doc_limit.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
reranker = StubReranker([
|
||||
RerankerResult(document_id="0", query_id="0", score=1.00),
|
||||
RerankerResult(document_id="1", query_id="0", score=0.95),
|
||||
RerankerResult(document_id="2", query_id="0", score=0.90),
|
||||
])
|
||||
rag = DocumentRag(
|
||||
*clients,
|
||||
reranker_client=reranker,
|
||||
rerank_diversity_mode="mmr",
|
||||
)
|
||||
|
||||
await rag.query(query="What is the return policy?", doc_limit=2)
|
||||
|
||||
assert reranker.calls[0]["limit"] == len(ORDERED_CONTENT)
|
||||
|
||||
call = rag.prompt_client.document_prompt.call_args
|
||||
passed_docs = call.kwargs["documents"]
|
||||
assert len(passed_docs) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diversity_mode_selects_less_redundant_context_set(self):
|
||||
"""
|
||||
MMR should use cross-encoder scores as relevance while penalizing redundant
|
||||
chunks, so a slightly lower-scored but less redundant chunk can be selected.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
|
||||
|
||||
duplicate_a = "apple banana fruit return policy"
|
||||
duplicate_b = "apple banana fruit return policy duplicate"
|
||||
diverse_c = "engine motor vehicle warranty"
|
||||
|
||||
async def mock_fetch(chunk_id):
|
||||
return {
|
||||
CHUNK_A: duplicate_a,
|
||||
CHUNK_B: duplicate_b,
|
||||
CHUNK_C: diverse_c,
|
||||
}[chunk_id]
|
||||
|
||||
fetch_chunk.side_effect = mock_fetch
|
||||
|
||||
reranker = StubReranker([
|
||||
RerankerResult(document_id="0", query_id="0", score=1.00),
|
||||
RerankerResult(document_id="1", query_id="0", score=0.95),
|
||||
RerankerResult(document_id="2", query_id="0", score=0.90),
|
||||
])
|
||||
rag = DocumentRag(
|
||||
*clients,
|
||||
reranker_client=reranker,
|
||||
rerank_diversity_mode="mmr",
|
||||
rerank_diversity_lambda=0.2,
|
||||
)
|
||||
|
||||
await rag.query(query="What is the return policy?", doc_limit=2)
|
||||
|
||||
call = rag.prompt_client.document_prompt.call_args
|
||||
passed_docs = call.kwargs["documents"]
|
||||
|
||||
assert passed_docs == [duplicate_a, diverse_c]
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
"""
|
||||
Cross-layer wiring contract for the Document-RAG reranker (issue #878).
|
||||
|
||||
The Document-RAG processor registers a ``RerankerClientSpec`` for the
|
||||
``reranker-request`` / ``reranker-response`` roles (see
|
||||
``retrieval/document_rag/rag.py``). At flow construction every spec runs
|
||||
``spec.add(flow, processor, definition)``, and ``RequestResponseSpec.add``
|
||||
resolves its topics via ``definition["topics"][name]`` - which raises
|
||||
``KeyError`` if the flow blueprint does not provide those topics.
|
||||
|
||||
This means the monorepo code change is only safe to deploy together with the
|
||||
companion ``trustgraph-templates`` change that wires ``reranker-request`` /
|
||||
``reranker-response`` into the Document-RAG flow (mirroring what templates
|
||||
PR #279 did for GraphRAG via ``graph-store.jsonnet``). These tests pin that
|
||||
contract from the monorepo side:
|
||||
|
||||
* with the reranker topics present (as the updated templates compile them),
|
||||
the spec binds cleanly and registers the client;
|
||||
* without them (the pre-companion blueprint), construction fails fast with a
|
||||
KeyError naming the missing role - documenting exactly why the templates
|
||||
change is required.
|
||||
|
||||
No broker/network: the pub/sub backend is mocked (topics are bound at add()
|
||||
time, connections happen later at start()).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from trustgraph.base import RerankerClientSpec
|
||||
|
||||
|
||||
def _flow():
|
||||
f = MagicMock()
|
||||
f.workspace = "ws"
|
||||
f.name = "document-rag"
|
||||
f.id = "proc1"
|
||||
f.consumer = {}
|
||||
return f
|
||||
|
||||
|
||||
def _processor():
|
||||
p = MagicMock()
|
||||
p.pubsub = MagicMock()
|
||||
p.id = "proc1"
|
||||
p.taskgroup = MagicMock()
|
||||
return p
|
||||
|
||||
|
||||
def _spec():
|
||||
return RerankerClientSpec(
|
||||
request_name="reranker-request",
|
||||
response_name="reranker-response",
|
||||
)
|
||||
|
||||
|
||||
# Topics dict as the UPDATED document-store.jsonnet compiles them
|
||||
# (verified by compiling the template: reranker-request -> request:tg:reranker:{workspace}:{id}).
|
||||
DEFINITION_WITH_RERANKER = {
|
||||
"topics": {
|
||||
"request": "request:tg:document-rag:ws:id",
|
||||
"response": "response:tg:document-rag:ws:id",
|
||||
"reranker-request": "request:tg:reranker:ws:id",
|
||||
"reranker-response": "response:tg:reranker:ws:id",
|
||||
}
|
||||
}
|
||||
|
||||
# Pre-companion blueprint: no reranker topics (document-rag before the templates change).
|
||||
DEFINITION_WITHOUT_RERANKER = {
|
||||
"topics": {
|
||||
"request": "request:tg:document-rag:ws:id",
|
||||
"response": "response:tg:document-rag:ws:id",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_reranker_client_binds_when_flow_provides_topics():
|
||||
flow = _flow()
|
||||
_spec().add(flow, _processor(), DEFINITION_WITH_RERANKER)
|
||||
# The client consumer is registered against the reranker role.
|
||||
assert "reranker-request" in flow.consumer
|
||||
|
||||
|
||||
def test_reranker_client_keyerrors_without_companion_template_topics():
|
||||
with pytest.raises(KeyError) as exc:
|
||||
_spec().add(_flow(), _processor(), DEFINITION_WITHOUT_RERANKER)
|
||||
# Fails fast naming the missing role -> the trustgraph-templates companion
|
||||
# change (wire reranker-request/response into the document-rag flow) is required.
|
||||
assert "reranker-request" in str(exc.value)
|
||||
|
|
@ -66,6 +66,7 @@ class TestDocumentRagService:
|
|||
workspace=ANY, # Workspace comes from flow.workspace (mock)
|
||||
collection="test_coll_1", # Must be from message, not hardcoded default
|
||||
doc_limit=5,
|
||||
fetch_limit=0, # Unset -> core derives the candidate pool
|
||||
explain_callback=ANY, # Explainability callback is always passed
|
||||
save_answer_callback=ANY, # Librarian save callback is always passed
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,54 +15,52 @@ class TestGraphRag:
|
|||
|
||||
def test_graph_rag_initialization_with_defaults(self):
|
||||
"""Test GraphRag initialization with default verbose setting"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = MagicMock()
|
||||
mock_embeddings_client = MagicMock()
|
||||
mock_graph_embeddings_client = MagicMock()
|
||||
mock_triples_client = MagicMock()
|
||||
mock_reranker_client = MagicMock()
|
||||
|
||||
# Initialize GraphRag
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert graph_rag.prompt_client == mock_prompt_client
|
||||
assert graph_rag.embeddings_client == mock_embeddings_client
|
||||
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
||||
assert graph_rag.triples_client == mock_triples_client
|
||||
assert graph_rag.verbose is False # Default value
|
||||
# Verify label_cache is an LRUCacheWithTTL instance
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
|
||||
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
|
||||
|
||||
def test_graph_rag_initialization_with_verbose(self):
|
||||
"""Test GraphRag initialization with verbose enabled"""
|
||||
# Create mock clients
|
||||
mock_prompt_client = MagicMock()
|
||||
mock_embeddings_client = MagicMock()
|
||||
mock_graph_embeddings_client = MagicMock()
|
||||
mock_triples_client = MagicMock()
|
||||
|
||||
# Initialize GraphRag with verbose=True
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
verbose=True
|
||||
reranker_client=mock_reranker_client,
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert graph_rag.prompt_client == mock_prompt_client
|
||||
assert graph_rag.embeddings_client == mock_embeddings_client
|
||||
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
||||
assert graph_rag.triples_client == mock_triples_client
|
||||
assert graph_rag.reranker_client == mock_reranker_client
|
||||
assert graph_rag.verbose is False
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
|
||||
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
|
||||
|
||||
def test_graph_rag_initialization_with_verbose(self):
|
||||
"""Test GraphRag initialization with verbose enabled"""
|
||||
mock_prompt_client = MagicMock()
|
||||
mock_embeddings_client = MagicMock()
|
||||
mock_graph_embeddings_client = MagicMock()
|
||||
mock_triples_client = MagicMock()
|
||||
mock_reranker_client = MagicMock()
|
||||
|
||||
graph_rag = GraphRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
reranker_client=mock_reranker_client,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
assert graph_rag.prompt_client == mock_prompt_client
|
||||
assert graph_rag.embeddings_client == mock_embeddings_client
|
||||
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
|
||||
assert graph_rag.triples_client == mock_triples_client
|
||||
assert graph_rag.reranker_client == mock_reranker_client
|
||||
assert graph_rag.verbose is True
|
||||
# Verify label_cache is an LRUCacheWithTTL instance
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
|
||||
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
|
||||
|
||||
|
|
@ -365,244 +363,162 @@ class TestQuery:
|
|||
assert "workspace" not in c.kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_never_passes_workspace(self):
|
||||
"""Verify follow_edges never passes workspace to query_stream."""
|
||||
async def test_hop_and_filter_never_passes_workspace(self):
|
||||
"""Verify hop_and_filter never passes workspace to query_stream."""
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_reranker_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
mock_rag.reranker_client = mock_reranker_client
|
||||
mock_rag.label_cache = MagicMock()
|
||||
mock_rag.label_cache.get.return_value = None
|
||||
|
||||
mock_triple = MagicMock()
|
||||
mock_triple.s, mock_triple.p, mock_triple.o = "e1", "p1", "o1"
|
||||
mock_triple.s = "e1"
|
||||
mock_triple.p = "p1"
|
||||
mock_triple.o = "o1"
|
||||
mock_triples_client.query_stream.return_value = [mock_triple]
|
||||
mock_triples_client.query.return_value = []
|
||||
|
||||
result = MagicMock()
|
||||
result.document_id = "0"
|
||||
result.query_id = "0"
|
||||
result.score = 0.9
|
||||
mock_reranker_client.rerank.return_value = [result]
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
triple_limit=10
|
||||
triple_limit=10,
|
||||
)
|
||||
|
||||
subgraph = set()
|
||||
await query.follow_edges("e1", subgraph, path_length=1)
|
||||
await query.hop_and_filter(["e1"], ["concept"])
|
||||
|
||||
for c in mock_triples_client.query_stream.call_args_list:
|
||||
assert "workspace" not in c.kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_basic_functionality(self):
|
||||
"""Test Query.follow_edges method basic triple discovery"""
|
||||
async def test_hop_and_filter_basic_functionality(self):
|
||||
"""Test hop_and_filter retrieves edges and scores them with reranker."""
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_reranker_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
mock_rag.reranker_client = mock_reranker_client
|
||||
mock_rag.label_cache = MagicMock()
|
||||
mock_rag.label_cache.get.return_value = None
|
||||
|
||||
mock_triple1 = MagicMock()
|
||||
mock_triple1.s, mock_triple1.p, mock_triple1.o = "entity1", "predicate1", "object1"
|
||||
mock_triple = MagicMock()
|
||||
mock_triple.s = "entity1"
|
||||
mock_triple.p = "predicate1"
|
||||
mock_triple.o = "object1"
|
||||
mock_triples_client.query_stream.return_value = [mock_triple]
|
||||
mock_triples_client.query.return_value = []
|
||||
|
||||
mock_triple2 = MagicMock()
|
||||
mock_triple2.s, mock_triple2.p, mock_triple2.o = "subject2", "entity1", "object2"
|
||||
|
||||
mock_triple3 = MagicMock()
|
||||
mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1"
|
||||
|
||||
mock_triples_client.query_stream.side_effect = [
|
||||
[mock_triple1], # s=ent
|
||||
[mock_triple2], # p=ent
|
||||
[mock_triple3], # o=ent
|
||||
]
|
||||
result = MagicMock()
|
||||
result.document_id = "0"
|
||||
result.query_id = "0"
|
||||
result.score = 0.95
|
||||
mock_reranker_client.rerank.return_value = [result]
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
triple_limit=10
|
||||
triple_limit=10,
|
||||
edge_limit=25,
|
||||
)
|
||||
|
||||
subgraph = set()
|
||||
await query.follow_edges("entity1", subgraph, path_length=1)
|
||||
|
||||
assert mock_triples_client.query_stream.call_count == 3
|
||||
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s="entity1", p=None, o=None, limit=10,
|
||||
collection="test_collection", batch_size=20, g=""
|
||||
)
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s=None, p="entity1", o=None, limit=10,
|
||||
collection="test_collection", batch_size=20, g=""
|
||||
)
|
||||
mock_triples_client.query_stream.assert_any_call(
|
||||
s=None, p=None, o="entity1", limit=10,
|
||||
collection="test_collection", batch_size=20, g=""
|
||||
selected, uri_map, edge_meta = await query.hop_and_filter(
|
||||
["entity1"], ["test concept"],
|
||||
)
|
||||
|
||||
expected_subgraph = {
|
||||
("entity1", "predicate1", "object1"),
|
||||
("subject2", "entity1", "object2"),
|
||||
("subject3", "predicate3", "entity1")
|
||||
}
|
||||
assert subgraph == expected_subgraph
|
||||
assert len(selected) == 1
|
||||
assert len(uri_map) == 1
|
||||
assert len(edge_meta) == 1
|
||||
|
||||
mock_reranker_client.rerank.assert_called_once()
|
||||
call_kwargs = mock_reranker_client.rerank.call_args
|
||||
assert call_kwargs.kwargs["limit"] == 25
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_with_path_length_zero(self):
|
||||
"""Test Query.follow_edges method with path_length=0"""
|
||||
async def test_hop_and_filter_with_empty_frontier(self):
|
||||
"""Test hop_and_filter with no seed entities returns empty."""
|
||||
mock_rag = MagicMock()
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
selected, uri_map, edge_meta = await query.hop_and_filter([], ["concept"])
|
||||
|
||||
assert selected == []
|
||||
assert uri_map == {}
|
||||
assert edge_meta == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hop_and_filter_filters_label_triples(self):
|
||||
"""Test hop_and_filter skips rdfs:label edges."""
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_reranker_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
mock_rag.reranker_client = mock_reranker_client
|
||||
mock_rag.label_cache = MagicMock()
|
||||
mock_rag.label_cache.get.return_value = None
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False
|
||||
)
|
||||
label_triple = MagicMock()
|
||||
label_triple.s = "entity1"
|
||||
label_triple.p = "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
label_triple.o = "Entity One"
|
||||
|
||||
subgraph = set()
|
||||
await query.follow_edges("entity1", subgraph, path_length=0)
|
||||
|
||||
mock_triples_client.query_stream.assert_not_called()
|
||||
assert subgraph == set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_edges_with_max_subgraph_size_limit(self):
|
||||
"""Test Query.follow_edges method respects max_subgraph_size"""
|
||||
mock_rag = MagicMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_rag.triples_client = mock_triples_client
|
||||
mock_triples_client.query_stream.return_value = [label_triple]
|
||||
mock_triples_client.query.return_value = []
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_subgraph_size=2
|
||||
triple_limit=10,
|
||||
)
|
||||
|
||||
subgraph = {("s1", "p1", "o1"), ("s2", "p2", "o2"), ("s3", "p3", "o3")}
|
||||
|
||||
await query.follow_edges("entity1", subgraph, path_length=1)
|
||||
|
||||
mock_triples_client.query_stream.assert_not_called()
|
||||
assert len(subgraph) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_subgraph_method(self):
|
||||
"""Test Query.get_subgraph returns (subgraph, entities, concepts) tuple"""
|
||||
mock_rag = MagicMock()
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_path_length=1
|
||||
selected, uri_map, edge_meta = await query.hop_and_filter(
|
||||
["entity1"], ["concept"],
|
||||
)
|
||||
|
||||
# Mock get_entities to return (entities, concepts) tuple
|
||||
query.get_entities = AsyncMock(
|
||||
return_value=(["entity1", "entity2"], ["concept1"])
|
||||
)
|
||||
|
||||
query.follow_edges_batch = AsyncMock(return_value=(
|
||||
{
|
||||
("entity1", "predicate1", "object1"),
|
||||
("entity2", "predicate2", "object2")
|
||||
},
|
||||
{}
|
||||
))
|
||||
|
||||
subgraph, term_map, entities, concepts = await query.get_subgraph("test query")
|
||||
|
||||
query.get_entities.assert_called_once_with("test query")
|
||||
query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1)
|
||||
|
||||
assert isinstance(subgraph, list)
|
||||
assert len(subgraph) == 2
|
||||
assert ("entity1", "predicate1", "object1") in subgraph
|
||||
assert ("entity2", "predicate2", "object2") in subgraph
|
||||
assert entities == ["entity1", "entity2"]
|
||||
assert concepts == ["concept1"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_labelgraph_method(self):
|
||||
"""Test Query.get_labelgraph returns (labeled_edges, uri_map, entities, concepts)"""
|
||||
mock_rag = MagicMock()
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
collection="test_collection",
|
||||
verbose=False,
|
||||
max_subgraph_size=100
|
||||
)
|
||||
|
||||
test_subgraph = [
|
||||
("entity1", "predicate1", "object1"),
|
||||
("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"),
|
||||
("entity3", "predicate3", "object3")
|
||||
]
|
||||
test_entities = ["entity1", "entity3"]
|
||||
test_concepts = ["concept1"]
|
||||
query.get_subgraph = AsyncMock(
|
||||
return_value=(test_subgraph, {}, test_entities, test_concepts)
|
||||
)
|
||||
|
||||
async def mock_maybe_label(entity):
|
||||
label_map = {
|
||||
"entity1": "Human Entity One",
|
||||
"predicate1": "Human Predicate One",
|
||||
"object1": "Human Object One",
|
||||
"entity3": "Human Entity Three",
|
||||
"predicate3": "Human Predicate Three",
|
||||
"object3": "Human Object Three"
|
||||
}
|
||||
return label_map.get(entity, entity)
|
||||
|
||||
query.maybe_label = AsyncMock(side_effect=mock_maybe_label)
|
||||
|
||||
labeled_edges, uri_map, entities, concepts = await query.get_labelgraph("test query")
|
||||
|
||||
query.get_subgraph.assert_called_once_with("test query")
|
||||
|
||||
# Label triples filtered out
|
||||
assert len(labeled_edges) == 2
|
||||
|
||||
# maybe_label called for non-label triples
|
||||
assert query.maybe_label.call_count == 6
|
||||
|
||||
expected_edges = [
|
||||
("Human Entity One", "Human Predicate One", "Human Object One"),
|
||||
("Human Entity Three", "Human Predicate Three", "Human Object Three")
|
||||
]
|
||||
assert labeled_edges == expected_edges
|
||||
|
||||
assert len(uri_map) == 2
|
||||
assert entities == test_entities
|
||||
assert concepts == test_concepts
|
||||
assert selected == []
|
||||
mock_reranker_client.rerank.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_query_method(self):
|
||||
"""Test GraphRag.query method orchestrates full RAG pipeline with provenance"""
|
||||
import json
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import edge_id
|
||||
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_embeddings_client = AsyncMock()
|
||||
mock_graph_embeddings_client = AsyncMock()
|
||||
mock_triples_client = AsyncMock()
|
||||
mock_reranker_client = AsyncMock()
|
||||
|
||||
expected_response = "This is the RAG response"
|
||||
test_labelgraph = [("Subject", "Predicate", "Object")]
|
||||
test_edge_id = edge_id("Subject", "Predicate", "Object")
|
||||
test_selected_edges = [("Subject", "Predicate", "Object")]
|
||||
test_eid = edge_id("Subject", "Predicate", "Object")
|
||||
test_uri_map = {
|
||||
test_edge_id: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
|
||||
test_eid: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
|
||||
}
|
||||
test_edge_metadata = {
|
||||
test_eid: {"concept": "test concept", "score": 0.95}
|
||||
}
|
||||
test_entities = ["http://example.org/subject"]
|
||||
test_concepts = ["test concept"]
|
||||
|
||||
# Mock prompt responses for the multi-step process
|
||||
mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
|
||||
mock_graph_embeddings_client.query.return_value = []
|
||||
|
||||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-scoring":
|
||||
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "score": 0.9}])
|
||||
elif prompt_name == "kg-edge-reasoning":
|
||||
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "reasoning": "relevant"}])
|
||||
return PromptResult(response_type="text", text="test concept")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return PromptResult(response_type="text", text=expected_response)
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
|
@ -614,16 +530,16 @@ class TestQuery:
|
|||
embeddings_client=mock_embeddings_client,
|
||||
graph_embeddings_client=mock_graph_embeddings_client,
|
||||
triples_client=mock_triples_client,
|
||||
verbose=False
|
||||
reranker_client=mock_reranker_client,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# Patch Query.get_labelgraph to return test data
|
||||
original_get_labelgraph = Query.get_labelgraph
|
||||
original_hop_and_filter = Query.hop_and_filter
|
||||
|
||||
async def mock_get_labelgraph(self, query_text):
|
||||
return test_labelgraph, test_uri_map, test_entities, test_concepts
|
||||
async def mock_hop_and_filter(self, seed_entities, concepts):
|
||||
return test_selected_edges, test_uri_map, test_edge_metadata
|
||||
|
||||
Query.get_labelgraph = mock_get_labelgraph
|
||||
Query.hop_and_filter = mock_hop_and_filter
|
||||
|
||||
provenance_events = []
|
||||
|
||||
|
|
@ -636,7 +552,7 @@ class TestQuery:
|
|||
collection="test_collection",
|
||||
entity_limit=25,
|
||||
triple_limit=15,
|
||||
explain_callback=collect_provenance
|
||||
explain_callback=collect_provenance,
|
||||
)
|
||||
|
||||
response_text, usage = response
|
||||
|
|
@ -650,7 +566,6 @@ class TestQuery:
|
|||
assert len(triples) > 0
|
||||
assert prov_id.startswith("urn:trustgraph:")
|
||||
|
||||
# Verify order
|
||||
assert "question" in provenance_events[0][1]
|
||||
assert "grounding" in provenance_events[1][1]
|
||||
assert "exploration" in provenance_events[2][1]
|
||||
|
|
@ -658,4 +573,4 @@ class TestQuery:
|
|||
assert "synthesis" in provenance_events[4][1]
|
||||
|
||||
finally:
|
||||
Query.get_labelgraph = original_get_labelgraph
|
||||
Query.hop_and_filter = original_hop_and_filter
|
||||
|
|
|
|||
353
tests/unit/test_retrieval/test_graph_rag_direction_aware_text.py
Normal file
353
tests/unit/test_retrieval/test_graph_rag_direction_aware_text.py
Normal file
|
|
@ -0,0 +1,353 @@
|
|||
"""
|
||||
Tests for direction-aware reranker text in GraphRAG hop-and-filter.
|
||||
|
||||
The reranker document text varies by traversal direction:
|
||||
- From S (subject is the frontier entity): text = "{p} {o}"
|
||||
- From O (object is the frontier entity): text = "{s} {p}"
|
||||
- From P (predicate is the frontier entity): text = "{s} {o}"
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import Query, LRUCacheWithTTL
|
||||
from trustgraph.schema import Term, IRI, LITERAL
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_rag(reranker_results=None):
|
||||
"""Create a mock GraphRag with all clients stubbed."""
|
||||
rag = MagicMock()
|
||||
rag.label_cache = LRUCacheWithTTL()
|
||||
rag.triples_client = AsyncMock()
|
||||
rag.reranker_client = AsyncMock()
|
||||
|
||||
# Label lookups return empty (fall back to URI)
|
||||
rag.triples_client.query.return_value = []
|
||||
|
||||
if reranker_results is not None:
|
||||
rag.reranker_client.rerank.return_value = reranker_results
|
||||
else:
|
||||
rag.reranker_client.rerank.return_value = []
|
||||
|
||||
return rag
|
||||
|
||||
|
||||
def _make_query(rag, max_path_length=1, edge_limit=25):
|
||||
return Query(
|
||||
rag=rag,
|
||||
collection="test",
|
||||
verbose=False,
|
||||
entity_limit=50,
|
||||
triple_limit=30,
|
||||
max_subgraph_size=1000,
|
||||
max_path_length=max_path_length,
|
||||
edge_limit=edge_limit,
|
||||
)
|
||||
|
||||
|
||||
def _make_schema_triple(s, p, o):
|
||||
"""Create a mock triple matching the schema interface."""
|
||||
t = MagicMock()
|
||||
t.s = s
|
||||
t.p = p
|
||||
t.o = o
|
||||
return t
|
||||
|
||||
|
||||
def _reranker_result(document_id, query_id="0", score=0.9):
|
||||
r = MagicMock()
|
||||
r.document_id = str(document_id)
|
||||
r.query_id = str(query_id)
|
||||
r.score = score
|
||||
return r
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: execute_batch_triple_queries direction tracking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDirectionTracking:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_from_s_direction(self):
|
||||
"""Triples from s=entity queries are tagged FROM_S."""
|
||||
triple = _make_schema_triple("ent1", "pred", "obj")
|
||||
rag = _make_rag()
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if s is not None:
|
||||
return [triple]
|
||||
return []
|
||||
|
||||
rag.triples_client.query_stream.side_effect = query_stream
|
||||
q = _make_query(rag)
|
||||
|
||||
result = await q.execute_batch_triple_queries(["ent1"], 10)
|
||||
|
||||
from_s = [(t, d) for t, d in result if d == Query.FROM_S]
|
||||
assert len(from_s) == 1
|
||||
assert from_s[0][0] is triple
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_from_o_direction(self):
|
||||
"""Triples from o=entity queries are tagged FROM_O."""
|
||||
triple = _make_schema_triple("subj", "pred", "ent1")
|
||||
rag = _make_rag()
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if o is not None:
|
||||
return [triple]
|
||||
return []
|
||||
|
||||
rag.triples_client.query_stream.side_effect = query_stream
|
||||
q = _make_query(rag)
|
||||
|
||||
result = await q.execute_batch_triple_queries(["ent1"], 10)
|
||||
|
||||
from_o = [(t, d) for t, d in result if d == Query.FROM_O]
|
||||
assert len(from_o) == 1
|
||||
assert from_o[0][0] is triple
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_from_p_direction(self):
|
||||
"""Triples from p=entity queries are tagged FROM_P."""
|
||||
triple = _make_schema_triple("subj", "ent1", "obj")
|
||||
rag = _make_rag()
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if p is not None:
|
||||
return [triple]
|
||||
return []
|
||||
|
||||
rag.triples_client.query_stream.side_effect = query_stream
|
||||
q = _make_query(rag)
|
||||
|
||||
result = await q.execute_batch_triple_queries(["ent1"], 10)
|
||||
|
||||
from_p = [(t, d) for t, d in result if d == Query.FROM_P]
|
||||
assert len(from_p) == 1
|
||||
assert from_p[0][0] is triple
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: hop_and_filter reranker document text
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDirectionAwareRerankerText:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_from_s_uses_predicate_object(self):
|
||||
"""From-S traversal: reranker text should be '{p} {o}'."""
|
||||
triple = _make_schema_triple(
|
||||
"http://ex/entity-A",
|
||||
"http://ex/likes",
|
||||
"http://ex/entity-B",
|
||||
)
|
||||
reranker_result = _reranker_result(0)
|
||||
rag = _make_rag(reranker_results=[reranker_result])
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if s is not None:
|
||||
return [triple]
|
||||
return []
|
||||
|
||||
rag.triples_client.query_stream.side_effect = query_stream
|
||||
|
||||
q = _make_query(rag, max_path_length=1, edge_limit=10)
|
||||
|
||||
await q.hop_and_filter(
|
||||
seed_entities=["http://ex/entity-A"],
|
||||
concepts=["likes"],
|
||||
)
|
||||
|
||||
call_args = rag.reranker_client.rerank.call_args
|
||||
documents = call_args.kwargs["documents"]
|
||||
# Text should be "{p} {o}" — the URIs since no labels found
|
||||
assert len(documents) == 1
|
||||
assert documents[0]["text"] == "http://ex/likes http://ex/entity-B"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_from_o_uses_subject_predicate(self):
|
||||
"""From-O traversal: reranker text should be '{s} {p}'."""
|
||||
triple = _make_schema_triple(
|
||||
"http://ex/entity-A",
|
||||
"http://ex/likes",
|
||||
"http://ex/entity-B",
|
||||
)
|
||||
reranker_result = _reranker_result(0)
|
||||
rag = _make_rag(reranker_results=[reranker_result])
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if o is not None:
|
||||
return [triple]
|
||||
return []
|
||||
|
||||
rag.triples_client.query_stream.side_effect = query_stream
|
||||
|
||||
q = _make_query(rag, max_path_length=1, edge_limit=10)
|
||||
|
||||
await q.hop_and_filter(
|
||||
seed_entities=["http://ex/entity-B"],
|
||||
concepts=["likes"],
|
||||
)
|
||||
|
||||
call_args = rag.reranker_client.rerank.call_args
|
||||
documents = call_args.kwargs["documents"]
|
||||
assert len(documents) == 1
|
||||
assert documents[0]["text"] == "http://ex/entity-A http://ex/likes"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_from_p_uses_subject_object(self):
|
||||
"""From-P traversal: reranker text should be '{s} {o}'."""
|
||||
triple = _make_schema_triple(
|
||||
"http://ex/entity-A",
|
||||
"http://ex/likes",
|
||||
"http://ex/entity-B",
|
||||
)
|
||||
reranker_result = _reranker_result(0)
|
||||
rag = _make_rag(reranker_results=[reranker_result])
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if p is not None:
|
||||
return [triple]
|
||||
return []
|
||||
|
||||
rag.triples_client.query_stream.side_effect = query_stream
|
||||
|
||||
q = _make_query(rag, max_path_length=1, edge_limit=10)
|
||||
|
||||
await q.hop_and_filter(
|
||||
seed_entities=["http://ex/likes"],
|
||||
concepts=["entity"],
|
||||
)
|
||||
|
||||
call_args = rag.reranker_client.rerank.call_args
|
||||
documents = call_args.kwargs["documents"]
|
||||
assert len(documents) == 1
|
||||
assert documents[0]["text"] == "http://ex/entity-A http://ex/entity-B"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_directions_produce_different_text(self):
|
||||
"""Edges from different directions use different text formats."""
|
||||
triple_from_s = _make_schema_triple(
|
||||
"http://ex/seed", "http://ex/rel", "http://ex/target",
|
||||
)
|
||||
triple_from_o = _make_schema_triple(
|
||||
"http://ex/other", "http://ex/ref", "http://ex/seed",
|
||||
)
|
||||
|
||||
rag = _make_rag(reranker_results=[
|
||||
_reranker_result(0), _reranker_result(1),
|
||||
])
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if s == "http://ex/seed":
|
||||
return [triple_from_s]
|
||||
if o == "http://ex/seed":
|
||||
return [triple_from_o]
|
||||
return []
|
||||
|
||||
rag.triples_client.query_stream.side_effect = query_stream
|
||||
|
||||
q = _make_query(rag, max_path_length=1, edge_limit=10)
|
||||
|
||||
await q.hop_and_filter(
|
||||
seed_entities=["http://ex/seed"],
|
||||
concepts=["test"],
|
||||
)
|
||||
|
||||
call_args = rag.reranker_client.rerank.call_args
|
||||
documents = call_args.kwargs["documents"]
|
||||
texts = {d["text"] for d in documents}
|
||||
|
||||
# From S: "{p} {o}" = "http://ex/rel http://ex/target"
|
||||
assert "http://ex/rel http://ex/target" in texts
|
||||
# From O: "{s} {p}" = "http://ex/other http://ex/ref"
|
||||
assert "http://ex/other http://ex/ref" in texts
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_labels_applied_to_direction_text(self):
|
||||
"""Labels should be resolved and used in the direction-aware text."""
|
||||
triple = _make_schema_triple(
|
||||
"http://ex/entity-A",
|
||||
"http://ex/likes",
|
||||
"http://ex/entity-B",
|
||||
)
|
||||
reranker_result = _reranker_result(0)
|
||||
rag = _make_rag(reranker_results=[reranker_result])
|
||||
|
||||
LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if s is not None and p is None:
|
||||
return [triple]
|
||||
return []
|
||||
|
||||
async def label_query(s=None, p=None, o=None, limit=1, **kwargs):
|
||||
if p == LABEL:
|
||||
labels = {
|
||||
"http://ex/entity-A": "Alice",
|
||||
"http://ex/likes": "likes",
|
||||
"http://ex/entity-B": "Bob",
|
||||
}
|
||||
if s in labels:
|
||||
return [MagicMock(o=labels[s])]
|
||||
return []
|
||||
|
||||
rag.triples_client.query_stream.side_effect = query_stream
|
||||
rag.triples_client.query.side_effect = label_query
|
||||
|
||||
q = _make_query(rag, max_path_length=1, edge_limit=10)
|
||||
|
||||
await q.hop_and_filter(
|
||||
seed_entities=["http://ex/entity-A"],
|
||||
concepts=["friendship"],
|
||||
)
|
||||
|
||||
call_args = rag.reranker_client.rerank.call_args
|
||||
documents = call_args.kwargs["documents"]
|
||||
assert len(documents) == 1
|
||||
# From S with labels: "{p_label} {o_label}"
|
||||
assert documents[0]["text"] == "likes Bob"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_duplicate_text_from_shared_object(self):
|
||||
"""Multiple edges sharing an object should produce distinct texts."""
|
||||
triple_a = _make_schema_triple(
|
||||
"http://ex/cpu-A", "http://ex/hasCategory", "http://ex/Processors",
|
||||
)
|
||||
triple_b = _make_schema_triple(
|
||||
"http://ex/cpu-B", "http://ex/hasCategory", "http://ex/Processors",
|
||||
)
|
||||
|
||||
rag = _make_rag(reranker_results=[
|
||||
_reranker_result(0), _reranker_result(1),
|
||||
])
|
||||
|
||||
async def query_stream(s=None, p=None, o=None, **kwargs):
|
||||
if o == "http://ex/Processors":
|
||||
return [triple_a, triple_b]
|
||||
return []
|
||||
|
||||
rag.triples_client.query_stream.side_effect = query_stream
|
||||
|
||||
q = _make_query(rag, max_path_length=1, edge_limit=10)
|
||||
|
||||
await q.hop_and_filter(
|
||||
seed_entities=["http://ex/Processors"],
|
||||
concepts=["CPUs"],
|
||||
)
|
||||
|
||||
call_args = rag.reranker_client.rerank.call_args
|
||||
documents = call_args.kwargs["documents"]
|
||||
texts = [d["text"] for d in documents]
|
||||
|
||||
assert len(texts) == 2
|
||||
# From O: "{s} {p}" — subjects differ, so texts differ
|
||||
assert texts[0] != texts[1]
|
||||
assert "http://ex/cpu-A" in texts[0]
|
||||
assert "http://ex/cpu-B" in texts[1]
|
||||
|
|
@ -20,7 +20,7 @@ from trustgraph.provenance.namespaces import (
|
|||
TG_GRAPH_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
|
||||
TG_FOCUS, TG_SYNTHESIS, TG_ANSWER_TYPE,
|
||||
TG_QUERY, TG_CONCEPT, TG_ENTITY, TG_EDGE_COUNT,
|
||||
TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
TG_SELECTED_EDGE, TG_EDGE, TG_SCORE, TG_EDGE_SELECTION,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -91,17 +91,17 @@ def build_mock_clients():
|
|||
1. prompt_client.prompt("extract-concepts", ...) -> concepts
|
||||
2. embeddings_client.embed(concepts) -> vectors
|
||||
3. graph_embeddings_client.query(vector, ...) -> entity matches
|
||||
4. triples_client.query_stream(s/p/o, ...) -> edges (follow_edges_batch)
|
||||
4. triples_client.query_stream(s/p/o, ...) -> edges (hop_and_filter)
|
||||
5. triples_client.query(s, LABEL, ...) -> labels (maybe_label)
|
||||
6. prompt_client.prompt("kg-edge-scoring", ...) -> scored edges
|
||||
7. prompt_client.prompt("kg-edge-reasoning", ...) -> reasoning
|
||||
8. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns [])
|
||||
9. prompt_client.prompt("kg-synthesis", ...) -> final answer
|
||||
6. reranker_client.rerank(queries, documents, limit) -> scored edges
|
||||
7. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns [])
|
||||
8. prompt_client.prompt("kg-synthesis", ...) -> final answer
|
||||
"""
|
||||
prompt_client = AsyncMock()
|
||||
embeddings_client = AsyncMock()
|
||||
graph_embeddings_client = AsyncMock()
|
||||
triples_client = AsyncMock()
|
||||
reranker_client = AsyncMock()
|
||||
|
||||
# 1. Concept extraction
|
||||
prompt_responses = {}
|
||||
|
|
@ -116,7 +116,7 @@ def build_mock_clients():
|
|||
EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_B)),
|
||||
]
|
||||
|
||||
# 4. Triple queries (follow_edges_batch) - return our edges
|
||||
# 4. Triple queries (hop_and_filter) - return our edges
|
||||
kg_triples = [
|
||||
make_schema_triple(*EDGE_1),
|
||||
make_schema_triple(*EDGE_2),
|
||||
|
|
@ -130,9 +130,18 @@ def build_mock_clients():
|
|||
return [] # No labels found, will fall back to URI
|
||||
triples_client.query.side_effect = mock_label_query
|
||||
|
||||
# 6+7. Edge scoring and reasoning: dynamically score/reason about
|
||||
# whatever edges the query method sends us, since edge IDs are computed
|
||||
# from str(Term) representations which include the full dataclass repr.
|
||||
# 6. Reranker: select all documents with high scores
|
||||
async def mock_rerank(queries, documents, limit):
|
||||
results = []
|
||||
for i, doc in enumerate(documents):
|
||||
result = MagicMock()
|
||||
result.document_id = doc["id"]
|
||||
result.query_id = queries[0]["id"] if queries else "0"
|
||||
result.score = 0.9 - (i * 0.1)
|
||||
results.append(result)
|
||||
return results[:limit]
|
||||
reranker_client.rerank.side_effect = mock_rerank
|
||||
|
||||
synthesis_answer = "Quantum computing applies physics principles to computation."
|
||||
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
|
|
@ -141,26 +150,6 @@ def build_mock_clients():
|
|||
response_type="text",
|
||||
text=prompt_responses["extract-concepts"],
|
||||
)
|
||||
elif template_id == "kg-edge-scoring":
|
||||
# Score all edges highly, using the IDs that GraphRag computed
|
||||
edges = variables.get("knowledge", [])
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"id": e["id"], "score": 10 - i}
|
||||
for i, e in enumerate(edges)
|
||||
],
|
||||
)
|
||||
elif template_id == "kg-edge-reasoning":
|
||||
# Provide reasoning for each edge
|
||||
edges = variables.get("knowledge", [])
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
|
||||
for i, e in enumerate(edges)
|
||||
],
|
||||
)
|
||||
elif template_id == "kg-synthesis":
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
|
|
@ -170,7 +159,8 @@ def build_mock_clients():
|
|||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
|
||||
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
|
||||
return (prompt_client, embeddings_client, graph_embeddings_client,
|
||||
triples_client, reranker_client)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -197,7 +187,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0, # skip semantic pre-filter for simplicity
|
||||
|
||||
)
|
||||
|
||||
assert len(events) == 5, (
|
||||
|
|
@ -222,7 +212,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
expected_types = [
|
||||
|
|
@ -260,7 +250,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
uris = [e["explain_id"] for e in events]
|
||||
|
|
@ -297,7 +287,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
q_uri = events[0]["explain_id"]
|
||||
|
|
@ -320,7 +310,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
gnd_uri = events[1]["explain_id"]
|
||||
|
|
@ -344,7 +334,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
exp_uri = events[2]["explain_id"]
|
||||
|
|
@ -355,10 +345,10 @@ class TestGraphRagQueryProvenance:
|
|||
assert int(t.o.value) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_focus_has_selected_edges_with_reasoning(self):
|
||||
async def test_focus_has_selected_edges_with_concept_and_score(self):
|
||||
"""
|
||||
The focus event should carry selected edges as quoted triples
|
||||
with reasoning text.
|
||||
with cross-encoder concept and score metadata.
|
||||
"""
|
||||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
|
@ -371,7 +361,6 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
foc_uri = events[3]["explain_id"]
|
||||
|
|
@ -387,11 +376,19 @@ class TestGraphRagQueryProvenance:
|
|||
for t in edge_t:
|
||||
assert t.o.triple is not None, "tg:edge object must be a quoted triple"
|
||||
|
||||
# Should have reasoning
|
||||
reasoning = find_triples(foc_triples, TG_REASONING)
|
||||
assert len(reasoning) > 0, "Focus should have reasoning for selected edges"
|
||||
reasoning_texts = {t.o.value for t in reasoning}
|
||||
assert any(r for r in reasoning_texts), "Reasoning should not be empty"
|
||||
# Edge selections should be typed as EdgeSelection
|
||||
edge_sel_uris = [t.o.iri for t in selected]
|
||||
for uri in edge_sel_uris:
|
||||
assert has_type(foc_triples, uri, TG_EDGE_SELECTION)
|
||||
|
||||
# Should have concept and score
|
||||
concepts = find_triples(foc_triples, TG_CONCEPT)
|
||||
assert len(concepts) > 0, "Focus should have tg:concept for selected edges"
|
||||
|
||||
scores = find_triples(foc_triples, TG_SCORE)
|
||||
assert len(scores) > 0, "Focus should have tg:score for selected edges"
|
||||
for t in scores:
|
||||
float(t.o.value) # Should be parseable as float
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_synthesis_is_answer_type(self):
|
||||
|
|
@ -407,7 +404,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
syn_uri = events[4]["explain_id"]
|
||||
|
|
@ -429,7 +426,7 @@ class TestGraphRagQueryProvenance:
|
|||
result_text, usage = await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
assert result_text == "Quantum computing applies physics principles to computation."
|
||||
|
|
@ -449,7 +446,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
parent_uri=parent,
|
||||
)
|
||||
|
||||
|
|
@ -465,7 +462,7 @@ class TestGraphRagQueryProvenance:
|
|||
|
||||
result_text, usage = await rag.query(
|
||||
query="What is quantum computing?",
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
assert result_text == "Quantum computing applies physics principles to computation."
|
||||
|
|
@ -484,7 +481,7 @@ class TestGraphRagQueryProvenance:
|
|||
await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
|
||||
)
|
||||
|
||||
for event in events:
|
||||
|
|
|
|||
|
|
@ -527,7 +527,8 @@ class AsyncFlowInstance:
|
|||
return result.get("response", "")
|
||||
|
||||
async def document_rag(self, query: str, collection: str,
|
||||
doc_limit: int = 10, **kwargs: Any) -> str:
|
||||
doc_limit: int = 10, fetch_limit: int = 0,
|
||||
**kwargs: Any) -> str:
|
||||
"""
|
||||
Execute document-based RAG query (non-streaming).
|
||||
|
||||
|
|
@ -541,7 +542,9 @@ class AsyncFlowInstance:
|
|||
Args:
|
||||
query: User query text
|
||||
collection: Collection identifier containing documents
|
||||
doc_limit: Maximum number of document chunks to retrieve (default: 10)
|
||||
doc_limit: Document chunks selected into the prompt (default: 10)
|
||||
fetch_limit: Candidate chunks fetched from the vector store before
|
||||
reranking (default: 0 = derive from doc_limit)
|
||||
**kwargs: Additional service-specific parameters
|
||||
|
||||
Returns:
|
||||
|
|
@ -564,6 +567,7 @@ class AsyncFlowInstance:
|
|||
"query": query,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
"fetch-limit": fetch_limit,
|
||||
"streaming": False
|
||||
}
|
||||
request_data.update(kwargs)
|
||||
|
|
@ -646,6 +650,16 @@ class AsyncFlowInstance:
|
|||
|
||||
return await self.request("embeddings", request_data)
|
||||
|
||||
async def rerank(self, queries: list, documents: list, limit: int = 10, **kwargs: Any):
|
||||
request_data = {
|
||||
"queries": queries,
|
||||
"documents": documents,
|
||||
"limit": limit,
|
||||
}
|
||||
request_data.update(kwargs)
|
||||
|
||||
return await self.request("reranker", request_data)
|
||||
|
||||
async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs: Any):
|
||||
"""
|
||||
Query RDF triples using pattern matching.
|
||||
|
|
|
|||
|
|
@ -94,7 +94,9 @@ class AsyncSocketClient:
|
|||
|
||||
if resp.get("type") == "auth-ok":
|
||||
if not self._workspace_explicit:
|
||||
self.workspace = resp.get("workspace", self.workspace)
|
||||
self.workspace = resp.get(
|
||||
"default_workspace", self.workspace,
|
||||
)
|
||||
elif resp.get("type") == "auth-failed":
|
||||
await self._socket.close()
|
||||
raise ProtocolException(
|
||||
|
|
@ -377,12 +379,14 @@ class AsyncSocketFlowInstance:
|
|||
yield chunk.content
|
||||
|
||||
async def document_rag(self, query: str, collection: str,
|
||||
doc_limit: int = 10, streaming: bool = False, **kwargs):
|
||||
doc_limit: int = 10, fetch_limit: int = 0,
|
||||
streaming: bool = False, **kwargs):
|
||||
"""Document RAG with optional streaming"""
|
||||
request = {
|
||||
"query": query,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
"fetch-limit": fetch_limit,
|
||||
"streaming": streaming
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
|
@ -441,6 +445,19 @@ class AsyncSocketFlowInstance:
|
|||
|
||||
return await self.client._send_request("embeddings", self.flow_id, request)
|
||||
|
||||
async def rerank(self, queries: list, documents: list, limit: int = 10,
|
||||
**kwargs):
|
||||
request = {
|
||||
"queries": queries,
|
||||
"documents": documents,
|
||||
"limit": limit,
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
return await self.client._send_request(
|
||||
"reranker", self.flow_id, request,
|
||||
)
|
||||
|
||||
async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs):
|
||||
"""Triple pattern query"""
|
||||
request = {"limit": limit}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ TG_EDGE_COUNT = TG + "edgeCount"
|
|||
TG_SELECTED_EDGE = TG + "selectedEdge"
|
||||
TG_EDGE = TG + "edge"
|
||||
TG_REASONING = TG + "reasoning"
|
||||
TG_SCORE = TG + "score"
|
||||
TG_DOCUMENT = TG + "document"
|
||||
TG_CONCEPT = TG + "concept"
|
||||
TG_ENTITY = TG + "entity"
|
||||
|
|
@ -66,10 +67,12 @@ RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
|
|||
|
||||
@dataclass
|
||||
class EdgeSelection:
|
||||
"""A selected edge with reasoning from GraphRAG Focus step."""
|
||||
"""A selected edge with cross-encoder metadata from GraphRAG Focus step."""
|
||||
uri: str
|
||||
edge: Optional[Dict[str, str]] = None # {"s": ..., "p": ..., "o": ...}
|
||||
reasoning: str = ""
|
||||
concept: str = ""
|
||||
score: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -209,7 +212,7 @@ class Exploration(ExplainEntity):
|
|||
|
||||
@dataclass
|
||||
class Focus(ExplainEntity):
|
||||
"""Focus entity - selected edges with LLM reasoning (GraphRAG only)."""
|
||||
"""Focus entity - selected edges with cross-encoder scoring (GraphRAG only)."""
|
||||
selected_edge_uris: List[str] = field(default_factory=list)
|
||||
edge_selections: List[EdgeSelection] = field(default_factory=list)
|
||||
|
||||
|
|
@ -418,14 +421,26 @@ def parse_edge_selection_triples(triples: List[Tuple[str, str, Any]]) -> EdgeSel
|
|||
uri = triples[0][0] if triples else ""
|
||||
edge = None
|
||||
reasoning = ""
|
||||
concept = ""
|
||||
score = None
|
||||
|
||||
for s, p, o in triples:
|
||||
if p == TG_EDGE and isinstance(o, dict):
|
||||
edge = o
|
||||
elif p == TG_REASONING:
|
||||
reasoning = o
|
||||
elif p == TG_CONCEPT:
|
||||
concept = o
|
||||
elif p == TG_SCORE:
|
||||
try:
|
||||
score = float(o)
|
||||
except (ValueError, TypeError):
|
||||
score = None
|
||||
|
||||
return EdgeSelection(uri=uri, edge=edge, reasoning=reasoning)
|
||||
return EdgeSelection(
|
||||
uri=uri, edge=edge, reasoning=reasoning,
|
||||
concept=concept, score=score,
|
||||
)
|
||||
|
||||
|
||||
def extract_term_value(term: Dict[str, Any]) -> Any:
|
||||
|
|
|
|||
|
|
@ -415,7 +415,7 @@ class FlowInstance:
|
|||
|
||||
def document_rag(
|
||||
self, query,collection="default",
|
||||
doc_limit=10,
|
||||
doc_limit=10, fetch_limit=0,
|
||||
):
|
||||
"""
|
||||
Execute document-based Retrieval-Augmented Generation (RAG) query.
|
||||
|
|
@ -426,7 +426,9 @@ class FlowInstance:
|
|||
Args:
|
||||
query: Natural language query
|
||||
collection: Collection identifier (default: "default")
|
||||
doc_limit: Maximum document chunks to retrieve (default: 10)
|
||||
doc_limit: Document chunks selected into the prompt (default: 10)
|
||||
fetch_limit: Candidate chunks fetched from the vector store before
|
||||
reranking (default: 0 = derive from doc_limit)
|
||||
|
||||
Returns:
|
||||
str: Generated response incorporating document context
|
||||
|
|
@ -447,6 +449,7 @@ class FlowInstance:
|
|||
"query": query,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
"fetch-limit": fetch_limit,
|
||||
}
|
||||
|
||||
result = self.request(
|
||||
|
|
@ -491,6 +494,19 @@ class FlowInstance:
|
|||
input
|
||||
)["vectors"]
|
||||
|
||||
def rerank(self, queries, documents, limit=10):
|
||||
|
||||
input = {
|
||||
"queries": queries,
|
||||
"documents": documents,
|
||||
"limit": limit,
|
||||
}
|
||||
|
||||
return self.request(
|
||||
"service/reranker",
|
||||
input
|
||||
)
|
||||
|
||||
def graph_embeddings_query(self, text, collection, limit=10):
|
||||
"""
|
||||
Query knowledge graph entities using semantic similarity.
|
||||
|
|
|
|||
|
|
@ -168,7 +168,9 @@ class SocketClient:
|
|||
|
||||
if resp.get("type") == "auth-ok":
|
||||
if self.workspace == "default":
|
||||
self.workspace = resp.get("workspace", self.workspace)
|
||||
self.workspace = resp.get(
|
||||
"default_workspace", self.workspace,
|
||||
)
|
||||
elif resp.get("type") == "auth-failed":
|
||||
await self._socket.close()
|
||||
raise ProtocolException(
|
||||
|
|
@ -750,6 +752,7 @@ class SocketFlowInstance:
|
|||
query: str,
|
||||
collection: str,
|
||||
doc_limit: int = 10,
|
||||
fetch_limit: int = 0,
|
||||
streaming: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
|
||||
|
|
@ -762,6 +765,7 @@ class SocketFlowInstance:
|
|||
"query": query,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
"fetch-limit": fetch_limit,
|
||||
"streaming": streaming
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
|
@ -783,6 +787,7 @@ class SocketFlowInstance:
|
|||
query: str,
|
||||
collection: str,
|
||||
doc_limit: int = 10,
|
||||
fetch_limit: int = 0,
|
||||
**kwargs: Any
|
||||
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
|
||||
"""Execute document-based RAG query with explainability support."""
|
||||
|
|
@ -790,6 +795,7 @@ class SocketFlowInstance:
|
|||
"query": query,
|
||||
"collection": collection,
|
||||
"doc-limit": doc_limit,
|
||||
"fetch-limit": fetch_limit,
|
||||
"streaming": True,
|
||||
"explainable": True,
|
||||
}
|
||||
|
|
@ -883,6 +889,19 @@ class SocketFlowInstance:
|
|||
|
||||
return self.client._send_request_sync("embeddings", self.flow_id, request, False)
|
||||
|
||||
def rerank(self, queries: list, documents: list, limit: int = 10,
|
||||
**kwargs: Any) -> Dict[str, Any]:
|
||||
request = {
|
||||
"queries": queries,
|
||||
"documents": documents,
|
||||
"limit": limit,
|
||||
}
|
||||
request.update(kwargs)
|
||||
|
||||
return self.client._send_request_sync(
|
||||
"reranker", self.flow_id, request, False,
|
||||
)
|
||||
|
||||
def triples_query(
|
||||
self,
|
||||
s: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
|
|
|
|||
|
|
@ -42,6 +42,8 @@ from . dynamic_tool_service import DynamicToolService
|
|||
from . tool_service_client import ToolServiceClientSpec
|
||||
from . agent_client import AgentClientSpec
|
||||
from . structured_query_client import StructuredQueryClientSpec
|
||||
from . reranker_client import RerankerClientSpec
|
||||
from . reranker_service import RerankerService
|
||||
from . row_embeddings_query_client import RowEmbeddingsQueryClientSpec
|
||||
from . collection_config_handler import CollectionConfigHandler
|
||||
|
||||
|
|
|
|||
|
|
@ -65,31 +65,25 @@ class IamClient(RequestResponse):
|
|||
async def authenticate_anonymous(self, timeout=IAM_TIMEOUT):
|
||||
"""Request anonymous access from the IAM regime.
|
||||
|
||||
Returns ``(user_id, workspace, roles)`` if the regime permits
|
||||
anonymous access, or raises ``RuntimeError`` with error type
|
||||
``auth-failed`` if it does not."""
|
||||
Returns ``(user_id, default_workspace, roles)`` if the regime
|
||||
permits anonymous access, or raises ``RuntimeError`` with
|
||||
error type ``auth-failed`` if it does not."""
|
||||
resp = await self._request(
|
||||
operation="authenticate-anonymous",
|
||||
timeout=timeout,
|
||||
)
|
||||
return (
|
||||
resp.resolved_user_id,
|
||||
resp.resolved_workspace,
|
||||
resp.resolved_default_workspace,
|
||||
list(resp.resolved_roles),
|
||||
)
|
||||
|
||||
async def resolve_api_key(self, api_key, timeout=IAM_TIMEOUT):
|
||||
"""Resolve a plaintext API key to its identity triple.
|
||||
|
||||
Returns ``(user_id, workspace, roles)`` or raises
|
||||
Returns ``(user_id, default_workspace, roles)`` or raises
|
||||
``RuntimeError`` with error type ``auth-failed`` if the key is
|
||||
unknown / expired / revoked.
|
||||
|
||||
Note: the ``roles`` value is a regime-internal hint and is
|
||||
not used by the gateway directly under the IAM contract;
|
||||
all authorisation decisions go through ``authorise()``.
|
||||
Returned here only for backward compatibility with callers
|
||||
that haven't migrated."""
|
||||
unknown / expired / revoked."""
|
||||
resp = await self._request(
|
||||
operation="resolve-api-key",
|
||||
api_key=api_key,
|
||||
|
|
@ -97,7 +91,7 @@ class IamClient(RequestResponse):
|
|||
)
|
||||
return (
|
||||
resp.resolved_user_id,
|
||||
resp.resolved_workspace,
|
||||
resp.resolved_default_workspace,
|
||||
list(resp.resolved_roles),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -157,21 +157,6 @@ class PromptClient(RequestResponse):
|
|||
timeout = timeout,
|
||||
)
|
||||
|
||||
async def kg_prompt(self, query, kg, timeout=600, streaming=False, chunk_callback=None):
|
||||
return await self.prompt(
|
||||
id = "kg-prompt",
|
||||
variables = {
|
||||
"query": query,
|
||||
"knowledge": [
|
||||
{ "s": v[0], "p": v[1], "o": v[2] }
|
||||
for v in kg
|
||||
]
|
||||
},
|
||||
timeout = timeout,
|
||||
streaming = streaming,
|
||||
chunk_callback = chunk_callback,
|
||||
)
|
||||
|
||||
async def document_prompt(self, query, documents, timeout=600, streaming=False, chunk_callback=None):
|
||||
return await self.prompt(
|
||||
id = "document-prompt",
|
||||
|
|
|
|||
43
trustgraph-base/trustgraph/base/reranker_client.py
Normal file
43
trustgraph-base/trustgraph/base/reranker_client.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
|
||||
from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||
from .. schema import (
|
||||
RerankerRequest, RerankerResponse,
|
||||
RerankerQuery, RerankerDocument,
|
||||
)
|
||||
|
||||
class RerankerClient(RequestResponse):
|
||||
async def rerank(self, queries, documents, limit=10, timeout=300):
|
||||
|
||||
resp = await self.request(
|
||||
RerankerRequest(
|
||||
queries=[
|
||||
RerankerQuery(query_id=q["id"], query_text=q["text"])
|
||||
for q in queries
|
||||
],
|
||||
documents=[
|
||||
RerankerDocument(
|
||||
document_id=d["id"], document_text=d["text"]
|
||||
)
|
||||
for d in documents
|
||||
],
|
||||
limit=limit,
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
return resp.results
|
||||
|
||||
class RerankerClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
self, request_name, response_name,
|
||||
):
|
||||
super(RerankerClientSpec, self).__init__(
|
||||
request_name = request_name,
|
||||
request_schema = RerankerRequest,
|
||||
response_name = response_name,
|
||||
response_schema = RerankerResponse,
|
||||
impl = RerankerClient,
|
||||
)
|
||||
109
trustgraph-base/trustgraph/base/reranker_service.py
Normal file
109
trustgraph-base/trustgraph/base/reranker_service.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import logging
|
||||
|
||||
from .. schema import (
|
||||
RerankerRequest, RerankerResponse, RerankerResult, Error,
|
||||
)
|
||||
from .. exceptions import TooManyRequests
|
||||
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec, ParameterSpec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "reranker"
|
||||
default_concurrency = 1
|
||||
|
||||
class RerankerService(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
concurrency = params.get("concurrency", 1)
|
||||
|
||||
super(RerankerService, self).__init__(**params | {
|
||||
"id": id,
|
||||
"concurrency": concurrency,
|
||||
})
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "request",
|
||||
schema = RerankerRequest,
|
||||
handler = self.on_request,
|
||||
concurrency = concurrency,
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "response",
|
||||
schema = RerankerResponse
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ParameterSpec(
|
||||
name = "model",
|
||||
)
|
||||
)
|
||||
|
||||
async def on_request(self, msg, consumer, flow):
|
||||
|
||||
try:
|
||||
|
||||
request = msg.value()
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
logger.debug(f"Handling reranker request {id}...")
|
||||
|
||||
model = flow("model")
|
||||
results = await self.on_rerank(
|
||||
request.queries, request.documents,
|
||||
request.limit, model=model,
|
||||
)
|
||||
|
||||
await flow("response").send(
|
||||
RerankerResponse(
|
||||
error = None,
|
||||
results = results,
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
logger.debug("Reranker request handled successfully")
|
||||
|
||||
except TooManyRequests as e:
|
||||
raise e
|
||||
|
||||
except Exception as e:
|
||||
|
||||
logger.error(f"Exception in reranker service: {e}", exc_info=True)
|
||||
|
||||
logger.info("Sending error response...")
|
||||
|
||||
await flow.producer["response"].send(
|
||||
RerankerResponse(
|
||||
error=Error(
|
||||
type = "reranker-error",
|
||||
message = str(e),
|
||||
),
|
||||
results=[],
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser: ArgumentParser) -> None:
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--concurrency',
|
||||
type=int,
|
||||
default=default_concurrency,
|
||||
help=f'Concurrent processing threads (default: {default_concurrency})'
|
||||
)
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
|
@ -140,20 +140,6 @@ class PromptClient(BaseClient):
|
|||
timeout=timeout
|
||||
)
|
||||
|
||||
def request_kg_prompt(self, query, kg, timeout=300):
|
||||
|
||||
return self.request(
|
||||
id="kg-prompt",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": [
|
||||
{ "s": v[0], "p": v[1], "o": v[2] }
|
||||
for v in kg
|
||||
]
|
||||
},
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
def request_document_prompt(self, query, documents, timeout=300):
|
||||
|
||||
return self.request(
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from .translators.rows_query import RowsQueryRequestTranslator, RowsQueryRespons
|
|||
from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator
|
||||
from .translators.structured_query import StructuredQueryRequestTranslator, StructuredQueryResponseTranslator
|
||||
from .translators.diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
|
||||
from .translators.reranker import RerankerRequestTranslator, RerankerResponseTranslator
|
||||
from .translators.collection import CollectionManagementRequestTranslator, CollectionManagementResponseTranslator
|
||||
from .translators.sparql_query import SparqlQueryRequestTranslator, SparqlQueryResponseTranslator
|
||||
|
||||
|
|
@ -163,6 +164,12 @@ TranslatorRegistry.register_service(
|
|||
SparqlQueryResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"reranker",
|
||||
RerankerRequestTranslator(),
|
||||
RerankerResponseTranslator()
|
||||
)
|
||||
|
||||
# Register single-direction translators for document loading
|
||||
TranslatorRegistry.register_request("document", DocumentTranslator())
|
||||
TranslatorRegistry.register_request("text-document", TextDocumentTranslator())
|
||||
|
|
|
|||
|
|
@ -20,3 +20,4 @@ from .embeddings_query import (
|
|||
)
|
||||
from .rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
|
||||
from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
|
||||
from .reranker import RerankerRequestTranslator, RerankerResponseTranslator
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from ...schema import (
|
|||
UserInput, UserRecord,
|
||||
WorkspaceInput, WorkspaceRecord,
|
||||
ApiKeyInput, ApiKeyRecord,
|
||||
GroupInput, GrantInput,
|
||||
)
|
||||
from .base import MessageTranslator
|
||||
|
||||
|
|
@ -43,12 +44,31 @@ def _api_key_input_from_dict(d):
|
|||
)
|
||||
|
||||
|
||||
def _group_input_from_dict(d):
|
||||
if d is None:
|
||||
return None
|
||||
return GroupInput(
|
||||
name=d.get("name", ""),
|
||||
description=d.get("description", ""),
|
||||
enabled=d.get("enabled", True),
|
||||
)
|
||||
|
||||
|
||||
def _grant_input_from_dict(d):
|
||||
if d is None:
|
||||
return None
|
||||
return GrantInput(
|
||||
capability=d.get("capability", ""),
|
||||
workspace=d.get("workspace", ""),
|
||||
)
|
||||
|
||||
|
||||
def _user_record_to_dict(r):
|
||||
if r is None:
|
||||
return None
|
||||
return {
|
||||
"id": r.id,
|
||||
"workspace": r.workspace,
|
||||
"default_workspace": r.default_workspace,
|
||||
"username": r.username,
|
||||
"name": r.name,
|
||||
"email": r.email,
|
||||
|
|
@ -102,6 +122,15 @@ class IamRequestTranslator(MessageTranslator):
|
|||
data.get("workspace_record")
|
||||
),
|
||||
key=_api_key_input_from_dict(data.get("key")),
|
||||
group_id=data.get("group_id", ""),
|
||||
member_type=data.get("member_type", ""),
|
||||
member_id=data.get("member_id", ""),
|
||||
group=_group_input_from_dict(data.get("group")),
|
||||
grant=_grant_input_from_dict(data.get("grant")),
|
||||
capability=data.get("capability", ""),
|
||||
resource_json=data.get("resource_json", ""),
|
||||
parameters_json=data.get("parameters_json", ""),
|
||||
authorise_checks=data.get("authorise_checks", ""),
|
||||
)
|
||||
|
||||
def encode(self, obj: IamRequest) -> Dict[str, Any]:
|
||||
|
|
@ -109,6 +138,9 @@ class IamRequestTranslator(MessageTranslator):
|
|||
for fname in (
|
||||
"workspace", "actor", "user_id", "username", "key_id",
|
||||
"api_key", "password", "new_password",
|
||||
"group_id", "member_type", "member_id",
|
||||
"capability", "resource_json", "parameters_json",
|
||||
"authorise_checks",
|
||||
):
|
||||
v = getattr(obj, fname, "")
|
||||
if v:
|
||||
|
|
@ -135,6 +167,17 @@ class IamRequestTranslator(MessageTranslator):
|
|||
"name": obj.key.name,
|
||||
"expires": obj.key.expires,
|
||||
}
|
||||
if obj.group is not None:
|
||||
result["group"] = {
|
||||
"name": obj.group.name,
|
||||
"description": obj.group.description,
|
||||
"enabled": obj.group.enabled,
|
||||
}
|
||||
if obj.grant is not None:
|
||||
result["grant"] = {
|
||||
"capability": obj.grant.capability,
|
||||
"workspace": obj.grant.workspace,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -175,8 +218,8 @@ class IamResponseTranslator(MessageTranslator):
|
|||
result["signing_key_public"] = obj.signing_key_public
|
||||
if obj.resolved_user_id:
|
||||
result["resolved_user_id"] = obj.resolved_user_id
|
||||
if obj.resolved_workspace:
|
||||
result["resolved_workspace"] = obj.resolved_workspace
|
||||
if obj.resolved_default_workspace:
|
||||
result["resolved_default_workspace"] = obj.resolved_default_workspace
|
||||
if obj.resolved_roles:
|
||||
result["resolved_roles"] = list(obj.resolved_roles)
|
||||
if obj.temporary_password:
|
||||
|
|
@ -190,6 +233,23 @@ class IamResponseTranslator(MessageTranslator):
|
|||
# setup, so it can't be dropped by a truthy-only filter.
|
||||
result["bootstrap_available"] = bool(obj.bootstrap_available)
|
||||
|
||||
# authorise / authorise-many outputs.
|
||||
if obj.decision_allow:
|
||||
result["decision_allow"] = obj.decision_allow
|
||||
if obj.decision_ttl_seconds:
|
||||
result["decision_ttl_seconds"] = obj.decision_ttl_seconds
|
||||
if obj.decisions_json:
|
||||
result["decisions_json"] = obj.decisions_json
|
||||
|
||||
# Enterprise IAM outputs.
|
||||
for fname in (
|
||||
"group_json", "groups_json", "members_json",
|
||||
"grants_json", "effective_permissions_json",
|
||||
):
|
||||
v = getattr(obj, fname, "")
|
||||
if v:
|
||||
result[fname] = v
|
||||
|
||||
return result
|
||||
|
||||
def encode_with_completion(
|
||||
|
|
|
|||
73
trustgraph-base/trustgraph/messaging/translators/reranker.py
Normal file
73
trustgraph-base/trustgraph/messaging/translators/reranker.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
from typing import Dict, Any, Tuple
|
||||
from ...schema import (
|
||||
RerankerRequest, RerankerResponse,
|
||||
RerankerQuery, RerankerDocument, RerankerResult,
|
||||
)
|
||||
from .base import MessageTranslator
|
||||
|
||||
|
||||
class RerankerRequestTranslator(MessageTranslator):
|
||||
|
||||
def decode(self, data: Dict[str, Any]) -> RerankerRequest:
|
||||
return RerankerRequest(
|
||||
queries=[
|
||||
RerankerQuery(
|
||||
query_id=q["query_id"],
|
||||
query_text=q["query_text"],
|
||||
)
|
||||
for q in data.get("queries", [])
|
||||
],
|
||||
documents=[
|
||||
RerankerDocument(
|
||||
document_id=d["document_id"],
|
||||
document_text=d["document_text"],
|
||||
)
|
||||
for d in data.get("documents", [])
|
||||
],
|
||||
limit=data.get("limit", 10),
|
||||
)
|
||||
|
||||
def encode(self, obj: RerankerRequest) -> Dict[str, Any]:
|
||||
return {
|
||||
"queries": [
|
||||
{"query_id": q.query_id, "query_text": q.query_text}
|
||||
for q in obj.queries
|
||||
],
|
||||
"documents": [
|
||||
{"document_id": d.document_id, "document_text": d.document_text}
|
||||
for d in obj.documents
|
||||
],
|
||||
"limit": obj.limit,
|
||||
}
|
||||
|
||||
|
||||
class RerankerResponseTranslator(MessageTranslator):
|
||||
|
||||
def decode(self, data: Dict[str, Any]) -> RerankerResponse:
|
||||
return RerankerResponse(
|
||||
results=[
|
||||
RerankerResult(
|
||||
document_id=r["document_id"],
|
||||
query_id=r["query_id"],
|
||||
score=r["score"],
|
||||
)
|
||||
for r in data.get("results", [])
|
||||
],
|
||||
)
|
||||
|
||||
def encode(self, obj: RerankerResponse) -> Dict[str, Any]:
|
||||
return {
|
||||
"results": [
|
||||
{
|
||||
"document_id": r.document_id,
|
||||
"query_id": r.query_id,
|
||||
"score": r.score,
|
||||
}
|
||||
for r in obj.results
|
||||
],
|
||||
}
|
||||
|
||||
def encode_with_completion(
|
||||
self, obj: RerankerResponse
|
||||
) -> Tuple[Dict[str, Any], bool]:
|
||||
return self.encode(obj), True
|
||||
|
|
@ -12,6 +12,7 @@ class DocumentRagRequestTranslator(MessageTranslator):
|
|||
query=data["query"],
|
||||
collection=data.get("collection", "default"),
|
||||
doc_limit=int(data.get("doc-limit", 20)),
|
||||
fetch_limit=int(data.get("fetch-limit", 0)),
|
||||
streaming=data.get("streaming", False)
|
||||
)
|
||||
|
||||
|
|
@ -20,6 +21,7 @@ class DocumentRagRequestTranslator(MessageTranslator):
|
|||
"query": obj.query,
|
||||
"collection": obj.collection,
|
||||
"doc-limit": obj.doc_limit,
|
||||
"fetch-limit": obj.fetch_limit,
|
||||
"streaming": getattr(obj, "streaming", False)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -64,6 +64,8 @@ from . uris import (
|
|||
docrag_question_uri,
|
||||
docrag_grounding_uri,
|
||||
docrag_exploration_uri,
|
||||
docrag_focus_uri,
|
||||
chunk_selection_uri,
|
||||
docrag_synthesis_uri,
|
||||
)
|
||||
|
||||
|
|
@ -89,9 +91,13 @@ from . namespaces import (
|
|||
TG_IMAGE_TYPE, TG_SUBGRAPH_TYPE,
|
||||
# Query-time provenance predicates (GraphRAG)
|
||||
TG_QUERY, TG_CONCEPT, TG_ENTITY,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING, TG_SCORE,
|
||||
# Edge selection entity type
|
||||
TG_EDGE_SELECTION,
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
|
||||
# Chunk selection entity type
|
||||
TG_CHUNK_SELECTION,
|
||||
# Explainability entity types
|
||||
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_ANALYSIS, TG_CONCLUSION,
|
||||
|
|
@ -130,6 +136,7 @@ from . triples import (
|
|||
# Query-time provenance triple builders (DocumentRAG)
|
||||
docrag_question_triples,
|
||||
docrag_exploration_triples,
|
||||
docrag_chunk_selection_triples,
|
||||
docrag_synthesis_triples,
|
||||
# Utility
|
||||
set_graph,
|
||||
|
|
@ -194,6 +201,8 @@ __all__ = [
|
|||
"docrag_question_uri",
|
||||
"docrag_grounding_uri",
|
||||
"docrag_exploration_uri",
|
||||
"docrag_focus_uri",
|
||||
"chunk_selection_uri",
|
||||
"docrag_synthesis_uri",
|
||||
# Namespaces
|
||||
"PROV", "PROV_ENTITY", "PROV_ACTIVITY", "PROV_AGENT",
|
||||
|
|
@ -212,9 +221,13 @@ __all__ = [
|
|||
"TG_CHUNK_TYPE", "TG_IMAGE_TYPE", "TG_SUBGRAPH_TYPE",
|
||||
# Query-time provenance predicates (GraphRAG)
|
||||
"TG_QUERY", "TG_CONCEPT", "TG_ENTITY",
|
||||
"TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING",
|
||||
"TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING", "TG_SCORE",
|
||||
# Edge selection entity type
|
||||
"TG_EDGE_SELECTION",
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
"TG_CHUNK_COUNT", "TG_SELECTED_CHUNK",
|
||||
# Chunk selection entity type
|
||||
"TG_CHUNK_SELECTION",
|
||||
# Explainability entity types
|
||||
"TG_QUESTION", "TG_GROUNDING", "TG_EXPLORATION", "TG_FOCUS", "TG_SYNTHESIS",
|
||||
"TG_ANALYSIS", "TG_CONCLUSION",
|
||||
|
|
@ -250,6 +263,7 @@ __all__ = [
|
|||
# Query-time provenance triple builders (DocumentRAG)
|
||||
"docrag_question_triples",
|
||||
"docrag_exploration_triples",
|
||||
"docrag_chunk_selection_triples",
|
||||
"docrag_synthesis_triples",
|
||||
# Agent provenance triple builders
|
||||
"agent_session_triples",
|
||||
|
|
|
|||
|
|
@ -66,12 +66,19 @@ TG_EDGE_COUNT = TG + "edgeCount"
|
|||
TG_SELECTED_EDGE = TG + "selectedEdge"
|
||||
TG_EDGE = TG + "edge"
|
||||
TG_REASONING = TG + "reasoning"
|
||||
TG_SCORE = TG + "score"
|
||||
TG_DOCUMENT = TG + "document" # Reference to document in librarian
|
||||
|
||||
# Edge selection entity type (cross-encoder scored edge in Focus)
|
||||
TG_EDGE_SELECTION = TG + "EdgeSelection"
|
||||
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
TG_CHUNK_COUNT = TG + "chunkCount"
|
||||
TG_SELECTED_CHUNK = TG + "selectedChunk"
|
||||
|
||||
# Chunk selection entity type (cross-encoder reranked chunk in Focus)
|
||||
TG_CHUNK_SELECTION = TG + "ChunkSelection"
|
||||
|
||||
# Extraction provenance entity types
|
||||
TG_DOCUMENT_TYPE = TG + "Document"
|
||||
TG_PAGE_TYPE = TG + "Page"
|
||||
|
|
|
|||
|
|
@ -24,10 +24,14 @@ from . namespaces import (
|
|||
TG_ELEMENT_TYPES, TG_TABLE_COUNT, TG_IMAGE_COUNT,
|
||||
# Query-time provenance predicates (GraphRAG)
|
||||
TG_QUERY, TG_CONCEPT, TG_ENTITY,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
|
||||
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, TG_SCORE,
|
||||
TG_DOCUMENT,
|
||||
# Edge selection entity type
|
||||
TG_EDGE_SELECTION,
|
||||
# Query-time provenance predicates (DocumentRAG)
|
||||
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
|
||||
# Chunk selection entity type
|
||||
TG_CHUNK_SELECTION,
|
||||
# Explainability entity types
|
||||
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
# Unifying types
|
||||
|
|
@ -38,7 +42,10 @@ from . namespaces import (
|
|||
TG_IN_TOKEN, TG_OUT_TOKEN,
|
||||
)
|
||||
|
||||
from . uris import activity_uri, agent_uri, subgraph_uri, edge_selection_uri
|
||||
from . uris import (
|
||||
activity_uri, agent_uri, subgraph_uri, edge_selection_uri,
|
||||
chunk_selection_uri,
|
||||
)
|
||||
|
||||
|
||||
def set_graph(triples: List[Triple], graph: str) -> List[Triple]:
|
||||
|
|
@ -536,10 +543,9 @@ def focus_triples(
|
|||
_triple(focus_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)),
|
||||
]
|
||||
|
||||
# Add each selected edge with its reasoning via intermediate entity
|
||||
# Add each selected edge with metadata via intermediate entity
|
||||
for idx, edge_info in enumerate(selected_edges_with_reasoning):
|
||||
edge = edge_info.get("edge")
|
||||
reasoning = edge_info.get("reasoning", "")
|
||||
|
||||
if edge:
|
||||
s, p, o = edge
|
||||
|
|
@ -552,13 +558,32 @@ def focus_triples(
|
|||
_triple(focus_uri, TG_SELECTED_EDGE, _iri(edge_sel_uri))
|
||||
)
|
||||
|
||||
# Type the edge selection entity
|
||||
triples.append(
|
||||
_triple(edge_sel_uri, RDF_TYPE, _iri(TG_EDGE_SELECTION))
|
||||
)
|
||||
|
||||
# Attach quoted triple to edge selection entity
|
||||
quoted = _quoted_triple(s, p, o)
|
||||
triples.append(
|
||||
Triple(s=_iri(edge_sel_uri), p=_iri(TG_EDGE), o=quoted)
|
||||
)
|
||||
|
||||
# Attach reasoning to edge selection entity
|
||||
# Structured cross-encoder metadata
|
||||
concept = edge_info.get("concept")
|
||||
if concept:
|
||||
triples.append(
|
||||
_triple(edge_sel_uri, TG_CONCEPT, _literal(concept))
|
||||
)
|
||||
|
||||
score = edge_info.get("score")
|
||||
if score is not None:
|
||||
triples.append(
|
||||
_triple(edge_sel_uri, TG_SCORE, _literal(str(score)))
|
||||
)
|
||||
|
||||
# Legacy reasoning text (for non-cross-encoder callers)
|
||||
reasoning = edge_info.get("reasoning", "")
|
||||
if reasoning:
|
||||
triples.append(
|
||||
_triple(edge_sel_uri, TG_REASONING, _literal(reasoning))
|
||||
|
|
@ -698,6 +723,75 @@ def docrag_exploration_triples(
|
|||
return triples
|
||||
|
||||
|
||||
def docrag_chunk_selection_triples(
|
||||
focus_uri: str,
|
||||
exploration_uri: str,
|
||||
selected_chunks_with_scores: List[dict],
|
||||
session_id: str,
|
||||
) -> List[Triple]:
|
||||
"""
|
||||
Build triples for a document RAG focus entity (chunks selected by the
|
||||
cross-encoder reranker).
|
||||
|
||||
Mirrors GraphRAG's focus_triples / tg:EdgeSelection pattern: a Focus entity
|
||||
derived from exploration, with one ChunkSelection sub-entity per surviving
|
||||
chunk carrying the chunk reference and the reranker score.
|
||||
|
||||
Structure:
|
||||
<focus> a tg:Focus ; prov:wasDerivedFrom <exploration> .
|
||||
<focus> tg:selectedChunk <chunk_sel_0> .
|
||||
<chunk_sel_0> a tg:ChunkSelection .
|
||||
<chunk_sel_0> tg:document <chunk_id> .
|
||||
<chunk_sel_0> tg:score "0.97" .
|
||||
|
||||
Args:
|
||||
focus_uri: URI of the focus entity (from docrag_focus_uri)
|
||||
exploration_uri: URI of the parent exploration entity
|
||||
selected_chunks_with_scores: List of dicts with 'chunk_id' and 'score'
|
||||
session_id: Session UUID for generating chunk selection URIs
|
||||
|
||||
Returns:
|
||||
List of Triple objects
|
||||
"""
|
||||
triples = [
|
||||
_triple(focus_uri, RDF_TYPE, _iri(PROV_ENTITY)),
|
||||
_triple(focus_uri, RDF_TYPE, _iri(TG_FOCUS)),
|
||||
_triple(focus_uri, RDFS_LABEL, _literal("Chunk Selection")),
|
||||
_triple(focus_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)),
|
||||
]
|
||||
|
||||
for idx, chunk_info in enumerate(selected_chunks_with_scores):
|
||||
chunk_id = chunk_info.get("chunk_id")
|
||||
if not chunk_id:
|
||||
continue
|
||||
|
||||
chunk_sel_uri = chunk_selection_uri(session_id, idx)
|
||||
|
||||
# Link focus to chunk selection entity
|
||||
triples.append(
|
||||
_triple(focus_uri, TG_SELECTED_CHUNK, _iri(chunk_sel_uri))
|
||||
)
|
||||
|
||||
# Type the chunk selection entity
|
||||
triples.append(
|
||||
_triple(chunk_sel_uri, RDF_TYPE, _iri(TG_CHUNK_SELECTION))
|
||||
)
|
||||
|
||||
# Reference the actual chunk (in librarian)
|
||||
triples.append(
|
||||
_triple(chunk_sel_uri, TG_DOCUMENT, _iri(chunk_id))
|
||||
)
|
||||
|
||||
# Cross-encoder score
|
||||
score = chunk_info.get("score")
|
||||
if score is not None:
|
||||
triples.append(
|
||||
_triple(chunk_sel_uri, TG_SCORE, _literal(str(score)))
|
||||
)
|
||||
|
||||
return triples
|
||||
|
||||
|
||||
def docrag_synthesis_triples(
|
||||
synthesis_uri: str,
|
||||
exploration_uri: str,
|
||||
|
|
|
|||
|
|
@ -309,6 +309,35 @@ def docrag_exploration_uri(session_id: str) -> str:
|
|||
return f"urn:trustgraph:docrag:{session_id}/exploration"
|
||||
|
||||
|
||||
def docrag_focus_uri(session_id: str) -> str:
|
||||
"""
|
||||
Generate URI for a document RAG focus entity (chunks selected by the
|
||||
cross-encoder reranker).
|
||||
|
||||
Args:
|
||||
session_id: The session UUID.
|
||||
|
||||
Returns:
|
||||
URN in format: urn:trustgraph:docrag:{uuid}/focus
|
||||
"""
|
||||
return f"urn:trustgraph:docrag:{session_id}/focus"
|
||||
|
||||
|
||||
def chunk_selection_uri(session_id: str, chunk_index: int) -> str:
|
||||
"""
|
||||
Generate URI for a chunk selection item (links a reranked chunk to its
|
||||
score). Mirrors edge_selection_uri for GraphRAG.
|
||||
|
||||
Args:
|
||||
session_id: The session UUID.
|
||||
chunk_index: Index of this chunk in the selection (0-based).
|
||||
|
||||
Returns:
|
||||
URN in format: urn:trustgraph:prov:chunk:{uuid}:{index}
|
||||
"""
|
||||
return f"urn:trustgraph:prov:chunk:{session_id}:{chunk_index}"
|
||||
|
||||
|
||||
def docrag_synthesis_uri(session_id: str) -> str:
|
||||
"""
|
||||
Generate URI for a document RAG synthesis entity (final answer).
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@ from . namespaces import (
|
|||
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
|
||||
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
|
||||
TG_SUBAGENT_GOAL, TG_PLAN_STEP,
|
||||
TG_EDGE_SELECTION, TG_SCORE,
|
||||
TG_CHUNK_SELECTION,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -93,6 +95,8 @@ TG_CLASS_LABELS = [
|
|||
_label_triple(TG_FINDING, "Finding"),
|
||||
_label_triple(TG_PLAN_TYPE, "Plan"),
|
||||
_label_triple(TG_STEP_RESULT, "Step Result"),
|
||||
_label_triple(TG_EDGE_SELECTION, "Edge Selection"),
|
||||
_label_triple(TG_CHUNK_SELECTION, "Chunk Selection"),
|
||||
]
|
||||
|
||||
# TrustGraph predicate labels
|
||||
|
|
@ -117,6 +121,7 @@ TG_PREDICATE_LABELS = [
|
|||
_label_triple(TG_ENTITY, "entity"),
|
||||
_label_triple(TG_SUBAGENT_GOAL, "subagent goal"),
|
||||
_label_triple(TG_PLAN_STEP, "plan step"),
|
||||
_label_triple(TG_SCORE, "score"),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,3 +16,4 @@ from .collection import *
|
|||
from .storage import *
|
||||
from .tool_service import *
|
||||
from .sparql_query import *
|
||||
from .reranker import *
|
||||
|
|
@ -29,7 +29,7 @@ class UserInput:
|
|||
@dataclass
|
||||
class UserRecord:
|
||||
id: str = ""
|
||||
workspace: str = ""
|
||||
default_workspace: str = ""
|
||||
username: str = ""
|
||||
name: str = ""
|
||||
email: str = ""
|
||||
|
|
@ -74,6 +74,21 @@ class ApiKeyRecord:
|
|||
last_used: str = ""
|
||||
|
||||
|
||||
# ---- Enterprise IAM types (additive) ----
|
||||
|
||||
@dataclass
|
||||
class GroupInput:
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class GrantInput:
|
||||
capability: str = ""
|
||||
workspace: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class IamRequest:
|
||||
operation: str = ""
|
||||
|
|
@ -99,6 +114,13 @@ class IamRequest:
|
|||
workspace_record: WorkspaceInput | None = None
|
||||
key: ApiKeyInput | None = None
|
||||
|
||||
# ---- Enterprise IAM inputs (additive) ----
|
||||
group_id: str = ""
|
||||
member_type: str = ""
|
||||
member_id: str = ""
|
||||
group: GroupInput | None = None
|
||||
grant: GrantInput | None = None
|
||||
|
||||
# ---- authorise / authorise-many inputs ----
|
||||
# Capability string from the vocabulary in capabilities.md.
|
||||
capability: str = ""
|
||||
|
|
@ -138,7 +160,7 @@ class IamResponse:
|
|||
|
||||
# resolve-api-key
|
||||
resolved_user_id: str = ""
|
||||
resolved_workspace: str = ""
|
||||
resolved_default_workspace: str = ""
|
||||
resolved_roles: list[str] = field(default_factory=list)
|
||||
|
||||
# reset-password
|
||||
|
|
@ -164,6 +186,14 @@ class IamResponse:
|
|||
# authorise_checks.
|
||||
decisions_json: str = ""
|
||||
|
||||
# ---- Enterprise IAM outputs (additive) ----
|
||||
# JSON-serialised payloads for enterprise group/grant operations.
|
||||
group_json: str = ""
|
||||
groups_json: str = ""
|
||||
members_json: str = ""
|
||||
grants_json: str = ""
|
||||
effective_permissions_json: str = ""
|
||||
|
||||
error: Error | None = None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,17 +6,6 @@ from ..core.primitives import Error
|
|||
|
||||
# Prompt services, abstract the prompt generation
|
||||
|
||||
# extract-definitions:
|
||||
# chunk -> definitions
|
||||
# extract-relationships:
|
||||
# chunk -> relationships
|
||||
# kg-prompt:
|
||||
# query, triples -> answer
|
||||
# document-prompt:
|
||||
# query, documents -> answer
|
||||
# extract-rows
|
||||
# schema, chunk -> rows
|
||||
|
||||
@dataclass
|
||||
class PromptRequest:
|
||||
id: str = ""
|
||||
|
|
|
|||
35
trustgraph-base/trustgraph/schema/services/reranker.py
Normal file
35
trustgraph-base/trustgraph/schema/services/reranker.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..core.primitives import Error
|
||||
|
||||
############################################################################
|
||||
|
||||
# Cross-encoder reranker
|
||||
|
||||
@dataclass
|
||||
class RerankerQuery:
|
||||
query_id: str = ""
|
||||
query_text: str = ""
|
||||
|
||||
@dataclass
|
||||
class RerankerDocument:
|
||||
document_id: str = ""
|
||||
document_text: str = ""
|
||||
|
||||
@dataclass
|
||||
class RerankerRequest:
|
||||
queries: list[RerankerQuery] = field(default_factory=list)
|
||||
documents: list[RerankerDocument] = field(default_factory=list)
|
||||
limit: int = 10
|
||||
|
||||
@dataclass
|
||||
class RerankerResult:
|
||||
document_id: str = ""
|
||||
query_id: str = ""
|
||||
score: float = 0.0
|
||||
|
||||
@dataclass
|
||||
class RerankerResponse:
|
||||
error: Error | None = None
|
||||
results: list[RerankerResult] = field(default_factory=list)
|
||||
|
|
@ -40,7 +40,10 @@ class GraphRagResponse:
|
|||
class DocumentRagQuery:
|
||||
query: str = ""
|
||||
collection: str = ""
|
||||
doc_limit: int = 0
|
||||
doc_limit: int = 0 # docs selected into the synthesis prompt
|
||||
fetch_limit: int = 0 # candidate pool fetched from the vector store
|
||||
# before reranking (0 = derive from doc_limit;
|
||||
# values below doc_limit are raised to it)
|
||||
streaming: bool = False
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
|
|||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
dependencies = [
|
||||
"trustgraph-base>=2.5,<2.6",
|
||||
"trustgraph-base>=2.6,<2.7",
|
||||
"pulsar-client",
|
||||
"prometheus-client",
|
||||
"boto3",
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
|
|||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
dependencies = [
|
||||
"trustgraph-base>=2.5,<2.6",
|
||||
"trustgraph-base>=2.6,<2.7",
|
||||
"requests",
|
||||
"pulsar-client",
|
||||
"aiohttp",
|
||||
|
|
@ -71,6 +71,7 @@ tg-invoke-rows-query = "trustgraph.cli.invoke_rows_query:main"
|
|||
tg-invoke-sparql-query = "trustgraph.cli.invoke_sparql_query:main"
|
||||
tg-invoke-row-embeddings = "trustgraph.cli.invoke_row_embeddings:main"
|
||||
tg-invoke-prompt = "trustgraph.cli.invoke_prompt:main"
|
||||
tg-invoke-reranker = "trustgraph.cli.invoke_reranker:main"
|
||||
tg-invoke-structured-query = "trustgraph.cli.invoke_structured_query:main"
|
||||
tg-load-doc-embeds = "trustgraph.cli.load_doc_embeds:main"
|
||||
tg-load-kg-core = "trustgraph.cli.load_kg_core:main"
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ def main():
|
|||
help="Auth token (default: $TRUSTGRAPH_TOKEN)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--username", required=True, help="Username (unique in workspace)",
|
||||
"--username", required=True, help="Username (globally unique)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--password", default=None,
|
||||
|
|
@ -75,10 +75,7 @@ def main():
|
|||
)
|
||||
parser.add_argument(
|
||||
"-w", "--workspace", default=None,
|
||||
help=(
|
||||
"Target workspace (admin only; defaults to caller's "
|
||||
"assigned workspace)"
|
||||
),
|
||||
help="Default workspace for the new user",
|
||||
)
|
||||
run_main(do_create_user, parser)
|
||||
|
||||
|
|
|
|||
|
|
@ -21,10 +21,12 @@ default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
|||
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
|
||||
default_collection = 'default'
|
||||
default_doc_limit = 10
|
||||
default_fetch_limit = 0
|
||||
|
||||
|
||||
def question_explainable(
|
||||
url, flow_id, question_text, collection, doc_limit, token=None, debug=False,
|
||||
url, flow_id, question_text, collection, doc_limit, fetch_limit=0,
|
||||
token=None, debug=False,
|
||||
workspace="default",
|
||||
):
|
||||
"""Execute document RAG with explainability - shows provenance events inline."""
|
||||
|
|
@ -39,6 +41,7 @@ def question_explainable(
|
|||
query=question_text,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
fetch_limit=fetch_limit,
|
||||
):
|
||||
if isinstance(item, RAGChunk):
|
||||
# Print response content
|
||||
|
|
@ -97,7 +100,7 @@ def question_explainable(
|
|||
|
||||
|
||||
def question(
|
||||
url, flow_id, question_text, collection, doc_limit,
|
||||
url, flow_id, question_text, collection, doc_limit, fetch_limit=0,
|
||||
streaming=True, token=None, explainable=False, debug=False,
|
||||
show_usage=False, workspace="default",
|
||||
):
|
||||
|
|
@ -109,6 +112,7 @@ def question(
|
|||
question_text=question_text,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
fetch_limit=fetch_limit,
|
||||
token=token,
|
||||
debug=debug,
|
||||
workspace=workspace,
|
||||
|
|
@ -128,6 +132,7 @@ def question(
|
|||
query=question_text,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
fetch_limit=fetch_limit,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
|
|
@ -155,6 +160,7 @@ def question(
|
|||
query=question_text,
|
||||
collection=collection,
|
||||
doc_limit=doc_limit,
|
||||
fetch_limit=fetch_limit,
|
||||
)
|
||||
print(result.text)
|
||||
|
||||
|
|
@ -214,7 +220,15 @@ def main():
|
|||
'-d', '--doc-limit',
|
||||
type=int,
|
||||
default=default_doc_limit,
|
||||
help=f'Document limit (default: {default_doc_limit})'
|
||||
help=f'Documents selected into the prompt (default: {default_doc_limit})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--fetch-limit',
|
||||
type=int,
|
||||
default=default_fetch_limit,
|
||||
help='Candidate documents fetched from the vector store before '
|
||||
'reranking (default: derive from doc-limit)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
|
@ -251,6 +265,7 @@ def main():
|
|||
question_text=args.question,
|
||||
collection=args.collection,
|
||||
doc_limit=args.doc_limit,
|
||||
fetch_limit=args.fetch_limit,
|
||||
streaming=not args.no_streaming,
|
||||
token=args.token,
|
||||
explainable=args.explainable,
|
||||
|
|
|
|||
|
|
@ -112,14 +112,13 @@ def _question_explainable_api(
|
|||
if focus_full and focus_full.edge_selections:
|
||||
for edge_sel in focus_full.edge_selections:
|
||||
if edge_sel.edge:
|
||||
# Resolve labels for edge components
|
||||
s_label, p_label, o_label = explain_client.resolve_edge_labels(
|
||||
edge_sel.edge, collection
|
||||
)
|
||||
print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr)
|
||||
if edge_sel.reasoning:
|
||||
r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning
|
||||
print(f" Reason: {r_short}", file=sys.stderr)
|
||||
if edge_sel.concept or edge_sel.score is not None:
|
||||
score_str = f"{edge_sel.score:.4f}" if edge_sel.score is not None else "?"
|
||||
print(f" Concept: {edge_sel.concept} Score: {score_str}", file=sys.stderr)
|
||||
|
||||
elif isinstance(entity, Synthesis):
|
||||
print(f"\n [synthesis] {prov_id}", file=sys.stderr)
|
||||
|
|
|
|||
127
trustgraph-cli/trustgraph/cli/invoke_reranker.py
Normal file
127
trustgraph-cli/trustgraph/cli/invoke_reranker.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
"""
|
||||
Invokes the reranker service to score and rank documents by relevance
|
||||
to one or more queries.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from trustgraph.api import Api
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
|
||||
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
|
||||
|
||||
def query(url, flow_id, queries, documents, limit, token=None,
|
||||
workspace="default"):
|
||||
|
||||
api = Api(url=url, token=token, workspace=workspace)
|
||||
socket = api.socket()
|
||||
flow = socket.flow(flow_id)
|
||||
|
||||
try:
|
||||
|
||||
query_objects = [
|
||||
{"query_id": str(i), "query_text": q}
|
||||
for i, q in enumerate(queries)
|
||||
]
|
||||
|
||||
document_objects = [
|
||||
{"document_id": str(i), "document_text": d}
|
||||
for i, d in enumerate(documents)
|
||||
]
|
||||
|
||||
result = flow.rerank(
|
||||
queries=query_objects,
|
||||
documents=document_objects,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
if "error" in result and result["error"]:
|
||||
err = result["error"]
|
||||
print(f"Error: [{err.get('type', '')}] {err.get('message', '')}")
|
||||
return
|
||||
|
||||
for r in result.get("results", []):
|
||||
doc_idx = int(r["document_id"])
|
||||
query_idx = int(r["query_id"])
|
||||
print(
|
||||
f" {r['score']:.4f} | "
|
||||
f"query: {queries[query_idx]} | "
|
||||
f"doc: {documents[doc_idx]}"
|
||||
)
|
||||
|
||||
finally:
|
||||
socket.close()
|
||||
|
||||
def main():
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='tg-invoke-reranker',
|
||||
description=__doc__,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-u', '--url',
|
||||
default=default_url,
|
||||
help=f'API URL (default: {default_url})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--token',
|
||||
default=default_token,
|
||||
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-w', '--workspace',
|
||||
default=default_workspace,
|
||||
help=f'Workspace (default: {default_workspace})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-f', '--flow-id',
|
||||
default="default",
|
||||
help=f'Flow ID (default: default)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-l', '--limit',
|
||||
type=int,
|
||||
default=10,
|
||||
help='Maximum number of results (default: 10)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-q', '--query',
|
||||
action='append',
|
||||
required=True,
|
||||
help='Query text (can be specified multiple times)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'documents',
|
||||
nargs='+',
|
||||
help='Documents to rerank',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
||||
query(
|
||||
url=args.url,
|
||||
flow_id=args.flow_id,
|
||||
queries=args.query,
|
||||
documents=args.documents,
|
||||
limit=args.limit,
|
||||
token=args.token,
|
||||
workspace=args.workspace,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print("Exception:", e, flush=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -51,8 +51,8 @@ def main():
|
|||
parser.add_argument(
|
||||
"-w", "--workspace", default=None,
|
||||
help=(
|
||||
"Optional workspace to log in against. Defaults to "
|
||||
"the user's assigned workspace."
|
||||
"Override the default workspace for this session's JWT. "
|
||||
"If omitted, uses the user's stored default workspace."
|
||||
),
|
||||
)
|
||||
run_main(do_login, parser)
|
||||
|
|
|
|||
|
|
@ -203,9 +203,9 @@ def print_graphrag_text(trace, explain_client, flow, collection, api=None, show_
|
|||
)
|
||||
print(f" {i}. ({s_label}, {p_label}, {o_label})")
|
||||
|
||||
if edge_sel.reasoning:
|
||||
r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning
|
||||
print(f" Reasoning: {r_short}")
|
||||
if edge_sel.concept or edge_sel.score is not None:
|
||||
score_str = f"{edge_sel.score:.4f}" if edge_sel.score is not None else "?"
|
||||
print(f" Concept: {edge_sel.concept} Score: {score_str}")
|
||||
|
||||
if show_provenance and edge_sel.edge:
|
||||
provenance = trace_edge_provenance(
|
||||
|
|
@ -519,7 +519,8 @@ def trace_to_dict(trace, trace_type):
|
|||
"selected_edges": [
|
||||
{
|
||||
"edge": edge_sel.edge,
|
||||
"reasoning": edge_sel.reasoning,
|
||||
"concept": edge_sel.concept,
|
||||
"score": edge_sel.score,
|
||||
}
|
||||
for edge_sel in focus.edge_selections
|
||||
],
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ def do_update_user(args):
|
|||
print(f"username : {rec.get('username', '')}")
|
||||
print(f"name : {rec.get('name', '')}")
|
||||
print(f"email : {rec.get('email', '')}")
|
||||
print(f"workspace : {rec.get('workspace', '')}")
|
||||
print(f"default_ws: {rec.get('default_workspace', '')}")
|
||||
print(f"roles : {', '.join(rec.get('roles', []))}")
|
||||
print(f"enabled : {'yes' if rec.get('enabled') else 'no'}")
|
||||
print(
|
||||
|
|
@ -114,7 +114,7 @@ def main():
|
|||
"-w", "--workspace", default=None,
|
||||
help=(
|
||||
"Optional workspace integrity check — when supplied, "
|
||||
"iam-svc verifies the target user's home workspace "
|
||||
"iam-svc verifies the target user's default workspace "
|
||||
"matches"
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ description = "HuggingFace embeddings support for TrustGraph."
|
|||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
dependencies = [
|
||||
"trustgraph-base>=2.5,<2.6",
|
||||
"trustgraph-flow>=2.5,<2.6",
|
||||
"trustgraph-base>=2.6,<2.7",
|
||||
"trustgraph-flow>=2.6,<2.7",
|
||||
"torch",
|
||||
"urllib3",
|
||||
"transformers",
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
|
|||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
dependencies = [
|
||||
"trustgraph-base>=2.5,<2.6",
|
||||
"trustgraph-base>=2.6,<2.7",
|
||||
"aiohttp",
|
||||
"anthropic",
|
||||
"scylla-driver",
|
||||
|
|
@ -19,6 +19,7 @@ dependencies = [
|
|||
"faiss-cpu",
|
||||
"falkordb",
|
||||
"fastembed",
|
||||
"flashrank",
|
||||
"ibis",
|
||||
"jsonschema",
|
||||
"langchain",
|
||||
|
|
@ -83,6 +84,7 @@ graph-embeddings-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone:
|
|||
graph-embeddings-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run"
|
||||
graph-embeddings = "trustgraph.embeddings.graph_embeddings:run"
|
||||
graph-rag = "trustgraph.retrieval.graph_rag:run"
|
||||
reranker-flashrank = "trustgraph.reranker.flashrank:run"
|
||||
kg-extract-agent = "trustgraph.extract.kg.agent:run"
|
||||
kg-extract-definitions = "trustgraph.extract.kg.definitions:run"
|
||||
kg-extract-rows = "trustgraph.extract.kg.rows:run"
|
||||
|
|
|
|||
|
|
@ -18,6 +18,10 @@ description : str (default "Default")
|
|||
Human-readable description passed to flow-svc.
|
||||
parameters : dict (optional)
|
||||
Optional parameter overrides passed to start-flow.
|
||||
list_timeout : int (default 10)
|
||||
Timeout in seconds for the list-flows request.
|
||||
start_timeout : int (default 30)
|
||||
Timeout in seconds for the start-flow request.
|
||||
"""
|
||||
|
||||
from trustgraph.schema import FlowRequest
|
||||
|
|
@ -34,6 +38,8 @@ class DefaultFlowStart(Initialiser):
|
|||
blueprint=None,
|
||||
description="Default",
|
||||
parameters=None,
|
||||
list_timeout=10,
|
||||
start_timeout=30,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
|
@ -46,6 +52,8 @@ class DefaultFlowStart(Initialiser):
|
|||
self.blueprint = blueprint
|
||||
self.description = description
|
||||
self.parameters = dict(parameters) if parameters else {}
|
||||
self.list_timeout = list_timeout
|
||||
self.start_timeout = start_timeout
|
||||
|
||||
async def run(self, ctx, old_flag, new_flag):
|
||||
|
||||
|
|
@ -70,7 +78,7 @@ class DefaultFlowStart(Initialiser):
|
|||
FlowRequest(
|
||||
operation="list-flows",
|
||||
),
|
||||
timeout=10,
|
||||
timeout=self.list_timeout,
|
||||
)
|
||||
if list_resp.error:
|
||||
raise RuntimeError(
|
||||
|
|
@ -99,7 +107,7 @@ class DefaultFlowStart(Initialiser):
|
|||
description=self.description,
|
||||
parameters=self.parameters,
|
||||
),
|
||||
timeout=30,
|
||||
timeout=self.start_timeout,
|
||||
)
|
||||
if resp.error:
|
||||
raise RuntimeError(
|
||||
|
|
|
|||
|
|
@ -14,7 +14,9 @@ seed_file : str (required when source=="seed-file")
|
|||
Path to a JSON seed file with the same shape TemplateSeed consumes.
|
||||
overwrite : bool (default False)
|
||||
On re-run (flag change), if True overwrite all keys; if False,
|
||||
upsert-missing-only (preserves in-workspace customisations).
|
||||
upsert-missing-only (preserves in-workspace customisations)
|
||||
iam_timeout : int (default 10)
|
||||
Timeout in seconds for the IAM create-workspace request.
|
||||
|
||||
Raises (in ``run``)
|
||||
-------------------
|
||||
|
|
@ -41,7 +43,9 @@ class WorkspaceInit(Initialiser):
|
|||
source="template",
|
||||
seed_file=None,
|
||||
overwrite=False,
|
||||
iam_timeout=10,
|
||||
**kwargs,
|
||||
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
|
@ -59,6 +63,7 @@ class WorkspaceInit(Initialiser):
|
|||
self.source = source
|
||||
self.seed_file = seed_file
|
||||
self.overwrite = overwrite
|
||||
self.iam_timeout = iam_timeout
|
||||
|
||||
async def run(self, ctx, old_flag, new_flag):
|
||||
await self._create_workspace(ctx)
|
||||
|
|
@ -123,7 +128,7 @@ class WorkspaceInit(Initialiser):
|
|||
enabled=True,
|
||||
),
|
||||
),
|
||||
timeout=10,
|
||||
timeout=self.iam_timeout,
|
||||
)
|
||||
if resp.error:
|
||||
if resp.error.type == "duplicate":
|
||||
|
|
|
|||
|
|
@ -57,16 +57,17 @@ class Identity:
|
|||
# the OSS regime this is the user record's id; the gateway
|
||||
# treats it as a string with no semantic content.
|
||||
handle: str
|
||||
# The workspace this credential authenticates to. Used by the
|
||||
# gateway as the default-fill-in for operations that omit a
|
||||
# workspace. Never used as policy input.
|
||||
workspace: str
|
||||
# The user's default workspace. Used by the gateway as the
|
||||
# default-fill-in for operations that omit a workspace. Not a
|
||||
# permission boundary — workspace access is controlled by the
|
||||
# IAM regime's authorise() decision, not by this field.
|
||||
default_workspace: str
|
||||
# Stable identifier for audit logs. In OSS this is the same
|
||||
# value as ``handle``; not assumed equal in the contract.
|
||||
principal_id: str
|
||||
# How the credential was presented. Non-policy; useful for
|
||||
# logs / metrics only.
|
||||
source: str # "api-key" | "jwt"
|
||||
source: str # "api-key" | "jwt" | "anonymous"
|
||||
|
||||
|
||||
def _auth_failure():
|
||||
|
|
@ -256,21 +257,22 @@ class IamAuth:
|
|||
raise _auth_failure()
|
||||
|
||||
sub = claims.get("sub", "")
|
||||
ws = claims.get("workspace", "")
|
||||
ws = claims.get("default_workspace", "")
|
||||
if not sub or not ws:
|
||||
raise _auth_failure()
|
||||
|
||||
# JWT carries no policy state under the IAM contract;
|
||||
# any roles / claims field is ignored here.
|
||||
return Identity(
|
||||
handle=sub, workspace=ws, principal_id=sub, source="jwt",
|
||||
handle=sub, default_workspace=ws,
|
||||
principal_id=sub, source="jwt",
|
||||
)
|
||||
|
||||
async def _authenticate_anonymous(self):
|
||||
try:
|
||||
async def _call(client):
|
||||
return await client.authenticate_anonymous()
|
||||
user_id, workspace, _roles = await self._with_client(_call)
|
||||
user_id, default_workspace, _roles = await self._with_client(
|
||||
_call,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Anonymous authentication rejected: "
|
||||
|
|
@ -278,11 +280,11 @@ class IamAuth:
|
|||
)
|
||||
raise _auth_failure()
|
||||
|
||||
if not user_id or not workspace:
|
||||
if not user_id or not default_workspace:
|
||||
raise _auth_failure()
|
||||
|
||||
return Identity(
|
||||
handle=user_id, workspace=workspace,
|
||||
handle=user_id, default_workspace=default_workspace,
|
||||
principal_id=user_id, source="anonymous",
|
||||
)
|
||||
|
||||
|
|
@ -305,7 +307,9 @@ class IamAuth:
|
|||
# ``roles`` is returned by the OSS regime as a hint
|
||||
# but is not consulted by the gateway; all policy
|
||||
# decisions go through ``authorise``.
|
||||
user_id, workspace, _roles = await self._with_client(_call)
|
||||
user_id, default_workspace, _roles = await self._with_client(
|
||||
_call,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"API key resolution failed: "
|
||||
|
|
@ -313,11 +317,11 @@ class IamAuth:
|
|||
)
|
||||
raise _auth_failure()
|
||||
|
||||
if not user_id or not workspace:
|
||||
if not user_id or not default_workspace:
|
||||
raise _auth_failure()
|
||||
|
||||
identity = Identity(
|
||||
handle=user_id, workspace=workspace,
|
||||
handle=user_id, default_workspace=default_workspace,
|
||||
principal_id=user_id, source="api-key",
|
||||
)
|
||||
self._key_cache[h] = (identity, now + API_KEY_CACHE_TTL)
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ async def enforce_workspace(data, identity, auth, capability=None):
|
|||
return data
|
||||
|
||||
requested = data.get("workspace", "")
|
||||
target = requested or identity.workspace
|
||||
target = requested or identity.default_workspace
|
||||
data["workspace"] = target
|
||||
|
||||
if target not in auth.known_workspaces:
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
|
|||
from . document_embeddings_query import DocumentEmbeddingsQueryRequestor
|
||||
from . row_embeddings_query import RowEmbeddingsQueryRequestor
|
||||
from . mcp_tool import McpToolRequestor
|
||||
from . reranker import RerankerRequestor
|
||||
from . text_load import TextLoad
|
||||
from . document_load import DocumentLoad
|
||||
|
||||
|
|
@ -74,6 +75,7 @@ request_response_dispatchers = {
|
|||
"structured-diag": StructuredDiagRequestor,
|
||||
"row-embeddings": RowEmbeddingsQueryRequestor,
|
||||
"sparql": SparqlQueryRequestor,
|
||||
"reranker": RerankerRequestor,
|
||||
}
|
||||
|
||||
system_dispatchers = {
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import uuid
|
|||
import logging
|
||||
|
||||
from ..capabilities import PUBLIC, AUTHENTICATED
|
||||
from ..registry import ResourceLevel
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -79,7 +80,7 @@ class Mux:
|
|||
self.identity = identity
|
||||
await self.ws.send_json({
|
||||
"type": "auth-ok",
|
||||
"workspace": identity.workspace,
|
||||
"default_workspace": identity.default_workspace,
|
||||
})
|
||||
|
||||
async def receive(self, msg):
|
||||
|
|
@ -159,12 +160,14 @@ class Mux:
|
|||
return
|
||||
|
||||
# Resolve workspace (default-fill from the caller's
|
||||
# bound workspace). Workspace resolution applies to all
|
||||
# operations regardless of capability level.
|
||||
# bound workspace). The envelope workspace is the
|
||||
# single canonical workspace for routing AND
|
||||
# authorisation. The inner request body's workspace
|
||||
# field is not consulted — workspace-scoped services
|
||||
# receive workspace from the queue identity, not the
|
||||
# message body.
|
||||
try:
|
||||
await enforce_workspace(data, self.identity, self.auth)
|
||||
if isinstance(inner, dict):
|
||||
await enforce_workspace(inner, self.identity, self.auth)
|
||||
|
||||
# Authorisation: capability sentinels short-circuit
|
||||
# the regime call; capability strings go through
|
||||
|
|
@ -176,11 +179,18 @@ class Mux:
|
|||
"flow": data.get("flow", ""),
|
||||
}
|
||||
parameters = {}
|
||||
elif op.resource_level == ResourceLevel.WORKSPACE:
|
||||
# Workspace-scoped services (config, flow,
|
||||
# librarian, etc.) — workspace comes from the
|
||||
# envelope, same as flow-level services.
|
||||
resource = {
|
||||
"workspace": data.get("workspace", ""),
|
||||
}
|
||||
parameters = {}
|
||||
else:
|
||||
# Build a minimal RequestContext so the matched
|
||||
# operation's own extractors decide resource
|
||||
# and parameters — same path the HTTP
|
||||
# endpoints take.
|
||||
# System-level services (IAM) — resource is
|
||||
# {} and parameters come from the inner body
|
||||
# (e.g. user.workspace, workspace_record.id).
|
||||
from ..registry import RequestContext
|
||||
ctx = RequestContext(
|
||||
body=inner if isinstance(inner, dict) else {},
|
||||
|
|
|
|||
31
trustgraph-flow/trustgraph/gateway/dispatch/reranker.py
Normal file
31
trustgraph-flow/trustgraph/gateway/dispatch/reranker.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
|
||||
from ... schema import RerankerRequest, RerankerResponse
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class RerankerRequestor(ServiceRequestor):
|
||||
def __init__(
|
||||
self, backend, request_queue, response_queue, timeout,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
super(RerankerRequestor, self).__init__(
|
||||
backend=backend,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
request_schema=RerankerRequest,
|
||||
response_schema=RerankerResponse,
|
||||
subscription = subscriber,
|
||||
consumer_name = consumer,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("reranker")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("reranker")
|
||||
|
||||
def to_request(self, body):
|
||||
return self.request_translator.decode(body)
|
||||
|
||||
def from_response(self, message):
|
||||
return self.response_translator.encode_with_completion(message)
|
||||
|
|
@ -92,7 +92,7 @@ class Operation:
|
|||
# Returns a dict with the appropriate components for the
|
||||
# resource level: {} for SYSTEM, {workspace} for WORKSPACE,
|
||||
# {workspace, flow} for FLOW. Default-fill-in of workspace
|
||||
# from identity.workspace happens here when applicable.
|
||||
# from identity.default_workspace happens here when applicable.
|
||||
extract_resource: Callable[[RequestContext], dict]
|
||||
|
||||
# Build the parameters dict — decision-relevant fields the
|
||||
|
|
@ -141,7 +141,7 @@ def _workspace_from_body(ctx: RequestContext) -> dict:
|
|||
workspace field, defaulting to the caller's bound workspace."""
|
||||
ws = (ctx.body.get("workspace") if isinstance(ctx.body, dict) else "")
|
||||
if not ws and ctx.identity is not None:
|
||||
ws = ctx.identity.workspace
|
||||
ws = ctx.identity.default_workspace
|
||||
return {"workspace": ws}
|
||||
|
||||
|
||||
|
|
@ -188,7 +188,7 @@ def _workspace_param_only(ctx: RequestContext) -> dict:
|
|||
or body.get("workspace")
|
||||
)
|
||||
if not ws and ctx.identity is not None:
|
||||
ws = ctx.identity.workspace
|
||||
ws = ctx.identity.default_workspace
|
||||
return {"workspace": ws or ""}
|
||||
|
||||
|
||||
|
|
@ -311,7 +311,7 @@ register(Operation(
|
|||
))
|
||||
register(Operation(
|
||||
name="list-my-workspaces",
|
||||
capability="workspaces:list-own",
|
||||
capability=AUTHENTICATED,
|
||||
resource_level=ResourceLevel.SYSTEM,
|
||||
extract_resource=_empty_resource,
|
||||
extract_parameters=_no_parameters,
|
||||
|
|
@ -506,18 +506,19 @@ _FLOW_SERVICES = {
|
|||
"text-completion": "llm",
|
||||
"prompt": "llm",
|
||||
"mcp-tool": "mcp",
|
||||
"graph-rag": "graph:read",
|
||||
"document-rag": "documents:read",
|
||||
"graph-rag": "graph-rag:read",
|
||||
"document-rag": "document-rag:read",
|
||||
"embeddings": "embeddings",
|
||||
"graph-embeddings": "graph:read",
|
||||
"document-embeddings": "documents:read",
|
||||
"triples": "graph:read",
|
||||
"graph-embeddings": "graph-embeddings:read",
|
||||
"document-embeddings": "document-embeddings:read",
|
||||
"triples": "triples:read",
|
||||
"rows": "rows:read",
|
||||
"nlp-query": "rows:read",
|
||||
"structured-query": "rows:read",
|
||||
"structured-diag": "rows:read",
|
||||
"row-embeddings": "rows:read",
|
||||
"sparql": "graph:read",
|
||||
"nlp-query": "nlp-query:read",
|
||||
"structured-query": "structured-query:read",
|
||||
"structured-diag": "structured-query:read",
|
||||
"row-embeddings": "row-embeddings:read",
|
||||
"sparql": "sparql:read",
|
||||
"reranker": "reranker",
|
||||
}
|
||||
for _kind, _cap in _FLOW_SERVICES.items():
|
||||
_register_flow_kind("flow-service", _kind, _cap)
|
||||
|
|
@ -525,10 +526,10 @@ for _kind, _cap in _FLOW_SERVICES.items():
|
|||
|
||||
# Streaming import socket endpoints.
|
||||
_FLOW_IMPORTS = {
|
||||
"triples": "graph:write",
|
||||
"graph-embeddings": "graph:write",
|
||||
"document-embeddings": "documents:write",
|
||||
"entity-contexts": "documents:write",
|
||||
"triples": "triples:write",
|
||||
"graph-embeddings": "graph-embeddings:write",
|
||||
"document-embeddings": "document-embeddings:write",
|
||||
"entity-contexts": "entity-contexts:write",
|
||||
"rows": "rows:write",
|
||||
}
|
||||
for _kind, _cap in _FLOW_IMPORTS.items():
|
||||
|
|
@ -537,10 +538,35 @@ for _kind, _cap in _FLOW_IMPORTS.items():
|
|||
|
||||
# Streaming export socket endpoints.
|
||||
_FLOW_EXPORTS = {
|
||||
"triples": "graph:read",
|
||||
"graph-embeddings": "graph:read",
|
||||
"document-embeddings": "documents:read",
|
||||
"entity-contexts": "documents:read",
|
||||
"triples": "triples:read",
|
||||
"graph-embeddings": "graph-embeddings:read",
|
||||
"document-embeddings": "document-embeddings:read",
|
||||
"entity-contexts": "entity-contexts:read",
|
||||
}
|
||||
for _kind, _cap in _FLOW_EXPORTS.items():
|
||||
_register_flow_kind("flow-export", _kind, _cap)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Enterprise IAM operations.
|
||||
#
|
||||
# These are additive — they register alongside the OSS IAM operations.
|
||||
# When the OSS regime receives an unknown operation it returns an error;
|
||||
# when the enterprise regime is running, it handles them.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
for _op in (
|
||||
"create-group", "get-group", "list-groups",
|
||||
"update-group", "delete-group",
|
||||
"add-group-member", "remove-group-member", "list-group-members",
|
||||
"add-group-grant", "remove-group-grant", "list-group-grants",
|
||||
"add-user-grant", "remove-user-grant", "list-user-grants",
|
||||
"resolve-effective-permissions",
|
||||
):
|
||||
register(Operation(
|
||||
name=_op,
|
||||
capability="iam:admin",
|
||||
resource_level=ResourceLevel.SYSTEM,
|
||||
extract_resource=_empty_resource,
|
||||
extract_parameters=_no_parameters,
|
||||
))
|
||||
|
|
|
|||
|
|
@ -28,14 +28,14 @@ class NoAuthHandler:
|
|||
def _default_identity_response(self):
|
||||
return IamResponse(
|
||||
resolved_user_id=self.default_user_id,
|
||||
resolved_workspace=self.default_workspace,
|
||||
resolved_default_workspace=self.default_workspace,
|
||||
resolved_roles=["admin"],
|
||||
)
|
||||
|
||||
def _default_user_record(self):
|
||||
return UserRecord(
|
||||
id=self.default_user_id,
|
||||
workspace=self.default_workspace,
|
||||
default_workspace=self.default_workspace,
|
||||
username=self.default_user_id,
|
||||
name="Anonymous User",
|
||||
roles=["admin"],
|
||||
|
|
|
|||
|
|
@ -58,21 +58,35 @@ AUTHZ_CACHE_TTL_SECONDS = 60
|
|||
_READER_CAPS = {
|
||||
"agent",
|
||||
"graph:read",
|
||||
"triples:read",
|
||||
"sparql:read",
|
||||
"graph-rag:read",
|
||||
"graph-embeddings:read",
|
||||
"documents:read",
|
||||
"document-rag:read",
|
||||
"document-embeddings:read",
|
||||
"entity-contexts:read",
|
||||
"rows:read",
|
||||
"nlp-query:read",
|
||||
"structured-query:read",
|
||||
"row-embeddings:read",
|
||||
"llm",
|
||||
"embeddings",
|
||||
"reranker",
|
||||
"mcp",
|
||||
"config:read",
|
||||
"flows:read",
|
||||
"collections:read",
|
||||
"knowledge:read",
|
||||
"keys:self",
|
||||
"workspaces:list-own",
|
||||
}
|
||||
|
||||
_WRITER_CAPS = _READER_CAPS | {
|
||||
"graph:write",
|
||||
"triples:write",
|
||||
"graph-embeddings:write",
|
||||
"document-embeddings:write",
|
||||
"entity-contexts:write",
|
||||
"documents:write",
|
||||
"rows:write",
|
||||
"collections:write",
|
||||
|
|
@ -369,7 +383,7 @@ class IamService:
|
|||
) = row
|
||||
return UserRecord(
|
||||
id=id or "",
|
||||
workspace=workspace or "",
|
||||
default_workspace=workspace or "",
|
||||
username=username or "",
|
||||
name=name or "",
|
||||
email=email or "",
|
||||
|
|
@ -582,14 +596,8 @@ class IamService:
|
|||
if not v.password:
|
||||
return _err("auth-failed", "password required")
|
||||
|
||||
# Login accepts an optional workspace parameter. If omitted
|
||||
# we use the default workspace (OSS single-workspace
|
||||
# assumption). Multi-workspace enterprise editions swap in a
|
||||
# resolver that looks across the caller's permitted set.
|
||||
workspace = v.workspace or DEFAULT_WORKSPACE
|
||||
|
||||
user_id = await self.table_store.get_user_id_by_username(
|
||||
workspace, v.username,
|
||||
v.username,
|
||||
)
|
||||
if not user_id:
|
||||
return _err("auth-failed", "no such user")
|
||||
|
|
@ -610,7 +618,10 @@ class IamService:
|
|||
):
|
||||
return _err("auth-failed", "bad credentials")
|
||||
|
||||
ws_row = await self.table_store.get_workspace(ws)
|
||||
# JWT workspace: login request override, or the user's default.
|
||||
jwt_workspace = v.workspace or ws
|
||||
|
||||
ws_row = await self.table_store.get_workspace(jwt_workspace)
|
||||
if ws_row is None or not ws_row[2]:
|
||||
return _err("auth-failed", "workspace disabled")
|
||||
|
||||
|
|
@ -618,14 +629,10 @@ class IamService:
|
|||
|
||||
now_ts = int(_now_dt().timestamp())
|
||||
exp_ts = now_ts + JWT_TTL_SECONDS
|
||||
# Per the IAM contract the gateway never reads policy state
|
||||
# from the credential — roles stay server-side, reachable
|
||||
# only via authorise(). JWT carries identity + workspace
|
||||
# binding only.
|
||||
claims = {
|
||||
"iss": JWT_ISSUER,
|
||||
"sub": id,
|
||||
"workspace": ws,
|
||||
"default_workspace": jwt_workspace,
|
||||
"iat": now_ts,
|
||||
"exp": exp_ts,
|
||||
}
|
||||
|
|
@ -864,20 +871,15 @@ class IamService:
|
|||
|
||||
# user_row indices match get_user columns. Username is [2].
|
||||
username = user_row[2]
|
||||
record_workspace = user_row[1]
|
||||
|
||||
# Revoke all API keys.
|
||||
key_rows = await self.table_store.list_api_keys_by_user(v.user_id)
|
||||
for kr in key_rows:
|
||||
await self.table_store.delete_api_key(kr[0])
|
||||
|
||||
# Remove username lookup — keyed on (workspace, username),
|
||||
# so use the resolved workspace from the user record rather
|
||||
# than relying on the caller-supplied filter.
|
||||
# Remove global username lookup.
|
||||
if username:
|
||||
await self.table_store.delete_username_lookup(
|
||||
record_workspace, username,
|
||||
)
|
||||
await self.table_store.delete_username_lookup(username)
|
||||
|
||||
# Remove user record.
|
||||
await self.table_store.delete_user(v.user_id)
|
||||
|
|
@ -1096,13 +1098,15 @@ class IamService:
|
|||
return _err("auth-failed", "owning user disabled")
|
||||
|
||||
# Workspace-disabled check.
|
||||
ws_row = await self.table_store.get_workspace(user.workspace)
|
||||
ws_row = await self.table_store.get_workspace(
|
||||
user.default_workspace,
|
||||
)
|
||||
if ws_row is None or not ws_row[2]:
|
||||
return _err("auth-failed", "owning workspace disabled")
|
||||
|
||||
return IamResponse(
|
||||
resolved_user_id=user.id,
|
||||
resolved_workspace=user.workspace,
|
||||
resolved_default_workspace=user.default_workspace,
|
||||
resolved_roles=list(user.roles),
|
||||
)
|
||||
|
||||
|
|
@ -1129,9 +1133,9 @@ class IamService:
|
|||
if ws is None or not ws[2]:
|
||||
return _err("not-found", "workspace not found or disabled")
|
||||
|
||||
# Uniqueness on username within workspace.
|
||||
# Global username uniqueness.
|
||||
existing = await self.table_store.get_user_id_by_username(
|
||||
v.workspace, v.user.username,
|
||||
v.user.username,
|
||||
)
|
||||
if existing:
|
||||
return _err("duplicate", "username already exists")
|
||||
|
|
@ -1303,8 +1307,9 @@ class IamService:
|
|||
return False, AUTHZ_CACHE_TTL_SECONDS
|
||||
|
||||
# user_row layout:
|
||||
# 0:id 1:workspace 2:username 3:name 4:email 5:password_hash
|
||||
# 6:roles 7:enabled 8:must_change_password 9:created
|
||||
# 0:id 1:default_workspace 2:username 3:name 4:email
|
||||
# 5:password_hash 6:roles 7:enabled 8:must_change_password
|
||||
# 9:created
|
||||
if not user_row[7]: # disabled
|
||||
return False, AUTHZ_CACHE_TTL_SECONDS
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import logging
|
|||
|
||||
from .... exceptions import TooManyRequests, LlmError
|
||||
from .... base import LlmService, LlmResult, LlmChunk
|
||||
from . variants import get_variant, DEFAULT_VARIANT, VARIANTS
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -21,6 +22,7 @@ default_temperature = 0.0
|
|||
default_max_output = 4096
|
||||
default_api_key = os.getenv("OPENAI_TOKEN")
|
||||
default_base_url = os.getenv("OPENAI_BASE_URL")
|
||||
default_thinking = "off"
|
||||
|
||||
if default_base_url is None or default_base_url == "":
|
||||
default_base_url = "https://api.openai.com/v1"
|
||||
|
|
@ -34,10 +36,15 @@ class Processor(LlmService):
|
|||
base_url = params.get("url", default_base_url)
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
max_output = params.get("max_output", default_max_output)
|
||||
thinking = params.get("thinking", default_thinking)
|
||||
variant_name = params.get("variant", DEFAULT_VARIANT)
|
||||
|
||||
if not api_key:
|
||||
api_key = "not-set"
|
||||
|
||||
self.variant = get_variant(variant_name)
|
||||
self.thinking = thinking
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"model": model,
|
||||
|
|
@ -56,13 +63,28 @@ class Processor(LlmService):
|
|||
else:
|
||||
self.openai = OpenAI(api_key=api_key)
|
||||
|
||||
logger.info("OpenAI LLM service initialized")
|
||||
logger.info(
|
||||
f"OpenAI LLM service initialized "
|
||||
f"(variant={self.variant.name}, thinking={self.thinking})"
|
||||
)
|
||||
|
||||
def _build_kwargs(self, model_name, temperature):
|
||||
"""Build API call kwargs using the active variant."""
|
||||
return self.variant.completion_kwargs(
|
||||
max_output=self.max_output,
|
||||
temperature=temperature,
|
||||
thinking=self.thinking,
|
||||
)
|
||||
|
||||
def _extract_content(self, message):
|
||||
"""Extract visible content from a response message."""
|
||||
if hasattr(self.variant, "extract_content"):
|
||||
return self.variant.extract_content(message)
|
||||
return message.content
|
||||
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
|
|
@ -72,8 +94,8 @@ class Processor(LlmService):
|
|||
|
||||
try:
|
||||
|
||||
resp = self.openai.chat.completions.create(
|
||||
model=model_name,
|
||||
api_kwargs = self._build_kwargs(model_name, effective_temperature)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
|
|
@ -84,19 +106,26 @@ class Processor(LlmService):
|
|||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
temperature=effective_temperature,
|
||||
max_completion_tokens=self.max_output,
|
||||
]
|
||||
|
||||
resp = self.variant.create_completion(
|
||||
self.openai, model_name, messages, **api_kwargs,
|
||||
)
|
||||
|
||||
inputtokens = resp.usage.prompt_tokens
|
||||
outputtokens = resp.usage.completion_tokens
|
||||
logger.debug(f"LLM response: {resp.choices[0].message.content}")
|
||||
|
||||
content = self._extract_content(resp.choices[0].message)
|
||||
thinking = self.variant.extract_thinking(resp.choices[0].message)
|
||||
|
||||
logger.debug(f"LLM response: {content}")
|
||||
if thinking:
|
||||
logger.debug(f"LLM thinking: {thinking[:200]}...")
|
||||
logger.info(f"Input Tokens: {inputtokens}")
|
||||
logger.info(f"Output Tokens: {outputtokens}")
|
||||
|
||||
resp = LlmResult(
|
||||
text = resp.choices[0].message.content,
|
||||
text = content,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = model_name
|
||||
|
|
@ -136,9 +165,7 @@ class Processor(LlmService):
|
|||
Stream content generation from OpenAI.
|
||||
Yields LlmChunk objects with is_final=True on the last chunk.
|
||||
"""
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model (streaming): {model_name}")
|
||||
|
|
@ -147,8 +174,8 @@ class Processor(LlmService):
|
|||
prompt = system + "\n\n" + prompt
|
||||
|
||||
try:
|
||||
response = self.openai.chat.completions.create(
|
||||
model=model_name,
|
||||
api_kwargs = self._build_kwargs(model_name, effective_temperature)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
|
|
@ -159,18 +186,14 @@ class Processor(LlmService):
|
|||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
temperature=effective_temperature,
|
||||
max_completion_tokens=self.max_output,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True}
|
||||
)
|
||||
]
|
||||
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
|
||||
# Stream chunks
|
||||
for chunk in response:
|
||||
async for chunk in self.variant.create_completion_stream(
|
||||
self.openai, model_name, messages, **api_kwargs,
|
||||
):
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
yield LlmChunk(
|
||||
text=chunk.choices[0].delta.content,
|
||||
|
|
@ -254,6 +277,20 @@ class Processor(LlmService):
|
|||
help=f'LLM max output tokens (default: {default_max_output})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--thinking',
|
||||
choices=["off", "low", "medium", "high"],
|
||||
default=default_thinking,
|
||||
help=f'Thinking/reasoning effort level (default: {default_thinking})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--variant',
|
||||
choices=sorted(VARIANTS.keys()),
|
||||
default=DEFAULT_VARIANT,
|
||||
help=f'API variant (default: {DEFAULT_VARIANT})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,219 @@
|
|||
"""
|
||||
OpenAI API variant profiles.
|
||||
|
||||
Different providers expose OpenAI-compatible APIs with subtle differences
|
||||
in parameter names, thinking/reasoning support, and temperature handling.
|
||||
Each variant encapsulates those quirks so the processor doesn't need
|
||||
provider-specific conditionals.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Variant:
|
||||
"""Base variant — defines the interface all variants implement."""
|
||||
|
||||
name = None
|
||||
token_param = "max_completion_tokens"
|
||||
temperature_with_thinking = False
|
||||
|
||||
def completion_kwargs(self, max_output, temperature, thinking):
|
||||
"""Build provider-specific kwargs for chat.completions.create().
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_output : int
|
||||
Configured max output tokens.
|
||||
temperature : float
|
||||
Configured temperature.
|
||||
thinking : str
|
||||
Thinking effort level: "off", "low", "medium", "high".
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Extra kwargs to spread into the API call.
|
||||
"""
|
||||
kwargs = {self.token_param: max_output}
|
||||
|
||||
if thinking != "off":
|
||||
kwargs.update(self.thinking_kwargs(thinking))
|
||||
if not self.temperature_with_thinking:
|
||||
kwargs["temperature"] = 1.0
|
||||
else:
|
||||
kwargs["temperature"] = temperature
|
||||
else:
|
||||
kwargs["temperature"] = temperature
|
||||
|
||||
return kwargs
|
||||
|
||||
def thinking_kwargs(self, effort):
|
||||
"""Return kwargs to enable thinking at the given effort level."""
|
||||
return {}
|
||||
|
||||
def extract_thinking(self, message):
|
||||
"""Extract thinking/reasoning content from a response message."""
|
||||
return getattr(message, "reasoning_content", None)
|
||||
|
||||
def extract_thinking_stream(self, delta):
|
||||
"""Extract thinking content from a streaming delta."""
|
||||
return getattr(delta, "reasoning_content", None)
|
||||
|
||||
def create_completion(self, client, model, messages, **kwargs):
|
||||
"""Call the completions API. Override for non-standard SDKs."""
|
||||
return client.chat.completions.create(
|
||||
model=model, messages=messages, **kwargs,
|
||||
)
|
||||
|
||||
async def create_completion_stream(self, client, model, messages, **kwargs):
|
||||
"""Call the streaming completions API. Override for non-standard SDKs."""
|
||||
for chunk in client.chat.completions.create(
|
||||
model=model, messages=messages, stream=True,
|
||||
stream_options={"include_usage": True}, **kwargs,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
class OpenAIVariant(Variant):
|
||||
"""Standard OpenAI API (GPT-4o, o1, o3, etc.)."""
|
||||
|
||||
name = "openai"
|
||||
token_param = "max_completion_tokens"
|
||||
temperature_with_thinking = False
|
||||
|
||||
def thinking_kwargs(self, effort):
|
||||
return {"reasoning_effort": effort}
|
||||
|
||||
|
||||
class DeepSeekVariant(Variant):
|
||||
"""DeepSeek API (R1, V3, etc.)."""
|
||||
|
||||
name = "deepseek"
|
||||
token_param = "max_completion_tokens"
|
||||
temperature_with_thinking = True
|
||||
|
||||
def completion_kwargs(self, max_output, temperature, thinking):
|
||||
enabled = "enabled" if thinking != "off" else "disabled"
|
||||
kwargs = {
|
||||
self.token_param: max_output,
|
||||
"temperature": temperature,
|
||||
"extra_body": {
|
||||
"thinking": {"type": enabled},
|
||||
},
|
||||
}
|
||||
return kwargs
|
||||
|
||||
def thinking_kwargs(self, effort):
|
||||
return {}
|
||||
|
||||
|
||||
class DashScopeVariant(Variant):
|
||||
"""Alibaba Cloud DashScope API (Qwen models)."""
|
||||
|
||||
name = "dashscope"
|
||||
token_param = "max_completion_tokens"
|
||||
temperature_with_thinking = True
|
||||
|
||||
def completion_kwargs(self, max_output, temperature, thinking):
|
||||
enabled = thinking != "off"
|
||||
return {
|
||||
self.token_param: max_output,
|
||||
"temperature": temperature,
|
||||
"extra_body": {
|
||||
"enable_thinking": enabled,
|
||||
},
|
||||
}
|
||||
|
||||
def thinking_kwargs(self, effort):
|
||||
return {}
|
||||
|
||||
|
||||
class QwenVariant(DashScopeVariant):
|
||||
"""Qwen — alias for DashScope."""
|
||||
|
||||
name = "qwen"
|
||||
|
||||
|
||||
class MistralVariant(Variant):
|
||||
"""Mistral API (Mistral Large, etc.)."""
|
||||
|
||||
name = "mistral"
|
||||
token_param = "max_tokens"
|
||||
temperature_with_thinking = False
|
||||
|
||||
def thinking_kwargs(self, effort):
|
||||
return {"reasoning_effort": effort}
|
||||
|
||||
|
||||
class GlmVariant(Variant):
|
||||
"""GLM / Zhipu AI API (GLM-4, GLM-4.7, etc.)."""
|
||||
|
||||
name = "glm"
|
||||
token_param = "max_tokens"
|
||||
temperature_with_thinking = True
|
||||
|
||||
def completion_kwargs(self, max_output, temperature, thinking):
|
||||
enabled = "enabled" if thinking != "off" else "disabled"
|
||||
kwargs = {
|
||||
self.token_param: max_output,
|
||||
"temperature": temperature,
|
||||
"extra_body": {
|
||||
"thinking": {"type": enabled},
|
||||
},
|
||||
}
|
||||
return kwargs
|
||||
|
||||
def thinking_kwargs(self, effort):
|
||||
return {}
|
||||
|
||||
|
||||
class LlamaVariant(Variant):
|
||||
"""Llama models via OpenAI-compatible servers (vLLM, Ollama, etc.).
|
||||
|
||||
Thinking is typically always-on or always-off depending on the model.
|
||||
When present, thinking appears inline as <think>...</think> tags.
|
||||
"""
|
||||
|
||||
name = "llama"
|
||||
token_param = "max_tokens"
|
||||
temperature_with_thinking = True
|
||||
|
||||
def thinking_kwargs(self, effort):
|
||||
return {}
|
||||
|
||||
def extract_thinking(self, message):
|
||||
content = message.content or ""
|
||||
match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
|
||||
return match.group(1).strip() if match else None
|
||||
|
||||
def extract_content(self, message):
|
||||
"""Strip think tags from visible content."""
|
||||
content = message.content or ""
|
||||
return re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
|
||||
|
||||
|
||||
VARIANTS = {
|
||||
"openai": OpenAIVariant,
|
||||
"deepseek": DeepSeekVariant,
|
||||
"qwen": QwenVariant,
|
||||
"mistral": MistralVariant,
|
||||
"dashscope": DashScopeVariant,
|
||||
"glm": GlmVariant,
|
||||
"llama": LlamaVariant,
|
||||
}
|
||||
|
||||
DEFAULT_VARIANT = "openai"
|
||||
|
||||
|
||||
def get_variant(name):
|
||||
"""Look up a variant by name, raising ValueError if unknown."""
|
||||
cls = VARIANTS.get(name)
|
||||
if cls is None:
|
||||
raise ValueError(
|
||||
f"Unknown variant {name!r}. "
|
||||
f"Available: {', '.join(sorted(VARIANTS))}"
|
||||
)
|
||||
return cls()
|
||||
1
trustgraph-flow/trustgraph/reranker/__init__.py
Normal file
1
trustgraph-flow/trustgraph/reranker/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
|
||||
from . processor import *
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . processor import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
109
trustgraph-flow/trustgraph/reranker/flashrank/processor.py
Normal file
109
trustgraph-flow/trustgraph/reranker/flashrank/processor.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
|
||||
"""
|
||||
Reranker service using flashrank.
|
||||
Scores query-document pairs and returns the top results ranked by
|
||||
relevance.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from ... base import RerankerService
|
||||
from ... schema import RerankerResult
|
||||
|
||||
from flashrank import Ranker, RerankRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "reranker"
|
||||
|
||||
default_model = "ms-marco-MiniLM-L-12-v2"
|
||||
|
||||
class Processor(RerankerService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
model = params.get("model", default_model)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | { "model": model }
|
||||
)
|
||||
|
||||
self.default_model = model
|
||||
|
||||
self.cached_model_name = None
|
||||
self.ranker = None
|
||||
|
||||
self._load_model(model)
|
||||
|
||||
def _load_model(self, model_name):
|
||||
if self.cached_model_name != model_name:
|
||||
logger.info(f"Loading flashrank model: {model_name}")
|
||||
self.ranker = Ranker(model_name=model_name)
|
||||
self.cached_model_name = model_name
|
||||
logger.info(f"flashrank model {model_name} loaded successfully")
|
||||
else:
|
||||
logger.debug(f"Using cached model: {model_name}")
|
||||
|
||||
def _run_rerank(self, query, passages):
|
||||
request = RerankRequest(query=query, passages=passages)
|
||||
return self.ranker.rerank(request)
|
||||
|
||||
async def on_rerank(self, queries, documents, limit, model=None):
|
||||
|
||||
if not queries or not documents:
|
||||
return []
|
||||
|
||||
use_model = model or self.default_model
|
||||
|
||||
if self.cached_model_name != use_model:
|
||||
await asyncio.to_thread(self._load_model, use_model)
|
||||
|
||||
passages = [
|
||||
{"id": d.document_id, "text": d.document_text}
|
||||
for d in documents
|
||||
]
|
||||
|
||||
best_scores = {}
|
||||
|
||||
for q in queries:
|
||||
ranked = await asyncio.to_thread(
|
||||
self._run_rerank, q.query_text, passages,
|
||||
)
|
||||
|
||||
for r in ranked:
|
||||
doc_id = r["id"]
|
||||
score = r["score"]
|
||||
score = float(score)
|
||||
if doc_id not in best_scores or score > best_scores[doc_id][1]:
|
||||
best_scores[doc_id] = (q.query_id, score)
|
||||
|
||||
results = sorted(
|
||||
best_scores.items(),
|
||||
key=lambda x: x[1][1],
|
||||
reverse=True,
|
||||
)[:limit]
|
||||
|
||||
return [
|
||||
RerankerResult(
|
||||
document_id=doc_id,
|
||||
query_id=query_id,
|
||||
score=score,
|
||||
)
|
||||
for doc_id, (query_id, score) in results
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
RerankerService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default=default_model,
|
||||
help=f'Reranker model (default: {default_model})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
@ -9,31 +9,41 @@ from trustgraph.provenance import (
|
|||
docrag_question_uri,
|
||||
docrag_grounding_uri,
|
||||
docrag_exploration_uri,
|
||||
docrag_focus_uri,
|
||||
docrag_synthesis_uri,
|
||||
docrag_question_triples,
|
||||
grounding_triples,
|
||||
docrag_exploration_triples,
|
||||
docrag_chunk_selection_triples,
|
||||
docrag_synthesis_triples,
|
||||
set_graph,
|
||||
GRAPH_RETRIEVAL,
|
||||
)
|
||||
|
||||
from .rerank import RerankCandidate, mmr_select
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# When the caller does not specify a fetch_limit, reranking over-fetches this
|
||||
# many times the final doc_limit as the candidate pool, so the cross-encoder can
|
||||
# recover relevant chunks the bi-encoder ranked just outside the top doc_limit.
|
||||
# This is only the fallback default: an explicit fetch_limit overrides it.
|
||||
OVERFETCH_FACTOR = 3
|
||||
|
||||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
||||
class Query:
|
||||
|
||||
def __init__(
|
||||
self, rag, workspace, collection, verbose,
|
||||
doc_limit=20, track_usage=None,
|
||||
fetch_limit=20, track_usage=None,
|
||||
):
|
||||
self.rag = rag
|
||||
self.workspace = workspace
|
||||
self.collection = collection
|
||||
self.verbose = verbose
|
||||
self.doc_limit = doc_limit
|
||||
self.fetch_limit = fetch_limit
|
||||
self.track_usage = track_usage
|
||||
|
||||
async def extract_concepts(self, query):
|
||||
|
|
@ -91,7 +101,7 @@ class Query:
|
|||
|
||||
# Query chunk matches for each concept concurrently
|
||||
per_concept_limit = max(
|
||||
1, self.doc_limit // len(vectors)
|
||||
1, self.fetch_limit // len(vectors)
|
||||
)
|
||||
|
||||
async def query_concept(vec):
|
||||
|
|
@ -140,7 +150,10 @@ class DocumentRag:
|
|||
def __init__(
|
||||
self, prompt_client, embeddings_client, doc_embeddings_client,
|
||||
fetch_chunk,
|
||||
reranker_client=None,
|
||||
verbose=False,
|
||||
rerank_diversity_mode="none",
|
||||
rerank_diversity_lambda=0.7,
|
||||
):
|
||||
|
||||
self.verbose = verbose
|
||||
|
|
@ -150,12 +163,18 @@ class DocumentRag:
|
|||
self.doc_embeddings_client = doc_embeddings_client
|
||||
self.fetch_chunk = fetch_chunk
|
||||
|
||||
# Optional cross-encoder reranker. When None, the retrieval path is
|
||||
# byte-identical to the pre-reranker behaviour.
|
||||
self.reranker_client = reranker_client
|
||||
self.rerank_diversity_mode = rerank_diversity_mode
|
||||
self.rerank_diversity_lambda = rerank_diversity_lambda
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("DocumentRag initialized")
|
||||
|
||||
async def query(
|
||||
self, query, workspace="default", collection="default",
|
||||
doc_limit=20, streaming=False, chunk_callback=None,
|
||||
doc_limit=20, fetch_limit=0, streaming=False, chunk_callback=None,
|
||||
explain_callback=None, save_answer_callback=None,
|
||||
):
|
||||
"""
|
||||
|
|
@ -165,7 +184,10 @@ class DocumentRag:
|
|||
query: The query string
|
||||
workspace: Workspace for isolation (also scopes chunk lookup)
|
||||
collection: Collection identifier
|
||||
doc_limit: Max chunks to retrieve
|
||||
doc_limit: Chunks selected into the synthesis prompt (after rerank)
|
||||
fetch_limit: Candidate pool fetched from the vector store before
|
||||
reranking. 0 = derive (OVERFETCH_FACTOR x doc_limit when a
|
||||
reranker is wired, else doc_limit).
|
||||
streaming: Enable streaming LLM response
|
||||
chunk_callback: async def callback(chunk, end_of_stream) for streaming
|
||||
explain_callback: async def callback(triples, explain_id) for explainability
|
||||
|
|
@ -197,6 +219,7 @@ class DocumentRag:
|
|||
q_uri = docrag_question_uri(session_id)
|
||||
gnd_uri = docrag_grounding_uri(session_id)
|
||||
exp_uri = docrag_exploration_uri(session_id)
|
||||
foc_uri = docrag_focus_uri(session_id)
|
||||
syn_uri = docrag_synthesis_uri(session_id)
|
||||
|
||||
timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
|
|
@ -209,10 +232,21 @@ class DocumentRag:
|
|||
)
|
||||
await explain_callback(q_triples, q_uri)
|
||||
|
||||
# Resolve the candidate-pool size fetched from the vector store. When a
|
||||
# reranker is wired, honour an explicit fetch_limit; if unset, fall back
|
||||
# to the OVERFETCH_FACTOR heuristic. Never fetch fewer than doc_limit,
|
||||
# else the rerank could not fill the prompt. Without a reranker, fetch
|
||||
# doc_limit as before (byte-identical behaviour).
|
||||
if self.reranker_client is not None:
|
||||
fl = fetch_limit or (OVERFETCH_FACTOR * doc_limit)
|
||||
fetch_count = max(fl, doc_limit)
|
||||
else:
|
||||
fetch_count = doc_limit
|
||||
|
||||
q = Query(
|
||||
rag=self, workspace=workspace, collection=collection,
|
||||
verbose=self.verbose,
|
||||
doc_limit=doc_limit, track_usage=track_usage,
|
||||
fetch_limit=fetch_count, track_usage=track_usage,
|
||||
)
|
||||
|
||||
# Extract concepts from query (grounding step)
|
||||
|
|
@ -235,6 +269,7 @@ class DocumentRag:
|
|||
docs, chunk_ids = await q.get_docs(concepts)
|
||||
|
||||
# Emit exploration explainability after chunks retrieved
|
||||
# (full candidate set, before any reranking)
|
||||
if explain_callback:
|
||||
exp_triples = set_graph(
|
||||
docrag_exploration_triples(exp_uri, gnd_uri, len(chunk_ids), chunk_ids),
|
||||
|
|
@ -242,6 +277,89 @@ class DocumentRag:
|
|||
)
|
||||
await explain_callback(exp_triples, exp_uri)
|
||||
|
||||
# Optional cross-encoder reranking pass between retrieval and
|
||||
# synthesis. Mirrors GraphRAG's reranker usage but with a single
|
||||
# query (the question). When no reranker is wired, this block is
|
||||
# skipped entirely and behaviour is byte-identical to before.
|
||||
reranked = False
|
||||
if self.reranker_client is not None and docs:
|
||||
use_diversity = self.rerank_diversity_mode == "mmr"
|
||||
|
||||
# Without diversity selection, preserve the existing #1011
|
||||
# behavior: ask the reranker for exactly doc_limit results.
|
||||
#
|
||||
# With diversity selection enabled, ask the reranker to score the
|
||||
# full fetched candidate pool first, then let MMR choose the final
|
||||
# doc_limit context set.
|
||||
rerank_limit = len(docs) if use_diversity else doc_limit
|
||||
|
||||
results = await self.reranker_client.rerank(
|
||||
queries=[{"id": "0", "text": query}],
|
||||
documents=[
|
||||
{"id": str(i), "text": d} for i, d in enumerate(docs)
|
||||
],
|
||||
limit=rerank_limit,
|
||||
)
|
||||
|
||||
source_docs = docs
|
||||
source_chunk_ids = chunk_ids
|
||||
|
||||
if use_diversity:
|
||||
candidates = [
|
||||
RerankCandidate(
|
||||
index=int(r.document_id),
|
||||
chunk_id=source_chunk_ids[int(r.document_id)],
|
||||
text=source_docs[int(r.document_id)],
|
||||
reranker_score=r.score,
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
|
||||
selected_candidates = mmr_select(
|
||||
candidates,
|
||||
limit=doc_limit,
|
||||
lambda_mult=self.rerank_diversity_lambda,
|
||||
)
|
||||
|
||||
docs = [candidate.text for candidate in selected_candidates]
|
||||
chunk_ids = [
|
||||
candidate.chunk_id for candidate in selected_candidates
|
||||
]
|
||||
|
||||
selected_chunks_with_scores = [
|
||||
{
|
||||
"chunk_id": candidate.chunk_id,
|
||||
"score": candidate.reranker_score,
|
||||
}
|
||||
for candidate in selected_candidates
|
||||
]
|
||||
|
||||
else:
|
||||
# results are sorted desc by score and truncated to limit by the
|
||||
# reranker service, so order gives the surviving top-N directly.
|
||||
order = [int(r.document_id) for r in results]
|
||||
docs = [source_docs[i] for i in order]
|
||||
chunk_ids = [source_chunk_ids[i] for i in order]
|
||||
|
||||
selected_chunks_with_scores = [
|
||||
{"chunk_id": chunk_ids[i], "score": r.score}
|
||||
for i, r in enumerate(results)
|
||||
]
|
||||
|
||||
reranked = True
|
||||
|
||||
# Emit chunk-selection (focus) explainability: surviving chunks
|
||||
# with their cross-encoder scores, derived from exploration.
|
||||
if explain_callback:
|
||||
foc_triples = set_graph(
|
||||
docrag_chunk_selection_triples(
|
||||
foc_uri, exp_uri,
|
||||
selected_chunks_with_scores, session_id,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(foc_triples, foc_uri)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Invoking LLM...")
|
||||
logger.debug(f"Documents: {docs}")
|
||||
|
|
@ -291,9 +409,15 @@ class DocumentRag:
|
|||
logger.warning(f"Failed to save answer to librarian: {e}")
|
||||
synthesis_doc_id = None
|
||||
|
||||
# When reranking ran, synthesis derives from the focus (the
|
||||
# reranked chunks actually fed to the LLM), as GraphRAG always does.
|
||||
# When no reranker is wired, there is no focus stage, so synthesis
|
||||
# derives from exploration (the unchanged no-op lineage) - a
|
||||
# deliberate divergence from GraphRAG's always-on focus.
|
||||
syn_parent = foc_uri if reranked else exp_uri
|
||||
syn_triples = set_graph(
|
||||
docrag_synthesis_triples(
|
||||
syn_uri, exp_uri,
|
||||
syn_uri, syn_parent,
|
||||
document_id=synthesis_doc_id,
|
||||
in_token=synthesis_result.in_token if synthesis_result else None,
|
||||
out_token=synthesis_result.out_token if synthesis_result else None,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from . document_rag import DocumentRag
|
|||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
||||
from ... base import DocumentEmbeddingsClientSpec
|
||||
from ... base import RerankerClientSpec
|
||||
from ... base import LibrarianSpec
|
||||
|
||||
# Module logger
|
||||
|
|
@ -28,14 +29,27 @@ class Processor(FlowProcessor):
|
|||
|
||||
doc_limit = params.get("doc_limit", 5)
|
||||
|
||||
# Instance-default candidate-pool size fetched before cross-encoder
|
||||
# reranking; the rerank step narrows it back down to doc_limit for the
|
||||
# LLM. 0 means the core derives it (OVERFETCH_FACTOR x doc_limit).
|
||||
fetch_limit = params.get("fetch_limit", 0)
|
||||
rerank_diversity_mode = params.get("rerank_diversity_mode", "none")
|
||||
rerank_diversity_lambda = params.get("rerank_diversity_lambda", 0.7)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"doc_limit": doc_limit,
|
||||
"fetch_limit": fetch_limit,
|
||||
"rerank_diversity_mode": rerank_diversity_mode,
|
||||
"rerank_diversity_lambda": rerank_diversity_lambda,
|
||||
}
|
||||
)
|
||||
|
||||
self.doc_limit = doc_limit
|
||||
self.fetch_limit = fetch_limit
|
||||
self.rerank_diversity_mode = rerank_diversity_mode
|
||||
self.rerank_diversity_lambda = rerank_diversity_lambda
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
|
|
@ -66,6 +80,13 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
RerankerClientSpec(
|
||||
request_name = "reranker-request",
|
||||
response_name = "reranker-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "response",
|
||||
|
|
@ -105,7 +126,10 @@ class Processor(FlowProcessor):
|
|||
doc_embeddings_client = flow("document-embeddings-request"),
|
||||
prompt_client = flow("prompt-request"),
|
||||
fetch_chunk = fetch_chunk,
|
||||
reranker_client = flow("reranker-request"),
|
||||
verbose=True,
|
||||
rerank_diversity_mode=self.rerank_diversity_mode,
|
||||
rerank_diversity_lambda=self.rerank_diversity_lambda,
|
||||
)
|
||||
|
||||
if v.doc_limit:
|
||||
|
|
@ -113,6 +137,13 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
doc_limit = self.doc_limit
|
||||
|
||||
# Candidate-pool size: per-request override, else the instance
|
||||
# default; 0 lets the core derive it from doc_limit.
|
||||
if v.fetch_limit:
|
||||
fetch_limit = v.fetch_limit
|
||||
else:
|
||||
fetch_limit = self.fetch_limit
|
||||
|
||||
async def send_explainability(triples, explain_id):
|
||||
await flow("explainability").send(Triples(
|
||||
metadata=Metadata(
|
||||
|
|
@ -163,6 +194,7 @@ class Processor(FlowProcessor):
|
|||
workspace=flow.workspace,
|
||||
collection=v.collection,
|
||||
doc_limit=doc_limit,
|
||||
fetch_limit=fetch_limit,
|
||||
streaming=True,
|
||||
chunk_callback=send_chunk,
|
||||
explain_callback=send_explainability,
|
||||
|
|
@ -188,6 +220,7 @@ class Processor(FlowProcessor):
|
|||
workspace=flow.workspace,
|
||||
collection=v.collection,
|
||||
doc_limit=doc_limit,
|
||||
fetch_limit=fetch_limit,
|
||||
explain_callback=send_explainability,
|
||||
save_answer_callback=save_answer,
|
||||
)
|
||||
|
|
@ -243,6 +276,29 @@ class Processor(FlowProcessor):
|
|||
help=f'Default document fetch limit (default: 10)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--fetch-limit',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Candidate chunks to fetch from the vector store and rerank '
|
||||
'before keeping the top doc-limit for the LLM '
|
||||
'(default: derive from doc-limit)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--rerank-diversity-mode',
|
||||
choices=['none', 'mmr'],
|
||||
default='none',
|
||||
help='Optional diversity-aware selection after reranking (default: none)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--rerank-diversity-lambda',
|
||||
type=float,
|
||||
default=0.7,
|
||||
help='MMR relevance/diversity tradeoff, higher values prefer relevance'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
142
trustgraph-flow/trustgraph/retrieval/document_rag/rerank.py
Normal file
142
trustgraph-flow/trustgraph/retrieval/document_rag/rerank.py
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
import re
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import List, Sequence, Set
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RerankCandidate:
|
||||
"""
|
||||
Candidate chunk after cross-encoder reranking.
|
||||
|
||||
reranker_score is the raw score returned by the reranker backend. It may
|
||||
not be normalized, so MMR should use normalized_score instead.
|
||||
"""
|
||||
index: int
|
||||
chunk_id: str
|
||||
text: str
|
||||
reranker_score: float
|
||||
normalized_score: float = 0.0
|
||||
|
||||
|
||||
_TOKEN_RE = re.compile(r"[A-Za-z0-9_]+")
|
||||
|
||||
|
||||
def _clamp01(value: float) -> float:
|
||||
return max(0.0, min(1.0, value))
|
||||
|
||||
|
||||
def _token_set(text: str) -> Set[str]:
|
||||
return set(token.lower() for token in _TOKEN_RE.findall(text or ""))
|
||||
|
||||
|
||||
def _jaccard(a: str, b: str) -> float:
|
||||
a_tokens = _token_set(a)
|
||||
b_tokens = _token_set(b)
|
||||
|
||||
if not a_tokens or not b_tokens:
|
||||
return 0.0
|
||||
|
||||
return len(a_tokens & b_tokens) / len(a_tokens | b_tokens)
|
||||
|
||||
|
||||
def normalize_candidate_scores(
|
||||
candidates: Sequence[RerankCandidate],
|
||||
) -> List[RerankCandidate]:
|
||||
"""
|
||||
Min-max normalize reranker scores within the current candidate set.
|
||||
|
||||
Reranker backends may return different score scales: probabilities,
|
||||
logits, or prompt-defined scores. MMR needs a stable [0, 1] relevance
|
||||
signal, so normalize per candidate set instead of assuming a global range.
|
||||
"""
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
scores = [float(candidate.reranker_score) for candidate in candidates]
|
||||
min_score = min(scores)
|
||||
max_score = max(scores)
|
||||
|
||||
if max_score == min_score:
|
||||
return [
|
||||
replace(candidate, normalized_score=0.5)
|
||||
for candidate in candidates
|
||||
]
|
||||
|
||||
score_range = max_score - min_score
|
||||
|
||||
return [
|
||||
replace(
|
||||
candidate,
|
||||
normalized_score=(float(candidate.reranker_score) - min_score) / score_range,
|
||||
)
|
||||
for candidate in candidates
|
||||
]
|
||||
|
||||
|
||||
def _pair_diversity_penalty(
|
||||
candidate: RerankCandidate,
|
||||
selected: RerankCandidate,
|
||||
token_overlap_weight: float,
|
||||
) -> float:
|
||||
"""
|
||||
Pairwise diversity penalty between two candidate chunks.
|
||||
|
||||
The first revision only uses token overlap because the current Document-RAG
|
||||
reranker document_id is the candidate index, not a source document id.
|
||||
"""
|
||||
penalty = token_overlap_weight * _jaccard(candidate.text, selected.text)
|
||||
return _clamp01(penalty)
|
||||
|
||||
|
||||
def mmr_select(
|
||||
candidates: Sequence[RerankCandidate],
|
||||
limit: int,
|
||||
lambda_mult: float = 0.7,
|
||||
token_overlap_weight: float = 1.0,
|
||||
) -> List[RerankCandidate]:
|
||||
"""
|
||||
Select a diverse final context set using MMR.
|
||||
|
||||
Relevance comes from normalized cross-encoder reranker scores.
|
||||
Diversity comes from token overlap against already selected chunks.
|
||||
"""
|
||||
if limit <= 0:
|
||||
return []
|
||||
|
||||
lambda_mult = _clamp01(lambda_mult)
|
||||
token_overlap_weight = max(0.0, token_overlap_weight)
|
||||
|
||||
remaining = normalize_candidate_scores(candidates)
|
||||
selected: List[RerankCandidate] = []
|
||||
|
||||
while remaining and len(selected) < limit:
|
||||
best_idx = 0
|
||||
best_score = None
|
||||
|
||||
for idx, candidate in enumerate(remaining):
|
||||
relevance = candidate.normalized_score
|
||||
|
||||
if selected:
|
||||
diversity_penalty = max(
|
||||
_pair_diversity_penalty(
|
||||
candidate,
|
||||
chosen,
|
||||
token_overlap_weight=token_overlap_weight,
|
||||
)
|
||||
for chosen in selected
|
||||
)
|
||||
else:
|
||||
diversity_penalty = 0.0
|
||||
|
||||
mmr_score = (
|
||||
lambda_mult * relevance
|
||||
- (1.0 - lambda_mult) * diversity_penalty
|
||||
)
|
||||
|
||||
if best_score is None or mmr_score > best_score:
|
||||
best_score = mmr_score
|
||||
best_idx = idx
|
||||
|
||||
selected.append(remaining.pop(best_idx))
|
||||
|
||||
return selected
|
||||
|
|
@ -120,7 +120,7 @@ class Query:
|
|||
def __init__(
|
||||
self, rag, collection, verbose,
|
||||
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
|
||||
max_path_length=2, track_usage=None,
|
||||
max_path_length=2, edge_limit=25, track_usage=None,
|
||||
):
|
||||
self.rag = rag
|
||||
self.collection = collection
|
||||
|
|
@ -129,6 +129,7 @@ class Query:
|
|||
self.triple_limit = triple_limit
|
||||
self.max_subgraph_size = max_subgraph_size
|
||||
self.max_path_length = max_path_length
|
||||
self.edge_limit = edge_limit
|
||||
self.track_usage = track_usage
|
||||
|
||||
async def extract_concepts(self, query):
|
||||
|
|
@ -220,9 +221,6 @@ class Query:
|
|||
|
||||
async def maybe_label(self, e):
|
||||
|
||||
# The label cache lives on a per-request GraphRag instance — no
|
||||
# cross-request isolation concern. The collection prefix keeps
|
||||
# entries from different collections distinct within one request.
|
||||
cache_key = f"{self.collection}:{e}"
|
||||
|
||||
cached_label = self.rag.label_cache.get(cache_key)
|
||||
|
|
@ -243,206 +241,217 @@ class Query:
|
|||
self.rag.label_cache.put(cache_key, label)
|
||||
return label
|
||||
|
||||
FROM_S = "from_s"
|
||||
FROM_P = "from_p"
|
||||
FROM_O = "from_o"
|
||||
|
||||
async def execute_batch_triple_queries(self, entities, limit_per_entity):
|
||||
"""Execute triple queries for multiple entities concurrently using streaming"""
|
||||
"""Execute triple queries for multiple entities concurrently.
|
||||
|
||||
Returns a list of (triple, direction) tuples where direction
|
||||
indicates which position the frontier entity occupied.
|
||||
"""
|
||||
tasks = []
|
||||
directions = []
|
||||
|
||||
for entity in entities:
|
||||
# Create concurrent streaming tasks for all 3 query types per entity
|
||||
tasks.extend([
|
||||
tasks.append(
|
||||
self.rag.triples_client.query_stream(
|
||||
s=entity, p=None, o=None,
|
||||
limit=limit_per_entity,
|
||||
collection=self.collection,
|
||||
batch_size=20, g="",
|
||||
),
|
||||
)
|
||||
directions.append(self.FROM_S)
|
||||
|
||||
tasks.append(
|
||||
self.rag.triples_client.query_stream(
|
||||
s=None, p=entity, o=None,
|
||||
limit=limit_per_entity,
|
||||
collection=self.collection,
|
||||
batch_size=20, g="",
|
||||
),
|
||||
)
|
||||
directions.append(self.FROM_P)
|
||||
|
||||
tasks.append(
|
||||
self.rag.triples_client.query_stream(
|
||||
s=None, p=None, o=entity,
|
||||
limit=limit_per_entity,
|
||||
collection=self.collection,
|
||||
batch_size=20, g="",
|
||||
),
|
||||
)
|
||||
])
|
||||
directions.append(self.FROM_O)
|
||||
|
||||
# Execute all queries concurrently
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Combine all results
|
||||
all_triples = []
|
||||
for result in results:
|
||||
for direction, result in zip(directions, results):
|
||||
if not isinstance(result, Exception) and result is not None:
|
||||
all_triples.extend(result)
|
||||
all_triples.extend((triple, direction) for triple in result)
|
||||
|
||||
return all_triples
|
||||
|
||||
async def follow_edges_batch(self, entities, max_depth):
|
||||
"""Optimized iterative graph traversal with batching.
|
||||
|
||||
Returns:
|
||||
tuple: (subgraph, term_map) where subgraph is a set of
|
||||
(str, str, str) tuples and term_map maps each string tuple
|
||||
to its original (Term, Term, Term) for type-preserving
|
||||
provenance.
|
||||
"""
|
||||
visited = set()
|
||||
current_level = set(entities)
|
||||
subgraph = set()
|
||||
term_map = {} # (str, str, str) -> (Term, Term, Term)
|
||||
|
||||
for depth in range(max_depth):
|
||||
if not current_level or len(subgraph) >= self.max_subgraph_size:
|
||||
break
|
||||
|
||||
# Filter out already visited entities
|
||||
unvisited_entities = [e for e in current_level if e not in visited]
|
||||
if not unvisited_entities:
|
||||
break
|
||||
|
||||
# Batch query all unvisited entities at current level
|
||||
triples = await self.execute_batch_triple_queries(
|
||||
unvisited_entities, self.triple_limit
|
||||
)
|
||||
|
||||
# Process results and collect next level entities
|
||||
next_level = set()
|
||||
for triple in triples:
|
||||
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
|
||||
subgraph.add(triple_tuple)
|
||||
term_map[triple_tuple] = (to_term(triple.s), to_term(triple.p), to_term(triple.o))
|
||||
|
||||
# Collect entities for next level (only from s and o positions)
|
||||
if depth < max_depth - 1: # Don't collect for final depth
|
||||
s, p, o = triple_tuple
|
||||
if s not in visited:
|
||||
next_level.add(s)
|
||||
if o not in visited:
|
||||
next_level.add(o)
|
||||
|
||||
# Stop if subgraph size limit reached
|
||||
if len(subgraph) >= self.max_subgraph_size:
|
||||
return subgraph, term_map
|
||||
|
||||
# Update for next iteration
|
||||
visited.update(current_level)
|
||||
current_level = next_level
|
||||
|
||||
return subgraph, term_map
|
||||
|
||||
async def follow_edges(self, ent, subgraph, path_length):
|
||||
"""Legacy method - replaced by follow_edges_batch"""
|
||||
# Maintain backward compatibility with early termination checks
|
||||
if path_length <= 0:
|
||||
return
|
||||
|
||||
if len(subgraph) >= self.max_subgraph_size:
|
||||
return
|
||||
|
||||
# For backward compatibility, convert to new approach
|
||||
batch_result, _ = await self.follow_edges_batch([ent], path_length)
|
||||
subgraph.update(batch_result)
|
||||
|
||||
async def get_subgraph(self, query):
|
||||
"""
|
||||
Get subgraph by extracting concepts, finding entities, and traversing.
|
||||
|
||||
Returns:
|
||||
tuple: (subgraph, term_map, entities, concepts) where subgraph is
|
||||
a list of (s, p, o) string tuples, term_map maps each string
|
||||
tuple to its original (Term, Term, Term), entities is the seed
|
||||
entity list, and concepts is the extracted concept list.
|
||||
"""
|
||||
|
||||
entities, concepts = await self.get_entities(query)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Getting subgraph...")
|
||||
|
||||
# Use optimized batch traversal instead of sequential processing
|
||||
subgraph, term_map = await self.follow_edges_batch(entities, self.max_path_length)
|
||||
|
||||
return list(subgraph), term_map, entities, concepts
|
||||
|
||||
async def resolve_labels_batch(self, entities):
|
||||
"""Resolve labels for multiple entities in parallel"""
|
||||
tasks = []
|
||||
for entity in entities:
|
||||
tasks.append(self.maybe_label(entity))
|
||||
|
||||
"""Resolve labels for multiple entities in parallel."""
|
||||
tasks = [self.maybe_label(entity) for entity in entities]
|
||||
return await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def get_labelgraph(self, query):
|
||||
"""
|
||||
Get subgraph with labels resolved for display.
|
||||
async def hop_and_filter(self, seed_entities, concepts):
|
||||
"""Iterative hop-and-filter graph traversal with cross-encoder.
|
||||
|
||||
At each hop:
|
||||
1. Retrieve all edges one hop from the frontier.
|
||||
2. Resolve labels and represent each edge as "{p} {o}".
|
||||
3. Score edges against concepts using the cross-encoder.
|
||||
4. Select the top edges; their target nodes become the next
|
||||
frontier.
|
||||
|
||||
Returns:
|
||||
tuple: (labeled_edges, uri_map, entities, concepts) where:
|
||||
- labeled_edges: list of (label_s, label_p, label_o) tuples
|
||||
- uri_map: dict mapping edge_id(label_s, label_p, label_o) -> (uri_s, uri_p, uri_o)
|
||||
- entities: list of seed entity URI strings
|
||||
- concepts: list of concept strings extracted from query
|
||||
tuple: (selected_edges, uri_map, edge_metadata) where:
|
||||
- selected_edges: list of (label_s, label_p, label_o)
|
||||
- uri_map: dict mapping edge_id -> (Term, Term, Term)
|
||||
- edge_metadata: dict mapping edge_id -> {concept, score}
|
||||
"""
|
||||
subgraph, term_map, entities, concepts = await self.get_subgraph(query)
|
||||
all_selected_edges = []
|
||||
uri_map = {}
|
||||
edge_metadata = {}
|
||||
frontier = set(seed_entities)
|
||||
visited_entities = set()
|
||||
seen_edges = set()
|
||||
|
||||
# Filter out label triples
|
||||
filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL]
|
||||
for hop in range(self.max_path_length):
|
||||
if not frontier:
|
||||
break
|
||||
|
||||
# Collect all unique entities that need label resolution
|
||||
unvisited = [e for e in frontier if e not in visited_entities]
|
||||
if not unvisited:
|
||||
break
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Hop {hop + 1}: {len(unvisited)} frontier entities"
|
||||
)
|
||||
|
||||
# Retrieve edges one hop from frontier
|
||||
triples = await self.execute_batch_triple_queries(
|
||||
unvisited, self.triple_limit,
|
||||
)
|
||||
|
||||
# Deduplicate and filter already-seen edges
|
||||
hop_triples = []
|
||||
hop_term_map = {}
|
||||
hop_directions = {}
|
||||
for triple, direction in triples:
|
||||
triple_tuple = (str(triple.s), str(triple.p), str(triple.o))
|
||||
if triple_tuple[1] == LABEL:
|
||||
continue
|
||||
if triple_tuple in seen_edges:
|
||||
continue
|
||||
seen_edges.add(triple_tuple)
|
||||
hop_triples.append(triple_tuple)
|
||||
hop_term_map[triple_tuple] = (
|
||||
to_term(triple.s), to_term(triple.p), to_term(triple.o),
|
||||
)
|
||||
hop_directions[triple_tuple] = direction
|
||||
|
||||
if not hop_triples:
|
||||
visited_entities.update(frontier)
|
||||
break
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Hop {hop + 1}: {len(hop_triples)} candidate edges"
|
||||
)
|
||||
|
||||
# Resolve labels for all entities in hop edges
|
||||
entities_to_resolve = set()
|
||||
for s, p, o in filtered_subgraph:
|
||||
for s, p, o in hop_triples:
|
||||
entities_to_resolve.update([s, p, o])
|
||||
|
||||
# Batch resolve labels for all entities in parallel
|
||||
entity_list = list(entities_to_resolve)
|
||||
resolved_labels = await self.resolve_labels_batch(entity_list)
|
||||
resolved = await self.resolve_labels_batch(entity_list)
|
||||
|
||||
# Create entity-to-label mapping
|
||||
label_map = {}
|
||||
for entity, label in zip(entity_list, resolved_labels):
|
||||
for entity, label in zip(entity_list, resolved):
|
||||
if not isinstance(label, Exception):
|
||||
label_map[entity] = label
|
||||
else:
|
||||
label_map[entity] = entity # Fallback to entity itself
|
||||
label_map[entity] = entity
|
||||
|
||||
# Apply labels to subgraph and build URI mapping
|
||||
labeled_edges = []
|
||||
uri_map = {} # Maps edge_id of labeled edge -> original Term triple
|
||||
# Build labeled edges and documents for cross-encoder.
|
||||
# The reranker text highlights the NEW information relative
|
||||
# to the traversal direction: arriving from S means p,o are
|
||||
# new; from O means s,p are new; from P means s,o are new.
|
||||
labeled_hop = []
|
||||
for s, p, o in hop_triples:
|
||||
ls = label_map.get(s, s)
|
||||
lp = label_map.get(p, p)
|
||||
lo = label_map.get(o, o)
|
||||
labeled_hop.append((ls, lp, lo))
|
||||
|
||||
for s, p, o in filtered_subgraph:
|
||||
labeled_triple = (
|
||||
label_map.get(s, s),
|
||||
label_map.get(p, p),
|
||||
label_map.get(o, o)
|
||||
documents = []
|
||||
for i, (triple_tuple, (ls, lp, lo)) in enumerate(
|
||||
zip(hop_triples, labeled_hop)
|
||||
):
|
||||
direction = hop_directions[triple_tuple]
|
||||
if direction == self.FROM_S:
|
||||
text = f"{lp} {lo}"
|
||||
elif direction == self.FROM_O:
|
||||
text = f"{ls} {lp}"
|
||||
else:
|
||||
text = f"{ls} {lo}"
|
||||
documents.append({"id": str(i), "text": text})
|
||||
|
||||
queries = [
|
||||
{"id": str(i), "text": c}
|
||||
for i, c in enumerate(concepts)
|
||||
]
|
||||
|
||||
# Score with cross-encoder
|
||||
results = await self.rag.reranker_client.rerank(
|
||||
queries=queries,
|
||||
documents=documents,
|
||||
limit=self.edge_limit,
|
||||
)
|
||||
labeled_edges.append(labeled_triple)
|
||||
|
||||
# Map from labeled edge ID to original Terms (preserving types)
|
||||
labeled_eid = edge_id(labeled_triple[0], labeled_triple[1], labeled_triple[2])
|
||||
uri_map[labeled_eid] = term_map.get((s, p, o), (s, p, o))
|
||||
# Collect selected edges and metadata
|
||||
next_frontier = set()
|
||||
for r in results:
|
||||
idx = int(r.document_id)
|
||||
ls, lp, lo = labeled_hop[idx]
|
||||
s, p, o = hop_triples[idx]
|
||||
eid = edge_id(ls, lp, lo)
|
||||
|
||||
labeled_edges = labeled_edges[0:self.max_subgraph_size]
|
||||
all_selected_edges.append((ls, lp, lo))
|
||||
uri_map[eid] = hop_term_map[(s, p, o)]
|
||||
edge_metadata[eid] = {
|
||||
"concept": concepts[int(r.query_id)],
|
||||
"score": r.score,
|
||||
}
|
||||
|
||||
# Target nodes become next frontier
|
||||
next_frontier.add(s)
|
||||
next_frontier.add(o)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Subgraph:")
|
||||
for edge in labeled_edges:
|
||||
logger.debug(f" {str(edge)}")
|
||||
logger.debug(
|
||||
f"Hop {hop + 1}: selected {len(results)} edges"
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Done.")
|
||||
visited_entities.update(frontier)
|
||||
frontier = next_frontier - visited_entities
|
||||
|
||||
return labeled_edges, uri_map, entities, concepts
|
||||
return all_selected_edges, uri_map, edge_metadata
|
||||
|
||||
async def trace_source_documents(self, edge_uris):
|
||||
"""
|
||||
Trace selected edges back to their source documents via provenance.
|
||||
|
||||
Follows the chain: edge → subgraph (via tg:contains) → chunk →
|
||||
page → document (via prov:wasDerivedFrom), all in urn:graph:source.
|
||||
Follows the chain: edge -> subgraph (via tg:contains) -> chunk ->
|
||||
page -> document (via prov:wasDerivedFrom), all in urn:graph:source.
|
||||
|
||||
Args:
|
||||
edge_uris: List of (s, p, o) URI string tuples
|
||||
|
|
@ -453,7 +462,6 @@ class Query:
|
|||
# Step 1: Find subgraphs containing these edges via tg:contains
|
||||
subgraph_tasks = []
|
||||
for s, p, o in edge_uris:
|
||||
# s, p, o may be Term objects (preserving types) or strings
|
||||
s_term = s if isinstance(s, Term) else Term(type=IRI, iri=s)
|
||||
p_term = p if isinstance(p, Term) else Term(type=IRI, iri=p)
|
||||
o_term = o if isinstance(o, Term) else Term(type=IRI, iri=o)
|
||||
|
|
@ -487,12 +495,10 @@ class Query:
|
|||
return []
|
||||
|
||||
# Step 2: Walk prov:wasDerivedFrom chain to find documents
|
||||
# Each level: query ?entity prov:wasDerivedFrom ?parent
|
||||
# Stop when we find entities typed tg:Document
|
||||
current_uris = subgraph_uris
|
||||
doc_uris = set()
|
||||
|
||||
for depth in range(4): # Max depth: subgraph → chunk → page → doc
|
||||
for depth in range(4):
|
||||
if not current_uris:
|
||||
break
|
||||
|
||||
|
|
@ -509,7 +515,6 @@ class Query:
|
|||
*derivation_tasks, return_exceptions=True
|
||||
)
|
||||
|
||||
# URIs with no parent are root documents
|
||||
next_uris = set()
|
||||
for uri, result in zip(current_uris, derivation_results):
|
||||
if isinstance(result, Exception) or not result:
|
||||
|
|
@ -524,7 +529,6 @@ class Query:
|
|||
return []
|
||||
|
||||
# Step 3: Get all document metadata properties
|
||||
# Skip structural predicates that aren't useful context
|
||||
SKIP_PREDICATES = {
|
||||
PROV_WAS_DERIVED_FROM,
|
||||
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
|
||||
|
|
@ -565,7 +569,7 @@ class GraphRag:
|
|||
|
||||
def __init__(
|
||||
self, prompt_client, embeddings_client, graph_embeddings_client,
|
||||
triples_client, verbose=False,
|
||||
triples_client, reranker_client, verbose=False,
|
||||
):
|
||||
|
||||
self.verbose = verbose
|
||||
|
|
@ -574,9 +578,8 @@ class GraphRag:
|
|||
self.embeddings_client = embeddings_client
|
||||
self.graph_embeddings_client = graph_embeddings_client
|
||||
self.triples_client = triples_client
|
||||
self.reranker_client = reranker_client
|
||||
|
||||
# Replace simple dict with LRU cache with TTL
|
||||
# CRITICAL: This cache only lives for one request due to per-request instantiation
|
||||
self.label_cache = LRUCacheWithTTL(max_size=5000, ttl=300)
|
||||
|
||||
if self.verbose:
|
||||
|
|
@ -585,33 +588,12 @@ class GraphRag:
|
|||
async def query(
|
||||
self, query, collection = "default",
|
||||
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
|
||||
max_path_length = 2, edge_score_limit = 30, edge_limit = 25,
|
||||
max_path_length = 2, edge_limit = 25,
|
||||
streaming = False,
|
||||
chunk_callback = None,
|
||||
explain_callback = None, save_answer_callback = None,
|
||||
parent_uri = "",
|
||||
):
|
||||
"""
|
||||
Execute a GraphRAG query with real-time explainability tracking.
|
||||
|
||||
Args:
|
||||
query: The query string
|
||||
collection: Collection identifier
|
||||
entity_limit: Max entities to retrieve
|
||||
triple_limit: Max triples per entity
|
||||
max_subgraph_size: Max edges in subgraph
|
||||
max_path_length: Max hops from seed entities
|
||||
edge_score_limit: Max edges to pass to LLM scoring (semantic pre-filter)
|
||||
edge_limit: Max edges after LLM scoring
|
||||
streaming: Enable streaming LLM response
|
||||
chunk_callback: async def callback(chunk, end_of_stream) for streaming
|
||||
explain_callback: async def callback(triples, explain_id) for real-time explainability
|
||||
save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian
|
||||
|
||||
Returns:
|
||||
tuple: (answer_text, usage) where usage is a dict with
|
||||
in_token, out_token, model
|
||||
"""
|
||||
# Accumulate token usage across all prompt calls
|
||||
total_in = 0
|
||||
total_out = 0
|
||||
|
|
@ -638,7 +620,9 @@ class GraphRag:
|
|||
foc_uri = make_focus_uri(session_id)
|
||||
syn_uri = make_synthesis_uri(session_id)
|
||||
|
||||
timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
timestamp = datetime.now(timezone.utc).isoformat().replace(
|
||||
"+00:00", "Z",
|
||||
)
|
||||
|
||||
# Emit question explainability immediately
|
||||
if explain_callback:
|
||||
|
|
@ -657,10 +641,12 @@ class GraphRag:
|
|||
triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
edge_limit = edge_limit,
|
||||
track_usage = track_usage,
|
||||
)
|
||||
|
||||
kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query)
|
||||
# Step 1: Extract concepts and find seed entities
|
||||
seed_entities, concepts = await q.get_entities(query)
|
||||
|
||||
# Emit grounding explain after concept extraction
|
||||
if explain_callback:
|
||||
|
|
@ -676,11 +662,16 @@ class GraphRag:
|
|||
)
|
||||
await explain_callback(gnd_triples, gnd_uri)
|
||||
|
||||
# Emit exploration explain after graph retrieval completes
|
||||
# Step 2: Iterative hop-and-filter with cross-encoder
|
||||
selected_edges, uri_map, edge_metadata = await q.hop_and_filter(
|
||||
seed_entities, concepts,
|
||||
)
|
||||
|
||||
# Emit exploration explain
|
||||
if explain_callback:
|
||||
exp_triples = set_graph(
|
||||
exploration_triples(
|
||||
exp_uri, gnd_uri, len(kg),
|
||||
exp_uri, gnd_uri, len(selected_edges),
|
||||
entities=seed_entities,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
|
|
@ -688,235 +679,63 @@ class GraphRag:
|
|||
await explain_callback(exp_triples, exp_uri)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Invoking LLM...")
|
||||
logger.debug(f"Knowledge graph: {kg}")
|
||||
logger.debug(f"Query: {query}")
|
||||
|
||||
# Semantic pre-filter: reduce edges before expensive LLM scoring
|
||||
if edge_score_limit > 0 and len(kg) > edge_score_limit:
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Semantic pre-filter: {len(kg)} edges > "
|
||||
f"limit {edge_score_limit}, filtering..."
|
||||
)
|
||||
|
||||
# Embed edge descriptions: "subject, predicate, object"
|
||||
edge_descriptions = [
|
||||
f"{s}, {p}, {o}" for s, p, o in kg
|
||||
]
|
||||
|
||||
# Embed concepts and edge descriptions concurrently
|
||||
concept_embed_task = self.embeddings_client.embed(concepts)
|
||||
edge_embed_task = self.embeddings_client.embed(edge_descriptions)
|
||||
|
||||
concept_vectors, edge_vectors = await asyncio.gather(
|
||||
concept_embed_task, edge_embed_task
|
||||
)
|
||||
|
||||
# Score each edge by max cosine similarity to any concept
|
||||
def cosine_similarity(a, b):
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = math.sqrt(sum(x * x for x in a))
|
||||
norm_b = math.sqrt(sum(x * x for x in b))
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
return dot / (norm_a * norm_b)
|
||||
|
||||
edge_scores = []
|
||||
for i, edge_vec in enumerate(edge_vectors):
|
||||
max_sim = max(
|
||||
cosine_similarity(edge_vec, cv)
|
||||
for cv in concept_vectors
|
||||
)
|
||||
edge_scores.append((max_sim, i))
|
||||
|
||||
# Sort by similarity descending and keep top edge_score_limit
|
||||
edge_scores.sort(reverse=True)
|
||||
keep_indices = set(
|
||||
idx for _, idx in edge_scores[:edge_score_limit]
|
||||
)
|
||||
|
||||
# Filter kg and rebuild uri_map
|
||||
filtered_kg = []
|
||||
filtered_uri_map = {}
|
||||
for i, (s, p, o) in enumerate(kg):
|
||||
if i in keep_indices:
|
||||
filtered_kg.append((s, p, o))
|
||||
logger.debug(f"Selected {len(selected_edges)} edges")
|
||||
for s, p, o in selected_edges:
|
||||
eid = edge_id(s, p, o)
|
||||
if eid in uri_map:
|
||||
filtered_uri_map[eid] = uri_map[eid]
|
||||
|
||||
if self.verbose:
|
||||
meta = edge_metadata.get(eid, {})
|
||||
logger.debug(
|
||||
f"Semantic pre-filter kept {len(filtered_kg)} "
|
||||
f"of {len(kg)} edges"
|
||||
f" {meta.get('score', 0):.4f} "
|
||||
f"[{meta.get('concept', '')}] "
|
||||
f"{s} | {p} | {o}"
|
||||
)
|
||||
|
||||
kg = filtered_kg
|
||||
uri_map = filtered_uri_map
|
||||
|
||||
# Build edge map: {hash_id: (labeled_s, labeled_p, labeled_o)}
|
||||
# uri_map already maps edge_id -> (uri_s, uri_p, uri_o)
|
||||
edge_map = {}
|
||||
edges_with_ids = []
|
||||
for s, p, o in kg:
|
||||
eid = edge_id(s, p, o)
|
||||
edge_map[eid] = (s, p, o)
|
||||
edges_with_ids.append({
|
||||
"id": eid,
|
||||
"s": s,
|
||||
"p": p,
|
||||
"o": o
|
||||
})
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Built edge map with {len(edge_map)} edges")
|
||||
|
||||
# Step 1a: Edge Scoring - LLM scores edges for relevance
|
||||
scoring_result = await self.prompt_client.prompt(
|
||||
"kg-edge-scoring",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": edges_with_ids
|
||||
}
|
||||
)
|
||||
track_usage(scoring_result)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Edge scoring result: {scoring_result}")
|
||||
|
||||
# Parse scoring response (jsonl) to get edge IDs with scores
|
||||
scored_edges = []
|
||||
|
||||
for obj in scoring_result.objects or []:
|
||||
if isinstance(obj, dict) and "id" in obj and "score" in obj:
|
||||
try:
|
||||
score = int(obj["score"])
|
||||
except (ValueError, TypeError):
|
||||
score = 0
|
||||
scored_edges.append({"id": obj["id"], "score": score})
|
||||
|
||||
# Select top N edges by score
|
||||
scored_edges.sort(key=lambda x: x["score"], reverse=True)
|
||||
top_edges = scored_edges[:edge_limit]
|
||||
selected_ids = {e["id"] for e in top_edges}
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Scored {len(scored_edges)} edges, "
|
||||
f"selected top {len(selected_ids)}"
|
||||
)
|
||||
|
||||
# Filter to selected edges
|
||||
selected_edges = []
|
||||
for eid in selected_ids:
|
||||
if eid in edge_map:
|
||||
selected_edges.append(edge_map[eid])
|
||||
|
||||
# Step 1b: Edge Reasoning + Document Tracing (concurrent)
|
||||
selected_edges_with_ids = [
|
||||
{"id": eid, "s": s, "p": p, "o": o}
|
||||
for eid in selected_ids
|
||||
if eid in edge_map
|
||||
for s, p, o in [edge_map[eid]]
|
||||
]
|
||||
|
||||
# Collect selected edge URIs for document tracing
|
||||
# Step 3: Document tracing
|
||||
selected_edge_uris = [
|
||||
uri_map[eid]
|
||||
for eid in selected_ids
|
||||
if eid in uri_map
|
||||
uri_map[edge_id(s, p, o)]
|
||||
for s, p, o in selected_edges
|
||||
if edge_id(s, p, o) in uri_map
|
||||
]
|
||||
|
||||
# Run reasoning and document tracing concurrently
|
||||
async def _get_reasoning():
|
||||
result = await self.prompt_client.prompt(
|
||||
"kg-edge-reasoning",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": selected_edges_with_ids
|
||||
}
|
||||
)
|
||||
track_usage(result)
|
||||
return result
|
||||
|
||||
reasoning_task = _get_reasoning()
|
||||
doc_trace_task = q.trace_source_documents(selected_edge_uris)
|
||||
|
||||
reasoning_result, source_documents = await asyncio.gather(
|
||||
reasoning_task, doc_trace_task, return_exceptions=True
|
||||
source_documents = await q.trace_source_documents(
|
||||
selected_edge_uris,
|
||||
)
|
||||
|
||||
# Handle exceptions from gather
|
||||
if isinstance(reasoning_result, Exception):
|
||||
logger.warning(
|
||||
f"Edge reasoning failed: {reasoning_result}"
|
||||
)
|
||||
reasoning_result = None
|
||||
if isinstance(source_documents, Exception):
|
||||
logger.warning(
|
||||
f"Document tracing failed: {source_documents}"
|
||||
)
|
||||
source_documents = []
|
||||
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Edge reasoning result: {reasoning_result}")
|
||||
|
||||
# Parse reasoning response (jsonl) and build explainability data
|
||||
reasoning_map = {}
|
||||
|
||||
if reasoning_result is not None:
|
||||
for obj in reasoning_result.objects or []:
|
||||
if isinstance(obj, dict) and "id" in obj:
|
||||
reasoning_map[obj["id"]] = obj.get("reasoning", "")
|
||||
|
||||
# Build focus explainability data with cross-encoder metadata
|
||||
selected_edges_with_reasoning = []
|
||||
for eid in selected_ids:
|
||||
for s, p, o in selected_edges:
|
||||
eid = edge_id(s, p, o)
|
||||
if eid in uri_map:
|
||||
uri_s, uri_p, uri_o = uri_map[eid]
|
||||
meta = edge_metadata.get(eid, {})
|
||||
selected_edges_with_reasoning.append({
|
||||
"edge": (uri_s, uri_p, uri_o),
|
||||
"reasoning": reasoning_map.get(eid, ""),
|
||||
"concept": meta.get("concept", ""),
|
||||
"score": meta.get("score", 0),
|
||||
})
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Filtered to {len(selected_edges)} edges")
|
||||
|
||||
# Emit focus explain after edge selection completes
|
||||
# Emit focus explain
|
||||
if explain_callback:
|
||||
# Sum scoring + reasoning token usage for focus event
|
||||
focus_in = 0
|
||||
focus_out = 0
|
||||
focus_model = None
|
||||
for r in [scoring_result, reasoning_result]:
|
||||
if r is not None:
|
||||
if r.in_token is not None:
|
||||
focus_in += r.in_token
|
||||
if r.out_token is not None:
|
||||
focus_out += r.out_token
|
||||
if r.model is not None:
|
||||
focus_model = r.model
|
||||
|
||||
foc_triples = set_graph(
|
||||
focus_triples(
|
||||
foc_uri, exp_uri, selected_edges_with_reasoning, session_id,
|
||||
in_token=focus_in or None,
|
||||
out_token=focus_out or None,
|
||||
model=focus_model,
|
||||
foc_uri, exp_uri,
|
||||
selected_edges_with_reasoning, session_id,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(foc_triples, foc_uri)
|
||||
|
||||
# Step 2: Synthesis - LLM generates answer from selected edges only
|
||||
# Step 4: Synthesis
|
||||
selected_edge_dicts = [
|
||||
{"s": s, "p": p, "o": o}
|
||||
for s, p, o in selected_edges
|
||||
]
|
||||
|
||||
# Add source document metadata as knowledge edges
|
||||
for s, p, o in source_documents:
|
||||
selected_edge_dicts.append({
|
||||
"s": s, "p": p, "o": o,
|
||||
|
|
@ -928,7 +747,6 @@ class GraphRag:
|
|||
}
|
||||
|
||||
if streaming and chunk_callback:
|
||||
# Accumulate chunks for answer storage while forwarding to callback
|
||||
accumulated_chunks = []
|
||||
|
||||
async def accumulating_callback(chunk, end_of_stream):
|
||||
|
|
@ -942,7 +760,6 @@ class GraphRag:
|
|||
chunk_callback=accumulating_callback
|
||||
)
|
||||
track_usage(synthesis_result)
|
||||
# Combine all chunks into full response
|
||||
resp = "".join(accumulated_chunks)
|
||||
else:
|
||||
synthesis_result = await self.prompt_client.prompt(
|
||||
|
|
@ -955,29 +772,42 @@ class GraphRag:
|
|||
if self.verbose:
|
||||
logger.debug("Query processing complete")
|
||||
|
||||
# Emit synthesis explain after synthesis completes
|
||||
# Emit synthesis explain
|
||||
if explain_callback:
|
||||
synthesis_doc_id = None
|
||||
answer_text = resp if resp else ""
|
||||
|
||||
# Save answer to librarian
|
||||
if save_answer_callback and answer_text:
|
||||
synthesis_doc_id = f"urn:trustgraph:synthesis:{session_id}"
|
||||
try:
|
||||
await save_answer_callback(synthesis_doc_id, answer_text)
|
||||
if self.verbose:
|
||||
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
|
||||
logger.debug(
|
||||
f"Saved answer to librarian: "
|
||||
f"{synthesis_doc_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save answer to librarian: {e}")
|
||||
logger.warning(
|
||||
f"Failed to save answer to librarian: {e}"
|
||||
)
|
||||
synthesis_doc_id = None
|
||||
|
||||
syn_triples = set_graph(
|
||||
synthesis_triples(
|
||||
syn_uri, foc_uri,
|
||||
document_id=synthesis_doc_id,
|
||||
in_token=synthesis_result.in_token if synthesis_result else None,
|
||||
out_token=synthesis_result.out_token if synthesis_result else None,
|
||||
model=synthesis_result.model if synthesis_result else None,
|
||||
in_token=(
|
||||
synthesis_result.in_token
|
||||
if synthesis_result else None
|
||||
),
|
||||
out_token=(
|
||||
synthesis_result.out_token
|
||||
if synthesis_result else None
|
||||
),
|
||||
model=(
|
||||
synthesis_result.model
|
||||
if synthesis_result else None
|
||||
),
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
|
|
@ -993,4 +823,3 @@ class GraphRag:
|
|||
}
|
||||
|
||||
return resp, usage
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from . graph_rag import GraphRag
|
|||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
||||
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
|
||||
from ... base import RerankerClientSpec
|
||||
from ... base import LibrarianSpec
|
||||
|
||||
# Module logger
|
||||
|
|
@ -32,7 +33,6 @@ class Processor(FlowProcessor):
|
|||
triple_limit = params.get("triple_limit", 30)
|
||||
max_subgraph_size = params.get("max_subgraph_size", 150)
|
||||
max_path_length = params.get("max_path_length", 2)
|
||||
edge_score_limit = params.get("edge_score_limit", 30)
|
||||
edge_limit = params.get("edge_limit", 25)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
|
|
@ -43,7 +43,6 @@ class Processor(FlowProcessor):
|
|||
"triple_limit": triple_limit,
|
||||
"max_subgraph_size": max_subgraph_size,
|
||||
"max_path_length": max_path_length,
|
||||
"edge_score_limit": edge_score_limit,
|
||||
"edge_limit": edge_limit,
|
||||
}
|
||||
)
|
||||
|
|
@ -52,7 +51,6 @@ class Processor(FlowProcessor):
|
|||
self.default_triple_limit = triple_limit
|
||||
self.default_max_subgraph_size = max_subgraph_size
|
||||
self.default_max_path_length = max_path_length
|
||||
self.default_edge_score_limit = edge_score_limit
|
||||
self.default_edge_limit = edge_limit
|
||||
|
||||
# Workspace isolation is enforced by the flow layer (flow.workspace).
|
||||
|
|
@ -96,6 +94,13 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
RerankerClientSpec(
|
||||
request_name = "reranker-request",
|
||||
response_name = "reranker-response",
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "response",
|
||||
|
|
@ -163,6 +168,7 @@ class Processor(FlowProcessor):
|
|||
graph_embeddings_client=flow("graph-embeddings-request"),
|
||||
triples_client=flow("triples-request"),
|
||||
prompt_client=flow("prompt-request"),
|
||||
reranker_client=flow("reranker-request"),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
|
@ -186,11 +192,6 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
max_path_length = self.default_max_path_length
|
||||
|
||||
if v.edge_score_limit:
|
||||
edge_score_limit = v.edge_score_limit
|
||||
else:
|
||||
edge_score_limit = self.default_edge_score_limit
|
||||
|
||||
if v.edge_limit:
|
||||
edge_limit = v.edge_limit
|
||||
else:
|
||||
|
|
@ -225,7 +226,7 @@ class Processor(FlowProcessor):
|
|||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
edge_score_limit = edge_score_limit,
|
||||
|
||||
edge_limit = edge_limit,
|
||||
streaming = True,
|
||||
chunk_callback = send_chunk,
|
||||
|
|
@ -241,7 +242,7 @@ class Processor(FlowProcessor):
|
|||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
edge_score_limit = edge_score_limit,
|
||||
|
||||
edge_limit = edge_limit,
|
||||
explain_callback = send_explainability,
|
||||
save_answer_callback = save_answer,
|
||||
|
|
@ -338,18 +339,11 @@ class Processor(FlowProcessor):
|
|||
help=f'Default max path length (default: 2)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--edge-score-limit',
|
||||
type=int,
|
||||
default=30,
|
||||
help=f'Semantic pre-filter limit before LLM scoring (default: 30)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--edge-limit',
|
||||
type=int,
|
||||
default=25,
|
||||
help=f'Max edges after LLM scoring (default: 25)'
|
||||
help=f'Max edges selected per hop by cross-encoder (default: 25)'
|
||||
)
|
||||
|
||||
# Note: Explainability triples are now stored in the request's collection
|
||||
|
|
|
|||
|
|
@ -94,10 +94,8 @@ class IamTableStore:
|
|||
|
||||
self.cassandra.execute("""
|
||||
CREATE TABLE IF NOT EXISTS iam_users_by_username (
|
||||
workspace text,
|
||||
username text,
|
||||
user_id text,
|
||||
PRIMARY KEY ((workspace), username)
|
||||
username text PRIMARY KEY,
|
||||
user_id text
|
||||
);
|
||||
""")
|
||||
|
||||
|
|
@ -175,16 +173,16 @@ class IamTableStore:
|
|||
""")
|
||||
|
||||
self.put_username_lookup_stmt = c.prepare("""
|
||||
INSERT INTO iam_users_by_username (workspace, username, user_id)
|
||||
VALUES (?, ?, ?)
|
||||
INSERT INTO iam_users_by_username (username, user_id)
|
||||
VALUES (?, ?)
|
||||
""")
|
||||
self.get_user_id_by_username_stmt = c.prepare("""
|
||||
SELECT user_id FROM iam_users_by_username
|
||||
WHERE workspace = ? AND username = ?
|
||||
WHERE username = ?
|
||||
""")
|
||||
self.delete_username_lookup_stmt = c.prepare("""
|
||||
DELETE FROM iam_users_by_username
|
||||
WHERE workspace = ? AND username = ?
|
||||
WHERE username = ?
|
||||
""")
|
||||
self.delete_user_stmt = c.prepare("""
|
||||
DELETE FROM iam_users WHERE id = ?
|
||||
|
|
@ -289,7 +287,7 @@ class IamTableStore:
|
|||
)
|
||||
await async_execute(
|
||||
self.cassandra, self.put_username_lookup_stmt,
|
||||
(workspace, username, id),
|
||||
(username, id),
|
||||
)
|
||||
|
||||
async def get_user(self, id):
|
||||
|
|
@ -298,10 +296,10 @@ class IamTableStore:
|
|||
)
|
||||
return rows[0] if rows else None
|
||||
|
||||
async def get_user_id_by_username(self, workspace, username):
|
||||
async def get_user_id_by_username(self, username):
|
||||
rows = await async_execute(
|
||||
self.cassandra, self.get_user_id_by_username_stmt,
|
||||
(workspace, username),
|
||||
(username,),
|
||||
)
|
||||
return rows[0][0] if rows else None
|
||||
|
||||
|
|
@ -324,10 +322,10 @@ class IamTableStore:
|
|||
self.cassandra, self.delete_user_stmt, (id,),
|
||||
)
|
||||
|
||||
async def delete_username_lookup(self, workspace, username):
|
||||
async def delete_username_lookup(self, username):
|
||||
await async_execute(
|
||||
self.cassandra, self.delete_username_lookup_stmt,
|
||||
(workspace, username),
|
||||
(username,),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,49 +1,110 @@
|
|||
|
||||
from dataclasses import dataclass
|
||||
from websockets.asyncio.client import connect
|
||||
from urllib.parse import urlencode, urlparse, urlunparse, parse_qs
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
import uuid
|
||||
import time
|
||||
import hashlib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _token_key(token):
|
||||
"""Derive a dict key from a token without storing the raw secret."""
|
||||
return hashlib.sha256(token.encode()).hexdigest()[:16]
|
||||
|
||||
|
||||
class WebSocketManager:
|
||||
"""Manages an authenticated WebSocket connection to the TrustGraph
|
||||
gateway on behalf of a single caller.
|
||||
|
||||
def __init__(self, url, token=None):
|
||||
Each caller token gets its own WebSocketManager so that gateway-side
|
||||
identity, workspace, and capability scoping are preserved end-to-end.
|
||||
"""
|
||||
|
||||
def __init__(self, url, token):
|
||||
self.url = url
|
||||
# ── Security boundary: token storage ──
|
||||
# This is the MCP caller's Bearer token, forwarded verbatim to
|
||||
# the gateway. It MUST NOT be logged, persisted, or shared
|
||||
# across callers. It is held only for the lifetime of this
|
||||
# connection so that re-auth (e.g. after a reconnect) is
|
||||
# possible.
|
||||
self.token = token
|
||||
self.socket = None
|
||||
|
||||
# FIXME: authentication is broken. The /api/v1/socket endpoint uses
|
||||
# in-band auth (first-frame protocol via the Mux dispatcher), not
|
||||
# query-parameter tokens. This query-string token is silently ignored.
|
||||
# Fix: after connect(), send an auth frame with the bearer token as
|
||||
# the first message, matching the gateway's in-band auth protocol.
|
||||
def _build_url(self):
|
||||
if not self.token:
|
||||
return self.url
|
||||
parsed = urlparse(self.url)
|
||||
params = parse_qs(parsed.query)
|
||||
params["token"] = [self.token]
|
||||
new_query = urlencode(params, doseq=True)
|
||||
return urlunparse(parsed._replace(query=new_query))
|
||||
self.identity = None
|
||||
self.last_used = None
|
||||
|
||||
async def start(self):
|
||||
self.socket = await connect(self._build_url())
|
||||
"""Connect and authenticate via the gateway's in-band auth
|
||||
protocol. Raises on auth failure."""
|
||||
|
||||
# ── Security boundary: MCP server → gateway ──
|
||||
# The WebSocket connects to the gateway and authenticates using
|
||||
# the caller's Bearer token via the in-band first-frame auth
|
||||
# protocol. The token belongs to the MCP client — we forward
|
||||
# it as-is and never interpret its contents.
|
||||
self.socket = await connect(self.url)
|
||||
self.pending_requests = {}
|
||||
self.running = True
|
||||
|
||||
await self._authenticate()
|
||||
|
||||
self.reader_task = asyncio.create_task(self.reader())
|
||||
|
||||
async def _authenticate(self):
|
||||
"""Send in-band auth frame and wait for auth-ok / auth-failed.
|
||||
|
||||
The gateway expects ``{"type": "auth", "token": "..."}`` as the
|
||||
first frame on a new WebSocket. Any service frame sent before
|
||||
auth-ok is rejected.
|
||||
"""
|
||||
await self.socket.send(json.dumps({
|
||||
"type": "auth",
|
||||
"token": self.token,
|
||||
}))
|
||||
|
||||
response_text = await asyncio.wait_for(self.socket.recv(), 10)
|
||||
response = json.loads(response_text)
|
||||
|
||||
if response.get("type") == "auth-ok":
|
||||
logger.info(
|
||||
"WebSocket authenticated, default workspace: %s",
|
||||
response.get("workspace"),
|
||||
)
|
||||
return
|
||||
|
||||
# Auth failed — close immediately, do not leave an
|
||||
# unauthenticated socket open.
|
||||
await self.socket.close()
|
||||
self.socket = None
|
||||
|
||||
if response.get("type") == "auth-failed":
|
||||
raise RuntimeError(
|
||||
"Gateway rejected the authentication token"
|
||||
)
|
||||
|
||||
raise RuntimeError(
|
||||
f"Unexpected auth response type: {response.get('type')}"
|
||||
)
|
||||
|
||||
async def whoami(self):
|
||||
"""Verify the token by calling the gateway's whoami endpoint.
|
||||
Returns the identity dict and caches it on ``self.identity``.
|
||||
"""
|
||||
gen = self.request("iam", {"operation": "whoami"}, flow_id=None)
|
||||
async for response in gen:
|
||||
self.identity = response
|
||||
return response
|
||||
|
||||
async def stop(self):
|
||||
self.running = False
|
||||
if hasattr(self, "reader_task"):
|
||||
await self.reader_task
|
||||
|
||||
async def reader(self):
|
||||
"""
|
||||
Background task to read websocket responses and route to correct
|
||||
request
|
||||
"""
|
||||
"""Background task: read WebSocket frames and route them to the
|
||||
correct pending-request queue by ``id``."""
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
|
|
@ -59,23 +120,21 @@ class WebSocketManager:
|
|||
|
||||
request_id = response.get("id")
|
||||
if request_id and request_id in self.pending_requests:
|
||||
# Put the response in the queue
|
||||
queue = self.pending_requests[request_id]
|
||||
await queue.put(response)
|
||||
else:
|
||||
logging.warning(
|
||||
f"Response for unknown request ID: {request_id}"
|
||||
logger.warning(
|
||||
"Response for unknown request ID: %s", request_id
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
logging.error(f"Error in websocket reader: {e}")
|
||||
logger.error("Error in websocket reader: %s", e)
|
||||
|
||||
# Put error in all pending queues
|
||||
for queue in self.pending_requests.values():
|
||||
try:
|
||||
await queue.put({"error": str(e)})
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.pending_requests.clear()
|
||||
|
|
@ -86,25 +145,29 @@ class WebSocketManager:
|
|||
|
||||
async def request(
|
||||
self, service, request_data, flow_id="default",
|
||||
workspace=None,
|
||||
):
|
||||
"""
|
||||
Send a request via websocket and handle single or streaming responses
|
||||
"""Send a request via WebSocket and yield responses.
|
||||
|
||||
Args:
|
||||
service: Gateway service name (e.g. "graph-rag", "config").
|
||||
request_data: Inner request payload.
|
||||
flow_id: Optional flow identifier. ``None`` omits the field
|
||||
(workspace-level services don't use flows).
|
||||
workspace: Optional workspace override. When ``None`` the
|
||||
gateway uses the caller's default workspace.
|
||||
"""
|
||||
|
||||
# Generate unique request ID
|
||||
import time
|
||||
self.last_used = time.monotonic()
|
||||
|
||||
request_id = f"{uuid.uuid4()}"
|
||||
|
||||
# Determine if this service streams responses
|
||||
streaming_services = {"agent"}
|
||||
is_streaming = service in streaming_services
|
||||
|
||||
# Create a queue for all responses (streaming and single)
|
||||
response_queue = asyncio.Queue()
|
||||
self.pending_requests[request_id] = response_queue
|
||||
|
||||
try:
|
||||
|
||||
# Build request message
|
||||
message = {
|
||||
"id": request_id,
|
||||
"service": service,
|
||||
|
|
@ -114,7 +177,16 @@ class WebSocketManager:
|
|||
if flow_id is not None:
|
||||
message["flow"] = flow_id
|
||||
|
||||
# Send request
|
||||
# ── Security boundary: workspace scoping ──
|
||||
# When the caller supplies a workspace, we set it on the
|
||||
# message envelope. The gateway's enforce_workspace()
|
||||
# validates that the authenticated identity is permitted
|
||||
# to access the target workspace — we MUST NOT skip or
|
||||
# override that check. When workspace is None, the
|
||||
# gateway default-fills from the identity's bound workspace.
|
||||
if workspace is not None:
|
||||
message["workspace"] = workspace
|
||||
|
||||
await self.socket.send(json.dumps(message))
|
||||
|
||||
while self.running:
|
||||
|
|
@ -127,19 +199,17 @@ class WebSocketManager:
|
|||
continue
|
||||
|
||||
if "error" in response:
|
||||
if "message" in response["error"]:
|
||||
raise RuntimeError(response["error"]["text"])
|
||||
if isinstance(response["error"], dict):
|
||||
raise RuntimeError(
|
||||
response["error"].get("message", str(response["error"]))
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(str(response["error"]))
|
||||
|
||||
yield response["response"]
|
||||
|
||||
if "complete" in response:
|
||||
if response["complete"]:
|
||||
if response.get("complete"):
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
# Clean up on error
|
||||
finally:
|
||||
self.pending_requests.pop(request_id, None)
|
||||
raise e
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
|
|||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
dependencies = [
|
||||
"trustgraph-base>=2.5,<2.6",
|
||||
"trustgraph-base>=2.6,<2.7",
|
||||
"pulsar-client",
|
||||
"prometheus-client",
|
||||
"boto3",
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
|
|||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
dependencies = [
|
||||
"trustgraph-base>=2.5,<2.6",
|
||||
"trustgraph-base>=2.6,<2.7",
|
||||
"pulsar-client",
|
||||
"prometheus-client",
|
||||
"python-magic",
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
|
|||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
dependencies = [
|
||||
"trustgraph-base>=2.5,<2.6",
|
||||
"trustgraph-base>=2.6,<2.7",
|
||||
"pulsar-client",
|
||||
"google-genai",
|
||||
"google-api-core",
|
||||
|
|
|
|||
|
|
@ -10,13 +10,13 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
|
|||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
dependencies = [
|
||||
"trustgraph-base>=2.5,<2.6",
|
||||
"trustgraph-bedrock>=2.5,<2.6",
|
||||
"trustgraph-cli>=2.5,<2.6",
|
||||
"trustgraph-embeddings-hf>=2.5,<2.6",
|
||||
"trustgraph-flow>=2.5,<2.6",
|
||||
"trustgraph-unstructured>=2.5,<2.6",
|
||||
"trustgraph-vertexai>=2.5,<2.6",
|
||||
"trustgraph-base>=2.6,<2.7",
|
||||
"trustgraph-bedrock>=2.6,<2.7",
|
||||
"trustgraph-cli>=2.6,<2.7",
|
||||
"trustgraph-embeddings-hf>=2.6,<2.7",
|
||||
"trustgraph-flow>=2.6,<2.7",
|
||||
"trustgraph-unstructured>=2.6,<2.7",
|
||||
"trustgraph-vertexai>=2.6,<2.7",
|
||||
]
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue