Merge branch 'release/v2.3'

This commit is contained in:
Cyber MacGeddon 2026-04-16 09:13:52 +01:00
commit 1f30a3bcea
155 changed files with 6526 additions and 1885 deletions

View file

@ -22,7 +22,7 @@ jobs:
uses: actions/checkout@v3
- name: Setup packages
run: make update-package-versions VERSION=2.2.999
run: make update-package-versions VERSION=2.3.999
- name: Setup environment
run: python3 -m venv env

View file

@ -40,10 +40,9 @@ jobs:
- name: Publish release distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
deploy-container-image:
build-platform-image:
name: Release container images
runs-on: ubuntu-24.04
name: Build ${{ matrix.container }} (${{ matrix.platform }})
permissions:
contents: write
id-token: write
@ -52,14 +51,24 @@ jobs:
strategy:
matrix:
container:
- trustgraph-base
- trustgraph-flow
- trustgraph-bedrock
- trustgraph-vertexai
- trustgraph-hf
- trustgraph-ocr
- trustgraph-unstructured
- trustgraph-mcp
- base
- flow
- bedrock
- vertexai
- hf
- ocr
- unstructured
- mcp
platform:
- amd64
- arm64
include:
- platform: amd64
runner: ubuntu-24.04
- platform: arm64
runner: ubuntu-24.04-arm
runs-on: ${{ matrix.runner }}
steps:
@ -76,12 +85,48 @@ jobs:
id: version
run: echo VERSION=$(git describe --exact-match --tags | sed 's/^v//') >> $GITHUB_OUTPUT
- name: Put version into package manifests
run: make update-package-versions VERSION=${{ steps.version.outputs.VERSION }}
- name: Build container
run: make platform-${{ matrix.container }}-${{ matrix.platform }} VERSION=${{ steps.version.outputs.VERSION }}
- name: Build container - ${{ matrix.container }}
run: make container-${{ matrix.container }} VERSION=${{ steps.version.outputs.VERSION }}
- name: Push container
run: make push-platform-${{ matrix.container }}-${{ matrix.platform }} VERSION=${{ steps.version.outputs.VERSION }}
- name: Push container - ${{ matrix.container }}
run: make push-${{ matrix.container }} VERSION=${{ steps.version.outputs.VERSION }}
combine-manifests:
name: Combine manifest ${{ matrix.container }}
runs-on: ubuntu-24.04
needs: build-platform-image
permissions:
contents: write
id-token: write
environment:
name: release
strategy:
matrix:
container:
- base
- flow
- bedrock
- vertexai
- hf
- ocr
- unstructured
- mcp
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker Hub token
run: echo ${{ secrets.DOCKER_SECRET }} > docker-token.txt
- name: Authenticate with Docker hub
run: make docker-hub-login
- name: Get version
id: version
run: echo VERSION=$(git describe --exact-match --tags | sed 's/^v//') >> $GITHUB_OUTPUT
- name: Combine and push manifest
run: make combine-manifest-${{ matrix.container }} VERSION=${{ steps.version.outputs.VERSION }}

181
Makefile
View file

@ -52,51 +52,12 @@ update-package-versions:
echo __version__ = \"${VERSION}\" > trustgraph/trustgraph/trustgraph_version.py
echo __version__ = \"${VERSION}\" > trustgraph-mcp/trustgraph/mcp_version.py
FORCE:
containers: container-base container-flow \
container-bedrock container-vertexai \
container-hf container-ocr \
container-unstructured container-mcp
containers: FORCE
${DOCKER} build -f containers/Containerfile.base \
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
${DOCKER} build -f containers/Containerfile.flow \
-t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
${DOCKER} build -f containers/Containerfile.bedrock \
-t ${CONTAINER_BASE}/trustgraph-bedrock:${VERSION} .
${DOCKER} build -f containers/Containerfile.vertexai \
-t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
${DOCKER} build -f containers/Containerfile.hf \
-t ${CONTAINER_BASE}/trustgraph-hf:${VERSION} .
${DOCKER} build -f containers/Containerfile.ocr \
-t ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} .
${DOCKER} build -f containers/Containerfile.unstructured \
-t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} .
${DOCKER} build -f containers/Containerfile.mcp \
-t ${CONTAINER_BASE}/trustgraph-mcp:${VERSION} .
some-containers:
${DOCKER} build -f containers/Containerfile.base \
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
${DOCKER} build -f containers/Containerfile.flow \
-t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
# ${DOCKER} build -f containers/Containerfile.unstructured \
# -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} .
# ${DOCKER} build -f containers/Containerfile.vertexai \
# -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
# ${DOCKER} build -f containers/Containerfile.mcp \
# -t ${CONTAINER_BASE}/trustgraph-mcp:${VERSION} .
# ${DOCKER} build -f containers/Containerfile.vertexai \
# -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
# ${DOCKER} build -f containers/Containerfile.bedrock \
# -t ${CONTAINER_BASE}/trustgraph-bedrock:${VERSION} .
basic-containers: update-package-versions
${DOCKER} build -f containers/Containerfile.base \
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
${DOCKER} build -f containers/Containerfile.flow \
-t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
container.ocr:
${DOCKER} build -f containers/Containerfile.ocr \
-t ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} .
some-containers: container-base container-flow
push:
${DOCKER} push ${CONTAINER_BASE}/trustgraph-base:${VERSION}
@ -109,54 +70,60 @@ push:
${DOCKER} push ${CONTAINER_BASE}/trustgraph-mcp:${VERSION}
# Individual container build targets
container-trustgraph-base: update-package-versions
${DOCKER} build -f containers/Containerfile.base -t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
container-%: update-package-versions
${DOCKER} build \
-f containers/Containerfile.${@:container-%=%} \
-t ${CONTAINER_BASE}/trustgraph-${@:container-%=%}:${VERSION} .
container-trustgraph-flow: update-package-versions
${DOCKER} build -f containers/Containerfile.flow -t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
# Multi-arch: build both platforms sequentially into one manifest (local use)
manifest-%: update-package-versions
-@${DOCKER} manifest rm \
${CONTAINER_BASE}/trustgraph-${@:manifest-%=%}:${VERSION}
${DOCKER} build --platform linux/amd64,linux/arm64 \
-f containers/Containerfile.${@:manifest-%=%} \
--manifest \
${CONTAINER_BASE}/trustgraph-${@:manifest-%=%}:${VERSION} .
container-trustgraph-bedrock: update-package-versions
${DOCKER} build -f containers/Containerfile.bedrock -t ${CONTAINER_BASE}/trustgraph-bedrock:${VERSION} .
# Multi-arch: build a single platform image (for parallel CI)
platform-%-amd64: update-package-versions
${DOCKER} build --platform linux/amd64 \
-f containers/Containerfile.${@:platform-%-amd64=%} \
-t ${CONTAINER_BASE}/trustgraph-${@:platform-%-amd64=%}:${VERSION}-amd64 .
container-trustgraph-vertexai: update-package-versions
${DOCKER} build -f containers/Containerfile.vertexai -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
platform-%-arm64: update-package-versions
${DOCKER} build --platform linux/arm64 \
-f containers/Containerfile.${@:platform-%-arm64=%} \
-t ${CONTAINER_BASE}/trustgraph-${@:platform-%-arm64=%}:${VERSION}-arm64 .
container-trustgraph-hf: update-package-versions
${DOCKER} build -f containers/Containerfile.hf -t ${CONTAINER_BASE}/trustgraph-hf:${VERSION} .
# Push a single platform image
push-platform-%-amd64:
${DOCKER} push \
${CONTAINER_BASE}/trustgraph-${@:push-platform-%-amd64=%}:${VERSION}-amd64
container-trustgraph-ocr: update-package-versions
${DOCKER} build -f containers/Containerfile.ocr -t ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} .
push-platform-%-arm64:
${DOCKER} push \
${CONTAINER_BASE}/trustgraph-${@:push-platform-%-arm64=%}:${VERSION}-arm64
container-trustgraph-unstructured: update-package-versions
${DOCKER} build -f containers/Containerfile.unstructured -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} .
# Combine per-platform images into a multi-arch manifest
combine-manifest-%:
-@${DOCKER} manifest rm \
${CONTAINER_BASE}/trustgraph-${@:combine-manifest-%=%}:${VERSION}
${DOCKER} manifest create \
${CONTAINER_BASE}/trustgraph-${@:combine-manifest-%=%}:${VERSION} \
docker://${CONTAINER_BASE}/trustgraph-${@:combine-manifest-%=%}:${VERSION}-amd64 \
docker://${CONTAINER_BASE}/trustgraph-${@:combine-manifest-%=%}:${VERSION}-arm64
${DOCKER} manifest push \
${CONTAINER_BASE}/trustgraph-${@:combine-manifest-%=%}:${VERSION}
container-trustgraph-mcp: update-package-versions
${DOCKER} build -f containers/Containerfile.mcp -t ${CONTAINER_BASE}/trustgraph-mcp:${VERSION} .
# Push a container
push-container-%:
${DOCKER} push \
${CONTAINER_BASE}/trustgraph-${@:push-container-%=%}:${VERSION}
# Individual container push targets
push-trustgraph-base:
${DOCKER} push ${CONTAINER_BASE}/trustgraph-base:${VERSION}
push-trustgraph-flow:
${DOCKER} push ${CONTAINER_BASE}/trustgraph-flow:${VERSION}
push-trustgraph-bedrock:
${DOCKER} push ${CONTAINER_BASE}/trustgraph-bedrock:${VERSION}
push-trustgraph-vertexai:
${DOCKER} push ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION}
push-trustgraph-hf:
${DOCKER} push ${CONTAINER_BASE}/trustgraph-hf:${VERSION}
push-trustgraph-ocr:
${DOCKER} push ${CONTAINER_BASE}/trustgraph-ocr:${VERSION}
push-trustgraph-unstructured:
${DOCKER} push ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION}
push-trustgraph-mcp:
${DOCKER} push ${CONTAINER_BASE}/trustgraph-mcp:${VERSION}
# Push a manifest (from local multi-arch build)
push-manifest-%:
${DOCKER} manifest push \
${CONTAINER_BASE}/trustgraph-${@:push-manifest-%=%}:${VERSION}
clean:
rm -rf wheels/
@ -164,52 +131,6 @@ clean:
set-version:
echo '"${VERSION}"' > templates/values/version.jsonnet
TEMPLATES=azure bedrock claude cohere mix llamafile mistral ollama openai vertexai \
openai-neo4j storage
DCS=$(foreach template,${TEMPLATES},${template:%=tg-launch-%.yaml})
MODELS=azure bedrock claude cohere llamafile mistral ollama openai vertexai
GRAPHS=cassandra neo4j falkordb memgraph
# tg-launch-%.yaml: templates/%.jsonnet templates/components/version.jsonnet
# jsonnet -Jtemplates \
# -S ${@:tg-launch-%.yaml=templates/%.jsonnet} > $@
# VECTORDB=milvus
VECTORDB=qdrant
JSONNET_FLAGS=-J templates -J .
# Temporarily going back to how templates were built in 0.9 because this
# is going away in 0.11.
update-templates: update-dcs
JSON_TO_YAML=python -c 'import sys, yaml, json; j=json.loads(sys.stdin.read()); print(yaml.safe_dump(j))'
update-dcs: set-version
for graph in ${GRAPHS}; do \
cm=$${graph},pulsar,${VECTORDB},grafana; \
input=templates/opts-to-docker-compose.jsonnet; \
output=tg-storage-$${graph}.yaml; \
echo $${graph} '->' $${output}; \
jsonnet ${JSONNET_FLAGS} \
--ext-str options=$${cm} $${input} | \
${JSON_TO_YAML} > $${output}; \
done
for model in ${MODELS}; do \
for graph in ${GRAPHS}; do \
cm=$${graph},pulsar,${VECTORDB},embeddings-hf,graph-rag,grafana,trustgraph,$${model}; \
input=templates/opts-to-docker-compose.jsonnet; \
output=tg-launch-$${model}-$${graph}.yaml; \
echo $${model} + $${graph} '->' $${output}; \
jsonnet ${JSONNET_FLAGS} \
--ext-str options=$${cm} $${input} | \
${JSON_TO_YAML} > $${output}; \
done; \
done
docker-hub-login:
cat docker-token.txt | \
${DOCKER} login -u trustgraph --password-stdin registry-1.docker.io

View file

@ -1,22 +1,22 @@
# ----------------------------------------------------------------------------
# Build an AI container. This does the torch install which is huge, and I
# like to avoid re-doing this.
# ----------------------------------------------------------------------------
# Torch is stable and compiles for ARM64 and AMD64 on Python 3.12
FROM docker.io/fedora:42 AS ai
ENV PIP_BREAK_SYSTEM_PACKAGES=1
RUN dnf install -y python3.13 && \
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
RUN dnf install -y python3.12 && \
alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \
python -m ensurepip --upgrade && \
pip3 install --no-cache-dir build wheel aiohttp && \
pip3 install --no-cache-dir pulsar-client==3.7.0 && \
dnf clean all
RUN pip3 install torch==2.5.1+cpu \
--index-url https://download.pytorch.org/whl/cpu
# This won't work on ARM
#RUN pip3 install torch==2.5.1+cpu \
# --index-url https://download.pytorch.org/whl/cpu
RUN pip3 install torch
RUN pip3 install --no-cache-dir \
langchain==0.3.25 langchain-core==0.3.60 langchain-huggingface==0.2.0 \

View file

@ -0,0 +1,117 @@
# proc-group — run TrustGraph as a single process
A dev-focused alternative to the per-container deployment. Instead of 30+
containers each running a single processor, `processor-group` runs all the
processors as asyncio tasks inside one Python process, sharing the event
loop, Prometheus registry, and (importantly) resources on your laptop.
This is **not** for production. Scale deployments should keep using
per-processor containers — one failure bringing down the whole process,
no horizontal scaling, and a single giant log are fine for dev and a
bad idea in prod.
## What this directory contains
- `group.yaml` — the group runner config. One entry per processor, each
with the dotted class path and a params dict. Defaults (pubsub backend,
rabbitmq host, log level) are pulled in per-entry with a YAML anchor.
- `README.md` — this file.
## Prerequisites
Install the TrustGraph packages into a venv:
```
pip install trustgraph-base trustgraph-flow trustgraph-unstructured
```
`trustgraph-base` provides the `processor-group` endpoint. The others
provide the processor classes that `group.yaml` imports at runtime.
`trustgraph-unstructured` is only needed if you want `document-decoder`
(the `universal-decoder` processor).
## Running it
Start infrastructure (cassandra, qdrant, rabbitmq, garage, observability
stack) with a working compose file. These aren't packable into the group -
they're third-party services. You may be able to run these as standalone
services.
To get Cassandra to be accessible from the host, you need to
set a couple of environment variables:
```
CASSANDRA_BROADCAST_ADDRESS: 127.0.0.1
CASSANDRA_LISTEN_ADDRESS: 127.0.0.1
```
and also set `network: host`. Then start services:
```
podman-compose up -d cassandra qdrant rabbitmq
podman-compose up -d garage garage-init
podman-compose up -d loki prometheus grafana
podman-compose up -d init-trustgraph
```
`init-trustgraph` is a one-shot that seeds config and the default flow
into cassandra/rabbitmq. Don't leave too long a delay between starting
`init-trustgraph` and running the processor-group, because it needs to
talk to the config service.
Run the api-gateway separately — it's an aiohttp HTTP server, not an
`AsyncProcessor`, so the group runner doesn't host it:
Raise the file descriptor limit — 30+ processors sharing one process
open far more sockets than the default 1024 allows:
```
ulimit -n 65536
```
Then start the group from a terminal:
```
processor-group -c group.yaml --no-loki-enabled
```
You'll see every processor's startup messages interleaved in one log.
Each processor has a supervisor that restarts it independently on
failure, so a transient crash (or a dependency that isn't ready yet)
only affects that one processor — siblings keep running and the failing
one self-heals on the next retry.
Finally when everything is running you can start the API gateway from
its own terminal:
```
api-gateway \
--pubsub-backend rabbitmq --rabbitmq-host localhost \
--loki-url http://localhost:3100/loki/api/v1/push \
--no-metrics
```
## When things go wrong
- **"Too many open files"** — raise `ulimit -n` further. 65536 is
usually plenty but some workflows need more.
- **One processor failing repeatedly** — look for its id in the log. The
supervisor will log each failure before restarting. Fix the cause
(missing env var, unreachable dependency, bad params) and the
processor self-heals on the next 4-second retry without restarting
the whole group.
- **Ctrl-C leaves the process hung** — the pika and cassandra drivers
spawn non-cooperative threads that asyncio can't cancel. Use Ctrl-\
(SIGQUIT) to force-kill. Not a bug in the group runner, just a
limitation of those libraries.
## Environment variables
Processors that talk to external LLMs or APIs read their credentials
from env vars, same as in the per-container deployment:
- `OPENAI_TOKEN`, `OPENAI_BASE_URL` — for `text-completion` /
`text-completion-rag`
Export whatever your particular `group.yaml` needs before running.

View file

@ -0,0 +1,257 @@
# Multi-processor group config, derived from docker-compose.yaml.
#
# Covers every AsyncProcessor-based service from the compose file.
# Out of scope:
# - api-gateway (aiohttp, not AsyncProcessor)
# - init-trustgraph (one-shot init, not a processor)
# - document-decoder (universal-decoder, trustgraph-unstructured package —
# packable but lives in a separate image/package)
# - mcp-server (trustgraph-mcp package, separate image)
# - ddg-mcp-server (third-party image)
# - infrastructure (cassandra, rabbitmq, qdrant, garage, grafana,
# prometheus, loki, workbench-ui)
#
# Run with:
# processor-group -c group.yaml
_defaults: &defaults
pubsub_backend: rabbitmq
rabbitmq_host: localhost
log_level: INFO
processors:
- class: trustgraph.agent.orchestrator.Processor
params:
<<: *defaults
id: agent-manager
- class: trustgraph.chunking.recursive.Processor
params:
<<: *defaults
id: chunker
chunk_size: 2000
chunk_overlap: 50
- class: trustgraph.config.service.Processor
params:
<<: *defaults
id: config-svc
cassandra_host: localhost
- class: trustgraph.decoding.universal.Processor
params:
<<: *defaults
id: document-decoder
- class: trustgraph.embeddings.document_embeddings.Processor
params:
<<: *defaults
id: document-embeddings
- class: trustgraph.retrieval.document_rag.Processor
params:
<<: *defaults
id: document-rag
doc_limit: 20
- class: trustgraph.embeddings.fastembed.Processor
params:
<<: *defaults
id: embeddings
concurrency: 1
- class: trustgraph.embeddings.graph_embeddings.Processor
params:
<<: *defaults
id: graph-embeddings
- class: trustgraph.retrieval.graph_rag.Processor
params:
<<: *defaults
id: graph-rag
concurrency: 1
entity_limit: 50
triple_limit: 30
edge_limit: 30
edge_score_limit: 10
max_subgraph_size: 100
max_path_length: 2
- class: trustgraph.extract.kg.agent.Processor
params:
<<: *defaults
id: kg-extract-agent
concurrency: 1
- class: trustgraph.extract.kg.definitions.Processor
params:
<<: *defaults
id: kg-extract-definitions
concurrency: 1
- class: trustgraph.extract.kg.ontology.Processor
params:
<<: *defaults
id: kg-extract-ontology
concurrency: 1
- class: trustgraph.extract.kg.relationships.Processor
params:
<<: *defaults
id: kg-extract-relationships
concurrency: 1
- class: trustgraph.extract.kg.rows.Processor
params:
<<: *defaults
id: kg-extract-rows
concurrency: 1
- class: trustgraph.cores.service.Processor
params:
<<: *defaults
id: knowledge
cassandra_host: localhost
- class: trustgraph.storage.knowledge.store.Processor
params:
<<: *defaults
id: kg-store
cassandra_host: localhost
- class: trustgraph.librarian.Processor
params:
<<: *defaults
id: librarian
cassandra_host: localhost
object_store_endpoint: localhost:3900
object_store_access_key: GK000000000000000000000001
object_store_secret_key: b171f00be9be4c32c734f4c05fe64c527a8ab5eb823b376cfa8c2531f70fc427
object_store_region: garage
- class: trustgraph.agent.mcp_tool.Service
params:
<<: *defaults
id: mcp-tool
- class: trustgraph.metering.Processor
params:
<<: *defaults
id: metering
- class: trustgraph.metering.Processor
params:
<<: *defaults
id: metering-rag
- class: trustgraph.retrieval.nlp_query.Processor
params:
<<: *defaults
id: nlp-query
- class: trustgraph.prompt.template.Processor
params:
<<: *defaults
id: prompt
concurrency: 1
- class: trustgraph.prompt.template.Processor
params:
<<: *defaults
id: prompt-rag
concurrency: 1
- class: trustgraph.query.doc_embeddings.qdrant.Processor
params:
<<: *defaults
id: doc-embeddings-query
store_uri: http://localhost:6333
- class: trustgraph.query.graph_embeddings.qdrant.Processor
params:
<<: *defaults
id: graph-embeddings-query
store_uri: http://localhost:6333
- class: trustgraph.query.row_embeddings.qdrant.Processor
params:
<<: *defaults
id: row-embeddings-query
store_uri: http://localhost:6333
- class: trustgraph.query.rows.cassandra.Processor
params:
<<: *defaults
id: rows-query
cassandra_host: localhost
- class: trustgraph.query.triples.cassandra.Processor
params:
<<: *defaults
id: triples-query
cassandra_host: localhost
- class: trustgraph.embeddings.row_embeddings.Processor
params:
<<: *defaults
id: row-embeddings
- class: trustgraph.query.sparql.Processor
params:
<<: *defaults
id: sparql-query
- class: trustgraph.storage.doc_embeddings.qdrant.Processor
params:
<<: *defaults
id: doc-embeddings-write
store_uri: http://localhost:6333
- class: trustgraph.storage.graph_embeddings.qdrant.Processor
params:
<<: *defaults
id: graph-embeddings-write
store_uri: http://localhost:6333
- class: trustgraph.storage.row_embeddings.qdrant.Processor
params:
<<: *defaults
id: row-embeddings-write
store_uri: http://localhost:6333
- class: trustgraph.storage.rows.cassandra.Processor
params:
<<: *defaults
id: rows-write
cassandra_host: localhost
- class: trustgraph.storage.triples.cassandra.Processor
params:
<<: *defaults
id: triples-write
cassandra_host: localhost
- class: trustgraph.retrieval.structured_diag.Processor
params:
<<: *defaults
id: structured-diag
- class: trustgraph.retrieval.structured_query.Processor
params:
<<: *defaults
id: structured-query
- class: trustgraph.model.text_completion.openai.Processor
params:
<<: *defaults
id: text-completion
max_output: 8192
temperature: 0.0
- class: trustgraph.model.text_completion.openai.Processor
params:
<<: *defaults
id: text-completion-rag
max_output: 8192
temperature: 0.0

View file

@ -0,0 +1,47 @@
# Control plane. Stateful "always on" services that every flow depends on.
# Cassandra-heavy, low traffic.
_defaults: &defaults
pubsub_backend: rabbitmq
rabbitmq_host: localhost
log_level: INFO
processors:
- class: trustgraph.config.service.Processor
params:
<<: *defaults
id: config-svc
cassandra_host: localhost
- class: trustgraph.librarian.Processor
params:
<<: *defaults
id: librarian
cassandra_host: localhost
object_store_endpoint: localhost:3900
object_store_access_key: GK000000000000000000000001
object_store_secret_key: b171f00be9be4c32c734f4c05fe64c527a8ab5eb823b376cfa8c2531f70fc427
object_store_region: garage
- class: trustgraph.cores.service.Processor
params:
<<: *defaults
id: knowledge
cassandra_host: localhost
- class: trustgraph.storage.knowledge.store.Processor
params:
<<: *defaults
id: kg-store
cassandra_host: localhost
- class: trustgraph.metering.Processor
params:
<<: *defaults
id: metering
- class: trustgraph.metering.Processor
params:
<<: *defaults
id: metering-rag

View file

@ -0,0 +1,45 @@
# Embeddings store. All Qdrant-backed vector query/write processors.
# One process owns the Qdrant driver pool for the whole group.
_defaults: &defaults
pubsub_backend: rabbitmq
rabbitmq_host: localhost
log_level: INFO
processors:
- class: trustgraph.query.doc_embeddings.qdrant.Processor
params:
<<: *defaults
id: doc-embeddings-query
store_uri: http://localhost:6333
- class: trustgraph.storage.doc_embeddings.qdrant.Processor
params:
<<: *defaults
id: doc-embeddings-write
store_uri: http://localhost:6333
- class: trustgraph.query.graph_embeddings.qdrant.Processor
params:
<<: *defaults
id: graph-embeddings-query
store_uri: http://localhost:6333
- class: trustgraph.storage.graph_embeddings.qdrant.Processor
params:
<<: *defaults
id: graph-embeddings-write
store_uri: http://localhost:6333
- class: trustgraph.query.row_embeddings.qdrant.Processor
params:
<<: *defaults
id: row-embeddings-query
store_uri: http://localhost:6333
- class: trustgraph.storage.row_embeddings.qdrant.Processor
params:
<<: *defaults
id: row-embeddings-write
store_uri: http://localhost:6333

View file

@ -0,0 +1,31 @@
# Embeddings. Memory-hungry — fastembed loads an ML model at startup.
# Keep isolated from other groups so its memory footprint and restart
# latency don't affect siblings.
_defaults: &defaults
pubsub_backend: rabbitmq
rabbitmq_host: localhost
log_level: INFO
processors:
- class: trustgraph.embeddings.fastembed.Processor
params:
<<: *defaults
id: embeddings
concurrency: 1
- class: trustgraph.embeddings.document_embeddings.Processor
params:
<<: *defaults
id: document-embeddings
- class: trustgraph.embeddings.graph_embeddings.Processor
params:
<<: *defaults
id: graph-embeddings
- class: trustgraph.embeddings.row_embeddings.Processor
params:
<<: *defaults
id: row-embeddings

View file

@ -0,0 +1,52 @@
# Ingest pipeline. Document-processing hot path. Bursty, correlated
# failures — if chunker dies the extractors have nothing to do anyway.
_defaults: &defaults
pubsub_backend: rabbitmq
rabbitmq_host: localhost
log_level: INFO
processors:
- class: trustgraph.chunking.recursive.Processor
params:
<<: *defaults
id: chunker
chunk_size: 2000
chunk_overlap: 50
- class: trustgraph.extract.kg.agent.Processor
params:
<<: *defaults
id: kg-extract-agent
concurrency: 1
- class: trustgraph.extract.kg.definitions.Processor
params:
<<: *defaults
id: kg-extract-definitions
concurrency: 1
- class: trustgraph.extract.kg.ontology.Processor
params:
<<: *defaults
id: kg-extract-ontology
concurrency: 1
- class: trustgraph.extract.kg.relationships.Processor
params:
<<: *defaults
id: kg-extract-relationships
concurrency: 1
- class: trustgraph.extract.kg.rows.Processor
params:
<<: *defaults
id: kg-extract-rows
concurrency: 1
- class: trustgraph.prompt.template.Processor
params:
<<: *defaults
id: prompt
concurrency: 1

View file

@ -0,0 +1,24 @@
# LLM. Outbound text-completion calls. Isolated because the upstream
# LLM API is often the bottleneck and the most likely thing to need
# restart (provider changes, model changes, API flakiness).
_defaults: &defaults
pubsub_backend: rabbitmq
rabbitmq_host: localhost
log_level: INFO
processors:
- class: trustgraph.model.text_completion.openai.Processor
params:
<<: *defaults
id: text-completion
max_output: 8192
temperature: 0.0
- class: trustgraph.model.text_completion.openai.Processor
params:
<<: *defaults
id: text-completion-rag
max_output: 8192
temperature: 0.0

View file

@ -0,0 +1,64 @@
# RAG / retrieval / agent. Query-time serving path. Drives outbound
# LLM calls via prompt-rag. sparql-query lives here because it's a
# read-side serving endpoint, not a backend writer.
_defaults: &defaults
pubsub_backend: rabbitmq
rabbitmq_host: localhost
log_level: INFO
processors:
- class: trustgraph.agent.orchestrator.Processor
params:
<<: *defaults
id: agent-manager
- class: trustgraph.retrieval.graph_rag.Processor
params:
<<: *defaults
id: graph-rag
concurrency: 1
entity_limit: 50
triple_limit: 30
edge_limit: 30
edge_score_limit: 10
max_subgraph_size: 100
max_path_length: 2
- class: trustgraph.retrieval.document_rag.Processor
params:
<<: *defaults
id: document-rag
doc_limit: 20
- class: trustgraph.retrieval.nlp_query.Processor
params:
<<: *defaults
id: nlp-query
- class: trustgraph.retrieval.structured_query.Processor
params:
<<: *defaults
id: structured-query
- class: trustgraph.retrieval.structured_diag.Processor
params:
<<: *defaults
id: structured-diag
- class: trustgraph.query.sparql.Processor
params:
<<: *defaults
id: sparql-query
- class: trustgraph.prompt.template.Processor
params:
<<: *defaults
id: prompt-rag
concurrency: 1
- class: trustgraph.agent.mcp_tool.Service
params:
<<: *defaults
id: mcp-tool

View file

@ -0,0 +1,20 @@
# Rows store. Cassandra-backed structured row query/write.
_defaults: &defaults
pubsub_backend: rabbitmq
rabbitmq_host: localhost
log_level: INFO
processors:
- class: trustgraph.query.rows.cassandra.Processor
params:
<<: *defaults
id: rows-query
cassandra_host: localhost
- class: trustgraph.storage.rows.cassandra.Processor
params:
<<: *defaults
id: rows-write
cassandra_host: localhost

View file

@ -0,0 +1,20 @@
# Triples store. Cassandra-backed RDF triple query/write.
_defaults: &defaults
pubsub_backend: rabbitmq
rabbitmq_host: localhost
log_level: INFO
processors:
- class: trustgraph.query.triples.cassandra.Processor
params:
<<: *defaults
id: triples-query
cassandra_host: localhost
- class: trustgraph.storage.triples.cassandra.Processor
params:
<<: *defaults
id: triples-write
cassandra_host: localhost

View file

@ -131,21 +131,21 @@ async def analyse(path, url, flow, user, collection):
for i, msg in enumerate(messages):
resp = msg.get("response", {})
chunk_type = resp.get("chunk_type", "?")
message_type = resp.get("message_type", "?")
if chunk_type == "explain":
if message_type == "explain":
explain_id = resp.get("explain_id", "")
explain_ids.append(explain_id)
print(f" {i:3d} {chunk_type} {explain_id}")
print(f" {i:3d} {message_type} {explain_id}")
else:
print(f" {i:3d} {chunk_type}")
print(f" {i:3d} {message_type}")
# Rule 7: message_id on content chunks
if chunk_type in ("thought", "observation", "answer"):
if message_type in ("thought", "observation", "answer"):
mid = resp.get("message_id", "")
if not mid:
errors.append(
f"[msg {i}] {chunk_type} chunk missing message_id"
f"[msg {i}] {message_type} chunk missing message_id"
)
print()

View file

@ -6,273 +6,212 @@ parent: "Tech Specs"
# Agent Explainability: Provenance Recording
## Status
Implemented
## Overview
Add provenance recording to the React agent loop so agent sessions can be traced and debugged using the same explainability infrastructure as GraphRAG.
Agent sessions are traced and debugged using the same explainability infrastructure as GraphRAG and Document RAG. Provenance is written to `urn:graph:retrieval` and delivered inline on the explain stream.
**Design Decisions:**
- Write to `urn:graph:retrieval` (generic explainability graph)
- Linear dependency chain for now (analysis N → wasDerivedFrom → analysis N-1)
- Tools are opaque black boxes (record input/output only)
- DAG support deferred to future iteration
The canonical vocabulary for all predicates and types is published as an OWL ontology at `specs/ontology/trustgraph.ttl`.
## Entity Types
Both GraphRAG and Agent use PROV-O as the base ontology with TrustGraph-specific subtypes:
All services use PROV-O as the base ontology with TrustGraph-specific subtypes.
### GraphRAG Types
| Entity | PROV-O Type | TG Types | Description |
|--------|-------------|----------|-------------|
| Question | `prov:Activity` | `tg:Question`, `tg:GraphRagQuestion` | The user's query |
| Exploration | `prov:Entity` | `tg:Exploration` | Edges retrieved from knowledge graph |
| Focus | `prov:Entity` | `tg:Focus` | Selected edges with reasoning |
| Synthesis | `prov:Entity` | `tg:Synthesis` | Final answer |
### Agent Types
| Entity | PROV-O Type | TG Types | Description |
|--------|-------------|----------|-------------|
| Question | `prov:Activity` | `tg:Question`, `tg:AgentQuestion` | The user's query |
| Analysis | `prov:Entity` | `tg:Analysis` | Each think/act/observe cycle |
| Conclusion | `prov:Entity` | `tg:Conclusion` | Final answer |
| Entity | TG Types | Description |
|--------|----------|-------------|
| Question | `tg:Question`, `tg:GraphRagQuestion` | The user's query |
| Grounding | `tg:Grounding` | Concept extraction from query |
| Exploration | `tg:Exploration` | Edges retrieved from knowledge graph |
| Focus | `tg:Focus` | Selected edges with reasoning |
| Synthesis | `tg:Synthesis`, `tg:Answer` | Final answer |
### Document RAG Types
| Entity | PROV-O Type | TG Types | Description |
|--------|-------------|----------|-------------|
| Question | `prov:Activity` | `tg:Question`, `tg:DocRagQuestion` | The user's query |
| Exploration | `prov:Entity` | `tg:Exploration` | Chunks retrieved from document store |
| Synthesis | `prov:Entity` | `tg:Synthesis` | Final answer |
| Entity | TG Types | Description |
|--------|----------|-------------|
| Question | `tg:Question`, `tg:DocRagQuestion` | The user's query |
| Grounding | `tg:Grounding` | Concept extraction from query |
| Exploration | `tg:Exploration` | Chunks retrieved from document store |
| Synthesis | `tg:Synthesis`, `tg:Answer` | Final answer |
**Note:** Document RAG uses a subset of GraphRAG's types (no Focus step since there's no edge selection/reasoning phase).
### Agent Types (React)
| Entity | TG Types | Description |
|--------|----------|-------------|
| Question | `tg:Question`, `tg:AgentQuestion` | The user's query (session start) |
| PatternDecision | `tg:PatternDecision` | Meta-router routing decision |
| Analysis | `tg:Analysis`, `tg:ToolUse` | One think/act cycle |
| Thought | `tg:Reflection`, `tg:Thought` | Agent reasoning (sub-entity of Analysis) |
| Observation | `tg:Reflection`, `tg:Observation` | Tool result (standalone entity) |
| Conclusion | `tg:Conclusion`, `tg:Answer` | Final answer |
### Agent Types (Orchestrator — Plan)
| Entity | TG Types | Description |
|--------|----------|-------------|
| Plan | `tg:Plan` | Structured plan of steps |
| StepResult | `tg:StepResult`, `tg:Answer` | Result from executing one plan step |
| Synthesis | `tg:Synthesis`, `tg:Answer` | Final synthesised answer |
### Agent Types (Orchestrator — Supervisor)
| Entity | TG Types | Description |
|--------|----------|-------------|
| Decomposition | `tg:Decomposition` | Question decomposed into sub-goals |
| Finding | `tg:Finding`, `tg:Answer` | Result from a sub-agent |
| Synthesis | `tg:Synthesis`, `tg:Answer` | Final synthesised answer |
### Mixin Types
| Type | Description |
|------|-------------|
| `tg:Answer` | Unifying type for terminal answers (Synthesis, Conclusion, Finding, StepResult) |
| `tg:Reflection` | Unifying type for intermediate commentary (Thought, Observation) |
| `tg:ToolUse` | Applied to Analysis when a tool is invoked |
| `tg:Error` | Applied to Observation events where a failure occurred (tool error or LLM parse error) |
### Question Subtypes
All Question entities share `tg:Question` as a base type but have a specific subtype to identify the retrieval mechanism:
| Subtype | URI Pattern | Mechanism |
|---------|-------------|-----------|
| `tg:GraphRagQuestion` | `urn:trustgraph:question:{uuid}` | Knowledge graph RAG |
| `tg:DocRagQuestion` | `urn:trustgraph:docrag:{uuid}` | Document/chunk RAG |
| `tg:AgentQuestion` | `urn:trustgraph:agent:{uuid}` | ReAct agent |
| `tg:AgentQuestion` | `urn:trustgraph:agent:session:{uuid}` | Agent orchestrator |
This allows querying all questions via `tg:Question` while filtering by specific mechanism via the subtype.
## Provenance Chains
## Provenance Model
All chains use `prov:wasDerivedFrom` links. Each entity is a `prov:Entity`.
### GraphRAG
```
Question (urn:trustgraph:agent:{uuid})
│ tg:query = "User's question"
│ prov:startedAtTime = timestamp
│ rdf:type = prov:Activity, tg:Question
↓ prov:wasDerivedFrom
Analysis1 (urn:trustgraph:agent:{uuid}/i1)
│ tg:thought = "I need to query the knowledge base..."
│ tg:action = "knowledge-query"
│ tg:arguments = {"question": "..."}
│ tg:observation = "Result from tool..."
│ rdf:type = prov:Entity, tg:Analysis
↓ prov:wasDerivedFrom
Analysis2 (urn:trustgraph:agent:{uuid}/i2)
│ ...
↓ prov:wasDerivedFrom
Conclusion (urn:trustgraph:agent:{uuid}/final)
│ tg:answer = "The final response..."
│ rdf:type = prov:Entity, tg:Conclusion
Question → Grounding → Exploration → Focus → Synthesis
```
### Document RAG Provenance Model
### Document RAG
```
Question (urn:trustgraph:docrag:{uuid})
│ tg:query = "User's question"
│ prov:startedAtTime = timestamp
│ rdf:type = prov:Activity, tg:Question
↓ prov:wasGeneratedBy
Exploration (urn:trustgraph:docrag:{uuid}/exploration)
│ tg:chunkCount = 5
│ tg:selectedChunk = "chunk-id-1"
│ tg:selectedChunk = "chunk-id-2"
│ ...
│ rdf:type = prov:Entity, tg:Exploration
↓ prov:wasDerivedFrom
Synthesis (urn:trustgraph:docrag:{uuid}/synthesis)
│ tg:content = "The synthesized answer..."
│ rdf:type = prov:Entity, tg:Synthesis
Question → Grounding → Exploration → Synthesis
```
## Changes Required
### Agent React
### 1. Schema Changes
**File:** `trustgraph-base/trustgraph/schema/services/agent.py`
Add `session_id` and `collection` fields to `AgentRequest`:
```python
@dataclass
class AgentRequest:
question: str = ""
state: str = ""
group: list[str] | None = None
history: list[AgentStep] = field(default_factory=list)
user: str = ""
collection: str = "default" # NEW: Collection for provenance traces
streaming: bool = False
session_id: str = "" # NEW: For provenance tracking across iterations
```
Question → PatternDecision → Analysis(1) → Observation(1) → Analysis(2) → ... → Conclusion
```
**File:** `trustgraph-base/trustgraph/messaging/translators/agent.py`
The PatternDecision entity records which execution pattern the meta-router selected. It is only emitted on the first iteration when routing occurs.
Update translator to handle `session_id` and `collection` in both `to_pulsar()` and `from_pulsar()`.
Thought sub-entities derive from their parent Analysis. Observation entities derive from their parent Analysis (or from a sub-trace entity if the tool produced its own explainability, e.g. a GraphRAG query).
### 2. Add Explainability Producer to Agent Service
### Agent Plan-then-Execute
**File:** `trustgraph-flow/trustgraph/agent/react/service.py`
Register an "explainability" producer (same pattern as GraphRAG):
```python
from ... base import ProducerSpec
from ... schema import Triples
# In __init__:
self.register_specification(
ProducerSpec(
name = "explainability",
schema = Triples,
)
)
```
Question → PatternDecision → Plan → StepResult(0) → StepResult(1) → ... → Synthesis
```
### 3. Provenance Triple Generation
### Agent Supervisor
**File:** `trustgraph-base/trustgraph/provenance/agent.py`
Create helper functions (similar to GraphRAG's `question_triples`, `exploration_triples`, etc.):
```python
def agent_session_triples(session_uri, query, timestamp):
"""Generate triples for agent Question."""
return [
Triple(s=session_uri, p=RDF_TYPE, o=PROV_ACTIVITY),
Triple(s=session_uri, p=RDF_TYPE, o=TG_QUESTION),
Triple(s=session_uri, p=TG_QUERY, o=query),
Triple(s=session_uri, p=PROV_STARTED_AT_TIME, o=timestamp),
]
def agent_iteration_triples(iteration_uri, parent_uri, thought, action, arguments, observation):
"""Generate triples for one Analysis step."""
return [
Triple(s=iteration_uri, p=RDF_TYPE, o=PROV_ENTITY),
Triple(s=iteration_uri, p=RDF_TYPE, o=TG_ANALYSIS),
Triple(s=iteration_uri, p=TG_THOUGHT, o=thought),
Triple(s=iteration_uri, p=TG_ACTION, o=action),
Triple(s=iteration_uri, p=TG_ARGUMENTS, o=json.dumps(arguments)),
Triple(s=iteration_uri, p=TG_OBSERVATION, o=observation),
Triple(s=iteration_uri, p=PROV_WAS_DERIVED_FROM, o=parent_uri),
]
def agent_final_triples(final_uri, parent_uri, answer):
"""Generate triples for Conclusion."""
return [
Triple(s=final_uri, p=RDF_TYPE, o=PROV_ENTITY),
Triple(s=final_uri, p=RDF_TYPE, o=TG_CONCLUSION),
Triple(s=final_uri, p=TG_ANSWER, o=answer),
Triple(s=final_uri, p=PROV_WAS_DERIVED_FROM, o=parent_uri),
]
```
Question → PatternDecision → Decomposition → [fan-out sub-agents]
→ Finding(0) → Finding(1) → ... → Synthesis
```
### 4. Type Definitions
Each sub-agent runs its own session with `wasDerivedFrom` linking back to the parent's Decomposition. Findings derive from their sub-agent's Conclusion.
**File:** `trustgraph-base/trustgraph/provenance/namespaces.py`
## Predicates
Add explainability entity types and agent predicates:
```python
# Explainability entity types (used by both GraphRAG and Agent)
TG_QUESTION = TG + "Question"
TG_EXPLORATION = TG + "Exploration"
TG_FOCUS = TG + "Focus"
TG_SYNTHESIS = TG + "Synthesis"
TG_ANALYSIS = TG + "Analysis"
TG_CONCLUSION = TG + "Conclusion"
### Session / Question
| Predicate | Type | Description |
|-----------|------|-------------|
| `tg:query` | string | The user's query text |
| `prov:startedAtTime` | string | ISO timestamp |
# Agent predicates
TG_THOUGHT = TG + "thought"
TG_ACTION = TG + "action"
TG_ARGUMENTS = TG + "arguments"
TG_OBSERVATION = TG + "observation"
TG_ANSWER = TG + "answer"
```
### Pattern Decision
| Predicate | Type | Description |
|-----------|------|-------------|
| `tg:pattern` | string | Selected pattern (react, plan-then-execute, supervisor) |
| `tg:taskType` | string | Identified task type (general, research, etc.) |
## Files Modified
### Analysis (Iteration)
| Predicate | Type | Description |
|-----------|------|-------------|
| `tg:action` | string | Tool name selected by the agent |
| `tg:arguments` | string | JSON-encoded arguments |
| `tg:thought` | IRI | Link to Thought sub-entity |
| `tg:toolCandidate` | string | Tool name available to the LLM (one per candidate) |
| `tg:stepNumber` | integer | 1-based iteration counter |
| `tg:llmDurationMs` | integer | LLM call duration in milliseconds |
| `tg:inToken` | integer | Input token count |
| `tg:outToken` | integer | Output token count |
| `tg:llmModel` | string | Model identifier |
| File | Change |
|------|--------|
| `trustgraph-base/trustgraph/schema/services/agent.py` | Add session_id and collection to AgentRequest |
| `trustgraph-base/trustgraph/messaging/translators/agent.py` | Update translator for new fields |
| `trustgraph-base/trustgraph/provenance/namespaces.py` | Add entity types, agent predicates, and Document RAG predicates |
| `trustgraph-base/trustgraph/provenance/triples.py` | Add TG types to GraphRAG triple builders, add Document RAG triple builders |
| `trustgraph-base/trustgraph/provenance/uris.py` | Add Document RAG URI generators |
| `trustgraph-base/trustgraph/provenance/__init__.py` | Export new types, predicates, and Document RAG functions |
| `trustgraph-base/trustgraph/schema/services/retrieval.py` | Add explain_id, explain_graph, and explain_triples to DocumentRagResponse |
| `trustgraph-base/trustgraph/messaging/translators/retrieval.py` | Update DocumentRagResponseTranslator for explainability fields including inline triples |
| `trustgraph-flow/trustgraph/agent/react/service.py` | Add explainability producer + recording logic |
| `trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py` | Add explainability callback and emit provenance triples |
| `trustgraph-flow/trustgraph/retrieval/document_rag/rag.py` | Add explainability producer and wire up callback |
| `trustgraph-cli/trustgraph/cli/show_explain_trace.py` | Handle agent trace types |
| `trustgraph-cli/trustgraph/cli/list_explain_traces.py` | List agent sessions alongside GraphRAG |
### Observation
| Predicate | Type | Description |
|-----------|------|-------------|
| `tg:document` | IRI | Librarian document reference |
| `tg:toolDurationMs` | integer | Tool execution time in milliseconds |
| `tg:toolError` | string | Error message (tool failure or LLM parse error) |
## Files Created
When `tg:toolError` is present, the Observation also carries the `tg:Error` mixin type.
| File | Purpose |
|------|---------|
| `trustgraph-base/trustgraph/provenance/agent.py` | Agent-specific triple generators |
### Conclusion / Synthesis
| Predicate | Type | Description |
|-----------|------|-------------|
| `tg:document` | IRI | Librarian document reference |
| `tg:terminationReason` | string | Why the loop stopped |
| `tg:inToken` | integer | Input token count (synthesis LLM call) |
| `tg:outToken` | integer | Output token count |
| `tg:llmModel` | string | Model identifier |
## CLI Updates
Termination reason values:
- `final-answer` -- LLM produced a confident answer (react)
- `plan-complete` -- all plan steps executed (plan-then-execute)
- `subagents-complete` -- all sub-agents reported back (supervisor)
**Detection:** Both GraphRAG and Agent Questions have `tg:Question` type. Distinguished by:
1. URI pattern: `urn:trustgraph:agent:` vs `urn:trustgraph:question:`
2. Derived entities: `tg:Analysis` (agent) vs `tg:Exploration` (GraphRAG)
### Decomposition
| Predicate | Type | Description |
|-----------|------|-------------|
| `tg:subagentGoal` | string | Goal assigned to a sub-agent (one per goal) |
| `tg:inToken` | integer | Input token count |
| `tg:outToken` | integer | Output token count |
**`list_explain_traces.py`:**
- Shows Type column (Agent vs GraphRAG)
### Plan
| Predicate | Type | Description |
|-----------|------|-------------|
| `tg:planStep` | string | Goal for a plan step (one per step) |
| `tg:inToken` | integer | Input token count |
| `tg:outToken` | integer | Output token count |
**`show_explain_trace.py`:**
- Auto-detects trace type
- Agent rendering shows: Question → Analysis step(s) → Conclusion
### Token Counts on RAG Events
## Backwards Compatibility
Grounding, Focus, and Synthesis events on GraphRAG and Document RAG also carry `tg:inToken`, `tg:outToken`, and `tg:llmModel` for the LLM calls associated with that step.
- `session_id` defaults to `""` - old requests work, just won't have provenance
- `collection` defaults to `"default"` - reasonable fallback
- CLI gracefully handles both trace types
## Error Handling
Tool execution errors and LLM parse errors are captured as Observation events rather than crashing the agent:
- The error message is recorded on `tg:toolError`
- The Observation carries the `tg:Error` mixin type
- The error text becomes the observation content, visible to the LLM on the next iteration
- The provenance chain is preserved (Observation derives from Analysis)
- The agent gets another iteration to retry or choose a different approach
## Vocabulary Reference
The full OWL ontology covering all classes and predicates is at `specs/ontology/trustgraph.ttl`.
## Verification
```bash
# Run an agent query
tg-invoke-agent -q "What is the capital of France?"
# Run an agent query with explainability
tg-invoke-agent -q "What is quantum computing?" -x
# List traces (should show agent sessions with Type column)
tg-list-explain-traces -U trustgraph -C default
# Run with token usage
tg-invoke-agent -q "What is quantum computing?" --show-usage
# Show agent trace
tg-show-explain-trace "urn:trustgraph:agent:xxx"
# GraphRAG with explainability
tg-invoke-graph-rag -q "Tell me about AI" -x
# Document RAG with explainability
tg-invoke-document-rag -q "Summarize the findings" -x
```
## Future Work (Not This PR)
- DAG dependencies (when analysis N uses results from multiple prior analyses)
- Tool-specific provenance linking (KnowledgeQuery → its GraphRAG trace)
- Streaming provenance emission (emit as we go, not batch at end)

View file

@ -868,7 +868,7 @@ independently.
Response chunk fields:
message_id UUID for this message (groups chunks)
session_id Which agent session produced this chunk
chunk_type "thought" | "observation" | "answer" | ...
message_type "thought" | "observation" | "answer" | ...
content The chunk text
end_of_message True on the final chunk of this message
end_of_dialog True on the final message of the entire execution

View file

@ -209,3 +209,7 @@ def subgraph_provenance_triples(
This is a breaking change to the provenance model. Provenance has not
been released, so no migration is needed. The old `tg:reifies` /
`statement_uri` code can be removed outright.
## Vocabulary Reference
The full OWL ontology covering all extraction and query-time classes and predicates is at `specs/ontology/trustgraph.ttl`.

View file

@ -618,8 +618,13 @@ Link embedding entity IDs to chunk.
| Triple store | Use reification for triple → chunk provenance |
| Embedding provenance | Link entity ID → chunk ID |
## Vocabulary Reference
The full OWL ontology covering all extraction and query-time classes and predicates is at `specs/ontology/trustgraph.ttl`.
## References
- Query-time provenance: `docs/tech-specs/query-time-provenance.md`
- Query-time provenance: `docs/tech-specs/query-time-explainability.md`
- Agent explainability: `docs/tech-specs/agent-explainability.md`
- PROV-O standard for provenance modeling
- Existing source metadata in knowledge graph (needs audit)

View file

@ -710,17 +710,17 @@ class StreamingChunk:
@dataclasses.dataclass
class AgentThought(StreamingChunk):
"""Agent reasoning chunk"""
chunk_type: str = "thought"
message_type: str = "thought"
@dataclasses.dataclass
class AgentObservation(StreamingChunk):
"""Agent tool observation chunk"""
chunk_type: str = "observation"
message_type: str = "observation"
@dataclasses.dataclass
class AgentAnswer(StreamingChunk):
"""Agent final answer chunk"""
chunk_type: str = "final-answer"
message_type: str = "final-answer"
end_of_dialog: bool = False
@dataclasses.dataclass

View file

@ -144,6 +144,25 @@ Defined in `trustgraph-base/trustgraph/provenance/namespaces.py`:
| `TG_REASONING` | `https://trustgraph.ai/ns/reasoning` |
| `TG_CONTENT` | `https://trustgraph.ai/ns/content` |
| `TG_DOCUMENT` | `https://trustgraph.ai/ns/document` |
| `TG_IN_TOKEN` | `https://trustgraph.ai/ns/inToken` |
| `TG_OUT_TOKEN` | `https://trustgraph.ai/ns/outToken` |
| `TG_LLM_MODEL` | `https://trustgraph.ai/ns/llmModel` |
### Token Usage on Events
Grounding, Focus, and Synthesis events carry per-event LLM token counts:
| Predicate | Type | Present on |
|-----------|------|------------|
| `tg:inToken` | integer | Grounding, Focus, Synthesis |
| `tg:outToken` | integer | Grounding, Focus, Synthesis |
| `tg:llmModel` | string | Grounding, Focus, Synthesis |
- **Grounding**: tokens from the extract-concepts LLM call
- **Focus**: summed tokens from edge-scoring + edge-reasoning LLM calls
- **Synthesis**: tokens from the synthesis LLM call
Values are absent (not zero) when token counts are unavailable.
## GraphRagResponse Schema
@ -267,8 +286,13 @@ Based on the provided knowledge statements...
| `trustgraph-flow/trustgraph/query/triples/cassandra/service.py` | Quoted triple query support |
| `trustgraph-cli/trustgraph/cli/invoke_graph_rag.py` | CLI with explainability display |
## Vocabulary Reference
The full OWL ontology covering all classes and predicates is at `specs/ontology/trustgraph.ttl`.
## References
- PROV-O (W3C Provenance Ontology): https://www.w3.org/TR/prov-o/
- RDF-star: https://w3c.github.io/rdf-star/
- Extraction-time provenance: `docs/tech-specs/extraction-time-provenance.md`
- Agent explainability: `docs/tech-specs/agent-explainability.md`

View file

@ -399,28 +399,28 @@ The agent produces multiple types of output during its reasoning cycle:
- Answer (final response)
- Errors
Since `chunk_type` identifies what kind of content is being sent, the separate
Since `message_type` identifies what kind of content is being sent, the separate
`answer`, `error`, `thought`, and `observation` fields can be collapsed into
a single `content` field:
```python
class AgentResponse(Record):
chunk_type = String() # "thought", "action", "observation", "answer", "error"
content = String() # The actual content (interpretation depends on chunk_type)
message_type = String() # "thought", "action", "observation", "answer", "error"
content = String() # The actual content (interpretation depends on message_type)
end_of_message = Boolean() # Current thought/action/observation/answer is complete
end_of_dialog = Boolean() # Entire agent dialog is complete
```
**Field Semantics:**
- `chunk_type`: Indicates what type of content is in the `content` field
- `message_type`: Indicates what type of content is in the `content` field
- `"thought"`: Agent reasoning/thinking
- `"action"`: Tool/action being invoked
- `"observation"`: Result from tool execution
- `"answer"`: Final answer to the user's question
- `"error"`: Error message
- `content`: The actual streamed content, interpreted based on `chunk_type`
- `content`: The actual streamed content, interpreted based on `message_type`
- `end_of_message`: When `true`, the current chunk type is complete
- Example: All tokens for the current thought have been sent
@ -434,27 +434,27 @@ class AgentResponse(Record):
When `streaming=true`:
1. **Thought streaming**:
- Multiple chunks with `chunk_type="thought"`, `end_of_message=false`
- Multiple chunks with `message_type="thought"`, `end_of_message=false`
- Final thought chunk has `end_of_message=true`
2. **Action notification**:
- Single chunk with `chunk_type="action"`, `end_of_message=true`
- Single chunk with `message_type="action"`, `end_of_message=true`
3. **Observation**:
- Chunk(s) with `chunk_type="observation"`, final has `end_of_message=true`
- Chunk(s) with `message_type="observation"`, final has `end_of_message=true`
4. **Repeat** steps 1-3 as the agent reasons
5. **Final answer**:
- `chunk_type="answer"` with the final response in `content`
- `message_type="answer"` with the final response in `content`
- Last chunk has `end_of_message=true`, `end_of_dialog=true`
**Example Stream Sequence:**
```
{chunk_type: "thought", content: "I need to", end_of_message: false, end_of_dialog: false}
{chunk_type: "thought", content: " search for...", end_of_message: true, end_of_dialog: false}
{chunk_type: "action", content: "search", end_of_message: true, end_of_dialog: false}
{chunk_type: "observation", content: "Found: ...", end_of_message: true, end_of_dialog: false}
{chunk_type: "thought", content: "Based on this", end_of_message: false, end_of_dialog: false}
{chunk_type: "thought", content: " I can answer...", end_of_message: true, end_of_dialog: false}
{chunk_type: "answer", content: "The answer is...", end_of_message: true, end_of_dialog: true}
{message_type: "thought", content: "I need to", end_of_message: false, end_of_dialog: false}
{message_type: "thought", content: " search for...", end_of_message: true, end_of_dialog: false}
{message_type: "action", content: "search", end_of_message: true, end_of_dialog: false}
{message_type: "observation", content: "Found: ...", end_of_message: true, end_of_dialog: false}
{message_type: "thought", content: "Based on this", end_of_message: false, end_of_dialog: false}
{message_type: "thought", content: " I can answer...", end_of_message: true, end_of_dialog: false}
{message_type: "answer", content: "The answer is...", end_of_message: true, end_of_dialog: true}
```
When `streaming=false`:
@ -547,7 +547,7 @@ The following questions were resolved during specification:
populated and no other fields are needed. An error is always the final
communication - no subsequent messages are permitted or expected after
an error. For LLM/Prompt streams, `end_of_stream=true`. For Agent streams,
`chunk_type="error"` with `end_of_dialog=true`.
`message_type="error"` with `end_of_dialog=true`.
3. **Partial Response Recovery**: The messaging protocol (Pulsar) is resilient,
so message-level retry is not needed. If a client loses track of the stream

View file

@ -0,0 +1,415 @@
@prefix tg: <https://trustgraph.ai/ns/> .
@prefix owl: <http://www.w3.org/2002/07/owl#> .
@prefix rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#> .
@prefix rdfs: <http://www.w3.org/2000/01/rdf-schema#> .
@prefix xsd: <http://www.w3.org/2001/XMLSchema#> .
@prefix prov: <http://www.w3.org/ns/prov#> .
# =============================================================================
# Ontology declaration
# =============================================================================
<https://trustgraph.ai/ns/>
a owl:Ontology ;
rdfs:label "TrustGraph Ontology" ;
rdfs:comment "Vocabulary for TrustGraph provenance, extraction metadata, and explainability." ;
owl:versionInfo "2.3" .
# =============================================================================
# Classes — Extraction provenance
# =============================================================================
tg:Document a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Document" ;
rdfs:comment "A loaded document (PDF, text, etc.)." .
tg:Page a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Page" ;
rdfs:comment "A page within a document." .
tg:Section a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Section" ;
rdfs:comment "A structural section within a document." .
tg:Chunk a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Chunk" ;
rdfs:comment "A text chunk produced by the chunker." .
tg:Image a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Image" ;
rdfs:comment "An image extracted from a document." .
tg:Subgraph a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Subgraph" ;
rdfs:comment "A set of triples extracted from a chunk." .
# =============================================================================
# Classes — Query-time explainability (shared)
# =============================================================================
tg:Question a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Question" ;
rdfs:comment "Root entity for a query session." .
tg:GraphRagQuestion a owl:Class ;
rdfs:subClassOf tg:Question ;
rdfs:label "Graph RAG Question" ;
rdfs:comment "A question answered via graph-based RAG." .
tg:DocRagQuestion a owl:Class ;
rdfs:subClassOf tg:Question ;
rdfs:label "Document RAG Question" ;
rdfs:comment "A question answered via document-based RAG." .
tg:AgentQuestion a owl:Class ;
rdfs:subClassOf tg:Question ;
rdfs:label "Agent Question" ;
rdfs:comment "A question answered via the agent orchestrator." .
tg:Grounding a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Grounding" ;
rdfs:comment "Concept extraction step (query decomposition into search terms)." .
tg:Exploration a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Exploration" ;
rdfs:comment "Entity/chunk retrieval step." .
tg:Focus a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Focus" ;
rdfs:comment "Edge selection and scoring step (GraphRAG)." .
tg:Synthesis a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Synthesis" ;
rdfs:comment "Final answer synthesis from retrieved context." .
# =============================================================================
# Classes — Agent provenance
# =============================================================================
tg:Analysis a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Analysis" ;
rdfs:comment "One agent iteration: reasoning followed by tool selection." .
tg:ToolUse a owl:Class ;
rdfs:label "ToolUse" ;
rdfs:comment "Mixin type applied to Analysis when a tool is invoked." .
tg:Error a owl:Class ;
rdfs:label "Error" ;
rdfs:comment "Mixin type applied to events where a failure occurred (tool error, parse error)." .
tg:Conclusion a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Conclusion" ;
rdfs:comment "Agent final answer (ReAct pattern)." .
tg:PatternDecision a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Pattern Decision" ;
rdfs:comment "Meta-router decision recording which execution pattern was selected." .
# --- Unifying types ---
tg:Answer a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Answer" ;
rdfs:comment "Unifying type for any terminal answer (Synthesis, Conclusion, Finding, StepResult)." .
tg:Reflection a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Reflection" ;
rdfs:comment "Unifying type for intermediate commentary (Thought, Observation)." .
tg:Thought a owl:Class ;
rdfs:subClassOf tg:Reflection ;
rdfs:label "Thought" ;
rdfs:comment "Agent reasoning text within an iteration." .
tg:Observation a owl:Class ;
rdfs:subClassOf tg:Reflection ;
rdfs:label "Observation" ;
rdfs:comment "Tool execution result." .
# --- Orchestrator types ---
tg:Decomposition a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Decomposition" ;
rdfs:comment "Supervisor pattern: question decomposed into sub-goals." .
tg:Finding a owl:Class ;
rdfs:subClassOf tg:Answer ;
rdfs:label "Finding" ;
rdfs:comment "Result from a sub-agent execution." .
tg:Plan a owl:Class ;
rdfs:subClassOf prov:Entity ;
rdfs:label "Plan" ;
rdfs:comment "Plan-then-execute pattern: structured plan of steps." .
tg:StepResult a owl:Class ;
rdfs:subClassOf tg:Answer ;
rdfs:label "Step Result" ;
rdfs:comment "Result from executing one plan step." .
# =============================================================================
# Properties — Extraction metadata
# =============================================================================
tg:contains a owl:ObjectProperty ;
rdfs:label "contains" ;
rdfs:comment "Links a parent entity to a child (e.g. Document contains Page, Subgraph contains triple)." .
tg:pageCount a owl:DatatypeProperty ;
rdfs:label "page count" ;
rdfs:range xsd:integer ;
rdfs:domain tg:Document .
tg:mimeType a owl:DatatypeProperty ;
rdfs:label "MIME type" ;
rdfs:range xsd:string ;
rdfs:domain tg:Document .
tg:pageNumber a owl:DatatypeProperty ;
rdfs:label "page number" ;
rdfs:range xsd:integer ;
rdfs:domain tg:Page .
tg:chunkIndex a owl:DatatypeProperty ;
rdfs:label "chunk index" ;
rdfs:range xsd:integer ;
rdfs:domain tg:Chunk .
tg:charOffset a owl:DatatypeProperty ;
rdfs:label "character offset" ;
rdfs:range xsd:integer .
tg:charLength a owl:DatatypeProperty ;
rdfs:label "character length" ;
rdfs:range xsd:integer .
tg:chunkSize a owl:DatatypeProperty ;
rdfs:label "chunk size" ;
rdfs:range xsd:integer .
tg:chunkOverlap a owl:DatatypeProperty ;
rdfs:label "chunk overlap" ;
rdfs:range xsd:integer .
tg:componentVersion a owl:DatatypeProperty ;
rdfs:label "component version" ;
rdfs:range xsd:string .
tg:llmModel a owl:DatatypeProperty ;
rdfs:label "LLM model" ;
rdfs:range xsd:string .
tg:ontology a owl:DatatypeProperty ;
rdfs:label "ontology" ;
rdfs:range xsd:string .
tg:embeddingModel a owl:DatatypeProperty ;
rdfs:label "embedding model" ;
rdfs:range xsd:string .
tg:sourceText a owl:DatatypeProperty ;
rdfs:label "source text" ;
rdfs:range xsd:string .
tg:sourceCharOffset a owl:DatatypeProperty ;
rdfs:label "source character offset" ;
rdfs:range xsd:integer .
tg:sourceCharLength a owl:DatatypeProperty ;
rdfs:label "source character length" ;
rdfs:range xsd:integer .
tg:elementTypes a owl:DatatypeProperty ;
rdfs:label "element types" ;
rdfs:range xsd:string .
tg:tableCount a owl:DatatypeProperty ;
rdfs:label "table count" ;
rdfs:range xsd:integer .
tg:imageCount a owl:DatatypeProperty ;
rdfs:label "image count" ;
rdfs:range xsd:integer .
# =============================================================================
# Properties — Query-time provenance (GraphRAG / DocumentRAG)
# =============================================================================
tg:query a owl:DatatypeProperty ;
rdfs:label "query" ;
rdfs:comment "The user's query text." ;
rdfs:range xsd:string ;
rdfs:domain tg:Question .
tg:concept a owl:DatatypeProperty ;
rdfs:label "concept" ;
rdfs:comment "An extracted concept from the query." ;
rdfs:range xsd:string ;
rdfs:domain tg:Grounding .
tg:entity a owl:ObjectProperty ;
rdfs:label "entity" ;
rdfs:comment "A seed entity retrieved during exploration." ;
rdfs:domain tg:Exploration .
tg:edgeCount a owl:DatatypeProperty ;
rdfs:label "edge count" ;
rdfs:comment "Number of edges explored." ;
rdfs:range xsd:integer ;
rdfs:domain tg:Exploration .
tg:selectedEdge a owl:ObjectProperty ;
rdfs:label "selected edge" ;
rdfs:comment "Link to an edge selection entity within a Focus event." ;
rdfs:domain tg:Focus .
tg:edge a owl:ObjectProperty ;
rdfs:label "edge" ;
rdfs:comment "A quoted triple representing a knowledge graph edge." ;
rdfs:domain tg:Focus .
tg:reasoning a owl:DatatypeProperty ;
rdfs:label "reasoning" ;
rdfs:comment "LLM-generated reasoning for an edge selection." ;
rdfs:range xsd:string .
tg:document a owl:ObjectProperty ;
rdfs:label "document" ;
rdfs:comment "Reference to a document stored in the librarian." .
tg:chunkCount a owl:DatatypeProperty ;
rdfs:label "chunk count" ;
rdfs:comment "Number of document chunks retrieved (DocumentRAG)." ;
rdfs:range xsd:integer ;
rdfs:domain tg:Exploration .
tg:selectedChunk a owl:DatatypeProperty ;
rdfs:label "selected chunk" ;
rdfs:comment "A selected chunk ID (DocumentRAG)." ;
rdfs:range xsd:string ;
rdfs:domain tg:Exploration .
# =============================================================================
# Properties — Agent provenance
# =============================================================================
tg:thought a owl:ObjectProperty ;
rdfs:label "thought" ;
rdfs:comment "Links an Analysis iteration to its Thought sub-entity." ;
rdfs:domain tg:Analysis ;
rdfs:range tg:Thought .
tg:action a owl:DatatypeProperty ;
rdfs:label "action" ;
rdfs:comment "The tool/action name selected by the agent." ;
rdfs:range xsd:string ;
rdfs:domain tg:Analysis .
tg:arguments a owl:DatatypeProperty ;
rdfs:label "arguments" ;
rdfs:comment "JSON-encoded arguments passed to the tool." ;
rdfs:range xsd:string ;
rdfs:domain tg:Analysis .
tg:observation a owl:ObjectProperty ;
rdfs:label "observation" ;
rdfs:comment "Links an Analysis iteration to its Observation sub-entity." ;
rdfs:domain tg:Analysis ;
rdfs:range tg:Observation .
tg:toolCandidate a owl:DatatypeProperty ;
rdfs:label "tool candidate" ;
rdfs:comment "Name of a tool available to the LLM for this iteration. One triple per candidate." ;
rdfs:range xsd:string ;
rdfs:domain tg:Analysis .
tg:stepNumber a owl:DatatypeProperty ;
rdfs:label "step number" ;
rdfs:comment "Explicit 1-based step counter for iteration events." ;
rdfs:range xsd:integer ;
rdfs:domain tg:Analysis .
tg:terminationReason a owl:DatatypeProperty ;
rdfs:label "termination reason" ;
rdfs:comment "Why the agent loop stopped: final-answer, plan-complete, subagents-complete, max-iterations, error." ;
rdfs:range xsd:string .
tg:pattern a owl:DatatypeProperty ;
rdfs:label "pattern" ;
rdfs:comment "Selected execution pattern (react, plan-then-execute, supervisor)." ;
rdfs:range xsd:string ;
rdfs:domain tg:PatternDecision .
tg:taskType a owl:DatatypeProperty ;
rdfs:label "task type" ;
rdfs:comment "Identified task type from the meta-router (general, research, etc.)." ;
rdfs:range xsd:string ;
rdfs:domain tg:PatternDecision .
tg:llmDurationMs a owl:DatatypeProperty ;
rdfs:label "LLM duration (ms)" ;
rdfs:comment "Time spent in the LLM prompt call, in milliseconds." ;
rdfs:range xsd:integer ;
rdfs:domain tg:Analysis .
tg:toolDurationMs a owl:DatatypeProperty ;
rdfs:label "tool duration (ms)" ;
rdfs:comment "Time spent executing the tool, in milliseconds." ;
rdfs:range xsd:integer ;
rdfs:domain tg:Observation .
tg:toolError a owl:DatatypeProperty ;
rdfs:label "tool error" ;
rdfs:comment "Error message from a failed tool execution." ;
rdfs:range xsd:string ;
rdfs:domain tg:Observation .
# --- Token usage predicates (on any event that involves an LLM call) ---
tg:inToken a owl:DatatypeProperty ;
rdfs:label "input tokens" ;
rdfs:comment "Input token count for the LLM call associated with this event." ;
rdfs:range xsd:integer .
tg:outToken a owl:DatatypeProperty ;
rdfs:label "output tokens" ;
rdfs:comment "Output token count for the LLM call associated with this event." ;
rdfs:range xsd:integer .
# --- Orchestrator predicates ---
tg:subagentGoal a owl:DatatypeProperty ;
rdfs:label "sub-agent goal" ;
rdfs:comment "Goal string assigned to a sub-agent (Decomposition, Finding)." ;
rdfs:range xsd:string .
tg:planStep a owl:DatatypeProperty ;
rdfs:label "plan step" ;
rdfs:comment "Goal string for a plan step (Plan, StepResult)." ;
rdfs:range xsd:string .
# =============================================================================
# Named graphs
# =============================================================================
# These are not OWL classes but documented here for reference:
#
# (default graph) — Core knowledge facts (extracted triples)
# urn:graph:source — Extraction provenance (document → chunk → triple)
# urn:graph:retrieval — Query-time explainability (question → exploration → synthesis)

View file

@ -87,7 +87,7 @@ def sample_message_data():
"history": []
},
"AgentResponse": {
"chunk_type": "answer",
"message_type": "answer",
"content": "Machine learning is a subset of AI.",
"end_of_message": True,
"end_of_dialog": True,

View file

@ -212,7 +212,7 @@ class TestAgentMessageContracts:
# Test required fields
response = AgentResponse(**response_data)
assert hasattr(response, 'chunk_type')
assert hasattr(response, 'message_type')
assert hasattr(response, 'content')
assert hasattr(response, 'end_of_message')
assert hasattr(response, 'end_of_dialog')

View file

@ -0,0 +1,73 @@
"""
Contract tests for schema dataclass field sets.
These pin the *field names* of small, widely-constructed schema dataclasses
so that any rename, removal, or accidental addition fails CI loudly instead
of waiting for a runtime TypeError on the next websocket message.
Background: in v2.2 the `Metadata` dataclass dropped a `metadata: list[Triple]`
field but several call sites kept passing `Metadata(metadata=...)`. The bug
was only discovered when a websocket import dispatcher received its first
real message in production. A trivial structural assertion of the kind
below would have caught it at unit-test time.
Add to this file whenever a schema rename burns you. The cost of a frozen
field set is a one-line update when you intentionally evolve the schema; the
benefit is that every call site is forced to come along for the ride.
"""
import dataclasses
import pytest
from trustgraph.schema import (
Metadata,
EntityContext,
EntityEmbeddings,
ChunkEmbeddings,
)
def _field_names(dc):
return {f.name for f in dataclasses.fields(dc)}
@pytest.mark.contract
class TestSchemaFieldContracts:
"""Pin the field set of dataclasses that get constructed all over the
codebase. If you intentionally change one of these, update the
expected set in the same commit that diff will surface every call
site that needs to come along."""
def test_metadata_fields(self):
# NOTE: there is no `metadata` field. A previous regression
# constructed Metadata(metadata=...) and crashed at runtime.
assert _field_names(Metadata) == {
"id",
"root",
"user",
"collection",
}
def test_entity_embeddings_fields(self):
# NOTE: the embedding field is `vector` (singular, list[float]).
# There is no `vectors` field. Several call sites historically
# passed `vectors=` and crashed at runtime.
assert _field_names(EntityEmbeddings) == {
"entity",
"vector",
"chunk_id",
}
def test_chunk_embeddings_fields(self):
# Same `vector` (singular) convention as EntityEmbeddings.
assert _field_names(ChunkEmbeddings) == {
"chunk_id",
"vector",
}
def test_entity_context_fields(self):
assert _field_names(EntityContext) == {
"entity",
"context",
"chunk_id",
}

View file

@ -188,7 +188,7 @@ class TestAgentTranslatorCompletionFlags:
# Arrange
translator = TranslatorRegistry.get_response_translator("agent")
response = AgentResponse(
chunk_type="answer",
message_type="answer",
content="4",
end_of_message=True,
end_of_dialog=True,
@ -210,7 +210,7 @@ class TestAgentTranslatorCompletionFlags:
# Arrange
translator = TranslatorRegistry.get_response_translator("agent")
response = AgentResponse(
chunk_type="thought",
message_type="thought",
content="I need to solve this.",
end_of_message=True,
end_of_dialog=False,
@ -233,7 +233,7 @@ class TestAgentTranslatorCompletionFlags:
# Test thought message
thought_response = AgentResponse(
chunk_type="thought",
message_type="thought",
content="Processing...",
end_of_message=True,
end_of_dialog=False,
@ -247,7 +247,7 @@ class TestAgentTranslatorCompletionFlags:
# Test observation message
observation_response = AgentResponse(
chunk_type="observation",
message_type="observation",
content="Result found",
end_of_message=True,
end_of_dialog=False,
@ -268,7 +268,7 @@ class TestAgentTranslatorCompletionFlags:
# Streaming format with end_of_dialog=True
response = AgentResponse(
chunk_type="answer",
message_type="answer",
content="",
end_of_message=True,
end_of_dialog=True,

View file

@ -418,55 +418,55 @@ def sample_streaming_agent_response():
"""Sample streaming agent response chunks"""
return [
{
"chunk_type": "thought",
"message_type": "thought",
"content": "I need to search",
"end_of_message": False,
"end_of_dialog": False
},
{
"chunk_type": "thought",
"message_type": "thought",
"content": " for information",
"end_of_message": False,
"end_of_dialog": False
},
{
"chunk_type": "thought",
"message_type": "thought",
"content": " about machine learning.",
"end_of_message": True,
"end_of_dialog": False
},
{
"chunk_type": "action",
"message_type": "action",
"content": "knowledge_query",
"end_of_message": True,
"end_of_dialog": False
},
{
"chunk_type": "observation",
"message_type": "observation",
"content": "Machine learning is",
"end_of_message": False,
"end_of_dialog": False
},
{
"chunk_type": "observation",
"message_type": "observation",
"content": " a subset of AI.",
"end_of_message": True,
"end_of_dialog": False
},
{
"chunk_type": "final-answer",
"message_type": "final-answer",
"content": "Machine learning",
"end_of_message": False,
"end_of_dialog": False
},
{
"chunk_type": "final-answer",
"message_type": "final-answer",
"content": " is a subset",
"end_of_message": False,
"end_of_dialog": False
},
{
"chunk_type": "final-answer",
"message_type": "final-answer",
"content": " of artificial intelligence.",
"end_of_message": True,
"end_of_dialog": True
@ -494,10 +494,10 @@ def streaming_chunk_collector():
"""Concatenate all chunk content"""
return "".join(self.chunks)
def get_chunk_types(self):
def get_message_types(self):
"""Get list of chunk types if chunks are dicts"""
if self.chunks and isinstance(self.chunks[0], dict):
return [c.get("chunk_type") for c in self.chunks]
return [c.get("message_type") for c in self.chunks]
return []
def verify_streaming_protocol(self):

View file

@ -15,6 +15,7 @@ from trustgraph.agent.react.agent_manager import AgentManager
from trustgraph.agent.react.tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl
from trustgraph.agent.react.types import Action, Final, Tool, Argument
from trustgraph.schema import AgentRequest, AgentResponse, AgentStep, Error
from trustgraph.base import PromptResult
@pytest.mark.integration
@ -28,19 +29,25 @@ class TestAgentManagerIntegration:
# Mock prompt client
prompt_client = AsyncMock()
prompt_client.agent_react.return_value = """Thought: I need to search for information about machine learning
prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to search for information about machine learning
Action: knowledge_query
Args: {
"question": "What is machine learning?"
}"""
)
# Mock graph RAG client
graph_rag_client = AsyncMock()
graph_rag_client.rag.return_value = "Machine learning is a subset of AI that enables computers to learn from data."
# Mock text completion client
text_completion_client = AsyncMock()
text_completion_client.question.return_value = "Machine learning involves algorithms that improve through experience."
text_completion_client.question.return_value = PromptResult(
response_type="text",
text="Machine learning involves algorithms that improve through experience."
)
# Mock MCP tool client
mcp_tool_client = AsyncMock()
@ -147,8 +154,11 @@ Args: {
async def test_agent_manager_final_answer(self, agent_manager, mock_flow_context):
"""Test agent manager returning final answer"""
# Arrange
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have enough information to answer the question
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I have enough information to answer the question
Final Answer: Machine learning is a field of AI that enables computers to learn from data."""
)
question = "What is machine learning?"
history = []
@ -193,8 +203,11 @@ Final Answer: Machine learning is a field of AI that enables computers to learn
async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context):
"""Test ReAct cycle ending with final answer"""
# Arrange
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide a direct answer
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I can provide a direct answer
Final Answer: Machine learning is a branch of artificial intelligence."""
)
question = "What is machine learning?"
history = []
@ -254,11 +267,14 @@ Final Answer: Machine learning is a branch of artificial intelligence."""
for tool_name, expected_service in tool_scenarios:
# Arrange
mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: I need to use {tool_name}
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text=f"""Thought: I need to use {tool_name}
Action: {tool_name}
Args: {{
"question": "test question"
}}"""
)
think_callback = AsyncMock()
observe_callback = AsyncMock()
@ -284,11 +300,14 @@ Args: {{
async def test_agent_manager_unknown_tool_error(self, agent_manager, mock_flow_context):
"""Test agent manager error handling for unknown tool"""
# Arrange
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to use an unknown tool
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to use an unknown tool
Action: unknown_tool
Args: {
"param": "value"
}"""
)
think_callback = AsyncMock()
observe_callback = AsyncMock()
@ -308,11 +327,13 @@ Args: {
think_callback = AsyncMock()
observe_callback = AsyncMock()
# Act & Assert
with pytest.raises(Exception) as exc_info:
await agent_manager.react("test question", [], think_callback, observe_callback, mock_flow_context)
assert "Tool execution failed" in str(exc_info.value)
# Act - tool errors are now caught and returned as observations
result = await agent_manager.react("test question", [], think_callback, observe_callback, mock_flow_context)
# Assert - error captured on the action, not raised
assert result.tool_error is not None
assert "Tool execution failed" in result.tool_error
assert "Error:" in result.observation
@pytest.mark.asyncio
async def test_agent_manager_multiple_tools_coordination(self, agent_manager, mock_flow_context):
@ -321,11 +342,14 @@ Args: {
question = "Find information about AI and summarize it"
# Mock multi-step reasoning
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search for AI information first
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to search for AI information first
Action: knowledge_query
Args: {
"question": "What is artificial intelligence?"
}"""
)
# Act
action = await agent_manager.reason(question, [], mock_flow_context)
@ -372,9 +396,12 @@ Args: {
# Format arguments as JSON
import json
args_json = json.dumps(test_case['arguments'], indent=4)
mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: Using {test_case['action']}
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text=f"""Thought: Using {test_case['action']}
Action: {test_case['action']}
Args: {args_json}"""
)
think_callback = AsyncMock()
observe_callback = AsyncMock()
@ -507,15 +534,17 @@ Args: {
]
for test_case in test_cases:
mock_flow_context("prompt-request").agent_react.return_value = test_case["response"]
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text=test_case["response"]
)
if test_case["error_contains"]:
# Should raise an error
with pytest.raises(RuntimeError) as exc_info:
await agent_manager.reason("test question", [], mock_flow_context)
assert "Failed to parse agent response" in str(exc_info.value)
assert test_case["error_contains"] in str(exc_info.value)
# Parse errors now return an Action with tool_error
result = await agent_manager.reason("test question", [], mock_flow_context)
assert isinstance(result, Action)
assert result.name == "__parse_error__"
assert result.tool_error is not None
else:
# Should succeed
action = await agent_manager.reason("test question", [], mock_flow_context)
@ -527,13 +556,16 @@ Args: {
async def test_agent_manager_text_parsing_edge_cases(self, agent_manager, mock_flow_context):
"""Test edge cases in text parsing"""
# Test response with markdown code blocks
mock_flow_context("prompt-request").agent_react.return_value = """```
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""```
Thought: I need to search for information
Action: knowledge_query
Args: {
"question": "What is AI?"
}
```"""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@ -541,15 +573,18 @@ Args: {
assert action.name == "knowledge_query"
# Test response with extra whitespace
mock_flow_context("prompt-request").agent_react.return_value = """
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""
Thought: I need to think about this
Action: knowledge_query
Thought: I need to think about this
Action: knowledge_query
Args: {
"question": "test"
}
"""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@ -560,7 +595,9 @@ Args: {
async def test_agent_manager_multiline_content(self, agent_manager, mock_flow_context):
"""Test handling of multi-line thoughts and final answers"""
# Multi-line thought
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to consider multiple factors:
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to consider multiple factors:
1. The user's question is complex
2. I should search for comprehensive information
3. This requires using the knowledge query tool
@ -568,6 +605,7 @@ Action: knowledge_query
Args: {
"question": "complex query"
}"""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@ -575,13 +613,16 @@ Args: {
assert "knowledge query tool" in action.thought
# Multi-line final answer
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have gathered enough information
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I have gathered enough information
Final Answer: Here is a comprehensive answer:
1. First point about the topic
2. Second point with details
3. Final conclusion
This covers all aspects of the question."""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Final)
@ -593,13 +634,16 @@ This covers all aspects of the question."""
async def test_agent_manager_json_args_special_characters(self, agent_manager, mock_flow_context):
"""Test JSON arguments with special characters and edge cases"""
# Test with special characters in JSON (properly escaped)
mock_flow_context("prompt-request").agent_react.return_value = """Thought: Processing special characters
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: Processing special characters
Action: knowledge_query
Args: {
"question": "What about \\"quotes\\" and 'apostrophes'?",
"context": "Line 1\\nLine 2\\tTabbed",
"special": "Symbols: @#$%^&*()_+-=[]{}|;':,.<>?"
}"""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@ -608,7 +652,9 @@ Args: {
assert "@#$%^&*" in action.arguments["special"]
# Test with nested JSON
mock_flow_context("prompt-request").agent_react.return_value = """Thought: Complex arguments
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: Complex arguments
Action: web_search
Args: {
"query": "test",
@ -621,6 +667,7 @@ Args: {
}
}
}"""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@ -632,7 +679,9 @@ Args: {
async def test_agent_manager_final_answer_json_format(self, agent_manager, mock_flow_context):
"""Test final answers that contain JSON-like content"""
# Final answer with JSON content
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide the data in JSON format
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I can provide the data in JSON format
Final Answer: {
"result": "success",
"data": {
@ -642,6 +691,7 @@ Final Answer: {
},
"confidence": 0.95
}"""
)
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Final)
@ -792,11 +842,14 @@ Final Answer: {
agent = AgentManager(tools=custom_tools, additional_context="")
# Mock response for custom collection query
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search in the research papers
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to search in the research papers
Action: knowledge_query_custom
Args: {
"question": "Latest AI research?"
}"""
)
think_callback = AsyncMock()
observe_callback = AsyncMock()

View file

@ -10,11 +10,12 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.agent.react.agent_manager import AgentManager
from trustgraph.agent.react.tools import KnowledgeQueryImpl
from trustgraph.agent.react.types import Tool, Argument
from trustgraph.base import PromptResult
from tests.utils.streaming_assertions import (
assert_agent_streaming_chunks,
assert_streaming_chunks_valid,
assert_callback_invoked,
assert_chunk_types_valid,
assert_message_types_valid,
)
@ -51,10 +52,10 @@ Args: {
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final)
return full_text
return PromptResult(response_type="text", text=full_text)
else:
# Non-streaming response - same text
return full_text
return PromptResult(response_type="text", text=full_text)
client.agent_react.side_effect = agent_react_streaming
return client
@ -317,8 +318,8 @@ Final Answer: AI is the simulation of human intelligence in machines."""
for i, chunk in enumerate(chunks):
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk + " ", is_final)
return response
return response
return PromptResult(response_type="text", text=response)
return PromptResult(response_type="text", text=response)
mock_prompt_client_streaming.agent_react.side_effect = multi_step_agent_react

View file

@ -16,6 +16,7 @@ from trustgraph.schema import (
Error
)
from trustgraph.agent.react.service import Processor
from trustgraph.base import PromptResult
@pytest.mark.integration
@ -95,11 +96,14 @@ class TestAgentStructuredQueryIntegration:
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from New York using structured query
mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to find customers from New York using structured query
Action: structured-query
Args: {
"question": "Find all customers from New York"
}"""
)
# Set up flow context routing
def flow_context(service_name):
@ -173,11 +177,14 @@ Args: {
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to query for a table that might not exist
mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to query for a table that might not exist
Action: structured-query
Args: {
"question": "Find data from a table that doesn't exist"
}"""
)
# Set up flow context routing
def flow_context(service_name):
@ -250,11 +257,14 @@ Args: {
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from California first
mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to find customers from California first
Action: structured-query
Args: {
"question": "Find all customers from California"
}"""
)
# Set up flow context routing
def flow_context(service_name):
@ -339,11 +349,14 @@ Args: {
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to query the sales data
mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to query the sales data
Action: structured-query
Args: {
"question": "Query the sales data for recent transactions"
}"""
)
# Set up flow context routing
def flow_context(service_name):
@ -447,11 +460,14 @@ Args: {
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
mock_prompt_client.agent_react.return_value = """Thought: I need to get customer information
mock_prompt_client.agent_react.return_value = PromptResult(
response_type="text",
text="""Thought: I need to get customer information
Action: structured-query
Args: {
"question": "Get customer information and format it nicely"
}"""
)
# Set up flow context routing
def flow_context(service_name):

View file

@ -10,6 +10,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import ChunkMatch
from trustgraph.base import PromptResult
# Sample chunk content for testing - maps chunk_id to content
@ -61,11 +62,16 @@ class TestDocumentRagIntegration:
def mock_prompt_client(self):
"""Mock prompt client that generates realistic responses"""
client = AsyncMock()
client.document_prompt.return_value = (
"Machine learning is a field of artificial intelligence that enables computers to learn "
"and improve from experience without being explicitly programmed. It uses algorithms "
"to find patterns in data and make predictions or decisions."
client.document_prompt.return_value = PromptResult(
response_type="text",
text=(
"Machine learning is a field of artificial intelligence that enables computers to learn "
"and improve from experience without being explicitly programmed. It uses algorithms "
"to find patterns in data and make predictions or decisions."
)
)
# Mock prompt() for extract-concepts call in DocumentRag
client.prompt.return_value = PromptResult(response_type="text", text="")
return client
@pytest.fixture
@ -119,6 +125,7 @@ class TestDocumentRagIntegration:
)
# Verify final response
result, usage = result
assert result is not None
assert isinstance(result, str)
assert "machine learning" in result.lower()
@ -131,7 +138,11 @@ class TestDocumentRagIntegration:
"""Test DocumentRAG behavior when no documents are retrieved"""
# Arrange
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query."
mock_prompt_client.document_prompt.return_value = PromptResult(
response_type="text",
text="I couldn't find any relevant documents for your query."
)
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
@ -152,7 +163,8 @@ class TestDocumentRagIntegration:
documents=[]
)
assert result == "I couldn't find any relevant documents for your query."
result_text, usage = result
assert result_text == "I couldn't find any relevant documents for your query."
@pytest.mark.asyncio
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,

View file

@ -9,6 +9,7 @@ import pytest
from unittest.mock import AsyncMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import ChunkMatch
from trustgraph.base import PromptResult
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
assert_callback_invoked,
@ -74,12 +75,14 @@ class TestDocumentRagStreaming:
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final)
return full_text
return PromptResult(response_type="text", text=full_text)
else:
# Non-streaming response - same text
return full_text
return PromptResult(response_type="text", text=full_text)
client.document_prompt.side_effect = document_prompt_side_effect
# Mock prompt() for extract-concepts call in DocumentRag
client.prompt.return_value = PromptResult(response_type="text", text="")
return client
@pytest.fixture
@ -119,11 +122,12 @@ class TestDocumentRagStreaming:
collector.verify_streaming_protocol()
# Verify full response matches concatenated chunks
result_text, usage = result
full_from_chunks = collector.get_full_text()
assert result == full_from_chunks
assert result_text == full_from_chunks
# Verify content is reasonable
assert len(result) > 0
assert len(result_text) > 0
@pytest.mark.asyncio
async def test_document_rag_streaming_vs_non_streaming(self, document_rag_streaming):
@ -159,9 +163,11 @@ class TestDocumentRagStreaming:
)
# Assert - Results should be equivalent
assert streaming_result == non_streaming_result
non_streaming_text, _ = non_streaming_result
streaming_text, _ = streaming_result
assert streaming_text == non_streaming_text
assert len(streaming_chunks) > 0
assert "".join(streaming_chunks) == streaming_result
assert "".join(streaming_chunks) == streaming_text
@pytest.mark.asyncio
async def test_document_rag_streaming_callback_invocation(self, document_rag_streaming):
@ -180,8 +186,9 @@ class TestDocumentRagStreaming:
)
# Assert
result_text, usage = result
assert callback.call_count > 0
assert result is not None
assert result_text is not None
# Verify all callback invocations had string arguments
for call in callback.call_args_list:
@ -202,7 +209,8 @@ class TestDocumentRagStreaming:
# Assert - Should complete without error
assert result is not None
assert isinstance(result, str)
result_text, usage = result
assert isinstance(result_text, str)
@pytest.mark.asyncio
async def test_document_rag_streaming_with_no_documents(self, document_rag_streaming,
@ -223,7 +231,8 @@ class TestDocumentRagStreaming:
)
# Assert - Should still produce streamed response
assert result is not None
result_text, usage = result
assert result_text is not None
assert callback.call_count > 0
@pytest.mark.asyncio
@ -271,7 +280,8 @@ class TestDocumentRagStreaming:
)
# Assert
assert result is not None
result_text, usage = result
assert result_text is not None
assert callback.call_count > 0
# Verify doc_limit was passed correctly

View file

@ -12,6 +12,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.schema import EntityMatch, Term, IRI
from trustgraph.base import PromptResult
@pytest.mark.integration
@ -93,18 +94,21 @@ class TestGraphRagIntegration:
# 4. kg-synthesis returns the final answer
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "extract-concepts":
return "" # Falls back to raw query
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-scoring":
return "" # No edges scored
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-reasoning":
return "" # No reasoning
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
return (
"Machine learning is a subset of artificial intelligence that enables computers "
"to learn from data without being explicitly programmed. It uses algorithms "
"and statistical models to find patterns in data."
return PromptResult(
response_type="text",
text=(
"Machine learning is a subset of artificial intelligence that enables computers "
"to learn from data without being explicitly programmed. It uses algorithms "
"and statistical models to find patterns in data."
)
)
return ""
return PromptResult(response_type="text", text="")
client.prompt.side_effect = mock_prompt
return client
@ -169,6 +173,7 @@ class TestGraphRagIntegration:
assert mock_prompt_client.prompt.call_count == 4
# Verify final response
response, usage = response
assert response is not None
assert isinstance(response, str)
assert "machine learning" in response.lower()

View file

@ -9,6 +9,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.schema import EntityMatch, Term, IRI
from trustgraph.base import PromptResult
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
assert_rag_streaming_chunks,
@ -61,12 +62,12 @@ class TestGraphRagStreaming:
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
if prompt_id == "extract-concepts":
return "" # Falls back to raw query
return PromptResult(response_type="text", text="")
elif prompt_id == "kg-edge-scoring":
# Edge scoring returns JSONL with IDs and scores
return '{"id": "abc12345", "score": 0.9}\n'
return PromptResult(response_type="text", text='{"id": "abc12345", "score": 0.9}\n')
elif prompt_id == "kg-edge-reasoning":
return '{"id": "abc12345", "reasoning": "Relevant to query"}\n'
return PromptResult(response_type="text", text='{"id": "abc12345", "reasoning": "Relevant to query"}\n')
elif prompt_id == "kg-synthesis":
if streaming and chunk_callback:
# Simulate streaming chunks with end_of_stream flags
@ -79,10 +80,10 @@ class TestGraphRagStreaming:
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final)
return full_text
return PromptResult(response_type="text", text=full_text)
else:
return full_text
return ""
return PromptResult(response_type="text", text=full_text)
return PromptResult(response_type="text", text="")
client.prompt.side_effect = prompt_side_effect
return client
@ -123,6 +124,7 @@ class TestGraphRagStreaming:
)
# Assert
response, usage = response
assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
@ -172,9 +174,11 @@ class TestGraphRagStreaming:
)
# Assert - Results should be equivalent
assert streaming_response == non_streaming_response
non_streaming_text, _ = non_streaming_response
streaming_text, _ = streaming_response
assert streaming_text == non_streaming_text
assert len(streaming_chunks) > 0
assert "".join(streaming_chunks) == streaming_response
assert "".join(streaming_chunks) == streaming_text
@pytest.mark.asyncio
async def test_graph_rag_streaming_callback_invocation(self, graph_rag_streaming):
@ -213,7 +217,8 @@ class TestGraphRagStreaming:
# Assert - Should complete without error
assert response is not None
assert isinstance(response, str)
response_text, usage = response
assert isinstance(response_text, str)
@pytest.mark.asyncio
async def test_graph_rag_streaming_with_empty_kg(self, graph_rag_streaming,

View file

@ -18,6 +18,7 @@ from trustgraph.storage.knowledge.store import Processor as KnowledgeStoreProces
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings, EntityEmbeddings
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL
from trustgraph.base import PromptResult
@pytest.mark.integration
@ -31,32 +32,38 @@ class TestKnowledgeGraphPipelineIntegration:
# Mock prompt client for definitions extraction
prompt_client = AsyncMock()
prompt_client.extract_definitions.return_value = [
{
"entity": "Machine Learning",
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
},
{
"entity": "Neural Networks",
"definition": "Computing systems inspired by biological neural networks that process information."
}
]
prompt_client.extract_definitions.return_value = PromptResult(
response_type="jsonl",
objects=[
{
"entity": "Machine Learning",
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
},
{
"entity": "Neural Networks",
"definition": "Computing systems inspired by biological neural networks that process information."
}
]
)
# Mock prompt client for relationships extraction
prompt_client.extract_relationships.return_value = [
{
"subject": "Machine Learning",
"predicate": "is_subset_of",
"object": "Artificial Intelligence",
"object-entity": True
},
{
"subject": "Neural Networks",
"predicate": "is_used_in",
"object": "Machine Learning",
"object-entity": True
}
]
prompt_client.extract_relationships.return_value = PromptResult(
response_type="jsonl",
objects=[
{
"subject": "Machine Learning",
"predicate": "is_subset_of",
"object": "Artificial Intelligence",
"object-entity": True
},
{
"subject": "Neural Networks",
"predicate": "is_used_in",
"object": "Machine Learning",
"object-entity": True
}
]
)
# Mock producers for output streams
triples_producer = AsyncMock()
@ -489,7 +496,10 @@ class TestKnowledgeGraphPipelineIntegration:
async def test_empty_extraction_results_handling(self, definitions_processor, mock_flow_context, sample_chunk):
"""Test handling of empty extraction results"""
# Arrange
mock_flow_context("prompt-request").extract_definitions.return_value = []
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
response_type="jsonl",
objects=[]
)
mock_msg = MagicMock()
mock_msg.value.return_value = sample_chunk
@ -510,7 +520,10 @@ class TestKnowledgeGraphPipelineIntegration:
async def test_invalid_extraction_format_handling(self, definitions_processor, mock_flow_context, sample_chunk):
"""Test handling of invalid extraction response format"""
# Arrange
mock_flow_context("prompt-request").extract_definitions.return_value = "invalid format" # Should be list
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
response_type="text",
text="invalid format"
) # Should be jsonl with objects list
mock_msg = MagicMock()
mock_msg.value.return_value = sample_chunk
@ -528,13 +541,16 @@ class TestKnowledgeGraphPipelineIntegration:
async def test_entity_filtering_and_validation(self, definitions_processor, mock_flow_context):
"""Test entity filtering and validation in extraction"""
# Arrange
mock_flow_context("prompt-request").extract_definitions.return_value = [
{"entity": "Valid Entity", "definition": "Valid definition"},
{"entity": "", "definition": "Empty entity"}, # Should be filtered
{"entity": "Valid Entity 2", "definition": ""}, # Should be filtered
{"entity": None, "definition": "None entity"}, # Should be filtered
{"entity": "Valid Entity 3", "definition": None}, # Should be filtered
]
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
response_type="jsonl",
objects=[
{"entity": "Valid Entity", "definition": "Valid definition"},
{"entity": "", "definition": "Empty entity"}, # Should be filtered
{"entity": "Valid Entity 2", "definition": ""}, # Should be filtered
{"entity": None, "definition": "None entity"}, # Should be filtered
{"entity": "Valid Entity 3", "definition": None}, # Should be filtered
]
)
sample_chunk = Chunk(
metadata=Metadata(id="test", user="user", collection="collection"),

View file

@ -16,6 +16,7 @@ from trustgraph.schema import (
Chunk, ExtractedObject, Metadata, RowSchema, Field,
PromptRequest, PromptResponse
)
from trustgraph.base import PromptResult
@pytest.mark.integration
@ -114,49 +115,61 @@ class TestObjectExtractionServiceIntegration:
schema_name = schema.get("name") if isinstance(schema, dict) else schema.name
if schema_name == "customer_records":
if "john" in text.lower():
return [
{
"customer_id": "CUST001",
"name": "John Smith",
"email": "john.smith@email.com",
"phone": "555-0123"
}
]
return PromptResult(
response_type="jsonl",
objects=[
{
"customer_id": "CUST001",
"name": "John Smith",
"email": "john.smith@email.com",
"phone": "555-0123"
}
]
)
elif "jane" in text.lower():
return [
{
"customer_id": "CUST002",
"name": "Jane Doe",
"email": "jane.doe@email.com",
"phone": ""
}
]
return PromptResult(
response_type="jsonl",
objects=[
{
"customer_id": "CUST002",
"name": "Jane Doe",
"email": "jane.doe@email.com",
"phone": ""
}
]
)
else:
return []
return PromptResult(response_type="jsonl", objects=[])
elif schema_name == "product_catalog":
if "laptop" in text.lower():
return [
{
"product_id": "PROD001",
"name": "Gaming Laptop",
"price": "1299.99",
"category": "electronics"
}
]
return PromptResult(
response_type="jsonl",
objects=[
{
"product_id": "PROD001",
"name": "Gaming Laptop",
"price": "1299.99",
"category": "electronics"
}
]
)
elif "book" in text.lower():
return [
{
"product_id": "PROD002",
"name": "Python Programming Guide",
"price": "49.99",
"category": "books"
}
]
return PromptResult(
response_type="jsonl",
objects=[
{
"product_id": "PROD002",
"name": "Python Programming Guide",
"price": "49.99",
"category": "books"
}
]
)
else:
return []
return []
return PromptResult(response_type="jsonl", objects=[])
return PromptResult(response_type="jsonl", objects=[])
prompt_client.extract_objects.side_effect = mock_extract_objects

View file

@ -9,6 +9,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.prompt.template.service import Processor
from trustgraph.schema import PromptRequest, PromptResponse, TextCompletionResponse
from trustgraph.base.text_completion_client import TextCompletionResult
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
assert_callback_invoked,
@ -27,34 +28,52 @@ class TestPromptStreaming:
# Mock text completion client with streaming
text_completion_client = AsyncMock()
async def streaming_request(request, recipient=None, timeout=600):
"""Simulate streaming text completion"""
if request.streaming and recipient:
# Simulate streaming chunks
chunks = [
"Machine", " learning", " is", " a", " field",
" of", " artificial", " intelligence", "."
]
# Streaming chunks to send
chunks = [
"Machine", " learning", " is", " a", " field",
" of", " artificial", " intelligence", "."
]
for i, chunk_text in enumerate(chunks):
is_final = (i == len(chunks) - 1)
response = TextCompletionResponse(
response=chunk_text,
error=None,
end_of_stream=is_final
)
final = await recipient(response)
if final:
break
# Final empty chunk
await recipient(TextCompletionResponse(
response="",
async def streaming_text_completion_stream(system, prompt, handler, timeout=600):
"""Simulate streaming text completion via text_completion_stream"""
for i, chunk_text in enumerate(chunks):
response = TextCompletionResponse(
response=chunk_text,
error=None,
end_of_stream=True
))
end_of_stream=False
)
await handler(response)
text_completion_client.request = streaming_request
# Send final empty chunk with end_of_stream
await handler(TextCompletionResponse(
response="",
error=None,
end_of_stream=True
))
return TextCompletionResult(
text=None,
in_token=10,
out_token=9,
model="test-model",
)
async def non_streaming_text_completion(system, prompt, timeout=600):
"""Simulate non-streaming text completion"""
full_text = "Machine learning is a field of artificial intelligence."
return TextCompletionResult(
text=full_text,
in_token=10,
out_token=9,
model="test-model",
)
text_completion_client.text_completion_stream = AsyncMock(
side_effect=streaming_text_completion_stream
)
text_completion_client.text_completion = AsyncMock(
side_effect=non_streaming_text_completion
)
# Mock response producer
response_producer = AsyncMock()
@ -156,14 +175,6 @@ class TestPromptStreaming:
consumer = MagicMock()
# Mock non-streaming text completion
text_completion_client = mock_flow_context_streaming("text-completion-request")
async def non_streaming_text_completion(system, prompt, streaming=False):
return "AI is the simulation of human intelligence in machines."
text_completion_client.text_completion = non_streaming_text_completion
# Act
await prompt_processor_streaming.on_request(
message, consumer, mock_flow_context_streaming
@ -218,17 +229,12 @@ class TestPromptStreaming:
# Mock text completion client that raises an error
text_completion_client = AsyncMock()
async def failing_request(request, recipient=None, timeout=600):
if recipient:
# Send error response with proper Error schema
error_response = TextCompletionResponse(
response="",
error=Error(message="Text completion error", type="processing_error"),
end_of_stream=True
)
await recipient(error_response)
async def failing_stream(system, prompt, handler, timeout=600):
raise RuntimeError("Text completion error")
text_completion_client.request = failing_request
text_completion_client.text_completion_stream = AsyncMock(
side_effect=failing_stream
)
# Mock response producer to capture error response
response_producer = AsyncMock()
@ -255,22 +261,15 @@ class TestPromptStreaming:
consumer = MagicMock()
# Act - The service catches errors and sends error responses, doesn't raise
# Act - The service catches errors and sends an error PromptResponse
await prompt_processor_streaming.on_request(message, consumer, context)
# Assert - Verify error response was sent
assert response_producer.send.call_count > 0
# Check that at least one response contains an error
error_sent = False
for call in response_producer.send.call_args_list:
response = call.args[0]
if hasattr(response, 'error') and response.error:
error_sent = True
assert "Text completion error" in response.error.message
break
assert error_sent, "Expected error response to be sent"
# Assert - error response was sent
calls = response_producer.send.call_args_list
assert len(calls) > 0
error_response = calls[-1].args[0]
assert error_response.error is not None
assert "Text completion error" in error_response.error.message
@pytest.mark.asyncio
async def test_prompt_streaming_preserves_message_id(self, prompt_processor_streaming,
@ -315,21 +314,22 @@ class TestPromptStreaming:
# Mock text completion that sends empty chunks
text_completion_client = AsyncMock()
async def empty_streaming_request(request, recipient=None, timeout=600):
if request.streaming and recipient:
# Send empty chunk followed by final marker
await recipient(TextCompletionResponse(
response="",
error=None,
end_of_stream=False
))
await recipient(TextCompletionResponse(
response="",
error=None,
end_of_stream=True
))
async def empty_streaming(system, prompt, handler, timeout=600):
# Send empty chunk followed by final marker
await handler(TextCompletionResponse(
response="",
error=None,
end_of_stream=False
))
await handler(TextCompletionResponse(
response="",
error=None,
end_of_stream=True
))
text_completion_client.request = empty_streaming_request
text_completion_client.text_completion_stream = AsyncMock(
side_effect=empty_streaming
)
response_producer = AsyncMock()
def context_router(service_name):
@ -401,4 +401,4 @@ class TestPromptStreaming:
# Verify chunks concatenate to expected result
full_text = "".join(chunk_texts)
assert full_text == "Machine learning is a field of artificial intelligence"
assert full_text == "Machine learning is a field of artificial intelligence."

View file

@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, call
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI
from trustgraph.base import PromptResult
class TestGraphRagStreamingProtocol:
@ -46,8 +47,7 @@ class TestGraphRagStreamingProtocol:
async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
# Edge selection returns empty (no edges selected)
return ""
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
if streaming and chunk_callback:
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
@ -55,10 +55,10 @@ class TestGraphRagStreamingProtocol:
await chunk_callback(" answer", False)
await chunk_callback(" is here.", False)
await chunk_callback("", True) # Empty final chunk with end_of_stream=True
return "" # Return value not used since callback handles everything
return PromptResult(response_type="text", text="")
else:
return "The answer is here."
return ""
return PromptResult(response_type="text", text="The answer is here.")
return PromptResult(response_type="text", text="")
client.prompt.side_effect = prompt_side_effect
return client
@ -237,11 +237,13 @@ class TestDocumentRagStreamingProtocol:
await chunk_callback("Document", False)
await chunk_callback(" summary", False)
await chunk_callback(".", True) # Non-empty final chunk
return ""
return PromptResult(response_type="text", text="")
else:
return "Document summary."
return PromptResult(response_type="text", text="Document summary.")
client.document_prompt.side_effect = document_prompt_side_effect
# Mock prompt() for extract-concepts call in DocumentRag
client.prompt.return_value = PromptResult(response_type="text", text="")
return client
@pytest.fixture
@ -334,17 +336,17 @@ class TestStreamingProtocolEdgeCases:
async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
return ""
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
if streaming and chunk_callback:
await chunk_callback("text", False)
await chunk_callback("", False) # Empty but not final
await chunk_callback("more", False)
await chunk_callback("", True) # Empty and final
return ""
return PromptResult(response_type="text", text="")
else:
return "textmore"
return ""
return PromptResult(response_type="text", text="textmore")
return PromptResult(response_type="text", text="")
client.prompt.side_effect = prompt_with_empties

View file

@ -4,14 +4,11 @@ python_paths = .
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts =
addopts =
-v
--tb=short
--strict-markers
--disable-warnings
--cov=trustgraph
--cov-report=html
--cov-report=term-missing
# --cov-fail-under=80
asyncio_mode = auto
markers =

View file

@ -78,10 +78,10 @@ class TestAgentServiceNonStreaming:
# Filter out explain events — those are always sent now
content_responses = [
r for r in sent_responses if r.chunk_type != "explain"
r for r in sent_responses if r.message_type != "explain"
]
explain_responses = [
r for r in sent_responses if r.chunk_type == "explain"
r for r in sent_responses if r.message_type == "explain"
]
# Should have explain events for session, iteration, observation, and final
@ -93,7 +93,7 @@ class TestAgentServiceNonStreaming:
# Check thought message
thought_response = content_responses[0]
assert isinstance(thought_response, AgentResponse)
assert thought_response.chunk_type == "thought"
assert thought_response.message_type == "thought"
assert thought_response.content == "I need to solve this."
assert thought_response.end_of_message is True, "Thought message must have end_of_message=True"
assert thought_response.end_of_dialog is False, "Thought message must have end_of_dialog=False"
@ -101,7 +101,7 @@ class TestAgentServiceNonStreaming:
# Check observation message
observation_response = content_responses[1]
assert isinstance(observation_response, AgentResponse)
assert observation_response.chunk_type == "observation"
assert observation_response.message_type == "observation"
assert observation_response.content == "The answer is 4."
assert observation_response.end_of_message is True, "Observation message must have end_of_message=True"
assert observation_response.end_of_dialog is False, "Observation message must have end_of_dialog=False"
@ -168,10 +168,10 @@ class TestAgentServiceNonStreaming:
# Filter out explain events — those are always sent now
content_responses = [
r for r in sent_responses if r.chunk_type != "explain"
r for r in sent_responses if r.message_type != "explain"
]
explain_responses = [
r for r in sent_responses if r.chunk_type == "explain"
r for r in sent_responses if r.message_type == "explain"
]
# Should have explain events for session and final
@ -183,7 +183,7 @@ class TestAgentServiceNonStreaming:
# Check final answer message
answer_response = content_responses[0]
assert isinstance(answer_response, AgentResponse)
assert answer_response.chunk_type == "answer"
assert answer_response.message_type == "answer"
assert answer_response.content == "4"
assert answer_response.end_of_message is True, "Final answer must have end_of_message=True"
assert answer_response.end_of_dialog is True, "Final answer must have end_of_dialog=True"

View file

@ -29,7 +29,7 @@ class TestThinkCallbackMessageId:
assert len(responses) == 1
assert responses[0].message_id == msg_id
assert responses[0].chunk_type == "thought"
assert responses[0].message_type == "thought"
@pytest.mark.asyncio
async def test_non_streaming_think_has_message_id(self, pattern):
@ -58,7 +58,7 @@ class TestObserveCallbackMessageId:
await observe("result", is_final=True)
assert responses[0].message_id == msg_id
assert responses[0].chunk_type == "observation"
assert responses[0].message_type == "observation"
class TestAnswerCallbackMessageId:
@ -74,7 +74,7 @@ class TestAnswerCallbackMessageId:
await answer("the answer")
assert responses[0].message_id == msg_id
assert responses[0].chunk_type == "answer"
assert responses[0].message_type == "answer"
@pytest.mark.asyncio
async def test_no_message_id_default(self, pattern):

View file

@ -9,6 +9,7 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.agent.orchestrator.meta_router import (
MetaRouter, DEFAULT_PATTERN, DEFAULT_TASK_TYPE,
)
from trustgraph.base import PromptResult
def _make_config(patterns=None, task_types=None):
@ -28,7 +29,9 @@ def _make_config(patterns=None, task_types=None):
def _make_context(prompt_response):
"""Build a mock context that returns a mock prompt client."""
client = AsyncMock()
client.prompt = AsyncMock(return_value=prompt_response)
client.prompt = AsyncMock(
return_value=PromptResult(response_type="text", text=prompt_response)
)
def context(service_name):
return client
@ -274,8 +277,8 @@ class TestRoute:
nonlocal call_count
call_count += 1
if call_count == 1:
return "research" # task type
return "plan-then-execute" # pattern
return PromptResult(response_type="text", text="research")
return PromptResult(response_type="text", text="plan-then-execute")
client.prompt = mock_prompt
context = lambda name: client

View file

@ -18,6 +18,7 @@ from dataclasses import dataclass, field
from trustgraph.schema import (
AgentRequest, AgentResponse, AgentStep, PlanStep,
)
from trustgraph.base import PromptResult
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
@ -68,7 +69,7 @@ def collect_explain_events(respond_mock):
events = []
for call in respond_mock.call_args_list:
resp = call[0][0]
if isinstance(resp, AgentResponse) and resp.chunk_type == "explain":
if isinstance(resp, AgentResponse) and resp.message_type == "explain":
events.append({
"explain_id": resp.explain_id,
"explain_graph": resp.explain_graph,
@ -183,7 +184,7 @@ class TestReactPatternProvenance:
)
async def mock_react(question, history, think, observe, answer,
context, streaming, on_action):
context, streaming, on_action, **kwargs):
# Simulate the on_action callback before returning Final
if on_action:
await on_action(Action(
@ -267,7 +268,7 @@ class TestReactPatternProvenance:
MockAM.return_value = mock_am
async def mock_react(question, history, think, observe, answer,
context, streaming, on_action):
context, streaming, on_action, **kwargs):
if on_action:
await on_action(action)
return action
@ -309,7 +310,7 @@ class TestReactPatternProvenance:
MockAM.return_value = mock_am
async def mock_react(question, history, think, observe, answer,
context, streaming, on_action):
context, streaming, on_action, **kwargs):
if on_action:
await on_action(Action(
thought="done", name="final",
@ -355,10 +356,13 @@ class TestPlanPatternProvenance:
# Mock prompt client for plan creation
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = [
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
]
mock_prompt_client.prompt.return_value = PromptResult(
response_type="jsonl",
objects=[
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
],
)
def flow_factory(name):
if name == "prompt-request":
@ -418,10 +422,13 @@ class TestPlanPatternProvenance:
# Mock prompt for step execution
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = {
"tool": "knowledge-query",
"arguments": {"question": "quantum computing"},
}
mock_prompt_client.prompt.return_value = PromptResult(
response_type="json",
object={
"tool": "knowledge-query",
"arguments": {"question": "quantum computing"},
},
)
def flow_factory(name):
if name == "prompt-request":
@ -475,7 +482,7 @@ class TestPlanPatternProvenance:
# Mock prompt for synthesis
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = "The synthesised answer."
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The synthesised answer.")
def flow_factory(name):
if name == "prompt-request":
@ -542,10 +549,13 @@ class TestSupervisorPatternProvenance:
# Mock prompt for decomposition
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = [
"What is quantum computing?",
"What are qubits?",
]
mock_prompt_client.prompt.return_value = PromptResult(
response_type="jsonl",
objects=[
"What is quantum computing?",
"What are qubits?",
],
)
def flow_factory(name):
if name == "prompt-request":
@ -590,7 +600,7 @@ class TestSupervisorPatternProvenance:
# Mock prompt for synthesis
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = "The combined answer."
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The combined answer.")
def flow_factory(name):
if name == "prompt-request":
@ -639,7 +649,10 @@ class TestSupervisorPatternProvenance:
flow = make_mock_flow()
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = ["Goal A", "Goal B", "Goal C"]
mock_prompt_client.prompt.return_value = PromptResult(
response_type="jsonl",
objects=["Goal A", "Goal B", "Goal C"],
)
def flow_factory(name):
if name == "prompt-request":

View file

@ -20,7 +20,7 @@ class TestParseChunkMessageId:
def test_thought_message_id(self, client):
resp = {
"chunk_type": "thought",
"message_type": "thought",
"content": "thinking...",
"end_of_message": False,
"message_id": "urn:trustgraph:agent:sess/i1/thought",
@ -31,7 +31,7 @@ class TestParseChunkMessageId:
def test_observation_message_id(self, client):
resp = {
"chunk_type": "observation",
"message_type": "observation",
"content": "result",
"end_of_message": True,
"message_id": "urn:trustgraph:agent:sess/i1/observation",
@ -42,7 +42,7 @@ class TestParseChunkMessageId:
def test_answer_message_id(self, client):
resp = {
"chunk_type": "answer",
"message_type": "answer",
"content": "the answer",
"end_of_message": False,
"end_of_dialog": False,
@ -54,7 +54,7 @@ class TestParseChunkMessageId:
def test_thought_missing_message_id(self, client):
resp = {
"chunk_type": "thought",
"message_type": "thought",
"content": "thinking...",
"end_of_message": False,
}
@ -64,7 +64,7 @@ class TestParseChunkMessageId:
def test_answer_missing_message_id(self, client):
resp = {
"chunk_type": "answer",
"message_type": "answer",
"content": "answer",
"end_of_message": True,
"end_of_dialog": True,

View file

@ -7,6 +7,14 @@ from trustgraph.base import metrics
@pytest.fixture(autouse=True)
def reset_metric_singletons():
"""Temporarily remove metric singletons so each test can inject mocks.
Saves any existing class-level metrics and restores them after the test
so that later tests in the same process still find the hasattr() guard
intact deleting without restoring causes every subsequent Processor()
construction to re-register the same Prometheus metric name, which raises
ValueError: Duplicated timeseries.
"""
classes_and_attrs = {
metrics.ConsumerMetrics: [
"state_metric",
@ -23,18 +31,24 @@ def reset_metric_singletons():
],
}
saved = {}
for cls, attrs in classes_and_attrs.items():
for attr in attrs:
if hasattr(cls, attr):
saved[(cls, attr)] = getattr(cls, attr)
delattr(cls, attr)
yield
# Remove anything the test may have set, then restore originals
for cls, attrs in classes_and_attrs.items():
for attr in attrs:
if hasattr(cls, attr):
delattr(cls, attr)
for (cls, attr), value in saved.items():
setattr(cls, attr, value)
def test_consumer_metrics_reuses_singletons_and_records_events(monkeypatch):
enum_factory = MagicMock()

View file

@ -236,6 +236,10 @@ async def test_subscriber_graceful_shutdown():
with patch.object(subscriber, 'run') as mock_run:
# Mock run that simulates graceful shutdown
async def mock_run_graceful():
# Honor the readiness contract: real run() signals _ready
# after binding the consumer, so start() can unblock. Mocks
# of run() must do the same or start() hangs forever.
subscriber._ready.set_result(None)
# Process messages while running, then drain
while subscriber.running or subscriber.draining:
if subscriber.draining:
@ -337,6 +341,8 @@ async def test_subscriber_pending_acks_cleanup():
with patch.object(subscriber, 'run') as mock_run:
# Mock run that simulates cleanup of pending acks
async def mock_run_cleanup():
# Honor the readiness contract — see test_subscriber_graceful_shutdown.
subscriber._ready.set_result(None)
while subscriber.running or subscriber.draining:
await asyncio.sleep(0.05)
if subscriber.draining:
@ -406,4 +412,4 @@ async def test_subscriber_multiple_subscribers():
msg1 = await queue1.get()
msg_all = await queue_all.get()
assert msg1 == {"data": "broadcast"}
assert msg_all == {"data": "broadcast"}
assert msg_all == {"data": "broadcast"}

View file

@ -0,0 +1,189 @@
"""
Regression tests for Subscriber.start() readiness barrier.
Background: prior to the eager-connect fix, Subscriber.start() created
the run() task and returned immediately. The underlying backend consumer
was lazily connected on its first receive() call, which left a setup
race for request/response clients using ephemeral per-subscriber response
queues (RabbitMQ auto-delete exclusive queues): the request would be
published before the response queue was bound, and the broker would
silently drop the reply. fetch_config(), document-embeddings, and
api-gateway all hit this with "Failed to fetch config on notify" /
"Request timeout exception" symptoms.
These tests pin the readiness contract:
await subscriber.start()
# at this point, consumer.ensure_connected() MUST have run
so that any future change which removes the eager bind, or moves it
back to lazy initialisation, fails CI loudly.
"""
import asyncio
import pytest
from unittest.mock import MagicMock
from trustgraph.base.subscriber import Subscriber
def _make_backend(ensure_connected_side_effect=None,
receive_side_effect=None):
"""Build a fake backend whose consumer records ensure_connected /
receive calls. ensure_connected_side_effect lets a test inject a
delay or exception."""
backend = MagicMock()
consumer = MagicMock()
consumer.ensure_connected = MagicMock(
side_effect=ensure_connected_side_effect,
)
# By default receive raises a timeout-style exception that the
# subscriber loop is supposed to swallow as a "no message yet" — this
# keeps the subscriber idling cleanly while the test inspects state.
if receive_side_effect is None:
receive_side_effect = TimeoutError("No message received within timeout")
consumer.receive = MagicMock(side_effect=receive_side_effect)
consumer.acknowledge = MagicMock()
consumer.negative_acknowledge = MagicMock()
consumer.pause_message_listener = MagicMock()
consumer.unsubscribe = MagicMock()
consumer.close = MagicMock()
backend.create_consumer.return_value = consumer
return backend, consumer
def _make_subscriber(backend):
return Subscriber(
backend=backend,
topic="response:tg:config",
subscription="test-sub",
consumer_name="test-consumer",
schema=dict,
max_size=10,
drain_timeout=1.0,
backpressure_strategy="block",
)
class TestSubscriberReadiness:
@pytest.mark.asyncio
async def test_start_calls_ensure_connected_before_returning(self):
"""The barrier: ensure_connected must have been invoked at least
once by the time start() returns."""
backend, consumer = _make_backend()
subscriber = _make_subscriber(backend)
await subscriber.start()
try:
consumer.ensure_connected.assert_called_once()
finally:
await subscriber.stop()
@pytest.mark.asyncio
async def test_start_blocks_until_ensure_connected_completes(self):
"""If ensure_connected is slow, start() must wait for it. This is
the actual race-condition guard it would have failed against
the buggy version where start() returned before run() had even
scheduled the consumer creation."""
connect_started = asyncio.Event()
release_connect = asyncio.Event()
# ensure_connected runs in the executor thread, so we need a
# threading-safe gate. Use a simple busy-wait on a flag set by
# the asyncio side via call_soon_threadsafe — but the simpler
# path is to give it a sleep and observe ordering.
import threading
gate = threading.Event()
def slow_connect():
connect_started.set() # safe: only mutates the Event flag
gate.wait(timeout=2.0)
backend, consumer = _make_backend(
ensure_connected_side_effect=slow_connect,
)
subscriber = _make_subscriber(backend)
start_task = asyncio.create_task(subscriber.start())
# Wait until ensure_connected has begun executing.
await asyncio.wait_for(connect_started.wait(), timeout=2.0)
# ensure_connected is in flight — start() must NOT have returned.
assert not start_task.done(), (
"start() returned before ensure_connected() completed — "
"the readiness barrier is broken and the request/response "
"race condition is back."
)
# Release the gate; start() should now complete promptly.
gate.set()
await asyncio.wait_for(start_task, timeout=2.0)
consumer.ensure_connected.assert_called_once()
await subscriber.stop()
@pytest.mark.asyncio
async def test_start_propagates_consumer_creation_failure(self):
"""If create_consumer() raises, start() must surface the error
rather than hang on the readiness future. The old code path
retried indefinitely inside run() and never let start() unblock."""
backend = MagicMock()
backend.create_consumer.side_effect = RuntimeError("broker down")
subscriber = _make_subscriber(backend)
with pytest.raises(RuntimeError, match="broker down"):
await asyncio.wait_for(subscriber.start(), timeout=2.0)
@pytest.mark.asyncio
async def test_start_propagates_ensure_connected_failure(self):
"""Same contract for an ensure_connected() that raises (e.g. the
broker is up but the queue declare/bind fails)."""
backend, consumer = _make_backend(
ensure_connected_side_effect=RuntimeError("queue declare failed"),
)
subscriber = _make_subscriber(backend)
with pytest.raises(RuntimeError, match="queue declare failed"):
await asyncio.wait_for(subscriber.start(), timeout=2.0)
@pytest.mark.asyncio
async def test_ensure_connected_runs_before_subscriber_running_log(self):
"""Subtle ordering: ensure_connected MUST happen before the
receive loop, so that any reply is captured. We assert this by
checking ensure_connected was called before any receive call."""
call_order = []
def record_ensure():
call_order.append("ensure_connected")
def record_receive(*args, **kwargs):
call_order.append("receive")
raise TimeoutError("No message received within timeout")
backend, consumer = _make_backend(
ensure_connected_side_effect=record_ensure,
receive_side_effect=record_receive,
)
subscriber = _make_subscriber(backend)
await subscriber.start()
# Give the receive loop a tick to run at least once.
await asyncio.sleep(0.05)
await subscriber.stop()
# ensure_connected must come first; receive may not have happened
# yet on a fast machine, but if it did, it must come after.
assert call_order, "neither ensure_connected nor receive was called"
assert call_order[0] == "ensure_connected"

View file

@ -70,11 +70,12 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
# Mock message and flow
mock_message = MagicMock()
mock_consumer = MagicMock()
# Flow exposes parameter lookup via __call__: flow("chunk-size")
mock_flow = MagicMock()
mock_flow.parameters.get.side_effect = lambda param: {
mock_flow.side_effect = lambda key: {
"chunk-size": 2000, # Override chunk size
"chunk-overlap": None # Use default chunk overlap
}.get(param)
}.get(key)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@ -105,10 +106,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.parameters.get.side_effect = lambda param: {
mock_flow.side_effect = lambda key: {
"chunk-size": None, # Use default chunk size
"chunk-overlap": 200 # Override chunk overlap
}.get(param)
}.get(key)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@ -139,10 +140,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.parameters.get.side_effect = lambda param: {
mock_flow.side_effect = lambda key: {
"chunk-size": 1500, # Override chunk size
"chunk-overlap": 150 # Override chunk overlap
}.get(param)
}.get(key)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@ -195,15 +196,15 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_consumer = MagicMock()
mock_producer = AsyncMock()
mock_triples_producer = AsyncMock()
# Flow.__call__ resolves parameters and producers/consumers from the
# same dict — merge both kinds here.
mock_flow = MagicMock()
mock_flow.parameters.get.side_effect = lambda param: {
mock_flow.side_effect = lambda key: {
"chunk-size": 1500,
"chunk-overlap": 150,
}.get(param)
mock_flow.side_effect = lambda name: {
"output": mock_producer,
"triples": mock_triples_producer,
}.get(name)
}.get(key)
# Act
await processor.on_message(mock_message, mock_consumer, mock_flow)
@ -241,7 +242,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.parameters.get.return_value = None # No overrides
mock_flow.side_effect = lambda key: None # No overrides
# Act
chunk_size, chunk_overlap = await processor.chunk_document(

View file

@ -70,11 +70,12 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
# Mock message and flow
mock_message = MagicMock()
mock_consumer = MagicMock()
# Flow exposes parameter lookup via __call__: flow("chunk-size")
mock_flow = MagicMock()
mock_flow.parameters.get.side_effect = lambda param: {
mock_flow.side_effect = lambda key: {
"chunk-size": 400, # Override chunk size
"chunk-overlap": None # Use default chunk overlap
}.get(param)
}.get(key)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@ -105,10 +106,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.parameters.get.side_effect = lambda param: {
mock_flow.side_effect = lambda key: {
"chunk-size": None, # Use default chunk size
"chunk-overlap": 25 # Override chunk overlap
}.get(param)
}.get(key)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@ -139,10 +140,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.parameters.get.side_effect = lambda param: {
mock_flow.side_effect = lambda key: {
"chunk-size": 350, # Override chunk size
"chunk-overlap": 30 # Override chunk overlap
}.get(param)
}.get(key)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@ -195,15 +196,15 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_consumer = MagicMock()
mock_producer = AsyncMock()
mock_triples_producer = AsyncMock()
# Flow.__call__ resolves parameters and producers/consumers from the
# same dict — merge both kinds here.
mock_flow = MagicMock()
mock_flow.parameters.get.side_effect = lambda param: {
mock_flow.side_effect = lambda key: {
"chunk-size": 400,
"chunk-overlap": 40,
}.get(param)
mock_flow.side_effect = lambda name: {
"output": mock_producer,
"triples": mock_triples_producer,
}.get(name)
}.get(key)
# Act
await processor.on_message(mock_message, mock_consumer, mock_flow)
@ -245,7 +246,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
mock_flow.parameters.get.return_value = None # No overrides
mock_flow.side_effect = lambda key: None # No overrides
# Act
chunk_size, chunk_overlap = await processor.chunk_document(

View file

@ -12,6 +12,7 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.extract.kg.definitions.extract import (
Processor, default_triples_batch_size, default_entity_batch_size,
)
from trustgraph.base import PromptResult
from trustgraph.schema import (
Chunk, Triples, EntityContexts, Triple, Metadata, Term, IRI, LITERAL,
)
@ -51,8 +52,12 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
mock_triples_pub = AsyncMock()
mock_ecs_pub = AsyncMock()
mock_prompt_client = AsyncMock()
if isinstance(prompt_result, list):
wrapped = PromptResult(response_type="jsonl", objects=prompt_result)
else:
wrapped = PromptResult(response_type="text", text=prompt_result)
mock_prompt_client.extract_definitions = AsyncMock(
return_value=prompt_result
return_value=wrapped
)
def flow(name):

View file

@ -14,6 +14,7 @@ from trustgraph.extract.kg.relationships.extract import (
from trustgraph.schema import (
Chunk, Triples, Triple, Metadata, Term, IRI, LITERAL,
)
from trustgraph.base import PromptResult
# ---------------------------------------------------------------------------
@ -58,7 +59,10 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
mock_triples_pub = AsyncMock()
mock_prompt_client = AsyncMock()
mock_prompt_client.extract_relationships = AsyncMock(
return_value=prompt_result
return_value=PromptResult(
response_type="jsonl",
objects=prompt_result,
)
)
def flow(name):

View file

@ -0,0 +1,418 @@
"""
Round-trip unit tests for the core msgpack import/export gateway endpoints.
The kg-core export endpoint receives KnowledgeResponse-shaped dicts from
the responder callback and packs them into msgpack tuples. The kg-core
import endpoint takes msgpack tuples back off the wire and rebuilds
KnowledgeRequest-shaped dicts which it then hands to KnowledgeRequestor
(whose translator decodes them into real dataclasses).
Regression coverage: the previous wire format used `"vectors"` (plural)
in the entity blobs and embedded a stale `"m"` field that referenced the
removed `Metadata.metadata` triples-list field. The export side hit a
KeyError on first message; the import side built dicts that the
KnowledgeRequestTranslator (separately fixed) couldn't decode. These
tests pin both halves of the wire protocol.
"""
import msgpack
import pytest
from unittest.mock import AsyncMock, Mock, patch
from trustgraph.gateway.dispatch.core_export import CoreExport
from trustgraph.gateway.dispatch.core_import import CoreImport
# ---------------------------------------------------------------------------
# Helpers — sample translator-shaped dicts (as KnowledgeResponseTranslator
# would emit). The vector wire key is *singular* on purpose; the export
# side previously read the wrong key and crashed.
# ---------------------------------------------------------------------------
def _ge_response_dict():
return {
"graph-embeddings": {
"metadata": {
"id": "doc-1",
"root": "",
"user": "alice",
"collection": "testcoll",
},
"entities": [
{
"entity": {"t": "i", "i": "http://example.org/alice"},
"vector": [0.1, 0.2, 0.3],
},
{
"entity": {"t": "i", "i": "http://example.org/bob"},
"vector": [0.4, 0.5, 0.6],
},
],
}
}
def _triples_response_dict():
return {
"triples": {
"metadata": {
"id": "doc-1",
"root": "",
"user": "alice",
"collection": "testcoll",
},
"triples": [
{
"s": {"t": "i", "i": "http://example.org/alice"},
"p": {"t": "i", "i": "http://example.org/knows"},
"o": {"t": "i", "i": "http://example.org/bob"},
},
],
}
}
def _make_request(id_="doc-1", user="alice"):
request = Mock()
request.query = {"id": id_, "user": user}
return request
def _make_data_reader(payload: bytes):
"""Mock the aiohttp StreamReader: returns payload once, then EOF."""
chunks = [payload, b""]
data = Mock()
async def fake_read(n):
return chunks.pop(0) if chunks else b""
data.read = fake_read
return data
# ---------------------------------------------------------------------------
# Export side: translator-shaped dict -> msgpack bytes
# ---------------------------------------------------------------------------
class TestCoreExportWireFormat:
@pytest.mark.asyncio
@patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
async def test_export_packs_graph_embeddings_with_singular_vector(
self, mock_kr_class,
):
"""The export side must read `ent["vector"]` and emit `v`. The
previous bug was reading `ent["vectors"]` which KeyErrored against
the translator output."""
captured = []
async def fake_kr_process(req_dict, responder):
await responder(_ge_response_dict(), True)
mock_kr = AsyncMock()
mock_kr.start = AsyncMock()
mock_kr.stop = AsyncMock()
mock_kr.process = fake_kr_process
mock_kr_class.return_value = mock_kr
response = AsyncMock()
async def fake_write(b):
captured.append(b)
response.write = fake_write
response.write_eof = AsyncMock()
ok = AsyncMock(return_value=response)
error = AsyncMock()
exporter = CoreExport(backend=Mock())
await exporter.process(
data=Mock(),
error=error,
ok=ok,
request=_make_request(),
)
# Did not raise, did not call error()
error.assert_not_called()
assert len(captured) == 1
unpacker = msgpack.Unpacker()
unpacker.feed(captured[0])
items = list(unpacker)
assert len(items) == 1
msg_type, payload = items[0]
assert msg_type == "ge"
# Metadata envelope: only id/user/collection — no stale `m["m"]`.
assert payload["m"] == {
"i": "doc-1",
"u": "alice",
"c": "testcoll",
}
# Entities: each carries the *singular* `v` and the term envelope
assert len(payload["e"]) == 2
assert payload["e"][0]["v"] == [0.1, 0.2, 0.3]
assert payload["e"][1]["v"] == [0.4, 0.5, 0.6]
assert payload["e"][0]["e"]["i"] == "http://example.org/alice"
@pytest.mark.asyncio
@patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
async def test_export_packs_triples(self, mock_kr_class):
captured = []
async def fake_kr_process(req_dict, responder):
await responder(_triples_response_dict(), True)
mock_kr = AsyncMock()
mock_kr.start = AsyncMock()
mock_kr.stop = AsyncMock()
mock_kr.process = fake_kr_process
mock_kr_class.return_value = mock_kr
response = AsyncMock()
async def fake_write(b):
captured.append(b)
response.write = fake_write
response.write_eof = AsyncMock()
ok = AsyncMock(return_value=response)
error = AsyncMock()
exporter = CoreExport(backend=Mock())
await exporter.process(
data=Mock(), error=error, ok=ok, request=_make_request(),
)
error.assert_not_called()
assert len(captured) == 1
unpacker = msgpack.Unpacker()
unpacker.feed(captured[0])
items = list(unpacker)
assert len(items) == 1
msg_type, payload = items[0]
assert msg_type == "t"
assert payload["m"] == {
"i": "doc-1",
"u": "alice",
"c": "testcoll",
}
assert len(payload["t"]) == 1
# ---------------------------------------------------------------------------
# Import side: msgpack bytes -> translator-shaped dict
# ---------------------------------------------------------------------------
class TestCoreImportWireFormat:
@pytest.mark.asyncio
@patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
async def test_import_unpacks_graph_embeddings_to_singular_vector(
self, mock_kr_class,
):
"""The import side must build dicts whose entity blobs have the
singular `vector` key that's what the KnowledgeRequestTranslator
decode side reads. Previous bug emitted `vectors`."""
captured = []
async def fake_kr_process(req_dict):
captured.append(req_dict)
mock_kr = AsyncMock()
mock_kr.start = AsyncMock()
mock_kr.stop = AsyncMock()
mock_kr.process = fake_kr_process
mock_kr_class.return_value = mock_kr
# Build a msgpack tuple matching the new wire format
payload = msgpack.packb((
"ge",
{
"m": {"i": "doc-1", "u": "alice", "c": "testcoll"},
"e": [
{
"e": {"t": "i", "i": "http://example.org/alice"},
"v": [0.1, 0.2, 0.3],
},
],
},
))
ok = AsyncMock(return_value=AsyncMock(write_eof=AsyncMock()))
error = AsyncMock()
importer = CoreImport(backend=Mock())
await importer.process(
data=_make_data_reader(payload),
error=error,
ok=ok,
request=_make_request(),
)
error.assert_not_called()
assert len(captured) == 1
req = captured[0]
assert req["operation"] == "put-kg-core"
assert req["user"] == "alice"
assert req["id"] == "doc-1"
ge = req["graph-embeddings"]
# Metadata envelope must NOT contain a stale `metadata` key
# referencing the removed Metadata.metadata field.
assert "metadata" not in ge["metadata"]
assert ge["metadata"] == {
"id": "doc-1",
"user": "alice",
"collection": "default",
}
# Entity blob carries the singular `vector` key
assert len(ge["entities"]) == 1
ent = ge["entities"][0]
assert ent["vector"] == [0.1, 0.2, 0.3]
assert "vectors" not in ent
@pytest.mark.asyncio
@patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
async def test_import_unpacks_triples(self, mock_kr_class):
captured = []
async def fake_kr_process(req_dict):
captured.append(req_dict)
mock_kr = AsyncMock()
mock_kr.start = AsyncMock()
mock_kr.stop = AsyncMock()
mock_kr.process = fake_kr_process
mock_kr_class.return_value = mock_kr
payload = msgpack.packb((
"t",
{
"m": {"i": "doc-1", "u": "alice", "c": "testcoll"},
"t": [
{
"s": {"t": "i", "i": "http://example.org/alice"},
"p": {"t": "i", "i": "http://example.org/knows"},
"o": {"t": "i", "i": "http://example.org/bob"},
},
],
},
))
ok = AsyncMock(return_value=AsyncMock(write_eof=AsyncMock()))
error = AsyncMock()
importer = CoreImport(backend=Mock())
await importer.process(
data=_make_data_reader(payload),
error=error,
ok=ok,
request=_make_request(),
)
error.assert_not_called()
assert len(captured) == 1
req = captured[0]
triples = req["triples"]
assert "metadata" not in triples["metadata"] # no stale field
assert len(triples["triples"]) == 1
# ---------------------------------------------------------------------------
# Full round-trip: export bytes feed directly into import
# ---------------------------------------------------------------------------
class TestCoreImportExportRoundTrip:
"""End-to-end: produce bytes via core_export, consume them via
core_import, and verify the dict that lands at the import-side
translator is structurally equivalent to what went in. This is the
test that catches asymmetries between the two halves."""
@pytest.mark.asyncio
@patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
@patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
async def test_graph_embeddings_round_trip(
self, mock_export_kr_class, mock_import_kr_class,
):
# ----- export side: capture bytes -----
export_bytes = []
async def fake_export_process(req_dict, responder):
await responder(_ge_response_dict(), True)
export_kr = AsyncMock()
export_kr.start = AsyncMock()
export_kr.stop = AsyncMock()
export_kr.process = fake_export_process
mock_export_kr_class.return_value = export_kr
response = AsyncMock()
async def fake_write(b):
export_bytes.append(b)
response.write = fake_write
response.write_eof = AsyncMock()
exporter = CoreExport(backend=Mock())
await exporter.process(
data=Mock(),
error=AsyncMock(),
ok=AsyncMock(return_value=response),
request=_make_request(),
)
assert len(export_bytes) == 1
# ----- import side: feed those bytes back in -----
import_captured = []
async def fake_import_process(req_dict):
import_captured.append(req_dict)
import_kr = AsyncMock()
import_kr.start = AsyncMock()
import_kr.stop = AsyncMock()
import_kr.process = fake_import_process
mock_import_kr_class.return_value = import_kr
importer = CoreImport(backend=Mock())
await importer.process(
data=_make_data_reader(export_bytes[0]),
error=AsyncMock(),
ok=AsyncMock(return_value=AsyncMock(write_eof=AsyncMock())),
request=_make_request(),
)
# ----- verify the dict the importer would hand to the translator -----
assert len(import_captured) == 1
req = import_captured[0]
original = _ge_response_dict()["graph-embeddings"]
ge = req["graph-embeddings"]
# The import side overrides id/user from the URL query (intentional),
# so we only round-trip the entity payload itself.
assert ge["metadata"]["id"] == original["metadata"]["id"]
assert ge["metadata"]["user"] == original["metadata"]["user"]
assert len(ge["entities"]) == len(original["entities"])
for got, want in zip(ge["entities"], original["entities"]):
assert got["vector"] == want["vector"]
assert got["entity"] == want["entity"]

View file

@ -0,0 +1,242 @@
"""
Unit tests for entity contexts import dispatcher.
Tests the business logic of EntityContextsImport while mocking the
Publisher and websocket components.
Regression coverage: a previous version constructed Metadata(metadata=...)
which raised TypeError at runtime as soon as a message was received. These
tests exercise receive() end-to-end so any future schema/kwarg drift in
the Metadata or EntityContexts construction is caught immediately.
"""
import pytest
from unittest.mock import Mock, AsyncMock, patch
from trustgraph.gateway.dispatch.entity_contexts_import import EntityContextsImport
from trustgraph.schema import EntityContexts, EntityContext, Metadata
@pytest.fixture
def mock_backend():
return Mock()
@pytest.fixture
def mock_running():
running = Mock()
running.get.return_value = True
running.stop = Mock()
return running
@pytest.fixture
def mock_websocket():
ws = Mock()
ws.close = AsyncMock()
return ws
@pytest.fixture
def sample_message():
"""Sample entity-contexts websocket message."""
return {
"metadata": {
"id": "doc-123",
"user": "testuser",
"collection": "testcollection",
},
"entities": [
{
"entity": {"v": "http://example.org/alice", "e": True},
"context": "Alice is a person.",
},
{
"entity": {"v": "http://example.org/bob", "e": True},
"context": "Bob is a person.",
},
],
}
@pytest.fixture
def empty_entities_message():
return {
"metadata": {
"id": "doc-empty",
"user": "u",
"collection": "c",
},
"entities": [],
}
class TestEntityContextsImportInitialization:
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
def test_init_creates_publisher_with_correct_params(
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
):
instance = Mock()
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
queue="ec-queue",
)
mock_publisher_class.assert_called_once_with(
mock_backend,
topic="ec-queue",
schema=EntityContexts,
)
assert dispatcher.ws is mock_websocket
assert dispatcher.running is mock_running
assert dispatcher.publisher is instance
class TestEntityContextsImportLifecycle:
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
@pytest.mark.asyncio
async def test_start_calls_publisher_start(
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
):
instance = Mock()
instance.start = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
await dispatcher.start()
instance.start.assert_called_once()
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_stops_and_closes_properly(
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
):
instance = Mock()
instance.stop = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
await dispatcher.destroy()
mock_running.stop.assert_called_once()
instance.stop.assert_called_once()
mock_websocket.close.assert_called_once()
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_handles_none_websocket(
self, mock_publisher_class, mock_backend, mock_running
):
instance = Mock()
instance.stop = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=None, running=mock_running,
backend=mock_backend, queue="q",
)
await dispatcher.destroy()
mock_running.stop.assert_called_once()
instance.stop.assert_called_once()
class TestEntityContextsImportMessageProcessing:
"""Regression coverage for receive(): catches Metadata/schema drift."""
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
@pytest.mark.asyncio
async def test_receive_constructs_entity_contexts_correctly(
self, mock_publisher_class, mock_backend, mock_websocket,
mock_running, sample_message,
):
instance = Mock()
instance.send = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
mock_msg = Mock()
mock_msg.json.return_value = sample_message
# If Metadata or EntityContexts gain/lose kwargs, this raises
# TypeError — exactly the regression we want to catch.
await dispatcher.receive(mock_msg)
instance.send.assert_called_once()
call_args = instance.send.call_args
assert call_args[0][0] is None
sent = call_args[0][1]
assert isinstance(sent, EntityContexts)
assert isinstance(sent.metadata, Metadata)
assert sent.metadata.id == "doc-123"
assert sent.metadata.user == "testuser"
assert sent.metadata.collection == "testcollection"
assert len(sent.entities) == 2
assert all(isinstance(e, EntityContext) for e in sent.entities)
assert sent.entities[0].context == "Alice is a person."
assert sent.entities[1].context == "Bob is a person."
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
@pytest.mark.asyncio
async def test_receive_handles_empty_entities(
self, mock_publisher_class, mock_backend, mock_websocket,
mock_running, empty_entities_message,
):
instance = Mock()
instance.send = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
mock_msg = Mock()
mock_msg.json.return_value = empty_entities_message
await dispatcher.receive(mock_msg)
instance.send.assert_called_once()
sent = instance.send.call_args[0][1]
assert isinstance(sent, EntityContexts)
assert sent.entities == []
assert sent.metadata.id == "doc-empty"
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
@pytest.mark.asyncio
async def test_receive_propagates_publisher_errors(
self, mock_publisher_class, mock_backend, mock_websocket,
mock_running, sample_message,
):
instance = Mock()
instance.send = AsyncMock(side_effect=RuntimeError("publish failed"))
mock_publisher_class.return_value = instance
dispatcher = EntityContextsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
mock_msg = Mock()
mock_msg.json.return_value = sample_message
with pytest.raises(RuntimeError, match="publish failed"):
await dispatcher.receive(mock_msg)

View file

@ -158,7 +158,7 @@ class TestAgentExplainTriples:
translator = AgentResponseTranslator()
response = AgentResponse(
chunk_type="explain",
message_type="explain",
content="",
explain_id="urn:trustgraph:agent:session:abc123",
explain_graph="urn:graph:retrieval",
@ -179,7 +179,7 @@ class TestAgentExplainTriples:
translator = AgentResponseTranslator()
response = AgentResponse(
chunk_type="thought",
message_type="thought",
content="I need to think...",
)
@ -190,7 +190,7 @@ class TestAgentExplainTriples:
translator = AgentResponseTranslator()
response = AgentResponse(
chunk_type="explain",
message_type="explain",
explain_id="urn:trustgraph:agent:session:abc123",
explain_triples=sample_triples(),
end_of_dialog=False,
@ -203,7 +203,7 @@ class TestAgentExplainTriples:
translator = AgentResponseTranslator()
response = AgentResponse(
chunk_type="answer",
message_type="answer",
content="The answer is...",
end_of_dialog=True,
)

View file

@ -0,0 +1,247 @@
"""
Unit tests for graph embeddings import dispatcher.
Tests the business logic of GraphEmbeddingsImport while mocking the
Publisher and websocket components.
Regression coverage: a previous version of EntityContextsImport
constructed Metadata(metadata=...) which raised TypeError at runtime as
soon as a message was received. The same shape of bug can occur here, so
these tests exercise receive() end-to-end to catch any future schema or
kwarg drift in Metadata / GraphEmbeddings / EntityEmbeddings construction.
"""
import pytest
from unittest.mock import Mock, AsyncMock, patch
from trustgraph.gateway.dispatch.graph_embeddings_import import GraphEmbeddingsImport
from trustgraph.schema import GraphEmbeddings, EntityEmbeddings, Metadata
@pytest.fixture
def mock_backend():
return Mock()
@pytest.fixture
def mock_running():
running = Mock()
running.get.return_value = True
running.stop = Mock()
return running
@pytest.fixture
def mock_websocket():
ws = Mock()
ws.close = AsyncMock()
return ws
@pytest.fixture
def sample_message():
"""Sample graph-embeddings websocket message."""
return {
"metadata": {
"id": "doc-123",
"user": "testuser",
"collection": "testcollection",
},
"entities": [
{
"entity": {"v": "http://example.org/alice", "e": True},
"vector": [0.1, 0.2, 0.3],
},
{
"entity": {"v": "http://example.org/bob", "e": True},
"vector": [0.4, 0.5, 0.6],
},
],
}
@pytest.fixture
def empty_entities_message():
return {
"metadata": {
"id": "doc-empty",
"user": "u",
"collection": "c",
},
"entities": [],
}
class TestGraphEmbeddingsImportInitialization:
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
def test_init_creates_publisher_with_correct_params(
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
):
instance = Mock()
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=mock_websocket,
running=mock_running,
backend=mock_backend,
queue="ge-queue",
)
mock_publisher_class.assert_called_once_with(
mock_backend,
topic="ge-queue",
schema=GraphEmbeddings,
)
assert dispatcher.ws is mock_websocket
assert dispatcher.running is mock_running
assert dispatcher.publisher is instance
class TestGraphEmbeddingsImportLifecycle:
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
@pytest.mark.asyncio
async def test_start_calls_publisher_start(
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
):
instance = Mock()
instance.start = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
await dispatcher.start()
instance.start.assert_called_once()
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_stops_and_closes_properly(
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
):
instance = Mock()
instance.stop = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
await dispatcher.destroy()
mock_running.stop.assert_called_once()
instance.stop.assert_called_once()
mock_websocket.close.assert_called_once()
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
@pytest.mark.asyncio
async def test_destroy_handles_none_websocket(
self, mock_publisher_class, mock_backend, mock_running
):
instance = Mock()
instance.stop = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=None, running=mock_running,
backend=mock_backend, queue="q",
)
await dispatcher.destroy()
mock_running.stop.assert_called_once()
instance.stop.assert_called_once()
class TestGraphEmbeddingsImportMessageProcessing:
"""Regression coverage for receive(): catches Metadata/schema drift."""
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
@pytest.mark.asyncio
async def test_receive_constructs_graph_embeddings_correctly(
self, mock_publisher_class, mock_backend, mock_websocket,
mock_running, sample_message,
):
instance = Mock()
instance.send = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
mock_msg = Mock()
mock_msg.json.return_value = sample_message
# If Metadata, GraphEmbeddings, or EntityEmbeddings gain/lose
# kwargs, this raises TypeError — exactly the regression we want
# to catch.
await dispatcher.receive(mock_msg)
instance.send.assert_called_once()
call_args = instance.send.call_args
assert call_args[0][0] is None
sent = call_args[0][1]
assert isinstance(sent, GraphEmbeddings)
assert isinstance(sent.metadata, Metadata)
assert sent.metadata.id == "doc-123"
assert sent.metadata.user == "testuser"
assert sent.metadata.collection == "testcollection"
assert len(sent.entities) == 2
assert all(isinstance(e, EntityEmbeddings) for e in sent.entities)
# Lock in the wire format: incoming "vector" key (singular,
# list[float]) maps to EntityEmbeddings.vector. This mirrors
# serialize_graph_embeddings() on the export side.
assert sent.entities[0].vector == [0.1, 0.2, 0.3]
assert sent.entities[1].vector == [0.4, 0.5, 0.6]
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
@pytest.mark.asyncio
async def test_receive_handles_empty_entities(
self, mock_publisher_class, mock_backend, mock_websocket,
mock_running, empty_entities_message,
):
instance = Mock()
instance.send = AsyncMock()
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
mock_msg = Mock()
mock_msg.json.return_value = empty_entities_message
await dispatcher.receive(mock_msg)
instance.send.assert_called_once()
sent = instance.send.call_args[0][1]
assert isinstance(sent, GraphEmbeddings)
assert sent.entities == []
assert sent.metadata.id == "doc-empty"
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
@pytest.mark.asyncio
async def test_receive_propagates_publisher_errors(
self, mock_publisher_class, mock_backend, mock_websocket,
mock_running, sample_message,
):
instance = Mock()
instance.send = AsyncMock(side_effect=RuntimeError("publish failed"))
mock_publisher_class.return_value = instance
dispatcher = GraphEmbeddingsImport(
ws=mock_websocket, running=mock_running,
backend=mock_backend, queue="q",
)
mock_msg = Mock()
mock_msg.json.return_value = sample_message
with pytest.raises(RuntimeError, match="publish failed"):
await dispatcher.receive(mock_msg)

View file

@ -171,6 +171,14 @@ class TestApi:
patch('aiohttp.web.run_app') as mock_run_app:
mock_get_pubsub.return_value = Mock()
# Api.run() passes self.app_factory() — a coroutine — to
# web.run_app, which would normally consume it inside its own
# event loop. Since we mock run_app, close the coroutine here
# so it doesn't leak as an "unawaited coroutine" RuntimeWarning.
def _consume_coro(coro, **kwargs):
coro.close()
mock_run_app.side_effect = _consume_coro
api = Api(port=8080)
api.run()

View file

@ -0,0 +1,592 @@
"""
DAG structure tests for provenance chains.
Verifies that the wasDerivedFrom chain has the expected shape for each
service. These tests catch structural regressions when new entities are
inserted into the chain (e.g. PatternDecision between session and first
iteration).
Expected chains:
GraphRAG: question grounding exploration focus synthesis
DocumentRAG: question grounding exploration synthesis
Agent React: session pattern-decision iteration (observation iteration)* final
Agent Plan: session pattern-decision plan step-result(s) synthesis
Agent Super: session pattern-decision decomposition (fan-out) finding(s) synthesis
"""
import json
import uuid
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from trustgraph.schema import (
AgentRequest, AgentResponse, AgentStep, PlanStep,
Triple, Term, IRI, LITERAL,
)
from trustgraph.base import PromptResult
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_WAS_DERIVED_FROM, GRAPH_RETRIEVAL,
TG_AGENT_QUESTION, TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION,
TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
TG_ANALYSIS, TG_CONCLUSION, TG_PATTERN_DECISION,
TG_PLAN_TYPE, TG_STEP_RESULT, TG_DECOMPOSITION,
TG_OBSERVATION_TYPE,
TG_PATTERN, TG_TASK_TYPE,
)
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _collect_events(events):
"""Build a dict of explain_id → {types, derived_from, triples}."""
result = {}
for ev in events:
eid = ev["explain_id"]
triples = ev["triples"]
types = {
t.o.iri for t in triples
if t.s.iri == eid and t.p.iri == RDF_TYPE
}
parents = [
t.o.iri for t in triples
if t.s.iri == eid and t.p.iri == PROV_WAS_DERIVED_FROM
]
result[eid] = {
"types": types,
"derived_from": parents[0] if parents else None,
"triples": triples,
}
return result
def _find_by_type(dag, rdf_type):
"""Find all event IDs that have the given rdf:type."""
return [eid for eid, info in dag.items() if rdf_type in info["types"]]
def _assert_chain(dag, chain_types):
"""Assert that a linear wasDerivedFrom chain exists through the given types."""
for i in range(1, len(chain_types)):
parent_type = chain_types[i - 1]
child_type = chain_types[i]
parents = _find_by_type(dag, parent_type)
children = _find_by_type(dag, child_type)
assert parents, f"No entity with type {parent_type}"
assert children, f"No entity with type {child_type}"
# At least one child must derive from at least one parent
linked = False
for child_id in children:
derived = dag[child_id]["derived_from"]
if derived in parents:
linked = True
break
assert linked, (
f"No {child_type} derives from {parent_type}. "
f"Children derive from: "
f"{[dag[c]['derived_from'] for c in children]}"
)
# ---------------------------------------------------------------------------
# GraphRAG DAG structure
# ---------------------------------------------------------------------------
class TestGraphRagDagStructure:
"""Verify: question → grounding → exploration → focus → synthesis"""
@pytest.fixture
def mock_clients(self):
prompt_client = AsyncMock()
embeddings_client = AsyncMock()
graph_embeddings_client = AsyncMock()
triples_client = AsyncMock()
embeddings_client.embed.return_value = [[0.1, 0.2]]
graph_embeddings_client.query.return_value = [
MagicMock(entity=Term(type=IRI, iri="http://example.com/e1")),
]
triples_client.query_stream.return_value = [
Triple(
s=Term(type=IRI, iri="http://example.com/e1"),
p=Term(type=IRI, iri="http://example.com/p"),
o=Term(type=LITERAL, value="value"),
)
]
triples_client.query.return_value = []
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return PromptResult(response_type="text", text="concept")
elif template_id == "kg-edge-scoring":
edges = variables.get("knowledge", [])
return PromptResult(
response_type="jsonl",
objects=[{"id": e["id"], "score": 10} for e in edges],
)
elif template_id == "kg-edge-reasoning":
edges = variables.get("knowledge", [])
return PromptResult(
response_type="jsonl",
objects=[{"id": e["id"], "reasoning": "relevant"} for e in edges],
)
elif template_id == "kg-synthesis":
return PromptResult(response_type="text", text="Answer.")
return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
@pytest.mark.asyncio
async def test_dag_chain(self, mock_clients):
rag = GraphRag(*mock_clients)
events = []
async def explain_cb(triples, explain_id):
events.append({"explain_id": explain_id, "triples": triples})
await rag.query(
query="test", explain_callback=explain_cb, edge_score_limit=0,
)
dag = _collect_events(events)
assert len(dag) == 5, f"Expected 5 events, got {len(dag)}"
_assert_chain(dag, [
TG_GRAPH_RAG_QUESTION,
TG_GROUNDING,
TG_EXPLORATION,
TG_FOCUS,
TG_SYNTHESIS,
])
# ---------------------------------------------------------------------------
# DocumentRAG DAG structure
# ---------------------------------------------------------------------------
class TestDocumentRagDagStructure:
"""Verify: question → grounding → exploration → synthesis"""
@pytest.fixture
def mock_clients(self):
from trustgraph.schema import ChunkMatch
prompt_client = AsyncMock()
embeddings_client = AsyncMock()
doc_embeddings_client = AsyncMock()
fetch_chunk = AsyncMock(return_value="Chunk content.")
embeddings_client.embed.return_value = [[0.1, 0.2]]
doc_embeddings_client.query.return_value = [
ChunkMatch(chunk_id="doc/c1", score=0.9),
]
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return PromptResult(response_type="text", text="concept")
return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt
prompt_client.document_prompt.return_value = PromptResult(
response_type="text", text="Answer.",
)
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
@pytest.mark.asyncio
async def test_dag_chain(self, mock_clients):
rag = DocumentRag(*mock_clients)
events = []
async def explain_cb(triples, explain_id):
events.append({"explain_id": explain_id, "triples": triples})
await rag.query(
query="test", explain_callback=explain_cb,
)
dag = _collect_events(events)
assert len(dag) == 4, f"Expected 4 events, got {len(dag)}"
_assert_chain(dag, [
TG_DOC_RAG_QUESTION,
TG_GROUNDING,
TG_EXPLORATION,
TG_SYNTHESIS,
])
# ---------------------------------------------------------------------------
# Agent DAG structure — tested via service.agent_request()
# ---------------------------------------------------------------------------
def _make_processor(tools=None):
processor = MagicMock()
processor.max_iterations = 10
processor.save_answer_content = AsyncMock()
def mock_session_uri(sid):
return f"urn:trustgraph:agent:session:{sid}"
processor.provenance_session_uri.side_effect = mock_session_uri
agent = MagicMock()
agent.tools = tools or {}
agent.additional_context = ""
processor.agent = agent
processor.aggregator = MagicMock()
return processor
def _make_flow():
producers = {}
def factory(name):
if name not in producers:
producers[name] = AsyncMock()
return producers[name]
flow = MagicMock(side_effect=factory)
return flow
def _collect_agent_events(respond_mock):
events = []
for call in respond_mock.call_args_list:
resp = call[0][0]
if isinstance(resp, AgentResponse) and resp.message_type == "explain":
events.append({
"explain_id": resp.explain_id,
"triples": resp.explain_triples,
})
return events
class TestAgentReactDagStructure:
"""
Via service.agent_request(), full two-iteration react chain:
session pattern-decision iteration(1) observation(1) final
Iteration 1: tool call observation
Iteration 2: final answer
"""
def _make_service(self):
from trustgraph.agent.orchestrator.service import Processor
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
mock_tool = MagicMock()
mock_tool.name = "lookup"
mock_tool.description = "Look things up"
mock_tool.arguments = []
mock_tool.groups = []
mock_tool.states = {}
mock_tool_impl = AsyncMock(return_value="42")
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
processor = _make_processor(tools={"lookup": mock_tool})
service = Processor.__new__(Processor)
service.max_iterations = 10
service.save_answer_content = AsyncMock()
service.provenance_session_uri = processor.provenance_session_uri
service.agent = processor.agent
service.aggregator = processor.aggregator
service.react_pattern = ReactPattern(service)
service.plan_pattern = PlanThenExecutePattern(service)
service.supervisor_pattern = SupervisorPattern(service)
service.meta_router = None
return service
@pytest.mark.asyncio
async def test_dag_chain(self):
from trustgraph.agent.react.types import Action, Final
service = self._make_service()
respond = AsyncMock()
next_fn = AsyncMock()
flow = _make_flow()
session_id = str(uuid.uuid4())
# Iteration 1: tool call → returns Action, triggers on_action + tool exec
action = Action(
thought="I need to look this up",
name="lookup",
arguments={"question": "6x7"},
observation="",
)
with patch(
"trustgraph.agent.orchestrator.react_pattern.AgentManager"
) as MockAM:
mock_am = AsyncMock()
MockAM.return_value = mock_am
async def mock_react_iter1(on_action=None, **kwargs):
if on_action:
await on_action(action)
action.observation = "42"
return action
mock_am.react.side_effect = mock_react_iter1
request1 = AgentRequest(
question="What is 6x7?",
user="testuser",
collection="default",
streaming=False,
session_id=session_id,
pattern="react",
history=[],
)
await service.agent_request(request1, respond, next_fn, flow)
# next_fn should have been called with updated history
assert next_fn.called
# Iteration 2: final answer
final = Final(thought="The answer is 42", final="42")
next_request = next_fn.call_args[0][0]
with patch(
"trustgraph.agent.orchestrator.react_pattern.AgentManager"
) as MockAM:
mock_am = AsyncMock()
MockAM.return_value = mock_am
async def mock_react_iter2(**kwargs):
return final
mock_am.react.side_effect = mock_react_iter2
await service.agent_request(next_request, respond, next_fn, flow)
# Collect and verify DAG
events = _collect_agent_events(respond)
dag = _collect_events(events)
session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
analysis_ids = _find_by_type(dag, TG_ANALYSIS)
observation_ids = _find_by_type(dag, TG_OBSERVATION_TYPE)
final_ids = _find_by_type(dag, TG_CONCLUSION)
assert len(session_ids) == 1, f"Expected 1 session, got {len(session_ids)}"
assert len(pd_ids) == 1, f"Expected 1 pattern-decision, got {len(pd_ids)}"
assert len(analysis_ids) >= 1, f"Expected >=1 analysis, got {len(analysis_ids)}"
assert len(observation_ids) >= 1, f"Expected >=1 observation, got {len(observation_ids)}"
assert len(final_ids) == 1, f"Expected 1 final, got {len(final_ids)}"
# Full chain:
# session → pattern-decision
assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
# pattern-decision → iteration(1)
assert dag[analysis_ids[0]]["derived_from"] == pd_ids[0]
# iteration(1) → observation(1)
assert dag[observation_ids[0]]["derived_from"] == analysis_ids[0]
# observation(1) → final
assert dag[final_ids[0]]["derived_from"] == observation_ids[0]
class TestAgentPlanDagStructure:
"""
Via service.agent_request():
session pattern-decision plan step-result synthesis
"""
@pytest.mark.asyncio
async def test_dag_chain(self):
from trustgraph.agent.orchestrator.service import Processor
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
# Mock tool
mock_tool = MagicMock()
mock_tool.name = "knowledge-query"
mock_tool.description = "Query KB"
mock_tool.arguments = []
mock_tool.groups = []
mock_tool.states = {}
mock_tool_impl = AsyncMock(return_value="Found it")
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
processor = _make_processor(tools={"knowledge-query": mock_tool})
service = Processor.__new__(Processor)
service.max_iterations = 10
service.save_answer_content = AsyncMock()
service.provenance_session_uri = processor.provenance_session_uri
service.agent = processor.agent
service.aggregator = processor.aggregator
service.react_pattern = ReactPattern(service)
service.plan_pattern = PlanThenExecutePattern(service)
service.supervisor_pattern = SupervisorPattern(service)
service.meta_router = None
respond = AsyncMock()
next_fn = AsyncMock()
flow = _make_flow()
# Mock prompt client
mock_prompt_client = AsyncMock()
call_count = 0
async def mock_prompt(id, variables=None, **kwargs):
nonlocal call_count
call_count += 1
if id == "plan-create":
return PromptResult(
response_type="jsonl",
objects=[{"goal": "Find info", "tool_hint": "knowledge-query", "depends_on": []}],
)
elif id == "plan-step-execute":
return PromptResult(
response_type="json",
object={"tool": "knowledge-query", "arguments": {"question": "test"}},
)
elif id == "plan-synthesise":
return PromptResult(response_type="text", text="Final answer.")
return PromptResult(response_type="text", text="")
mock_prompt_client.prompt.side_effect = mock_prompt
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
session_id = str(uuid.uuid4())
# Iteration 1: planning
request1 = AgentRequest(
question="Test?",
user="testuser",
collection="default",
streaming=False,
session_id=session_id,
pattern="plan-then-execute",
history=[],
)
await service.agent_request(request1, respond, next_fn, flow)
# Iteration 2: execute step (next_fn was called with updated request)
assert next_fn.called
next_request = next_fn.call_args[0][0]
# Iteration 3: all steps done → synthesis
# Simulate completed step in history
next_request.history[-1].plan[0].status = "completed"
next_request.history[-1].plan[0].result = "Found it"
await service.agent_request(next_request, respond, next_fn, flow)
events = _collect_agent_events(respond)
dag = _collect_events(events)
session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
plan_ids = _find_by_type(dag, TG_PLAN_TYPE)
synthesis_ids = _find_by_type(dag, TG_SYNTHESIS)
assert len(session_ids) == 1
assert len(pd_ids) == 1
assert len(plan_ids) == 1
assert len(synthesis_ids) == 1
# Chain: session → pattern-decision → plan → ... → synthesis
assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
assert dag[plan_ids[0]]["derived_from"] == pd_ids[0]
class TestAgentSupervisorDagStructure:
"""
Via service.agent_request():
session pattern-decision decomposition (fan-out)
"""
@pytest.mark.asyncio
async def test_dag_chain(self):
from trustgraph.agent.orchestrator.service import Processor
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
processor = _make_processor()
service = Processor.__new__(Processor)
service.max_iterations = 10
service.save_answer_content = AsyncMock()
service.provenance_session_uri = processor.provenance_session_uri
service.agent = processor.agent
service.aggregator = processor.aggregator
service.react_pattern = ReactPattern(service)
service.plan_pattern = PlanThenExecutePattern(service)
service.supervisor_pattern = SupervisorPattern(service)
service.meta_router = None
respond = AsyncMock()
next_fn = AsyncMock()
flow = _make_flow()
mock_prompt_client = AsyncMock()
mock_prompt_client.prompt.return_value = PromptResult(
response_type="jsonl",
objects=["Goal A", "Goal B"],
)
def flow_factory(name):
if name == "prompt-request":
return mock_prompt_client
return AsyncMock()
flow.side_effect = flow_factory
request = AgentRequest(
question="Research quantum computing",
user="testuser",
collection="default",
streaming=False,
session_id=str(uuid.uuid4()),
pattern="supervisor",
history=[],
)
await service.agent_request(request, respond, next_fn, flow)
events = _collect_agent_events(respond)
dag = _collect_events(events)
session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
decomp_ids = _find_by_type(dag, TG_DECOMPOSITION)
assert len(session_ids) == 1
assert len(pd_ids) == 1
assert len(decomp_ids) == 1
# Chain: session → pattern-decision → decomposition
assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
assert dag[decomp_ids[0]]["derived_from"] == pd_ids[0]
# Fan-out should have been called
assert next_fn.call_count == 2 # One per goal

View file

@ -223,7 +223,7 @@ class TestDerivedEntityTriples:
assert has_type(triples, self.ENTITY_URI, PROV_ENTITY)
assert has_type(triples, self.ENTITY_URI, TG_PAGE_TYPE)
def test_chunk_entity_has_chunk_type(self):
def test_chunk_entity_has_message_type(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"chunker", "1.0",

View file

@ -304,14 +304,14 @@ class TestStreamingTypes:
assert chunk.content == "thinking..."
assert chunk.end_of_message is False
assert chunk.chunk_type == "thought"
assert chunk.message_type == "thought"
def test_agent_observation_creation(self):
"""Test creating AgentObservation chunk"""
chunk = AgentObservation(content="observing...", end_of_message=False)
assert chunk.content == "observing..."
assert chunk.chunk_type == "observation"
assert chunk.message_type == "observation"
def test_agent_answer_creation(self):
"""Test creating AgentAnswer chunk"""
@ -324,7 +324,7 @@ class TestStreamingTypes:
assert chunk.content == "answer"
assert chunk.end_of_message is True
assert chunk.end_of_dialog is True
assert chunk.chunk_type == "final-answer"
assert chunk.message_type == "final-answer"
def test_rag_chunk_creation(self):
"""Test creating RAGChunk"""

View file

@ -6,6 +6,7 @@ import pytest
from unittest.mock import MagicMock, AsyncMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag, Query
from trustgraph.base import PromptResult
# Sample chunk content mapping for tests
@ -132,7 +133,7 @@ class TestQuery:
mock_rag.prompt_client = mock_prompt_client
# Mock the prompt response with concept lines
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence\ndata patterns"
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nartificial intelligence\ndata patterns")
query = Query(
rag=mock_rag,
@ -157,7 +158,7 @@ class TestQuery:
mock_rag.prompt_client = mock_prompt_client
# Mock empty response
mock_prompt_client.prompt.return_value = ""
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
query = Query(
rag=mock_rag,
@ -258,7 +259,7 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction
mock_prompt_client.prompt.return_value = "test concept"
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="test concept")
# Mock embeddings - one vector per concept
test_vectors = [[0.1, 0.2, 0.3]]
@ -273,7 +274,7 @@ class TestQuery:
expected_response = "This is the document RAG response"
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
mock_prompt_client.document_prompt.return_value = expected_response
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text=expected_response)
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@ -315,7 +316,8 @@ class TestQuery:
assert "Relevant document content" in docs
assert "Another document" in docs
assert result == expected_response
result_text, usage = result
assert result_text == expected_response
@pytest.mark.asyncio
async def test_document_rag_query_with_defaults(self, mock_fetch_chunk):
@ -325,7 +327,7 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction fallback (empty → raw query)
mock_prompt_client.prompt.return_value = ""
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
# Mock responses
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
@ -333,7 +335,7 @@ class TestQuery:
mock_match.chunk_id = "doc/c5"
mock_match.score = 0.9
mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Default response"
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="Default response")
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@ -352,7 +354,8 @@ class TestQuery:
collection="default" # Default collection
)
assert result == "Default response"
result_text, usage = result
assert result_text == "Default response"
@pytest.mark.asyncio
async def test_get_docs_with_verbose_output(self):
@ -401,7 +404,7 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction
mock_prompt_client.prompt.return_value = "verbose query test"
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="verbose query test")
# Mock responses
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
@ -409,7 +412,7 @@ class TestQuery:
mock_match.chunk_id = "doc/c7"
mock_match.score = 0.92
mock_doc_embeddings_client.query.return_value = [mock_match]
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="Verbose RAG response")
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@ -428,7 +431,8 @@ class TestQuery:
assert call_args.kwargs["query"] == "verbose query test"
assert "Verbose doc content" in call_args.kwargs["documents"]
assert result == "Verbose RAG response"
result_text, usage = result
assert result_text == "Verbose RAG response"
@pytest.mark.asyncio
async def test_get_docs_with_empty_results(self):
@ -469,11 +473,11 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction
mock_prompt_client.prompt.return_value = "query with no matching docs"
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="query with no matching docs")
mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]]
mock_doc_embeddings_client.query.return_value = []
mock_prompt_client.document_prompt.return_value = "No documents found response"
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="No documents found response")
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@ -490,7 +494,8 @@ class TestQuery:
documents=[]
)
assert result == "No documents found response"
result_text, usage = result
assert result_text == "No documents found response"
@pytest.mark.asyncio
async def test_get_vectors_with_verbose(self):
@ -525,7 +530,7 @@ class TestQuery:
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
# Mock concept extraction
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence"
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nartificial intelligence")
# Mock embeddings - one vector per concept
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]]
@ -541,7 +546,7 @@ class TestQuery:
MagicMock(chunk_id="doc/ml3", score=0.82),
]
mock_doc_embeddings_client.query.side_effect = [mock_matches_1, mock_matches_2]
mock_prompt_client.document_prompt.return_value = final_response
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text=final_response)
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@ -584,7 +589,8 @@ class TestQuery:
assert "Common ML techniques include supervised and unsupervised learning..." in docs
assert len(docs) == 3 # doc/ml2 deduplicated
assert result == final_response
result_text, usage = result
assert result_text == final_response
@pytest.mark.asyncio
async def test_get_docs_deduplicates_across_concepts(self):

View file

@ -12,6 +12,7 @@ from unittest.mock import AsyncMock
from dataclasses import dataclass
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.base import PromptResult
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
@ -89,8 +90,8 @@ def build_mock_clients():
# 1. Concept extraction
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return "return policy\nrefund"
return ""
return PromptResult(response_type="text", text="return policy\nrefund")
return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt
@ -113,8 +114,9 @@ def build_mock_clients():
fetch_chunk.side_effect = mock_fetch
# 5. Synthesis
prompt_client.document_prompt.return_value = (
"Items can be returned within 30 days for a full refund."
prompt_client.document_prompt.return_value = PromptResult(
response_type="text",
text="Items can be returned within 30 days for a full refund.",
)
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
@ -340,12 +342,12 @@ class TestDocumentRagQueryProvenance:
clients = build_mock_clients()
rag = DocumentRag(*clients)
result = await rag.query(
result_text, usage = await rag.query(
query="What is the return policy?",
explain_callback=AsyncMock(),
)
assert result == "Items can be returned within 30 days for a full refund."
assert result_text == "Items can be returned within 30 days for a full refund."
@pytest.mark.asyncio
async def test_no_explain_callback_still_works(self):
@ -353,8 +355,8 @@ class TestDocumentRagQueryProvenance:
clients = build_mock_clients()
rag = DocumentRag(*clients)
result = await rag.query(query="What is the return policy?")
assert result == "Items can be returned within 30 days for a full refund."
result_text, usage = await rag.query(query="What is the return policy?")
assert result_text == "Items can be returned within 30 days for a full refund."
@pytest.mark.asyncio
async def test_all_triples_in_retrieval_graph(self):

View file

@ -34,7 +34,7 @@ class TestDocumentRagService:
# Setup mock DocumentRag instance
mock_rag_instance = AsyncMock()
mock_document_rag_class.return_value = mock_rag_instance
mock_rag_instance.query.return_value = "test response"
mock_rag_instance.query.return_value = ("test response", {"in_token": None, "out_token": None, "model": None})
# Setup message with custom user/collection
msg = MagicMock()
@ -97,7 +97,7 @@ class TestDocumentRagService:
# Setup mock DocumentRag instance
mock_rag_instance = AsyncMock()
mock_document_rag_class.return_value = mock_rag_instance
mock_rag_instance.query.return_value = "A document about cats."
mock_rag_instance.query.return_value = ("A document about cats.", {"in_token": None, "out_token": None, "model": None})
# Setup message with non-streaming request
msg = MagicMock()
@ -130,4 +130,5 @@ class TestDocumentRagService:
assert isinstance(sent_response, DocumentRagResponse)
assert sent_response.response == "A document about cats."
assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True"
assert sent_response.end_of_session is True
assert sent_response.error is None

View file

@ -7,6 +7,7 @@ import unittest.mock
from unittest.mock import MagicMock, AsyncMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, Query
from trustgraph.base import PromptResult
class TestGraphRag:
@ -172,7 +173,7 @@ class TestQuery:
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
mock_prompt_client.prompt.return_value = "machine learning\nneural networks\n"
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nneural networks\n")
query = Query(
rag=mock_rag,
@ -196,7 +197,7 @@ class TestQuery:
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
mock_prompt_client.prompt.return_value = ""
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
query = Query(
rag=mock_rag,
@ -220,7 +221,7 @@ class TestQuery:
mock_rag.graph_embeddings_client = mock_graph_embeddings_client
# extract_concepts returns empty -> falls back to [query]
mock_prompt_client.prompt.return_value = ""
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
# embed returns one vector set for the single concept
test_vectors = [[0.1, 0.2, 0.3]]
@ -565,14 +566,14 @@ class TestQuery:
# Mock prompt responses for the multi-step process
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "extract-concepts":
return "" # Falls back to raw query
return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-scoring":
return json.dumps({"id": test_edge_id, "score": 0.9})
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "score": 0.9}])
elif prompt_name == "kg-edge-reasoning":
return json.dumps({"id": test_edge_id, "reasoning": "relevant"})
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "reasoning": "relevant"}])
elif prompt_name == "kg-synthesis":
return expected_response
return ""
return PromptResult(response_type="text", text=expected_response)
return PromptResult(response_type="text", text="")
mock_prompt_client.prompt = mock_prompt
@ -607,7 +608,8 @@ class TestQuery:
explain_callback=collect_provenance
)
assert response == expected_response
response_text, usage = response
assert response_text == expected_response
# 5 events: question, grounding, exploration, focus, synthesis
assert len(provenance_events) == 5

View file

@ -13,6 +13,7 @@ from dataclasses import dataclass
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, edge_id
from trustgraph.schema import Triple as SchemaTriple, Term, IRI, LITERAL
from trustgraph.base import PromptResult
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
@ -136,24 +137,36 @@ def build_mock_clients():
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
return prompt_responses["extract-concepts"]
return PromptResult(
response_type="text",
text=prompt_responses["extract-concepts"],
)
elif template_id == "kg-edge-scoring":
# Score all edges highly, using the IDs that GraphRag computed
edges = variables.get("knowledge", [])
return [
{"id": e["id"], "score": 10 - i}
for i, e in enumerate(edges)
]
return PromptResult(
response_type="jsonl",
objects=[
{"id": e["id"], "score": 10 - i}
for i, e in enumerate(edges)
],
)
elif template_id == "kg-edge-reasoning":
# Provide reasoning for each edge
edges = variables.get("knowledge", [])
return [
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
for i, e in enumerate(edges)
]
return PromptResult(
response_type="jsonl",
objects=[
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
for i, e in enumerate(edges)
],
)
elif template_id == "kg-synthesis":
return synthesis_answer
return ""
return PromptResult(
response_type="text",
text=synthesis_answer,
)
return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt
@ -413,13 +426,13 @@ class TestGraphRagQueryProvenance:
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
result = await rag.query(
result_text, usage = await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
assert result == "Quantum computing applies physics principles to computation."
assert result_text == "Quantum computing applies physics principles to computation."
@pytest.mark.asyncio
async def test_parent_uri_links_question_to_parent(self):
@ -450,12 +463,12 @@ class TestGraphRagQueryProvenance:
clients = build_mock_clients()
rag = GraphRag(*clients)
result = await rag.query(
result_text, usage = await rag.query(
query="What is quantum computing?",
edge_score_limit=0,
)
assert result == "Quantum computing applies physics principles to computation."
assert result_text == "Quantum computing applies physics principles to computation."
@pytest.mark.asyncio
async def test_all_triples_in_retrieval_graph(self):

View file

@ -44,7 +44,7 @@ class TestGraphRagService:
await explain_callback([], "urn:trustgraph:prov:retrieval:test")
await explain_callback([], "urn:trustgraph:prov:selection:test")
await explain_callback([], "urn:trustgraph:prov:answer:test")
return "A small domesticated mammal."
return "A small domesticated mammal.", {"in_token": None, "out_token": None, "model": None}
mock_rag_instance.query.side_effect = mock_query
@ -79,8 +79,8 @@ class TestGraphRagService:
# Execute
await processor.on_request(msg, consumer, flow)
# Verify: 6 messages sent (4 provenance + 1 chunk + 1 end_of_session)
assert mock_response_producer.send.call_count == 6
# Verify: 5 messages sent (4 provenance + 1 combined chunk with end_of_session)
assert mock_response_producer.send.call_count == 5
# First 4 messages are explain (emitted in real-time during query)
for i in range(4):
@ -88,17 +88,12 @@ class TestGraphRagService:
assert prov_msg.message_type == "explain"
assert prov_msg.explain_id is not None
# 5th message is chunk with response
# 5th message is chunk with response and end_of_session
chunk_msg = mock_response_producer.send.call_args_list[4][0][0]
assert chunk_msg.message_type == "chunk"
assert chunk_msg.response == "A small domesticated mammal."
assert chunk_msg.end_of_stream is True
# 6th message is empty chunk with end_of_session=True
close_msg = mock_response_producer.send.call_args_list[5][0][0]
assert close_msg.message_type == "chunk"
assert close_msg.response == ""
assert close_msg.end_of_session is True
assert chunk_msg.end_of_session is True
# Verify provenance triples were sent to provenance queue
assert mock_provenance_producer.send.call_count == 4
@ -187,7 +182,7 @@ class TestGraphRagService:
async def mock_query(**kwargs):
# Don't call explain_callback
return "Response text"
return "Response text", {"in_token": None, "out_token": None, "model": None}
mock_rag_instance.query.side_effect = mock_query
@ -218,17 +213,12 @@ class TestGraphRagService:
# Execute
await processor.on_request(msg, consumer, flow)
# Verify: 2 messages (chunk + empty chunk to close)
assert mock_response_producer.send.call_count == 2
# Verify: 1 combined message (chunk with end_of_session)
assert mock_response_producer.send.call_count == 1
# First is the response chunk
# Single message has response and end_of_session
chunk_msg = mock_response_producer.send.call_args_list[0][0][0]
assert chunk_msg.message_type == "chunk"
assert chunk_msg.response == "Response text"
assert chunk_msg.end_of_stream is True
# Second is empty chunk to close session
close_msg = mock_response_producer.send.call_args_list[1][0][0]
assert close_msg.message_type == "chunk"
assert close_msg.response == ""
assert close_msg.end_of_session is True
assert chunk_msg.end_of_session is True

View file

View file

@ -0,0 +1,197 @@
"""
Unit tests for KnowledgeTableStore row deserialization.
Regression coverage: a previous version of get_graph_embeddings constructed
EntityEmbeddings(vectors=ent[1]) the schema field is `vector` (singular),
so any real Cassandra row would crash on read. These tests bypass the live
Cassandra connection entirely and exercise the row -> schema conversion
with hand-built fake rows.
"""
import pytest
from unittest.mock import Mock
from trustgraph.tables.knowledge import KnowledgeTableStore
from trustgraph.schema import (
EntityEmbeddings,
GraphEmbeddings,
Triples,
Triple,
Metadata,
IRI,
LITERAL,
)
def _make_store():
"""
Build a KnowledgeTableStore without invoking __init__ (which connects
to Cassandra). Tests inject only the attributes the method under test
actually touches.
"""
return KnowledgeTableStore.__new__(KnowledgeTableStore)
class TestGetGraphEmbeddings:
@pytest.mark.asyncio
async def test_row_converts_to_entity_embeddings_with_singular_vector(self):
"""
Cassandra rows return entities as a list of [entity_tuple, vector]
pairs in row[3]. The deserializer must construct EntityEmbeddings
with `vector=` (singular) the schema field name. A previous
version used `vectors=` and TypeError'd at runtime.
"""
# Arrange — fake row matching the get_triples_stmt result shape:
# row[0..2] are unused by the method, row[3] is the entities blob
fake_row = (
None, None, None,
[
# ((value, is_uri), vector)
(("http://example.org/alice", True), [0.1, 0.2, 0.3]),
(("http://example.org/bob", True), [0.4, 0.5, 0.6]),
(("a literal entity", False), [0.7, 0.8, 0.9]),
],
)
store = _make_store()
store.cassandra = Mock()
store.cassandra.execute = Mock(return_value=[fake_row])
store.get_graph_embeddings_stmt = Mock()
received = []
async def receiver(msg):
received.append(msg)
# Act
await store.get_graph_embeddings(
user="alice",
document_id="doc-1",
receiver=receiver,
)
# Assert
store.cassandra.execute.assert_called_once_with(
store.get_graph_embeddings_stmt,
("alice", "doc-1"),
)
assert len(received) == 1
ge = received[0]
assert isinstance(ge, GraphEmbeddings)
assert isinstance(ge.metadata, Metadata)
assert ge.metadata.id == "doc-1"
assert ge.metadata.user == "alice"
assert len(ge.entities) == 3
assert all(isinstance(e, EntityEmbeddings) for e in ge.entities)
# Vectors land in the singular `vector` field — this is the
# explicit regression assertion for the original bug.
assert ge.entities[0].vector == [0.1, 0.2, 0.3]
assert ge.entities[1].vector == [0.4, 0.5, 0.6]
assert ge.entities[2].vector == [0.7, 0.8, 0.9]
# Term type round-trips through tuple_to_term
assert ge.entities[0].entity.type == IRI
assert ge.entities[0].entity.iri == "http://example.org/alice"
assert ge.entities[1].entity.type == IRI
assert ge.entities[1].entity.iri == "http://example.org/bob"
assert ge.entities[2].entity.type == LITERAL
assert ge.entities[2].entity.value == "a literal entity"
@pytest.mark.asyncio
async def test_empty_entities_blob_yields_empty_list(self):
"""row[3] being None / empty must produce a GraphEmbeddings with
no entities, not raise."""
fake_row = (None, None, None, None)
store = _make_store()
store.cassandra = Mock()
store.cassandra.execute = Mock(return_value=[fake_row])
store.get_graph_embeddings_stmt = Mock()
received = []
async def receiver(msg):
received.append(msg)
await store.get_graph_embeddings("u", "d", receiver)
assert len(received) == 1
assert received[0].entities == []
@pytest.mark.asyncio
async def test_multiple_rows_each_emit_one_message(self):
fake_rows = [
(None, None, None, [
(("http://example.org/a", True), [1.0]),
]),
(None, None, None, [
(("http://example.org/b", True), [2.0]),
]),
]
store = _make_store()
store.cassandra = Mock()
store.cassandra.execute = Mock(return_value=fake_rows)
store.get_graph_embeddings_stmt = Mock()
received = []
async def receiver(msg):
received.append(msg)
await store.get_graph_embeddings("u", "d", receiver)
assert len(received) == 2
assert received[0].entities[0].entity.iri == "http://example.org/a"
assert received[0].entities[0].vector == [1.0]
assert received[1].entities[0].entity.iri == "http://example.org/b"
assert received[1].entities[0].vector == [2.0]
class TestGetTriples:
"""Bonus: the sibling get_triples path uses the same row[3] shape and
the same Metadata construction. Cover it for parity."""
@pytest.mark.asyncio
async def test_row_converts_to_triples(self):
# row[3] is a list of (s_val, s_uri, p_val, p_uri, o_val, o_uri)
fake_row = (
None, None, None,
[
(
"http://example.org/alice", True,
"http://example.org/knows", True,
"http://example.org/bob", True,
),
],
)
store = _make_store()
store.cassandra = Mock()
store.cassandra.execute = Mock(return_value=[fake_row])
store.get_triples_stmt = Mock()
received = []
async def receiver(msg):
received.append(msg)
await store.get_triples("alice", "doc-1", receiver)
assert len(received) == 1
triples_msg = received[0]
assert isinstance(triples_msg, Triples)
assert isinstance(triples_msg.metadata, Metadata)
assert triples_msg.metadata.id == "doc-1"
assert triples_msg.metadata.user == "alice"
assert len(triples_msg.triples) == 1
t = triples_msg.triples[0]
assert isinstance(t, Triple)
assert t.s.iri == "http://example.org/alice"
assert t.p.iri == "http://example.org/knows"
assert t.o.iri == "http://example.org/bob"

View file

View file

@ -0,0 +1,66 @@
"""
Round-trip unit tests for DocumentEmbeddingsTranslator.
Regression coverage: a previous version of the decode side constructed
ChunkEmbeddings(vectors=...) the schema field is `vector` (singular),
so any real DocumentEmbeddings message would crash on decode. The encode
side already wrote `"vector"`, so encodedecode was asymmetric.
"""
import pytest
from trustgraph.messaging.translators.document_loading import (
DocumentEmbeddingsTranslator,
)
from trustgraph.schema import (
DocumentEmbeddings,
ChunkEmbeddings,
Metadata,
)
@pytest.fixture
def translator():
return DocumentEmbeddingsTranslator()
@pytest.fixture
def sample():
return DocumentEmbeddings(
metadata=Metadata(
id="doc-1",
root="",
user="alice",
collection="testcoll",
),
chunks=[
ChunkEmbeddings(chunk_id="c1", vector=[0.1, 0.2, 0.3]),
ChunkEmbeddings(chunk_id="c2", vector=[0.4, 0.5, 0.6]),
],
)
class TestDocumentEmbeddingsTranslator:
def test_encode_uses_singular_vector_key(self, translator, sample):
encoded = translator.encode(sample)
chunks = encoded["chunks"]
assert all("vector" in c for c in chunks)
assert all("vectors" not in c for c in chunks)
assert chunks[0]["vector"] == [0.1, 0.2, 0.3]
def test_roundtrip_preserves_document_embeddings(self, translator, sample):
encoded = translator.encode(sample)
decoded = translator.decode(encoded)
assert isinstance(decoded, DocumentEmbeddings)
assert isinstance(decoded.metadata, Metadata)
assert decoded.metadata.id == "doc-1"
assert decoded.metadata.user == "alice"
assert decoded.metadata.collection == "testcoll"
assert len(decoded.chunks) == 2
assert decoded.chunks[0].chunk_id == "c1"
assert decoded.chunks[0].vector == [0.1, 0.2, 0.3]
assert decoded.chunks[1].chunk_id == "c2"
assert decoded.chunks[1].vector == [0.4, 0.5, 0.6]

View file

@ -0,0 +1,153 @@
"""
Round-trip unit tests for KnowledgeRequestTranslator.
Regression coverage: a previous version of the decode side constructed
EntityEmbeddings(vectors=...) the schema field is `vector` (singular),
so any real graph-embeddings KnowledgeRequest would crash on first
message. The encode side already wrote `"vector"`, so encodedecode was
asymmetric.
These tests build a real KnowledgeRequest with graph-embeddings, encode
it, decode the result, and assert the round-trip is lossless. They also
exercise the triples path so any future schema drift in Metadata or
Triples breaks the test.
"""
import pytest
from trustgraph.messaging.translators.knowledge import KnowledgeRequestTranslator
from trustgraph.schema import (
KnowledgeRequest,
GraphEmbeddings,
EntityEmbeddings,
Triples,
Triple,
Metadata,
Term,
IRI,
)
def _term_iri(uri):
return Term(type=IRI, iri=uri)
@pytest.fixture
def translator():
return KnowledgeRequestTranslator()
@pytest.fixture
def graph_embeddings_request():
return KnowledgeRequest(
operation="put-kg-core",
user="alice",
id="doc-1",
flow="default",
collection="testcoll",
graph_embeddings=GraphEmbeddings(
metadata=Metadata(
id="doc-1",
root="",
user="alice",
collection="testcoll",
),
entities=[
EntityEmbeddings(
entity=_term_iri("http://example.org/alice"),
vector=[0.1, 0.2, 0.3],
),
EntityEmbeddings(
entity=_term_iri("http://example.org/bob"),
vector=[0.4, 0.5, 0.6],
),
],
),
)
@pytest.fixture
def triples_request():
return KnowledgeRequest(
operation="put-kg-core",
user="alice",
id="doc-1",
flow="default",
collection="testcoll",
triples=Triples(
metadata=Metadata(
id="doc-1",
root="",
user="alice",
collection="testcoll",
),
triples=[
Triple(
s=_term_iri("http://example.org/alice"),
p=_term_iri("http://example.org/knows"),
o=_term_iri("http://example.org/bob"),
),
],
),
)
class TestKnowledgeRequestTranslatorGraphEmbeddings:
def test_encode_produces_singular_vector_key(
self, translator, graph_embeddings_request,
):
"""The wire key must be `vector`, never `vectors`."""
encoded = translator.encode(graph_embeddings_request)
entities = encoded["graph-embeddings"]["entities"]
assert all("vector" in e for e in entities)
assert all("vectors" not in e for e in entities)
assert entities[0]["vector"] == [0.1, 0.2, 0.3]
def test_roundtrip_preserves_graph_embeddings(
self, translator, graph_embeddings_request,
):
"""encode -> decode must be lossless for the GE branch."""
encoded = translator.encode(graph_embeddings_request)
decoded = translator.decode(encoded)
assert isinstance(decoded, KnowledgeRequest)
assert decoded.operation == "put-kg-core"
assert decoded.user == "alice"
assert decoded.id == "doc-1"
assert decoded.flow == "default"
assert decoded.collection == "testcoll"
assert decoded.graph_embeddings is not None
ge = decoded.graph_embeddings
assert isinstance(ge, GraphEmbeddings)
assert isinstance(ge.metadata, Metadata)
assert ge.metadata.id == "doc-1"
assert ge.metadata.user == "alice"
assert ge.metadata.collection == "testcoll"
assert len(ge.entities) == 2
assert ge.entities[0].vector == [0.1, 0.2, 0.3]
assert ge.entities[1].vector == [0.4, 0.5, 0.6]
assert ge.entities[0].entity.iri == "http://example.org/alice"
assert ge.entities[1].entity.iri == "http://example.org/bob"
class TestKnowledgeRequestTranslatorTriples:
def test_roundtrip_preserves_triples(self, translator, triples_request):
encoded = translator.encode(triples_request)
decoded = translator.decode(encoded)
assert isinstance(decoded, KnowledgeRequest)
assert decoded.triples is not None
assert isinstance(decoded.triples.metadata, Metadata)
assert decoded.triples.metadata.id == "doc-1"
assert decoded.triples.metadata.user == "alice"
assert decoded.triples.metadata.collection == "testcoll"
assert len(decoded.triples.triples) == 1
t = decoded.triples.triples[0]
assert t.s.iri == "http://example.org/alice"
assert t.p.iri == "http://example.org/knows"
assert t.o.iri == "http://example.org/bob"

View file

@ -9,7 +9,7 @@ from .streaming_assertions import (
assert_streaming_content_matches,
assert_no_empty_chunks,
assert_streaming_error_handled,
assert_chunk_types_valid,
assert_message_types_valid,
assert_streaming_latency_acceptable,
assert_callback_invoked,
)
@ -23,7 +23,7 @@ __all__ = [
"assert_streaming_content_matches",
"assert_no_empty_chunks",
"assert_streaming_error_handled",
"assert_chunk_types_valid",
"assert_message_types_valid",
"assert_streaming_latency_acceptable",
"assert_callback_invoked",
]

View file

@ -20,14 +20,14 @@ def assert_streaming_chunks_valid(chunks: List[Any], min_chunks: int = 1):
assert all(chunk is not None for chunk in chunks), "All chunks should be non-None"
def assert_streaming_sequence(chunks: List[Dict[str, Any]], expected_sequence: List[str], key: str = "chunk_type"):
def assert_streaming_sequence(chunks: List[Dict[str, Any]], expected_sequence: List[str], key: str = "message_type"):
"""
Assert that streaming chunks follow an expected sequence.
Args:
chunks: List of chunk dictionaries
expected_sequence: Expected sequence of chunk types/values
key: Dictionary key to check (default: "chunk_type")
key: Dictionary key to check (default: "message_type")
"""
actual_sequence = [chunk.get(key) for chunk in chunks if key in chunk]
assert actual_sequence == expected_sequence, \
@ -39,7 +39,7 @@ def assert_agent_streaming_chunks(chunks: List[Dict[str, Any]]):
Assert that agent streaming chunks have valid structure.
Validates:
- All chunks have chunk_type field
- All chunks have message_type field
- All chunks have content field
- All chunks have end_of_message field
- All chunks have end_of_dialog field
@ -51,15 +51,15 @@ def assert_agent_streaming_chunks(chunks: List[Dict[str, Any]]):
assert len(chunks) > 0, "Expected at least one chunk"
for i, chunk in enumerate(chunks):
assert "chunk_type" in chunk, f"Chunk {i} missing chunk_type"
assert "message_type" in chunk, f"Chunk {i} missing message_type"
assert "content" in chunk, f"Chunk {i} missing content"
assert "end_of_message" in chunk, f"Chunk {i} missing end_of_message"
assert "end_of_dialog" in chunk, f"Chunk {i} missing end_of_dialog"
# Validate chunk_type values
# Validate message_type values
valid_types = ["thought", "action", "observation", "final-answer"]
assert chunk["chunk_type"] in valid_types, \
f"Invalid chunk_type '{chunk['chunk_type']}' at index {i}"
assert chunk["message_type"] in valid_types, \
f"Invalid message_type '{chunk['message_type']}' at index {i}"
# Last chunk should signal end of dialog
assert chunks[-1]["end_of_dialog"] is True, \
@ -175,7 +175,7 @@ def assert_streaming_error_handled(chunks: List[Dict[str, Any]], error_flag: str
"Error chunk should have completion flag set to True"
def assert_chunk_types_valid(chunks: List[Dict[str, Any]], valid_types: List[str], type_key: str = "chunk_type"):
def assert_message_types_valid(chunks: List[Dict[str, Any]], valid_types: List[str], type_key: str = "message_type"):
"""
Assert that all chunk types are from a valid set.
@ -185,9 +185,9 @@ def assert_chunk_types_valid(chunks: List[Dict[str, Any]], valid_types: List[str
type_key: Dictionary key for chunk type
"""
for i, chunk in enumerate(chunks):
chunk_type = chunk.get(type_key)
assert chunk_type in valid_types, \
f"Chunk {i} has invalid type '{chunk_type}', expected one of {valid_types}"
message_type = chunk.get(type_key)
assert message_type in valid_types, \
f"Chunk {i} has invalid type '{message_type}', expected one of {valid_types}"
def assert_streaming_latency_acceptable(chunk_timestamps: List[float], max_gap_seconds: float = 5.0):

View file

@ -15,6 +15,7 @@ dependencies = [
"requests",
"python-logging-loki",
"pika",
"pyyaml",
]
classifiers = [
"Programming Language :: Python :: 3",
@ -24,6 +25,9 @@ classifiers = [
[project.urls]
Homepage = "https://github.com/trustgraph-ai/trustgraph"
[project.scripts]
processor-group = "trustgraph.base.processor_group:run"
[tool.setuptools.packages.find]
include = ["trustgraph*"]
@ -31,4 +35,4 @@ include = ["trustgraph*"]
"trustgraph.i18n.packs" = ["*.json"]
[tool.setuptools.dynamic]
version = {attr = "trustgraph.base_version.__version__"}
version = {attr = "trustgraph.base_version.__version__"}

View file

@ -107,6 +107,7 @@ from .types import (
AgentObservation,
AgentAnswer,
RAGChunk,
TextCompletionResult,
ProvenanceEvent,
)
@ -185,6 +186,7 @@ __all__ = [
"AgentObservation",
"AgentAnswer",
"RAGChunk",
"TextCompletionResult",
"ProvenanceEvent",
# Exceptions

View file

@ -14,6 +14,8 @@ import aiohttp
import json
from typing import Optional, Dict, Any, List
from . types import TextCompletionResult
from . exceptions import ProtocolException, ApplicationException
@ -434,12 +436,11 @@ class AsyncFlowInstance:
return await self.request("agent", request_data)
async def text_completion(self, system: str, prompt: str, **kwargs: Any) -> str:
async def text_completion(self, system: str, prompt: str, **kwargs: Any) -> TextCompletionResult:
"""
Generate text completion (non-streaming).
Generates a text response from an LLM given a system prompt and user prompt.
Returns the complete response text.
Note: This method does not support streaming. For streaming text generation,
use AsyncSocketFlowInstance.text_completion() instead.
@ -450,19 +451,19 @@ class AsyncFlowInstance:
**kwargs: Additional service-specific parameters
Returns:
str: Complete generated text response
TextCompletionResult: Result with text, in_token, out_token, model
Example:
```python
async_flow = await api.async_flow()
flow = async_flow.id("default")
# Generate text
response = await flow.text_completion(
result = await flow.text_completion(
system="You are a helpful assistant.",
prompt="Explain quantum computing in simple terms."
)
print(response)
print(result.text)
print(f"Tokens: {result.in_token} in, {result.out_token} out")
```
"""
request_data = {
@ -473,7 +474,12 @@ class AsyncFlowInstance:
request_data.update(kwargs)
result = await self.request("text-completion", request_data)
return result.get("response", "")
return TextCompletionResult(
text=result.get("response", ""),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
async def graph_rag(self, query: str, user: str, collection: str,
max_subgraph_size: int = 1000, max_subgraph_count: int = 5,

View file

@ -4,7 +4,7 @@ import asyncio
import websockets
from typing import Optional, Dict, Any, AsyncIterator, Union
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, TextCompletionResult
from . exceptions import ProtocolException, ApplicationException
@ -178,30 +178,32 @@ class AsyncSocketClient:
def _parse_chunk(self, resp: Dict[str, Any]):
"""Parse response chunk into appropriate type. Returns None for non-content messages."""
chunk_type = resp.get("chunk_type")
message_type = resp.get("message_type")
# Handle new GraphRAG message format with message_type
if message_type == "provenance":
return None
if chunk_type == "thought":
if message_type == "thought":
return AgentThought(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False)
)
elif chunk_type == "observation":
elif message_type == "observation":
return AgentObservation(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False)
)
elif chunk_type == "answer" or chunk_type == "final-answer":
elif message_type == "answer" or message_type == "final-answer":
return AgentAnswer(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False),
end_of_dialog=resp.get("end_of_dialog", False)
end_of_dialog=resp.get("end_of_dialog", False),
in_token=resp.get("in_token"),
out_token=resp.get("out_token"),
model=resp.get("model"),
)
elif chunk_type == "action":
elif message_type == "action":
return AgentThought(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False)
@ -211,7 +213,10 @@ class AsyncSocketClient:
return RAGChunk(
content=content,
end_of_stream=resp.get("end_of_stream", False),
error=None
error=None,
in_token=resp.get("in_token"),
out_token=resp.get("out_token"),
model=resp.get("model"),
)
async def aclose(self):
@ -269,7 +274,11 @@ class AsyncSocketFlowInstance:
return await self.client._send_request("agent", self.flow_id, request)
async def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs):
"""Text completion with optional streaming"""
"""Text completion with optional streaming.
Non-streaming: returns a TextCompletionResult with text and token counts.
Streaming: returns an async iterator of RAGChunk (with token counts on the final chunk).
"""
request = {
"system": system,
"prompt": prompt,
@ -281,13 +290,18 @@ class AsyncSocketFlowInstance:
return self._text_completion_streaming(request)
else:
result = await self.client._send_request("text-completion", self.flow_id, request)
return result.get("response", "")
return TextCompletionResult(
text=result.get("response", ""),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
async def _text_completion_streaming(self, request):
"""Helper for streaming text completion"""
"""Helper for streaming text completion. Yields RAGChunk objects."""
async for chunk in self.client._send_request_streaming("text-completion", self.flow_id, request):
if hasattr(chunk, 'content'):
yield chunk.content
if isinstance(chunk, RAGChunk):
yield chunk
async def graph_rag(self, query: str, user: str, collection: str,
max_subgraph_size: int = 1000, max_subgraph_count: int = 5,

View file

@ -11,7 +11,7 @@ import base64
from .. knowledge import hash, Uri, Literal, QuotedTriple
from .. schema import IRI, LITERAL, TRIPLE
from . types import Triple
from . types import Triple, TextCompletionResult
from . exceptions import ProtocolException
@ -360,16 +360,17 @@ class FlowInstance:
prompt: User prompt/question
Returns:
str: Generated response text
TextCompletionResult: Result with text, in_token, out_token, model
Example:
```python
flow = api.flow().id("default")
response = flow.text_completion(
result = flow.text_completion(
system="You are a helpful assistant",
prompt="What is quantum computing?"
)
print(response)
print(result.text)
print(f"Tokens: {result.in_token} in, {result.out_token} out")
```
"""
@ -379,10 +380,17 @@ class FlowInstance:
"prompt": prompt
}
return self.request(
result = self.request(
"service/text-completion",
input
)["response"]
)
return TextCompletionResult(
text=result.get("response", ""),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
def agent(self, question, user="trustgraph", state=None, group=None, history=None):
"""
@ -498,10 +506,17 @@ class FlowInstance:
"edge-limit": edge_limit,
}
return self.request(
result = self.request(
"service/graph-rag",
input
)["response"]
)
return TextCompletionResult(
text=result.get("response", ""),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
def document_rag(
self, query, user="trustgraph", collection="default",
@ -543,10 +558,17 @@ class FlowInstance:
"doc-limit": doc_limit,
}
return self.request(
result = self.request(
"service/document-rag",
input
)["response"]
)
return TextCompletionResult(
text=result.get("response", ""),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
def embeddings(self, texts):
"""

View file

@ -14,7 +14,7 @@ import websockets
from typing import Optional, Dict, Any, Iterator, Union, List
from threading import Lock
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk, ProvenanceEvent
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk, ProvenanceEvent, TextCompletionResult
from . exceptions import ProtocolException, raise_from_error_dict
@ -360,41 +360,36 @@ class SocketClient:
def _parse_chunk(self, resp: Dict[str, Any], include_provenance: bool = False) -> Optional[StreamingChunk]:
"""Parse response chunk into appropriate type. Returns None for non-content messages."""
chunk_type = resp.get("chunk_type")
message_type = resp.get("message_type")
# Handle GraphRAG/DocRAG message format with message_type
if message_type == "explain":
if include_provenance:
return self._build_provenance_event(resp)
return None
# Handle Agent message format with chunk_type="explain"
if chunk_type == "explain":
if include_provenance:
return self._build_provenance_event(resp)
return None
if chunk_type == "thought":
if message_type == "thought":
return AgentThought(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False),
message_id=resp.get("message_id", ""),
)
elif chunk_type == "observation":
elif message_type == "observation":
return AgentObservation(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False),
message_id=resp.get("message_id", ""),
)
elif chunk_type == "answer" or chunk_type == "final-answer":
elif message_type == "answer" or message_type == "final-answer":
return AgentAnswer(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False),
end_of_dialog=resp.get("end_of_dialog", False),
message_id=resp.get("message_id", ""),
in_token=resp.get("in_token"),
out_token=resp.get("out_token"),
model=resp.get("model"),
)
elif chunk_type == "action":
elif message_type == "action":
return AgentThought(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False)
@ -404,7 +399,10 @@ class SocketClient:
return RAGChunk(
content=content,
end_of_stream=resp.get("end_of_stream", False),
error=None
error=None,
in_token=resp.get("in_token"),
out_token=resp.get("out_token"),
model=resp.get("model"),
)
def _build_provenance_event(self, resp: Dict[str, Any]) -> ProvenanceEvent:
@ -543,8 +541,12 @@ class SocketFlowInstance:
streaming=True, include_provenance=True
)
def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[str, Iterator[str]]:
"""Execute text completion with optional streaming."""
def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
"""Execute text completion with optional streaming.
Non-streaming: returns a TextCompletionResult with text and token counts.
Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
"""
request = {
"system": system,
"prompt": prompt,
@ -557,12 +559,17 @@ class SocketFlowInstance:
if streaming:
return self._text_completion_generator(result)
else:
return result.get("response", "")
return TextCompletionResult(
text=result.get("response", ""),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
def _text_completion_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]:
def _text_completion_generator(self, result: Iterator[StreamingChunk]) -> Iterator[RAGChunk]:
for chunk in result:
if hasattr(chunk, 'content'):
yield chunk.content
if isinstance(chunk, RAGChunk):
yield chunk
def graph_rag(
self,
@ -577,8 +584,12 @@ class SocketFlowInstance:
edge_limit: int = 25,
streaming: bool = False,
**kwargs: Any
) -> Union[str, Iterator[str]]:
"""Execute graph-based RAG query with optional streaming."""
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
"""Execute graph-based RAG query with optional streaming.
Non-streaming: returns a TextCompletionResult with text and token counts.
Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
"""
request = {
"query": query,
"user": user,
@ -598,7 +609,12 @@ class SocketFlowInstance:
if streaming:
return self._rag_generator(result)
else:
return result.get("response", "")
return TextCompletionResult(
text=result.get("response", ""),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
def graph_rag_explain(
self,
@ -642,8 +658,12 @@ class SocketFlowInstance:
doc_limit: int = 10,
streaming: bool = False,
**kwargs: Any
) -> Union[str, Iterator[str]]:
"""Execute document-based RAG query with optional streaming."""
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
"""Execute document-based RAG query with optional streaming.
Non-streaming: returns a TextCompletionResult with text and token counts.
Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
"""
request = {
"query": query,
"user": user,
@ -658,7 +678,12 @@ class SocketFlowInstance:
if streaming:
return self._rag_generator(result)
else:
return result.get("response", "")
return TextCompletionResult(
text=result.get("response", ""),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
def document_rag_explain(
self,
@ -684,10 +709,10 @@ class SocketFlowInstance:
streaming=True, include_provenance=True
)
def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]:
def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[RAGChunk]:
for chunk in result:
if hasattr(chunk, 'content'):
yield chunk.content
if isinstance(chunk, RAGChunk):
yield chunk
def prompt(
self,
@ -695,8 +720,12 @@ class SocketFlowInstance:
variables: Dict[str, str],
streaming: bool = False,
**kwargs: Any
) -> Union[str, Iterator[str]]:
"""Execute a prompt template with optional streaming."""
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
"""Execute a prompt template with optional streaming.
Non-streaming: returns a TextCompletionResult with text and token counts.
Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
"""
request = {
"id": id,
"variables": variables,
@ -709,7 +738,12 @@ class SocketFlowInstance:
if streaming:
return self._rag_generator(result)
else:
return result.get("response", "")
return TextCompletionResult(
text=result.get("text", result.get("response", "")),
in_token=result.get("in_token"),
out_token=result.get("out_token"),
model=result.get("model"),
)
def graph_embeddings_query(
self,

View file

@ -149,10 +149,10 @@ class AgentThought(StreamingChunk):
Attributes:
content: Agent's thought text
end_of_message: True if this completes the current thought
chunk_type: Always "thought"
message_type: Always "thought"
message_id: Provenance URI of the entity being built
"""
chunk_type: str = "thought"
message_type: str = "thought"
message_id: str = ""
@dataclasses.dataclass
@ -166,10 +166,10 @@ class AgentObservation(StreamingChunk):
Attributes:
content: Observation text describing tool results
end_of_message: True if this completes the current observation
chunk_type: Always "observation"
message_type: Always "observation"
message_id: Provenance URI of the entity being built
"""
chunk_type: str = "observation"
message_type: str = "observation"
message_id: str = ""
@dataclasses.dataclass
@ -184,11 +184,14 @@ class AgentAnswer(StreamingChunk):
content: Answer text
end_of_message: True if this completes the current answer segment
end_of_dialog: True if this completes the entire agent interaction
chunk_type: Always "final-answer"
message_type: Always "final-answer"
"""
chunk_type: str = "final-answer"
message_type: str = "final-answer"
end_of_dialog: bool = False
message_id: str = ""
in_token: Optional[int] = None
out_token: Optional[int] = None
model: Optional[str] = None
@dataclasses.dataclass
class RAGChunk(StreamingChunk):
@ -202,11 +205,37 @@ class RAGChunk(StreamingChunk):
content: Generated text content
end_of_stream: True if this is the final chunk of the stream
error: Optional error information if an error occurred
chunk_type: Always "rag"
in_token: Input token count (populated on the final chunk, 0 otherwise)
out_token: Output token count (populated on the final chunk, 0 otherwise)
model: Model identifier (populated on the final chunk, empty otherwise)
message_type: Always "rag"
"""
chunk_type: str = "rag"
message_type: str = "rag"
end_of_stream: bool = False
error: Optional[Dict[str, str]] = None
in_token: Optional[int] = None
out_token: Optional[int] = None
model: Optional[str] = None
@dataclasses.dataclass
class TextCompletionResult:
"""
Result from a text completion request.
Returned by text_completion() in both streaming and non-streaming modes.
In streaming mode, text is None (chunks are delivered via the iterator).
In non-streaming mode, text contains the complete response.
Attributes:
text: Complete response text (None in streaming mode)
in_token: Input token count (None if not available)
out_token: Output token count (None if not available)
model: Model identifier (None if not available)
"""
text: Optional[str]
in_token: Optional[int] = None
out_token: Optional[int] = None
model: Optional[str] = None
@dataclasses.dataclass
class ProvenanceEvent:

View file

@ -18,8 +18,10 @@ from . librarian_client import LibrarianClient
from . chunking_service import ChunkingService
from . embeddings_service import EmbeddingsService
from . embeddings_client import EmbeddingsClientSpec
from . text_completion_client import TextCompletionClientSpec
from . prompt_client import PromptClientSpec
from . text_completion_client import (
TextCompletionClientSpec, TextCompletionClient, TextCompletionResult,
)
from . prompt_client import PromptClientSpec, PromptClient, PromptResult
from . triples_store_service import TriplesStoreService
from . graph_embeddings_store_service import GraphEmbeddingsStoreService
from . document_embeddings_store_service import DocumentEmbeddingsStoreService

View file

@ -30,19 +30,19 @@ class AgentClient(RequestResponse):
raise RuntimeError(resp.error.message)
# Handle thought chunks
if resp.chunk_type == 'thought':
if resp.message_type == 'thought':
if think:
await think(resp.content, resp.end_of_message)
return False # Continue receiving
# Handle observation chunks
if resp.chunk_type == 'observation':
if resp.message_type == 'observation':
if observe:
await observe(resp.content, resp.end_of_message)
return False # Continue receiving
# Handle answer chunks
if resp.chunk_type == 'answer':
if resp.message_type == 'answer':
if resp.content:
accumulated_answer.append(resp.content)
if answer_callback:

View file

@ -58,6 +58,18 @@ class BackendProducer(Protocol):
class BackendConsumer(Protocol):
"""Protocol for backend-specific consumer."""
def ensure_connected(self) -> None:
"""
Eagerly establish the underlying connection and bind the queue.
Backends that lazily connect on first receive() must implement this
so that callers can guarantee the consumer is fully bound and
therefore able to receive responses before any related request is
published. Backends that connect at construction time may make this
a no-op.
"""
...
def receive(self, timeout_millis: int = 2000) -> Message:
"""
Receive a message from the topic.

View file

@ -88,14 +88,14 @@ class ChunkingService(FlowProcessor):
chunk_overlap = default_chunk_overlap
try:
cs = flow.parameters.get("chunk-size")
cs = flow("chunk-size")
if cs is not None:
chunk_size = int(cs)
except Exception as e:
logger.warning(f"Could not parse chunk-size parameter: {e}")
try:
co = flow.parameters.get("chunk-overlap")
co = flow("chunk-overlap")
if co is not None:
chunk_overlap = int(co)
except Exception as e:

View file

@ -8,12 +8,51 @@ ensuring consistent log formats, levels, and command-line arguments.
Supports dual output to console and Loki for centralized log aggregation.
"""
import contextvars
import logging
import logging.handlers
from queue import Queue
import os
# The current processor id for this task context. Read by
# _ProcessorIdFilter to stamp every LogRecord with its owning
# processor, and read by logging_loki's emitter via record.tags
# to label log lines in Loki. ContextVar so asyncio subtasks
# inherit their parent supervisor's processor id automatically.
current_processor_id = contextvars.ContextVar(
"current_processor_id", default="unknown"
)
def set_processor_id(pid):
"""Set the processor id for the current task context.
All subsequent log records emitted from this task and any
asyncio tasks spawned from it will be tagged with this id
in the console format and in Loki labels.
"""
current_processor_id.set(pid)
class _ProcessorIdFilter(logging.Filter):
"""Stamps every LogRecord with processor_id from the contextvar.
Attaches two fields to each record:
record.processor_id used by the console format string
record.tags merged into Loki labels by logging_loki's
emitter (it reads record.tags and combines
with the handler's static tags)
"""
def filter(self, record):
pid = current_processor_id.get()
record.processor_id = pid
existing = getattr(record, "tags", None) or {}
record.tags = {**existing, "processor": pid}
return True
def add_logging_args(parser):
"""
Add standard logging arguments to an argument parser.
@ -87,12 +126,15 @@ def setup_logging(args):
loki_url = args.get('loki_url', 'http://loki:3100/loki/api/v1/push')
loki_username = args.get('loki_username')
loki_password = args.get('loki_password')
processor_id = args.get('id') # Processor identity (e.g., "config-svc", "text-completion")
try:
from logging_loki import LokiHandler
# Create Loki handler with optional authentication and processor label
# Create Loki handler with optional authentication. The
# processor label is NOT baked in here — it's stamped onto
# each record by _ProcessorIdFilter reading the task-local
# contextvar, and logging_loki's emitter reads record.tags
# to build per-record Loki labels.
loki_handler_kwargs = {
'url': loki_url,
'version': "1",
@ -101,10 +143,6 @@ def setup_logging(args):
if loki_username and loki_password:
loki_handler_kwargs['auth'] = (loki_username, loki_password)
# Add processor label if available (for consistency with Prometheus metrics)
if processor_id:
loki_handler_kwargs['tags'] = {'processor': processor_id}
loki_handler = LokiHandler(**loki_handler_kwargs)
# Wrap in QueueHandler for non-blocking operation
@ -133,23 +171,44 @@ def setup_logging(args):
print(f"WARNING: Failed to setup Loki logging: {e}")
print("Continuing with console-only logging")
# Get processor ID for log formatting (use 'unknown' if not available)
processor_id = args.get('id', 'unknown')
# Configure logging with all handlers
# Use processor ID as the primary identifier in logs
# Configure logging with all handlers. The processor id comes
# from _ProcessorIdFilter (via contextvar) and is injected into
# each record as record.processor_id. The format string reads
# that attribute on every emit.
logging.basicConfig(
level=getattr(logging, log_level.upper()),
format=f'%(asctime)s - {processor_id} - %(levelname)s - %(message)s',
format='%(asctime)s - %(processor_id)s - %(levelname)s - %(message)s',
handlers=handlers,
force=True # Force reconfiguration if already configured
)
# Prevent recursive logging from Loki's HTTP client
if loki_enabled and queue_listener:
# Disable urllib3 logging to prevent infinite loop
logging.getLogger('urllib3').setLevel(logging.WARNING)
logging.getLogger('urllib3.connectionpool').setLevel(logging.WARNING)
# Attach the processor-id filter to every handler so all records
# passing through any sink get stamped (console, queue→loki,
# future handlers). Filters on handlers run regardless of which
# logger originated the record, so logs from pika, cassandra,
# processor code, etc. all pass through it.
processor_filter = _ProcessorIdFilter()
for h in handlers:
h.addFilter(processor_filter)
# Seed the contextvar from --id if one was supplied. In group
# mode --id isn't present; the processor_group supervisor sets
# it per task. In standalone mode AsyncProcessor.launch provides
# it via argparse default.
if args.get('id'):
set_processor_id(args['id'])
# Silence noisy third-party library loggers. These emit INFO-level
# chatter (connection churn, channel open/close, driver warnings) that
# drowns the useful signal and can't be attributed to a specific
# processor anyway. WARNING and above still propagate.
for noisy in (
'pika',
'cassandra',
'urllib3',
'urllib3.connectionpool',
):
logging.getLogger(noisy).setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
logger.info(f"Logging configured with level: {log_level}")

View file

@ -0,0 +1,204 @@
# Multi-processor group runner. Runs multiple AsyncProcessor descendants
# as concurrent tasks inside a single process, sharing one event loop,
# one Prometheus HTTP server, and one pub/sub backend pool.
#
# Intended for dev and resource-constrained deployments. Scale deployments
# should continue to use per-processor endpoints.
#
# Group config is a YAML or JSON file with shape:
#
# processors:
# - class: trustgraph.extract.kg.definitions.extract.Processor
# params:
# id: kg-extract-definitions
# triples_batch_size: 1000
# - class: trustgraph.chunking.recursive.Processor
# params:
# id: chunker-recursive
#
# Each entry's params are passed directly to the class constructor alongside
# the shared taskgroup. Defaults live inside each processor class.
import argparse
import asyncio
import importlib
import json
import logging
import time
from prometheus_client import start_http_server
from . logging import add_logging_args, setup_logging, set_processor_id
logger = logging.getLogger(__name__)
def _load_config(path):
with open(path) as f:
text = f.read()
if path.endswith((".yaml", ".yml")):
import yaml
return yaml.safe_load(text)
return json.loads(text)
def _resolve_class(dotted):
module_path, _, class_name = dotted.rpartition(".")
if not module_path:
raise ValueError(
f"Processor class must be a dotted path, got {dotted!r}"
)
module = importlib.import_module(module_path)
return getattr(module, class_name)
RESTART_DELAY_SECONDS = 4
async def _supervise(entry):
"""Run one processor with its own nested TaskGroup, restarting on any
failure. Each processor is isolated from its siblings a crash here
does not propagate to the outer group."""
pid = entry["params"]["id"]
class_path = entry["class"]
# Stamp the contextvar for this supervisor task. Every log
# record emitted from this task — and from any inner TaskGroup
# child created by the processor — inherits this id via
# contextvar propagation. Siblings in the outer group set
# their own id in their own task context and do not interfere.
set_processor_id(pid)
while True:
try:
async with asyncio.TaskGroup() as inner_tg:
cls = _resolve_class(class_path)
params = dict(entry.get("params", {}))
params["taskgroup"] = inner_tg
logger.info(f"Starting {class_path} as {pid}")
p = cls(**params)
await p.start()
inner_tg.create_task(p.run())
# Clean exit — processor's run() returned without raising.
# Treat as a transient shutdown and restart, matching the
# behaviour of per-container `restart: on-failure`.
logger.warning(
f"Processor {pid} exited cleanly, will restart"
)
except asyncio.CancelledError:
logger.info(f"Processor {pid} cancelled")
raise
except BaseExceptionGroup as eg:
for e in eg.exceptions:
logger.error(
f"Processor {pid} failure: {type(e).__name__}: {e}",
exc_info=e,
)
except Exception as e:
logger.error(
f"Processor {pid} failure: {type(e).__name__}: {e}",
exc_info=True,
)
logger.info(
f"Restarting {pid} in {RESTART_DELAY_SECONDS}s..."
)
await asyncio.sleep(RESTART_DELAY_SECONDS)
async def run_group(config):
entries = config.get("processors", [])
if not entries:
raise RuntimeError("Group config has no processors")
seen_ids = set()
for entry in entries:
pid = entry.get("params", {}).get("id")
if pid is None:
raise RuntimeError(
f"Entry {entry.get('class')!r} missing params.id — "
f"required for metrics labelling"
)
if pid in seen_ids:
raise RuntimeError(f"Duplicate processor id {pid!r} in group")
seen_ids.add(pid)
async with asyncio.TaskGroup() as outer_tg:
for entry in entries:
outer_tg.create_task(_supervise(entry))
def run():
parser = argparse.ArgumentParser(
prog="processor-group",
description="Run multiple processors as tasks in one process",
)
parser.add_argument(
"-c", "--config",
required=True,
help="Path to group config file (JSON or YAML)",
)
parser.add_argument(
"--metrics",
action=argparse.BooleanOptionalAction,
default=True,
help="Metrics enabled (default: true)",
)
parser.add_argument(
"-P", "--metrics-port",
type=int,
default=8000,
help="Prometheus metrics port (default: 8000)",
)
add_logging_args(parser)
args = vars(parser.parse_args())
setup_logging(args)
config = _load_config(args["config"])
if args["metrics"]:
start_http_server(args["metrics_port"])
while True:
logger.info("Starting group...")
try:
asyncio.run(run_group(config))
except KeyboardInterrupt:
logger.info("Keyboard interrupt.")
return
except ExceptionGroup as e:
logger.error("Exception group:")
for se in e.exceptions:
logger.error(f" Type: {type(se)}")
logger.error(f" Exception: {se}", exc_info=se)
except Exception as e:
logger.error(f"Type: {type(e)}")
logger.error(f"Exception: {e}", exc_info=True)
logger.warning("Will retry...")
time.sleep(4)
logger.info("Retrying...")

View file

@ -1,10 +1,22 @@
import json
import asyncio
from dataclasses import dataclass
from typing import Optional, Any
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import PromptRequest, PromptResponse
@dataclass
class PromptResult:
response_type: str # "text", "json", or "jsonl"
text: Optional[str] = None # populated for "text"
object: Any = None # populated for "json"
objects: Optional[list] = None # populated for "jsonl"
in_token: Optional[int] = None
out_token: Optional[int] = None
model: Optional[str] = None
class PromptClient(RequestResponse):
async def prompt(self, id, variables, timeout=600, streaming=False, chunk_callback=None):
@ -26,17 +38,40 @@ class PromptClient(RequestResponse):
if resp.error:
raise RuntimeError(resp.error.message)
if resp.text: return resp.text
if resp.text:
return PromptResult(
response_type="text",
text=resp.text,
in_token=resp.in_token,
out_token=resp.out_token,
model=resp.model,
)
return json.loads(resp.object)
parsed = json.loads(resp.object)
if isinstance(parsed, list):
return PromptResult(
response_type="jsonl",
objects=parsed,
in_token=resp.in_token,
out_token=resp.out_token,
model=resp.model,
)
return PromptResult(
response_type="json",
object=parsed,
in_token=resp.in_token,
out_token=resp.out_token,
model=resp.model,
)
else:
last_text = ""
last_object = None
last_resp = None
async def forward_chunks(resp):
nonlocal last_text, last_object
nonlocal last_resp
if resp.error:
raise RuntimeError(resp.error.message)
@ -44,14 +79,13 @@ class PromptClient(RequestResponse):
end_stream = getattr(resp, 'end_of_stream', False)
if resp.text is not None:
last_text = resp.text
if chunk_callback:
if asyncio.iscoroutinefunction(chunk_callback):
await chunk_callback(resp.text, end_stream)
else:
chunk_callback(resp.text, end_stream)
elif resp.object:
last_object = resp.object
last_resp = resp
return end_stream
@ -70,10 +104,36 @@ class PromptClient(RequestResponse):
timeout=timeout
)
if last_text:
return last_text
if last_resp is None:
return PromptResult(response_type="text")
return json.loads(last_object) if last_object else None
if last_resp.object:
parsed = json.loads(last_resp.object)
if isinstance(parsed, list):
return PromptResult(
response_type="jsonl",
objects=parsed,
in_token=last_resp.in_token,
out_token=last_resp.out_token,
model=last_resp.model,
)
return PromptResult(
response_type="json",
object=parsed,
in_token=last_resp.in_token,
out_token=last_resp.out_token,
model=last_resp.model,
)
return PromptResult(
response_type="text",
text=last_resp.text,
in_token=last_resp.in_token,
out_token=last_resp.out_token,
model=last_resp.model,
)
async def extract_definitions(self, text, timeout=600):
return await self.prompt(
@ -152,4 +212,3 @@ class PromptClientSpec(RequestResponseSpec):
response_schema = PromptResponse,
impl = PromptClient,
)

View file

@ -72,6 +72,16 @@ class PulsarBackendConsumer:
self._consumer = pulsar_consumer
self._schema_cls = schema_cls
def ensure_connected(self) -> None:
"""No-op for Pulsar.
PulsarBackend.create_consumer() calls client.subscribe() which is
synchronous and returns a fully-subscribed consumer, so the
consumer is already ready by the time this object is constructed.
Defined for parity with the BackendConsumer protocol used by
Subscriber.start()'s readiness barrier."""
pass
def receive(self, timeout_millis: int = 2000) -> Message:
"""Receive a message. Raises TimeoutError if no message available."""
try:

View file

@ -214,16 +214,43 @@ class RabbitMQBackendConsumer:
and self._channel.is_open
)
def ensure_connected(self) -> None:
"""Eagerly declare and bind the queue.
Without this, the queue is only declared lazily on the first
receive() call. For request/response with ephemeral per-subscriber
response queues that is a race: a request published before the
response queue is bound will have its reply silently dropped by
the broker. Subscriber.start() calls this so callers get a hard
readiness barrier."""
if not self._is_alive():
self._connect()
def receive(self, timeout_millis: int = 2000) -> Message:
"""Receive a message. Raises TimeoutError if none available."""
"""Receive a message. Raises TimeoutError if none available.
Loop ordering matters: check _incoming at the TOP of each
iteration, not as the loop condition. process_data_events
may dispatch a message via the _on_message callback during
the pump; we must re-check _incoming on the next iteration
before giving up on the deadline. The previous control
flow (`while deadline: check; pump`) could lose a wakeup if
the pump consumed the remainder of the window the
`while` check would fail before `_incoming` was re-read,
leaving a just-dispatched message stranded until the next
receive() call one full poll cycle later.
"""
if not self._is_alive():
self._connect()
timeout_seconds = timeout_millis / 1000.0
deadline = time.monotonic() + timeout_seconds
while time.monotonic() < deadline:
# Check if a message was already delivered
while True:
# Check if a message has been dispatched to our queue.
# This catches both (a) messages dispatched before this
# receive() was called and (b) messages dispatched
# during the previous iteration's process_data_events.
try:
method, properties, body = self._incoming.get_nowait()
return RabbitMQMessage(
@ -232,14 +259,16 @@ class RabbitMQBackendConsumer:
except queue.Empty:
pass
# Drive pika's I/O — delivers messages and processes heartbeats
remaining = deadline - time.monotonic()
if remaining > 0:
self._connection.process_data_events(
time_limit=min(0.1, remaining),
)
if remaining <= 0:
raise TimeoutError("No message received within timeout")
raise TimeoutError("No message received within timeout")
# Drive pika's I/O. Any messages delivered during this
# call land in _incoming via _on_message; the next
# iteration of this loop catches them at the top.
self._connection.process_data_events(
time_limit=min(0.1, remaining),
)
def acknowledge(self, message: Message) -> None:
if isinstance(message, RabbitMQMessage) and message._method:

View file

@ -41,14 +41,55 @@ class Subscriber:
self.consumer = None
self.executor = None
# Readiness barrier — completed by run() once the underlying
# backend consumer is fully connected and bound. start() awaits
# this so callers know any subsequently published request will
# have a queue ready to receive its response. Without this,
# ephemeral per-subscriber response queues (RabbitMQ auto-delete
# exclusive queues) would race the request and lose the reply.
# A Future is used (rather than an Event) so that a first-attempt
# connection failure can be propagated to start() as an exception.
self._ready = None # created in start() so we have a running loop
def __del__(self):
self.running = False
async def start(self):
self._ready = asyncio.get_event_loop().create_future()
self.task = asyncio.create_task(self.run())
# Block until run() signals readiness OR exits. The future
# carries the outcome of the first connect attempt: a value on
# success, an exception on first-attempt failure. If run() exits
# without ever signalling (e.g. cancelled, or a code path bug),
# we surface that as a clear RuntimeError rather than hanging
# forever waiting on the future.
ready_wait = asyncio.ensure_future(
asyncio.shield(self._ready)
)
try:
await asyncio.wait(
{self.task, ready_wait},
return_when=asyncio.FIRST_COMPLETED,
)
finally:
ready_wait.cancel()
if self._ready.done():
# Re-raise first-attempt connect failure if any.
self._ready.result()
return
# run() exited before _ready was settled. Propagate its exception
# if it had one, otherwise raise a generic readiness error.
if self.task.done() and self.task.exception() is not None:
raise self.task.exception()
raise RuntimeError(
"Subscriber.run() exited before signalling readiness"
)
async def stop(self):
"""Initiate graceful shutdown with draining"""
self.running = False
@ -66,6 +107,7 @@ class Subscriber:
async def run(self):
"""Enhanced run method with integrated draining logic"""
first_attempt = True
while self.running or self.draining:
if self.metrics:
@ -87,10 +129,27 @@ class Subscriber:
),
)
# Eagerly bind the queue. For backends that connect
# lazily on first receive (RabbitMQ), this is what
# closes the request/response setup race — without
# it the response queue is not bound until later and
# any reply published in the meantime is dropped.
await loop.run_in_executor(
self.executor,
lambda: self.consumer.ensure_connected(),
)
if self.metrics:
self.metrics.state("running")
logger.info("Subscriber running...")
# Signal start() that the consumer is ready. This must
# happen AFTER ensure_connected() above so callers can
# safely publish requests immediately after start() returns.
if first_attempt and not self._ready.done():
self._ready.set_result(None)
first_attempt = False
drain_end_time = None
while self.running or self.draining:
@ -162,6 +221,16 @@ class Subscriber:
except Exception as e:
logger.error(f"Subscriber exception: {e}", exc_info=True)
# First-attempt connection failure: propagate to start()
# so the caller can decide what to do (retry, give up).
# Subsequent failures use the existing retry-with-backoff
# path so a long-lived subscriber survives broker blips.
if first_attempt and not self._ready.done():
self._ready.set_exception(e)
first_attempt = False
# Falls through into finally for cleanup, then the
# outer return below ends run() so start() unblocks.
finally:
# Negative acknowledge any pending messages
for msg in self.pending_acks.values():
@ -193,6 +262,11 @@ class Subscriber:
if not self.running and not self.draining:
return
# If start() has already returned with an exception there is
# nothing more to do — exit run() rather than busy-retry.
if self._ready.done() and self._ready.exception() is not None:
return
# Sleep before retry
await asyncio.sleep(1)

View file

@ -1,47 +1,71 @@
from dataclasses import dataclass
from typing import Optional
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import TextCompletionRequest, TextCompletionResponse
@dataclass
class TextCompletionResult:
text: Optional[str]
in_token: Optional[int] = None
out_token: Optional[int] = None
model: Optional[str] = None
class TextCompletionClient(RequestResponse):
async def text_completion(self, system, prompt, streaming=False, timeout=600):
# If not streaming, use original behavior
if not streaming:
resp = await self.request(
TextCompletionRequest(
system = system, prompt = prompt, streaming = False
),
timeout=timeout
)
if resp.error:
raise RuntimeError(resp.error.message)
async def text_completion(self, system, prompt, timeout=600):
return resp.response
# For streaming: collect all chunks and return complete response
full_response = ""
async def collect_chunks(resp):
nonlocal full_response
if resp.error:
raise RuntimeError(resp.error.message)
if resp.response:
full_response += resp.response
# Return True when end_of_stream is reached
return getattr(resp, 'end_of_stream', False)
await self.request(
resp = await self.request(
TextCompletionRequest(
system = system, prompt = prompt, streaming = True
system = system, prompt = prompt, streaming = False
),
recipient=collect_chunks,
timeout=timeout
)
return full_response
if resp.error:
raise RuntimeError(resp.error.message)
return TextCompletionResult(
text = resp.response,
in_token = resp.in_token,
out_token = resp.out_token,
model = resp.model,
)
async def text_completion_stream(
self, system, prompt, handler, timeout=600,
):
"""
Streaming text completion. `handler` is an async callable invoked
once per chunk with the chunk's TextCompletionResponse. Returns a
TextCompletionResult with text=None and token counts / model taken
from the end_of_stream message.
"""
async def on_chunk(resp):
if resp.error:
raise RuntimeError(resp.error.message)
await handler(resp)
return getattr(resp, "end_of_stream", False)
final = await self.request(
TextCompletionRequest(
system = system, prompt = prompt, streaming = True
),
recipient=on_chunk,
timeout=timeout,
)
return TextCompletionResult(
text = None,
in_token = final.in_token,
out_token = final.out_token,
model = final.model,
)
class TextCompletionClientSpec(RequestResponseSpec):
def __init__(
@ -54,4 +78,3 @@ class TextCompletionClientSpec(RequestResponseSpec):
response_schema = TextCompletionResponse,
impl = TextCompletionClient,
)

View file

@ -58,23 +58,23 @@ class AgentClient(BaseClient):
def inspect(x):
# Handle errors
if x.chunk_type == 'error' or x.error:
if x.message_type == 'error' or x.error:
if error_callback:
error_callback(x.content or (x.error.message if x.error else ""))
# Continue to check end_of_dialog
# Handle thought chunks
elif x.chunk_type == 'thought':
elif x.message_type == 'thought':
if think:
think(x.content, x.end_of_message)
# Handle observation chunks
elif x.chunk_type == 'observation':
elif x.message_type == 'observation':
if observe:
observe(x.content, x.end_of_message)
# Handle answer chunks
elif x.chunk_type == 'answer':
elif x.message_type == 'answer':
if x.content:
accumulated_answer.append(x.content)
if answer_callback:

View file

@ -60,8 +60,8 @@ class AgentResponseTranslator(MessageTranslator):
def encode(self, obj: AgentResponse) -> Dict[str, Any]:
result = {}
if obj.chunk_type:
result["chunk_type"] = obj.chunk_type
if obj.message_type:
result["message_type"] = obj.message_type
if obj.content:
result["content"] = obj.content
result["end_of_message"] = getattr(obj, "end_of_message", False)
@ -90,6 +90,13 @@ class AgentResponseTranslator(MessageTranslator):
if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "code": obj.error.code}
if obj.in_token is not None:
result["in_token"] = obj.in_token
if obj.out_token is not None:
result["out_token"] = obj.out_token
if obj.model is not None:
result["model"] = obj.model
return result
def encode_with_completion(self, obj: AgentResponse) -> Tuple[Dict[str, Any], bool]:

View file

@ -151,7 +151,7 @@ class DocumentEmbeddingsTranslator(SendTranslator):
chunks = [
ChunkEmbeddings(
chunk_id=chunk["chunk_id"],
vectors=chunk["vectors"]
vector=chunk["vector"]
)
for chunk in data.get("chunks", [])
]

View file

@ -39,7 +39,7 @@ class KnowledgeRequestTranslator(MessageTranslator):
entities=[
EntityEmbeddings(
entity=self.value_translator.decode(ent["entity"]),
vectors=ent["vectors"],
vector=ent["vector"],
)
for ent in data["graph-embeddings"]["entities"]
]

View file

@ -53,6 +53,13 @@ class PromptResponseTranslator(MessageTranslator):
# Always include end_of_stream flag for streaming support
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
if obj.in_token is not None:
result["in_token"] = obj.in_token
if obj.out_token is not None:
result["out_token"] = obj.out_token
if obj.model is not None:
result["model"] = obj.model
return result
def encode_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]:

View file

@ -74,6 +74,13 @@ class DocumentRagResponseTranslator(MessageTranslator):
if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "type": obj.error.type}
if obj.in_token is not None:
result["in_token"] = obj.in_token
if obj.out_token is not None:
result["out_token"] = obj.out_token
if obj.model is not None:
result["model"] = obj.model
return result
def encode_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]:
@ -163,6 +170,13 @@ class GraphRagResponseTranslator(MessageTranslator):
if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "type": obj.error.type}
if obj.in_token is not None:
result["in_token"] = obj.in_token
if obj.out_token is not None:
result["out_token"] = obj.out_token
if obj.model is not None:
result["model"] = obj.model
return result
def encode_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]:

View file

@ -29,11 +29,11 @@ class TextCompletionResponseTranslator(MessageTranslator):
def encode(self, obj: TextCompletionResponse) -> Dict[str, Any]:
result = {"response": obj.response}
if obj.in_token:
if obj.in_token is not None:
result["in_token"] = obj.in_token
if obj.out_token:
if obj.out_token is not None:
result["out_token"] = obj.out_token
if obj.model:
if obj.model is not None:
result["model"] = obj.model
# Always include end_of_stream flag for streaming support

View file

@ -59,6 +59,7 @@ from . uris import (
agent_plan_uri,
agent_step_result_uri,
agent_synthesis_uri,
agent_pattern_decision_uri,
# Document RAG provenance URIs
docrag_question_uri,
docrag_grounding_uri,
@ -102,6 +103,11 @@ from . namespaces import (
# Agent provenance predicates
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
TG_SUBAGENT_GOAL, TG_PLAN_STEP,
TG_TOOL_CANDIDATE, TG_TERMINATION_REASON,
TG_STEP_NUMBER, TG_PATTERN_DECISION, TG_PATTERN, TG_TASK_TYPE,
TG_LLM_DURATION_MS, TG_TOOL_DURATION_MS, TG_TOOL_ERROR,
TG_IN_TOKEN, TG_OUT_TOKEN,
TG_ERROR_TYPE,
# Orchestrator entity types
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
# Document reference predicate
@ -141,6 +147,7 @@ from . agent import (
agent_plan_triples,
agent_step_result_triples,
agent_synthesis_triples,
agent_pattern_decision_triples,
)
# Vocabulary bootstrap
@ -182,6 +189,7 @@ __all__ = [
"agent_plan_uri",
"agent_step_result_uri",
"agent_synthesis_uri",
"agent_pattern_decision_uri",
# Document RAG provenance URIs
"docrag_question_uri",
"docrag_grounding_uri",
@ -218,6 +226,11 @@ __all__ = [
# Agent provenance predicates
"TG_THOUGHT", "TG_ACTION", "TG_ARGUMENTS", "TG_OBSERVATION",
"TG_SUBAGENT_GOAL", "TG_PLAN_STEP",
"TG_TOOL_CANDIDATE", "TG_TERMINATION_REASON",
"TG_STEP_NUMBER", "TG_PATTERN_DECISION", "TG_PATTERN", "TG_TASK_TYPE",
"TG_LLM_DURATION_MS", "TG_TOOL_DURATION_MS", "TG_TOOL_ERROR",
"TG_IN_TOKEN", "TG_OUT_TOKEN",
"TG_ERROR_TYPE",
# Orchestrator entity types
"TG_DECOMPOSITION", "TG_FINDING", "TG_PLAN_TYPE", "TG_STEP_RESULT",
# Document reference predicate
@ -249,6 +262,7 @@ __all__ = [
"agent_plan_triples",
"agent_step_result_triples",
"agent_synthesis_triples",
"agent_pattern_decision_triples",
# Utility
"set_graph",
# Vocabulary

View file

@ -29,6 +29,11 @@ from . namespaces import (
TG_AGENT_QUESTION,
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
TG_SYNTHESIS, TG_SUBAGENT_GOAL, TG_PLAN_STEP,
TG_TOOL_CANDIDATE, TG_TERMINATION_REASON,
TG_STEP_NUMBER, TG_PATTERN_DECISION, TG_PATTERN, TG_TASK_TYPE,
TG_LLM_DURATION_MS, TG_TOOL_DURATION_MS, TG_TOOL_ERROR,
TG_ERROR_TYPE,
TG_IN_TOKEN, TG_OUT_TOKEN, TG_LLM_MODEL,
)
@ -47,6 +52,17 @@ def _triple(s: str, p: str, o_term: Term) -> Triple:
return Triple(s=_iri(s), p=_iri(p), o=o_term)
def _append_token_triples(triples, uri, in_token=None, out_token=None,
model=None):
"""Append in_token/out_token/model triples when values are present."""
if in_token is not None:
triples.append(_triple(uri, TG_IN_TOKEN, _literal(str(in_token))))
if out_token is not None:
triples.append(_triple(uri, TG_OUT_TOKEN, _literal(str(out_token))))
if model is not None:
triples.append(_triple(uri, TG_LLM_MODEL, _literal(model)))
def agent_session_triples(
session_uri: str,
query: str,
@ -90,6 +106,43 @@ def agent_session_triples(
return triples
def agent_pattern_decision_triples(
uri: str,
session_uri: str,
pattern: str,
task_type: str = "",
) -> List[Triple]:
"""
Build triples for a meta-router pattern decision.
Creates:
- Entity declaration with tg:PatternDecision type
- wasDerivedFrom link to session
- Pattern and task type predicates
Args:
uri: URI of this decision (from agent_pattern_decision_uri)
session_uri: URI of the parent session
pattern: Selected execution pattern (e.g. "react", "plan-then-execute")
task_type: Identified task type (e.g. "general", "research")
Returns:
List of Triple objects
"""
triples = [
_triple(uri, RDF_TYPE, _iri(PROV_ENTITY)),
_triple(uri, RDF_TYPE, _iri(TG_PATTERN_DECISION)),
_triple(uri, RDFS_LABEL, _literal(f"Pattern: {pattern}")),
_triple(uri, TG_PATTERN, _literal(pattern)),
_triple(uri, PROV_WAS_DERIVED_FROM, _iri(session_uri)),
]
if task_type:
triples.append(_triple(uri, TG_TASK_TYPE, _literal(task_type)))
return triples
def agent_iteration_triples(
iteration_uri: str,
question_uri: Optional[str] = None,
@ -98,6 +151,12 @@ def agent_iteration_triples(
arguments: Dict[str, Any] = None,
thought_uri: Optional[str] = None,
thought_document_id: Optional[str] = None,
tool_candidates: Optional[List[str]] = None,
step_number: Optional[int] = None,
llm_duration_ms: Optional[int] = None,
in_token: Optional[int] = None,
out_token: Optional[int] = None,
model: Optional[str] = None,
) -> List[Triple]:
"""
Build triples for one agent iteration (Analysis+ToolUse).
@ -106,6 +165,7 @@ def agent_iteration_triples(
- Entity declaration with tg:Analysis and tg:ToolUse types
- wasDerivedFrom link to question (if first iteration) or previous
- Action and arguments metadata
- Tool candidates (names of tools visible to the LLM)
- Thought sub-entity (tg:Reflection, tg:Thought) with librarian document
Args:
@ -116,6 +176,7 @@ def agent_iteration_triples(
arguments: Arguments passed to the tool (will be JSON-encoded)
thought_uri: URI for the thought sub-entity
thought_document_id: Document URI for thought in librarian
tool_candidates: List of tool names available to the LLM
Returns:
List of Triple objects
@ -132,6 +193,23 @@ def agent_iteration_triples(
_triple(iteration_uri, TG_ARGUMENTS, _literal(json.dumps(arguments))),
]
if tool_candidates:
for name in tool_candidates:
triples.append(
_triple(iteration_uri, TG_TOOL_CANDIDATE, _literal(name))
)
if step_number is not None:
triples.append(
_triple(iteration_uri, TG_STEP_NUMBER, _literal(str(step_number)))
)
if llm_duration_ms is not None:
triples.append(
_triple(iteration_uri, TG_LLM_DURATION_MS,
_literal(str(llm_duration_ms)))
)
if question_uri:
triples.append(
_triple(iteration_uri, PROV_WAS_DERIVED_FROM, _iri(question_uri))
@ -155,6 +233,8 @@ def agent_iteration_triples(
_triple(thought_uri, TG_DOCUMENT, _iri(thought_document_id))
)
_append_token_triples(triples, iteration_uri, in_token, out_token, model)
return triples
@ -162,6 +242,8 @@ def agent_observation_triples(
observation_uri: str,
iteration_uri: str,
document_id: Optional[str] = None,
tool_duration_ms: Optional[int] = None,
tool_error: Optional[str] = None,
) -> List[Triple]:
"""
Build triples for an agent observation (standalone entity).
@ -170,11 +252,15 @@ def agent_observation_triples(
- Entity declaration with prov:Entity and tg:Observation types
- wasDerivedFrom link to the iteration (Analysis+ToolUse)
- Document reference to librarian (if provided)
- Tool execution duration (if provided)
- Tool error message (if the tool failed)
Args:
observation_uri: URI of the observation entity
iteration_uri: URI of the iteration this observation derives from
document_id: Librarian document ID for the observation content
tool_duration_ms: Tool execution time in milliseconds
tool_error: Error message if the tool failed
Returns:
List of Triple objects
@ -191,6 +277,20 @@ def agent_observation_triples(
_triple(observation_uri, TG_DOCUMENT, _iri(document_id))
)
if tool_duration_ms is not None:
triples.append(
_triple(observation_uri, TG_TOOL_DURATION_MS,
_literal(str(tool_duration_ms)))
)
if tool_error:
triples.append(
_triple(observation_uri, TG_TOOL_ERROR, _literal(tool_error))
)
triples.append(
_triple(observation_uri, RDF_TYPE, _iri(TG_ERROR_TYPE))
)
return triples
@ -199,6 +299,10 @@ def agent_final_triples(
question_uri: Optional[str] = None,
previous_uri: Optional[str] = None,
document_id: Optional[str] = None,
termination_reason: Optional[str] = None,
in_token: Optional[int] = None,
out_token: Optional[int] = None,
model: Optional[str] = None,
) -> List[Triple]:
"""
Build triples for an agent final answer (Conclusion).
@ -208,12 +312,15 @@ def agent_final_triples(
- wasGeneratedBy link to question (if no iterations)
- wasDerivedFrom link to last iteration (if iterations exist)
- Document reference to librarian
- Termination reason (why the agent loop stopped)
Args:
final_uri: URI of the final answer (from agent_final_uri)
question_uri: URI of the question activity (if no iterations)
previous_uri: URI of the last iteration (if iterations exist)
document_id: Librarian document ID for the answer content
termination_reason: Why the loop stopped, e.g. "final-answer",
"max-iterations", "error"
Returns:
List of Triple objects
@ -237,6 +344,14 @@ def agent_final_triples(
if document_id:
triples.append(_triple(final_uri, TG_DOCUMENT, _iri(document_id)))
if termination_reason:
triples.append(
_triple(final_uri, TG_TERMINATION_REASON,
_literal(termination_reason))
)
_append_token_triples(triples, final_uri, in_token, out_token, model)
return triples
@ -244,6 +359,9 @@ def agent_decomposition_triples(
uri: str,
session_uri: str,
goals: List[str],
in_token: Optional[int] = None,
out_token: Optional[int] = None,
model: Optional[str] = None,
) -> List[Triple]:
"""Build triples for a supervisor decomposition step."""
triples = [
@ -255,6 +373,7 @@ def agent_decomposition_triples(
]
for goal in goals:
triples.append(_triple(uri, TG_SUBAGENT_GOAL, _literal(goal)))
_append_token_triples(triples, uri, in_token, out_token, model)
return triples
@ -282,6 +401,9 @@ def agent_plan_triples(
uri: str,
session_uri: str,
steps: List[str],
in_token: Optional[int] = None,
out_token: Optional[int] = None,
model: Optional[str] = None,
) -> List[Triple]:
"""Build triples for a plan-then-execute plan."""
triples = [
@ -293,6 +415,7 @@ def agent_plan_triples(
]
for step in steps:
triples.append(_triple(uri, TG_PLAN_STEP, _literal(step)))
_append_token_triples(triples, uri, in_token, out_token, model)
return triples
@ -301,6 +424,9 @@ def agent_step_result_triples(
plan_uri: str,
goal: str,
document_id: Optional[str] = None,
in_token: Optional[int] = None,
out_token: Optional[int] = None,
model: Optional[str] = None,
) -> List[Triple]:
"""Build triples for a plan step result."""
triples = [
@ -313,6 +439,7 @@ def agent_step_result_triples(
]
if document_id:
triples.append(_triple(uri, TG_DOCUMENT, _iri(document_id)))
_append_token_triples(triples, uri, in_token, out_token, model)
return triples
@ -320,6 +447,10 @@ def agent_synthesis_triples(
uri: str,
previous_uris,
document_id: Optional[str] = None,
termination_reason: Optional[str] = None,
in_token: Optional[int] = None,
out_token: Optional[int] = None,
model: Optional[str] = None,
) -> List[Triple]:
"""Build triples for a synthesis answer.
@ -327,6 +458,8 @@ def agent_synthesis_triples(
uri: URI of the synthesis entity
previous_uris: Single URI string or list of URIs to derive from
document_id: Librarian document ID for the answer content
termination_reason: Why the agent loop stopped
in_token/out_token/model: Token usage for the synthesis LLM call
"""
triples = [
_triple(uri, RDF_TYPE, _iri(PROV_ENTITY)),
@ -342,4 +475,12 @@ def agent_synthesis_triples(
if document_id:
triples.append(_triple(uri, TG_DOCUMENT, _iri(document_id)))
if termination_reason:
triples.append(
_triple(uri, TG_TERMINATION_REASON, _literal(termination_reason))
)
_append_token_triples(triples, uri, in_token, out_token, model)
return triples

Some files were not shown because too many files have changed in this diff Show more