diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml index 88d2b79e..dcf93946 100644 --- a/.github/workflows/pull-request.yaml +++ b/.github/workflows/pull-request.yaml @@ -22,7 +22,7 @@ jobs: uses: actions/checkout@v3 - name: Setup packages - run: make update-package-versions VERSION=2.5.999 + run: make update-package-versions VERSION=2.6.999 - name: Setup environment run: python3 -m venv env diff --git a/README.dev-install.md b/README.dev-install.md new file mode 100644 index 00000000..d57cb1f3 --- /dev/null +++ b/README.dev-install.md @@ -0,0 +1,218 @@ +# TrustGraph Developer Install Guide + +A guided installer that gets TrustGraph running locally in a single +command. It detects your hardware, recommends an LLM backend, installs +missing prerequisites, runs the test suite, generates a compose deployment, +starts the stack, and opens the Workbench UI. + +> **macOS only.** This installer has only been tested on macOS. If you are +> on Linux or Windows, use the standard docker-compose / podman-compose +> installation instructions instead. + +## Quick start + +```bash +./install_trustgraph.sh +``` + +The installer walks you through each step interactively. When it finishes, +the Workbench UI opens at `http://localhost:8888` and the API gateway is +available at `http://localhost:8088/`. + +## Prerequisites + +The installer checks for these and offers to install any that are missing +(via Homebrew): + +- **Python 3** with venv support +- **Node.js / npx** (drives the `@trustgraph/config` deployment generator) +- **Docker** (with Compose) or **Podman** (with podman-compose) +- **curl** and **unzip** +- **Ollama** (only if you choose local LLMs) + +The installer can also launch Docker Desktop or the Ollama app for you if +they are installed but not running. + +## What the installer does + +1. **Detects hardware** -- OS, architecture, CPU cores, memory, and GPU. +2. **Recommends an LLM mode** -- `ollama` for machines with >= 16 GB RAM and + a GPU or >= 8 cores; `openai` otherwise. +3. **Collects configuration** -- API key, LLM provider, model choices, + install directory. Answers are saved to + `/trustgraph-installer.env` and reused on subsequent runs. +4. **Checks and installs prerequisites** -- Python, Node/npx, Docker or + Podman, Ollama (if selected). +5. **Downloads Ollama models** (if using Ollama) -- chat model + (`granite4:350m` by default) and embeddings model (`mxbai-embed-large`). +6. **Creates a Python venv** and installs the local TrustGraph packages into + it, along with NLTK data and tiktoken caches. +7. **Runs the full pytest suite** against the local source tree. +8. **Runs `npx @trustgraph/config`** -- the existing interactive config + wizard that produces a `deploy.zip` with a compose file. +9. **Starts the compose stack** and waits for the API gateway to respond. +10. **Bootstraps IAM** and verifies the API key authenticates. +11. **Opens the Workbench UI** in your default browser. + +## Command-line options + +| Option | Description | +|---|---| +| `--install-dir PATH` | Directory for deployment files (default: `./trustgraph-deploy`) | +| `--api-url URL` | API gateway URL for health checks (default: `http://localhost:8088/`) | +| `--ui-url URL` | Workbench UI URL to open (default: `http://localhost:8888`) | +| `--use-existing-compose FILE` | Skip config generation and start this compose file directly | +| `--skip-tests` | Do not run the pytest suite | +| `--no-launch` | Do not open the Workbench UI at the end | +| `--non-interactive` | Accept all defaults without prompting | +| `--yes` | Auto-accept confirmation prompts | +| `--fresh` | Remove installer-managed files before generating a new deployment | +| `--remove-all` | Uninstall: stop containers, remove compose volumes, delete installer files | +| `--dry-run` | Print detected hardware and planned defaults, then exit | +| `-h`, `--help` | Show the built-in help text | + +## Environment variables + +These override the interactive prompts when set: + +| Variable | Purpose | +|---|---| +| `TRUSTGRAPH_TOKEN` | Admin/bootstrap API key (must start with `tg_`) | +| `TRUSTGRAPH_URL` | API gateway URL | +| `TRUSTGRAPH_UI_URL` | Workbench UI URL | +| `OPENAI_TOKEN` | OpenAI-compatible API key | +| `OPENAI_BASE_URL` | OpenAI-compatible base URL | +| `OLLAMA_HOST` / `OLLAMA_BASE_URL` | Ollama service URL | +| `OLLAMA_MODEL` | Ollama chat model (default: `granite4:350m`) | +| `OLLAMA_EMBEDDINGS_MODEL` | Ollama embeddings model (default: `mxbai-embed-large`) | +| `TG_INSTALL_DIR` | Override the install directory | +| `TG_VENV_DIR` | Override the Python venv location | +| `TG_NLTK_DATA_DIR` | Override the NLTK data directory | +| `TIKTOKEN_CACHE_DIR` | Override the tiktoken cache directory | +| `TG_HEALTH_TIMEOUT` | Seconds to wait for the API gateway (default: 240) | + +## Choosing an LLM mode + +### OpenAI (or any OpenAI-compatible provider) + +Best when you already have an API key or are running against a remote +endpoint. The installer asks for a base URL and an API key. + +```bash +OPENAI_TOKEN=sk-... ./install_trustgraph.sh +``` + +### Ollama (local models) + +Best on machines with enough RAM to run a small model. The installer detects +locally installed Ollama models and offers to pull missing ones. It uses +`host.docker.internal` so the Docker containers can reach the host-side +Ollama service. + +```bash +./install_trustgraph.sh # choose "ollama" when prompted +``` + +### None + +Start the platform without an LLM. Agent and RAG features will not work +until you configure one later through the Workbench. + +## Saved answers and re-running + +The installer saves your answers to +`/trustgraph-installer.env`. On the next run it loads those +answers as defaults, so you can re-run with a single Enter through each +prompt. + +To start completely fresh: + +```bash +./install_trustgraph.sh --fresh +``` + +This stops any running containers (keeping Docker volumes), removes +installer-managed files, and re-runs the full flow. + +## Using an existing compose file + +If you already have a compose file from the config tool or another source: + +```bash +./install_trustgraph.sh --use-existing-compose path/to/docker-compose.yaml +``` + +This skips the config wizard and `npx` prerequisite check, and goes straight +to starting the stack. + +## Non-interactive / CI usage + +```bash +TRUSTGRAPH_TOKEN=tg_my-token \ +OPENAI_TOKEN=sk-... \ +./install_trustgraph.sh --non-interactive --yes --skip-tests +``` + +In non-interactive mode the installer uses defaults for every prompt. Pair +with `--yes` to auto-accept confirmation prompts and `--skip-tests` if you +want a faster run. + +## Dry run + +Preview what the installer would do without making any changes: + +```bash +./install_trustgraph.sh --dry-run +``` + +This prints the detected hardware, recommended LLM mode, and planned +install paths, then exits. + +## Uninstalling + +```bash +./install_trustgraph.sh --remove-all +``` + +This stops containers, removes compose-managed volumes, and deletes +installer-managed files (venv, deploy output, logs, saved answers). It does +**not** remove Docker/Podman itself, container images, Ollama, or Ollama +models. + +## Troubleshooting + +### Logs + +All long-running operations write logs to `/logs/`. Key files: + +- `pytest.log` -- test suite output +- `compose-up.log` -- docker compose output +- `iam-bootstrap.log` -- IAM bootstrap output +- `ollama-pull-*.log` -- Ollama model downloads +- `pip-*.log` -- Python package installs +- `brew-install-*.log` -- Homebrew installs + +### API key rejected after reinstall + +If the API gateway returns 401/403 with your saved key, the compose volumes +likely contain IAM data from a previous install with a different key. Run: + +```bash +./install_trustgraph.sh --remove-all +./install_trustgraph.sh +``` + +This clears the old volumes and starts fresh. + +### Ollama not reachable from containers + +The Ollama base URL should use `host.docker.internal` instead of +`localhost` so that containers running in Docker Desktop can reach the +host-side Ollama service. The installer sets this automatically; if you +override `OLLAMA_HOST`, make sure the URL is reachable from inside the +container network. + +### Docker daemon not running + +The installer detects Docker Desktop and offers to start it. If that +doesn't work, start Docker Desktop manually and re-run the installer. diff --git a/README.md b/README.md index c366a3d9..69fcfb92 100644 --- a/README.md +++ b/README.md @@ -3,52 +3,97 @@ -[![PyPI version](https://img.shields.io/pypi/v/trustgraph.svg)](https://pypi.org/project/trustgraph/) [![License](https://img.shields.io/github/license/trustgraph-ai/trustgraph?color=blue)](LICENSE) ![E2E Tests](https://github.com/trustgraph-ai/trustgraph/actions/workflows/release.yaml/badge.svg) +[![PyPI version](https://img.shields.io/pypi/v/trustgraph.svg)](https://pypi.org/project/trustgraph/) ![License](https://img.shields.io/badge/license-Apache%202.0-blue) ![E2E Tests](https://github.com/trustgraph-ai/trustgraph/actions/workflows/release.yaml/badge.svg) [![Discord](https://img.shields.io/discord/1251652173201149994 )](https://discord.gg/sQMwkRz5GX) [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/trustgraph-ai/trustgraph) -[**Website**](https://trustgraph.ai) | [**Docs**](https://docs.trustgraph.ai) | [**YouTube**](https://www.youtube.com/@TrustGraphAI?sub_confirmation=1) | [**Configuration Terminal**](https://config-ui.demo.trustgraph.ai/) | [**Discord**](https://discord.gg/sQMwkRz5GX) | [**Blog**](https://blog.trustgraph.ai/subscribe) +[**Website**](https://trustgraph.ai) | [**Docs**](https://docs.trustgraph.ai) | [**YouTube**](https://www.youtube.com/@TrustGraphAI?sub_confirmation=1) | [**Configuration Terminal**](https://config-ui.demo.trustgraph.ai/) | [**Discord**](https://discord.gg/yUWRkfbD) | [**Blog**](https://blog.trustgraph.ai/subscribe) trustgraph-ai%2Ftrustgraph | Trendshift -# The agent runtime platform +# Write context once. Run agents anywhere. -TrustGraph is an agent runtime platform built around context graphs — structured, queryable representations of your domain knowledge that ground every agent query in verified, explainable facts in private deployments with sovereign control. The platform is the full stack for agentic systems: context graphs, memory, retrieval, orchestration, and inference for precision-critical agent workloads. +Stop rebuilding context from scratch. TrustGraph treats context as a holon — a modular, independent whole that naturally snaps into a larger domain-wide intelligence layer. By deploying context as holonic context graphs, TrustGraph powers multi-tenant agent workflows, dramatically reduces token consumption, and aligns with semantic web standards (RDF, OWL, SKOS, SHACL). Version your context, share it across teams, and scale with full provenance. -The platform: -- [x] Multi-model and multimodal database system - - [x] Tabular/relational, key-value - - [x] Document, graph, and vectors - - [x] Images, video, and audio -- [x] Context Graph engine - - [x] Automated entity and relationship extraction - - [x] Ontology-driven graph construction - - [x] Graph-grounded retrieval for explainable outputs -- [x] Automated data ingest and loading - - [x] Quick ingest with semantic similarity retrieval - - [x] Ontology structuring for precision retrieval -- [x] Out-of-the-box RAG pipelines - - [x] DocumentRAG - - [x] GraphRAG - - [x] OntologyRAG -- [x] 3D GraphViz for exploring context -- [x] Fully Agentic System - - [x] Single or Multi Agent - - [x] ReAct, Plan-then-Execute, and Supervisor patterns - - [x] MCP integration -- [x] Run anywhere - - [x] Deploy locally with Docker - - [x] Deploy in cloud with Kubernetes -- [x] Support for all major LLMs - - [x] API support for Anthropic, Cohere, Gemini, Mistral, OpenAI, and others - - [x] Model inferencing with vLLM, Ollama, TGI, LM Studio, and Llamafiles -- [x] Developer friendly - - [x] REST API [Docs](https://docs.trustgraph.ai/reference/apis/rest.html) - - [x] Websocket API [Docs](https://docs.trustgraph.ai/reference/apis/websocket.html) - - [x] Python API [Docs](https://docs.trustgraph.ai/reference/apis/python) - - [x] CLI [Docs](https://docs.trustgraph.ai/reference/cli/) +## What TrustGraph Does + +TrustGraph is a complete holonic context harness for all LLMs. It provides the full infrastructure layer underneath your agents: knowledge ingestion, structured storage, graph-grounded retrieval, agent orchestration, and a full LLM inferencing stack. + +TrustGraph relies on absolutely no 3rd party services aside from optional API integrations to cloud-hosted LLMs. Whether you are using Anthropic's or OpenAI's API, or self-hosting Qwen3.7 via vLLM, TrustGraph handles it all with pre-built API connectors and a full LLM inferencing stack to enrich the models with a sovereign, private holonic system that grounds your agents in reality. + +## The Problem: Why Agents Break + +When you build an AI agent today, you spend most of your time fighting context: + +- **RAG retrieves fragments, not meaning**. Chunks of text have no structure. Relationships between facts are invisible. Your agent guesses at the connections. + +- **Context is disposable**. What the agent learned in one session is gone in the next. There is no persistent, structured knowledge layer underneath. + +- **Answers aren't traceable**. You can't explain why the agent said what it said, which means you can't trust it in production. + +- **Knowledge can't be reused**. You rebuild the same context pipelines for every new project, every new agent, every new environment. + +These aren't retrieval problems. They are structural problems. Context needs to be organized, versioned, and composable — exactly the way software infrastructure is. + +## The Solution: A Holonic Context System +The philosopher Arthur Koestler coined the word [holon](https://en.wikipedia.org/wiki/Holon_(philosophy)) to describe something that is simultaneously a whole in itself and a part of something larger. A fact is whole. It is also part of a domain. A domain is whole. It is also part of an organization's knowledge. + +AI agents break down because this holonic structure is never built. Context gets shoved into flat text windows, scattered across vector stores, or hardwired into one-off prompts. Facts lose their relationships. + +TrustGraph solves this by organizing your domain into holonic context graphs. Entities, relationships, and evidence are treated as first-class objects. Every agent query is grounded against these holons—marrying symbolic graph structures with vector embeddings. Every answer carries provenance. Every fact is traceable. + +## Context Cores: Knowledge as a First-Class Citizen + +A Context Core is the deployable unit of knowledge in TrustGraph. It packages everything an agent needs to reason reliably over a domain into a single, portable artifact. + +### What's inside a Context Core +- **Ontology** — your domain schema and entity mappings +- **Holon** — entities, relationships, and supporting evidence +- **Embeddings** — vector indexes for fast semantic entry-point lookup +- **Provenance** — where every fact came from, when, and how it was derived +- **Retrieval policies** — traversal rules, freshness controls, authority ranking + +Context Cores decouple what agents know from how agents are deployed. Build once. Run in Docker locally, Kubernetes in production, or on any cloud. Pin a version. Roll back. Promote across environments. This is context engineering — and it works because knowledge is finally treated like the infrastructure it is. + +## Explainability: Trust Your Agents in Production +LLMs are black boxes, and traditional RAG makes it worse. When an agent pulls flat text chunks from a vector store, you have no idea how it connected those fragments to form an answer. You cannot ship agents to production if you can't explain why they said what they said. + +### How TrustGraph makes agents explainable: + +- **Traceable Reasoning Paths**: Instead of guessing at connections between text chunks, TrustGraph traverses explicit relationship paths in the holonic context graph. You can inspect exactly which entities, relationships, and sub-graphs were pulled into the LLM's context window to generate a given response. +- **Fact-Level Provenance**: Every node and edge in the graph carries strict provenance. When an agent makes a claim, you can trace it back to the exact source document, the time it was ingested, and the extraction method used to derive it. +- **No Black-Box Guesses**: By grounding the LLM in a structured, symbolic graph, you eliminate the hallucinations that occur when models are forced to infer relationships from unstructured text. If a fact isn't in the graph, the agent doesn't use it. + +TrustGraph doesn't just give you answers - it gives you the receipt. Every fact is traceable, every connection is visible, and every output is verifiable. + +## Workspaces, Collections, and Flows + +TrustGraph has a [three-level system](https://docs.trustgraph.ai/overview/workspaces) for organizing and isolating knowledge. + +A `Workspace` is the outermost boundary — a fully isolated tenancy scope where all data, users, configuration, and pipelines live independently from every other workspace. Isolation is structural: enforced at the pub/sub queue, storage, and API gateway layers, not by trusting a field in a message body. + +Within a workspace, a `Collection` groups related holons, graph structures, embeddings, and documents together — think of it as a dedicated shelf in a library, scoped to a specific domain, project, or customer. + +A `Flow` is a running data processing pipeline that defines how raw data moves through ingestion, extraction, structuring, and storage — the assembly line that turns documents into queryable knowledge. Together, the three layers let you run multiple isolated tenants on a single deployment, separate knowledge by domain within each tenant, and process that knowledge through fully configurable pipelines — all without restarting the system or rebuilding your infrastructure. + +## The Full Stack +TrustGraph is not a wrapper around a graph database. It is the complete backend for production agentic systems. + +- **Holonic context graph engine**: automated entity and relationship extraction, ontology-driven graph construction, graph-grounded retrieval for explainable outputs +- **Multi-model database**: tabular/relational, key-value, document, graph, vectors, images, video, and audio — all managed in Cassandra and S3-compatible Garage +- **Out-of-the-box RAG pipelines**: DocumentRAG, GraphRAG, and OntologyRAG ready to deploy +- **Fully agentic orchestration**: single or multi-agent, ReAct, Plan-then-Execute, Supervisor patterns, and MCP integration +- **3D Knowledge Explorer**: interactive graph visualization with BFS neighborhood extraction and edge pulse animation +- **Automated data ingest**: quick ingest with semantic similarity or ontology-structured precision retrieval +- **Run anywhere**: Docker/Podman locally, Kubernetes in the cloud + +All major LLMs — Anthropic, Cohere, Gemini, Mistral, OpenAI, and more via API. + +vLLM, Ollama, TGI, LM Studio, and Llamafiles for fully local inferencing. + +Verified cloud deployments for Alibaba Cloud, AWS, Azure, GCP, OVHcloud, and Scaleway. ## No API Keys Required @@ -62,12 +107,12 @@ Everything else is included. - [x] Managed Multi-model storage in [Cassandra](https://cassandra.apache.org/_/index.html) - [x] Managed Vector embedding storage in [Qdrant](https://github.com/qdrant/qdrant) - [x] Managed File and Object storage in [Garage](https://github.com/deuxfleurs-org/garage) (S3 compatible) -- [x] Managed High-speed Pub/Sub messaging fabric with [Pulsar](https://github.com/apache/pulsar) +- [x] Managed High-speed Pub/Sub messaging fabric with [Pulsar](https://github.com/apache/pulsar) or [RabbitMQ](https://www.rabbitmq.com/) - [x] Complete LLM inferencing stack for open LLMs with [vLLM](https://github.com/vllm-project/vllm), [TGI](https://github.com/huggingface/text-generation-inference), [Ollama](https://github.com/ollama/ollama), [LM Studio](https://github.com/lmstudio-ai), and [Llamafiles](https://github.com/mozilla-ai/llamafile) ## Quickstart -There's no need to clone this repo, unless you want to build from source. TrustGraph is a fully containerized app that deploys as a set of Docker containers. To configure TrustGraph on the command line: +No need to clone the repo unless you are building from source. TrustGraph deploys as a set of Docker containers. Configure it on the command line in one step: ``` npx @trustgraph/config @@ -78,44 +123,39 @@ The config process will generate an app config that can be run locally with Dock - Deployment instructions as `INSTALLATION.md`

-

For a browser based configuration, try the [Configuration Terminal](https://config-ui.demo.trustgraph.ai/). -## Watch What is a Context Graph? +## Watch What is a Holonic Context Graph? [![What is a Context Graph?](https://img.youtube.com/vi/gZjlt5WcWB4/maxresdefault.jpg)](https://www.youtube.com/watch?v=gZjlt5WcWB4) -## Watch Context Graphs in Action +## Watch Holonic Context Graphs in Action [![Context Graphs in Action with TrustGraph](https://img.youtube.com/vi/sWc7mkhITIo/maxresdefault.jpg)](https://www.youtube.com/watch?v=sWc7mkhITIo) ## Getting Started with TrustGraph - [**Getting Started Guides**](https://docs.trustgraph.ai/getting-started) -- [**Using the Workbench**](#workbench) - [**Developer APIs and CLI**](https://docs.trustgraph.ai/reference) - [**Deployment Guides**](https://docs.trustgraph.ai/deployment) -## Workbench +## TrustGraph UI -The **Workbench** provides tools for all major features of TrustGraph. The **Workbench** is on port `8888` by default. +Image -- **Vector Search**: Search the installed knowledge bases -- **Agentic, GraphRAG and LLM Chat**: Chat interface for agents, GraphRAG queries, or direct to LLMs -- **Relationships**: Analyze deep relationships in the installed knowledge bases -- **Graph Visualizer**: 3D GraphViz of the installed knowledge bases -- **Library**: Staging area for installing knowledge bases -- **Flow Classes**: Workflow preset configurations -- **Flows**: Create custom workflows and adjust LLM parameters during runtime -- **Knowledge Cores**: Manage resuable knowledge bases -- **Prompts**: Manage and adjust prompts during runtime -- **Schemas**: Define custom schemas for structured data knowledge bases -- **Ontologies**: Define custom ontologies for unstructured data knowledge bases -- **Agent Tools**: Define tools with collections, knowledge cores, MCP connections, and tool groups -- **MCP Tools**: Connect to MCP servers +The UI provides tools for all major features of TrustGraph. The UI deploys on port `8888` by default. + +- **Agent Console** — Query your agents directly with streaming responses and live explainability event tracking, so you can watch reasoning unfold in real time +- **GraphRAG View** — Interactive graph RAG queries with a visual explainability DAG and inline provenance display, making it easy to see exactly where answers came from +- **Context Explorer** — An interactive 3D context graph explorer with dynamic graph loading, BFS neighborhood extraction, edge pulse animation, and multiple navigation views +- **Document Ingestion** — A complete upload and submission workflow with page and chunk inspection and document structure browsing +- **Ontology Workbench** — A full ontology editor with class and property trees, OWL/XML and Turtle import/export with round-trip fidelity, circular dependency detection, and safe-delete confirmation dialogs +- **Schema Workbench** — Interactive schema management with list, create, edit, and delete operations including field and index management +- **Prompt Editor** — A dedicated prompt editing workflow ## TypeScript Library for UIs @@ -125,134 +165,6 @@ There are 3 libraries for quick UI integration of TrustGraph services. - [@trustgraph/react-state](https://www.npmjs.com/package/@trustgraph/react-state) - [@trustgraph/react-provider](https://www.npmjs.com/package/@trustgraph/react-provider) -## Context Cores - -Context Cores are how TrustGraph treats context like code. A Context Core is a **portable, versioned bundle of context** that you can ship between projects and environments, pin in production, and reuse across agents. It packages the “stuff agents need to know” (structured knowledge + embeddings + evidence + policies) into a single artifact, so you can treat context like code: build it, test it, version it, promote it, and roll it back. TrustGraph is built to support this kind of end-to-end context engineering and orchestration workflow. - -### What’s inside a Context Core -A Context Core typically includes: -- Ontology (your domain schema) and mappings -- Context Graph (entities, relationships, supporting evidence) -- Embeddings / vector indexes for fast semantic entry-point lookup -- Source manifests + provenance (where facts came from, when, and how they were derived) -- Retrieval policies (traversal rules, freshness, authority ranking) - -## Tech Stack -TrustGraph provides component flexibility to optimize agent workflows. - -
-LLM APIs -
- -- Anthropic
-- AWS Bedrock
-- AzureAI
-- AzureOpenAI
-- Cohere
-- Google AI Studio
-- Google VertexAI
-- Mistral
-- OpenAI
- -
-
-LLM Orchestration -
- -- LM Studio
-- Llamafiles
-- Ollama
-- TGI
-- vLLM
- -
-
-Multi-model storage -
- -- Apache Cassandra
- -
-
-VectorDB -
- -- Qdrant
- -
-
-File and Object Storage -
- -- Garage
- -
-
-Observability -
- -- Prometheus
-- Grafana
-- Loki
- -
-
-Data Streaming -
- -- Apache Pulsar
-- RabbitMQ
-- Apache Kafka
- -
-
-Clouds -
- -- AWS
-- Azure
-- Google Cloud
-- OVHcloud
-- Scaleway
- -
- -## Observability & Telemetry - -Once the platform is running, access the Grafana dashboard at: - -``` -http://localhost:3000 -``` - -Default credentials are: - -``` -user: admin -password: admin -``` - -The default Grafana dashboard tracks the following: - -
-Telemetry -
- -- LLM Latency
-- Error Rate
-- Service Request Rates
-- Queue Backlogs
-- Chunking Histogram
-- Error Source by Service
-- Rate Limit Events
-- CPU usage by Service
-- Memory usage by Service
-- Models Deployed
-- Token Throughput (Tokens/second)
-- Cost Throughput (Cost/second)
- -
- ## Contributing [Developer's Guide](https://docs.trustgraph.ai/guides/building/introduction.html) @@ -261,7 +173,7 @@ The default Grafana dashboard tracks the following: **TrustGraph** is licensed under [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0). - Copyright 2024-2025 TrustGraph + Copyright 2024-2026 TrustGraph Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/docs/tech-specs/capabilities.md b/docs/tech-specs/capabilities.md index ba27c738..7717cbc9 100644 --- a/docs/tech-specs/capabilities.md +++ b/docs/tech-specs/capabilities.md @@ -100,7 +100,6 @@ multi-word subsystems. | `users:admin` | Assign / remove roles on users within the workspace | | `keys:self` | Create / revoke / list **own** API keys | | `keys:admin` | Create / revoke / list **any user's** API keys within the workspace | -| `workspaces:list-own` | List workspaces the caller has access to | | `workspaces:admin` | Create / delete / disable workspaces (system-level) | | `iam:admin` | JWT signing-key rotation, IAM-level operations | | `metrics:read` | Prometheus metrics proxy | @@ -111,7 +110,7 @@ The open-source edition ships three roles: | Role | Capabilities | |---|---| -| `reader` | `agent`, `graph:read`, `documents:read`, `rows:read`, `llm`, `embeddings`, `mcp`, `collections:read`, `knowledge:read`, `flows:read`, `config:read`, `keys:self`, `workspaces:list-own` | +| `reader` | `agent`, `graph:read`, `documents:read`, `rows:read`, `llm`, `embeddings`, `mcp`, `collections:read`, `knowledge:read`, `flows:read`, `config:read`, `keys:self` | | `writer` | everything in `reader` **+** `graph:write`, `documents:write`, `rows:write`, `collections:write`, `knowledge:write` | | `admin` | everything in `writer` **+** `config:write`, `flows:write`, `users:read`, `users:write`, `users:admin`, `keys:admin`, `workspaces:admin`, `iam:admin`, `metrics:read` | diff --git a/docs/tech-specs/graph-rag-semantic-filter.md b/docs/tech-specs/graph-rag-semantic-filter.md new file mode 100644 index 00000000..58497d10 --- /dev/null +++ b/docs/tech-specs/graph-rag-semantic-filter.md @@ -0,0 +1,541 @@ +# GraphRAG Semantic Filter Improvement + +## Problem Statement + +The GraphRAG semantic filter is observed to be ineffective with certain +LLM models. Smaller models in particular produce poor-quality edge +relevance scores, and there is a suspicion that models trained or +evaluated heavily on non-Roman-script datasets offer lower performance +on the semantic ranking operation. + +The root cause is that the current implementation delegates edge +relevance scoring to the LLM via a prompt that asks the model to +assign a 1–10 relevance score to each knowledge-graph edge. This +task — ranking structured triples for relevance to a natural-language +query — is not well covered in standard LLM evaluation suites, so +model benchmark scores are not predictive of performance on this +operation. The result is that GraphRAG quality varies unpredictably +across model choices, undermining confidence in the pipeline. + +Beyond model variability, the LLM scoring step has further problems: + +- **Cost and latency.** The LLM call consumes tokens and adds + latency to every query, yet its output is unreliable. Even when + the model performs well, the cost is disproportionate for what is + fundamentally a ranking operation. + +- **Subjective scoring scale.** The 1–10 relevance scale gives the + model no objective criteria for what constitutes a 5 versus a 7. + Different models interpret the scale differently, and even the same + model can produce inconsistent scores across runs. + +- **Redundancy with the embedding pre-filter.** The pipeline already + contains a cosine-similarity stage that ranks edges by semantic + relevance using embeddings. The LLM scoring step is a second + filter applied on top of this, and it is not clear that it adds + enough value to justify the additional cost and risk of + degradation. + +### Industry context + +Semantic ranking is rigorously evaluated on dedicated benchmarks such +as MTEB (Massive Text Embedding Benchmark) and BEIR (Benchmarking +Information Retrieval), which test retrieval and reranking across +diverse domains. The current TrustGraph approach — prompting a +general-purpose LLM to score and rank documents (the "listwise" +approach) — is known to be poorly optimized for this task. It +suffers from positional bias, formatting failures, and +inconsistency at scale. + +The industry standard for semantic ranking has moved to +cross-encoder models: lightweight, purpose-built models that take a +query–document pair as input and produce a single relevance score. +These models are fine-tuned on millions of relevance-labelled pairs +and dominate retrieval benchmarks. They are fast, deterministic, +and do not require an LLM inference call. + +## Architecture + +### Cross-encoder service + +A new request/response service that exposes a generic semantic +ranking API. The service is not specific to GraphRAG — it is a +reusable building block for any component that needs to rank text +by relevance. + +The service interface is pluggable. Alternative implementations +can be swapped in behind the same API. + +**Packaging options considered:** + +- *`sentence-transformers`.* Full-featured, widely used. + However, it pulls in PyTorch (~2 GB), making containers + very large. Tested at ~1.8 seconds for 2200 edges. + +- *`optimum.onnxruntime`.* ONNX-based inference. Still + depends on PyTorch at import time despite using ONNX for + inference. Tested at ~4.2 seconds for 2200 edges. + +- *`flashrank`.* Lightweight wrapper around ONNX Runtime + with a clean API (`Ranker`, `RerankRequest`). No PyTorch + dependency. Tested at ~4.4 seconds for 2200 edges. + +- *Pure `onnxruntime` + `tokenizers`.* Leanest option + (~200 MB total). Requires manual tokenisation, padding, + and numpy array management — more boilerplate to maintain. + +- *External API (e.g. Cohere Rerank).* No local model at + all. Adds network latency and an external dependency. + +**Decision:** `flashrank` for the initial implementation. +No PyTorch dependency, clean API, comparable performance. +The pluggable interface allows swapping to another backend +later. + +**Request:** + +- `queries` — list of `{id, text}` objects. In the GraphRAG use + case these are the concepts extracted from the user's question. +- `documents` — list of `{id, text}` objects. In the GraphRAG + use case these are the candidate knowledge-graph edges + represented as text. +- `limit` — integer. Maximum number of results to return. + +**Scoring:** + +The service produces the cartesian product of all query–document +pairs and scores each pair through the cross-encoder model. For +each document, the maximum score across all queries is taken as the +document's relevance score. Documents are then ranked by this +score and the top `limit` results are returned. + +**Response:** + +A list of the top `limit` results, each containing: + +- `document_id` — the ID of the matched document. +- `query_id` — the ID of the query (concept) that produced the + highest score for this document. +- `score` — the relevance score. + +Including `query_id` in the response supports the explainability +interface: it records that an edge was selected because it is +related to a specific concept. + +### Integration + +The cross-encoder service follows the standard TrustGraph service +integration pattern: + +- **Base package (trustgraph-base).** Schema definitions for the + cross-encoder request/response messages. A client class that + other components (e.g. GraphRAG) can use to call the + cross-encoder service. Message translator registration so the + pub/sub layer can serialise/deserialise the messages. + +- **Flow package (trustgraph-flow).** The cross-encoder service + implementation itself — loads the model, listens for requests, + scores pairs, returns results. Flow definition support so the + cross-encoder can be introduced into a processing flow via the + standard flow configuration. `flashrank` is added as a + dependency of `trustgraph-flow`. The service runs in its own + container. + +- **API gateway.** A gateway endpoint that routes cross-encoder + requests from the HTTP API to the service over pub/sub and + returns the response. + +- **CLI tool.** A command-line utility + (e.g. `tg-invoke-cross-encoder`) that calls the gateway + endpoint for manual testing and debugging. + +### Current GraphRAG pipeline + +The current pipeline follows these steps: + +1. **Concept extraction.** An LLM prompt extracts key concepts + from the user's query. + +2. **Graph exploration.** Seed entities are found via embedding + similarity. A subgraph is built by multi-hop traversal from + the seed entities (up to `max_path_length` hops, capped at + `max_subgraph_size` edges). + +3. **Embedding pre-filter.** Each edge is embedded as + `"subject, predicate, object"` and scored by cosine similarity + against the concept embeddings. The top `edge_score_limit` + (default 30) edges are kept. + +4. **LLM edge scoring.** The `kg-edge-scoring` prompt asks the + LLM to assign a 1–10 relevance score to each remaining edge. + The top `edge_limit` (default 25) edges are kept. + +5. **LLM edge reasoning.** The `kg-edge-reasoning` prompt asks + the LLM to explain why each selected edge is relevant to the + query. Used for the explainability interface. + +6. **Document tracing.** Selected edges are traced back to their + source documents in the librarian. Runs concurrently with + step 5. + +7. **Synthesis.** The `kg-synthesis` prompt generates the final + answer from the selected edges and source document metadata. + +### Potential improvements + +#### Replace LLM edge scoring with cross-encoder (step 4) + +The LLM edge scoring step is replaced by a call to the +cross-encoder service. The candidate edges are the documents and +`edge_limit` is the limit. This is a direct substitution: faster, +cheaper, deterministic, and more reliable across model choices. +The LLM `kg-edge-scoring` prompt is retired. + +**Cross-encoder query input: concepts vs. raw query.** There are +two options for what to use as the cross-encoder queries: + +- *Option A: Raw user query.* Pass the original question as a + single query string. Simpler, no dependency on concept + extraction. However, raw queries contain noise words and + conversational phrasing that do not match well against the + structured vocabulary of knowledge-graph edges. A single query + also means every edge competes against the full question — a + partial match on one aspect is diluted. + +- *Option B: Extracted concepts.* Pass the concepts from step 1 + as separate queries. The concepts are distilled, focused terms + that are closer to the language of the edges. With multiple + concepts as independent queries, the cross-encoder scores each + edge against each concept separately, giving better coverage — + an edge only needs to match one concept well to be selected. + The trade-off is a dependency on the LLM concept extraction + step, but this is already in the pipeline and is a lightweight, + reliable LLM call. + +**Decision:** Option B — use extracted concepts. The concept +extraction is fast, and the resulting terms produce better +cross-encoder matches against structured triples. + +#### Edge text representation + +The current embedding pre-filter represents each edge as +`"subject, predicate, object"`. Two changes: + +- **Drop commas.** Commas add tokenisation noise without semantic + value. + +- **Direction-aware text.** The reranker text should highlight + the *new* information relative to the traversal direction. + The frontier entity is already known context — repeating it + adds noise and, when traversing from an object node, causes + many edges to produce identical reranker text (e.g. 18 + products sharing the same `hasSubcategory Processors` triple + all collapse to the same string when the subject is dropped). + + The text is constructed based on which position the frontier + entity occupied in the triple: + + - **From subject** (s=entity): `"{predicate} {object}"` — + the subject is known, predicate and object are new. + - **From object** (o=entity): `"{subject} {predicate}"` — + the object is known, subject and predicate are new. + - **From predicate** (p=entity): `"{subject} {object}"` — + the predicate is known, subject and object are new. + + This eliminates the duplicate-text problem that arises when + traversing inward from a shared object node, and gives the + cross-encoder a more informative signal at every hop. + +#### Remove the embedding pre-filter (step 3) + +The embedding pre-filter was introduced to reduce the number of +edges before the expensive LLM scoring call. With the +cross-encoder replacing the LLM call, this cost equation changes. + +**Arguments for removal:** + +- The cross-encoder is fast enough to score the full subgraph + directly. In testing, 2200 edges scored in ~1.8 seconds; at + the default `max_subgraph_size` of 150 edges, scoring takes + a fraction of a second. + +- The pre-filter is a weaker version of what the cross-encoder + does. Bi-encoder cosine similarity embeds the query and + document independently and compares vectors; the cross-encoder + processes both texts together through the full transformer, + giving it much better relevance judgement. Running a weaker + filter before a stronger one adds latency without improving + quality. + +- Removing it eliminates an embedding service call (two batches: + concepts + edges) and the associated latency. + +**Arguments for keeping it:** + +- If the subgraph is very large (thousands of edges), the + cross-encoder's linear scaling could become a bottleneck. + The pre-filter would act as a safety valve. + +- The embedding call is cheap compared to an LLM call, so the + overhead is modest. + +**Decision:** Remove the pre-filter. The `max_subgraph_size` +parameter (default 150) already caps the number of edges entering +this stage, so the cross-encoder will not face an unbounded +workload. If very large subgraphs become a concern in future, +the pre-filter can be reintroduced or `max_subgraph_size` can be +tuned. + +#### Iterative graph traversal with cross-encoder filtering + +The current pipeline performs graph exploration and edge filtering +as separate phases: first build the full subgraph (up to +`max_path_length` hops), then score and filter edges. An +alternative is to interleave traversal and filtering — at each +hop, use the cross-encoder to select relevant edges before +expanding further. + +**Option A: Big-bang traversal then filter.** Traverse the full +subgraph up to `max_path_length` hops from the seed entities, +collecting all edges up to `max_subgraph_size`. Then +cross-encode the entire result to select the top edges. + +- Simple to implement — the current traversal logic is largely + unchanged. +- Produces large, unfocused subgraphs. Irrelevant branches are + explored and scored even though they will be discarded. +- Poorly suited to multi-hop reasoning. For a query about + Voyager 1, the subgraph includes Voyager 2's edges because + they are within hop distance, and the filter must then + separate them. + +**Option B: Iterative hop-and-filter.** At each hop: + +1. Retrieve all edges one hop from the current frontier nodes. +2. Cross-encode these edges against the query concepts. +3. Select the top relevant edges. +4. The target nodes of the selected edges become the frontier + for the next hop. +5. Repeat up to `max_path_length` hops. + +The final set of selected edges across all hops is the input to +synthesis. + +- **Guided exploration.** Each hop focuses the search by + pruning irrelevant branches before expanding further. The + working set stays small and relevant at every step. +- **Multi-hop reasoning works naturally.** Following + "Voyager 1 → has-event → crossed the heliopause" succeeds + because each hop is individually relevant and leads to the + next. +- **Smaller total workload.** Fewer edges are scored overall + because irrelevant branches are never expanded. +- **Trade-off: greedy pruning.** An edge discarded at hop 1 + cannot lead to relevant edges at hop 2. This is inherent in + any bounded traversal, and the cross-encoder is better + equipped to make this relevance judgement than a blind hop + limit. +- **Trade-off: sequential latency.** Hops cannot be + parallelised since each depends on the previous. However, + each cross-encoder call on a small edge set is very fast + (sub-second for typical working sets). + +**Decision:** Option B — iterative hop-and-filter. The guided +traversal produces more focused subgraphs and supports multi-hop +reasoning, which is a significant quality improvement over the +current approach. + +#### Replace LLM edge reasoning with cross-encoder metadata (step 5) + +The current `kg-edge-reasoning` prompt asks the LLM to explain why +each edge is relevant. With the cross-encoder now making the +selection, this explanation would be a post-hoc fabrication — the +LLM was not involved in the decision. + +- *Option A: Keep LLM reasoning.* Generates natural-language + explanations but they are not grounded in the actual selection + process. Adds an LLM call per query. + +- *Option B: Record cross-encoder metadata.* The cross-encoder + already returns the matched concept and score for each selected + edge. Use this directly as the explanation. + +**Decision:** Option B. The cross-encoder metadata is the true +reason the edge was selected. The `kg-edge-reasoning` prompt is +retired. + +#### Explainability interface update + +The explainability interface uses a `Focus` entity containing +`EdgeSelection` sub-entities. Each `EdgeSelection` currently +carries an `edge` (the quoted triple) and a `reasoning` field +(free-text LLM prose), stored as `tg:reasoning` in the +provenance graph. + +With the cross-encoder replacing LLM reasoning, the +`EdgeSelection` type gains two new predicates and drops one: + +- **Remove** `tg:reasoning` — no longer produced. +- **Add** `tg:concept` — the concept text that produced the + highest cross-encoder score for this edge. +- **Add** `tg:score` — the cross-encoder relevance score. + +This is an evolution of the existing `EdgeSelection` type, not a +new entity type. The edge selection sub-entities currently have +no `rdf:type` declared; a new `tg:EdgeSelection` type should be +added so that consumers can identify them in the provenance +graph. The `Focus` entity and its relationship to `Exploration` +are unchanged. + +The `Focus` entity's token-usage metadata (`tg:inToken`, +`tg:outToken`, `tg:llmModel`) no longer applies since there is +no LLM call. These fields are dropped from the Focus entity. + +### Proposed pipeline + +1. **Concept extraction.** Unchanged — LLM extracts key concepts + from the user's query. + +2. **Seed entity lookup.** Find seed entities via embedding + similarity against the extracted concepts. + +3. **Iterative hop-and-filter.** For each hop up to + `max_path_length`: + + a. Retrieve all edges one hop from the current frontier nodes. + + b. Represent each edge using direction-aware text: from a + subject node use `"{predicate} {object}"`, from an object + node use `"{subject} {predicate}"`, from a predicate node + use `"{subject} {object}"`. + + c. Score edges against the extracted concepts using the + cross-encoder service. + + d. Select the top relevant edges. The target nodes of the + selected edges become the frontier for the next hop. + +4. **Document tracing.** Selected edges are traced back to source + documents. + +5. **Synthesis.** The `kg-synthesis` prompt generates the final + answer from the selected edges and source document metadata. + +### Implementation order + +1. Cross-encoder service with full integration (base schema, + flow service, gateway endpoint, CLI tool). +2. GraphRAG pipeline changes (iterative hop-and-filter, + edge representation, remove pre-filter). +3. Explainability update (`tg:EdgeSelection` type, concept + and score predicates, retire `tg:reasoning`). +4. Retire `kg-edge-scoring` and `kg-edge-reasoning` prompts. +5. Update `tg-invoke-graph-rag` and `tg-show-explain-trace` + to display the new metadata. Use these as the main + end-to-end test. +6. Fix any failing unit tests, then add new tests as needed. +7. Write guidance for UX devs to update the UI for the new + explainability predicates. + +## UX developer guidance + +This section describes the changes to the explainability interface +that affect frontend rendering of GraphRAG Focus events. + +### What changed + +Edge selection in GraphRAG previously used LLM-based scoring and +reasoning. Each selected edge carried a `tg:reasoning` predicate +with free-text explanation from the LLM. This has been replaced +by a cross-encoder reranker that scores edges against query +concepts. The explainability data now carries structured metadata +instead of free text. + +### Removed + +- **`tg:reasoning`** is no longer emitted on edge selection + entities in GraphRAG Focus events. UX code that reads + `edge_sel.reasoning` will get an empty string. Remove any + rendering that displays a "Reasoning" or "Reason" field for + Focus edges. + +- The **`kg-edge-scoring`**, **`kg-edge-reasoning`**, and + **`kg-edge-selection`** prompts are retired. Any UX that + references these prompt names should be cleaned up. + +### Added + +Each edge selection entity within a Focus event now has three +new properties: + +| RDF predicate | API field | Type | Description | +|---|---|---|---| +| `rdf:type tg:EdgeSelection` | (type check) | — | Each edge selection entity is now explicitly typed | +| `tg:concept` | `edge_sel.concept` | `str` | The query concept that matched this edge | +| `tg:score` | `edge_sel.score` | `float` or `None` | Cross-encoder relevance score (0.0–1.0) | + +The `tg:edge` predicate (RDF-star quoted triple) is unchanged. + +### How to render + +The recommended rendering for each selected edge in a Focus event: + +``` +Edge: (subject_label, predicate_label, object_label) + Concept: Score: +``` + +Scores near 1.0 indicate high relevance; scores near 0.0 indicate +low relevance. UX could use the score to drive visual indicators +such as colour intensity or a relevance bar. + +Edges are not returned in score order — they arrive in traversal +order across hops. If the UX wants to display edges ranked by +relevance, sort by `edge_sel.score` descending. + +### API classes (Python) + +The `EdgeSelection` dataclass in `trustgraph.api.explainability` +has these fields: + +```python +@dataclass +class EdgeSelection: + uri: str + edge: Optional[Dict[str, str]] # {"s": ..., "p": ..., "o": ...} + reasoning: str = "" # Legacy, always empty for new traces + concept: str = "" # Query concept that matched + score: Optional[float] = None # Cross-encoder relevance score +``` + +These are populated when calling +`ExplainabilityClient.fetch_focus_with_edges()` or when parsing +inline provenance triples from the streaming response. + +### WebSocket response format + +For inline explainability via the streaming WebSocket, Focus events +arrive as `message_type: "explain"` responses. The `explain_triples` +array contains the edge selection triples. The relevant predicates +in wire format are: + +```json +{"s": {"t": "i", "i": ""}, + "p": {"t": "i", "i": "https://trustgraph.ai/ns/concept"}, + "o": {"t": "l", "v": "flyby event"}} + +{"s": {"t": "i", "i": ""}, + "p": {"t": "i", "i": "https://trustgraph.ai/ns/score"}, + "o": {"t": "l", "v": "0.9962"}} +``` + +Note that `tg:score` is transmitted as a string literal and must +be parsed to a float on the client side. + +### Exploration event + +The Exploration event's `edge_count` field now reports the number +of edges selected by the cross-encoder across all hops (previously +it reported the total number of edges retrieved before filtering). +The `entities` list continues to report the seed entities found +by vector search. diff --git a/install_trustgraph.sh b/install_trustgraph.sh new file mode 100644 index 00000000..b3919791 --- /dev/null +++ b/install_trustgraph.sh @@ -0,0 +1,2603 @@ +#!/usr/bin/env bash + +set -Eeuo pipefail + +APP_NAME="TrustGraph" +DEFAULT_API_URL="http://localhost:8088/" +DEFAULT_UI_URL="http://localhost:8888" +DEFAULT_INSTALL_DIR="trustgraph-deploy" +DEFAULT_OLLAMA_MODEL="granite4:350m" +DEFAULT_OLLAMA_EMBEDDINGS_MODEL="mxbai-embed-large" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +INSTALL_DIR="${TG_INSTALL_DIR:-$SCRIPT_DIR/$DEFAULT_INSTALL_DIR}" +VENV_DIR="${TG_VENV_DIR:-$INSTALL_DIR/.venv}" +NLTK_DATA_DIR="${TG_NLTK_DATA_DIR:-$INSTALL_DIR/nltk_data}" +TIKTOKEN_CACHE_DIR_VALUE="${TIKTOKEN_CACHE_DIR:-$INSTALL_DIR/tiktoken_cache}" +PYTHON_BIN="python3" +API_URL="${TRUSTGRAPH_URL:-$DEFAULT_API_URL}" +UI_URL="${TRUSTGRAPH_UI_URL:-$DEFAULT_UI_URL}" + +RUN_TESTS=1 +AUTO_LAUNCH=1 +NON_INTERACTIVE=0 +DRY_RUN=0 +YES=0 +FRESH_INSTALL=0 +REMOVE_ALL=0 +USE_EXISTING_COMPOSE="" +HEALTH_TIMEOUT="${TG_HEALTH_TIMEOUT:-240}" +AUTH_CHECK_TIMEOUT="${TG_AUTH_CHECK_TIMEOUT:-45}" + +AUTH_TOKEN="${TRUSTGRAPH_TOKEN:-}" +LLM_MODE="" +OPENAI_TOKEN_VALUE="${OPENAI_TOKEN:-}" +OPENAI_BASE_URL_VALUE="${OPENAI_BASE_URL:-https://api.openai.com/v1}" +OLLAMA_BASE_URL_VALUE="${OLLAMA_HOST:-${OLLAMA_BASE_URL:-}}" +OLLAMA_MODEL="${OLLAMA_MODEL:-$DEFAULT_OLLAMA_MODEL}" +OLLAMA_EMBEDDINGS_MODEL="${OLLAMA_EMBEDDINGS_MODEL:-$DEFAULT_OLLAMA_EMBEDDINGS_MODEL}" + +HW_OS="" +HW_ARCH="" +HW_CPU_CORES="unknown" +HW_MEMORY_GB="unknown" +HW_GPU="none detected" +HW_CONTAINER_HINT="" +RECOMMENDED_LLM_MODE="openai" +RECOMMENDATION_REASON="" +COMPOSE_CMD=() +COLOR_RESET="" +COLOR_HEADING="" +COLOR_INFO="" +COLOR_WARN="" +COLOR_ERROR="" +COLOR_ACCENT="" + +usage() { + cat <<'USAGE' +Usage: ./install_trustgraph.sh [options] + +Guided local installer for TrustGraph. It detects the machine hardware, +recommends a local or hosted LLM path, asks for the few required values, +enumerates local Ollama models when relevant, runs the repo tests, generates +a deployment with the existing config tool, starts the stack, checks health, +and opens the Workbench UI. + +Options: + --install-dir PATH Directory for generated deployment files. + --api-url URL API gateway URL for health checks. + --ui-url URL Workbench UI URL to open. + --use-existing-compose F Skip config generation and start this compose file. + --skip-tests Do not run the full pytest suite. + --no-launch Do not open the Workbench UI at the end. + --non-interactive Use defaults where possible. Best with --dry-run or + --use-existing-compose. + --yes Accept confirmation prompts. + --fresh Remove installer-managed files in --install-dir + before generating a new deployment. + --remove-all Uninstall the installer-managed deployment: + stop containers, remove compose volumes, and + delete only installer-managed files. + --dry-run Show detected hardware and planned defaults only. + -h, --help Show this help. + +Environment defaults: + TRUSTGRAPH_TOKEN, TRUSTGRAPH_URL, OPENAI_TOKEN, OPENAI_BASE_URL, + OLLAMA_HOST, OLLAMA_BASE_URL, OLLAMA_MODEL, OLLAMA_EMBEDDINGS_MODEL, + TG_INSTALL_DIR, TG_VENV_DIR, TG_NLTK_DATA_DIR, TIKTOKEN_CACHE_DIR, + TG_HEALTH_TIMEOUT +USAGE +} + +say() { + printf '\n%b%s%b\n' "$COLOR_HEADING" "$*" "$COLOR_RESET" +} + +info() { + printf ' %b%s%b\n' "$COLOR_INFO" "$*" "$COLOR_RESET" +} + +warn() { + printf '%bWarning:%b %s\n' "$COLOR_WARN" "$COLOR_RESET" "$*" >&2 +} + +die() { + printf '%bError:%b %s\n' "$COLOR_ERROR" "$COLOR_RESET" "$*" >&2 + exit 1 +} + +command_exists() { + command -v "$1" >/dev/null 2>&1 +} + +spinner_enabled() { + [[ "${TG_NO_SPINNER:-0}" != "1" ]] && { [[ -t 2 ]] || [[ "${TG_FORCE_SPINNER:-0}" == "1" ]]; } +} + +clear_spinner_line() { + printf '\r\033[K' >&2 +} + +run_with_spinner() { + local message="$1" + shift + local frames=('|' '/' '-' '\') + local frame=0 + local pid + local status + + if ! spinner_enabled; then + "$@" + return + fi + + "$@" & + pid=$! + while kill -0 "$pid" 2>/dev/null; do + printf '\r %b%s%b %s' "$COLOR_ACCENT" "${frames[$frame]}" "$COLOR_RESET" "$message" >&2 + frame=$(((frame + 1) % ${#frames[@]})) + sleep 0.2 + done + + if wait "$pid"; then + status=0 + else + status=$? + fi + + clear_spinner_line + if [[ "$status" -eq 0 ]]; then + info "Done: $message" + else + warn "Failed: $message" + fi + return "$status" +} + +run_with_spinner_logged() { + local message="$1" + local log_file="$2" + shift 2 + local frames=('|' '/' '-' '\') + local frame=0 + local pid + local status + + if ! spinner_enabled; then + "$@" + return + fi + + mkdir -p "$(dirname "$log_file")" + "$@" >"$log_file" 2>&1 & + pid=$! + while kill -0 "$pid" 2>/dev/null; do + printf '\r %b%s%b %s' "$COLOR_ACCENT" "${frames[$frame]}" "$COLOR_RESET" "$message" >&2 + frame=$(((frame + 1) % ${#frames[@]})) + sleep 0.2 + done + + if wait "$pid"; then + status=0 + else + status=$? + fi + + clear_spinner_line + if [[ "$status" -eq 0 ]]; then + info "Done: $message" + else + warn "Failed: $message" + warn "Last log lines from $log_file:" + tail -n 40 "$log_file" >&2 || true + fi + return "$status" +} + +installer_log_file() { + local name="$1" + mkdir -p "$INSTALL_DIR/logs" + printf '%s/logs/%s.log\n' "$INSTALL_DIR" "$name" +} + +command_to_text() { + local arg + local out="" + + for arg in "$@"; do + if [[ -n "$out" ]]; then + out="$out " + fi + out="$out$(printf '%q' "$arg")" + done + + printf '%s\n' "$out" +} + +root_command_to_text() { + if [[ "${EUID:-$(id -u)}" -eq 0 ]]; then + command_to_text "$@" + elif command_exists sudo; then + command_to_text sudo "$@" + else + command_to_text "$@" + fi +} + +run_root_command() { + if [[ "${EUID:-$(id -u)}" -eq 0 ]]; then + "$@" + elif command_exists sudo; then + sudo "$@" + else + warn "Could not find sudo. Run this installer as an administrator or install the prerequisite manually." + return 1 + fi +} + +confirm_install_command() { + local question="$1" + local command_text="$2" + + info "Command: $command_text" + + if [[ "$YES" -eq 1 ]]; then + return 0 + fi + + if [[ "$NON_INTERACTIVE" -eq 1 ]]; then + return 1 + fi + + confirm "$question" 1 +} + +init_colors() { + if [[ -n "${NO_COLOR:-}" || ! -t 1 ]]; then + return + fi + + if command_exists tput && tput colors >/dev/null 2>&1 && [[ "$(tput colors)" -ge 8 ]]; then + COLOR_RESET="$(tput sgr0)" + COLOR_HEADING="$(tput bold)$(tput setaf 6)" + COLOR_INFO="$(tput setaf 2)" + COLOR_WARN="$(tput setaf 3)" + COLOR_ERROR="$(tput bold)$(tput setaf 1)" + COLOR_ACCENT="$(tput bold)$(tput setaf 5)" + fi +} + +print_banner() { + printf '\n%b+---------------------------+%b\n' "$COLOR_ACCENT" "$COLOR_RESET" + printf '%b| Touchgraph Easy Installer |%b\n' "$COLOR_ACCENT" "$COLOR_RESET" + printf '%b+---------------------------+%b\n' "$COLOR_ACCENT" "$COLOR_RESET" +} + +parse_args() { + while [[ $# -gt 0 ]]; do + case "$1" in + --install-dir) + [[ $# -ge 2 ]] || die "--install-dir needs a path" + INSTALL_DIR="$2" + shift 2 + ;; + --api-url) + [[ $# -ge 2 ]] || die "--api-url needs a URL" + API_URL="$2" + shift 2 + ;; + --ui-url) + [[ $# -ge 2 ]] || die "--ui-url needs a URL" + UI_URL="$2" + shift 2 + ;; + --use-existing-compose) + [[ $# -ge 2 ]] || die "--use-existing-compose needs a file path" + USE_EXISTING_COMPOSE="$2" + shift 2 + ;; + --skip-tests) + RUN_TESTS=0 + shift + ;; + --no-launch) + AUTO_LAUNCH=0 + shift + ;; + --non-interactive) + NON_INTERACTIVE=1 + shift + ;; + --yes) + YES=1 + shift + ;; + --fresh) + FRESH_INSTALL=1 + shift + ;; + --remove-all) + REMOVE_ALL=1 + shift + ;; + --dry-run) + DRY_RUN=1 + shift + ;; + -h|--help) + usage + exit 0 + ;; + *) + die "Unknown option: $1" + ;; + esac + done + + case "$API_URL" in + */) ;; + *) API_URL="$API_URL/" ;; + esac +} + +prompt_value() { + local label="$1" + local default="$2" + local helper="$3" + local answer="" + + if [[ -n "$helper" ]]; then + printf ' %s\n' "$helper" >&2 + fi + + if [[ "$NON_INTERACTIVE" -eq 1 ]]; then + printf '%s\n' "$default" + return + fi + + if [[ -n "$default" ]]; then + read -r -p "$label [$default]: " answer + printf '%s\n' "${answer:-$default}" + else + read -r -p "$label: " answer + printf '%s\n' "$answer" + fi +} + +looks_like_embedding_ollama_model() { + local model + model="$(printf '%s' "$1" | tr '[:upper:]' '[:lower:]')" + + case "$model" in + *embed*|*embedding*|*nomic*|*mxbai*|*bge*|*e5*|*gte*|*minilm*|*snowflake-arctic*) + return 0 + ;; + *) + return 1 + ;; + esac +} + +ollama_model_candidates() { + local kind="$1" + shift + local model + local selected=() + + for model in "$@"; do + case "$kind" in + embeddings) + if looks_like_embedding_ollama_model "$model"; then + selected+=("$model") + fi + ;; + chat) + if ! looks_like_embedding_ollama_model "$model"; then + selected+=("$model") + fi + ;; + *) + selected+=("$model") + ;; + esac + done + + if [[ "${#selected[@]}" -eq 0 ]]; then + selected=("$@") + fi + + for model in "${selected[@]}"; do + printf '%s\n' "$model" + done +} + +ollama_api_bases_for_host() { + local base="${OLLAMA_BASE_URL_VALUE%/}" + base="${base%/v1}" + + [[ -n "$base" ]] || base="http://localhost:11434" + printf '%s\n' "$base" + + case "$base" in + *host.docker.internal*) + printf '%s\n' "${base//host.docker.internal/localhost}" + ;; + *0.0.0.0*) + printf '%s\n' "${base//0.0.0.0/localhost}" + ;; + esac +} + +list_ollama_models_from_cli_for_host() { + local host="${1:-}" + + command_exists ollama || return 0 + + if [[ -n "$host" ]]; then + OLLAMA_HOST="$host" ollama list 2>/dev/null | awk 'NR > 1 && $1 != "" { print $1 }' || true + else + ollama list 2>/dev/null | awk 'NR > 1 && $1 != "" { print $1 }' || true + fi +} + +list_ollama_models_from_cli() { + local base + + list_ollama_models_from_cli_for_host + + if [[ -n "$OLLAMA_BASE_URL_VALUE" ]]; then + while IFS= read -r base; do + [[ -n "$base" ]] || continue + list_ollama_models_from_cli_for_host "$base" + done < <(ollama_api_bases_for_host) + fi +} + +list_ollama_models_from_api() { + command_exists curl || return 0 + command_exists python3 || return 0 + + local base + local response + + while IFS= read -r base; do + [[ -n "$base" ]] || continue + response="$(curl -fsS --max-time 2 "${base%/}/api/tags" 2>/dev/null || true)" + [[ -n "$response" ]] || continue + + printf '%s' "$response" | python3 -c 'import json, sys +try: + data = json.load(sys.stdin) +except Exception: + raise SystemExit(0) +for model in data.get("models", []): + name = model.get("name") or model.get("model") + if name: + print(name) +' 2>/dev/null || true + done < <(ollama_api_bases_for_host) +} + +list_ollama_models() { + { + list_ollama_models_from_cli + list_ollama_models_from_api + } | awk 'NF && !seen[$0]++' +} + +ollama_model_name_matches() { + local installed="$1" + local target="$2" + + [[ "$installed" == "$target" ]] && return 0 + [[ "$target" != *:* && "$installed" == "$target:latest" ]] && return 0 + [[ "$installed" != *:* && "$target" == "$installed:latest" ]] && return 0 + + return 1 +} + +find_reachable_ollama_cli_host() { + local base + + command_exists ollama || return 1 + + if ollama list >/dev/null 2>&1; then + printf '\n' + return 0 + fi + + while IFS= read -r base; do + [[ -n "$base" ]] || continue + if OLLAMA_HOST="$base" ollama list >/dev/null 2>&1; then + printf '%s\n' "$base" + return 0 + fi + done < <(ollama_api_bases_for_host) + + return 1 +} + +ollama_model_available_via_cli_host() { + local host="$1" + local target="$2" + local model + + while IFS= read -r model; do + ollama_model_name_matches "$model" "$target" && return 0 + done < <(list_ollama_models_from_cli_for_host "$host") + + return 1 +} + +pull_ollama_model() { + local host="$1" + local model="$2" + local log_file + log_file="$(installer_log_file "ollama-pull-${model//\//-}")" + + if [[ -n "$host" ]]; then + run_with_spinner_logged "Downloading Ollama model $model" "$log_file" env OLLAMA_HOST="$host" ollama pull "$model" + else + run_with_spinner_logged "Downloading Ollama model $model" "$log_file" ollama pull "$model" + fi +} + +wait_for_ollama_service() { + local timeout="${1:-30}" + local deadline=$((SECONDS + timeout)) + + while (( SECONDS < deadline )); do + if find_reachable_ollama_cli_host >/dev/null 2>&1; then + return 0 + fi + sleep 2 + done + + return 1 +} + +start_ollama_service_if_possible() { + local command_text + local log_file="$INSTALL_DIR/ollama.log" + + say "Ollama service is not running" + + if [[ "$HW_OS" == "Darwin" ]] && command_exists open && [[ -d /Applications/Ollama.app ]]; then + command_text="$(command_to_text open -a Ollama)" + if confirm_install_command "Start the Ollama app now?" "$command_text"; then + open -a Ollama + wait_for_ollama_service 45 + return + fi + fi + + if command_exists brew; then + command_text="$(command_to_text brew services start ollama)" + if confirm_install_command "Start the Ollama service with Homebrew now?" "$command_text"; then + brew services start ollama + wait_for_ollama_service 45 + return + fi + fi + + command_text="$(command_to_text ollama serve) > $(printf '%q' "$log_file") 2>&1 &" + if confirm_install_command "Start Ollama in the background now?" "$command_text"; then + mkdir -p "$INSTALL_DIR" + nohup ollama serve > "$log_file" 2>&1 & + wait_for_ollama_service 45 + return + fi + + return 1 +} + +offer_single_ollama_model_download() { + local kind="$1" + local default_model="$2" + local selected_model="$3" + local processor_label="$4" + local cli_host="$5" + local question + + say "Preparing Ollama $kind model" + info "TrustGraph's Ollama $processor_label default is $default_model." + + if ollama_model_available_via_cli_host "$cli_host" "$selected_model"; then + info "Ollama $kind model already available: $selected_model" + return 0 + fi + + if [[ "$selected_model" == "$default_model" ]]; then + question="Download TrustGraph's preferred Ollama $kind model ($selected_model) now?" + else + question="Download the selected Ollama $kind model ($selected_model) now?" + fi + + if confirm "$question" 1; then + info "Downloading $selected_model with Ollama. This may take a while." + if ! pull_ollama_model "$cli_host" "$selected_model"; then + die "Ollama could not download $selected_model. Try running: ollama pull $selected_model" + fi + else + warn "Skipping Ollama $kind model download. TrustGraph's Ollama processor will try to pull $selected_model on first use." + fi +} + +offer_ollama_model_downloads() { + local cli_host + + [[ "$LLM_MODE" == "ollama" ]] || return 0 + + if ! command_exists ollama; then + warn "Ollama was selected, but the ollama CLI was not found. Install Ollama and run: ollama pull $OLLAMA_MODEL && ollama pull $OLLAMA_EMBEDDINGS_MODEL" + return 0 + fi + + if ! cli_host="$(find_reachable_ollama_cli_host)"; then + start_ollama_service_if_possible || true + if ! cli_host="$(find_reachable_ollama_cli_host)"; then + warn "Ollama CLI is installed, but the Ollama service is not reachable. Start Ollama and run: ollama pull $OLLAMA_MODEL && ollama pull $OLLAMA_EMBEDDINGS_MODEL" + return 0 + fi + fi + + if [[ -n "$cli_host" ]]; then + info "Ollama service: $cli_host" + else + info "Ollama service: local Ollama default" + fi + + offer_single_ollama_model_download \ + "chat" \ + "$DEFAULT_OLLAMA_MODEL" \ + "$OLLAMA_MODEL" \ + "text-completion" \ + "$cli_host" + + offer_single_ollama_model_download \ + "embeddings" \ + "$DEFAULT_OLLAMA_EMBEDDINGS_MODEL" \ + "$OLLAMA_EMBEDDINGS_MODEL" \ + "embeddings" \ + "$cli_host" +} + +prompt_ollama_model_choice() { + local label="$1" + local default="$2" + local kind="$3" + local helper="$4" + shift 4 + local all_models=("$@") + local candidates=() + local options=() + local model + local answer + local idx + local found_default=0 + local detected_default="" + + if [[ "$NON_INTERACTIVE" -eq 1 ]]; then + printf '%s\n' "$default" + return + fi + + if [[ -n "$helper" ]]; then + printf ' %s\n' "$helper" >&2 + fi + + if [[ "${#all_models[@]}" -eq 0 ]]; then + prompt_value \ + "$label" \ + "$default" \ + "No local Ollama models were detected. Pull the recommended default with: ollama pull $default" + return + fi + + while IFS= read -r model; do + [[ -n "$model" ]] && candidates+=("$model") + done < <(ollama_model_candidates "$kind" "${all_models[@]}") + + if [[ "${#candidates[@]}" -eq 0 ]]; then + candidates=("${all_models[@]}") + fi + + for model in "${candidates[@]}"; do + if ollama_model_name_matches "$model" "$default"; then + found_default=1 + detected_default="$model" + break + fi + done + + options+=("$default") + for model in "${candidates[@]}"; do + ollama_model_name_matches "$model" "$default" && continue + options+=("$model") + done + + say "Local Ollama ${kind} model choices" >&2 + if [[ "$found_default" -eq 1 ]]; then + if [[ "$detected_default" != "$default" ]]; then + info "1) $default (recommended, detected as $detected_default)" >&2 + else + info "1) $default (recommended, detected)" >&2 + fi + else + info "1) $default (recommended default, not detected locally)" >&2 + fi + + idx=2 + for model in "${options[@]:1}"; do + info "$idx) $model" >&2 + idx=$((idx + 1)) + done + if [[ "$found_default" -eq 0 ]]; then + info "If you choose a missing model, the installer will offer to download it before startup." >&2 + fi + info "Or type another model name, for example one you plan to pull before startup." >&2 + + read -r -p "$label [1: $default]: " answer + answer="${answer:-1}" + + if [[ "$answer" =~ ^[0-9]+$ ]]; then + if (( answer >= 1 && answer <= ${#options[@]} )); then + printf '%s\n' "${options[$((answer - 1))]}" + return + fi + warn "Selection '$answer' is not in the list; using $default." + printf '%s\n' "$default" + return + fi + + printf '%s\n' "$answer" +} + +prompt_secret() { + local label="$1" + local default="$2" + local helper="$3" + local answer="" + local masked="${4:-}" + + if [[ -z "$masked" && -n "$default" ]]; then + masked="set in environment" + elif [[ -z "$masked" ]]; then + masked="blank" + fi + + if [[ -n "$helper" ]]; then + printf ' %s\n' "$helper" >&2 + fi + + if [[ "$NON_INTERACTIVE" -eq 1 ]]; then + printf '%s\n' "$default" + return + fi + + read -r -s -p "$label [$masked]: " answer + printf '\n' >&2 + printf '%s\n' "${answer:-$default}" +} + +confirm() { + local question="$1" + local default_yes="$2" + local answer="" + local prompt="[y/N]" + + if [[ "$YES" -eq 1 ]]; then + return 0 + fi + + if [[ "$NON_INTERACTIVE" -eq 1 ]]; then + [[ "$default_yes" -eq 1 ]] + return + fi + + if [[ "$default_yes" -eq 1 ]]; then + prompt="[Y/n]" + fi + + read -r -p "$question $prompt " answer + answer="${answer:-}" + if [[ -z "$answer" ]]; then + [[ "$default_yes" -eq 1 ]] + return + fi + [[ "$answer" =~ ^[Yy] ]] +} + +path_within_install_dir() { + local path="$1" + case "$path" in + "$INSTALL_DIR"/*) return 0 ;; + *) return 1 ;; + esac +} + +safe_existing_path_within_install_dir() { + local path="$1" + local resolved_install + local resolved_parent + local resolved_path + local parent + local base + + [[ -d "$INSTALL_DIR" ]] || return 1 + [[ -e "$path" || -L "$path" ]] || return 1 + + resolved_install="$(cd "$INSTALL_DIR" && pwd -P)" + parent="$(dirname "$path")" + base="$(basename "$path")" + [[ -d "$parent" ]] || return 1 + resolved_parent="$(cd "$parent" && pwd -P)" + resolved_path="$resolved_parent/$base" + + case "$resolved_path" in + "$resolved_install"/*) return 0 ;; + *) return 1 ;; + esac +} + +installer_artifact_paths() { + local candidates=( + "$INSTALL_DIR/deploy.zip" + "$INSTALL_DIR/deploy" + "$INSTALL_DIR/INSTALLATION.md" + "$INSTALL_DIR/trustgraph-installer.env" + "$INSTALL_DIR/iam-bootstrap.log" + "$INSTALL_DIR/ollama.log" + "$INSTALL_DIR/logs" + "$INSTALL_DIR/pip_cache" + ) + local path + + for path in "$VENV_DIR" "$NLTK_DATA_DIR" "$TIKTOKEN_CACHE_DIR_VALUE"; do + if path_within_install_dir "$path"; then + candidates+=("$path") + fi + done + + for path in "${candidates[@]}"; do + if [[ -e "$path" || -L "$path" ]]; then + printf '%s\n' "$path" + fi + done +} + +installer_artifacts_present() { + local path + while IFS= read -r path; do + [[ -n "$path" ]] && return 0 + done < <(installer_artifact_paths) + return 1 +} + +assert_safe_cleanup_target() { + local resolved_install + local resolved_script + local resolved_home="" + + [[ -n "$INSTALL_DIR" ]] || die "Install directory is empty; refusing cleanup." + case "$INSTALL_DIR" in + /|.|..) die "Install directory '$INSTALL_DIR' is too broad; refusing cleanup." ;; + esac + + if [[ -d "$INSTALL_DIR" ]]; then + resolved_install="$(cd "$INSTALL_DIR" && pwd -P)" + else + return 0 + fi + resolved_script="$(cd "$SCRIPT_DIR" && pwd -P)" + if [[ -n "${HOME:-}" && -d "$HOME" ]]; then + resolved_home="$(cd "$HOME" && pwd -P)" + fi + + [[ "$resolved_install" != "$resolved_script" ]] || die "Install directory resolves to the source checkout; refusing cleanup." + [[ -z "$resolved_home" || "$resolved_install" != "$resolved_home" ]] || die "Install directory resolves to your home directory; refusing cleanup." +} + +find_existing_compose_file() { + [[ -d "$INSTALL_DIR" ]] || return 0 + + find "$INSTALL_DIR/deploy" "$INSTALL_DIR" \ + \( -name 'docker-compose.yaml' -o -name 'docker-compose.yml' -o -name 'compose.yaml' -o -name 'compose.yml' \) \ + -type f 2>/dev/null | head -n 1 +} + +stop_previous_stack_if_possible() { + local compose_file + compose_file="$(find_existing_compose_file || true)" + [[ -n "$compose_file" ]] || return 0 + + if ! confirm "Stop any containers from the previous deployment first? Docker volumes will be kept." 1; then + info "Leaving any previous containers untouched." + return + fi + + if ! detect_compose_command; then + warn "Could not find Docker Compose or podman-compose, so previous containers were not stopped." + return + fi + + if "${COMPOSE_CMD[@]}" -f "$compose_file" down --remove-orphans; then + info "Stopped previous compose deployment using $compose_file" + else + warn "Could not stop previous compose deployment. Continuing with file cleanup only." + fi +} + +remove_installer_artifacts_only() { + local path + + say "Removing installer-managed files" + while IFS= read -r path; do + [[ -n "$path" ]] || continue + if ! safe_existing_path_within_install_dir "$path"; then + warn "Skipping cleanup path outside install directory: $path" + continue + fi + info "Removing $path" + rm -rf -- "$path" + done < <(installer_artifact_paths) + + rmdir "$INSTALL_DIR" 2>/dev/null || true +} + +cleanup_installer_artifacts() { + assert_safe_cleanup_target + stop_previous_stack_if_possible + remove_installer_artifacts_only +} + +find_uninstall_compose_file() { + if [[ -n "$USE_EXISTING_COMPOSE" ]]; then + [[ -f "$USE_EXISTING_COMPOSE" ]] || die "Compose file does not exist: $USE_EXISTING_COMPOSE" + printf '%s\n' "$USE_EXISTING_COMPOSE" + return + fi + + find_existing_compose_file || true +} + +print_uninstall_plan() { + local compose_file="$1" + local path + local found=0 + + say "Uninstall plan" + info "Install directory: $INSTALL_DIR" + if [[ -n "$compose_file" ]]; then + info "Compose file: $compose_file" + info "Will stop containers and remove compose-managed volumes for this deployment." + else + info "No compose file was found, so no containers or volumes can be removed automatically." + fi + + info "Installer-managed files to remove:" + while IFS= read -r path; do + [[ -n "$path" ]] || continue + found=1 + info "- $path" + done < <(installer_artifact_paths) + if [[ "$found" -eq 0 ]]; then + info "- none found" + fi + + info "Will not remove Docker/Podman, container images, external volumes, Ollama, Ollama models, or this source checkout." +} + +remove_compose_stack_for_uninstall() { + local compose_file="$1" + + [[ -n "$compose_file" ]] || return 0 + + say "Stopping TrustGraph containers and volumes" + if ! detect_compose_command; then + warn "Could not find Docker Compose or podman-compose, so containers and volumes were not removed." + return + fi + + if "${COMPOSE_CMD[@]}" -f "$compose_file" down --remove-orphans --volumes; then + info "Removed compose containers, networks, and compose-managed volumes." + return + fi + + warn "Compose did not accept the volume removal command; trying to stop containers without removing volumes." + if "${COMPOSE_CMD[@]}" -f "$compose_file" down --remove-orphans; then + warn "Containers were stopped, but compose volumes may remain." + else + warn "Could not stop the compose deployment. Installer-managed files will still be removed." + fi +} + +remove_all_installation() { + local compose_file + + assert_safe_cleanup_target + compose_file="$(find_uninstall_compose_file || true)" + + print_uninstall_plan "$compose_file" + + if [[ "$DRY_RUN" -eq 1 ]]; then + say "Dry run complete" + return 0 + fi + + if [[ -z "$compose_file" ]] && ! installer_artifacts_present; then + say "Nothing to remove" + info "No installer-managed files or compose deployment were found." + return 0 + fi + + if ! confirm "Remove the TrustGraph deployment listed above?" 0; then + die "Uninstall cancelled." + fi + + remove_compose_stack_for_uninstall "$compose_file" + remove_installer_artifacts_only + + say "TrustGraph installer-managed deployment removed" + info "Ollama models were left in place because they may be shared with other tools." +} + +handle_existing_install() { + local path + local found=0 + + [[ -z "$USE_EXISTING_COMPOSE" ]] || return 0 + installer_artifacts_present || return 0 + + say "Existing installer output detected" + info "Install directory: $INSTALL_DIR" + while IFS= read -r path; do + [[ -n "$path" ]] || continue + found=1 + info "Found: $path" + done < <(installer_artifact_paths) + + [[ "$found" -eq 1 ]] || return 0 + + if [[ "$DRY_RUN" -eq 1 ]]; then + if [[ "$FRESH_INSTALL" -eq 1 ]]; then + info "Dry run: --fresh would remove the files listed above." + else + info "Dry run: existing files would be kept unless you choose --fresh." + fi + return 0 + fi + + if [[ "$FRESH_INSTALL" -eq 1 ]]; then + cleanup_installer_artifacts + return 0 + fi + + if confirm "Treat this as a fresh install and delete only the installer-managed files listed above?" 0; then + cleanup_installer_artifacts + else + info "Continuing with the existing installer output." + fi +} + +load_saved_answers() { + local env_file="$INSTALL_DIR/trustgraph-installer.env" + + [[ -f "$env_file" ]] || return 0 + + local current_api_url="$API_URL" + local current_ui_url="$UI_URL" + local current_auth_token="$AUTH_TOKEN" + local current_venv_dir="$VENV_DIR" + local current_nltk_data_dir="$NLTK_DATA_DIR" + local current_tiktoken_cache_dir="$TIKTOKEN_CACHE_DIR_VALUE" + local current_llm_mode="$LLM_MODE" + local current_openai_base_url="$OPENAI_BASE_URL_VALUE" + local current_openai_token="$OPENAI_TOKEN_VALUE" + local current_ollama_base_url="$OLLAMA_BASE_URL_VALUE" + local current_ollama_model="$OLLAMA_MODEL" + local current_ollama_embeddings_model="$OLLAMA_EMBEDDINGS_MODEL" + + # The file is generated by this installer with shell-escaped exports and 0600 permissions. + # shellcheck disable=SC1090 + source "$env_file" + + if [[ "$current_api_url" == "$DEFAULT_API_URL" && -n "${TRUSTGRAPH_URL:-}" ]]; then + API_URL="$TRUSTGRAPH_URL" + else + API_URL="$current_api_url" + fi + + if [[ "$current_ui_url" == "$DEFAULT_UI_URL" && -n "${TRUSTGRAPH_UI_URL:-}" ]]; then + UI_URL="$TRUSTGRAPH_UI_URL" + else + UI_URL="$current_ui_url" + fi + + if [[ -z "$current_auth_token" && -n "${TRUSTGRAPH_TOKEN:-}" ]]; then + AUTH_TOKEN="$TRUSTGRAPH_TOKEN" + elif [[ -z "$current_auth_token" && -n "${IAM_BOOTSTRAP_TOKEN:-}" ]]; then + AUTH_TOKEN="$IAM_BOOTSTRAP_TOKEN" + else + AUTH_TOKEN="$current_auth_token" + fi + + if [[ -z "${TG_VENV_DIR:-}" && -n "${current_venv_dir:-}" ]]; then + VENV_DIR="$current_venv_dir" + elif [[ -n "${TG_VENV_DIR:-}" ]]; then + VENV_DIR="$TG_VENV_DIR" + fi + + if [[ -z "${TG_NLTK_DATA_DIR:-}" && -n "$current_nltk_data_dir" ]]; then + NLTK_DATA_DIR="$current_nltk_data_dir" + elif [[ -n "${TG_NLTK_DATA_DIR:-}" ]]; then + NLTK_DATA_DIR="$TG_NLTK_DATA_DIR" + fi + + if [[ -z "${TIKTOKEN_CACHE_DIR:-}" && -n "$current_tiktoken_cache_dir" ]]; then + TIKTOKEN_CACHE_DIR_VALUE="$current_tiktoken_cache_dir" + elif [[ -n "${TIKTOKEN_CACHE_DIR:-}" ]]; then + TIKTOKEN_CACHE_DIR_VALUE="$TIKTOKEN_CACHE_DIR" + fi + + if [[ -z "$current_llm_mode" && -n "${TRUSTGRAPH_LLM_MODE:-}" ]]; then + LLM_MODE="$TRUSTGRAPH_LLM_MODE" + else + LLM_MODE="$current_llm_mode" + fi + + if [[ "$current_openai_base_url" == "https://api.openai.com/v1" && -n "${OPENAI_BASE_URL:-}" ]]; then + OPENAI_BASE_URL_VALUE="$OPENAI_BASE_URL" + else + OPENAI_BASE_URL_VALUE="$current_openai_base_url" + fi + + if [[ -z "$current_openai_token" && -n "${OPENAI_TOKEN:-}" ]]; then + OPENAI_TOKEN_VALUE="$OPENAI_TOKEN" + else + OPENAI_TOKEN_VALUE="$current_openai_token" + fi + + if [[ -z "$current_ollama_base_url" ]]; then + OLLAMA_BASE_URL_VALUE="${OLLAMA_HOST:-${OLLAMA_BASE_URL:-}}" + else + OLLAMA_BASE_URL_VALUE="$current_ollama_base_url" + fi + + if [[ "$current_ollama_model" == "$DEFAULT_OLLAMA_MODEL" && -n "${OLLAMA_MODEL:-}" ]]; then + OLLAMA_MODEL="$OLLAMA_MODEL" + else + OLLAMA_MODEL="$current_ollama_model" + fi + + if [[ "$current_ollama_embeddings_model" == "$DEFAULT_OLLAMA_EMBEDDINGS_MODEL" && -n "${OLLAMA_EMBEDDINGS_MODEL:-}" ]]; then + OLLAMA_EMBEDDINGS_MODEL="$OLLAMA_EMBEDDINGS_MODEL" + else + OLLAMA_EMBEDDINGS_MODEL="$current_ollama_embeddings_model" + fi + + case "$API_URL" in + */) ;; + *) API_URL="$API_URL/" ;; + esac + + info "Loaded saved answers from $env_file" +} + +bytes_to_gb() { + local bytes="$1" + awk "BEGIN { printf \"%.0f\", $bytes / 1024 / 1024 / 1024 }" +} + +detect_hardware() { + HW_OS="$(uname -s 2>/dev/null || printf 'unknown')" + HW_ARCH="$(uname -m 2>/dev/null || printf 'unknown')" + + if [[ "$HW_OS" == "Darwin" ]]; then + HW_CPU_CORES="$(sysctl -n hw.logicalcpu 2>/dev/null || getconf _NPROCESSORS_ONLN 2>/dev/null || python3 -c 'import os; print(os.cpu_count() or "unknown")' 2>/dev/null || printf 'unknown')" + local mem_bytes + mem_bytes="$(sysctl -n hw.memsize 2>/dev/null || true)" + if [[ -z "$mem_bytes" ]] && command_exists python3; then + mem_bytes="$(python3 -c 'import os; print(os.sysconf("SC_PHYS_PAGES") * os.sysconf("SC_PAGE_SIZE"))' 2>/dev/null || true)" + fi + if [[ -n "$mem_bytes" ]]; then + HW_MEMORY_GB="$(bytes_to_gb "$mem_bytes")" + fi + if [[ "$HW_ARCH" == "arm64" ]]; then + HW_GPU="Apple Silicon unified GPU" + fi + HW_CONTAINER_HINT="Docker Desktop or Podman Desktop works well on macOS." + elif [[ "$HW_OS" == "Linux" ]]; then + HW_CPU_CORES="$(nproc 2>/dev/null || getconf _NPROCESSORS_ONLN 2>/dev/null || python3 -c 'import os; print(os.cpu_count() or "unknown")' 2>/dev/null || printf 'unknown')" + if [[ -r /proc/meminfo ]]; then + local mem_kb + mem_kb="$(awk '/MemTotal/ { print $2 }' /proc/meminfo)" + if [[ -n "$mem_kb" ]]; then + HW_MEMORY_GB="$(awk "BEGIN { printf \"%.0f\", $mem_kb / 1024 / 1024 }")" + fi + fi + if command_exists nvidia-smi; then + HW_GPU="$(nvidia-smi --query-gpu=name,memory.total --format=csv,noheader 2>/dev/null | head -n 1 || true)" + [[ -n "$HW_GPU" ]] || HW_GPU="NVIDIA GPU detected" + elif command_exists lspci; then + HW_GPU="$(lspci 2>/dev/null | awk 'BEGIN{IGNORECASE=1} /VGA|3D|Display/ {print; exit}')" + [[ -n "$HW_GPU" ]] || HW_GPU="none detected" + fi + HW_CONTAINER_HINT="Docker Engine, Docker Desktop, or Podman can run the compose stack." + else + HW_CONTAINER_HINT="Use Docker or Podman with compose support." + fi +} + +is_number() { + [[ "$1" =~ ^[0-9]+$ ]] +} + +choose_recommendations() { + local mem=0 + local cores=0 + + if is_number "$HW_MEMORY_GB"; then + mem="$HW_MEMORY_GB" + fi + if is_number "$HW_CPU_CORES"; then + cores="$HW_CPU_CORES" + fi + + if [[ -z "$OLLAMA_BASE_URL_VALUE" ]]; then + if [[ "$HW_OS" == "Darwin" ]]; then + OLLAMA_BASE_URL_VALUE="http://host.docker.internal:11434" + else + OLLAMA_BASE_URL_VALUE="http://localhost:11434" + fi + fi + + if [[ -n "$LLM_MODE" ]]; then + RECOMMENDED_LLM_MODE="$LLM_MODE" + RECOMMENDATION_REASON="Using the LLM provider saved from the previous installer run." + return + fi + + if (( mem >= 16 )) && { [[ "$HW_GPU" != "none detected" ]] || (( cores >= 8 )); }; then + RECOMMENDED_LLM_MODE="ollama" + RECOMMENDATION_REASON="This machine looks comfortable for a small local Ollama model." + elif (( mem >= 8 )); then + RECOMMENDED_LLM_MODE="openai" + RECOMMENDATION_REASON="Local Ollama may work with a small model, but a hosted OpenAI-compatible endpoint is smoother on this hardware." + else + RECOMMENDED_LLM_MODE="openai" + RECOMMENDATION_REASON="Memory looks tight for local LLMs, so a hosted OpenAI-compatible endpoint is the friendlier default." + fi + + if [[ -n "${OPENAI_TOKEN:-}" && "${OPENAI_TOKEN:-}" != "ollama" ]]; then + RECOMMENDED_LLM_MODE="openai" + RECOMMENDATION_REASON="OPENAI_TOKEN is already set, so the hosted/OpenAI-compatible path is ready to use." + fi +} + +print_hardware_summary() { + say "Detected hardware" + info "OS: $HW_OS" + info "Architecture: $HW_ARCH" + info "CPU cores: $HW_CPU_CORES" + info "Memory: $HW_MEMORY_GB GB" + info "GPU: $HW_GPU" + info "$HW_CONTAINER_HINT" + + say "Recommended install shape" + info "LLM path: $RECOMMENDED_LLM_MODE" + info "$RECOMMENDATION_REASON" + info "Default Workbench UI: $UI_URL" + info "Default API gateway: $API_URL" +} + +generate_token() { + if command_exists openssl; then + printf 'tg_%s\n' "$(openssl rand -base64 24 | tr '+/' '-_' | tr -d '=')" + elif command_exists python3; then + python3 -c 'import secrets; print("tg_" + secrets.token_urlsafe(24))' + else + die "Need openssl or python3 to generate a secure TrustGraph API key." + fi +} + +ensure_compliant_api_key() { + local token="$1" + + if [[ "$token" == tg_* ]]; then + printf '%s\n' "$token" + return + fi + + warn "TrustGraph API keys must start with 'tg_'; the provided value will not authenticate at the gateway." + + if [[ "$NON_INTERACTIVE" -eq 1 ]]; then + warn "Non-interactive mode: replacing the non-compliant key with a generated TrustGraph API key." + generate_token + return + fi + + if confirm "Generate a compliant TrustGraph API key now?" 1; then + generate_token + return + fi + + die "TrustGraph API key must start with 'tg_'." +} + +collect_answers() { + local generated_token + local token_default + local token_mask + generated_token="$(generate_token)" + if [[ -n "$AUTH_TOKEN" ]]; then + token_default="$AUTH_TOKEN" + token_mask="set in environment" + else + token_default="$generated_token" + token_mask="generated tg_ key" + fi + AUTH_TOKEN="$(prompt_secret \ + "TrustGraph admin/bootstrap API key" \ + "$token_default" \ + "Recommendation: press Enter to use a generated TrustGraph API key beginning with tg_; it will be stored in the installer env file with restricted permissions." \ + "$token_mask")" + AUTH_TOKEN="$(ensure_compliant_api_key "$AUTH_TOKEN")" + + LLM_MODE="$(prompt_value \ + "LLM provider: ollama, openai, or none" \ + "${LLM_MODE:-$RECOMMENDED_LLM_MODE}" \ + "Recommendation: $RECOMMENDED_LLM_MODE. $RECOMMENDATION_REASON")" + LLM_MODE="$(printf '%s' "$LLM_MODE" | tr '[:upper:]' '[:lower:]')" + + case "$LLM_MODE" in + ollama) + OLLAMA_BASE_URL_VALUE="$(prompt_value \ + "Ollama base URL" \ + "$OLLAMA_BASE_URL_VALUE" \ + "If Ollama runs on your laptop and TrustGraph runs in Docker, host.docker.internal is usually the right host on macOS/Windows.")" + local ollama_models=() + local ollama_model + if [[ "$NON_INTERACTIVE" -ne 1 ]]; then + while IFS= read -r ollama_model; do + [[ -n "$ollama_model" ]] && ollama_models+=("$ollama_model") + done < <(list_ollama_models) + fi + if [[ "${#ollama_models[@]}" -gt 0 ]]; then + OLLAMA_MODEL="$(prompt_ollama_model_choice \ + "Ollama chat model" \ + "$OLLAMA_MODEL" \ + "chat" \ + "Recommendation from the local Ollama processor defaults: $DEFAULT_OLLAMA_MODEL for a quick first run." \ + "${ollama_models[@]}")" + OLLAMA_EMBEDDINGS_MODEL="$(prompt_ollama_model_choice \ + "Ollama embeddings model" \ + "$OLLAMA_EMBEDDINGS_MODEL" \ + "embeddings" \ + "Recommendation from the local Ollama embeddings defaults: $DEFAULT_OLLAMA_EMBEDDINGS_MODEL." \ + "${ollama_models[@]}")" + else + OLLAMA_MODEL="$(prompt_ollama_model_choice \ + "Ollama chat model" \ + "$OLLAMA_MODEL" \ + "chat" \ + "Recommendation from the local Ollama processor defaults: $DEFAULT_OLLAMA_MODEL for a quick first run.")" + OLLAMA_EMBEDDINGS_MODEL="$(prompt_ollama_model_choice \ + "Ollama embeddings model" \ + "$OLLAMA_EMBEDDINGS_MODEL" \ + "embeddings" \ + "Recommendation from the local Ollama embeddings defaults: $DEFAULT_OLLAMA_EMBEDDINGS_MODEL.")" + fi + OPENAI_BASE_URL_VALUE="${OLLAMA_BASE_URL_VALUE%/}/v1" + OPENAI_TOKEN_VALUE="${OPENAI_TOKEN_VALUE:-ollama}" + ;; + openai) + OPENAI_BASE_URL_VALUE="$(prompt_value \ + "OpenAI-compatible base URL" \ + "$OPENAI_BASE_URL_VALUE" \ + "Use https://api.openai.com/v1 for OpenAI, or your provider's OpenAI-compatible /v1 endpoint.")" + OPENAI_TOKEN_VALUE="$(prompt_secret \ + "OpenAI-compatible API key" \ + "$OPENAI_TOKEN_VALUE" \ + "Press Enter to reuse OPENAI_TOKEN if set; leave blank only if your endpoint does not require a key.")" + ;; + none|skip) + LLM_MODE="none" + warn "Continuing without an LLM key. The platform can start, but agent/RAG calls will need an LLM configured later." + ;; + *) + warn "Unknown LLM provider '$LLM_MODE'; using '$RECOMMENDED_LLM_MODE'." + LLM_MODE="$RECOMMENDED_LLM_MODE" + ;; + esac + + INSTALL_DIR="$(prompt_value \ + "Installer output directory" \ + "$INSTALL_DIR" \ + "This keeps deploy.zip, compose files, logs, and saved answers together.")" + + if [[ -z "${TG_VENV_DIR:-}" ]]; then + VENV_DIR="$INSTALL_DIR/.venv" + fi + if [[ -z "${TG_NLTK_DATA_DIR:-}" ]]; then + NLTK_DATA_DIR="$INSTALL_DIR/nltk_data" + fi + if [[ -z "${TIKTOKEN_CACHE_DIR:-}" ]]; then + TIKTOKEN_CACHE_DIR_VALUE="$INSTALL_DIR/tiktoken_cache" + fi +} + +print_plan_summary() { + say "Install plan" + info "Install directory: $INSTALL_DIR" + info "Python venv: $VENV_DIR" + info "NLTK data: $NLTK_DATA_DIR" + info "Tokenizer cache: $TIKTOKEN_CACHE_DIR_VALUE" + info "Run all tests: $([[ "$RUN_TESTS" -eq 1 ]] && printf yes || printf no)" + if [[ -n "$USE_EXISTING_COMPOSE" ]]; then + info "Compose file: $USE_EXISTING_COMPOSE" + else + info "Config generator: npx @trustgraph/config" + fi + info "LLM provider: $LLM_MODE" + if [[ "$LLM_MODE" == "ollama" ]]; then + info "Ollama URL: $OLLAMA_BASE_URL_VALUE" + info "Ollama model: $OLLAMA_MODEL" + info "Ollama embeddings model: $OLLAMA_EMBEDDINGS_MODEL" + elif [[ "$LLM_MODE" == "openai" ]]; then + info "OpenAI-compatible URL: $OPENAI_BASE_URL_VALUE" + fi + info "Health check timeout: ${HEALTH_TIMEOUT}s" + info "Autolaunch UI: $([[ "$AUTO_LAUNCH" -eq 1 ]] && printf yes || printf no)" +} + +detect_compose_command() { + if command_exists docker && docker compose version >/dev/null 2>&1; then + COMPOSE_CMD=(docker compose) + elif command_exists docker-compose; then + COMPOSE_CMD=(docker-compose) + elif command_exists podman-compose; then + COMPOSE_CMD=(podman-compose) + else + return 1 + fi +} + +wait_for_docker_ready() { + local timeout="${1:-60}" + local deadline=$((SECONDS + timeout)) + + while (( SECONDS < deadline )); do + if docker info >/dev/null 2>&1; then + return 0 + fi + sleep 2 + done + + return 1 +} + +wait_for_podman_ready() { + local timeout="${1:-60}" + local deadline=$((SECONDS + timeout)) + + while (( SECONDS < deadline )); do + if podman info >/dev/null 2>&1; then + return 0 + fi + sleep 2 + done + + return 1 +} + +start_docker_runtime_if_possible() { + local command_text + + say "Docker is installed but not running" + + if [[ "$HW_OS" == "Darwin" ]] && command_exists open && [[ -d /Applications/Docker.app ]]; then + command_text="$(command_to_text open -a Docker)" + if confirm_install_command "Start Docker Desktop now?" "$command_text"; then + open -a Docker + wait_for_docker_ready 90 + return + fi + fi + + if command_exists systemctl; then + command_text="$(root_command_to_text systemctl start docker)" + if confirm_install_command "Start the Docker service now?" "$command_text"; then + run_root_command systemctl start docker + wait_for_docker_ready 60 + return + fi + fi + + return 1 +} + +start_podman_runtime_if_possible() { + local command_text + + say "Podman is installed but not running" + + if [[ "$HW_OS" == "Darwin" ]] && command_exists podman; then + command_text="$(command_to_text podman machine init) && $(command_to_text podman machine start)" + if confirm_install_command "Start a local Podman machine now?" "$command_text"; then + podman machine init >/dev/null 2>&1 || true + podman machine start + wait_for_podman_ready 90 + return + fi + fi + + if command_exists systemctl; then + command_text="$(command_to_text systemctl --user start podman.socket)" + if confirm_install_command "Start the user Podman socket now?" "$command_text"; then + systemctl --user start podman.socket + wait_for_podman_ready 30 + return + fi + fi + + return 1 +} + +check_container_runtime_ready() { + case "${COMPOSE_CMD[0]}" in + docker|docker-compose) + if ! docker info >/dev/null 2>&1; then + start_docker_runtime_if_possible || true + docker info >/dev/null 2>&1 || die "Docker is installed, but the Docker daemon is not reachable. Start Docker Desktop or Docker Engine and run this installer again." + fi + ;; + podman-compose) + if ! podman info >/dev/null 2>&1; then + start_podman_runtime_if_possible || true + podman info >/dev/null 2>&1 || die "Podman is installed, but the Podman service is not reachable. Start Podman Desktop or the Podman machine and run this installer again." + fi + ;; + esac +} + +install_with_brew() { + local label="$1" + shift + local command_text + local log_file + command_text="$(command_to_text brew install "$@")" + log_file="$(installer_log_file "brew-install-${label// /-}")" + + if confirm_install_command "Install $label with Homebrew now?" "$command_text"; then + run_with_spinner_logged "Installing $label with Homebrew" "$log_file" brew install "$@" + else + return 1 + fi +} + +install_with_apt() { + local label="$1" + shift + local command_text + command_text="$(root_command_to_text apt-get update) && $(root_command_to_text apt-get install -y "$@")" + + if confirm_install_command "Install $label with apt now?" "$command_text"; then + run_root_command apt-get update + run_root_command apt-get install -y "$@" + else + return 1 + fi +} + +install_with_dnf() { + local label="$1" + shift + local command_text + command_text="$(root_command_to_text dnf install -y "$@")" + + if confirm_install_command "Install $label with dnf now?" "$command_text"; then + run_root_command dnf install -y "$@" + else + return 1 + fi +} + +install_with_yum() { + local label="$1" + shift + local command_text + command_text="$(root_command_to_text yum install -y "$@")" + + if confirm_install_command "Install $label with yum now?" "$command_text"; then + run_root_command yum install -y "$@" + else + return 1 + fi +} + +install_with_pacman() { + local label="$1" + shift + local command_text + command_text="$(root_command_to_text pacman -Sy --noconfirm "$@")" + + if confirm_install_command "Install $label with pacman now?" "$command_text"; then + run_root_command pacman -Sy --noconfirm "$@" + else + return 1 + fi +} + +install_with_zypper() { + local label="$1" + shift + local command_text + command_text="$(root_command_to_text zypper install -y "$@")" + + if confirm_install_command "Install $label with zypper now?" "$command_text"; then + run_root_command zypper install -y "$@" + else + return 1 + fi +} + +install_python3_prerequisite() { + if command_exists brew; then + install_with_brew "Python 3" python + elif command_exists apt-get; then + install_with_apt "Python 3" python3 python3-venv python3-pip + elif command_exists dnf; then + install_with_dnf "Python 3" python3 python3-pip + elif command_exists yum; then + install_with_yum "Python 3" python3 python3-pip + elif command_exists pacman; then + install_with_pacman "Python 3" python + elif command_exists zypper; then + install_with_zypper "Python 3" python3 python3-pip python3-venv + else + warn "No supported package manager was found. Install Python 3 manually, then run this installer again." + return 1 + fi +} + +install_python_venv_prerequisite() { + if command_exists apt-get; then + install_with_apt "Python venv support" python3-venv + elif command_exists zypper; then + install_with_zypper "Python venv support" python3-venv + elif command_exists brew || command_exists dnf || command_exists yum || command_exists pacman; then + info "Python venv support is usually bundled with the Python package on this platform." + return 1 + else + warn "Install Python's venv module manually, then run this installer again." + return 1 + fi +} + +install_basic_tool_prerequisite() { + local tool="$1" + + if command_exists brew; then + install_with_brew "$tool" "$tool" + elif command_exists apt-get; then + install_with_apt "$tool" "$tool" + elif command_exists dnf; then + install_with_dnf "$tool" "$tool" + elif command_exists yum; then + install_with_yum "$tool" "$tool" + elif command_exists pacman; then + install_with_pacman "$tool" "$tool" + elif command_exists zypper; then + install_with_zypper "$tool" "$tool" + else + warn "No supported package manager was found. Install $tool manually, then run this installer again." + return 1 + fi +} + +install_node_prerequisite() { + if command_exists brew; then + install_with_brew "Node.js and npx" node + elif command_exists apt-get; then + install_with_apt "Node.js and npx" nodejs npm + elif command_exists dnf; then + install_with_dnf "Node.js and npx" nodejs npm + elif command_exists yum; then + install_with_yum "Node.js and npx" nodejs npm + elif command_exists pacman; then + install_with_pacman "Node.js and npx" nodejs npm + elif command_exists zypper; then + install_with_zypper "Node.js and npx" nodejs npm + else + warn "No supported package manager was found. Install Node.js/npm manually, then run this installer again." + return 1 + fi +} + +start_podman_machine_if_needed() { + [[ "$HW_OS" == "Darwin" ]] || return 0 + command_exists podman || return 0 + + if podman info >/dev/null 2>&1; then + return 0 + fi + + if ! confirm_install_command \ + "Start a local Podman machine now?" \ + "$(command_to_text podman machine init) && $(command_to_text podman machine start)"; then + return 1 + fi + + podman machine init >/dev/null 2>&1 || true + podman machine start +} + +install_compose_prerequisite() { + if command_exists docker && ! docker compose version >/dev/null 2>&1; then + if command_exists brew; then + install_with_brew "Docker Compose" docker-compose + elif command_exists apt-get; then + install_with_apt "Docker Compose plugin" docker-compose-plugin + elif command_exists dnf; then + install_with_dnf "Docker Compose plugin" docker-compose-plugin + elif command_exists yum; then + install_with_yum "Docker Compose plugin" docker-compose-plugin + elif command_exists pacman; then + install_with_pacman "Docker Compose" docker-compose + elif command_exists zypper; then + install_with_zypper "Docker Compose" docker-compose + else + warn "Install Docker Compose manually, then run this installer again." + return 1 + fi + return + fi + + if command_exists podman && ! command_exists podman-compose; then + if command_exists brew; then + install_with_brew "podman-compose" podman-compose + elif command_exists apt-get; then + install_with_apt "podman-compose" podman-compose + elif command_exists dnf; then + install_with_dnf "podman-compose" podman-compose + elif command_exists yum; then + install_with_yum "podman-compose" podman-compose + elif command_exists pacman; then + install_with_pacman "podman-compose" podman-compose + elif command_exists zypper; then + install_with_zypper "podman-compose" podman-compose + else + warn "Install podman-compose manually, then run this installer again." + return 1 + fi + start_podman_machine_if_needed || true + return + fi + + if command_exists brew; then + info "Docker Desktop also works well. The CLI-friendly fallback is Podman plus podman-compose." + install_with_brew "Podman and podman-compose" podman podman-compose + start_podman_machine_if_needed || true + elif command_exists apt-get; then + install_with_apt "Podman and podman-compose" podman podman-compose + elif command_exists dnf; then + install_with_dnf "Podman and podman-compose" podman podman-compose + elif command_exists yum; then + install_with_yum "Podman and podman-compose" podman podman-compose + elif command_exists pacman; then + install_with_pacman "Podman and podman-compose" podman podman-compose + elif command_exists zypper; then + install_with_zypper "Podman and podman-compose" podman podman-compose + else + warn "Install Docker Desktop, Docker Engine with Compose, or Podman with podman-compose, then run this installer again." + return 1 + fi +} + +install_ollama_prerequisite() { + local command_text + + if command_exists brew; then + install_with_brew "Ollama" ollama + elif [[ "$HW_OS" == "Linux" ]] && command_exists curl; then + command_text="curl -fsSL https://ollama.com/install.sh | sh" + info "This uses Ollama's official Linux install script." + if confirm_install_command "Install Ollama now?" "$command_text"; then + sh -c "$command_text" + else + return 1 + fi + else + warn "Install Ollama from https://ollama.com/download, then run this installer again." + return 1 + fi +} + +ensure_python3_available() { + command_exists python3 && return 0 + + say "Python 3 is missing" + install_python3_prerequisite || die "Python 3 is required to run tests and helper CLIs." + command_exists python3 || die "Python 3 was not found after installation. Open a new terminal or add it to PATH, then run this installer again." +} + +ensure_python_venv_available() { + python3 -m venv --help >/dev/null 2>&1 && return 0 + + say "Python venv support is missing" + install_python_venv_prerequisite || die "Python venv support is required to create the installer environment." + python3 -m venv --help >/dev/null 2>&1 || die "Python venv support is still unavailable. Open a new terminal or install python3-venv manually." +} + +ensure_basic_tool_available() { + local tool="$1" + local reason="$2" + + command_exists "$tool" && return 0 + + say "$tool is missing" + info "$reason" + install_basic_tool_prerequisite "$tool" || die "$tool is required. Install it manually, then run this installer again." + command_exists "$tool" || die "$tool was not found after installation. Open a new terminal or add it to PATH, then run this installer again." +} + +ensure_npx_available() { + [[ -n "$USE_EXISTING_COMPOSE" ]] && return 0 + command_exists npx && return 0 + + say "npx is missing" + info "npx is required for the existing TrustGraph config generator: npx @trustgraph/config." + install_node_prerequisite || die "npx is required. Install Node.js/npm manually, then run this installer again." + command_exists npx || die "npx was not found after installation. Open a new terminal or add it to PATH, then run this installer again." +} + +ensure_compose_available() { + detect_compose_command && return 0 + + say "Container compose support is missing" + info "TrustGraph runs as a compose stack. Docker Compose or podman-compose is required." + install_compose_prerequisite || die "Docker Compose or podman-compose is required to start TrustGraph." + detect_compose_command || die "Compose support was not found after installation. Open a new terminal or add it to PATH, then run this installer again." +} + +ensure_ollama_available_if_needed() { + [[ "$LLM_MODE" == "ollama" ]] || return 0 + command_exists ollama && return 0 + + say "Ollama is missing" + info "Ollama was selected for local LLMs, so the Ollama CLI and service are needed before model setup." + install_ollama_prerequisite || die "Ollama is required for the selected local LLM path. Install it manually, then run this installer again." + command_exists ollama || die "Ollama was not found after installation. Open a new terminal or add it to PATH, then run this installer again." +} + +preflight() { + say "Checking prerequisites" + + ensure_python3_available + ensure_python_venv_available + ensure_basic_tool_available unzip "unzip is required to unpack deploy.zip from the config generator." + ensure_basic_tool_available curl "curl is required for startup health checks and local service probes." + ensure_npx_available + ensure_compose_available + ensure_ollama_available_if_needed + check_container_runtime_ready + + info "Compose command: ${COMPOSE_CMD[*]}" + info "Python: $(python3 --version 2>&1)" + if command_exists npx; then + info "npx: $(npx --version 2>/dev/null || printf unknown)" + fi +} + +write_env_file() { + mkdir -p "$INSTALL_DIR" + local env_file="$INSTALL_DIR/trustgraph-installer.env" + local grafana_admin_password="${GF_SECURITY_ADMIN_PASSWORD:-${GRAFANA_ADMIN_PASSWORD:-$AUTH_TOKEN}}" + + umask 077 + { + printf 'export TRUSTGRAPH_URL=%q\n' "$API_URL" + printf 'export TRUSTGRAPH_UI_URL=%q\n' "$UI_URL" + printf 'export TRUSTGRAPH_TOKEN=%q\n' "$AUTH_TOKEN" + printf 'export TRUSTGRAPH_BOOTSTRAP_TOKEN=%q\n' "$AUTH_TOKEN" + printf 'export IAM_BOOTSTRAP_TOKEN=%q\n' "$AUTH_TOKEN" + printf 'export GF_SECURITY_ADMIN_PASSWORD=%q\n' "$grafana_admin_password" + printf 'export TG_VENV_DIR=%q\n' "$VENV_DIR" + printf 'export TG_NLTK_DATA_DIR=%q\n' "$NLTK_DATA_DIR" + printf 'export NLTK_DATA=%q\n' "$NLTK_DATA_DIR${NLTK_DATA:+:$NLTK_DATA}" + printf 'export TIKTOKEN_CACHE_DIR=%q\n' "$TIKTOKEN_CACHE_DIR_VALUE" + printf 'export TRUSTGRAPH_LLM_MODE=%q\n' "$LLM_MODE" + printf 'export OPENAI_BASE_URL=%q\n' "$OPENAI_BASE_URL_VALUE" + printf 'export OPENAI_TOKEN=%q\n' "$OPENAI_TOKEN_VALUE" + printf 'export OLLAMA_HOST=%q\n' "$OLLAMA_BASE_URL_VALUE" + printf 'export OLLAMA_BASE_URL=%q\n' "$OLLAMA_BASE_URL_VALUE" + printf 'export OLLAMA_MODEL=%q\n' "$OLLAMA_MODEL" + printf 'export OLLAMA_EMBEDDINGS_MODEL=%q\n' "$OLLAMA_EMBEDDINGS_MODEL" + } > "$env_file" + chmod 600 "$env_file" + + info "Saved answers to $env_file" +} + +prepare_python_env() { + say "Preparing Python environment" + mkdir -p "$INSTALL_DIR" + + if [[ ! -x "$VENV_DIR/bin/python" ]]; then + info "Creating venv at $VENV_DIR" + run_with_spinner "Creating Python venv" python3 -m venv "$VENV_DIR" + else + info "Using existing venv at $VENV_DIR" + fi + + PYTHON_BIN="$VENV_DIR/bin/python" + export PATH="$VENV_DIR/bin:$PATH" + info "Python venv: $($PYTHON_BIN --version 2>&1)" +} + +ensure_version_files() { + local version="${TRUSTGRAPH_LOCAL_VERSION:-2.5.0}" + local specs=( + "trustgraph-base/trustgraph/base_version.py:trustgraph.base_version" + "trustgraph-flow/trustgraph/flow_version.py:trustgraph.flow_version" + "trustgraph-vertexai/trustgraph/vertexai_version.py:trustgraph.vertexai_version" + "trustgraph-bedrock/trustgraph/bedrock_version.py:trustgraph.bedrock_version" + "trustgraph-embeddings-hf/trustgraph/embeddings_hf_version.py:trustgraph.embeddings_hf_version" + "trustgraph-cli/trustgraph/cli_version.py:trustgraph.cli_version" + "trustgraph-ocr/trustgraph/ocr_version.py:trustgraph.ocr_version" + "trustgraph-unstructured/trustgraph/unstructured_version.py:trustgraph.unstructured_version" + "trustgraph-mcp/trustgraph/mcp_version.py:trustgraph.mcp_version" + "trustgraph/trustgraph/trustgraph_version.py:trustgraph.trustgraph_version" + ) + + say "Ensuring local package version files" + for spec in "${specs[@]}"; do + local file="${spec%%:*}" + mkdir -p "$(dirname "$SCRIPT_DIR/$file")" + printf '__version__ = "%s"\n' "$version" > "$SCRIPT_DIR/$file" + info "Set $file to $version" + done +} + +local_package_pythonpath() { + local package_dirs=( + "$SCRIPT_DIR/trustgraph-flow" + "$SCRIPT_DIR/trustgraph-embeddings-hf" + "$SCRIPT_DIR/trustgraph-base" + "$SCRIPT_DIR/trustgraph-cli" + "$SCRIPT_DIR/trustgraph-bedrock" + "$SCRIPT_DIR/trustgraph-ocr" + "$SCRIPT_DIR/trustgraph-unstructured" + "$SCRIPT_DIR/trustgraph-mcp" + "$SCRIPT_DIR/trustgraph-vertexai" + "$SCRIPT_DIR/trustgraph" + ) + local joined="" + local dir + + for dir in "${package_dirs[@]}"; do + if [[ -d "$dir" ]]; then + if [[ -n "$joined" ]]; then + joined="$joined:$dir" + else + joined="$dir" + fi + fi + done + + printf '%s\n' "$joined" +} + +ensure_python_build_tools() { + say "Preparing Python build tools" + local pip_cache_dir="$INSTALL_DIR/pip_cache" + mkdir -p "$pip_cache_dir" + + if ! "$PYTHON_BIN" -m pip --version >/dev/null 2>&1; then + local ensurepip_log + ensurepip_log="$(installer_log_file "python-ensurepip")" + info "Installing pip into the Python venv" + run_with_spinner_logged \ + "Installing pip" \ + "$ensurepip_log" \ + "$PYTHON_BIN" -m ensurepip --upgrade \ + || die "Could not install pip into the Python venv." + fi + + if "$PYTHON_BIN" - <<'PY' >/dev/null 2>&1 +import setuptools.build_meta +PY + then + info "Python build backend available: setuptools.build_meta" + return + fi + + local log_file + log_file="$(installer_log_file "pip-build-tools")" + info "Installing setuptools and wheel into the Python venv" + run_with_spinner_logged \ + "Installing Python build tools" \ + "$log_file" \ + env \ + PIP_CACHE_DIR="$pip_cache_dir" \ + PIP_DISABLE_PIP_VERSION_CHECK=1 \ + "$PYTHON_BIN" -m pip install "setuptools>=61" wheel \ + || die "Could not install setuptools/wheel. Check $log_file, then re-run the installer." +} + +install_test_packages() { + say "Installing local Python packages for tests" + local pip_cache_dir="$INSTALL_DIR/pip_cache" + mkdir -p "$pip_cache_dir" + ensure_python_build_tools + + local package_dirs=( + trustgraph-base + trustgraph-cli + trustgraph-flow + trustgraph-vertexai + trustgraph-bedrock + trustgraph-embeddings-hf + trustgraph-ocr + trustgraph-unstructured + trustgraph-mcp + ) + + for package_dir in "${package_dirs[@]}"; do + if [[ -d "$SCRIPT_DIR/$package_dir" ]]; then + local log_file + log_file="$(installer_log_file "pip-${package_dir}")" + info "Installing $package_dir" + run_with_spinner_logged \ + "Installing $package_dir" \ + "$log_file" \ + env \ + PIP_CACHE_DIR="$pip_cache_dir" \ + PIP_DISABLE_PIP_VERSION_CHECK=1 \ + "$PYTHON_BIN" -m pip install --no-build-isolation "$SCRIPT_DIR/$package_dir" + fi + done + + if [[ -f "$SCRIPT_DIR/tests/requirements.txt" ]]; then + local log_file + log_file="$(installer_log_file "pip-test-requirements")" + info "Installing test requirements" + run_with_spinner_logged \ + "Installing test requirements" \ + "$log_file" \ + env \ + PIP_CACHE_DIR="$pip_cache_dir" \ + PIP_DISABLE_PIP_VERSION_CHECK=1 \ + "$PYTHON_BIN" -m pip install -r "$SCRIPT_DIR/tests/requirements.txt" + fi +} + +ensure_tokenizer_cache() { + say "Preparing tokenizer cache" + mkdir -p "$TIKTOKEN_CACHE_DIR_VALUE" + info "tiktoken cache: $TIKTOKEN_CACHE_DIR_VALUE" + + TIKTOKEN_CACHE_DIR="$TIKTOKEN_CACHE_DIR_VALUE" "$PYTHON_BIN" - <<'PY' +import tiktoken + +tiktoken.get_encoding("cl100k_base") +print(" Cached tiktoken encoding: cl100k_base") +PY +} + +ensure_nltk_data() { + say "Preparing NLTK tokenizer data" + mkdir -p "$NLTK_DATA_DIR" + info "NLTK data: $NLTK_DATA_DIR" + + TG_NLTK_DATA_DIR="$NLTK_DATA_DIR" \ + NLTK_DATA="$NLTK_DATA_DIR${NLTK_DATA:+:$NLTK_DATA}" \ + "$PYTHON_BIN" - <<'PY' +import os +import nltk + +target = os.environ["TG_NLTK_DATA_DIR"] +if target not in nltk.data.path: + nltk.data.path.insert(0, target) + +resources = ( + ("punkt", "tokenizers/punkt"), + ("punkt_tab", "tokenizers/punkt_tab"), + ("averaged_perceptron_tagger_eng", "taggers/averaged_perceptron_tagger_eng"), +) + +for package, resource in resources: + try: + nltk.data.find(resource) + except LookupError: + print(f" Downloading NLTK resource: {package}") + if not nltk.download(package, download_dir=target, quiet=True): + raise SystemExit(f"Could not download NLTK resource: {package}") + else: + print(f" NLTK resource already available: {package}") +PY +} + +run_all_tests() { + if [[ "$RUN_TESTS" -ne 1 ]]; then + warn "Skipping tests because --skip-tests was supplied." + return + fi + + prepare_python_env + ensure_version_files + install_test_packages + ensure_tokenizer_cache + ensure_nltk_data + + say "Running all tests" + info "Command: $PYTHON_BIN -m pytest tests" + local test_log + test_log="$(installer_log_file "pytest")" + if spinner_enabled; then + info "Test output log: $test_log" + fi + ( + cd "$SCRIPT_DIR" + run_with_spinner_logged \ + "Running pytest tests" \ + "$test_log" \ + env \ + INSTALL_TRUSTGRAPH_SOURCE_ONLY= \ + TG_NO_SPINNER= \ + TG_FORCE_SPINNER= \ + NLTK_DATA="$NLTK_DATA_DIR${NLTK_DATA:+:$NLTK_DATA}" \ + TIKTOKEN_CACHE_DIR="$TIKTOKEN_CACHE_DIR_VALUE" \ + TRUSTGRAPH_CASSANDRA_SKIP_ON_UNREADY=1 \ + "$PYTHON_BIN" -m pytest tests + ) +} + +show_config_guidance() { + say "Before the config wizard starts" + info "Choose a Docker/Podman compose deployment for local installation." + info "Keep the Workbench UI enabled; the existing UI default is port 8888." + info "Use the bundled infrastructure defaults: Cassandra, Qdrant, Garage, and RabbitMQ/Pulsar as offered." + if [[ -n "$AUTH_TOKEN" ]]; then + info "For IAM/auth, use token/bootstrap-token mode when offered." + info "Admin/bootstrap API key to enter if asked: $AUTH_TOKEN" + else + info "For IAM/auth, use token/bootstrap-token mode when offered and paste the API key saved by this installer." + fi + if [[ "$LLM_MODE" == "ollama" ]]; then + info "For LLMs, choose Ollama or an OpenAI-compatible endpoint and use $OLLAMA_BASE_URL_VALUE." + elif [[ "$LLM_MODE" == "openai" ]]; then + info "For LLMs, choose OpenAI/OpenAI-compatible and use $OPENAI_BASE_URL_VALUE." + else + info "You can skip LLM configuration now and add it later in the Workbench." + fi +} + +run_config_generator() { + if [[ -n "$USE_EXISTING_COMPOSE" ]]; then + return + fi + + mkdir -p "$INSTALL_DIR" + + if [[ -f "$INSTALL_DIR/deploy.zip" ]]; then + if confirm "Existing deploy.zip found in $INSTALL_DIR. Reuse it and skip the config wizard?" 1; then + info "Using existing deployment archive: $INSTALL_DIR/deploy.zip" + return + fi + fi + + show_config_guidance + + if ! confirm "Start the TrustGraph config wizard now?" 1; then + die "Config generation cancelled." + fi + + say "Running TrustGraph config generator" + ( + cd "$INSTALL_DIR" + TRUSTGRAPH_TOKEN="$AUTH_TOKEN" \ + TRUSTGRAPH_BOOTSTRAP_TOKEN="$AUTH_TOKEN" \ + OPENAI_TOKEN="$OPENAI_TOKEN_VALUE" \ + OPENAI_BASE_URL="$OPENAI_BASE_URL_VALUE" \ + OLLAMA_HOST="$OLLAMA_BASE_URL_VALUE" \ + OLLAMA_BASE_URL="$OLLAMA_BASE_URL_VALUE" \ + OLLAMA_MODEL="$OLLAMA_MODEL" \ + OLLAMA_EMBEDDINGS_MODEL="$OLLAMA_EMBEDDINGS_MODEL" \ + NLTK_DATA="$NLTK_DATA_DIR${NLTK_DATA:+:$NLTK_DATA}" \ + TIKTOKEN_CACHE_DIR="$TIKTOKEN_CACHE_DIR_VALUE" \ + npx @trustgraph/config + ) +} + +find_compose_file() { + if [[ -n "$USE_EXISTING_COMPOSE" ]]; then + [[ -f "$USE_EXISTING_COMPOSE" ]] || die "Compose file does not exist: $USE_EXISTING_COMPOSE" + printf '%s\n' "$USE_EXISTING_COMPOSE" + return + fi + + local deploy_zip="$INSTALL_DIR/deploy.zip" + local unpack_dir="$INSTALL_DIR/deploy" + + [[ -f "$deploy_zip" ]] || die "The config generator did not create $deploy_zip" + + rm -rf "$unpack_dir" + mkdir -p "$unpack_dir" + unzip -oq "$deploy_zip" -d "$unpack_dir" + + local compose_file + compose_file="$(find "$unpack_dir" "$INSTALL_DIR" \ + \( -name 'docker-compose.yaml' -o -name 'docker-compose.yml' -o -name 'compose.yaml' -o -name 'compose.yml' \) \ + -type f | head -n 1)" + + [[ -n "$compose_file" ]] || die "Could not find a compose file in $deploy_zip" + printf '%s\n' "$compose_file" +} + +compose_dir_for() { + local compose_file="$1" + (cd "$(dirname "$compose_file")" && pwd -P) +} + +compose_env_file_for() { + local compose_file="$1" + local compose_dir + + compose_dir="$(compose_dir_for "$compose_file")" + printf '%s/.env\n' "$compose_dir" +} + +write_compose_env_file() { + local compose_file="$1" + local compose_env_file + local grafana_admin_password="${GF_SECURITY_ADMIN_PASSWORD:-${GRAFANA_ADMIN_PASSWORD:-$AUTH_TOKEN}}" + + [[ -n "$AUTH_TOKEN" ]] || die "TrustGraph API key is empty; cannot create compose environment." + + compose_env_file="$(compose_env_file_for "$compose_file")" + umask 077 + { + printf 'TRUSTGRAPH_TOKEN=%s\n' "$AUTH_TOKEN" + printf 'TRUSTGRAPH_BOOTSTRAP_TOKEN=%s\n' "$AUTH_TOKEN" + printf 'IAM_BOOTSTRAP_TOKEN=%s\n' "$AUTH_TOKEN" + printf 'GF_SECURITY_ADMIN_PASSWORD=%s\n' "$grafana_admin_password" + printf 'OLLAMA_HOST=%s\n' "$OLLAMA_BASE_URL_VALUE" + printf 'OLLAMA_BASE_URL=%s\n' "$OLLAMA_BASE_URL_VALUE" + printf 'OLLAMA_MODEL=%s\n' "$OLLAMA_MODEL" + printf 'OLLAMA_EMBEDDINGS_MODEL=%s\n' "$OLLAMA_EMBEDDINGS_MODEL" + } > "$compose_env_file" + chmod 600 "$compose_env_file" + + info "Compose environment: $compose_env_file" +} + +start_stack() { + local compose_file="$1" + local compose_dir + local compose_name + local log_file + + say "Starting TrustGraph" + info "Compose file: $compose_file" + write_compose_env_file "$compose_file" + + compose_dir="$(compose_dir_for "$compose_file")" + compose_name="$(basename "$compose_file")" + log_file="$(installer_log_file "compose-up")" + case "$log_file" in + /*) ;; + *) log_file="$SCRIPT_DIR/$log_file" ;; + esac + + ( + cd "$compose_dir" + run_with_spinner_logged \ + "Starting TrustGraph containers" \ + "$log_file" \ + "${COMPOSE_CMD[@]}" -f "$compose_name" up -d + ) +} + +http_status() { + local url="$1" + curl -sS -o /dev/null -w '%{http_code}' --max-time 5 "$url" 2>/dev/null || true +} + +http_status_with_bearer() { + local url="$1" + local token="$2" + + curl -sS -o /dev/null -w '%{http_code}' --max-time 5 \ + -H "Authorization: Bearer $token" \ + "$url" 2>/dev/null || true +} + +sha256_text() { + local value="$1" + + printf '%s' "$value" | python3 -c 'import hashlib, sys; print(hashlib.sha256(sys.stdin.buffer.read()).hexdigest())' +} + +repair_local_iam_api_key() { + local compose_file="$1" + local compose_dir + local compose_name + local key_hash + local key_suffix + local user_id + local username + local key_id + local prefix + local cql + local log_file + + [[ -n "$compose_file" && -f "$compose_file" ]] || return 1 + [[ -n "$AUTH_TOKEN" ]] || return 1 + command_exists python3 || return 1 + [[ "${#COMPOSE_CMD[@]}" -gt 0 ]] || return 1 + + key_hash="$(sha256_text "$AUTH_TOKEN")" + key_suffix="${key_hash:0:12}" + user_id="installer-admin-$key_suffix" + username="installer-admin-$key_suffix" + key_id="installer-key-$key_suffix" + prefix="$(printf '%s' "${AUTH_TOKEN:0:7}" | tr -cd 'a-zA-Z0-9_-')" + compose_dir="$(compose_dir_for "$compose_file")" + compose_name="$(basename "$compose_file")" + log_file="$(installer_log_file "iam-key-repair")" + + cql=" +USE iam; +INSERT INTO iam_users (id, workspace, username, name, email, password_hash, roles, enabled, must_change_password, created) +VALUES ('$user_id', 'default', '$username', 'Installer Admin', '', 'installer-repair', {'admin'}, true, false, toTimestamp(now())); +INSERT INTO iam_users_by_username (workspace, username, user_id) +VALUES ('default', '$username', '$user_id'); +INSERT INTO iam_api_keys (key_hash, id, user_id, name, prefix, expires, created, last_used) +VALUES ('$key_hash', '$key_id', '$user_id', 'installer-repair', '$prefix', null, toTimestamp(now()), null); +" + + say "Repairing local IAM API key" + info "Adding the saved installer key to the local installer-managed IAM database." + mkdir -p "$(dirname "$log_file")" + if ( + cd "$compose_dir" + printf '%s\n' "$cql" | "${COMPOSE_CMD[@]}" -f "$compose_name" exec -T cassandra cqlsh + ) >"$log_file" 2>&1; then + info "Local IAM API key repair completed." + return 0 + fi + + warn "Local IAM API key repair failed. Last log lines from $log_file:" + tail -n 40 "$log_file" >&2 || true + return 1 +} + +wait_for_gateway() { + local deadline=$((SECONDS + HEALTH_TIMEOUT)) + local next_notice=$((SECONDS + 15)) + local status="" + + say "Waiting for API gateway" + info "Checking $API_URL for up to ${HEALTH_TIMEOUT}s." + while (( SECONDS < deadline )); do + status="$(http_status "$API_URL")" + if [[ "$status" == "200" || "$status" == "401" || "$status" == "404" ]]; then + info "API gateway is responding with HTTP $status" + return 0 + fi + if (( SECONDS >= next_notice )); then + info "Still waiting; last HTTP status was ${status:-connection failed}." + next_notice=$((SECONDS + 15)) + fi + sleep 3 + done + + die "API gateway did not respond at $API_URL within ${HEALTH_TIMEOUT}s" +} + +verify_api_key_authentication() { + local compose_file="${1:-}" + local deadline=$((SECONDS + AUTH_CHECK_TIMEOUT)) + local metrics_url="${API_URL%/}/api/metrics/query?query=processor_info" + local status="" + + [[ -n "$AUTH_TOKEN" ]] || return 0 + + say "Checking API key authentication" + info "The API gateway root can return HTTP 404; that is normal. This checks an authenticated endpoint." + + while :; do + status="$(http_status_with_bearer "$metrics_url" "$AUTH_TOKEN")" + case "$status" in + 200) + info "Installer API key authenticated at the API gateway." + return 0 + ;; + 401|403) + ;; + "") + ;; + *) + info "Authentication probe returned HTTP $status; continuing to the full health checks." + return 0 + ;; + esac + + (( SECONDS >= deadline )) && break + sleep 3 + done + + if [[ "$status" == "401" || "$status" == "403" ]]; then + if [[ -n "$compose_file" ]] && repair_local_iam_api_key "$compose_file"; then + status="$(http_status_with_bearer "$metrics_url" "$AUTH_TOKEN")" + if [[ "$status" == "200" ]]; then + info "Installer API key authenticated after local IAM repair." + return 0 + fi + fi + warn "The API gateway is running, but it rejected the installer API key." + info "Configured installer API key: $AUTH_TOKEN" + info "Saved environment: $INSTALL_DIR/trustgraph-installer.env" + warn "This usually means compose volumes contain IAM data from an earlier install. Run ./install_trustgraph.sh --remove-all to remove the installer-managed deployment and compose volumes, then reinstall; or rerun with the original TRUSTGRAPH_TOKEN if you know it." + return 1 + fi + + warn "Could not confirm API key authentication yet; continuing to the full health checks." +} + +bootstrap_iam_if_available() { + local bootstrap_output="" + local log_file="$INSTALL_DIR/iam-bootstrap.log" + + if ! command_exists tg-bootstrap-iam; then + warn "tg-bootstrap-iam is not on PATH; using the installer API key for health checks." + return + fi + + say "Checking IAM bootstrap" + if bootstrap_output="$(tg-bootstrap-iam --api-url "$API_URL" 2>"$log_file")"; then + if [[ -n "$bootstrap_output" ]]; then + AUTH_TOKEN="$bootstrap_output" + info "Captured the first-run admin API key from IAM bootstrap." + write_env_file + fi + else + info "IAM bootstrap did not issue a new key; this is normal for token mode or an already-bootstrapped system." + info "Details: $log_file" + fi +} + +verify_system() { + local verify_cmd=() + + if command_exists tg-verify-system-status; then + verify_cmd=(tg-verify-system-status) + elif "$PYTHON_BIN" -c 'import trustgraph.cli.verify_system_status' >/dev/null 2>&1; then + verify_cmd=("$PYTHON_BIN" -m trustgraph.cli.verify_system_status) + else + say "Verifying TrustGraph health" + info "API gateway: $API_URL" + info "Workbench UI: $UI_URL" + [[ "$(http_status "$API_URL")" =~ ^(200|401|404)$ ]] || die "API gateway health check failed." + [[ "$(http_status "${UI_URL%/}/index.html")" == "200" ]] || warn "Workbench UI did not return HTTP 200 yet." + return + fi + + say "Verifying TrustGraph health" + verify_cmd+=( + --api-url "$API_URL" + --ui-url "$UI_URL" + --global-timeout "$HEALTH_TIMEOUT" + ) + if [[ -n "$AUTH_TOKEN" ]]; then + verify_cmd+=(--token "$AUTH_TOKEN") + fi + + "${verify_cmd[@]}" +} + +launch_ui() { + if [[ "$AUTO_LAUNCH" -ne 1 ]]; then + info "Workbench UI autolaunch disabled." + return + fi + + say "Opening Workbench UI" + if command_exists open; then + open "$UI_URL" + elif command_exists xdg-open; then + xdg-open "$UI_URL" + elif command_exists wslview; then + wslview "$UI_URL" + else + warn "Could not find a browser launcher. Open this URL manually: $UI_URL" + return + fi + info "Workbench UI: $UI_URL" +} + +print_ready_summary() { + local auth_status="${1:-0}" + + if [[ "$auth_status" -eq 0 ]]; then + say "TrustGraph is ready" + else + say "TrustGraph started with an authentication warning" + fi + info "Workbench UI: $UI_URL" + info "API gateway: $API_URL" + if [[ -n "$AUTH_TOKEN" ]]; then + info "Admin/bootstrap API key: $AUTH_TOKEN" + fi + info "Saved environment: $INSTALL_DIR/trustgraph-installer.env" +} + +main() { + parse_args "$@" + cd "$SCRIPT_DIR" + init_colors + + print_banner + if [[ "$REMOVE_ALL" -eq 1 ]]; then + say "$APP_NAME guided uninstaller" + load_saved_answers + remove_all_installation + return 0 + fi + + say "$APP_NAME guided installer" + handle_existing_install + load_saved_answers + detect_hardware + choose_recommendations + print_hardware_summary + + collect_answers + print_plan_summary + + if [[ "$DRY_RUN" -eq 1 ]]; then + say "Dry run complete" + return 0 + fi + + if ! confirm "Proceed with this install plan?" 1; then + die "Install cancelled." + fi + + preflight + offer_ollama_model_downloads + write_env_file + run_all_tests + run_config_generator + + local compose_file + compose_file="$(find_compose_file)" + start_stack "$compose_file" + wait_for_gateway + bootstrap_iam_if_available + local auth_status=0 + local verify_status=0 + + verify_api_key_authentication "$compose_file" || auth_status=$? + if [[ "$auth_status" -eq 0 ]]; then + verify_system || verify_status=$? + else + warn "Skipping authenticated health checks because the configured API key was rejected." + fi + launch_ui + + print_ready_summary "$auth_status" + + if [[ "$auth_status" -ne 0 ]]; then + return "$auth_status" + fi + if [[ "$verify_status" -ne 0 ]]; then + return "$verify_status" + fi +} + +if [[ "${INSTALL_TRUSTGRAPH_SOURCE_ONLY:-0}" != "1" ]]; then + main "$@" +fi diff --git a/tests/integration/test_graph_rag_integration.py b/tests/integration/test_graph_rag_integration.py index 696df7ec..8930d159 100644 --- a/tests/integration/test_graph_rag_integration.py +++ b/tests/integration/test_graph_rag_integration.py @@ -95,10 +95,6 @@ class TestGraphRagIntegration: async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None): if prompt_name == "extract-concepts": return PromptResult(response_type="text", text="") - elif prompt_name == "kg-edge-scoring": - return PromptResult(response_type="text", text="") - elif prompt_name == "kg-edge-reasoning": - return PromptResult(response_type="text", text="") elif prompt_name == "kg-synthesis": return PromptResult( response_type="text", @@ -113,14 +109,22 @@ class TestGraphRagIntegration: client.prompt.side_effect = mock_prompt return client + @pytest.fixture + def mock_reranker_client(self): + """Mock reranker client for cross-encoder edge filtering""" + client = AsyncMock() + client.rerank.return_value = [] + return client + @pytest.fixture def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client, - mock_triples_client, mock_prompt_client): + mock_triples_client, mock_reranker_client, mock_prompt_client): """Create GraphRag instance with mocked dependencies""" return GraphRag( embeddings_client=mock_embeddings_client, graph_embeddings_client=mock_graph_embeddings_client, triples_client=mock_triples_client, + reranker_client=mock_reranker_client, prompt_client=mock_prompt_client, verbose=True ) @@ -167,8 +171,8 @@ class TestGraphRagIntegration: # 3. Should query triples to build knowledge subgraph assert mock_triples_client.query_stream.call_count > 0 - # 4. Should call prompt four times (extract-concepts + edge-scoring + edge-reasoning + synthesis) - assert mock_prompt_client.prompt.call_count == 4 + # 4. Should call prompt twice (extract-concepts + synthesis) + assert mock_prompt_client.prompt.call_count == 2 # Verify final response response, usage = response diff --git a/tests/integration/test_graph_rag_streaming_integration.py b/tests/integration/test_graph_rag_streaming_integration.py index 48e26618..8dfd9c2b 100644 --- a/tests/integration/test_graph_rag_streaming_integration.py +++ b/tests/integration/test_graph_rag_streaming_integration.py @@ -63,11 +63,6 @@ class TestGraphRagStreaming: async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs): if prompt_id == "extract-concepts": return PromptResult(response_type="text", text="") - elif prompt_id == "kg-edge-scoring": - # Edge scoring returns JSONL with IDs and scores - return PromptResult(response_type="text", text='{"id": "abc12345", "score": 0.9}\n') - elif prompt_id == "kg-edge-reasoning": - return PromptResult(response_type="text", text='{"id": "abc12345", "reasoning": "Relevant to query"}\n') elif prompt_id == "kg-synthesis": if streaming and chunk_callback: # Simulate streaming chunks with end_of_stream flags @@ -88,14 +83,23 @@ class TestGraphRagStreaming: client.prompt.side_effect = prompt_side_effect return client + @pytest.fixture + def mock_reranker_client(self): + """Mock reranker client for cross-encoder edge filtering""" + client = AsyncMock() + client.rerank.return_value = [] + return client + @pytest.fixture def graph_rag_streaming(self, mock_embeddings_client, mock_graph_embeddings_client, - mock_triples_client, mock_streaming_prompt_client): + mock_triples_client, mock_reranker_client, + mock_streaming_prompt_client): """Create GraphRag instance with streaming support""" return GraphRag( embeddings_client=mock_embeddings_client, graph_embeddings_client=mock_graph_embeddings_client, triples_client=mock_triples_client, + reranker_client=mock_reranker_client, prompt_client=mock_streaming_prompt_client, verbose=True ) diff --git a/tests/integration/test_rag_streaming_protocol.py b/tests/integration/test_rag_streaming_protocol.py index 279c81ef..efce5922 100644 --- a/tests/integration/test_rag_streaming_protocol.py +++ b/tests/integration/test_rag_streaming_protocol.py @@ -46,7 +46,7 @@ class TestGraphRagStreamingProtocol: client = AsyncMock() async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None): - if prompt_name == "kg-edge-selection": + if prompt_name == "extract-concepts": return PromptResult(response_type="text", text="") elif prompt_name == "kg-synthesis": if streaming and chunk_callback: @@ -63,14 +63,23 @@ class TestGraphRagStreamingProtocol: client.prompt.side_effect = prompt_side_effect return client + @pytest.fixture + def mock_reranker_client(self): + """Mock reranker client for cross-encoder edge filtering""" + client = AsyncMock() + client.rerank.return_value = [] + return client + @pytest.fixture def graph_rag(self, mock_embeddings_client, mock_graph_embeddings_client, - mock_triples_client, mock_streaming_prompt_client): + mock_triples_client, mock_reranker_client, + mock_streaming_prompt_client): """Create GraphRag instance with mocked dependencies""" return GraphRag( embeddings_client=mock_embeddings_client, graph_embeddings_client=mock_graph_embeddings_client, triples_client=mock_triples_client, + reranker_client=mock_reranker_client, prompt_client=mock_streaming_prompt_client, verbose=False ) @@ -327,7 +336,7 @@ class TestStreamingProtocolEdgeCases: client = AsyncMock() async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None): - if prompt_name == "kg-edge-selection": + if prompt_name == "extract-concepts": return PromptResult(response_type="text", text="") elif prompt_name == "kg-synthesis": if streaming and chunk_callback: @@ -342,10 +351,14 @@ class TestStreamingProtocolEdgeCases: client.prompt.side_effect = prompt_with_empties + mock_reranker = AsyncMock() + mock_reranker.rerank.return_value = [] + rag = GraphRag( embeddings_client=AsyncMock(embed=AsyncMock(return_value=[[[0.1]]])), graph_embeddings_client=AsyncMock(query=AsyncMock(return_value=[])), triples_client=AsyncMock(query=AsyncMock(return_value=[])), + reranker_client=mock_reranker, prompt_client=client, verbose=False ) diff --git a/tests/integration/test_text_completion_integration.py b/tests/integration/test_text_completion_integration.py index 6615bf84..521f7d74 100644 --- a/tests/integration/test_text_completion_integration.py +++ b/tests/integration/test_text_completion_integration.py @@ -15,11 +15,20 @@ from openai.types.chat.chat_completion import Choice from openai.types.completion_usage import CompletionUsage from trustgraph.model.text_completion.openai.llm import Processor +from trustgraph.model.text_completion.openai.variants import get_variant from trustgraph.exceptions import TooManyRequests from trustgraph.base import LlmResult from trustgraph.schema import TextCompletionRequest, TextCompletionResponse, Error +def _wire_variant(processor): + """Attach variant methods to a MagicMock processor.""" + processor.variant = get_variant("openai") + processor.thinking = "off" + processor._build_kwargs = Processor._build_kwargs.__get__(processor, Processor) + processor._extract_content = Processor._extract_content.__get__(processor, Processor) + + @pytest.mark.integration class TestTextCompletionIntegration: """Integration tests for OpenAI text completion service coordination""" @@ -66,6 +75,7 @@ class TestTextCompletionIntegration: # Add the actual generate_content method from Processor class processor.generate_content = Processor.generate_content.__get__(processor, Processor) + _wire_variant(processor) return processor @@ -119,6 +129,7 @@ class TestTextCompletionIntegration: # Add the actual generate_content method processor.generate_content = Processor.generate_content.__get__(processor, Processor) + _wire_variant(processor) # Act result = await processor.generate_content("System prompt", "User prompt") @@ -129,7 +140,7 @@ class TestTextCompletionIntegration: assert result.in_token == 50 assert result.out_token == 100 # Note: result.model comes from mock response, not processor config - + # Verify configuration was applied call_args = mock_openai_client.chat.completions.create.call_args assert call_args.kwargs['model'] == config['model'] @@ -247,6 +258,7 @@ class TestTextCompletionIntegration: processor.max_output = processor_config["max_output"] processor.openai = mock_openai_client processor.generate_content = Processor.generate_content.__get__(processor, Processor) + _wire_variant(processor) processors.append(processor) # Simulate multiple concurrent requests @@ -354,6 +366,7 @@ class TestTextCompletionIntegration: processor.max_output = 2048 processor.openai = mock_openai_client processor.generate_content = Processor.generate_content.__get__(processor, Processor) + _wire_variant(processor) # Act await processor.generate_content("System prompt", "User prompt") diff --git a/tests/integration/test_text_completion_streaming_integration.py b/tests/integration/test_text_completion_streaming_integration.py index 6968affa..7d514522 100644 --- a/tests/integration/test_text_completion_streaming_integration.py +++ b/tests/integration/test_text_completion_streaming_integration.py @@ -11,6 +11,7 @@ from openai.types.chat import ChatCompletionChunk from openai.types.chat.chat_completion_chunk import Choice as StreamChoice, ChoiceDelta from trustgraph.model.text_completion.openai.llm import Processor +from trustgraph.model.text_completion.openai.variants import get_variant from trustgraph.base import LlmChunk from tests.utils.streaming_assertions import ( assert_streaming_chunks_valid, @@ -18,6 +19,14 @@ from tests.utils.streaming_assertions import ( ) +def _wire_variant(processor): + """Attach variant methods to a MagicMock processor.""" + processor.variant = get_variant("openai") + processor.thinking = "off" + processor._build_kwargs = Processor._build_kwargs.__get__(processor, Processor) + processor._extract_content = Processor._extract_content.__get__(processor, Processor) + + @pytest.mark.integration class TestTextCompletionStreaming: """Integration tests for Text Completion streaming""" @@ -69,6 +78,7 @@ class TestTextCompletionStreaming: processor.generate_content_stream = Processor.generate_content_stream.__get__( processor, Processor ) + _wire_variant(processor) return processor @@ -190,6 +200,7 @@ class TestTextCompletionStreaming: processor.generate_content_stream = Processor.generate_content_stream.__get__( processor, Processor ) + _wire_variant(processor) # Act chunks = [] @@ -223,6 +234,7 @@ class TestTextCompletionStreaming: processor.generate_content_stream = Processor.generate_content_stream.__get__( processor, Processor ) + _wire_variant(processor) # Act chunks = [] @@ -258,6 +270,7 @@ class TestTextCompletionStreaming: processor.generate_content_stream = Processor.generate_content_stream.__get__( processor, Processor ) + _wire_variant(processor) # Act & Assert with pytest.raises(Exception) as exc_info: @@ -295,6 +308,7 @@ class TestTextCompletionStreaming: processor.generate_content_stream = Processor.generate_content_stream.__get__( processor, Processor ) + _wire_variant(processor) # Act chunks = [] @@ -318,6 +332,7 @@ class TestTextCompletionStreaming: processor.generate_content_stream = Processor.generate_content_stream.__get__( processor, Processor ) + _wire_variant(processor) system_prompt = "You are an expert." user_prompt = "Explain quantum physics." diff --git a/tests/unit/test_base/test_prompt_client_streaming.py b/tests/unit/test_base/test_prompt_client_streaming.py index 83a4b90e..fecf6095 100644 --- a/tests/unit/test_base/test_prompt_client_streaming.py +++ b/tests/unit/test_base/test_prompt_client_streaming.py @@ -195,38 +195,6 @@ class TestPromptClientStreamingCallback: assert callback.call_args_list[0] == call("test", False) assert callback.call_args_list[1] == call("", True) - @pytest.mark.asyncio - async def test_kg_prompt_passes_parameters_to_callback(self, prompt_client): - """Test that kg_prompt correctly passes streaming parameters""" - # Arrange - async def mock_request(request, recipient=None, timeout=600): - if recipient: - responses = [ - PromptResponse(text="Answer", object=None, error=None, end_of_stream=False), - PromptResponse(text="", object=None, error=None, end_of_stream=True), - ] - for resp in responses: - should_stop = await recipient(resp) - if should_stop: - break - - prompt_client.request = mock_request - - callback = AsyncMock() - - # Act - await prompt_client.kg_prompt( - query="What is machine learning?", - kg=[("subject", "predicate", "object")], - streaming=True, - chunk_callback=callback - ) - - # Assert - assert callback.call_count == 2 - assert callback.call_args_list[0] == call("Answer", False) - assert callback.call_args_list[1] == call("", True) - @pytest.mark.asyncio async def test_document_prompt_passes_parameters_to_callback(self, prompt_client): """Test that document_prompt correctly passes streaming parameters""" diff --git a/tests/unit/test_bootstrap/test_default_flow_start.py b/tests/unit/test_bootstrap/test_default_flow_start.py new file mode 100644 index 00000000..7846bee7 --- /dev/null +++ b/tests/unit/test_bootstrap/test_default_flow_start.py @@ -0,0 +1,54 @@ +""" +Unit tests for trustgraph.bootstrap.initialisers.DefaultFlowStart + +Verifies the list/start timeouts are configurable and that the +configured values actually reach the flow-client request calls. +""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from trustgraph.bootstrap.initialisers.default_flow_start import ( + DefaultFlowStart, +) + + +def test_default_timeouts(): + init = DefaultFlowStart(blueprint="bp") + assert init.list_timeout == 10 + assert init.start_timeout == 30 + + +def test_timeout_overrides_are_stored(): + init = DefaultFlowStart(blueprint="bp", list_timeout=5, start_timeout=99) + assert init.list_timeout == 5 + assert init.start_timeout == 99 + + +@pytest.mark.asyncio +async def test_run_forwards_configured_timeouts(): + init = DefaultFlowStart(blueprint="bp", list_timeout=5, start_timeout=99) + + # Flow client: list-flows returns no error + empty flow list, + # start-flow returns no error. + flow = MagicMock() + flow.start = AsyncMock() + flow.stop = AsyncMock() + flow.request = AsyncMock(side_effect=[ + MagicMock(error=None, flow_ids=[]), # list-flows response + MagicMock(error=None), # start-flow response + ]) + + # Context: workspace "default" exists, hands back our mock flow client. + ctx = MagicMock() + ctx.logger = MagicMock() + ctx.config.keys = AsyncMock(return_value=["default"]) + ctx.make_flow_client = MagicMock(return_value=flow) + + await init.run(ctx, None, "v1") + + calls = flow.request.call_args_list + assert len(calls) == 2 + assert calls[0].kwargs["timeout"] == 5 + assert calls[1].kwargs["timeout"] == 99 diff --git a/tests/unit/test_bootstrap/test_workspace_init.py b/tests/unit/test_bootstrap/test_workspace_init.py new file mode 100644 index 00000000..aa819904 --- /dev/null +++ b/tests/unit/test_bootstrap/test_workspace_init.py @@ -0,0 +1,13 @@ +"""Unit tests for trustgraph.bootstrap.initialisers.WorkspaceInit.""" + +from trustgraph.bootstrap.initialisers.workspace_init import WorkspaceInit + + +def test_default_iam_timeout(): + init = WorkspaceInit() + assert init.iam_timeout == 10 + + +def test_iam_timeout_override_is_stored(): + init = WorkspaceInit(iam_timeout=42) + assert init.iam_timeout == 42 diff --git a/tests/unit/test_concurrency/test_graph_rag_concurrency.py b/tests/unit/test_concurrency/test_graph_rag_concurrency.py index 1b35a238..ed567962 100644 --- a/tests/unit/test_concurrency/test_graph_rag_concurrency.py +++ b/tests/unit/test_concurrency/test_graph_rag_concurrency.py @@ -129,6 +129,9 @@ class TestBatchTripleQueries: # 3 queries, alternating results assert len(result) == 3 + # Each result is a (triple, direction) tuple + for triple, direction in result: + assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O) @pytest.mark.asyncio async def test_exception_in_one_query_does_not_block_others(self): @@ -153,6 +156,8 @@ class TestBatchTripleQueries: # 3 queries: 2 succeed, 1 fails → 2 triples assert len(result) == 2 + for triple, direction in result: + assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O) @pytest.mark.asyncio async def test_none_results_filtered(self): @@ -176,6 +181,8 @@ class TestBatchTripleQueries: # 3 queries: 1 returns None, 2 return triples assert len(result) == 2 + for triple, direction in result: + assert direction in (Query.FROM_S, Query.FROM_P, Query.FROM_O) @pytest.mark.asyncio async def test_empty_entities_no_queries(self): @@ -220,6 +227,80 @@ class TestBatchTripleQueries: assert calls[2].kwargs["p"] is None assert calls[2].kwargs["o"] == "ent-1" + @pytest.mark.asyncio + async def test_directions_assigned_correctly(self): + """Each query position should produce the correct direction tag.""" + triple = _make_triple("s", "p", "o") + + call_count = 0 + + async def one_triple_each(**kwargs): + nonlocal call_count + call_count += 1 + return [triple] + + client = AsyncMock() + client.query_stream = one_triple_each + query = _make_query(triples_client=client) + + result = await query.execute_batch_triple_queries( + ["e1"], limit_per_entity=10 + ) + + assert len(result) == 3 + # Order matches query order: s-position, p-position, o-position + assert result[0][1] == Query.FROM_S + assert result[1][1] == Query.FROM_P + assert result[2][1] == Query.FROM_O + + @pytest.mark.asyncio + async def test_directions_correct_for_multiple_entities(self): + """Direction tags cycle correctly across multiple entities.""" + triple = _make_triple("s", "p", "o") + client = AsyncMock() + client.query_stream = AsyncMock(return_value=[triple]) + query = _make_query(triples_client=client) + + result = await query.execute_batch_triple_queries( + ["e1", "e2"], limit_per_entity=10 + ) + + assert len(result) == 6 + expected_directions = [ + Query.FROM_S, Query.FROM_P, Query.FROM_O, + Query.FROM_S, Query.FROM_P, Query.FROM_O, + ] + for (_, direction), expected in zip(result, expected_directions): + assert direction == expected + + @pytest.mark.asyncio + async def test_direction_preserved_with_multiple_triples(self): + """All triples from one query share the same direction.""" + t1 = _make_triple("a", "p1", "b") + t2 = _make_triple("a", "p2", "c") + + call_count = 0 + + async def multi_results(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return [t1, t2] + return [] + + client = AsyncMock() + client.query_stream = multi_results + query = _make_query(triples_client=client) + + result = await query.execute_batch_triple_queries( + ["e1"], limit_per_entity=10 + ) + + # First query (FROM_S) returns 2 triples, both should be FROM_S + assert len(result) == 2 + assert result[0] == (t1, Query.FROM_S) + assert result[1] == (t2, Query.FROM_S) + class TestLRUCacheWithTTL: diff --git a/tests/unit/test_gateway/test_auth.py b/tests/unit/test_gateway/test_auth.py index 8ffcafa1..2775b0b0 100644 --- a/tests/unit/test_gateway/test_auth.py +++ b/tests/unit/test_gateway/test_auth.py @@ -86,19 +86,19 @@ class TestVerifyJwtEddsa: def test_valid_jwt_passes(self): priv, pub = make_keypair() claims = { - "sub": "user-1", "workspace": "default", + "sub": "user-1", "default_workspace": "default", "iat": int(time.time()), "exp": int(time.time()) + 60, } token = sign_jwt(priv, claims) got = _verify_jwt_eddsa(token, pub) assert got["sub"] == "user-1" - assert got["workspace"] == "default" + assert got["default_workspace"] == "default" def test_expired_jwt_rejected(self): priv, pub = make_keypair() claims = { - "sub": "user-1", "workspace": "default", + "sub": "user-1", "default_workspace": "default", "iat": int(time.time()) - 3600, "exp": int(time.time()) - 1, } @@ -110,7 +110,7 @@ class TestVerifyJwtEddsa: priv_a, _ = make_keypair() _, pub_b = make_keypair() claims = { - "sub": "user-1", "workspace": "default", + "sub": "user-1", "default_workspace": "default", "iat": int(time.time()), "exp": int(time.time()) + 60, } @@ -130,7 +130,7 @@ class TestVerifyJwtEddsa: # since we expect it to bail before verifying. header = {"alg": "HS256", "typ": "JWT", "kid": "x"} payload = { - "sub": "user-1", "workspace": "default", + "sub": "user-1", "default_workspace": "default", "iat": int(time.time()), "exp": int(time.time()) + 60, } h = _b64url(json.dumps(header, separators=(",", ":")).encode()) @@ -148,11 +148,11 @@ class TestIdentity: def test_fields(self): i = Identity( - handle="u", workspace="w", + handle="u", default_workspace="w", principal_id="u", source="api-key", ) assert i.handle == "u" - assert i.workspace == "w" + assert i.default_workspace == "w" assert i.principal_id == "u" assert i.source == "api-key" @@ -208,7 +208,7 @@ class TestIamAuthDispatch: async def test_valid_jwt_resolves_to_identity(self): priv, pub = make_keypair() claims = { - "sub": "user-1", "workspace": "default", + "sub": "user-1", "default_workspace": "default", "iat": int(time.time()), "exp": int(time.time()) + 60, } @@ -221,7 +221,7 @@ class TestIamAuthDispatch: make_request(f"Bearer {token}") ) assert ident.handle == "user-1" - assert ident.workspace == "default" + assert ident.default_workspace == "default" assert ident.principal_id == "user-1" assert ident.source == "jwt" @@ -231,7 +231,7 @@ class TestIamAuthDispatch: # must not validate — even ones that would otherwise pass. priv, _ = make_keypair() claims = { - "sub": "user-1", "workspace": "default", + "sub": "user-1", "default_workspace": "default", "iat": int(time.time()), "exp": int(time.time()) + 60, } token = sign_jwt(priv, claims) @@ -259,7 +259,7 @@ class TestIamAuthDispatch: make_request("Bearer tg_testkey") ) assert ident.handle == "user-xyz" - assert ident.workspace == "default" + assert ident.default_workspace == "default" assert ident.principal_id == "user-xyz" assert ident.source == "api-key" @@ -338,9 +338,9 @@ class TestAuthorise: decision for the regime's TTL (clamped above), and raises 403 on deny / 401 on regime error (fail closed).""" - def _make_identity(self, handle="u-1", workspace="default"): + def _make_identity(self, handle="u-1", default_workspace="default"): return Identity( - handle=handle, workspace=workspace, + handle=handle, default_workspace=default_workspace, principal_id=handle, source="api-key", ) diff --git a/tests/unit/test_gateway/test_capabilities.py b/tests/unit/test_gateway/test_capabilities.py index 4f781b16..04bac133 100644 --- a/tests/unit/test_gateway/test_capabilities.py +++ b/tests/unit/test_gateway/test_capabilities.py @@ -25,11 +25,11 @@ from trustgraph.gateway.capabilities import ( class _Identity: """Stand-in for auth.Identity — under the IAM contract it has - just ``handle``, ``workspace``, ``principal_id``, ``source``.""" + just ``handle``, ``default_workspace``, ``principal_id``, ``source``.""" - def __init__(self, handle="user-1", workspace="default"): + def __init__(self, handle="user-1", default_workspace="default"): self.handle = handle - self.workspace = workspace + self.default_workspace = default_workspace self.principal_id = handle self.source = "api-key" @@ -105,14 +105,14 @@ class TestEnforceWorkspace: async def test_default_fills_from_identity(self): data = {"operation": "x"} auth = _allow_auth() - await enforce_workspace(data, _Identity(workspace="default"), auth) + await enforce_workspace(data, _Identity(default_workspace="default"), auth) assert data["workspace"] == "default" @pytest.mark.asyncio async def test_caller_supplied_workspace_kept(self): data = {"workspace": "acme", "operation": "x"} auth = _allow_auth() - await enforce_workspace(data, _Identity(workspace="default"), auth) + await enforce_workspace(data, _Identity(default_workspace="default"), auth) assert data["workspace"] == "acme" @pytest.mark.asyncio diff --git a/tests/unit/test_gateway/test_socket_graceful_shutdown.py b/tests/unit/test_gateway/test_socket_graceful_shutdown.py index 6c3e323b..8116aa51 100644 --- a/tests/unit/test_gateway/test_socket_graceful_shutdown.py +++ b/tests/unit/test_gateway/test_socket_graceful_shutdown.py @@ -28,7 +28,7 @@ TEST_CAP = "graph:write" def _valid_identity(): return Identity( handle="test-user", - workspace="default", + default_workspace="default", principal_id="test-user", source="api-key", ) diff --git a/tests/unit/test_iam/test_noauth_handler.py b/tests/unit/test_iam/test_noauth_handler.py index 38461b62..7bfdac0c 100644 --- a/tests/unit/test_iam/test_noauth_handler.py +++ b/tests/unit/test_iam/test_noauth_handler.py @@ -32,7 +32,7 @@ class TestAuthenticateAnonymous: ) assert resp.error is None assert resp.resolved_user_id == "anon" - assert resp.resolved_workspace == "ws" + assert resp.resolved_default_workspace == "ws" assert "admin" in list(resp.resolved_roles) @pytest.mark.asyncio @@ -44,7 +44,7 @@ class TestAuthenticateAnonymous: _make_request(operation="authenticate-anonymous") ) assert resp.resolved_user_id == "dev-user" - assert resp.resolved_workspace == "dev-ws" + assert resp.resolved_default_workspace == "dev-ws" class TestResolveApiKey: @@ -57,7 +57,7 @@ class TestResolveApiKey: ) assert resp.error is None assert resp.resolved_user_id == "anonymous" - assert resp.resolved_workspace == "default" + assert resp.resolved_default_workspace == "default" class TestAuthorise: diff --git a/tests/unit/test_provenance/test_dag_structure.py b/tests/unit/test_provenance/test_dag_structure.py index e65ef2e3..d1ce097a 100644 --- a/tests/unit/test_provenance/test_dag_structure.py +++ b/tests/unit/test_provenance/test_dag_structure.py @@ -107,6 +107,7 @@ class TestGraphRagDagStructure: embeddings_client = AsyncMock() graph_embeddings_client = AsyncMock() triples_client = AsyncMock() + reranker_client = AsyncMock() embeddings_client.embed.return_value = [[0.1, 0.2]] graph_embeddings_client.query.return_value = [ @@ -121,27 +122,22 @@ class TestGraphRagDagStructure: ] triples_client.query.return_value = [] + result = MagicMock() + result.document_id = "0" + result.query_id = "0" + result.score = 0.95 + reranker_client.rerank.return_value = [result] + async def mock_prompt(template_id, variables=None, **kwargs): if template_id == "extract-concepts": return PromptResult(response_type="text", text="concept") - elif template_id == "kg-edge-scoring": - edges = variables.get("knowledge", []) - return PromptResult( - response_type="jsonl", - objects=[{"id": e["id"], "score": 10} for e in edges], - ) - elif template_id == "kg-edge-reasoning": - edges = variables.get("knowledge", []) - return PromptResult( - response_type="jsonl", - objects=[{"id": e["id"], "reasoning": "relevant"} for e in edges], - ) elif template_id == "kg-synthesis": return PromptResult(response_type="text", text="Answer.") return PromptResult(response_type="text", text="") prompt_client.prompt.side_effect = mock_prompt - return prompt_client, embeddings_client, graph_embeddings_client, triples_client + return (prompt_client, embeddings_client, graph_embeddings_client, + triples_client, reranker_client) @pytest.mark.asyncio async def test_dag_chain(self, mock_clients): @@ -152,7 +148,7 @@ class TestGraphRagDagStructure: events.append({"explain_id": explain_id, "triples": triples}) await rag.query( - query="test", explain_callback=explain_cb, edge_score_limit=0, + query="test", explain_callback=explain_cb, ) dag = _collect_events(events) diff --git a/tests/unit/test_retrieval/test_document_rag.py b/tests/unit/test_retrieval/test_document_rag.py index 7762b543..a08bc718 100644 --- a/tests/unit/test_retrieval/test_document_rag.py +++ b/tests/unit/test_retrieval/test_document_rag.py @@ -101,27 +101,27 @@ class TestQuery: assert query.rag == mock_rag assert query.collection == "test_collection" assert query.verbose is False - assert query.doc_limit == 20 # Default value + assert query.fetch_limit == 20 # Default value - def test_query_initialization_with_custom_doc_limit(self): - """Test Query initialization with custom doc_limit""" + def test_query_initialization_with_custom_fetch_limit(self): + """Test Query initialization with custom fetch_limit""" # Create mock DocumentRag mock_rag = MagicMock() - # Initialize Query with custom doc_limit + # Initialize Query with custom fetch_limit query = Query( rag=mock_rag, workspace="test_workspace", collection="custom_collection", verbose=True, - doc_limit=50 + fetch_limit=50 ) # Verify initialization assert query.rag == mock_rag assert query.collection == "custom_collection" assert query.verbose is True - assert query.doc_limit == 50 + assert query.fetch_limit == 50 @pytest.mark.asyncio async def test_extract_concepts(self): @@ -224,7 +224,7 @@ class TestQuery: workspace="test_workspace", collection="test_collection", verbose=False, - doc_limit=15 + fetch_limit=15 ) # Call get_docs with concepts list @@ -377,7 +377,7 @@ class TestQuery: workspace="test_workspace", collection="test_collection", verbose=True, - doc_limit=5 + fetch_limit=5 ) # Call get_docs with concepts @@ -615,7 +615,7 @@ class TestQuery: workspace="test_workspace", collection="test_collection", verbose=False, - doc_limit=10 + fetch_limit=10 ) docs, chunk_ids = await query.get_docs(["concept A", "concept B"]) diff --git a/tests/unit/test_retrieval/test_document_rag_diversity_selection.py b/tests/unit/test_retrieval/test_document_rag_diversity_selection.py new file mode 100644 index 00000000..6dcd9458 --- /dev/null +++ b/tests/unit/test_retrieval/test_document_rag_diversity_selection.py @@ -0,0 +1,92 @@ +from trustgraph.retrieval.document_rag.rerank import ( + RerankCandidate, normalize_candidate_scores, mmr_select, + _pair_diversity_penalty +) + +def candidate(index, chunk_id, text, score): + return RerankCandidate( + index=index, + chunk_id=chunk_id, + text=text, + reranker_score=score, + ) + + +def test_normalize_candidate_scores_min_max_scales_raw_scores(): + candidates = [ + candidate(0, "a", "alpha", -2.0), + candidate(1, "b", "beta", 0.0), + candidate(2, "c", "gamma", 4.0), + ] + + normalized = normalize_candidate_scores(candidates) + + assert normalized[0].normalized_score == 0.0 + assert normalized[1].normalized_score == 1.0 / 3.0 + assert normalized[2].normalized_score == 1.0 + + +def test_normalize_candidate_scores_handles_equal_scores(): + candidates = [ + candidate(0, "a", "alpha", 3.0), + candidate(1, "b", "beta", 3.0), + candidate(2, "c", "gamma", 3.0), + ] + + normalized = normalize_candidate_scores(candidates) + + assert [c.normalized_score for c in normalized] == [0.5, 0.5, 0.5] + + +def test_mmr_select_limits_results(): + candidates = [ + candidate(0, "a", "alpha policy", 0.9), + candidate(1, "b", "beta refund", 0.8), + candidate(2, "c", "gamma shipping", 0.7), + ] + + selected = mmr_select(candidates, limit=2) + + assert len(selected) == 2 + + +def test_mmr_select_prefers_highest_reranker_score_first(): + candidates = [ + candidate(0, "a", "weakly relevant text", 0.1), + candidate(1, "b", "strongly relevant answer", 10.0), + candidate(2, "c", "medium relevant text", 5.0), + ] + + selected = mmr_select(candidates, limit=1) + + assert selected[0].chunk_id == "b" + + +def test_mmr_select_penalizes_near_duplicate_chunks(): + candidates = [ + candidate(0, "a", "apple banana fruit return policy", 1.00), + candidate(1, "b", "apple banana fruit return policy duplicate", 0.95), + candidate(2, "c", "engine motor vehicle warranty", 0.90), + ] + + selected = mmr_select( + candidates, + limit=2, + lambda_mult=0.2, + token_overlap_weight=1.0, + ) + + assert [c.chunk_id for c in selected] == ["a", "c"] + + +def test_pair_diversity_penalty_is_clamped(): + left = candidate(0, "a", "same same same", 1.0) + right = candidate(1, "b", "same same same", 0.9) + + penalty = _pair_diversity_penalty( + left, + right, + token_overlap_weight=10.0, + ) + + assert penalty == 1.0 diff --git a/tests/unit/test_retrieval/test_document_rag_rerank.py b/tests/unit/test_retrieval/test_document_rag_rerank.py new file mode 100644 index 00000000..67b3a2b1 --- /dev/null +++ b/tests/unit/test_retrieval/test_document_rag_rerank.py @@ -0,0 +1,550 @@ +""" +Tests for the optional cross-encoder reranking pass in DocumentRag.query(). + +Two behaviours are covered: + + 1. No-op: when no reranker_client is wired (the default), query() must feed + the LLM the exact same chunks, in the same order, that retrieval produced + - byte-identical to the pre-reranker behaviour - and must NOT emit a + chunk-selection provenance event. + + 2. Rerank: when a reranker_client is wired, the retrieved chunks are reordered + and truncated according to the reranker's results, the LLM receives the + reranked top-N, and a tg:ChunkSelection (focus) provenance event is emitted + carrying the per-surviving-chunk scores and chunk references. + +These are pure orchestration tests - the reranker is a stub, so there is no +torch / network dependency. +""" + +import pytest +from unittest.mock import AsyncMock +from dataclasses import dataclass + +from trustgraph.retrieval.document_rag.document_rag import DocumentRag +from trustgraph.base import PromptResult +from trustgraph.schema import RerankerResult + +from trustgraph.provenance.namespaces import ( + RDF_TYPE, PROV_WAS_DERIVED_FROM, + TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION, + TG_FOCUS, TG_SYNTHESIS, + TG_CHUNK_SELECTION, TG_SELECTED_CHUNK, TG_SCORE, TG_DOCUMENT, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def find_triple(triples, predicate, subject=None): + for t in triples: + if t.p.iri == predicate: + if subject is None or t.s.iri == subject: + return t + return None + + +def find_triples(triples, predicate, subject=None): + return [ + t for t in triples + if t.p.iri == predicate + and (subject is None or t.s.iri == subject) + ] + + +def has_type(triples, subject, rdf_type): + return any( + t.s.iri == subject and t.p.iri == RDF_TYPE and t.o.iri == rdf_type + for t in triples + ) + + +def derived_from(triples, subject): + t = find_triple(triples, PROV_WAS_DERIVED_FROM, subject) + return t.o.iri if t else None + + +@dataclass +class ChunkMatch: + """Mimics the result from doc_embeddings_client.query().""" + chunk_id: str + + +# --------------------------------------------------------------------------- +# Fixtures: three retrievable chunks +# --------------------------------------------------------------------------- + +CHUNK_A = "urn:chunk:policy-doc-1:chunk-0" +CHUNK_B = "urn:chunk:policy-doc-1:chunk-1" +CHUNK_C = "urn:chunk:policy-doc-1:chunk-2" + +CHUNK_A_CONTENT = "Customers may return items within 30 days of purchase." +CHUNK_B_CONTENT = "Our stores are open from 9am to 5pm on weekdays." +CHUNK_C_CONTENT = "Refunds are processed to the original payment method." + +# Retrieval (post-dedupe) order is A, B, C. +ORDERED_CONTENT = [CHUNK_A_CONTENT, CHUNK_B_CONTENT, CHUNK_C_CONTENT] +ORDERED_CHUNK_IDS = [CHUNK_A, CHUNK_B, CHUNK_C] + + +def build_mock_clients(): + """ + Build mock subsidiary clients for a document-rag query returning three + distinct chunks (A, B, C) in that order. + """ + prompt_client = AsyncMock() + embeddings_client = AsyncMock() + doc_embeddings_client = AsyncMock() + fetch_chunk = AsyncMock() + + async def mock_prompt(template_id, variables=None, **kwargs): + if template_id == "extract-concepts": + return PromptResult(response_type="text", text="return policy\nrefund") + return PromptResult(response_type="text", text="") + + prompt_client.prompt.side_effect = mock_prompt + + embeddings_client.embed.return_value = [[0.1, 0.2], [0.3, 0.4]] + + # Each concept query returns the same three chunks; dedupe keeps A, B, C. + doc_embeddings_client.query.return_value = [ + ChunkMatch(chunk_id=CHUNK_A), + ChunkMatch(chunk_id=CHUNK_B), + ChunkMatch(chunk_id=CHUNK_C), + ] + + async def mock_fetch(chunk_id): + return { + CHUNK_A: CHUNK_A_CONTENT, + CHUNK_B: CHUNK_B_CONTENT, + CHUNK_C: CHUNK_C_CONTENT, + }[chunk_id] + + fetch_chunk.side_effect = mock_fetch + + prompt_client.document_prompt.return_value = PromptResult( + response_type="text", + text="Items can be returned within 30 days for a full refund.", + ) + + return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk + + +class StubReranker: + """ + Stub reranker_client mirroring RerankerClient.rerank(): returns a fixed, + pre-sorted, truncated list of RerankerResult - exactly the contract the + flashrank service guarantees (sorted desc by score, truncated to limit). + """ + + def __init__(self, results): + self._results = results + self.calls = [] + + async def rerank(self, queries, documents, limit=10, timeout=300): + self.calls.append( + {"queries": queries, "documents": documents, "limit": limit} + ) + return self._results + + +# --------------------------------------------------------------------------- +# 1. No-op: reranker_client=None must not change anything +# --------------------------------------------------------------------------- + +class TestRerankNoOp: + + @pytest.mark.asyncio + async def test_documents_passed_to_llm_are_unchanged(self): + """ + With no reranker wired, document_prompt must receive the retrieved + chunks in the original order and length. + """ + clients = build_mock_clients() + rag = DocumentRag(*clients) # reranker_client defaults to None + + await rag.query(query="What is the return policy?") + + call = rag.prompt_client.document_prompt.call_args + passed_docs = call.kwargs["documents"] + assert passed_docs == ORDERED_CONTENT + + @pytest.mark.asyncio + async def test_no_chunk_selection_event_emitted(self): + """ + Without a reranker, the provenance chain is the original 4 stages: + question, grounding, exploration, synthesis - no focus stage. + """ + clients = build_mock_clients() + rag = DocumentRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + assert len(events) == 4 + types = [ + TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_SYNTHESIS, + ] + for i, expected in enumerate(types): + assert has_type(events[i]["triples"], events[i]["explain_id"], expected) + + # No chunk-selection entity anywhere. + for e in events: + assert not any( + t.o.iri == TG_CHUNK_SELECTION + for t in e["triples"] + if t.p.iri == RDF_TYPE + ) + + @pytest.mark.asyncio + async def test_synthesis_derives_from_exploration_when_no_rerank(self): + """ + No-op lineage is unchanged: synthesis derives from exploration + (there is no focus stage). Guards the conditional synthesis parent. + """ + clients = build_mock_clients() + rag = DocumentRag(*clients) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + # events: question, grounding, exploration, synthesis + exp_uri = events[2]["explain_id"] + syn_event = events[3] + assert derived_from(syn_event["triples"], syn_event["explain_id"]) == exp_uri + + +# --------------------------------------------------------------------------- +# 2. Rerank: reorder + truncate + provenance +# --------------------------------------------------------------------------- + +class TestRerankActive: + + def _reranker_keeping_C_then_A(self): + # Reranker says chunk index 2 (C) is best, then index 0 (A); B dropped. + # Pre-sorted desc by score and truncated to limit, per the contract. + return StubReranker([ + RerankerResult(document_id="2", query_id="0", score=0.95), + RerankerResult(document_id="0", query_id="0", score=0.42), + ]) + + @pytest.mark.asyncio + async def test_documents_reordered_and_truncated(self): + clients = build_mock_clients() + reranker = self._reranker_keeping_C_then_A() + rag = DocumentRag(*clients, reranker_client=reranker) + + await rag.query(query="What is the return policy?") + + call = rag.prompt_client.document_prompt.call_args + passed_docs = call.kwargs["documents"] + assert passed_docs == [CHUNK_C_CONTENT, CHUNK_A_CONTENT] + + @pytest.mark.asyncio + async def test_reranker_called_with_single_query_and_all_docs(self): + clients = build_mock_clients() + reranker = self._reranker_keeping_C_then_A() + rag = DocumentRag(*clients, reranker_client=reranker) + + await rag.query(query="What is the return policy?", doc_limit=2) + + assert len(reranker.calls) == 1 + c = reranker.calls[0] + assert c["queries"] == [{"id": "0", "text": "What is the return policy?"}] + assert c["documents"] == [ + {"id": "0", "text": CHUNK_A_CONTENT}, + {"id": "1", "text": CHUNK_B_CONTENT}, + {"id": "2", "text": CHUNK_C_CONTENT}, + ] + # The rerank narrows down to the final doc_limit, NOT fetch_limit + # (fetch_limit is the over-fetched candidate pool size). + assert c["limit"] == 2 + + @pytest.mark.asyncio + async def test_explicit_fetch_limit_over_fetches_then_narrows(self): + """ + Semantic guard for the value of reranking AND the maintainer's two-limit + contract: an explicit fetch_limit makes retrieval OVER-FETCH a wider + candidate pool so the cross-encoder can surface chunks the bi-encoder + ranked outside the final doc_limit, then the rerank narrows the pool back + down to doc_limit. The fetch_limit is honoured directly (caller controls + how hard the reranker works), not overridden by any heuristic. + """ + clients = build_mock_clients() + prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients + reranker = self._reranker_keeping_C_then_A() + # Candidate pool (fetch_limit=60) >> final doc_limit (6). + rag = DocumentRag(*clients, reranker_client=reranker) + + await rag.query( + query="What is the return policy?", doc_limit=6, fetch_limit=60, + ) + + # Over-fetch: the embeddings store is queried with the fetch_limit + # budget (60 // 2 concept-vectors = 30 per concept), NOT the doc_limit + # budget (6 // 2 = 3). This is the bug guard. + q_limit = doc_embeddings_client.query.call_args.kwargs["limit"] + assert q_limit == 30 + + # Narrow: the rerank keeps the final doc_limit (6), not fetch_limit. + assert reranker.calls[0]["limit"] == 6 + + @pytest.mark.asyncio + async def test_default_fetch_limit_derives_overfetch_from_doc_limit(self): + """ + With no fetch_limit passed to query(), the candidate pool falls back to + the OVERFETCH_FACTOR x doc_limit heuristic, so over-fetch scales with + doc_limit and reranking keeps its recall benefit out of the box. + """ + clients = build_mock_clients() + prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients + reranker = self._reranker_keeping_C_then_A() + # No fetch_limit -> heuristic default. + rag = DocumentRag(*clients, reranker_client=reranker) + + await rag.query(query="What is the return policy?", doc_limit=20) + + # fetch = 3 x 20 = 60 -> 60 // 2 concept-vectors = 30 per concept. + q_limit = doc_embeddings_client.query.call_args.kwargs["limit"] + assert q_limit == 30 + # Rerank narrows to the final doc_limit (20). + assert reranker.calls[0]["limit"] == 20 + + @pytest.mark.asyncio + async def test_fetch_limit_floored_at_doc_limit(self): + """ + A fetch_limit below doc_limit is floored up to doc_limit: retrieval must + never fetch fewer candidates than the rerank is asked to keep, else the + prompt could not be filled. + """ + clients = build_mock_clients() + prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients + reranker = self._reranker_keeping_C_then_A() + rag = DocumentRag(*clients, reranker_client=reranker) + + await rag.query( + query="What is the return policy?", doc_limit=10, fetch_limit=4, + ) + + # fetch = max(4, 10) = 10 -> 10 // 2 concept-vectors = 5 per concept. + q_limit = doc_embeddings_client.query.call_args.kwargs["limit"] + assert q_limit == 5 + assert reranker.calls[0]["limit"] == 10 + + @pytest.mark.asyncio + async def test_chunk_selection_event_emitted(self): + clients = build_mock_clients() + reranker = self._reranker_keeping_C_then_A() + rag = DocumentRag(*clients, reranker_client=reranker) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + # Now 5 stages: question, grounding, exploration, focus, synthesis. + assert len(events) == 5 + ordered_types = [ + TG_DOC_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION, + TG_FOCUS, TG_SYNTHESIS, + ] + for i, expected in enumerate(ordered_types): + assert has_type(events[i]["triples"], events[i]["explain_id"], expected) + + @pytest.mark.asyncio + async def test_chunk_selection_carries_scores_and_chunk_refs(self): + clients = build_mock_clients() + reranker = self._reranker_keeping_C_then_A() + rag = DocumentRag(*clients, reranker_client=reranker) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + focus_event = events[3] + foc_uri = focus_event["explain_id"] + triples = focus_event["triples"] + + # focus is derived from exploration + exp_uri = events[2]["explain_id"] + assert derived_from(triples, foc_uri) == exp_uri + + # Two ChunkSelection sub-entities, linked from focus. + sel_links = find_triples(triples, TG_SELECTED_CHUNK, foc_uri) + assert len(sel_links) == 2 + + # Each selection has a ChunkSelection type, a chunk document ref and a score. + chunk_refs = set() + scores = set() + for link in sel_links: + sel_uri = link.o.iri + assert has_type(triples, sel_uri, TG_CHUNK_SELECTION) + doc_ref = find_triple(triples, TG_DOCUMENT, sel_uri) + assert doc_ref is not None + chunk_refs.add(doc_ref.o.iri) + score_t = find_triple(triples, TG_SCORE, sel_uri) + assert score_t is not None + scores.add(score_t.o.value) + + # Surviving chunks are C and A (B dropped), with the reranker scores. + assert chunk_refs == {CHUNK_C, CHUNK_A} + assert scores == {"0.95", "0.42"} + + @pytest.mark.asyncio + async def test_all_focus_triples_in_retrieval_graph(self): + clients = build_mock_clients() + reranker = self._reranker_keeping_C_then_A() + rag = DocumentRag(*clients, reranker_client=reranker) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + explain_callback=explain_callback, + ) + + for t in events[3]["triples"]: + assert t.g == "urn:graph:retrieval" + + @pytest.mark.asyncio + async def test_synthesis_derives_from_focus_when_reranking(self): + """ + When reranking runs, synthesis must derive from the focus node (the + reranked chunks actually fed to the LLM), mirroring GraphRAG - not from + exploration, which would leave focus as a dangling branch and + misrepresent what fed the answer. + """ + clients = build_mock_clients() + reranker = self._reranker_keeping_C_then_A() + rag = DocumentRag(*clients, reranker_client=reranker) + + events = [] + + async def explain_callback(triples, explain_id): + events.append({"triples": triples, "explain_id": explain_id}) + + await rag.query( + query="What is the return policy?", + doc_limit=2, + explain_callback=explain_callback, + ) + + # events: question, grounding, exploration, focus, synthesis + foc_uri = events[3]["explain_id"] + syn_event = events[4] + assert derived_from(syn_event["triples"], syn_event["explain_id"]) == foc_uri + + @pytest.mark.asyncio + async def test_empty_docs_skips_reranker(self): + """If retrieval returns no chunks, the reranker is never called.""" + clients = build_mock_clients() + prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients + doc_embeddings_client.query.return_value = [] # no matches + + reranker = self._reranker_keeping_C_then_A() + rag = DocumentRag(*clients, reranker_client=reranker) + + await rag.query(query="What is the return policy?") + + assert reranker.calls == [] + +# --------------------------------------------------------------------------- +# 3. Diversity selection: optional MMR after cross-encoder scoring +# --------------------------------------------------------------------------- + + @pytest.mark.asyncio + async def test_diversity_mode_scores_full_candidate_pool_before_selecting(self): + """ + With diversity selection enabled, the cross-encoder should score the full + fetched candidate pool before MMR narrows it down to doc_limit. + """ + clients = build_mock_clients() + reranker = StubReranker([ + RerankerResult(document_id="0", query_id="0", score=1.00), + RerankerResult(document_id="1", query_id="0", score=0.95), + RerankerResult(document_id="2", query_id="0", score=0.90), + ]) + rag = DocumentRag( + *clients, + reranker_client=reranker, + rerank_diversity_mode="mmr", + ) + + await rag.query(query="What is the return policy?", doc_limit=2) + + assert reranker.calls[0]["limit"] == len(ORDERED_CONTENT) + + call = rag.prompt_client.document_prompt.call_args + passed_docs = call.kwargs["documents"] + assert len(passed_docs) == 2 + + + @pytest.mark.asyncio + async def test_diversity_mode_selects_less_redundant_context_set(self): + """ + MMR should use cross-encoder scores as relevance while penalizing redundant + chunks, so a slightly lower-scored but less redundant chunk can be selected. + """ + clients = build_mock_clients() + prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk = clients + + duplicate_a = "apple banana fruit return policy" + duplicate_b = "apple banana fruit return policy duplicate" + diverse_c = "engine motor vehicle warranty" + + async def mock_fetch(chunk_id): + return { + CHUNK_A: duplicate_a, + CHUNK_B: duplicate_b, + CHUNK_C: diverse_c, + }[chunk_id] + + fetch_chunk.side_effect = mock_fetch + + reranker = StubReranker([ + RerankerResult(document_id="0", query_id="0", score=1.00), + RerankerResult(document_id="1", query_id="0", score=0.95), + RerankerResult(document_id="2", query_id="0", score=0.90), + ]) + rag = DocumentRag( + *clients, + reranker_client=reranker, + rerank_diversity_mode="mmr", + rerank_diversity_lambda=0.2, + ) + + await rag.query(query="What is the return policy?", doc_limit=2) + + call = rag.prompt_client.document_prompt.call_args + passed_docs = call.kwargs["documents"] + + assert passed_docs == [duplicate_a, diverse_c] \ No newline at end of file diff --git a/tests/unit/test_retrieval/test_document_rag_reranker_wiring.py b/tests/unit/test_retrieval/test_document_rag_reranker_wiring.py new file mode 100644 index 00000000..bf4337b4 --- /dev/null +++ b/tests/unit/test_retrieval/test_document_rag_reranker_wiring.py @@ -0,0 +1,89 @@ +""" +Cross-layer wiring contract for the Document-RAG reranker (issue #878). + +The Document-RAG processor registers a ``RerankerClientSpec`` for the +``reranker-request`` / ``reranker-response`` roles (see +``retrieval/document_rag/rag.py``). At flow construction every spec runs +``spec.add(flow, processor, definition)``, and ``RequestResponseSpec.add`` +resolves its topics via ``definition["topics"][name]`` - which raises +``KeyError`` if the flow blueprint does not provide those topics. + +This means the monorepo code change is only safe to deploy together with the +companion ``trustgraph-templates`` change that wires ``reranker-request`` / +``reranker-response`` into the Document-RAG flow (mirroring what templates +PR #279 did for GraphRAG via ``graph-store.jsonnet``). These tests pin that +contract from the monorepo side: + + * with the reranker topics present (as the updated templates compile them), + the spec binds cleanly and registers the client; + * without them (the pre-companion blueprint), construction fails fast with a + KeyError naming the missing role - documenting exactly why the templates + change is required. + +No broker/network: the pub/sub backend is mocked (topics are bound at add() +time, connections happen later at start()). +""" + +import pytest +from unittest.mock import MagicMock + +from trustgraph.base import RerankerClientSpec + + +def _flow(): + f = MagicMock() + f.workspace = "ws" + f.name = "document-rag" + f.id = "proc1" + f.consumer = {} + return f + + +def _processor(): + p = MagicMock() + p.pubsub = MagicMock() + p.id = "proc1" + p.taskgroup = MagicMock() + return p + + +def _spec(): + return RerankerClientSpec( + request_name="reranker-request", + response_name="reranker-response", + ) + + +# Topics dict as the UPDATED document-store.jsonnet compiles them +# (verified by compiling the template: reranker-request -> request:tg:reranker:{workspace}:{id}). +DEFINITION_WITH_RERANKER = { + "topics": { + "request": "request:tg:document-rag:ws:id", + "response": "response:tg:document-rag:ws:id", + "reranker-request": "request:tg:reranker:ws:id", + "reranker-response": "response:tg:reranker:ws:id", + } +} + +# Pre-companion blueprint: no reranker topics (document-rag before the templates change). +DEFINITION_WITHOUT_RERANKER = { + "topics": { + "request": "request:tg:document-rag:ws:id", + "response": "response:tg:document-rag:ws:id", + } +} + + +def test_reranker_client_binds_when_flow_provides_topics(): + flow = _flow() + _spec().add(flow, _processor(), DEFINITION_WITH_RERANKER) + # The client consumer is registered against the reranker role. + assert "reranker-request" in flow.consumer + + +def test_reranker_client_keyerrors_without_companion_template_topics(): + with pytest.raises(KeyError) as exc: + _spec().add(_flow(), _processor(), DEFINITION_WITHOUT_RERANKER) + # Fails fast naming the missing role -> the trustgraph-templates companion + # change (wire reranker-request/response into the document-rag flow) is required. + assert "reranker-request" in str(exc.value) diff --git a/tests/unit/test_retrieval/test_document_rag_service.py b/tests/unit/test_retrieval/test_document_rag_service.py index dde3acc1..2bdf3959 100644 --- a/tests/unit/test_retrieval/test_document_rag_service.py +++ b/tests/unit/test_retrieval/test_document_rag_service.py @@ -66,6 +66,7 @@ class TestDocumentRagService: workspace=ANY, # Workspace comes from flow.workspace (mock) collection="test_coll_1", # Must be from message, not hardcoded default doc_limit=5, + fetch_limit=0, # Unset -> core derives the candidate pool explain_callback=ANY, # Explainability callback is always passed save_answer_callback=ANY, # Librarian save callback is always passed ) diff --git a/tests/unit/test_retrieval/test_graph_rag.py b/tests/unit/test_retrieval/test_graph_rag.py index d1979211..15ffdc9d 100644 --- a/tests/unit/test_retrieval/test_graph_rag.py +++ b/tests/unit/test_retrieval/test_graph_rag.py @@ -15,54 +15,52 @@ class TestGraphRag: def test_graph_rag_initialization_with_defaults(self): """Test GraphRag initialization with default verbose setting""" - # Create mock clients mock_prompt_client = MagicMock() mock_embeddings_client = MagicMock() mock_graph_embeddings_client = MagicMock() mock_triples_client = MagicMock() + mock_reranker_client = MagicMock() - # Initialize GraphRag - graph_rag = GraphRag( - prompt_client=mock_prompt_client, - embeddings_client=mock_embeddings_client, - graph_embeddings_client=mock_graph_embeddings_client, - triples_client=mock_triples_client - ) - - # Verify initialization - assert graph_rag.prompt_client == mock_prompt_client - assert graph_rag.embeddings_client == mock_embeddings_client - assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client - assert graph_rag.triples_client == mock_triples_client - assert graph_rag.verbose is False # Default value - # Verify label_cache is an LRUCacheWithTTL instance - from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL - assert isinstance(graph_rag.label_cache, LRUCacheWithTTL) - - def test_graph_rag_initialization_with_verbose(self): - """Test GraphRag initialization with verbose enabled""" - # Create mock clients - mock_prompt_client = MagicMock() - mock_embeddings_client = MagicMock() - mock_graph_embeddings_client = MagicMock() - mock_triples_client = MagicMock() - - # Initialize GraphRag with verbose=True graph_rag = GraphRag( prompt_client=mock_prompt_client, embeddings_client=mock_embeddings_client, graph_embeddings_client=mock_graph_embeddings_client, triples_client=mock_triples_client, - verbose=True + reranker_client=mock_reranker_client, ) - # Verify initialization assert graph_rag.prompt_client == mock_prompt_client assert graph_rag.embeddings_client == mock_embeddings_client assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client assert graph_rag.triples_client == mock_triples_client + assert graph_rag.reranker_client == mock_reranker_client + assert graph_rag.verbose is False + from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL + assert isinstance(graph_rag.label_cache, LRUCacheWithTTL) + + def test_graph_rag_initialization_with_verbose(self): + """Test GraphRag initialization with verbose enabled""" + mock_prompt_client = MagicMock() + mock_embeddings_client = MagicMock() + mock_graph_embeddings_client = MagicMock() + mock_triples_client = MagicMock() + mock_reranker_client = MagicMock() + + graph_rag = GraphRag( + prompt_client=mock_prompt_client, + embeddings_client=mock_embeddings_client, + graph_embeddings_client=mock_graph_embeddings_client, + triples_client=mock_triples_client, + reranker_client=mock_reranker_client, + verbose=True, + ) + + assert graph_rag.prompt_client == mock_prompt_client + assert graph_rag.embeddings_client == mock_embeddings_client + assert graph_rag.graph_embeddings_client == mock_graph_embeddings_client + assert graph_rag.triples_client == mock_triples_client + assert graph_rag.reranker_client == mock_reranker_client assert graph_rag.verbose is True - # Verify label_cache is an LRUCacheWithTTL instance from trustgraph.retrieval.graph_rag.graph_rag import LRUCacheWithTTL assert isinstance(graph_rag.label_cache, LRUCacheWithTTL) @@ -365,244 +363,162 @@ class TestQuery: assert "workspace" not in c.kwargs @pytest.mark.asyncio - async def test_follow_edges_never_passes_workspace(self): - """Verify follow_edges never passes workspace to query_stream.""" + async def test_hop_and_filter_never_passes_workspace(self): + """Verify hop_and_filter never passes workspace to query_stream.""" mock_rag = MagicMock() mock_triples_client = AsyncMock() + mock_reranker_client = AsyncMock() mock_rag.triples_client = mock_triples_client + mock_rag.reranker_client = mock_reranker_client + mock_rag.label_cache = MagicMock() + mock_rag.label_cache.get.return_value = None mock_triple = MagicMock() - mock_triple.s, mock_triple.p, mock_triple.o = "e1", "p1", "o1" + mock_triple.s = "e1" + mock_triple.p = "p1" + mock_triple.o = "o1" mock_triples_client.query_stream.return_value = [mock_triple] + mock_triples_client.query.return_value = [] + + result = MagicMock() + result.document_id = "0" + result.query_id = "0" + result.score = 0.9 + mock_reranker_client.rerank.return_value = [result] query = Query( rag=mock_rag, collection="test_collection", verbose=False, - triple_limit=10 + triple_limit=10, ) - subgraph = set() - await query.follow_edges("e1", subgraph, path_length=1) + await query.hop_and_filter(["e1"], ["concept"]) for c in mock_triples_client.query_stream.call_args_list: assert "workspace" not in c.kwargs @pytest.mark.asyncio - async def test_follow_edges_basic_functionality(self): - """Test Query.follow_edges method basic triple discovery""" + async def test_hop_and_filter_basic_functionality(self): + """Test hop_and_filter retrieves edges and scores them with reranker.""" mock_rag = MagicMock() mock_triples_client = AsyncMock() + mock_reranker_client = AsyncMock() mock_rag.triples_client = mock_triples_client + mock_rag.reranker_client = mock_reranker_client + mock_rag.label_cache = MagicMock() + mock_rag.label_cache.get.return_value = None - mock_triple1 = MagicMock() - mock_triple1.s, mock_triple1.p, mock_triple1.o = "entity1", "predicate1", "object1" + mock_triple = MagicMock() + mock_triple.s = "entity1" + mock_triple.p = "predicate1" + mock_triple.o = "object1" + mock_triples_client.query_stream.return_value = [mock_triple] + mock_triples_client.query.return_value = [] - mock_triple2 = MagicMock() - mock_triple2.s, mock_triple2.p, mock_triple2.o = "subject2", "entity1", "object2" - - mock_triple3 = MagicMock() - mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1" - - mock_triples_client.query_stream.side_effect = [ - [mock_triple1], # s=ent - [mock_triple2], # p=ent - [mock_triple3], # o=ent - ] + result = MagicMock() + result.document_id = "0" + result.query_id = "0" + result.score = 0.95 + mock_reranker_client.rerank.return_value = [result] query = Query( rag=mock_rag, collection="test_collection", verbose=False, - triple_limit=10 + triple_limit=10, + edge_limit=25, ) - subgraph = set() - await query.follow_edges("entity1", subgraph, path_length=1) - - assert mock_triples_client.query_stream.call_count == 3 - - mock_triples_client.query_stream.assert_any_call( - s="entity1", p=None, o=None, limit=10, - collection="test_collection", batch_size=20, g="" - ) - mock_triples_client.query_stream.assert_any_call( - s=None, p="entity1", o=None, limit=10, - collection="test_collection", batch_size=20, g="" - ) - mock_triples_client.query_stream.assert_any_call( - s=None, p=None, o="entity1", limit=10, - collection="test_collection", batch_size=20, g="" + selected, uri_map, edge_meta = await query.hop_and_filter( + ["entity1"], ["test concept"], ) - expected_subgraph = { - ("entity1", "predicate1", "object1"), - ("subject2", "entity1", "object2"), - ("subject3", "predicate3", "entity1") - } - assert subgraph == expected_subgraph + assert len(selected) == 1 + assert len(uri_map) == 1 + assert len(edge_meta) == 1 + + mock_reranker_client.rerank.assert_called_once() + call_kwargs = mock_reranker_client.rerank.call_args + assert call_kwargs.kwargs["limit"] == 25 @pytest.mark.asyncio - async def test_follow_edges_with_path_length_zero(self): - """Test Query.follow_edges method with path_length=0""" + async def test_hop_and_filter_with_empty_frontier(self): + """Test hop_and_filter with no seed entities returns empty.""" + mock_rag = MagicMock() + + query = Query( + rag=mock_rag, + collection="test_collection", + verbose=False, + ) + + selected, uri_map, edge_meta = await query.hop_and_filter([], ["concept"]) + + assert selected == [] + assert uri_map == {} + assert edge_meta == {} + + @pytest.mark.asyncio + async def test_hop_and_filter_filters_label_triples(self): + """Test hop_and_filter skips rdfs:label edges.""" mock_rag = MagicMock() mock_triples_client = AsyncMock() + mock_reranker_client = AsyncMock() mock_rag.triples_client = mock_triples_client + mock_rag.reranker_client = mock_reranker_client + mock_rag.label_cache = MagicMock() + mock_rag.label_cache.get.return_value = None - query = Query( - rag=mock_rag, - collection="test_collection", - verbose=False - ) + label_triple = MagicMock() + label_triple.s = "entity1" + label_triple.p = "http://www.w3.org/2000/01/rdf-schema#label" + label_triple.o = "Entity One" - subgraph = set() - await query.follow_edges("entity1", subgraph, path_length=0) - - mock_triples_client.query_stream.assert_not_called() - assert subgraph == set() - - @pytest.mark.asyncio - async def test_follow_edges_with_max_subgraph_size_limit(self): - """Test Query.follow_edges method respects max_subgraph_size""" - mock_rag = MagicMock() - mock_triples_client = AsyncMock() - mock_rag.triples_client = mock_triples_client + mock_triples_client.query_stream.return_value = [label_triple] + mock_triples_client.query.return_value = [] query = Query( rag=mock_rag, collection="test_collection", verbose=False, - max_subgraph_size=2 + triple_limit=10, ) - subgraph = {("s1", "p1", "o1"), ("s2", "p2", "o2"), ("s3", "p3", "o3")} - - await query.follow_edges("entity1", subgraph, path_length=1) - - mock_triples_client.query_stream.assert_not_called() - assert len(subgraph) == 3 - - @pytest.mark.asyncio - async def test_get_subgraph_method(self): - """Test Query.get_subgraph returns (subgraph, entities, concepts) tuple""" - mock_rag = MagicMock() - - query = Query( - rag=mock_rag, - collection="test_collection", - verbose=False, - max_path_length=1 + selected, uri_map, edge_meta = await query.hop_and_filter( + ["entity1"], ["concept"], ) - # Mock get_entities to return (entities, concepts) tuple - query.get_entities = AsyncMock( - return_value=(["entity1", "entity2"], ["concept1"]) - ) - - query.follow_edges_batch = AsyncMock(return_value=( - { - ("entity1", "predicate1", "object1"), - ("entity2", "predicate2", "object2") - }, - {} - )) - - subgraph, term_map, entities, concepts = await query.get_subgraph("test query") - - query.get_entities.assert_called_once_with("test query") - query.follow_edges_batch.assert_called_once_with(["entity1", "entity2"], 1) - - assert isinstance(subgraph, list) - assert len(subgraph) == 2 - assert ("entity1", "predicate1", "object1") in subgraph - assert ("entity2", "predicate2", "object2") in subgraph - assert entities == ["entity1", "entity2"] - assert concepts == ["concept1"] - - @pytest.mark.asyncio - async def test_get_labelgraph_method(self): - """Test Query.get_labelgraph returns (labeled_edges, uri_map, entities, concepts)""" - mock_rag = MagicMock() - - query = Query( - rag=mock_rag, - collection="test_collection", - verbose=False, - max_subgraph_size=100 - ) - - test_subgraph = [ - ("entity1", "predicate1", "object1"), - ("subject2", "http://www.w3.org/2000/01/rdf-schema#label", "Label Value"), - ("entity3", "predicate3", "object3") - ] - test_entities = ["entity1", "entity3"] - test_concepts = ["concept1"] - query.get_subgraph = AsyncMock( - return_value=(test_subgraph, {}, test_entities, test_concepts) - ) - - async def mock_maybe_label(entity): - label_map = { - "entity1": "Human Entity One", - "predicate1": "Human Predicate One", - "object1": "Human Object One", - "entity3": "Human Entity Three", - "predicate3": "Human Predicate Three", - "object3": "Human Object Three" - } - return label_map.get(entity, entity) - - query.maybe_label = AsyncMock(side_effect=mock_maybe_label) - - labeled_edges, uri_map, entities, concepts = await query.get_labelgraph("test query") - - query.get_subgraph.assert_called_once_with("test query") - - # Label triples filtered out - assert len(labeled_edges) == 2 - - # maybe_label called for non-label triples - assert query.maybe_label.call_count == 6 - - expected_edges = [ - ("Human Entity One", "Human Predicate One", "Human Object One"), - ("Human Entity Three", "Human Predicate Three", "Human Object Three") - ] - assert labeled_edges == expected_edges - - assert len(uri_map) == 2 - assert entities == test_entities - assert concepts == test_concepts + assert selected == [] + mock_reranker_client.rerank.assert_not_called() @pytest.mark.asyncio async def test_graph_rag_query_method(self): """Test GraphRag.query method orchestrates full RAG pipeline with provenance""" - import json from trustgraph.retrieval.graph_rag.graph_rag import edge_id mock_prompt_client = AsyncMock() mock_embeddings_client = AsyncMock() mock_graph_embeddings_client = AsyncMock() mock_triples_client = AsyncMock() + mock_reranker_client = AsyncMock() expected_response = "This is the RAG response" - test_labelgraph = [("Subject", "Predicate", "Object")] - test_edge_id = edge_id("Subject", "Predicate", "Object") + test_selected_edges = [("Subject", "Predicate", "Object")] + test_eid = edge_id("Subject", "Predicate", "Object") test_uri_map = { - test_edge_id: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object") + test_eid: ("http://example.org/subject", "http://example.org/predicate", "http://example.org/object") + } + test_edge_metadata = { + test_eid: {"concept": "test concept", "score": 0.95} } - test_entities = ["http://example.org/subject"] - test_concepts = ["test concept"] - # Mock prompt responses for the multi-step process + mock_embeddings_client.embed.return_value = [[0.1, 0.2]] + mock_graph_embeddings_client.query.return_value = [] + async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None): if prompt_name == "extract-concepts": - return PromptResult(response_type="text", text="") - elif prompt_name == "kg-edge-scoring": - return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "score": 0.9}]) - elif prompt_name == "kg-edge-reasoning": - return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "reasoning": "relevant"}]) + return PromptResult(response_type="text", text="test concept") elif prompt_name == "kg-synthesis": return PromptResult(response_type="text", text=expected_response) return PromptResult(response_type="text", text="") @@ -614,16 +530,16 @@ class TestQuery: embeddings_client=mock_embeddings_client, graph_embeddings_client=mock_graph_embeddings_client, triples_client=mock_triples_client, - verbose=False + reranker_client=mock_reranker_client, + verbose=False, ) - # Patch Query.get_labelgraph to return test data - original_get_labelgraph = Query.get_labelgraph + original_hop_and_filter = Query.hop_and_filter - async def mock_get_labelgraph(self, query_text): - return test_labelgraph, test_uri_map, test_entities, test_concepts + async def mock_hop_and_filter(self, seed_entities, concepts): + return test_selected_edges, test_uri_map, test_edge_metadata - Query.get_labelgraph = mock_get_labelgraph + Query.hop_and_filter = mock_hop_and_filter provenance_events = [] @@ -636,7 +552,7 @@ class TestQuery: collection="test_collection", entity_limit=25, triple_limit=15, - explain_callback=collect_provenance + explain_callback=collect_provenance, ) response_text, usage = response @@ -650,7 +566,6 @@ class TestQuery: assert len(triples) > 0 assert prov_id.startswith("urn:trustgraph:") - # Verify order assert "question" in provenance_events[0][1] assert "grounding" in provenance_events[1][1] assert "exploration" in provenance_events[2][1] @@ -658,4 +573,4 @@ class TestQuery: assert "synthesis" in provenance_events[4][1] finally: - Query.get_labelgraph = original_get_labelgraph + Query.hop_and_filter = original_hop_and_filter diff --git a/tests/unit/test_retrieval/test_graph_rag_direction_aware_text.py b/tests/unit/test_retrieval/test_graph_rag_direction_aware_text.py new file mode 100644 index 00000000..cc95228a --- /dev/null +++ b/tests/unit/test_retrieval/test_graph_rag_direction_aware_text.py @@ -0,0 +1,353 @@ +""" +Tests for direction-aware reranker text in GraphRAG hop-and-filter. + +The reranker document text varies by traversal direction: +- From S (subject is the frontier entity): text = "{p} {o}" +- From O (object is the frontier entity): text = "{s} {p}" +- From P (predicate is the frontier entity): text = "{s} {o}" +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock + +from trustgraph.retrieval.graph_rag.graph_rag import Query, LRUCacheWithTTL +from trustgraph.schema import Term, IRI, LITERAL + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_rag(reranker_results=None): + """Create a mock GraphRag with all clients stubbed.""" + rag = MagicMock() + rag.label_cache = LRUCacheWithTTL() + rag.triples_client = AsyncMock() + rag.reranker_client = AsyncMock() + + # Label lookups return empty (fall back to URI) + rag.triples_client.query.return_value = [] + + if reranker_results is not None: + rag.reranker_client.rerank.return_value = reranker_results + else: + rag.reranker_client.rerank.return_value = [] + + return rag + + +def _make_query(rag, max_path_length=1, edge_limit=25): + return Query( + rag=rag, + collection="test", + verbose=False, + entity_limit=50, + triple_limit=30, + max_subgraph_size=1000, + max_path_length=max_path_length, + edge_limit=edge_limit, + ) + + +def _make_schema_triple(s, p, o): + """Create a mock triple matching the schema interface.""" + t = MagicMock() + t.s = s + t.p = p + t.o = o + return t + + +def _reranker_result(document_id, query_id="0", score=0.9): + r = MagicMock() + r.document_id = str(document_id) + r.query_id = str(query_id) + r.score = score + return r + + +# --------------------------------------------------------------------------- +# Tests: execute_batch_triple_queries direction tracking +# --------------------------------------------------------------------------- + +class TestDirectionTracking: + + @pytest.mark.asyncio + async def test_from_s_direction(self): + """Triples from s=entity queries are tagged FROM_S.""" + triple = _make_schema_triple("ent1", "pred", "obj") + rag = _make_rag() + + async def query_stream(s=None, p=None, o=None, **kwargs): + if s is not None: + return [triple] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + q = _make_query(rag) + + result = await q.execute_batch_triple_queries(["ent1"], 10) + + from_s = [(t, d) for t, d in result if d == Query.FROM_S] + assert len(from_s) == 1 + assert from_s[0][0] is triple + + @pytest.mark.asyncio + async def test_from_o_direction(self): + """Triples from o=entity queries are tagged FROM_O.""" + triple = _make_schema_triple("subj", "pred", "ent1") + rag = _make_rag() + + async def query_stream(s=None, p=None, o=None, **kwargs): + if o is not None: + return [triple] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + q = _make_query(rag) + + result = await q.execute_batch_triple_queries(["ent1"], 10) + + from_o = [(t, d) for t, d in result if d == Query.FROM_O] + assert len(from_o) == 1 + assert from_o[0][0] is triple + + @pytest.mark.asyncio + async def test_from_p_direction(self): + """Triples from p=entity queries are tagged FROM_P.""" + triple = _make_schema_triple("subj", "ent1", "obj") + rag = _make_rag() + + async def query_stream(s=None, p=None, o=None, **kwargs): + if p is not None: + return [triple] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + q = _make_query(rag) + + result = await q.execute_batch_triple_queries(["ent1"], 10) + + from_p = [(t, d) for t, d in result if d == Query.FROM_P] + assert len(from_p) == 1 + assert from_p[0][0] is triple + + +# --------------------------------------------------------------------------- +# Tests: hop_and_filter reranker document text +# --------------------------------------------------------------------------- + +class TestDirectionAwareRerankerText: + + @pytest.mark.asyncio + async def test_from_s_uses_predicate_object(self): + """From-S traversal: reranker text should be '{p} {o}'.""" + triple = _make_schema_triple( + "http://ex/entity-A", + "http://ex/likes", + "http://ex/entity-B", + ) + reranker_result = _reranker_result(0) + rag = _make_rag(reranker_results=[reranker_result]) + + async def query_stream(s=None, p=None, o=None, **kwargs): + if s is not None: + return [triple] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + + q = _make_query(rag, max_path_length=1, edge_limit=10) + + await q.hop_and_filter( + seed_entities=["http://ex/entity-A"], + concepts=["likes"], + ) + + call_args = rag.reranker_client.rerank.call_args + documents = call_args.kwargs["documents"] + # Text should be "{p} {o}" — the URIs since no labels found + assert len(documents) == 1 + assert documents[0]["text"] == "http://ex/likes http://ex/entity-B" + + @pytest.mark.asyncio + async def test_from_o_uses_subject_predicate(self): + """From-O traversal: reranker text should be '{s} {p}'.""" + triple = _make_schema_triple( + "http://ex/entity-A", + "http://ex/likes", + "http://ex/entity-B", + ) + reranker_result = _reranker_result(0) + rag = _make_rag(reranker_results=[reranker_result]) + + async def query_stream(s=None, p=None, o=None, **kwargs): + if o is not None: + return [triple] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + + q = _make_query(rag, max_path_length=1, edge_limit=10) + + await q.hop_and_filter( + seed_entities=["http://ex/entity-B"], + concepts=["likes"], + ) + + call_args = rag.reranker_client.rerank.call_args + documents = call_args.kwargs["documents"] + assert len(documents) == 1 + assert documents[0]["text"] == "http://ex/entity-A http://ex/likes" + + @pytest.mark.asyncio + async def test_from_p_uses_subject_object(self): + """From-P traversal: reranker text should be '{s} {o}'.""" + triple = _make_schema_triple( + "http://ex/entity-A", + "http://ex/likes", + "http://ex/entity-B", + ) + reranker_result = _reranker_result(0) + rag = _make_rag(reranker_results=[reranker_result]) + + async def query_stream(s=None, p=None, o=None, **kwargs): + if p is not None: + return [triple] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + + q = _make_query(rag, max_path_length=1, edge_limit=10) + + await q.hop_and_filter( + seed_entities=["http://ex/likes"], + concepts=["entity"], + ) + + call_args = rag.reranker_client.rerank.call_args + documents = call_args.kwargs["documents"] + assert len(documents) == 1 + assert documents[0]["text"] == "http://ex/entity-A http://ex/entity-B" + + @pytest.mark.asyncio + async def test_mixed_directions_produce_different_text(self): + """Edges from different directions use different text formats.""" + triple_from_s = _make_schema_triple( + "http://ex/seed", "http://ex/rel", "http://ex/target", + ) + triple_from_o = _make_schema_triple( + "http://ex/other", "http://ex/ref", "http://ex/seed", + ) + + rag = _make_rag(reranker_results=[ + _reranker_result(0), _reranker_result(1), + ]) + + async def query_stream(s=None, p=None, o=None, **kwargs): + if s == "http://ex/seed": + return [triple_from_s] + if o == "http://ex/seed": + return [triple_from_o] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + + q = _make_query(rag, max_path_length=1, edge_limit=10) + + await q.hop_and_filter( + seed_entities=["http://ex/seed"], + concepts=["test"], + ) + + call_args = rag.reranker_client.rerank.call_args + documents = call_args.kwargs["documents"] + texts = {d["text"] for d in documents} + + # From S: "{p} {o}" = "http://ex/rel http://ex/target" + assert "http://ex/rel http://ex/target" in texts + # From O: "{s} {p}" = "http://ex/other http://ex/ref" + assert "http://ex/other http://ex/ref" in texts + + @pytest.mark.asyncio + async def test_labels_applied_to_direction_text(self): + """Labels should be resolved and used in the direction-aware text.""" + triple = _make_schema_triple( + "http://ex/entity-A", + "http://ex/likes", + "http://ex/entity-B", + ) + reranker_result = _reranker_result(0) + rag = _make_rag(reranker_results=[reranker_result]) + + LABEL = "http://www.w3.org/2000/01/rdf-schema#label" + + async def query_stream(s=None, p=None, o=None, **kwargs): + if s is not None and p is None: + return [triple] + return [] + + async def label_query(s=None, p=None, o=None, limit=1, **kwargs): + if p == LABEL: + labels = { + "http://ex/entity-A": "Alice", + "http://ex/likes": "likes", + "http://ex/entity-B": "Bob", + } + if s in labels: + return [MagicMock(o=labels[s])] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + rag.triples_client.query.side_effect = label_query + + q = _make_query(rag, max_path_length=1, edge_limit=10) + + await q.hop_and_filter( + seed_entities=["http://ex/entity-A"], + concepts=["friendship"], + ) + + call_args = rag.reranker_client.rerank.call_args + documents = call_args.kwargs["documents"] + assert len(documents) == 1 + # From S with labels: "{p_label} {o_label}" + assert documents[0]["text"] == "likes Bob" + + @pytest.mark.asyncio + async def test_no_duplicate_text_from_shared_object(self): + """Multiple edges sharing an object should produce distinct texts.""" + triple_a = _make_schema_triple( + "http://ex/cpu-A", "http://ex/hasCategory", "http://ex/Processors", + ) + triple_b = _make_schema_triple( + "http://ex/cpu-B", "http://ex/hasCategory", "http://ex/Processors", + ) + + rag = _make_rag(reranker_results=[ + _reranker_result(0), _reranker_result(1), + ]) + + async def query_stream(s=None, p=None, o=None, **kwargs): + if o == "http://ex/Processors": + return [triple_a, triple_b] + return [] + + rag.triples_client.query_stream.side_effect = query_stream + + q = _make_query(rag, max_path_length=1, edge_limit=10) + + await q.hop_and_filter( + seed_entities=["http://ex/Processors"], + concepts=["CPUs"], + ) + + call_args = rag.reranker_client.rerank.call_args + documents = call_args.kwargs["documents"] + texts = [d["text"] for d in documents] + + assert len(texts) == 2 + # From O: "{s} {p}" — subjects differ, so texts differ + assert texts[0] != texts[1] + assert "http://ex/cpu-A" in texts[0] + assert "http://ex/cpu-B" in texts[1] diff --git a/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py b/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py index 1eb0dd72..bc2cb368 100644 --- a/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py +++ b/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py @@ -20,7 +20,7 @@ from trustgraph.provenance.namespaces import ( TG_GRAPH_RAG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, TG_ANSWER_TYPE, TG_QUERY, TG_CONCEPT, TG_ENTITY, TG_EDGE_COUNT, - TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, + TG_SELECTED_EDGE, TG_EDGE, TG_SCORE, TG_EDGE_SELECTION, ) @@ -91,17 +91,17 @@ def build_mock_clients(): 1. prompt_client.prompt("extract-concepts", ...) -> concepts 2. embeddings_client.embed(concepts) -> vectors 3. graph_embeddings_client.query(vector, ...) -> entity matches - 4. triples_client.query_stream(s/p/o, ...) -> edges (follow_edges_batch) + 4. triples_client.query_stream(s/p/o, ...) -> edges (hop_and_filter) 5. triples_client.query(s, LABEL, ...) -> labels (maybe_label) - 6. prompt_client.prompt("kg-edge-scoring", ...) -> scored edges - 7. prompt_client.prompt("kg-edge-reasoning", ...) -> reasoning - 8. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns []) - 9. prompt_client.prompt("kg-synthesis", ...) -> final answer + 6. reranker_client.rerank(queries, documents, limit) -> scored edges + 7. triples_client.query(s, TG_CONTAINS, ...) -> doc tracing (returns []) + 8. prompt_client.prompt("kg-synthesis", ...) -> final answer """ prompt_client = AsyncMock() embeddings_client = AsyncMock() graph_embeddings_client = AsyncMock() triples_client = AsyncMock() + reranker_client = AsyncMock() # 1. Concept extraction prompt_responses = {} @@ -116,7 +116,7 @@ def build_mock_clients(): EmbeddingMatch(entity=Term(type=IRI, iri=ENTITY_B)), ] - # 4. Triple queries (follow_edges_batch) - return our edges + # 4. Triple queries (hop_and_filter) - return our edges kg_triples = [ make_schema_triple(*EDGE_1), make_schema_triple(*EDGE_2), @@ -130,9 +130,18 @@ def build_mock_clients(): return [] # No labels found, will fall back to URI triples_client.query.side_effect = mock_label_query - # 6+7. Edge scoring and reasoning: dynamically score/reason about - # whatever edges the query method sends us, since edge IDs are computed - # from str(Term) representations which include the full dataclass repr. + # 6. Reranker: select all documents with high scores + async def mock_rerank(queries, documents, limit): + results = [] + for i, doc in enumerate(documents): + result = MagicMock() + result.document_id = doc["id"] + result.query_id = queries[0]["id"] if queries else "0" + result.score = 0.9 - (i * 0.1) + results.append(result) + return results[:limit] + reranker_client.rerank.side_effect = mock_rerank + synthesis_answer = "Quantum computing applies physics principles to computation." async def mock_prompt(template_id, variables=None, **kwargs): @@ -141,26 +150,6 @@ def build_mock_clients(): response_type="text", text=prompt_responses["extract-concepts"], ) - elif template_id == "kg-edge-scoring": - # Score all edges highly, using the IDs that GraphRag computed - edges = variables.get("knowledge", []) - return PromptResult( - response_type="jsonl", - objects=[ - {"id": e["id"], "score": 10 - i} - for i, e in enumerate(edges) - ], - ) - elif template_id == "kg-edge-reasoning": - # Provide reasoning for each edge - edges = variables.get("knowledge", []) - return PromptResult( - response_type="jsonl", - objects=[ - {"id": e["id"], "reasoning": f"Relevant edge {i}"} - for i, e in enumerate(edges) - ], - ) elif template_id == "kg-synthesis": return PromptResult( response_type="text", @@ -170,7 +159,8 @@ def build_mock_clients(): prompt_client.prompt.side_effect = mock_prompt - return prompt_client, embeddings_client, graph_embeddings_client, triples_client + return (prompt_client, embeddings_client, graph_embeddings_client, + triples_client, reranker_client) # --------------------------------------------------------------------------- @@ -197,7 +187,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, # skip semantic pre-filter for simplicity + ) assert len(events) == 5, ( @@ -222,7 +212,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + ) expected_types = [ @@ -260,7 +250,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + ) uris = [e["explain_id"] for e in events] @@ -297,7 +287,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + ) q_uri = events[0]["explain_id"] @@ -320,7 +310,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + ) gnd_uri = events[1]["explain_id"] @@ -344,7 +334,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + ) exp_uri = events[2]["explain_id"] @@ -355,10 +345,10 @@ class TestGraphRagQueryProvenance: assert int(t.o.value) > 0 @pytest.mark.asyncio - async def test_focus_has_selected_edges_with_reasoning(self): + async def test_focus_has_selected_edges_with_concept_and_score(self): """ The focus event should carry selected edges as quoted triples - with reasoning text. + with cross-encoder concept and score metadata. """ clients = build_mock_clients() rag = GraphRag(*clients) @@ -371,7 +361,6 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, ) foc_uri = events[3]["explain_id"] @@ -387,11 +376,19 @@ class TestGraphRagQueryProvenance: for t in edge_t: assert t.o.triple is not None, "tg:edge object must be a quoted triple" - # Should have reasoning - reasoning = find_triples(foc_triples, TG_REASONING) - assert len(reasoning) > 0, "Focus should have reasoning for selected edges" - reasoning_texts = {t.o.value for t in reasoning} - assert any(r for r in reasoning_texts), "Reasoning should not be empty" + # Edge selections should be typed as EdgeSelection + edge_sel_uris = [t.o.iri for t in selected] + for uri in edge_sel_uris: + assert has_type(foc_triples, uri, TG_EDGE_SELECTION) + + # Should have concept and score + concepts = find_triples(foc_triples, TG_CONCEPT) + assert len(concepts) > 0, "Focus should have tg:concept for selected edges" + + scores = find_triples(foc_triples, TG_SCORE) + assert len(scores) > 0, "Focus should have tg:score for selected edges" + for t in scores: + float(t.o.value) # Should be parseable as float @pytest.mark.asyncio async def test_synthesis_is_answer_type(self): @@ -407,7 +404,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + ) syn_uri = events[4]["explain_id"] @@ -429,7 +426,7 @@ class TestGraphRagQueryProvenance: result_text, usage = await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + ) assert result_text == "Quantum computing applies physics principles to computation." @@ -449,7 +446,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + parent_uri=parent, ) @@ -465,7 +462,7 @@ class TestGraphRagQueryProvenance: result_text, usage = await rag.query( query="What is quantum computing?", - edge_score_limit=0, + ) assert result_text == "Quantum computing applies physics principles to computation." @@ -484,7 +481,7 @@ class TestGraphRagQueryProvenance: await rag.query( query="What is quantum computing?", explain_callback=explain_callback, - edge_score_limit=0, + ) for event in events: diff --git a/trustgraph-base/trustgraph/api/async_flow.py b/trustgraph-base/trustgraph/api/async_flow.py index bf0b2ba1..afd48f1b 100644 --- a/trustgraph-base/trustgraph/api/async_flow.py +++ b/trustgraph-base/trustgraph/api/async_flow.py @@ -527,7 +527,8 @@ class AsyncFlowInstance: return result.get("response", "") async def document_rag(self, query: str, collection: str, - doc_limit: int = 10, **kwargs: Any) -> str: + doc_limit: int = 10, fetch_limit: int = 0, + **kwargs: Any) -> str: """ Execute document-based RAG query (non-streaming). @@ -541,7 +542,9 @@ class AsyncFlowInstance: Args: query: User query text collection: Collection identifier containing documents - doc_limit: Maximum number of document chunks to retrieve (default: 10) + doc_limit: Document chunks selected into the prompt (default: 10) + fetch_limit: Candidate chunks fetched from the vector store before + reranking (default: 0 = derive from doc_limit) **kwargs: Additional service-specific parameters Returns: @@ -564,6 +567,7 @@ class AsyncFlowInstance: "query": query, "collection": collection, "doc-limit": doc_limit, + "fetch-limit": fetch_limit, "streaming": False } request_data.update(kwargs) @@ -646,6 +650,16 @@ class AsyncFlowInstance: return await self.request("embeddings", request_data) + async def rerank(self, queries: list, documents: list, limit: int = 10, **kwargs: Any): + request_data = { + "queries": queries, + "documents": documents, + "limit": limit, + } + request_data.update(kwargs) + + return await self.request("reranker", request_data) + async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs: Any): """ Query RDF triples using pattern matching. diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py index 7b38a4b1..9eff3d60 100644 --- a/trustgraph-base/trustgraph/api/async_socket_client.py +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -94,7 +94,9 @@ class AsyncSocketClient: if resp.get("type") == "auth-ok": if not self._workspace_explicit: - self.workspace = resp.get("workspace", self.workspace) + self.workspace = resp.get( + "default_workspace", self.workspace, + ) elif resp.get("type") == "auth-failed": await self._socket.close() raise ProtocolException( @@ -377,12 +379,14 @@ class AsyncSocketFlowInstance: yield chunk.content async def document_rag(self, query: str, collection: str, - doc_limit: int = 10, streaming: bool = False, **kwargs): + doc_limit: int = 10, fetch_limit: int = 0, + streaming: bool = False, **kwargs): """Document RAG with optional streaming""" request = { "query": query, "collection": collection, "doc-limit": doc_limit, + "fetch-limit": fetch_limit, "streaming": streaming } request.update(kwargs) @@ -441,6 +445,19 @@ class AsyncSocketFlowInstance: return await self.client._send_request("embeddings", self.flow_id, request) + async def rerank(self, queries: list, documents: list, limit: int = 10, + **kwargs): + request = { + "queries": queries, + "documents": documents, + "limit": limit, + } + request.update(kwargs) + + return await self.client._send_request( + "reranker", self.flow_id, request, + ) + async def triples_query(self, s=None, p=None, o=None, collection=None, limit=100, **kwargs): """Triple pattern query""" request = {"limit": limit} diff --git a/trustgraph-base/trustgraph/api/explainability.py b/trustgraph-base/trustgraph/api/explainability.py index 656ff95f..74a8f32e 100644 --- a/trustgraph-base/trustgraph/api/explainability.py +++ b/trustgraph-base/trustgraph/api/explainability.py @@ -18,6 +18,7 @@ TG_EDGE_COUNT = TG + "edgeCount" TG_SELECTED_EDGE = TG + "selectedEdge" TG_EDGE = TG + "edge" TG_REASONING = TG + "reasoning" +TG_SCORE = TG + "score" TG_DOCUMENT = TG + "document" TG_CONCEPT = TG + "concept" TG_ENTITY = TG + "entity" @@ -66,10 +67,12 @@ RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label" @dataclass class EdgeSelection: - """A selected edge with reasoning from GraphRAG Focus step.""" + """A selected edge with cross-encoder metadata from GraphRAG Focus step.""" uri: str edge: Optional[Dict[str, str]] = None # {"s": ..., "p": ..., "o": ...} reasoning: str = "" + concept: str = "" + score: Optional[float] = None @dataclass @@ -209,7 +212,7 @@ class Exploration(ExplainEntity): @dataclass class Focus(ExplainEntity): - """Focus entity - selected edges with LLM reasoning (GraphRAG only).""" + """Focus entity - selected edges with cross-encoder scoring (GraphRAG only).""" selected_edge_uris: List[str] = field(default_factory=list) edge_selections: List[EdgeSelection] = field(default_factory=list) @@ -418,14 +421,26 @@ def parse_edge_selection_triples(triples: List[Tuple[str, str, Any]]) -> EdgeSel uri = triples[0][0] if triples else "" edge = None reasoning = "" + concept = "" + score = None for s, p, o in triples: if p == TG_EDGE and isinstance(o, dict): edge = o elif p == TG_REASONING: reasoning = o + elif p == TG_CONCEPT: + concept = o + elif p == TG_SCORE: + try: + score = float(o) + except (ValueError, TypeError): + score = None - return EdgeSelection(uri=uri, edge=edge, reasoning=reasoning) + return EdgeSelection( + uri=uri, edge=edge, reasoning=reasoning, + concept=concept, score=score, + ) def extract_term_value(term: Dict[str, Any]) -> Any: diff --git a/trustgraph-base/trustgraph/api/flow.py b/trustgraph-base/trustgraph/api/flow.py index 961e348b..b9e9487b 100644 --- a/trustgraph-base/trustgraph/api/flow.py +++ b/trustgraph-base/trustgraph/api/flow.py @@ -415,7 +415,7 @@ class FlowInstance: def document_rag( self, query,collection="default", - doc_limit=10, + doc_limit=10, fetch_limit=0, ): """ Execute document-based Retrieval-Augmented Generation (RAG) query. @@ -426,7 +426,9 @@ class FlowInstance: Args: query: Natural language query collection: Collection identifier (default: "default") - doc_limit: Maximum document chunks to retrieve (default: 10) + doc_limit: Document chunks selected into the prompt (default: 10) + fetch_limit: Candidate chunks fetched from the vector store before + reranking (default: 0 = derive from doc_limit) Returns: str: Generated response incorporating document context @@ -447,6 +449,7 @@ class FlowInstance: "query": query, "collection": collection, "doc-limit": doc_limit, + "fetch-limit": fetch_limit, } result = self.request( @@ -491,6 +494,19 @@ class FlowInstance: input )["vectors"] + def rerank(self, queries, documents, limit=10): + + input = { + "queries": queries, + "documents": documents, + "limit": limit, + } + + return self.request( + "service/reranker", + input + ) + def graph_embeddings_query(self, text, collection, limit=10): """ Query knowledge graph entities using semantic similarity. diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index 91bc67a1..efa887a1 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -168,7 +168,9 @@ class SocketClient: if resp.get("type") == "auth-ok": if self.workspace == "default": - self.workspace = resp.get("workspace", self.workspace) + self.workspace = resp.get( + "default_workspace", self.workspace, + ) elif resp.get("type") == "auth-failed": await self._socket.close() raise ProtocolException( @@ -750,6 +752,7 @@ class SocketFlowInstance: query: str, collection: str, doc_limit: int = 10, + fetch_limit: int = 0, streaming: bool = False, **kwargs: Any ) -> Union[TextCompletionResult, Iterator[RAGChunk]]: @@ -762,6 +765,7 @@ class SocketFlowInstance: "query": query, "collection": collection, "doc-limit": doc_limit, + "fetch-limit": fetch_limit, "streaming": streaming } request.update(kwargs) @@ -783,6 +787,7 @@ class SocketFlowInstance: query: str, collection: str, doc_limit: int = 10, + fetch_limit: int = 0, **kwargs: Any ) -> Iterator[Union[RAGChunk, ProvenanceEvent]]: """Execute document-based RAG query with explainability support.""" @@ -790,6 +795,7 @@ class SocketFlowInstance: "query": query, "collection": collection, "doc-limit": doc_limit, + "fetch-limit": fetch_limit, "streaming": True, "explainable": True, } @@ -883,6 +889,19 @@ class SocketFlowInstance: return self.client._send_request_sync("embeddings", self.flow_id, request, False) + def rerank(self, queries: list, documents: list, limit: int = 10, + **kwargs: Any) -> Dict[str, Any]: + request = { + "queries": queries, + "documents": documents, + "limit": limit, + } + request.update(kwargs) + + return self.client._send_request_sync( + "reranker", self.flow_id, request, False, + ) + def triples_query( self, s: Optional[Union[str, Dict[str, Any]]] = None, diff --git a/trustgraph-base/trustgraph/base/__init__.py b/trustgraph-base/trustgraph/base/__init__.py index 6062543b..be905116 100644 --- a/trustgraph-base/trustgraph/base/__init__.py +++ b/trustgraph-base/trustgraph/base/__init__.py @@ -42,6 +42,8 @@ from . dynamic_tool_service import DynamicToolService from . tool_service_client import ToolServiceClientSpec from . agent_client import AgentClientSpec from . structured_query_client import StructuredQueryClientSpec +from . reranker_client import RerankerClientSpec +from . reranker_service import RerankerService from . row_embeddings_query_client import RowEmbeddingsQueryClientSpec from . collection_config_handler import CollectionConfigHandler diff --git a/trustgraph-base/trustgraph/base/iam_client.py b/trustgraph-base/trustgraph/base/iam_client.py index a2878a0a..423abb8e 100644 --- a/trustgraph-base/trustgraph/base/iam_client.py +++ b/trustgraph-base/trustgraph/base/iam_client.py @@ -65,31 +65,25 @@ class IamClient(RequestResponse): async def authenticate_anonymous(self, timeout=IAM_TIMEOUT): """Request anonymous access from the IAM regime. - Returns ``(user_id, workspace, roles)`` if the regime permits - anonymous access, or raises ``RuntimeError`` with error type - ``auth-failed`` if it does not.""" + Returns ``(user_id, default_workspace, roles)`` if the regime + permits anonymous access, or raises ``RuntimeError`` with + error type ``auth-failed`` if it does not.""" resp = await self._request( operation="authenticate-anonymous", timeout=timeout, ) return ( resp.resolved_user_id, - resp.resolved_workspace, + resp.resolved_default_workspace, list(resp.resolved_roles), ) async def resolve_api_key(self, api_key, timeout=IAM_TIMEOUT): """Resolve a plaintext API key to its identity triple. - Returns ``(user_id, workspace, roles)`` or raises + Returns ``(user_id, default_workspace, roles)`` or raises ``RuntimeError`` with error type ``auth-failed`` if the key is - unknown / expired / revoked. - - Note: the ``roles`` value is a regime-internal hint and is - not used by the gateway directly under the IAM contract; - all authorisation decisions go through ``authorise()``. - Returned here only for backward compatibility with callers - that haven't migrated.""" + unknown / expired / revoked.""" resp = await self._request( operation="resolve-api-key", api_key=api_key, @@ -97,7 +91,7 @@ class IamClient(RequestResponse): ) return ( resp.resolved_user_id, - resp.resolved_workspace, + resp.resolved_default_workspace, list(resp.resolved_roles), ) diff --git a/trustgraph-base/trustgraph/base/prompt_client.py b/trustgraph-base/trustgraph/base/prompt_client.py index d4822ece..b1813ba2 100644 --- a/trustgraph-base/trustgraph/base/prompt_client.py +++ b/trustgraph-base/trustgraph/base/prompt_client.py @@ -157,21 +157,6 @@ class PromptClient(RequestResponse): timeout = timeout, ) - async def kg_prompt(self, query, kg, timeout=600, streaming=False, chunk_callback=None): - return await self.prompt( - id = "kg-prompt", - variables = { - "query": query, - "knowledge": [ - { "s": v[0], "p": v[1], "o": v[2] } - for v in kg - ] - }, - timeout = timeout, - streaming = streaming, - chunk_callback = chunk_callback, - ) - async def document_prompt(self, query, documents, timeout=600, streaming=False, chunk_callback=None): return await self.prompt( id = "document-prompt", diff --git a/trustgraph-base/trustgraph/base/reranker_client.py b/trustgraph-base/trustgraph/base/reranker_client.py new file mode 100644 index 00000000..d0bed394 --- /dev/null +++ b/trustgraph-base/trustgraph/base/reranker_client.py @@ -0,0 +1,43 @@ + +from . request_response_spec import RequestResponse, RequestResponseSpec +from .. schema import ( + RerankerRequest, RerankerResponse, + RerankerQuery, RerankerDocument, +) + +class RerankerClient(RequestResponse): + async def rerank(self, queries, documents, limit=10, timeout=300): + + resp = await self.request( + RerankerRequest( + queries=[ + RerankerQuery(query_id=q["id"], query_text=q["text"]) + for q in queries + ], + documents=[ + RerankerDocument( + document_id=d["id"], document_text=d["text"] + ) + for d in documents + ], + limit=limit, + ), + timeout=timeout + ) + + if resp.error: + raise RuntimeError(resp.error.message) + + return resp.results + +class RerankerClientSpec(RequestResponseSpec): + def __init__( + self, request_name, response_name, + ): + super(RerankerClientSpec, self).__init__( + request_name = request_name, + request_schema = RerankerRequest, + response_name = response_name, + response_schema = RerankerResponse, + impl = RerankerClient, + ) diff --git a/trustgraph-base/trustgraph/base/reranker_service.py b/trustgraph-base/trustgraph/base/reranker_service.py new file mode 100644 index 00000000..1da3a8bf --- /dev/null +++ b/trustgraph-base/trustgraph/base/reranker_service.py @@ -0,0 +1,109 @@ + +from __future__ import annotations + +from argparse import ArgumentParser + +import logging + +from .. schema import ( + RerankerRequest, RerankerResponse, RerankerResult, Error, +) +from .. exceptions import TooManyRequests +from .. base import FlowProcessor, ConsumerSpec, ProducerSpec, ParameterSpec + +logger = logging.getLogger(__name__) + +default_ident = "reranker" +default_concurrency = 1 + +class RerankerService(FlowProcessor): + + def __init__(self, **params): + + id = params.get("id") + concurrency = params.get("concurrency", 1) + + super(RerankerService, self).__init__(**params | { + "id": id, + "concurrency": concurrency, + }) + + self.register_specification( + ConsumerSpec( + name = "request", + schema = RerankerRequest, + handler = self.on_request, + concurrency = concurrency, + ) + ) + + self.register_specification( + ProducerSpec( + name = "response", + schema = RerankerResponse + ) + ) + + self.register_specification( + ParameterSpec( + name = "model", + ) + ) + + async def on_request(self, msg, consumer, flow): + + try: + + request = msg.value() + + id = msg.properties()["id"] + + logger.debug(f"Handling reranker request {id}...") + + model = flow("model") + results = await self.on_rerank( + request.queries, request.documents, + request.limit, model=model, + ) + + await flow("response").send( + RerankerResponse( + error = None, + results = results, + ), + properties={"id": id} + ) + + logger.debug("Reranker request handled successfully") + + except TooManyRequests as e: + raise e + + except Exception as e: + + logger.error(f"Exception in reranker service: {e}", exc_info=True) + + logger.info("Sending error response...") + + await flow.producer["response"].send( + RerankerResponse( + error=Error( + type = "reranker-error", + message = str(e), + ), + results=[], + ), + properties={"id": id} + ) + + @staticmethod + def add_args(parser: ArgumentParser) -> None: + + parser.add_argument( + '-c', '--concurrency', + type=int, + default=default_concurrency, + help=f'Concurrent processing threads (default: {default_concurrency})' + ) + + FlowProcessor.add_args(parser) diff --git a/trustgraph-base/trustgraph/clients/prompt_client.py b/trustgraph-base/trustgraph/clients/prompt_client.py index 12c9c194..ff29ec0a 100644 --- a/trustgraph-base/trustgraph/clients/prompt_client.py +++ b/trustgraph-base/trustgraph/clients/prompt_client.py @@ -140,20 +140,6 @@ class PromptClient(BaseClient): timeout=timeout ) - def request_kg_prompt(self, query, kg, timeout=300): - - return self.request( - id="kg-prompt", - variables={ - "query": query, - "knowledge": [ - { "s": v[0], "p": v[1], "o": v[2] } - for v in kg - ] - }, - timeout=timeout - ) - def request_document_prompt(self, query, documents, timeout=300): return self.request( diff --git a/trustgraph-base/trustgraph/messaging/__init__.py b/trustgraph-base/trustgraph/messaging/__init__.py index 9fcfa6f7..097153ac 100644 --- a/trustgraph-base/trustgraph/messaging/__init__.py +++ b/trustgraph-base/trustgraph/messaging/__init__.py @@ -27,6 +27,7 @@ from .translators.rows_query import RowsQueryRequestTranslator, RowsQueryRespons from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, QuestionToStructuredQueryResponseTranslator from .translators.structured_query import StructuredQueryRequestTranslator, StructuredQueryResponseTranslator from .translators.diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator +from .translators.reranker import RerankerRequestTranslator, RerankerResponseTranslator from .translators.collection import CollectionManagementRequestTranslator, CollectionManagementResponseTranslator from .translators.sparql_query import SparqlQueryRequestTranslator, SparqlQueryResponseTranslator @@ -163,6 +164,12 @@ TranslatorRegistry.register_service( SparqlQueryResponseTranslator() ) +TranslatorRegistry.register_service( + "reranker", + RerankerRequestTranslator(), + RerankerResponseTranslator() +) + # Register single-direction translators for document loading TranslatorRegistry.register_request("document", DocumentTranslator()) TranslatorRegistry.register_request("text-document", TextDocumentTranslator()) diff --git a/trustgraph-base/trustgraph/messaging/translators/__init__.py b/trustgraph-base/trustgraph/messaging/translators/__init__.py index 5b5820fa..b0f88e88 100644 --- a/trustgraph-base/trustgraph/messaging/translators/__init__.py +++ b/trustgraph-base/trustgraph/messaging/translators/__init__.py @@ -20,3 +20,4 @@ from .embeddings_query import ( ) from .rows_query import RowsQueryRequestTranslator, RowsQueryResponseTranslator from .diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator +from .reranker import RerankerRequestTranslator, RerankerResponseTranslator diff --git a/trustgraph-base/trustgraph/messaging/translators/iam.py b/trustgraph-base/trustgraph/messaging/translators/iam.py index 1d7bf21c..4a456044 100644 --- a/trustgraph-base/trustgraph/messaging/translators/iam.py +++ b/trustgraph-base/trustgraph/messaging/translators/iam.py @@ -5,6 +5,7 @@ from ...schema import ( UserInput, UserRecord, WorkspaceInput, WorkspaceRecord, ApiKeyInput, ApiKeyRecord, + GroupInput, GrantInput, ) from .base import MessageTranslator @@ -43,12 +44,31 @@ def _api_key_input_from_dict(d): ) +def _group_input_from_dict(d): + if d is None: + return None + return GroupInput( + name=d.get("name", ""), + description=d.get("description", ""), + enabled=d.get("enabled", True), + ) + + +def _grant_input_from_dict(d): + if d is None: + return None + return GrantInput( + capability=d.get("capability", ""), + workspace=d.get("workspace", ""), + ) + + def _user_record_to_dict(r): if r is None: return None return { "id": r.id, - "workspace": r.workspace, + "default_workspace": r.default_workspace, "username": r.username, "name": r.name, "email": r.email, @@ -102,6 +122,15 @@ class IamRequestTranslator(MessageTranslator): data.get("workspace_record") ), key=_api_key_input_from_dict(data.get("key")), + group_id=data.get("group_id", ""), + member_type=data.get("member_type", ""), + member_id=data.get("member_id", ""), + group=_group_input_from_dict(data.get("group")), + grant=_grant_input_from_dict(data.get("grant")), + capability=data.get("capability", ""), + resource_json=data.get("resource_json", ""), + parameters_json=data.get("parameters_json", ""), + authorise_checks=data.get("authorise_checks", ""), ) def encode(self, obj: IamRequest) -> Dict[str, Any]: @@ -109,6 +138,9 @@ class IamRequestTranslator(MessageTranslator): for fname in ( "workspace", "actor", "user_id", "username", "key_id", "api_key", "password", "new_password", + "group_id", "member_type", "member_id", + "capability", "resource_json", "parameters_json", + "authorise_checks", ): v = getattr(obj, fname, "") if v: @@ -135,6 +167,17 @@ class IamRequestTranslator(MessageTranslator): "name": obj.key.name, "expires": obj.key.expires, } + if obj.group is not None: + result["group"] = { + "name": obj.group.name, + "description": obj.group.description, + "enabled": obj.group.enabled, + } + if obj.grant is not None: + result["grant"] = { + "capability": obj.grant.capability, + "workspace": obj.grant.workspace, + } return result @@ -175,8 +218,8 @@ class IamResponseTranslator(MessageTranslator): result["signing_key_public"] = obj.signing_key_public if obj.resolved_user_id: result["resolved_user_id"] = obj.resolved_user_id - if obj.resolved_workspace: - result["resolved_workspace"] = obj.resolved_workspace + if obj.resolved_default_workspace: + result["resolved_default_workspace"] = obj.resolved_default_workspace if obj.resolved_roles: result["resolved_roles"] = list(obj.resolved_roles) if obj.temporary_password: @@ -190,6 +233,23 @@ class IamResponseTranslator(MessageTranslator): # setup, so it can't be dropped by a truthy-only filter. result["bootstrap_available"] = bool(obj.bootstrap_available) + # authorise / authorise-many outputs. + if obj.decision_allow: + result["decision_allow"] = obj.decision_allow + if obj.decision_ttl_seconds: + result["decision_ttl_seconds"] = obj.decision_ttl_seconds + if obj.decisions_json: + result["decisions_json"] = obj.decisions_json + + # Enterprise IAM outputs. + for fname in ( + "group_json", "groups_json", "members_json", + "grants_json", "effective_permissions_json", + ): + v = getattr(obj, fname, "") + if v: + result[fname] = v + return result def encode_with_completion( diff --git a/trustgraph-base/trustgraph/messaging/translators/reranker.py b/trustgraph-base/trustgraph/messaging/translators/reranker.py new file mode 100644 index 00000000..2d5dabc2 --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/reranker.py @@ -0,0 +1,73 @@ +from typing import Dict, Any, Tuple +from ...schema import ( + RerankerRequest, RerankerResponse, + RerankerQuery, RerankerDocument, RerankerResult, +) +from .base import MessageTranslator + + +class RerankerRequestTranslator(MessageTranslator): + + def decode(self, data: Dict[str, Any]) -> RerankerRequest: + return RerankerRequest( + queries=[ + RerankerQuery( + query_id=q["query_id"], + query_text=q["query_text"], + ) + for q in data.get("queries", []) + ], + documents=[ + RerankerDocument( + document_id=d["document_id"], + document_text=d["document_text"], + ) + for d in data.get("documents", []) + ], + limit=data.get("limit", 10), + ) + + def encode(self, obj: RerankerRequest) -> Dict[str, Any]: + return { + "queries": [ + {"query_id": q.query_id, "query_text": q.query_text} + for q in obj.queries + ], + "documents": [ + {"document_id": d.document_id, "document_text": d.document_text} + for d in obj.documents + ], + "limit": obj.limit, + } + + +class RerankerResponseTranslator(MessageTranslator): + + def decode(self, data: Dict[str, Any]) -> RerankerResponse: + return RerankerResponse( + results=[ + RerankerResult( + document_id=r["document_id"], + query_id=r["query_id"], + score=r["score"], + ) + for r in data.get("results", []) + ], + ) + + def encode(self, obj: RerankerResponse) -> Dict[str, Any]: + return { + "results": [ + { + "document_id": r.document_id, + "query_id": r.query_id, + "score": r.score, + } + for r in obj.results + ], + } + + def encode_with_completion( + self, obj: RerankerResponse + ) -> Tuple[Dict[str, Any], bool]: + return self.encode(obj), True diff --git a/trustgraph-base/trustgraph/messaging/translators/retrieval.py b/trustgraph-base/trustgraph/messaging/translators/retrieval.py index fe766522..f2a0b29a 100644 --- a/trustgraph-base/trustgraph/messaging/translators/retrieval.py +++ b/trustgraph-base/trustgraph/messaging/translators/retrieval.py @@ -12,6 +12,7 @@ class DocumentRagRequestTranslator(MessageTranslator): query=data["query"], collection=data.get("collection", "default"), doc_limit=int(data.get("doc-limit", 20)), + fetch_limit=int(data.get("fetch-limit", 0)), streaming=data.get("streaming", False) ) @@ -20,6 +21,7 @@ class DocumentRagRequestTranslator(MessageTranslator): "query": obj.query, "collection": obj.collection, "doc-limit": obj.doc_limit, + "fetch-limit": obj.fetch_limit, "streaming": getattr(obj, "streaming", False) } diff --git a/trustgraph-base/trustgraph/provenance/__init__.py b/trustgraph-base/trustgraph/provenance/__init__.py index 051efc66..d96bad1e 100644 --- a/trustgraph-base/trustgraph/provenance/__init__.py +++ b/trustgraph-base/trustgraph/provenance/__init__.py @@ -64,6 +64,8 @@ from . uris import ( docrag_question_uri, docrag_grounding_uri, docrag_exploration_uri, + docrag_focus_uri, + chunk_selection_uri, docrag_synthesis_uri, ) @@ -89,9 +91,13 @@ from . namespaces import ( TG_IMAGE_TYPE, TG_SUBGRAPH_TYPE, # Query-time provenance predicates (GraphRAG) TG_QUERY, TG_CONCEPT, TG_ENTITY, - TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING, + TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_REASONING, TG_SCORE, + # Edge selection entity type + TG_EDGE_SELECTION, # Query-time provenance predicates (DocumentRAG) TG_CHUNK_COUNT, TG_SELECTED_CHUNK, + # Chunk selection entity type + TG_CHUNK_SELECTION, # Explainability entity types TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, TG_ANALYSIS, TG_CONCLUSION, @@ -130,6 +136,7 @@ from . triples import ( # Query-time provenance triple builders (DocumentRAG) docrag_question_triples, docrag_exploration_triples, + docrag_chunk_selection_triples, docrag_synthesis_triples, # Utility set_graph, @@ -194,6 +201,8 @@ __all__ = [ "docrag_question_uri", "docrag_grounding_uri", "docrag_exploration_uri", + "docrag_focus_uri", + "chunk_selection_uri", "docrag_synthesis_uri", # Namespaces "PROV", "PROV_ENTITY", "PROV_ACTIVITY", "PROV_AGENT", @@ -212,9 +221,13 @@ __all__ = [ "TG_CHUNK_TYPE", "TG_IMAGE_TYPE", "TG_SUBGRAPH_TYPE", # Query-time provenance predicates (GraphRAG) "TG_QUERY", "TG_CONCEPT", "TG_ENTITY", - "TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING", + "TG_EDGE_COUNT", "TG_SELECTED_EDGE", "TG_REASONING", "TG_SCORE", + # Edge selection entity type + "TG_EDGE_SELECTION", # Query-time provenance predicates (DocumentRAG) "TG_CHUNK_COUNT", "TG_SELECTED_CHUNK", + # Chunk selection entity type + "TG_CHUNK_SELECTION", # Explainability entity types "TG_QUESTION", "TG_GROUNDING", "TG_EXPLORATION", "TG_FOCUS", "TG_SYNTHESIS", "TG_ANALYSIS", "TG_CONCLUSION", @@ -250,6 +263,7 @@ __all__ = [ # Query-time provenance triple builders (DocumentRAG) "docrag_question_triples", "docrag_exploration_triples", + "docrag_chunk_selection_triples", "docrag_synthesis_triples", # Agent provenance triple builders "agent_session_triples", diff --git a/trustgraph-base/trustgraph/provenance/namespaces.py b/trustgraph-base/trustgraph/provenance/namespaces.py index 0b14f1b9..da6e30b2 100644 --- a/trustgraph-base/trustgraph/provenance/namespaces.py +++ b/trustgraph-base/trustgraph/provenance/namespaces.py @@ -66,12 +66,19 @@ TG_EDGE_COUNT = TG + "edgeCount" TG_SELECTED_EDGE = TG + "selectedEdge" TG_EDGE = TG + "edge" TG_REASONING = TG + "reasoning" +TG_SCORE = TG + "score" TG_DOCUMENT = TG + "document" # Reference to document in librarian +# Edge selection entity type (cross-encoder scored edge in Focus) +TG_EDGE_SELECTION = TG + "EdgeSelection" + # Query-time provenance predicates (DocumentRAG) TG_CHUNK_COUNT = TG + "chunkCount" TG_SELECTED_CHUNK = TG + "selectedChunk" +# Chunk selection entity type (cross-encoder reranked chunk in Focus) +TG_CHUNK_SELECTION = TG + "ChunkSelection" + # Extraction provenance entity types TG_DOCUMENT_TYPE = TG + "Document" TG_PAGE_TYPE = TG + "Page" diff --git a/trustgraph-base/trustgraph/provenance/triples.py b/trustgraph-base/trustgraph/provenance/triples.py index 8dedff9a..d2374d54 100644 --- a/trustgraph-base/trustgraph/provenance/triples.py +++ b/trustgraph-base/trustgraph/provenance/triples.py @@ -24,10 +24,14 @@ from . namespaces import ( TG_ELEMENT_TYPES, TG_TABLE_COUNT, TG_IMAGE_COUNT, # Query-time provenance predicates (GraphRAG) TG_QUERY, TG_CONCEPT, TG_ENTITY, - TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, + TG_EDGE_COUNT, TG_SELECTED_EDGE, TG_EDGE, TG_REASONING, TG_SCORE, TG_DOCUMENT, + # Edge selection entity type + TG_EDGE_SELECTION, # Query-time provenance predicates (DocumentRAG) TG_CHUNK_COUNT, TG_SELECTED_CHUNK, + # Chunk selection entity type + TG_CHUNK_SELECTION, # Explainability entity types TG_QUESTION, TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, # Unifying types @@ -38,7 +42,10 @@ from . namespaces import ( TG_IN_TOKEN, TG_OUT_TOKEN, ) -from . uris import activity_uri, agent_uri, subgraph_uri, edge_selection_uri +from . uris import ( + activity_uri, agent_uri, subgraph_uri, edge_selection_uri, + chunk_selection_uri, +) def set_graph(triples: List[Triple], graph: str) -> List[Triple]: @@ -536,10 +543,9 @@ def focus_triples( _triple(focus_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)), ] - # Add each selected edge with its reasoning via intermediate entity + # Add each selected edge with metadata via intermediate entity for idx, edge_info in enumerate(selected_edges_with_reasoning): edge = edge_info.get("edge") - reasoning = edge_info.get("reasoning", "") if edge: s, p, o = edge @@ -552,13 +558,32 @@ def focus_triples( _triple(focus_uri, TG_SELECTED_EDGE, _iri(edge_sel_uri)) ) + # Type the edge selection entity + triples.append( + _triple(edge_sel_uri, RDF_TYPE, _iri(TG_EDGE_SELECTION)) + ) + # Attach quoted triple to edge selection entity quoted = _quoted_triple(s, p, o) triples.append( Triple(s=_iri(edge_sel_uri), p=_iri(TG_EDGE), o=quoted) ) - # Attach reasoning to edge selection entity + # Structured cross-encoder metadata + concept = edge_info.get("concept") + if concept: + triples.append( + _triple(edge_sel_uri, TG_CONCEPT, _literal(concept)) + ) + + score = edge_info.get("score") + if score is not None: + triples.append( + _triple(edge_sel_uri, TG_SCORE, _literal(str(score))) + ) + + # Legacy reasoning text (for non-cross-encoder callers) + reasoning = edge_info.get("reasoning", "") if reasoning: triples.append( _triple(edge_sel_uri, TG_REASONING, _literal(reasoning)) @@ -698,6 +723,75 @@ def docrag_exploration_triples( return triples +def docrag_chunk_selection_triples( + focus_uri: str, + exploration_uri: str, + selected_chunks_with_scores: List[dict], + session_id: str, +) -> List[Triple]: + """ + Build triples for a document RAG focus entity (chunks selected by the + cross-encoder reranker). + + Mirrors GraphRAG's focus_triples / tg:EdgeSelection pattern: a Focus entity + derived from exploration, with one ChunkSelection sub-entity per surviving + chunk carrying the chunk reference and the reranker score. + + Structure: + a tg:Focus ; prov:wasDerivedFrom . + tg:selectedChunk . + a tg:ChunkSelection . + tg:document . + tg:score "0.97" . + + Args: + focus_uri: URI of the focus entity (from docrag_focus_uri) + exploration_uri: URI of the parent exploration entity + selected_chunks_with_scores: List of dicts with 'chunk_id' and 'score' + session_id: Session UUID for generating chunk selection URIs + + Returns: + List of Triple objects + """ + triples = [ + _triple(focus_uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(focus_uri, RDF_TYPE, _iri(TG_FOCUS)), + _triple(focus_uri, RDFS_LABEL, _literal("Chunk Selection")), + _triple(focus_uri, PROV_WAS_DERIVED_FROM, _iri(exploration_uri)), + ] + + for idx, chunk_info in enumerate(selected_chunks_with_scores): + chunk_id = chunk_info.get("chunk_id") + if not chunk_id: + continue + + chunk_sel_uri = chunk_selection_uri(session_id, idx) + + # Link focus to chunk selection entity + triples.append( + _triple(focus_uri, TG_SELECTED_CHUNK, _iri(chunk_sel_uri)) + ) + + # Type the chunk selection entity + triples.append( + _triple(chunk_sel_uri, RDF_TYPE, _iri(TG_CHUNK_SELECTION)) + ) + + # Reference the actual chunk (in librarian) + triples.append( + _triple(chunk_sel_uri, TG_DOCUMENT, _iri(chunk_id)) + ) + + # Cross-encoder score + score = chunk_info.get("score") + if score is not None: + triples.append( + _triple(chunk_sel_uri, TG_SCORE, _literal(str(score))) + ) + + return triples + + def docrag_synthesis_triples( synthesis_uri: str, exploration_uri: str, diff --git a/trustgraph-base/trustgraph/provenance/uris.py b/trustgraph-base/trustgraph/provenance/uris.py index a26ac867..00beacbe 100644 --- a/trustgraph-base/trustgraph/provenance/uris.py +++ b/trustgraph-base/trustgraph/provenance/uris.py @@ -309,6 +309,35 @@ def docrag_exploration_uri(session_id: str) -> str: return f"urn:trustgraph:docrag:{session_id}/exploration" +def docrag_focus_uri(session_id: str) -> str: + """ + Generate URI for a document RAG focus entity (chunks selected by the + cross-encoder reranker). + + Args: + session_id: The session UUID. + + Returns: + URN in format: urn:trustgraph:docrag:{uuid}/focus + """ + return f"urn:trustgraph:docrag:{session_id}/focus" + + +def chunk_selection_uri(session_id: str, chunk_index: int) -> str: + """ + Generate URI for a chunk selection item (links a reranked chunk to its + score). Mirrors edge_selection_uri for GraphRAG. + + Args: + session_id: The session UUID. + chunk_index: Index of this chunk in the selection (0-based). + + Returns: + URN in format: urn:trustgraph:prov:chunk:{uuid}:{index} + """ + return f"urn:trustgraph:prov:chunk:{session_id}:{chunk_index}" + + def docrag_synthesis_uri(session_id: str) -> str: """ Generate URI for a document RAG synthesis entity (final answer). diff --git a/trustgraph-base/trustgraph/provenance/vocabulary.py b/trustgraph-base/trustgraph/provenance/vocabulary.py index afb5c30f..f5139992 100644 --- a/trustgraph-base/trustgraph/provenance/vocabulary.py +++ b/trustgraph-base/trustgraph/provenance/vocabulary.py @@ -29,6 +29,8 @@ from . namespaces import ( TG_ANSWER_TYPE, TG_REFLECTION_TYPE, TG_THOUGHT_TYPE, TG_OBSERVATION_TYPE, TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT, TG_SUBAGENT_GOAL, TG_PLAN_STEP, + TG_EDGE_SELECTION, TG_SCORE, + TG_CHUNK_SELECTION, ) @@ -93,6 +95,8 @@ TG_CLASS_LABELS = [ _label_triple(TG_FINDING, "Finding"), _label_triple(TG_PLAN_TYPE, "Plan"), _label_triple(TG_STEP_RESULT, "Step Result"), + _label_triple(TG_EDGE_SELECTION, "Edge Selection"), + _label_triple(TG_CHUNK_SELECTION, "Chunk Selection"), ] # TrustGraph predicate labels @@ -117,6 +121,7 @@ TG_PREDICATE_LABELS = [ _label_triple(TG_ENTITY, "entity"), _label_triple(TG_SUBAGENT_GOAL, "subagent goal"), _label_triple(TG_PLAN_STEP, "plan step"), + _label_triple(TG_SCORE, "score"), ] diff --git a/trustgraph-base/trustgraph/schema/services/__init__.py b/trustgraph-base/trustgraph/schema/services/__init__.py index 2a214201..63dc05fd 100644 --- a/trustgraph-base/trustgraph/schema/services/__init__.py +++ b/trustgraph-base/trustgraph/schema/services/__init__.py @@ -15,4 +15,5 @@ from .diagnosis import * from .collection import * from .storage import * from .tool_service import * -from .sparql_query import * \ No newline at end of file +from .sparql_query import * +from .reranker import * \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/services/iam.py b/trustgraph-base/trustgraph/schema/services/iam.py index 797d6203..6b2599e0 100644 --- a/trustgraph-base/trustgraph/schema/services/iam.py +++ b/trustgraph-base/trustgraph/schema/services/iam.py @@ -29,7 +29,7 @@ class UserInput: @dataclass class UserRecord: id: str = "" - workspace: str = "" + default_workspace: str = "" username: str = "" name: str = "" email: str = "" @@ -74,6 +74,21 @@ class ApiKeyRecord: last_used: str = "" +# ---- Enterprise IAM types (additive) ---- + +@dataclass +class GroupInput: + name: str = "" + description: str = "" + enabled: bool = True + + +@dataclass +class GrantInput: + capability: str = "" + workspace: str = "" + + @dataclass class IamRequest: operation: str = "" @@ -99,6 +114,13 @@ class IamRequest: workspace_record: WorkspaceInput | None = None key: ApiKeyInput | None = None + # ---- Enterprise IAM inputs (additive) ---- + group_id: str = "" + member_type: str = "" + member_id: str = "" + group: GroupInput | None = None + grant: GrantInput | None = None + # ---- authorise / authorise-many inputs ---- # Capability string from the vocabulary in capabilities.md. capability: str = "" @@ -138,7 +160,7 @@ class IamResponse: # resolve-api-key resolved_user_id: str = "" - resolved_workspace: str = "" + resolved_default_workspace: str = "" resolved_roles: list[str] = field(default_factory=list) # reset-password @@ -164,6 +186,14 @@ class IamResponse: # authorise_checks. decisions_json: str = "" + # ---- Enterprise IAM outputs (additive) ---- + # JSON-serialised payloads for enterprise group/grant operations. + group_json: str = "" + groups_json: str = "" + members_json: str = "" + grants_json: str = "" + effective_permissions_json: str = "" + error: Error | None = None diff --git a/trustgraph-base/trustgraph/schema/services/prompt.py b/trustgraph-base/trustgraph/schema/services/prompt.py index 1696790b..0a9c23ef 100644 --- a/trustgraph-base/trustgraph/schema/services/prompt.py +++ b/trustgraph-base/trustgraph/schema/services/prompt.py @@ -6,17 +6,6 @@ from ..core.primitives import Error # Prompt services, abstract the prompt generation -# extract-definitions: -# chunk -> definitions -# extract-relationships: -# chunk -> relationships -# kg-prompt: -# query, triples -> answer -# document-prompt: -# query, documents -> answer -# extract-rows -# schema, chunk -> rows - @dataclass class PromptRequest: id: str = "" @@ -46,4 +35,4 @@ class PromptResponse: out_token: int | None = None model: str | None = None -############################################################################ \ No newline at end of file +############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/reranker.py b/trustgraph-base/trustgraph/schema/services/reranker.py new file mode 100644 index 00000000..948746e4 --- /dev/null +++ b/trustgraph-base/trustgraph/schema/services/reranker.py @@ -0,0 +1,35 @@ + +from dataclasses import dataclass, field + +from ..core.primitives import Error + +############################################################################ + +# Cross-encoder reranker + +@dataclass +class RerankerQuery: + query_id: str = "" + query_text: str = "" + +@dataclass +class RerankerDocument: + document_id: str = "" + document_text: str = "" + +@dataclass +class RerankerRequest: + queries: list[RerankerQuery] = field(default_factory=list) + documents: list[RerankerDocument] = field(default_factory=list) + limit: int = 10 + +@dataclass +class RerankerResult: + document_id: str = "" + query_id: str = "" + score: float = 0.0 + +@dataclass +class RerankerResponse: + error: Error | None = None + results: list[RerankerResult] = field(default_factory=list) diff --git a/trustgraph-base/trustgraph/schema/services/retrieval.py b/trustgraph-base/trustgraph/schema/services/retrieval.py index e937e720..2d4e01e1 100644 --- a/trustgraph-base/trustgraph/schema/services/retrieval.py +++ b/trustgraph-base/trustgraph/schema/services/retrieval.py @@ -40,7 +40,10 @@ class GraphRagResponse: class DocumentRagQuery: query: str = "" collection: str = "" - doc_limit: int = 0 + doc_limit: int = 0 # docs selected into the synthesis prompt + fetch_limit: int = 0 # candidate pool fetched from the vector store + # before reranking (0 = derive from doc_limit; + # values below doc_limit are raised to it) streaming: bool = False @dataclass diff --git a/trustgraph-bedrock/pyproject.toml b/trustgraph-bedrock/pyproject.toml index 7aa2f96a..2dc724b0 100644 --- a/trustgraph-bedrock/pyproject.toml +++ b/trustgraph-bedrock/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.5,<2.6", + "trustgraph-base>=2.6,<2.7", "pulsar-client", "prometheus-client", "boto3", diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index 16b0ae0a..193ee1cd 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.5,<2.6", + "trustgraph-base>=2.6,<2.7", "requests", "pulsar-client", "aiohttp", @@ -71,6 +71,7 @@ tg-invoke-rows-query = "trustgraph.cli.invoke_rows_query:main" tg-invoke-sparql-query = "trustgraph.cli.invoke_sparql_query:main" tg-invoke-row-embeddings = "trustgraph.cli.invoke_row_embeddings:main" tg-invoke-prompt = "trustgraph.cli.invoke_prompt:main" +tg-invoke-reranker = "trustgraph.cli.invoke_reranker:main" tg-invoke-structured-query = "trustgraph.cli.invoke_structured_query:main" tg-load-doc-embeds = "trustgraph.cli.load_doc_embeds:main" tg-load-kg-core = "trustgraph.cli.load_kg_core:main" diff --git a/trustgraph-cli/trustgraph/cli/create_user.py b/trustgraph-cli/trustgraph/cli/create_user.py index c9253aca..14760454 100644 --- a/trustgraph-cli/trustgraph/cli/create_user.py +++ b/trustgraph-cli/trustgraph/cli/create_user.py @@ -53,7 +53,7 @@ def main(): help="Auth token (default: $TRUSTGRAPH_TOKEN)", ) parser.add_argument( - "--username", required=True, help="Username (unique in workspace)", + "--username", required=True, help="Username (globally unique)", ) parser.add_argument( "--password", default=None, @@ -75,10 +75,7 @@ def main(): ) parser.add_argument( "-w", "--workspace", default=None, - help=( - "Target workspace (admin only; defaults to caller's " - "assigned workspace)" - ), + help="Default workspace for the new user", ) run_main(do_create_user, parser) diff --git a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py index 01512ac8..04f4deda 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py @@ -21,10 +21,12 @@ default_token = os.getenv("TRUSTGRAPH_TOKEN", None) default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_collection = 'default' default_doc_limit = 10 +default_fetch_limit = 0 def question_explainable( - url, flow_id, question_text, collection, doc_limit, token=None, debug=False, + url, flow_id, question_text, collection, doc_limit, fetch_limit=0, + token=None, debug=False, workspace="default", ): """Execute document RAG with explainability - shows provenance events inline.""" @@ -39,6 +41,7 @@ def question_explainable( query=question_text, collection=collection, doc_limit=doc_limit, + fetch_limit=fetch_limit, ): if isinstance(item, RAGChunk): # Print response content @@ -97,7 +100,7 @@ def question_explainable( def question( - url, flow_id, question_text, collection, doc_limit, + url, flow_id, question_text, collection, doc_limit, fetch_limit=0, streaming=True, token=None, explainable=False, debug=False, show_usage=False, workspace="default", ): @@ -109,6 +112,7 @@ def question( question_text=question_text, collection=collection, doc_limit=doc_limit, + fetch_limit=fetch_limit, token=token, debug=debug, workspace=workspace, @@ -128,6 +132,7 @@ def question( query=question_text, collection=collection, doc_limit=doc_limit, + fetch_limit=fetch_limit, streaming=True ) @@ -155,6 +160,7 @@ def question( query=question_text, collection=collection, doc_limit=doc_limit, + fetch_limit=fetch_limit, ) print(result.text) @@ -214,7 +220,15 @@ def main(): '-d', '--doc-limit', type=int, default=default_doc_limit, - help=f'Document limit (default: {default_doc_limit})' + help=f'Documents selected into the prompt (default: {default_doc_limit})' + ) + + parser.add_argument( + '--fetch-limit', + type=int, + default=default_fetch_limit, + help='Candidate documents fetched from the vector store before ' + 'reranking (default: derive from doc-limit)' ) parser.add_argument( @@ -251,6 +265,7 @@ def main(): question_text=args.question, collection=args.collection, doc_limit=args.doc_limit, + fetch_limit=args.fetch_limit, streaming=not args.no_streaming, token=args.token, explainable=args.explainable, diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py index f39cdab0..892d2d35 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py @@ -112,14 +112,13 @@ def _question_explainable_api( if focus_full and focus_full.edge_selections: for edge_sel in focus_full.edge_selections: if edge_sel.edge: - # Resolve labels for edge components s_label, p_label, o_label = explain_client.resolve_edge_labels( edge_sel.edge, collection ) print(f" Edge: ({s_label}, {p_label}, {o_label})", file=sys.stderr) - if edge_sel.reasoning: - r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning - print(f" Reason: {r_short}", file=sys.stderr) + if edge_sel.concept or edge_sel.score is not None: + score_str = f"{edge_sel.score:.4f}" if edge_sel.score is not None else "?" + print(f" Concept: {edge_sel.concept} Score: {score_str}", file=sys.stderr) elif isinstance(entity, Synthesis): print(f"\n [synthesis] {prov_id}", file=sys.stderr) diff --git a/trustgraph-cli/trustgraph/cli/invoke_reranker.py b/trustgraph-cli/trustgraph/cli/invoke_reranker.py new file mode 100644 index 00000000..91337c97 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/invoke_reranker.py @@ -0,0 +1,127 @@ +""" +Invokes the reranker service to score and rank documents by relevance +to one or more queries. +""" + +import argparse +import json +import os +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") + +def query(url, flow_id, queries, documents, limit, token=None, + workspace="default"): + + api = Api(url=url, token=token, workspace=workspace) + socket = api.socket() + flow = socket.flow(flow_id) + + try: + + query_objects = [ + {"query_id": str(i), "query_text": q} + for i, q in enumerate(queries) + ] + + document_objects = [ + {"document_id": str(i), "document_text": d} + for i, d in enumerate(documents) + ] + + result = flow.rerank( + queries=query_objects, + documents=document_objects, + limit=limit, + ) + + if "error" in result and result["error"]: + err = result["error"] + print(f"Error: [{err.get('type', '')}] {err.get('message', '')}") + return + + for r in result.get("results", []): + doc_idx = int(r["document_id"]) + query_idx = int(r["query_id"]) + print( + f" {r['score']:.4f} | " + f"query: {queries[query_idx]} | " + f"doc: {documents[doc_idx]}" + ) + + finally: + socket.close() + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-invoke-reranker', + description=__doc__, + ) + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-w', '--workspace', + default=default_workspace, + help=f'Workspace (default: {default_workspace})', + ) + + parser.add_argument( + '-f', '--flow-id', + default="default", + help=f'Flow ID (default: default)' + ) + + parser.add_argument( + '-l', '--limit', + type=int, + default=10, + help='Maximum number of results (default: 10)', + ) + + parser.add_argument( + '-q', '--query', + action='append', + required=True, + help='Query text (can be specified multiple times)', + ) + + parser.add_argument( + 'documents', + nargs='+', + help='Documents to rerank', + ) + + args = parser.parse_args() + + try: + + query( + url=args.url, + flow_id=args.flow_id, + queries=args.query, + documents=args.documents, + limit=args.limit, + token=args.token, + workspace=args.workspace, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/login.py b/trustgraph-cli/trustgraph/cli/login.py index 0e87c3b0..977cf15b 100644 --- a/trustgraph-cli/trustgraph/cli/login.py +++ b/trustgraph-cli/trustgraph/cli/login.py @@ -51,8 +51,8 @@ def main(): parser.add_argument( "-w", "--workspace", default=None, help=( - "Optional workspace to log in against. Defaults to " - "the user's assigned workspace." + "Override the default workspace for this session's JWT. " + "If omitted, uses the user's stored default workspace." ), ) run_main(do_login, parser) diff --git a/trustgraph-cli/trustgraph/cli/show_explain_trace.py b/trustgraph-cli/trustgraph/cli/show_explain_trace.py index 17aaca1a..ed4a9807 100644 --- a/trustgraph-cli/trustgraph/cli/show_explain_trace.py +++ b/trustgraph-cli/trustgraph/cli/show_explain_trace.py @@ -203,9 +203,9 @@ def print_graphrag_text(trace, explain_client, flow, collection, api=None, show_ ) print(f" {i}. ({s_label}, {p_label}, {o_label})") - if edge_sel.reasoning: - r_short = edge_sel.reasoning[:100] + "..." if len(edge_sel.reasoning) > 100 else edge_sel.reasoning - print(f" Reasoning: {r_short}") + if edge_sel.concept or edge_sel.score is not None: + score_str = f"{edge_sel.score:.4f}" if edge_sel.score is not None else "?" + print(f" Concept: {edge_sel.concept} Score: {score_str}") if show_provenance and edge_sel.edge: provenance = trace_edge_provenance( @@ -519,7 +519,8 @@ def trace_to_dict(trace, trace_type): "selected_edges": [ { "edge": edge_sel.edge, - "reasoning": edge_sel.reasoning, + "concept": edge_sel.concept, + "score": edge_sel.score, } for edge_sel in focus.edge_selections ], diff --git a/trustgraph-cli/trustgraph/cli/update_user.py b/trustgraph-cli/trustgraph/cli/update_user.py index 5c1dc4d7..d24d5323 100644 --- a/trustgraph-cli/trustgraph/cli/update_user.py +++ b/trustgraph-cli/trustgraph/cli/update_user.py @@ -68,7 +68,7 @@ def do_update_user(args): print(f"username : {rec.get('username', '')}") print(f"name : {rec.get('name', '')}") print(f"email : {rec.get('email', '')}") - print(f"workspace : {rec.get('workspace', '')}") + print(f"default_ws: {rec.get('default_workspace', '')}") print(f"roles : {', '.join(rec.get('roles', []))}") print(f"enabled : {'yes' if rec.get('enabled') else 'no'}") print( @@ -114,7 +114,7 @@ def main(): "-w", "--workspace", default=None, help=( "Optional workspace integrity check — when supplied, " - "iam-svc verifies the target user's home workspace " + "iam-svc verifies the target user's default workspace " "matches" ), ) diff --git a/trustgraph-embeddings-hf/pyproject.toml b/trustgraph-embeddings-hf/pyproject.toml index 4bf17688..b8bd7d1c 100644 --- a/trustgraph-embeddings-hf/pyproject.toml +++ b/trustgraph-embeddings-hf/pyproject.toml @@ -10,8 +10,8 @@ description = "HuggingFace embeddings support for TrustGraph." readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.5,<2.6", - "trustgraph-flow>=2.5,<2.6", + "trustgraph-base>=2.6,<2.7", + "trustgraph-flow>=2.6,<2.7", "torch", "urllib3", "transformers", diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index 547dea3c..90647104 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.5,<2.6", + "trustgraph-base>=2.6,<2.7", "aiohttp", "anthropic", "scylla-driver", @@ -19,6 +19,7 @@ dependencies = [ "faiss-cpu", "falkordb", "fastembed", + "flashrank", "ibis", "jsonschema", "langchain", @@ -83,6 +84,7 @@ graph-embeddings-write-pinecone = "trustgraph.storage.graph_embeddings.pinecone: graph-embeddings-write-qdrant = "trustgraph.storage.graph_embeddings.qdrant:run" graph-embeddings = "trustgraph.embeddings.graph_embeddings:run" graph-rag = "trustgraph.retrieval.graph_rag:run" +reranker-flashrank = "trustgraph.reranker.flashrank:run" kg-extract-agent = "trustgraph.extract.kg.agent:run" kg-extract-definitions = "trustgraph.extract.kg.definitions:run" kg-extract-rows = "trustgraph.extract.kg.rows:run" diff --git a/trustgraph-flow/trustgraph/bootstrap/initialisers/default_flow_start.py b/trustgraph-flow/trustgraph/bootstrap/initialisers/default_flow_start.py index 96d13d28..524fd306 100644 --- a/trustgraph-flow/trustgraph/bootstrap/initialisers/default_flow_start.py +++ b/trustgraph-flow/trustgraph/bootstrap/initialisers/default_flow_start.py @@ -18,6 +18,10 @@ description : str (default "Default") Human-readable description passed to flow-svc. parameters : dict (optional) Optional parameter overrides passed to start-flow. +list_timeout : int (default 10) + Timeout in seconds for the list-flows request. +start_timeout : int (default 30) + Timeout in seconds for the start-flow request. """ from trustgraph.schema import FlowRequest @@ -34,6 +38,8 @@ class DefaultFlowStart(Initialiser): blueprint=None, description="Default", parameters=None, + list_timeout=10, + start_timeout=30, **kwargs, ): super().__init__(**kwargs) @@ -46,6 +52,8 @@ class DefaultFlowStart(Initialiser): self.blueprint = blueprint self.description = description self.parameters = dict(parameters) if parameters else {} + self.list_timeout = list_timeout + self.start_timeout = start_timeout async def run(self, ctx, old_flag, new_flag): @@ -70,7 +78,7 @@ class DefaultFlowStart(Initialiser): FlowRequest( operation="list-flows", ), - timeout=10, + timeout=self.list_timeout, ) if list_resp.error: raise RuntimeError( @@ -99,7 +107,7 @@ class DefaultFlowStart(Initialiser): description=self.description, parameters=self.parameters, ), - timeout=30, + timeout=self.start_timeout, ) if resp.error: raise RuntimeError( diff --git a/trustgraph-flow/trustgraph/bootstrap/initialisers/workspace_init.py b/trustgraph-flow/trustgraph/bootstrap/initialisers/workspace_init.py index 423c5f5e..b1881fff 100644 --- a/trustgraph-flow/trustgraph/bootstrap/initialisers/workspace_init.py +++ b/trustgraph-flow/trustgraph/bootstrap/initialisers/workspace_init.py @@ -14,7 +14,9 @@ seed_file : str (required when source=="seed-file") Path to a JSON seed file with the same shape TemplateSeed consumes. overwrite : bool (default False) On re-run (flag change), if True overwrite all keys; if False, - upsert-missing-only (preserves in-workspace customisations). + upsert-missing-only (preserves in-workspace customisations) +iam_timeout : int (default 10) + Timeout in seconds for the IAM create-workspace request. Raises (in ``run``) ------------------- @@ -41,7 +43,9 @@ class WorkspaceInit(Initialiser): source="template", seed_file=None, overwrite=False, + iam_timeout=10, **kwargs, + ): super().__init__(**kwargs) @@ -59,6 +63,7 @@ class WorkspaceInit(Initialiser): self.source = source self.seed_file = seed_file self.overwrite = overwrite + self.iam_timeout = iam_timeout async def run(self, ctx, old_flag, new_flag): await self._create_workspace(ctx) @@ -120,10 +125,10 @@ class WorkspaceInit(Initialiser): workspace_record=WorkspaceInput( id=self.workspace, name=self.workspace.title(), - enabled=True, + enabled=True, ), ), - timeout=10, + timeout=self.iam_timeout, ) if resp.error: if resp.error.type == "duplicate": diff --git a/trustgraph-flow/trustgraph/gateway/auth.py b/trustgraph-flow/trustgraph/gateway/auth.py index 273fcb5a..b08bf650 100644 --- a/trustgraph-flow/trustgraph/gateway/auth.py +++ b/trustgraph-flow/trustgraph/gateway/auth.py @@ -57,16 +57,17 @@ class Identity: # the OSS regime this is the user record's id; the gateway # treats it as a string with no semantic content. handle: str - # The workspace this credential authenticates to. Used by the - # gateway as the default-fill-in for operations that omit a - # workspace. Never used as policy input. - workspace: str + # The user's default workspace. Used by the gateway as the + # default-fill-in for operations that omit a workspace. Not a + # permission boundary — workspace access is controlled by the + # IAM regime's authorise() decision, not by this field. + default_workspace: str # Stable identifier for audit logs. In OSS this is the same # value as ``handle``; not assumed equal in the contract. principal_id: str # How the credential was presented. Non-policy; useful for # logs / metrics only. - source: str # "api-key" | "jwt" + source: str # "api-key" | "jwt" | "anonymous" def _auth_failure(): @@ -256,21 +257,22 @@ class IamAuth: raise _auth_failure() sub = claims.get("sub", "") - ws = claims.get("workspace", "") + ws = claims.get("default_workspace", "") if not sub or not ws: raise _auth_failure() - # JWT carries no policy state under the IAM contract; - # any roles / claims field is ignored here. return Identity( - handle=sub, workspace=ws, principal_id=sub, source="jwt", + handle=sub, default_workspace=ws, + principal_id=sub, source="jwt", ) async def _authenticate_anonymous(self): try: async def _call(client): return await client.authenticate_anonymous() - user_id, workspace, _roles = await self._with_client(_call) + user_id, default_workspace, _roles = await self._with_client( + _call, + ) except Exception as e: logger.debug( f"Anonymous authentication rejected: " @@ -278,11 +280,11 @@ class IamAuth: ) raise _auth_failure() - if not user_id or not workspace: + if not user_id or not default_workspace: raise _auth_failure() return Identity( - handle=user_id, workspace=workspace, + handle=user_id, default_workspace=default_workspace, principal_id=user_id, source="anonymous", ) @@ -305,7 +307,9 @@ class IamAuth: # ``roles`` is returned by the OSS regime as a hint # but is not consulted by the gateway; all policy # decisions go through ``authorise``. - user_id, workspace, _roles = await self._with_client(_call) + user_id, default_workspace, _roles = await self._with_client( + _call, + ) except Exception as e: logger.debug( f"API key resolution failed: " @@ -313,11 +317,11 @@ class IamAuth: ) raise _auth_failure() - if not user_id or not workspace: + if not user_id or not default_workspace: raise _auth_failure() identity = Identity( - handle=user_id, workspace=workspace, + handle=user_id, default_workspace=default_workspace, principal_id=user_id, source="api-key", ) self._key_cache[h] = (identity, now + API_KEY_CACHE_TTL) diff --git a/trustgraph-flow/trustgraph/gateway/capabilities.py b/trustgraph-flow/trustgraph/gateway/capabilities.py index dbbb01e0..c9d3b516 100644 --- a/trustgraph-flow/trustgraph/gateway/capabilities.py +++ b/trustgraph-flow/trustgraph/gateway/capabilities.py @@ -99,7 +99,7 @@ async def enforce_workspace(data, identity, auth, capability=None): return data requested = data.get("workspace", "") - target = requested or identity.workspace + target = requested or identity.default_workspace data["workspace"] = target if target not in auth.known_workspaces: diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index bddb009d..7285250f 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -37,6 +37,7 @@ from . graph_embeddings_query import GraphEmbeddingsQueryRequestor from . document_embeddings_query import DocumentEmbeddingsQueryRequestor from . row_embeddings_query import RowEmbeddingsQueryRequestor from . mcp_tool import McpToolRequestor +from . reranker import RerankerRequestor from . text_load import TextLoad from . document_load import DocumentLoad @@ -74,6 +75,7 @@ request_response_dispatchers = { "structured-diag": StructuredDiagRequestor, "row-embeddings": RowEmbeddingsQueryRequestor, "sparql": SparqlQueryRequestor, + "reranker": RerankerRequestor, } system_dispatchers = { diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py index 9b119f8e..c1020998 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py @@ -5,6 +5,7 @@ import uuid import logging from ..capabilities import PUBLIC, AUTHENTICATED +from ..registry import ResourceLevel # Module logger logger = logging.getLogger(__name__) @@ -79,7 +80,7 @@ class Mux: self.identity = identity await self.ws.send_json({ "type": "auth-ok", - "workspace": identity.workspace, + "default_workspace": identity.default_workspace, }) async def receive(self, msg): @@ -159,12 +160,14 @@ class Mux: return # Resolve workspace (default-fill from the caller's - # bound workspace). Workspace resolution applies to all - # operations regardless of capability level. + # bound workspace). The envelope workspace is the + # single canonical workspace for routing AND + # authorisation. The inner request body's workspace + # field is not consulted — workspace-scoped services + # receive workspace from the queue identity, not the + # message body. try: await enforce_workspace(data, self.identity, self.auth) - if isinstance(inner, dict): - await enforce_workspace(inner, self.identity, self.auth) # Authorisation: capability sentinels short-circuit # the regime call; capability strings go through @@ -176,11 +179,18 @@ class Mux: "flow": data.get("flow", ""), } parameters = {} + elif op.resource_level == ResourceLevel.WORKSPACE: + # Workspace-scoped services (config, flow, + # librarian, etc.) — workspace comes from the + # envelope, same as flow-level services. + resource = { + "workspace": data.get("workspace", ""), + } + parameters = {} else: - # Build a minimal RequestContext so the matched - # operation's own extractors decide resource - # and parameters — same path the HTTP - # endpoints take. + # System-level services (IAM) — resource is + # {} and parameters come from the inner body + # (e.g. user.workspace, workspace_record.id). from ..registry import RequestContext ctx = RequestContext( body=inner if isinstance(inner, dict) else {}, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/reranker.py b/trustgraph-flow/trustgraph/gateway/dispatch/reranker.py new file mode 100644 index 00000000..e456f3d1 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/dispatch/reranker.py @@ -0,0 +1,31 @@ + +from ... schema import RerankerRequest, RerankerResponse +from ... messaging import TranslatorRegistry + +from . requestor import ServiceRequestor + +class RerankerRequestor(ServiceRequestor): + def __init__( + self, backend, request_queue, response_queue, timeout, + consumer, subscriber, + ): + + super(RerankerRequestor, self).__init__( + backend=backend, + request_queue=request_queue, + response_queue=response_queue, + request_schema=RerankerRequest, + response_schema=RerankerResponse, + subscription = subscriber, + consumer_name = consumer, + timeout=timeout, + ) + + self.request_translator = TranslatorRegistry.get_request_translator("reranker") + self.response_translator = TranslatorRegistry.get_response_translator("reranker") + + def to_request(self, body): + return self.request_translator.decode(body) + + def from_response(self, message): + return self.response_translator.encode_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/registry.py b/trustgraph-flow/trustgraph/gateway/registry.py index ca235315..14f820c2 100644 --- a/trustgraph-flow/trustgraph/gateway/registry.py +++ b/trustgraph-flow/trustgraph/gateway/registry.py @@ -92,7 +92,7 @@ class Operation: # Returns a dict with the appropriate components for the # resource level: {} for SYSTEM, {workspace} for WORKSPACE, # {workspace, flow} for FLOW. Default-fill-in of workspace - # from identity.workspace happens here when applicable. + # from identity.default_workspace happens here when applicable. extract_resource: Callable[[RequestContext], dict] # Build the parameters dict — decision-relevant fields the @@ -141,7 +141,7 @@ def _workspace_from_body(ctx: RequestContext) -> dict: workspace field, defaulting to the caller's bound workspace.""" ws = (ctx.body.get("workspace") if isinstance(ctx.body, dict) else "") if not ws and ctx.identity is not None: - ws = ctx.identity.workspace + ws = ctx.identity.default_workspace return {"workspace": ws} @@ -188,7 +188,7 @@ def _workspace_param_only(ctx: RequestContext) -> dict: or body.get("workspace") ) if not ws and ctx.identity is not None: - ws = ctx.identity.workspace + ws = ctx.identity.default_workspace return {"workspace": ws or ""} @@ -311,7 +311,7 @@ register(Operation( )) register(Operation( name="list-my-workspaces", - capability="workspaces:list-own", + capability=AUTHENTICATED, resource_level=ResourceLevel.SYSTEM, extract_resource=_empty_resource, extract_parameters=_no_parameters, @@ -506,18 +506,19 @@ _FLOW_SERVICES = { "text-completion": "llm", "prompt": "llm", "mcp-tool": "mcp", - "graph-rag": "graph:read", - "document-rag": "documents:read", + "graph-rag": "graph-rag:read", + "document-rag": "document-rag:read", "embeddings": "embeddings", - "graph-embeddings": "graph:read", - "document-embeddings": "documents:read", - "triples": "graph:read", + "graph-embeddings": "graph-embeddings:read", + "document-embeddings": "document-embeddings:read", + "triples": "triples:read", "rows": "rows:read", - "nlp-query": "rows:read", - "structured-query": "rows:read", - "structured-diag": "rows:read", - "row-embeddings": "rows:read", - "sparql": "graph:read", + "nlp-query": "nlp-query:read", + "structured-query": "structured-query:read", + "structured-diag": "structured-query:read", + "row-embeddings": "row-embeddings:read", + "sparql": "sparql:read", + "reranker": "reranker", } for _kind, _cap in _FLOW_SERVICES.items(): _register_flow_kind("flow-service", _kind, _cap) @@ -525,10 +526,10 @@ for _kind, _cap in _FLOW_SERVICES.items(): # Streaming import socket endpoints. _FLOW_IMPORTS = { - "triples": "graph:write", - "graph-embeddings": "graph:write", - "document-embeddings": "documents:write", - "entity-contexts": "documents:write", + "triples": "triples:write", + "graph-embeddings": "graph-embeddings:write", + "document-embeddings": "document-embeddings:write", + "entity-contexts": "entity-contexts:write", "rows": "rows:write", } for _kind, _cap in _FLOW_IMPORTS.items(): @@ -537,10 +538,35 @@ for _kind, _cap in _FLOW_IMPORTS.items(): # Streaming export socket endpoints. _FLOW_EXPORTS = { - "triples": "graph:read", - "graph-embeddings": "graph:read", - "document-embeddings": "documents:read", - "entity-contexts": "documents:read", + "triples": "triples:read", + "graph-embeddings": "graph-embeddings:read", + "document-embeddings": "document-embeddings:read", + "entity-contexts": "entity-contexts:read", } for _kind, _cap in _FLOW_EXPORTS.items(): _register_flow_kind("flow-export", _kind, _cap) + + +# --------------------------------------------------------------------------- +# Enterprise IAM operations. +# +# These are additive — they register alongside the OSS IAM operations. +# When the OSS regime receives an unknown operation it returns an error; +# when the enterprise regime is running, it handles them. +# --------------------------------------------------------------------------- + +for _op in ( + "create-group", "get-group", "list-groups", + "update-group", "delete-group", + "add-group-member", "remove-group-member", "list-group-members", + "add-group-grant", "remove-group-grant", "list-group-grants", + "add-user-grant", "remove-user-grant", "list-user-grants", + "resolve-effective-permissions", +): + register(Operation( + name=_op, + capability="iam:admin", + resource_level=ResourceLevel.SYSTEM, + extract_resource=_empty_resource, + extract_parameters=_no_parameters, + )) diff --git a/trustgraph-flow/trustgraph/iam/noauth/handler.py b/trustgraph-flow/trustgraph/iam/noauth/handler.py index dd70b02d..1d3c6c8f 100644 --- a/trustgraph-flow/trustgraph/iam/noauth/handler.py +++ b/trustgraph-flow/trustgraph/iam/noauth/handler.py @@ -28,14 +28,14 @@ class NoAuthHandler: def _default_identity_response(self): return IamResponse( resolved_user_id=self.default_user_id, - resolved_workspace=self.default_workspace, + resolved_default_workspace=self.default_workspace, resolved_roles=["admin"], ) def _default_user_record(self): return UserRecord( id=self.default_user_id, - workspace=self.default_workspace, + default_workspace=self.default_workspace, username=self.default_user_id, name="Anonymous User", roles=["admin"], diff --git a/trustgraph-flow/trustgraph/iam/service/iam.py b/trustgraph-flow/trustgraph/iam/service/iam.py index 5f86e688..fced972e 100644 --- a/trustgraph-flow/trustgraph/iam/service/iam.py +++ b/trustgraph-flow/trustgraph/iam/service/iam.py @@ -58,21 +58,35 @@ AUTHZ_CACHE_TTL_SECONDS = 60 _READER_CAPS = { "agent", "graph:read", + "triples:read", + "sparql:read", + "graph-rag:read", + "graph-embeddings:read", "documents:read", + "document-rag:read", + "document-embeddings:read", + "entity-contexts:read", "rows:read", + "nlp-query:read", + "structured-query:read", + "row-embeddings:read", "llm", "embeddings", + "reranker", "mcp", "config:read", "flows:read", "collections:read", "knowledge:read", "keys:self", - "workspaces:list-own", } _WRITER_CAPS = _READER_CAPS | { "graph:write", + "triples:write", + "graph-embeddings:write", + "document-embeddings:write", + "entity-contexts:write", "documents:write", "rows:write", "collections:write", @@ -369,7 +383,7 @@ class IamService: ) = row return UserRecord( id=id or "", - workspace=workspace or "", + default_workspace=workspace or "", username=username or "", name=name or "", email=email or "", @@ -582,14 +596,8 @@ class IamService: if not v.password: return _err("auth-failed", "password required") - # Login accepts an optional workspace parameter. If omitted - # we use the default workspace (OSS single-workspace - # assumption). Multi-workspace enterprise editions swap in a - # resolver that looks across the caller's permitted set. - workspace = v.workspace or DEFAULT_WORKSPACE - user_id = await self.table_store.get_user_id_by_username( - workspace, v.username, + v.username, ) if not user_id: return _err("auth-failed", "no such user") @@ -610,7 +618,10 @@ class IamService: ): return _err("auth-failed", "bad credentials") - ws_row = await self.table_store.get_workspace(ws) + # JWT workspace: login request override, or the user's default. + jwt_workspace = v.workspace or ws + + ws_row = await self.table_store.get_workspace(jwt_workspace) if ws_row is None or not ws_row[2]: return _err("auth-failed", "workspace disabled") @@ -618,14 +629,10 @@ class IamService: now_ts = int(_now_dt().timestamp()) exp_ts = now_ts + JWT_TTL_SECONDS - # Per the IAM contract the gateway never reads policy state - # from the credential — roles stay server-side, reachable - # only via authorise(). JWT carries identity + workspace - # binding only. claims = { "iss": JWT_ISSUER, "sub": id, - "workspace": ws, + "default_workspace": jwt_workspace, "iat": now_ts, "exp": exp_ts, } @@ -864,20 +871,15 @@ class IamService: # user_row indices match get_user columns. Username is [2]. username = user_row[2] - record_workspace = user_row[1] # Revoke all API keys. key_rows = await self.table_store.list_api_keys_by_user(v.user_id) for kr in key_rows: await self.table_store.delete_api_key(kr[0]) - # Remove username lookup — keyed on (workspace, username), - # so use the resolved workspace from the user record rather - # than relying on the caller-supplied filter. + # Remove global username lookup. if username: - await self.table_store.delete_username_lookup( - record_workspace, username, - ) + await self.table_store.delete_username_lookup(username) # Remove user record. await self.table_store.delete_user(v.user_id) @@ -1096,13 +1098,15 @@ class IamService: return _err("auth-failed", "owning user disabled") # Workspace-disabled check. - ws_row = await self.table_store.get_workspace(user.workspace) + ws_row = await self.table_store.get_workspace( + user.default_workspace, + ) if ws_row is None or not ws_row[2]: return _err("auth-failed", "owning workspace disabled") return IamResponse( resolved_user_id=user.id, - resolved_workspace=user.workspace, + resolved_default_workspace=user.default_workspace, resolved_roles=list(user.roles), ) @@ -1129,9 +1133,9 @@ class IamService: if ws is None or not ws[2]: return _err("not-found", "workspace not found or disabled") - # Uniqueness on username within workspace. + # Global username uniqueness. existing = await self.table_store.get_user_id_by_username( - v.workspace, v.user.username, + v.user.username, ) if existing: return _err("duplicate", "username already exists") @@ -1303,8 +1307,9 @@ class IamService: return False, AUTHZ_CACHE_TTL_SECONDS # user_row layout: - # 0:id 1:workspace 2:username 3:name 4:email 5:password_hash - # 6:roles 7:enabled 8:must_change_password 9:created + # 0:id 1:default_workspace 2:username 3:name 4:email + # 5:password_hash 6:roles 7:enabled 8:must_change_password + # 9:created if not user_row[7]: # disabled return False, AUTHZ_CACHE_TTL_SECONDS diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py index c8ab9c36..57958bc0 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py @@ -10,6 +10,7 @@ import logging from .... exceptions import TooManyRequests, LlmError from .... base import LlmService, LlmResult, LlmChunk +from . variants import get_variant, DEFAULT_VARIANT, VARIANTS # Module logger logger = logging.getLogger(__name__) @@ -21,6 +22,7 @@ default_temperature = 0.0 default_max_output = 4096 default_api_key = os.getenv("OPENAI_TOKEN") default_base_url = os.getenv("OPENAI_BASE_URL") +default_thinking = "off" if default_base_url is None or default_base_url == "": default_base_url = "https://api.openai.com/v1" @@ -28,16 +30,21 @@ if default_base_url is None or default_base_url == "": class Processor(LlmService): def __init__(self, **params): - + model = params.get("model", default_model) api_key = params.get("api_key", default_api_key) base_url = params.get("url", default_base_url) temperature = params.get("temperature", default_temperature) max_output = params.get("max_output", default_max_output) + thinking = params.get("thinking", default_thinking) + variant_name = params.get("variant", DEFAULT_VARIANT) if not api_key: api_key = "not-set" + self.variant = get_variant(variant_name) + self.thinking = thinking + super(Processor, self).__init__( **params | { "model": model, @@ -56,13 +63,28 @@ class Processor(LlmService): else: self.openai = OpenAI(api_key=api_key) - logger.info("OpenAI LLM service initialized") + logger.info( + f"OpenAI LLM service initialized " + f"(variant={self.variant.name}, thinking={self.thinking})" + ) + + def _build_kwargs(self, model_name, temperature): + """Build API call kwargs using the active variant.""" + return self.variant.completion_kwargs( + max_output=self.max_output, + temperature=temperature, + thinking=self.thinking, + ) + + def _extract_content(self, message): + """Extract visible content from a response message.""" + if hasattr(self.variant, "extract_content"): + return self.variant.extract_content(message) + return message.content async def generate_content(self, system, prompt, model=None, temperature=None): - # Use provided model or fall back to default model_name = model or self.default_model - # Use provided temperature or fall back to default effective_temperature = temperature if temperature is not None else self.temperature logger.debug(f"Using model: {model_name}") @@ -72,31 +94,38 @@ class Processor(LlmService): try: - resp = self.openai.chat.completions.create( - model=model_name, - messages=[ - { - "role": "user", - "content": [ - { - "type": "text", - "text": prompt - } - ] - } - ], - temperature=effective_temperature, - max_completion_tokens=self.max_output, + api_kwargs = self._build_kwargs(model_name, effective_temperature) + + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt + } + ] + } + ] + + resp = self.variant.create_completion( + self.openai, model_name, messages, **api_kwargs, ) - + inputtokens = resp.usage.prompt_tokens outputtokens = resp.usage.completion_tokens - logger.debug(f"LLM response: {resp.choices[0].message.content}") + + content = self._extract_content(resp.choices[0].message) + thinking = self.variant.extract_thinking(resp.choices[0].message) + + logger.debug(f"LLM response: {content}") + if thinking: + logger.debug(f"LLM thinking: {thinking[:200]}...") logger.info(f"Input Tokens: {inputtokens}") logger.info(f"Output Tokens: {outputtokens}") resp = LlmResult( - text = resp.choices[0].message.content, + text = content, in_token = inputtokens, out_token = outputtokens, model = model_name @@ -136,9 +165,7 @@ class Processor(LlmService): Stream content generation from OpenAI. Yields LlmChunk objects with is_final=True on the last chunk. """ - # Use provided model or fall back to default model_name = model or self.default_model - # Use provided temperature or fall back to default effective_temperature = temperature if temperature is not None else self.temperature logger.debug(f"Using model (streaming): {model_name}") @@ -147,30 +174,26 @@ class Processor(LlmService): prompt = system + "\n\n" + prompt try: - response = self.openai.chat.completions.create( - model=model_name, - messages=[ - { - "role": "user", - "content": [ - { - "type": "text", - "text": prompt - } - ] - } - ], - temperature=effective_temperature, - max_completion_tokens=self.max_output, - stream=True, - stream_options={"include_usage": True} - ) + api_kwargs = self._build_kwargs(model_name, effective_temperature) + + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt + } + ] + } + ] total_input_tokens = 0 total_output_tokens = 0 - # Stream chunks - for chunk in response: + async for chunk in self.variant.create_completion_stream( + self.openai, model_name, messages, **api_kwargs, + ): if chunk.choices and chunk.choices[0].delta.content: yield LlmChunk( text=chunk.choices[0].delta.content, @@ -254,6 +277,20 @@ class Processor(LlmService): help=f'LLM max output tokens (default: {default_max_output})' ) + parser.add_argument( + '--thinking', + choices=["off", "low", "medium", "high"], + default=default_thinking, + help=f'Thinking/reasoning effort level (default: {default_thinking})' + ) + + parser.add_argument( + '--variant', + choices=sorted(VARIANTS.keys()), + default=DEFAULT_VARIANT, + help=f'API variant (default: {DEFAULT_VARIANT})' + ) + def run(): - + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/variants.py b/trustgraph-flow/trustgraph/model/text_completion/openai/variants.py new file mode 100644 index 00000000..87de725d --- /dev/null +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/variants.py @@ -0,0 +1,219 @@ +""" +OpenAI API variant profiles. + +Different providers expose OpenAI-compatible APIs with subtle differences +in parameter names, thinking/reasoning support, and temperature handling. +Each variant encapsulates those quirks so the processor doesn't need +provider-specific conditionals. +""" + +import re +import logging + +logger = logging.getLogger(__name__) + + +class Variant: + """Base variant — defines the interface all variants implement.""" + + name = None + token_param = "max_completion_tokens" + temperature_with_thinking = False + + def completion_kwargs(self, max_output, temperature, thinking): + """Build provider-specific kwargs for chat.completions.create(). + + Parameters + ---------- + max_output : int + Configured max output tokens. + temperature : float + Configured temperature. + thinking : str + Thinking effort level: "off", "low", "medium", "high". + + Returns + ------- + dict + Extra kwargs to spread into the API call. + """ + kwargs = {self.token_param: max_output} + + if thinking != "off": + kwargs.update(self.thinking_kwargs(thinking)) + if not self.temperature_with_thinking: + kwargs["temperature"] = 1.0 + else: + kwargs["temperature"] = temperature + else: + kwargs["temperature"] = temperature + + return kwargs + + def thinking_kwargs(self, effort): + """Return kwargs to enable thinking at the given effort level.""" + return {} + + def extract_thinking(self, message): + """Extract thinking/reasoning content from a response message.""" + return getattr(message, "reasoning_content", None) + + def extract_thinking_stream(self, delta): + """Extract thinking content from a streaming delta.""" + return getattr(delta, "reasoning_content", None) + + def create_completion(self, client, model, messages, **kwargs): + """Call the completions API. Override for non-standard SDKs.""" + return client.chat.completions.create( + model=model, messages=messages, **kwargs, + ) + + async def create_completion_stream(self, client, model, messages, **kwargs): + """Call the streaming completions API. Override for non-standard SDKs.""" + for chunk in client.chat.completions.create( + model=model, messages=messages, stream=True, + stream_options={"include_usage": True}, **kwargs, + ): + yield chunk + + +class OpenAIVariant(Variant): + """Standard OpenAI API (GPT-4o, o1, o3, etc.).""" + + name = "openai" + token_param = "max_completion_tokens" + temperature_with_thinking = False + + def thinking_kwargs(self, effort): + return {"reasoning_effort": effort} + + +class DeepSeekVariant(Variant): + """DeepSeek API (R1, V3, etc.).""" + + name = "deepseek" + token_param = "max_completion_tokens" + temperature_with_thinking = True + + def completion_kwargs(self, max_output, temperature, thinking): + enabled = "enabled" if thinking != "off" else "disabled" + kwargs = { + self.token_param: max_output, + "temperature": temperature, + "extra_body": { + "thinking": {"type": enabled}, + }, + } + return kwargs + + def thinking_kwargs(self, effort): + return {} + + +class DashScopeVariant(Variant): + """Alibaba Cloud DashScope API (Qwen models).""" + + name = "dashscope" + token_param = "max_completion_tokens" + temperature_with_thinking = True + + def completion_kwargs(self, max_output, temperature, thinking): + enabled = thinking != "off" + return { + self.token_param: max_output, + "temperature": temperature, + "extra_body": { + "enable_thinking": enabled, + }, + } + + def thinking_kwargs(self, effort): + return {} + + +class QwenVariant(DashScopeVariant): + """Qwen — alias for DashScope.""" + + name = "qwen" + + +class MistralVariant(Variant): + """Mistral API (Mistral Large, etc.).""" + + name = "mistral" + token_param = "max_tokens" + temperature_with_thinking = False + + def thinking_kwargs(self, effort): + return {"reasoning_effort": effort} + + +class GlmVariant(Variant): + """GLM / Zhipu AI API (GLM-4, GLM-4.7, etc.).""" + + name = "glm" + token_param = "max_tokens" + temperature_with_thinking = True + + def completion_kwargs(self, max_output, temperature, thinking): + enabled = "enabled" if thinking != "off" else "disabled" + kwargs = { + self.token_param: max_output, + "temperature": temperature, + "extra_body": { + "thinking": {"type": enabled}, + }, + } + return kwargs + + def thinking_kwargs(self, effort): + return {} + + +class LlamaVariant(Variant): + """Llama models via OpenAI-compatible servers (vLLM, Ollama, etc.). + + Thinking is typically always-on or always-off depending on the model. + When present, thinking appears inline as ... tags. + """ + + name = "llama" + token_param = "max_tokens" + temperature_with_thinking = True + + def thinking_kwargs(self, effort): + return {} + + def extract_thinking(self, message): + content = message.content or "" + match = re.search(r"(.*?)", content, re.DOTALL) + return match.group(1).strip() if match else None + + def extract_content(self, message): + """Strip think tags from visible content.""" + content = message.content or "" + return re.sub(r".*?", "", content, flags=re.DOTALL).strip() + + +VARIANTS = { + "openai": OpenAIVariant, + "deepseek": DeepSeekVariant, + "qwen": QwenVariant, + "mistral": MistralVariant, + "dashscope": DashScopeVariant, + "glm": GlmVariant, + "llama": LlamaVariant, +} + +DEFAULT_VARIANT = "openai" + + +def get_variant(name): + """Look up a variant by name, raising ValueError if unknown.""" + cls = VARIANTS.get(name) + if cls is None: + raise ValueError( + f"Unknown variant {name!r}. " + f"Available: {', '.join(sorted(VARIANTS))}" + ) + return cls() diff --git a/trustgraph-flow/trustgraph/reranker/__init__.py b/trustgraph-flow/trustgraph/reranker/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/trustgraph-flow/trustgraph/reranker/__init__.py @@ -0,0 +1 @@ + diff --git a/trustgraph-flow/trustgraph/reranker/flashrank/__init__.py b/trustgraph-flow/trustgraph/reranker/flashrank/__init__.py new file mode 100644 index 00000000..bd3b0e96 --- /dev/null +++ b/trustgraph-flow/trustgraph/reranker/flashrank/__init__.py @@ -0,0 +1,2 @@ + +from . processor import * diff --git a/trustgraph-flow/trustgraph/reranker/flashrank/__main__.py b/trustgraph-flow/trustgraph/reranker/flashrank/__main__.py new file mode 100644 index 00000000..1ebce4d4 --- /dev/null +++ b/trustgraph-flow/trustgraph/reranker/flashrank/__main__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from . processor import run + +if __name__ == '__main__': + run() diff --git a/trustgraph-flow/trustgraph/reranker/flashrank/processor.py b/trustgraph-flow/trustgraph/reranker/flashrank/processor.py new file mode 100644 index 00000000..481d1a79 --- /dev/null +++ b/trustgraph-flow/trustgraph/reranker/flashrank/processor.py @@ -0,0 +1,109 @@ + +""" +Reranker service using flashrank. +Scores query-document pairs and returns the top results ranked by +relevance. +""" + +import asyncio +import logging + +from ... base import RerankerService +from ... schema import RerankerResult + +from flashrank import Ranker, RerankRequest + +logger = logging.getLogger(__name__) + +default_ident = "reranker" + +default_model = "ms-marco-MiniLM-L-12-v2" + +class Processor(RerankerService): + + def __init__(self, **params): + + model = params.get("model", default_model) + + super(Processor, self).__init__( + **params | { "model": model } + ) + + self.default_model = model + + self.cached_model_name = None + self.ranker = None + + self._load_model(model) + + def _load_model(self, model_name): + if self.cached_model_name != model_name: + logger.info(f"Loading flashrank model: {model_name}") + self.ranker = Ranker(model_name=model_name) + self.cached_model_name = model_name + logger.info(f"flashrank model {model_name} loaded successfully") + else: + logger.debug(f"Using cached model: {model_name}") + + def _run_rerank(self, query, passages): + request = RerankRequest(query=query, passages=passages) + return self.ranker.rerank(request) + + async def on_rerank(self, queries, documents, limit, model=None): + + if not queries or not documents: + return [] + + use_model = model or self.default_model + + if self.cached_model_name != use_model: + await asyncio.to_thread(self._load_model, use_model) + + passages = [ + {"id": d.document_id, "text": d.document_text} + for d in documents + ] + + best_scores = {} + + for q in queries: + ranked = await asyncio.to_thread( + self._run_rerank, q.query_text, passages, + ) + + for r in ranked: + doc_id = r["id"] + score = r["score"] + score = float(score) + if doc_id not in best_scores or score > best_scores[doc_id][1]: + best_scores[doc_id] = (q.query_id, score) + + results = sorted( + best_scores.items(), + key=lambda x: x[1][1], + reverse=True, + )[:limit] + + return [ + RerankerResult( + document_id=doc_id, + query_id=query_id, + score=score, + ) + for doc_id, (query_id, score) in results + ] + + @staticmethod + def add_args(parser): + + RerankerService.add_args(parser) + + parser.add_argument( + '-m', '--model', + default=default_model, + help=f'Reranker model (default: {default_model})' + ) + +def run(): + + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py index ecfa7936..f2087912 100644 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py @@ -9,31 +9,41 @@ from trustgraph.provenance import ( docrag_question_uri, docrag_grounding_uri, docrag_exploration_uri, + docrag_focus_uri, docrag_synthesis_uri, docrag_question_triples, grounding_triples, docrag_exploration_triples, + docrag_chunk_selection_triples, docrag_synthesis_triples, set_graph, GRAPH_RETRIEVAL, ) +from .rerank import RerankCandidate, mmr_select + # Module logger logger = logging.getLogger(__name__) +# When the caller does not specify a fetch_limit, reranking over-fetches this +# many times the final doc_limit as the candidate pool, so the cross-encoder can +# recover relevant chunks the bi-encoder ranked just outside the top doc_limit. +# This is only the fallback default: an explicit fetch_limit overrides it. +OVERFETCH_FACTOR = 3 + LABEL="http://www.w3.org/2000/01/rdf-schema#label" class Query: def __init__( self, rag, workspace, collection, verbose, - doc_limit=20, track_usage=None, + fetch_limit=20, track_usage=None, ): self.rag = rag self.workspace = workspace self.collection = collection self.verbose = verbose - self.doc_limit = doc_limit + self.fetch_limit = fetch_limit self.track_usage = track_usage async def extract_concepts(self, query): @@ -91,7 +101,7 @@ class Query: # Query chunk matches for each concept concurrently per_concept_limit = max( - 1, self.doc_limit // len(vectors) + 1, self.fetch_limit // len(vectors) ) async def query_concept(vec): @@ -140,7 +150,10 @@ class DocumentRag: def __init__( self, prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk, + reranker_client=None, verbose=False, + rerank_diversity_mode="none", + rerank_diversity_lambda=0.7, ): self.verbose = verbose @@ -150,12 +163,18 @@ class DocumentRag: self.doc_embeddings_client = doc_embeddings_client self.fetch_chunk = fetch_chunk + # Optional cross-encoder reranker. When None, the retrieval path is + # byte-identical to the pre-reranker behaviour. + self.reranker_client = reranker_client + self.rerank_diversity_mode = rerank_diversity_mode + self.rerank_diversity_lambda = rerank_diversity_lambda + if self.verbose: logger.debug("DocumentRag initialized") async def query( self, query, workspace="default", collection="default", - doc_limit=20, streaming=False, chunk_callback=None, + doc_limit=20, fetch_limit=0, streaming=False, chunk_callback=None, explain_callback=None, save_answer_callback=None, ): """ @@ -165,7 +184,10 @@ class DocumentRag: query: The query string workspace: Workspace for isolation (also scopes chunk lookup) collection: Collection identifier - doc_limit: Max chunks to retrieve + doc_limit: Chunks selected into the synthesis prompt (after rerank) + fetch_limit: Candidate pool fetched from the vector store before + reranking. 0 = derive (OVERFETCH_FACTOR x doc_limit when a + reranker is wired, else doc_limit). streaming: Enable streaming LLM response chunk_callback: async def callback(chunk, end_of_stream) for streaming explain_callback: async def callback(triples, explain_id) for explainability @@ -197,6 +219,7 @@ class DocumentRag: q_uri = docrag_question_uri(session_id) gnd_uri = docrag_grounding_uri(session_id) exp_uri = docrag_exploration_uri(session_id) + foc_uri = docrag_focus_uri(session_id) syn_uri = docrag_synthesis_uri(session_id) timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") @@ -209,10 +232,21 @@ class DocumentRag: ) await explain_callback(q_triples, q_uri) + # Resolve the candidate-pool size fetched from the vector store. When a + # reranker is wired, honour an explicit fetch_limit; if unset, fall back + # to the OVERFETCH_FACTOR heuristic. Never fetch fewer than doc_limit, + # else the rerank could not fill the prompt. Without a reranker, fetch + # doc_limit as before (byte-identical behaviour). + if self.reranker_client is not None: + fl = fetch_limit or (OVERFETCH_FACTOR * doc_limit) + fetch_count = max(fl, doc_limit) + else: + fetch_count = doc_limit + q = Query( rag=self, workspace=workspace, collection=collection, verbose=self.verbose, - doc_limit=doc_limit, track_usage=track_usage, + fetch_limit=fetch_count, track_usage=track_usage, ) # Extract concepts from query (grounding step) @@ -235,6 +269,7 @@ class DocumentRag: docs, chunk_ids = await q.get_docs(concepts) # Emit exploration explainability after chunks retrieved + # (full candidate set, before any reranking) if explain_callback: exp_triples = set_graph( docrag_exploration_triples(exp_uri, gnd_uri, len(chunk_ids), chunk_ids), @@ -242,6 +277,89 @@ class DocumentRag: ) await explain_callback(exp_triples, exp_uri) + # Optional cross-encoder reranking pass between retrieval and + # synthesis. Mirrors GraphRAG's reranker usage but with a single + # query (the question). When no reranker is wired, this block is + # skipped entirely and behaviour is byte-identical to before. + reranked = False + if self.reranker_client is not None and docs: + use_diversity = self.rerank_diversity_mode == "mmr" + + # Without diversity selection, preserve the existing #1011 + # behavior: ask the reranker for exactly doc_limit results. + # + # With diversity selection enabled, ask the reranker to score the + # full fetched candidate pool first, then let MMR choose the final + # doc_limit context set. + rerank_limit = len(docs) if use_diversity else doc_limit + + results = await self.reranker_client.rerank( + queries=[{"id": "0", "text": query}], + documents=[ + {"id": str(i), "text": d} for i, d in enumerate(docs) + ], + limit=rerank_limit, + ) + + source_docs = docs + source_chunk_ids = chunk_ids + + if use_diversity: + candidates = [ + RerankCandidate( + index=int(r.document_id), + chunk_id=source_chunk_ids[int(r.document_id)], + text=source_docs[int(r.document_id)], + reranker_score=r.score, + ) + for r in results + ] + + selected_candidates = mmr_select( + candidates, + limit=doc_limit, + lambda_mult=self.rerank_diversity_lambda, + ) + + docs = [candidate.text for candidate in selected_candidates] + chunk_ids = [ + candidate.chunk_id for candidate in selected_candidates + ] + + selected_chunks_with_scores = [ + { + "chunk_id": candidate.chunk_id, + "score": candidate.reranker_score, + } + for candidate in selected_candidates + ] + + else: + # results are sorted desc by score and truncated to limit by the + # reranker service, so order gives the surviving top-N directly. + order = [int(r.document_id) for r in results] + docs = [source_docs[i] for i in order] + chunk_ids = [source_chunk_ids[i] for i in order] + + selected_chunks_with_scores = [ + {"chunk_id": chunk_ids[i], "score": r.score} + for i, r in enumerate(results) + ] + + reranked = True + + # Emit chunk-selection (focus) explainability: surviving chunks + # with their cross-encoder scores, derived from exploration. + if explain_callback: + foc_triples = set_graph( + docrag_chunk_selection_triples( + foc_uri, exp_uri, + selected_chunks_with_scores, session_id, + ), + GRAPH_RETRIEVAL + ) + await explain_callback(foc_triples, foc_uri) + if self.verbose: logger.debug("Invoking LLM...") logger.debug(f"Documents: {docs}") @@ -291,9 +409,15 @@ class DocumentRag: logger.warning(f"Failed to save answer to librarian: {e}") synthesis_doc_id = None + # When reranking ran, synthesis derives from the focus (the + # reranked chunks actually fed to the LLM), as GraphRAG always does. + # When no reranker is wired, there is no focus stage, so synthesis + # derives from exploration (the unchanged no-op lineage) - a + # deliberate divergence from GraphRAG's always-on focus. + syn_parent = foc_uri if reranked else exp_uri syn_triples = set_graph( docrag_synthesis_triples( - syn_uri, exp_uri, + syn_uri, syn_parent, document_id=synthesis_doc_id, in_token=synthesis_result.in_token if synthesis_result else None, out_token=synthesis_result.out_token if synthesis_result else None, diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py index c80f4172..80dfb6b1 100755 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py @@ -13,6 +13,7 @@ from . document_rag import DocumentRag from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import PromptClientSpec, EmbeddingsClientSpec from ... base import DocumentEmbeddingsClientSpec +from ... base import RerankerClientSpec from ... base import LibrarianSpec # Module logger @@ -28,14 +29,27 @@ class Processor(FlowProcessor): doc_limit = params.get("doc_limit", 5) + # Instance-default candidate-pool size fetched before cross-encoder + # reranking; the rerank step narrows it back down to doc_limit for the + # LLM. 0 means the core derives it (OVERFETCH_FACTOR x doc_limit). + fetch_limit = params.get("fetch_limit", 0) + rerank_diversity_mode = params.get("rerank_diversity_mode", "none") + rerank_diversity_lambda = params.get("rerank_diversity_lambda", 0.7) + super(Processor, self).__init__( **params | { "id": id, "doc_limit": doc_limit, + "fetch_limit": fetch_limit, + "rerank_diversity_mode": rerank_diversity_mode, + "rerank_diversity_lambda": rerank_diversity_lambda, } ) self.doc_limit = doc_limit + self.fetch_limit = fetch_limit + self.rerank_diversity_mode = rerank_diversity_mode + self.rerank_diversity_lambda = rerank_diversity_lambda self.register_specification( ConsumerSpec( @@ -66,6 +80,13 @@ class Processor(FlowProcessor): ) ) + self.register_specification( + RerankerClientSpec( + request_name = "reranker-request", + response_name = "reranker-response", + ) + ) + self.register_specification( ProducerSpec( name = "response", @@ -105,7 +126,10 @@ class Processor(FlowProcessor): doc_embeddings_client = flow("document-embeddings-request"), prompt_client = flow("prompt-request"), fetch_chunk = fetch_chunk, + reranker_client = flow("reranker-request"), verbose=True, + rerank_diversity_mode=self.rerank_diversity_mode, + rerank_diversity_lambda=self.rerank_diversity_lambda, ) if v.doc_limit: @@ -113,6 +137,13 @@ class Processor(FlowProcessor): else: doc_limit = self.doc_limit + # Candidate-pool size: per-request override, else the instance + # default; 0 lets the core derive it from doc_limit. + if v.fetch_limit: + fetch_limit = v.fetch_limit + else: + fetch_limit = self.fetch_limit + async def send_explainability(triples, explain_id): await flow("explainability").send(Triples( metadata=Metadata( @@ -163,6 +194,7 @@ class Processor(FlowProcessor): workspace=flow.workspace, collection=v.collection, doc_limit=doc_limit, + fetch_limit=fetch_limit, streaming=True, chunk_callback=send_chunk, explain_callback=send_explainability, @@ -188,6 +220,7 @@ class Processor(FlowProcessor): workspace=flow.workspace, collection=v.collection, doc_limit=doc_limit, + fetch_limit=fetch_limit, explain_callback=send_explainability, save_answer_callback=save_answer, ) @@ -243,6 +276,29 @@ class Processor(FlowProcessor): help=f'Default document fetch limit (default: 10)' ) + parser.add_argument( + '--fetch-limit', + type=int, + default=0, + help='Candidate chunks to fetch from the vector store and rerank ' + 'before keeping the top doc-limit for the LLM ' + '(default: derive from doc-limit)' + ) + + parser.add_argument( + '--rerank-diversity-mode', + choices=['none', 'mmr'], + default='none', + help='Optional diversity-aware selection after reranking (default: none)' + ) + + parser.add_argument( + '--rerank-diversity-lambda', + type=float, + default=0.7, + help='MMR relevance/diversity tradeoff, higher values prefer relevance' + ) + def run(): Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rerank.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rerank.py new file mode 100644 index 00000000..a0a7e8ee --- /dev/null +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rerank.py @@ -0,0 +1,142 @@ +import re +from dataclasses import dataclass, replace +from typing import List, Sequence, Set + + +@dataclass(frozen=True) +class RerankCandidate: + """ + Candidate chunk after cross-encoder reranking. + + reranker_score is the raw score returned by the reranker backend. It may + not be normalized, so MMR should use normalized_score instead. + """ + index: int + chunk_id: str + text: str + reranker_score: float + normalized_score: float = 0.0 + + +_TOKEN_RE = re.compile(r"[A-Za-z0-9_]+") + + +def _clamp01(value: float) -> float: + return max(0.0, min(1.0, value)) + + +def _token_set(text: str) -> Set[str]: + return set(token.lower() for token in _TOKEN_RE.findall(text or "")) + + +def _jaccard(a: str, b: str) -> float: + a_tokens = _token_set(a) + b_tokens = _token_set(b) + + if not a_tokens or not b_tokens: + return 0.0 + + return len(a_tokens & b_tokens) / len(a_tokens | b_tokens) + + +def normalize_candidate_scores( + candidates: Sequence[RerankCandidate], +) -> List[RerankCandidate]: + """ + Min-max normalize reranker scores within the current candidate set. + + Reranker backends may return different score scales: probabilities, + logits, or prompt-defined scores. MMR needs a stable [0, 1] relevance + signal, so normalize per candidate set instead of assuming a global range. + """ + if not candidates: + return [] + + scores = [float(candidate.reranker_score) for candidate in candidates] + min_score = min(scores) + max_score = max(scores) + + if max_score == min_score: + return [ + replace(candidate, normalized_score=0.5) + for candidate in candidates + ] + + score_range = max_score - min_score + + return [ + replace( + candidate, + normalized_score=(float(candidate.reranker_score) - min_score) / score_range, + ) + for candidate in candidates + ] + + +def _pair_diversity_penalty( + candidate: RerankCandidate, + selected: RerankCandidate, + token_overlap_weight: float, +) -> float: + """ + Pairwise diversity penalty between two candidate chunks. + + The first revision only uses token overlap because the current Document-RAG + reranker document_id is the candidate index, not a source document id. + """ + penalty = token_overlap_weight * _jaccard(candidate.text, selected.text) + return _clamp01(penalty) + + +def mmr_select( + candidates: Sequence[RerankCandidate], + limit: int, + lambda_mult: float = 0.7, + token_overlap_weight: float = 1.0, +) -> List[RerankCandidate]: + """ + Select a diverse final context set using MMR. + + Relevance comes from normalized cross-encoder reranker scores. + Diversity comes from token overlap against already selected chunks. + """ + if limit <= 0: + return [] + + lambda_mult = _clamp01(lambda_mult) + token_overlap_weight = max(0.0, token_overlap_weight) + + remaining = normalize_candidate_scores(candidates) + selected: List[RerankCandidate] = [] + + while remaining and len(selected) < limit: + best_idx = 0 + best_score = None + + for idx, candidate in enumerate(remaining): + relevance = candidate.normalized_score + + if selected: + diversity_penalty = max( + _pair_diversity_penalty( + candidate, + chosen, + token_overlap_weight=token_overlap_weight, + ) + for chosen in selected + ) + else: + diversity_penalty = 0.0 + + mmr_score = ( + lambda_mult * relevance + - (1.0 - lambda_mult) * diversity_penalty + ) + + if best_score is None or mmr_score > best_score: + best_score = mmr_score + best_idx = idx + + selected.append(remaining.pop(best_idx)) + + return selected diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py index 81dc8fe2..2054cb0f 100644 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py @@ -120,7 +120,7 @@ class Query: def __init__( self, rag, collection, verbose, entity_limit=50, triple_limit=30, max_subgraph_size=1000, - max_path_length=2, track_usage=None, + max_path_length=2, edge_limit=25, track_usage=None, ): self.rag = rag self.collection = collection @@ -129,6 +129,7 @@ class Query: self.triple_limit = triple_limit self.max_subgraph_size = max_subgraph_size self.max_path_length = max_path_length + self.edge_limit = edge_limit self.track_usage = track_usage async def extract_concepts(self, query): @@ -217,12 +218,9 @@ class Query: logger.debug(f" {ent}") return entities, concepts - + async def maybe_label(self, e): - # The label cache lives on a per-request GraphRag instance — no - # cross-request isolation concern. The collection prefix keeps - # entries from different collections distinct within one request. cache_key = f"{self.collection}:{e}" cached_label = self.rag.label_cache.get(cache_key) @@ -243,206 +241,217 @@ class Query: self.rag.label_cache.put(cache_key, label) return label + FROM_S = "from_s" + FROM_P = "from_p" + FROM_O = "from_o" + async def execute_batch_triple_queries(self, entities, limit_per_entity): - """Execute triple queries for multiple entities concurrently using streaming""" + """Execute triple queries for multiple entities concurrently. + + Returns a list of (triple, direction) tuples where direction + indicates which position the frontier entity occupied. + """ tasks = [] + directions = [] for entity in entities: - # Create concurrent streaming tasks for all 3 query types per entity - tasks.extend([ + tasks.append( self.rag.triples_client.query_stream( s=entity, p=None, o=None, limit=limit_per_entity, collection=self.collection, batch_size=20, g="", ), + ) + directions.append(self.FROM_S) + + tasks.append( self.rag.triples_client.query_stream( s=None, p=entity, o=None, limit=limit_per_entity, collection=self.collection, batch_size=20, g="", ), + ) + directions.append(self.FROM_P) + + tasks.append( self.rag.triples_client.query_stream( s=None, p=None, o=entity, limit=limit_per_entity, collection=self.collection, batch_size=20, g="", - ) - ]) + ), + ) + directions.append(self.FROM_O) - # Execute all queries concurrently results = await asyncio.gather(*tasks, return_exceptions=True) - # Combine all results all_triples = [] - for result in results: + for direction, result in zip(directions, results): if not isinstance(result, Exception) and result is not None: - all_triples.extend(result) + all_triples.extend((triple, direction) for triple in result) return all_triples - async def follow_edges_batch(self, entities, max_depth): - """Optimized iterative graph traversal with batching. - - Returns: - tuple: (subgraph, term_map) where subgraph is a set of - (str, str, str) tuples and term_map maps each string tuple - to its original (Term, Term, Term) for type-preserving - provenance. - """ - visited = set() - current_level = set(entities) - subgraph = set() - term_map = {} # (str, str, str) -> (Term, Term, Term) - - for depth in range(max_depth): - if not current_level or len(subgraph) >= self.max_subgraph_size: - break - - # Filter out already visited entities - unvisited_entities = [e for e in current_level if e not in visited] - if not unvisited_entities: - break - - # Batch query all unvisited entities at current level - triples = await self.execute_batch_triple_queries( - unvisited_entities, self.triple_limit - ) - - # Process results and collect next level entities - next_level = set() - for triple in triples: - triple_tuple = (str(triple.s), str(triple.p), str(triple.o)) - subgraph.add(triple_tuple) - term_map[triple_tuple] = (to_term(triple.s), to_term(triple.p), to_term(triple.o)) - - # Collect entities for next level (only from s and o positions) - if depth < max_depth - 1: # Don't collect for final depth - s, p, o = triple_tuple - if s not in visited: - next_level.add(s) - if o not in visited: - next_level.add(o) - - # Stop if subgraph size limit reached - if len(subgraph) >= self.max_subgraph_size: - return subgraph, term_map - - # Update for next iteration - visited.update(current_level) - current_level = next_level - - return subgraph, term_map - - async def follow_edges(self, ent, subgraph, path_length): - """Legacy method - replaced by follow_edges_batch""" - # Maintain backward compatibility with early termination checks - if path_length <= 0: - return - - if len(subgraph) >= self.max_subgraph_size: - return - - # For backward compatibility, convert to new approach - batch_result, _ = await self.follow_edges_batch([ent], path_length) - subgraph.update(batch_result) - - async def get_subgraph(self, query): - """ - Get subgraph by extracting concepts, finding entities, and traversing. - - Returns: - tuple: (subgraph, term_map, entities, concepts) where subgraph is - a list of (s, p, o) string tuples, term_map maps each string - tuple to its original (Term, Term, Term), entities is the seed - entity list, and concepts is the extracted concept list. - """ - - entities, concepts = await self.get_entities(query) - - if self.verbose: - logger.debug("Getting subgraph...") - - # Use optimized batch traversal instead of sequential processing - subgraph, term_map = await self.follow_edges_batch(entities, self.max_path_length) - - return list(subgraph), term_map, entities, concepts - async def resolve_labels_batch(self, entities): - """Resolve labels for multiple entities in parallel""" - tasks = [] - for entity in entities: - tasks.append(self.maybe_label(entity)) - + """Resolve labels for multiple entities in parallel.""" + tasks = [self.maybe_label(entity) for entity in entities] return await asyncio.gather(*tasks, return_exceptions=True) - async def get_labelgraph(self, query): - """ - Get subgraph with labels resolved for display. + async def hop_and_filter(self, seed_entities, concepts): + """Iterative hop-and-filter graph traversal with cross-encoder. + + At each hop: + 1. Retrieve all edges one hop from the frontier. + 2. Resolve labels and represent each edge as "{p} {o}". + 3. Score edges against concepts using the cross-encoder. + 4. Select the top edges; their target nodes become the next + frontier. Returns: - tuple: (labeled_edges, uri_map, entities, concepts) where: - - labeled_edges: list of (label_s, label_p, label_o) tuples - - uri_map: dict mapping edge_id(label_s, label_p, label_o) -> (uri_s, uri_p, uri_o) - - entities: list of seed entity URI strings - - concepts: list of concept strings extracted from query + tuple: (selected_edges, uri_map, edge_metadata) where: + - selected_edges: list of (label_s, label_p, label_o) + - uri_map: dict mapping edge_id -> (Term, Term, Term) + - edge_metadata: dict mapping edge_id -> {concept, score} """ - subgraph, term_map, entities, concepts = await self.get_subgraph(query) + all_selected_edges = [] + uri_map = {} + edge_metadata = {} + frontier = set(seed_entities) + visited_entities = set() + seen_edges = set() - # Filter out label triples - filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL] + for hop in range(self.max_path_length): + if not frontier: + break - # Collect all unique entities that need label resolution - entities_to_resolve = set() - for s, p, o in filtered_subgraph: - entities_to_resolve.update([s, p, o]) + unvisited = [e for e in frontier if e not in visited_entities] + if not unvisited: + break - # Batch resolve labels for all entities in parallel - entity_list = list(entities_to_resolve) - resolved_labels = await self.resolve_labels_batch(entity_list) + if self.verbose: + logger.debug( + f"Hop {hop + 1}: {len(unvisited)} frontier entities" + ) - # Create entity-to-label mapping - label_map = {} - for entity, label in zip(entity_list, resolved_labels): - if not isinstance(label, Exception): - label_map[entity] = label - else: - label_map[entity] = entity # Fallback to entity itself - - # Apply labels to subgraph and build URI mapping - labeled_edges = [] - uri_map = {} # Maps edge_id of labeled edge -> original Term triple - - for s, p, o in filtered_subgraph: - labeled_triple = ( - label_map.get(s, s), - label_map.get(p, p), - label_map.get(o, o) + # Retrieve edges one hop from frontier + triples = await self.execute_batch_triple_queries( + unvisited, self.triple_limit, ) - labeled_edges.append(labeled_triple) - # Map from labeled edge ID to original Terms (preserving types) - labeled_eid = edge_id(labeled_triple[0], labeled_triple[1], labeled_triple[2]) - uri_map[labeled_eid] = term_map.get((s, p, o), (s, p, o)) + # Deduplicate and filter already-seen edges + hop_triples = [] + hop_term_map = {} + hop_directions = {} + for triple, direction in triples: + triple_tuple = (str(triple.s), str(triple.p), str(triple.o)) + if triple_tuple[1] == LABEL: + continue + if triple_tuple in seen_edges: + continue + seen_edges.add(triple_tuple) + hop_triples.append(triple_tuple) + hop_term_map[triple_tuple] = ( + to_term(triple.s), to_term(triple.p), to_term(triple.o), + ) + hop_directions[triple_tuple] = direction - labeled_edges = labeled_edges[0:self.max_subgraph_size] + if not hop_triples: + visited_entities.update(frontier) + break - if self.verbose: - logger.debug("Subgraph:") - for edge in labeled_edges: - logger.debug(f" {str(edge)}") + if self.verbose: + logger.debug( + f"Hop {hop + 1}: {len(hop_triples)} candidate edges" + ) - if self.verbose: - logger.debug("Done.") + # Resolve labels for all entities in hop edges + entities_to_resolve = set() + for s, p, o in hop_triples: + entities_to_resolve.update([s, p, o]) - return labeled_edges, uri_map, entities, concepts + entity_list = list(entities_to_resolve) + resolved = await self.resolve_labels_batch(entity_list) + + label_map = {} + for entity, label in zip(entity_list, resolved): + if not isinstance(label, Exception): + label_map[entity] = label + else: + label_map[entity] = entity + + # Build labeled edges and documents for cross-encoder. + # The reranker text highlights the NEW information relative + # to the traversal direction: arriving from S means p,o are + # new; from O means s,p are new; from P means s,o are new. + labeled_hop = [] + for s, p, o in hop_triples: + ls = label_map.get(s, s) + lp = label_map.get(p, p) + lo = label_map.get(o, o) + labeled_hop.append((ls, lp, lo)) + + documents = [] + for i, (triple_tuple, (ls, lp, lo)) in enumerate( + zip(hop_triples, labeled_hop) + ): + direction = hop_directions[triple_tuple] + if direction == self.FROM_S: + text = f"{lp} {lo}" + elif direction == self.FROM_O: + text = f"{ls} {lp}" + else: + text = f"{ls} {lo}" + documents.append({"id": str(i), "text": text}) + + queries = [ + {"id": str(i), "text": c} + for i, c in enumerate(concepts) + ] + + # Score with cross-encoder + results = await self.rag.reranker_client.rerank( + queries=queries, + documents=documents, + limit=self.edge_limit, + ) + + # Collect selected edges and metadata + next_frontier = set() + for r in results: + idx = int(r.document_id) + ls, lp, lo = labeled_hop[idx] + s, p, o = hop_triples[idx] + eid = edge_id(ls, lp, lo) + + all_selected_edges.append((ls, lp, lo)) + uri_map[eid] = hop_term_map[(s, p, o)] + edge_metadata[eid] = { + "concept": concepts[int(r.query_id)], + "score": r.score, + } + + # Target nodes become next frontier + next_frontier.add(s) + next_frontier.add(o) + + if self.verbose: + logger.debug( + f"Hop {hop + 1}: selected {len(results)} edges" + ) + + visited_entities.update(frontier) + frontier = next_frontier - visited_entities + + return all_selected_edges, uri_map, edge_metadata async def trace_source_documents(self, edge_uris): """ Trace selected edges back to their source documents via provenance. - Follows the chain: edge → subgraph (via tg:contains) → chunk → - page → document (via prov:wasDerivedFrom), all in urn:graph:source. + Follows the chain: edge -> subgraph (via tg:contains) -> chunk -> + page -> document (via prov:wasDerivedFrom), all in urn:graph:source. Args: edge_uris: List of (s, p, o) URI string tuples @@ -453,7 +462,6 @@ class Query: # Step 1: Find subgraphs containing these edges via tg:contains subgraph_tasks = [] for s, p, o in edge_uris: - # s, p, o may be Term objects (preserving types) or strings s_term = s if isinstance(s, Term) else Term(type=IRI, iri=s) p_term = p if isinstance(p, Term) else Term(type=IRI, iri=p) o_term = o if isinstance(o, Term) else Term(type=IRI, iri=o) @@ -487,12 +495,10 @@ class Query: return [] # Step 2: Walk prov:wasDerivedFrom chain to find documents - # Each level: query ?entity prov:wasDerivedFrom ?parent - # Stop when we find entities typed tg:Document current_uris = subgraph_uris doc_uris = set() - for depth in range(4): # Max depth: subgraph → chunk → page → doc + for depth in range(4): if not current_uris: break @@ -509,7 +515,6 @@ class Query: *derivation_tasks, return_exceptions=True ) - # URIs with no parent are root documents next_uris = set() for uri, result in zip(current_uris, derivation_results): if isinstance(result, Exception) or not result: @@ -524,7 +529,6 @@ class Query: return [] # Step 3: Get all document metadata properties - # Skip structural predicates that aren't useful context SKIP_PREDICATES = { PROV_WAS_DERIVED_FROM, "http://www.w3.org/1999/02/22-rdf-syntax-ns#type", @@ -565,7 +569,7 @@ class GraphRag: def __init__( self, prompt_client, embeddings_client, graph_embeddings_client, - triples_client, verbose=False, + triples_client, reranker_client, verbose=False, ): self.verbose = verbose @@ -574,9 +578,8 @@ class GraphRag: self.embeddings_client = embeddings_client self.graph_embeddings_client = graph_embeddings_client self.triples_client = triples_client + self.reranker_client = reranker_client - # Replace simple dict with LRU cache with TTL - # CRITICAL: This cache only lives for one request due to per-request instantiation self.label_cache = LRUCacheWithTTL(max_size=5000, ttl=300) if self.verbose: @@ -585,33 +588,12 @@ class GraphRag: async def query( self, query, collection = "default", entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000, - max_path_length = 2, edge_score_limit = 30, edge_limit = 25, + max_path_length = 2, edge_limit = 25, streaming = False, chunk_callback = None, explain_callback = None, save_answer_callback = None, parent_uri = "", ): - """ - Execute a GraphRAG query with real-time explainability tracking. - - Args: - query: The query string - collection: Collection identifier - entity_limit: Max entities to retrieve - triple_limit: Max triples per entity - max_subgraph_size: Max edges in subgraph - max_path_length: Max hops from seed entities - edge_score_limit: Max edges to pass to LLM scoring (semantic pre-filter) - edge_limit: Max edges after LLM scoring - streaming: Enable streaming LLM response - chunk_callback: async def callback(chunk, end_of_stream) for streaming - explain_callback: async def callback(triples, explain_id) for real-time explainability - save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian - - Returns: - tuple: (answer_text, usage) where usage is a dict with - in_token, out_token, model - """ # Accumulate token usage across all prompt calls total_in = 0 total_out = 0 @@ -638,7 +620,9 @@ class GraphRag: foc_uri = make_focus_uri(session_id) syn_uri = make_synthesis_uri(session_id) - timestamp = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") + timestamp = datetime.now(timezone.utc).isoformat().replace( + "+00:00", "Z", + ) # Emit question explainability immediately if explain_callback: @@ -657,10 +641,12 @@ class GraphRag: triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, + edge_limit = edge_limit, track_usage = track_usage, ) - kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query) + # Step 1: Extract concepts and find seed entities + seed_entities, concepts = await q.get_entities(query) # Emit grounding explain after concept extraction if explain_callback: @@ -676,11 +662,16 @@ class GraphRag: ) await explain_callback(gnd_triples, gnd_uri) - # Emit exploration explain after graph retrieval completes + # Step 2: Iterative hop-and-filter with cross-encoder + selected_edges, uri_map, edge_metadata = await q.hop_and_filter( + seed_entities, concepts, + ) + + # Emit exploration explain if explain_callback: exp_triples = set_graph( exploration_triples( - exp_uri, gnd_uri, len(kg), + exp_uri, gnd_uri, len(selected_edges), entities=seed_entities, ), GRAPH_RETRIEVAL @@ -688,235 +679,63 @@ class GraphRag: await explain_callback(exp_triples, exp_uri) if self.verbose: - logger.debug("Invoking LLM...") - logger.debug(f"Knowledge graph: {kg}") - logger.debug(f"Query: {query}") - - # Semantic pre-filter: reduce edges before expensive LLM scoring - if edge_score_limit > 0 and len(kg) > edge_score_limit: - - if self.verbose: + logger.debug(f"Selected {len(selected_edges)} edges") + for s, p, o in selected_edges: + eid = edge_id(s, p, o) + meta = edge_metadata.get(eid, {}) logger.debug( - f"Semantic pre-filter: {len(kg)} edges > " - f"limit {edge_score_limit}, filtering..." + f" {meta.get('score', 0):.4f} " + f"[{meta.get('concept', '')}] " + f"{s} | {p} | {o}" ) - # Embed edge descriptions: "subject, predicate, object" - edge_descriptions = [ - f"{s}, {p}, {o}" for s, p, o in kg - ] - - # Embed concepts and edge descriptions concurrently - concept_embed_task = self.embeddings_client.embed(concepts) - edge_embed_task = self.embeddings_client.embed(edge_descriptions) - - concept_vectors, edge_vectors = await asyncio.gather( - concept_embed_task, edge_embed_task - ) - - # Score each edge by max cosine similarity to any concept - def cosine_similarity(a, b): - dot = sum(x * y for x, y in zip(a, b)) - norm_a = math.sqrt(sum(x * x for x in a)) - norm_b = math.sqrt(sum(x * x for x in b)) - if norm_a == 0 or norm_b == 0: - return 0.0 - return dot / (norm_a * norm_b) - - edge_scores = [] - for i, edge_vec in enumerate(edge_vectors): - max_sim = max( - cosine_similarity(edge_vec, cv) - for cv in concept_vectors - ) - edge_scores.append((max_sim, i)) - - # Sort by similarity descending and keep top edge_score_limit - edge_scores.sort(reverse=True) - keep_indices = set( - idx for _, idx in edge_scores[:edge_score_limit] - ) - - # Filter kg and rebuild uri_map - filtered_kg = [] - filtered_uri_map = {} - for i, (s, p, o) in enumerate(kg): - if i in keep_indices: - filtered_kg.append((s, p, o)) - eid = edge_id(s, p, o) - if eid in uri_map: - filtered_uri_map[eid] = uri_map[eid] - - if self.verbose: - logger.debug( - f"Semantic pre-filter kept {len(filtered_kg)} " - f"of {len(kg)} edges" - ) - - kg = filtered_kg - uri_map = filtered_uri_map - - # Build edge map: {hash_id: (labeled_s, labeled_p, labeled_o)} - # uri_map already maps edge_id -> (uri_s, uri_p, uri_o) - edge_map = {} - edges_with_ids = [] - for s, p, o in kg: - eid = edge_id(s, p, o) - edge_map[eid] = (s, p, o) - edges_with_ids.append({ - "id": eid, - "s": s, - "p": p, - "o": o - }) - - if self.verbose: - logger.debug(f"Built edge map with {len(edge_map)} edges") - - # Step 1a: Edge Scoring - LLM scores edges for relevance - scoring_result = await self.prompt_client.prompt( - "kg-edge-scoring", - variables={ - "query": query, - "knowledge": edges_with_ids - } - ) - track_usage(scoring_result) - - if self.verbose: - logger.debug(f"Edge scoring result: {scoring_result}") - - # Parse scoring response (jsonl) to get edge IDs with scores - scored_edges = [] - - for obj in scoring_result.objects or []: - if isinstance(obj, dict) and "id" in obj and "score" in obj: - try: - score = int(obj["score"]) - except (ValueError, TypeError): - score = 0 - scored_edges.append({"id": obj["id"], "score": score}) - - # Select top N edges by score - scored_edges.sort(key=lambda x: x["score"], reverse=True) - top_edges = scored_edges[:edge_limit] - selected_ids = {e["id"] for e in top_edges} - - if self.verbose: - logger.debug( - f"Scored {len(scored_edges)} edges, " - f"selected top {len(selected_ids)}" - ) - - # Filter to selected edges - selected_edges = [] - for eid in selected_ids: - if eid in edge_map: - selected_edges.append(edge_map[eid]) - - # Step 1b: Edge Reasoning + Document Tracing (concurrent) - selected_edges_with_ids = [ - {"id": eid, "s": s, "p": p, "o": o} - for eid in selected_ids - if eid in edge_map - for s, p, o in [edge_map[eid]] - ] - - # Collect selected edge URIs for document tracing + # Step 3: Document tracing selected_edge_uris = [ - uri_map[eid] - for eid in selected_ids - if eid in uri_map + uri_map[edge_id(s, p, o)] + for s, p, o in selected_edges + if edge_id(s, p, o) in uri_map ] - # Run reasoning and document tracing concurrently - async def _get_reasoning(): - result = await self.prompt_client.prompt( - "kg-edge-reasoning", - variables={ - "query": query, - "knowledge": selected_edges_with_ids - } - ) - track_usage(result) - return result - - reasoning_task = _get_reasoning() - doc_trace_task = q.trace_source_documents(selected_edge_uris) - - reasoning_result, source_documents = await asyncio.gather( - reasoning_task, doc_trace_task, return_exceptions=True + source_documents = await q.trace_source_documents( + selected_edge_uris, ) - # Handle exceptions from gather - if isinstance(reasoning_result, Exception): - logger.warning( - f"Edge reasoning failed: {reasoning_result}" - ) - reasoning_result = None if isinstance(source_documents, Exception): logger.warning( f"Document tracing failed: {source_documents}" ) source_documents = [] - - if self.verbose: - logger.debug(f"Edge reasoning result: {reasoning_result}") - - # Parse reasoning response (jsonl) and build explainability data - reasoning_map = {} - - if reasoning_result is not None: - for obj in reasoning_result.objects or []: - if isinstance(obj, dict) and "id" in obj: - reasoning_map[obj["id"]] = obj.get("reasoning", "") - + # Build focus explainability data with cross-encoder metadata selected_edges_with_reasoning = [] - for eid in selected_ids: + for s, p, o in selected_edges: + eid = edge_id(s, p, o) if eid in uri_map: uri_s, uri_p, uri_o = uri_map[eid] + meta = edge_metadata.get(eid, {}) selected_edges_with_reasoning.append({ "edge": (uri_s, uri_p, uri_o), - "reasoning": reasoning_map.get(eid, ""), + "concept": meta.get("concept", ""), + "score": meta.get("score", 0), }) - if self.verbose: - logger.debug(f"Filtered to {len(selected_edges)} edges") - - # Emit focus explain after edge selection completes + # Emit focus explain if explain_callback: - # Sum scoring + reasoning token usage for focus event - focus_in = 0 - focus_out = 0 - focus_model = None - for r in [scoring_result, reasoning_result]: - if r is not None: - if r.in_token is not None: - focus_in += r.in_token - if r.out_token is not None: - focus_out += r.out_token - if r.model is not None: - focus_model = r.model - foc_triples = set_graph( focus_triples( - foc_uri, exp_uri, selected_edges_with_reasoning, session_id, - in_token=focus_in or None, - out_token=focus_out or None, - model=focus_model, + foc_uri, exp_uri, + selected_edges_with_reasoning, session_id, ), GRAPH_RETRIEVAL ) await explain_callback(foc_triples, foc_uri) - # Step 2: Synthesis - LLM generates answer from selected edges only + # Step 4: Synthesis selected_edge_dicts = [ {"s": s, "p": p, "o": o} for s, p, o in selected_edges ] - # Add source document metadata as knowledge edges for s, p, o in source_documents: selected_edge_dicts.append({ "s": s, "p": p, "o": o, @@ -928,7 +747,6 @@ class GraphRag: } if streaming and chunk_callback: - # Accumulate chunks for answer storage while forwarding to callback accumulated_chunks = [] async def accumulating_callback(chunk, end_of_stream): @@ -942,7 +760,6 @@ class GraphRag: chunk_callback=accumulating_callback ) track_usage(synthesis_result) - # Combine all chunks into full response resp = "".join(accumulated_chunks) else: synthesis_result = await self.prompt_client.prompt( @@ -955,29 +772,42 @@ class GraphRag: if self.verbose: logger.debug("Query processing complete") - # Emit synthesis explain after synthesis completes + # Emit synthesis explain if explain_callback: synthesis_doc_id = None answer_text = resp if resp else "" - # Save answer to librarian if save_answer_callback and answer_text: synthesis_doc_id = f"urn:trustgraph:synthesis:{session_id}" try: await save_answer_callback(synthesis_doc_id, answer_text) if self.verbose: - logger.debug(f"Saved answer to librarian: {synthesis_doc_id}") + logger.debug( + f"Saved answer to librarian: " + f"{synthesis_doc_id}" + ) except Exception as e: - logger.warning(f"Failed to save answer to librarian: {e}") + logger.warning( + f"Failed to save answer to librarian: {e}" + ) synthesis_doc_id = None syn_triples = set_graph( synthesis_triples( syn_uri, foc_uri, document_id=synthesis_doc_id, - in_token=synthesis_result.in_token if synthesis_result else None, - out_token=synthesis_result.out_token if synthesis_result else None, - model=synthesis_result.model if synthesis_result else None, + in_token=( + synthesis_result.in_token + if synthesis_result else None + ), + out_token=( + synthesis_result.out_token + if synthesis_result else None + ), + model=( + synthesis_result.model + if synthesis_result else None + ), ), GRAPH_RETRIEVAL ) @@ -993,4 +823,3 @@ class GraphRag: } return resp, usage - diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py index 959ae8e0..27ec4937 100755 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py @@ -13,6 +13,7 @@ from . graph_rag import GraphRag from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import PromptClientSpec, EmbeddingsClientSpec from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec +from ... base import RerankerClientSpec from ... base import LibrarianSpec # Module logger @@ -32,7 +33,6 @@ class Processor(FlowProcessor): triple_limit = params.get("triple_limit", 30) max_subgraph_size = params.get("max_subgraph_size", 150) max_path_length = params.get("max_path_length", 2) - edge_score_limit = params.get("edge_score_limit", 30) edge_limit = params.get("edge_limit", 25) super(Processor, self).__init__( @@ -43,7 +43,6 @@ class Processor(FlowProcessor): "triple_limit": triple_limit, "max_subgraph_size": max_subgraph_size, "max_path_length": max_path_length, - "edge_score_limit": edge_score_limit, "edge_limit": edge_limit, } ) @@ -52,7 +51,6 @@ class Processor(FlowProcessor): self.default_triple_limit = triple_limit self.default_max_subgraph_size = max_subgraph_size self.default_max_path_length = max_path_length - self.default_edge_score_limit = edge_score_limit self.default_edge_limit = edge_limit # Workspace isolation is enforced by the flow layer (flow.workspace). @@ -96,6 +94,13 @@ class Processor(FlowProcessor): ) ) + self.register_specification( + RerankerClientSpec( + request_name = "reranker-request", + response_name = "reranker-response", + ) + ) + self.register_specification( ProducerSpec( name = "response", @@ -163,6 +168,7 @@ class Processor(FlowProcessor): graph_embeddings_client=flow("graph-embeddings-request"), triples_client=flow("triples-request"), prompt_client=flow("prompt-request"), + reranker_client=flow("reranker-request"), verbose=True, ) @@ -186,11 +192,6 @@ class Processor(FlowProcessor): else: max_path_length = self.default_max_path_length - if v.edge_score_limit: - edge_score_limit = v.edge_score_limit - else: - edge_score_limit = self.default_edge_score_limit - if v.edge_limit: edge_limit = v.edge_limit else: @@ -225,7 +226,7 @@ class Processor(FlowProcessor): entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, - edge_score_limit = edge_score_limit, + edge_limit = edge_limit, streaming = True, chunk_callback = send_chunk, @@ -241,7 +242,7 @@ class Processor(FlowProcessor): entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, - edge_score_limit = edge_score_limit, + edge_limit = edge_limit, explain_callback = send_explainability, save_answer_callback = save_answer, @@ -338,18 +339,11 @@ class Processor(FlowProcessor): help=f'Default max path length (default: 2)' ) - parser.add_argument( - '--edge-score-limit', - type=int, - default=30, - help=f'Semantic pre-filter limit before LLM scoring (default: 30)' - ) - parser.add_argument( '--edge-limit', type=int, default=25, - help=f'Max edges after LLM scoring (default: 25)' + help=f'Max edges selected per hop by cross-encoder (default: 25)' ) # Note: Explainability triples are now stored in the request's collection diff --git a/trustgraph-flow/trustgraph/tables/iam.py b/trustgraph-flow/trustgraph/tables/iam.py index b60e9cff..d06f9e35 100644 --- a/trustgraph-flow/trustgraph/tables/iam.py +++ b/trustgraph-flow/trustgraph/tables/iam.py @@ -94,10 +94,8 @@ class IamTableStore: self.cassandra.execute(""" CREATE TABLE IF NOT EXISTS iam_users_by_username ( - workspace text, - username text, - user_id text, - PRIMARY KEY ((workspace), username) + username text PRIMARY KEY, + user_id text ); """) @@ -175,16 +173,16 @@ class IamTableStore: """) self.put_username_lookup_stmt = c.prepare(""" - INSERT INTO iam_users_by_username (workspace, username, user_id) - VALUES (?, ?, ?) + INSERT INTO iam_users_by_username (username, user_id) + VALUES (?, ?) """) self.get_user_id_by_username_stmt = c.prepare(""" SELECT user_id FROM iam_users_by_username - WHERE workspace = ? AND username = ? + WHERE username = ? """) self.delete_username_lookup_stmt = c.prepare(""" DELETE FROM iam_users_by_username - WHERE workspace = ? AND username = ? + WHERE username = ? """) self.delete_user_stmt = c.prepare(""" DELETE FROM iam_users WHERE id = ? @@ -289,7 +287,7 @@ class IamTableStore: ) await async_execute( self.cassandra, self.put_username_lookup_stmt, - (workspace, username, id), + (username, id), ) async def get_user(self, id): @@ -298,10 +296,10 @@ class IamTableStore: ) return rows[0] if rows else None - async def get_user_id_by_username(self, workspace, username): + async def get_user_id_by_username(self, username): rows = await async_execute( self.cassandra, self.get_user_id_by_username_stmt, - (workspace, username), + (username,), ) return rows[0][0] if rows else None @@ -324,10 +322,10 @@ class IamTableStore: self.cassandra, self.delete_user_stmt, (id,), ) - async def delete_username_lookup(self, workspace, username): + async def delete_username_lookup(self, username): await async_execute( self.cassandra, self.delete_username_lookup_stmt, - (workspace, username), + (username,), ) # ------------------------------------------------------------------ diff --git a/trustgraph-mcp/trustgraph/mcp_server/mcp.py b/trustgraph-mcp/trustgraph/mcp_server/mcp.py index 7378db64..11b975b2 100755 --- a/trustgraph-mcp/trustgraph/mcp_server/mcp.py +++ b/trustgraph-mcp/trustgraph/mcp_server/mcp.py @@ -8,71 +8,180 @@ import logging import json import uuid import argparse -from dataclasses import dataclass +from dataclasses import dataclass, field from collections.abc import AsyncIterator from functools import partial from mcp.server.fastmcp import FastMCP, Context -from mcp.types import TextContent -from websockets.asyncio.client import connect +from mcp.server.auth.provider import AccessToken, TokenVerifier +from mcp.server.auth.middleware.auth_context import get_access_token from trustgraph.base.logging import add_logging_args, setup_logging -from . tg_socket import WebSocketManager +from . tg_socket import WebSocketManager, _token_key + +logger = logging.getLogger(__name__) + + +# Wire-format Term type codes (match TermTranslator compact keys) +_TERM_TYPES = { + "iri": "i", + "literal": "l", + "blank": "b", +} + + +def _make_term(value: str, term_type: str) -> dict: + """Build a compact-key Term dict for the gateway wire format. + + Args: + value: The term value (IRI string, literal text, or blank node id). + term_type: One of "iri", "literal", "blank". + """ + t = _TERM_TYPES.get(term_type) + if t is None: + raise ValueError( + f"Unknown term type '{term_type}' — " + f"expected one of: {', '.join(_TERM_TYPES)}" + ) + + if t == "i": + return {"t": t, "i": value} + elif t == "l": + return {"t": t, "v": value} + elif t == "b": + return {"t": t, "d": value} + return {"t": t} + +# ── Security boundary: MCP client → MCP server ── +# The MCP client authenticates to this server via a Bearer token in the +# HTTP Authorization header. The SDK's auth middleware extracts and +# verifies the token before any tool handler runs. +# +# We implement a pass-through TokenVerifier: the gateway is the real +# authority, so we accept any non-empty Bearer token here and forward +# it to the gateway for validation. The gateway's in-band auth +# protocol and IAM regime decide whether the token is valid. +# +# This means an invalid token will connect to the MCP server but will +# fail when the first WebSocket auth frame is sent to the gateway. +# That is intentional — the gateway is the single source of truth. + + +class PassthroughTokenVerifier(TokenVerifier): + """Accept any non-empty Bearer token and forward it downstream. + + The TrustGraph gateway is the authority for token validation, not + this MCP server. We store the raw token in the AccessToken so that + tool handlers can retrieve it via ``get_access_token().token`` and + forward it to the gateway. + """ + + async def verify_token(self, token: str) -> AccessToken | None: + if not token: + return None + return AccessToken( + token=token, + client_id="mcp-caller", + scopes=[], + ) + @dataclass class AppContext: - sockets: dict[str, WebSocketManager] - websocket_url: str - gateway_token: str + sockets: dict[str, WebSocketManager] = field(default_factory=dict) + websocket_url: str = "" + @asynccontextmanager -async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = "") -> AsyncIterator[AppContext]: +async def app_lifespan( + server: FastMCP, + websocket_url: str = "ws://api-gateway:8088/api/v1/socket", +) -> AsyncIterator[AppContext]: + """Manage per-server state: the pool of per-caller WebSocket + connections to the gateway.""" - """ - Manage application lifecycle with type-safe context - """ - - # Initialize on startup - sockets = {} + sockets: dict[str, WebSocketManager] = {} try: - yield AppContext(sockets=sockets, websocket_url=websocket_url, gateway_token=gateway_token) + yield AppContext(sockets=sockets, websocket_url=websocket_url) finally: - # Cleanup on shutdown - logging.info("Shutting down context") + logger.info("Shutting down — closing %d WebSocket(s)", len(sockets)) - for k, manager in sockets.items(): - logging.info(f"Closing socket for {k}") - await manager.stop() + for key, manager in sockets.items(): + try: + await manager.stop() + except Exception as e: + logger.warning("Error closing socket %s: %s", key, e) - logging.info("Shutdown complete") + logger.info("Shutdown complete") -async def get_socket_manager(ctx): + +def _require_token() -> str: + """Extract the caller's Bearer token from the MCP auth context. + + Raises RuntimeError if no token is present (the caller did not + authenticate). + """ + # ── Security boundary: token extraction ── + # get_access_token() reads the contextvar set by the SDK's + # AuthContextMiddleware. The token was placed there by + # PassthroughTokenVerifier.verify_token() and is the raw Bearer + # value from the MCP client's Authorization header. + access = get_access_token() + if access is None or not access.token: + raise RuntimeError( + "Authentication required — send a Bearer token in the " + "Authorization header" + ) + return access.token + + +async def get_socket_manager(ctx, token): + """Return (or create) an authenticated WebSocket for this token. + + Each unique token gets its own WebSocket connection so that + gateway-side identity, workspace binding, and capability scoping + are preserved per caller. + """ lifespan_context = ctx.request_context.lifespan_context sockets = lifespan_context.sockets websocket_url = lifespan_context.websocket_url - gateway_token = lifespan_context.gateway_token - if "default" in sockets: - logging.info("Return existing socket manager") - return sockets["default"] + key = _token_key(token) - logging.info(f"Opening socket to {websocket_url}...") + if key in sockets: + manager = sockets[key] + if manager.socket is not None: + return manager + # Socket was closed (e.g. server-side timeout) — reconnect. + del sockets[key] - # Create manager with empty pending requests - manager = WebSocketManager(websocket_url, token=gateway_token) + logger.info("Opening authenticated WebSocket to %s …", websocket_url) - # Start reader task with the proper manager + manager = WebSocketManager(websocket_url, token=token) await manager.start() - sockets["default"] = manager + # Verify the token is valid by calling whoami. This confirms the + # gateway accepted the token and gives us the caller's identity. + try: + identity = await manager.whoami() + logger.info( + "WebSocket ready — caller: %s", + identity.get("handle", "unknown"), + ) + except Exception as e: + await manager.stop() + raise RuntimeError( + f"Token rejected by gateway (whoami failed): {e}" + ) from e - logging.info("Return new socket manager") + sockets[key] = manager return manager + @dataclass class EmbeddingsResponse: vectors: List[List[float]] @@ -182,10 +291,23 @@ class PutConfigResponse: class DeleteConfigResponse: pass +@dataclass +class SparqlQueryResponse: + query_type: str + variables: List[str] + bindings: List[Dict[str, Any]] + ask_result: bool + triples: List[Dict[str, Any]] + +@dataclass +class GraphQLQueryResponse: + data: Any + errors: List[Dict[str, Any]] + @dataclass class GetPromptsResponse: prompts: List[str] - + @dataclass class GetPromptResponse: prompt: Dict[str, Any] @@ -194,31 +316,61 @@ class GetPromptResponse: class GetSystemPromptResponse: prompt: str + class McpServer: - def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = ""): + def __init__( + self, + host: str = "0.0.0.0", + port: int = 8000, + websocket_url: str = "ws://api-gateway:8088/api/v1/socket", + auth_issuer: str = "", + auth_resource_url: str = "", + ): self.host = host self.port = port self.websocket_url = websocket_url - self.gateway_token = gateway_token - # Create a partial function to pass websocket_url to app_lifespan - lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url, gateway_token=gateway_token) - + lifespan_with_url = partial( + app_lifespan, websocket_url=websocket_url, + ) + + # ── Security: MCP-level auth configuration ── + # The SDK requires AuthSettings whenever a token_verifier is + # present. The issuer_url tells MCP clients where to obtain + # tokens; resource_server_url identifies this server in OAuth + # protected-resource metadata. + # + # The PassthroughTokenVerifier accepts any non-empty Bearer + # token — real validation happens at the gateway. This is + # intentional: the gateway is the single source of truth for + # identity and capability checks. + from mcp.server.auth.settings import AuthSettings + + auth_settings = AuthSettings( + issuer_url=auth_issuer or f"http://{host}:{port}", + resource_server_url=auth_resource_url or f"http://{host}:{port}", + ) + self.mcp = FastMCP( - "TrustGraph", dependencies=["trustgraph-base"], - host=self.host, port=self.port, + "TrustGraph", + dependencies=["trustgraph-base"], + host=self.host, + port=self.port, lifespan=lifespan_with_url, + token_verifier=PassthroughTokenVerifier(), + auth=auth_settings, ) self._register_tools() - + def _register_tools(self): """Register all MCP tools""" - # Register all the tools that were previously registered globally self.mcp.tool()(self.embeddings) self.mcp.tool()(self.text_completion) self.mcp.tool()(self.graph_rag) self.mcp.tool()(self.agent) self.mcp.tool()(self.triples_query) + self.mcp.tool()(self.sparql_query) + self.mcp.tool()(self.graphql_query) self.mcp.tool()(self.graph_embeddings_query) self.mcp.tool()(self.get_config_all) self.mcp.tool()(self.get_config) @@ -243,67 +395,69 @@ class McpServer: self.mcp.tool()(self.load_document) self.mcp.tool()(self.remove_document) self.mcp.tool()(self.add_processing) - + def run(self): """Run the MCP server""" self.mcp.run(transport="streamable-http") + async def _get_manager(self, ctx): + """Get an authenticated WebSocket manager for the current caller. + + Extracts the Bearer token from the MCP auth context and returns + a per-token WebSocket connection to the gateway. + """ + token = _require_token() + return await get_socket_manager(ctx, token) + async def embeddings( self, - text: str, + texts: List[str], flow_id: str | None = None, + workspace: str | None = None, ctx: Context = None, ) -> EmbeddingsResponse: """ - Generate vector embeddings for the given text using TrustGraph's embedding models. - + Generate vector embeddings for the given texts using TrustGraph's embedding models. + This tool converts text into high-dimensional vectors that capture semantic meaning, enabling similarity searches, clustering, and other vector-based operations. - + Args: - text: The input text to convert into embeddings. Can be a sentence, paragraph, - or document. The text will be processed by the configured embedding model. + texts: List of input texts to convert into embeddings. Each text can be a + sentence, paragraph, or document. flow_id: Optional flow identifier to use for processing (default: "default"). Different flows may use different embedding models or configurations. - + workspace: Optional workspace to query. If omitted, uses the caller's + default workspace. + Returns: - EmbeddingsResponse containing a list of vectors. Each vector is a list of floats - representing the text's semantic embedding in the model's vector space. - - Example usage: - - Convert a query into embeddings for similarity search - - Generate embeddings for documents before storing them - - Create embeddings for comparison with existing knowledge + EmbeddingsResponse containing a list of vectors, one per input text. """ - logging.info("Embeddings request made") + logger.info("Embeddings request") if flow_id is None: flow_id = "default" - manager = await get_socket_manager(ctx, "trustgraph") + manager = await self._get_manager(ctx) - if ctx is None: - raise RuntimeError("No context provided") + if ctx: + await ctx.session.send_log_message( + level="info", + data="Computing embeddings via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - await ctx.session.send_log_message( - level="info", - data=f"Computing embeddings via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, + request_data = {"texts": texts} + + gen = manager.request( + "embeddings", request_data, flow_id, workspace=workspace, ) - # Send websocket request - request_data = {"text": text} - logging.info("making request") - - gen = manager.request("embeddings", request_data, flow_id) - async for response in gen: - - # Extract vectors from response vectors = response.get("vectors", [[]]) break - + return EmbeddingsResponse(vectors=vectors) async def text_completion( @@ -311,62 +465,47 @@ class McpServer: prompt: str, system: str | None = None, flow_id: str | None = None, + workspace: str | None = None, ctx: Context = None, ) -> TextCompletionResponse: """ Generate text completions using TrustGraph's language models. - - This tool sends prompts to configured language models and returns generated text. - It supports both user prompts and system instructions for controlling generation. - + Args: prompt: The main prompt or question to send to the language model. - This is the primary input that guides the model's response. system: Optional system prompt that sets the context, role, or behavior - for the AI assistant (e.g., "You are a helpful coding assistant"). - System prompts influence how the model interprets and responds. - flow_id: Optional flow identifier (default: "default"). Different flows - may use different models, parameters, or processing pipelines. - + for the AI assistant. + flow_id: Optional flow identifier (default: "default"). + workspace: Optional workspace to query. If omitted, uses the caller's + default workspace. + Returns: TextCompletionResponse containing the generated text response from the model. - - Example usage: - - Ask questions and get AI-generated answers - - Generate code, documentation, or creative content - - Perform text analysis, summarization, or transformation tasks - - Use system prompts to control tone, style, or domain expertise """ if system is None: system = "" if flow_id is None: flow_id = "default" - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - # Use websocket if context is available - logging.info("Text completion request made via websocket") + if ctx: + await ctx.session.send_log_message( + level="info", + data="Generating text completion via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Generating text completion via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - - # Send websocket request request_data = {"system": system, "prompt": prompt} - gen = manager.request("text-completion", request_data, flow_id) + gen = manager.request( + "text-completion", request_data, flow_id, workspace=workspace, + ) async for response in gen: - - # Extract vectors from response text = response.get("response", "") break - + return TextCompletionResponse(response=text) async def graph_rag( @@ -378,58 +517,43 @@ class McpServer: max_subgraph_size: int | None = None, max_path_length: int | None = None, flow_id: str | None = None, + workspace: str | None = None, ctx: Context = None, ) -> GraphRagResponse: """ Perform Graph-based Retrieval Augmented Generation (GraphRAG) queries. - + GraphRAG combines knowledge graph traversal with language model generation to provide - contextually rich answers. It explores relationships between entities to build relevant - context before generating responses. - + contextually rich answers. + Args: question: The question or query to answer using the knowledge graph. - The system will find relevant entities and relationships to inform the response. collection: Knowledge collection to query (default: "default"). - Different collections may contain domain-specific knowledge. entity_limit: Maximum number of entities to retrieve during graph traversal. - Higher limits provide more context but increase processing time. triple_limit: Maximum number of relationship triples to consider. - Controls the depth of relationship exploration. max_subgraph_size: Maximum size of the subgraph to extract for context. - Larger subgraphs provide richer context but use more resources. max_path_length: Maximum path length to traverse in the knowledge graph. - Longer paths can discover distant but relevant relationships. flow_id: Processing flow to use (default: "default"). - + workspace: Optional workspace to query. If omitted, uses the caller's + default workspace. + Returns: GraphRagResponse containing the generated answer informed by knowledge graph context. - - Example usage: - - Answer complex questions requiring multi-hop reasoning - - Explore relationships between entities in your knowledge base - - Generate responses grounded in structured knowledge - - Perform research queries across connected information """ if collection is None: collection = "default" if flow_id is None: flow_id = "default" - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("GraphRAG request made via websocket") + if ctx: + await ctx.session.send_log_message( + level="info", + data="Processing GraphRAG query via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Processing GraphRAG query via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - - # Build request data with all parameters request_data = { "query": question } @@ -440,20 +564,19 @@ class McpServer: if max_subgraph_size: request_data["max_subgraph_size"] = max_subgraph_size if max_path_length: request_data["max_path_length"] = max_path_length - gen = manager.request("graph-rag", request_data, flow_id) + gen = manager.request( + "graph-rag", request_data, flow_id, workspace=workspace, + ) text_chunks = [] async for response in gen: - # Handle new message format with message_type message_type = response.get("message_type", "chunk") - # Only collect text from chunk messages if message_type == "chunk": chunk_text = response.get("response", "") if chunk_text: text_chunks.append(chunk_text) - # Check if session is complete if response.get("end_of_session"): break @@ -464,404 +587,447 @@ class McpServer: question: str, collection: str | None = None, flow_id: str | None = None, + workspace: str | None = None, ctx: Context = None, ) -> AgentResponse: """ Execute intelligent agent queries with reasoning and tool usage capabilities. - - The agent can perform complex multi-step reasoning, use tools, and provide - detailed thought processes. It's designed for tasks requiring planning, - analysis, and iterative problem-solving. - + Args: - question: The question or task for the agent to solve. Can be complex - queries requiring multiple steps, analysis, or tool usage. + question: The question or task for the agent to solve. collection: Knowledge collection the agent can access (default: "default"). - Determines what information and tools are available. - flow_id: Agent workflow to use (default: "default"). Different flows - may have different capabilities, tools, or reasoning strategies. - + flow_id: Agent workflow to use (default: "default"). + workspace: Optional workspace to query. If omitted, uses the caller's + default workspace. + Returns: AgentResponse containing the final answer after the agent's reasoning process. - During execution, you'll see intermediate thoughts and observations. - - Example usage: - - Solve complex analytical problems requiring multiple steps - - Perform research tasks across multiple information sources - - Handle queries that need tool usage and decision-making - - Get detailed explanations of reasoning processes - - Note: This tool provides real-time updates on the agent's thinking process - through log messages, so you can follow its reasoning steps. """ if collection is None: collection = "default" if flow_id is None: flow_id = "default" - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Agent request made via websocket") + if ctx: + await ctx.session.send_log_message( + level="info", + data="Processing agent query via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Processing agent query via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - - # Build request data with all parameters request_data = { "question": question } if collection: request_data["collection"] = collection - gen = manager.request("agent", request_data, flow_id) + gen = manager.request( + "agent", request_data, flow_id, workspace=workspace, + ) async for response in gen: - logging.debug(f"Agent response: {response}") + logger.debug("Agent response: %s", response) - if "thought" in response: - await ctx.session.send_log_message( - level="info", - data=f"Thinking: {response['thought']}", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + if "thought" in response: + await ctx.session.send_log_message( + level="info", + data=f"Thinking: {response['thought']}", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - if "observation" in response: - await ctx.session.send_log_message( - level="info", - data=f"Observation: {response['observation']}", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if "observation" in response: + await ctx.session.send_log_message( + level="info", + data=f"Observation: {response['observation']}", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - # Extract vectors from response if "answer" in response: answer = response.get("answer", "") return AgentResponse(answer=answer) async def triples_query( self, - s_v: str | None = None, - s_e: bool | None = None, - p_v: str | None = None, - p_e: bool | None = None, - o_v: str | None = None, - o_e: bool | None = None, + s: str | None = None, + s_type: str | None = None, + p: str | None = None, + p_type: str | None = None, + o: str | None = None, + o_type: str | None = None, + collection: str | None = None, + graph: str | None = None, limit: int | None = None, flow_id: str | None = None, + workspace: str | None = None, ctx: Context = None, ) -> TriplesQueryResponse: """ Query knowledge graph triples using subject-predicate-object patterns. - - Knowledge graphs store information as triples (subject, predicate, object). - This tool allows flexible querying by specifying any combination of these - components, with wildcards for unspecified parts. - + + Each of s, p, o is an RDF term value. Use the corresponding _type + parameter to specify the term kind: + - "iri" (default for s and p): an IRI / entity reference + - "literal" (default for o): a plain literal value + - "blank": a blank node identifier + Args: - s_v: Subject value to match (e.g., "John", "Apple Inc."). Leave None for wildcard. - s_e: Whether subject should be treated as an entity (True) or literal (False). - p_v: Predicate/relationship value (e.g., "works_for", "type_of"). Leave None for wildcard. - p_e: Whether predicate should be treated as an entity (True) or literal (False). - o_v: Object value to match (e.g., "Engineer", "Company"). Leave None for wildcard. - o_e: Whether object should be treated as an entity (True) or literal (False). + s: Subject value to match. Leave None for wildcard. + s_type: Subject term type: "iri" (default), "literal", or "blank". + p: Predicate value to match. Leave None for wildcard. + p_type: Predicate term type: "iri" (default), "literal", or "blank". + o: Object value to match. Leave None for wildcard. + o_type: Object term type: "iri", "literal" (default), or "blank". + collection: Knowledge collection to query (default: "default"). + graph: Named graph IRI to restrict the query. None = default graph, + "*" = all graphs. limit: Maximum number of triples to return (default: 20). flow_id: Processing flow identifier (default: "default"). - + workspace: Optional workspace to query. If omitted, uses the caller's + default workspace. + Returns: TriplesQueryResponse containing matching triples from the knowledge graph. - - Example queries: - - Find all relationships for an entity: s_v="John", others None - - Find all instances of a relationship: p_v="works_for", others None - - Find specific facts: s_v="John", p_v="works_for", o_v=None - - Explore entity types: p_v="type_of", others None - - Use this for: - - Exploring knowledge graph structure - - Finding specific facts or relationships - - Discovering connections between entities - - Validating or debugging knowledge content """ if flow_id is None: flow_id = "default" if limit is None: limit = 20 + if collection is None: collection = "default" - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Triples query request made via websocket") + if ctx: + await ctx.session.send_log_message( + level="info", + data="Processing triples query via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Processing triples query via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - - # Build request data with Value objects request_data = { - "limit": limit + "limit": limit, + "collection": collection, } - # Add subject if provided - if s_v is not None: - request_data["s"] = {"v": s_v, "e": s_e } + if s is not None: + request_data["s"] = _make_term(s, s_type or "iri") - # Add predicate if provided - if p_v is not None: - request_data["p"] = {"v": p_v, "e": p_e } + if p is not None: + request_data["p"] = _make_term(p, p_type or "iri") - # Add object if provided - if o_v is not None: - request_data["o"] = {"v": o_v, "e": o_e } + if o is not None: + request_data["o"] = _make_term(o, o_type or "literal") - gen = manager.request("triples", request_data, flow_id) + if graph is not None: + request_data["g"] = graph + + gen = manager.request( + "triples", request_data, flow_id, workspace=workspace, + ) async for response in gen: - # Extract response data triples = response.get("response", []) break - + return TriplesQueryResponse(triples=triples) + async def sparql_query( + self, + query: str, + collection: str | None = None, + limit: int | None = None, + flow_id: str | None = None, + workspace: str | None = None, + ctx: Context = None, + ) -> SparqlQueryResponse: + """ + Execute a SPARQL query against the knowledge graph. + + Supports SELECT, ASK, CONSTRUCT, and DESCRIBE query forms. + + Args: + query: SPARQL query string (e.g. "SELECT ?s ?p ?o WHERE { ?s ?p ?o } LIMIT 10"). + collection: Knowledge collection to query (default: "default"). + limit: Safety limit on number of results (default: 10000). + flow_id: Processing flow identifier (default: "default"). + workspace: Optional workspace to query. If omitted, uses the caller's + default workspace. + + Returns: + SparqlQueryResponse containing the query results. The structure depends + on query type: + - SELECT: variables (column names) and bindings (rows of Term values) + - ASK: ask_result (boolean) + - CONSTRUCT/DESCRIBE: triples + """ + + if collection is None: collection = "default" + if flow_id is None: flow_id = "default" + if limit is None: limit = 10000 + + manager = await self._get_manager(ctx) + + if ctx: + await ctx.session.send_log_message( + level="info", + data="Processing SPARQL query via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) + + request_data = { + "query": query, + "collection": collection, + "limit": limit, + } + + gen = manager.request( + "sparql", request_data, flow_id, workspace=workspace, + ) + + async for response in gen: + query_type = response.get("query-type", "") + return SparqlQueryResponse( + query_type=query_type, + variables=response.get("variables", []), + bindings=response.get("bindings", []), + ask_result=response.get("ask-result", False), + triples=response.get("triples", []), + ) + + async def graphql_query( + self, + query: str, + collection: str | None = None, + variables: Dict[str, Any] | None = None, + operation_name: str | None = None, + flow_id: str | None = None, + workspace: str | None = None, + ctx: Context = None, + ) -> GraphQLQueryResponse: + """ + Execute a GraphQL query against structured data (rows). + + Queries structured data schemas that have been loaded into TrustGraph. + The available types and fields depend on the schemas configured in the + target workspace. + + Args: + query: GraphQL query string (e.g. '{ customers(where: {status: {eq: "active"}}) { id name } }'). + collection: Data collection to query (default: "default"). + variables: Optional GraphQL variables as a dict. + operation_name: Optional operation name for multi-operation documents. + flow_id: Processing flow identifier (default: "default"). + workspace: Optional workspace to query. If omitted, uses the caller's + default workspace. + + Returns: + GraphQLQueryResponse containing data (the query result) and errors + (any GraphQL field-level errors). + """ + + if collection is None: collection = "default" + if flow_id is None: flow_id = "default" + + manager = await self._get_manager(ctx) + + if ctx: + await ctx.session.send_log_message( + level="info", + data="Processing GraphQL query via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) + + request_data = { + "query": query, + "collection": collection, + "variables": variables or {}, + } + + if operation_name is not None: + request_data["operation_name"] = operation_name + + gen = manager.request( + "rows", request_data, flow_id, workspace=workspace, + ) + + async for response in gen: + return GraphQLQueryResponse( + data=response.get("data"), + errors=response.get("errors", []), + ) + async def graph_embeddings_query( self, vectors: List[List[float]], limit: int | None = None, flow_id: str | None = None, + workspace: str | None = None, ctx: Context = None, ) -> GraphEmbeddingsQueryResponse: """ Find entities in the knowledge graph using vector similarity search. - - This tool performs semantic search by comparing embedding vectors to find - the most similar entities in the knowledge graph. It's useful for finding - conceptually related information even when exact text matches don't exist. - + Args: - vectors: List of embedding vectors to search with. Each vector should be - a list of floats representing semantic embeddings (typically from - the embeddings tool). Multiple vectors can be provided for batch queries. + vectors: List of embedding vectors to search with. limit: Maximum number of similar entities to return (default: 20). - Higher limits provide more results but may include less relevant matches. flow_id: Processing flow identifier (default: "default"). - + workspace: Optional workspace to query. If omitted, uses the caller's + default workspace. + Returns: - GraphEmbeddingsQueryResponse containing entities ranked by similarity to the - input vectors, along with similarity scores and entity metadata. - - Example workflow: - 1. Use the 'embeddings' tool to convert text to vectors - 2. Use this tool to find similar entities in the knowledge graph - 3. Explore the returned entities for relevant information - - Use this for: - - Semantic search across knowledge entities - - Finding conceptually similar content - - Discovering related entities without exact keyword matches - - Building recommendation systems based on entity similarity + GraphEmbeddingsQueryResponse containing entities ranked by similarity. """ if flow_id is None: flow_id = "default" if limit is None: limit = 20 - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Graph embeddings query request made via websocket") + if ctx: + await ctx.session.send_log_message( + level="info", + data="Processing graph embeddings query via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Processing graph embeddings query via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - - # Build request data request_data = { "vectors": vectors, "limit": limit } - gen = manager.request("graph-embeddings", request_data, flow_id) + gen = manager.request( + "graph-embeddings", request_data, flow_id, workspace=workspace, + ) async for response in gen: - # Extract entities from response entities = response.get("entities", []) break - + return GraphEmbeddingsQueryResponse(entities=entities) async def get_config_all( self, + workspace: str | None = None, ctx: Context = None, ) -> ConfigResponse: """ Retrieve the complete TrustGraph system configuration. - - This tool returns all configuration settings for the TrustGraph system, - including model configurations, API keys, flow definitions, and system parameters. - + + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: - ConfigResponse containing the full configuration as a nested dictionary - with all system settings, organized by category (e.g., models, flows, storage). - - Use this for: - - Inspecting current system configuration - - Debugging configuration issues - - Understanding available models and settings - - Auditing system setup and parameters + ConfigResponse containing the full configuration as a nested dictionary. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get config all request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving all configuration via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving all configuration via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "config" } - gen = manager.request("config", request_data, None) + gen = manager.request("config", request_data, None, workspace=workspace) async for response in gen: config = response.get("config", {}) break - + return ConfigResponse(config=config) async def get_config( self, keys: List[Dict[str, str]], + workspace: str | None = None, ctx: Context = None, ) -> ConfigGetResponse: """ Retrieve specific configuration values by key. - - This tool allows you to fetch specific configuration settings without - retrieving the entire configuration. Useful for checking particular - settings or API keys. - + Args: - keys: List of configuration keys to retrieve. Each key should be a dict with: - - 'type': Configuration category (e.g., 'llm', 'embeddings', 'storage') - - 'key': Specific setting name within that category - + keys: List of configuration keys to retrieve. Each key should be a dict with + 'type' and 'key' fields. + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: ConfigGetResponse containing the requested configuration values. - - Example keys: - - {'type': 'llm', 'key': 'openai.model'} - - {'type': 'embeddings', 'key': 'default.model'} - - {'type': 'storage', 'key': 'database.url'} - - Use this for: - - Checking specific model configurations - - Validating API key settings - - Inspecting individual system parameters """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get config request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving specific configuration via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving specific configuration via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "get", "keys": keys } - gen = manager.request("config", request_data, None) + gen = manager.request("config", request_data, None, workspace=workspace) async for response in gen: values = response.get("values", []) break - + return ConfigGetResponse(values=values) async def put_config( self, values: List[Dict[str, str]], + workspace: str | None = None, ctx: Context = None, ) -> PutConfigResponse: """ Update system configuration values. - - This tool allows you to modify TrustGraph system settings, such as - model parameters, API keys, and system behavior configurations. - + Args: - values: List of configuration updates. Each update should be a dict with: - - 'type': Configuration category (e.g., 'llm', 'embeddings') - - 'key': Specific setting name to update - - 'value': New value for the setting - + values: List of configuration updates. Each should be a dict with + 'type', 'key', and 'value' fields. + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: PutConfigResponse confirming the configuration update. - - Example updates: - - {'type': 'llm', 'key': 'openai.model', 'value': 'gpt-4'} - - {'type': 'embeddings', 'key': 'batch_size', 'value': '100'} - - Use this for: - - Switching between different models - - Updating API credentials - - Modifying system behavior parameters - - Configuring processing settings - - Note: Configuration changes may require system restart to take effect. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Put config request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Updating configuration via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Updating configuration via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "put", "values": values } - gen = manager.request("config", request_data, None) + gen = manager.request("config", request_data, None, workspace=workspace) async for response in gen: return PutConfigResponse() @@ -869,97 +1035,73 @@ class McpServer: async def delete_config( self, keys: List[Dict[str, str]], + workspace: str | None = None, ctx: Context = None, ) -> DeleteConfigResponse: """ Delete specific configuration entries from the system. - - This tool removes configuration settings, reverting them to system defaults - or disabling specific features. - + Args: - keys: List of configuration keys to delete. Each key should be a dict with: - - 'type': Configuration category (e.g., 'llm', 'embeddings') - - 'key': Specific setting name to remove - + keys: List of configuration keys to delete. Each should be a dict with + 'type' and 'key' fields. + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: DeleteConfigResponse confirming the deletion. - - Use this for: - - Removing custom model configurations - - Clearing API credentials - - Resetting settings to defaults - - Cleaning up obsolete configurations - - Warning: Deleting essential configuration may cause system functionality - to be disabled until properly reconfigured. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Delete config request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Deleting configuration via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Deleting configuration via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "delete", "keys": keys } - gen = manager.request("config", request_data, None) + gen = manager.request("config", request_data, None, workspace=workspace) async for response in gen: return DeleteConfigResponse() async def get_prompts( self, + workspace: str | None = None, ctx: Context = None, ) -> GetPromptsResponse: """ List all available prompt templates in the system. - - Prompt templates are reusable prompts that can be used with language models - for consistent behavior across different queries and use cases. - + + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: GetPromptsResponse containing a list of available prompt template IDs. - Each ID can be used with get_prompt to retrieve the full template. - - Use this for: - - Discovering available prompt templates - - Exploring pre-configured prompts for different tasks - - Finding templates for specific use cases - - Understanding what prompt options are available """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get prompts request made via websocket") + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving prompt templates via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving prompt templates via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - - # First get all config request_data = { "operation": "config" } - gen = manager.request("config", request_data, None) + gen = manager.request("config", request_data, None, workspace=workspace) async for response in gen: config = response.get("config", {}) @@ -971,49 +1113,36 @@ class McpServer: async def get_prompt( self, prompt_id: str, + workspace: str | None = None, ctx: Context = None, ) -> GetPromptResponse: """ Retrieve a specific prompt template by ID. - - Prompt templates contain structured prompts with placeholders, instructions, - and metadata for specific tasks or domains. - + Args: prompt_id: The unique identifier of the prompt template to retrieve. - Use get_prompts to see available template IDs. - + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: - GetPromptResponse containing the complete prompt template with its - structure, placeholders, and usage instructions. - - Use this for: - - Examining prompt template structure - - Understanding how to use specific templates - - Copying or modifying existing prompts - - Learning prompt engineering patterns + GetPromptResponse containing the complete prompt template. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get prompt request made via websocket") + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Retrieving prompt template '{prompt_id}' via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving prompt template '{prompt_id}' via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - - # First get all config request_data = { "operation": "config" } - gen = manager.request("config", request_data, None) + gen = manager.request("config", request_data, None, workspace=workspace) async for response in gen: config = response.get("config", {}) @@ -1025,44 +1154,35 @@ class McpServer: async def get_system_prompt( self, + workspace: str | None = None, ctx: Context = None, ) -> GetSystemPromptResponse: """ Retrieve the current system prompt configuration. - - The system prompt defines the default behavior, personality, and instructions - for language models across the TrustGraph system. - + + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: - GetSystemPromptResponse containing the system prompt text and configuration. - - Use this for: - - Understanding default AI behavior settings - - Checking current system-wide prompt configuration - - Auditing AI personality and instruction settings - - Debugging unexpected AI responses + GetSystemPromptResponse containing the system prompt text. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get system prompt request made via websocket") + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving system prompt via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving system prompt via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - - # First get all config request_data = { "operation": "config" } - gen = manager.request("config", request_data, None) + gen = manager.request("config", request_data, None, workspace=workspace) async for response in gen: config = response.get("config", {}) @@ -1073,51 +1193,39 @@ class McpServer: async def get_token_costs( self, + workspace: str | None = None, ctx: Context = None, ) -> ConfigTokenCostsResponse: """ Retrieve token pricing information for all configured AI models. - - This tool provides cost information for input and output tokens across - different language models, helping with budget planning and cost optimization. - + + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: - ConfigTokenCostsResponse containing pricing data for each model including: - - Model name/identifier - - Input token cost (per token) - - Output token cost (per token) - - Use this for: - - Estimating costs for different models - - Choosing cost-effective models for tasks - - Budget planning and cost analysis - - Monitoring and optimizing AI spending + ConfigTokenCostsResponse containing pricing data for each model. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get token costs request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving token costs via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving token costs via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "getvalues", "type": "token-costs" } - gen = manager.request("config", request_data, None) + gen = manager.request("config", request_data, None, workspace=workspace) async for response in gen: values = response.get("values", []) - # Transform to match TypeScript API format costs = [] for item in values: try: @@ -1130,106 +1238,89 @@ class McpServer: except (json.JSONDecodeError, AttributeError): continue break - + return ConfigTokenCostsResponse(costs=costs) async def get_knowledge_cores( self, + workspace: str | None = None, ctx: Context = None, ) -> KnowledgeCoresResponse: """ List all available knowledge graph cores in the current workspace. - Knowledge cores are packaged collections of structured knowledge that can - be loaded into the system for querying and reasoning. They contain entities, - relationships, and facts organized as knowledge graphs. + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: KnowledgeCoresResponse containing a list of available knowledge core IDs. - - Use this for: - - Discovering available knowledge collections - - Understanding what knowledge domains are accessible - - Planning which cores to load for specific tasks - - Managing knowledge resources """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get knowledge cores request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving knowledge graph cores via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving knowledge graph cores via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "list-kg-cores", } - gen = manager.request("knowledge", request_data, None) + gen = manager.request( + "knowledge", request_data, None, workspace=workspace, + ) async for response in gen: ids = response.get("ids", []) break - + return KnowledgeCoresResponse(ids=ids) async def delete_kg_core( self, core_id: str, + workspace: str | None = None, ctx: Context = None, ) -> DeleteKgCoreResponse: """ Permanently delete a knowledge graph core. - This operation removes a knowledge core from storage. Use with caution - as this action cannot be undone. - Args: core_id: Unique identifier of the knowledge core to delete. + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: DeleteKgCoreResponse confirming the deletion. - - Use this for: - - Cleaning up obsolete knowledge cores - - Removing test or experimental data - - Managing storage space - - Maintaining organized knowledge collections - - Warning: This permanently deletes the knowledge core and all its data. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Delete KG core request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Deleting knowledge graph core '{core_id}' via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Deleting knowledge graph core '{core_id}' via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "delete-kg-core", "id": core_id, } - gen = manager.request("knowledge", request_data, None) + gen = manager.request( + "knowledge", request_data, None, workspace=workspace, + ) async for response in gen: break - + return DeleteKgCoreResponse() async def load_kg_core( @@ -1237,46 +1328,34 @@ class McpServer: core_id: str, flow: str, collection: str | None = None, + workspace: str | None = None, ctx: Context = None, ) -> LoadKgCoreResponse: """ Load a knowledge graph core into the active system for querying. - This operation makes a knowledge core available for GraphRAG queries, - triple searches, and other knowledge-based operations. - Args: core_id: Unique identifier of the knowledge core to load. - flow: Processing flow to use for loading the core. Different flows - may apply different processing, indexing, or optimization steps. - collection: Target collection name (default: "default"). The loaded - knowledge will be available under this collection name. + flow: Processing flow to use for loading the core. + collection: Target collection name (default: "default"). + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: LoadKgCoreResponse confirming the core has been loaded. - - Use this for: - - Making knowledge cores available for queries - - Switching between different knowledge domains - - Loading domain-specific knowledge for tasks - - Preparing knowledge for GraphRAG operations """ if collection is None: collection = "default" - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Load KG core request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Loading knowledge graph core '{core_id}' via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Loading knowledge graph core '{core_id}' via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "load-kg-core", @@ -1285,292 +1364,241 @@ class McpServer: "collection": collection } - gen = manager.request("knowledge", request_data, None) + gen = manager.request( + "knowledge", request_data, None, workspace=workspace, + ) async for response in gen: break - + return LoadKgCoreResponse() async def get_kg_core( self, core_id: str, + workspace: str | None = None, ctx: Context = None, ) -> GetKgCoreResponse: """ Download and retrieve the complete content of a knowledge graph core. - This tool streams the entire content of a knowledge core, returning all - entities, relationships, and metadata. Due to potentially large data sizes, - the content is streamed in chunks. - Args: core_id: Unique identifier of the knowledge core to retrieve. + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: GetKgCoreResponse containing all chunks of the knowledge core data. - Each chunk contains part of the knowledge graph structure. - - Use this for: - - Examining knowledge core content and structure - - Debugging knowledge graph data - - Exporting knowledge for backup or analysis - - Understanding the scope and quality of knowledge - - Note: Large knowledge cores may take significant time to download. - Progress updates are provided through log messages during streaming. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get KG core request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving knowledge graph core '{core_id}' via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Retrieving knowledge graph core '{core_id}' via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "get-kg-core", "id": core_id, } - # Collect all streaming responses chunks = [] - gen = manager.request("knowledge", request_data, None) + gen = manager.request( + "knowledge", request_data, None, workspace=workspace, + ) async for response in gen: - # Check for end of stream if response.get("eos", False): - await ctx.session.send_log_message( - level="info", - data=f"Completed streaming KG core data", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Completed streaming KG core data", + logger="notification_stream", + related_request_id=ctx.request_id, + ) break else: chunks.append(response) - await ctx.session.send_log_message( - level="info", - data=f"Received KG core chunk ({len(chunks)} chunks so far)", - logger="notification_stream", - related_request_id=ctx.request_id, - ) - + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Received KG core chunk ({len(chunks)} chunks so far)", + logger="notification_stream", + related_request_id=ctx.request_id, + ) + return GetKgCoreResponse(chunks=chunks) async def get_flows( self, + workspace: str | None = None, ctx: Context = None, ) -> FlowsResponse: """ List all available processing flows in the system. - - Flows define processing pipelines for different types of operations - (e.g., document processing, knowledge extraction, query handling). - Each flow encapsulates a specific workflow with configured steps. - + + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: FlowsResponse containing a list of available flow identifiers. - - Use this for: - - Discovering available processing workflows - - Understanding what processing options are available - - Choosing appropriate flows for specific tasks - - Planning workflow-based operations """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get flows request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving available flows via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving available flows via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "list-flows" } - gen = manager.request("flow", request_data, None) + gen = manager.request( + "flow", request_data, None, workspace=workspace, + ) async for response in gen: flow_ids = response.get("flow-ids", []) break - + return FlowsResponse(flow_ids=flow_ids) async def get_flow( self, flow_id: str, + workspace: str | None = None, ctx: Context = None, ) -> FlowResponse: """ Retrieve the complete definition of a specific processing flow. - - This tool returns the detailed configuration, steps, and parameters - of a processing flow, showing how it processes data and what operations it performs. - + Args: flow_id: Unique identifier of the flow to retrieve. - + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: - FlowResponse containing the complete flow definition including: - - Flow configuration and parameters - - Processing steps and their order - - Input/output specifications - - Dependencies and requirements - - Use this for: - - Understanding how specific flows work - - Debugging flow processing issues - - Learning flow configuration patterns - - Customizing or duplicating flows + FlowResponse containing the complete flow definition. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get flow request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving flow definition for '{flow_id}' via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Retrieving flow definition for '{flow_id}' via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "get-flow", "flow-id": flow_id, } - gen = manager.request("flow", request_data, None) + gen = manager.request( + "flow", request_data, None, workspace=workspace, + ) async for response in gen: flow_data = response.get("flow", "{}") - # Parse JSON flow definition as done in TypeScript flow = json.loads(flow_data) if isinstance(flow_data, str) else flow_data break - + return FlowResponse(flow=flow) async def get_flow_classes( self, + workspace: str | None = None, ctx: Context = None, ) -> FlowClassesResponse: """ List all available flow class templates. - - Flow classes are templates that define types of processing workflows. - They serve as blueprints for creating specific flow instances with - customized parameters. - + + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: FlowClassesResponse containing a list of available flow class names. - - Use this for: - - Discovering available flow templates - - Understanding what types of processing are supported - - Planning new flow creation - - Exploring system capabilities """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get flow classes request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving flow classes via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving flow classes via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "list-classes" } - gen = manager.request("flow", request_data, None) + gen = manager.request( + "flow", request_data, None, workspace=workspace, + ) async for response in gen: class_names = response.get("class-names", []) break - + return FlowClassesResponse(class_names=class_names) async def get_flow_class( self, class_name: str, + workspace: str | None = None, ctx: Context = None, ) -> FlowClassResponse: """ Retrieve the definition of a specific flow class template. - - Flow classes define the structure, parameters, and capabilities of - flow types. This tool returns the class specification including - configurable parameters and processing logic. - + Args: class_name: Name of the flow class to retrieve. - + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: - FlowClassResponse containing the flow class definition with: - - Class parameters and configuration options - - Processing capabilities and requirements - - Usage instructions and examples - - Use this for: - - Understanding flow class capabilities - - Learning how to configure new flows - - Troubleshooting flow creation issues - - Exploring advanced flow features + FlowClassResponse containing the flow class definition. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get flow class request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving flow class definition for '{class_name}' via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Retrieving flow class definition for '{class_name}' via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "get-class", "class-name": class_name } - gen = manager.request("flow", request_data, None) + gen = manager.request( + "flow", request_data, None, workspace=workspace, + ) async for response in gen: class_def_data = response.get("class-definition", "{}") - # Parse JSON class definition as done in TypeScript class_definition = json.loads(class_def_data) if isinstance(class_def_data, str) else class_def_data break - + return FlowClassResponse(class_definition=class_definition) async def start_flow( @@ -1578,43 +1606,32 @@ class McpServer: flow_id: str, class_name: str, description: str, + workspace: str | None = None, ctx: Context = None, ) -> StartFlowResponse: """ Create and start a new processing flow instance. - - This tool creates a new flow based on a flow class template and starts - it running. The flow will begin processing according to its configuration. - + Args: flow_id: Unique identifier for the new flow instance. class_name: Flow class template to use for creating the flow. - Use get_flow_classes to see available classes. description: Human-readable description of the flow's purpose. - + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: StartFlowResponse confirming the flow has been started. - - Use this for: - - Creating new processing workflows - - Starting automated processing tasks - - Launching background operations - - Initiating data processing pipelines """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Start flow request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Starting flow '{flow_id}' with class '{class_name}' via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Starting flow '{flow_id}' with class '{class_name}' via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "start-flow", @@ -1623,162 +1640,135 @@ class McpServer: "description": description } - gen = manager.request("flow", request_data, None) + gen = manager.request( + "flow", request_data, None, workspace=workspace, + ) async for response in gen: break - + return StartFlowResponse() async def stop_flow( self, flow_id: str, + workspace: str | None = None, ctx: Context = None, ) -> StopFlowResponse: """ Stop a running flow instance. - - This tool gracefully stops a running flow, allowing it to complete - current operations before shutting down. - + Args: flow_id: Unique identifier of the flow instance to stop. - + workspace: Optional workspace. If omitted, uses the caller's + default workspace. + Returns: StopFlowResponse confirming the flow has been stopped. - - Use this for: - - Stopping unwanted or completed flows - - Managing system resources - - Interrupting long-running processes - - Maintaining flow lifecycle """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Stop flow request made via websocket") - - manager = await get_socket_manager(ctx, "trustgraph") - - await ctx.session.send_log_message( - level="info", - data=f"Stopping flow '{flow_id}' via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Stopping flow '{flow_id}' via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "stop-flow", "flow-id": flow_id } - gen = manager.request("flow", request_data, None) + gen = manager.request( + "flow", request_data, None, workspace=workspace, + ) async for response in gen: break - + return StopFlowResponse() async def get_documents( self, + workspace: str | None = None, ctx: Context = None, ) -> DocumentsResponse: """ List all documents stored in the TrustGraph document library. - This tool returns metadata for all documents that have been uploaded - to the system, including their processing status and properties. + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: - DocumentsResponse containing metadata for each document including: - - Document ID and title - - Upload timestamp - - MIME type and size information - - Tags and custom metadata - - Processing status - - Use this for: - - Browsing available documents - - Managing document collections - - Finding documents by metadata - - Auditing document storage + DocumentsResponse containing metadata for each document. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get documents request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving documents list via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving documents list via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "list-documents", } - gen = manager.request("librarian", request_data, None) + gen = manager.request( + "librarian", request_data, None, workspace=workspace, + ) async for response in gen: document_metadatas = response.get("document-metadatas", []) break - + return DocumentsResponse(document_metadatas=document_metadatas) async def get_processing( self, + workspace: str | None = None, ctx: Context = None, ) -> ProcessingResponse: """ List all documents currently in the processing queue. - This tool shows documents that are being processed or waiting to be - processed, along with their processing status and configuration. + Args: + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: - ProcessingResponse containing processing metadata including: - - Processing job ID and document ID - - Processing flow and status - - Target collection - - Timestamp and progress information - - Use this for: - - Monitoring document processing progress - - Debugging processing issues - - Managing processing queues - - Understanding system workload + ProcessingResponse containing processing metadata. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Get processing request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Retrieving processing list via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Retrieving processing list via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "list-processing", } - gen = manager.request("librarian", request_data, None) + gen = manager.request( + "librarian", request_data, None, workspace=workspace, + ) async for response in gen: processing_metadatas = response.get("processing-metadatas", []) break - + return ProcessingResponse(processing_metadatas=processing_metadatas) async def load_document( @@ -1790,50 +1780,39 @@ class McpServer: title: str = "", comments: str = "", tags: List[str] | None = None, + workspace: str | None = None, ctx: Context = None, ) -> LoadDocumentResponse: """ Upload a document to the TrustGraph document library. - This tool stores documents with rich metadata for later processing, - search, and knowledge extraction. Documents can be text files, PDFs, - or other supported formats. - Args: document: The document content as a string. For binary files, this should be base64-encoded content. document_id: Optional unique identifier. If not provided, one will be generated. metadata: Optional list of custom metadata key-value pairs. - mime_type: MIME type of the document (e.g., 'text/plain', 'application/pdf'). + mime_type: MIME type of the document. title: Human-readable title for the document. comments: Optional description or notes about the document. - tags: List of tags for categorizing and finding the document. + tags: List of tags for categorizing the document. + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: LoadDocumentResponse confirming the document has been stored. - - Use this for: - - Adding new documents to the knowledge base - - Storing reference materials and data sources - - Building document collections for processing - - Importing external content for analysis """ if tags is None: tags = [] - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Load document request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Loading document to library via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data="Loading document to library via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) import time timestamp = int(time.time()) @@ -1852,63 +1831,55 @@ class McpServer: "content": document } - gen = manager.request("librarian", request_data, None) + gen = manager.request( + "librarian", request_data, None, workspace=workspace, + ) async for response in gen: break - + return LoadDocumentResponse() async def remove_document( self, document_id: str, + workspace: str | None = None, ctx: Context = None, ) -> RemoveDocumentResponse: """ Permanently remove a document from the library. - This operation deletes a document and all its associated metadata. - Use with caution as this action cannot be undone. - Args: document_id: Unique identifier of the document to remove. + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: RemoveDocumentResponse confirming the document has been deleted. - - Use this for: - - Cleaning up obsolete or incorrect documents - - Managing storage space - - Removing sensitive or inappropriate content - - Maintaining organized document collections - - Warning: This permanently deletes the document and all its metadata. """ - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Remove document request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Removing document '{document_id}' from library via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Removing document '{document_id}' from library via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) request_data = { "operation": "remove-document", "document-id": document_id, } - gen = manager.request("librarian", request_data, None) + gen = manager.request( + "librarian", request_data, None, workspace=workspace, + ) async for response in gen: break - + return RemoveDocumentResponse() async def add_processing( @@ -1918,53 +1889,37 @@ class McpServer: flow: str, collection: str | None = None, tags: List[str] | None = None, + workspace: str | None = None, ctx: Context = None, ) -> AddProcessingResponse: """ Queue a document for processing through a specific workflow. - This tool adds a document to the processing queue where it will be - processed by the specified flow to extract knowledge, create embeddings, - or perform other analysis operations. - Args: processing_id: Unique identifier for this processing job. document_id: ID of the document to process (must exist in library). - flow: Processing flow to use. Different flows perform different - types of analysis (e.g., knowledge extraction, summarization). + flow: Processing flow to use. collection: Target collection for processed knowledge (default: "default"). - Results will be stored under this collection name. tags: Optional tags for categorizing this processing job. + workspace: Optional workspace. If omitted, uses the caller's + default workspace. Returns: AddProcessingResponse confirming the document has been queued. - - Use this for: - - Processing uploaded documents into knowledge - - Extracting entities and relationships from text - - Creating searchable embeddings - - Converting documents into structured knowledge - - Note: Processing may take time depending on document size and flow complexity. - Use get_processing to monitor progress. """ if collection is None: collection = "default" if tags is None: tags = [] - if ctx is None: - raise RuntimeError("No context provided") + manager = await self._get_manager(ctx) - logging.info("Add processing request made via websocket") - - manager = await get_socket_manager(ctx) - - await ctx.session.send_log_message( - level="info", - data=f"Adding document '{document_id}' to processing queue via websocket...", - logger="notification_stream", - related_request_id=ctx.request_id, - ) + if ctx: + await ctx.session.send_log_message( + level="info", + data=f"Adding document '{document_id}' to processing queue via websocket...", + logger="notification_stream", + related_request_id=ctx.request_id, + ) import time timestamp = int(time.time()) @@ -1981,38 +1936,61 @@ class McpServer: } } - gen = manager.request("librarian", request_data, None) + gen = manager.request( + "librarian", request_data, None, workspace=workspace, + ) async for response in gen: break - + return AddProcessingResponse() + def main(): parser = argparse.ArgumentParser(description='TrustGraph MCP Server') - parser.add_argument('--host', default='0.0.0.0', help='Host to bind to (default: 0.0.0.0)') - parser.add_argument('--port', type=int, default=8000, help='Port to bind to (default: 8000)') - parser.add_argument('--websocket-url', default='ws://api-gateway:8088/api/v1/socket', help='WebSocket URL to connect to (default: ws://api-gateway:8088/api/v1/socket)') + parser.add_argument( + '--host', default='0.0.0.0', + help='Host to bind to (default: 0.0.0.0)', + ) + parser.add_argument( + '--port', type=int, default=8000, + help='Port to bind to (default: 8000)', + ) + parser.add_argument( + '--websocket-url', + default='ws://api-gateway:8088/api/v1/socket', + help='WebSocket URL for the TrustGraph gateway', + ) + parser.add_argument( + '--auth-issuer', + default=os.environ.get("AUTH_ISSUER", ""), + help='OAuth issuer URL for MCP auth metadata discovery', + ) + parser.add_argument( + '--auth-resource-url', + default=os.environ.get("AUTH_RESOURCE_URL", ""), + help='Resource server URL for OAuth protected resource metadata', + ) - # Add logging arguments add_logging_args(parser) args = parser.parse_args() - # Setup logging before creating server setup_logging(vars(args)) - # Read gateway auth token from environment - gateway_token = os.environ.get("GATEWAY_SECRET", "") - - # Create and run the MCP server - server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url, gateway_token=gateway_token) + server = McpServer( + host=args.host, + port=args.port, + websocket_url=args.websocket_url, + auth_issuer=args.auth_issuer, + auth_resource_url=args.auth_resource_url, + ) server.run() + def run(): - """Legacy function for backward compatibility""" main() + if __name__ == "__main__": main() - diff --git a/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py b/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py index bff8ae75..9fbf7459 100644 --- a/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py +++ b/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py @@ -1,49 +1,110 @@ -from dataclasses import dataclass from websockets.asyncio.client import connect -from urllib.parse import urlencode, urlparse, urlunparse, parse_qs import asyncio import logging import json import uuid -import time +import hashlib + +logger = logging.getLogger(__name__) + + +def _token_key(token): + """Derive a dict key from a token without storing the raw secret.""" + return hashlib.sha256(token.encode()).hexdigest()[:16] + class WebSocketManager: + """Manages an authenticated WebSocket connection to the TrustGraph + gateway on behalf of a single caller. - def __init__(self, url, token=None): + Each caller token gets its own WebSocketManager so that gateway-side + identity, workspace, and capability scoping are preserved end-to-end. + """ + + def __init__(self, url, token): self.url = url + # ── Security boundary: token storage ── + # This is the MCP caller's Bearer token, forwarded verbatim to + # the gateway. It MUST NOT be logged, persisted, or shared + # across callers. It is held only for the lifetime of this + # connection so that re-auth (e.g. after a reconnect) is + # possible. self.token = token self.socket = None - - # FIXME: authentication is broken. The /api/v1/socket endpoint uses - # in-band auth (first-frame protocol via the Mux dispatcher), not - # query-parameter tokens. This query-string token is silently ignored. - # Fix: after connect(), send an auth frame with the bearer token as - # the first message, matching the gateway's in-band auth protocol. - def _build_url(self): - if not self.token: - return self.url - parsed = urlparse(self.url) - params = parse_qs(parsed.query) - params["token"] = [self.token] - new_query = urlencode(params, doseq=True) - return urlunparse(parsed._replace(query=new_query)) + self.identity = None + self.last_used = None async def start(self): - self.socket = await connect(self._build_url()) + """Connect and authenticate via the gateway's in-band auth + protocol. Raises on auth failure.""" + + # ── Security boundary: MCP server → gateway ── + # The WebSocket connects to the gateway and authenticates using + # the caller's Bearer token via the in-band first-frame auth + # protocol. The token belongs to the MCP client — we forward + # it as-is and never interpret its contents. + self.socket = await connect(self.url) self.pending_requests = {} self.running = True + + await self._authenticate() + self.reader_task = asyncio.create_task(self.reader()) + async def _authenticate(self): + """Send in-band auth frame and wait for auth-ok / auth-failed. + + The gateway expects ``{"type": "auth", "token": "..."}`` as the + first frame on a new WebSocket. Any service frame sent before + auth-ok is rejected. + """ + await self.socket.send(json.dumps({ + "type": "auth", + "token": self.token, + })) + + response_text = await asyncio.wait_for(self.socket.recv(), 10) + response = json.loads(response_text) + + if response.get("type") == "auth-ok": + logger.info( + "WebSocket authenticated, default workspace: %s", + response.get("workspace"), + ) + return + + # Auth failed — close immediately, do not leave an + # unauthenticated socket open. + await self.socket.close() + self.socket = None + + if response.get("type") == "auth-failed": + raise RuntimeError( + "Gateway rejected the authentication token" + ) + + raise RuntimeError( + f"Unexpected auth response type: {response.get('type')}" + ) + + async def whoami(self): + """Verify the token by calling the gateway's whoami endpoint. + Returns the identity dict and caches it on ``self.identity``. + """ + gen = self.request("iam", {"operation": "whoami"}, flow_id=None) + async for response in gen: + self.identity = response + return response + async def stop(self): self.running = False - await self.reader_task + if hasattr(self, "reader_task"): + await self.reader_task async def reader(self): - """ - Background task to read websocket responses and route to correct - request - """ + """Background task: read WebSocket frames and route them to the + correct pending-request queue by ``id``.""" while self.running: try: @@ -59,23 +120,21 @@ class WebSocketManager: request_id = response.get("id") if request_id and request_id in self.pending_requests: - # Put the response in the queue queue = self.pending_requests[request_id] await queue.put(response) else: - logging.warning( - f"Response for unknown request ID: {request_id}" + logger.warning( + "Response for unknown request ID: %s", request_id ) except Exception as e: - logging.error(f"Error in websocket reader: {e}") + logger.error("Error in websocket reader: %s", e) - # Put error in all pending queues for queue in self.pending_requests.values(): try: await queue.put({"error": str(e)}) - except: + except Exception: pass self.pending_requests.clear() @@ -86,25 +145,29 @@ class WebSocketManager: async def request( self, service, request_data, flow_id="default", + workspace=None, ): - """ - Send a request via websocket and handle single or streaming responses + """Send a request via WebSocket and yield responses. + + Args: + service: Gateway service name (e.g. "graph-rag", "config"). + request_data: Inner request payload. + flow_id: Optional flow identifier. ``None`` omits the field + (workspace-level services don't use flows). + workspace: Optional workspace override. When ``None`` the + gateway uses the caller's default workspace. """ - # Generate unique request ID + import time + self.last_used = time.monotonic() + request_id = f"{uuid.uuid4()}" - # Determine if this service streams responses - streaming_services = {"agent"} - is_streaming = service in streaming_services - - # Create a queue for all responses (streaming and single) response_queue = asyncio.Queue() self.pending_requests[request_id] = response_queue try: - # Build request message message = { "id": request_id, "service": service, @@ -114,7 +177,16 @@ class WebSocketManager: if flow_id is not None: message["flow"] = flow_id - # Send request + # ── Security boundary: workspace scoping ── + # When the caller supplies a workspace, we set it on the + # message envelope. The gateway's enforce_workspace() + # validates that the authenticated identity is permitted + # to access the target workspace — we MUST NOT skip or + # override that check. When workspace is None, the + # gateway default-fills from the identity's bound workspace. + if workspace is not None: + message["workspace"] = workspace + await self.socket.send(json.dumps(message)) while self.running: @@ -127,19 +199,17 @@ class WebSocketManager: continue if "error" in response: - if "message" in response["error"]: - raise RuntimeError(response["error"]["text"]) + if isinstance(response["error"], dict): + raise RuntimeError( + response["error"].get("message", str(response["error"])) + ) else: raise RuntimeError(str(response["error"])) yield response["response"] - if "complete" in response: - if response["complete"]: - break + if response.get("complete"): + break - except Exception as e: - # Clean up on error + finally: self.pending_requests.pop(request_id, None) - raise e - diff --git a/trustgraph-ocr/pyproject.toml b/trustgraph-ocr/pyproject.toml index fa9f7cd4..4b515032 100644 --- a/trustgraph-ocr/pyproject.toml +++ b/trustgraph-ocr/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.5,<2.6", + "trustgraph-base>=2.6,<2.7", "pulsar-client", "prometheus-client", "boto3", diff --git a/trustgraph-unstructured/pyproject.toml b/trustgraph-unstructured/pyproject.toml index f17b9812..dc987fd9 100644 --- a/trustgraph-unstructured/pyproject.toml +++ b/trustgraph-unstructured/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.5,<2.6", + "trustgraph-base>=2.6,<2.7", "pulsar-client", "prometheus-client", "python-magic", diff --git a/trustgraph-vertexai/pyproject.toml b/trustgraph-vertexai/pyproject.toml index 347594fe..50acce0d 100644 --- a/trustgraph-vertexai/pyproject.toml +++ b/trustgraph-vertexai/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.5,<2.6", + "trustgraph-base>=2.6,<2.7", "pulsar-client", "google-genai", "google-api-core", diff --git a/trustgraph/pyproject.toml b/trustgraph/pyproject.toml index bcc72a41..5746f7eb 100644 --- a/trustgraph/pyproject.toml +++ b/trustgraph/pyproject.toml @@ -10,13 +10,13 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.5,<2.6", - "trustgraph-bedrock>=2.5,<2.6", - "trustgraph-cli>=2.5,<2.6", - "trustgraph-embeddings-hf>=2.5,<2.6", - "trustgraph-flow>=2.5,<2.6", - "trustgraph-unstructured>=2.5,<2.6", - "trustgraph-vertexai>=2.5,<2.6", + "trustgraph-base>=2.6,<2.7", + "trustgraph-bedrock>=2.6,<2.7", + "trustgraph-cli>=2.6,<2.7", + "trustgraph-embeddings-hf>=2.6,<2.7", + "trustgraph-flow>=2.6,<2.7", + "trustgraph-unstructured>=2.6,<2.7", + "trustgraph-vertexai>=2.6,<2.7", ] classifiers = [ "Programming Language :: Python :: 3",