Compare commits

..

47 commits

Author SHA1 Message Date
Cyber MacGeddon
508d0bb5c1 Merge branch 'release/v2.6' 2026-07-03 13:45:02 +01:00
cybermaggedon
c05296376e
fix: remove test import (#1017)
Convention in the tests is to just import the libraries as production
code would. This is fragile, and could possibly be used to inject
malicious code in the CI environment.
2026-07-03 13:44:09 +01:00
YingzuoLiu
f04ae5331d
Add diversity-aware selection after Document-RAG reranking (#1014)
* Add Document-RAG diversity selection helper

* Add optional MMR diversity selection after reranking

* Fix Document-RAG diversity test method signatures
2026-07-03 13:35:42 +01:00
cybermaggedon
db7fdbc652
feat: direction-aware reranker text in GraphRAG hop-and-filter (#1016)
The reranker document text now reflects the traversal direction,
showing only the new information relative to the frontier entity:
- From S (subject is frontier): text = "{predicate} {object}"
- From O (object is frontier): text = "{subject} {predicate}"
- From P (predicate is frontier): text = "{subject} {object}"

This eliminates duplicate reranker texts when traversing inward
from shared object nodes (e.g. 18 CPUs all producing identical
"hasSubcategory Processors" text when the subject was dropped).

execute_batch_triple_queries now returns (triple, direction)
tuples so hop_and_filter can select the appropriate text format.

Updates tech spec to document the direction-aware approach.
Adds unit tests for direction tracking and reranker text
construction.
2026-07-02 21:14:47 +01:00
Cyber MacGeddon
4aaa1ce915 Merge branch 'release/v2.6' into master to keep sync. 2026-07-02 14:58:37 +01:00
cybermaggedon
9cf7dcb578
fix: wire variant into remaining streaming integration test mocks (#1013)
Three more streaming tests were missing _wire_variant after the
async for change in create_completion_stream.
2026-07-02 11:14:54 +01:00
Sunny
6c9a545a06
feat: add cross-encoder reranking to Document-RAG with two-limit control (#878) (#1011)
Wire the FlashRank reranker subsystem from #1005 into Document-RAG: after
vector retrieval, over-fetch a wider candidate pool, rerank with the
cross-encoder, and keep the top doc_limit chunks for synthesis.

Per maintainer review, the fetch and select sizes are two caller-controlled
limits rather than one internal heuristic:

- doc_limit:   chunks selected into the synthesis prompt (unchanged meaning).
- fetch_limit: candidate pool pulled from the vector store before reranking.
  0 = derive (OVERFETCH_FACTOR x doc_limit); values below doc_limit are
  raised to it. Lets the caller control how hard the reranker has to work.

Details:
- schema: DocumentRagQuery.fetch_limit (additive, backward compatible).
- document_rag.py / rag.py: fetch_limit resolved in the processor (mirrors
  doc_limit); the core applies the heuristic default and derives synthesis
  provenance from the chunk-selection focus when reranking ran.
- provenance: tg:ChunkSelection focus stage (mirrors tg:EdgeSelection).
- request translator + client SDKs + CLI: fetch-limit / --fetch-limit,
  threaded exactly like doc_limit and the GraphRAG limits.
- tests: no-op identity, over-fetch/narrow, explicit fetch_limit, heuristic
  default, floor-at-doc_limit, provenance lineage, cross-repo topic wiring.

Reranking is skipped byte-identically when no reranker role is wired.
Requires the companion trustgraph-templates change wiring the reranker
topics into the document-rag flow (mirrors #279 for GraphRAG).
2026-07-02 09:50:13 +01:00
cybermaggedon
f18d48dc39
fix: simplify dashscope variant and route API calls through variants (#1012)
Replace the client.post()/httpx bypass with standard SDK extra_body,
confirmed working against DashScope. Make DashScope the base variant
with Qwen as a subclass alias. Route all API calls through variant
create_completion/create_completion_stream methods.
2026-07-02 09:12:55 +01:00
cybermaggedon
6887076ce0
feat: add dashscope variant for Alibaba Cloud DashScope API (#1010)
DashScope uses enable_thinking as a top-level parameter rather than
inside extra_body as the Qwen docs suggest.
2026-07-01 16:50:47 +01:00
cybermaggedon
55e2a2a3ce
feat: add guided macOS installer and developer install guide (#1003)
Interactive bash installer (install_trustgraph.sh) that detects hardware,
recommends an LLM mode (OpenAI or Ollama), installs missing prerequisites
via Homebrew, sets up a Python venv, runs the test suite, generates a
deployment via npx @trustgraph/config, starts the Docker Compose stack,
health-checks the API gateway, and opens the Workbench UI.

Includes README.dev-install.md with usage documentation covering CLI
options, environment variables, LLM mode selection, non-interactive/CI
usage, uninstall, and troubleshooting. Currently macOS only.
2026-07-01 16:50:14 +01:00
cybermaggedon
11ca7c89c4
feat: add GLM (Zhipu AI) variant for OpenAI processor (#1009) 2026-07-01 16:20:43 +01:00
cybermaggedon
656ca430b9
fix: wire variant into text-completion integration test mocks (#1008)
Tests using MagicMock processors need the variant, thinking mode,
and _build_kwargs/_extract_content methods bound to work with the
new variant-based API kwargs construction.
2026-07-01 15:40:23 +01:00
cybermaggedon
f20b50cfb2
feat: add API variant profiles and thinking support to OpenAI processor (#1007)
Add a --variant flag (openai, deepseek, qwen, mistral, llama) that
encapsulates provider-specific API differences: output token parameter
names, thinking/reasoning toggles, temperature rules, and thinking
output extraction. Add --thinking flag (off, low, medium, high) to
control reasoning effort.
2026-07-01 14:48:32 +01:00
Jack Colquitt
04c5921687
Fix Discord link in README (#1006)
Updated Discord link in README.md
2026-06-30 19:39:33 -07:00
cybermaggedon
01cc8dbc64
feat: replace LLM edge scoring with cross-encoder reranker in GraphRAG (#1005)
Replace the three-prompt LLM scoring pipeline (kg-edge-scoring,
kg-edge-reasoning, kg-edge-selection) with a cross-encoder reranker
service backed by FlashRank. The new hop_and_filter() method performs
iterative graph traversal with semantic scoring at each hop, replacing
the previous follow_edges/get_subgraph approach.

- Add reranker service (trustgraph-base client/service, FlashRank processor)
- Add gateway dispatch for reranker via API and WebSocket
- Rewrite GraphRAG pipeline: hop_and_filter() with per-hop cross-encoder scoring
- Remove kg_prompt() and edge_score_limit from prompt client
- Update provenance: add tg:EdgeSelection type, tg:concept, tg:score predicates
- Update CLIs (tg-invoke-graph-rag, tg-show-explain-trace) for new metadata
- Add tg-invoke-reranker CLI tool
- Add tech spec and UX developer guidance
- Update all unit and integration tests
2026-06-30 14:36:37 +01:00
corvus-0x
1aa9549912
feat: make bootstrapper initialiser timeouts configurable (#999)
* feat: make bootstrapper initialiser timeouts configurable

DefaultFlowStart and WorkspaceInit hardcoded the request timeouts for
their flow-svc and IAM calls, leaving operators no way to tune them for
high-latency environments (#874).

Expose them as constructor parameters threaded through the existing
initialiser `params:` mechanism, defaulting to the current values so
behaviour is unchanged unless explicitly overridden:

- DefaultFlowStart: list_timeout=10 (list-flows), start_timeout=30 (start-flow)
- WorkspaceInit: iam_timeout=10 (create-workspace)

Add unit tests for the defaults, override storage, and that configured
values reach the underlying request calls.

* test: mark async bootstrap test with @pytest.mark.asyncio

Addresses review feedback on PR #999: add the explicit
@pytest.mark.asyncio decorator to test_run_forwards_configured_timeouts
so it does not rely on asyncio_mode=auto and stays consistent with the
rest of the suite.
2026-06-30 09:37:22 +01:00
cybermaggedon
5cb4f83afa
fix: list-my-workspaces permissions were broken (#1002)
list-my-workspaces has AUTHENTICATED scope, so anyone is permitted
to run the operation.  No specific permission grant is needed.
2026-06-29 09:13:05 +01:00
cybermaggedon
0a828379be
feat: global usernames and rename workspace to default_workspace (#1001)
Users are global entities, not scoped to workspaces. This change:

Track A — Global usernames:
- Change iam_users_by_username to PRIMARY KEY (username), removing
  workspace from the lookup key
- Login looks up username globally, no workspace required
- Username uniqueness is enforced globally, not per-workspace
- Login -w now overrides the JWT workspace (session workspace)
  rather than selecting which user registry to search

Track B — Rename workspace to default_workspace:
- UserRecord.workspace → UserRecord.default_workspace
- Identity.workspace → Identity.default_workspace
- JWT claim "workspace" → "default_workspace"
- IamResponse.resolved_workspace → resolved_default_workspace
- WebSocket auth-ok frame field → default_workspace
- Socket clients read default_workspace from auth-ok
- _user_record_to_dict wire key → default_workspace
- CLI help text and output updated throughout
- Test files updated for renamed fields
2026-06-25 16:34:31 +01:00
cybermaggedon
16f8cfd972
fix: use envelope workspace for mux authorisation, not inner request body (#1000)
The mux was extracting the authorisation resource workspace from the
inner request body via registry extractors. But workspace-scoped
services (config, flow, librarian, etc.) receive workspace from the
queue identity, not the message body — the inner workspace field is
a dead field that no service handler reads.

This caused access-denied errors when the inner body's workspace
(e.g. CLI default "default") disagreed with the caller's assigned
workspace, even though the envelope workspace was correct.

Fix: resolve workspace from the envelope only. Split the non-flow
authorisation path by resource level — WORKSPACE ops use the envelope
workspace directly; SYSTEM ops (IAM) still use registry extractors
since they legitimately read operation-specific body fields.
2026-06-25 13:44:57 +01:00
Cyber MacGeddon
a3df4f62bb Merge branch 'master' into release/v2.6 2026-06-22 21:20:29 +01:00
cybermaggedon
09b8a1d347
feat: fine-grained capabilities and enterprise IAM schema extensions (#996)
Split coarse gateway capabilities into fine-grained variants to
support per-operation access control in the enterprise IAM regime.
Add additive schema fields for enterprise group and grant management.

Capability split (gateway registry):
- graph:read -> triples:read, sparql:read, graph-rag:read,
  graph-embeddings:read
- graph:write -> triples:write, graph-embeddings:write,
  entity-contexts:write
- documents:read -> documents:read, document-rag:read,
  document-embeddings:read, entity-contexts:read
- documents:write -> documents:write, document-embeddings:write
- rows:read -> rows:read, nlp-query:read, structured-query:read,
  row-embeddings:read

OSS role definitions expanded to include all new fine-grained
capability names — no behavioral change for OSS deployments.

Schema additions (IamRequest):
- group_id, member_type, member_id for group membership operations
- group (GroupInput), grant (GrantInput) for create/update payloads
- Decoder now handles capability, resource_json, parameters_json,
  authorise_checks (previously missing from translator)

Schema additions (IamResponse):
- group_json, groups_json, members_json, grants_json,
  effective_permissions_json for enterprise operation responses
- Encoder now emits authorise decision fields

Gateway registry:
- 16 enterprise IAM operations registered (create-group,
  add-group-member, add-user-grant, etc.) under iam:admin capability
2026-06-22 20:23:34 +01:00
Jack Colquitt
fa264ded46
Update section titles for Holonic Context Graph (#995) 2026-06-18 19:56:44 -07:00
Jack Colquitt
cae931409a
Update TrustGraph description in README (#994)
Clarified the description of TrustGraph's capabilities and API integrations.
2026-06-18 19:46:25 -07:00
Jack Colquitt
6b0475e315
Revise README for clarity on TrustGraph features (#993)
Updated the README to clarify the concept of holons and the functionality of TrustGraph. Improved the structure and flow of information regarding context management and agent explainability.
2026-06-18 19:42:16 -07:00
Jack Colquitt
cb0ad1a450
Change video link in README (#992)
Updated video source link in README.md.
2026-06-17 17:52:30 -07:00
Jack Colquitt
fc0ecc770a
Format terms as code in README.md (#991) 2026-06-17 16:53:22 -07:00
Jack Colquitt
345da375b1
Document Workspaces, Collections, and Flows in README (#990)
Added section on Workspaces, Collections, and Flows to explain the organizational structure of TrustGraph.
2026-06-17 16:48:22 -07:00
Jack Colquitt
0ba1eeeda0
Enhance README with token consumption details (#989)
Added a note about reducing token consumption in context management.
2026-06-17 16:27:16 -07:00
Jack Colquitt
eb1e38d7d0
Add hyperlink to 'holon' in README.md (#988) 2026-06-17 16:13:41 -07:00
Jack Colquitt
b8770a6005
Update README with new context and features (#987) 2026-06-17 16:08:22 -07:00
Jack Colquitt
28802a644a
Update license badge in README.md (#986) 2026-06-11 20:45:39 -07:00
cybermaggedon
8797d9d9ff feat: per-caller Bearer token auth and new query tools for MCP server (#984)
Replace the broken GATEWAY_SECRET auth (token was sent as a query
parameter, silently ignored by the gateway) with end-to-end Bearer
token forwarding.  Each MCP caller gets a dedicated WebSocket
authenticated via the gateway's in-band first-frame protocol, with
whoami verification on first connect.

Also fix and extend the tool surface:
- embeddings: accept list of texts (was single string)
- triples_query: use Term wire format with compact keys (was legacy
  Value format), add collection and graph parameters
- sparql_query: new tool for SPARQL SELECT/ASK/CONSTRUCT/DESCRIBE
- graphql_query: new tool for structured data (rows) GraphQL queries
- all tools: add optional workspace parameter
2026-06-10 14:11:49 +01:00
cybermaggedon
627c669097
feat: per-caller Bearer token auth and new query tools for MCP server (#984)
Replace the broken GATEWAY_SECRET auth (token was sent as a query
parameter, silently ignored by the gateway) with end-to-end Bearer
token forwarding.  Each MCP caller gets a dedicated WebSocket
authenticated via the gateway's in-band first-frame protocol, with
whoami verification on first connect.

Also fix and extend the tool surface:
- embeddings: accept list of texts (was single string)
- triples_query: use Term wire format with compact keys (was legacy
  Value format), add collection and graph parameters
- sparql_query: new tool for SPARQL SELECT/ASK/CONSTRUCT/DESCRIBE
- graphql_query: new tool for structured data (rows) GraphQL queries
- all tools: add optional workspace parameter
2026-06-10 14:10:43 +01:00
cybermaggedon
8b0619e5d8
Bump version numbers to 2.6 (#983) 2026-06-09 20:03:14 +01:00
cybermaggedon
e3f9f8c357
Merge pull request #982 from trustgraph-ai/master
master -> release/v2.6
2026-06-09 19:46:50 +01:00
Cyber MacGeddon
81d57826c8 Merge branch 'release/v2.5' 2026-06-09 19:43:31 +01:00
Jacob Molz
79d7ef6a90 fix: reject invalid PDF decoder input (#977) 2026-06-09 16:37:39 +01:00
Jacob Molz
28a51c244f
fix: reject invalid PDF decoder input (#977) 2026-06-09 16:37:10 +01:00
Cyber MacGeddon
fa5ebe2393 Merge branch 'release/v2.5' 2026-06-09 16:34:20 +01:00
cybermaggedon
e1c9351454
fix: update row query tests to mock async_execute_paged and async_scan (#979)
The query service now uses async_execute_paged (indexed path) and
async_scan (scan path) instead of async_execute. Tests were mocking
the old function, causing them to hang indefinitely.
2026-06-09 16:29:32 +01:00
cybermaggedon
dbc21c0bb9
fix: structured data query and auth fixes (#978)
- Pass auth token to schema discovery and descriptor generation in
  tg-load-structured-data, fixing 401 errors with IAM enabled
- Fix row query pagination: replace single-page async_execute with
  async_scan that streams pages and applies filters without
  materialising the full result set (OOM on large datasets)
- Add missing filter operators (not, startsWith, endsWith, not_in)
  to row query post-filter matching
- Fall back to scan path when an indexed field is queried with an
  empty string value, since empty index values are not stored
- Revert top-level indexes array support — the current table schema
  overwrites rows with duplicate index values, so only primary_key
  fields are safe to index until the schema is redesigned
2026-06-08 15:22:11 +01:00
cybermaggedon
08bfec1539
fix: wire replication params through YAML/params path for Cassandra and Qdrant (#976)
resolve_cassandra_config did not accept replication_factor as a kwarg,
so cassandra_replication_factor from YAML params was silently ignored
by all 6 callers. Add the kwarg and pass it from every caller.

Same fix for Qdrant: 3 writers now pass qdrant_replication_factor and
qdrant_shard_number from params.

Add tests covering the params path for both helpers.
2026-06-04 12:36:36 +01:00
cybermaggedon
4913f8c2eb
feat: data store replication configuration and TLS upgrade (#975)
- Add centralised qdrant_config.py helper with env-var fallback for
  QDRANT_URL, QDRANT_API_KEY, QDRANT_REPLICATION_FACTOR, QDRANT_SHARD_NUMBER
- Update all 6 Qdrant processors to use the helper; writers pass
  replication_factor and shard_number to create_collection
- Fix hardcoded Cassandra replication_factor=1 in cassandra_kg.py,
  write.py, and sparql_cassandra.py to respect CASSANDRA_REPLICATION_FACTOR
- Upgrade Cassandra TLS from deprecated PROTOCOL_TLSv1_2 to
  ssl.create_default_context() across all connectors
2026-06-04 11:49:29 +01:00
Jack Colquitt
97453d9b83
Change project title to 'The semantic deployment platform' (#968)
Updated the project title in the README.
2026-06-01 14:08:30 -07:00
Jack Colquitt
6dfa47aac8
Revise README for semantic infrastructure terminology (#962)
Updated the README to reflect changes in terminology and improve clarity regarding the platform's features.
2026-05-30 17:07:19 -07:00
Cyber MacGeddon
dcee842455 Merge branch 'release/v2.5' 2026-05-28 11:26:43 +01:00
cybermaggedon
36eadbda3a
Merge pull request #953 from trustgraph-ai/release/v2.5
release/v2.5 -> master
2026-05-26 15:01:44 +01:00
121 changed files with 8577 additions and 2446 deletions

View file

@ -22,7 +22,7 @@ jobs:
uses: actions/checkout@v3
- name: Setup packages
run: make update-package-versions VERSION=2.5.999
run: make update-package-versions VERSION=2.6.999
- name: Setup environment
run: python3 -m venv env

218
README.dev-install.md Normal file
View file

@ -0,0 +1,218 @@
# TrustGraph Developer Install Guide
A guided installer that gets TrustGraph running locally in a single
command. It detects your hardware, recommends an LLM backend, installs
missing prerequisites, runs the test suite, generates a compose deployment,
starts the stack, and opens the Workbench UI.
> **macOS only.** This installer has only been tested on macOS. If you are
> on Linux or Windows, use the standard docker-compose / podman-compose
> installation instructions instead.
## Quick start
```bash
./install_trustgraph.sh
```
The installer walks you through each step interactively. When it finishes,
the Workbench UI opens at `http://localhost:8888` and the API gateway is
available at `http://localhost:8088/`.
## Prerequisites
The installer checks for these and offers to install any that are missing
(via Homebrew):
- **Python 3** with venv support
- **Node.js / npx** (drives the `@trustgraph/config` deployment generator)
- **Docker** (with Compose) or **Podman** (with podman-compose)
- **curl** and **unzip**
- **Ollama** (only if you choose local LLMs)
The installer can also launch Docker Desktop or the Ollama app for you if
they are installed but not running.
## What the installer does
1. **Detects hardware** -- OS, architecture, CPU cores, memory, and GPU.
2. **Recommends an LLM mode** -- `ollama` for machines with >= 16 GB RAM and
a GPU or >= 8 cores; `openai` otherwise.
3. **Collects configuration** -- API key, LLM provider, model choices,
install directory. Answers are saved to
`<install-dir>/trustgraph-installer.env` and reused on subsequent runs.
4. **Checks and installs prerequisites** -- Python, Node/npx, Docker or
Podman, Ollama (if selected).
5. **Downloads Ollama models** (if using Ollama) -- chat model
(`granite4:350m` by default) and embeddings model (`mxbai-embed-large`).
6. **Creates a Python venv** and installs the local TrustGraph packages into
it, along with NLTK data and tiktoken caches.
7. **Runs the full pytest suite** against the local source tree.
8. **Runs `npx @trustgraph/config`** -- the existing interactive config
wizard that produces a `deploy.zip` with a compose file.
9. **Starts the compose stack** and waits for the API gateway to respond.
10. **Bootstraps IAM** and verifies the API key authenticates.
11. **Opens the Workbench UI** in your default browser.
## Command-line options
| Option | Description |
|---|---|
| `--install-dir PATH` | Directory for deployment files (default: `./trustgraph-deploy`) |
| `--api-url URL` | API gateway URL for health checks (default: `http://localhost:8088/`) |
| `--ui-url URL` | Workbench UI URL to open (default: `http://localhost:8888`) |
| `--use-existing-compose FILE` | Skip config generation and start this compose file directly |
| `--skip-tests` | Do not run the pytest suite |
| `--no-launch` | Do not open the Workbench UI at the end |
| `--non-interactive` | Accept all defaults without prompting |
| `--yes` | Auto-accept confirmation prompts |
| `--fresh` | Remove installer-managed files before generating a new deployment |
| `--remove-all` | Uninstall: stop containers, remove compose volumes, delete installer files |
| `--dry-run` | Print detected hardware and planned defaults, then exit |
| `-h`, `--help` | Show the built-in help text |
## Environment variables
These override the interactive prompts when set:
| Variable | Purpose |
|---|---|
| `TRUSTGRAPH_TOKEN` | Admin/bootstrap API key (must start with `tg_`) |
| `TRUSTGRAPH_URL` | API gateway URL |
| `TRUSTGRAPH_UI_URL` | Workbench UI URL |
| `OPENAI_TOKEN` | OpenAI-compatible API key |
| `OPENAI_BASE_URL` | OpenAI-compatible base URL |
| `OLLAMA_HOST` / `OLLAMA_BASE_URL` | Ollama service URL |
| `OLLAMA_MODEL` | Ollama chat model (default: `granite4:350m`) |
| `OLLAMA_EMBEDDINGS_MODEL` | Ollama embeddings model (default: `mxbai-embed-large`) |
| `TG_INSTALL_DIR` | Override the install directory |
| `TG_VENV_DIR` | Override the Python venv location |
| `TG_NLTK_DATA_DIR` | Override the NLTK data directory |
| `TIKTOKEN_CACHE_DIR` | Override the tiktoken cache directory |
| `TG_HEALTH_TIMEOUT` | Seconds to wait for the API gateway (default: 240) |
## Choosing an LLM mode
### OpenAI (or any OpenAI-compatible provider)
Best when you already have an API key or are running against a remote
endpoint. The installer asks for a base URL and an API key.
```bash
OPENAI_TOKEN=sk-... ./install_trustgraph.sh
```
### Ollama (local models)
Best on machines with enough RAM to run a small model. The installer detects
locally installed Ollama models and offers to pull missing ones. It uses
`host.docker.internal` so the Docker containers can reach the host-side
Ollama service.
```bash
./install_trustgraph.sh # choose "ollama" when prompted
```
### None
Start the platform without an LLM. Agent and RAG features will not work
until you configure one later through the Workbench.
## Saved answers and re-running
The installer saves your answers to
`<install-dir>/trustgraph-installer.env`. On the next run it loads those
answers as defaults, so you can re-run with a single Enter through each
prompt.
To start completely fresh:
```bash
./install_trustgraph.sh --fresh
```
This stops any running containers (keeping Docker volumes), removes
installer-managed files, and re-runs the full flow.
## Using an existing compose file
If you already have a compose file from the config tool or another source:
```bash
./install_trustgraph.sh --use-existing-compose path/to/docker-compose.yaml
```
This skips the config wizard and `npx` prerequisite check, and goes straight
to starting the stack.
## Non-interactive / CI usage
```bash
TRUSTGRAPH_TOKEN=tg_my-token \
OPENAI_TOKEN=sk-... \
./install_trustgraph.sh --non-interactive --yes --skip-tests
```
In non-interactive mode the installer uses defaults for every prompt. Pair
with `--yes` to auto-accept confirmation prompts and `--skip-tests` if you
want a faster run.
## Dry run
Preview what the installer would do without making any changes:
```bash
./install_trustgraph.sh --dry-run
```
This prints the detected hardware, recommended LLM mode, and planned
install paths, then exits.
## Uninstalling
```bash
./install_trustgraph.sh --remove-all
```
This stops containers, removes compose-managed volumes, and deletes
installer-managed files (venv, deploy output, logs, saved answers). It does
**not** remove Docker/Podman itself, container images, Ollama, or Ollama
models.
## Troubleshooting
### Logs
All long-running operations write logs to `<install-dir>/logs/`. Key files:
- `pytest.log` -- test suite output
- `compose-up.log` -- docker compose output
- `iam-bootstrap.log` -- IAM bootstrap output
- `ollama-pull-*.log` -- Ollama model downloads
- `pip-*.log` -- Python package installs
- `brew-install-*.log` -- Homebrew installs
### API key rejected after reinstall
If the API gateway returns 401/403 with your saved key, the compose volumes
likely contain IAM data from a previous install with a different key. Run:
```bash
./install_trustgraph.sh --remove-all
./install_trustgraph.sh
```
This clears the old volumes and starts fresh.
### Ollama not reachable from containers
The Ollama base URL should use `host.docker.internal` instead of
`localhost` so that containers running in Docker Desktop can reach the
host-side Ollama service. The installer sets this automatically; if you
override `OLLAMA_HOST`, make sure the URL is reachable from inside the
container network.
### Docker daemon not running
The installer detects Docker Desktop and offers to start it. If that
doesn't work, start Docker Desktop manually and re-run the installer.

284
README.md
View file

@ -3,52 +3,97 @@
<img src="TG-fullname-logo.svg" width=100% />
[![PyPI version](https://img.shields.io/pypi/v/trustgraph.svg)](https://pypi.org/project/trustgraph/) [![License](https://img.shields.io/github/license/trustgraph-ai/trustgraph?color=blue)](LICENSE) ![E2E Tests](https://github.com/trustgraph-ai/trustgraph/actions/workflows/release.yaml/badge.svg)
[![PyPI version](https://img.shields.io/pypi/v/trustgraph.svg)](https://pypi.org/project/trustgraph/) ![License](https://img.shields.io/badge/license-Apache%202.0-blue) ![E2E Tests](https://github.com/trustgraph-ai/trustgraph/actions/workflows/release.yaml/badge.svg)
[![Discord](https://img.shields.io/discord/1251652173201149994
)](https://discord.gg/sQMwkRz5GX) [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/trustgraph-ai/trustgraph)
[**Website**](https://trustgraph.ai) | [**Docs**](https://docs.trustgraph.ai) | [**YouTube**](https://www.youtube.com/@TrustGraphAI?sub_confirmation=1) | [**Configuration Terminal**](https://config-ui.demo.trustgraph.ai/) | [**Discord**](https://discord.gg/sQMwkRz5GX) | [**Blog**](https://blog.trustgraph.ai/subscribe)
[**Website**](https://trustgraph.ai) | [**Docs**](https://docs.trustgraph.ai) | [**YouTube**](https://www.youtube.com/@TrustGraphAI?sub_confirmation=1) | [**Configuration Terminal**](https://config-ui.demo.trustgraph.ai/) | [**Discord**](https://discord.gg/yUWRkfbD) | [**Blog**](https://blog.trustgraph.ai/subscribe)
<a href="https://trendshift.io/repositories/17291" target="_blank"><img src="https://trendshift.io/api/badge/repositories/17291" alt="trustgraph-ai%2Ftrustgraph | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
# The agent runtime platform
# Write context once. Run agents anywhere.
</div>
TrustGraph is an agent runtime platform built around context graphs — structured, queryable representations of your domain knowledge that ground every agent query in verified, explainable facts in private deployments with sovereign control. The platform is the full stack for agentic systems: context graphs, memory, retrieval, orchestration, and inference for precision-critical agent workloads.
Stop rebuilding context from scratch. TrustGraph treats context as a holon — a modular, independent whole that naturally snaps into a larger domain-wide intelligence layer. By deploying context as holonic context graphs, TrustGraph powers multi-tenant agent workflows, dramatically reduces token consumption, and aligns with semantic web standards (RDF, OWL, SKOS, SHACL). Version your context, share it across teams, and scale with full provenance.
The platform:
- [x] Multi-model and multimodal database system
- [x] Tabular/relational, key-value
- [x] Document, graph, and vectors
- [x] Images, video, and audio
- [x] Context Graph engine
- [x] Automated entity and relationship extraction
- [x] Ontology-driven graph construction
- [x] Graph-grounded retrieval for explainable outputs
- [x] Automated data ingest and loading
- [x] Quick ingest with semantic similarity retrieval
- [x] Ontology structuring for precision retrieval
- [x] Out-of-the-box RAG pipelines
- [x] DocumentRAG
- [x] GraphRAG
- [x] OntologyRAG
- [x] 3D GraphViz for exploring context
- [x] Fully Agentic System
- [x] Single or Multi Agent
- [x] ReAct, Plan-then-Execute, and Supervisor patterns
- [x] MCP integration
- [x] Run anywhere
- [x] Deploy locally with Docker
- [x] Deploy in cloud with Kubernetes
- [x] Support for all major LLMs
- [x] API support for Anthropic, Cohere, Gemini, Mistral, OpenAI, and others
- [x] Model inferencing with vLLM, Ollama, TGI, LM Studio, and Llamafiles
- [x] Developer friendly
- [x] REST API [Docs](https://docs.trustgraph.ai/reference/apis/rest.html)
- [x] Websocket API [Docs](https://docs.trustgraph.ai/reference/apis/websocket.html)
- [x] Python API [Docs](https://docs.trustgraph.ai/reference/apis/python)
- [x] CLI [Docs](https://docs.trustgraph.ai/reference/cli/)
## What TrustGraph Does
TrustGraph is a complete holonic context harness for all LLMs. It provides the full infrastructure layer underneath your agents: knowledge ingestion, structured storage, graph-grounded retrieval, agent orchestration, and a full LLM inferencing stack.
TrustGraph relies on absolutely no 3rd party services aside from optional API integrations to cloud-hosted LLMs. Whether you are using Anthropic's or OpenAI's API, or self-hosting Qwen3.7 via vLLM, TrustGraph handles it all with pre-built API connectors and a full LLM inferencing stack to enrich the models with a sovereign, private holonic system that grounds your agents in reality.
## The Problem: Why Agents Break
When you build an AI agent today, you spend most of your time fighting context:
- **RAG retrieves fragments, not meaning**. Chunks of text have no structure. Relationships between facts are invisible. Your agent guesses at the connections.
- **Context is disposable**. What the agent learned in one session is gone in the next. There is no persistent, structured knowledge layer underneath.
- **Answers aren't traceable**. You can't explain why the agent said what it said, which means you can't trust it in production.
- **Knowledge can't be reused**. You rebuild the same context pipelines for every new project, every new agent, every new environment.
These aren't retrieval problems. They are structural problems. Context needs to be organized, versioned, and composable — exactly the way software infrastructure is.
## The Solution: A Holonic Context System
The philosopher Arthur Koestler coined the word [holon](https://en.wikipedia.org/wiki/Holon_(philosophy)) to describe something that is simultaneously a whole in itself and a part of something larger. A fact is whole. It is also part of a domain. A domain is whole. It is also part of an organization's knowledge.
AI agents break down because this holonic structure is never built. Context gets shoved into flat text windows, scattered across vector stores, or hardwired into one-off prompts. Facts lose their relationships.
TrustGraph solves this by organizing your domain into holonic context graphs. Entities, relationships, and evidence are treated as first-class objects. Every agent query is grounded against these holons—marrying symbolic graph structures with vector embeddings. Every answer carries provenance. Every fact is traceable.
## Context Cores: Knowledge as a First-Class Citizen
A Context Core is the deployable unit of knowledge in TrustGraph. It packages everything an agent needs to reason reliably over a domain into a single, portable artifact.
### What's inside a Context Core
- **Ontology** — your domain schema and entity mappings
- **Holon** — entities, relationships, and supporting evidence
- **Embeddings** — vector indexes for fast semantic entry-point lookup
- **Provenance** — where every fact came from, when, and how it was derived
- **Retrieval policies** — traversal rules, freshness controls, authority ranking
Context Cores decouple what agents know from how agents are deployed. Build once. Run in Docker locally, Kubernetes in production, or on any cloud. Pin a version. Roll back. Promote across environments. This is context engineering — and it works because knowledge is finally treated like the infrastructure it is.
## Explainability: Trust Your Agents in Production
LLMs are black boxes, and traditional RAG makes it worse. When an agent pulls flat text chunks from a vector store, you have no idea how it connected those fragments to form an answer. You cannot ship agents to production if you can't explain why they said what they said.
### How TrustGraph makes agents explainable:
- **Traceable Reasoning Paths**: Instead of guessing at connections between text chunks, TrustGraph traverses explicit relationship paths in the holonic context graph. You can inspect exactly which entities, relationships, and sub-graphs were pulled into the LLM's context window to generate a given response.
- **Fact-Level Provenance**: Every node and edge in the graph carries strict provenance. When an agent makes a claim, you can trace it back to the exact source document, the time it was ingested, and the extraction method used to derive it.
- **No Black-Box Guesses**: By grounding the LLM in a structured, symbolic graph, you eliminate the hallucinations that occur when models are forced to infer relationships from unstructured text. If a fact isn't in the graph, the agent doesn't use it.
TrustGraph doesn't just give you answers - it gives you the receipt. Every fact is traceable, every connection is visible, and every output is verifiable.
## Workspaces, Collections, and Flows
TrustGraph has a [three-level system](https://docs.trustgraph.ai/overview/workspaces) for organizing and isolating knowledge.
A `Workspace` is the outermost boundary — a fully isolated tenancy scope where all data, users, configuration, and pipelines live independently from every other workspace. Isolation is structural: enforced at the pub/sub queue, storage, and API gateway layers, not by trusting a field in a message body.
Within a workspace, a `Collection` groups related holons, graph structures, embeddings, and documents together — think of it as a dedicated shelf in a library, scoped to a specific domain, project, or customer.
A `Flow` is a running data processing pipeline that defines how raw data moves through ingestion, extraction, structuring, and storage — the assembly line that turns documents into queryable knowledge. Together, the three layers let you run multiple isolated tenants on a single deployment, separate knowledge by domain within each tenant, and process that knowledge through fully configurable pipelines — all without restarting the system or rebuilding your infrastructure.
## The Full Stack
TrustGraph is not a wrapper around a graph database. It is the complete backend for production agentic systems.
- **Holonic context graph engine**: automated entity and relationship extraction, ontology-driven graph construction, graph-grounded retrieval for explainable outputs
- **Multi-model database**: tabular/relational, key-value, document, graph, vectors, images, video, and audio — all managed in Cassandra and S3-compatible Garage
- **Out-of-the-box RAG pipelines**: DocumentRAG, GraphRAG, and OntologyRAG ready to deploy
- **Fully agentic orchestration**: single or multi-agent, ReAct, Plan-then-Execute, Supervisor patterns, and MCP integration
- **3D Knowledge Explorer**: interactive graph visualization with BFS neighborhood extraction and edge pulse animation
- **Automated data ingest**: quick ingest with semantic similarity or ontology-structured precision retrieval
- **Run anywhere**: Docker/Podman locally, Kubernetes in the cloud
All major LLMs — Anthropic, Cohere, Gemini, Mistral, OpenAI, and more via API.
vLLM, Ollama, TGI, LM Studio, and Llamafiles for fully local inferencing.
Verified cloud deployments for Alibaba Cloud, AWS, Azure, GCP, OVHcloud, and Scaleway.
## No API Keys Required
@ -62,12 +107,12 @@ Everything else is included.
- [x] Managed Multi-model storage in [Cassandra](https://cassandra.apache.org/_/index.html)
- [x] Managed Vector embedding storage in [Qdrant](https://github.com/qdrant/qdrant)
- [x] Managed File and Object storage in [Garage](https://github.com/deuxfleurs-org/garage) (S3 compatible)
- [x] Managed High-speed Pub/Sub messaging fabric with [Pulsar](https://github.com/apache/pulsar)
- [x] Managed High-speed Pub/Sub messaging fabric with [Pulsar](https://github.com/apache/pulsar) or [RabbitMQ](https://www.rabbitmq.com/)
- [x] Complete LLM inferencing stack for open LLMs with [vLLM](https://github.com/vllm-project/vllm), [TGI](https://github.com/huggingface/text-generation-inference), [Ollama](https://github.com/ollama/ollama), [LM Studio](https://github.com/lmstudio-ai), and [Llamafiles](https://github.com/mozilla-ai/llamafile)
## Quickstart
There's no need to clone this repo, unless you want to build from source. TrustGraph is a fully containerized app that deploys as a set of Docker containers. To configure TrustGraph on the command line:
No need to clone the repo unless you are building from source. TrustGraph deploys as a set of Docker containers. Configure it on the command line in one step:
```
npx @trustgraph/config
@ -78,44 +123,39 @@ The config process will generate an app config that can be run locally with Dock
- Deployment instructions as `INSTALLATION.md`
<p align="center">
<video src="https://github.com/user-attachments/assets/2978a6aa-4c9c-4d7c-ad02-8f3d01a1c602"
<video src="https://github.com/user-attachments/assets/33434c3c-f586-4610-8bb2-d7b7b586a672"
width="80%" controls></video>
</p>
For a browser based configuration, try the [Configuration Terminal](https://config-ui.demo.trustgraph.ai/).
## Watch What is a Context Graph?
## Watch What is a Holonic Context Graph?
[![What is a Context Graph?](https://img.youtube.com/vi/gZjlt5WcWB4/maxresdefault.jpg)](https://www.youtube.com/watch?v=gZjlt5WcWB4)
## Watch Context Graphs in Action
## Watch Holonic Context Graphs in Action
[![Context Graphs in Action with TrustGraph](https://img.youtube.com/vi/sWc7mkhITIo/maxresdefault.jpg)](https://www.youtube.com/watch?v=sWc7mkhITIo)
## Getting Started with TrustGraph
- [**Getting Started Guides**](https://docs.trustgraph.ai/getting-started)
- [**Using the Workbench**](#workbench)
- [**Developer APIs and CLI**](https://docs.trustgraph.ai/reference)
- [**Deployment Guides**](https://docs.trustgraph.ai/deployment)
## Workbench
## TrustGraph UI
The **Workbench** provides tools for all major features of TrustGraph. The **Workbench** is on port `8888` by default.
<img width="1389" height="961" alt="Image" src="https://github.com/user-attachments/assets/35c9250d-0f01-40cb-9294-1ee8fd9a1b56" />
- **Vector Search**: Search the installed knowledge bases
- **Agentic, GraphRAG and LLM Chat**: Chat interface for agents, GraphRAG queries, or direct to LLMs
- **Relationships**: Analyze deep relationships in the installed knowledge bases
- **Graph Visualizer**: 3D GraphViz of the installed knowledge bases
- **Library**: Staging area for installing knowledge bases
- **Flow Classes**: Workflow preset configurations
- **Flows**: Create custom workflows and adjust LLM parameters during runtime
- **Knowledge Cores**: Manage resuable knowledge bases
- **Prompts**: Manage and adjust prompts during runtime
- **Schemas**: Define custom schemas for structured data knowledge bases
- **Ontologies**: Define custom ontologies for unstructured data knowledge bases
- **Agent Tools**: Define tools with collections, knowledge cores, MCP connections, and tool groups
- **MCP Tools**: Connect to MCP servers
The UI provides tools for all major features of TrustGraph. The UI deploys on port `8888` by default.
- **Agent Console** — Query your agents directly with streaming responses and live explainability event tracking, so you can watch reasoning unfold in real time
- **GraphRAG View** — Interactive graph RAG queries with a visual explainability DAG and inline provenance display, making it easy to see exactly where answers came from
- **Context Explorer** — An interactive 3D context graph explorer with dynamic graph loading, BFS neighborhood extraction, edge pulse animation, and multiple navigation views
- **Document Ingestion** — A complete upload and submission workflow with page and chunk inspection and document structure browsing
- **Ontology Workbench** — A full ontology editor with class and property trees, OWL/XML and Turtle import/export with round-trip fidelity, circular dependency detection, and safe-delete confirmation dialogs
- **Schema Workbench** — Interactive schema management with list, create, edit, and delete operations including field and index management
- **Prompt Editor** — A dedicated prompt editing workflow
## TypeScript Library for UIs
@ -125,134 +165,6 @@ There are 3 libraries for quick UI integration of TrustGraph services.
- [@trustgraph/react-state](https://www.npmjs.com/package/@trustgraph/react-state)
- [@trustgraph/react-provider](https://www.npmjs.com/package/@trustgraph/react-provider)
## Context Cores
Context Cores are how TrustGraph treats context like code. A Context Core is a **portable, versioned bundle of context** that you can ship between projects and environments, pin in production, and reuse across agents. It packages the “stuff agents need to know” (structured knowledge + embeddings + evidence + policies) into a single artifact, so you can treat context like code: build it, test it, version it, promote it, and roll it back. TrustGraph is built to support this kind of end-to-end context engineering and orchestration workflow.
### Whats inside a Context Core
A Context Core typically includes:
- Ontology (your domain schema) and mappings
- Context Graph (entities, relationships, supporting evidence)
- Embeddings / vector indexes for fast semantic entry-point lookup
- Source manifests + provenance (where facts came from, when, and how they were derived)
- Retrieval policies (traversal rules, freshness, authority ranking)
## Tech Stack
TrustGraph provides component flexibility to optimize agent workflows.
<details>
<summary>LLM APIs</summary>
<br>
- Anthropic<br>
- AWS Bedrock<br>
- AzureAI<br>
- AzureOpenAI<br>
- Cohere<br>
- Google AI Studio<br>
- Google VertexAI<br>
- Mistral<br>
- OpenAI<br>
</details>
<details>
<summary>LLM Orchestration</summary>
<br>
- LM Studio<br>
- Llamafiles<br>
- Ollama<br>
- TGI<br>
- vLLM<br>
</details>
<details>
<summary>Multi-model storage</summary>
<br>
- Apache Cassandra<br>
</details>
<details>
<summary>VectorDB</summary>
<br>
- Qdrant<br>
</details>
<details>
<summary>File and Object Storage</summary>
<br>
- Garage<br>
</details>
<details>
<summary>Observability</summary>
<br>
- Prometheus<br>
- Grafana<br>
- Loki<br>
</details>
<details>
<summary>Data Streaming</summary>
<br>
- Apache Pulsar<br>
- RabbitMQ<br>
- Apache Kafka<br>
</details>
<details>
<summary>Clouds</summary>
<br>
- AWS<br>
- Azure<br>
- Google Cloud<br>
- OVHcloud<br>
- Scaleway<br>
</details>
## Observability & Telemetry
Once the platform is running, access the Grafana dashboard at:
```
http://localhost:3000
```
Default credentials are:
```
user: admin
password: admin
```
The default Grafana dashboard tracks the following:
<details>
<summary>Telemetry</summary>
<br>
- LLM Latency<br>
- Error Rate<br>
- Service Request Rates<br>
- Queue Backlogs<br>
- Chunking Histogram<br>
- Error Source by Service<br>
- Rate Limit Events<br>
- CPU usage by Service<br>
- Memory usage by Service<br>
- Models Deployed<br>
- Token Throughput (Tokens/second)<br>
- Cost Throughput (Cost/second)<br>
</details>
## Contributing
[Developer's Guide](https://docs.trustgraph.ai/guides/building/introduction.html)
@ -261,7 +173,7 @@ The default Grafana dashboard tracks the following:
**TrustGraph** is licensed under [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0).
Copyright 2024-2025 TrustGraph
Copyright 2024-2026 TrustGraph
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View file

@ -100,7 +100,6 @@ multi-word subsystems.
| `users:admin` | Assign / remove roles on users within the workspace |
| `keys:self` | Create / revoke / list **own** API keys |
| `keys:admin` | Create / revoke / list **any user's** API keys within the workspace |
| `workspaces:list-own` | List workspaces the caller has access to |
| `workspaces:admin` | Create / delete / disable workspaces (system-level) |
| `iam:admin` | JWT signing-key rotation, IAM-level operations |
| `metrics:read` | Prometheus metrics proxy |
@ -111,7 +110,7 @@ The open-source edition ships three roles:
| Role | Capabilities |
|---|---|
| `reader` | `agent`, `graph:read`, `documents:read`, `rows:read`, `llm`, `embeddings`, `mcp`, `collections:read`, `knowledge:read`, `flows:read`, `config:read`, `keys:self`, `workspaces:list-own` |
| `reader` | `agent`, `graph:read`, `documents:read`, `rows:read`, `llm`, `embeddings`, `mcp`, `collections:read`, `knowledge:read`, `flows:read`, `config:read`, `keys:self` |
| `writer` | everything in `reader` **+** `graph:write`, `documents:write`, `rows:write`, `collections:write`, `knowledge:write` |
| `admin` | everything in `writer` **+** `config:write`, `flows:write`, `users:read`, `users:write`, `users:admin`, `keys:admin`, `workspaces:admin`, `iam:admin`, `metrics:read` |

View file

@ -0,0 +1,541 @@
# GraphRAG Semantic Filter Improvement
## Problem Statement
The GraphRAG semantic filter is observed to be ineffective with certain
LLM models. Smaller models in particular produce poor-quality edge
relevance scores, and there is a suspicion that models trained or
evaluated heavily on non-Roman-script datasets offer lower performance
on the semantic ranking operation.
The root cause is that the current implementation delegates edge
relevance scoring to the LLM via a prompt that asks the model to
assign a 110 relevance score to each knowledge-graph edge. This
task — ranking structured triples for relevance to a natural-language
query — is not well covered in standard LLM evaluation suites, so
model benchmark scores are not predictive of performance on this
operation. The result is that GraphRAG quality varies unpredictably
across model choices, undermining confidence in the pipeline.
Beyond model variability, the LLM scoring step has further problems:
- **Cost and latency.** The LLM call consumes tokens and adds
latency to every query, yet its output is unreliable. Even when
the model performs well, the cost is disproportionate for what is
fundamentally a ranking operation.
- **Subjective scoring scale.** The 110 relevance scale gives the
model no objective criteria for what constitutes a 5 versus a 7.
Different models interpret the scale differently, and even the same
model can produce inconsistent scores across runs.
- **Redundancy with the embedding pre-filter.** The pipeline already
contains a cosine-similarity stage that ranks edges by semantic
relevance using embeddings. The LLM scoring step is a second
filter applied on top of this, and it is not clear that it adds
enough value to justify the additional cost and risk of
degradation.
### Industry context
Semantic ranking is rigorously evaluated on dedicated benchmarks such
as MTEB (Massive Text Embedding Benchmark) and BEIR (Benchmarking
Information Retrieval), which test retrieval and reranking across
diverse domains. The current TrustGraph approach — prompting a
general-purpose LLM to score and rank documents (the "listwise"
approach) — is known to be poorly optimized for this task. It
suffers from positional bias, formatting failures, and
inconsistency at scale.
The industry standard for semantic ranking has moved to
cross-encoder models: lightweight, purpose-built models that take a
querydocument pair as input and produce a single relevance score.
These models are fine-tuned on millions of relevance-labelled pairs
and dominate retrieval benchmarks. They are fast, deterministic,
and do not require an LLM inference call.
## Architecture
### Cross-encoder service
A new request/response service that exposes a generic semantic
ranking API. The service is not specific to GraphRAG — it is a
reusable building block for any component that needs to rank text
by relevance.
The service interface is pluggable. Alternative implementations
can be swapped in behind the same API.
**Packaging options considered:**
- *`sentence-transformers`.* Full-featured, widely used.
However, it pulls in PyTorch (~2 GB), making containers
very large. Tested at ~1.8 seconds for 2200 edges.
- *`optimum.onnxruntime`.* ONNX-based inference. Still
depends on PyTorch at import time despite using ONNX for
inference. Tested at ~4.2 seconds for 2200 edges.
- *`flashrank`.* Lightweight wrapper around ONNX Runtime
with a clean API (`Ranker`, `RerankRequest`). No PyTorch
dependency. Tested at ~4.4 seconds for 2200 edges.
- *Pure `onnxruntime` + `tokenizers`.* Leanest option
(~200 MB total). Requires manual tokenisation, padding,
and numpy array management — more boilerplate to maintain.
- *External API (e.g. Cohere Rerank).* No local model at
all. Adds network latency and an external dependency.
**Decision:** `flashrank` for the initial implementation.
No PyTorch dependency, clean API, comparable performance.
The pluggable interface allows swapping to another backend
later.
**Request:**
- `queries` — list of `{id, text}` objects. In the GraphRAG use
case these are the concepts extracted from the user's question.
- `documents` — list of `{id, text}` objects. In the GraphRAG
use case these are the candidate knowledge-graph edges
represented as text.
- `limit` — integer. Maximum number of results to return.
**Scoring:**
The service produces the cartesian product of all querydocument
pairs and scores each pair through the cross-encoder model. For
each document, the maximum score across all queries is taken as the
document's relevance score. Documents are then ranked by this
score and the top `limit` results are returned.
**Response:**
A list of the top `limit` results, each containing:
- `document_id` — the ID of the matched document.
- `query_id` — the ID of the query (concept) that produced the
highest score for this document.
- `score` — the relevance score.
Including `query_id` in the response supports the explainability
interface: it records that an edge was selected because it is
related to a specific concept.
### Integration
The cross-encoder service follows the standard TrustGraph service
integration pattern:
- **Base package (trustgraph-base).** Schema definitions for the
cross-encoder request/response messages. A client class that
other components (e.g. GraphRAG) can use to call the
cross-encoder service. Message translator registration so the
pub/sub layer can serialise/deserialise the messages.
- **Flow package (trustgraph-flow).** The cross-encoder service
implementation itself — loads the model, listens for requests,
scores pairs, returns results. Flow definition support so the
cross-encoder can be introduced into a processing flow via the
standard flow configuration. `flashrank` is added as a
dependency of `trustgraph-flow`. The service runs in its own
container.
- **API gateway.** A gateway endpoint that routes cross-encoder
requests from the HTTP API to the service over pub/sub and
returns the response.
- **CLI tool.** A command-line utility
(e.g. `tg-invoke-cross-encoder`) that calls the gateway
endpoint for manual testing and debugging.
### Current GraphRAG pipeline
The current pipeline follows these steps:
1. **Concept extraction.** An LLM prompt extracts key concepts
from the user's query.
2. **Graph exploration.** Seed entities are found via embedding
similarity. A subgraph is built by multi-hop traversal from
the seed entities (up to `max_path_length` hops, capped at
`max_subgraph_size` edges).
3. **Embedding pre-filter.** Each edge is embedded as
`"subject, predicate, object"` and scored by cosine similarity
against the concept embeddings. The top `edge_score_limit`
(default 30) edges are kept.
4. **LLM edge scoring.** The `kg-edge-scoring` prompt asks the
LLM to assign a 110 relevance score to each remaining edge.
The top `edge_limit` (default 25) edges are kept.
5. **LLM edge reasoning.** The `kg-edge-reasoning` prompt asks
the LLM to explain why each selected edge is relevant to the
query. Used for the explainability interface.
6. **Document tracing.** Selected edges are traced back to their
source documents in the librarian. Runs concurrently with
step 5.
7. **Synthesis.** The `kg-synthesis` prompt generates the final
answer from the selected edges and source document metadata.
### Potential improvements
#### Replace LLM edge scoring with cross-encoder (step 4)
The LLM edge scoring step is replaced by a call to the
cross-encoder service. The candidate edges are the documents and
`edge_limit` is the limit. This is a direct substitution: faster,
cheaper, deterministic, and more reliable across model choices.
The LLM `kg-edge-scoring` prompt is retired.
**Cross-encoder query input: concepts vs. raw query.** There are
two options for what to use as the cross-encoder queries:
- *Option A: Raw user query.* Pass the original question as a
single query string. Simpler, no dependency on concept
extraction. However, raw queries contain noise words and
conversational phrasing that do not match well against the
structured vocabulary of knowledge-graph edges. A single query
also means every edge competes against the full question — a
partial match on one aspect is diluted.
- *Option B: Extracted concepts.* Pass the concepts from step 1
as separate queries. The concepts are distilled, focused terms
that are closer to the language of the edges. With multiple
concepts as independent queries, the cross-encoder scores each
edge against each concept separately, giving better coverage —
an edge only needs to match one concept well to be selected.
The trade-off is a dependency on the LLM concept extraction
step, but this is already in the pipeline and is a lightweight,
reliable LLM call.
**Decision:** Option B — use extracted concepts. The concept
extraction is fast, and the resulting terms produce better
cross-encoder matches against structured triples.
#### Edge text representation
The current embedding pre-filter represents each edge as
`"subject, predicate, object"`. Two changes:
- **Drop commas.** Commas add tokenisation noise without semantic
value.
- **Direction-aware text.** The reranker text should highlight
the *new* information relative to the traversal direction.
The frontier entity is already known context — repeating it
adds noise and, when traversing from an object node, causes
many edges to produce identical reranker text (e.g. 18
products sharing the same `hasSubcategory Processors` triple
all collapse to the same string when the subject is dropped).
The text is constructed based on which position the frontier
entity occupied in the triple:
- **From subject** (s=entity): `"{predicate} {object}"`
the subject is known, predicate and object are new.
- **From object** (o=entity): `"{subject} {predicate}"`
the object is known, subject and predicate are new.
- **From predicate** (p=entity): `"{subject} {object}"`
the predicate is known, subject and object are new.
This eliminates the duplicate-text problem that arises when
traversing inward from a shared object node, and gives the
cross-encoder a more informative signal at every hop.
#### Remove the embedding pre-filter (step 3)
The embedding pre-filter was introduced to reduce the number of
edges before the expensive LLM scoring call. With the
cross-encoder replacing the LLM call, this cost equation changes.
**Arguments for removal:**
- The cross-encoder is fast enough to score the full subgraph
directly. In testing, 2200 edges scored in ~1.8 seconds; at
the default `max_subgraph_size` of 150 edges, scoring takes
a fraction of a second.
- The pre-filter is a weaker version of what the cross-encoder
does. Bi-encoder cosine similarity embeds the query and
document independently and compares vectors; the cross-encoder
processes both texts together through the full transformer,
giving it much better relevance judgement. Running a weaker
filter before a stronger one adds latency without improving
quality.
- Removing it eliminates an embedding service call (two batches:
concepts + edges) and the associated latency.
**Arguments for keeping it:**
- If the subgraph is very large (thousands of edges), the
cross-encoder's linear scaling could become a bottleneck.
The pre-filter would act as a safety valve.
- The embedding call is cheap compared to an LLM call, so the
overhead is modest.
**Decision:** Remove the pre-filter. The `max_subgraph_size`
parameter (default 150) already caps the number of edges entering
this stage, so the cross-encoder will not face an unbounded
workload. If very large subgraphs become a concern in future,
the pre-filter can be reintroduced or `max_subgraph_size` can be
tuned.
#### Iterative graph traversal with cross-encoder filtering
The current pipeline performs graph exploration and edge filtering
as separate phases: first build the full subgraph (up to
`max_path_length` hops), then score and filter edges. An
alternative is to interleave traversal and filtering — at each
hop, use the cross-encoder to select relevant edges before
expanding further.
**Option A: Big-bang traversal then filter.** Traverse the full
subgraph up to `max_path_length` hops from the seed entities,
collecting all edges up to `max_subgraph_size`. Then
cross-encode the entire result to select the top edges.
- Simple to implement — the current traversal logic is largely
unchanged.
- Produces large, unfocused subgraphs. Irrelevant branches are
explored and scored even though they will be discarded.
- Poorly suited to multi-hop reasoning. For a query about
Voyager 1, the subgraph includes Voyager 2's edges because
they are within hop distance, and the filter must then
separate them.
**Option B: Iterative hop-and-filter.** At each hop:
1. Retrieve all edges one hop from the current frontier nodes.
2. Cross-encode these edges against the query concepts.
3. Select the top relevant edges.
4. The target nodes of the selected edges become the frontier
for the next hop.
5. Repeat up to `max_path_length` hops.
The final set of selected edges across all hops is the input to
synthesis.
- **Guided exploration.** Each hop focuses the search by
pruning irrelevant branches before expanding further. The
working set stays small and relevant at every step.
- **Multi-hop reasoning works naturally.** Following
"Voyager 1 → has-event → crossed the heliopause" succeeds
because each hop is individually relevant and leads to the
next.
- **Smaller total workload.** Fewer edges are scored overall
because irrelevant branches are never expanded.
- **Trade-off: greedy pruning.** An edge discarded at hop 1
cannot lead to relevant edges at hop 2. This is inherent in
any bounded traversal, and the cross-encoder is better
equipped to make this relevance judgement than a blind hop
limit.
- **Trade-off: sequential latency.** Hops cannot be
parallelised since each depends on the previous. However,
each cross-encoder call on a small edge set is very fast
(sub-second for typical working sets).
**Decision:** Option B — iterative hop-and-filter. The guided
traversal produces more focused subgraphs and supports multi-hop
reasoning, which is a significant quality improvement over the
current approach.
#### Replace LLM edge reasoning with cross-encoder metadata (step 5)
The current `kg-edge-reasoning` prompt asks the LLM to explain why
each edge is relevant. With the cross-encoder now making the
selection, this explanation would be a post-hoc fabrication — the
LLM was not involved in the decision.
- *Option A: Keep LLM reasoning.* Generates natural-language
explanations but they are not grounded in the actual selection
process. Adds an LLM call per query.
- *Option B: Record cross-encoder metadata.* The cross-encoder
already returns the matched concept and score for each selected
edge. Use this directly as the explanation.
**Decision:** Option B. The cross-encoder metadata is the true
reason the edge was selected. The `kg-edge-reasoning` prompt is
retired.
#### Explainability interface update
The explainability interface uses a `Focus` entity containing
`EdgeSelection` sub-entities. Each `EdgeSelection` currently
carries an `edge` (the quoted triple) and a `reasoning` field
(free-text LLM prose), stored as `tg:reasoning` in the
provenance graph.
With the cross-encoder replacing LLM reasoning, the
`EdgeSelection` type gains two new predicates and drops one:
- **Remove** `tg:reasoning` — no longer produced.
- **Add** `tg:concept` — the concept text that produced the
highest cross-encoder score for this edge.
- **Add** `tg:score` — the cross-encoder relevance score.
This is an evolution of the existing `EdgeSelection` type, not a
new entity type. The edge selection sub-entities currently have
no `rdf:type` declared; a new `tg:EdgeSelection` type should be
added so that consumers can identify them in the provenance
graph. The `Focus` entity and its relationship to `Exploration`
are unchanged.
The `Focus` entity's token-usage metadata (`tg:inToken`,
`tg:outToken`, `tg:llmModel`) no longer applies since there is
no LLM call. These fields are dropped from the Focus entity.
### Proposed pipeline
1. **Concept extraction.** Unchanged — LLM extracts key concepts
from the user's query.
2. **Seed entity lookup.** Find seed entities via embedding
similarity against the extracted concepts.
3. **Iterative hop-and-filter.** For each hop up to
`max_path_length`:
a. Retrieve all edges one hop from the current frontier nodes.
b. Represent each edge using direction-aware text: from a
subject node use `"{predicate} {object}"`, from an object
node use `"{subject} {predicate}"`, from a predicate node
use `"{subject} {object}"`.
c. Score edges against the extracted concepts using the
cross-encoder service.
d. Select the top relevant edges. The target nodes of the
selected edges become the frontier for the next hop.
4. **Document tracing.** Selected edges are traced back to source
documents.
5. **Synthesis.** The `kg-synthesis` prompt generates the final
answer from the selected edges and source document metadata.
### Implementation order
1. Cross-encoder service with full integration (base schema,
flow service, gateway endpoint, CLI tool).
2. GraphRAG pipeline changes (iterative hop-and-filter,
edge representation, remove pre-filter).
3. Explainability update (`tg:EdgeSelection` type, concept
and score predicates, retire `tg:reasoning`).
4. Retire `kg-edge-scoring` and `kg-edge-reasoning` prompts.
5. Update `tg-invoke-graph-rag` and `tg-show-explain-trace`
to display the new metadata. Use these as the main
end-to-end test.
6. Fix any failing unit tests, then add new tests as needed.
7. Write guidance for UX devs to update the UI for the new
explainability predicates.
## UX developer guidance
This section describes the changes to the explainability interface
that affect frontend rendering of GraphRAG Focus events.
### What changed
Edge selection in GraphRAG previously used LLM-based scoring and
reasoning. Each selected edge carried a `tg:reasoning` predicate
with free-text explanation from the LLM. This has been replaced
by a cross-encoder reranker that scores edges against query
concepts. The explainability data now carries structured metadata
instead of free text.
### Removed
- **`tg:reasoning`** is no longer emitted on edge selection
entities in GraphRAG Focus events. UX code that reads
`edge_sel.reasoning` will get an empty string. Remove any
rendering that displays a "Reasoning" or "Reason" field for
Focus edges.
- The **`kg-edge-scoring`**, **`kg-edge-reasoning`**, and
**`kg-edge-selection`** prompts are retired. Any UX that
references these prompt names should be cleaned up.
### Added
Each edge selection entity within a Focus event now has three
new properties:
| RDF predicate | API field | Type | Description |
|---|---|---|---|
| `rdf:type tg:EdgeSelection` | (type check) | — | Each edge selection entity is now explicitly typed |
| `tg:concept` | `edge_sel.concept` | `str` | The query concept that matched this edge |
| `tg:score` | `edge_sel.score` | `float` or `None` | Cross-encoder relevance score (0.01.0) |
The `tg:edge` predicate (RDF-star quoted triple) is unchanged.
### How to render
The recommended rendering for each selected edge in a Focus event:
```
Edge: (subject_label, predicate_label, object_label)
Concept: <concept> Score: <score formatted to 4 decimal places>
```
Scores near 1.0 indicate high relevance; scores near 0.0 indicate
low relevance. UX could use the score to drive visual indicators
such as colour intensity or a relevance bar.
Edges are not returned in score order — they arrive in traversal
order across hops. If the UX wants to display edges ranked by
relevance, sort by `edge_sel.score` descending.
### API classes (Python)
The `EdgeSelection` dataclass in `trustgraph.api.explainability`
has these fields:
```python
@dataclass
class EdgeSelection:
uri: str
edge: Optional[Dict[str, str]] # {"s": ..., "p": ..., "o": ...}
reasoning: str = "" # Legacy, always empty for new traces
concept: str = "" # Query concept that matched
score: Optional[float] = None # Cross-encoder relevance score
```
These are populated when calling
`ExplainabilityClient.fetch_focus_with_edges()` or when parsing
inline provenance triples from the streaming response.
### WebSocket response format
For inline explainability via the streaming WebSocket, Focus events
arrive as `message_type: "explain"` responses. The `explain_triples`
array contains the edge selection triples. The relevant predicates
in wire format are:
```json
{"s": {"t": "i", "i": "<edge_sel_uri>"},
"p": {"t": "i", "i": "https://trustgraph.ai/ns/concept"},
"o": {"t": "l", "v": "flyby event"}}
{"s": {"t": "i", "i": "<edge_sel_uri>"},
"p": {"t": "i", "i": "https://trustgraph.ai/ns/score"},
"o": {"t": "l", "v": "0.9962"}}
```
Note that `tg:score` is transmitted as a string literal and must
be parsed to a float on the client side.
### Exploration event
The Exploration event's `edge_count` field now reports the number
of edges selected by the cross-encoder across all hops (previously
it reported the total number of edges retrieved before filtering).
The `entities` list continues to report the seed entities found
by vector search.

2603
install_trustgraph.sh Normal file

File diff suppressed because it is too large Load diff

View file

@ -95,10 +95,6 @@ class TestGraphRagIntegration:
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "extract-concepts":
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-scoring":
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-reasoning":
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
return PromptResult(
response_type="text",
@ -113,14 +109,22 @@ class TestGraphRagIntegration:
client.prompt.side_effect = mock_prompt
return client
@pytest.fixture
def mock_reranker_client(self):
"""Mock reranker client for cross-encoder edge filtering"""
client = AsyncMock()
client.rerank.return_value = []
return client
@pytest.fixture
def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client,
mock_triples_client, mock_prompt_client):
mock_triples_client, mock_reranker_client, mock_prompt_client):
"""Create GraphRag instance with mocked dependencies"""
return GraphRag(
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
reranker_client=mock_reranker_client,
prompt_client=mock_prompt_client,
verbose=True
)
@ -167,8 +171,8 @@ class TestGraphRagIntegration:
# 3. Should query triples to build knowledge subgraph
assert mock_triples_client.query_stream.call_count > 0
# 4. Should call prompt four times (extract-concepts + edge-scoring + edge-reasoning + synthesis)
assert mock_prompt_client.prompt.call_count == 4
# 4. Should call prompt twice (extract-concepts + synthesis)
assert mock_prompt_client.prompt.call_count == 2
# Verify final response
response, usage = response

View file

@ -63,11 +63,6 @@ class TestGraphRagStreaming:
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
if prompt_id == "extract-concepts":
return PromptResult(response_type="text", text="")
elif prompt_id == "kg-edge-scoring":
# Edge scoring returns JSONL with IDs and scores
return PromptResult(response_type="text", text='{"id": "abc12345", "score": 0.9}\n')
elif prompt_id == "kg-edge-reasoning":
return PromptResult(response_type="text", text='{"id": "abc12345", "reasoning": "Relevant to query"}\n')
elif prompt_id == "kg-synthesis":
if streaming and chunk_callback:
# Simulate streaming chunks with end_of_stream flags
@ -88,14 +83,23 @@ class TestGraphRagStreaming:
client.prompt.side_effect = prompt_side_effect
return client
@pytest.fixture
def mock_reranker_client(self):
"""Mock reranker client for cross-encoder edge filtering"""
client = AsyncMock()
client.rerank.return_value = []
return client
@pytest.fixture
def graph_rag_streaming(self, mock_embeddings_client, mock_graph_embeddings_client,
mock_triples_client, mock_streaming_prompt_client):
mock_triples_client, mock_reranker_client,
mock_streaming_prompt_client):
"""Create GraphRag instance with streaming support"""
return GraphRag(
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
reranker_client=mock_reranker_client,
prompt_client=mock_streaming_prompt_client,
verbose=True
)

View file

@ -46,7 +46,7 @@ class TestGraphRagStreamingProtocol:
client = AsyncMock()
async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
if prompt_name == "extract-concepts":
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
if streaming and chunk_callback:
@ -63,14 +63,23 @@ class TestGraphRagStreamingProtocol:
client.prompt.side_effect = prompt_side_effect
return client
@pytest.fixture
def mock_reranker_client(self):
"""Mock reranker client for cross-encoder edge filtering"""
client = AsyncMock()
client.rerank.return_value = []
return client
@pytest.fixture
def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client,
mock_triples_client, mock_streaming_prompt_client):
mock_triples_client, mock_reranker_client,
mock_streaming_prompt_client):
"""Create GraphRag instance with mocked dependencies"""
return GraphRag(
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
reranker_client=mock_reranker_client,
prompt_client=mock_streaming_prompt_client,
verbose=False
)
@ -327,7 +336,7 @@ class TestStreamingProtocolEdgeCases:
client = AsyncMock()
async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
if prompt_name == "extract-concepts":
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
if streaming and chunk_callback:
@ -342,10 +351,14 @@ class TestStreamingProtocolEdgeCases:
client.prompt.side_effect = prompt_with_empties
mock_reranker = AsyncMock()
mock_reranker.rerank.return_value = []
rag = GraphRag(
embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[[0.1]]])),
graph_embeddings_client=AsyncMock(query=AsyncMock(return_value=[])),
triples_client=AsyncMock(query=AsyncMock(return_value=[])),
reranker_client=mock_reranker,
prompt_client=client,
verbose=False
)

View file

@ -15,11 +15,20 @@ from openai.types.chat.chat_completion import Choice
from openai.types.completion_usage import CompletionUsage
from trustgraph.model.text_completion.openai.llm import Processor
from trustgraph.model.text_completion.openai.variants import get_variant
from trustgraph.exceptions import TooManyRequests
from trustgraph.base import LlmResult
from trustgraph.schema import TextCompletionRequest, TextCompletionResponse, Error
def _wire_variant(processor):
"""Attach variant methods to a MagicMock processor."""
processor.variant = get_variant("openai")
processor.thinking = "off"
processor._build_kwargs = Processor._build_kwargs.__get__(processor, Processor)
processor._extract_content = Processor._extract_content.__get__(processor, Processor)
@pytest.mark.integration
class TestTextCompletionIntegration:
"""Integration tests for OpenAI text completion service coordination"""
@ -66,6 +75,7 @@ class TestTextCompletionIntegration:
# Add the actual generate_content method from Processor class
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
_wire_variant(processor)
return processor
@ -119,6 +129,7 @@ class TestTextCompletionIntegration:
# Add the actual generate_content method
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
_wire_variant(processor)
# Act
result = await processor.generate_content("System prompt", "User prompt")
@ -247,6 +258,7 @@ class TestTextCompletionIntegration:
processor.max_output = processor_config["max_output"]
processor.openai = mock_openai_client
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
_wire_variant(processor)
processors.append(processor)
# Simulate multiple concurrent requests
@ -354,6 +366,7 @@ class TestTextCompletionIntegration:
processor.max_output = 2048
processor.openai = mock_openai_client
processor.generate_content = Processor.generate_content.__get__(processor, Processor)
_wire_variant(processor)
# Act
await processor.generate_content("System prompt", "User prompt")

View file

@ -11,6 +11,7 @@ from openai.types.chat import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as StreamChoice, ChoiceDelta
from trustgraph.model.text_completion.openai.llm import Processor
from trustgraph.model.text_completion.openai.variants import get_variant
from trustgraph.base import LlmChunk
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
@ -18,6 +19,14 @@ from tests.utils.streaming_assertions import (
)
def _wire_variant(processor):
"""Attach variant methods to a MagicMock processor."""
processor.variant = get_variant("openai")
processor.thinking = "off"
processor._build_kwargs = Processor._build_kwargs.__get__(processor, Processor)
processor._extract_content = Processor._extract_content.__get__(processor, Processor)
@pytest.mark.integration
class TestTextCompletionStreaming:
"""Integration tests for Text Completion streaming"""
@ -69,6 +78,7 @@ class TestTextCompletionStreaming:
processor.generate_content_stream = Processor.generate_content_stream.__get__(
processor, Processor
)
_wire_variant(processor)
return processor
@ -190,6 +200,7 @@ class TestTextCompletionStreaming:
processor.generate_content_stream = Processor.generate_content_stream.__get__(
processor, Processor
)
_wire_variant(processor)
# Act
chunks = []
@ -223,6 +234,7 @@ class TestTextCompletionStreaming:
processor.generate_content_stream = Processor.generate_content_stream.__get__(
processor, Processor
)
_wire_variant(processor)
# Act
chunks = []
@ -258,6 +270,7 @@ class TestTextCompletionStreaming:
processor.generate_content_stream = Processor.generate_content_stream.__get__(
processor, Processor
)
_wire_variant(processor)
# Act & Assert
with pytest.raises(Exception) as exc_info:
@ -295,6 +308,7 @@ class TestTextCompletionStreaming:
processor.generate_content_stream = Processor.generate_content_stream.__get__(
processor, Processor
)
_wire_variant(processor)
# Act
chunks = []
@ -318,6 +332,7 @@ class TestTextCompletionStreaming:
processor.generate_content_stream = Processor.generate_content_stream.__get__(
processor, Processor
)
_wire_variant(processor)
system_prompt = "You are an expert."
user_prompt = "Explain quantum physics."

View file

@ -410,3 +410,56 @@ class TestEdgeCases:
assert hosts == ['mixed-host']
assert username is None # Stays None
assert password == 'mixed-pass'
class TestReplicationFactorParamPath:
def test_explicit_kwarg(self):
with patch.dict(os.environ, {}, clear=True):
_, _, _, _, rf = resolve_cassandra_config(
replication_factor=3,
)
assert rf == 3
def test_kwarg_overrides_env(self):
with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
_, _, _, _, rf = resolve_cassandra_config(
replication_factor=3,
)
assert rf == 3
def test_env_fallback_when_kwarg_none(self):
with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
_, _, _, _, rf = resolve_cassandra_config(
replication_factor=None,
)
assert rf == 5
def test_default_when_no_kwarg_no_env(self):
with patch.dict(os.environ, {}, clear=True):
_, _, _, _, rf = resolve_cassandra_config()
assert rf == 1
def test_params_dict_path(self):
with patch.dict(os.environ, {}, clear=True):
params = {'cassandra_replication_factor': 3}
_, _, _, _, rf = resolve_cassandra_config(
replication_factor=params.get('cassandra_replication_factor'),
)
assert rf == 3
def test_params_dict_overrides_env(self):
with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
params = {'cassandra_replication_factor': 3}
_, _, _, _, rf = resolve_cassandra_config(
replication_factor=params.get('cassandra_replication_factor'),
)
assert rf == 3
def test_params_dict_missing_falls_to_env(self):
with patch.dict(os.environ, {'CASSANDRA_REPLICATION_FACTOR': '5'}, clear=True):
params = {}
_, _, _, _, rf = resolve_cassandra_config(
replication_factor=params.get('cassandra_replication_factor'),
)
assert rf == 5

View file

@ -195,38 +195,6 @@ class TestPromptClientStreamingCallback:
assert callback.call_args_list[0] == call("test", False)
assert callback.call_args_list[1] == call("", True)
@pytest.mark.asyncio
async def test_kg_prompt_passes_parameters_to_callback(self, prompt_client):
"""Test that kg_prompt correctly passes streaming parameters"""
# Arrange
async def mock_request(request, recipient=None, timeout=600):
if recipient:
responses = [
PromptResponse(text="Answer", object=None, error=None, end_of_stream=False),
PromptResponse(text="", object=None, error=None, end_of_stream=True),
]
for resp in responses:
should_stop = await recipient(resp)
if should_stop:
break
prompt_client.request = mock_request
callback = AsyncMock()
# Act
await prompt_client.kg_prompt(
query="What is machine learning?",
kg=[("subject", "predicate", "object")],
streaming=True,
chunk_callback=callback
)
# Assert
assert callback.call_count == 2
assert callback.call_args_list[0] == call("Answer", False)
assert callback.call_args_list[1] == call("", True)
@pytest.mark.asyncio
async def test_document_prompt_passes_parameters_to_callback(self, prompt_client):
"""Test that document_prompt correctly passes streaming parameters"""

View file

@ -0,0 +1,136 @@
import os
import pytest
from unittest.mock import patch
from trustgraph.base.qdrant_config import (
get_qdrant_defaults,
resolve_qdrant_config,
)
class TestGetQdrantDefaults:
def test_defaults_with_no_env_vars(self):
with patch.dict(os.environ, {}, clear=True):
defaults = get_qdrant_defaults()
assert defaults['url'] == 'http://localhost:6333'
assert defaults['api_key'] is None
assert defaults['replication_factor'] == 1
assert defaults['shard_number'] == 1
def test_defaults_from_env(self):
env = {
'QDRANT_URL': 'http://qdrant:6333',
'QDRANT_API_KEY': 'secret',
'QDRANT_REPLICATION_FACTOR': '3',
'QDRANT_SHARD_NUMBER': '5',
}
with patch.dict(os.environ, env, clear=True):
defaults = get_qdrant_defaults()
assert defaults['url'] == 'http://qdrant:6333'
assert defaults['api_key'] == 'secret'
assert defaults['replication_factor'] == 3
assert defaults['shard_number'] == 5
class TestResolveQdrantConfig:
def test_defaults(self):
with patch.dict(os.environ, {}, clear=True):
url, api_key, rf, sn = resolve_qdrant_config()
assert url == 'http://localhost:6333'
assert api_key is None
assert rf == 1
assert sn == 1
def test_explicit_kwargs(self):
with patch.dict(os.environ, {}, clear=True):
url, api_key, rf, sn = resolve_qdrant_config(
url='http://custom:6333',
api_key='key',
replication_factor=3,
shard_number=5,
)
assert url == 'http://custom:6333'
assert api_key == 'key'
assert rf == 3
assert sn == 5
def test_kwargs_override_env(self):
env = {
'QDRANT_URL': 'http://env:6333',
'QDRANT_REPLICATION_FACTOR': '10',
'QDRANT_SHARD_NUMBER': '10',
}
with patch.dict(os.environ, env, clear=True):
url, _, rf, sn = resolve_qdrant_config(
url='http://explicit:6333',
replication_factor=3,
shard_number=5,
)
assert url == 'http://explicit:6333'
assert rf == 3
assert sn == 5
def test_env_fallback_when_kwargs_none(self):
env = {
'QDRANT_URL': 'http://env:6333',
'QDRANT_REPLICATION_FACTOR': '3',
'QDRANT_SHARD_NUMBER': '5',
}
with patch.dict(os.environ, env, clear=True):
url, _, rf, sn = resolve_qdrant_config()
assert url == 'http://env:6333'
assert rf == 3
assert sn == 5
def test_params_dict_path(self):
with patch.dict(os.environ, {}, clear=True):
params = {
'store_uri': 'http://params:6333',
'api_key': 'pkey',
'qdrant_replication_factor': 3,
'qdrant_shard_number': 5,
}
url, api_key, rf, sn = resolve_qdrant_config(
url=params.get('store_uri'),
api_key=params.get('api_key'),
replication_factor=params.get('qdrant_replication_factor'),
shard_number=params.get('qdrant_shard_number'),
)
assert url == 'http://params:6333'
assert api_key == 'pkey'
assert rf == 3
assert sn == 5
def test_params_dict_overrides_env(self):
env = {
'QDRANT_REPLICATION_FACTOR': '10',
'QDRANT_SHARD_NUMBER': '10',
}
with patch.dict(os.environ, env, clear=True):
params = {
'qdrant_replication_factor': 3,
'qdrant_shard_number': 5,
}
_, _, rf, sn = resolve_qdrant_config(
replication_factor=params.get('qdrant_replication_factor'),
shard_number=params.get('qdrant_shard_number'),
)
assert rf == 3
assert sn == 5
def test_params_dict_missing_falls_to_env(self):
env = {
'QDRANT_REPLICATION_FACTOR': '3',
'QDRANT_SHARD_NUMBER': '5',
}
with patch.dict(os.environ, env, clear=True):
params = {}
_, _, rf, sn = resolve_qdrant_config(
replication_factor=params.get('qdrant_replication_factor'),
shard_number=params.get('qdrant_shard_number'),
)
assert rf == 3
assert sn == 5

View file

@ -0,0 +1,54 @@
"""
Unit tests for trustgraph.bootstrap.initialisers.DefaultFlowStart
Verifies the list/start timeouts are configurable and that the
configured values actually reach the flow-client request calls.
"""
from unittest.mock import AsyncMock, MagicMock
import pytest
from trustgraph.bootstrap.initialisers.default_flow_start import (
DefaultFlowStart,
)
def test_default_timeouts():
init = DefaultFlowStart(blueprint="bp")
assert init.list_timeout == 10
assert init.start_timeout == 30
def test_timeout_overrides_are_stored():
init = DefaultFlowStart(blueprint="bp", list_timeout=5, start_timeout=99)
assert init.list_timeout == 5
assert init.start_timeout == 99
@pytest.mark.asyncio
async def test_run_forwards_configured_timeouts():
init = DefaultFlowStart(blueprint="bp", list_timeout=5, start_timeout=99)
# Flow client: list-flows returns no error + empty flow list,
# start-flow returns no error.
flow = MagicMock()
flow.start = AsyncMock()
flow.stop = AsyncMock()
flow.request = AsyncMock(side_effect=[
MagicMock(error=None, flow_ids=[]), # list-flows response
MagicMock(error=None), # start-flow response
])
# Context: workspace "default" exists, hands back our mock flow client.
ctx = MagicMock()
ctx.logger = MagicMock()
ctx.config.keys = AsyncMock(return_value=["default"])
ctx.make_flow_client = MagicMock(return_value=flow)
await init.run(ctx, None, "v1")
calls = flow.request.call_args_list
assert len(calls) == 2
assert calls[0].kwargs["timeout"] == 5
assert calls[1].kwargs["timeout"] == 99

View file

@ -0,0 +1,13 @@
"""Unit tests for trustgraph.bootstrap.initialisers.WorkspaceInit."""
from trustgraph.bootstrap.initialisers.workspace_init import WorkspaceInit
def test_default_iam_timeout():
init = WorkspaceInit()
assert init.iam_timeout == 10
def test_iam_timeout_override_is_stored():
init = WorkspaceInit(iam_timeout=42)
assert init.iam_timeout == 42

View file

@ -129,6 +129,9 @@ class TestBatchTripleQueries:
# 3 queries, alternating results
assert len(result) == 3
# Each result is a (triple, direction) tuple
for triple, direction in result:
assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O)
@pytest.mark.asyncio
async def test_exception_in_one_query_does_not_block_others(self):
@ -153,6 +156,8 @@ class TestBatchTripleQueries:
# 3 queries: 2 succeed, 1 fails → 2 triples
assert len(result) == 2
for triple, direction in result:
assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O)
@pytest.mark.asyncio
async def test_none_results_filtered(self):
@ -176,6 +181,8 @@ class TestBatchTripleQueries:
# 3 queries: 1 returns None, 2 return triples
assert len(result) == 2
for triple, direction in result:
assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O)
@pytest.mark.asyncio
async def test_empty_entities_no_queries(self):
@ -220,6 +227,80 @@ class TestBatchTripleQueries:
assert calls[2].kwargs["p"] is None
assert calls[2].kwargs["o"] == "ent-1"
@pytest.mark.asyncio
async def test_directions_assigned_correctly(self):
"""Each query position should produce the correct direction tag."""
triple = _make_triple("s", "p", "o")
call_count = 0
async def one_triple_each(**kwargs):
nonlocal call_count
call_count += 1
return [triple]
client = AsyncMock()
client.query_stream = one_triple_each
query = _make_query(triples_client=client)
result = await query.execute_batch_triple_queries(
["e1"], limit_per_entity=10
)
assert len(result) == 3
# Order matches query order: s-position, p-position, o-position
assert result[0][1] == Query.FROM_S
assert result[1][1] == Query.FROM_P
assert result[2][1] == Query.FROM_O
@pytest.mark.asyncio
async def test_directions_correct_for_multiple_entities(self):
"""Direction tags cycle correctly across multiple entities."""
triple = _make_triple("s", "p", "o")
client = AsyncMock()
client.query_stream = AsyncMock(return_value=[triple])
query = _make_query(triples_client=client)
result = await query.execute_batch_triple_queries(
["e1", "e2"], limit_per_entity=10
)
assert len(result) == 6
expected_directions = [
Query.FROM_S, Query.FROM_P, Query.FROM_O,
Query.FROM_S, Query.FROM_P, Query.FROM_O,
]
for (_, direction), expected in zip(result, expected_directions):
assert direction == expected
@pytest.mark.asyncio
async def test_direction_preserved_with_multiple_triples(self):
"""All triples from one query share the same direction."""
t1 = _make_triple("a", "p1", "b")
t2 = _make_triple("a", "p2", "c")
call_count = 0
async def multi_results(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return [t1, t2]
return []
client = AsyncMock()
client.query_stream = multi_results
query = _make_query(triples_client=client)
result = await query.execute_batch_triple_queries(
["e1"], limit_per_entity=10
)
# First query (FROM_S) returns 2 triples, both should be FROM_S
assert len(result) == 2
assert result[0] == (t1, Query.FROM_S)
assert result[1] == (t2, Query.FROM_S)
class TestLRUCacheWithTTL:

View file

@ -49,7 +49,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test successful PDF processing"""
# Mock PDF content
pdf_content = b"fake pdf content"
pdf_content = b"%PDF-1.7\nfake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
# Mock PyPDFLoader
@ -88,13 +88,55 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
# Verify triples were sent for each page (provenance)
assert mock_triples_flow.send.call_count == 2
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_rejects_librarian_content_that_is_not_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test rejecting non-PDF content before invoking the PDF loader"""
html_content = b"<html><body>Not found</body></html>"
html_base64 = base64.b64encode(html_content)
mock_metadata = Metadata(id="test-doc")
mock_document = Document(metadata=mock_metadata, document_id="doc-123")
mock_msg = MagicMock()
mock_msg.value.return_value = mock_document
mock_output_flow = AsyncMock()
mock_triples_flow = AsyncMock()
mock_flow = MagicMock(side_effect=lambda name: {
"output": mock_output_flow,
"triples": mock_triples_flow,
}.get(name))
mock_flow.librarian.fetch_document_metadata = AsyncMock(
return_value=MagicMock(kind="application/pdf")
)
mock_flow.librarian.fetch_document_content = AsyncMock(
return_value=html_base64
)
mock_flow.librarian.save_child_document = AsyncMock()
config = {
'id': 'test-pdf-decoder',
'taskgroup': AsyncMock()
}
processor = Processor(**config)
await processor.on_message(mock_msg, None, mock_flow)
mock_pdf_loader_class.assert_not_called()
mock_output_flow.send.assert_not_called()
mock_triples_flow.send.assert_not_called()
mock_flow.librarian.save_child_document.assert_not_called()
@patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test handling of empty PDF"""
pdf_content = b"fake pdf content"
pdf_content = b"%PDF-1.7\nfake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
mock_loader = MagicMock()
@ -126,7 +168,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test handling of unicode content in PDF"""
pdf_content = b"fake pdf content"
pdf_content = b"%PDF-1.7\nfake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
mock_loader = MagicMock()

View file

@ -86,19 +86,19 @@ class TestVerifyJwtEddsa:
def test_valid_jwt_passes(self):
priv, pub = make_keypair()
claims = {
"sub": "user-1", "workspace": "default",
"sub": "user-1", "default_workspace": "default",
"iat": int(time.time()),
"exp": int(time.time()) + 60,
}
token = sign_jwt(priv, claims)
got = _verify_jwt_eddsa(token, pub)
assert got["sub"] == "user-1"
assert got["workspace"] == "default"
assert got["default_workspace"] == "default"
def test_expired_jwt_rejected(self):
priv, pub = make_keypair()
claims = {
"sub": "user-1", "workspace": "default",
"sub": "user-1", "default_workspace": "default",
"iat": int(time.time()) - 3600,
"exp": int(time.time()) - 1,
}
@ -110,7 +110,7 @@ class TestVerifyJwtEddsa:
priv_a, _ = make_keypair()
_, pub_b = make_keypair()
claims = {
"sub": "user-1", "workspace": "default",
"sub": "user-1", "default_workspace": "default",
"iat": int(time.time()),
"exp": int(time.time()) + 60,
}
@ -130,7 +130,7 @@ class TestVerifyJwtEddsa:
# since we expect it to bail before verifying.
header = {"alg": "HS256", "typ": "JWT", "kid": "x"}
payload = {
"sub": "user-1", "workspace": "default",
"sub": "user-1", "default_workspace": "default",
"iat": int(time.time()), "exp": int(time.time()) + 60,
}
h = _b64url(json.dumps(header, separators=(",", ":")).encode())
@ -148,11 +148,11 @@ class TestIdentity:
def test_fields(self):
i = Identity(
handle="u", workspace="w",
handle="u", default_workspace="w",
principal_id="u", source="api-key",
)
assert i.handle == "u"
assert i.workspace == "w"
assert i.default_workspace == "w"
assert i.principal_id == "u"
assert i.source == "api-key"
@ -208,7 +208,7 @@ class TestIamAuthDispatch:
async def test_valid_jwt_resolves_to_identity(self):
priv, pub = make_keypair()
claims = {
"sub": "user-1", "workspace": "default",
"sub": "user-1", "default_workspace": "default",
"iat": int(time.time()),
"exp": int(time.time()) + 60,
}
@ -221,7 +221,7 @@ class TestIamAuthDispatch:
make_request(f"Bearer {token}")
)
assert ident.handle == "user-1"
assert ident.workspace == "default"
assert ident.default_workspace == "default"
assert ident.principal_id == "user-1"
assert ident.source == "jwt"
@ -231,7 +231,7 @@ class TestIamAuthDispatch:
# must not validate — even ones that would otherwise pass.
priv, _ = make_keypair()
claims = {
"sub": "user-1", "workspace": "default",
"sub": "user-1", "default_workspace": "default",
"iat": int(time.time()), "exp": int(time.time()) + 60,
}
token = sign_jwt(priv, claims)
@ -259,7 +259,7 @@ class TestIamAuthDispatch:
make_request("Bearer tg_testkey")
)
assert ident.handle == "user-xyz"
assert ident.workspace == "default"
assert ident.default_workspace == "default"
assert ident.principal_id == "user-xyz"
assert ident.source == "api-key"
@ -338,9 +338,9 @@ class TestAuthorise:
decision for the regime's TTL (clamped above), and raises 403
on deny / 401 on regime error (fail closed)."""
def _make_identity(self, handle="u-1", workspace="default"):
def _make_identity(self, handle="u-1", default_workspace="default"):
return Identity(
handle=handle, workspace=workspace,
handle=handle, default_workspace=default_workspace,
principal_id=handle, source="api-key",
)

View file

@ -25,11 +25,11 @@ from trustgraph.gateway.capabilities import (
class _Identity:
"""Stand-in for auth.Identity — under the IAM contract it has
just ``handle``, ``workspace``, ``principal_id``, ``source``."""
just ``handle``, ``default_workspace``, ``principal_id``, ``source``."""
def __init__(self, handle="user-1", workspace="default"):
def __init__(self, handle="user-1", default_workspace="default"):
self.handle = handle
self.workspace = workspace
self.default_workspace = default_workspace
self.principal_id = handle
self.source = "api-key"
@ -105,14 +105,14 @@ class TestEnforceWorkspace:
async def test_default_fills_from_identity(self):
data = {"operation": "x"}
auth = _allow_auth()
await enforce_workspace(data, _Identity(workspace="default"), auth)
await enforce_workspace(data, _Identity(default_workspace="default"), auth)
assert data["workspace"] == "default"
@pytest.mark.asyncio
async def test_caller_supplied_workspace_kept(self):
data = {"workspace": "acme", "operation": "x"}
auth = _allow_auth()
await enforce_workspace(data, _Identity(workspace="default"), auth)
await enforce_workspace(data, _Identity(default_workspace="default"), auth)
assert data["workspace"] == "acme"
@pytest.mark.asyncio

View file

@ -28,7 +28,7 @@ TEST_CAP = "graph:write"
def _valid_identity():
return Identity(
handle="test-user",
workspace="default",
default_workspace="default",
principal_id="test-user",
source="api-key",
)

View file

@ -32,7 +32,7 @@ class TestAuthenticateAnonymous:
)
assert resp.error is None
assert resp.resolved_user_id == "anon"
assert resp.resolved_workspace == "ws"
assert resp.resolved_default_workspace == "ws"
assert "admin" in list(resp.resolved_roles)
@pytest.mark.asyncio
@ -44,7 +44,7 @@ class TestAuthenticateAnonymous:
_make_request(operation="authenticate-anonymous")
)
assert resp.resolved_user_id == "dev-user"
assert resp.resolved_workspace == "dev-ws"
assert resp.resolved_default_workspace == "dev-ws"
class TestResolveApiKey:
@ -57,7 +57,7 @@ class TestResolveApiKey:
)
assert resp.error is None
assert resp.resolved_user_id == "anonymous"
assert resp.resolved_workspace == "default"
assert resp.resolved_default_workspace == "default"
class TestAuthorise:

View file

@ -107,6 +107,7 @@ class TestGraphRagDagStructure:
embeddings_client = AsyncMock()
graph_embeddings_client = AsyncMock()
triples_client = AsyncMock()
reranker_client = AsyncMock()
embeddings_client.embed.return_value = [[0.1, 0.2]]
graph_embeddings_client.query.return_value = [
@ -121,27 +122,22 @@ class TestGraphRagDagStructure:
]
triples_client.query.return_value = []
result = MagicMock()
result.document_id = "0"
result.query_id = "0"
result.score = 0.95
reranker_client.rerank.return_value = [result]
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return PromptResult(response_type="text", text="concept")
elif template_id == "kg-edge-scoring":
edges = variables.get("knowledge", [])
return PromptResult(
response_type="jsonl",
objects=[{"id": e["id"], "score": 10} for e in edges],
)
elif template_id == "kg-edge-reasoning":
edges = variables.get("knowledge", [])
return PromptResult(
response_type="jsonl",
objects=[{"id": e["id"], "reasoning": "relevant"} for e in edges],
)
elif template_id == "kg-synthesis":
return PromptResult(response_type="text", text="Answer.")
return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
return (prompt_client, embeddings_client, graph_embeddings_client,
triples_client, reranker_client)
@pytest.mark.asyncio
async def test_dag_chain(self, mock_clients):
@ -152,7 +148,7 @@ class TestGraphRagDagStructure:
events.append({"explain_id": explain_id, "triples": triples})
await rag.query(
query="test", explain_callback=explain_cb, edge_score_limit=0,
query="test", explain_callback=explain_cb,
)
dag = _collect_events(events)

View file

@ -333,8 +333,8 @@ class TestUnifiedTableQueries:
"""Test queries against the unified rows table"""
@pytest.mark.asyncio
@patch('trustgraph.query.rows.cassandra.service.async_execute', new_callable=AsyncMock)
async def test_query_with_index_match(self, mock_async_execute):
@patch('trustgraph.query.rows.cassandra.service.async_execute_paged', new_callable=AsyncMock)
async def test_query_with_index_match(self, mock_async_execute_paged):
"""Test query execution with matching index"""
processor = MagicMock()
processor.session = MagicMock()
@ -344,10 +344,10 @@ class TestUnifiedTableQueries:
processor.find_matching_index = Processor.find_matching_index.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
# Mock async_execute to return test data
# Mock async_execute_paged to return test data (list of pages)
mock_row = MagicMock()
mock_row.data = {"id": "123", "name": "Test Product", "category": "electronics"}
mock_async_execute.return_value = [mock_row]
mock_async_execute_paged.return_value = [[mock_row]]
schema = RowSchema(
name="products",
@ -370,10 +370,10 @@ class TestUnifiedTableQueries:
# Verify Cassandra was connected and queried
processor.connect_cassandra.assert_called_once()
mock_async_execute.assert_called_once()
mock_async_execute_paged.assert_called_once()
# Verify query structure - should query unified rows table
call_args = mock_async_execute.call_args
call_args = mock_async_execute_paged.call_args
query = call_args[0][1]
params = call_args[0][2]
@ -394,8 +394,8 @@ class TestUnifiedTableQueries:
assert results[0]["category"] == "electronics"
@pytest.mark.asyncio
@patch('trustgraph.query.rows.cassandra.service.async_execute', new_callable=AsyncMock)
async def test_query_without_index_match(self, mock_async_execute):
@patch('trustgraph.query.rows.cassandra.service.async_scan', new_callable=AsyncMock)
async def test_query_without_index_match(self, mock_async_scan):
"""Test query execution without matching index (scan mode)"""
processor = MagicMock()
processor.session = MagicMock()
@ -406,12 +406,10 @@ class TestUnifiedTableQueries:
processor._matches_filters = Processor._matches_filters.__get__(processor, Processor)
processor.query_cassandra = Processor.query_cassandra.__get__(processor, Processor)
# Mock async_execute to return test data
# Mock async_scan to return filtered test data
mock_row1 = MagicMock()
mock_row1.data = {"id": "1", "name": "Product A", "price": "100"}
mock_row2 = MagicMock()
mock_row2.data = {"id": "2", "name": "Product B", "price": "200"}
mock_async_execute.return_value = [mock_row1, mock_row2]
mock_async_scan.return_value = [mock_row1]
schema = RowSchema(
name="products",
@ -432,13 +430,16 @@ class TestUnifiedTableQueries:
limit=10
)
# Query should use ALLOW FILTERING for scan
call_args = mock_async_execute.call_args
# Verify async_scan was called
mock_async_scan.assert_called_once()
# Verify query structure
call_args = mock_async_scan.call_args
query = call_args[0][1]
assert "ALLOW FILTERING" in query
# Should post-filter results
# Should return filtered results
assert len(results) == 1
assert results[0]["name"] == "Product A"

View file

@ -259,6 +259,8 @@ class TestGraphEmbeddingsNullProtection:
proc.collection_exists = MagicMock(return_value=True)
proc._cache_lock = asyncio.Lock()
proc._known_collections = set()
proc.replication_factor = 1
proc.shard_number = 1
msg = MagicMock()
msg.metadata.collection = "graphs"

View file

@ -101,27 +101,27 @@ class TestQuery:
assert query.rag == mock_rag
assert query.collection == "test_collection"
assert query.verbose is False
assert query.doc_limit == 20 # Default value
assert query.fetch_limit == 20 # Default value
def test_query_initialization_with_custom_doc_limit(self):
"""Test Query initialization with custom doc_limit"""
def test_query_initialization_with_custom_fetch_limit(self):
"""Test Query initialization with custom fetch_limit"""
# Create mock DocumentRag
mock_rag = MagicMock()
# Initialize Query with custom doc_limit
# Initialize Query with custom fetch_limit
query = Query(
rag=mock_rag,
workspace="test_workspace",
collection="custom_collection",
verbose=True,
doc_limit=50
fetch_limit=50
)
# Verify initialization
assert query.rag == mock_rag
assert query.collection == "custom_collection"
assert query.verbose is True
assert query.doc_limit == 50
assert query.fetch_limit == 50
@pytest.mark.asyncio
async def test_extract_concepts(self):
@ -224,7 +224,7 @@ class TestQuery:
workspace="test_workspace",
collection="test_collection",
verbose=False,
doc_limit=15
fetch_limit=15
)
# Call get_docs with concepts list
@ -377,7 +377,7 @@ class TestQuery:
workspace="test_workspace",
collection="test_collection",
verbose=True,
doc_limit=5
fetch_limit=5
)
# Call get_docs with concepts
@ -615,7 +615,7 @@ class TestQuery:
workspace="test_workspace",
collection="test_collection",
verbose=False,
doc_limit=10
fetch_limit=10
)
docs, chunk_ids = await query.get_docs(["concept A", "concept B"])

View file

@ -0,0 +1,92 @@
from trustgraph.retrieval.document_rag.rerank import (
RerankCandidate, normalize_candidate_scores, mmr_select,
_pair_diversity_penalty
)
def candidate(index, chunk_id, text, score):
return RerankCandidate(
index=index,
chunk_id=chunk_id,
text=text,
reranker_score=score,
)
def test_normalize_candidate_scores_min_max_scales_raw_scores():
candidates = [
candidate(0, "a", "alpha", -2.0),
candidate(1, "b", "beta", 0.0),
candidate(2, "c", "gamma", 4.0),
]
normalized = normalize_candidate_scores(candidates)
assert normalized[0].normalized_score == 0.0
assert normalized[1].normalized_score == 1.0 / 3.0
assert normalized[2].normalized_score == 1.0
def test_normalize_candidate_scores_handles_equal_scores():
candidates = [
candidate(0, "a", "alpha", 3.0),
candidate(1, "b", "beta", 3.0),
candidate(2, "c", "gamma", 3.0),
]
normalized = normalize_candidate_scores(candidates)
assert [c.normalized_score for c in normalized] == [0.5, 0.5, 0.5]
def test_mmr_select_limits_results():
candidates = [
candidate(0, "a", "alpha policy", 0.9),
candidate(1, "b", "beta refund", 0.8),
candidate(2, "c", "gamma shipping", 0.7),
]
selected = mmr_select(candidates, limit=2)
assert len(selected) == 2
def test_mmr_select_prefers_highest_reranker_score_first():
candidates = [
candidate(0, "a", "weakly relevant text", 0.1),
candidate(1, "b", "strongly relevant answer", 10.0),
candidate(2, "c", "medium relevant text", 5.0),
]
selected = mmr_select(candidates, limit=1)
assert selected[0].chunk_id == "b"
def test_mmr_select_penalizes_near_duplicate_chunks():
candidates = [
candidate(0, "a", "apple banana fruit return policy", 1.00),
candidate(1, "b", "apple banana fruit return policy duplicate", 0.95),
candidate(2, "c", "engine motor vehicle warranty", 0.90),
]
selected = mmr_select(
candidates,
limit=2,
lambda_mult=0.2,
token_overlap_weight=1.0,
)
assert [c.chunk_id for c in selected] == ["a", "c"]
def test_pair_diversity_penalty_is_clamped():
left = candidate(0, "a", "same same same", 1.0)
right = candidate(1, "b", "same same same", 0.9)
penalty = _pair_diversity_penalty(
left,
right,
token_overlap_weight=10.0,
)
assert penalty == 1.0

View file

@ -0,0 +1,550 @@
"""
Tests for the optional cross-encoder reranking pass in DocumentRag.query().
Two behaviours are covered:
1. No-op: when no reranker_client is wired (the default), query() must feed
the LLM the exact same chunks, in the same order, that retrieval produced
- byte-identical to the pre-reranker behaviour - and must NOT emit a
chunk-selection provenance event.
2. Rerank: when a reranker_client is wired, the retrieved chunks are reordered
and truncated according to the reranker's results, the LLM receives the
reranked top-N, and a tg:ChunkSelection (focus) provenance event is emitted
carrying the per-surviving-chunk scores and chunk references.
These are pure orchestration tests - the reranker is a stub, so there is no
torch / network dependency.
"""
import pytest
from unittest.mock import AsyncMock
from dataclasses import dataclass
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.base import PromptResult
from trustgraph.schema import RerankerResult
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_WAS_DERIVED_FROM,
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
TG_FOCUS, TG_SYNTHESIS,
TG_CHUNK_SELECTION, TG_SELECTED_CHUNK, TG_SCORE, TG_DOCUMENT,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def find_triple(triples, predicate, subject=None):
for t in triples:
if t.p.iri == predicate:
if subject is None or t.s.iri == subject:
return t
return None
def find_triples(triples, predicate, subject=None):
return [
t for t in triples
if t.p.iri == predicate
and (subject is None or t.s.iri == subject)
]
def has_type(triples, subject, rdf_type):
return any(
t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type
for t in triples
)
def derived_from(triples, subject):
t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject)
return t.o.iri if t else None
@dataclass
class ChunkMatch:
"""Mimics the result from doc_embeddings_client.query()."""
chunk_id: str
# ---------------------------------------------------------------------------
# Fixtures: three retrievable chunks
# ---------------------------------------------------------------------------
CHUNK_A = "urn:chunk:policy-doc-1:chunk-0"
CHUNK_B = "urn:chunk:policy-doc-1:chunk-1"
CHUNK_C = "urn:chunk:policy-doc-1:chunk-2"
CHUNK_A_CONTENT = "Customers may return items within 30 days of purchase."
CHUNK_B_CONTENT = "Our stores are open from 9am to 5pm on weekdays."
CHUNK_C_CONTENT = "Refunds are processed to the original payment method."
# Retrieval (post-dedupe) order is A, B, C.
ORDERED_CONTENT = [CHUNK_A_CONTENT, CHUNK_B_CONTENT, CHUNK_C_CONTENT]
ORDERED_CHUNK_IDS = [CHUNK_A, CHUNK_B, CHUNK_C]
def build_mock_clients():
"""
Build mock subsidiary clients for a document-rag query returning three
distinct chunks (A, B, C) in that order.
"""
prompt_client = AsyncMock()
embeddings_client = AsyncMock()
doc_embeddings_client = AsyncMock()
fetch_chunk = AsyncMock()
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return PromptResult(response_type="text", text="return policy\nrefund")
return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt
embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
# Each concept query returns the same three chunks; dedupe keeps A, B, C.
doc_embeddings_client.query.return_value = [
ChunkMatch(chunk_id=CHUNK_A),
ChunkMatch(chunk_id=CHUNK_B),
ChunkMatch(chunk_id=CHUNK_C),
]
async def mock_fetch(chunk_id):
return {
CHUNK_A: CHUNK_A_CONTENT,
CHUNK_B: CHUNK_B_CONTENT,
CHUNK_C: CHUNK_C_CONTENT,
}[chunk_id]
fetch_chunk.side_effect = mock_fetch
prompt_client.document_prompt.return_value = PromptResult(
response_type="text",
text="Items can be returned within 30 days for a full refund.",
)
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
class StubReranker:
"""
Stub reranker_client mirroring RerankerClient.rerank(): returns a fixed,
pre-sorted, truncated list of RerankerResult - exactly the contract the
flashrank service guarantees (sorted desc by score, truncated to limit).
"""
def __init__(self, results):
self._results = results
self.calls = []
async def rerank(self, queries, documents, limit=10, timeout=300):
self.calls.append(
{"queries": queries, "documents": documents, "limit": limit}
)
return self._results
# ---------------------------------------------------------------------------
# 1. No-op: reranker_client=None must not change anything
# ---------------------------------------------------------------------------
class TestRerankNoOp:
@pytest.mark.asyncio
async def test_documents_passed_to_llm_are_unchanged(self):
"""
With no reranker wired, document_prompt must receive the retrieved
chunks in the original order and length.
"""
clients = build_mock_clients()
rag = DocumentRag(*clients) # reranker_client defaults to None
await rag.query(query="What is the return policy?")
call = rag.prompt_client.document_prompt.call_args
passed_docs = call.kwargs["documents"]
assert passed_docs == ORDERED_CONTENT
@pytest.mark.asyncio
async def test_no_chunk_selection_event_emitted(self):
"""
Without a reranker, the provenance chain is the original 4 stages:
question, grounding, exploration, synthesis - no focus stage.
"""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
assert len(events) == 4
types = [
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_SYNTHESIS,
]
for i, expected in enumerate(types):
assert has_type(events[i]["triples"], events[i]["explain_id"], expected)
# No chunk-selection entity anywhere.
for e in events:
assert not any(
t.o.iri == TG_CHUNK_SELECTION
for t in e["triples"]
if t.p.iri == RDF_TYPE
)
@pytest.mark.asyncio
async def test_synthesis_derives_from_exploration_when_no_rerank(self):
"""
No-op lineage is unchanged: synthesis derives from exploration
(there is no focus stage). Guards the conditional synthesis parent.
"""
clients = build_mock_clients()
rag = DocumentRag(*clients)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
# events: question, grounding, exploration, synthesis
exp_uri = events[2]["explain_id"]
syn_event = events[3]
assert derived_from(syn_event["triples"], syn_event["explain_id"]) == exp_uri
# ---------------------------------------------------------------------------
# 2. Rerank: reorder + truncate + provenance
# ---------------------------------------------------------------------------
class TestRerankActive:
def _reranker_keeping_C_then_A(self):
# Reranker says chunk index 2 (C) is best, then index 0 (A); B dropped.
# Pre-sorted desc by score and truncated to limit, per the contract.
return StubReranker([
RerankerResult(document_id="2", query_id="0", score=0.95),
RerankerResult(document_id="0", query_id="0", score=0.42),
])
@pytest.mark.asyncio
async def test_documents_reordered_and_truncated(self):
clients = build_mock_clients()
reranker = self._reranker_keeping_C_then_A()
rag = DocumentRag(*clients, reranker_client=reranker)
await rag.query(query="What is the return policy?")
call = rag.prompt_client.document_prompt.call_args
passed_docs = call.kwargs["documents"]
assert passed_docs == [CHUNK_C_CONTENT, CHUNK_A_CONTENT]
@pytest.mark.asyncio
async def test_reranker_called_with_single_query_and_all_docs(self):
clients = build_mock_clients()
reranker = self._reranker_keeping_C_then_A()
rag = DocumentRag(*clients, reranker_client=reranker)
await rag.query(query="What is the return policy?", doc_limit=2)
assert len(reranker.calls) == 1
c = reranker.calls[0]
assert c["queries"] == [{"id": "0", "text": "What is the return policy?"}]
assert c["documents"] == [
{"id": "0", "text": CHUNK_A_CONTENT},
{"id": "1", "text": CHUNK_B_CONTENT},
{"id": "2", "text": CHUNK_C_CONTENT},
]
# The rerank narrows down to the final doc_limit, NOT fetch_limit
# (fetch_limit is the over-fetched candidate pool size).
assert c["limit"] == 2
@pytest.mark.asyncio
async def test_explicit_fetch_limit_over_fetches_then_narrows(self):
"""
Semantic guard for the value of reranking AND the maintainer's two-limit
contract: an explicit fetch_limit makes retrieval OVER-FETCH a wider
candidate pool so the cross-encoder can surface chunks the bi-encoder
ranked outside the final doc_limit, then the rerank narrows the pool back
down to doc_limit. The fetch_limit is honoured directly (caller controls
how hard the reranker works), not overridden by any heuristic.
"""
clients = build_mock_clients()
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
reranker = self._reranker_keeping_C_then_A()
# Candidate pool (fetch_limit=60) >> final doc_limit (6).
rag = DocumentRag(*clients, reranker_client=reranker)
await rag.query(
query="What is the return policy?", doc_limit=6, fetch_limit=60,
)
# Over-fetch: the embeddings store is queried with the fetch_limit
# budget (60 // 2 concept-vectors = 30 per concept), NOT the doc_limit
# budget (6 // 2 = 3). This is the bug guard.
q_limit = doc_embeddings_client.query.call_args.kwargs["limit"]
assert q_limit == 30
# Narrow: the rerank keeps the final doc_limit (6), not fetch_limit.
assert reranker.calls[0]["limit"] == 6
@pytest.mark.asyncio
async def test_default_fetch_limit_derives_overfetch_from_doc_limit(self):
"""
With no fetch_limit passed to query(), the candidate pool falls back to
the OVERFETCH_FACTOR x doc_limit heuristic, so over-fetch scales with
doc_limit and reranking keeps its recall benefit out of the box.
"""
clients = build_mock_clients()
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
reranker = self._reranker_keeping_C_then_A()
# No fetch_limit -> heuristic default.
rag = DocumentRag(*clients, reranker_client=reranker)
await rag.query(query="What is the return policy?", doc_limit=20)
# fetch = 3 x 20 = 60 -> 60 // 2 concept-vectors = 30 per concept.
q_limit = doc_embeddings_client.query.call_args.kwargs["limit"]
assert q_limit == 30
# Rerank narrows to the final doc_limit (20).
assert reranker.calls[0]["limit"] == 20
@pytest.mark.asyncio
async def test_fetch_limit_floored_at_doc_limit(self):
"""
A fetch_limit below doc_limit is floored up to doc_limit: retrieval must
never fetch fewer candidates than the rerank is asked to keep, else the
prompt could not be filled.
"""
clients = build_mock_clients()
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
reranker = self._reranker_keeping_C_then_A()
rag = DocumentRag(*clients, reranker_client=reranker)
await rag.query(
query="What is the return policy?", doc_limit=10, fetch_limit=4,
)
# fetch = max(4, 10) = 10 -> 10 // 2 concept-vectors = 5 per concept.
q_limit = doc_embeddings_client.query.call_args.kwargs["limit"]
assert q_limit == 5
assert reranker.calls[0]["limit"] == 10
@pytest.mark.asyncio
async def test_chunk_selection_event_emitted(self):
clients = build_mock_clients()
reranker = self._reranker_keeping_C_then_A()
rag = DocumentRag(*clients, reranker_client=reranker)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
# Now 5 stages: question, grounding, exploration, focus, synthesis.
assert len(events) == 5
ordered_types = [
TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
TG_FOCUS, TG_SYNTHESIS,
]
for i, expected in enumerate(ordered_types):
assert has_type(events[i]["triples"], events[i]["explain_id"], expected)
@pytest.mark.asyncio
async def test_chunk_selection_carries_scores_and_chunk_refs(self):
clients = build_mock_clients()
reranker = self._reranker_keeping_C_then_A()
rag = DocumentRag(*clients, reranker_client=reranker)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
focus_event = events[3]
foc_uri = focus_event["explain_id"]
triples = focus_event["triples"]
# focus is derived from exploration
exp_uri = events[2]["explain_id"]
assert derived_from(triples, foc_uri) == exp_uri
# Two ChunkSelection sub-entities, linked from focus.
sel_links = find_triples(triples, TG_SELECTED_CHUNK, foc_uri)
assert len(sel_links) == 2
# Each selection has a ChunkSelection type, a chunk document ref and a score.
chunk_refs = set()
scores = set()
for link in sel_links:
sel_uri = link.o.iri
assert has_type(triples, sel_uri, TG_CHUNK_SELECTION)
doc_ref = find_triple(triples, TG_DOCUMENT, sel_uri)
assert doc_ref is not None
chunk_refs.add(doc_ref.o.iri)
score_t = find_triple(triples, TG_SCORE, sel_uri)
assert score_t is not None
scores.add(score_t.o.value)
# Surviving chunks are C and A (B dropped), with the reranker scores.
assert chunk_refs == {CHUNK_C, CHUNK_A}
assert scores == {"0.95", "0.42"}
@pytest.mark.asyncio
async def test_all_focus_triples_in_retrieval_graph(self):
clients = build_mock_clients()
reranker = self._reranker_keeping_C_then_A()
rag = DocumentRag(*clients, reranker_client=reranker)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
explain_callback=explain_callback,
)
for t in events[3]["triples"]:
assert t.g == "urn:graph:retrieval"
@pytest.mark.asyncio
async def test_synthesis_derives_from_focus_when_reranking(self):
"""
When reranking runs, synthesis must derive from the focus node (the
reranked chunks actually fed to the LLM), mirroring GraphRAG - not from
exploration, which would leave focus as a dangling branch and
misrepresent what fed the answer.
"""
clients = build_mock_clients()
reranker = self._reranker_keeping_C_then_A()
rag = DocumentRag(*clients, reranker_client=reranker)
events = []
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
await rag.query(
query="What is the return policy?",
doc_limit=2,
explain_callback=explain_callback,
)
# events: question, grounding, exploration, focus, synthesis
foc_uri = events[3]["explain_id"]
syn_event = events[4]
assert derived_from(syn_event["triples"], syn_event["explain_id"]) == foc_uri
@pytest.mark.asyncio
async def test_empty_docs_skips_reranker(self):
"""If retrieval returns no chunks, the reranker is never called."""
clients = build_mock_clients()
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
doc_embeddings_client.query.return_value = [] # no matches
reranker = self._reranker_keeping_C_then_A()
rag = DocumentRag(*clients, reranker_client=reranker)
await rag.query(query="What is the return policy?")
assert reranker.calls == []
# ---------------------------------------------------------------------------
# 3. Diversity selection: optional MMR after cross-encoder scoring
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_diversity_mode_scores_full_candidate_pool_before_selecting(self):
"""
With diversity selection enabled, the cross-encoder should score the full
fetched candidate pool before MMR narrows it down to doc_limit.
"""
clients = build_mock_clients()
reranker = StubReranker([
RerankerResult(document_id="0", query_id="0", score=1.00),
RerankerResult(document_id="1", query_id="0", score=0.95),
RerankerResult(document_id="2", query_id="0", score=0.90),
])
rag = DocumentRag(
*clients,
reranker_client=reranker,
rerank_diversity_mode="mmr",
)
await rag.query(query="What is the return policy?", doc_limit=2)
assert reranker.calls[0]["limit"] == len(ORDERED_CONTENT)
call = rag.prompt_client.document_prompt.call_args
passed_docs = call.kwargs["documents"]
assert len(passed_docs) == 2
@pytest.mark.asyncio
async def test_diversity_mode_selects_less_redundant_context_set(self):
"""
MMR should use cross-encoder scores as relevance while penalizing redundant
chunks, so a slightly lower-scored but less redundant chunk can be selected.
"""
clients = build_mock_clients()
prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients
duplicate_a = "apple banana fruit return policy"
duplicate_b = "apple banana fruit return policy duplicate"
diverse_c = "engine motor vehicle warranty"
async def mock_fetch(chunk_id):
return {
CHUNK_A: duplicate_a,
CHUNK_B: duplicate_b,
CHUNK_C: diverse_c,
}[chunk_id]
fetch_chunk.side_effect = mock_fetch
reranker = StubReranker([
RerankerResult(document_id="0", query_id="0", score=1.00),
RerankerResult(document_id="1", query_id="0", score=0.95),
RerankerResult(document_id="2", query_id="0", score=0.90),
])
rag = DocumentRag(
*clients,
reranker_client=reranker,
rerank_diversity_mode="mmr",
rerank_diversity_lambda=0.2,
)
await rag.query(query="What is the return policy?", doc_limit=2)
call = rag.prompt_client.document_prompt.call_args
passed_docs = call.kwargs["documents"]
assert passed_docs == [duplicate_a, diverse_c]

View file

@ -0,0 +1,89 @@
"""
Cross-layer wiring contract for the Document-RAG reranker (issue #878).
The Document-RAG processor registers a ``RerankerClientSpec`` for the
``reranker-request`` / ``reranker-response`` roles (see
``retrieval/document_rag/rag.py``). At flow construction every spec runs
``spec.add(flow, processor, definition)``, and ``RequestResponseSpec.add``
resolves its topics via ``definition["topics"][name]`` - which raises
``KeyError`` if the flow blueprint does not provide those topics.
This means the monorepo code change is only safe to deploy together with the
companion ``trustgraph-templates`` change that wires ``reranker-request`` /
``reranker-response`` into the Document-RAG flow (mirroring what templates
PR #279 did for GraphRAG via ``graph-store.jsonnet``). These tests pin that
contract from the monorepo side:
* with the reranker topics present (as the updated templates compile them),
the spec binds cleanly and registers the client;
* without them (the pre-companion blueprint), construction fails fast with a
KeyError naming the missing role - documenting exactly why the templates
change is required.
No broker/network: the pub/sub backend is mocked (topics are bound at add()
time, connections happen later at start()).
"""
import pytest
from unittest.mock import MagicMock
from trustgraph.base import RerankerClientSpec
def _flow():
f = MagicMock()
f.workspace = "ws"
f.name = "document-rag"
f.id = "proc1"
f.consumer = {}
return f
def _processor():
p = MagicMock()
p.pubsub = MagicMock()
p.id = "proc1"
p.taskgroup = MagicMock()
return p
def _spec():
return RerankerClientSpec(
request_name="reranker-request",
response_name="reranker-response",
)
# Topics dict as the UPDATED document-store.jsonnet compiles them
# (verified by compiling the template: reranker-request -> request:tg:reranker:{workspace}:{id}).
DEFINITION_WITH_RERANKER = {
"topics": {
"request": "request:tg:document-rag:ws:id",
"response": "response:tg:document-rag:ws:id",
"reranker-request": "request:tg:reranker:ws:id",
"reranker-response": "response:tg:reranker:ws:id",
}
}
# Pre-companion blueprint: no reranker topics (document-rag before the templates change).
DEFINITION_WITHOUT_RERANKER = {
"topics": {
"request": "request:tg:document-rag:ws:id",
"response": "response:tg:document-rag:ws:id",
}
}
def test_reranker_client_binds_when_flow_provides_topics():
flow = _flow()
_spec().add(flow, _processor(), DEFINITION_WITH_RERANKER)
# The client consumer is registered against the reranker role.
assert "reranker-request" in flow.consumer
def test_reranker_client_keyerrors_without_companion_template_topics():
with pytest.raises(KeyError) as exc:
_spec().add(_flow(), _processor(), DEFINITION_WITHOUT_RERANKER)
# Fails fast naming the missing role -> the trustgraph-templates companion
# change (wire reranker-request/response into the document-rag flow) is required.
assert "reranker-request" in str(exc.value)

View file

@ -66,6 +66,7 @@ class TestDocumentRagService:
workspace=ANY, # Workspace comes from flow.workspace (mock)
collection="test_coll_1", # Must be from message, not hardcoded default
doc_limit=5,
fetch_limit=0, # Unset -> core derives the candidate pool
explain_callback=ANY, # Explainability callback is always passed
save_answer_callback=ANY, # Librarian save callback is always passed
)

View file

@ -15,54 +15,52 @@ class TestGraphRag:
def test_graph_rag_initialization_with_defaults(self):
"""Test GraphRag initialization with default verbose setting"""
# Create mock clients
mock_prompt_client = MagicMock()
mock_embeddings_client = MagicMock()
mock_graph_embeddings_client = MagicMock()
mock_triples_client = MagicMock()
mock_reranker_client = MagicMock()
# Initialize GraphRag
graph_rag = GraphRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client
)
# Verify initialization
assert graph_rag.prompt_client == mock_prompt_client
assert graph_rag.embeddings_client == mock_embeddings_client
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
assert graph_rag.triples_client == mock_triples_client
assert graph_rag.verbose is False # Default value
# Verify label_cache is an LRUCacheWithTTL instance
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
def test_graph_rag_initialization_with_verbose(self):
"""Test GraphRag initialization with verbose enabled"""
# Create mock clients
mock_prompt_client = MagicMock()
mock_embeddings_client = MagicMock()
mock_graph_embeddings_client = MagicMock()
mock_triples_client = MagicMock()
# Initialize GraphRag with verbose=True
graph_rag = GraphRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
verbose=True
reranker_client=mock_reranker_client,
)
# Verify initialization
assert graph_rag.prompt_client == mock_prompt_client
assert graph_rag.embeddings_client == mock_embeddings_client
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
assert graph_rag.triples_client == mock_triples_client
assert graph_rag.reranker_client == mock_reranker_client
assert graph_rag.verbose is False
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
def test_graph_rag_initialization_with_verbose(self):
"""Test GraphRag initialization with verbose enabled"""
mock_prompt_client = MagicMock()
mock_embeddings_client = MagicMock()
mock_graph_embeddings_client = MagicMock()
mock_triples_client = MagicMock()
mock_reranker_client = MagicMock()
graph_rag = GraphRag(
prompt_client=mock_prompt_client,
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
reranker_client=mock_reranker_client,
verbose=True,
)
assert graph_rag.prompt_client == mock_prompt_client
assert graph_rag.embeddings_client == mock_embeddings_client
assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client
assert graph_rag.triples_client == mock_triples_client
assert graph_rag.reranker_client == mock_reranker_client
assert graph_rag.verbose is True
# Verify label_cache is an LRUCacheWithTTL instance
from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL
assert isinstance(graph_rag.label_cache, LRUCacheWithTTL)
@ -365,244 +363,162 @@ class TestQuery:
assert "workspace" not in c.kwargs
@pytest.mark.asyncio
async def test_follow_edges_never_passes_workspace(self):
"""Verify follow_edges never passes workspace to query_stream."""
async def test_hop_and_filter_never_passes_workspace(self):
"""Verify hop_and_filter never passes workspace to query_stream."""
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_reranker_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_rag.reranker_client = mock_reranker_client
mock_rag.label_cache = MagicMock()
mock_rag.label_cache.get.return_value = None
mock_triple = MagicMock()
mock_triple.s, mock_triple.p, mock_triple.o = "e1", "p1", "o1"
mock_triple.s = "e1"
mock_triple.p = "p1"
mock_triple.o = "o1"
mock_triples_client.query_stream.return_value = [mock_triple]
mock_triples_client.query.return_value = []
result = MagicMock()
result.document_id = "0"
result.query_id = "0"
result.score = 0.9
mock_reranker_client.rerank.return_value = [result]
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
triple_limit=10
triple_limit=10,
)
subgraph = set()
await query.follow_edges("e1", subgraph, path_length=1)
await query.hop_and_filter(["e1"], ["concept"])
for c in mock_triples_client.query_stream.call_args_list:
assert "workspace" not in c.kwargs
@pytest.mark.asyncio
async def test_follow_edges_basic_functionality(self):
"""Test Query.follow_edges method basic triple discovery"""
async def test_hop_and_filter_basic_functionality(self):
"""Test hop_and_filter retrieves edges and scores them with reranker."""
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_reranker_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_rag.reranker_client = mock_reranker_client
mock_rag.label_cache = MagicMock()
mock_rag.label_cache.get.return_value = None
mock_triple1 = MagicMock()
mock_triple1.s, mock_triple1.p, mock_triple1.o = "entity1", "predicate1", "object1"
mock_triple = MagicMock()
mock_triple.s = "entity1"
mock_triple.p = "predicate1"
mock_triple.o = "object1"
mock_triples_client.query_stream.return_value = [mock_triple]
mock_triples_client.query.return_value = []
mock_triple2 = MagicMock()
mock_triple2.s, mock_triple2.p, mock_triple2.o = "subject2", "entity1", "object2"
mock_triple3 = MagicMock()
mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1"
mock_triples_client.query_stream.side_effect = [
[mock_triple1], # s=ent
[mock_triple2], # p=ent
[mock_triple3], # o=ent
]
result = MagicMock()
result.document_id = "0"
result.query_id = "0"
result.score = 0.95
mock_reranker_client.rerank.return_value = [result]
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
triple_limit=10
triple_limit=10,
edge_limit=25,
)
subgraph = set()
await query.follow_edges("entity1", subgraph, path_length=1)
assert mock_triples_client.query_stream.call_count == 3
mock_triples_client.query_stream.assert_any_call(
s="entity1", p=None, o=None, limit=10,
collection="test_collection", batch_size=20, g=""
)
mock_triples_client.query_stream.assert_any_call(
s=None, p="entity1", o=None, limit=10,
collection="test_collection", batch_size=20, g=""
)
mock_triples_client.query_stream.assert_any_call(
s=None, p=None, o="entity1", limit=10,
collection="test_collection", batch_size=20, g=""
selected, uri_map, edge_meta = await query.hop_and_filter(
["entity1"], ["test concept"],
)
expected_subgraph = {
("entity1", "predicate1", "object1"),
("subject2", "entity1", "object2"),
("subject3", "predicate3", "entity1")
}
assert subgraph == expected_subgraph
assert len(selected) == 1
assert len(uri_map) == 1
assert len(edge_meta) == 1
mock_reranker_client.rerank.assert_called_once()
call_kwargs = mock_reranker_client.rerank.call_args
assert call_kwargs.kwargs["limit"] == 25
@pytest.mark.asyncio
async def test_follow_edges_with_path_length_zero(self):
"""Test Query.follow_edges method with path_length=0"""
async def test_hop_and_filter_with_empty_frontier(self):
"""Test hop_and_filter with no seed entities returns empty."""
mock_rag = MagicMock()
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
)
selected, uri_map, edge_meta = await query.hop_and_filter([], ["concept"])
assert selected == []
assert uri_map == {}
assert edge_meta == {}
@pytest.mark.asyncio
async def test_hop_and_filter_filters_label_triples(self):
"""Test hop_and_filter skips rdfs:label edges."""
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_reranker_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_rag.reranker_client = mock_reranker_client
mock_rag.label_cache = MagicMock()
mock_rag.label_cache.get.return_value = None
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False
)
label_triple = MagicMock()
label_triple.s = "entity1"
label_triple.p = "http://www.w3.org/2000/01/rdf-schema#label"
label_triple.o = "Entity One"
subgraph = set()
await query.follow_edges("entity1", subgraph, path_length=0)
mock_triples_client.query_stream.assert_not_called()
assert subgraph == set()
@pytest.mark.asyncio
async def test_follow_edges_with_max_subgraph_size_limit(self):
"""Test Query.follow_edges method respects max_subgraph_size"""
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_triples_client.query_stream.return_value = [label_triple]
mock_triples_client.query.return_value = []
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
max_subgraph_size=2
triple_limit=10,
)
subgraph = {("s1", "p1", "o1"), ("s2", "p2", "o2"), ("s3", "p3", "o3")}
await query.follow_edges("entity1", subgraph, path_length=1)
mock_triples_client.query_stream.assert_not_called()
assert len(subgraph) == 3
@pytest.mark.asyncio
async def test_get_subgraph_method(self):
"""Test Query.get_subgraph returns (subgraph, entities, concepts) tuple"""
mock_rag = MagicMock()
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
max_path_length=1
selected, uri_map, edge_meta = await query.hop_and_filter(
["entity1"], ["concept"],
)
# Mock get_entities to return (entities, concepts) tuple
query.get_entities = AsyncMock(
return_value=(["entity1", "entity2"], ["concept1"])
)
query.follow_edges_batch = AsyncMock(return_value=(
{
("entity1", "predicate1", "object1"),
("entity2", "predicate2", "object2")
},
{}
))
subgraph, term_map, entities, concepts = await query.get_subgraph("test query")
query.get_entities.assert_called_once_with("test query")
query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1)
assert isinstance(subgraph, list)
assert len(subgraph) == 2
assert ("entity1", "predicate1", "object1") in subgraph
assert ("entity2", "predicate2", "object2") in subgraph
assert entities == ["entity1", "entity2"]
assert concepts == ["concept1"]
@pytest.mark.asyncio
async def test_get_labelgraph_method(self):
"""Test Query.get_labelgraph returns (labeled_edges, uri_map, entities, concepts)"""
mock_rag = MagicMock()
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
max_subgraph_size=100
)
test_subgraph = [
("entity1", "predicate1", "object1"),
("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"),
("entity3", "predicate3", "object3")
]
test_entities = ["entity1", "entity3"]
test_concepts = ["concept1"]
query.get_subgraph = AsyncMock(
return_value=(test_subgraph, {}, test_entities, test_concepts)
)
async def mock_maybe_label(entity):
label_map = {
"entity1": "Human Entity One",
"predicate1": "Human Predicate One",
"object1": "Human Object One",
"entity3": "Human Entity Three",
"predicate3": "Human Predicate Three",
"object3": "Human Object Three"
}
return label_map.get(entity, entity)
query.maybe_label = AsyncMock(side_effect=mock_maybe_label)
labeled_edges, uri_map, entities, concepts = await query.get_labelgraph("test query")
query.get_subgraph.assert_called_once_with("test query")
# Label triples filtered out
assert len(labeled_edges) == 2
# maybe_label called for non-label triples
assert query.maybe_label.call_count == 6
expected_edges = [
("Human Entity One", "Human Predicate One", "Human Object One"),
("Human Entity Three", "Human Predicate Three", "Human Object Three")
]
assert labeled_edges == expected_edges
assert len(uri_map) == 2
assert entities == test_entities
assert concepts == test_concepts
assert selected == []
mock_reranker_client.rerank.assert_not_called()
@pytest.mark.asyncio
async def test_graph_rag_query_method(self):
"""Test GraphRag.query method orchestrates full RAG pipeline with provenance"""
import json
from trustgraph.retrieval.graph_rag.graph_rag import edge_id
mock_prompt_client = AsyncMock()
mock_embeddings_client = AsyncMock()
mock_graph_embeddings_client = AsyncMock()
mock_triples_client = AsyncMock()
mock_reranker_client = AsyncMock()
expected_response = "This is the RAG response"
test_labelgraph = [("Subject", "Predicate", "Object")]
test_edge_id = edge_id("Subject", "Predicate", "Object")
test_selected_edges = [("Subject", "Predicate", "Object")]
test_eid = edge_id("Subject", "Predicate", "Object")
test_uri_map = {
test_edge_id: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
test_eid: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object")
}
test_edge_metadata = {
test_eid: {"concept": "test concept", "score": 0.95}
}
test_entities = ["http://example.org/subject"]
test_concepts = ["test concept"]
# Mock prompt responses for the multi-step process
mock_embeddings_client.embed.return_value = [[0.1, 0.2]]
mock_graph_embeddings_client.query.return_value = []
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "extract-concepts":
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-scoring":
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "score": 0.9}])
elif prompt_name == "kg-edge-reasoning":
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "reasoning": "relevant"}])
return PromptResult(response_type="text", text="test concept")
elif prompt_name == "kg-synthesis":
return PromptResult(response_type="text", text=expected_response)
return PromptResult(response_type="text", text="")
@ -614,16 +530,16 @@ class TestQuery:
embeddings_client=mock_embeddings_client,
graph_embeddings_client=mock_graph_embeddings_client,
triples_client=mock_triples_client,
verbose=False
reranker_client=mock_reranker_client,
verbose=False,
)
# Patch Query.get_labelgraph to return test data
original_get_labelgraph = Query.get_labelgraph
original_hop_and_filter = Query.hop_and_filter
async def mock_get_labelgraph(self, query_text):
return test_labelgraph, test_uri_map, test_entities, test_concepts
async def mock_hop_and_filter(self, seed_entities, concepts):
return test_selected_edges, test_uri_map, test_edge_metadata
Query.get_labelgraph = mock_get_labelgraph
Query.hop_and_filter = mock_hop_and_filter
provenance_events = []
@ -636,7 +552,7 @@ class TestQuery:
collection="test_collection",
entity_limit=25,
triple_limit=15,
explain_callback=collect_provenance
explain_callback=collect_provenance,
)
response_text, usage = response
@ -650,7 +566,6 @@ class TestQuery:
assert len(triples) > 0
assert prov_id.startswith("urn:trustgraph:")
# Verify order
assert "question" in provenance_events[0][1]
assert "grounding" in provenance_events[1][1]
assert "exploration" in provenance_events[2][1]
@ -658,4 +573,4 @@ class TestQuery:
assert "synthesis" in provenance_events[4][1]
finally:
Query.get_labelgraph = original_get_labelgraph
Query.hop_and_filter = original_hop_and_filter

View file

@ -0,0 +1,353 @@
"""
Tests for direction-aware reranker text in GraphRAG hop-and-filter.
The reranker document text varies by traversal direction:
- From S (subject is the frontier entity): text = "{p} {o}"
- From O (object is the frontier entity): text = "{s} {p}"
- From P (predicate is the frontier entity): text = "{s} {o}"
"""
import pytest
from unittest.mock import MagicMock, AsyncMock
from trustgraph.retrieval.graph_rag.graph_rag import Query, LRUCacheWithTTL
from trustgraph.schema import Term, IRI, LITERAL
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_rag(reranker_results=None):
"""Create a mock GraphRag with all clients stubbed."""
rag = MagicMock()
rag.label_cache = LRUCacheWithTTL()
rag.triples_client = AsyncMock()
rag.reranker_client = AsyncMock()
# Label lookups return empty (fall back to URI)
rag.triples_client.query.return_value = []
if reranker_results is not None:
rag.reranker_client.rerank.return_value = reranker_results
else:
rag.reranker_client.rerank.return_value = []
return rag
def _make_query(rag, max_path_length=1, edge_limit=25):
return Query(
rag=rag,
collection="test",
verbose=False,
entity_limit=50,
triple_limit=30,
max_subgraph_size=1000,
max_path_length=max_path_length,
edge_limit=edge_limit,
)
def _make_schema_triple(s, p, o):
"""Create a mock triple matching the schema interface."""
t = MagicMock()
t.s = s
t.p = p
t.o = o
return t
def _reranker_result(document_id, query_id="0", score=0.9):
r = MagicMock()
r.document_id = str(document_id)
r.query_id = str(query_id)
r.score = score
return r
# ---------------------------------------------------------------------------
# Tests: execute_batch_triple_queries direction tracking
# ---------------------------------------------------------------------------
class TestDirectionTracking:
@pytest.mark.asyncio
async def test_from_s_direction(self):
"""Triples from s=entity queries are tagged FROM_S."""
triple = _make_schema_triple("ent1", "pred", "obj")
rag = _make_rag()
async def query_stream(s=None, p=None, o=None, **kwargs):
if s is not None:
return [triple]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag)
result = await q.execute_batch_triple_queries(["ent1"], 10)
from_s = [(t, d) for t, d in result if d == Query.FROM_S]
assert len(from_s) == 1
assert from_s[0][0] is triple
@pytest.mark.asyncio
async def test_from_o_direction(self):
"""Triples from o=entity queries are tagged FROM_O."""
triple = _make_schema_triple("subj", "pred", "ent1")
rag = _make_rag()
async def query_stream(s=None, p=None, o=None, **kwargs):
if o is not None:
return [triple]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag)
result = await q.execute_batch_triple_queries(["ent1"], 10)
from_o = [(t, d) for t, d in result if d == Query.FROM_O]
assert len(from_o) == 1
assert from_o[0][0] is triple
@pytest.mark.asyncio
async def test_from_p_direction(self):
"""Triples from p=entity queries are tagged FROM_P."""
triple = _make_schema_triple("subj", "ent1", "obj")
rag = _make_rag()
async def query_stream(s=None, p=None, o=None, **kwargs):
if p is not None:
return [triple]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag)
result = await q.execute_batch_triple_queries(["ent1"], 10)
from_p = [(t, d) for t, d in result if d == Query.FROM_P]
assert len(from_p) == 1
assert from_p[0][0] is triple
# ---------------------------------------------------------------------------
# Tests: hop_and_filter reranker document text
# ---------------------------------------------------------------------------
class TestDirectionAwareRerankerText:
@pytest.mark.asyncio
async def test_from_s_uses_predicate_object(self):
"""From-S traversal: reranker text should be '{p} {o}'."""
triple = _make_schema_triple(
"http://ex/entity-A",
"http://ex/likes",
"http://ex/entity-B",
)
reranker_result = _reranker_result(0)
rag = _make_rag(reranker_results=[reranker_result])
async def query_stream(s=None, p=None, o=None, **kwargs):
if s is not None:
return [triple]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag, max_path_length=1, edge_limit=10)
await q.hop_and_filter(
seed_entities=["http://ex/entity-A"],
concepts=["likes"],
)
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
# Text should be "{p} {o}" — the URIs since no labels found
assert len(documents) == 1
assert documents[0]["text"] == "http://ex/likes http://ex/entity-B"
@pytest.mark.asyncio
async def test_from_o_uses_subject_predicate(self):
"""From-O traversal: reranker text should be '{s} {p}'."""
triple = _make_schema_triple(
"http://ex/entity-A",
"http://ex/likes",
"http://ex/entity-B",
)
reranker_result = _reranker_result(0)
rag = _make_rag(reranker_results=[reranker_result])
async def query_stream(s=None, p=None, o=None, **kwargs):
if o is not None:
return [triple]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag, max_path_length=1, edge_limit=10)
await q.hop_and_filter(
seed_entities=["http://ex/entity-B"],
concepts=["likes"],
)
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
assert len(documents) == 1
assert documents[0]["text"] == "http://ex/entity-A http://ex/likes"
@pytest.mark.asyncio
async def test_from_p_uses_subject_object(self):
"""From-P traversal: reranker text should be '{s} {o}'."""
triple = _make_schema_triple(
"http://ex/entity-A",
"http://ex/likes",
"http://ex/entity-B",
)
reranker_result = _reranker_result(0)
rag = _make_rag(reranker_results=[reranker_result])
async def query_stream(s=None, p=None, o=None, **kwargs):
if p is not None:
return [triple]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag, max_path_length=1, edge_limit=10)
await q.hop_and_filter(
seed_entities=["http://ex/likes"],
concepts=["entity"],
)
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
assert len(documents) == 1
assert documents[0]["text"] == "http://ex/entity-A http://ex/entity-B"
@pytest.mark.asyncio
async def test_mixed_directions_produce_different_text(self):
"""Edges from different directions use different text formats."""
triple_from_s = _make_schema_triple(
"http://ex/seed", "http://ex/rel", "http://ex/target",
)
triple_from_o = _make_schema_triple(
"http://ex/other", "http://ex/ref", "http://ex/seed",
)
rag = _make_rag(reranker_results=[
_reranker_result(0), _reranker_result(1),
])
async def query_stream(s=None, p=None, o=None, **kwargs):
if s == "http://ex/seed":
return [triple_from_s]
if o == "http://ex/seed":
return [triple_from_o]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag, max_path_length=1, edge_limit=10)
await q.hop_and_filter(
seed_entities=["http://ex/seed"],
concepts=["test"],
)
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
texts = {d["text"] for d in documents}
# From S: "{p} {o}" = "http://ex/rel http://ex/target"
assert "http://ex/rel http://ex/target" in texts
# From O: "{s} {p}" = "http://ex/other http://ex/ref"
assert "http://ex/other http://ex/ref" in texts
@pytest.mark.asyncio
async def test_labels_applied_to_direction_text(self):
"""Labels should be resolved and used in the direction-aware text."""
triple = _make_schema_triple(
"http://ex/entity-A",
"http://ex/likes",
"http://ex/entity-B",
)
reranker_result = _reranker_result(0)
rag = _make_rag(reranker_results=[reranker_result])
LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
async def query_stream(s=None, p=None, o=None, **kwargs):
if s is not None and p is None:
return [triple]
return []
async def label_query(s=None, p=None, o=None, limit=1, **kwargs):
if p == LABEL:
labels = {
"http://ex/entity-A": "Alice",
"http://ex/likes": "likes",
"http://ex/entity-B": "Bob",
}
if s in labels:
return [MagicMock(o=labels[s])]
return []
rag.triples_client.query_stream.side_effect = query_stream
rag.triples_client.query.side_effect = label_query
q = _make_query(rag, max_path_length=1, edge_limit=10)
await q.hop_and_filter(
seed_entities=["http://ex/entity-A"],
concepts=["friendship"],
)
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
assert len(documents) == 1
# From S with labels: "{p_label} {o_label}"
assert documents[0]["text"] == "likes Bob"
@pytest.mark.asyncio
async def test_no_duplicate_text_from_shared_object(self):
"""Multiple edges sharing an object should produce distinct texts."""
triple_a = _make_schema_triple(
"http://ex/cpu-A", "http://ex/hasCategory", "http://ex/Processors",
)
triple_b = _make_schema_triple(
"http://ex/cpu-B", "http://ex/hasCategory", "http://ex/Processors",
)
rag = _make_rag(reranker_results=[
_reranker_result(0), _reranker_result(1),
])
async def query_stream(s=None, p=None, o=None, **kwargs):
if o == "http://ex/Processors":
return [triple_a, triple_b]
return []
rag.triples_client.query_stream.side_effect = query_stream
q = _make_query(rag, max_path_length=1, edge_limit=10)
await q.hop_and_filter(
seed_entities=["http://ex/Processors"],
concepts=["CPUs"],
)
call_args = rag.reranker_client.rerank.call_args
documents = call_args.kwargs["documents"]
texts = [d["text"] for d in documents]
assert len(texts) == 2
# From O: "{s} {p}" — subjects differ, so texts differ
assert texts[0] != texts[1]
assert "http://ex/cpu-A" in texts[0]
assert "http://ex/cpu-B" in texts[1]

View file

@ -20,7 +20,7 @@ from trustgraph.provenance.namespaces import (
TG_GRAPH_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION,
TG_FOCUS, TG_SYNTHESIS, TG_ANSWER_TYPE,
TG_QUERY, TG_CONCEPT, TG_ENTITY, TG_EDGE_COUNT,
TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_SELECTED_EDGE, TG_EDGE, TG_SCORE, TG_EDGE_SELECTION,
)
@ -91,17 +91,17 @@ def build_mock_clients():
1. prompt_client.prompt("extract-concepts", ...) -> concepts
2. embeddings_client.embed(concepts) -> vectors
3. graph_embeddings_client.query(vector, ...) -> entity matches
4. triples_client.query_stream(s/p/o, ...) -> edges (follow_edges_batch)
4. triples_client.query_stream(s/p/o, ...) -> edges (hop_and_filter)
5. triples_client.query(s, LABEL, ...) -> labels (maybe_label)
6. prompt_client.prompt("kg-edge-scoring", ...) -> scored edges
7. prompt_client.prompt("kg-edge-reasoning", ...) -> reasoning
8. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns [])
9. prompt_client.prompt("kg-synthesis", ...) -> final answer
6. reranker_client.rerank(queries, documents, limit) -> scored edges
7. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns [])
8. prompt_client.prompt("kg-synthesis", ...) -> final answer
"""
prompt_client = AsyncMock()
embeddings_client = AsyncMock()
graph_embeddings_client = AsyncMock()
triples_client = AsyncMock()
reranker_client = AsyncMock()
# 1. Concept extraction
prompt_responses = {}
@ -116,7 +116,7 @@ def build_mock_clients():
EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_B)),
]
# 4. Triple queries (follow_edges_batch) - return our edges
# 4. Triple queries (hop_and_filter) - return our edges
kg_triples = [
make_schema_triple(*EDGE_1),
make_schema_triple(*EDGE_2),
@ -130,9 +130,18 @@ def build_mock_clients():
return [] # No labels found, will fall back to URI
triples_client.query.side_effect = mock_label_query
# 6+7. Edge scoring and reasoning: dynamically score/reason about
# whatever edges the query method sends us, since edge IDs are computed
# from str(Term) representations which include the full dataclass repr.
# 6. Reranker: select all documents with high scores
async def mock_rerank(queries, documents, limit):
results = []
for i, doc in enumerate(documents):
result = MagicMock()
result.document_id = doc["id"]
result.query_id = queries[0]["id"] if queries else "0"
result.score = 0.9 - (i * 0.1)
results.append(result)
return results[:limit]
reranker_client.rerank.side_effect = mock_rerank
synthesis_answer = "Quantum computing applies physics principles to computation."
async def mock_prompt(template_id, variables=None, **kwargs):
@ -141,26 +150,6 @@ def build_mock_clients():
response_type="text",
text=prompt_responses["extract-concepts"],
)
elif template_id == "kg-edge-scoring":
# Score all edges highly, using the IDs that GraphRag computed
edges = variables.get("knowledge", [])
return PromptResult(
response_type="jsonl",
objects=[
{"id": e["id"], "score": 10 - i}
for i, e in enumerate(edges)
],
)
elif template_id == "kg-edge-reasoning":
# Provide reasoning for each edge
edges = variables.get("knowledge", [])
return PromptResult(
response_type="jsonl",
objects=[
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
for i, e in enumerate(edges)
],
)
elif template_id == "kg-synthesis":
return PromptResult(
response_type="text",
@ -170,7 +159,8 @@ def build_mock_clients():
prompt_client.prompt.side_effect = mock_prompt
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
return (prompt_client, embeddings_client, graph_embeddings_client,
triples_client, reranker_client)
# ---------------------------------------------------------------------------
@ -197,7 +187,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0, # skip semantic pre-filter for simplicity
)
assert len(events) == 5, (
@ -222,7 +212,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
expected_types = [
@ -260,7 +250,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
uris = [e["explain_id"] for e in events]
@ -297,7 +287,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
q_uri = events[0]["explain_id"]
@ -320,7 +310,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
gnd_uri = events[1]["explain_id"]
@ -344,7 +334,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
exp_uri = events[2]["explain_id"]
@ -355,10 +345,10 @@ class TestGraphRagQueryProvenance:
assert int(t.o.value) > 0
@pytest.mark.asyncio
async def test_focus_has_selected_edges_with_reasoning(self):
async def test_focus_has_selected_edges_with_concept_and_score(self):
"""
The focus event should carry selected edges as quoted triples
with reasoning text.
with cross-encoder concept and score metadata.
"""
clients = build_mock_clients()
rag = GraphRag(*clients)
@ -371,7 +361,6 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
foc_uri = events[3]["explain_id"]
@ -387,11 +376,19 @@ class TestGraphRagQueryProvenance:
for t in edge_t:
assert t.o.triple is not None, "tg:edge object must be a quoted triple"
# Should have reasoning
reasoning = find_triples(foc_triples, TG_REASONING)
assert len(reasoning) > 0, "Focus should have reasoning for selected edges"
reasoning_texts = {t.o.value for t in reasoning}
assert any(r for r in reasoning_texts), "Reasoning should not be empty"
# Edge selections should be typed as EdgeSelection
edge_sel_uris = [t.o.iri for t in selected]
for uri in edge_sel_uris:
assert has_type(foc_triples, uri, TG_EDGE_SELECTION)
# Should have concept and score
concepts = find_triples(foc_triples, TG_CONCEPT)
assert len(concepts) > 0, "Focus should have tg:concept for selected edges"
scores = find_triples(foc_triples, TG_SCORE)
assert len(scores) > 0, "Focus should have tg:score for selected edges"
for t in scores:
float(t.o.value) # Should be parseable as float
@pytest.mark.asyncio
async def test_synthesis_is_answer_type(self):
@ -407,7 +404,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
syn_uri = events[4]["explain_id"]
@ -429,7 +426,7 @@ class TestGraphRagQueryProvenance:
result_text, usage = await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
assert result_text == "Quantum computing applies physics principles to computation."
@ -449,7 +446,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
parent_uri=parent,
)
@ -465,7 +462,7 @@ class TestGraphRagQueryProvenance:
result_text, usage = await rag.query(
query="What is quantum computing?",
edge_score_limit=0,
)
assert result_text == "Quantum computing applies physics principles to computation."
@ -484,7 +481,7 @@ class TestGraphRagQueryProvenance:
await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
for event in events:

View file

@ -527,7 +527,8 @@ class AsyncFlowInstance:
return result.get("response", "")
async def document_rag(self, query: str, collection: str,
doc_limit: int = 10, **kwargs: Any) -> str:
doc_limit: int = 10, fetch_limit: int = 0,
**kwargs: Any) -> str:
"""
Execute document-based RAG query (non-streaming).
@ -541,7 +542,9 @@ class AsyncFlowInstance:
Args:
query: User query text
collection: Collection identifier containing documents
doc_limit: Maximum number of document chunks to retrieve (default: 10)
doc_limit: Document chunks selected into the prompt (default: 10)
fetch_limit: Candidate chunks fetched from the vector store before
reranking (default: 0 = derive from doc_limit)
**kwargs: Additional service-specific parameters
Returns:
@ -564,6 +567,7 @@ class AsyncFlowInstance:
"query": query,
"collection": collection,
"doc-limit": doc_limit,
"fetch-limit": fetch_limit,
"streaming": False
}
request_data.update(kwargs)
@ -646,6 +650,16 @@ class AsyncFlowInstance:
return await self.request("embeddings", request_data)
async def rerank(self, queries: list, documents: list, limit: int = 10, **kwargs: Any):
request_data = {
"queries": queries,
"documents": documents,
"limit": limit,
}
request_data.update(kwargs)
return await self.request("reranker", request_data)
async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs: Any):
"""
Query RDF triples using pattern matching.

View file

@ -94,7 +94,9 @@ class AsyncSocketClient:
if resp.get("type") == "auth-ok":
if not self._workspace_explicit:
self.workspace = resp.get("workspace", self.workspace)
self.workspace = resp.get(
"default_workspace", self.workspace,
)
elif resp.get("type") == "auth-failed":
await self._socket.close()
raise ProtocolException(
@ -377,12 +379,14 @@ class AsyncSocketFlowInstance:
yield chunk.content
async def document_rag(self, query: str, collection: str,
doc_limit: int = 10, streaming: bool = False, **kwargs):
doc_limit: int = 10, fetch_limit: int = 0,
streaming: bool = False, **kwargs):
"""Document RAG with optional streaming"""
request = {
"query": query,
"collection": collection,
"doc-limit": doc_limit,
"fetch-limit": fetch_limit,
"streaming": streaming
}
request.update(kwargs)
@ -441,6 +445,19 @@ class AsyncSocketFlowInstance:
return await self.client._send_request("embeddings", self.flow_id, request)
async def rerank(self, queries: list, documents: list, limit: int = 10,
**kwargs):
request = {
"queries": queries,
"documents": documents,
"limit": limit,
}
request.update(kwargs)
return await self.client._send_request(
"reranker", self.flow_id, request,
)
async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs):
"""Triple pattern query"""
request = {"limit": limit}

View file

@ -18,6 +18,7 @@ TG_EDGE_COUNT = TG + "edgeCount"
TG_SELECTED_EDGE = TG + "selectedEdge"
TG_EDGE = TG + "edge"
TG_REASONING = TG + "reasoning"
TG_SCORE = TG + "score"
TG_DOCUMENT = TG + "document"
TG_CONCEPT = TG + "concept"
TG_ENTITY = TG + "entity"
@ -66,10 +67,12 @@ RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
@dataclass
class EdgeSelection:
"""A selected edge with reasoning from GraphRAG Focus step."""
"""A selected edge with cross-encoder metadata from GraphRAG Focus step."""
uri: str
edge: Optional[Dict[str, str]] = None # {"s": ..., "p": ..., "o": ...}
reasoning: str = ""
concept: str = ""
score: Optional[float] = None
@dataclass
@ -209,7 +212,7 @@ class Exploration(ExplainEntity):
@dataclass
class Focus(ExplainEntity):
"""Focus entity - selected edges with LLM reasoning (GraphRAG only)."""
"""Focus entity - selected edges with cross-encoder scoring (GraphRAG only)."""
selected_edge_uris: List[str] = field(default_factory=list)
edge_selections: List[EdgeSelection] = field(default_factory=list)
@ -418,14 +421,26 @@ def parse_edge_selection_triples(triples: List[Tuple[str, str, Any]]) -> EdgeSel
uri = triples[0][0] if triples else ""
edge = None
reasoning = ""
concept = ""
score = None
for s, p, o in triples:
if p == TG_EDGE and isinstance(o, dict):
edge = o
elif p == TG_REASONING:
reasoning = o
elif p == TG_CONCEPT:
concept = o
elif p == TG_SCORE:
try:
score = float(o)
except (ValueError, TypeError):
score = None
return EdgeSelection(uri=uri, edge=edge, reasoning=reasoning)
return EdgeSelection(
uri=uri, edge=edge, reasoning=reasoning,
concept=concept, score=score,
)
def extract_term_value(term: Dict[str, Any]) -> Any:

View file

@ -415,7 +415,7 @@ class FlowInstance:
def document_rag(
self, query,collection="default",
doc_limit=10,
doc_limit=10, fetch_limit=0,
):
"""
Execute document-based Retrieval-Augmented Generation (RAG) query.
@ -426,7 +426,9 @@ class FlowInstance:
Args:
query: Natural language query
collection: Collection identifier (default: "default")
doc_limit: Maximum document chunks to retrieve (default: 10)
doc_limit: Document chunks selected into the prompt (default: 10)
fetch_limit: Candidate chunks fetched from the vector store before
reranking (default: 0 = derive from doc_limit)
Returns:
str: Generated response incorporating document context
@ -447,6 +449,7 @@ class FlowInstance:
"query": query,
"collection": collection,
"doc-limit": doc_limit,
"fetch-limit": fetch_limit,
}
result = self.request(
@ -491,6 +494,19 @@ class FlowInstance:
input
)["vectors"]
def rerank(self, queries, documents, limit=10):
input = {
"queries": queries,
"documents": documents,
"limit": limit,
}
return self.request(
"service/reranker",
input
)
def graph_embeddings_query(self, text, collection, limit=10):
"""
Query knowledge graph entities using semantic similarity.

View file

@ -168,7 +168,9 @@ class SocketClient:
if resp.get("type") == "auth-ok":
if self.workspace == "default":
self.workspace = resp.get("workspace", self.workspace)
self.workspace = resp.get(
"default_workspace", self.workspace,
)
elif resp.get("type") == "auth-failed":
await self._socket.close()
raise ProtocolException(
@ -750,6 +752,7 @@ class SocketFlowInstance:
query: str,
collection: str,
doc_limit: int = 10,
fetch_limit: int = 0,
streaming: bool = False,
**kwargs: Any
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
@ -762,6 +765,7 @@ class SocketFlowInstance:
"query": query,
"collection": collection,
"doc-limit": doc_limit,
"fetch-limit": fetch_limit,
"streaming": streaming
}
request.update(kwargs)
@ -783,6 +787,7 @@ class SocketFlowInstance:
query: str,
collection: str,
doc_limit: int = 10,
fetch_limit: int = 0,
**kwargs: Any
) -> Iterator[Union[RAGChunk, ProvenanceEvent]]:
"""Execute document-based RAG query with explainability support."""
@ -790,6 +795,7 @@ class SocketFlowInstance:
"query": query,
"collection": collection,
"doc-limit": doc_limit,
"fetch-limit": fetch_limit,
"streaming": True,
"explainable": True,
}
@ -883,6 +889,19 @@ class SocketFlowInstance:
return self.client._send_request_sync("embeddings", self.flow_id, request, False)
def rerank(self, queries: list, documents: list, limit: int = 10,
**kwargs: Any) -> Dict[str, Any]:
request = {
"queries": queries,
"documents": documents,
"limit": limit,
}
request.update(kwargs)
return self.client._send_request_sync(
"reranker", self.flow_id, request, False,
)
def triples_query(
self,
s: Optional[Union[str, Dict[str, Any]]] = None,

View file

@ -42,6 +42,8 @@ from . dynamic_tool_service import DynamicToolService
from . tool_service_client import ToolServiceClientSpec
from . agent_client import AgentClientSpec
from . structured_query_client import StructuredQueryClientSpec
from . reranker_client import RerankerClientSpec
from . reranker_service import RerankerService
from . row_embeddings_query_client import RowEmbeddingsQueryClientSpec
from . collection_config_handler import CollectionConfigHandler

View file

@ -103,35 +103,19 @@ def resolve_cassandra_config(
host: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
default_keyspace: Optional[str] = None
default_keyspace: Optional[str] = None,
replication_factor: Optional[int] = None,
) -> Tuple[List[str], Optional[str], Optional[str], Optional[str], int]:
"""
Resolve Cassandra configuration from various sources.
Can accept either argparse args object or explicit parameters.
Converts host string to list format for Cassandra driver.
Args:
args: Optional argparse namespace with cassandra_host, cassandra_username, cassandra_password, cassandra_keyspace, cassandra_replication_factor
host: Optional explicit host parameter (overrides args)
username: Optional explicit username parameter (overrides args)
password: Optional explicit password parameter (overrides args)
default_keyspace: Optional default keyspace if not specified elsewhere
Returns:
tuple: (hosts_list, username, password, keyspace, replication_factor)
"""
# If args provided, extract values
keyspace = None
replication_factor = 1
if args is not None:
host = host or getattr(args, 'cassandra_host', None)
username = username or getattr(args, 'cassandra_username', None)
password = password or getattr(args, 'cassandra_password', None)
keyspace = getattr(args, 'cassandra_keyspace', None)
replication_factor = getattr(args, 'cassandra_replication_factor', 1)
replication_factor = replication_factor or getattr(
args, 'cassandra_replication_factor', None
)
# Apply defaults if still None
defaults = get_cassandra_defaults()
host = host or defaults['host']
username = username or defaults['username']

View file

@ -65,31 +65,25 @@ class IamClient(RequestResponse):
async def authenticate_anonymous(self, timeout=IAM_TIMEOUT):
"""Request anonymous access from the IAM regime.
Returns ``(user_id, workspace, roles)`` if the regime permits
anonymous access, or raises ``RuntimeError`` with error type
``auth-failed`` if it does not."""
Returns ``(user_id, default_workspace, roles)`` if the regime
permits anonymous access, or raises ``RuntimeError`` with
error type ``auth-failed`` if it does not."""
resp = await self._request(
operation="authenticate-anonymous",
timeout=timeout,
)
return (
resp.resolved_user_id,
resp.resolved_workspace,
resp.resolved_default_workspace,
list(resp.resolved_roles),
)
async def resolve_api_key(self, api_key, timeout=IAM_TIMEOUT):
"""Resolve a plaintext API key to its identity triple.
Returns ``(user_id, workspace, roles)`` or raises
Returns ``(user_id, default_workspace, roles)`` or raises
``RuntimeError`` with error type ``auth-failed`` if the key is
unknown / expired / revoked.
Note: the ``roles`` value is a regime-internal hint and is
not used by the gateway directly under the IAM contract;
all authorisation decisions go through ``authorise()``.
Returned here only for backward compatibility with callers
that haven't migrated."""
unknown / expired / revoked."""
resp = await self._request(
operation="resolve-api-key",
api_key=api_key,
@ -97,7 +91,7 @@ class IamClient(RequestResponse):
)
return (
resp.resolved_user_id,
resp.resolved_workspace,
resp.resolved_default_workspace,
list(resp.resolved_roles),
)

View file

@ -157,21 +157,6 @@ class PromptClient(RequestResponse):
timeout = timeout,
)
async def kg_prompt(self, query, kg, timeout=600, streaming=False, chunk_callback=None):
return await self.prompt(
id = "kg-prompt",
variables = {
"query": query,
"knowledge": [
{ "s": v[0], "p": v[1], "o": v[2] }
for v in kg
]
},
timeout = timeout,
streaming = streaming,
chunk_callback = chunk_callback,
)
async def document_prompt(self, query, documents, timeout=600, streaming=False, chunk_callback=None):
return await self.prompt(
id = "document-prompt",

View file

@ -0,0 +1,87 @@
import os
import argparse
from typing import Optional, Any, Tuple
def get_qdrant_defaults() -> dict:
return {
'url': os.getenv('QDRANT_URL', 'http://localhost:6333'),
'api_key': os.getenv('QDRANT_API_KEY'),
'replication_factor': int(os.getenv('QDRANT_REPLICATION_FACTOR', '1')),
'shard_number': int(os.getenv('QDRANT_SHARD_NUMBER', '1')),
}
def add_qdrant_args(parser: argparse.ArgumentParser) -> None:
defaults = get_qdrant_defaults()
url_help = f"Qdrant URL (default: {defaults['url']})"
if 'QDRANT_URL' in os.environ:
url_help += " [from QDRANT_URL]"
api_key_help = "Qdrant API key"
if defaults['api_key']:
api_key_help += " (default: <set>)"
if 'QDRANT_API_KEY' in os.environ:
api_key_help += " [from QDRANT_API_KEY]"
replication_help = f"Qdrant collection replication factor (default: {defaults['replication_factor']})"
if 'QDRANT_REPLICATION_FACTOR' in os.environ:
replication_help += " [from QDRANT_REPLICATION_FACTOR]"
shard_help = f"Qdrant collection shard number (default: {defaults['shard_number']})"
if 'QDRANT_SHARD_NUMBER' in os.environ:
shard_help += " [from QDRANT_SHARD_NUMBER]"
parser.add_argument(
'--store-uri',
default=defaults['url'],
help=url_help,
)
parser.add_argument(
'--api-key',
default=defaults['api_key'],
help=api_key_help,
)
parser.add_argument(
'--qdrant-replication-factor',
type=int,
default=defaults['replication_factor'],
help=replication_help,
)
parser.add_argument(
'--qdrant-shard-number',
type=int,
default=defaults['shard_number'],
help=shard_help,
)
def resolve_qdrant_config(
args: Optional[Any] = None,
url: Optional[str] = None,
api_key: Optional[str] = None,
replication_factor: Optional[int] = None,
shard_number: Optional[int] = None,
) -> Tuple[str, Optional[str], int, int]:
if args is not None:
url = url or getattr(args, 'store_uri', None)
api_key = api_key or getattr(args, 'api_key', None)
replication_factor = replication_factor or getattr(
args, 'qdrant_replication_factor', None
)
shard_number = shard_number or getattr(
args, 'qdrant_shard_number', None
)
defaults = get_qdrant_defaults()
url = url or defaults['url']
api_key = api_key or defaults['api_key']
replication_factor = replication_factor or defaults['replication_factor']
shard_number = shard_number or defaults['shard_number']
return url, api_key, replication_factor, shard_number

View file

@ -0,0 +1,43 @@
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import (
RerankerRequest, RerankerResponse,
RerankerQuery, RerankerDocument,
)
class RerankerClient(RequestResponse):
async def rerank(self, queries, documents, limit=10, timeout=300):
resp = await self.request(
RerankerRequest(
queries=[
RerankerQuery(query_id=q["id"], query_text=q["text"])
for q in queries
],
documents=[
RerankerDocument(
document_id=d["id"], document_text=d["text"]
)
for d in documents
],
limit=limit,
),
timeout=timeout
)
if resp.error:
raise RuntimeError(resp.error.message)
return resp.results
class RerankerClientSpec(RequestResponseSpec):
def __init__(
self, request_name, response_name,
):
super(RerankerClientSpec, self).__init__(
request_name = request_name,
request_schema = RerankerRequest,
response_name = response_name,
response_schema = RerankerResponse,
impl = RerankerClient,
)

View file

@ -0,0 +1,109 @@
from __future__ import annotations
from argparse import ArgumentParser
import logging
from .. schema import (
RerankerRequest, RerankerResponse, RerankerResult, Error,
)
from .. exceptions import TooManyRequests
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec, ParameterSpec
logger = logging.getLogger(__name__)
default_ident = "reranker"
default_concurrency = 1
class RerankerService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
concurrency = params.get("concurrency", 1)
super(RerankerService, self).__init__(**params | {
"id": id,
"concurrency": concurrency,
})
self.register_specification(
ConsumerSpec(
name = "request",
schema = RerankerRequest,
handler = self.on_request,
concurrency = concurrency,
)
)
self.register_specification(
ProducerSpec(
name = "response",
schema = RerankerResponse
)
)
self.register_specification(
ParameterSpec(
name = "model",
)
)
async def on_request(self, msg, consumer, flow):
try:
request = msg.value()
id = msg.properties()["id"]
logger.debug(f"Handling reranker request {id}...")
model = flow("model")
results = await self.on_rerank(
request.queries, request.documents,
request.limit, model=model,
)
await flow("response").send(
RerankerResponse(
error = None,
results = results,
),
properties={"id": id}
)
logger.debug("Reranker request handled successfully")
except TooManyRequests as e:
raise e
except Exception as e:
logger.error(f"Exception in reranker service: {e}", exc_info=True)
logger.info("Sending error response...")
await flow.producer["response"].send(
RerankerResponse(
error=Error(
type = "reranker-error",
message = str(e),
),
results=[],
),
properties={"id": id}
)
@staticmethod
def add_args(parser: ArgumentParser) -> None:
parser.add_argument(
'-c', '--concurrency',
type=int,
default=default_concurrency,
help=f'Concurrent processing threads (default: {default_concurrency})'
)
FlowProcessor.add_args(parser)

View file

@ -140,20 +140,6 @@ class PromptClient(BaseClient):
timeout=timeout
)
def request_kg_prompt(self, query, kg, timeout=300):
return self.request(
id="kg-prompt",
variables={
"query": query,
"knowledge": [
{ "s": v[0], "p": v[1], "o": v[2] }
for v in kg
]
},
timeout=timeout
)
def request_document_prompt(self, query, documents, timeout=300):
return self.request(

View file

@ -27,6 +27,7 @@ from .translators.rows_query import RowsQueryRequestTranslator, RowsQueryRespons
from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator
from .translators.structured_query import StructuredQueryRequestTranslator, StructuredQueryResponseTranslator
from .translators.diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
from .translators.reranker import RerankerRequestTranslator, RerankerResponseTranslator
from .translators.collection import CollectionManagementRequestTranslator, CollectionManagementResponseTranslator
from .translators.sparql_query import SparqlQueryRequestTranslator, SparqlQueryResponseTranslator
@ -163,6 +164,12 @@ TranslatorRegistry.register_service(
SparqlQueryResponseTranslator()
)
TranslatorRegistry.register_service(
"reranker",
RerankerRequestTranslator(),
RerankerResponseTranslator()
)
# Register single-direction translators for document loading
TranslatorRegistry.register_request("document", DocumentTranslator())
TranslatorRegistry.register_request("text-document", TextDocumentTranslator())

View file

@ -20,3 +20,4 @@ from .embeddings_query import (
)
from .rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator
from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator
from .reranker import RerankerRequestTranslator, RerankerResponseTranslator

View file

@ -5,6 +5,7 @@ from ...schema import (
UserInput, UserRecord,
WorkspaceInput, WorkspaceRecord,
ApiKeyInput, ApiKeyRecord,
GroupInput, GrantInput,
)
from .base import MessageTranslator
@ -43,12 +44,31 @@ def _api_key_input_from_dict(d):
)
def _group_input_from_dict(d):
if d is None:
return None
return GroupInput(
name=d.get("name", ""),
description=d.get("description", ""),
enabled=d.get("enabled", True),
)
def _grant_input_from_dict(d):
if d is None:
return None
return GrantInput(
capability=d.get("capability", ""),
workspace=d.get("workspace", ""),
)
def _user_record_to_dict(r):
if r is None:
return None
return {
"id": r.id,
"workspace": r.workspace,
"default_workspace": r.default_workspace,
"username": r.username,
"name": r.name,
"email": r.email,
@ -102,6 +122,15 @@ class IamRequestTranslator(MessageTranslator):
data.get("workspace_record")
),
key=_api_key_input_from_dict(data.get("key")),
group_id=data.get("group_id", ""),
member_type=data.get("member_type", ""),
member_id=data.get("member_id", ""),
group=_group_input_from_dict(data.get("group")),
grant=_grant_input_from_dict(data.get("grant")),
capability=data.get("capability", ""),
resource_json=data.get("resource_json", ""),
parameters_json=data.get("parameters_json", ""),
authorise_checks=data.get("authorise_checks", ""),
)
def encode(self, obj: IamRequest) -> Dict[str, Any]:
@ -109,6 +138,9 @@ class IamRequestTranslator(MessageTranslator):
for fname in (
"workspace", "actor", "user_id", "username", "key_id",
"api_key", "password", "new_password",
"group_id", "member_type", "member_id",
"capability", "resource_json", "parameters_json",
"authorise_checks",
):
v = getattr(obj, fname, "")
if v:
@ -135,6 +167,17 @@ class IamRequestTranslator(MessageTranslator):
"name": obj.key.name,
"expires": obj.key.expires,
}
if obj.group is not None:
result["group"] = {
"name": obj.group.name,
"description": obj.group.description,
"enabled": obj.group.enabled,
}
if obj.grant is not None:
result["grant"] = {
"capability": obj.grant.capability,
"workspace": obj.grant.workspace,
}
return result
@ -175,8 +218,8 @@ class IamResponseTranslator(MessageTranslator):
result["signing_key_public"] = obj.signing_key_public
if obj.resolved_user_id:
result["resolved_user_id"] = obj.resolved_user_id
if obj.resolved_workspace:
result["resolved_workspace"] = obj.resolved_workspace
if obj.resolved_default_workspace:
result["resolved_default_workspace"] = obj.resolved_default_workspace
if obj.resolved_roles:
result["resolved_roles"] = list(obj.resolved_roles)
if obj.temporary_password:
@ -190,6 +233,23 @@ class IamResponseTranslator(MessageTranslator):
# setup, so it can't be dropped by a truthy-only filter.
result["bootstrap_available"] = bool(obj.bootstrap_available)
# authorise / authorise-many outputs.
if obj.decision_allow:
result["decision_allow"] = obj.decision_allow
if obj.decision_ttl_seconds:
result["decision_ttl_seconds"] = obj.decision_ttl_seconds
if obj.decisions_json:
result["decisions_json"] = obj.decisions_json
# Enterprise IAM outputs.
for fname in (
"group_json", "groups_json", "members_json",
"grants_json", "effective_permissions_json",
):
v = getattr(obj, fname, "")
if v:
result[fname] = v
return result
def encode_with_completion(

View file

@ -0,0 +1,73 @@
from typing import Dict, Any, Tuple
from ...schema import (
RerankerRequest, RerankerResponse,
RerankerQuery, RerankerDocument, RerankerResult,
)
from .base import MessageTranslator
class RerankerRequestTranslator(MessageTranslator):
def decode(self, data: Dict[str, Any]) -> RerankerRequest:
return RerankerRequest(
queries=[
RerankerQuery(
query_id=q["query_id"],
query_text=q["query_text"],
)
for q in data.get("queries", [])
],
documents=[
RerankerDocument(
document_id=d["document_id"],
document_text=d["document_text"],
)
for d in data.get("documents", [])
],
limit=data.get("limit", 10),
)
def encode(self, obj: RerankerRequest) -> Dict[str, Any]:
return {
"queries": [
{"query_id": q.query_id, "query_text": q.query_text}
for q in obj.queries
],
"documents": [
{"document_id": d.document_id, "document_text": d.document_text}
for d in obj.documents
],
"limit": obj.limit,
}
class RerankerResponseTranslator(MessageTranslator):
def decode(self, data: Dict[str, Any]) -> RerankerResponse:
return RerankerResponse(
results=[
RerankerResult(
document_id=r["document_id"],
query_id=r["query_id"],
score=r["score"],
)
for r in data.get("results", [])
],
)
def encode(self, obj: RerankerResponse) -> Dict[str, Any]:
return {
"results": [
{
"document_id": r.document_id,
"query_id": r.query_id,
"score": r.score,
}
for r in obj.results
],
}
def encode_with_completion(
self, obj: RerankerResponse
) -> Tuple[Dict[str, Any], bool]:
return self.encode(obj), True

View file

@ -12,6 +12,7 @@ class DocumentRagRequestTranslator(MessageTranslator):
query=data["query"],
collection=data.get("collection", "default"),
doc_limit=int(data.get("doc-limit", 20)),
fetch_limit=int(data.get("fetch-limit", 0)),
streaming=data.get("streaming", False)
)
@ -20,6 +21,7 @@ class DocumentRagRequestTranslator(MessageTranslator):
"query": obj.query,
"collection": obj.collection,
"doc-limit": obj.doc_limit,
"fetch-limit": obj.fetch_limit,
"streaming": getattr(obj, "streaming", False)
}

View file

@ -64,6 +64,8 @@ from . uris import (
docrag_question_uri,
docrag_grounding_uri,
docrag_exploration_uri,
docrag_focus_uri,
chunk_selection_uri,
docrag_synthesis_uri,
)
@ -89,9 +91,13 @@ from . namespaces import (
TG_IMAGE_TYPE, TG_SUBGRAPH_TYPE,
# Query-time provenance predicates (GraphRAG)
TG_QUERY, TG_CONCEPT, TG_ENTITY,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING, TG_SCORE,
# Edge selection entity type
TG_EDGE_SELECTION,
# Query-time provenance predicates (DocumentRAG)
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
# Chunk selection entity type
TG_CHUNK_SELECTION,
# Explainability entity types
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_ANALYSIS, TG_CONCLUSION,
@ -130,6 +136,7 @@ from . triples import (
# Query-time provenance triple builders (DocumentRAG)
docrag_question_triples,
docrag_exploration_triples,
docrag_chunk_selection_triples,
docrag_synthesis_triples,
# Utility
set_graph,
@ -194,6 +201,8 @@ __all__ = [
"docrag_question_uri",
"docrag_grounding_uri",
"docrag_exploration_uri",
"docrag_focus_uri",
"chunk_selection_uri",
"docrag_synthesis_uri",
# Namespaces
"PROV", "PROV_ENTITY", "PROV_ACTIVITY", "PROV_AGENT",
@ -212,9 +221,13 @@ __all__ = [
"TG_CHUNK_TYPE", "TG_IMAGE_TYPE", "TG_SUBGRAPH_TYPE",
# Query-time provenance predicates (GraphRAG)
"TG_QUERY", "TG_CONCEPT", "TG_ENTITY",
"TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING",
"TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING", "TG_SCORE",
# Edge selection entity type
"TG_EDGE_SELECTION",
# Query-time provenance predicates (DocumentRAG)
"TG_CHUNK_COUNT", "TG_SELECTED_CHUNK",
# Chunk selection entity type
"TG_CHUNK_SELECTION",
# Explainability entity types
"TG_QUESTION", "TG_GROUNDING", "TG_EXPLORATION", "TG_FOCUS", "TG_SYNTHESIS",
"TG_ANALYSIS", "TG_CONCLUSION",
@ -250,6 +263,7 @@ __all__ = [
# Query-time provenance triple builders (DocumentRAG)
"docrag_question_triples",
"docrag_exploration_triples",
"docrag_chunk_selection_triples",
"docrag_synthesis_triples",
# Agent provenance triple builders
"agent_session_triples",

View file

@ -66,12 +66,19 @@ TG_EDGE_COUNT = TG + "edgeCount"
TG_SELECTED_EDGE = TG + "selectedEdge"
TG_EDGE = TG + "edge"
TG_REASONING = TG + "reasoning"
TG_SCORE = TG + "score"
TG_DOCUMENT = TG + "document" # Reference to document in librarian
# Edge selection entity type (cross-encoder scored edge in Focus)
TG_EDGE_SELECTION = TG + "EdgeSelection"
# Query-time provenance predicates (DocumentRAG)
TG_CHUNK_COUNT = TG + "chunkCount"
TG_SELECTED_CHUNK = TG + "selectedChunk"
# Chunk selection entity type (cross-encoder reranked chunk in Focus)
TG_CHUNK_SELECTION = TG + "ChunkSelection"
# Extraction provenance entity types
TG_DOCUMENT_TYPE = TG + "Document"
TG_PAGE_TYPE = TG + "Page"

View file

@ -24,10 +24,14 @@ from . namespaces import (
TG_ELEMENT_TYPES, TG_TABLE_COUNT, TG_IMAGE_COUNT,
# Query-time provenance predicates (GraphRAG)
TG_QUERY, TG_CONCEPT, TG_ENTITY,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING,
TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, TG_SCORE,
TG_DOCUMENT,
# Edge selection entity type
TG_EDGE_SELECTION,
# Query-time provenance predicates (DocumentRAG)
TG_CHUNK_COUNT, TG_SELECTED_CHUNK,
# Chunk selection entity type
TG_CHUNK_SELECTION,
# Explainability entity types
TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
# Unifying types
@ -38,7 +42,10 @@ from . namespaces import (
TG_IN_TOKEN, TG_OUT_TOKEN,
)
from . uris import activity_uri, agent_uri, subgraph_uri, edge_selection_uri
from . uris import (
activity_uri, agent_uri, subgraph_uri, edge_selection_uri,
chunk_selection_uri,
)
def set_graph(triples: List[Triple], graph: str) -> List[Triple]:
@ -536,10 +543,9 @@ def focus_triples(
_triple(focus_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)),
]
# Add each selected edge with its reasoning via intermediate entity
# Add each selected edge with metadata via intermediate entity
for idx, edge_info in enumerate(selected_edges_with_reasoning):
edge = edge_info.get("edge")
reasoning = edge_info.get("reasoning", "")
if edge:
s, p, o = edge
@ -552,13 +558,32 @@ def focus_triples(
_triple(focus_uri, TG_SELECTED_EDGE, _iri(edge_sel_uri))
)
# Type the edge selection entity
triples.append(
_triple(edge_sel_uri, RDF_TYPE, _iri(TG_EDGE_SELECTION))
)
# Attach quoted triple to edge selection entity
quoted = _quoted_triple(s, p, o)
triples.append(
Triple(s=_iri(edge_sel_uri), p=_iri(TG_EDGE), o=quoted)
)
# Attach reasoning to edge selection entity
# Structured cross-encoder metadata
concept = edge_info.get("concept")
if concept:
triples.append(
_triple(edge_sel_uri, TG_CONCEPT, _literal(concept))
)
score = edge_info.get("score")
if score is not None:
triples.append(
_triple(edge_sel_uri, TG_SCORE, _literal(str(score)))
)
# Legacy reasoning text (for non-cross-encoder callers)
reasoning = edge_info.get("reasoning", "")
if reasoning:
triples.append(
_triple(edge_sel_uri, TG_REASONING, _literal(reasoning))
@ -698,6 +723,75 @@ def docrag_exploration_triples(
return triples
def docrag_chunk_selection_triples(
focus_uri: str,
exploration_uri: str,
selected_chunks_with_scores: List[dict],
session_id: str,
) -> List[Triple]:
"""
Build triples for a document RAG focus entity (chunks selected by the
cross-encoder reranker).
Mirrors GraphRAG's focus_triples / tg:EdgeSelection pattern: a Focus entity
derived from exploration, with one ChunkSelection sub-entity per surviving
chunk carrying the chunk reference and the reranker score.
Structure:
<focus> a tg:Focus ; prov:wasDerivedFrom <exploration> .
<focus> tg:selectedChunk <chunk_sel_0> .
<chunk_sel_0> a tg:ChunkSelection .
<chunk_sel_0> tg:document <chunk_id> .
<chunk_sel_0> tg:score "0.97" .
Args:
focus_uri: URI of the focus entity (from docrag_focus_uri)
exploration_uri: URI of the parent exploration entity
selected_chunks_with_scores: List of dicts with 'chunk_id' and 'score'
session_id: Session UUID for generating chunk selection URIs
Returns:
List of Triple objects
"""
triples = [
_triple(focus_uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(focus_uri, RDF_TYPE, _iri(TG_FOCUS)),
_triple(focus_uri, RDFS_LABEL, _literal("Chunk Selection")),
_triple(focus_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)),
]
for idx, chunk_info in enumerate(selected_chunks_with_scores):
chunk_id = chunk_info.get("chunk_id")
if not chunk_id:
continue
chunk_sel_uri = chunk_selection_uri(session_id, idx)
# Link focus to chunk selection entity
triples.append(
_triple(focus_uri, TG_SELECTED_CHUNK, _iri(chunk_sel_uri))
)
# Type the chunk selection entity
triples.append(
_triple(chunk_sel_uri, RDF_TYPE, _iri(TG_CHUNK_SELECTION))
)
# Reference the actual chunk (in librarian)
triples.append(
_triple(chunk_sel_uri, TG_DOCUMENT, _iri(chunk_id))
)
# Cross-encoder score
score = chunk_info.get("score")
if score is not None:
triples.append(
_triple(chunk_sel_uri, TG_SCORE, _literal(str(score)))
)
return triples
def docrag_synthesis_triples(
synthesis_uri: str,
exploration_uri: str,

View file

@ -309,6 +309,35 @@ def docrag_exploration_uri(session_id: str) -> str:
return f"urn:trustgraph:docrag:{session_id}/exploration"
def docrag_focus_uri(session_id: str) -> str:
"""
Generate URI for a document RAG focus entity (chunks selected by the
cross-encoder reranker).
Args:
session_id: The session UUID.
Returns:
URN in format: urn:trustgraph:docrag:{uuid}/focus
"""
return f"urn:trustgraph:docrag:{session_id}/focus"
def chunk_selection_uri(session_id: str, chunk_index: int) -> str:
"""
Generate URI for a chunk selection item (links a reranked chunk to its
score). Mirrors edge_selection_uri for GraphRAG.
Args:
session_id: The session UUID.
chunk_index: Index of this chunk in the selection (0-based).
Returns:
URN in format: urn:trustgraph:prov:chunk:{uuid}:{index}
"""
return f"urn:trustgraph:prov:chunk:{session_id}:{chunk_index}"
def docrag_synthesis_uri(session_id: str) -> str:
"""
Generate URI for a document RAG synthesis entity (final answer).

View file

@ -29,6 +29,8 @@ from . namespaces import (
TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE,
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
TG_SUBAGENT_GOAL, TG_PLAN_STEP,
TG_EDGE_SELECTION, TG_SCORE,
TG_CHUNK_SELECTION,
)
@ -93,6 +95,8 @@ TG_CLASS_LABELS = [
_label_triple(TG_FINDING, "Finding"),
_label_triple(TG_PLAN_TYPE, "Plan"),
_label_triple(TG_STEP_RESULT, "Step Result"),
_label_triple(TG_EDGE_SELECTION, "Edge Selection"),
_label_triple(TG_CHUNK_SELECTION, "Chunk Selection"),
]
# TrustGraph predicate labels
@ -117,6 +121,7 @@ TG_PREDICATE_LABELS = [
_label_triple(TG_ENTITY, "entity"),
_label_triple(TG_SUBAGENT_GOAL, "subagent goal"),
_label_triple(TG_PLAN_STEP, "plan step"),
_label_triple(TG_SCORE, "score"),
]

View file

@ -16,3 +16,4 @@ from .collection import *
from .storage import *
from .tool_service import *
from .sparql_query import *
from .reranker import *

View file

@ -29,7 +29,7 @@ class UserInput:
@dataclass
class UserRecord:
id: str = ""
workspace: str = ""
default_workspace: str = ""
username: str = ""
name: str = ""
email: str = ""
@ -74,6 +74,21 @@ class ApiKeyRecord:
last_used: str = ""
# ---- Enterprise IAM types (additive) ----
@dataclass
class GroupInput:
name: str = ""
description: str = ""
enabled: bool = True
@dataclass
class GrantInput:
capability: str = ""
workspace: str = ""
@dataclass
class IamRequest:
operation: str = ""
@ -99,6 +114,13 @@ class IamRequest:
workspace_record: WorkspaceInput | None = None
key: ApiKeyInput | None = None
# ---- Enterprise IAM inputs (additive) ----
group_id: str = ""
member_type: str = ""
member_id: str = ""
group: GroupInput | None = None
grant: GrantInput | None = None
# ---- authorise / authorise-many inputs ----
# Capability string from the vocabulary in capabilities.md.
capability: str = ""
@ -138,7 +160,7 @@ class IamResponse:
# resolve-api-key
resolved_user_id: str = ""
resolved_workspace: str = ""
resolved_default_workspace: str = ""
resolved_roles: list[str] = field(default_factory=list)
# reset-password
@ -164,6 +186,14 @@ class IamResponse:
# authorise_checks.
decisions_json: str = ""
# ---- Enterprise IAM outputs (additive) ----
# JSON-serialised payloads for enterprise group/grant operations.
group_json: str = ""
groups_json: str = ""
members_json: str = ""
grants_json: str = ""
effective_permissions_json: str = ""
error: Error | None = None

View file

@ -6,17 +6,6 @@ from ..core.primitives import Error
# Prompt services, abstract the prompt generation
# extract-definitions:
# chunk -> definitions
# extract-relationships:
# chunk -> relationships
# kg-prompt:
# query, triples -> answer
# document-prompt:
# query, documents -> answer
# extract-rows
# schema, chunk -> rows
@dataclass
class PromptRequest:
id: str = ""

View file

@ -0,0 +1,35 @@
from dataclasses import dataclass, field
from ..core.primitives import Error
############################################################################
# Cross-encoder reranker
@dataclass
class RerankerQuery:
query_id: str = ""
query_text: str = ""
@dataclass
class RerankerDocument:
document_id: str = ""
document_text: str = ""
@dataclass
class RerankerRequest:
queries: list[RerankerQuery] = field(default_factory=list)
documents: list[RerankerDocument] = field(default_factory=list)
limit: int = 10
@dataclass
class RerankerResult:
document_id: str = ""
query_id: str = ""
score: float = 0.0
@dataclass
class RerankerResponse:
error: Error | None = None
results: list[RerankerResult] = field(default_factory=list)

View file

@ -40,7 +40,10 @@ class GraphRagResponse:
class DocumentRagQuery:
query: str = ""
collection: str = ""
doc_limit: int = 0
doc_limit: int = 0 # docs selected into the synthesis prompt
fetch_limit: int = 0 # candidate pool fetched from the vector store
# before reranking (0 = derive from doc_limit;
# values below doc_limit are raised to it)
streaming: bool = False
@dataclass

View file

@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=2.5,<2.6",
"trustgraph-base>=2.6,<2.7",
"pulsar-client",
"prometheus-client",
"boto3",

View file

@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=2.5,<2.6",
"trustgraph-base>=2.6,<2.7",
"requests",
"pulsar-client",
"aiohttp",
@ -71,6 +71,7 @@ tg-invoke-rows-query = "trustgraph.cli.invoke_rows_query:main"
tg-invoke-sparql-query = "trustgraph.cli.invoke_sparql_query:main"
tg-invoke-row-embeddings = "trustgraph.cli.invoke_row_embeddings:main"
tg-invoke-prompt = "trustgraph.cli.invoke_prompt:main"
tg-invoke-reranker = "trustgraph.cli.invoke_reranker:main"
tg-invoke-structured-query = "trustgraph.cli.invoke_structured_query:main"
tg-load-doc-embeds = "trustgraph.cli.load_doc_embeds:main"
tg-load-kg-core = "trustgraph.cli.load_kg_core:main"

View file

@ -53,7 +53,7 @@ def main():
help="Auth token (default: $TRUSTGRAPH_TOKEN)",
)
parser.add_argument(
"--username", required=True, help="Username (unique in workspace)",
"--username", required=True, help="Username (globally unique)",
)
parser.add_argument(
"--password", default=None,
@ -75,10 +75,7 @@ def main():
)
parser.add_argument(
"-w", "--workspace", default=None,
help=(
"Target workspace (admin only; defaults to caller's "
"assigned workspace)"
),
help="Default workspace for the new user",
)
run_main(do_create_user, parser)

View file

@ -21,10 +21,12 @@ default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
default_collection = 'default'
default_doc_limit = 10
default_fetch_limit = 0
def question_explainable(
url, flow_id, question_text, collection, doc_limit, token=None, debug=False,
url, flow_id, question_text, collection, doc_limit, fetch_limit=0,
token=None, debug=False,
workspace="default",
):
"""Execute document RAG with explainability - shows provenance events inline."""
@ -39,6 +41,7 @@ def question_explainable(
query=question_text,
collection=collection,
doc_limit=doc_limit,
fetch_limit=fetch_limit,
):
if isinstance(item, RAGChunk):
# Print response content
@ -97,7 +100,7 @@ def question_explainable(
def question(
url, flow_id, question_text, collection, doc_limit,
url, flow_id, question_text, collection, doc_limit, fetch_limit=0,
streaming=True, token=None, explainable=False, debug=False,
show_usage=False, workspace="default",
):
@ -109,6 +112,7 @@ def question(
question_text=question_text,
collection=collection,
doc_limit=doc_limit,
fetch_limit=fetch_limit,
token=token,
debug=debug,
workspace=workspace,
@ -128,6 +132,7 @@ def question(
query=question_text,
collection=collection,
doc_limit=doc_limit,
fetch_limit=fetch_limit,
streaming=True
)
@ -155,6 +160,7 @@ def question(
query=question_text,
collection=collection,
doc_limit=doc_limit,
fetch_limit=fetch_limit,
)
print(result.text)
@ -214,7 +220,15 @@ def main():
'-d', '--doc-limit',
type=int,
default=default_doc_limit,
help=f'Document limit (default: {default_doc_limit})'
help=f'Documents selected into the prompt (default: {default_doc_limit})'
)
parser.add_argument(
'--fetch-limit',
type=int,
default=default_fetch_limit,
help='Candidate documents fetched from the vector store before '
'reranking (default: derive from doc-limit)'
)
parser.add_argument(
@ -251,6 +265,7 @@ def main():
question_text=args.question,
collection=args.collection,
doc_limit=args.doc_limit,
fetch_limit=args.fetch_limit,
streaming=not args.no_streaming,
token=args.token,
explainable=args.explainable,

View file

@ -112,14 +112,13 @@ def _question_explainable_api(
if focus_full and focus_full.edge_selections:
for edge_sel in focus_full.edge_selections:
if edge_sel.edge:
# Resolve labels for edge components
s_label, p_label, o_label = explain_client.resolve_edge_labels(
edge_sel.edge, collection
)
print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr)
if edge_sel.reasoning:
r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning
print(f" Reason: {r_short}", file=sys.stderr)
if edge_sel.concept or edge_sel.score is not None:
score_str = f"{edge_sel.score:.4f}" if edge_sel.score is not None else "?"
print(f" Concept: {edge_sel.concept} Score: {score_str}", file=sys.stderr)
elif isinstance(entity, Synthesis):
print(f"\n [synthesis] {prov_id}", file=sys.stderr)

View file

@ -0,0 +1,127 @@
"""
Invokes the reranker service to score and rank documents by relevance
to one or more queries.
"""
import argparse
import json
import os
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default")
def query(url, flow_id, queries, documents, limit, token=None,
workspace="default"):
api = Api(url=url, token=token, workspace=workspace)
socket = api.socket()
flow = socket.flow(flow_id)
try:
query_objects = [
{"query_id": str(i), "query_text": q}
for i, q in enumerate(queries)
]
document_objects = [
{"document_id": str(i), "document_text": d}
for i, d in enumerate(documents)
]
result = flow.rerank(
queries=query_objects,
documents=document_objects,
limit=limit,
)
if "error" in result and result["error"]:
err = result["error"]
print(f"Error: [{err.get('type', '')}] {err.get('message', '')}")
return
for r in result.get("results", []):
doc_idx = int(r["document_id"])
query_idx = int(r["query_id"])
print(
f" {r['score']:.4f} | "
f"query: {queries[query_idx]} | "
f"doc: {documents[doc_idx]}"
)
finally:
socket.close()
def main():
parser = argparse.ArgumentParser(
prog='tg-invoke-reranker',
description=__doc__,
)
parser.add_argument(
'-u', '--url',
default=default_url,
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-t', '--token',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-w', '--workspace',
default=default_workspace,
help=f'Workspace (default: {default_workspace})',
)
parser.add_argument(
'-f', '--flow-id',
default="default",
help=f'Flow ID (default: default)'
)
parser.add_argument(
'-l', '--limit',
type=int,
default=10,
help='Maximum number of results (default: 10)',
)
parser.add_argument(
'-q', '--query',
action='append',
required=True,
help='Query text (can be specified multiple times)',
)
parser.add_argument(
'documents',
nargs='+',
help='Documents to rerank',
)
args = parser.parse_args()
try:
query(
url=args.url,
flow_id=args.flow_id,
queries=args.query,
documents=args.documents,
limit=args.limit,
token=args.token,
workspace=args.workspace,
)
except Exception as e:
print("Exception:", e, flush=True)
if __name__ == "__main__":
main()

View file

@ -78,7 +78,7 @@ def load_structured_data(
logger.info("Step 1: Analyzing data to discover best matching schema...")
# Step 1: Auto-discover schema (reuse discover_schema logic)
discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace)
discovered_schema = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, token=token, workspace=workspace)
if not discovered_schema:
logger.error("Failed to discover suitable schema automatically")
print("❌ Could not automatically determine the best schema for your data.")
@ -90,7 +90,7 @@ def load_structured_data(
# Step 2: Auto-generate descriptor
logger.info("Step 2: Generating descriptor configuration...")
auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, workspace=workspace)
auto_descriptor = _auto_generate_descriptor(api_url, input_file, discovered_schema, sample_chars, flow, logger, token=token, workspace=workspace)
if not auto_descriptor:
logger.error("Failed to generate descriptor automatically")
print("❌ Could not automatically generate descriptor configuration.")
@ -172,7 +172,7 @@ def load_structured_data(
logger.info(f"Sample chars: {sample_chars} characters")
# Use the helper function to discover schema (get raw response for display)
response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, workspace=workspace)
response = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=True, token=token, workspace=workspace)
if response:
# Debug: print response type and content
@ -203,7 +203,7 @@ def load_structured_data(
# If no schema specified, discover it first
if not schema_name:
logger.info("No schema specified, auto-discovering...")
schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, workspace=workspace)
schema_name = _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, token=token, workspace=workspace)
if not schema_name:
print("Error: Could not determine schema automatically.")
print("Please specify a schema using --schema-name or run --discover-schema first.")
@ -213,7 +213,7 @@ def load_structured_data(
logger.info(f"Target schema: {schema_name}")
# Generate descriptor using helper function
descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace=workspace)
descriptor = _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, token=token, workspace=workspace)
if descriptor:
# Output the generated descriptor
@ -603,7 +603,7 @@ def _send_to_trustgraph(rows, api_url, flow, batch_size=1000, token=None, worksp
# Helper functions for auto mode
def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, workspace="default"):
def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, return_raw_response=False, token=None, workspace="default"):
"""Auto-discover the best matching schema for the input data
Args:
@ -626,7 +626,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur
# Import API modules
from trustgraph.api import Api
from trustgraph.api.types import ConfigKey
api = Api(api_url, workspace=workspace)
api = Api(api_url, token=token, workspace=workspace)
config_api = api.config()
# Get available schemas
@ -707,7 +707,7 @@ def _auto_discover_schema(api_url, input_file, sample_chars, flow, logger, retur
return None
def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, workspace="default"):
def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, flow, logger, token=None, workspace="default"):
"""Auto-generate descriptor configuration for the discovered schema"""
try:
# Read sample data
@ -717,7 +717,7 @@ def _auto_generate_descriptor(api_url, input_file, schema_name, sample_chars, fl
# Import API modules
from trustgraph.api import Api
from trustgraph.api.types import ConfigKey
api = Api(api_url, workspace=workspace)
api = Api(api_url, token=token, workspace=workspace)
config_api = api.config()
# Get schema definition

View file

@ -51,8 +51,8 @@ def main():
parser.add_argument(
"-w", "--workspace", default=None,
help=(
"Optional workspace to log in against. Defaults to "
"the user's assigned workspace."
"Override the default workspace for this session's JWT. "
"If omitted, uses the user's stored default workspace."
),
)
run_main(do_login, parser)

View file

@ -203,9 +203,9 @@ def print_graphrag_text(trace, explain_client, flow, collection, api=None, show_
)
print(f" {i}. ({s_label}, {p_label}, {o_label})")
if edge_sel.reasoning:
r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning
print(f" Reasoning: {r_short}")
if edge_sel.concept or edge_sel.score is not None:
score_str = f"{edge_sel.score:.4f}" if edge_sel.score is not None else "?"
print(f" Concept: {edge_sel.concept} Score: {score_str}")
if show_provenance and edge_sel.edge:
provenance = trace_edge_provenance(
@ -519,7 +519,8 @@ def trace_to_dict(trace, trace_type):
"selected_edges": [
{
"edge": edge_sel.edge,
"reasoning": edge_sel.reasoning,
"concept": edge_sel.concept,
"score": edge_sel.score,
}
for edge_sel in focus.edge_selections
],

View file

@ -68,7 +68,7 @@ def do_update_user(args):
print(f"username : {rec.get('username', '')}")
print(f"name : {rec.get('name', '')}")
print(f"email : {rec.get('email', '')}")
print(f"workspace : {rec.get('workspace', '')}")
print(f"default_ws: {rec.get('default_workspace', '')}")
print(f"roles : {', '.join(rec.get('roles', []))}")
print(f"enabled : {'yes' if rec.get('enabled') else 'no'}")
print(
@ -114,7 +114,7 @@ def main():
"-w", "--workspace", default=None,
help=(
"Optional workspace integrity check — when supplied, "
"iam-svc verifies the target user's home workspace "
"iam-svc verifies the target user's default workspace "
"matches"
),
)

View file

@ -10,8 +10,8 @@ description = "HuggingFace embeddings support for TrustGraph."
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=2.5,<2.6",
"trustgraph-flow>=2.5,<2.6",
"trustgraph-base>=2.6,<2.7",
"trustgraph-flow>=2.6,<2.7",
"torch",
"urllib3",
"transformers",

View file

@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=2.5,<2.6",
"trustgraph-base>=2.6,<2.7",
"aiohttp",
"anthropic",
"scylla-driver",
@ -19,6 +19,7 @@ dependencies = [
"faiss-cpu",
"falkordb",
"fastembed",
"flashrank",
"ibis",
"jsonschema",
"langchain",
@ -83,6 +84,7 @@ graph-embeddings-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone:
graph-embeddings-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run"
graph-embeddings = "trustgraph.embeddings.graph_embeddings:run"
graph-rag = "trustgraph.retrieval.graph_rag:run"
reranker-flashrank = "trustgraph.reranker.flashrank:run"
kg-extract-agent = "trustgraph.extract.kg.agent:run"
kg-extract-definitions = "trustgraph.extract.kg.definitions:run"
kg-extract-rows = "trustgraph.extract.kg.rows:run"

View file

@ -18,6 +18,10 @@ description : str (default "Default")
Human-readable description passed to flow-svc.
parameters : dict (optional)
Optional parameter overrides passed to start-flow.
list_timeout : int (default 10)
Timeout in seconds for the list-flows request.
start_timeout : int (default 30)
Timeout in seconds for the start-flow request.
"""
from trustgraph.schema import FlowRequest
@ -34,6 +38,8 @@ class DefaultFlowStart(Initialiser):
blueprint=None,
description="Default",
parameters=None,
list_timeout=10,
start_timeout=30,
**kwargs,
):
super().__init__(**kwargs)
@ -46,6 +52,8 @@ class DefaultFlowStart(Initialiser):
self.blueprint = blueprint
self.description = description
self.parameters = dict(parameters) if parameters else {}
self.list_timeout = list_timeout
self.start_timeout = start_timeout
async def run(self, ctx, old_flag, new_flag):
@ -70,7 +78,7 @@ class DefaultFlowStart(Initialiser):
FlowRequest(
operation="list-flows",
),
timeout=10,
timeout=self.list_timeout,
)
if list_resp.error:
raise RuntimeError(
@ -99,7 +107,7 @@ class DefaultFlowStart(Initialiser):
description=self.description,
parameters=self.parameters,
),
timeout=30,
timeout=self.start_timeout,
)
if resp.error:
raise RuntimeError(

View file

@ -14,7 +14,9 @@ seed_file : str (required when source=="seed-file")
Path to a JSON seed file with the same shape TemplateSeed consumes.
overwrite : bool (default False)
On re-run (flag change), if True overwrite all keys; if False,
upsert-missing-only (preserves in-workspace customisations).
upsert-missing-only (preserves in-workspace customisations)
iam_timeout : int (default 10)
Timeout in seconds for the IAM create-workspace request.
Raises (in ``run``)
-------------------
@ -41,7 +43,9 @@ class WorkspaceInit(Initialiser):
source="template",
seed_file=None,
overwrite=False,
iam_timeout=10,
**kwargs,
):
super().__init__(**kwargs)
@ -59,6 +63,7 @@ class WorkspaceInit(Initialiser):
self.source = source
self.seed_file = seed_file
self.overwrite = overwrite
self.iam_timeout = iam_timeout
async def run(self, ctx, old_flag, new_flag):
await self._create_workspace(ctx)
@ -123,7 +128,7 @@ class WorkspaceInit(Initialiser):
enabled=True,
),
),
timeout=10,
timeout=self.iam_timeout,
)
if resp.error:
if resp.error.type == "duplicate":

View file

@ -83,7 +83,8 @@ class Processor(AsyncProcessor):
host=cassandra_host,
username=cassandra_username,
password=cassandra_password,
default_keyspace="config"
default_keyspace="config",
replication_factor=params.get("cassandra_replication_factor"),
)
# Store resolved configuration

View file

@ -61,7 +61,8 @@ class Processor(WorkspaceProcessor):
host=cassandra_host,
username=cassandra_username,
password=cassandra_password,
default_keyspace="knowledge"
default_keyspace="knowledge",
replication_factor=params.get("cassandra_replication_factor"),
)
self.cassandra_host = hosts

View file

@ -32,6 +32,10 @@ logger = logging.getLogger(__name__)
default_ident = "document-decoder"
def _looks_like_pdf(content):
return content.lstrip().startswith(b"%PDF-")
class Processor(FlowProcessor):
def __init__(self, **params):
@ -94,14 +98,10 @@ class Processor(FlowProcessor):
)
return
with tempfile.NamedTemporaryFile(delete_on_close=False, suffix='.pdf') as fp:
temp_path = fp.name
# Check if we should fetch from librarian or use inline data
if v.document_id:
# Fetch from librarian via Pulsar
logger.info(f"Fetching document {v.document_id} from librarian...")
fp.close()
content = await flow.librarian.fetch_document_content(
document_id=v.document_id,
@ -113,13 +113,21 @@ class Processor(FlowProcessor):
content = content.encode('utf-8')
decoded_content = base64.b64decode(content)
with open(temp_path, 'wb') as f:
f.write(decoded_content)
logger.info(f"Fetched {len(decoded_content)} bytes from librarian")
else:
# Use inline data (backward compatibility)
fp.write(base64.b64decode(v.data))
decoded_content = base64.b64decode(v.data)
if not _looks_like_pdf(decoded_content):
logger.error(
f"Document {v.metadata.id} is not valid PDF content. "
f"Ignoring document."
)
return
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as fp:
temp_path = fp.name
fp.write(decoded_content)
fp.close()
global PyPDFLoader

View file

@ -6,7 +6,7 @@ import logging
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.query import BatchStatement, SimpleStatement
from ssl import SSLContext, PROTOCOL_TLSv1_2
import ssl
from ..tables.cassandra_async import async_execute
@ -41,13 +41,15 @@ class KnowledgeGraph:
def __init__(
self, hosts=None,
keyspace="trustgraph", username=None, password=None
keyspace="trustgraph", username=None, password=None,
replication_factor=1,
):
if hosts is None:
hosts = ["localhost"]
self.keyspace = keyspace
self.replication_factor = replication_factor
self.username = username
# 7-table schema for quads with full query pattern support
@ -68,7 +70,7 @@ class KnowledgeGraph:
self.collection_metadata_table = "collection_metadata"
if username and password:
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
ssl_context = ssl.create_default_context()
auth_provider = PlainTextAuthProvider(username=username, password=password)
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
else:
@ -92,7 +94,7 @@ class KnowledgeGraph:
create keyspace if not exists {self.keyspace}
with replication = {{
'class' : 'SimpleStrategy',
'replication_factor' : 1
'replication_factor' : {self.replication_factor}
}};
""")
@ -539,13 +541,15 @@ class EntityCentricKnowledgeGraph:
def __init__(
self, hosts=None,
keyspace="trustgraph", username=None, password=None
keyspace="trustgraph", username=None, password=None,
replication_factor=1,
):
if hosts is None:
hosts = ["localhost"]
self.keyspace = keyspace
self.replication_factor = replication_factor
self.username = username
# 2-table entity-centric schema
@ -556,7 +560,7 @@ class EntityCentricKnowledgeGraph:
self.collection_metadata_table = "collection_metadata"
if username and password:
ssl_context = SSLContext(PROTOCOL_TLSv1_2)
ssl_context = ssl.create_default_context()
auth_provider = PlainTextAuthProvider(username=username, password=password)
self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context)
else:
@ -580,7 +584,7 @@ class EntityCentricKnowledgeGraph:
create keyspace if not exists {self.keyspace}
with replication = {{
'class' : 'SimpleStrategy',
'replication_factor' : 1
'replication_factor' : {self.replication_factor}
}};
""")

View file

@ -57,16 +57,17 @@ class Identity:
# the OSS regime this is the user record's id; the gateway
# treats it as a string with no semantic content.
handle: str
# The workspace this credential authenticates to. Used by the
# gateway as the default-fill-in for operations that omit a
# workspace. Never used as policy input.
workspace: str
# The user's default workspace. Used by the gateway as the
# default-fill-in for operations that omit a workspace. Not a
# permission boundary — workspace access is controlled by the
# IAM regime's authorise() decision, not by this field.
default_workspace: str
# Stable identifier for audit logs. In OSS this is the same
# value as ``handle``; not assumed equal in the contract.
principal_id: str
# How the credential was presented. Non-policy; useful for
# logs / metrics only.
source: str # "api-key" | "jwt"
source: str # "api-key" | "jwt" | "anonymous"
def _auth_failure():
@ -256,21 +257,22 @@ class IamAuth:
raise _auth_failure()
sub = claims.get("sub", "")
ws = claims.get("workspace", "")
ws = claims.get("default_workspace", "")
if not sub or not ws:
raise _auth_failure()
# JWT carries no policy state under the IAM contract;
# any roles / claims field is ignored here.
return Identity(
handle=sub, workspace=ws, principal_id=sub, source="jwt",
handle=sub, default_workspace=ws,
principal_id=sub, source="jwt",
)
async def _authenticate_anonymous(self):
try:
async def _call(client):
return await client.authenticate_anonymous()
user_id, workspace, _roles = await self._with_client(_call)
user_id, default_workspace, _roles = await self._with_client(
_call,
)
except Exception as e:
logger.debug(
f"Anonymous authentication rejected: "
@ -278,11 +280,11 @@ class IamAuth:
)
raise _auth_failure()
if not user_id or not workspace:
if not user_id or not default_workspace:
raise _auth_failure()
return Identity(
handle=user_id, workspace=workspace,
handle=user_id, default_workspace=default_workspace,
principal_id=user_id, source="anonymous",
)
@ -305,7 +307,9 @@ class IamAuth:
# ``roles`` is returned by the OSS regime as a hint
# but is not consulted by the gateway; all policy
# decisions go through ``authorise``.
user_id, workspace, _roles = await self._with_client(_call)
user_id, default_workspace, _roles = await self._with_client(
_call,
)
except Exception as e:
logger.debug(
f"API key resolution failed: "
@ -313,11 +317,11 @@ class IamAuth:
)
raise _auth_failure()
if not user_id or not workspace:
if not user_id or not default_workspace:
raise _auth_failure()
identity = Identity(
handle=user_id, workspace=workspace,
handle=user_id, default_workspace=default_workspace,
principal_id=user_id, source="api-key",
)
self._key_cache[h] = (identity, now + API_KEY_CACHE_TTL)

View file

@ -99,7 +99,7 @@ async def enforce_workspace(data, identity, auth, capability=None):
return data
requested = data.get("workspace", "")
target = requested or identity.workspace
target = requested or identity.default_workspace
data["workspace"] = target
if target not in auth.known_workspaces:

View file

@ -37,6 +37,7 @@ from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
from . document_embeddings_query import DocumentEmbeddingsQueryRequestor
from . row_embeddings_query import RowEmbeddingsQueryRequestor
from . mcp_tool import McpToolRequestor
from . reranker import RerankerRequestor
from . text_load import TextLoad
from . document_load import DocumentLoad
@ -74,6 +75,7 @@ request_response_dispatchers = {
"structured-diag": StructuredDiagRequestor,
"row-embeddings": RowEmbeddingsQueryRequestor,
"sparql": SparqlQueryRequestor,
"reranker": RerankerRequestor,
}
system_dispatchers = {

View file

@ -5,6 +5,7 @@ import uuid
import logging
from ..capabilities import PUBLIC, AUTHENTICATED
from ..registry import ResourceLevel
# Module logger
logger = logging.getLogger(__name__)
@ -79,7 +80,7 @@ class Mux:
self.identity = identity
await self.ws.send_json({
"type": "auth-ok",
"workspace": identity.workspace,
"default_workspace": identity.default_workspace,
})
async def receive(self, msg):
@ -159,12 +160,14 @@ class Mux:
return
# Resolve workspace (default-fill from the caller's
# bound workspace). Workspace resolution applies to all
# operations regardless of capability level.
# bound workspace). The envelope workspace is the
# single canonical workspace for routing AND
# authorisation. The inner request body's workspace
# field is not consulted — workspace-scoped services
# receive workspace from the queue identity, not the
# message body.
try:
await enforce_workspace(data, self.identity, self.auth)
if isinstance(inner, dict):
await enforce_workspace(inner, self.identity, self.auth)
# Authorisation: capability sentinels short-circuit
# the regime call; capability strings go through
@ -176,11 +179,18 @@ class Mux:
"flow": data.get("flow", ""),
}
parameters = {}
elif op.resource_level == ResourceLevel.WORKSPACE:
# Workspace-scoped services (config, flow,
# librarian, etc.) — workspace comes from the
# envelope, same as flow-level services.
resource = {
"workspace": data.get("workspace", ""),
}
parameters = {}
else:
# Build a minimal RequestContext so the matched
# operation's own extractors decide resource
# and parameters — same path the HTTP
# endpoints take.
# System-level services (IAM) — resource is
# {} and parameters come from the inner body
# (e.g. user.workspace, workspace_record.id).
from ..registry import RequestContext
ctx = RequestContext(
body=inner if isinstance(inner, dict) else {},

View file

@ -0,0 +1,31 @@
from ... schema import RerankerRequest, RerankerResponse
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
class RerankerRequestor(ServiceRequestor):
def __init__(
self, backend, request_queue, response_queue, timeout,
consumer, subscriber,
):
super(RerankerRequestor, self).__init__(
backend=backend,
request_queue=request_queue,
response_queue=response_queue,
request_schema=RerankerRequest,
response_schema=RerankerResponse,
subscription = subscriber,
consumer_name = consumer,
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("reranker")
self.response_translator = TranslatorRegistry.get_response_translator("reranker")
def to_request(self, body):
return self.request_translator.decode(body)
def from_response(self, message):
return self.response_translator.encode_with_completion(message)

View file

@ -92,7 +92,7 @@ class Operation:
# Returns a dict with the appropriate components for the
# resource level: {} for SYSTEM, {workspace} for WORKSPACE,
# {workspace, flow} for FLOW. Default-fill-in of workspace
# from identity.workspace happens here when applicable.
# from identity.default_workspace happens here when applicable.
extract_resource: Callable[[RequestContext], dict]
# Build the parameters dict — decision-relevant fields the
@ -141,7 +141,7 @@ def _workspace_from_body(ctx: RequestContext) -> dict:
workspace field, defaulting to the caller's bound workspace."""
ws = (ctx.body.get("workspace") if isinstance(ctx.body, dict) else "")
if not ws and ctx.identity is not None:
ws = ctx.identity.workspace
ws = ctx.identity.default_workspace
return {"workspace": ws}
@ -188,7 +188,7 @@ def _workspace_param_only(ctx: RequestContext) -> dict:
or body.get("workspace")
)
if not ws and ctx.identity is not None:
ws = ctx.identity.workspace
ws = ctx.identity.default_workspace
return {"workspace": ws or ""}
@ -311,7 +311,7 @@ register(Operation(
))
register(Operation(
name="list-my-workspaces",
capability="workspaces:list-own",
capability=AUTHENTICATED,
resource_level=ResourceLevel.SYSTEM,
extract_resource=_empty_resource,
extract_parameters=_no_parameters,
@ -506,18 +506,19 @@ _FLOW_SERVICES = {
"text-completion": "llm",
"prompt": "llm",
"mcp-tool": "mcp",
"graph-rag": "graph:read",
"document-rag": "documents:read",
"graph-rag": "graph-rag:read",
"document-rag": "document-rag:read",
"embeddings": "embeddings",
"graph-embeddings": "graph:read",
"document-embeddings": "documents:read",
"triples": "graph:read",
"graph-embeddings": "graph-embeddings:read",
"document-embeddings": "document-embeddings:read",
"triples": "triples:read",
"rows": "rows:read",
"nlp-query": "rows:read",
"structured-query": "rows:read",
"structured-diag": "rows:read",
"row-embeddings": "rows:read",
"sparql": "graph:read",
"nlp-query": "nlp-query:read",
"structured-query": "structured-query:read",
"structured-diag": "structured-query:read",
"row-embeddings": "row-embeddings:read",
"sparql": "sparql:read",
"reranker": "reranker",
}
for _kind, _cap in _FLOW_SERVICES.items():
_register_flow_kind("flow-service", _kind, _cap)
@ -525,10 +526,10 @@ for _kind, _cap in _FLOW_SERVICES.items():
# Streaming import socket endpoints.
_FLOW_IMPORTS = {
"triples": "graph:write",
"graph-embeddings": "graph:write",
"document-embeddings": "documents:write",
"entity-contexts": "documents:write",
"triples": "triples:write",
"graph-embeddings": "graph-embeddings:write",
"document-embeddings": "document-embeddings:write",
"entity-contexts": "entity-contexts:write",
"rows": "rows:write",
}
for _kind, _cap in _FLOW_IMPORTS.items():
@ -537,10 +538,35 @@ for _kind, _cap in _FLOW_IMPORTS.items():
# Streaming export socket endpoints.
_FLOW_EXPORTS = {
"triples": "graph:read",
"graph-embeddings": "graph:read",
"document-embeddings": "documents:read",
"entity-contexts": "documents:read",
"triples": "triples:read",
"graph-embeddings": "graph-embeddings:read",
"document-embeddings": "document-embeddings:read",
"entity-contexts": "entity-contexts:read",
}
for _kind, _cap in _FLOW_EXPORTS.items():
_register_flow_kind("flow-export", _kind, _cap)
# ---------------------------------------------------------------------------
# Enterprise IAM operations.
#
# These are additive — they register alongside the OSS IAM operations.
# When the OSS regime receives an unknown operation it returns an error;
# when the enterprise regime is running, it handles them.
# ---------------------------------------------------------------------------
for _op in (
"create-group", "get-group", "list-groups",
"update-group", "delete-group",
"add-group-member", "remove-group-member", "list-group-members",
"add-group-grant", "remove-group-grant", "list-group-grants",
"add-user-grant", "remove-user-grant", "list-user-grants",
"resolve-effective-permissions",
):
register(Operation(
name=_op,
capability="iam:admin",
resource_level=ResourceLevel.SYSTEM,
extract_resource=_empty_resource,
extract_parameters=_no_parameters,
))

View file

@ -28,14 +28,14 @@ class NoAuthHandler:
def _default_identity_response(self):
return IamResponse(
resolved_user_id=self.default_user_id,
resolved_workspace=self.default_workspace,
resolved_default_workspace=self.default_workspace,
resolved_roles=["admin"],
)
def _default_user_record(self):
return UserRecord(
id=self.default_user_id,
workspace=self.default_workspace,
default_workspace=self.default_workspace,
username=self.default_user_id,
name="Anonymous User",
roles=["admin"],

View file

@ -58,21 +58,35 @@ AUTHZ_CACHE_TTL_SECONDS = 60
_READER_CAPS = {
"agent",
"graph:read",
"triples:read",
"sparql:read",
"graph-rag:read",
"graph-embeddings:read",
"documents:read",
"document-rag:read",
"document-embeddings:read",
"entity-contexts:read",
"rows:read",
"nlp-query:read",
"structured-query:read",
"row-embeddings:read",
"llm",
"embeddings",
"reranker",
"mcp",
"config:read",
"flows:read",
"collections:read",
"knowledge:read",
"keys:self",
"workspaces:list-own",
}
_WRITER_CAPS = _READER_CAPS | {
"graph:write",
"triples:write",
"graph-embeddings:write",
"document-embeddings:write",
"entity-contexts:write",
"documents:write",
"rows:write",
"collections:write",
@ -369,7 +383,7 @@ class IamService:
) = row
return UserRecord(
id=id or "",
workspace=workspace or "",
default_workspace=workspace or "",
username=username or "",
name=name or "",
email=email or "",
@ -582,14 +596,8 @@ class IamService:
if not v.password:
return _err("auth-failed", "password required")
# Login accepts an optional workspace parameter. If omitted
# we use the default workspace (OSS single-workspace
# assumption). Multi-workspace enterprise editions swap in a
# resolver that looks across the caller's permitted set.
workspace = v.workspace or DEFAULT_WORKSPACE
user_id = await self.table_store.get_user_id_by_username(
workspace, v.username,
v.username,
)
if not user_id:
return _err("auth-failed", "no such user")
@ -610,7 +618,10 @@ class IamService:
):
return _err("auth-failed", "bad credentials")
ws_row = await self.table_store.get_workspace(ws)
# JWT workspace: login request override, or the user's default.
jwt_workspace = v.workspace or ws
ws_row = await self.table_store.get_workspace(jwt_workspace)
if ws_row is None or not ws_row[2]:
return _err("auth-failed", "workspace disabled")
@ -618,14 +629,10 @@ class IamService:
now_ts = int(_now_dt().timestamp())
exp_ts = now_ts + JWT_TTL_SECONDS
# Per the IAM contract the gateway never reads policy state
# from the credential — roles stay server-side, reachable
# only via authorise(). JWT carries identity + workspace
# binding only.
claims = {
"iss": JWT_ISSUER,
"sub": id,
"workspace": ws,
"default_workspace": jwt_workspace,
"iat": now_ts,
"exp": exp_ts,
}
@ -864,20 +871,15 @@ class IamService:
# user_row indices match get_user columns. Username is [2].
username = user_row[2]
record_workspace = user_row[1]
# Revoke all API keys.
key_rows = await self.table_store.list_api_keys_by_user(v.user_id)
for kr in key_rows:
await self.table_store.delete_api_key(kr[0])
# Remove username lookup — keyed on (workspace, username),
# so use the resolved workspace from the user record rather
# than relying on the caller-supplied filter.
# Remove global username lookup.
if username:
await self.table_store.delete_username_lookup(
record_workspace, username,
)
await self.table_store.delete_username_lookup(username)
# Remove user record.
await self.table_store.delete_user(v.user_id)
@ -1096,13 +1098,15 @@ class IamService:
return _err("auth-failed", "owning user disabled")
# Workspace-disabled check.
ws_row = await self.table_store.get_workspace(user.workspace)
ws_row = await self.table_store.get_workspace(
user.default_workspace,
)
if ws_row is None or not ws_row[2]:
return _err("auth-failed", "owning workspace disabled")
return IamResponse(
resolved_user_id=user.id,
resolved_workspace=user.workspace,
resolved_default_workspace=user.default_workspace,
resolved_roles=list(user.roles),
)
@ -1129,9 +1133,9 @@ class IamService:
if ws is None or not ws[2]:
return _err("not-found", "workspace not found or disabled")
# Uniqueness on username within workspace.
# Global username uniqueness.
existing = await self.table_store.get_user_id_by_username(
v.workspace, v.user.username,
v.user.username,
)
if existing:
return _err("duplicate", "username already exists")
@ -1303,8 +1307,9 @@ class IamService:
return False, AUTHZ_CACHE_TTL_SECONDS
# user_row layout:
# 0:id 1:workspace 2:username 3:name 4:email 5:password_hash
# 6:roles 7:enabled 8:must_change_password 9:created
# 0:id 1:default_workspace 2:username 3:name 4:email
# 5:password_hash 6:roles 7:enabled 8:must_change_password
# 9:created
if not user_row[7]: # disabled
return False, AUTHZ_CACHE_TTL_SECONDS

View file

@ -101,6 +101,7 @@ class Processor(AsyncProcessor):
username=cassandra_username,
password=cassandra_password,
default_keyspace="iam",
replication_factor=params.get("cassandra_replication_factor"),
)
self.cassandra_host = hosts

View file

@ -146,7 +146,8 @@ class Processor(WorkspaceProcessor):
host=cassandra_host,
username=cassandra_username,
password=cassandra_password,
default_keyspace="librarian"
default_keyspace="librarian",
replication_factor=params.get("cassandra_replication_factor"),
)
# Store resolved configuration

View file

@ -10,6 +10,7 @@ import logging
from .... exceptions import TooManyRequests, LlmError
from .... base import LlmService, LlmResult, LlmChunk
from . variants import get_variant, DEFAULT_VARIANT, VARIANTS
# Module logger
logger = logging.getLogger(__name__)
@ -21,6 +22,7 @@ default_temperature = 0.0
default_max_output = 4096
default_api_key = os.getenv("OPENAI_TOKEN")
default_base_url = os.getenv("OPENAI_BASE_URL")
default_thinking = "off"
if default_base_url is None or default_base_url == "":
default_base_url = "https://api.openai.com/v1"
@ -34,10 +36,15 @@ class Processor(LlmService):
base_url = params.get("url", default_base_url)
temperature = params.get("temperature", default_temperature)
max_output = params.get("max_output", default_max_output)
thinking = params.get("thinking", default_thinking)
variant_name = params.get("variant", DEFAULT_VARIANT)
if not api_key:
api_key = "not-set"
self.variant = get_variant(variant_name)
self.thinking = thinking
super(Processor, self).__init__(
**params | {
"model": model,
@ -56,13 +63,28 @@ class Processor(LlmService):
else:
self.openai = OpenAI(api_key=api_key)
logger.info("OpenAI LLM service initialized")
logger.info(
f"OpenAI LLM service initialized "
f"(variant={self.variant.name}, thinking={self.thinking})"
)
def _build_kwargs(self, model_name, temperature):
"""Build API call kwargs using the active variant."""
return self.variant.completion_kwargs(
max_output=self.max_output,
temperature=temperature,
thinking=self.thinking,
)
def _extract_content(self, message):
"""Extract visible content from a response message."""
if hasattr(self.variant, "extract_content"):
return self.variant.extract_content(message)
return message.content
async def generate_content(self, system, prompt, model=None, temperature=None):
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model: {model_name}")
@ -72,8 +94,8 @@ class Processor(LlmService):
try:
resp = self.openai.chat.completions.create(
model=model_name,
api_kwargs = self._build_kwargs(model_name, effective_temperature)
messages = [
{
"role": "user",
@ -84,19 +106,26 @@ class Processor(LlmService):
}
]
}
],
temperature=effective_temperature,
max_completion_tokens=self.max_output,
]
resp = self.variant.create_completion(
self.openai, model_name, messages, **api_kwargs,
)
inputtokens = resp.usage.prompt_tokens
outputtokens = resp.usage.completion_tokens
logger.debug(f"LLM response: {resp.choices[0].message.content}")
content = self._extract_content(resp.choices[0].message)
thinking = self.variant.extract_thinking(resp.choices[0].message)
logger.debug(f"LLM response: {content}")
if thinking:
logger.debug(f"LLM thinking: {thinking[:200]}...")
logger.info(f"Input Tokens: {inputtokens}")
logger.info(f"Output Tokens: {outputtokens}")
resp = LlmResult(
text = resp.choices[0].message.content,
text = content,
in_token = inputtokens,
out_token = outputtokens,
model = model_name
@ -136,9 +165,7 @@ class Processor(LlmService):
Stream content generation from OpenAI.
Yields LlmChunk objects with is_final=True on the last chunk.
"""
# Use provided model or fall back to default
model_name = model or self.default_model
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
logger.debug(f"Using model (streaming): {model_name}")
@ -147,8 +174,8 @@ class Processor(LlmService):
prompt = system + "\n\n" + prompt
try:
response = self.openai.chat.completions.create(
model=model_name,
api_kwargs = self._build_kwargs(model_name, effective_temperature)
messages = [
{
"role": "user",
@ -159,18 +186,14 @@ class Processor(LlmService):
}
]
}
],
temperature=effective_temperature,
max_completion_tokens=self.max_output,
stream=True,
stream_options={"include_usage": True}
)
]
total_input_tokens = 0
total_output_tokens = 0
# Stream chunks
for chunk in response:
async for chunk in self.variant.create_completion_stream(
self.openai, model_name, messages, **api_kwargs,
):
if chunk.choices and chunk.choices[0].delta.content:
yield LlmChunk(
text=chunk.choices[0].delta.content,
@ -254,6 +277,20 @@ class Processor(LlmService):
help=f'LLM max output tokens (default: {default_max_output})'
)
parser.add_argument(
'--thinking',
choices=["off", "low", "medium", "high"],
default=default_thinking,
help=f'Thinking/reasoning effort level (default: {default_thinking})'
)
parser.add_argument(
'--variant',
choices=sorted(VARIANTS.keys()),
default=DEFAULT_VARIANT,
help=f'API variant (default: {DEFAULT_VARIANT})'
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -0,0 +1,219 @@
"""
OpenAI API variant profiles.
Different providers expose OpenAI-compatible APIs with subtle differences
in parameter names, thinking/reasoning support, and temperature handling.
Each variant encapsulates those quirks so the processor doesn't need
provider-specific conditionals.
"""
import re
import logging
logger = logging.getLogger(__name__)
class Variant:
"""Base variant — defines the interface all variants implement."""
name = None
token_param = "max_completion_tokens"
temperature_with_thinking = False
def completion_kwargs(self, max_output, temperature, thinking):
"""Build provider-specific kwargs for chat.completions.create().
Parameters
----------
max_output : int
Configured max output tokens.
temperature : float
Configured temperature.
thinking : str
Thinking effort level: "off", "low", "medium", "high".
Returns
-------
dict
Extra kwargs to spread into the API call.
"""
kwargs = {self.token_param: max_output}
if thinking != "off":
kwargs.update(self.thinking_kwargs(thinking))
if not self.temperature_with_thinking:
kwargs["temperature"] = 1.0
else:
kwargs["temperature"] = temperature
else:
kwargs["temperature"] = temperature
return kwargs
def thinking_kwargs(self, effort):
"""Return kwargs to enable thinking at the given effort level."""
return {}
def extract_thinking(self, message):
"""Extract thinking/reasoning content from a response message."""
return getattr(message, "reasoning_content", None)
def extract_thinking_stream(self, delta):
"""Extract thinking content from a streaming delta."""
return getattr(delta, "reasoning_content", None)
def create_completion(self, client, model, messages, **kwargs):
"""Call the completions API. Override for non-standard SDKs."""
return client.chat.completions.create(
model=model, messages=messages, **kwargs,
)
async def create_completion_stream(self, client, model, messages, **kwargs):
"""Call the streaming completions API. Override for non-standard SDKs."""
for chunk in client.chat.completions.create(
model=model, messages=messages, stream=True,
stream_options={"include_usage": True}, **kwargs,
):
yield chunk
class OpenAIVariant(Variant):
"""Standard OpenAI API (GPT-4o, o1, o3, etc.)."""
name = "openai"
token_param = "max_completion_tokens"
temperature_with_thinking = False
def thinking_kwargs(self, effort):
return {"reasoning_effort": effort}
class DeepSeekVariant(Variant):
"""DeepSeek API (R1, V3, etc.)."""
name = "deepseek"
token_param = "max_completion_tokens"
temperature_with_thinking = True
def completion_kwargs(self, max_output, temperature, thinking):
enabled = "enabled" if thinking != "off" else "disabled"
kwargs = {
self.token_param: max_output,
"temperature": temperature,
"extra_body": {
"thinking": {"type": enabled},
},
}
return kwargs
def thinking_kwargs(self, effort):
return {}
class DashScopeVariant(Variant):
"""Alibaba Cloud DashScope API (Qwen models)."""
name = "dashscope"
token_param = "max_completion_tokens"
temperature_with_thinking = True
def completion_kwargs(self, max_output, temperature, thinking):
enabled = thinking != "off"
return {
self.token_param: max_output,
"temperature": temperature,
"extra_body": {
"enable_thinking": enabled,
},
}
def thinking_kwargs(self, effort):
return {}
class QwenVariant(DashScopeVariant):
"""Qwen — alias for DashScope."""
name = "qwen"
class MistralVariant(Variant):
"""Mistral API (Mistral Large, etc.)."""
name = "mistral"
token_param = "max_tokens"
temperature_with_thinking = False
def thinking_kwargs(self, effort):
return {"reasoning_effort": effort}
class GlmVariant(Variant):
"""GLM / Zhipu AI API (GLM-4, GLM-4.7, etc.)."""
name = "glm"
token_param = "max_tokens"
temperature_with_thinking = True
def completion_kwargs(self, max_output, temperature, thinking):
enabled = "enabled" if thinking != "off" else "disabled"
kwargs = {
self.token_param: max_output,
"temperature": temperature,
"extra_body": {
"thinking": {"type": enabled},
},
}
return kwargs
def thinking_kwargs(self, effort):
return {}
class LlamaVariant(Variant):
"""Llama models via OpenAI-compatible servers (vLLM, Ollama, etc.).
Thinking is typically always-on or always-off depending on the model.
When present, thinking appears inline as <think>...</think> tags.
"""
name = "llama"
token_param = "max_tokens"
temperature_with_thinking = True
def thinking_kwargs(self, effort):
return {}
def extract_thinking(self, message):
content = message.content or ""
match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
return match.group(1).strip() if match else None
def extract_content(self, message):
"""Strip think tags from visible content."""
content = message.content or ""
return re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
VARIANTS = {
"openai": OpenAIVariant,
"deepseek": DeepSeekVariant,
"qwen": QwenVariant,
"mistral": MistralVariant,
"dashscope": DashScopeVariant,
"glm": GlmVariant,
"llama": LlamaVariant,
}
DEFAULT_VARIANT = "openai"
def get_variant(name):
"""Look up a variant by name, raising ValueError if unknown."""
cls = VARIANTS.get(name)
if cls is None:
raise ValueError(
f"Unknown variant {name!r}. "
f"Available: {', '.join(sorted(VARIANTS))}"
)
return cls()

View file

@ -12,31 +12,33 @@ from qdrant_client import QdrantClient
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
from .... schema import Error
from .... base import DocumentEmbeddingsQueryService
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "doc-embeddings-query"
default_store_uri = 'http://localhost:6333'
class Processor(DocumentEmbeddingsQueryService):
def __init__(self, **params):
store_uri = params.get("store_uri", default_store_uri)
store_uri = params.get("store_uri")
api_key = params.get("api_key")
#optional api key
api_key = params.get("api_key", None)
url, api_key, _, _ = resolve_qdrant_config(
url=store_uri,
api_key=api_key,
)
super(Processor, self).__init__(
**params | {
"store_uri": store_uri,
"store_uri": url,
"api_key": api_key,
}
)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self.qdrant = QdrantClient(url=url, api_key=api_key)
async def query_document_embeddings(self, workspace, msg):
@ -85,18 +87,7 @@ class Processor(DocumentEmbeddingsQueryService):
def add_args(parser):
DocumentEmbeddingsQueryService.add_args(parser)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant store URI (default: {default_store_uri})'
)
parser.add_argument(
'-k', '--api-key',
default=None,
help=f'API key for qdrant (default: None)'
)
add_qdrant_args(parser)
def run():

View file

@ -12,31 +12,32 @@ from qdrant_client import QdrantClient
from .... schema import GraphEmbeddingsResponse, EntityMatch
from .... schema import Error, Term, IRI, LITERAL
from .... base import GraphEmbeddingsQueryService
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "graph-embeddings-query"
default_store_uri = 'http://localhost:6333'
class Processor(GraphEmbeddingsQueryService):
def __init__(self, **params):
store_uri = params.get("store_uri", default_store_uri)
store_uri = params.get("store_uri")
api_key = params.get("api_key")
#optional api key
api_key = params.get("api_key", None)
url, api_key, _, _ = resolve_qdrant_config(
url=store_uri, api_key=api_key,
)
super(Processor, self).__init__(
**params | {
"store_uri": store_uri,
"store_uri": url,
"api_key": api_key,
}
)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self.qdrant = QdrantClient(url=url, api_key=api_key)
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
@ -104,18 +105,7 @@ class Processor(GraphEmbeddingsQueryService):
def add_args(parser):
GraphEmbeddingsQueryService.add_args(parser)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant store URI (default: {default_store_uri})'
)
parser.add_argument(
'-k', '--api-key',
default=None,
help=f'API key for qdrant (default: None)'
)
add_qdrant_args(parser)
def run():

View file

@ -116,7 +116,7 @@ class CassandraTripleStore(Store if RDFLIB_AVAILABLE else object):
# Create keyspace
self.session.execute(f"""
CREATE KEYSPACE IF NOT EXISTS {self.keyspace}
WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}
WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': {self.cassandra_config.get('replication_factor', 1)}}}
""")
# Create triples table optimized for SPARQL queries

View file

@ -19,12 +19,12 @@ from .... schema import (
RowIndexMatch, Error
)
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base.qdrant_config import add_qdrant_args, resolve_qdrant_config
# Module logger
logger = logging.getLogger(__name__)
default_ident = "row-embeddings-query"
default_store_uri = 'http://localhost:6333'
default_concurrency = 10
@ -35,13 +35,17 @@ class Processor(FlowProcessor):
id = params.get("id", default_ident)
concurrency = params.get("concurrency", default_concurrency)
store_uri = params.get("store_uri", default_store_uri)
api_key = params.get("api_key", None)
store_uri = params.get("store_uri")
api_key = params.get("api_key")
url, api_key, _, _ = resolve_qdrant_config(
url=store_uri, api_key=api_key,
)
super(Processor, self).__init__(
**params | {
"id": id,
"store_uri": store_uri,
"store_uri": url,
"api_key": api_key,
}
)
@ -62,7 +66,7 @@ class Processor(FlowProcessor):
)
)
self.qdrant = QdrantClient(url=store_uri, api_key=api_key)
self.qdrant = QdrantClient(url=url, api_key=api_key)
def sanitize_name(self, name: str) -> str:
"""Sanitize names for Qdrant collection naming"""
@ -192,21 +196,9 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
"""Add command-line arguments"""
FlowProcessor.add_args(parser)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Qdrant store URI (default: {default_store_uri})'
)
parser.add_argument(
'-k', '--api-key',
default=None,
help='API key for Qdrant (default: None)'
)
add_qdrant_args(parser)
parser.add_argument(
'-c', '--concurrency',

View file

@ -24,7 +24,7 @@ from .... schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
from .... schema import Error, RowSchema, Field as SchemaField
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
from .... tables.cassandra_async import async_execute
from .... tables.cassandra_async import async_execute, async_execute_paged, async_scan
from ... graphql import GraphQLSchemaBuilder, SortDirection
@ -180,7 +180,7 @@ class Processor(FlowProcessor):
description=field_def.get("description", ""),
required=field_def.get("required", False),
enum_values=field_def.get("enum", []),
indexed=field_def.get("indexed", False)
indexed=field_def.get("indexed", False),
)
fields.append(field)
@ -232,6 +232,8 @@ class Processor(FlowProcessor):
for index_name in index_names:
if index_name in filters:
value = filters[index_name]
if value == "" or value is None:
continue
# Single field index -> single element list
index_value = [str(value)]
return (index_name, index_value)
@ -282,9 +284,11 @@ class Processor(FlowProcessor):
query += f" LIMIT {limit}"
try:
rows = await async_execute(self.session, query, params)
for row in rows:
# Convert data map to dict with proper field names
pages = await async_execute_paged(
self.session, query, params
)
for page in pages:
for row in page:
row_dict = dict(row.data) if row.data else {}
results.append(row_dict)
except Exception as e:
@ -308,8 +312,6 @@ class Processor(FlowProcessor):
# Query using the first index (arbitrary choice for scan)
primary_index = index_names[0]
# We need to scan all values for this index
# This requires ALLOW FILTERING or a different approach
query = f"""
SELECT data, source FROM {safe_keyspace}.rows
WHERE collection = %s
@ -320,18 +322,19 @@ class Processor(FlowProcessor):
params = [collection, schema_name, primary_index]
try:
rows = await async_execute(self.session, query, params)
for row in rows:
def row_filter(row):
row_dict = dict(row.data) if row.data else {}
return self._matches_filters(row_dict, filters, row_schema)
# Apply post-filters
if self._matches_filters(row_dict, filters, row_schema):
matched_rows = await async_scan(
self.session, query, params,
row_filter=row_filter,
limit=limit,
)
for row in matched_rows:
row_dict = dict(row.data) if row.data else {}
results.append(row_dict)
if limit and len(results) >= limit:
break
except Exception as e:
logger.error(f"Failed to scan rows: {e}", exc_info=True)
raise
@ -363,7 +366,7 @@ class Processor(FlowProcessor):
# Parse filter key for operator
if '_' in filter_key:
parts = filter_key.rsplit('_', 1)
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in']:
if parts[1] in ['gt', 'gte', 'lt', 'lte', 'contains', 'in', 'not', 'startsWith', 'endsWith', 'not_in']:
field_name = parts[0]
operator = parts[1]
else:
@ -400,6 +403,18 @@ class Processor(FlowProcessor):
elif operator == 'in':
if str(row_value) not in [str(v) for v in filter_value]:
return False
elif operator == 'not':
if str(row_value) == str(filter_value):
return False
elif operator == 'startsWith':
if not str(row_value).startswith(str(filter_value)):
return False
elif operator == 'endsWith':
if not str(row_value).endswith(str(filter_value)):
return False
elif operator == 'not_in':
if str(row_value) in [str(v) for v in filter_value]:
return False
except (ValueError, TypeError):
return False

View file

@ -0,0 +1 @@

View file

@ -0,0 +1,2 @@
from . processor import *

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from . processor import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,109 @@
"""
Reranker service using flashrank.
Scores query-document pairs and returns the top results ranked by
relevance.
"""
import asyncio
import logging
from ... base import RerankerService
from ... schema import RerankerResult
from flashrank import Ranker, RerankRequest
logger = logging.getLogger(__name__)
default_ident = "reranker"
default_model = "ms-marco-MiniLM-L-12-v2"
class Processor(RerankerService):
def __init__(self, **params):
model = params.get("model", default_model)
super(Processor, self).__init__(
**params | { "model": model }
)
self.default_model = model
self.cached_model_name = None
self.ranker = None
self._load_model(model)
def _load_model(self, model_name):
if self.cached_model_name != model_name:
logger.info(f"Loading flashrank model: {model_name}")
self.ranker = Ranker(model_name=model_name)
self.cached_model_name = model_name
logger.info(f"flashrank model {model_name} loaded successfully")
else:
logger.debug(f"Using cached model: {model_name}")
def _run_rerank(self, query, passages):
request = RerankRequest(query=query, passages=passages)
return self.ranker.rerank(request)
async def on_rerank(self, queries, documents, limit, model=None):
if not queries or not documents:
return []
use_model = model or self.default_model
if self.cached_model_name != use_model:
await asyncio.to_thread(self._load_model, use_model)
passages = [
{"id": d.document_id, "text": d.document_text}
for d in documents
]
best_scores = {}
for q in queries:
ranked = await asyncio.to_thread(
self._run_rerank, q.query_text, passages,
)
for r in ranked:
doc_id = r["id"]
score = r["score"]
score = float(score)
if doc_id not in best_scores or score > best_scores[doc_id][1]:
best_scores[doc_id] = (q.query_id, score)
results = sorted(
best_scores.items(),
key=lambda x: x[1][1],
reverse=True,
)[:limit]
return [
RerankerResult(
document_id=doc_id,
query_id=query_id,
score=score,
)
for doc_id, (query_id, score) in results
]
@staticmethod
def add_args(parser):
RerankerService.add_args(parser)
parser.add_argument(
'-m', '--model',
default=default_model,
help=f'Reranker model (default: {default_model})'
)
def run():
Processor.launch(default_ident, __doc__)

Some files were not shown because too many files have changed in this diff Show more