diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml
index 2243da10..b1ae8611 100644
--- a/.github/workflows/pull-request.yaml
+++ b/.github/workflows/pull-request.yaml
@@ -22,7 +22,7 @@ jobs:
uses: actions/checkout@v3
- name: Setup packages
- run: make update-package-versions VERSION=2.2.999
+ run: make update-package-versions VERSION=2.3.999
- name: Setup environment
run: python3 -m venv env
diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml
index dc2fc89b..07af8db9 100644
--- a/.github/workflows/release.yaml
+++ b/.github/workflows/release.yaml
@@ -40,10 +40,9 @@ jobs:
- name: Publish release distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
- deploy-container-image:
+ build-platform-image:
- name: Release container images
- runs-on: ubuntu-24.04
+ name: Build ${{ matrix.container }} (${{ matrix.platform }})
permissions:
contents: write
id-token: write
@@ -52,14 +51,24 @@ jobs:
strategy:
matrix:
container:
- - trustgraph-base
- - trustgraph-flow
- - trustgraph-bedrock
- - trustgraph-vertexai
- - trustgraph-hf
- - trustgraph-ocr
- - trustgraph-unstructured
- - trustgraph-mcp
+ - base
+ - flow
+ - bedrock
+ - vertexai
+ - hf
+ - ocr
+ - unstructured
+ - mcp
+ platform:
+ - amd64
+ - arm64
+ include:
+ - platform: amd64
+ runner: ubuntu-24.04
+ - platform: arm64
+ runner: ubuntu-24.04-arm
+
+ runs-on: ${{ matrix.runner }}
steps:
@@ -76,12 +85,48 @@ jobs:
id: version
run: echo VERSION=$(git describe --exact-match --tags | sed 's/^v//') >> $GITHUB_OUTPUT
- - name: Put version into package manifests
- run: make update-package-versions VERSION=${{ steps.version.outputs.VERSION }}
+ - name: Build container
+ run: make platform-${{ matrix.container }}-${{ matrix.platform }} VERSION=${{ steps.version.outputs.VERSION }}
- - name: Build container - ${{ matrix.container }}
- run: make container-${{ matrix.container }} VERSION=${{ steps.version.outputs.VERSION }}
+ - name: Push container
+ run: make push-platform-${{ matrix.container }}-${{ matrix.platform }} VERSION=${{ steps.version.outputs.VERSION }}
- - name: Push container - ${{ matrix.container }}
- run: make push-${{ matrix.container }} VERSION=${{ steps.version.outputs.VERSION }}
+ combine-manifests:
+ name: Combine manifest ${{ matrix.container }}
+ runs-on: ubuntu-24.04
+ needs: build-platform-image
+ permissions:
+ contents: write
+ id-token: write
+ environment:
+ name: release
+ strategy:
+ matrix:
+ container:
+ - base
+ - flow
+ - bedrock
+ - vertexai
+ - hf
+ - ocr
+ - unstructured
+ - mcp
+
+ steps:
+
+ - name: Checkout
+ uses: actions/checkout@v4
+
+ - name: Docker Hub token
+ run: echo ${{ secrets.DOCKER_SECRET }} > docker-token.txt
+
+ - name: Authenticate with Docker hub
+ run: make docker-hub-login
+
+ - name: Get version
+ id: version
+ run: echo VERSION=$(git describe --exact-match --tags | sed 's/^v//') >> $GITHUB_OUTPUT
+
+ - name: Combine and push manifest
+ run: make combine-manifest-${{ matrix.container }} VERSION=${{ steps.version.outputs.VERSION }}
diff --git a/Makefile b/Makefile
index 197a6c63..85f10fdd 100644
--- a/Makefile
+++ b/Makefile
@@ -52,51 +52,12 @@ update-package-versions:
echo __version__ = \"${VERSION}\" > trustgraph/trustgraph/trustgraph_version.py
echo __version__ = \"${VERSION}\" > trustgraph-mcp/trustgraph/mcp_version.py
-FORCE:
+containers: container-base container-flow \
+container-bedrock container-vertexai \
+container-hf container-ocr \
+container-unstructured container-mcp
-containers: FORCE
- ${DOCKER} build -f containers/Containerfile.base \
- -t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
- ${DOCKER} build -f containers/Containerfile.flow \
- -t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
- ${DOCKER} build -f containers/Containerfile.bedrock \
- -t ${CONTAINER_BASE}/trustgraph-bedrock:${VERSION} .
- ${DOCKER} build -f containers/Containerfile.vertexai \
- -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
- ${DOCKER} build -f containers/Containerfile.hf \
- -t ${CONTAINER_BASE}/trustgraph-hf:${VERSION} .
- ${DOCKER} build -f containers/Containerfile.ocr \
- -t ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} .
- ${DOCKER} build -f containers/Containerfile.unstructured \
- -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} .
- ${DOCKER} build -f containers/Containerfile.mcp \
- -t ${CONTAINER_BASE}/trustgraph-mcp:${VERSION} .
-
-some-containers:
- ${DOCKER} build -f containers/Containerfile.base \
- -t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
- ${DOCKER} build -f containers/Containerfile.flow \
- -t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
-# ${DOCKER} build -f containers/Containerfile.unstructured \
-# -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} .
-# ${DOCKER} build -f containers/Containerfile.vertexai \
-# -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
-# ${DOCKER} build -f containers/Containerfile.mcp \
-# -t ${CONTAINER_BASE}/trustgraph-mcp:${VERSION} .
-# ${DOCKER} build -f containers/Containerfile.vertexai \
-# -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
-# ${DOCKER} build -f containers/Containerfile.bedrock \
-# -t ${CONTAINER_BASE}/trustgraph-bedrock:${VERSION} .
-
-basic-containers: update-package-versions
- ${DOCKER} build -f containers/Containerfile.base \
- -t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
- ${DOCKER} build -f containers/Containerfile.flow \
- -t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
-
-container.ocr:
- ${DOCKER} build -f containers/Containerfile.ocr \
- -t ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} .
+some-containers: container-base container-flow
push:
${DOCKER} push ${CONTAINER_BASE}/trustgraph-base:${VERSION}
@@ -109,54 +70,60 @@ push:
${DOCKER} push ${CONTAINER_BASE}/trustgraph-mcp:${VERSION}
# Individual container build targets
-container-trustgraph-base: update-package-versions
- ${DOCKER} build -f containers/Containerfile.base -t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
+container-%: update-package-versions
+ ${DOCKER} build \
+ -f containers/Containerfile.${@:container-%=%} \
+ -t ${CONTAINER_BASE}/trustgraph-${@:container-%=%}:${VERSION} .
-container-trustgraph-flow: update-package-versions
- ${DOCKER} build -f containers/Containerfile.flow -t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
+# Multi-arch: build both platforms sequentially into one manifest (local use)
+manifest-%: update-package-versions
+ -@${DOCKER} manifest rm \
+ ${CONTAINER_BASE}/trustgraph-${@:manifest-%=%}:${VERSION}
+ ${DOCKER} build --platform linux/amd64,linux/arm64 \
+ -f containers/Containerfile.${@:manifest-%=%} \
+ --manifest \
+ ${CONTAINER_BASE}/trustgraph-${@:manifest-%=%}:${VERSION} .
-container-trustgraph-bedrock: update-package-versions
- ${DOCKER} build -f containers/Containerfile.bedrock -t ${CONTAINER_BASE}/trustgraph-bedrock:${VERSION} .
+# Multi-arch: build a single platform image (for parallel CI)
+platform-%-amd64: update-package-versions
+ ${DOCKER} build --platform linux/amd64 \
+ -f containers/Containerfile.${@:platform-%-amd64=%} \
+ -t ${CONTAINER_BASE}/trustgraph-${@:platform-%-amd64=%}:${VERSION}-amd64 .
-container-trustgraph-vertexai: update-package-versions
- ${DOCKER} build -f containers/Containerfile.vertexai -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
+platform-%-arm64: update-package-versions
+ ${DOCKER} build --platform linux/arm64 \
+ -f containers/Containerfile.${@:platform-%-arm64=%} \
+ -t ${CONTAINER_BASE}/trustgraph-${@:platform-%-arm64=%}:${VERSION}-arm64 .
-container-trustgraph-hf: update-package-versions
- ${DOCKER} build -f containers/Containerfile.hf -t ${CONTAINER_BASE}/trustgraph-hf:${VERSION} .
+# Push a single platform image
+push-platform-%-amd64:
+ ${DOCKER} push \
+ ${CONTAINER_BASE}/trustgraph-${@:push-platform-%-amd64=%}:${VERSION}-amd64
-container-trustgraph-ocr: update-package-versions
- ${DOCKER} build -f containers/Containerfile.ocr -t ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} .
+push-platform-%-arm64:
+ ${DOCKER} push \
+ ${CONTAINER_BASE}/trustgraph-${@:push-platform-%-arm64=%}:${VERSION}-arm64
-container-trustgraph-unstructured: update-package-versions
- ${DOCKER} build -f containers/Containerfile.unstructured -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} .
+# Combine per-platform images into a multi-arch manifest
+combine-manifest-%:
+ -@${DOCKER} manifest rm \
+ ${CONTAINER_BASE}/trustgraph-${@:combine-manifest-%=%}:${VERSION}
+ ${DOCKER} manifest create \
+ ${CONTAINER_BASE}/trustgraph-${@:combine-manifest-%=%}:${VERSION} \
+ docker://${CONTAINER_BASE}/trustgraph-${@:combine-manifest-%=%}:${VERSION}-amd64 \
+ docker://${CONTAINER_BASE}/trustgraph-${@:combine-manifest-%=%}:${VERSION}-arm64
+ ${DOCKER} manifest push \
+ ${CONTAINER_BASE}/trustgraph-${@:combine-manifest-%=%}:${VERSION}
-container-trustgraph-mcp: update-package-versions
- ${DOCKER} build -f containers/Containerfile.mcp -t ${CONTAINER_BASE}/trustgraph-mcp:${VERSION} .
+# Push a container
+push-container-%:
+ ${DOCKER} push \
+ ${CONTAINER_BASE}/trustgraph-${@:push-container-%=%}:${VERSION}
-# Individual container push targets
-push-trustgraph-base:
- ${DOCKER} push ${CONTAINER_BASE}/trustgraph-base:${VERSION}
-
-push-trustgraph-flow:
- ${DOCKER} push ${CONTAINER_BASE}/trustgraph-flow:${VERSION}
-
-push-trustgraph-bedrock:
- ${DOCKER} push ${CONTAINER_BASE}/trustgraph-bedrock:${VERSION}
-
-push-trustgraph-vertexai:
- ${DOCKER} push ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION}
-
-push-trustgraph-hf:
- ${DOCKER} push ${CONTAINER_BASE}/trustgraph-hf:${VERSION}
-
-push-trustgraph-ocr:
- ${DOCKER} push ${CONTAINER_BASE}/trustgraph-ocr:${VERSION}
-
-push-trustgraph-unstructured:
- ${DOCKER} push ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION}
-
-push-trustgraph-mcp:
- ${DOCKER} push ${CONTAINER_BASE}/trustgraph-mcp:${VERSION}
+# Push a manifest (from local multi-arch build)
+push-manifest-%:
+ ${DOCKER} manifest push \
+ ${CONTAINER_BASE}/trustgraph-${@:push-manifest-%=%}:${VERSION}
clean:
rm -rf wheels/
@@ -164,52 +131,6 @@ clean:
set-version:
echo '"${VERSION}"' > templates/values/version.jsonnet
-TEMPLATES=azure bedrock claude cohere mix llamafile mistral ollama openai vertexai \
- openai-neo4j storage
-
-DCS=$(foreach template,${TEMPLATES},${template:%=tg-launch-%.yaml})
-
-MODELS=azure bedrock claude cohere llamafile mistral ollama openai vertexai
-GRAPHS=cassandra neo4j falkordb memgraph
-
-# tg-launch-%.yaml: templates/%.jsonnet templates/components/version.jsonnet
-# jsonnet -Jtemplates \
-# -S ${@:tg-launch-%.yaml=templates/%.jsonnet} > $@
-
-# VECTORDB=milvus
-VECTORDB=qdrant
-
-JSONNET_FLAGS=-J templates -J .
-
-# Temporarily going back to how templates were built in 0.9 because this
-# is going away in 0.11.
-
-update-templates: update-dcs
-
-JSON_TO_YAML=python -c 'import sys, yaml, json; j=json.loads(sys.stdin.read()); print(yaml.safe_dump(j))'
-
-update-dcs: set-version
- for graph in ${GRAPHS}; do \
- cm=$${graph},pulsar,${VECTORDB},grafana; \
- input=templates/opts-to-docker-compose.jsonnet; \
- output=tg-storage-$${graph}.yaml; \
- echo $${graph} '->' $${output}; \
- jsonnet ${JSONNET_FLAGS} \
- --ext-str options=$${cm} $${input} | \
- ${JSON_TO_YAML} > $${output}; \
- done
- for model in ${MODELS}; do \
- for graph in ${GRAPHS}; do \
- cm=$${graph},pulsar,${VECTORDB},embeddings-hf,graph-rag,grafana,trustgraph,$${model}; \
- input=templates/opts-to-docker-compose.jsonnet; \
- output=tg-launch-$${model}-$${graph}.yaml; \
- echo $${model} + $${graph} '->' $${output}; \
- jsonnet ${JSONNET_FLAGS} \
- --ext-str options=$${cm} $${input} | \
- ${JSON_TO_YAML} > $${output}; \
- done; \
- done
-
docker-hub-login:
cat docker-token.txt | \
${DOCKER} login -u trustgraph --password-stdin registry-1.docker.io
diff --git a/containers/Containerfile.hf b/containers/Containerfile.hf
index 351300ae..a1ec5346 100644
--- a/containers/Containerfile.hf
+++ b/containers/Containerfile.hf
@@ -1,22 +1,22 @@
-# ----------------------------------------------------------------------------
-# Build an AI container. This does the torch install which is huge, and I
-# like to avoid re-doing this.
-# ----------------------------------------------------------------------------
+# Torch is stable and compiles for ARM64 and AMD64 on Python 3.12
FROM docker.io/fedora:42 AS ai
ENV PIP_BREAK_SYSTEM_PACKAGES=1
-RUN dnf install -y python3.13 && \
- alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
+RUN dnf install -y python3.12 && \
+ alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \
python -m ensurepip --upgrade && \
pip3 install --no-cache-dir build wheel aiohttp && \
pip3 install --no-cache-dir pulsar-client==3.7.0 && \
dnf clean all
-RUN pip3 install torch==2.5.1+cpu \
- --index-url https://download.pytorch.org/whl/cpu
+# This won't work on ARM
+#RUN pip3 install torch==2.5.1+cpu \
+# --index-url https://download.pytorch.org/whl/cpu
+
+RUN pip3 install torch
RUN pip3 install --no-cache-dir \
langchain==0.3.25 langchain-core==0.3.60 langchain-huggingface==0.2.0 \
diff --git a/dev-tools/proc-group/README.md b/dev-tools/proc-group/README.md
new file mode 100644
index 00000000..1874ea36
--- /dev/null
+++ b/dev-tools/proc-group/README.md
@@ -0,0 +1,117 @@
+# proc-group — run TrustGraph as a single process
+
+A dev-focused alternative to the per-container deployment. Instead of 30+
+containers each running a single processor, `processor-group` runs all the
+processors as asyncio tasks inside one Python process, sharing the event
+loop, Prometheus registry, and (importantly) resources on your laptop.
+
+This is **not** for production. Scale deployments should keep using
+per-processor containers — one failure bringing down the whole process,
+no horizontal scaling, and a single giant log are fine for dev and a
+bad idea in prod.
+
+## What this directory contains
+
+- `group.yaml` — the group runner config. One entry per processor, each
+ with the dotted class path and a params dict. Defaults (pubsub backend,
+ rabbitmq host, log level) are pulled in per-entry with a YAML anchor.
+- `README.md` — this file.
+
+## Prerequisites
+
+Install the TrustGraph packages into a venv:
+
+```
+pip install trustgraph-base trustgraph-flow trustgraph-unstructured
+```
+
+`trustgraph-base` provides the `processor-group` endpoint. The others
+provide the processor classes that `group.yaml` imports at runtime.
+`trustgraph-unstructured` is only needed if you want `document-decoder`
+(the `universal-decoder` processor).
+
+## Running it
+
+Start infrastructure (cassandra, qdrant, rabbitmq, garage, observability
+stack) with a working compose file. These aren't packable into the group -
+they're third-party services. You may be able to run these as standalone
+services.
+
+To get Cassandra to be accessible from the host, you need to
+set a couple of environment variables:
+```
+ CASSANDRA_BROADCAST_ADDRESS: 127.0.0.1
+ CASSANDRA_LISTEN_ADDRESS: 127.0.0.1
+```
+and also set `network: host`. Then start services:
+
+```
+podman-compose up -d cassandra qdrant rabbitmq
+podman-compose up -d garage garage-init
+podman-compose up -d loki prometheus grafana
+podman-compose up -d init-trustgraph
+```
+
+`init-trustgraph` is a one-shot that seeds config and the default flow
+into cassandra/rabbitmq. Don't leave too long a delay between starting
+`init-trustgraph` and running the processor-group, because it needs to
+talk to the config service.
+
+Run the api-gateway separately — it's an aiohttp HTTP server, not an
+`AsyncProcessor`, so the group runner doesn't host it:
+
+Raise the file descriptor limit — 30+ processors sharing one process
+open far more sockets than the default 1024 allows:
+
+```
+ulimit -n 65536
+```
+
+Then start the group from a terminal:
+
+```
+processor-group -c group.yaml --no-loki-enabled
+```
+
+You'll see every processor's startup messages interleaved in one log.
+Each processor has a supervisor that restarts it independently on
+failure, so a transient crash (or a dependency that isn't ready yet)
+only affects that one processor — siblings keep running and the failing
+one self-heals on the next retry.
+
+Finally when everything is running you can start the API gateway from
+its own terminal:
+
+```
+api-gateway \
+ --pubsub-backend rabbitmq --rabbitmq-host localhost \
+ --loki-url http://localhost:3100/loki/api/v1/push \
+ --no-metrics
+```
+
+
+
+## When things go wrong
+
+- **"Too many open files"** — raise `ulimit -n` further. 65536 is
+ usually plenty but some workflows need more.
+- **One processor failing repeatedly** — look for its id in the log. The
+ supervisor will log each failure before restarting. Fix the cause
+ (missing env var, unreachable dependency, bad params) and the
+ processor self-heals on the next 4-second retry without restarting
+ the whole group.
+- **Ctrl-C leaves the process hung** — the pika and cassandra drivers
+ spawn non-cooperative threads that asyncio can't cancel. Use Ctrl-\
+ (SIGQUIT) to force-kill. Not a bug in the group runner, just a
+ limitation of those libraries.
+
+## Environment variables
+
+Processors that talk to external LLMs or APIs read their credentials
+from env vars, same as in the per-container deployment:
+
+- `OPENAI_TOKEN`, `OPENAI_BASE_URL` — for `text-completion` /
+ `text-completion-rag`
+
+Export whatever your particular `group.yaml` needs before running.
+
diff --git a/dev-tools/proc-group/group.yaml b/dev-tools/proc-group/group.yaml
new file mode 100644
index 00000000..98ef5016
--- /dev/null
+++ b/dev-tools/proc-group/group.yaml
@@ -0,0 +1,257 @@
+# Multi-processor group config, derived from docker-compose.yaml.
+#
+# Covers every AsyncProcessor-based service from the compose file.
+# Out of scope:
+# - api-gateway (aiohttp, not AsyncProcessor)
+# - init-trustgraph (one-shot init, not a processor)
+# - document-decoder (universal-decoder, trustgraph-unstructured package —
+# packable but lives in a separate image/package)
+# - mcp-server (trustgraph-mcp package, separate image)
+# - ddg-mcp-server (third-party image)
+# - infrastructure (cassandra, rabbitmq, qdrant, garage, grafana,
+# prometheus, loki, workbench-ui)
+#
+# Run with:
+# processor-group -c group.yaml
+
+_defaults: &defaults
+ pubsub_backend: rabbitmq
+ rabbitmq_host: localhost
+ log_level: INFO
+
+processors:
+
+ - class: trustgraph.agent.orchestrator.Processor
+ params:
+ <<: *defaults
+ id: agent-manager
+
+ - class: trustgraph.chunking.recursive.Processor
+ params:
+ <<: *defaults
+ id: chunker
+ chunk_size: 2000
+ chunk_overlap: 50
+
+ - class: trustgraph.config.service.Processor
+ params:
+ <<: *defaults
+ id: config-svc
+ cassandra_host: localhost
+
+ - class: trustgraph.decoding.universal.Processor
+ params:
+ <<: *defaults
+ id: document-decoder
+
+ - class: trustgraph.embeddings.document_embeddings.Processor
+ params:
+ <<: *defaults
+ id: document-embeddings
+
+ - class: trustgraph.retrieval.document_rag.Processor
+ params:
+ <<: *defaults
+ id: document-rag
+ doc_limit: 20
+
+ - class: trustgraph.embeddings.fastembed.Processor
+ params:
+ <<: *defaults
+ id: embeddings
+ concurrency: 1
+
+ - class: trustgraph.embeddings.graph_embeddings.Processor
+ params:
+ <<: *defaults
+ id: graph-embeddings
+
+ - class: trustgraph.retrieval.graph_rag.Processor
+ params:
+ <<: *defaults
+ id: graph-rag
+ concurrency: 1
+ entity_limit: 50
+ triple_limit: 30
+ edge_limit: 30
+ edge_score_limit: 10
+ max_subgraph_size: 100
+ max_path_length: 2
+
+ - class: trustgraph.extract.kg.agent.Processor
+ params:
+ <<: *defaults
+ id: kg-extract-agent
+ concurrency: 1
+
+ - class: trustgraph.extract.kg.definitions.Processor
+ params:
+ <<: *defaults
+ id: kg-extract-definitions
+ concurrency: 1
+
+ - class: trustgraph.extract.kg.ontology.Processor
+ params:
+ <<: *defaults
+ id: kg-extract-ontology
+ concurrency: 1
+
+ - class: trustgraph.extract.kg.relationships.Processor
+ params:
+ <<: *defaults
+ id: kg-extract-relationships
+ concurrency: 1
+
+ - class: trustgraph.extract.kg.rows.Processor
+ params:
+ <<: *defaults
+ id: kg-extract-rows
+ concurrency: 1
+
+ - class: trustgraph.cores.service.Processor
+ params:
+ <<: *defaults
+ id: knowledge
+ cassandra_host: localhost
+
+ - class: trustgraph.storage.knowledge.store.Processor
+ params:
+ <<: *defaults
+ id: kg-store
+ cassandra_host: localhost
+
+ - class: trustgraph.librarian.Processor
+ params:
+ <<: *defaults
+ id: librarian
+ cassandra_host: localhost
+ object_store_endpoint: localhost:3900
+ object_store_access_key: GK000000000000000000000001
+ object_store_secret_key: b171f00be9be4c32c734f4c05fe64c527a8ab5eb823b376cfa8c2531f70fc427
+ object_store_region: garage
+
+ - class: trustgraph.agent.mcp_tool.Service
+ params:
+ <<: *defaults
+ id: mcp-tool
+
+ - class: trustgraph.metering.Processor
+ params:
+ <<: *defaults
+ id: metering
+
+ - class: trustgraph.metering.Processor
+ params:
+ <<: *defaults
+ id: metering-rag
+
+ - class: trustgraph.retrieval.nlp_query.Processor
+ params:
+ <<: *defaults
+ id: nlp-query
+
+ - class: trustgraph.prompt.template.Processor
+ params:
+ <<: *defaults
+ id: prompt
+ concurrency: 1
+
+ - class: trustgraph.prompt.template.Processor
+ params:
+ <<: *defaults
+ id: prompt-rag
+ concurrency: 1
+
+ - class: trustgraph.query.doc_embeddings.qdrant.Processor
+ params:
+ <<: *defaults
+ id: doc-embeddings-query
+ store_uri: http://localhost:6333
+
+ - class: trustgraph.query.graph_embeddings.qdrant.Processor
+ params:
+ <<: *defaults
+ id: graph-embeddings-query
+ store_uri: http://localhost:6333
+
+ - class: trustgraph.query.row_embeddings.qdrant.Processor
+ params:
+ <<: *defaults
+ id: row-embeddings-query
+ store_uri: http://localhost:6333
+
+ - class: trustgraph.query.rows.cassandra.Processor
+ params:
+ <<: *defaults
+ id: rows-query
+ cassandra_host: localhost
+
+ - class: trustgraph.query.triples.cassandra.Processor
+ params:
+ <<: *defaults
+ id: triples-query
+ cassandra_host: localhost
+
+ - class: trustgraph.embeddings.row_embeddings.Processor
+ params:
+ <<: *defaults
+ id: row-embeddings
+
+ - class: trustgraph.query.sparql.Processor
+ params:
+ <<: *defaults
+ id: sparql-query
+
+ - class: trustgraph.storage.doc_embeddings.qdrant.Processor
+ params:
+ <<: *defaults
+ id: doc-embeddings-write
+ store_uri: http://localhost:6333
+
+ - class: trustgraph.storage.graph_embeddings.qdrant.Processor
+ params:
+ <<: *defaults
+ id: graph-embeddings-write
+ store_uri: http://localhost:6333
+
+ - class: trustgraph.storage.row_embeddings.qdrant.Processor
+ params:
+ <<: *defaults
+ id: row-embeddings-write
+ store_uri: http://localhost:6333
+
+ - class: trustgraph.storage.rows.cassandra.Processor
+ params:
+ <<: *defaults
+ id: rows-write
+ cassandra_host: localhost
+
+ - class: trustgraph.storage.triples.cassandra.Processor
+ params:
+ <<: *defaults
+ id: triples-write
+ cassandra_host: localhost
+
+ - class: trustgraph.retrieval.structured_diag.Processor
+ params:
+ <<: *defaults
+ id: structured-diag
+
+ - class: trustgraph.retrieval.structured_query.Processor
+ params:
+ <<: *defaults
+ id: structured-query
+
+ - class: trustgraph.model.text_completion.openai.Processor
+ params:
+ <<: *defaults
+ id: text-completion
+ max_output: 8192
+ temperature: 0.0
+
+ - class: trustgraph.model.text_completion.openai.Processor
+ params:
+ <<: *defaults
+ id: text-completion-rag
+ max_output: 8192
+ temperature: 0.0
diff --git a/dev-tools/proc-group/groups/control.yaml b/dev-tools/proc-group/groups/control.yaml
new file mode 100644
index 00000000..b9ee9bfa
--- /dev/null
+++ b/dev-tools/proc-group/groups/control.yaml
@@ -0,0 +1,47 @@
+# Control plane. Stateful "always on" services that every flow depends on.
+# Cassandra-heavy, low traffic.
+
+_defaults: &defaults
+ pubsub_backend: rabbitmq
+ rabbitmq_host: localhost
+ log_level: INFO
+
+processors:
+
+ - class: trustgraph.config.service.Processor
+ params:
+ <<: *defaults
+ id: config-svc
+ cassandra_host: localhost
+
+ - class: trustgraph.librarian.Processor
+ params:
+ <<: *defaults
+ id: librarian
+ cassandra_host: localhost
+ object_store_endpoint: localhost:3900
+ object_store_access_key: GK000000000000000000000001
+ object_store_secret_key: b171f00be9be4c32c734f4c05fe64c527a8ab5eb823b376cfa8c2531f70fc427
+ object_store_region: garage
+
+ - class: trustgraph.cores.service.Processor
+ params:
+ <<: *defaults
+ id: knowledge
+ cassandra_host: localhost
+
+ - class: trustgraph.storage.knowledge.store.Processor
+ params:
+ <<: *defaults
+ id: kg-store
+ cassandra_host: localhost
+
+ - class: trustgraph.metering.Processor
+ params:
+ <<: *defaults
+ id: metering
+
+ - class: trustgraph.metering.Processor
+ params:
+ <<: *defaults
+ id: metering-rag
diff --git a/dev-tools/proc-group/groups/embeddings-store.yaml b/dev-tools/proc-group/groups/embeddings-store.yaml
new file mode 100644
index 00000000..b5d4a6c8
--- /dev/null
+++ b/dev-tools/proc-group/groups/embeddings-store.yaml
@@ -0,0 +1,45 @@
+# Embeddings store. All Qdrant-backed vector query/write processors.
+# One process owns the Qdrant driver pool for the whole group.
+
+_defaults: &defaults
+ pubsub_backend: rabbitmq
+ rabbitmq_host: localhost
+ log_level: INFO
+
+processors:
+
+ - class: trustgraph.query.doc_embeddings.qdrant.Processor
+ params:
+ <<: *defaults
+ id: doc-embeddings-query
+ store_uri: http://localhost:6333
+
+ - class: trustgraph.storage.doc_embeddings.qdrant.Processor
+ params:
+ <<: *defaults
+ id: doc-embeddings-write
+ store_uri: http://localhost:6333
+
+ - class: trustgraph.query.graph_embeddings.qdrant.Processor
+ params:
+ <<: *defaults
+ id: graph-embeddings-query
+ store_uri: http://localhost:6333
+
+ - class: trustgraph.storage.graph_embeddings.qdrant.Processor
+ params:
+ <<: *defaults
+ id: graph-embeddings-write
+ store_uri: http://localhost:6333
+
+ - class: trustgraph.query.row_embeddings.qdrant.Processor
+ params:
+ <<: *defaults
+ id: row-embeddings-query
+ store_uri: http://localhost:6333
+
+ - class: trustgraph.storage.row_embeddings.qdrant.Processor
+ params:
+ <<: *defaults
+ id: row-embeddings-write
+ store_uri: http://localhost:6333
diff --git a/dev-tools/proc-group/groups/embeddings.yaml b/dev-tools/proc-group/groups/embeddings.yaml
new file mode 100644
index 00000000..a4e0298b
--- /dev/null
+++ b/dev-tools/proc-group/groups/embeddings.yaml
@@ -0,0 +1,31 @@
+# Embeddings. Memory-hungry — fastembed loads an ML model at startup.
+# Keep isolated from other groups so its memory footprint and restart
+# latency don't affect siblings.
+
+_defaults: &defaults
+ pubsub_backend: rabbitmq
+ rabbitmq_host: localhost
+ log_level: INFO
+
+processors:
+
+ - class: trustgraph.embeddings.fastembed.Processor
+ params:
+ <<: *defaults
+ id: embeddings
+ concurrency: 1
+
+ - class: trustgraph.embeddings.document_embeddings.Processor
+ params:
+ <<: *defaults
+ id: document-embeddings
+
+ - class: trustgraph.embeddings.graph_embeddings.Processor
+ params:
+ <<: *defaults
+ id: graph-embeddings
+
+ - class: trustgraph.embeddings.row_embeddings.Processor
+ params:
+ <<: *defaults
+ id: row-embeddings
diff --git a/dev-tools/proc-group/groups/ingest.yaml b/dev-tools/proc-group/groups/ingest.yaml
new file mode 100644
index 00000000..146a6339
--- /dev/null
+++ b/dev-tools/proc-group/groups/ingest.yaml
@@ -0,0 +1,52 @@
+# Ingest pipeline. Document-processing hot path. Bursty, correlated
+# failures — if chunker dies the extractors have nothing to do anyway.
+
+_defaults: &defaults
+ pubsub_backend: rabbitmq
+ rabbitmq_host: localhost
+ log_level: INFO
+
+processors:
+
+ - class: trustgraph.chunking.recursive.Processor
+ params:
+ <<: *defaults
+ id: chunker
+ chunk_size: 2000
+ chunk_overlap: 50
+
+ - class: trustgraph.extract.kg.agent.Processor
+ params:
+ <<: *defaults
+ id: kg-extract-agent
+ concurrency: 1
+
+ - class: trustgraph.extract.kg.definitions.Processor
+ params:
+ <<: *defaults
+ id: kg-extract-definitions
+ concurrency: 1
+
+ - class: trustgraph.extract.kg.ontology.Processor
+ params:
+ <<: *defaults
+ id: kg-extract-ontology
+ concurrency: 1
+
+ - class: trustgraph.extract.kg.relationships.Processor
+ params:
+ <<: *defaults
+ id: kg-extract-relationships
+ concurrency: 1
+
+ - class: trustgraph.extract.kg.rows.Processor
+ params:
+ <<: *defaults
+ id: kg-extract-rows
+ concurrency: 1
+
+ - class: trustgraph.prompt.template.Processor
+ params:
+ <<: *defaults
+ id: prompt
+ concurrency: 1
diff --git a/dev-tools/proc-group/groups/llm.yaml b/dev-tools/proc-group/groups/llm.yaml
new file mode 100644
index 00000000..35930dbf
--- /dev/null
+++ b/dev-tools/proc-group/groups/llm.yaml
@@ -0,0 +1,24 @@
+# LLM. Outbound text-completion calls. Isolated because the upstream
+# LLM API is often the bottleneck and the most likely thing to need
+# restart (provider changes, model changes, API flakiness).
+
+_defaults: &defaults
+ pubsub_backend: rabbitmq
+ rabbitmq_host: localhost
+ log_level: INFO
+
+processors:
+
+ - class: trustgraph.model.text_completion.openai.Processor
+ params:
+ <<: *defaults
+ id: text-completion
+ max_output: 8192
+ temperature: 0.0
+
+ - class: trustgraph.model.text_completion.openai.Processor
+ params:
+ <<: *defaults
+ id: text-completion-rag
+ max_output: 8192
+ temperature: 0.0
diff --git a/dev-tools/proc-group/groups/rag.yaml b/dev-tools/proc-group/groups/rag.yaml
new file mode 100644
index 00000000..be27086b
--- /dev/null
+++ b/dev-tools/proc-group/groups/rag.yaml
@@ -0,0 +1,64 @@
+# RAG / retrieval / agent. Query-time serving path. Drives outbound
+# LLM calls via prompt-rag. sparql-query lives here because it's a
+# read-side serving endpoint, not a backend writer.
+
+_defaults: &defaults
+ pubsub_backend: rabbitmq
+ rabbitmq_host: localhost
+ log_level: INFO
+
+processors:
+
+ - class: trustgraph.agent.orchestrator.Processor
+ params:
+ <<: *defaults
+ id: agent-manager
+
+ - class: trustgraph.retrieval.graph_rag.Processor
+ params:
+ <<: *defaults
+ id: graph-rag
+ concurrency: 1
+ entity_limit: 50
+ triple_limit: 30
+ edge_limit: 30
+ edge_score_limit: 10
+ max_subgraph_size: 100
+ max_path_length: 2
+
+ - class: trustgraph.retrieval.document_rag.Processor
+ params:
+ <<: *defaults
+ id: document-rag
+ doc_limit: 20
+
+ - class: trustgraph.retrieval.nlp_query.Processor
+ params:
+ <<: *defaults
+ id: nlp-query
+
+ - class: trustgraph.retrieval.structured_query.Processor
+ params:
+ <<: *defaults
+ id: structured-query
+
+ - class: trustgraph.retrieval.structured_diag.Processor
+ params:
+ <<: *defaults
+ id: structured-diag
+
+ - class: trustgraph.query.sparql.Processor
+ params:
+ <<: *defaults
+ id: sparql-query
+
+ - class: trustgraph.prompt.template.Processor
+ params:
+ <<: *defaults
+ id: prompt-rag
+ concurrency: 1
+
+ - class: trustgraph.agent.mcp_tool.Service
+ params:
+ <<: *defaults
+ id: mcp-tool
diff --git a/dev-tools/proc-group/groups/rows-store.yaml b/dev-tools/proc-group/groups/rows-store.yaml
new file mode 100644
index 00000000..ed52556d
--- /dev/null
+++ b/dev-tools/proc-group/groups/rows-store.yaml
@@ -0,0 +1,20 @@
+# Rows store. Cassandra-backed structured row query/write.
+
+_defaults: &defaults
+ pubsub_backend: rabbitmq
+ rabbitmq_host: localhost
+ log_level: INFO
+
+processors:
+
+ - class: trustgraph.query.rows.cassandra.Processor
+ params:
+ <<: *defaults
+ id: rows-query
+ cassandra_host: localhost
+
+ - class: trustgraph.storage.rows.cassandra.Processor
+ params:
+ <<: *defaults
+ id: rows-write
+ cassandra_host: localhost
diff --git a/dev-tools/proc-group/groups/triples-store.yaml b/dev-tools/proc-group/groups/triples-store.yaml
new file mode 100644
index 00000000..4e32bfbd
--- /dev/null
+++ b/dev-tools/proc-group/groups/triples-store.yaml
@@ -0,0 +1,20 @@
+# Triples store. Cassandra-backed RDF triple query/write.
+
+_defaults: &defaults
+ pubsub_backend: rabbitmq
+ rabbitmq_host: localhost
+ log_level: INFO
+
+processors:
+
+ - class: trustgraph.query.triples.cassandra.Processor
+ params:
+ <<: *defaults
+ id: triples-query
+ cassandra_host: localhost
+
+ - class: trustgraph.storage.triples.cassandra.Processor
+ params:
+ <<: *defaults
+ id: triples-write
+ cassandra_host: localhost
diff --git a/dev-tools/tests/agent_dag/analyse_trace.py b/dev-tools/tests/agent_dag/analyse_trace.py
index b71cdebe..42cca118 100644
--- a/dev-tools/tests/agent_dag/analyse_trace.py
+++ b/dev-tools/tests/agent_dag/analyse_trace.py
@@ -131,21 +131,21 @@ async def analyse(path, url, flow, user, collection):
for i, msg in enumerate(messages):
resp = msg.get("response", {})
- chunk_type = resp.get("chunk_type", "?")
+ message_type = resp.get("message_type", "?")
- if chunk_type == "explain":
+ if message_type == "explain":
explain_id = resp.get("explain_id", "")
explain_ids.append(explain_id)
- print(f" {i:3d} {chunk_type} {explain_id}")
+ print(f" {i:3d} {message_type} {explain_id}")
else:
- print(f" {i:3d} {chunk_type}")
+ print(f" {i:3d} {message_type}")
# Rule 7: message_id on content chunks
- if chunk_type in ("thought", "observation", "answer"):
+ if message_type in ("thought", "observation", "answer"):
mid = resp.get("message_id", "")
if not mid:
errors.append(
- f"[msg {i}] {chunk_type} chunk missing message_id"
+ f"[msg {i}] {message_type} chunk missing message_id"
)
print()
diff --git a/docs/tech-specs/agent-explainability.md b/docs/tech-specs/agent-explainability.md
index e054f023..add39fd7 100644
--- a/docs/tech-specs/agent-explainability.md
+++ b/docs/tech-specs/agent-explainability.md
@@ -6,273 +6,212 @@ parent: "Tech Specs"
# Agent Explainability: Provenance Recording
+## Status
+
+Implemented
+
## Overview
-Add provenance recording to the React agent loop so agent sessions can be traced and debugged using the same explainability infrastructure as GraphRAG.
+Agent sessions are traced and debugged using the same explainability infrastructure as GraphRAG and Document RAG. Provenance is written to `urn:graph:retrieval` and delivered inline on the explain stream.
-**Design Decisions:**
-- Write to `urn:graph:retrieval` (generic explainability graph)
-- Linear dependency chain for now (analysis N → wasDerivedFrom → analysis N-1)
-- Tools are opaque black boxes (record input/output only)
-- DAG support deferred to future iteration
+The canonical vocabulary for all predicates and types is published as an OWL ontology at `specs/ontology/trustgraph.ttl`.
## Entity Types
-Both GraphRAG and Agent use PROV-O as the base ontology with TrustGraph-specific subtypes:
+All services use PROV-O as the base ontology with TrustGraph-specific subtypes.
### GraphRAG Types
-| Entity | PROV-O Type | TG Types | Description |
-|--------|-------------|----------|-------------|
-| Question | `prov:Activity` | `tg:Question`, `tg:GraphRagQuestion` | The user's query |
-| Exploration | `prov:Entity` | `tg:Exploration` | Edges retrieved from knowledge graph |
-| Focus | `prov:Entity` | `tg:Focus` | Selected edges with reasoning |
-| Synthesis | `prov:Entity` | `tg:Synthesis` | Final answer |
-
-### Agent Types
-| Entity | PROV-O Type | TG Types | Description |
-|--------|-------------|----------|-------------|
-| Question | `prov:Activity` | `tg:Question`, `tg:AgentQuestion` | The user's query |
-| Analysis | `prov:Entity` | `tg:Analysis` | Each think/act/observe cycle |
-| Conclusion | `prov:Entity` | `tg:Conclusion` | Final answer |
+| Entity | TG Types | Description |
+|--------|----------|-------------|
+| Question | `tg:Question`, `tg:GraphRagQuestion` | The user's query |
+| Grounding | `tg:Grounding` | Concept extraction from query |
+| Exploration | `tg:Exploration` | Edges retrieved from knowledge graph |
+| Focus | `tg:Focus` | Selected edges with reasoning |
+| Synthesis | `tg:Synthesis`, `tg:Answer` | Final answer |
### Document RAG Types
-| Entity | PROV-O Type | TG Types | Description |
-|--------|-------------|----------|-------------|
-| Question | `prov:Activity` | `tg:Question`, `tg:DocRagQuestion` | The user's query |
-| Exploration | `prov:Entity` | `tg:Exploration` | Chunks retrieved from document store |
-| Synthesis | `prov:Entity` | `tg:Synthesis` | Final answer |
+| Entity | TG Types | Description |
+|--------|----------|-------------|
+| Question | `tg:Question`, `tg:DocRagQuestion` | The user's query |
+| Grounding | `tg:Grounding` | Concept extraction from query |
+| Exploration | `tg:Exploration` | Chunks retrieved from document store |
+| Synthesis | `tg:Synthesis`, `tg:Answer` | Final answer |
-**Note:** Document RAG uses a subset of GraphRAG's types (no Focus step since there's no edge selection/reasoning phase).
+### Agent Types (React)
+| Entity | TG Types | Description |
+|--------|----------|-------------|
+| Question | `tg:Question`, `tg:AgentQuestion` | The user's query (session start) |
+| PatternDecision | `tg:PatternDecision` | Meta-router routing decision |
+| Analysis | `tg:Analysis`, `tg:ToolUse` | One think/act cycle |
+| Thought | `tg:Reflection`, `tg:Thought` | Agent reasoning (sub-entity of Analysis) |
+| Observation | `tg:Reflection`, `tg:Observation` | Tool result (standalone entity) |
+| Conclusion | `tg:Conclusion`, `tg:Answer` | Final answer |
+
+### Agent Types (Orchestrator — Plan)
+| Entity | TG Types | Description |
+|--------|----------|-------------|
+| Plan | `tg:Plan` | Structured plan of steps |
+| StepResult | `tg:StepResult`, `tg:Answer` | Result from executing one plan step |
+| Synthesis | `tg:Synthesis`, `tg:Answer` | Final synthesised answer |
+
+### Agent Types (Orchestrator — Supervisor)
+| Entity | TG Types | Description |
+|--------|----------|-------------|
+| Decomposition | `tg:Decomposition` | Question decomposed into sub-goals |
+| Finding | `tg:Finding`, `tg:Answer` | Result from a sub-agent |
+| Synthesis | `tg:Synthesis`, `tg:Answer` | Final synthesised answer |
+
+### Mixin Types
+| Type | Description |
+|------|-------------|
+| `tg:Answer` | Unifying type for terminal answers (Synthesis, Conclusion, Finding, StepResult) |
+| `tg:Reflection` | Unifying type for intermediate commentary (Thought, Observation) |
+| `tg:ToolUse` | Applied to Analysis when a tool is invoked |
+| `tg:Error` | Applied to Observation events where a failure occurred (tool error or LLM parse error) |
### Question Subtypes
-All Question entities share `tg:Question` as a base type but have a specific subtype to identify the retrieval mechanism:
-
| Subtype | URI Pattern | Mechanism |
|---------|-------------|-----------|
| `tg:GraphRagQuestion` | `urn:trustgraph:question:{uuid}` | Knowledge graph RAG |
| `tg:DocRagQuestion` | `urn:trustgraph:docrag:{uuid}` | Document/chunk RAG |
-| `tg:AgentQuestion` | `urn:trustgraph:agent:{uuid}` | ReAct agent |
+| `tg:AgentQuestion` | `urn:trustgraph:agent:session:{uuid}` | Agent orchestrator |
-This allows querying all questions via `tg:Question` while filtering by specific mechanism via the subtype.
+## Provenance Chains
-## Provenance Model
+All chains use `prov:wasDerivedFrom` links. Each entity is a `prov:Entity`.
+
+### GraphRAG
```
-Question (urn:trustgraph:agent:{uuid})
- │
- │ tg:query = "User's question"
- │ prov:startedAtTime = timestamp
- │ rdf:type = prov:Activity, tg:Question
- │
- ↓ prov:wasDerivedFrom
- │
-Analysis1 (urn:trustgraph:agent:{uuid}/i1)
- │
- │ tg:thought = "I need to query the knowledge base..."
- │ tg:action = "knowledge-query"
- │ tg:arguments = {"question": "..."}
- │ tg:observation = "Result from tool..."
- │ rdf:type = prov:Entity, tg:Analysis
- │
- ↓ prov:wasDerivedFrom
- │
-Analysis2 (urn:trustgraph:agent:{uuid}/i2)
- │ ...
- ↓ prov:wasDerivedFrom
- │
-Conclusion (urn:trustgraph:agent:{uuid}/final)
- │
- │ tg:answer = "The final response..."
- │ rdf:type = prov:Entity, tg:Conclusion
+Question → Grounding → Exploration → Focus → Synthesis
```
-### Document RAG Provenance Model
+### Document RAG
```
-Question (urn:trustgraph:docrag:{uuid})
- │
- │ tg:query = "User's question"
- │ prov:startedAtTime = timestamp
- │ rdf:type = prov:Activity, tg:Question
- │
- ↓ prov:wasGeneratedBy
- │
-Exploration (urn:trustgraph:docrag:{uuid}/exploration)
- │
- │ tg:chunkCount = 5
- │ tg:selectedChunk = "chunk-id-1"
- │ tg:selectedChunk = "chunk-id-2"
- │ ...
- │ rdf:type = prov:Entity, tg:Exploration
- │
- ↓ prov:wasDerivedFrom
- │
-Synthesis (urn:trustgraph:docrag:{uuid}/synthesis)
- │
- │ tg:content = "The synthesized answer..."
- │ rdf:type = prov:Entity, tg:Synthesis
+Question → Grounding → Exploration → Synthesis
```
-## Changes Required
+### Agent React
-### 1. Schema Changes
-
-**File:** `trustgraph-base/trustgraph/schema/services/agent.py`
-
-Add `session_id` and `collection` fields to `AgentRequest`:
-```python
-@dataclass
-class AgentRequest:
- question: str = ""
- state: str = ""
- group: list[str] | None = None
- history: list[AgentStep] = field(default_factory=list)
- user: str = ""
- collection: str = "default" # NEW: Collection for provenance traces
- streaming: bool = False
- session_id: str = "" # NEW: For provenance tracking across iterations
+```
+Question → PatternDecision → Analysis(1) → Observation(1) → Analysis(2) → ... → Conclusion
```
-**File:** `trustgraph-base/trustgraph/messaging/translators/agent.py`
+The PatternDecision entity records which execution pattern the meta-router selected. It is only emitted on the first iteration when routing occurs.
-Update translator to handle `session_id` and `collection` in both `to_pulsar()` and `from_pulsar()`.
+Thought sub-entities derive from their parent Analysis. Observation entities derive from their parent Analysis (or from a sub-trace entity if the tool produced its own explainability, e.g. a GraphRAG query).
-### 2. Add Explainability Producer to Agent Service
+### Agent Plan-then-Execute
-**File:** `trustgraph-flow/trustgraph/agent/react/service.py`
-
-Register an "explainability" producer (same pattern as GraphRAG):
-```python
-from ... base import ProducerSpec
-from ... schema import Triples
-
-# In __init__:
-self.register_specification(
- ProducerSpec(
- name = "explainability",
- schema = Triples,
- )
-)
+```
+Question → PatternDecision → Plan → StepResult(0) → StepResult(1) → ... → Synthesis
```
-### 3. Provenance Triple Generation
+### Agent Supervisor
-**File:** `trustgraph-base/trustgraph/provenance/agent.py`
-
-Create helper functions (similar to GraphRAG's `question_triples`, `exploration_triples`, etc.):
-```python
-def agent_session_triples(session_uri, query, timestamp):
- """Generate triples for agent Question."""
- return [
- Triple(s=session_uri, p=RDF_TYPE, o=PROV_ACTIVITY),
- Triple(s=session_uri, p=RDF_TYPE, o=TG_QUESTION),
- Triple(s=session_uri, p=TG_QUERY, o=query),
- Triple(s=session_uri, p=PROV_STARTED_AT_TIME, o=timestamp),
- ]
-
-def agent_iteration_triples(iteration_uri, parent_uri, thought, action, arguments, observation):
- """Generate triples for one Analysis step."""
- return [
- Triple(s=iteration_uri, p=RDF_TYPE, o=PROV_ENTITY),
- Triple(s=iteration_uri, p=RDF_TYPE, o=TG_ANALYSIS),
- Triple(s=iteration_uri, p=TG_THOUGHT, o=thought),
- Triple(s=iteration_uri, p=TG_ACTION, o=action),
- Triple(s=iteration_uri, p=TG_ARGUMENTS, o=json.dumps(arguments)),
- Triple(s=iteration_uri, p=TG_OBSERVATION, o=observation),
- Triple(s=iteration_uri, p=PROV_WAS_DERIVED_FROM, o=parent_uri),
- ]
-
-def agent_final_triples(final_uri, parent_uri, answer):
- """Generate triples for Conclusion."""
- return [
- Triple(s=final_uri, p=RDF_TYPE, o=PROV_ENTITY),
- Triple(s=final_uri, p=RDF_TYPE, o=TG_CONCLUSION),
- Triple(s=final_uri, p=TG_ANSWER, o=answer),
- Triple(s=final_uri, p=PROV_WAS_DERIVED_FROM, o=parent_uri),
- ]
+```
+Question → PatternDecision → Decomposition → [fan-out sub-agents]
+ → Finding(0) → Finding(1) → ... → Synthesis
```
-### 4. Type Definitions
+Each sub-agent runs its own session with `wasDerivedFrom` linking back to the parent's Decomposition. Findings derive from their sub-agent's Conclusion.
-**File:** `trustgraph-base/trustgraph/provenance/namespaces.py`
+## Predicates
-Add explainability entity types and agent predicates:
-```python
-# Explainability entity types (used by both GraphRAG and Agent)
-TG_QUESTION = TG + "Question"
-TG_EXPLORATION = TG + "Exploration"
-TG_FOCUS = TG + "Focus"
-TG_SYNTHESIS = TG + "Synthesis"
-TG_ANALYSIS = TG + "Analysis"
-TG_CONCLUSION = TG + "Conclusion"
+### Session / Question
+| Predicate | Type | Description |
+|-----------|------|-------------|
+| `tg:query` | string | The user's query text |
+| `prov:startedAtTime` | string | ISO timestamp |
-# Agent predicates
-TG_THOUGHT = TG + "thought"
-TG_ACTION = TG + "action"
-TG_ARGUMENTS = TG + "arguments"
-TG_OBSERVATION = TG + "observation"
-TG_ANSWER = TG + "answer"
-```
+### Pattern Decision
+| Predicate | Type | Description |
+|-----------|------|-------------|
+| `tg:pattern` | string | Selected pattern (react, plan-then-execute, supervisor) |
+| `tg:taskType` | string | Identified task type (general, research, etc.) |
-## Files Modified
+### Analysis (Iteration)
+| Predicate | Type | Description |
+|-----------|------|-------------|
+| `tg:action` | string | Tool name selected by the agent |
+| `tg:arguments` | string | JSON-encoded arguments |
+| `tg:thought` | IRI | Link to Thought sub-entity |
+| `tg:toolCandidate` | string | Tool name available to the LLM (one per candidate) |
+| `tg:stepNumber` | integer | 1-based iteration counter |
+| `tg:llmDurationMs` | integer | LLM call duration in milliseconds |
+| `tg:inToken` | integer | Input token count |
+| `tg:outToken` | integer | Output token count |
+| `tg:llmModel` | string | Model identifier |
-| File | Change |
-|------|--------|
-| `trustgraph-base/trustgraph/schema/services/agent.py` | Add session_id and collection to AgentRequest |
-| `trustgraph-base/trustgraph/messaging/translators/agent.py` | Update translator for new fields |
-| `trustgraph-base/trustgraph/provenance/namespaces.py` | Add entity types, agent predicates, and Document RAG predicates |
-| `trustgraph-base/trustgraph/provenance/triples.py` | Add TG types to GraphRAG triple builders, add Document RAG triple builders |
-| `trustgraph-base/trustgraph/provenance/uris.py` | Add Document RAG URI generators |
-| `trustgraph-base/trustgraph/provenance/__init__.py` | Export new types, predicates, and Document RAG functions |
-| `trustgraph-base/trustgraph/schema/services/retrieval.py` | Add explain_id, explain_graph, and explain_triples to DocumentRagResponse |
-| `trustgraph-base/trustgraph/messaging/translators/retrieval.py` | Update DocumentRagResponseTranslator for explainability fields including inline triples |
-| `trustgraph-flow/trustgraph/agent/react/service.py` | Add explainability producer + recording logic |
-| `trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py` | Add explainability callback and emit provenance triples |
-| `trustgraph-flow/trustgraph/retrieval/document_rag/rag.py` | Add explainability producer and wire up callback |
-| `trustgraph-cli/trustgraph/cli/show_explain_trace.py` | Handle agent trace types |
-| `trustgraph-cli/trustgraph/cli/list_explain_traces.py` | List agent sessions alongside GraphRAG |
+### Observation
+| Predicate | Type | Description |
+|-----------|------|-------------|
+| `tg:document` | IRI | Librarian document reference |
+| `tg:toolDurationMs` | integer | Tool execution time in milliseconds |
+| `tg:toolError` | string | Error message (tool failure or LLM parse error) |
-## Files Created
+When `tg:toolError` is present, the Observation also carries the `tg:Error` mixin type.
-| File | Purpose |
-|------|---------|
-| `trustgraph-base/trustgraph/provenance/agent.py` | Agent-specific triple generators |
+### Conclusion / Synthesis
+| Predicate | Type | Description |
+|-----------|------|-------------|
+| `tg:document` | IRI | Librarian document reference |
+| `tg:terminationReason` | string | Why the loop stopped |
+| `tg:inToken` | integer | Input token count (synthesis LLM call) |
+| `tg:outToken` | integer | Output token count |
+| `tg:llmModel` | string | Model identifier |
-## CLI Updates
+Termination reason values:
+- `final-answer` -- LLM produced a confident answer (react)
+- `plan-complete` -- all plan steps executed (plan-then-execute)
+- `subagents-complete` -- all sub-agents reported back (supervisor)
-**Detection:** Both GraphRAG and Agent Questions have `tg:Question` type. Distinguished by:
-1. URI pattern: `urn:trustgraph:agent:` vs `urn:trustgraph:question:`
-2. Derived entities: `tg:Analysis` (agent) vs `tg:Exploration` (GraphRAG)
+### Decomposition
+| Predicate | Type | Description |
+|-----------|------|-------------|
+| `tg:subagentGoal` | string | Goal assigned to a sub-agent (one per goal) |
+| `tg:inToken` | integer | Input token count |
+| `tg:outToken` | integer | Output token count |
-**`list_explain_traces.py`:**
-- Shows Type column (Agent vs GraphRAG)
+### Plan
+| Predicate | Type | Description |
+|-----------|------|-------------|
+| `tg:planStep` | string | Goal for a plan step (one per step) |
+| `tg:inToken` | integer | Input token count |
+| `tg:outToken` | integer | Output token count |
-**`show_explain_trace.py`:**
-- Auto-detects trace type
-- Agent rendering shows: Question → Analysis step(s) → Conclusion
+### Token Counts on RAG Events
-## Backwards Compatibility
+Grounding, Focus, and Synthesis events on GraphRAG and Document RAG also carry `tg:inToken`, `tg:outToken`, and `tg:llmModel` for the LLM calls associated with that step.
-- `session_id` defaults to `""` - old requests work, just won't have provenance
-- `collection` defaults to `"default"` - reasonable fallback
-- CLI gracefully handles both trace types
+## Error Handling
+
+Tool execution errors and LLM parse errors are captured as Observation events rather than crashing the agent:
+
+- The error message is recorded on `tg:toolError`
+- The Observation carries the `tg:Error` mixin type
+- The error text becomes the observation content, visible to the LLM on the next iteration
+- The provenance chain is preserved (Observation derives from Analysis)
+- The agent gets another iteration to retry or choose a different approach
+
+## Vocabulary Reference
+
+The full OWL ontology covering all classes and predicates is at `specs/ontology/trustgraph.ttl`.
## Verification
```bash
-# Run an agent query
-tg-invoke-agent -q "What is the capital of France?"
+# Run an agent query with explainability
+tg-invoke-agent -q "What is quantum computing?" -x
-# List traces (should show agent sessions with Type column)
-tg-list-explain-traces -U trustgraph -C default
+# Run with token usage
+tg-invoke-agent -q "What is quantum computing?" --show-usage
-# Show agent trace
-tg-show-explain-trace "urn:trustgraph:agent:xxx"
+# GraphRAG with explainability
+tg-invoke-graph-rag -q "Tell me about AI" -x
+
+# Document RAG with explainability
+tg-invoke-document-rag -q "Summarize the findings" -x
```
-
-## Future Work (Not This PR)
-
-- DAG dependencies (when analysis N uses results from multiple prior analyses)
-- Tool-specific provenance linking (KnowledgeQuery → its GraphRAG trace)
-- Streaming provenance emission (emit as we go, not batch at end)
diff --git a/docs/tech-specs/agent-orchestration.md b/docs/tech-specs/agent-orchestration.md
index 19261af0..ab1569ed 100644
--- a/docs/tech-specs/agent-orchestration.md
+++ b/docs/tech-specs/agent-orchestration.md
@@ -868,7 +868,7 @@ independently.
Response chunk fields:
message_id UUID for this message (groups chunks)
session_id Which agent session produced this chunk
- chunk_type "thought" | "observation" | "answer" | ...
+ message_type "thought" | "observation" | "answer" | ...
content The chunk text
end_of_message True on the final chunk of this message
end_of_dialog True on the final message of the entire execution
diff --git a/docs/tech-specs/extraction-provenance-subgraph.md b/docs/tech-specs/extraction-provenance-subgraph.md
index 197691f2..62d3a701 100644
--- a/docs/tech-specs/extraction-provenance-subgraph.md
+++ b/docs/tech-specs/extraction-provenance-subgraph.md
@@ -209,3 +209,7 @@ def subgraph_provenance_triples(
This is a breaking change to the provenance model. Provenance has not
been released, so no migration is needed. The old `tg:reifies` /
`statement_uri` code can be removed outright.
+
+## Vocabulary Reference
+
+The full OWL ontology covering all extraction and query-time classes and predicates is at `specs/ontology/trustgraph.ttl`.
diff --git a/docs/tech-specs/extraction-time-provenance.md b/docs/tech-specs/extraction-time-provenance.md
index e6197f8a..6c4b4513 100644
--- a/docs/tech-specs/extraction-time-provenance.md
+++ b/docs/tech-specs/extraction-time-provenance.md
@@ -618,8 +618,13 @@ Link embedding entity IDs to chunk.
| Triple store | Use reification for triple → chunk provenance |
| Embedding provenance | Link entity ID → chunk ID |
+## Vocabulary Reference
+
+The full OWL ontology covering all extraction and query-time classes and predicates is at `specs/ontology/trustgraph.ttl`.
+
## References
-- Query-time provenance: `docs/tech-specs/query-time-provenance.md`
+- Query-time provenance: `docs/tech-specs/query-time-explainability.md`
+- Agent explainability: `docs/tech-specs/agent-explainability.md`
- PROV-O standard for provenance modeling
- Existing source metadata in knowledge graph (needs audit)
diff --git a/docs/tech-specs/python-api-refactor.md b/docs/tech-specs/python-api-refactor.md
index 97ebe2f7..dd0022df 100644
--- a/docs/tech-specs/python-api-refactor.md
+++ b/docs/tech-specs/python-api-refactor.md
@@ -710,17 +710,17 @@ class StreamingChunk:
@dataclasses.dataclass
class AgentThought(StreamingChunk):
"""Agent reasoning chunk"""
- chunk_type: str = "thought"
+ message_type: str = "thought"
@dataclasses.dataclass
class AgentObservation(StreamingChunk):
"""Agent tool observation chunk"""
- chunk_type: str = "observation"
+ message_type: str = "observation"
@dataclasses.dataclass
class AgentAnswer(StreamingChunk):
"""Agent final answer chunk"""
- chunk_type: str = "final-answer"
+ message_type: str = "final-answer"
end_of_dialog: bool = False
@dataclasses.dataclass
diff --git a/docs/tech-specs/query-time-explainability.md b/docs/tech-specs/query-time-explainability.md
index 0e1d18f6..02598563 100644
--- a/docs/tech-specs/query-time-explainability.md
+++ b/docs/tech-specs/query-time-explainability.md
@@ -144,6 +144,25 @@ Defined in `trustgraph-base/trustgraph/provenance/namespaces.py`:
| `TG_REASONING` | `https://trustgraph.ai/ns/reasoning` |
| `TG_CONTENT` | `https://trustgraph.ai/ns/content` |
| `TG_DOCUMENT` | `https://trustgraph.ai/ns/document` |
+| `TG_IN_TOKEN` | `https://trustgraph.ai/ns/inToken` |
+| `TG_OUT_TOKEN` | `https://trustgraph.ai/ns/outToken` |
+| `TG_LLM_MODEL` | `https://trustgraph.ai/ns/llmModel` |
+
+### Token Usage on Events
+
+Grounding, Focus, and Synthesis events carry per-event LLM token counts:
+
+| Predicate | Type | Present on |
+|-----------|------|------------|
+| `tg:inToken` | integer | Grounding, Focus, Synthesis |
+| `tg:outToken` | integer | Grounding, Focus, Synthesis |
+| `tg:llmModel` | string | Grounding, Focus, Synthesis |
+
+- **Grounding**: tokens from the extract-concepts LLM call
+- **Focus**: summed tokens from edge-scoring + edge-reasoning LLM calls
+- **Synthesis**: tokens from the synthesis LLM call
+
+Values are absent (not zero) when token counts are unavailable.
## GraphRagResponse Schema
@@ -267,8 +286,13 @@ Based on the provided knowledge statements...
| `trustgraph-flow/trustgraph/query/triples/cassandra/service.py` | Quoted triple query support |
| `trustgraph-cli/trustgraph/cli/invoke_graph_rag.py` | CLI with explainability display |
+## Vocabulary Reference
+
+The full OWL ontology covering all classes and predicates is at `specs/ontology/trustgraph.ttl`.
+
## References
- PROV-O (W3C Provenance Ontology): https://www.w3.org/TR/prov-o/
- RDF-star: https://w3c.github.io/rdf-star/
- Extraction-time provenance: `docs/tech-specs/extraction-time-provenance.md`
+- Agent explainability: `docs/tech-specs/agent-explainability.md`
diff --git a/docs/tech-specs/streaming-llm-responses.md b/docs/tech-specs/streaming-llm-responses.md
index 7ecea5a5..5f6d9877 100644
--- a/docs/tech-specs/streaming-llm-responses.md
+++ b/docs/tech-specs/streaming-llm-responses.md
@@ -399,28 +399,28 @@ The agent produces multiple types of output during its reasoning cycle:
- Answer (final response)
- Errors
-Since `chunk_type` identifies what kind of content is being sent, the separate
+Since `message_type` identifies what kind of content is being sent, the separate
`answer`, `error`, `thought`, and `observation` fields can be collapsed into
a single `content` field:
```python
class AgentResponse(Record):
- chunk_type = String() # "thought", "action", "observation", "answer", "error"
- content = String() # The actual content (interpretation depends on chunk_type)
+ message_type = String() # "thought", "action", "observation", "answer", "error"
+ content = String() # The actual content (interpretation depends on message_type)
end_of_message = Boolean() # Current thought/action/observation/answer is complete
end_of_dialog = Boolean() # Entire agent dialog is complete
```
**Field Semantics:**
-- `chunk_type`: Indicates what type of content is in the `content` field
+- `message_type`: Indicates what type of content is in the `content` field
- `"thought"`: Agent reasoning/thinking
- `"action"`: Tool/action being invoked
- `"observation"`: Result from tool execution
- `"answer"`: Final answer to the user's question
- `"error"`: Error message
-- `content`: The actual streamed content, interpreted based on `chunk_type`
+- `content`: The actual streamed content, interpreted based on `message_type`
- `end_of_message`: When `true`, the current chunk type is complete
- Example: All tokens for the current thought have been sent
@@ -434,27 +434,27 @@ class AgentResponse(Record):
When `streaming=true`:
1. **Thought streaming**:
- - Multiple chunks with `chunk_type="thought"`, `end_of_message=false`
+ - Multiple chunks with `message_type="thought"`, `end_of_message=false`
- Final thought chunk has `end_of_message=true`
2. **Action notification**:
- - Single chunk with `chunk_type="action"`, `end_of_message=true`
+ - Single chunk with `message_type="action"`, `end_of_message=true`
3. **Observation**:
- - Chunk(s) with `chunk_type="observation"`, final has `end_of_message=true`
+ - Chunk(s) with `message_type="observation"`, final has `end_of_message=true`
4. **Repeat** steps 1-3 as the agent reasons
5. **Final answer**:
- - `chunk_type="answer"` with the final response in `content`
+ - `message_type="answer"` with the final response in `content`
- Last chunk has `end_of_message=true`, `end_of_dialog=true`
**Example Stream Sequence:**
```
-{chunk_type: "thought", content: "I need to", end_of_message: false, end_of_dialog: false}
-{chunk_type: "thought", content: " search for...", end_of_message: true, end_of_dialog: false}
-{chunk_type: "action", content: "search", end_of_message: true, end_of_dialog: false}
-{chunk_type: "observation", content: "Found: ...", end_of_message: true, end_of_dialog: false}
-{chunk_type: "thought", content: "Based on this", end_of_message: false, end_of_dialog: false}
-{chunk_type: "thought", content: " I can answer...", end_of_message: true, end_of_dialog: false}
-{chunk_type: "answer", content: "The answer is...", end_of_message: true, end_of_dialog: true}
+{message_type: "thought", content: "I need to", end_of_message: false, end_of_dialog: false}
+{message_type: "thought", content: " search for...", end_of_message: true, end_of_dialog: false}
+{message_type: "action", content: "search", end_of_message: true, end_of_dialog: false}
+{message_type: "observation", content: "Found: ...", end_of_message: true, end_of_dialog: false}
+{message_type: "thought", content: "Based on this", end_of_message: false, end_of_dialog: false}
+{message_type: "thought", content: " I can answer...", end_of_message: true, end_of_dialog: false}
+{message_type: "answer", content: "The answer is...", end_of_message: true, end_of_dialog: true}
```
When `streaming=false`:
@@ -547,7 +547,7 @@ The following questions were resolved during specification:
populated and no other fields are needed. An error is always the final
communication - no subsequent messages are permitted or expected after
an error. For LLM/Prompt streams, `end_of_stream=true`. For Agent streams,
- `chunk_type="error"` with `end_of_dialog=true`.
+ `message_type="error"` with `end_of_dialog=true`.
3. **Partial Response Recovery**: The messaging protocol (Pulsar) is resilient,
so message-level retry is not needed. If a client loses track of the stream
diff --git a/specs/ontology/trustgraph.ttl b/specs/ontology/trustgraph.ttl
new file mode 100644
index 00000000..4c7de612
--- /dev/null
+++ b/specs/ontology/trustgraph.ttl
@@ -0,0 +1,415 @@
+@prefix tg: .
+@prefix owl: .
+@prefix rdf: .
+@prefix rdfs: .
+@prefix xsd: .
+@prefix prov: .
+
+# =============================================================================
+# Ontology declaration
+# =============================================================================
+
+
+ a owl:Ontology ;
+ rdfs:label "TrustGraph Ontology" ;
+ rdfs:comment "Vocabulary for TrustGraph provenance, extraction metadata, and explainability." ;
+ owl:versionInfo "2.3" .
+
+# =============================================================================
+# Classes — Extraction provenance
+# =============================================================================
+
+tg:Document a owl:Class ;
+ rdfs:subClassOf prov:Entity ;
+ rdfs:label "Document" ;
+ rdfs:comment "A loaded document (PDF, text, etc.)." .
+
+tg:Page a owl:Class ;
+ rdfs:subClassOf prov:Entity ;
+ rdfs:label "Page" ;
+ rdfs:comment "A page within a document." .
+
+tg:Section a owl:Class ;
+ rdfs:subClassOf prov:Entity ;
+ rdfs:label "Section" ;
+ rdfs:comment "A structural section within a document." .
+
+tg:Chunk a owl:Class ;
+ rdfs:subClassOf prov:Entity ;
+ rdfs:label "Chunk" ;
+ rdfs:comment "A text chunk produced by the chunker." .
+
+tg:Image a owl:Class ;
+ rdfs:subClassOf prov:Entity ;
+ rdfs:label "Image" ;
+ rdfs:comment "An image extracted from a document." .
+
+tg:Subgraph a owl:Class ;
+ rdfs:subClassOf prov:Entity ;
+ rdfs:label "Subgraph" ;
+ rdfs:comment "A set of triples extracted from a chunk." .
+
+# =============================================================================
+# Classes — Query-time explainability (shared)
+# =============================================================================
+
+tg:Question a owl:Class ;
+ rdfs:subClassOf prov:Entity ;
+ rdfs:label "Question" ;
+ rdfs:comment "Root entity for a query session." .
+
+tg:GraphRagQuestion a owl:Class ;
+ rdfs:subClassOf tg:Question ;
+ rdfs:label "Graph RAG Question" ;
+ rdfs:comment "A question answered via graph-based RAG." .
+
+tg:DocRagQuestion a owl:Class ;
+ rdfs:subClassOf tg:Question ;
+ rdfs:label "Document RAG Question" ;
+ rdfs:comment "A question answered via document-based RAG." .
+
+tg:AgentQuestion a owl:Class ;
+ rdfs:subClassOf tg:Question ;
+ rdfs:label "Agent Question" ;
+ rdfs:comment "A question answered via the agent orchestrator." .
+
+tg:Grounding a owl:Class ;
+ rdfs:subClassOf prov:Entity ;
+ rdfs:label "Grounding" ;
+ rdfs:comment "Concept extraction step (query decomposition into search terms)." .
+
+tg:Exploration a owl:Class ;
+ rdfs:subClassOf prov:Entity ;
+ rdfs:label "Exploration" ;
+ rdfs:comment "Entity/chunk retrieval step." .
+
+tg:Focus a owl:Class ;
+ rdfs:subClassOf prov:Entity ;
+ rdfs:label "Focus" ;
+ rdfs:comment "Edge selection and scoring step (GraphRAG)." .
+
+tg:Synthesis a owl:Class ;
+ rdfs:subClassOf prov:Entity ;
+ rdfs:label "Synthesis" ;
+ rdfs:comment "Final answer synthesis from retrieved context." .
+
+# =============================================================================
+# Classes — Agent provenance
+# =============================================================================
+
+tg:Analysis a owl:Class ;
+ rdfs:subClassOf prov:Entity ;
+ rdfs:label "Analysis" ;
+ rdfs:comment "One agent iteration: reasoning followed by tool selection." .
+
+tg:ToolUse a owl:Class ;
+ rdfs:label "ToolUse" ;
+ rdfs:comment "Mixin type applied to Analysis when a tool is invoked." .
+
+tg:Error a owl:Class ;
+ rdfs:label "Error" ;
+ rdfs:comment "Mixin type applied to events where a failure occurred (tool error, parse error)." .
+
+tg:Conclusion a owl:Class ;
+ rdfs:subClassOf prov:Entity ;
+ rdfs:label "Conclusion" ;
+ rdfs:comment "Agent final answer (ReAct pattern)." .
+
+tg:PatternDecision a owl:Class ;
+ rdfs:subClassOf prov:Entity ;
+ rdfs:label "Pattern Decision" ;
+ rdfs:comment "Meta-router decision recording which execution pattern was selected." .
+
+# --- Unifying types ---
+
+tg:Answer a owl:Class ;
+ rdfs:subClassOf prov:Entity ;
+ rdfs:label "Answer" ;
+ rdfs:comment "Unifying type for any terminal answer (Synthesis, Conclusion, Finding, StepResult)." .
+
+tg:Reflection a owl:Class ;
+ rdfs:subClassOf prov:Entity ;
+ rdfs:label "Reflection" ;
+ rdfs:comment "Unifying type for intermediate commentary (Thought, Observation)." .
+
+tg:Thought a owl:Class ;
+ rdfs:subClassOf tg:Reflection ;
+ rdfs:label "Thought" ;
+ rdfs:comment "Agent reasoning text within an iteration." .
+
+tg:Observation a owl:Class ;
+ rdfs:subClassOf tg:Reflection ;
+ rdfs:label "Observation" ;
+ rdfs:comment "Tool execution result." .
+
+# --- Orchestrator types ---
+
+tg:Decomposition a owl:Class ;
+ rdfs:subClassOf prov:Entity ;
+ rdfs:label "Decomposition" ;
+ rdfs:comment "Supervisor pattern: question decomposed into sub-goals." .
+
+tg:Finding a owl:Class ;
+ rdfs:subClassOf tg:Answer ;
+ rdfs:label "Finding" ;
+ rdfs:comment "Result from a sub-agent execution." .
+
+tg:Plan a owl:Class ;
+ rdfs:subClassOf prov:Entity ;
+ rdfs:label "Plan" ;
+ rdfs:comment "Plan-then-execute pattern: structured plan of steps." .
+
+tg:StepResult a owl:Class ;
+ rdfs:subClassOf tg:Answer ;
+ rdfs:label "Step Result" ;
+ rdfs:comment "Result from executing one plan step." .
+
+# =============================================================================
+# Properties — Extraction metadata
+# =============================================================================
+
+tg:contains a owl:ObjectProperty ;
+ rdfs:label "contains" ;
+ rdfs:comment "Links a parent entity to a child (e.g. Document contains Page, Subgraph contains triple)." .
+
+tg:pageCount a owl:DatatypeProperty ;
+ rdfs:label "page count" ;
+ rdfs:range xsd:integer ;
+ rdfs:domain tg:Document .
+
+tg:mimeType a owl:DatatypeProperty ;
+ rdfs:label "MIME type" ;
+ rdfs:range xsd:string ;
+ rdfs:domain tg:Document .
+
+tg:pageNumber a owl:DatatypeProperty ;
+ rdfs:label "page number" ;
+ rdfs:range xsd:integer ;
+ rdfs:domain tg:Page .
+
+tg:chunkIndex a owl:DatatypeProperty ;
+ rdfs:label "chunk index" ;
+ rdfs:range xsd:integer ;
+ rdfs:domain tg:Chunk .
+
+tg:charOffset a owl:DatatypeProperty ;
+ rdfs:label "character offset" ;
+ rdfs:range xsd:integer .
+
+tg:charLength a owl:DatatypeProperty ;
+ rdfs:label "character length" ;
+ rdfs:range xsd:integer .
+
+tg:chunkSize a owl:DatatypeProperty ;
+ rdfs:label "chunk size" ;
+ rdfs:range xsd:integer .
+
+tg:chunkOverlap a owl:DatatypeProperty ;
+ rdfs:label "chunk overlap" ;
+ rdfs:range xsd:integer .
+
+tg:componentVersion a owl:DatatypeProperty ;
+ rdfs:label "component version" ;
+ rdfs:range xsd:string .
+
+tg:llmModel a owl:DatatypeProperty ;
+ rdfs:label "LLM model" ;
+ rdfs:range xsd:string .
+
+tg:ontology a owl:DatatypeProperty ;
+ rdfs:label "ontology" ;
+ rdfs:range xsd:string .
+
+tg:embeddingModel a owl:DatatypeProperty ;
+ rdfs:label "embedding model" ;
+ rdfs:range xsd:string .
+
+tg:sourceText a owl:DatatypeProperty ;
+ rdfs:label "source text" ;
+ rdfs:range xsd:string .
+
+tg:sourceCharOffset a owl:DatatypeProperty ;
+ rdfs:label "source character offset" ;
+ rdfs:range xsd:integer .
+
+tg:sourceCharLength a owl:DatatypeProperty ;
+ rdfs:label "source character length" ;
+ rdfs:range xsd:integer .
+
+tg:elementTypes a owl:DatatypeProperty ;
+ rdfs:label "element types" ;
+ rdfs:range xsd:string .
+
+tg:tableCount a owl:DatatypeProperty ;
+ rdfs:label "table count" ;
+ rdfs:range xsd:integer .
+
+tg:imageCount a owl:DatatypeProperty ;
+ rdfs:label "image count" ;
+ rdfs:range xsd:integer .
+
+# =============================================================================
+# Properties — Query-time provenance (GraphRAG / DocumentRAG)
+# =============================================================================
+
+tg:query a owl:DatatypeProperty ;
+ rdfs:label "query" ;
+ rdfs:comment "The user's query text." ;
+ rdfs:range xsd:string ;
+ rdfs:domain tg:Question .
+
+tg:concept a owl:DatatypeProperty ;
+ rdfs:label "concept" ;
+ rdfs:comment "An extracted concept from the query." ;
+ rdfs:range xsd:string ;
+ rdfs:domain tg:Grounding .
+
+tg:entity a owl:ObjectProperty ;
+ rdfs:label "entity" ;
+ rdfs:comment "A seed entity retrieved during exploration." ;
+ rdfs:domain tg:Exploration .
+
+tg:edgeCount a owl:DatatypeProperty ;
+ rdfs:label "edge count" ;
+ rdfs:comment "Number of edges explored." ;
+ rdfs:range xsd:integer ;
+ rdfs:domain tg:Exploration .
+
+tg:selectedEdge a owl:ObjectProperty ;
+ rdfs:label "selected edge" ;
+ rdfs:comment "Link to an edge selection entity within a Focus event." ;
+ rdfs:domain tg:Focus .
+
+tg:edge a owl:ObjectProperty ;
+ rdfs:label "edge" ;
+ rdfs:comment "A quoted triple representing a knowledge graph edge." ;
+ rdfs:domain tg:Focus .
+
+tg:reasoning a owl:DatatypeProperty ;
+ rdfs:label "reasoning" ;
+ rdfs:comment "LLM-generated reasoning for an edge selection." ;
+ rdfs:range xsd:string .
+
+tg:document a owl:ObjectProperty ;
+ rdfs:label "document" ;
+ rdfs:comment "Reference to a document stored in the librarian." .
+
+tg:chunkCount a owl:DatatypeProperty ;
+ rdfs:label "chunk count" ;
+ rdfs:comment "Number of document chunks retrieved (DocumentRAG)." ;
+ rdfs:range xsd:integer ;
+ rdfs:domain tg:Exploration .
+
+tg:selectedChunk a owl:DatatypeProperty ;
+ rdfs:label "selected chunk" ;
+ rdfs:comment "A selected chunk ID (DocumentRAG)." ;
+ rdfs:range xsd:string ;
+ rdfs:domain tg:Exploration .
+
+# =============================================================================
+# Properties — Agent provenance
+# =============================================================================
+
+tg:thought a owl:ObjectProperty ;
+ rdfs:label "thought" ;
+ rdfs:comment "Links an Analysis iteration to its Thought sub-entity." ;
+ rdfs:domain tg:Analysis ;
+ rdfs:range tg:Thought .
+
+tg:action a owl:DatatypeProperty ;
+ rdfs:label "action" ;
+ rdfs:comment "The tool/action name selected by the agent." ;
+ rdfs:range xsd:string ;
+ rdfs:domain tg:Analysis .
+
+tg:arguments a owl:DatatypeProperty ;
+ rdfs:label "arguments" ;
+ rdfs:comment "JSON-encoded arguments passed to the tool." ;
+ rdfs:range xsd:string ;
+ rdfs:domain tg:Analysis .
+
+tg:observation a owl:ObjectProperty ;
+ rdfs:label "observation" ;
+ rdfs:comment "Links an Analysis iteration to its Observation sub-entity." ;
+ rdfs:domain tg:Analysis ;
+ rdfs:range tg:Observation .
+
+tg:toolCandidate a owl:DatatypeProperty ;
+ rdfs:label "tool candidate" ;
+ rdfs:comment "Name of a tool available to the LLM for this iteration. One triple per candidate." ;
+ rdfs:range xsd:string ;
+ rdfs:domain tg:Analysis .
+
+tg:stepNumber a owl:DatatypeProperty ;
+ rdfs:label "step number" ;
+ rdfs:comment "Explicit 1-based step counter for iteration events." ;
+ rdfs:range xsd:integer ;
+ rdfs:domain tg:Analysis .
+
+tg:terminationReason a owl:DatatypeProperty ;
+ rdfs:label "termination reason" ;
+ rdfs:comment "Why the agent loop stopped: final-answer, plan-complete, subagents-complete, max-iterations, error." ;
+ rdfs:range xsd:string .
+
+tg:pattern a owl:DatatypeProperty ;
+ rdfs:label "pattern" ;
+ rdfs:comment "Selected execution pattern (react, plan-then-execute, supervisor)." ;
+ rdfs:range xsd:string ;
+ rdfs:domain tg:PatternDecision .
+
+tg:taskType a owl:DatatypeProperty ;
+ rdfs:label "task type" ;
+ rdfs:comment "Identified task type from the meta-router (general, research, etc.)." ;
+ rdfs:range xsd:string ;
+ rdfs:domain tg:PatternDecision .
+
+tg:llmDurationMs a owl:DatatypeProperty ;
+ rdfs:label "LLM duration (ms)" ;
+ rdfs:comment "Time spent in the LLM prompt call, in milliseconds." ;
+ rdfs:range xsd:integer ;
+ rdfs:domain tg:Analysis .
+
+tg:toolDurationMs a owl:DatatypeProperty ;
+ rdfs:label "tool duration (ms)" ;
+ rdfs:comment "Time spent executing the tool, in milliseconds." ;
+ rdfs:range xsd:integer ;
+ rdfs:domain tg:Observation .
+
+tg:toolError a owl:DatatypeProperty ;
+ rdfs:label "tool error" ;
+ rdfs:comment "Error message from a failed tool execution." ;
+ rdfs:range xsd:string ;
+ rdfs:domain tg:Observation .
+
+# --- Token usage predicates (on any event that involves an LLM call) ---
+
+tg:inToken a owl:DatatypeProperty ;
+ rdfs:label "input tokens" ;
+ rdfs:comment "Input token count for the LLM call associated with this event." ;
+ rdfs:range xsd:integer .
+
+tg:outToken a owl:DatatypeProperty ;
+ rdfs:label "output tokens" ;
+ rdfs:comment "Output token count for the LLM call associated with this event." ;
+ rdfs:range xsd:integer .
+
+# --- Orchestrator predicates ---
+
+tg:subagentGoal a owl:DatatypeProperty ;
+ rdfs:label "sub-agent goal" ;
+ rdfs:comment "Goal string assigned to a sub-agent (Decomposition, Finding)." ;
+ rdfs:range xsd:string .
+
+tg:planStep a owl:DatatypeProperty ;
+ rdfs:label "plan step" ;
+ rdfs:comment "Goal string for a plan step (Plan, StepResult)." ;
+ rdfs:range xsd:string .
+
+# =============================================================================
+# Named graphs
+# =============================================================================
+# These are not OWL classes but documented here for reference:
+#
+# (default graph) — Core knowledge facts (extracted triples)
+# urn:graph:source — Extraction provenance (document → chunk → triple)
+# urn:graph:retrieval — Query-time explainability (question → exploration → synthesis)
diff --git a/tests/contract/conftest.py b/tests/contract/conftest.py
index 15082437..4fdfe83b 100644
--- a/tests/contract/conftest.py
+++ b/tests/contract/conftest.py
@@ -87,7 +87,7 @@ def sample_message_data():
"history": []
},
"AgentResponse": {
- "chunk_type": "answer",
+ "message_type": "answer",
"content": "Machine learning is a subset of AI.",
"end_of_message": True,
"end_of_dialog": True,
diff --git a/tests/contract/test_message_contracts.py b/tests/contract/test_message_contracts.py
index bc5bece1..6b7f82e7 100644
--- a/tests/contract/test_message_contracts.py
+++ b/tests/contract/test_message_contracts.py
@@ -212,7 +212,7 @@ class TestAgentMessageContracts:
# Test required fields
response = AgentResponse(**response_data)
- assert hasattr(response, 'chunk_type')
+ assert hasattr(response, 'message_type')
assert hasattr(response, 'content')
assert hasattr(response, 'end_of_message')
assert hasattr(response, 'end_of_dialog')
diff --git a/tests/contract/test_schema_field_contracts.py b/tests/contract/test_schema_field_contracts.py
new file mode 100644
index 00000000..4b7c3da5
--- /dev/null
+++ b/tests/contract/test_schema_field_contracts.py
@@ -0,0 +1,73 @@
+"""
+Contract tests for schema dataclass field sets.
+
+These pin the *field names* of small, widely-constructed schema dataclasses
+so that any rename, removal, or accidental addition fails CI loudly instead
+of waiting for a runtime TypeError on the next websocket message.
+
+Background: in v2.2 the `Metadata` dataclass dropped a `metadata: list[Triple]`
+field but several call sites kept passing `Metadata(metadata=...)`. The bug
+was only discovered when a websocket import dispatcher received its first
+real message in production. A trivial structural assertion of the kind
+below would have caught it at unit-test time.
+
+Add to this file whenever a schema rename burns you. The cost of a frozen
+field set is a one-line update when you intentionally evolve the schema; the
+benefit is that every call site is forced to come along for the ride.
+"""
+
+import dataclasses
+import pytest
+
+from trustgraph.schema import (
+ Metadata,
+ EntityContext,
+ EntityEmbeddings,
+ ChunkEmbeddings,
+)
+
+
+def _field_names(dc):
+ return {f.name for f in dataclasses.fields(dc)}
+
+
+@pytest.mark.contract
+class TestSchemaFieldContracts:
+ """Pin the field set of dataclasses that get constructed all over the
+ codebase. If you intentionally change one of these, update the
+ expected set in the same commit — that diff will surface every call
+ site that needs to come along."""
+
+ def test_metadata_fields(self):
+ # NOTE: there is no `metadata` field. A previous regression
+ # constructed Metadata(metadata=...) and crashed at runtime.
+ assert _field_names(Metadata) == {
+ "id",
+ "root",
+ "user",
+ "collection",
+ }
+
+ def test_entity_embeddings_fields(self):
+ # NOTE: the embedding field is `vector` (singular, list[float]).
+ # There is no `vectors` field. Several call sites historically
+ # passed `vectors=` and crashed at runtime.
+ assert _field_names(EntityEmbeddings) == {
+ "entity",
+ "vector",
+ "chunk_id",
+ }
+
+ def test_chunk_embeddings_fields(self):
+ # Same `vector` (singular) convention as EntityEmbeddings.
+ assert _field_names(ChunkEmbeddings) == {
+ "chunk_id",
+ "vector",
+ }
+
+ def test_entity_context_fields(self):
+ assert _field_names(EntityContext) == {
+ "entity",
+ "context",
+ "chunk_id",
+ }
diff --git a/tests/contract/test_translator_completion_flags.py b/tests/contract/test_translator_completion_flags.py
index 91ce1b77..606061f9 100644
--- a/tests/contract/test_translator_completion_flags.py
+++ b/tests/contract/test_translator_completion_flags.py
@@ -188,7 +188,7 @@ class TestAgentTranslatorCompletionFlags:
# Arrange
translator = TranslatorRegistry.get_response_translator("agent")
response = AgentResponse(
- chunk_type="answer",
+ message_type="answer",
content="4",
end_of_message=True,
end_of_dialog=True,
@@ -210,7 +210,7 @@ class TestAgentTranslatorCompletionFlags:
# Arrange
translator = TranslatorRegistry.get_response_translator("agent")
response = AgentResponse(
- chunk_type="thought",
+ message_type="thought",
content="I need to solve this.",
end_of_message=True,
end_of_dialog=False,
@@ -233,7 +233,7 @@ class TestAgentTranslatorCompletionFlags:
# Test thought message
thought_response = AgentResponse(
- chunk_type="thought",
+ message_type="thought",
content="Processing...",
end_of_message=True,
end_of_dialog=False,
@@ -247,7 +247,7 @@ class TestAgentTranslatorCompletionFlags:
# Test observation message
observation_response = AgentResponse(
- chunk_type="observation",
+ message_type="observation",
content="Result found",
end_of_message=True,
end_of_dialog=False,
@@ -268,7 +268,7 @@ class TestAgentTranslatorCompletionFlags:
# Streaming format with end_of_dialog=True
response = AgentResponse(
- chunk_type="answer",
+ message_type="answer",
content="",
end_of_message=True,
end_of_dialog=True,
diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py
index 7e18f0de..44a9f127 100644
--- a/tests/integration/conftest.py
+++ b/tests/integration/conftest.py
@@ -418,55 +418,55 @@ def sample_streaming_agent_response():
"""Sample streaming agent response chunks"""
return [
{
- "chunk_type": "thought",
+ "message_type": "thought",
"content": "I need to search",
"end_of_message": False,
"end_of_dialog": False
},
{
- "chunk_type": "thought",
+ "message_type": "thought",
"content": " for information",
"end_of_message": False,
"end_of_dialog": False
},
{
- "chunk_type": "thought",
+ "message_type": "thought",
"content": " about machine learning.",
"end_of_message": True,
"end_of_dialog": False
},
{
- "chunk_type": "action",
+ "message_type": "action",
"content": "knowledge_query",
"end_of_message": True,
"end_of_dialog": False
},
{
- "chunk_type": "observation",
+ "message_type": "observation",
"content": "Machine learning is",
"end_of_message": False,
"end_of_dialog": False
},
{
- "chunk_type": "observation",
+ "message_type": "observation",
"content": " a subset of AI.",
"end_of_message": True,
"end_of_dialog": False
},
{
- "chunk_type": "final-answer",
+ "message_type": "final-answer",
"content": "Machine learning",
"end_of_message": False,
"end_of_dialog": False
},
{
- "chunk_type": "final-answer",
+ "message_type": "final-answer",
"content": " is a subset",
"end_of_message": False,
"end_of_dialog": False
},
{
- "chunk_type": "final-answer",
+ "message_type": "final-answer",
"content": " of artificial intelligence.",
"end_of_message": True,
"end_of_dialog": True
@@ -494,10 +494,10 @@ def streaming_chunk_collector():
"""Concatenate all chunk content"""
return "".join(self.chunks)
- def get_chunk_types(self):
+ def get_message_types(self):
"""Get list of chunk types if chunks are dicts"""
if self.chunks and isinstance(self.chunks[0], dict):
- return [c.get("chunk_type") for c in self.chunks]
+ return [c.get("message_type") for c in self.chunks]
return []
def verify_streaming_protocol(self):
diff --git a/tests/integration/test_agent_manager_integration.py b/tests/integration/test_agent_manager_integration.py
index 652894a2..743ab4d2 100644
--- a/tests/integration/test_agent_manager_integration.py
+++ b/tests/integration/test_agent_manager_integration.py
@@ -15,6 +15,7 @@ from trustgraph.agent.react.agent_manager import AgentManager
from trustgraph.agent.react.tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl
from trustgraph.agent.react.types import Action, Final, Tool, Argument
from trustgraph.schema import AgentRequest, AgentResponse, AgentStep, Error
+from trustgraph.base import PromptResult
@pytest.mark.integration
@@ -28,19 +29,25 @@ class TestAgentManagerIntegration:
# Mock prompt client
prompt_client = AsyncMock()
- prompt_client.agent_react.return_value = """Thought: I need to search for information about machine learning
+ prompt_client.agent_react.return_value = PromptResult(
+ response_type="text",
+ text="""Thought: I need to search for information about machine learning
Action: knowledge_query
Args: {
"question": "What is machine learning?"
}"""
-
+ )
+
# Mock graph RAG client
graph_rag_client = AsyncMock()
graph_rag_client.rag.return_value = "Machine learning is a subset of AI that enables computers to learn from data."
-
+
# Mock text completion client
text_completion_client = AsyncMock()
- text_completion_client.question.return_value = "Machine learning involves algorithms that improve through experience."
+ text_completion_client.question.return_value = PromptResult(
+ response_type="text",
+ text="Machine learning involves algorithms that improve through experience."
+ )
# Mock MCP tool client
mcp_tool_client = AsyncMock()
@@ -147,8 +154,11 @@ Args: {
async def test_agent_manager_final_answer(self, agent_manager, mock_flow_context):
"""Test agent manager returning final answer"""
# Arrange
- mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have enough information to answer the question
+ mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
+ response_type="text",
+ text="""Thought: I have enough information to answer the question
Final Answer: Machine learning is a field of AI that enables computers to learn from data."""
+ )
question = "What is machine learning?"
history = []
@@ -193,8 +203,11 @@ Final Answer: Machine learning is a field of AI that enables computers to learn
async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context):
"""Test ReAct cycle ending with final answer"""
# Arrange
- mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide a direct answer
+ mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
+ response_type="text",
+ text="""Thought: I can provide a direct answer
Final Answer: Machine learning is a branch of artificial intelligence."""
+ )
question = "What is machine learning?"
history = []
@@ -254,11 +267,14 @@ Final Answer: Machine learning is a branch of artificial intelligence."""
for tool_name, expected_service in tool_scenarios:
# Arrange
- mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: I need to use {tool_name}
+ mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
+ response_type="text",
+ text=f"""Thought: I need to use {tool_name}
Action: {tool_name}
Args: {{
"question": "test question"
}}"""
+ )
think_callback = AsyncMock()
observe_callback = AsyncMock()
@@ -284,11 +300,14 @@ Args: {{
async def test_agent_manager_unknown_tool_error(self, agent_manager, mock_flow_context):
"""Test agent manager error handling for unknown tool"""
# Arrange
- mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to use an unknown tool
+ mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
+ response_type="text",
+ text="""Thought: I need to use an unknown tool
Action: unknown_tool
Args: {
"param": "value"
}"""
+ )
think_callback = AsyncMock()
observe_callback = AsyncMock()
@@ -308,11 +327,13 @@ Args: {
think_callback = AsyncMock()
observe_callback = AsyncMock()
- # Act & Assert
- with pytest.raises(Exception) as exc_info:
- await agent_manager.react("test question", [], think_callback, observe_callback, mock_flow_context)
-
- assert "Tool execution failed" in str(exc_info.value)
+ # Act - tool errors are now caught and returned as observations
+ result = await agent_manager.react("test question", [], think_callback, observe_callback, mock_flow_context)
+
+ # Assert - error captured on the action, not raised
+ assert result.tool_error is not None
+ assert "Tool execution failed" in result.tool_error
+ assert "Error:" in result.observation
@pytest.mark.asyncio
async def test_agent_manager_multiple_tools_coordination(self, agent_manager, mock_flow_context):
@@ -321,11 +342,14 @@ Args: {
question = "Find information about AI and summarize it"
# Mock multi-step reasoning
- mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search for AI information first
+ mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
+ response_type="text",
+ text="""Thought: I need to search for AI information first
Action: knowledge_query
Args: {
"question": "What is artificial intelligence?"
}"""
+ )
# Act
action = await agent_manager.reason(question, [], mock_flow_context)
@@ -372,9 +396,12 @@ Args: {
# Format arguments as JSON
import json
args_json = json.dumps(test_case['arguments'], indent=4)
- mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: Using {test_case['action']}
+ mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
+ response_type="text",
+ text=f"""Thought: Using {test_case['action']}
Action: {test_case['action']}
Args: {args_json}"""
+ )
think_callback = AsyncMock()
observe_callback = AsyncMock()
@@ -507,15 +534,17 @@ Args: {
]
for test_case in test_cases:
- mock_flow_context("prompt-request").agent_react.return_value = test_case["response"]
+ mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
+ response_type="text",
+ text=test_case["response"]
+ )
if test_case["error_contains"]:
- # Should raise an error
- with pytest.raises(RuntimeError) as exc_info:
- await agent_manager.reason("test question", [], mock_flow_context)
-
- assert "Failed to parse agent response" in str(exc_info.value)
- assert test_case["error_contains"] in str(exc_info.value)
+ # Parse errors now return an Action with tool_error
+ result = await agent_manager.reason("test question", [], mock_flow_context)
+ assert isinstance(result, Action)
+ assert result.name == "__parse_error__"
+ assert result.tool_error is not None
else:
# Should succeed
action = await agent_manager.reason("test question", [], mock_flow_context)
@@ -527,13 +556,16 @@ Args: {
async def test_agent_manager_text_parsing_edge_cases(self, agent_manager, mock_flow_context):
"""Test edge cases in text parsing"""
# Test response with markdown code blocks
- mock_flow_context("prompt-request").agent_react.return_value = """```
+ mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
+ response_type="text",
+ text="""```
Thought: I need to search for information
Action: knowledge_query
Args: {
"question": "What is AI?"
}
```"""
+ )
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@@ -541,15 +573,18 @@ Args: {
assert action.name == "knowledge_query"
# Test response with extra whitespace
- mock_flow_context("prompt-request").agent_react.return_value = """
+ mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
+ response_type="text",
+ text="""
-Thought: I need to think about this
-Action: knowledge_query
+Thought: I need to think about this
+Action: knowledge_query
Args: {
"question": "test"
}
"""
+ )
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@@ -560,7 +595,9 @@ Args: {
async def test_agent_manager_multiline_content(self, agent_manager, mock_flow_context):
"""Test handling of multi-line thoughts and final answers"""
# Multi-line thought
- mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to consider multiple factors:
+ mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
+ response_type="text",
+ text="""Thought: I need to consider multiple factors:
1. The user's question is complex
2. I should search for comprehensive information
3. This requires using the knowledge query tool
@@ -568,6 +605,7 @@ Action: knowledge_query
Args: {
"question": "complex query"
}"""
+ )
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@@ -575,13 +613,16 @@ Args: {
assert "knowledge query tool" in action.thought
# Multi-line final answer
- mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have gathered enough information
+ mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
+ response_type="text",
+ text="""Thought: I have gathered enough information
Final Answer: Here is a comprehensive answer:
1. First point about the topic
2. Second point with details
3. Final conclusion
This covers all aspects of the question."""
+ )
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Final)
@@ -593,13 +634,16 @@ This covers all aspects of the question."""
async def test_agent_manager_json_args_special_characters(self, agent_manager, mock_flow_context):
"""Test JSON arguments with special characters and edge cases"""
# Test with special characters in JSON (properly escaped)
- mock_flow_context("prompt-request").agent_react.return_value = """Thought: Processing special characters
+ mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
+ response_type="text",
+ text="""Thought: Processing special characters
Action: knowledge_query
Args: {
"question": "What about \\"quotes\\" and 'apostrophes'?",
"context": "Line 1\\nLine 2\\tTabbed",
"special": "Symbols: @#$%^&*()_+-=[]{}|;':,.<>?"
}"""
+ )
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@@ -608,7 +652,9 @@ Args: {
assert "@#$%^&*" in action.arguments["special"]
# Test with nested JSON
- mock_flow_context("prompt-request").agent_react.return_value = """Thought: Complex arguments
+ mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
+ response_type="text",
+ text="""Thought: Complex arguments
Action: web_search
Args: {
"query": "test",
@@ -621,6 +667,7 @@ Args: {
}
}
}"""
+ )
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Action)
@@ -632,7 +679,9 @@ Args: {
async def test_agent_manager_final_answer_json_format(self, agent_manager, mock_flow_context):
"""Test final answers that contain JSON-like content"""
# Final answer with JSON content
- mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide the data in JSON format
+ mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
+ response_type="text",
+ text="""Thought: I can provide the data in JSON format
Final Answer: {
"result": "success",
"data": {
@@ -642,6 +691,7 @@ Final Answer: {
},
"confidence": 0.95
}"""
+ )
action = await agent_manager.reason("test", [], mock_flow_context)
assert isinstance(action, Final)
@@ -792,11 +842,14 @@ Final Answer: {
agent = AgentManager(tools=custom_tools, additional_context="")
# Mock response for custom collection query
- mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search in the research papers
+ mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
+ response_type="text",
+ text="""Thought: I need to search in the research papers
Action: knowledge_query_custom
Args: {
"question": "Latest AI research?"
}"""
+ )
think_callback = AsyncMock()
observe_callback = AsyncMock()
diff --git a/tests/integration/test_agent_streaming_integration.py b/tests/integration/test_agent_streaming_integration.py
index d6004c21..de7372f1 100644
--- a/tests/integration/test_agent_streaming_integration.py
+++ b/tests/integration/test_agent_streaming_integration.py
@@ -10,11 +10,12 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.agent.react.agent_manager import AgentManager
from trustgraph.agent.react.tools import KnowledgeQueryImpl
from trustgraph.agent.react.types import Tool, Argument
+from trustgraph.base import PromptResult
from tests.utils.streaming_assertions import (
assert_agent_streaming_chunks,
assert_streaming_chunks_valid,
assert_callback_invoked,
- assert_chunk_types_valid,
+ assert_message_types_valid,
)
@@ -51,10 +52,10 @@ Args: {
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final)
- return full_text
+ return PromptResult(response_type="text", text=full_text)
else:
# Non-streaming response - same text
- return full_text
+ return PromptResult(response_type="text", text=full_text)
client.agent_react.side_effect = agent_react_streaming
return client
@@ -317,8 +318,8 @@ Final Answer: AI is the simulation of human intelligence in machines."""
for i, chunk in enumerate(chunks):
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk + " ", is_final)
- return response
- return response
+ return PromptResult(response_type="text", text=response)
+ return PromptResult(response_type="text", text=response)
mock_prompt_client_streaming.agent_react.side_effect = multi_step_agent_react
diff --git a/tests/integration/test_agent_structured_query_integration.py b/tests/integration/test_agent_structured_query_integration.py
index 0fedd2b5..2442bf10 100644
--- a/tests/integration/test_agent_structured_query_integration.py
+++ b/tests/integration/test_agent_structured_query_integration.py
@@ -16,6 +16,7 @@ from trustgraph.schema import (
Error
)
from trustgraph.agent.react.service import Processor
+from trustgraph.base import PromptResult
@pytest.mark.integration
@@ -95,11 +96,14 @@ class TestAgentStructuredQueryIntegration:
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
- mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from New York using structured query
+ mock_prompt_client.agent_react.return_value = PromptResult(
+ response_type="text",
+ text="""Thought: I need to find customers from New York using structured query
Action: structured-query
Args: {
"question": "Find all customers from New York"
}"""
+ )
# Set up flow context routing
def flow_context(service_name):
@@ -173,11 +177,14 @@ Args: {
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
- mock_prompt_client.agent_react.return_value = """Thought: I need to query for a table that might not exist
+ mock_prompt_client.agent_react.return_value = PromptResult(
+ response_type="text",
+ text="""Thought: I need to query for a table that might not exist
Action: structured-query
Args: {
"question": "Find data from a table that doesn't exist"
}"""
+ )
# Set up flow context routing
def flow_context(service_name):
@@ -250,11 +257,14 @@ Args: {
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
- mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from California first
+ mock_prompt_client.agent_react.return_value = PromptResult(
+ response_type="text",
+ text="""Thought: I need to find customers from California first
Action: structured-query
Args: {
"question": "Find all customers from California"
}"""
+ )
# Set up flow context routing
def flow_context(service_name):
@@ -339,11 +349,14 @@ Args: {
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
- mock_prompt_client.agent_react.return_value = """Thought: I need to query the sales data
+ mock_prompt_client.agent_react.return_value = PromptResult(
+ response_type="text",
+ text="""Thought: I need to query the sales data
Action: structured-query
Args: {
"question": "Query the sales data for recent transactions"
}"""
+ )
# Set up flow context routing
def flow_context(service_name):
@@ -447,11 +460,14 @@ Args: {
# Mock the prompt client that agent calls for reasoning
mock_prompt_client = AsyncMock()
- mock_prompt_client.agent_react.return_value = """Thought: I need to get customer information
+ mock_prompt_client.agent_react.return_value = PromptResult(
+ response_type="text",
+ text="""Thought: I need to get customer information
Action: structured-query
Args: {
"question": "Get customer information and format it nicely"
}"""
+ )
# Set up flow context routing
def flow_context(service_name):
diff --git a/tests/integration/test_document_rag_integration.py b/tests/integration/test_document_rag_integration.py
index e9df05cf..8c165385 100644
--- a/tests/integration/test_document_rag_integration.py
+++ b/tests/integration/test_document_rag_integration.py
@@ -10,6 +10,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import ChunkMatch
+from trustgraph.base import PromptResult
# Sample chunk content for testing - maps chunk_id to content
@@ -61,11 +62,16 @@ class TestDocumentRagIntegration:
def mock_prompt_client(self):
"""Mock prompt client that generates realistic responses"""
client = AsyncMock()
- client.document_prompt.return_value = (
- "Machine learning is a field of artificial intelligence that enables computers to learn "
- "and improve from experience without being explicitly programmed. It uses algorithms "
- "to find patterns in data and make predictions or decisions."
+ client.document_prompt.return_value = PromptResult(
+ response_type="text",
+ text=(
+ "Machine learning is a field of artificial intelligence that enables computers to learn "
+ "and improve from experience without being explicitly programmed. It uses algorithms "
+ "to find patterns in data and make predictions or decisions."
+ )
)
+ # Mock prompt() for extract-concepts call in DocumentRag
+ client.prompt.return_value = PromptResult(response_type="text", text="")
return client
@pytest.fixture
@@ -119,6 +125,7 @@ class TestDocumentRagIntegration:
)
# Verify final response
+ result, usage = result
assert result is not None
assert isinstance(result, str)
assert "machine learning" in result.lower()
@@ -131,7 +138,11 @@ class TestDocumentRagIntegration:
"""Test DocumentRAG behavior when no documents are retrieved"""
# Arrange
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
- mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query."
+ mock_prompt_client.document_prompt.return_value = PromptResult(
+ response_type="text",
+ text="I couldn't find any relevant documents for your query."
+ )
+ mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
document_rag = DocumentRag(
embeddings_client=mock_embeddings_client,
@@ -152,7 +163,8 @@ class TestDocumentRagIntegration:
documents=[]
)
- assert result == "I couldn't find any relevant documents for your query."
+ result_text, usage = result
+ assert result_text == "I couldn't find any relevant documents for your query."
@pytest.mark.asyncio
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
diff --git a/tests/integration/test_document_rag_streaming_integration.py b/tests/integration/test_document_rag_streaming_integration.py
index dad30a8f..e2c032ad 100644
--- a/tests/integration/test_document_rag_streaming_integration.py
+++ b/tests/integration/test_document_rag_streaming_integration.py
@@ -9,6 +9,7 @@ import pytest
from unittest.mock import AsyncMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import ChunkMatch
+from trustgraph.base import PromptResult
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
assert_callback_invoked,
@@ -74,12 +75,14 @@ class TestDocumentRagStreaming:
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final)
- return full_text
+ return PromptResult(response_type="text", text=full_text)
else:
# Non-streaming response - same text
- return full_text
+ return PromptResult(response_type="text", text=full_text)
client.document_prompt.side_effect = document_prompt_side_effect
+ # Mock prompt() for extract-concepts call in DocumentRag
+ client.prompt.return_value = PromptResult(response_type="text", text="")
return client
@pytest.fixture
@@ -119,11 +122,12 @@ class TestDocumentRagStreaming:
collector.verify_streaming_protocol()
# Verify full response matches concatenated chunks
+ result_text, usage = result
full_from_chunks = collector.get_full_text()
- assert result == full_from_chunks
+ assert result_text == full_from_chunks
# Verify content is reasonable
- assert len(result) > 0
+ assert len(result_text) > 0
@pytest.mark.asyncio
async def test_document_rag_streaming_vs_non_streaming(self, document_rag_streaming):
@@ -159,9 +163,11 @@ class TestDocumentRagStreaming:
)
# Assert - Results should be equivalent
- assert streaming_result == non_streaming_result
+ non_streaming_text, _ = non_streaming_result
+ streaming_text, _ = streaming_result
+ assert streaming_text == non_streaming_text
assert len(streaming_chunks) > 0
- assert "".join(streaming_chunks) == streaming_result
+ assert "".join(streaming_chunks) == streaming_text
@pytest.mark.asyncio
async def test_document_rag_streaming_callback_invocation(self, document_rag_streaming):
@@ -180,8 +186,9 @@ class TestDocumentRagStreaming:
)
# Assert
+ result_text, usage = result
assert callback.call_count > 0
- assert result is not None
+ assert result_text is not None
# Verify all callback invocations had string arguments
for call in callback.call_args_list:
@@ -202,7 +209,8 @@ class TestDocumentRagStreaming:
# Assert - Should complete without error
assert result is not None
- assert isinstance(result, str)
+ result_text, usage = result
+ assert isinstance(result_text, str)
@pytest.mark.asyncio
async def test_document_rag_streaming_with_no_documents(self, document_rag_streaming,
@@ -223,7 +231,8 @@ class TestDocumentRagStreaming:
)
# Assert - Should still produce streamed response
- assert result is not None
+ result_text, usage = result
+ assert result_text is not None
assert callback.call_count > 0
@pytest.mark.asyncio
@@ -271,7 +280,8 @@ class TestDocumentRagStreaming:
)
# Assert
- assert result is not None
+ result_text, usage = result
+ assert result_text is not None
assert callback.call_count > 0
# Verify doc_limit was passed correctly
diff --git a/tests/integration/test_graph_rag_integration.py b/tests/integration/test_graph_rag_integration.py
index 5e3279e3..9c3cdf45 100644
--- a/tests/integration/test_graph_rag_integration.py
+++ b/tests/integration/test_graph_rag_integration.py
@@ -12,6 +12,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.schema import EntityMatch, Term, IRI
+from trustgraph.base import PromptResult
@pytest.mark.integration
@@ -93,18 +94,21 @@ class TestGraphRagIntegration:
# 4. kg-synthesis returns the final answer
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "extract-concepts":
- return "" # Falls back to raw query
+ return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-scoring":
- return "" # No edges scored
+ return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-reasoning":
- return "" # No reasoning
+ return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
- return (
- "Machine learning is a subset of artificial intelligence that enables computers "
- "to learn from data without being explicitly programmed. It uses algorithms "
- "and statistical models to find patterns in data."
+ return PromptResult(
+ response_type="text",
+ text=(
+ "Machine learning is a subset of artificial intelligence that enables computers "
+ "to learn from data without being explicitly programmed. It uses algorithms "
+ "and statistical models to find patterns in data."
+ )
)
- return ""
+ return PromptResult(response_type="text", text="")
client.prompt.side_effect = mock_prompt
return client
@@ -169,6 +173,7 @@ class TestGraphRagIntegration:
assert mock_prompt_client.prompt.call_count == 4
# Verify final response
+ response, usage = response
assert response is not None
assert isinstance(response, str)
assert "machine learning" in response.lower()
diff --git a/tests/integration/test_graph_rag_streaming_integration.py b/tests/integration/test_graph_rag_streaming_integration.py
index b66c5289..95c494bb 100644
--- a/tests/integration/test_graph_rag_streaming_integration.py
+++ b/tests/integration/test_graph_rag_streaming_integration.py
@@ -9,6 +9,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.schema import EntityMatch, Term, IRI
+from trustgraph.base import PromptResult
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
assert_rag_streaming_chunks,
@@ -61,12 +62,12 @@ class TestGraphRagStreaming:
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
if prompt_id == "extract-concepts":
- return "" # Falls back to raw query
+ return PromptResult(response_type="text", text="")
elif prompt_id == "kg-edge-scoring":
# Edge scoring returns JSONL with IDs and scores
- return '{"id": "abc12345", "score": 0.9}\n'
+ return PromptResult(response_type="text", text='{"id": "abc12345", "score": 0.9}\n')
elif prompt_id == "kg-edge-reasoning":
- return '{"id": "abc12345", "reasoning": "Relevant to query"}\n'
+ return PromptResult(response_type="text", text='{"id": "abc12345", "reasoning": "Relevant to query"}\n')
elif prompt_id == "kg-synthesis":
if streaming and chunk_callback:
# Simulate streaming chunks with end_of_stream flags
@@ -79,10 +80,10 @@ class TestGraphRagStreaming:
is_final = (i == len(chunks) - 1)
await chunk_callback(chunk, is_final)
- return full_text
+ return PromptResult(response_type="text", text=full_text)
else:
- return full_text
- return ""
+ return PromptResult(response_type="text", text=full_text)
+ return PromptResult(response_type="text", text="")
client.prompt.side_effect = prompt_side_effect
return client
@@ -123,6 +124,7 @@ class TestGraphRagStreaming:
)
# Assert
+ response, usage = response
assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
@@ -172,9 +174,11 @@ class TestGraphRagStreaming:
)
# Assert - Results should be equivalent
- assert streaming_response == non_streaming_response
+ non_streaming_text, _ = non_streaming_response
+ streaming_text, _ = streaming_response
+ assert streaming_text == non_streaming_text
assert len(streaming_chunks) > 0
- assert "".join(streaming_chunks) == streaming_response
+ assert "".join(streaming_chunks) == streaming_text
@pytest.mark.asyncio
async def test_graph_rag_streaming_callback_invocation(self, graph_rag_streaming):
@@ -213,7 +217,8 @@ class TestGraphRagStreaming:
# Assert - Should complete without error
assert response is not None
- assert isinstance(response, str)
+ response_text, usage = response
+ assert isinstance(response_text, str)
@pytest.mark.asyncio
async def test_graph_rag_streaming_with_empty_kg(self, graph_rag_streaming,
diff --git a/tests/integration/test_kg_extract_store_integration.py b/tests/integration/test_kg_extract_store_integration.py
index 4d8b60ad..84c0905d 100644
--- a/tests/integration/test_kg_extract_store_integration.py
+++ b/tests/integration/test_kg_extract_store_integration.py
@@ -18,6 +18,7 @@ from trustgraph.storage.knowledge.store import Processor as KnowledgeStoreProces
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings, EntityEmbeddings
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL
+from trustgraph.base import PromptResult
@pytest.mark.integration
@@ -31,32 +32,38 @@ class TestKnowledgeGraphPipelineIntegration:
# Mock prompt client for definitions extraction
prompt_client = AsyncMock()
- prompt_client.extract_definitions.return_value = [
- {
- "entity": "Machine Learning",
- "definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
- },
- {
- "entity": "Neural Networks",
- "definition": "Computing systems inspired by biological neural networks that process information."
- }
- ]
-
+ prompt_client.extract_definitions.return_value = PromptResult(
+ response_type="jsonl",
+ objects=[
+ {
+ "entity": "Machine Learning",
+ "definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
+ },
+ {
+ "entity": "Neural Networks",
+ "definition": "Computing systems inspired by biological neural networks that process information."
+ }
+ ]
+ )
+
# Mock prompt client for relationships extraction
- prompt_client.extract_relationships.return_value = [
- {
- "subject": "Machine Learning",
- "predicate": "is_subset_of",
- "object": "Artificial Intelligence",
- "object-entity": True
- },
- {
- "subject": "Neural Networks",
- "predicate": "is_used_in",
- "object": "Machine Learning",
- "object-entity": True
- }
- ]
+ prompt_client.extract_relationships.return_value = PromptResult(
+ response_type="jsonl",
+ objects=[
+ {
+ "subject": "Machine Learning",
+ "predicate": "is_subset_of",
+ "object": "Artificial Intelligence",
+ "object-entity": True
+ },
+ {
+ "subject": "Neural Networks",
+ "predicate": "is_used_in",
+ "object": "Machine Learning",
+ "object-entity": True
+ }
+ ]
+ )
# Mock producers for output streams
triples_producer = AsyncMock()
@@ -489,7 +496,10 @@ class TestKnowledgeGraphPipelineIntegration:
async def test_empty_extraction_results_handling(self, definitions_processor, mock_flow_context, sample_chunk):
"""Test handling of empty extraction results"""
# Arrange
- mock_flow_context("prompt-request").extract_definitions.return_value = []
+ mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
+ response_type="jsonl",
+ objects=[]
+ )
mock_msg = MagicMock()
mock_msg.value.return_value = sample_chunk
@@ -510,7 +520,10 @@ class TestKnowledgeGraphPipelineIntegration:
async def test_invalid_extraction_format_handling(self, definitions_processor, mock_flow_context, sample_chunk):
"""Test handling of invalid extraction response format"""
# Arrange
- mock_flow_context("prompt-request").extract_definitions.return_value = "invalid format" # Should be list
+ mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
+ response_type="text",
+ text="invalid format"
+ ) # Should be jsonl with objects list
mock_msg = MagicMock()
mock_msg.value.return_value = sample_chunk
@@ -528,13 +541,16 @@ class TestKnowledgeGraphPipelineIntegration:
async def test_entity_filtering_and_validation(self, definitions_processor, mock_flow_context):
"""Test entity filtering and validation in extraction"""
# Arrange
- mock_flow_context("prompt-request").extract_definitions.return_value = [
- {"entity": "Valid Entity", "definition": "Valid definition"},
- {"entity": "", "definition": "Empty entity"}, # Should be filtered
- {"entity": "Valid Entity 2", "definition": ""}, # Should be filtered
- {"entity": None, "definition": "None entity"}, # Should be filtered
- {"entity": "Valid Entity 3", "definition": None}, # Should be filtered
- ]
+ mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
+ response_type="jsonl",
+ objects=[
+ {"entity": "Valid Entity", "definition": "Valid definition"},
+ {"entity": "", "definition": "Empty entity"}, # Should be filtered
+ {"entity": "Valid Entity 2", "definition": ""}, # Should be filtered
+ {"entity": None, "definition": "None entity"}, # Should be filtered
+ {"entity": "Valid Entity 3", "definition": None}, # Should be filtered
+ ]
+ )
sample_chunk = Chunk(
metadata=Metadata(id="test", user="user", collection="collection"),
diff --git a/tests/integration/test_object_extraction_integration.py b/tests/integration/test_object_extraction_integration.py
index faa63381..22ba9a3f 100644
--- a/tests/integration/test_object_extraction_integration.py
+++ b/tests/integration/test_object_extraction_integration.py
@@ -16,6 +16,7 @@ from trustgraph.schema import (
Chunk, ExtractedObject, Metadata, RowSchema, Field,
PromptRequest, PromptResponse
)
+from trustgraph.base import PromptResult
@pytest.mark.integration
@@ -114,49 +115,61 @@ class TestObjectExtractionServiceIntegration:
schema_name = schema.get("name") if isinstance(schema, dict) else schema.name
if schema_name == "customer_records":
if "john" in text.lower():
- return [
- {
- "customer_id": "CUST001",
- "name": "John Smith",
- "email": "john.smith@email.com",
- "phone": "555-0123"
- }
- ]
+ return PromptResult(
+ response_type="jsonl",
+ objects=[
+ {
+ "customer_id": "CUST001",
+ "name": "John Smith",
+ "email": "john.smith@email.com",
+ "phone": "555-0123"
+ }
+ ]
+ )
elif "jane" in text.lower():
- return [
- {
- "customer_id": "CUST002",
- "name": "Jane Doe",
- "email": "jane.doe@email.com",
- "phone": ""
- }
- ]
+ return PromptResult(
+ response_type="jsonl",
+ objects=[
+ {
+ "customer_id": "CUST002",
+ "name": "Jane Doe",
+ "email": "jane.doe@email.com",
+ "phone": ""
+ }
+ ]
+ )
else:
- return []
-
+ return PromptResult(response_type="jsonl", objects=[])
+
elif schema_name == "product_catalog":
if "laptop" in text.lower():
- return [
- {
- "product_id": "PROD001",
- "name": "Gaming Laptop",
- "price": "1299.99",
- "category": "electronics"
- }
- ]
+ return PromptResult(
+ response_type="jsonl",
+ objects=[
+ {
+ "product_id": "PROD001",
+ "name": "Gaming Laptop",
+ "price": "1299.99",
+ "category": "electronics"
+ }
+ ]
+ )
elif "book" in text.lower():
- return [
- {
- "product_id": "PROD002",
- "name": "Python Programming Guide",
- "price": "49.99",
- "category": "books"
- }
- ]
+ return PromptResult(
+ response_type="jsonl",
+ objects=[
+ {
+ "product_id": "PROD002",
+ "name": "Python Programming Guide",
+ "price": "49.99",
+ "category": "books"
+ }
+ ]
+ )
else:
- return []
-
- return []
+ return PromptResult(response_type="jsonl", objects=[])
+
+ return PromptResult(response_type="jsonl", objects=[])
prompt_client.extract_objects.side_effect = mock_extract_objects
diff --git a/tests/integration/test_prompt_streaming_integration.py b/tests/integration/test_prompt_streaming_integration.py
index 9b1a06b6..a1414e2d 100644
--- a/tests/integration/test_prompt_streaming_integration.py
+++ b/tests/integration/test_prompt_streaming_integration.py
@@ -9,6 +9,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock
from trustgraph.prompt.template.service import Processor
from trustgraph.schema import PromptRequest, PromptResponse, TextCompletionResponse
+from trustgraph.base.text_completion_client import TextCompletionResult
from tests.utils.streaming_assertions import (
assert_streaming_chunks_valid,
assert_callback_invoked,
@@ -27,34 +28,52 @@ class TestPromptStreaming:
# Mock text completion client with streaming
text_completion_client = AsyncMock()
- async def streaming_request(request, recipient=None, timeout=600):
- """Simulate streaming text completion"""
- if request.streaming and recipient:
- # Simulate streaming chunks
- chunks = [
- "Machine", " learning", " is", " a", " field",
- " of", " artificial", " intelligence", "."
- ]
+ # Streaming chunks to send
+ chunks = [
+ "Machine", " learning", " is", " a", " field",
+ " of", " artificial", " intelligence", "."
+ ]
- for i, chunk_text in enumerate(chunks):
- is_final = (i == len(chunks) - 1)
- response = TextCompletionResponse(
- response=chunk_text,
- error=None,
- end_of_stream=is_final
- )
- final = await recipient(response)
- if final:
- break
-
- # Final empty chunk
- await recipient(TextCompletionResponse(
- response="",
+ async def streaming_text_completion_stream(system, prompt, handler, timeout=600):
+ """Simulate streaming text completion via text_completion_stream"""
+ for i, chunk_text in enumerate(chunks):
+ response = TextCompletionResponse(
+ response=chunk_text,
error=None,
- end_of_stream=True
- ))
+ end_of_stream=False
+ )
+ await handler(response)
- text_completion_client.request = streaming_request
+ # Send final empty chunk with end_of_stream
+ await handler(TextCompletionResponse(
+ response="",
+ error=None,
+ end_of_stream=True
+ ))
+
+ return TextCompletionResult(
+ text=None,
+ in_token=10,
+ out_token=9,
+ model="test-model",
+ )
+
+ async def non_streaming_text_completion(system, prompt, timeout=600):
+ """Simulate non-streaming text completion"""
+ full_text = "Machine learning is a field of artificial intelligence."
+ return TextCompletionResult(
+ text=full_text,
+ in_token=10,
+ out_token=9,
+ model="test-model",
+ )
+
+ text_completion_client.text_completion_stream = AsyncMock(
+ side_effect=streaming_text_completion_stream
+ )
+ text_completion_client.text_completion = AsyncMock(
+ side_effect=non_streaming_text_completion
+ )
# Mock response producer
response_producer = AsyncMock()
@@ -156,14 +175,6 @@ class TestPromptStreaming:
consumer = MagicMock()
- # Mock non-streaming text completion
- text_completion_client = mock_flow_context_streaming("text-completion-request")
-
- async def non_streaming_text_completion(system, prompt, streaming=False):
- return "AI is the simulation of human intelligence in machines."
-
- text_completion_client.text_completion = non_streaming_text_completion
-
# Act
await prompt_processor_streaming.on_request(
message, consumer, mock_flow_context_streaming
@@ -218,17 +229,12 @@ class TestPromptStreaming:
# Mock text completion client that raises an error
text_completion_client = AsyncMock()
- async def failing_request(request, recipient=None, timeout=600):
- if recipient:
- # Send error response with proper Error schema
- error_response = TextCompletionResponse(
- response="",
- error=Error(message="Text completion error", type="processing_error"),
- end_of_stream=True
- )
- await recipient(error_response)
+ async def failing_stream(system, prompt, handler, timeout=600):
+ raise RuntimeError("Text completion error")
- text_completion_client.request = failing_request
+ text_completion_client.text_completion_stream = AsyncMock(
+ side_effect=failing_stream
+ )
# Mock response producer to capture error response
response_producer = AsyncMock()
@@ -255,22 +261,15 @@ class TestPromptStreaming:
consumer = MagicMock()
- # Act - The service catches errors and sends error responses, doesn't raise
+ # Act - The service catches errors and sends an error PromptResponse
await prompt_processor_streaming.on_request(message, consumer, context)
- # Assert - Verify error response was sent
- assert response_producer.send.call_count > 0
-
- # Check that at least one response contains an error
- error_sent = False
- for call in response_producer.send.call_args_list:
- response = call.args[0]
- if hasattr(response, 'error') and response.error:
- error_sent = True
- assert "Text completion error" in response.error.message
- break
-
- assert error_sent, "Expected error response to be sent"
+ # Assert - error response was sent
+ calls = response_producer.send.call_args_list
+ assert len(calls) > 0
+ error_response = calls[-1].args[0]
+ assert error_response.error is not None
+ assert "Text completion error" in error_response.error.message
@pytest.mark.asyncio
async def test_prompt_streaming_preserves_message_id(self, prompt_processor_streaming,
@@ -315,21 +314,22 @@ class TestPromptStreaming:
# Mock text completion that sends empty chunks
text_completion_client = AsyncMock()
- async def empty_streaming_request(request, recipient=None, timeout=600):
- if request.streaming and recipient:
- # Send empty chunk followed by final marker
- await recipient(TextCompletionResponse(
- response="",
- error=None,
- end_of_stream=False
- ))
- await recipient(TextCompletionResponse(
- response="",
- error=None,
- end_of_stream=True
- ))
+ async def empty_streaming(system, prompt, handler, timeout=600):
+ # Send empty chunk followed by final marker
+ await handler(TextCompletionResponse(
+ response="",
+ error=None,
+ end_of_stream=False
+ ))
+ await handler(TextCompletionResponse(
+ response="",
+ error=None,
+ end_of_stream=True
+ ))
- text_completion_client.request = empty_streaming_request
+ text_completion_client.text_completion_stream = AsyncMock(
+ side_effect=empty_streaming
+ )
response_producer = AsyncMock()
def context_router(service_name):
@@ -401,4 +401,4 @@ class TestPromptStreaming:
# Verify chunks concatenate to expected result
full_text = "".join(chunk_texts)
- assert full_text == "Machine learning is a field of artificial intelligence"
+ assert full_text == "Machine learning is a field of artificial intelligence."
diff --git a/tests/integration/test_rag_streaming_protocol.py b/tests/integration/test_rag_streaming_protocol.py
index f5fe14b5..83a90412 100644
--- a/tests/integration/test_rag_streaming_protocol.py
+++ b/tests/integration/test_rag_streaming_protocol.py
@@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, call
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI
+from trustgraph.base import PromptResult
class TestGraphRagStreamingProtocol:
@@ -46,8 +47,7 @@ class TestGraphRagStreamingProtocol:
async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
- # Edge selection returns empty (no edges selected)
- return ""
+ return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
if streaming and chunk_callback:
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
@@ -55,10 +55,10 @@ class TestGraphRagStreamingProtocol:
await chunk_callback(" answer", False)
await chunk_callback(" is here.", False)
await chunk_callback("", True) # Empty final chunk with end_of_stream=True
- return "" # Return value not used since callback handles everything
+ return PromptResult(response_type="text", text="")
else:
- return "The answer is here."
- return ""
+ return PromptResult(response_type="text", text="The answer is here.")
+ return PromptResult(response_type="text", text="")
client.prompt.side_effect = prompt_side_effect
return client
@@ -237,11 +237,13 @@ class TestDocumentRagStreamingProtocol:
await chunk_callback("Document", False)
await chunk_callback(" summary", False)
await chunk_callback(".", True) # Non-empty final chunk
- return ""
+ return PromptResult(response_type="text", text="")
else:
- return "Document summary."
+ return PromptResult(response_type="text", text="Document summary.")
client.document_prompt.side_effect = document_prompt_side_effect
+ # Mock prompt() for extract-concepts call in DocumentRag
+ client.prompt.return_value = PromptResult(response_type="text", text="")
return client
@pytest.fixture
@@ -334,17 +336,17 @@ class TestStreamingProtocolEdgeCases:
async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "kg-edge-selection":
- return ""
+ return PromptResult(response_type="text", text="")
elif prompt_name == "kg-synthesis":
if streaming and chunk_callback:
await chunk_callback("text", False)
await chunk_callback("", False) # Empty but not final
await chunk_callback("more", False)
await chunk_callback("", True) # Empty and final
- return ""
+ return PromptResult(response_type="text", text="")
else:
- return "textmore"
- return ""
+ return PromptResult(response_type="text", text="textmore")
+ return PromptResult(response_type="text", text="")
client.prompt.side_effect = prompt_with_empties
diff --git a/tests/pytest.ini b/tests/pytest.ini
index b032a9d4..8541bd8f 100644
--- a/tests/pytest.ini
+++ b/tests/pytest.ini
@@ -4,14 +4,11 @@ python_paths = .
python_files = test_*.py
python_classes = Test*
python_functions = test_*
-addopts =
+addopts =
-v
--tb=short
--strict-markers
--disable-warnings
- --cov=trustgraph
- --cov-report=html
- --cov-report=term-missing
# --cov-fail-under=80
asyncio_mode = auto
markers =
diff --git a/tests/unit/test_agent/test_agent_service_non_streaming.py b/tests/unit/test_agent/test_agent_service_non_streaming.py
index 0b9b283a..bb58e5ee 100644
--- a/tests/unit/test_agent/test_agent_service_non_streaming.py
+++ b/tests/unit/test_agent/test_agent_service_non_streaming.py
@@ -78,10 +78,10 @@ class TestAgentServiceNonStreaming:
# Filter out explain events — those are always sent now
content_responses = [
- r for r in sent_responses if r.chunk_type != "explain"
+ r for r in sent_responses if r.message_type != "explain"
]
explain_responses = [
- r for r in sent_responses if r.chunk_type == "explain"
+ r for r in sent_responses if r.message_type == "explain"
]
# Should have explain events for session, iteration, observation, and final
@@ -93,7 +93,7 @@ class TestAgentServiceNonStreaming:
# Check thought message
thought_response = content_responses[0]
assert isinstance(thought_response, AgentResponse)
- assert thought_response.chunk_type == "thought"
+ assert thought_response.message_type == "thought"
assert thought_response.content == "I need to solve this."
assert thought_response.end_of_message is True, "Thought message must have end_of_message=True"
assert thought_response.end_of_dialog is False, "Thought message must have end_of_dialog=False"
@@ -101,7 +101,7 @@ class TestAgentServiceNonStreaming:
# Check observation message
observation_response = content_responses[1]
assert isinstance(observation_response, AgentResponse)
- assert observation_response.chunk_type == "observation"
+ assert observation_response.message_type == "observation"
assert observation_response.content == "The answer is 4."
assert observation_response.end_of_message is True, "Observation message must have end_of_message=True"
assert observation_response.end_of_dialog is False, "Observation message must have end_of_dialog=False"
@@ -168,10 +168,10 @@ class TestAgentServiceNonStreaming:
# Filter out explain events — those are always sent now
content_responses = [
- r for r in sent_responses if r.chunk_type != "explain"
+ r for r in sent_responses if r.message_type != "explain"
]
explain_responses = [
- r for r in sent_responses if r.chunk_type == "explain"
+ r for r in sent_responses if r.message_type == "explain"
]
# Should have explain events for session and final
@@ -183,7 +183,7 @@ class TestAgentServiceNonStreaming:
# Check final answer message
answer_response = content_responses[0]
assert isinstance(answer_response, AgentResponse)
- assert answer_response.chunk_type == "answer"
+ assert answer_response.message_type == "answer"
assert answer_response.content == "4"
assert answer_response.end_of_message is True, "Final answer must have end_of_message=True"
assert answer_response.end_of_dialog is True, "Final answer must have end_of_dialog=True"
diff --git a/tests/unit/test_agent/test_callback_message_id.py b/tests/unit/test_agent/test_callback_message_id.py
index 7cb0ee54..2c4964a5 100644
--- a/tests/unit/test_agent/test_callback_message_id.py
+++ b/tests/unit/test_agent/test_callback_message_id.py
@@ -29,7 +29,7 @@ class TestThinkCallbackMessageId:
assert len(responses) == 1
assert responses[0].message_id == msg_id
- assert responses[0].chunk_type == "thought"
+ assert responses[0].message_type == "thought"
@pytest.mark.asyncio
async def test_non_streaming_think_has_message_id(self, pattern):
@@ -58,7 +58,7 @@ class TestObserveCallbackMessageId:
await observe("result", is_final=True)
assert responses[0].message_id == msg_id
- assert responses[0].chunk_type == "observation"
+ assert responses[0].message_type == "observation"
class TestAnswerCallbackMessageId:
@@ -74,7 +74,7 @@ class TestAnswerCallbackMessageId:
await answer("the answer")
assert responses[0].message_id == msg_id
- assert responses[0].chunk_type == "answer"
+ assert responses[0].message_type == "answer"
@pytest.mark.asyncio
async def test_no_message_id_default(self, pattern):
diff --git a/tests/unit/test_agent/test_meta_router.py b/tests/unit/test_agent/test_meta_router.py
index da0c634c..da8c6c79 100644
--- a/tests/unit/test_agent/test_meta_router.py
+++ b/tests/unit/test_agent/test_meta_router.py
@@ -9,6 +9,7 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.agent.orchestrator.meta_router import (
MetaRouter, DEFAULT_PATTERN, DEFAULT_TASK_TYPE,
)
+from trustgraph.base import PromptResult
def _make_config(patterns=None, task_types=None):
@@ -28,7 +29,9 @@ def _make_config(patterns=None, task_types=None):
def _make_context(prompt_response):
"""Build a mock context that returns a mock prompt client."""
client = AsyncMock()
- client.prompt = AsyncMock(return_value=prompt_response)
+ client.prompt = AsyncMock(
+ return_value=PromptResult(response_type="text", text=prompt_response)
+ )
def context(service_name):
return client
@@ -274,8 +277,8 @@ class TestRoute:
nonlocal call_count
call_count += 1
if call_count == 1:
- return "research" # task type
- return "plan-then-execute" # pattern
+ return PromptResult(response_type="text", text="research")
+ return PromptResult(response_type="text", text="plan-then-execute")
client.prompt = mock_prompt
context = lambda name: client
diff --git a/tests/unit/test_agent/test_orchestrator_provenance_integration.py b/tests/unit/test_agent/test_orchestrator_provenance_integration.py
index 96d41259..63d87ba1 100644
--- a/tests/unit/test_agent/test_orchestrator_provenance_integration.py
+++ b/tests/unit/test_agent/test_orchestrator_provenance_integration.py
@@ -18,6 +18,7 @@ from dataclasses import dataclass, field
from trustgraph.schema import (
AgentRequest, AgentResponse, AgentStep, PlanStep,
)
+from trustgraph.base import PromptResult
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
@@ -68,7 +69,7 @@ def collect_explain_events(respond_mock):
events = []
for call in respond_mock.call_args_list:
resp = call[0][0]
- if isinstance(resp, AgentResponse) and resp.chunk_type == "explain":
+ if isinstance(resp, AgentResponse) and resp.message_type == "explain":
events.append({
"explain_id": resp.explain_id,
"explain_graph": resp.explain_graph,
@@ -183,7 +184,7 @@ class TestReactPatternProvenance:
)
async def mock_react(question, history, think, observe, answer,
- context, streaming, on_action):
+ context, streaming, on_action, **kwargs):
# Simulate the on_action callback before returning Final
if on_action:
await on_action(Action(
@@ -267,7 +268,7 @@ class TestReactPatternProvenance:
MockAM.return_value = mock_am
async def mock_react(question, history, think, observe, answer,
- context, streaming, on_action):
+ context, streaming, on_action, **kwargs):
if on_action:
await on_action(action)
return action
@@ -309,7 +310,7 @@ class TestReactPatternProvenance:
MockAM.return_value = mock_am
async def mock_react(question, history, think, observe, answer,
- context, streaming, on_action):
+ context, streaming, on_action, **kwargs):
if on_action:
await on_action(Action(
thought="done", name="final",
@@ -355,10 +356,13 @@ class TestPlanPatternProvenance:
# Mock prompt client for plan creation
mock_prompt_client = AsyncMock()
- mock_prompt_client.prompt.return_value = [
- {"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
- {"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
- ]
+ mock_prompt_client.prompt.return_value = PromptResult(
+ response_type="jsonl",
+ objects=[
+ {"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
+ {"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
+ ],
+ )
def flow_factory(name):
if name == "prompt-request":
@@ -418,10 +422,13 @@ class TestPlanPatternProvenance:
# Mock prompt for step execution
mock_prompt_client = AsyncMock()
- mock_prompt_client.prompt.return_value = {
- "tool": "knowledge-query",
- "arguments": {"question": "quantum computing"},
- }
+ mock_prompt_client.prompt.return_value = PromptResult(
+ response_type="json",
+ object={
+ "tool": "knowledge-query",
+ "arguments": {"question": "quantum computing"},
+ },
+ )
def flow_factory(name):
if name == "prompt-request":
@@ -475,7 +482,7 @@ class TestPlanPatternProvenance:
# Mock prompt for synthesis
mock_prompt_client = AsyncMock()
- mock_prompt_client.prompt.return_value = "The synthesised answer."
+ mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The synthesised answer.")
def flow_factory(name):
if name == "prompt-request":
@@ -542,10 +549,13 @@ class TestSupervisorPatternProvenance:
# Mock prompt for decomposition
mock_prompt_client = AsyncMock()
- mock_prompt_client.prompt.return_value = [
- "What is quantum computing?",
- "What are qubits?",
- ]
+ mock_prompt_client.prompt.return_value = PromptResult(
+ response_type="jsonl",
+ objects=[
+ "What is quantum computing?",
+ "What are qubits?",
+ ],
+ )
def flow_factory(name):
if name == "prompt-request":
@@ -590,7 +600,7 @@ class TestSupervisorPatternProvenance:
# Mock prompt for synthesis
mock_prompt_client = AsyncMock()
- mock_prompt_client.prompt.return_value = "The combined answer."
+ mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The combined answer.")
def flow_factory(name):
if name == "prompt-request":
@@ -639,7 +649,10 @@ class TestSupervisorPatternProvenance:
flow = make_mock_flow()
mock_prompt_client = AsyncMock()
- mock_prompt_client.prompt.return_value = ["Goal A", "Goal B", "Goal C"]
+ mock_prompt_client.prompt.return_value = PromptResult(
+ response_type="jsonl",
+ objects=["Goal A", "Goal B", "Goal C"],
+ )
def flow_factory(name):
if name == "prompt-request":
diff --git a/tests/unit/test_agent/test_parse_chunk_message_id.py b/tests/unit/test_agent/test_parse_chunk_message_id.py
index 38942f1e..36d2220e 100644
--- a/tests/unit/test_agent/test_parse_chunk_message_id.py
+++ b/tests/unit/test_agent/test_parse_chunk_message_id.py
@@ -20,7 +20,7 @@ class TestParseChunkMessageId:
def test_thought_message_id(self, client):
resp = {
- "chunk_type": "thought",
+ "message_type": "thought",
"content": "thinking...",
"end_of_message": False,
"message_id": "urn:trustgraph:agent:sess/i1/thought",
@@ -31,7 +31,7 @@ class TestParseChunkMessageId:
def test_observation_message_id(self, client):
resp = {
- "chunk_type": "observation",
+ "message_type": "observation",
"content": "result",
"end_of_message": True,
"message_id": "urn:trustgraph:agent:sess/i1/observation",
@@ -42,7 +42,7 @@ class TestParseChunkMessageId:
def test_answer_message_id(self, client):
resp = {
- "chunk_type": "answer",
+ "message_type": "answer",
"content": "the answer",
"end_of_message": False,
"end_of_dialog": False,
@@ -54,7 +54,7 @@ class TestParseChunkMessageId:
def test_thought_missing_message_id(self, client):
resp = {
- "chunk_type": "thought",
+ "message_type": "thought",
"content": "thinking...",
"end_of_message": False,
}
@@ -64,7 +64,7 @@ class TestParseChunkMessageId:
def test_answer_missing_message_id(self, client):
resp = {
- "chunk_type": "answer",
+ "message_type": "answer",
"content": "answer",
"end_of_message": True,
"end_of_dialog": True,
diff --git a/tests/unit/test_base/test_metrics.py b/tests/unit/test_base/test_metrics.py
index b5a792a1..0496db20 100644
--- a/tests/unit/test_base/test_metrics.py
+++ b/tests/unit/test_base/test_metrics.py
@@ -7,6 +7,14 @@ from trustgraph.base import metrics
@pytest.fixture(autouse=True)
def reset_metric_singletons():
+ """Temporarily remove metric singletons so each test can inject mocks.
+
+ Saves any existing class-level metrics and restores them after the test
+ so that later tests in the same process still find the hasattr() guard
+ intact — deleting without restoring causes every subsequent Processor()
+ construction to re-register the same Prometheus metric name, which raises
+ ValueError: Duplicated timeseries.
+ """
classes_and_attrs = {
metrics.ConsumerMetrics: [
"state_metric",
@@ -23,18 +31,24 @@ def reset_metric_singletons():
],
}
+ saved = {}
for cls, attrs in classes_and_attrs.items():
for attr in attrs:
if hasattr(cls, attr):
+ saved[(cls, attr)] = getattr(cls, attr)
delattr(cls, attr)
yield
+ # Remove anything the test may have set, then restore originals
for cls, attrs in classes_and_attrs.items():
for attr in attrs:
if hasattr(cls, attr):
delattr(cls, attr)
+ for (cls, attr), value in saved.items():
+ setattr(cls, attr, value)
+
def test_consumer_metrics_reuses_singletons_and_records_events(monkeypatch):
enum_factory = MagicMock()
diff --git a/tests/unit/test_base/test_subscriber_graceful_shutdown.py b/tests/unit/test_base/test_subscriber_graceful_shutdown.py
index ec14f66b..cbd4d535 100644
--- a/tests/unit/test_base/test_subscriber_graceful_shutdown.py
+++ b/tests/unit/test_base/test_subscriber_graceful_shutdown.py
@@ -236,6 +236,10 @@ async def test_subscriber_graceful_shutdown():
with patch.object(subscriber, 'run') as mock_run:
# Mock run that simulates graceful shutdown
async def mock_run_graceful():
+ # Honor the readiness contract: real run() signals _ready
+ # after binding the consumer, so start() can unblock. Mocks
+ # of run() must do the same or start() hangs forever.
+ subscriber._ready.set_result(None)
# Process messages while running, then drain
while subscriber.running or subscriber.draining:
if subscriber.draining:
@@ -337,6 +341,8 @@ async def test_subscriber_pending_acks_cleanup():
with patch.object(subscriber, 'run') as mock_run:
# Mock run that simulates cleanup of pending acks
async def mock_run_cleanup():
+ # Honor the readiness contract — see test_subscriber_graceful_shutdown.
+ subscriber._ready.set_result(None)
while subscriber.running or subscriber.draining:
await asyncio.sleep(0.05)
if subscriber.draining:
@@ -406,4 +412,4 @@ async def test_subscriber_multiple_subscribers():
msg1 = await queue1.get()
msg_all = await queue_all.get()
assert msg1 == {"data": "broadcast"}
- assert msg_all == {"data": "broadcast"}
\ No newline at end of file
+ assert msg_all == {"data": "broadcast"}
diff --git a/tests/unit/test_base/test_subscriber_readiness.py b/tests/unit/test_base/test_subscriber_readiness.py
new file mode 100644
index 00000000..e1ef47de
--- /dev/null
+++ b/tests/unit/test_base/test_subscriber_readiness.py
@@ -0,0 +1,189 @@
+"""
+Regression tests for Subscriber.start() readiness barrier.
+
+Background: prior to the eager-connect fix, Subscriber.start() created
+the run() task and returned immediately. The underlying backend consumer
+was lazily connected on its first receive() call, which left a setup
+race for request/response clients using ephemeral per-subscriber response
+queues (RabbitMQ auto-delete exclusive queues): the request would be
+published before the response queue was bound, and the broker would
+silently drop the reply. fetch_config(), document-embeddings, and
+api-gateway all hit this with "Failed to fetch config on notify" /
+"Request timeout exception" symptoms.
+
+These tests pin the readiness contract:
+
+ await subscriber.start()
+ # at this point, consumer.ensure_connected() MUST have run
+
+so that any future change which removes the eager bind, or moves it
+back to lazy initialisation, fails CI loudly.
+"""
+
+import asyncio
+
+import pytest
+from unittest.mock import MagicMock
+
+from trustgraph.base.subscriber import Subscriber
+
+
+def _make_backend(ensure_connected_side_effect=None,
+ receive_side_effect=None):
+ """Build a fake backend whose consumer records ensure_connected /
+ receive calls. ensure_connected_side_effect lets a test inject a
+ delay or exception."""
+ backend = MagicMock()
+ consumer = MagicMock()
+
+ consumer.ensure_connected = MagicMock(
+ side_effect=ensure_connected_side_effect,
+ )
+
+ # By default receive raises a timeout-style exception that the
+ # subscriber loop is supposed to swallow as a "no message yet" — this
+ # keeps the subscriber idling cleanly while the test inspects state.
+ if receive_side_effect is None:
+ receive_side_effect = TimeoutError("No message received within timeout")
+ consumer.receive = MagicMock(side_effect=receive_side_effect)
+
+ consumer.acknowledge = MagicMock()
+ consumer.negative_acknowledge = MagicMock()
+ consumer.pause_message_listener = MagicMock()
+ consumer.unsubscribe = MagicMock()
+ consumer.close = MagicMock()
+
+ backend.create_consumer.return_value = consumer
+ return backend, consumer
+
+
+def _make_subscriber(backend):
+ return Subscriber(
+ backend=backend,
+ topic="response:tg:config",
+ subscription="test-sub",
+ consumer_name="test-consumer",
+ schema=dict,
+ max_size=10,
+ drain_timeout=1.0,
+ backpressure_strategy="block",
+ )
+
+
+class TestSubscriberReadiness:
+
+ @pytest.mark.asyncio
+ async def test_start_calls_ensure_connected_before_returning(self):
+ """The barrier: ensure_connected must have been invoked at least
+ once by the time start() returns."""
+ backend, consumer = _make_backend()
+ subscriber = _make_subscriber(backend)
+
+ await subscriber.start()
+
+ try:
+ consumer.ensure_connected.assert_called_once()
+ finally:
+ await subscriber.stop()
+
+ @pytest.mark.asyncio
+ async def test_start_blocks_until_ensure_connected_completes(self):
+ """If ensure_connected is slow, start() must wait for it. This is
+ the actual race-condition guard — it would have failed against
+ the buggy version where start() returned before run() had even
+ scheduled the consumer creation."""
+ connect_started = asyncio.Event()
+ release_connect = asyncio.Event()
+
+ # ensure_connected runs in the executor thread, so we need a
+ # threading-safe gate. Use a simple busy-wait on a flag set by
+ # the asyncio side via call_soon_threadsafe — but the simpler
+ # path is to give it a sleep and observe ordering.
+ import threading
+ gate = threading.Event()
+
+ def slow_connect():
+ connect_started.set() # safe: only mutates the Event flag
+ gate.wait(timeout=2.0)
+
+ backend, consumer = _make_backend(
+ ensure_connected_side_effect=slow_connect,
+ )
+ subscriber = _make_subscriber(backend)
+
+ start_task = asyncio.create_task(subscriber.start())
+
+ # Wait until ensure_connected has begun executing.
+ await asyncio.wait_for(connect_started.wait(), timeout=2.0)
+
+ # ensure_connected is in flight — start() must NOT have returned.
+ assert not start_task.done(), (
+ "start() returned before ensure_connected() completed — "
+ "the readiness barrier is broken and the request/response "
+ "race condition is back."
+ )
+
+ # Release the gate; start() should now complete promptly.
+ gate.set()
+ await asyncio.wait_for(start_task, timeout=2.0)
+
+ consumer.ensure_connected.assert_called_once()
+
+ await subscriber.stop()
+
+ @pytest.mark.asyncio
+ async def test_start_propagates_consumer_creation_failure(self):
+ """If create_consumer() raises, start() must surface the error
+ rather than hang on the readiness future. The old code path
+ retried indefinitely inside run() and never let start() unblock."""
+ backend = MagicMock()
+ backend.create_consumer.side_effect = RuntimeError("broker down")
+
+ subscriber = _make_subscriber(backend)
+
+ with pytest.raises(RuntimeError, match="broker down"):
+ await asyncio.wait_for(subscriber.start(), timeout=2.0)
+
+ @pytest.mark.asyncio
+ async def test_start_propagates_ensure_connected_failure(self):
+ """Same contract for an ensure_connected() that raises (e.g. the
+ broker is up but the queue declare/bind fails)."""
+ backend, consumer = _make_backend(
+ ensure_connected_side_effect=RuntimeError("queue declare failed"),
+ )
+ subscriber = _make_subscriber(backend)
+
+ with pytest.raises(RuntimeError, match="queue declare failed"):
+ await asyncio.wait_for(subscriber.start(), timeout=2.0)
+
+ @pytest.mark.asyncio
+ async def test_ensure_connected_runs_before_subscriber_running_log(self):
+ """Subtle ordering: ensure_connected MUST happen before the
+ receive loop, so that any reply is captured. We assert this by
+ checking ensure_connected was called before any receive call."""
+ call_order = []
+
+ def record_ensure():
+ call_order.append("ensure_connected")
+
+ def record_receive(*args, **kwargs):
+ call_order.append("receive")
+ raise TimeoutError("No message received within timeout")
+
+ backend, consumer = _make_backend(
+ ensure_connected_side_effect=record_ensure,
+ receive_side_effect=record_receive,
+ )
+ subscriber = _make_subscriber(backend)
+
+ await subscriber.start()
+
+ # Give the receive loop a tick to run at least once.
+ await asyncio.sleep(0.05)
+
+ await subscriber.stop()
+
+ # ensure_connected must come first; receive may not have happened
+ # yet on a fast machine, but if it did, it must come after.
+ assert call_order, "neither ensure_connected nor receive was called"
+ assert call_order[0] == "ensure_connected"
diff --git a/tests/unit/test_chunking/test_recursive_chunker.py b/tests/unit/test_chunking/test_recursive_chunker.py
index a5ec59c8..d1a5d247 100644
--- a/tests/unit/test_chunking/test_recursive_chunker.py
+++ b/tests/unit/test_chunking/test_recursive_chunker.py
@@ -70,11 +70,12 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
# Mock message and flow
mock_message = MagicMock()
mock_consumer = MagicMock()
+ # Flow exposes parameter lookup via __call__: flow("chunk-size")
mock_flow = MagicMock()
- mock_flow.parameters.get.side_effect = lambda param: {
+ mock_flow.side_effect = lambda key: {
"chunk-size": 2000, # Override chunk size
"chunk-overlap": None # Use default chunk overlap
- }.get(param)
+ }.get(key)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@@ -105,10 +106,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
- mock_flow.parameters.get.side_effect = lambda param: {
+ mock_flow.side_effect = lambda key: {
"chunk-size": None, # Use default chunk size
"chunk-overlap": 200 # Override chunk overlap
- }.get(param)
+ }.get(key)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@@ -139,10 +140,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
- mock_flow.parameters.get.side_effect = lambda param: {
+ mock_flow.side_effect = lambda key: {
"chunk-size": 1500, # Override chunk size
"chunk-overlap": 150 # Override chunk overlap
- }.get(param)
+ }.get(key)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@@ -195,15 +196,15 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_consumer = MagicMock()
mock_producer = AsyncMock()
mock_triples_producer = AsyncMock()
+ # Flow.__call__ resolves parameters and producers/consumers from the
+ # same dict — merge both kinds here.
mock_flow = MagicMock()
- mock_flow.parameters.get.side_effect = lambda param: {
+ mock_flow.side_effect = lambda key: {
"chunk-size": 1500,
"chunk-overlap": 150,
- }.get(param)
- mock_flow.side_effect = lambda name: {
"output": mock_producer,
"triples": mock_triples_producer,
- }.get(name)
+ }.get(key)
# Act
await processor.on_message(mock_message, mock_consumer, mock_flow)
@@ -241,7 +242,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
- mock_flow.parameters.get.return_value = None # No overrides
+ mock_flow.side_effect = lambda key: None # No overrides
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
diff --git a/tests/unit/test_chunking/test_token_chunker.py b/tests/unit/test_chunking/test_token_chunker.py
index f3f83904..dba4ca94 100644
--- a/tests/unit/test_chunking/test_token_chunker.py
+++ b/tests/unit/test_chunking/test_token_chunker.py
@@ -70,11 +70,12 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
# Mock message and flow
mock_message = MagicMock()
mock_consumer = MagicMock()
+ # Flow exposes parameter lookup via __call__: flow("chunk-size")
mock_flow = MagicMock()
- mock_flow.parameters.get.side_effect = lambda param: {
+ mock_flow.side_effect = lambda key: {
"chunk-size": 400, # Override chunk size
"chunk-overlap": None # Use default chunk overlap
- }.get(param)
+ }.get(key)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@@ -105,10 +106,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
- mock_flow.parameters.get.side_effect = lambda param: {
+ mock_flow.side_effect = lambda key: {
"chunk-size": None, # Use default chunk size
"chunk-overlap": 25 # Override chunk overlap
- }.get(param)
+ }.get(key)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@@ -139,10 +140,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
- mock_flow.parameters.get.side_effect = lambda param: {
+ mock_flow.side_effect = lambda key: {
"chunk-size": 350, # Override chunk size
"chunk-overlap": 30 # Override chunk overlap
- }.get(param)
+ }.get(key)
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
@@ -195,15 +196,15 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_consumer = MagicMock()
mock_producer = AsyncMock()
mock_triples_producer = AsyncMock()
+ # Flow.__call__ resolves parameters and producers/consumers from the
+ # same dict — merge both kinds here.
mock_flow = MagicMock()
- mock_flow.parameters.get.side_effect = lambda param: {
+ mock_flow.side_effect = lambda key: {
"chunk-size": 400,
"chunk-overlap": 40,
- }.get(param)
- mock_flow.side_effect = lambda name: {
"output": mock_producer,
"triples": mock_triples_producer,
- }.get(name)
+ }.get(key)
# Act
await processor.on_message(mock_message, mock_consumer, mock_flow)
@@ -245,7 +246,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock()
mock_consumer = MagicMock()
mock_flow = MagicMock()
- mock_flow.parameters.get.return_value = None # No overrides
+ mock_flow.side_effect = lambda key: None # No overrides
# Act
chunk_size, chunk_overlap = await processor.chunk_document(
diff --git a/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py b/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py
index b651b59e..cbc9a05a 100644
--- a/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py
+++ b/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py
@@ -12,6 +12,7 @@ from unittest.mock import AsyncMock, MagicMock
from trustgraph.extract.kg.definitions.extract import (
Processor, default_triples_batch_size, default_entity_batch_size,
)
+from trustgraph.base import PromptResult
from trustgraph.schema import (
Chunk, Triples, EntityContexts, Triple, Metadata, Term, IRI, LITERAL,
)
@@ -51,8 +52,12 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
mock_triples_pub = AsyncMock()
mock_ecs_pub = AsyncMock()
mock_prompt_client = AsyncMock()
+ if isinstance(prompt_result, list):
+ wrapped = PromptResult(response_type="jsonl", objects=prompt_result)
+ else:
+ wrapped = PromptResult(response_type="text", text=prompt_result)
mock_prompt_client.extract_definitions = AsyncMock(
- return_value=prompt_result
+ return_value=wrapped
)
def flow(name):
diff --git a/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py b/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py
index cf3b1fb0..d9861cf3 100644
--- a/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py
+++ b/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py
@@ -14,6 +14,7 @@ from trustgraph.extract.kg.relationships.extract import (
from trustgraph.schema import (
Chunk, Triples, Triple, Metadata, Term, IRI, LITERAL,
)
+from trustgraph.base import PromptResult
# ---------------------------------------------------------------------------
@@ -58,7 +59,10 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
mock_triples_pub = AsyncMock()
mock_prompt_client = AsyncMock()
mock_prompt_client.extract_relationships = AsyncMock(
- return_value=prompt_result
+ return_value=PromptResult(
+ response_type="jsonl",
+ objects=prompt_result,
+ )
)
def flow(name):
diff --git a/tests/unit/test_gateway/test_core_import_export_roundtrip.py b/tests/unit/test_gateway/test_core_import_export_roundtrip.py
new file mode 100644
index 00000000..843a2b7b
--- /dev/null
+++ b/tests/unit/test_gateway/test_core_import_export_roundtrip.py
@@ -0,0 +1,418 @@
+"""
+Round-trip unit tests for the core msgpack import/export gateway endpoints.
+
+The kg-core export endpoint receives KnowledgeResponse-shaped dicts from
+the responder callback and packs them into msgpack tuples. The kg-core
+import endpoint takes msgpack tuples back off the wire and rebuilds
+KnowledgeRequest-shaped dicts which it then hands to KnowledgeRequestor
+(whose translator decodes them into real dataclasses).
+
+Regression coverage: the previous wire format used `"vectors"` (plural)
+in the entity blobs and embedded a stale `"m"` field that referenced the
+removed `Metadata.metadata` triples-list field. The export side hit a
+KeyError on first message; the import side built dicts that the
+KnowledgeRequestTranslator (separately fixed) couldn't decode. These
+tests pin both halves of the wire protocol.
+"""
+
+import msgpack
+import pytest
+from unittest.mock import AsyncMock, Mock, patch
+
+from trustgraph.gateway.dispatch.core_export import CoreExport
+from trustgraph.gateway.dispatch.core_import import CoreImport
+
+
+# ---------------------------------------------------------------------------
+# Helpers — sample translator-shaped dicts (as KnowledgeResponseTranslator
+# would emit). The vector wire key is *singular* on purpose; the export
+# side previously read the wrong key and crashed.
+# ---------------------------------------------------------------------------
+
+
+def _ge_response_dict():
+ return {
+ "graph-embeddings": {
+ "metadata": {
+ "id": "doc-1",
+ "root": "",
+ "user": "alice",
+ "collection": "testcoll",
+ },
+ "entities": [
+ {
+ "entity": {"t": "i", "i": "http://example.org/alice"},
+ "vector": [0.1, 0.2, 0.3],
+ },
+ {
+ "entity": {"t": "i", "i": "http://example.org/bob"},
+ "vector": [0.4, 0.5, 0.6],
+ },
+ ],
+ }
+ }
+
+
+def _triples_response_dict():
+ return {
+ "triples": {
+ "metadata": {
+ "id": "doc-1",
+ "root": "",
+ "user": "alice",
+ "collection": "testcoll",
+ },
+ "triples": [
+ {
+ "s": {"t": "i", "i": "http://example.org/alice"},
+ "p": {"t": "i", "i": "http://example.org/knows"},
+ "o": {"t": "i", "i": "http://example.org/bob"},
+ },
+ ],
+ }
+ }
+
+
+def _make_request(id_="doc-1", user="alice"):
+ request = Mock()
+ request.query = {"id": id_, "user": user}
+ return request
+
+
+def _make_data_reader(payload: bytes):
+ """Mock the aiohttp StreamReader: returns payload once, then EOF."""
+ chunks = [payload, b""]
+
+ data = Mock()
+
+ async def fake_read(n):
+ return chunks.pop(0) if chunks else b""
+
+ data.read = fake_read
+ return data
+
+
+# ---------------------------------------------------------------------------
+# Export side: translator-shaped dict -> msgpack bytes
+# ---------------------------------------------------------------------------
+
+
+class TestCoreExportWireFormat:
+
+ @pytest.mark.asyncio
+ @patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
+ async def test_export_packs_graph_embeddings_with_singular_vector(
+ self, mock_kr_class,
+ ):
+ """The export side must read `ent["vector"]` and emit `v`. The
+ previous bug was reading `ent["vectors"]` which KeyErrored against
+ the translator output."""
+ captured = []
+
+ async def fake_kr_process(req_dict, responder):
+ await responder(_ge_response_dict(), True)
+
+ mock_kr = AsyncMock()
+ mock_kr.start = AsyncMock()
+ mock_kr.stop = AsyncMock()
+ mock_kr.process = fake_kr_process
+ mock_kr_class.return_value = mock_kr
+
+ response = AsyncMock()
+
+ async def fake_write(b):
+ captured.append(b)
+
+ response.write = fake_write
+ response.write_eof = AsyncMock()
+
+ ok = AsyncMock(return_value=response)
+ error = AsyncMock()
+
+ exporter = CoreExport(backend=Mock())
+ await exporter.process(
+ data=Mock(),
+ error=error,
+ ok=ok,
+ request=_make_request(),
+ )
+
+ # Did not raise, did not call error()
+ error.assert_not_called()
+ assert len(captured) == 1
+
+ unpacker = msgpack.Unpacker()
+ unpacker.feed(captured[0])
+ items = list(unpacker)
+
+ assert len(items) == 1
+ msg_type, payload = items[0]
+ assert msg_type == "ge"
+
+ # Metadata envelope: only id/user/collection — no stale `m["m"]`.
+ assert payload["m"] == {
+ "i": "doc-1",
+ "u": "alice",
+ "c": "testcoll",
+ }
+
+ # Entities: each carries the *singular* `v` and the term envelope
+ assert len(payload["e"]) == 2
+ assert payload["e"][0]["v"] == [0.1, 0.2, 0.3]
+ assert payload["e"][1]["v"] == [0.4, 0.5, 0.6]
+ assert payload["e"][0]["e"]["i"] == "http://example.org/alice"
+
+ @pytest.mark.asyncio
+ @patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
+ async def test_export_packs_triples(self, mock_kr_class):
+ captured = []
+
+ async def fake_kr_process(req_dict, responder):
+ await responder(_triples_response_dict(), True)
+
+ mock_kr = AsyncMock()
+ mock_kr.start = AsyncMock()
+ mock_kr.stop = AsyncMock()
+ mock_kr.process = fake_kr_process
+ mock_kr_class.return_value = mock_kr
+
+ response = AsyncMock()
+
+ async def fake_write(b):
+ captured.append(b)
+
+ response.write = fake_write
+ response.write_eof = AsyncMock()
+
+ ok = AsyncMock(return_value=response)
+ error = AsyncMock()
+
+ exporter = CoreExport(backend=Mock())
+ await exporter.process(
+ data=Mock(), error=error, ok=ok, request=_make_request(),
+ )
+
+ error.assert_not_called()
+ assert len(captured) == 1
+
+ unpacker = msgpack.Unpacker()
+ unpacker.feed(captured[0])
+ items = list(unpacker)
+ assert len(items) == 1
+
+ msg_type, payload = items[0]
+ assert msg_type == "t"
+ assert payload["m"] == {
+ "i": "doc-1",
+ "u": "alice",
+ "c": "testcoll",
+ }
+ assert len(payload["t"]) == 1
+
+
+# ---------------------------------------------------------------------------
+# Import side: msgpack bytes -> translator-shaped dict
+# ---------------------------------------------------------------------------
+
+
+class TestCoreImportWireFormat:
+
+ @pytest.mark.asyncio
+ @patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
+ async def test_import_unpacks_graph_embeddings_to_singular_vector(
+ self, mock_kr_class,
+ ):
+ """The import side must build dicts whose entity blobs have the
+ singular `vector` key — that's what the KnowledgeRequestTranslator
+ decode side reads. Previous bug emitted `vectors`."""
+ captured = []
+
+ async def fake_kr_process(req_dict):
+ captured.append(req_dict)
+
+ mock_kr = AsyncMock()
+ mock_kr.start = AsyncMock()
+ mock_kr.stop = AsyncMock()
+ mock_kr.process = fake_kr_process
+ mock_kr_class.return_value = mock_kr
+
+ # Build a msgpack tuple matching the new wire format
+ payload = msgpack.packb((
+ "ge",
+ {
+ "m": {"i": "doc-1", "u": "alice", "c": "testcoll"},
+ "e": [
+ {
+ "e": {"t": "i", "i": "http://example.org/alice"},
+ "v": [0.1, 0.2, 0.3],
+ },
+ ],
+ },
+ ))
+
+ ok = AsyncMock(return_value=AsyncMock(write_eof=AsyncMock()))
+ error = AsyncMock()
+
+ importer = CoreImport(backend=Mock())
+ await importer.process(
+ data=_make_data_reader(payload),
+ error=error,
+ ok=ok,
+ request=_make_request(),
+ )
+
+ error.assert_not_called()
+ assert len(captured) == 1
+
+ req = captured[0]
+ assert req["operation"] == "put-kg-core"
+ assert req["user"] == "alice"
+ assert req["id"] == "doc-1"
+
+ ge = req["graph-embeddings"]
+ # Metadata envelope must NOT contain a stale `metadata` key
+ # referencing the removed Metadata.metadata field.
+ assert "metadata" not in ge["metadata"]
+ assert ge["metadata"] == {
+ "id": "doc-1",
+ "user": "alice",
+ "collection": "default",
+ }
+
+ # Entity blob carries the singular `vector` key
+ assert len(ge["entities"]) == 1
+ ent = ge["entities"][0]
+ assert ent["vector"] == [0.1, 0.2, 0.3]
+ assert "vectors" not in ent
+
+ @pytest.mark.asyncio
+ @patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
+ async def test_import_unpacks_triples(self, mock_kr_class):
+ captured = []
+
+ async def fake_kr_process(req_dict):
+ captured.append(req_dict)
+
+ mock_kr = AsyncMock()
+ mock_kr.start = AsyncMock()
+ mock_kr.stop = AsyncMock()
+ mock_kr.process = fake_kr_process
+ mock_kr_class.return_value = mock_kr
+
+ payload = msgpack.packb((
+ "t",
+ {
+ "m": {"i": "doc-1", "u": "alice", "c": "testcoll"},
+ "t": [
+ {
+ "s": {"t": "i", "i": "http://example.org/alice"},
+ "p": {"t": "i", "i": "http://example.org/knows"},
+ "o": {"t": "i", "i": "http://example.org/bob"},
+ },
+ ],
+ },
+ ))
+
+ ok = AsyncMock(return_value=AsyncMock(write_eof=AsyncMock()))
+ error = AsyncMock()
+
+ importer = CoreImport(backend=Mock())
+ await importer.process(
+ data=_make_data_reader(payload),
+ error=error,
+ ok=ok,
+ request=_make_request(),
+ )
+
+ error.assert_not_called()
+ assert len(captured) == 1
+
+ req = captured[0]
+ triples = req["triples"]
+ assert "metadata" not in triples["metadata"] # no stale field
+ assert len(triples["triples"]) == 1
+
+
+# ---------------------------------------------------------------------------
+# Full round-trip: export bytes feed directly into import
+# ---------------------------------------------------------------------------
+
+
+class TestCoreImportExportRoundTrip:
+ """End-to-end: produce bytes via core_export, consume them via
+ core_import, and verify the dict that lands at the import-side
+ translator is structurally equivalent to what went in. This is the
+ test that catches asymmetries between the two halves."""
+
+ @pytest.mark.asyncio
+ @patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
+ @patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
+ async def test_graph_embeddings_round_trip(
+ self, mock_export_kr_class, mock_import_kr_class,
+ ):
+ # ----- export side: capture bytes -----
+ export_bytes = []
+
+ async def fake_export_process(req_dict, responder):
+ await responder(_ge_response_dict(), True)
+
+ export_kr = AsyncMock()
+ export_kr.start = AsyncMock()
+ export_kr.stop = AsyncMock()
+ export_kr.process = fake_export_process
+ mock_export_kr_class.return_value = export_kr
+
+ response = AsyncMock()
+
+ async def fake_write(b):
+ export_bytes.append(b)
+
+ response.write = fake_write
+ response.write_eof = AsyncMock()
+
+ exporter = CoreExport(backend=Mock())
+ await exporter.process(
+ data=Mock(),
+ error=AsyncMock(),
+ ok=AsyncMock(return_value=response),
+ request=_make_request(),
+ )
+
+ assert len(export_bytes) == 1
+
+ # ----- import side: feed those bytes back in -----
+ import_captured = []
+
+ async def fake_import_process(req_dict):
+ import_captured.append(req_dict)
+
+ import_kr = AsyncMock()
+ import_kr.start = AsyncMock()
+ import_kr.stop = AsyncMock()
+ import_kr.process = fake_import_process
+ mock_import_kr_class.return_value = import_kr
+
+ importer = CoreImport(backend=Mock())
+ await importer.process(
+ data=_make_data_reader(export_bytes[0]),
+ error=AsyncMock(),
+ ok=AsyncMock(return_value=AsyncMock(write_eof=AsyncMock())),
+ request=_make_request(),
+ )
+
+ # ----- verify the dict the importer would hand to the translator -----
+ assert len(import_captured) == 1
+ req = import_captured[0]
+
+ original = _ge_response_dict()["graph-embeddings"]
+
+ ge = req["graph-embeddings"]
+ # The import side overrides id/user from the URL query (intentional),
+ # so we only round-trip the entity payload itself.
+ assert ge["metadata"]["id"] == original["metadata"]["id"]
+ assert ge["metadata"]["user"] == original["metadata"]["user"]
+
+ assert len(ge["entities"]) == len(original["entities"])
+ for got, want in zip(ge["entities"], original["entities"]):
+ assert got["vector"] == want["vector"]
+ assert got["entity"] == want["entity"]
diff --git a/tests/unit/test_gateway/test_entity_contexts_import_dispatcher.py b/tests/unit/test_gateway/test_entity_contexts_import_dispatcher.py
new file mode 100644
index 00000000..8eddeba9
--- /dev/null
+++ b/tests/unit/test_gateway/test_entity_contexts_import_dispatcher.py
@@ -0,0 +1,242 @@
+"""
+Unit tests for entity contexts import dispatcher.
+
+Tests the business logic of EntityContextsImport while mocking the
+Publisher and websocket components.
+
+Regression coverage: a previous version constructed Metadata(metadata=...)
+which raised TypeError at runtime as soon as a message was received. These
+tests exercise receive() end-to-end so any future schema/kwarg drift in
+the Metadata or EntityContexts construction is caught immediately.
+"""
+
+import pytest
+from unittest.mock import Mock, AsyncMock, patch
+
+from trustgraph.gateway.dispatch.entity_contexts_import import EntityContextsImport
+from trustgraph.schema import EntityContexts, EntityContext, Metadata
+
+
+@pytest.fixture
+def mock_backend():
+ return Mock()
+
+
+@pytest.fixture
+def mock_running():
+ running = Mock()
+ running.get.return_value = True
+ running.stop = Mock()
+ return running
+
+
+@pytest.fixture
+def mock_websocket():
+ ws = Mock()
+ ws.close = AsyncMock()
+ return ws
+
+
+@pytest.fixture
+def sample_message():
+ """Sample entity-contexts websocket message."""
+ return {
+ "metadata": {
+ "id": "doc-123",
+ "user": "testuser",
+ "collection": "testcollection",
+ },
+ "entities": [
+ {
+ "entity": {"v": "http://example.org/alice", "e": True},
+ "context": "Alice is a person.",
+ },
+ {
+ "entity": {"v": "http://example.org/bob", "e": True},
+ "context": "Bob is a person.",
+ },
+ ],
+ }
+
+
+@pytest.fixture
+def empty_entities_message():
+ return {
+ "metadata": {
+ "id": "doc-empty",
+ "user": "u",
+ "collection": "c",
+ },
+ "entities": [],
+ }
+
+
+class TestEntityContextsImportInitialization:
+
+ @patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
+ def test_init_creates_publisher_with_correct_params(
+ self, mock_publisher_class, mock_backend, mock_websocket, mock_running
+ ):
+ instance = Mock()
+ mock_publisher_class.return_value = instance
+
+ dispatcher = EntityContextsImport(
+ ws=mock_websocket,
+ running=mock_running,
+ backend=mock_backend,
+ queue="ec-queue",
+ )
+
+ mock_publisher_class.assert_called_once_with(
+ mock_backend,
+ topic="ec-queue",
+ schema=EntityContexts,
+ )
+ assert dispatcher.ws is mock_websocket
+ assert dispatcher.running is mock_running
+ assert dispatcher.publisher is instance
+
+
+class TestEntityContextsImportLifecycle:
+
+ @patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
+ @pytest.mark.asyncio
+ async def test_start_calls_publisher_start(
+ self, mock_publisher_class, mock_backend, mock_websocket, mock_running
+ ):
+ instance = Mock()
+ instance.start = AsyncMock()
+ mock_publisher_class.return_value = instance
+
+ dispatcher = EntityContextsImport(
+ ws=mock_websocket, running=mock_running,
+ backend=mock_backend, queue="q",
+ )
+ await dispatcher.start()
+ instance.start.assert_called_once()
+
+ @patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
+ @pytest.mark.asyncio
+ async def test_destroy_stops_and_closes_properly(
+ self, mock_publisher_class, mock_backend, mock_websocket, mock_running
+ ):
+ instance = Mock()
+ instance.stop = AsyncMock()
+ mock_publisher_class.return_value = instance
+
+ dispatcher = EntityContextsImport(
+ ws=mock_websocket, running=mock_running,
+ backend=mock_backend, queue="q",
+ )
+ await dispatcher.destroy()
+
+ mock_running.stop.assert_called_once()
+ instance.stop.assert_called_once()
+ mock_websocket.close.assert_called_once()
+
+ @patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
+ @pytest.mark.asyncio
+ async def test_destroy_handles_none_websocket(
+ self, mock_publisher_class, mock_backend, mock_running
+ ):
+ instance = Mock()
+ instance.stop = AsyncMock()
+ mock_publisher_class.return_value = instance
+
+ dispatcher = EntityContextsImport(
+ ws=None, running=mock_running,
+ backend=mock_backend, queue="q",
+ )
+ await dispatcher.destroy()
+
+ mock_running.stop.assert_called_once()
+ instance.stop.assert_called_once()
+
+
+class TestEntityContextsImportMessageProcessing:
+ """Regression coverage for receive(): catches Metadata/schema drift."""
+
+ @patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
+ @pytest.mark.asyncio
+ async def test_receive_constructs_entity_contexts_correctly(
+ self, mock_publisher_class, mock_backend, mock_websocket,
+ mock_running, sample_message,
+ ):
+ instance = Mock()
+ instance.send = AsyncMock()
+ mock_publisher_class.return_value = instance
+
+ dispatcher = EntityContextsImport(
+ ws=mock_websocket, running=mock_running,
+ backend=mock_backend, queue="q",
+ )
+
+ mock_msg = Mock()
+ mock_msg.json.return_value = sample_message
+
+ # If Metadata or EntityContexts gain/lose kwargs, this raises
+ # TypeError — exactly the regression we want to catch.
+ await dispatcher.receive(mock_msg)
+
+ instance.send.assert_called_once()
+ call_args = instance.send.call_args
+ assert call_args[0][0] is None
+
+ sent = call_args[0][1]
+ assert isinstance(sent, EntityContexts)
+ assert isinstance(sent.metadata, Metadata)
+ assert sent.metadata.id == "doc-123"
+ assert sent.metadata.user == "testuser"
+ assert sent.metadata.collection == "testcollection"
+
+ assert len(sent.entities) == 2
+ assert all(isinstance(e, EntityContext) for e in sent.entities)
+ assert sent.entities[0].context == "Alice is a person."
+ assert sent.entities[1].context == "Bob is a person."
+
+ @patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
+ @pytest.mark.asyncio
+ async def test_receive_handles_empty_entities(
+ self, mock_publisher_class, mock_backend, mock_websocket,
+ mock_running, empty_entities_message,
+ ):
+ instance = Mock()
+ instance.send = AsyncMock()
+ mock_publisher_class.return_value = instance
+
+ dispatcher = EntityContextsImport(
+ ws=mock_websocket, running=mock_running,
+ backend=mock_backend, queue="q",
+ )
+
+ mock_msg = Mock()
+ mock_msg.json.return_value = empty_entities_message
+
+ await dispatcher.receive(mock_msg)
+
+ instance.send.assert_called_once()
+ sent = instance.send.call_args[0][1]
+ assert isinstance(sent, EntityContexts)
+ assert sent.entities == []
+ assert sent.metadata.id == "doc-empty"
+
+ @patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
+ @pytest.mark.asyncio
+ async def test_receive_propagates_publisher_errors(
+ self, mock_publisher_class, mock_backend, mock_websocket,
+ mock_running, sample_message,
+ ):
+ instance = Mock()
+ instance.send = AsyncMock(side_effect=RuntimeError("publish failed"))
+ mock_publisher_class.return_value = instance
+
+ dispatcher = EntityContextsImport(
+ ws=mock_websocket, running=mock_running,
+ backend=mock_backend, queue="q",
+ )
+
+ mock_msg = Mock()
+ mock_msg.json.return_value = sample_message
+
+ with pytest.raises(RuntimeError, match="publish failed"):
+ await dispatcher.receive(mock_msg)
diff --git a/tests/unit/test_gateway/test_explain_triples.py b/tests/unit/test_gateway/test_explain_triples.py
index 24e77410..42a2f4c5 100644
--- a/tests/unit/test_gateway/test_explain_triples.py
+++ b/tests/unit/test_gateway/test_explain_triples.py
@@ -158,7 +158,7 @@ class TestAgentExplainTriples:
translator = AgentResponseTranslator()
response = AgentResponse(
- chunk_type="explain",
+ message_type="explain",
content="",
explain_id="urn:trustgraph:agent:session:abc123",
explain_graph="urn:graph:retrieval",
@@ -179,7 +179,7 @@ class TestAgentExplainTriples:
translator = AgentResponseTranslator()
response = AgentResponse(
- chunk_type="thought",
+ message_type="thought",
content="I need to think...",
)
@@ -190,7 +190,7 @@ class TestAgentExplainTriples:
translator = AgentResponseTranslator()
response = AgentResponse(
- chunk_type="explain",
+ message_type="explain",
explain_id="urn:trustgraph:agent:session:abc123",
explain_triples=sample_triples(),
end_of_dialog=False,
@@ -203,7 +203,7 @@ class TestAgentExplainTriples:
translator = AgentResponseTranslator()
response = AgentResponse(
- chunk_type="answer",
+ message_type="answer",
content="The answer is...",
end_of_dialog=True,
)
diff --git a/tests/unit/test_gateway/test_graph_embeddings_import_dispatcher.py b/tests/unit/test_gateway/test_graph_embeddings_import_dispatcher.py
new file mode 100644
index 00000000..fa277178
--- /dev/null
+++ b/tests/unit/test_gateway/test_graph_embeddings_import_dispatcher.py
@@ -0,0 +1,247 @@
+"""
+Unit tests for graph embeddings import dispatcher.
+
+Tests the business logic of GraphEmbeddingsImport while mocking the
+Publisher and websocket components.
+
+Regression coverage: a previous version of EntityContextsImport
+constructed Metadata(metadata=...) which raised TypeError at runtime as
+soon as a message was received. The same shape of bug can occur here, so
+these tests exercise receive() end-to-end to catch any future schema or
+kwarg drift in Metadata / GraphEmbeddings / EntityEmbeddings construction.
+"""
+
+import pytest
+from unittest.mock import Mock, AsyncMock, patch
+
+from trustgraph.gateway.dispatch.graph_embeddings_import import GraphEmbeddingsImport
+from trustgraph.schema import GraphEmbeddings, EntityEmbeddings, Metadata
+
+
+@pytest.fixture
+def mock_backend():
+ return Mock()
+
+
+@pytest.fixture
+def mock_running():
+ running = Mock()
+ running.get.return_value = True
+ running.stop = Mock()
+ return running
+
+
+@pytest.fixture
+def mock_websocket():
+ ws = Mock()
+ ws.close = AsyncMock()
+ return ws
+
+
+@pytest.fixture
+def sample_message():
+ """Sample graph-embeddings websocket message."""
+ return {
+ "metadata": {
+ "id": "doc-123",
+ "user": "testuser",
+ "collection": "testcollection",
+ },
+ "entities": [
+ {
+ "entity": {"v": "http://example.org/alice", "e": True},
+ "vector": [0.1, 0.2, 0.3],
+ },
+ {
+ "entity": {"v": "http://example.org/bob", "e": True},
+ "vector": [0.4, 0.5, 0.6],
+ },
+ ],
+ }
+
+
+@pytest.fixture
+def empty_entities_message():
+ return {
+ "metadata": {
+ "id": "doc-empty",
+ "user": "u",
+ "collection": "c",
+ },
+ "entities": [],
+ }
+
+
+class TestGraphEmbeddingsImportInitialization:
+
+ @patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
+ def test_init_creates_publisher_with_correct_params(
+ self, mock_publisher_class, mock_backend, mock_websocket, mock_running
+ ):
+ instance = Mock()
+ mock_publisher_class.return_value = instance
+
+ dispatcher = GraphEmbeddingsImport(
+ ws=mock_websocket,
+ running=mock_running,
+ backend=mock_backend,
+ queue="ge-queue",
+ )
+
+ mock_publisher_class.assert_called_once_with(
+ mock_backend,
+ topic="ge-queue",
+ schema=GraphEmbeddings,
+ )
+ assert dispatcher.ws is mock_websocket
+ assert dispatcher.running is mock_running
+ assert dispatcher.publisher is instance
+
+
+class TestGraphEmbeddingsImportLifecycle:
+
+ @patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
+ @pytest.mark.asyncio
+ async def test_start_calls_publisher_start(
+ self, mock_publisher_class, mock_backend, mock_websocket, mock_running
+ ):
+ instance = Mock()
+ instance.start = AsyncMock()
+ mock_publisher_class.return_value = instance
+
+ dispatcher = GraphEmbeddingsImport(
+ ws=mock_websocket, running=mock_running,
+ backend=mock_backend, queue="q",
+ )
+ await dispatcher.start()
+ instance.start.assert_called_once()
+
+ @patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
+ @pytest.mark.asyncio
+ async def test_destroy_stops_and_closes_properly(
+ self, mock_publisher_class, mock_backend, mock_websocket, mock_running
+ ):
+ instance = Mock()
+ instance.stop = AsyncMock()
+ mock_publisher_class.return_value = instance
+
+ dispatcher = GraphEmbeddingsImport(
+ ws=mock_websocket, running=mock_running,
+ backend=mock_backend, queue="q",
+ )
+ await dispatcher.destroy()
+
+ mock_running.stop.assert_called_once()
+ instance.stop.assert_called_once()
+ mock_websocket.close.assert_called_once()
+
+ @patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
+ @pytest.mark.asyncio
+ async def test_destroy_handles_none_websocket(
+ self, mock_publisher_class, mock_backend, mock_running
+ ):
+ instance = Mock()
+ instance.stop = AsyncMock()
+ mock_publisher_class.return_value = instance
+
+ dispatcher = GraphEmbeddingsImport(
+ ws=None, running=mock_running,
+ backend=mock_backend, queue="q",
+ )
+ await dispatcher.destroy()
+
+ mock_running.stop.assert_called_once()
+ instance.stop.assert_called_once()
+
+
+class TestGraphEmbeddingsImportMessageProcessing:
+ """Regression coverage for receive(): catches Metadata/schema drift."""
+
+ @patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
+ @pytest.mark.asyncio
+ async def test_receive_constructs_graph_embeddings_correctly(
+ self, mock_publisher_class, mock_backend, mock_websocket,
+ mock_running, sample_message,
+ ):
+ instance = Mock()
+ instance.send = AsyncMock()
+ mock_publisher_class.return_value = instance
+
+ dispatcher = GraphEmbeddingsImport(
+ ws=mock_websocket, running=mock_running,
+ backend=mock_backend, queue="q",
+ )
+
+ mock_msg = Mock()
+ mock_msg.json.return_value = sample_message
+
+ # If Metadata, GraphEmbeddings, or EntityEmbeddings gain/lose
+ # kwargs, this raises TypeError — exactly the regression we want
+ # to catch.
+ await dispatcher.receive(mock_msg)
+
+ instance.send.assert_called_once()
+ call_args = instance.send.call_args
+ assert call_args[0][0] is None
+
+ sent = call_args[0][1]
+ assert isinstance(sent, GraphEmbeddings)
+ assert isinstance(sent.metadata, Metadata)
+ assert sent.metadata.id == "doc-123"
+ assert sent.metadata.user == "testuser"
+ assert sent.metadata.collection == "testcollection"
+
+ assert len(sent.entities) == 2
+ assert all(isinstance(e, EntityEmbeddings) for e in sent.entities)
+ # Lock in the wire format: incoming "vector" key (singular,
+ # list[float]) maps to EntityEmbeddings.vector. This mirrors
+ # serialize_graph_embeddings() on the export side.
+ assert sent.entities[0].vector == [0.1, 0.2, 0.3]
+ assert sent.entities[1].vector == [0.4, 0.5, 0.6]
+
+ @patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
+ @pytest.mark.asyncio
+ async def test_receive_handles_empty_entities(
+ self, mock_publisher_class, mock_backend, mock_websocket,
+ mock_running, empty_entities_message,
+ ):
+ instance = Mock()
+ instance.send = AsyncMock()
+ mock_publisher_class.return_value = instance
+
+ dispatcher = GraphEmbeddingsImport(
+ ws=mock_websocket, running=mock_running,
+ backend=mock_backend, queue="q",
+ )
+
+ mock_msg = Mock()
+ mock_msg.json.return_value = empty_entities_message
+
+ await dispatcher.receive(mock_msg)
+
+ instance.send.assert_called_once()
+ sent = instance.send.call_args[0][1]
+ assert isinstance(sent, GraphEmbeddings)
+ assert sent.entities == []
+ assert sent.metadata.id == "doc-empty"
+
+ @patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
+ @pytest.mark.asyncio
+ async def test_receive_propagates_publisher_errors(
+ self, mock_publisher_class, mock_backend, mock_websocket,
+ mock_running, sample_message,
+ ):
+ instance = Mock()
+ instance.send = AsyncMock(side_effect=RuntimeError("publish failed"))
+ mock_publisher_class.return_value = instance
+
+ dispatcher = GraphEmbeddingsImport(
+ ws=mock_websocket, running=mock_running,
+ backend=mock_backend, queue="q",
+ )
+
+ mock_msg = Mock()
+ mock_msg.json.return_value = sample_message
+
+ with pytest.raises(RuntimeError, match="publish failed"):
+ await dispatcher.receive(mock_msg)
diff --git a/tests/unit/test_gateway/test_service.py b/tests/unit/test_gateway/test_service.py
index 22d9ab04..71428db4 100644
--- a/tests/unit/test_gateway/test_service.py
+++ b/tests/unit/test_gateway/test_service.py
@@ -171,6 +171,14 @@ class TestApi:
patch('aiohttp.web.run_app') as mock_run_app:
mock_get_pubsub.return_value = Mock()
+ # Api.run() passes self.app_factory() — a coroutine — to
+ # web.run_app, which would normally consume it inside its own
+ # event loop. Since we mock run_app, close the coroutine here
+ # so it doesn't leak as an "unawaited coroutine" RuntimeWarning.
+ def _consume_coro(coro, **kwargs):
+ coro.close()
+ mock_run_app.side_effect = _consume_coro
+
api = Api(port=8080)
api.run()
diff --git a/tests/unit/test_provenance/test_dag_structure.py b/tests/unit/test_provenance/test_dag_structure.py
new file mode 100644
index 00000000..184560f0
--- /dev/null
+++ b/tests/unit/test_provenance/test_dag_structure.py
@@ -0,0 +1,592 @@
+"""
+DAG structure tests for provenance chains.
+
+Verifies that the wasDerivedFrom chain has the expected shape for each
+service. These tests catch structural regressions when new entities are
+inserted into the chain (e.g. PatternDecision between session and first
+iteration).
+
+Expected chains:
+
+ GraphRAG: question → grounding → exploration → focus → synthesis
+ DocumentRAG: question → grounding → exploration → synthesis
+ Agent React: session → pattern-decision → iteration → (observation → iteration)* → final
+ Agent Plan: session → pattern-decision → plan → step-result(s) → synthesis
+ Agent Super: session → pattern-decision → decomposition → (fan-out) → finding(s) → synthesis
+"""
+
+import json
+import uuid
+import pytest
+from unittest.mock import AsyncMock, MagicMock, patch
+
+from trustgraph.schema import (
+ AgentRequest, AgentResponse, AgentStep, PlanStep,
+ Triple, Term, IRI, LITERAL,
+)
+from trustgraph.base import PromptResult
+
+from trustgraph.provenance.namespaces import (
+ RDF_TYPE, PROV_WAS_DERIVED_FROM, GRAPH_RETRIEVAL,
+ TG_AGENT_QUESTION, TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION,
+ TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
+ TG_ANALYSIS, TG_CONCLUSION, TG_PATTERN_DECISION,
+ TG_PLAN_TYPE, TG_STEP_RESULT, TG_DECOMPOSITION,
+ TG_OBSERVATION_TYPE,
+ TG_PATTERN, TG_TASK_TYPE,
+)
+
+from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
+from trustgraph.retrieval.document_rag.document_rag import DocumentRag
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+def _collect_events(events):
+ """Build a dict of explain_id → {types, derived_from, triples}."""
+ result = {}
+ for ev in events:
+ eid = ev["explain_id"]
+ triples = ev["triples"]
+ types = {
+ t.o.iri for t in triples
+ if t.s.iri == eid and t.p.iri == RDF_TYPE
+ }
+ parents = [
+ t.o.iri for t in triples
+ if t.s.iri == eid and t.p.iri == PROV_WAS_DERIVED_FROM
+ ]
+ result[eid] = {
+ "types": types,
+ "derived_from": parents[0] if parents else None,
+ "triples": triples,
+ }
+ return result
+
+
+def _find_by_type(dag, rdf_type):
+ """Find all event IDs that have the given rdf:type."""
+ return [eid for eid, info in dag.items() if rdf_type in info["types"]]
+
+
+def _assert_chain(dag, chain_types):
+ """Assert that a linear wasDerivedFrom chain exists through the given types."""
+ for i in range(1, len(chain_types)):
+ parent_type = chain_types[i - 1]
+ child_type = chain_types[i]
+ parents = _find_by_type(dag, parent_type)
+ children = _find_by_type(dag, child_type)
+ assert parents, f"No entity with type {parent_type}"
+ assert children, f"No entity with type {child_type}"
+ # At least one child must derive from at least one parent
+ linked = False
+ for child_id in children:
+ derived = dag[child_id]["derived_from"]
+ if derived in parents:
+ linked = True
+ break
+ assert linked, (
+ f"No {child_type} derives from {parent_type}. "
+ f"Children derive from: "
+ f"{[dag[c]['derived_from'] for c in children]}"
+ )
+
+
+# ---------------------------------------------------------------------------
+# GraphRAG DAG structure
+# ---------------------------------------------------------------------------
+
+class TestGraphRagDagStructure:
+ """Verify: question → grounding → exploration → focus → synthesis"""
+
+ @pytest.fixture
+ def mock_clients(self):
+ prompt_client = AsyncMock()
+ embeddings_client = AsyncMock()
+ graph_embeddings_client = AsyncMock()
+ triples_client = AsyncMock()
+
+ embeddings_client.embed.return_value = [[0.1, 0.2]]
+ graph_embeddings_client.query.return_value = [
+ MagicMock(entity=Term(type=IRI, iri="http://example.com/e1")),
+ ]
+ triples_client.query_stream.return_value = [
+ Triple(
+ s=Term(type=IRI, iri="http://example.com/e1"),
+ p=Term(type=IRI, iri="http://example.com/p"),
+ o=Term(type=LITERAL, value="value"),
+ )
+ ]
+ triples_client.query.return_value = []
+
+ async def mock_prompt(template_id, variables=None, **kwargs):
+ if template_id == "extract-concepts":
+ return PromptResult(response_type="text", text="concept")
+ elif template_id == "kg-edge-scoring":
+ edges = variables.get("knowledge", [])
+ return PromptResult(
+ response_type="jsonl",
+ objects=[{"id": e["id"], "score": 10} for e in edges],
+ )
+ elif template_id == "kg-edge-reasoning":
+ edges = variables.get("knowledge", [])
+ return PromptResult(
+ response_type="jsonl",
+ objects=[{"id": e["id"], "reasoning": "relevant"} for e in edges],
+ )
+ elif template_id == "kg-synthesis":
+ return PromptResult(response_type="text", text="Answer.")
+ return PromptResult(response_type="text", text="")
+
+ prompt_client.prompt.side_effect = mock_prompt
+ return prompt_client, embeddings_client, graph_embeddings_client, triples_client
+
+ @pytest.mark.asyncio
+ async def test_dag_chain(self, mock_clients):
+ rag = GraphRag(*mock_clients)
+ events = []
+
+ async def explain_cb(triples, explain_id):
+ events.append({"explain_id": explain_id, "triples": triples})
+
+ await rag.query(
+ query="test", explain_callback=explain_cb, edge_score_limit=0,
+ )
+
+ dag = _collect_events(events)
+ assert len(dag) == 5, f"Expected 5 events, got {len(dag)}"
+
+ _assert_chain(dag, [
+ TG_GRAPH_RAG_QUESTION,
+ TG_GROUNDING,
+ TG_EXPLORATION,
+ TG_FOCUS,
+ TG_SYNTHESIS,
+ ])
+
+
+# ---------------------------------------------------------------------------
+# DocumentRAG DAG structure
+# ---------------------------------------------------------------------------
+
+class TestDocumentRagDagStructure:
+ """Verify: question → grounding → exploration → synthesis"""
+
+ @pytest.fixture
+ def mock_clients(self):
+ from trustgraph.schema import ChunkMatch
+
+ prompt_client = AsyncMock()
+ embeddings_client = AsyncMock()
+ doc_embeddings_client = AsyncMock()
+ fetch_chunk = AsyncMock(return_value="Chunk content.")
+
+ embeddings_client.embed.return_value = [[0.1, 0.2]]
+ doc_embeddings_client.query.return_value = [
+ ChunkMatch(chunk_id="doc/c1", score=0.9),
+ ]
+
+ async def mock_prompt(template_id, variables=None, **kwargs):
+ if template_id == "extract-concepts":
+ return PromptResult(response_type="text", text="concept")
+ return PromptResult(response_type="text", text="")
+
+ prompt_client.prompt.side_effect = mock_prompt
+ prompt_client.document_prompt.return_value = PromptResult(
+ response_type="text", text="Answer.",
+ )
+
+ return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
+
+ @pytest.mark.asyncio
+ async def test_dag_chain(self, mock_clients):
+ rag = DocumentRag(*mock_clients)
+ events = []
+
+ async def explain_cb(triples, explain_id):
+ events.append({"explain_id": explain_id, "triples": triples})
+
+ await rag.query(
+ query="test", explain_callback=explain_cb,
+ )
+
+ dag = _collect_events(events)
+ assert len(dag) == 4, f"Expected 4 events, got {len(dag)}"
+
+ _assert_chain(dag, [
+ TG_DOC_RAG_QUESTION,
+ TG_GROUNDING,
+ TG_EXPLORATION,
+ TG_SYNTHESIS,
+ ])
+
+
+# ---------------------------------------------------------------------------
+# Agent DAG structure — tested via service.agent_request()
+# ---------------------------------------------------------------------------
+
+def _make_processor(tools=None):
+ processor = MagicMock()
+ processor.max_iterations = 10
+ processor.save_answer_content = AsyncMock()
+
+ def mock_session_uri(sid):
+ return f"urn:trustgraph:agent:session:{sid}"
+ processor.provenance_session_uri.side_effect = mock_session_uri
+
+ agent = MagicMock()
+ agent.tools = tools or {}
+ agent.additional_context = ""
+ processor.agent = agent
+ processor.aggregator = MagicMock()
+
+ return processor
+
+
+def _make_flow():
+ producers = {}
+
+ def factory(name):
+ if name not in producers:
+ producers[name] = AsyncMock()
+ return producers[name]
+
+ flow = MagicMock(side_effect=factory)
+ return flow
+
+
+def _collect_agent_events(respond_mock):
+ events = []
+ for call in respond_mock.call_args_list:
+ resp = call[0][0]
+ if isinstance(resp, AgentResponse) and resp.message_type == "explain":
+ events.append({
+ "explain_id": resp.explain_id,
+ "triples": resp.explain_triples,
+ })
+ return events
+
+
+class TestAgentReactDagStructure:
+ """
+ Via service.agent_request(), full two-iteration react chain:
+ session → pattern-decision → iteration(1) → observation(1) → final
+
+ Iteration 1: tool call → observation
+ Iteration 2: final answer
+ """
+
+ def _make_service(self):
+ from trustgraph.agent.orchestrator.service import Processor
+ from trustgraph.agent.orchestrator.react_pattern import ReactPattern
+ from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
+ from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
+
+ mock_tool = MagicMock()
+ mock_tool.name = "lookup"
+ mock_tool.description = "Look things up"
+ mock_tool.arguments = []
+ mock_tool.groups = []
+ mock_tool.states = {}
+ mock_tool_impl = AsyncMock(return_value="42")
+ mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
+
+ processor = _make_processor(tools={"lookup": mock_tool})
+
+ service = Processor.__new__(Processor)
+ service.max_iterations = 10
+ service.save_answer_content = AsyncMock()
+ service.provenance_session_uri = processor.provenance_session_uri
+ service.agent = processor.agent
+ service.aggregator = processor.aggregator
+
+ service.react_pattern = ReactPattern(service)
+ service.plan_pattern = PlanThenExecutePattern(service)
+ service.supervisor_pattern = SupervisorPattern(service)
+ service.meta_router = None
+
+ return service
+
+ @pytest.mark.asyncio
+ async def test_dag_chain(self):
+ from trustgraph.agent.react.types import Action, Final
+
+ service = self._make_service()
+
+ respond = AsyncMock()
+ next_fn = AsyncMock()
+ flow = _make_flow()
+ session_id = str(uuid.uuid4())
+
+ # Iteration 1: tool call → returns Action, triggers on_action + tool exec
+ action = Action(
+ thought="I need to look this up",
+ name="lookup",
+ arguments={"question": "6x7"},
+ observation="",
+ )
+
+ with patch(
+ "trustgraph.agent.orchestrator.react_pattern.AgentManager"
+ ) as MockAM:
+ mock_am = AsyncMock()
+ MockAM.return_value = mock_am
+
+ async def mock_react_iter1(on_action=None, **kwargs):
+ if on_action:
+ await on_action(action)
+ action.observation = "42"
+ return action
+
+ mock_am.react.side_effect = mock_react_iter1
+
+ request1 = AgentRequest(
+ question="What is 6x7?",
+ user="testuser",
+ collection="default",
+ streaming=False,
+ session_id=session_id,
+ pattern="react",
+ history=[],
+ )
+
+ await service.agent_request(request1, respond, next_fn, flow)
+
+ # next_fn should have been called with updated history
+ assert next_fn.called
+
+ # Iteration 2: final answer
+ final = Final(thought="The answer is 42", final="42")
+ next_request = next_fn.call_args[0][0]
+
+ with patch(
+ "trustgraph.agent.orchestrator.react_pattern.AgentManager"
+ ) as MockAM:
+ mock_am = AsyncMock()
+ MockAM.return_value = mock_am
+
+ async def mock_react_iter2(**kwargs):
+ return final
+
+ mock_am.react.side_effect = mock_react_iter2
+
+ await service.agent_request(next_request, respond, next_fn, flow)
+
+ # Collect and verify DAG
+ events = _collect_agent_events(respond)
+ dag = _collect_events(events)
+
+ session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
+ pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
+ analysis_ids = _find_by_type(dag, TG_ANALYSIS)
+ observation_ids = _find_by_type(dag, TG_OBSERVATION_TYPE)
+ final_ids = _find_by_type(dag, TG_CONCLUSION)
+
+ assert len(session_ids) == 1, f"Expected 1 session, got {len(session_ids)}"
+ assert len(pd_ids) == 1, f"Expected 1 pattern-decision, got {len(pd_ids)}"
+ assert len(analysis_ids) >= 1, f"Expected >=1 analysis, got {len(analysis_ids)}"
+ assert len(observation_ids) >= 1, f"Expected >=1 observation, got {len(observation_ids)}"
+ assert len(final_ids) == 1, f"Expected 1 final, got {len(final_ids)}"
+
+ # Full chain:
+ # session → pattern-decision
+ assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
+
+ # pattern-decision → iteration(1)
+ assert dag[analysis_ids[0]]["derived_from"] == pd_ids[0]
+
+ # iteration(1) → observation(1)
+ assert dag[observation_ids[0]]["derived_from"] == analysis_ids[0]
+
+ # observation(1) → final
+ assert dag[final_ids[0]]["derived_from"] == observation_ids[0]
+
+
+class TestAgentPlanDagStructure:
+ """
+ Via service.agent_request():
+ session → pattern-decision → plan → step-result → synthesis
+ """
+
+ @pytest.mark.asyncio
+ async def test_dag_chain(self):
+ from trustgraph.agent.orchestrator.service import Processor
+ from trustgraph.agent.orchestrator.react_pattern import ReactPattern
+ from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
+ from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
+
+ # Mock tool
+ mock_tool = MagicMock()
+ mock_tool.name = "knowledge-query"
+ mock_tool.description = "Query KB"
+ mock_tool.arguments = []
+ mock_tool.groups = []
+ mock_tool.states = {}
+ mock_tool_impl = AsyncMock(return_value="Found it")
+ mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
+
+ processor = _make_processor(tools={"knowledge-query": mock_tool})
+
+ service = Processor.__new__(Processor)
+ service.max_iterations = 10
+ service.save_answer_content = AsyncMock()
+ service.provenance_session_uri = processor.provenance_session_uri
+ service.agent = processor.agent
+ service.aggregator = processor.aggregator
+
+ service.react_pattern = ReactPattern(service)
+ service.plan_pattern = PlanThenExecutePattern(service)
+ service.supervisor_pattern = SupervisorPattern(service)
+ service.meta_router = None
+
+ respond = AsyncMock()
+ next_fn = AsyncMock()
+ flow = _make_flow()
+
+ # Mock prompt client
+ mock_prompt_client = AsyncMock()
+
+ call_count = 0
+
+ async def mock_prompt(id, variables=None, **kwargs):
+ nonlocal call_count
+ call_count += 1
+ if id == "plan-create":
+ return PromptResult(
+ response_type="jsonl",
+ objects=[{"goal": "Find info", "tool_hint": "knowledge-query", "depends_on": []}],
+ )
+ elif id == "plan-step-execute":
+ return PromptResult(
+ response_type="json",
+ object={"tool": "knowledge-query", "arguments": {"question": "test"}},
+ )
+ elif id == "plan-synthesise":
+ return PromptResult(response_type="text", text="Final answer.")
+ return PromptResult(response_type="text", text="")
+
+ mock_prompt_client.prompt.side_effect = mock_prompt
+
+ def flow_factory(name):
+ if name == "prompt-request":
+ return mock_prompt_client
+ return AsyncMock()
+ flow.side_effect = flow_factory
+
+ session_id = str(uuid.uuid4())
+
+ # Iteration 1: planning
+ request1 = AgentRequest(
+ question="Test?",
+ user="testuser",
+ collection="default",
+ streaming=False,
+ session_id=session_id,
+ pattern="plan-then-execute",
+ history=[],
+ )
+ await service.agent_request(request1, respond, next_fn, flow)
+
+ # Iteration 2: execute step (next_fn was called with updated request)
+ assert next_fn.called
+ next_request = next_fn.call_args[0][0]
+
+ # Iteration 3: all steps done → synthesis
+ # Simulate completed step in history
+ next_request.history[-1].plan[0].status = "completed"
+ next_request.history[-1].plan[0].result = "Found it"
+
+ await service.agent_request(next_request, respond, next_fn, flow)
+
+ events = _collect_agent_events(respond)
+ dag = _collect_events(events)
+
+ session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
+ pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
+ plan_ids = _find_by_type(dag, TG_PLAN_TYPE)
+ synthesis_ids = _find_by_type(dag, TG_SYNTHESIS)
+
+ assert len(session_ids) == 1
+ assert len(pd_ids) == 1
+ assert len(plan_ids) == 1
+ assert len(synthesis_ids) == 1
+
+ # Chain: session → pattern-decision → plan → ... → synthesis
+ assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
+ assert dag[plan_ids[0]]["derived_from"] == pd_ids[0]
+
+
+class TestAgentSupervisorDagStructure:
+ """
+ Via service.agent_request():
+ session → pattern-decision → decomposition → (fan-out)
+ """
+
+ @pytest.mark.asyncio
+ async def test_dag_chain(self):
+ from trustgraph.agent.orchestrator.service import Processor
+ from trustgraph.agent.orchestrator.react_pattern import ReactPattern
+ from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
+ from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
+
+ processor = _make_processor()
+
+ service = Processor.__new__(Processor)
+ service.max_iterations = 10
+ service.save_answer_content = AsyncMock()
+ service.provenance_session_uri = processor.provenance_session_uri
+ service.agent = processor.agent
+ service.aggregator = processor.aggregator
+
+ service.react_pattern = ReactPattern(service)
+ service.plan_pattern = PlanThenExecutePattern(service)
+ service.supervisor_pattern = SupervisorPattern(service)
+ service.meta_router = None
+
+ respond = AsyncMock()
+ next_fn = AsyncMock()
+ flow = _make_flow()
+
+ mock_prompt_client = AsyncMock()
+ mock_prompt_client.prompt.return_value = PromptResult(
+ response_type="jsonl",
+ objects=["Goal A", "Goal B"],
+ )
+
+ def flow_factory(name):
+ if name == "prompt-request":
+ return mock_prompt_client
+ return AsyncMock()
+ flow.side_effect = flow_factory
+
+ request = AgentRequest(
+ question="Research quantum computing",
+ user="testuser",
+ collection="default",
+ streaming=False,
+ session_id=str(uuid.uuid4()),
+ pattern="supervisor",
+ history=[],
+ )
+
+ await service.agent_request(request, respond, next_fn, flow)
+
+ events = _collect_agent_events(respond)
+ dag = _collect_events(events)
+
+ session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
+ pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
+ decomp_ids = _find_by_type(dag, TG_DECOMPOSITION)
+
+ assert len(session_ids) == 1
+ assert len(pd_ids) == 1
+ assert len(decomp_ids) == 1
+
+ # Chain: session → pattern-decision → decomposition
+ assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
+ assert dag[decomp_ids[0]]["derived_from"] == pd_ids[0]
+
+ # Fan-out should have been called
+ assert next_fn.call_count == 2 # One per goal
diff --git a/tests/unit/test_provenance/test_triples.py b/tests/unit/test_provenance/test_triples.py
index 792db028..f906a00d 100644
--- a/tests/unit/test_provenance/test_triples.py
+++ b/tests/unit/test_provenance/test_triples.py
@@ -223,7 +223,7 @@ class TestDerivedEntityTriples:
assert has_type(triples, self.ENTITY_URI, PROV_ENTITY)
assert has_type(triples, self.ENTITY_URI, TG_PAGE_TYPE)
- def test_chunk_entity_has_chunk_type(self):
+ def test_chunk_entity_has_message_type(self):
triples = derived_entity_triples(
self.ENTITY_URI, self.PARENT_URI,
"chunker", "1.0",
diff --git a/tests/unit/test_python_api_client.py b/tests/unit/test_python_api_client.py
index 80443a0c..0b6709fb 100644
--- a/tests/unit/test_python_api_client.py
+++ b/tests/unit/test_python_api_client.py
@@ -304,14 +304,14 @@ class TestStreamingTypes:
assert chunk.content == "thinking..."
assert chunk.end_of_message is False
- assert chunk.chunk_type == "thought"
+ assert chunk.message_type == "thought"
def test_agent_observation_creation(self):
"""Test creating AgentObservation chunk"""
chunk = AgentObservation(content="observing...", end_of_message=False)
assert chunk.content == "observing..."
- assert chunk.chunk_type == "observation"
+ assert chunk.message_type == "observation"
def test_agent_answer_creation(self):
"""Test creating AgentAnswer chunk"""
@@ -324,7 +324,7 @@ class TestStreamingTypes:
assert chunk.content == "answer"
assert chunk.end_of_message is True
assert chunk.end_of_dialog is True
- assert chunk.chunk_type == "final-answer"
+ assert chunk.message_type == "final-answer"
def test_rag_chunk_creation(self):
"""Test creating RAGChunk"""
diff --git a/tests/unit/test_retrieval/test_document_rag.py b/tests/unit/test_retrieval/test_document_rag.py
index 27508ba4..1ff85f5a 100644
--- a/tests/unit/test_retrieval/test_document_rag.py
+++ b/tests/unit/test_retrieval/test_document_rag.py
@@ -6,6 +6,7 @@ import pytest
from unittest.mock import MagicMock, AsyncMock
from trustgraph.retrieval.document_rag.document_rag import DocumentRag, Query
+from trustgraph.base import PromptResult
# Sample chunk content mapping for tests
@@ -132,7 +133,7 @@ class TestQuery:
mock_rag.prompt_client = mock_prompt_client
# Mock the prompt response with concept lines
- mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence\ndata patterns"
+ mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nartificial intelligence\ndata patterns")
query = Query(
rag=mock_rag,
@@ -157,7 +158,7 @@ class TestQuery:
mock_rag.prompt_client = mock_prompt_client
# Mock empty response
- mock_prompt_client.prompt.return_value = ""
+ mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
query = Query(
rag=mock_rag,
@@ -258,7 +259,7 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction
- mock_prompt_client.prompt.return_value = "test concept"
+ mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="test concept")
# Mock embeddings - one vector per concept
test_vectors = [[0.1, 0.2, 0.3]]
@@ -273,7 +274,7 @@ class TestQuery:
expected_response = "This is the document RAG response"
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
- mock_prompt_client.document_prompt.return_value = expected_response
+ mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text=expected_response)
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@@ -315,7 +316,8 @@ class TestQuery:
assert "Relevant document content" in docs
assert "Another document" in docs
- assert result == expected_response
+ result_text, usage = result
+ assert result_text == expected_response
@pytest.mark.asyncio
async def test_document_rag_query_with_defaults(self, mock_fetch_chunk):
@@ -325,7 +327,7 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction fallback (empty → raw query)
- mock_prompt_client.prompt.return_value = ""
+ mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
# Mock responses
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
@@ -333,7 +335,7 @@ class TestQuery:
mock_match.chunk_id = "doc/c5"
mock_match.score = 0.9
mock_doc_embeddings_client.query.return_value = [mock_match]
- mock_prompt_client.document_prompt.return_value = "Default response"
+ mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="Default response")
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@@ -352,7 +354,8 @@ class TestQuery:
collection="default" # Default collection
)
- assert result == "Default response"
+ result_text, usage = result
+ assert result_text == "Default response"
@pytest.mark.asyncio
async def test_get_docs_with_verbose_output(self):
@@ -401,7 +404,7 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction
- mock_prompt_client.prompt.return_value = "verbose query test"
+ mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="verbose query test")
# Mock responses
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
@@ -409,7 +412,7 @@ class TestQuery:
mock_match.chunk_id = "doc/c7"
mock_match.score = 0.92
mock_doc_embeddings_client.query.return_value = [mock_match]
- mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
+ mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="Verbose RAG response")
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@@ -428,7 +431,8 @@ class TestQuery:
assert call_args.kwargs["query"] == "verbose query test"
assert "Verbose doc content" in call_args.kwargs["documents"]
- assert result == "Verbose RAG response"
+ result_text, usage = result
+ assert result_text == "Verbose RAG response"
@pytest.mark.asyncio
async def test_get_docs_with_empty_results(self):
@@ -469,11 +473,11 @@ class TestQuery:
mock_doc_embeddings_client = AsyncMock()
# Mock concept extraction
- mock_prompt_client.prompt.return_value = "query with no matching docs"
+ mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="query with no matching docs")
mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]]
mock_doc_embeddings_client.query.return_value = []
- mock_prompt_client.document_prompt.return_value = "No documents found response"
+ mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="No documents found response")
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@@ -490,7 +494,8 @@ class TestQuery:
documents=[]
)
- assert result == "No documents found response"
+ result_text, usage = result
+ assert result_text == "No documents found response"
@pytest.mark.asyncio
async def test_get_vectors_with_verbose(self):
@@ -525,7 +530,7 @@ class TestQuery:
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
# Mock concept extraction
- mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence"
+ mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nartificial intelligence")
# Mock embeddings - one vector per concept
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]]
@@ -541,7 +546,7 @@ class TestQuery:
MagicMock(chunk_id="doc/ml3", score=0.82),
]
mock_doc_embeddings_client.query.side_effect = [mock_matches_1, mock_matches_2]
- mock_prompt_client.document_prompt.return_value = final_response
+ mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text=final_response)
document_rag = DocumentRag(
prompt_client=mock_prompt_client,
@@ -584,7 +589,8 @@ class TestQuery:
assert "Common ML techniques include supervised and unsupervised learning..." in docs
assert len(docs) == 3 # doc/ml2 deduplicated
- assert result == final_response
+ result_text, usage = result
+ assert result_text == final_response
@pytest.mark.asyncio
async def test_get_docs_deduplicates_across_concepts(self):
diff --git a/tests/unit/test_retrieval/test_document_rag_provenance_integration.py b/tests/unit/test_retrieval/test_document_rag_provenance_integration.py
index 74157285..8fa10642 100644
--- a/tests/unit/test_retrieval/test_document_rag_provenance_integration.py
+++ b/tests/unit/test_retrieval/test_document_rag_provenance_integration.py
@@ -12,6 +12,7 @@ from unittest.mock import AsyncMock
from dataclasses import dataclass
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
+from trustgraph.base import PromptResult
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
@@ -89,8 +90,8 @@ def build_mock_clients():
# 1. Concept extraction
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
- return "return policy\nrefund"
- return ""
+ return PromptResult(response_type="text", text="return policy\nrefund")
+ return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt
@@ -113,8 +114,9 @@ def build_mock_clients():
fetch_chunk.side_effect = mock_fetch
# 5. Synthesis
- prompt_client.document_prompt.return_value = (
- "Items can be returned within 30 days for a full refund."
+ prompt_client.document_prompt.return_value = PromptResult(
+ response_type="text",
+ text="Items can be returned within 30 days for a full refund.",
)
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
@@ -340,12 +342,12 @@ class TestDocumentRagQueryProvenance:
clients = build_mock_clients()
rag = DocumentRag(*clients)
- result = await rag.query(
+ result_text, usage = await rag.query(
query="What is the return policy?",
explain_callback=AsyncMock(),
)
- assert result == "Items can be returned within 30 days for a full refund."
+ assert result_text == "Items can be returned within 30 days for a full refund."
@pytest.mark.asyncio
async def test_no_explain_callback_still_works(self):
@@ -353,8 +355,8 @@ class TestDocumentRagQueryProvenance:
clients = build_mock_clients()
rag = DocumentRag(*clients)
- result = await rag.query(query="What is the return policy?")
- assert result == "Items can be returned within 30 days for a full refund."
+ result_text, usage = await rag.query(query="What is the return policy?")
+ assert result_text == "Items can be returned within 30 days for a full refund."
@pytest.mark.asyncio
async def test_all_triples_in_retrieval_graph(self):
diff --git a/tests/unit/test_retrieval/test_document_rag_service.py b/tests/unit/test_retrieval/test_document_rag_service.py
index 05e1bb60..a5d42f3a 100644
--- a/tests/unit/test_retrieval/test_document_rag_service.py
+++ b/tests/unit/test_retrieval/test_document_rag_service.py
@@ -34,7 +34,7 @@ class TestDocumentRagService:
# Setup mock DocumentRag instance
mock_rag_instance = AsyncMock()
mock_document_rag_class.return_value = mock_rag_instance
- mock_rag_instance.query.return_value = "test response"
+ mock_rag_instance.query.return_value = ("test response", {"in_token": None, "out_token": None, "model": None})
# Setup message with custom user/collection
msg = MagicMock()
@@ -97,7 +97,7 @@ class TestDocumentRagService:
# Setup mock DocumentRag instance
mock_rag_instance = AsyncMock()
mock_document_rag_class.return_value = mock_rag_instance
- mock_rag_instance.query.return_value = "A document about cats."
+ mock_rag_instance.query.return_value = ("A document about cats.", {"in_token": None, "out_token": None, "model": None})
# Setup message with non-streaming request
msg = MagicMock()
@@ -130,4 +130,5 @@ class TestDocumentRagService:
assert isinstance(sent_response, DocumentRagResponse)
assert sent_response.response == "A document about cats."
assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True"
+ assert sent_response.end_of_session is True
assert sent_response.error is None
\ No newline at end of file
diff --git a/tests/unit/test_retrieval/test_graph_rag.py b/tests/unit/test_retrieval/test_graph_rag.py
index 00d8b72a..00a9551f 100644
--- a/tests/unit/test_retrieval/test_graph_rag.py
+++ b/tests/unit/test_retrieval/test_graph_rag.py
@@ -7,6 +7,7 @@ import unittest.mock
from unittest.mock import MagicMock, AsyncMock
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, Query
+from trustgraph.base import PromptResult
class TestGraphRag:
@@ -172,7 +173,7 @@ class TestQuery:
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
- mock_prompt_client.prompt.return_value = "machine learning\nneural networks\n"
+ mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nneural networks\n")
query = Query(
rag=mock_rag,
@@ -196,7 +197,7 @@ class TestQuery:
mock_prompt_client = AsyncMock()
mock_rag.prompt_client = mock_prompt_client
- mock_prompt_client.prompt.return_value = ""
+ mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
query = Query(
rag=mock_rag,
@@ -220,7 +221,7 @@ class TestQuery:
mock_rag.graph_embeddings_client = mock_graph_embeddings_client
# extract_concepts returns empty -> falls back to [query]
- mock_prompt_client.prompt.return_value = ""
+ mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
# embed returns one vector set for the single concept
test_vectors = [[0.1, 0.2, 0.3]]
@@ -565,14 +566,14 @@ class TestQuery:
# Mock prompt responses for the multi-step process
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
if prompt_name == "extract-concepts":
- return "" # Falls back to raw query
+ return PromptResult(response_type="text", text="")
elif prompt_name == "kg-edge-scoring":
- return json.dumps({"id": test_edge_id, "score": 0.9})
+ return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "score": 0.9}])
elif prompt_name == "kg-edge-reasoning":
- return json.dumps({"id": test_edge_id, "reasoning": "relevant"})
+ return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "reasoning": "relevant"}])
elif prompt_name == "kg-synthesis":
- return expected_response
- return ""
+ return PromptResult(response_type="text", text=expected_response)
+ return PromptResult(response_type="text", text="")
mock_prompt_client.prompt = mock_prompt
@@ -607,7 +608,8 @@ class TestQuery:
explain_callback=collect_provenance
)
- assert response == expected_response
+ response_text, usage = response
+ assert response_text == expected_response
# 5 events: question, grounding, exploration, focus, synthesis
assert len(provenance_events) == 5
diff --git a/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py b/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py
index 36536f7d..1eb0dd72 100644
--- a/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py
+++ b/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py
@@ -13,6 +13,7 @@ from dataclasses import dataclass
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, edge_id
from trustgraph.schema import Triple as SchemaTriple, Term, IRI, LITERAL
+from trustgraph.base import PromptResult
from trustgraph.provenance.namespaces import (
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
@@ -136,24 +137,36 @@ def build_mock_clients():
async def mock_prompt(template_id, variables=None, **kwargs):
if template_id == "extract-concepts":
- return prompt_responses["extract-concepts"]
+ return PromptResult(
+ response_type="text",
+ text=prompt_responses["extract-concepts"],
+ )
elif template_id == "kg-edge-scoring":
# Score all edges highly, using the IDs that GraphRag computed
edges = variables.get("knowledge", [])
- return [
- {"id": e["id"], "score": 10 - i}
- for i, e in enumerate(edges)
- ]
+ return PromptResult(
+ response_type="jsonl",
+ objects=[
+ {"id": e["id"], "score": 10 - i}
+ for i, e in enumerate(edges)
+ ],
+ )
elif template_id == "kg-edge-reasoning":
# Provide reasoning for each edge
edges = variables.get("knowledge", [])
- return [
- {"id": e["id"], "reasoning": f"Relevant edge {i}"}
- for i, e in enumerate(edges)
- ]
+ return PromptResult(
+ response_type="jsonl",
+ objects=[
+ {"id": e["id"], "reasoning": f"Relevant edge {i}"}
+ for i, e in enumerate(edges)
+ ],
+ )
elif template_id == "kg-synthesis":
- return synthesis_answer
- return ""
+ return PromptResult(
+ response_type="text",
+ text=synthesis_answer,
+ )
+ return PromptResult(response_type="text", text="")
prompt_client.prompt.side_effect = mock_prompt
@@ -413,13 +426,13 @@ class TestGraphRagQueryProvenance:
async def explain_callback(triples, explain_id):
events.append({"triples": triples, "explain_id": explain_id})
- result = await rag.query(
+ result_text, usage = await rag.query(
query="What is quantum computing?",
explain_callback=explain_callback,
edge_score_limit=0,
)
- assert result == "Quantum computing applies physics principles to computation."
+ assert result_text == "Quantum computing applies physics principles to computation."
@pytest.mark.asyncio
async def test_parent_uri_links_question_to_parent(self):
@@ -450,12 +463,12 @@ class TestGraphRagQueryProvenance:
clients = build_mock_clients()
rag = GraphRag(*clients)
- result = await rag.query(
+ result_text, usage = await rag.query(
query="What is quantum computing?",
edge_score_limit=0,
)
- assert result == "Quantum computing applies physics principles to computation."
+ assert result_text == "Quantum computing applies physics principles to computation."
@pytest.mark.asyncio
async def test_all_triples_in_retrieval_graph(self):
diff --git a/tests/unit/test_retrieval/test_graph_rag_service.py b/tests/unit/test_retrieval/test_graph_rag_service.py
index 2cd62286..606aa7fe 100644
--- a/tests/unit/test_retrieval/test_graph_rag_service.py
+++ b/tests/unit/test_retrieval/test_graph_rag_service.py
@@ -44,7 +44,7 @@ class TestGraphRagService:
await explain_callback([], "urn:trustgraph:prov:retrieval:test")
await explain_callback([], "urn:trustgraph:prov:selection:test")
await explain_callback([], "urn:trustgraph:prov:answer:test")
- return "A small domesticated mammal."
+ return "A small domesticated mammal.", {"in_token": None, "out_token": None, "model": None}
mock_rag_instance.query.side_effect = mock_query
@@ -79,8 +79,8 @@ class TestGraphRagService:
# Execute
await processor.on_request(msg, consumer, flow)
- # Verify: 6 messages sent (4 provenance + 1 chunk + 1 end_of_session)
- assert mock_response_producer.send.call_count == 6
+ # Verify: 5 messages sent (4 provenance + 1 combined chunk with end_of_session)
+ assert mock_response_producer.send.call_count == 5
# First 4 messages are explain (emitted in real-time during query)
for i in range(4):
@@ -88,17 +88,12 @@ class TestGraphRagService:
assert prov_msg.message_type == "explain"
assert prov_msg.explain_id is not None
- # 5th message is chunk with response
+ # 5th message is chunk with response and end_of_session
chunk_msg = mock_response_producer.send.call_args_list[4][0][0]
assert chunk_msg.message_type == "chunk"
assert chunk_msg.response == "A small domesticated mammal."
assert chunk_msg.end_of_stream is True
-
- # 6th message is empty chunk with end_of_session=True
- close_msg = mock_response_producer.send.call_args_list[5][0][0]
- assert close_msg.message_type == "chunk"
- assert close_msg.response == ""
- assert close_msg.end_of_session is True
+ assert chunk_msg.end_of_session is True
# Verify provenance triples were sent to provenance queue
assert mock_provenance_producer.send.call_count == 4
@@ -187,7 +182,7 @@ class TestGraphRagService:
async def mock_query(**kwargs):
# Don't call explain_callback
- return "Response text"
+ return "Response text", {"in_token": None, "out_token": None, "model": None}
mock_rag_instance.query.side_effect = mock_query
@@ -218,17 +213,12 @@ class TestGraphRagService:
# Execute
await processor.on_request(msg, consumer, flow)
- # Verify: 2 messages (chunk + empty chunk to close)
- assert mock_response_producer.send.call_count == 2
+ # Verify: 1 combined message (chunk with end_of_session)
+ assert mock_response_producer.send.call_count == 1
- # First is the response chunk
+ # Single message has response and end_of_session
chunk_msg = mock_response_producer.send.call_args_list[0][0][0]
assert chunk_msg.message_type == "chunk"
assert chunk_msg.response == "Response text"
assert chunk_msg.end_of_stream is True
-
- # Second is empty chunk to close session
- close_msg = mock_response_producer.send.call_args_list[1][0][0]
- assert close_msg.message_type == "chunk"
- assert close_msg.response == ""
- assert close_msg.end_of_session is True
+ assert chunk_msg.end_of_session is True
diff --git a/tests/unit/test_tables/__init__.py b/tests/unit/test_tables/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/unit/test_tables/test_knowledge_table_store.py b/tests/unit/test_tables/test_knowledge_table_store.py
new file mode 100644
index 00000000..5129b01e
--- /dev/null
+++ b/tests/unit/test_tables/test_knowledge_table_store.py
@@ -0,0 +1,197 @@
+"""
+Unit tests for KnowledgeTableStore row deserialization.
+
+Regression coverage: a previous version of get_graph_embeddings constructed
+EntityEmbeddings(vectors=ent[1]) — the schema field is `vector` (singular),
+so any real Cassandra row would crash on read. These tests bypass the live
+Cassandra connection entirely and exercise the row -> schema conversion
+with hand-built fake rows.
+"""
+
+import pytest
+from unittest.mock import Mock
+
+from trustgraph.tables.knowledge import KnowledgeTableStore
+from trustgraph.schema import (
+ EntityEmbeddings,
+ GraphEmbeddings,
+ Triples,
+ Triple,
+ Metadata,
+ IRI,
+ LITERAL,
+)
+
+
+def _make_store():
+ """
+ Build a KnowledgeTableStore without invoking __init__ (which connects
+ to Cassandra). Tests inject only the attributes the method under test
+ actually touches.
+ """
+ return KnowledgeTableStore.__new__(KnowledgeTableStore)
+
+
+class TestGetGraphEmbeddings:
+
+ @pytest.mark.asyncio
+ async def test_row_converts_to_entity_embeddings_with_singular_vector(self):
+ """
+ Cassandra rows return entities as a list of [entity_tuple, vector]
+ pairs in row[3]. The deserializer must construct EntityEmbeddings
+ with `vector=` (singular) — the schema field name. A previous
+ version used `vectors=` and TypeError'd at runtime.
+ """
+ # Arrange — fake row matching the get_triples_stmt result shape:
+ # row[0..2] are unused by the method, row[3] is the entities blob
+ fake_row = (
+ None, None, None,
+ [
+ # ((value, is_uri), vector)
+ (("http://example.org/alice", True), [0.1, 0.2, 0.3]),
+ (("http://example.org/bob", True), [0.4, 0.5, 0.6]),
+ (("a literal entity", False), [0.7, 0.8, 0.9]),
+ ],
+ )
+
+ store = _make_store()
+ store.cassandra = Mock()
+ store.cassandra.execute = Mock(return_value=[fake_row])
+ store.get_graph_embeddings_stmt = Mock()
+
+ received = []
+
+ async def receiver(msg):
+ received.append(msg)
+
+ # Act
+ await store.get_graph_embeddings(
+ user="alice",
+ document_id="doc-1",
+ receiver=receiver,
+ )
+
+ # Assert
+ store.cassandra.execute.assert_called_once_with(
+ store.get_graph_embeddings_stmt,
+ ("alice", "doc-1"),
+ )
+
+ assert len(received) == 1
+ ge = received[0]
+ assert isinstance(ge, GraphEmbeddings)
+ assert isinstance(ge.metadata, Metadata)
+ assert ge.metadata.id == "doc-1"
+ assert ge.metadata.user == "alice"
+
+ assert len(ge.entities) == 3
+ assert all(isinstance(e, EntityEmbeddings) for e in ge.entities)
+
+ # Vectors land in the singular `vector` field — this is the
+ # explicit regression assertion for the original bug.
+ assert ge.entities[0].vector == [0.1, 0.2, 0.3]
+ assert ge.entities[1].vector == [0.4, 0.5, 0.6]
+ assert ge.entities[2].vector == [0.7, 0.8, 0.9]
+
+ # Term type round-trips through tuple_to_term
+ assert ge.entities[0].entity.type == IRI
+ assert ge.entities[0].entity.iri == "http://example.org/alice"
+ assert ge.entities[1].entity.type == IRI
+ assert ge.entities[1].entity.iri == "http://example.org/bob"
+ assert ge.entities[2].entity.type == LITERAL
+ assert ge.entities[2].entity.value == "a literal entity"
+
+ @pytest.mark.asyncio
+ async def test_empty_entities_blob_yields_empty_list(self):
+ """row[3] being None / empty must produce a GraphEmbeddings with
+ no entities, not raise."""
+ fake_row = (None, None, None, None)
+
+ store = _make_store()
+ store.cassandra = Mock()
+ store.cassandra.execute = Mock(return_value=[fake_row])
+ store.get_graph_embeddings_stmt = Mock()
+
+ received = []
+
+ async def receiver(msg):
+ received.append(msg)
+
+ await store.get_graph_embeddings("u", "d", receiver)
+
+ assert len(received) == 1
+ assert received[0].entities == []
+
+ @pytest.mark.asyncio
+ async def test_multiple_rows_each_emit_one_message(self):
+ fake_rows = [
+ (None, None, None, [
+ (("http://example.org/a", True), [1.0]),
+ ]),
+ (None, None, None, [
+ (("http://example.org/b", True), [2.0]),
+ ]),
+ ]
+
+ store = _make_store()
+ store.cassandra = Mock()
+ store.cassandra.execute = Mock(return_value=fake_rows)
+ store.get_graph_embeddings_stmt = Mock()
+
+ received = []
+
+ async def receiver(msg):
+ received.append(msg)
+
+ await store.get_graph_embeddings("u", "d", receiver)
+
+ assert len(received) == 2
+ assert received[0].entities[0].entity.iri == "http://example.org/a"
+ assert received[0].entities[0].vector == [1.0]
+ assert received[1].entities[0].entity.iri == "http://example.org/b"
+ assert received[1].entities[0].vector == [2.0]
+
+
+class TestGetTriples:
+ """Bonus: the sibling get_triples path uses the same row[3] shape and
+ the same Metadata construction. Cover it for parity."""
+
+ @pytest.mark.asyncio
+ async def test_row_converts_to_triples(self):
+ # row[3] is a list of (s_val, s_uri, p_val, p_uri, o_val, o_uri)
+ fake_row = (
+ None, None, None,
+ [
+ (
+ "http://example.org/alice", True,
+ "http://example.org/knows", True,
+ "http://example.org/bob", True,
+ ),
+ ],
+ )
+
+ store = _make_store()
+ store.cassandra = Mock()
+ store.cassandra.execute = Mock(return_value=[fake_row])
+ store.get_triples_stmt = Mock()
+
+ received = []
+
+ async def receiver(msg):
+ received.append(msg)
+
+ await store.get_triples("alice", "doc-1", receiver)
+
+ assert len(received) == 1
+ triples_msg = received[0]
+ assert isinstance(triples_msg, Triples)
+ assert isinstance(triples_msg.metadata, Metadata)
+ assert triples_msg.metadata.id == "doc-1"
+ assert triples_msg.metadata.user == "alice"
+
+ assert len(triples_msg.triples) == 1
+ t = triples_msg.triples[0]
+ assert isinstance(t, Triple)
+ assert t.s.iri == "http://example.org/alice"
+ assert t.p.iri == "http://example.org/knows"
+ assert t.o.iri == "http://example.org/bob"
diff --git a/tests/unit/test_translators/__init__.py b/tests/unit/test_translators/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/unit/test_translators/test_document_embeddings_translator_roundtrip.py b/tests/unit/test_translators/test_document_embeddings_translator_roundtrip.py
new file mode 100644
index 00000000..72f4796b
--- /dev/null
+++ b/tests/unit/test_translators/test_document_embeddings_translator_roundtrip.py
@@ -0,0 +1,66 @@
+"""
+Round-trip unit tests for DocumentEmbeddingsTranslator.
+
+Regression coverage: a previous version of the decode side constructed
+ChunkEmbeddings(vectors=...) — the schema field is `vector` (singular),
+so any real DocumentEmbeddings message would crash on decode. The encode
+side already wrote `"vector"`, so encode→decode was asymmetric.
+"""
+
+import pytest
+
+from trustgraph.messaging.translators.document_loading import (
+ DocumentEmbeddingsTranslator,
+)
+from trustgraph.schema import (
+ DocumentEmbeddings,
+ ChunkEmbeddings,
+ Metadata,
+)
+
+
+@pytest.fixture
+def translator():
+ return DocumentEmbeddingsTranslator()
+
+
+@pytest.fixture
+def sample():
+ return DocumentEmbeddings(
+ metadata=Metadata(
+ id="doc-1",
+ root="",
+ user="alice",
+ collection="testcoll",
+ ),
+ chunks=[
+ ChunkEmbeddings(chunk_id="c1", vector=[0.1, 0.2, 0.3]),
+ ChunkEmbeddings(chunk_id="c2", vector=[0.4, 0.5, 0.6]),
+ ],
+ )
+
+
+class TestDocumentEmbeddingsTranslator:
+
+ def test_encode_uses_singular_vector_key(self, translator, sample):
+ encoded = translator.encode(sample)
+ chunks = encoded["chunks"]
+ assert all("vector" in c for c in chunks)
+ assert all("vectors" not in c for c in chunks)
+ assert chunks[0]["vector"] == [0.1, 0.2, 0.3]
+
+ def test_roundtrip_preserves_document_embeddings(self, translator, sample):
+ encoded = translator.encode(sample)
+ decoded = translator.decode(encoded)
+
+ assert isinstance(decoded, DocumentEmbeddings)
+ assert isinstance(decoded.metadata, Metadata)
+ assert decoded.metadata.id == "doc-1"
+ assert decoded.metadata.user == "alice"
+ assert decoded.metadata.collection == "testcoll"
+
+ assert len(decoded.chunks) == 2
+ assert decoded.chunks[0].chunk_id == "c1"
+ assert decoded.chunks[0].vector == [0.1, 0.2, 0.3]
+ assert decoded.chunks[1].chunk_id == "c2"
+ assert decoded.chunks[1].vector == [0.4, 0.5, 0.6]
diff --git a/tests/unit/test_translators/test_knowledge_translator_roundtrip.py b/tests/unit/test_translators/test_knowledge_translator_roundtrip.py
new file mode 100644
index 00000000..57e7ae17
--- /dev/null
+++ b/tests/unit/test_translators/test_knowledge_translator_roundtrip.py
@@ -0,0 +1,153 @@
+"""
+Round-trip unit tests for KnowledgeRequestTranslator.
+
+Regression coverage: a previous version of the decode side constructed
+EntityEmbeddings(vectors=...) — the schema field is `vector` (singular),
+so any real graph-embeddings KnowledgeRequest would crash on first
+message. The encode side already wrote `"vector"`, so encode→decode was
+asymmetric.
+
+These tests build a real KnowledgeRequest with graph-embeddings, encode
+it, decode the result, and assert the round-trip is lossless. They also
+exercise the triples path so any future schema drift in Metadata or
+Triples breaks the test.
+"""
+
+import pytest
+
+from trustgraph.messaging.translators.knowledge import KnowledgeRequestTranslator
+from trustgraph.schema import (
+ KnowledgeRequest,
+ GraphEmbeddings,
+ EntityEmbeddings,
+ Triples,
+ Triple,
+ Metadata,
+ Term,
+ IRI,
+)
+
+
+def _term_iri(uri):
+ return Term(type=IRI, iri=uri)
+
+
+@pytest.fixture
+def translator():
+ return KnowledgeRequestTranslator()
+
+
+@pytest.fixture
+def graph_embeddings_request():
+ return KnowledgeRequest(
+ operation="put-kg-core",
+ user="alice",
+ id="doc-1",
+ flow="default",
+ collection="testcoll",
+ graph_embeddings=GraphEmbeddings(
+ metadata=Metadata(
+ id="doc-1",
+ root="",
+ user="alice",
+ collection="testcoll",
+ ),
+ entities=[
+ EntityEmbeddings(
+ entity=_term_iri("http://example.org/alice"),
+ vector=[0.1, 0.2, 0.3],
+ ),
+ EntityEmbeddings(
+ entity=_term_iri("http://example.org/bob"),
+ vector=[0.4, 0.5, 0.6],
+ ),
+ ],
+ ),
+ )
+
+
+@pytest.fixture
+def triples_request():
+ return KnowledgeRequest(
+ operation="put-kg-core",
+ user="alice",
+ id="doc-1",
+ flow="default",
+ collection="testcoll",
+ triples=Triples(
+ metadata=Metadata(
+ id="doc-1",
+ root="",
+ user="alice",
+ collection="testcoll",
+ ),
+ triples=[
+ Triple(
+ s=_term_iri("http://example.org/alice"),
+ p=_term_iri("http://example.org/knows"),
+ o=_term_iri("http://example.org/bob"),
+ ),
+ ],
+ ),
+ )
+
+
+class TestKnowledgeRequestTranslatorGraphEmbeddings:
+
+ def test_encode_produces_singular_vector_key(
+ self, translator, graph_embeddings_request,
+ ):
+ """The wire key must be `vector`, never `vectors`."""
+ encoded = translator.encode(graph_embeddings_request)
+ entities = encoded["graph-embeddings"]["entities"]
+ assert all("vector" in e for e in entities)
+ assert all("vectors" not in e for e in entities)
+ assert entities[0]["vector"] == [0.1, 0.2, 0.3]
+
+ def test_roundtrip_preserves_graph_embeddings(
+ self, translator, graph_embeddings_request,
+ ):
+ """encode -> decode must be lossless for the GE branch."""
+ encoded = translator.encode(graph_embeddings_request)
+ decoded = translator.decode(encoded)
+
+ assert isinstance(decoded, KnowledgeRequest)
+ assert decoded.operation == "put-kg-core"
+ assert decoded.user == "alice"
+ assert decoded.id == "doc-1"
+ assert decoded.flow == "default"
+ assert decoded.collection == "testcoll"
+
+ assert decoded.graph_embeddings is not None
+ ge = decoded.graph_embeddings
+ assert isinstance(ge, GraphEmbeddings)
+ assert isinstance(ge.metadata, Metadata)
+ assert ge.metadata.id == "doc-1"
+ assert ge.metadata.user == "alice"
+ assert ge.metadata.collection == "testcoll"
+
+ assert len(ge.entities) == 2
+ assert ge.entities[0].vector == [0.1, 0.2, 0.3]
+ assert ge.entities[1].vector == [0.4, 0.5, 0.6]
+ assert ge.entities[0].entity.iri == "http://example.org/alice"
+ assert ge.entities[1].entity.iri == "http://example.org/bob"
+
+
+class TestKnowledgeRequestTranslatorTriples:
+
+ def test_roundtrip_preserves_triples(self, translator, triples_request):
+ encoded = translator.encode(triples_request)
+ decoded = translator.decode(encoded)
+
+ assert isinstance(decoded, KnowledgeRequest)
+ assert decoded.triples is not None
+ assert isinstance(decoded.triples.metadata, Metadata)
+ assert decoded.triples.metadata.id == "doc-1"
+ assert decoded.triples.metadata.user == "alice"
+ assert decoded.triples.metadata.collection == "testcoll"
+
+ assert len(decoded.triples.triples) == 1
+ t = decoded.triples.triples[0]
+ assert t.s.iri == "http://example.org/alice"
+ assert t.p.iri == "http://example.org/knows"
+ assert t.o.iri == "http://example.org/bob"
diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py
index 985bcbf1..c8c676c9 100644
--- a/tests/utils/__init__.py
+++ b/tests/utils/__init__.py
@@ -9,7 +9,7 @@ from .streaming_assertions import (
assert_streaming_content_matches,
assert_no_empty_chunks,
assert_streaming_error_handled,
- assert_chunk_types_valid,
+ assert_message_types_valid,
assert_streaming_latency_acceptable,
assert_callback_invoked,
)
@@ -23,7 +23,7 @@ __all__ = [
"assert_streaming_content_matches",
"assert_no_empty_chunks",
"assert_streaming_error_handled",
- "assert_chunk_types_valid",
+ "assert_message_types_valid",
"assert_streaming_latency_acceptable",
"assert_callback_invoked",
]
diff --git a/tests/utils/streaming_assertions.py b/tests/utils/streaming_assertions.py
index cc9164ed..945bb031 100644
--- a/tests/utils/streaming_assertions.py
+++ b/tests/utils/streaming_assertions.py
@@ -20,14 +20,14 @@ def assert_streaming_chunks_valid(chunks: List[Any], min_chunks: int = 1):
assert all(chunk is not None for chunk in chunks), "All chunks should be non-None"
-def assert_streaming_sequence(chunks: List[Dict[str, Any]], expected_sequence: List[str], key: str = "chunk_type"):
+def assert_streaming_sequence(chunks: List[Dict[str, Any]], expected_sequence: List[str], key: str = "message_type"):
"""
Assert that streaming chunks follow an expected sequence.
Args:
chunks: List of chunk dictionaries
expected_sequence: Expected sequence of chunk types/values
- key: Dictionary key to check (default: "chunk_type")
+ key: Dictionary key to check (default: "message_type")
"""
actual_sequence = [chunk.get(key) for chunk in chunks if key in chunk]
assert actual_sequence == expected_sequence, \
@@ -39,7 +39,7 @@ def assert_agent_streaming_chunks(chunks: List[Dict[str, Any]]):
Assert that agent streaming chunks have valid structure.
Validates:
- - All chunks have chunk_type field
+ - All chunks have message_type field
- All chunks have content field
- All chunks have end_of_message field
- All chunks have end_of_dialog field
@@ -51,15 +51,15 @@ def assert_agent_streaming_chunks(chunks: List[Dict[str, Any]]):
assert len(chunks) > 0, "Expected at least one chunk"
for i, chunk in enumerate(chunks):
- assert "chunk_type" in chunk, f"Chunk {i} missing chunk_type"
+ assert "message_type" in chunk, f"Chunk {i} missing message_type"
assert "content" in chunk, f"Chunk {i} missing content"
assert "end_of_message" in chunk, f"Chunk {i} missing end_of_message"
assert "end_of_dialog" in chunk, f"Chunk {i} missing end_of_dialog"
- # Validate chunk_type values
+ # Validate message_type values
valid_types = ["thought", "action", "observation", "final-answer"]
- assert chunk["chunk_type"] in valid_types, \
- f"Invalid chunk_type '{chunk['chunk_type']}' at index {i}"
+ assert chunk["message_type"] in valid_types, \
+ f"Invalid message_type '{chunk['message_type']}' at index {i}"
# Last chunk should signal end of dialog
assert chunks[-1]["end_of_dialog"] is True, \
@@ -175,7 +175,7 @@ def assert_streaming_error_handled(chunks: List[Dict[str, Any]], error_flag: str
"Error chunk should have completion flag set to True"
-def assert_chunk_types_valid(chunks: List[Dict[str, Any]], valid_types: List[str], type_key: str = "chunk_type"):
+def assert_message_types_valid(chunks: List[Dict[str, Any]], valid_types: List[str], type_key: str = "message_type"):
"""
Assert that all chunk types are from a valid set.
@@ -185,9 +185,9 @@ def assert_chunk_types_valid(chunks: List[Dict[str, Any]], valid_types: List[str
type_key: Dictionary key for chunk type
"""
for i, chunk in enumerate(chunks):
- chunk_type = chunk.get(type_key)
- assert chunk_type in valid_types, \
- f"Chunk {i} has invalid type '{chunk_type}', expected one of {valid_types}"
+ message_type = chunk.get(type_key)
+ assert message_type in valid_types, \
+ f"Chunk {i} has invalid type '{message_type}', expected one of {valid_types}"
def assert_streaming_latency_acceptable(chunk_timestamps: List[float], max_gap_seconds: float = 5.0):
diff --git a/trustgraph-base/pyproject.toml b/trustgraph-base/pyproject.toml
index 216ccbd6..4f1bce76 100644
--- a/trustgraph-base/pyproject.toml
+++ b/trustgraph-base/pyproject.toml
@@ -15,6 +15,7 @@ dependencies = [
"requests",
"python-logging-loki",
"pika",
+ "pyyaml",
]
classifiers = [
"Programming Language :: Python :: 3",
@@ -24,6 +25,9 @@ classifiers = [
[project.urls]
Homepage = "https://github.com/trustgraph-ai/trustgraph"
+[project.scripts]
+processor-group = "trustgraph.base.processor_group:run"
+
[tool.setuptools.packages.find]
include = ["trustgraph*"]
@@ -31,4 +35,4 @@ include = ["trustgraph*"]
"trustgraph.i18n.packs" = ["*.json"]
[tool.setuptools.dynamic]
-version = {attr = "trustgraph.base_version.__version__"}
\ No newline at end of file
+version = {attr = "trustgraph.base_version.__version__"}
diff --git a/trustgraph-base/trustgraph/api/__init__.py b/trustgraph-base/trustgraph/api/__init__.py
index 8b703dc7..2f44aad0 100644
--- a/trustgraph-base/trustgraph/api/__init__.py
+++ b/trustgraph-base/trustgraph/api/__init__.py
@@ -107,6 +107,7 @@ from .types import (
AgentObservation,
AgentAnswer,
RAGChunk,
+ TextCompletionResult,
ProvenanceEvent,
)
@@ -185,6 +186,7 @@ __all__ = [
"AgentObservation",
"AgentAnswer",
"RAGChunk",
+ "TextCompletionResult",
"ProvenanceEvent",
# Exceptions
diff --git a/trustgraph-base/trustgraph/api/async_flow.py b/trustgraph-base/trustgraph/api/async_flow.py
index 2ff37307..68899341 100644
--- a/trustgraph-base/trustgraph/api/async_flow.py
+++ b/trustgraph-base/trustgraph/api/async_flow.py
@@ -14,6 +14,8 @@ import aiohttp
import json
from typing import Optional, Dict, Any, List
+from . types import TextCompletionResult
+
from . exceptions import ProtocolException, ApplicationException
@@ -434,12 +436,11 @@ class AsyncFlowInstance:
return await self.request("agent", request_data)
- async def text_completion(self, system: str, prompt: str, **kwargs: Any) -> str:
+ async def text_completion(self, system: str, prompt: str, **kwargs: Any) -> TextCompletionResult:
"""
Generate text completion (non-streaming).
Generates a text response from an LLM given a system prompt and user prompt.
- Returns the complete response text.
Note: This method does not support streaming. For streaming text generation,
use AsyncSocketFlowInstance.text_completion() instead.
@@ -450,19 +451,19 @@ class AsyncFlowInstance:
**kwargs: Additional service-specific parameters
Returns:
- str: Complete generated text response
+ TextCompletionResult: Result with text, in_token, out_token, model
Example:
```python
async_flow = await api.async_flow()
flow = async_flow.id("default")
- # Generate text
- response = await flow.text_completion(
+ result = await flow.text_completion(
system="You are a helpful assistant.",
prompt="Explain quantum computing in simple terms."
)
- print(response)
+ print(result.text)
+ print(f"Tokens: {result.in_token} in, {result.out_token} out")
```
"""
request_data = {
@@ -473,7 +474,12 @@ class AsyncFlowInstance:
request_data.update(kwargs)
result = await self.request("text-completion", request_data)
- return result.get("response", "")
+ return TextCompletionResult(
+ text=result.get("response", ""),
+ in_token=result.get("in_token"),
+ out_token=result.get("out_token"),
+ model=result.get("model"),
+ )
async def graph_rag(self, query: str, user: str, collection: str,
max_subgraph_size: int = 1000, max_subgraph_count: int = 5,
diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py
index 7a239b07..6e5064ab 100644
--- a/trustgraph-base/trustgraph/api/async_socket_client.py
+++ b/trustgraph-base/trustgraph/api/async_socket_client.py
@@ -4,7 +4,7 @@ import asyncio
import websockets
from typing import Optional, Dict, Any, AsyncIterator, Union
-from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk
+from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, TextCompletionResult
from . exceptions import ProtocolException, ApplicationException
@@ -178,30 +178,32 @@ class AsyncSocketClient:
def _parse_chunk(self, resp: Dict[str, Any]):
"""Parse response chunk into appropriate type. Returns None for non-content messages."""
- chunk_type = resp.get("chunk_type")
message_type = resp.get("message_type")
# Handle new GraphRAG message format with message_type
if message_type == "provenance":
return None
- if chunk_type == "thought":
+ if message_type == "thought":
return AgentThought(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False)
)
- elif chunk_type == "observation":
+ elif message_type == "observation":
return AgentObservation(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False)
)
- elif chunk_type == "answer" or chunk_type == "final-answer":
+ elif message_type == "answer" or message_type == "final-answer":
return AgentAnswer(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False),
- end_of_dialog=resp.get("end_of_dialog", False)
+ end_of_dialog=resp.get("end_of_dialog", False),
+ in_token=resp.get("in_token"),
+ out_token=resp.get("out_token"),
+ model=resp.get("model"),
)
- elif chunk_type == "action":
+ elif message_type == "action":
return AgentThought(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False)
@@ -211,7 +213,10 @@ class AsyncSocketClient:
return RAGChunk(
content=content,
end_of_stream=resp.get("end_of_stream", False),
- error=None
+ error=None,
+ in_token=resp.get("in_token"),
+ out_token=resp.get("out_token"),
+ model=resp.get("model"),
)
async def aclose(self):
@@ -269,7 +274,11 @@ class AsyncSocketFlowInstance:
return await self.client._send_request("agent", self.flow_id, request)
async def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs):
- """Text completion with optional streaming"""
+ """Text completion with optional streaming.
+
+ Non-streaming: returns a TextCompletionResult with text and token counts.
+ Streaming: returns an async iterator of RAGChunk (with token counts on the final chunk).
+ """
request = {
"system": system,
"prompt": prompt,
@@ -281,13 +290,18 @@ class AsyncSocketFlowInstance:
return self._text_completion_streaming(request)
else:
result = await self.client._send_request("text-completion", self.flow_id, request)
- return result.get("response", "")
+ return TextCompletionResult(
+ text=result.get("response", ""),
+ in_token=result.get("in_token"),
+ out_token=result.get("out_token"),
+ model=result.get("model"),
+ )
async def _text_completion_streaming(self, request):
- """Helper for streaming text completion"""
+ """Helper for streaming text completion. Yields RAGChunk objects."""
async for chunk in self.client._send_request_streaming("text-completion", self.flow_id, request):
- if hasattr(chunk, 'content'):
- yield chunk.content
+ if isinstance(chunk, RAGChunk):
+ yield chunk
async def graph_rag(self, query: str, user: str, collection: str,
max_subgraph_size: int = 1000, max_subgraph_count: int = 5,
diff --git a/trustgraph-base/trustgraph/api/flow.py b/trustgraph-base/trustgraph/api/flow.py
index 0aa55347..7ee32dad 100644
--- a/trustgraph-base/trustgraph/api/flow.py
+++ b/trustgraph-base/trustgraph/api/flow.py
@@ -11,7 +11,7 @@ import base64
from .. knowledge import hash, Uri, Literal, QuotedTriple
from .. schema import IRI, LITERAL, TRIPLE
-from . types import Triple
+from . types import Triple, TextCompletionResult
from . exceptions import ProtocolException
@@ -360,16 +360,17 @@ class FlowInstance:
prompt: User prompt/question
Returns:
- str: Generated response text
+ TextCompletionResult: Result with text, in_token, out_token, model
Example:
```python
flow = api.flow().id("default")
- response = flow.text_completion(
+ result = flow.text_completion(
system="You are a helpful assistant",
prompt="What is quantum computing?"
)
- print(response)
+ print(result.text)
+ print(f"Tokens: {result.in_token} in, {result.out_token} out")
```
"""
@@ -379,10 +380,17 @@ class FlowInstance:
"prompt": prompt
}
- return self.request(
+ result = self.request(
"service/text-completion",
input
- )["response"]
+ )
+
+ return TextCompletionResult(
+ text=result.get("response", ""),
+ in_token=result.get("in_token"),
+ out_token=result.get("out_token"),
+ model=result.get("model"),
+ )
def agent(self, question, user="trustgraph", state=None, group=None, history=None):
"""
@@ -498,10 +506,17 @@ class FlowInstance:
"edge-limit": edge_limit,
}
- return self.request(
+ result = self.request(
"service/graph-rag",
input
- )["response"]
+ )
+
+ return TextCompletionResult(
+ text=result.get("response", ""),
+ in_token=result.get("in_token"),
+ out_token=result.get("out_token"),
+ model=result.get("model"),
+ )
def document_rag(
self, query, user="trustgraph", collection="default",
@@ -543,10 +558,17 @@ class FlowInstance:
"doc-limit": doc_limit,
}
- return self.request(
+ result = self.request(
"service/document-rag",
input
- )["response"]
+ )
+
+ return TextCompletionResult(
+ text=result.get("response", ""),
+ in_token=result.get("in_token"),
+ out_token=result.get("out_token"),
+ model=result.get("model"),
+ )
def embeddings(self, texts):
"""
diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py
index b6ceba00..c590c9b4 100644
--- a/trustgraph-base/trustgraph/api/socket_client.py
+++ b/trustgraph-base/trustgraph/api/socket_client.py
@@ -14,7 +14,7 @@ import websockets
from typing import Optional, Dict, Any, Iterator, Union, List
from threading import Lock
-from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk, ProvenanceEvent
+from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk, ProvenanceEvent, TextCompletionResult
from . exceptions import ProtocolException, raise_from_error_dict
@@ -360,41 +360,36 @@ class SocketClient:
def _parse_chunk(self, resp: Dict[str, Any], include_provenance: bool = False) -> Optional[StreamingChunk]:
"""Parse response chunk into appropriate type. Returns None for non-content messages."""
- chunk_type = resp.get("chunk_type")
message_type = resp.get("message_type")
- # Handle GraphRAG/DocRAG message format with message_type
if message_type == "explain":
if include_provenance:
return self._build_provenance_event(resp)
return None
- # Handle Agent message format with chunk_type="explain"
- if chunk_type == "explain":
- if include_provenance:
- return self._build_provenance_event(resp)
- return None
-
- if chunk_type == "thought":
+ if message_type == "thought":
return AgentThought(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False),
message_id=resp.get("message_id", ""),
)
- elif chunk_type == "observation":
+ elif message_type == "observation":
return AgentObservation(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False),
message_id=resp.get("message_id", ""),
)
- elif chunk_type == "answer" or chunk_type == "final-answer":
+ elif message_type == "answer" or message_type == "final-answer":
return AgentAnswer(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False),
end_of_dialog=resp.get("end_of_dialog", False),
message_id=resp.get("message_id", ""),
+ in_token=resp.get("in_token"),
+ out_token=resp.get("out_token"),
+ model=resp.get("model"),
)
- elif chunk_type == "action":
+ elif message_type == "action":
return AgentThought(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False)
@@ -404,7 +399,10 @@ class SocketClient:
return RAGChunk(
content=content,
end_of_stream=resp.get("end_of_stream", False),
- error=None
+ error=None,
+ in_token=resp.get("in_token"),
+ out_token=resp.get("out_token"),
+ model=resp.get("model"),
)
def _build_provenance_event(self, resp: Dict[str, Any]) -> ProvenanceEvent:
@@ -543,8 +541,12 @@ class SocketFlowInstance:
streaming=True, include_provenance=True
)
- def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[str, Iterator[str]]:
- """Execute text completion with optional streaming."""
+ def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
+ """Execute text completion with optional streaming.
+
+ Non-streaming: returns a TextCompletionResult with text and token counts.
+ Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
+ """
request = {
"system": system,
"prompt": prompt,
@@ -557,12 +559,17 @@ class SocketFlowInstance:
if streaming:
return self._text_completion_generator(result)
else:
- return result.get("response", "")
+ return TextCompletionResult(
+ text=result.get("response", ""),
+ in_token=result.get("in_token"),
+ out_token=result.get("out_token"),
+ model=result.get("model"),
+ )
- def _text_completion_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]:
+ def _text_completion_generator(self, result: Iterator[StreamingChunk]) -> Iterator[RAGChunk]:
for chunk in result:
- if hasattr(chunk, 'content'):
- yield chunk.content
+ if isinstance(chunk, RAGChunk):
+ yield chunk
def graph_rag(
self,
@@ -577,8 +584,12 @@ class SocketFlowInstance:
edge_limit: int = 25,
streaming: bool = False,
**kwargs: Any
- ) -> Union[str, Iterator[str]]:
- """Execute graph-based RAG query with optional streaming."""
+ ) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
+ """Execute graph-based RAG query with optional streaming.
+
+ Non-streaming: returns a TextCompletionResult with text and token counts.
+ Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
+ """
request = {
"query": query,
"user": user,
@@ -598,7 +609,12 @@ class SocketFlowInstance:
if streaming:
return self._rag_generator(result)
else:
- return result.get("response", "")
+ return TextCompletionResult(
+ text=result.get("response", ""),
+ in_token=result.get("in_token"),
+ out_token=result.get("out_token"),
+ model=result.get("model"),
+ )
def graph_rag_explain(
self,
@@ -642,8 +658,12 @@ class SocketFlowInstance:
doc_limit: int = 10,
streaming: bool = False,
**kwargs: Any
- ) -> Union[str, Iterator[str]]:
- """Execute document-based RAG query with optional streaming."""
+ ) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
+ """Execute document-based RAG query with optional streaming.
+
+ Non-streaming: returns a TextCompletionResult with text and token counts.
+ Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
+ """
request = {
"query": query,
"user": user,
@@ -658,7 +678,12 @@ class SocketFlowInstance:
if streaming:
return self._rag_generator(result)
else:
- return result.get("response", "")
+ return TextCompletionResult(
+ text=result.get("response", ""),
+ in_token=result.get("in_token"),
+ out_token=result.get("out_token"),
+ model=result.get("model"),
+ )
def document_rag_explain(
self,
@@ -684,10 +709,10 @@ class SocketFlowInstance:
streaming=True, include_provenance=True
)
- def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]:
+ def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[RAGChunk]:
for chunk in result:
- if hasattr(chunk, 'content'):
- yield chunk.content
+ if isinstance(chunk, RAGChunk):
+ yield chunk
def prompt(
self,
@@ -695,8 +720,12 @@ class SocketFlowInstance:
variables: Dict[str, str],
streaming: bool = False,
**kwargs: Any
- ) -> Union[str, Iterator[str]]:
- """Execute a prompt template with optional streaming."""
+ ) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
+ """Execute a prompt template with optional streaming.
+
+ Non-streaming: returns a TextCompletionResult with text and token counts.
+ Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
+ """
request = {
"id": id,
"variables": variables,
@@ -709,7 +738,12 @@ class SocketFlowInstance:
if streaming:
return self._rag_generator(result)
else:
- return result.get("response", "")
+ return TextCompletionResult(
+ text=result.get("text", result.get("response", "")),
+ in_token=result.get("in_token"),
+ out_token=result.get("out_token"),
+ model=result.get("model"),
+ )
def graph_embeddings_query(
self,
diff --git a/trustgraph-base/trustgraph/api/types.py b/trustgraph-base/trustgraph/api/types.py
index 55635584..f5987b0e 100644
--- a/trustgraph-base/trustgraph/api/types.py
+++ b/trustgraph-base/trustgraph/api/types.py
@@ -149,10 +149,10 @@ class AgentThought(StreamingChunk):
Attributes:
content: Agent's thought text
end_of_message: True if this completes the current thought
- chunk_type: Always "thought"
+ message_type: Always "thought"
message_id: Provenance URI of the entity being built
"""
- chunk_type: str = "thought"
+ message_type: str = "thought"
message_id: str = ""
@dataclasses.dataclass
@@ -166,10 +166,10 @@ class AgentObservation(StreamingChunk):
Attributes:
content: Observation text describing tool results
end_of_message: True if this completes the current observation
- chunk_type: Always "observation"
+ message_type: Always "observation"
message_id: Provenance URI of the entity being built
"""
- chunk_type: str = "observation"
+ message_type: str = "observation"
message_id: str = ""
@dataclasses.dataclass
@@ -184,11 +184,14 @@ class AgentAnswer(StreamingChunk):
content: Answer text
end_of_message: True if this completes the current answer segment
end_of_dialog: True if this completes the entire agent interaction
- chunk_type: Always "final-answer"
+ message_type: Always "final-answer"
"""
- chunk_type: str = "final-answer"
+ message_type: str = "final-answer"
end_of_dialog: bool = False
message_id: str = ""
+ in_token: Optional[int] = None
+ out_token: Optional[int] = None
+ model: Optional[str] = None
@dataclasses.dataclass
class RAGChunk(StreamingChunk):
@@ -202,11 +205,37 @@ class RAGChunk(StreamingChunk):
content: Generated text content
end_of_stream: True if this is the final chunk of the stream
error: Optional error information if an error occurred
- chunk_type: Always "rag"
+ in_token: Input token count (populated on the final chunk, 0 otherwise)
+ out_token: Output token count (populated on the final chunk, 0 otherwise)
+ model: Model identifier (populated on the final chunk, empty otherwise)
+ message_type: Always "rag"
"""
- chunk_type: str = "rag"
+ message_type: str = "rag"
end_of_stream: bool = False
error: Optional[Dict[str, str]] = None
+ in_token: Optional[int] = None
+ out_token: Optional[int] = None
+ model: Optional[str] = None
+
+@dataclasses.dataclass
+class TextCompletionResult:
+ """
+ Result from a text completion request.
+
+ Returned by text_completion() in both streaming and non-streaming modes.
+ In streaming mode, text is None (chunks are delivered via the iterator).
+ In non-streaming mode, text contains the complete response.
+
+ Attributes:
+ text: Complete response text (None in streaming mode)
+ in_token: Input token count (None if not available)
+ out_token: Output token count (None if not available)
+ model: Model identifier (None if not available)
+ """
+ text: Optional[str]
+ in_token: Optional[int] = None
+ out_token: Optional[int] = None
+ model: Optional[str] = None
@dataclasses.dataclass
class ProvenanceEvent:
diff --git a/trustgraph-base/trustgraph/base/__init__.py b/trustgraph-base/trustgraph/base/__init__.py
index 24b6c1f0..ce17a585 100644
--- a/trustgraph-base/trustgraph/base/__init__.py
+++ b/trustgraph-base/trustgraph/base/__init__.py
@@ -18,8 +18,10 @@ from . librarian_client import LibrarianClient
from . chunking_service import ChunkingService
from . embeddings_service import EmbeddingsService
from . embeddings_client import EmbeddingsClientSpec
-from . text_completion_client import TextCompletionClientSpec
-from . prompt_client import PromptClientSpec
+from . text_completion_client import (
+ TextCompletionClientSpec, TextCompletionClient, TextCompletionResult,
+)
+from . prompt_client import PromptClientSpec, PromptClient, PromptResult
from . triples_store_service import TriplesStoreService
from . graph_embeddings_store_service import GraphEmbeddingsStoreService
from . document_embeddings_store_service import DocumentEmbeddingsStoreService
diff --git a/trustgraph-base/trustgraph/base/agent_client.py b/trustgraph-base/trustgraph/base/agent_client.py
index d73d03b9..393864fa 100644
--- a/trustgraph-base/trustgraph/base/agent_client.py
+++ b/trustgraph-base/trustgraph/base/agent_client.py
@@ -30,19 +30,19 @@ class AgentClient(RequestResponse):
raise RuntimeError(resp.error.message)
# Handle thought chunks
- if resp.chunk_type == 'thought':
+ if resp.message_type == 'thought':
if think:
await think(resp.content, resp.end_of_message)
return False # Continue receiving
# Handle observation chunks
- if resp.chunk_type == 'observation':
+ if resp.message_type == 'observation':
if observe:
await observe(resp.content, resp.end_of_message)
return False # Continue receiving
# Handle answer chunks
- if resp.chunk_type == 'answer':
+ if resp.message_type == 'answer':
if resp.content:
accumulated_answer.append(resp.content)
if answer_callback:
diff --git a/trustgraph-base/trustgraph/base/backend.py b/trustgraph-base/trustgraph/base/backend.py
index 9b9a42af..f0d6b397 100644
--- a/trustgraph-base/trustgraph/base/backend.py
+++ b/trustgraph-base/trustgraph/base/backend.py
@@ -58,6 +58,18 @@ class BackendProducer(Protocol):
class BackendConsumer(Protocol):
"""Protocol for backend-specific consumer."""
+ def ensure_connected(self) -> None:
+ """
+ Eagerly establish the underlying connection and bind the queue.
+
+ Backends that lazily connect on first receive() must implement this
+ so that callers can guarantee the consumer is fully bound — and
+ therefore able to receive responses — before any related request is
+ published. Backends that connect at construction time may make this
+ a no-op.
+ """
+ ...
+
def receive(self, timeout_millis: int = 2000) -> Message:
"""
Receive a message from the topic.
diff --git a/trustgraph-base/trustgraph/base/chunking_service.py b/trustgraph-base/trustgraph/base/chunking_service.py
index d4bf4cd4..4bd78428 100644
--- a/trustgraph-base/trustgraph/base/chunking_service.py
+++ b/trustgraph-base/trustgraph/base/chunking_service.py
@@ -88,14 +88,14 @@ class ChunkingService(FlowProcessor):
chunk_overlap = default_chunk_overlap
try:
- cs = flow.parameters.get("chunk-size")
+ cs = flow("chunk-size")
if cs is not None:
chunk_size = int(cs)
except Exception as e:
logger.warning(f"Could not parse chunk-size parameter: {e}")
try:
- co = flow.parameters.get("chunk-overlap")
+ co = flow("chunk-overlap")
if co is not None:
chunk_overlap = int(co)
except Exception as e:
diff --git a/trustgraph-base/trustgraph/base/logging.py b/trustgraph-base/trustgraph/base/logging.py
index 7bab6091..93cd8fa5 100644
--- a/trustgraph-base/trustgraph/base/logging.py
+++ b/trustgraph-base/trustgraph/base/logging.py
@@ -8,12 +8,51 @@ ensuring consistent log formats, levels, and command-line arguments.
Supports dual output to console and Loki for centralized log aggregation.
"""
+import contextvars
import logging
import logging.handlers
from queue import Queue
import os
+# The current processor id for this task context. Read by
+# _ProcessorIdFilter to stamp every LogRecord with its owning
+# processor, and read by logging_loki's emitter via record.tags
+# to label log lines in Loki. ContextVar so asyncio subtasks
+# inherit their parent supervisor's processor id automatically.
+current_processor_id = contextvars.ContextVar(
+ "current_processor_id", default="unknown"
+)
+
+
+def set_processor_id(pid):
+ """Set the processor id for the current task context.
+
+ All subsequent log records emitted from this task — and any
+ asyncio tasks spawned from it — will be tagged with this id
+ in the console format and in Loki labels.
+ """
+ current_processor_id.set(pid)
+
+
+class _ProcessorIdFilter(logging.Filter):
+ """Stamps every LogRecord with processor_id from the contextvar.
+
+ Attaches two fields to each record:
+ record.processor_id — used by the console format string
+ record.tags — merged into Loki labels by logging_loki's
+ emitter (it reads record.tags and combines
+ with the handler's static tags)
+ """
+
+ def filter(self, record):
+ pid = current_processor_id.get()
+ record.processor_id = pid
+ existing = getattr(record, "tags", None) or {}
+ record.tags = {**existing, "processor": pid}
+ return True
+
+
def add_logging_args(parser):
"""
Add standard logging arguments to an argument parser.
@@ -87,12 +126,15 @@ def setup_logging(args):
loki_url = args.get('loki_url', 'http://loki:3100/loki/api/v1/push')
loki_username = args.get('loki_username')
loki_password = args.get('loki_password')
- processor_id = args.get('id') # Processor identity (e.g., "config-svc", "text-completion")
try:
from logging_loki import LokiHandler
- # Create Loki handler with optional authentication and processor label
+ # Create Loki handler with optional authentication. The
+ # processor label is NOT baked in here — it's stamped onto
+ # each record by _ProcessorIdFilter reading the task-local
+ # contextvar, and logging_loki's emitter reads record.tags
+ # to build per-record Loki labels.
loki_handler_kwargs = {
'url': loki_url,
'version': "1",
@@ -101,10 +143,6 @@ def setup_logging(args):
if loki_username and loki_password:
loki_handler_kwargs['auth'] = (loki_username, loki_password)
- # Add processor label if available (for consistency with Prometheus metrics)
- if processor_id:
- loki_handler_kwargs['tags'] = {'processor': processor_id}
-
loki_handler = LokiHandler(**loki_handler_kwargs)
# Wrap in QueueHandler for non-blocking operation
@@ -133,23 +171,44 @@ def setup_logging(args):
print(f"WARNING: Failed to setup Loki logging: {e}")
print("Continuing with console-only logging")
- # Get processor ID for log formatting (use 'unknown' if not available)
- processor_id = args.get('id', 'unknown')
-
- # Configure logging with all handlers
- # Use processor ID as the primary identifier in logs
+ # Configure logging with all handlers. The processor id comes
+ # from _ProcessorIdFilter (via contextvar) and is injected into
+ # each record as record.processor_id. The format string reads
+ # that attribute on every emit.
logging.basicConfig(
level=getattr(logging, log_level.upper()),
- format=f'%(asctime)s - {processor_id} - %(levelname)s - %(message)s',
+ format='%(asctime)s - %(processor_id)s - %(levelname)s - %(message)s',
handlers=handlers,
force=True # Force reconfiguration if already configured
)
- # Prevent recursive logging from Loki's HTTP client
- if loki_enabled and queue_listener:
- # Disable urllib3 logging to prevent infinite loop
- logging.getLogger('urllib3').setLevel(logging.WARNING)
- logging.getLogger('urllib3.connectionpool').setLevel(logging.WARNING)
+ # Attach the processor-id filter to every handler so all records
+ # passing through any sink get stamped (console, queue→loki,
+ # future handlers). Filters on handlers run regardless of which
+ # logger originated the record, so logs from pika, cassandra,
+ # processor code, etc. all pass through it.
+ processor_filter = _ProcessorIdFilter()
+ for h in handlers:
+ h.addFilter(processor_filter)
+
+ # Seed the contextvar from --id if one was supplied. In group
+ # mode --id isn't present; the processor_group supervisor sets
+ # it per task. In standalone mode AsyncProcessor.launch provides
+ # it via argparse default.
+ if args.get('id'):
+ set_processor_id(args['id'])
+
+ # Silence noisy third-party library loggers. These emit INFO-level
+ # chatter (connection churn, channel open/close, driver warnings) that
+ # drowns the useful signal and can't be attributed to a specific
+ # processor anyway. WARNING and above still propagate.
+ for noisy in (
+ 'pika',
+ 'cassandra',
+ 'urllib3',
+ 'urllib3.connectionpool',
+ ):
+ logging.getLogger(noisy).setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
logger.info(f"Logging configured with level: {log_level}")
diff --git a/trustgraph-base/trustgraph/base/processor_group.py b/trustgraph-base/trustgraph/base/processor_group.py
new file mode 100644
index 00000000..d27b82c4
--- /dev/null
+++ b/trustgraph-base/trustgraph/base/processor_group.py
@@ -0,0 +1,204 @@
+
+# Multi-processor group runner. Runs multiple AsyncProcessor descendants
+# as concurrent tasks inside a single process, sharing one event loop,
+# one Prometheus HTTP server, and one pub/sub backend pool.
+#
+# Intended for dev and resource-constrained deployments. Scale deployments
+# should continue to use per-processor endpoints.
+#
+# Group config is a YAML or JSON file with shape:
+#
+# processors:
+# - class: trustgraph.extract.kg.definitions.extract.Processor
+# params:
+# id: kg-extract-definitions
+# triples_batch_size: 1000
+# - class: trustgraph.chunking.recursive.Processor
+# params:
+# id: chunker-recursive
+#
+# Each entry's params are passed directly to the class constructor alongside
+# the shared taskgroup. Defaults live inside each processor class.
+
+import argparse
+import asyncio
+import importlib
+import json
+import logging
+import time
+
+from prometheus_client import start_http_server
+
+from . logging import add_logging_args, setup_logging, set_processor_id
+
+logger = logging.getLogger(__name__)
+
+
+def _load_config(path):
+ with open(path) as f:
+ text = f.read()
+ if path.endswith((".yaml", ".yml")):
+ import yaml
+ return yaml.safe_load(text)
+ return json.loads(text)
+
+
+def _resolve_class(dotted):
+ module_path, _, class_name = dotted.rpartition(".")
+ if not module_path:
+ raise ValueError(
+ f"Processor class must be a dotted path, got {dotted!r}"
+ )
+ module = importlib.import_module(module_path)
+ return getattr(module, class_name)
+
+
+RESTART_DELAY_SECONDS = 4
+
+
+async def _supervise(entry):
+ """Run one processor with its own nested TaskGroup, restarting on any
+ failure. Each processor is isolated from its siblings — a crash here
+ does not propagate to the outer group."""
+
+ pid = entry["params"]["id"]
+ class_path = entry["class"]
+
+ # Stamp the contextvar for this supervisor task. Every log
+ # record emitted from this task — and from any inner TaskGroup
+ # child created by the processor — inherits this id via
+ # contextvar propagation. Siblings in the outer group set
+ # their own id in their own task context and do not interfere.
+ set_processor_id(pid)
+
+ while True:
+
+ try:
+
+ async with asyncio.TaskGroup() as inner_tg:
+
+ cls = _resolve_class(class_path)
+ params = dict(entry.get("params", {}))
+ params["taskgroup"] = inner_tg
+
+ logger.info(f"Starting {class_path} as {pid}")
+
+ p = cls(**params)
+ await p.start()
+ inner_tg.create_task(p.run())
+
+ # Clean exit — processor's run() returned without raising.
+ # Treat as a transient shutdown and restart, matching the
+ # behaviour of per-container `restart: on-failure`.
+ logger.warning(
+ f"Processor {pid} exited cleanly, will restart"
+ )
+
+ except asyncio.CancelledError:
+ logger.info(f"Processor {pid} cancelled")
+ raise
+
+ except BaseExceptionGroup as eg:
+ for e in eg.exceptions:
+ logger.error(
+ f"Processor {pid} failure: {type(e).__name__}: {e}",
+ exc_info=e,
+ )
+
+ except Exception as e:
+ logger.error(
+ f"Processor {pid} failure: {type(e).__name__}: {e}",
+ exc_info=True,
+ )
+
+ logger.info(
+ f"Restarting {pid} in {RESTART_DELAY_SECONDS}s..."
+ )
+ await asyncio.sleep(RESTART_DELAY_SECONDS)
+
+
+async def run_group(config):
+
+ entries = config.get("processors", [])
+ if not entries:
+ raise RuntimeError("Group config has no processors")
+
+ seen_ids = set()
+ for entry in entries:
+ pid = entry.get("params", {}).get("id")
+ if pid is None:
+ raise RuntimeError(
+ f"Entry {entry.get('class')!r} missing params.id — "
+ f"required for metrics labelling"
+ )
+ if pid in seen_ids:
+ raise RuntimeError(f"Duplicate processor id {pid!r} in group")
+ seen_ids.add(pid)
+
+ async with asyncio.TaskGroup() as outer_tg:
+ for entry in entries:
+ outer_tg.create_task(_supervise(entry))
+
+
+def run():
+
+ parser = argparse.ArgumentParser(
+ prog="processor-group",
+ description="Run multiple processors as tasks in one process",
+ )
+
+ parser.add_argument(
+ "-c", "--config",
+ required=True,
+ help="Path to group config file (JSON or YAML)",
+ )
+
+ parser.add_argument(
+ "--metrics",
+ action=argparse.BooleanOptionalAction,
+ default=True,
+ help="Metrics enabled (default: true)",
+ )
+
+ parser.add_argument(
+ "-P", "--metrics-port",
+ type=int,
+ default=8000,
+ help="Prometheus metrics port (default: 8000)",
+ )
+
+ add_logging_args(parser)
+
+ args = vars(parser.parse_args())
+
+ setup_logging(args)
+
+ config = _load_config(args["config"])
+
+ if args["metrics"]:
+ start_http_server(args["metrics_port"])
+
+ while True:
+
+ logger.info("Starting group...")
+
+ try:
+ asyncio.run(run_group(config))
+
+ except KeyboardInterrupt:
+ logger.info("Keyboard interrupt.")
+ return
+
+ except ExceptionGroup as e:
+ logger.error("Exception group:")
+ for se in e.exceptions:
+ logger.error(f" Type: {type(se)}")
+ logger.error(f" Exception: {se}", exc_info=se)
+
+ except Exception as e:
+ logger.error(f"Type: {type(e)}")
+ logger.error(f"Exception: {e}", exc_info=True)
+
+ logger.warning("Will retry...")
+ time.sleep(4)
+ logger.info("Retrying...")
diff --git a/trustgraph-base/trustgraph/base/prompt_client.py b/trustgraph-base/trustgraph/base/prompt_client.py
index 6859a9f0..853e7e66 100644
--- a/trustgraph-base/trustgraph/base/prompt_client.py
+++ b/trustgraph-base/trustgraph/base/prompt_client.py
@@ -1,10 +1,22 @@
import json
import asyncio
+from dataclasses import dataclass
+from typing import Optional, Any
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import PromptRequest, PromptResponse
+@dataclass
+class PromptResult:
+ response_type: str # "text", "json", or "jsonl"
+ text: Optional[str] = None # populated for "text"
+ object: Any = None # populated for "json"
+ objects: Optional[list] = None # populated for "jsonl"
+ in_token: Optional[int] = None
+ out_token: Optional[int] = None
+ model: Optional[str] = None
+
class PromptClient(RequestResponse):
async def prompt(self, id, variables, timeout=600, streaming=False, chunk_callback=None):
@@ -26,17 +38,40 @@ class PromptClient(RequestResponse):
if resp.error:
raise RuntimeError(resp.error.message)
- if resp.text: return resp.text
+ if resp.text:
+ return PromptResult(
+ response_type="text",
+ text=resp.text,
+ in_token=resp.in_token,
+ out_token=resp.out_token,
+ model=resp.model,
+ )
- return json.loads(resp.object)
+ parsed = json.loads(resp.object)
+
+ if isinstance(parsed, list):
+ return PromptResult(
+ response_type="jsonl",
+ objects=parsed,
+ in_token=resp.in_token,
+ out_token=resp.out_token,
+ model=resp.model,
+ )
+
+ return PromptResult(
+ response_type="json",
+ object=parsed,
+ in_token=resp.in_token,
+ out_token=resp.out_token,
+ model=resp.model,
+ )
else:
- last_text = ""
- last_object = None
+ last_resp = None
async def forward_chunks(resp):
- nonlocal last_text, last_object
+ nonlocal last_resp
if resp.error:
raise RuntimeError(resp.error.message)
@@ -44,14 +79,13 @@ class PromptClient(RequestResponse):
end_stream = getattr(resp, 'end_of_stream', False)
if resp.text is not None:
- last_text = resp.text
if chunk_callback:
if asyncio.iscoroutinefunction(chunk_callback):
await chunk_callback(resp.text, end_stream)
else:
chunk_callback(resp.text, end_stream)
- elif resp.object:
- last_object = resp.object
+
+ last_resp = resp
return end_stream
@@ -70,10 +104,36 @@ class PromptClient(RequestResponse):
timeout=timeout
)
- if last_text:
- return last_text
+ if last_resp is None:
+ return PromptResult(response_type="text")
- return json.loads(last_object) if last_object else None
+ if last_resp.object:
+ parsed = json.loads(last_resp.object)
+
+ if isinstance(parsed, list):
+ return PromptResult(
+ response_type="jsonl",
+ objects=parsed,
+ in_token=last_resp.in_token,
+ out_token=last_resp.out_token,
+ model=last_resp.model,
+ )
+
+ return PromptResult(
+ response_type="json",
+ object=parsed,
+ in_token=last_resp.in_token,
+ out_token=last_resp.out_token,
+ model=last_resp.model,
+ )
+
+ return PromptResult(
+ response_type="text",
+ text=last_resp.text,
+ in_token=last_resp.in_token,
+ out_token=last_resp.out_token,
+ model=last_resp.model,
+ )
async def extract_definitions(self, text, timeout=600):
return await self.prompt(
@@ -152,4 +212,3 @@ class PromptClientSpec(RequestResponseSpec):
response_schema = PromptResponse,
impl = PromptClient,
)
-
diff --git a/trustgraph-base/trustgraph/base/pulsar_backend.py b/trustgraph-base/trustgraph/base/pulsar_backend.py
index a567191e..6f125399 100644
--- a/trustgraph-base/trustgraph/base/pulsar_backend.py
+++ b/trustgraph-base/trustgraph/base/pulsar_backend.py
@@ -72,6 +72,16 @@ class PulsarBackendConsumer:
self._consumer = pulsar_consumer
self._schema_cls = schema_cls
+ def ensure_connected(self) -> None:
+ """No-op for Pulsar.
+
+ PulsarBackend.create_consumer() calls client.subscribe() which is
+ synchronous and returns a fully-subscribed consumer, so the
+ consumer is already ready by the time this object is constructed.
+ Defined for parity with the BackendConsumer protocol used by
+ Subscriber.start()'s readiness barrier."""
+ pass
+
def receive(self, timeout_millis: int = 2000) -> Message:
"""Receive a message. Raises TimeoutError if no message available."""
try:
diff --git a/trustgraph-base/trustgraph/base/rabbitmq_backend.py b/trustgraph-base/trustgraph/base/rabbitmq_backend.py
index 3fafcead..7de51a0a 100644
--- a/trustgraph-base/trustgraph/base/rabbitmq_backend.py
+++ b/trustgraph-base/trustgraph/base/rabbitmq_backend.py
@@ -214,16 +214,43 @@ class RabbitMQBackendConsumer:
and self._channel.is_open
)
+ def ensure_connected(self) -> None:
+ """Eagerly declare and bind the queue.
+
+ Without this, the queue is only declared lazily on the first
+ receive() call. For request/response with ephemeral per-subscriber
+ response queues that is a race: a request published before the
+ response queue is bound will have its reply silently dropped by
+ the broker. Subscriber.start() calls this so callers get a hard
+ readiness barrier."""
+ if not self._is_alive():
+ self._connect()
+
def receive(self, timeout_millis: int = 2000) -> Message:
- """Receive a message. Raises TimeoutError if none available."""
+ """Receive a message. Raises TimeoutError if none available.
+
+ Loop ordering matters: check _incoming at the TOP of each
+ iteration, not as the loop condition. process_data_events
+ may dispatch a message via the _on_message callback during
+ the pump; we must re-check _incoming on the next iteration
+ before giving up on the deadline. The previous control
+ flow (`while deadline: check; pump`) could lose a wakeup if
+ the pump consumed the remainder of the window — the
+ `while` check would fail before `_incoming` was re-read,
+ leaving a just-dispatched message stranded until the next
+ receive() call one full poll cycle later.
+ """
if not self._is_alive():
self._connect()
timeout_seconds = timeout_millis / 1000.0
deadline = time.monotonic() + timeout_seconds
- while time.monotonic() < deadline:
- # Check if a message was already delivered
+ while True:
+ # Check if a message has been dispatched to our queue.
+ # This catches both (a) messages dispatched before this
+ # receive() was called and (b) messages dispatched
+ # during the previous iteration's process_data_events.
try:
method, properties, body = self._incoming.get_nowait()
return RabbitMQMessage(
@@ -232,14 +259,16 @@ class RabbitMQBackendConsumer:
except queue.Empty:
pass
- # Drive pika's I/O — delivers messages and processes heartbeats
remaining = deadline - time.monotonic()
- if remaining > 0:
- self._connection.process_data_events(
- time_limit=min(0.1, remaining),
- )
+ if remaining <= 0:
+ raise TimeoutError("No message received within timeout")
- raise TimeoutError("No message received within timeout")
+ # Drive pika's I/O. Any messages delivered during this
+ # call land in _incoming via _on_message; the next
+ # iteration of this loop catches them at the top.
+ self._connection.process_data_events(
+ time_limit=min(0.1, remaining),
+ )
def acknowledge(self, message: Message) -> None:
if isinstance(message, RabbitMQMessage) and message._method:
diff --git a/trustgraph-base/trustgraph/base/subscriber.py b/trustgraph-base/trustgraph/base/subscriber.py
index 8c68e51c..82ffe8e2 100644
--- a/trustgraph-base/trustgraph/base/subscriber.py
+++ b/trustgraph-base/trustgraph/base/subscriber.py
@@ -41,14 +41,55 @@ class Subscriber:
self.consumer = None
self.executor = None
+ # Readiness barrier — completed by run() once the underlying
+ # backend consumer is fully connected and bound. start() awaits
+ # this so callers know any subsequently published request will
+ # have a queue ready to receive its response. Without this,
+ # ephemeral per-subscriber response queues (RabbitMQ auto-delete
+ # exclusive queues) would race the request and lose the reply.
+ # A Future is used (rather than an Event) so that a first-attempt
+ # connection failure can be propagated to start() as an exception.
+ self._ready = None # created in start() so we have a running loop
+
def __del__(self):
self.running = False
async def start(self):
+ self._ready = asyncio.get_event_loop().create_future()
self.task = asyncio.create_task(self.run())
+ # Block until run() signals readiness OR exits. The future
+ # carries the outcome of the first connect attempt: a value on
+ # success, an exception on first-attempt failure. If run() exits
+ # without ever signalling (e.g. cancelled, or a code path bug),
+ # we surface that as a clear RuntimeError rather than hanging
+ # forever waiting on the future.
+ ready_wait = asyncio.ensure_future(
+ asyncio.shield(self._ready)
+ )
+ try:
+ await asyncio.wait(
+ {self.task, ready_wait},
+ return_when=asyncio.FIRST_COMPLETED,
+ )
+ finally:
+ ready_wait.cancel()
+
+ if self._ready.done():
+ # Re-raise first-attempt connect failure if any.
+ self._ready.result()
+ return
+
+ # run() exited before _ready was settled. Propagate its exception
+ # if it had one, otherwise raise a generic readiness error.
+ if self.task.done() and self.task.exception() is not None:
+ raise self.task.exception()
+ raise RuntimeError(
+ "Subscriber.run() exited before signalling readiness"
+ )
+
async def stop(self):
"""Initiate graceful shutdown with draining"""
self.running = False
@@ -66,6 +107,7 @@ class Subscriber:
async def run(self):
"""Enhanced run method with integrated draining logic"""
+ first_attempt = True
while self.running or self.draining:
if self.metrics:
@@ -87,10 +129,27 @@ class Subscriber:
),
)
+ # Eagerly bind the queue. For backends that connect
+ # lazily on first receive (RabbitMQ), this is what
+ # closes the request/response setup race — without
+ # it the response queue is not bound until later and
+ # any reply published in the meantime is dropped.
+ await loop.run_in_executor(
+ self.executor,
+ lambda: self.consumer.ensure_connected(),
+ )
+
if self.metrics:
self.metrics.state("running")
logger.info("Subscriber running...")
+
+ # Signal start() that the consumer is ready. This must
+ # happen AFTER ensure_connected() above so callers can
+ # safely publish requests immediately after start() returns.
+ if first_attempt and not self._ready.done():
+ self._ready.set_result(None)
+ first_attempt = False
drain_end_time = None
while self.running or self.draining:
@@ -162,6 +221,16 @@ class Subscriber:
except Exception as e:
logger.error(f"Subscriber exception: {e}", exc_info=True)
+ # First-attempt connection failure: propagate to start()
+ # so the caller can decide what to do (retry, give up).
+ # Subsequent failures use the existing retry-with-backoff
+ # path so a long-lived subscriber survives broker blips.
+ if first_attempt and not self._ready.done():
+ self._ready.set_exception(e)
+ first_attempt = False
+ # Falls through into finally for cleanup, then the
+ # outer return below ends run() so start() unblocks.
+
finally:
# Negative acknowledge any pending messages
for msg in self.pending_acks.values():
@@ -193,6 +262,11 @@ class Subscriber:
if not self.running and not self.draining:
return
+ # If start() has already returned with an exception there is
+ # nothing more to do — exit run() rather than busy-retry.
+ if self._ready.done() and self._ready.exception() is not None:
+ return
+
# Sleep before retry
await asyncio.sleep(1)
diff --git a/trustgraph-base/trustgraph/base/text_completion_client.py b/trustgraph-base/trustgraph/base/text_completion_client.py
index ae93e22e..876d71df 100644
--- a/trustgraph-base/trustgraph/base/text_completion_client.py
+++ b/trustgraph-base/trustgraph/base/text_completion_client.py
@@ -1,47 +1,71 @@
+from dataclasses import dataclass
+from typing import Optional
+
from . request_response_spec import RequestResponse, RequestResponseSpec
from .. schema import TextCompletionRequest, TextCompletionResponse
+@dataclass
+class TextCompletionResult:
+ text: Optional[str]
+ in_token: Optional[int] = None
+ out_token: Optional[int] = None
+ model: Optional[str] = None
+
class TextCompletionClient(RequestResponse):
- async def text_completion(self, system, prompt, streaming=False, timeout=600):
- # If not streaming, use original behavior
- if not streaming:
- resp = await self.request(
- TextCompletionRequest(
- system = system, prompt = prompt, streaming = False
- ),
- timeout=timeout
- )
- if resp.error:
- raise RuntimeError(resp.error.message)
+ async def text_completion(self, system, prompt, timeout=600):
- return resp.response
-
- # For streaming: collect all chunks and return complete response
- full_response = ""
-
- async def collect_chunks(resp):
- nonlocal full_response
-
- if resp.error:
- raise RuntimeError(resp.error.message)
-
- if resp.response:
- full_response += resp.response
-
- # Return True when end_of_stream is reached
- return getattr(resp, 'end_of_stream', False)
-
- await self.request(
+ resp = await self.request(
TextCompletionRequest(
- system = system, prompt = prompt, streaming = True
+ system = system, prompt = prompt, streaming = False
),
- recipient=collect_chunks,
timeout=timeout
)
- return full_response
+ if resp.error:
+ raise RuntimeError(resp.error.message)
+
+ return TextCompletionResult(
+ text = resp.response,
+ in_token = resp.in_token,
+ out_token = resp.out_token,
+ model = resp.model,
+ )
+
+ async def text_completion_stream(
+ self, system, prompt, handler, timeout=600,
+ ):
+ """
+ Streaming text completion. `handler` is an async callable invoked
+ once per chunk with the chunk's TextCompletionResponse. Returns a
+ TextCompletionResult with text=None and token counts / model taken
+ from the end_of_stream message.
+ """
+
+ async def on_chunk(resp):
+
+ if resp.error:
+ raise RuntimeError(resp.error.message)
+
+ await handler(resp)
+
+ return getattr(resp, "end_of_stream", False)
+
+ final = await self.request(
+ TextCompletionRequest(
+ system = system, prompt = prompt, streaming = True
+ ),
+ recipient=on_chunk,
+ timeout=timeout,
+ )
+
+ return TextCompletionResult(
+ text = None,
+ in_token = final.in_token,
+ out_token = final.out_token,
+ model = final.model,
+ )
class TextCompletionClientSpec(RequestResponseSpec):
def __init__(
@@ -54,4 +78,3 @@ class TextCompletionClientSpec(RequestResponseSpec):
response_schema = TextCompletionResponse,
impl = TextCompletionClient,
)
-
diff --git a/trustgraph-base/trustgraph/clients/agent_client.py b/trustgraph-base/trustgraph/clients/agent_client.py
index 1cadbdd5..d17ea37a 100644
--- a/trustgraph-base/trustgraph/clients/agent_client.py
+++ b/trustgraph-base/trustgraph/clients/agent_client.py
@@ -58,23 +58,23 @@ class AgentClient(BaseClient):
def inspect(x):
# Handle errors
- if x.chunk_type == 'error' or x.error:
+ if x.message_type == 'error' or x.error:
if error_callback:
error_callback(x.content or (x.error.message if x.error else ""))
# Continue to check end_of_dialog
# Handle thought chunks
- elif x.chunk_type == 'thought':
+ elif x.message_type == 'thought':
if think:
think(x.content, x.end_of_message)
# Handle observation chunks
- elif x.chunk_type == 'observation':
+ elif x.message_type == 'observation':
if observe:
observe(x.content, x.end_of_message)
# Handle answer chunks
- elif x.chunk_type == 'answer':
+ elif x.message_type == 'answer':
if x.content:
accumulated_answer.append(x.content)
if answer_callback:
diff --git a/trustgraph-base/trustgraph/messaging/translators/agent.py b/trustgraph-base/trustgraph/messaging/translators/agent.py
index 8cf525f5..7df59907 100644
--- a/trustgraph-base/trustgraph/messaging/translators/agent.py
+++ b/trustgraph-base/trustgraph/messaging/translators/agent.py
@@ -60,8 +60,8 @@ class AgentResponseTranslator(MessageTranslator):
def encode(self, obj: AgentResponse) -> Dict[str, Any]:
result = {}
- if obj.chunk_type:
- result["chunk_type"] = obj.chunk_type
+ if obj.message_type:
+ result["message_type"] = obj.message_type
if obj.content:
result["content"] = obj.content
result["end_of_message"] = getattr(obj, "end_of_message", False)
@@ -90,6 +90,13 @@ class AgentResponseTranslator(MessageTranslator):
if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "code": obj.error.code}
+ if obj.in_token is not None:
+ result["in_token"] = obj.in_token
+ if obj.out_token is not None:
+ result["out_token"] = obj.out_token
+ if obj.model is not None:
+ result["model"] = obj.model
+
return result
def encode_with_completion(self, obj: AgentResponse) -> Tuple[Dict[str, Any], bool]:
diff --git a/trustgraph-base/trustgraph/messaging/translators/document_loading.py b/trustgraph-base/trustgraph/messaging/translators/document_loading.py
index 3e7062e2..df2aa3ba 100644
--- a/trustgraph-base/trustgraph/messaging/translators/document_loading.py
+++ b/trustgraph-base/trustgraph/messaging/translators/document_loading.py
@@ -151,7 +151,7 @@ class DocumentEmbeddingsTranslator(SendTranslator):
chunks = [
ChunkEmbeddings(
chunk_id=chunk["chunk_id"],
- vectors=chunk["vectors"]
+ vector=chunk["vector"]
)
for chunk in data.get("chunks", [])
]
diff --git a/trustgraph-base/trustgraph/messaging/translators/knowledge.py b/trustgraph-base/trustgraph/messaging/translators/knowledge.py
index 2f11d75a..f819dc9c 100644
--- a/trustgraph-base/trustgraph/messaging/translators/knowledge.py
+++ b/trustgraph-base/trustgraph/messaging/translators/knowledge.py
@@ -39,7 +39,7 @@ class KnowledgeRequestTranslator(MessageTranslator):
entities=[
EntityEmbeddings(
entity=self.value_translator.decode(ent["entity"]),
- vectors=ent["vectors"],
+ vector=ent["vector"],
)
for ent in data["graph-embeddings"]["entities"]
]
diff --git a/trustgraph-base/trustgraph/messaging/translators/prompt.py b/trustgraph-base/trustgraph/messaging/translators/prompt.py
index 4345e6fd..7f76bf4a 100644
--- a/trustgraph-base/trustgraph/messaging/translators/prompt.py
+++ b/trustgraph-base/trustgraph/messaging/translators/prompt.py
@@ -53,6 +53,13 @@ class PromptResponseTranslator(MessageTranslator):
# Always include end_of_stream flag for streaming support
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
+ if obj.in_token is not None:
+ result["in_token"] = obj.in_token
+ if obj.out_token is not None:
+ result["out_token"] = obj.out_token
+ if obj.model is not None:
+ result["model"] = obj.model
+
return result
def encode_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]:
diff --git a/trustgraph-base/trustgraph/messaging/translators/retrieval.py b/trustgraph-base/trustgraph/messaging/translators/retrieval.py
index 849bee94..e37b76e1 100644
--- a/trustgraph-base/trustgraph/messaging/translators/retrieval.py
+++ b/trustgraph-base/trustgraph/messaging/translators/retrieval.py
@@ -74,6 +74,13 @@ class DocumentRagResponseTranslator(MessageTranslator):
if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "type": obj.error.type}
+ if obj.in_token is not None:
+ result["in_token"] = obj.in_token
+ if obj.out_token is not None:
+ result["out_token"] = obj.out_token
+ if obj.model is not None:
+ result["model"] = obj.model
+
return result
def encode_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]:
@@ -163,6 +170,13 @@ class GraphRagResponseTranslator(MessageTranslator):
if hasattr(obj, 'error') and obj.error and obj.error.message:
result["error"] = {"message": obj.error.message, "type": obj.error.type}
+ if obj.in_token is not None:
+ result["in_token"] = obj.in_token
+ if obj.out_token is not None:
+ result["out_token"] = obj.out_token
+ if obj.model is not None:
+ result["model"] = obj.model
+
return result
def encode_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]:
diff --git a/trustgraph-base/trustgraph/messaging/translators/text_completion.py b/trustgraph-base/trustgraph/messaging/translators/text_completion.py
index 596ff744..62cc4afb 100644
--- a/trustgraph-base/trustgraph/messaging/translators/text_completion.py
+++ b/trustgraph-base/trustgraph/messaging/translators/text_completion.py
@@ -29,11 +29,11 @@ class TextCompletionResponseTranslator(MessageTranslator):
def encode(self, obj: TextCompletionResponse) -> Dict[str, Any]:
result = {"response": obj.response}
- if obj.in_token:
+ if obj.in_token is not None:
result["in_token"] = obj.in_token
- if obj.out_token:
+ if obj.out_token is not None:
result["out_token"] = obj.out_token
- if obj.model:
+ if obj.model is not None:
result["model"] = obj.model
# Always include end_of_stream flag for streaming support
diff --git a/trustgraph-base/trustgraph/provenance/__init__.py b/trustgraph-base/trustgraph/provenance/__init__.py
index e6ce0a9e..051efc66 100644
--- a/trustgraph-base/trustgraph/provenance/__init__.py
+++ b/trustgraph-base/trustgraph/provenance/__init__.py
@@ -59,6 +59,7 @@ from . uris import (
agent_plan_uri,
agent_step_result_uri,
agent_synthesis_uri,
+ agent_pattern_decision_uri,
# Document RAG provenance URIs
docrag_question_uri,
docrag_grounding_uri,
@@ -102,6 +103,11 @@ from . namespaces import (
# Agent provenance predicates
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
TG_SUBAGENT_GOAL, TG_PLAN_STEP,
+ TG_TOOL_CANDIDATE, TG_TERMINATION_REASON,
+ TG_STEP_NUMBER, TG_PATTERN_DECISION, TG_PATTERN, TG_TASK_TYPE,
+ TG_LLM_DURATION_MS, TG_TOOL_DURATION_MS, TG_TOOL_ERROR,
+ TG_IN_TOKEN, TG_OUT_TOKEN,
+ TG_ERROR_TYPE,
# Orchestrator entity types
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
# Document reference predicate
@@ -141,6 +147,7 @@ from . agent import (
agent_plan_triples,
agent_step_result_triples,
agent_synthesis_triples,
+ agent_pattern_decision_triples,
)
# Vocabulary bootstrap
@@ -182,6 +189,7 @@ __all__ = [
"agent_plan_uri",
"agent_step_result_uri",
"agent_synthesis_uri",
+ "agent_pattern_decision_uri",
# Document RAG provenance URIs
"docrag_question_uri",
"docrag_grounding_uri",
@@ -218,6 +226,11 @@ __all__ = [
# Agent provenance predicates
"TG_THOUGHT", "TG_ACTION", "TG_ARGUMENTS", "TG_OBSERVATION",
"TG_SUBAGENT_GOAL", "TG_PLAN_STEP",
+ "TG_TOOL_CANDIDATE", "TG_TERMINATION_REASON",
+ "TG_STEP_NUMBER", "TG_PATTERN_DECISION", "TG_PATTERN", "TG_TASK_TYPE",
+ "TG_LLM_DURATION_MS", "TG_TOOL_DURATION_MS", "TG_TOOL_ERROR",
+ "TG_IN_TOKEN", "TG_OUT_TOKEN",
+ "TG_ERROR_TYPE",
# Orchestrator entity types
"TG_DECOMPOSITION", "TG_FINDING", "TG_PLAN_TYPE", "TG_STEP_RESULT",
# Document reference predicate
@@ -249,6 +262,7 @@ __all__ = [
"agent_plan_triples",
"agent_step_result_triples",
"agent_synthesis_triples",
+ "agent_pattern_decision_triples",
# Utility
"set_graph",
# Vocabulary
diff --git a/trustgraph-base/trustgraph/provenance/agent.py b/trustgraph-base/trustgraph/provenance/agent.py
index 7203174e..5c4f0b2e 100644
--- a/trustgraph-base/trustgraph/provenance/agent.py
+++ b/trustgraph-base/trustgraph/provenance/agent.py
@@ -29,6 +29,11 @@ from . namespaces import (
TG_AGENT_QUESTION,
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
TG_SYNTHESIS, TG_SUBAGENT_GOAL, TG_PLAN_STEP,
+ TG_TOOL_CANDIDATE, TG_TERMINATION_REASON,
+ TG_STEP_NUMBER, TG_PATTERN_DECISION, TG_PATTERN, TG_TASK_TYPE,
+ TG_LLM_DURATION_MS, TG_TOOL_DURATION_MS, TG_TOOL_ERROR,
+ TG_ERROR_TYPE,
+ TG_IN_TOKEN, TG_OUT_TOKEN, TG_LLM_MODEL,
)
@@ -47,6 +52,17 @@ def _triple(s: str, p: str, o_term: Term) -> Triple:
return Triple(s=_iri(s), p=_iri(p), o=o_term)
+def _append_token_triples(triples, uri, in_token=None, out_token=None,
+ model=None):
+ """Append in_token/out_token/model triples when values are present."""
+ if in_token is not None:
+ triples.append(_triple(uri, TG_IN_TOKEN, _literal(str(in_token))))
+ if out_token is not None:
+ triples.append(_triple(uri, TG_OUT_TOKEN, _literal(str(out_token))))
+ if model is not None:
+ triples.append(_triple(uri, TG_LLM_MODEL, _literal(model)))
+
+
def agent_session_triples(
session_uri: str,
query: str,
@@ -90,6 +106,43 @@ def agent_session_triples(
return triples
+def agent_pattern_decision_triples(
+ uri: str,
+ session_uri: str,
+ pattern: str,
+ task_type: str = "",
+) -> List[Triple]:
+ """
+ Build triples for a meta-router pattern decision.
+
+ Creates:
+ - Entity declaration with tg:PatternDecision type
+ - wasDerivedFrom link to session
+ - Pattern and task type predicates
+
+ Args:
+ uri: URI of this decision (from agent_pattern_decision_uri)
+ session_uri: URI of the parent session
+ pattern: Selected execution pattern (e.g. "react", "plan-then-execute")
+ task_type: Identified task type (e.g. "general", "research")
+
+ Returns:
+ List of Triple objects
+ """
+ triples = [
+ _triple(uri, RDF_TYPE, _iri(PROV_ENTITY)),
+ _triple(uri, RDF_TYPE, _iri(TG_PATTERN_DECISION)),
+ _triple(uri, RDFS_LABEL, _literal(f"Pattern: {pattern}")),
+ _triple(uri, TG_PATTERN, _literal(pattern)),
+ _triple(uri, PROV_WAS_DERIVED_FROM, _iri(session_uri)),
+ ]
+
+ if task_type:
+ triples.append(_triple(uri, TG_TASK_TYPE, _literal(task_type)))
+
+ return triples
+
+
def agent_iteration_triples(
iteration_uri: str,
question_uri: Optional[str] = None,
@@ -98,6 +151,12 @@ def agent_iteration_triples(
arguments: Dict[str, Any] = None,
thought_uri: Optional[str] = None,
thought_document_id: Optional[str] = None,
+ tool_candidates: Optional[List[str]] = None,
+ step_number: Optional[int] = None,
+ llm_duration_ms: Optional[int] = None,
+ in_token: Optional[int] = None,
+ out_token: Optional[int] = None,
+ model: Optional[str] = None,
) -> List[Triple]:
"""
Build triples for one agent iteration (Analysis+ToolUse).
@@ -106,6 +165,7 @@ def agent_iteration_triples(
- Entity declaration with tg:Analysis and tg:ToolUse types
- wasDerivedFrom link to question (if first iteration) or previous
- Action and arguments metadata
+ - Tool candidates (names of tools visible to the LLM)
- Thought sub-entity (tg:Reflection, tg:Thought) with librarian document
Args:
@@ -116,6 +176,7 @@ def agent_iteration_triples(
arguments: Arguments passed to the tool (will be JSON-encoded)
thought_uri: URI for the thought sub-entity
thought_document_id: Document URI for thought in librarian
+ tool_candidates: List of tool names available to the LLM
Returns:
List of Triple objects
@@ -132,6 +193,23 @@ def agent_iteration_triples(
_triple(iteration_uri, TG_ARGUMENTS, _literal(json.dumps(arguments))),
]
+ if tool_candidates:
+ for name in tool_candidates:
+ triples.append(
+ _triple(iteration_uri, TG_TOOL_CANDIDATE, _literal(name))
+ )
+
+ if step_number is not None:
+ triples.append(
+ _triple(iteration_uri, TG_STEP_NUMBER, _literal(str(step_number)))
+ )
+
+ if llm_duration_ms is not None:
+ triples.append(
+ _triple(iteration_uri, TG_LLM_DURATION_MS,
+ _literal(str(llm_duration_ms)))
+ )
+
if question_uri:
triples.append(
_triple(iteration_uri, PROV_WAS_DERIVED_FROM, _iri(question_uri))
@@ -155,6 +233,8 @@ def agent_iteration_triples(
_triple(thought_uri, TG_DOCUMENT, _iri(thought_document_id))
)
+ _append_token_triples(triples, iteration_uri, in_token, out_token, model)
+
return triples
@@ -162,6 +242,8 @@ def agent_observation_triples(
observation_uri: str,
iteration_uri: str,
document_id: Optional[str] = None,
+ tool_duration_ms: Optional[int] = None,
+ tool_error: Optional[str] = None,
) -> List[Triple]:
"""
Build triples for an agent observation (standalone entity).
@@ -170,11 +252,15 @@ def agent_observation_triples(
- Entity declaration with prov:Entity and tg:Observation types
- wasDerivedFrom link to the iteration (Analysis+ToolUse)
- Document reference to librarian (if provided)
+ - Tool execution duration (if provided)
+ - Tool error message (if the tool failed)
Args:
observation_uri: URI of the observation entity
iteration_uri: URI of the iteration this observation derives from
document_id: Librarian document ID for the observation content
+ tool_duration_ms: Tool execution time in milliseconds
+ tool_error: Error message if the tool failed
Returns:
List of Triple objects
@@ -191,6 +277,20 @@ def agent_observation_triples(
_triple(observation_uri, TG_DOCUMENT, _iri(document_id))
)
+ if tool_duration_ms is not None:
+ triples.append(
+ _triple(observation_uri, TG_TOOL_DURATION_MS,
+ _literal(str(tool_duration_ms)))
+ )
+
+ if tool_error:
+ triples.append(
+ _triple(observation_uri, TG_TOOL_ERROR, _literal(tool_error))
+ )
+ triples.append(
+ _triple(observation_uri, RDF_TYPE, _iri(TG_ERROR_TYPE))
+ )
+
return triples
@@ -199,6 +299,10 @@ def agent_final_triples(
question_uri: Optional[str] = None,
previous_uri: Optional[str] = None,
document_id: Optional[str] = None,
+ termination_reason: Optional[str] = None,
+ in_token: Optional[int] = None,
+ out_token: Optional[int] = None,
+ model: Optional[str] = None,
) -> List[Triple]:
"""
Build triples for an agent final answer (Conclusion).
@@ -208,12 +312,15 @@ def agent_final_triples(
- wasGeneratedBy link to question (if no iterations)
- wasDerivedFrom link to last iteration (if iterations exist)
- Document reference to librarian
+ - Termination reason (why the agent loop stopped)
Args:
final_uri: URI of the final answer (from agent_final_uri)
question_uri: URI of the question activity (if no iterations)
previous_uri: URI of the last iteration (if iterations exist)
document_id: Librarian document ID for the answer content
+ termination_reason: Why the loop stopped, e.g. "final-answer",
+ "max-iterations", "error"
Returns:
List of Triple objects
@@ -237,6 +344,14 @@ def agent_final_triples(
if document_id:
triples.append(_triple(final_uri, TG_DOCUMENT, _iri(document_id)))
+ if termination_reason:
+ triples.append(
+ _triple(final_uri, TG_TERMINATION_REASON,
+ _literal(termination_reason))
+ )
+
+ _append_token_triples(triples, final_uri, in_token, out_token, model)
+
return triples
@@ -244,6 +359,9 @@ def agent_decomposition_triples(
uri: str,
session_uri: str,
goals: List[str],
+ in_token: Optional[int] = None,
+ out_token: Optional[int] = None,
+ model: Optional[str] = None,
) -> List[Triple]:
"""Build triples for a supervisor decomposition step."""
triples = [
@@ -255,6 +373,7 @@ def agent_decomposition_triples(
]
for goal in goals:
triples.append(_triple(uri, TG_SUBAGENT_GOAL, _literal(goal)))
+ _append_token_triples(triples, uri, in_token, out_token, model)
return triples
@@ -282,6 +401,9 @@ def agent_plan_triples(
uri: str,
session_uri: str,
steps: List[str],
+ in_token: Optional[int] = None,
+ out_token: Optional[int] = None,
+ model: Optional[str] = None,
) -> List[Triple]:
"""Build triples for a plan-then-execute plan."""
triples = [
@@ -293,6 +415,7 @@ def agent_plan_triples(
]
for step in steps:
triples.append(_triple(uri, TG_PLAN_STEP, _literal(step)))
+ _append_token_triples(triples, uri, in_token, out_token, model)
return triples
@@ -301,6 +424,9 @@ def agent_step_result_triples(
plan_uri: str,
goal: str,
document_id: Optional[str] = None,
+ in_token: Optional[int] = None,
+ out_token: Optional[int] = None,
+ model: Optional[str] = None,
) -> List[Triple]:
"""Build triples for a plan step result."""
triples = [
@@ -313,6 +439,7 @@ def agent_step_result_triples(
]
if document_id:
triples.append(_triple(uri, TG_DOCUMENT, _iri(document_id)))
+ _append_token_triples(triples, uri, in_token, out_token, model)
return triples
@@ -320,6 +447,10 @@ def agent_synthesis_triples(
uri: str,
previous_uris,
document_id: Optional[str] = None,
+ termination_reason: Optional[str] = None,
+ in_token: Optional[int] = None,
+ out_token: Optional[int] = None,
+ model: Optional[str] = None,
) -> List[Triple]:
"""Build triples for a synthesis answer.
@@ -327,6 +458,8 @@ def agent_synthesis_triples(
uri: URI of the synthesis entity
previous_uris: Single URI string or list of URIs to derive from
document_id: Librarian document ID for the answer content
+ termination_reason: Why the agent loop stopped
+ in_token/out_token/model: Token usage for the synthesis LLM call
"""
triples = [
_triple(uri, RDF_TYPE, _iri(PROV_ENTITY)),
@@ -342,4 +475,12 @@ def agent_synthesis_triples(
if document_id:
triples.append(_triple(uri, TG_DOCUMENT, _iri(document_id)))
+
+ if termination_reason:
+ triples.append(
+ _triple(uri, TG_TERMINATION_REASON, _literal(termination_reason))
+ )
+
+ _append_token_triples(triples, uri, in_token, out_token, model)
+
return triples
diff --git a/trustgraph-base/trustgraph/provenance/namespaces.py b/trustgraph-base/trustgraph/provenance/namespaces.py
index 9e7fbb2d..0b14f1b9 100644
--- a/trustgraph-base/trustgraph/provenance/namespaces.py
+++ b/trustgraph-base/trustgraph/provenance/namespaces.py
@@ -119,6 +119,18 @@ TG_ARGUMENTS = TG + "arguments"
TG_OBSERVATION = TG + "observation" # Links iteration to observation sub-entity
TG_SUBAGENT_GOAL = TG + "subagentGoal" # Goal string on Decomposition/Finding
TG_PLAN_STEP = TG + "planStep" # Step goal string on Plan/StepResult
+TG_TOOL_CANDIDATE = TG + "toolCandidate" # Tool name on Analysis events
+TG_TERMINATION_REASON = TG + "terminationReason" # Why the agent loop stopped
+TG_STEP_NUMBER = TG + "stepNumber" # Explicit step counter on iteration events
+TG_PATTERN_DECISION = TG + "PatternDecision" # Meta-router routing decision entity type
+TG_PATTERN = TG + "pattern" # Selected execution pattern
+TG_TASK_TYPE = TG + "taskType" # Identified task type
+TG_LLM_DURATION_MS = TG + "llmDurationMs" # LLM call duration in milliseconds
+TG_TOOL_DURATION_MS = TG + "toolDurationMs" # Tool execution duration in milliseconds
+TG_TOOL_ERROR = TG + "toolError" # Error message from a failed tool execution
+TG_ERROR_TYPE = TG + "Error" # Mixin type for failure events
+TG_IN_TOKEN = TG + "inToken" # Input token count for an LLM call
+TG_OUT_TOKEN = TG + "outToken" # Output token count for an LLM call
# Named graph URIs for RDF datasets
# These separate different types of data while keeping them in the same collection
diff --git a/trustgraph-base/trustgraph/provenance/triples.py b/trustgraph-base/trustgraph/provenance/triples.py
index 920a3482..8bdfc2cb 100644
--- a/trustgraph-base/trustgraph/provenance/triples.py
+++ b/trustgraph-base/trustgraph/provenance/triples.py
@@ -34,6 +34,8 @@ from . namespaces import (
TG_ANSWER_TYPE,
# Question subtypes
TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION,
+ # Token usage
+ TG_IN_TOKEN, TG_OUT_TOKEN,
)
from . uris import activity_uri, agent_uri, subgraph_uri, edge_selection_uri
@@ -74,6 +76,17 @@ def _triple(s: str, p: str, o_term: Term) -> Triple:
return Triple(s=_iri(s), p=_iri(p), o=o_term)
+def _append_token_triples(triples, uri, in_token=None, out_token=None,
+ model=None):
+ """Append in_token/out_token/model triples when values are present."""
+ if in_token is not None:
+ triples.append(_triple(uri, TG_IN_TOKEN, _literal(str(in_token))))
+ if out_token is not None:
+ triples.append(_triple(uri, TG_OUT_TOKEN, _literal(str(out_token))))
+ if model is not None:
+ triples.append(_triple(uri, TG_LLM_MODEL, _literal(model)))
+
+
def document_triples(
doc_uri: str,
title: Optional[str] = None,
@@ -396,6 +409,9 @@ def grounding_triples(
grounding_uri: str,
question_uri: str,
concepts: List[str],
+ in_token: Optional[int] = None,
+ out_token: Optional[int] = None,
+ model: Optional[str] = None,
) -> List[Triple]:
"""
Build triples for a grounding entity (concept decomposition of query).
@@ -423,6 +439,8 @@ def grounding_triples(
for concept in concepts:
triples.append(_triple(grounding_uri, TG_CONCEPT, _literal(concept)))
+ _append_token_triples(triples, grounding_uri, in_token, out_token, model)
+
return triples
@@ -485,6 +503,9 @@ def focus_triples(
exploration_uri: str,
selected_edges_with_reasoning: List[dict],
session_id: str = "",
+ in_token: Optional[int] = None,
+ out_token: Optional[int] = None,
+ model: Optional[str] = None,
) -> List[Triple]:
"""
Build triples for a focus entity (selected edges with reasoning).
@@ -543,6 +564,8 @@ def focus_triples(
_triple(edge_sel_uri, TG_REASONING, _literal(reasoning))
)
+ _append_token_triples(triples, focus_uri, in_token, out_token, model)
+
return triples
@@ -550,6 +573,9 @@ def synthesis_triples(
synthesis_uri: str,
focus_uri: str,
document_id: Optional[str] = None,
+ in_token: Optional[int] = None,
+ out_token: Optional[int] = None,
+ model: Optional[str] = None,
) -> List[Triple]:
"""
Build triples for a synthesis entity (final answer).
@@ -578,6 +604,8 @@ def synthesis_triples(
if document_id:
triples.append(_triple(synthesis_uri, TG_DOCUMENT, _iri(document_id)))
+ _append_token_triples(triples, synthesis_uri, in_token, out_token, model)
+
return triples
@@ -674,6 +702,9 @@ def docrag_synthesis_triples(
synthesis_uri: str,
exploration_uri: str,
document_id: Optional[str] = None,
+ in_token: Optional[int] = None,
+ out_token: Optional[int] = None,
+ model: Optional[str] = None,
) -> List[Triple]:
"""
Build triples for a document RAG synthesis entity (final answer).
@@ -702,4 +733,6 @@ def docrag_synthesis_triples(
if document_id:
triples.append(_triple(synthesis_uri, TG_DOCUMENT, _iri(document_id)))
+ _append_token_triples(triples, synthesis_uri, in_token, out_token, model)
+
return triples
diff --git a/trustgraph-base/trustgraph/provenance/uris.py b/trustgraph-base/trustgraph/provenance/uris.py
index a3aadef6..a26ac867 100644
--- a/trustgraph-base/trustgraph/provenance/uris.py
+++ b/trustgraph-base/trustgraph/provenance/uris.py
@@ -259,6 +259,11 @@ def agent_synthesis_uri(session_id: str) -> str:
return f"urn:trustgraph:agent:{session_id}/synthesis"
+def agent_pattern_decision_uri(session_id: str) -> str:
+ """Generate URI for a meta-router pattern decision."""
+ return f"urn:trustgraph:agent:{session_id}/pattern-decision"
+
+
# Document RAG provenance URIs
# These URIs use the urn:trustgraph:docrag: namespace to distinguish
# document RAG provenance from graph RAG provenance
diff --git a/trustgraph-base/trustgraph/schema/services/agent.py b/trustgraph-base/trustgraph/schema/services/agent.py
index fbc0101c..cd4a2b45 100644
--- a/trustgraph-base/trustgraph/schema/services/agent.py
+++ b/trustgraph-base/trustgraph/schema/services/agent.py
@@ -51,8 +51,8 @@ class AgentRequest:
@dataclass
class AgentResponse:
# Streaming-first design
- chunk_type: str = "" # "thought", "action", "observation", "answer", "explain", "error"
- content: str = "" # The actual content (interpretation depends on chunk_type)
+ message_type: str = "" # "thought", "action", "observation", "answer", "explain", "error"
+ content: str = "" # The actual content (interpretation depends on message_type)
end_of_message: bool = False # Current chunk type (thought/action/etc.) is complete
end_of_dialog: bool = False # Entire agent dialog is complete
@@ -66,5 +66,10 @@ class AgentResponse:
error: Error | None = None
+ # Token usage (populated on end_of_dialog message)
+ in_token: int | None = None
+ out_token: int | None = None
+ model: str | None = None
+
############################################################################
diff --git a/trustgraph-base/trustgraph/schema/services/llm.py b/trustgraph-base/trustgraph/schema/services/llm.py
index 0fd6ab90..89c0cd54 100644
--- a/trustgraph-base/trustgraph/schema/services/llm.py
+++ b/trustgraph-base/trustgraph/schema/services/llm.py
@@ -17,9 +17,9 @@ class TextCompletionRequest:
class TextCompletionResponse:
error: Error | None = None
response: str = ""
- in_token: int = 0
- out_token: int = 0
- model: str = ""
+ in_token: int | None = None
+ out_token: int | None = None
+ model: str | None = None
end_of_stream: bool = False # Indicates final message in stream
############################################################################
diff --git a/trustgraph-base/trustgraph/schema/services/prompt.py b/trustgraph-base/trustgraph/schema/services/prompt.py
index f7388102..1696790b 100644
--- a/trustgraph-base/trustgraph/schema/services/prompt.py
+++ b/trustgraph-base/trustgraph/schema/services/prompt.py
@@ -41,4 +41,9 @@ class PromptResponse:
# Indicates final message in stream
end_of_stream: bool = False
+ # Token usage from the underlying text completion
+ in_token: int | None = None
+ out_token: int | None = None
+ model: str | None = None
+
############################################################################
\ No newline at end of file
diff --git a/trustgraph-base/trustgraph/schema/services/retrieval.py b/trustgraph-base/trustgraph/schema/services/retrieval.py
index 4b17733d..a1af9170 100644
--- a/trustgraph-base/trustgraph/schema/services/retrieval.py
+++ b/trustgraph-base/trustgraph/schema/services/retrieval.py
@@ -29,6 +29,9 @@ class GraphRagResponse:
explain_triples: list[Triple] = field(default_factory=list) # Provenance triples for this step
message_type: str = "" # "chunk" or "explain"
end_of_session: bool = False # Entire session complete
+ in_token: int | None = None
+ out_token: int | None = None
+ model: str | None = None
############################################################################
@@ -52,3 +55,6 @@ class DocumentRagResponse:
explain_triples: list[Triple] = field(default_factory=list) # Provenance triples for this step
message_type: str = "" # "chunk" or "explain"
end_of_session: bool = False # Entire session complete
+ in_token: int | None = None
+ out_token: int | None = None
+ model: str | None = None
diff --git a/trustgraph-bedrock/pyproject.toml b/trustgraph-bedrock/pyproject.toml
index 6dd017c7..2d65461b 100644
--- a/trustgraph-bedrock/pyproject.toml
+++ b/trustgraph-bedrock/pyproject.toml
@@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
- "trustgraph-base>=2.2,<2.3",
+ "trustgraph-base>=2.3,<2.4",
"pulsar-client",
"prometheus-client",
"boto3",
diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml
index 2b111cae..0151fef4 100644
--- a/trustgraph-cli/pyproject.toml
+++ b/trustgraph-cli/pyproject.toml
@@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
- "trustgraph-base>=2.2,<2.3",
+ "trustgraph-base>=2.3,<2.4",
"requests",
"pulsar-client",
"aiohttp",
diff --git a/trustgraph-cli/trustgraph/cli/invoke_agent.py b/trustgraph-cli/trustgraph/cli/invoke_agent.py
index 026286d0..b379c2df 100644
--- a/trustgraph-cli/trustgraph/cli/invoke_agent.py
+++ b/trustgraph-cli/trustgraph/cli/invoke_agent.py
@@ -126,7 +126,7 @@ def question_explainable(
try:
# Track last chunk type for formatting
- last_chunk_type = None
+ last_message_type = None
current_outputter = None
# Stream agent with explainability - process events as they arrive
@@ -138,7 +138,7 @@ def question_explainable(
group=group,
):
if isinstance(item, AgentThought):
- if last_chunk_type != "thought":
+ if last_message_type != "thought":
if current_outputter:
current_outputter.__exit__(None, None, None)
current_outputter = None
@@ -146,7 +146,7 @@ def question_explainable(
if verbose:
current_outputter = Outputter(width=78, prefix="\U0001f914 ")
current_outputter.__enter__()
- last_chunk_type = "thought"
+ last_message_type = "thought"
if current_outputter:
current_outputter.output(item.content)
if current_outputter.word_buffer:
@@ -155,7 +155,7 @@ def question_explainable(
current_outputter.word_buffer = ""
elif isinstance(item, AgentObservation):
- if last_chunk_type != "observation":
+ if last_message_type != "observation":
if current_outputter:
current_outputter.__exit__(None, None, None)
current_outputter = None
@@ -163,7 +163,7 @@ def question_explainable(
if verbose:
current_outputter = Outputter(width=78, prefix="\U0001f4a1 ")
current_outputter.__enter__()
- last_chunk_type = "observation"
+ last_message_type = "observation"
if current_outputter:
current_outputter.output(item.content)
if current_outputter.word_buffer:
@@ -172,12 +172,12 @@ def question_explainable(
current_outputter.word_buffer = ""
elif isinstance(item, AgentAnswer):
- if last_chunk_type != "answer":
+ if last_message_type != "answer":
if current_outputter:
current_outputter.__exit__(None, None, None)
current_outputter = None
print()
- last_chunk_type = "answer"
+ last_message_type = "answer"
# Print answer content directly
print(item.content, end="", flush=True)
@@ -261,7 +261,7 @@ def question_explainable(
current_outputter = None
# Final newline if we ended with answer
- if last_chunk_type == "answer":
+ if last_message_type == "answer":
print()
finally:
@@ -272,7 +272,8 @@ def question(
url, question, flow_id, user, collection,
plan=None, state=None, group=None, pattern=None,
verbose=False, streaming=True,
- token=None, explainable=False, debug=False
+ token=None, explainable=False, debug=False,
+ show_usage=False
):
# Explainable mode uses the API to capture and process provenance events
if explainable:
@@ -321,15 +322,16 @@ def question(
# Handle streaming response
if streaming:
# Track last chunk type and current outputter for streaming
- last_chunk_type = None
+ last_message_type = None
current_outputter = None
+ last_answer_chunk = None
for chunk in response:
- chunk_type = chunk.chunk_type
+ message_type = chunk.message_type
content = chunk.content
# Check if we're switching to a new message type
- if last_chunk_type != chunk_type:
+ if last_message_type != message_type:
# Close previous outputter if exists
if current_outputter:
current_outputter.__exit__(None, None, None)
@@ -337,15 +339,15 @@ def question(
print() # Blank line between message types
# Create new outputter for new message type
- if chunk_type == "thought" and verbose:
+ if message_type == "thought" and verbose:
current_outputter = Outputter(width=78, prefix="\U0001f914 ")
current_outputter.__enter__()
- elif chunk_type == "observation" and verbose:
+ elif message_type == "observation" and verbose:
current_outputter = Outputter(width=78, prefix="\U0001f4a1 ")
current_outputter.__enter__()
# For answer, don't use Outputter - just print as-is
- last_chunk_type = chunk_type
+ last_message_type = message_type
# Output the chunk
if current_outputter:
@@ -355,33 +357,42 @@ def question(
print(current_outputter.word_buffer, end="", flush=True)
current_outputter.column += len(current_outputter.word_buffer)
current_outputter.word_buffer = ""
- elif chunk_type == "final-answer":
+ elif message_type == "final-answer":
print(content, end="", flush=True)
+ last_answer_chunk = chunk
# Close any remaining outputter
if current_outputter:
current_outputter.__exit__(None, None, None)
current_outputter = None
# Add final newline if we were outputting answer
- elif last_chunk_type == "final-answer":
+ elif last_message_type == "final-answer":
print()
+ if show_usage and last_answer_chunk:
+ print(
+ f"Input tokens: {last_answer_chunk.in_token} "
+ f"Output tokens: {last_answer_chunk.out_token} "
+ f"Model: {last_answer_chunk.model}",
+ file=sys.stderr,
+ )
+
else:
# Non-streaming response - but agents use multipart messaging
# so we iterate through the chunks (which are complete messages, not text chunks)
for chunk in response:
# Display thoughts if verbose
- if chunk.chunk_type == "thought" and verbose:
+ if chunk.message_type == "thought" and verbose:
output(wrap(chunk.content), "\U0001f914 ")
print()
# Display observations if verbose
- elif chunk.chunk_type == "observation" and verbose:
+ elif chunk.message_type == "observation" and verbose:
output(wrap(chunk.content), "\U0001f4a1 ")
print()
# Display answer
- elif chunk.chunk_type == "final-answer" or chunk.chunk_type == "answer":
+ elif chunk.message_type == "final-answer" or chunk.message_type == "answer":
print(chunk.content)
finally:
@@ -477,6 +488,12 @@ def main():
help='Show debug output for troubleshooting'
)
+ parser.add_argument(
+ '--show-usage',
+ action='store_true',
+ help='Show token usage and model on stderr'
+ )
+
args = parser.parse_args()
try:
@@ -496,6 +513,7 @@ def main():
token = args.token,
explainable = args.explainable,
debug = args.debug,
+ show_usage = args.show_usage,
)
except Exception as e:
diff --git a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py
index 066b92f4..d566f51d 100644
--- a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py
+++ b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py
@@ -99,7 +99,8 @@ def question_explainable(
def question(
url, flow_id, question_text, user, collection, doc_limit,
- streaming=True, token=None, explainable=False, debug=False
+ streaming=True, token=None, explainable=False, debug=False,
+ show_usage=False
):
# Explainable mode uses the API to capture and process provenance events
if explainable:
@@ -133,22 +134,40 @@ def question(
)
# Stream output
+ last_chunk = None
for chunk in response:
- print(chunk, end="", flush=True)
+ print(chunk.content, end="", flush=True)
+ last_chunk = chunk
print() # Final newline
+ if show_usage and last_chunk:
+ print(
+ f"Input tokens: {last_chunk.in_token} "
+ f"Output tokens: {last_chunk.out_token} "
+ f"Model: {last_chunk.model}",
+ file=sys.stderr,
+ )
+
finally:
socket.close()
else:
# Use REST API for non-streaming
flow = api.flow().id(flow_id)
- resp = flow.document_rag(
+ result = flow.document_rag(
query=question_text,
user=user,
collection=collection,
doc_limit=doc_limit,
)
- print(resp)
+ print(result.text)
+
+ if show_usage:
+ print(
+ f"Input tokens: {result.in_token} "
+ f"Output tokens: {result.out_token} "
+ f"Model: {result.model}",
+ file=sys.stderr,
+ )
def main():
@@ -219,6 +238,12 @@ def main():
help='Show debug output for troubleshooting'
)
+ parser.add_argument(
+ '--show-usage',
+ action='store_true',
+ help='Show token usage and model on stderr'
+ )
+
args = parser.parse_args()
try:
@@ -234,6 +259,7 @@ def main():
token=args.token,
explainable=args.explainable,
debug=args.debug,
+ show_usage=args.show_usage,
)
except Exception as e:
diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py
index 230cc54b..c9efe54d 100644
--- a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py
+++ b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py
@@ -753,7 +753,7 @@ def question(
url, flow_id, question, user, collection, entity_limit, triple_limit,
max_subgraph_size, max_path_length, edge_score_limit=50,
edge_limit=25, streaming=True, token=None,
- explainable=False, debug=False
+ explainable=False, debug=False, show_usage=False
):
# Explainable mode uses the API to capture and process provenance events
@@ -798,16 +798,26 @@ def question(
)
# Stream output
+ last_chunk = None
for chunk in response:
- print(chunk, end="", flush=True)
+ print(chunk.content, end="", flush=True)
+ last_chunk = chunk
print() # Final newline
+ if show_usage and last_chunk:
+ print(
+ f"Input tokens: {last_chunk.in_token} "
+ f"Output tokens: {last_chunk.out_token} "
+ f"Model: {last_chunk.model}",
+ file=sys.stderr,
+ )
+
finally:
socket.close()
else:
# Use REST API for non-streaming
flow = api.flow().id(flow_id)
- resp = flow.graph_rag(
+ result = flow.graph_rag(
query=question,
user=user,
collection=collection,
@@ -818,7 +828,15 @@ def question(
edge_score_limit=edge_score_limit,
edge_limit=edge_limit,
)
- print(resp)
+ print(result.text)
+
+ if show_usage:
+ print(
+ f"Input tokens: {result.in_token} "
+ f"Output tokens: {result.out_token} "
+ f"Model: {result.model}",
+ file=sys.stderr,
+ )
def main():
@@ -923,6 +941,12 @@ def main():
help='Show debug output for troubleshooting'
)
+ parser.add_argument(
+ '--show-usage',
+ action='store_true',
+ help='Show token usage and model on stderr'
+ )
+
args = parser.parse_args()
try:
@@ -943,6 +967,7 @@ def main():
token=args.token,
explainable=args.explainable,
debug=args.debug,
+ show_usage=args.show_usage,
)
except Exception as e:
diff --git a/trustgraph-cli/trustgraph/cli/invoke_llm.py b/trustgraph-cli/trustgraph/cli/invoke_llm.py
index a1611625..3bf521f6 100644
--- a/trustgraph-cli/trustgraph/cli/invoke_llm.py
+++ b/trustgraph-cli/trustgraph/cli/invoke_llm.py
@@ -10,7 +10,8 @@ from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
-def query(url, flow_id, system, prompt, streaming=True, token=None):
+def query(url, flow_id, system, prompt, streaming=True, token=None,
+ show_usage=False):
# Create API client
api = Api(url=url, token=token)
@@ -26,14 +27,29 @@ def query(url, flow_id, system, prompt, streaming=True, token=None):
)
if streaming:
- # Stream output to stdout without newline
+ last_chunk = None
for chunk in response:
- print(chunk, end="", flush=True)
- # Add final newline after streaming
+ print(chunk.content, end="", flush=True)
+ last_chunk = chunk
print()
+
+ if show_usage and last_chunk:
+ print(
+ f"Input tokens: {last_chunk.in_token} "
+ f"Output tokens: {last_chunk.out_token} "
+ f"Model: {last_chunk.model}",
+ file=__import__('sys').stderr,
+ )
else:
- # Non-streaming: print complete response
- print(response)
+ print(response.text)
+
+ if show_usage:
+ print(
+ f"Input tokens: {response.in_token} "
+ f"Output tokens: {response.out_token} "
+ f"Model: {response.model}",
+ file=__import__('sys').stderr,
+ )
finally:
# Clean up socket connection
@@ -82,6 +98,12 @@ def main():
help='Disable streaming (default: streaming enabled)'
)
+ parser.add_argument(
+ '--show-usage',
+ action='store_true',
+ help='Show token usage and model on stderr'
+ )
+
args = parser.parse_args()
try:
@@ -93,6 +115,7 @@ def main():
prompt=args.prompt[0],
streaming=not args.no_streaming,
token=args.token,
+ show_usage=args.show_usage,
)
except Exception as e:
diff --git a/trustgraph-cli/trustgraph/cli/invoke_prompt.py b/trustgraph-cli/trustgraph/cli/invoke_prompt.py
index 09cc9043..86f7a024 100644
--- a/trustgraph-cli/trustgraph/cli/invoke_prompt.py
+++ b/trustgraph-cli/trustgraph/cli/invoke_prompt.py
@@ -15,7 +15,8 @@ from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
-def query(url, flow_id, template_id, variables, streaming=True, token=None):
+def query(url, flow_id, template_id, variables, streaming=True, token=None,
+ show_usage=False):
# Create API client
api = Api(url=url, token=token)
@@ -31,16 +32,30 @@ def query(url, flow_id, template_id, variables, streaming=True, token=None):
)
if streaming:
- # Stream output (prompt yields strings directly)
+ last_chunk = None
for chunk in response:
- if chunk:
- print(chunk, end="", flush=True)
- # Add final newline after streaming
+ if chunk.content:
+ print(chunk.content, end="", flush=True)
+ last_chunk = chunk
print()
+ if show_usage and last_chunk:
+ print(
+ f"Input tokens: {last_chunk.in_token} "
+ f"Output tokens: {last_chunk.out_token} "
+ f"Model: {last_chunk.model}",
+ file=__import__('sys').stderr,
+ )
else:
- # Non-streaming: print complete response
- print(response)
+ print(response.text)
+
+ if show_usage:
+ print(
+ f"Input tokens: {response.in_token} "
+ f"Output tokens: {response.out_token} "
+ f"Model: {response.model}",
+ file=__import__('sys').stderr,
+ )
finally:
# Clean up socket connection
@@ -92,6 +107,12 @@ specified multiple times''',
help='Disable streaming (default: streaming enabled for text responses)'
)
+ parser.add_argument(
+ '--show-usage',
+ action='store_true',
+ help='Show token usage and model on stderr'
+ )
+
args = parser.parse_args()
variables = {}
@@ -113,6 +134,7 @@ specified multiple times''',
variables=variables,
streaming=not args.no_streaming,
token=args.token,
+ show_usage=args.show_usage,
)
except Exception as e:
diff --git a/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py b/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py
index 82e48456..7b1ae9a6 100644
--- a/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py
+++ b/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py
@@ -62,6 +62,11 @@ def sparql_query(url, token, flow_id, query, user, collection, limit,
limit=limit,
batch_size=batch_size,
):
+ if "error" in response:
+ err = response["error"]
+ msg = err.get("message", err) if isinstance(err, dict) else err
+ raise RuntimeError(msg)
+
query_type = response.get("query-type", "select")
# ASK queries - just print and return
diff --git a/trustgraph-cli/trustgraph/cli/verify_system_status.py b/trustgraph-cli/trustgraph/cli/verify_system_status.py
index d7aa1d93..9491deaa 100644
--- a/trustgraph-cli/trustgraph/cli/verify_system_status.py
+++ b/trustgraph-cli/trustgraph/cli/verify_system_status.py
@@ -403,15 +403,8 @@ def main():
# Phase 1: Infrastructure
print(tr.t("cli.verify_system_status.phase_1"))
print("-" * 60)
- if not checker.run_check(
- tr.t("cli.verify_system_status.check_name.pulsar"),
- check_pulsar,
- args.pulsar_url,
- args.check_timeout,
- tr,
- ):
- print(f"\n⚠️ {tr.t('cli.verify_system_status.pulsar_not_responding')}")
- print()
+ # Pulsar check is skipped — not all deployments use Pulsar.
+ # The API Gateway check covers broker connectivity indirectly.
checker.run_check(
tr.t("cli.verify_system_status.check_name.api_gateway"),
diff --git a/trustgraph-embeddings-hf/pyproject.toml b/trustgraph-embeddings-hf/pyproject.toml
index 1bc7bcb4..459f6123 100644
--- a/trustgraph-embeddings-hf/pyproject.toml
+++ b/trustgraph-embeddings-hf/pyproject.toml
@@ -10,8 +10,8 @@ description = "HuggingFace embeddings support for TrustGraph."
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
- "trustgraph-base>=2.2,<2.3",
- "trustgraph-flow>=2.2,<2.3",
+ "trustgraph-base>=2.3,<2.4",
+ "trustgraph-flow>=2.3,<2.4",
"torch",
"urllib3",
"transformers",
diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml
index b2df4a4c..14e919c0 100644
--- a/trustgraph-flow/pyproject.toml
+++ b/trustgraph-flow/pyproject.toml
@@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
- "trustgraph-base>=2.2,<2.3",
+ "trustgraph-base>=2.3,<2.4",
"aiohttp",
"anthropic",
"scylla-driver",
diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/meta_router.py b/trustgraph-flow/trustgraph/agent/orchestrator/meta_router.py
index c3b1afa6..97b87134 100644
--- a/trustgraph-flow/trustgraph/agent/orchestrator/meta_router.py
+++ b/trustgraph-flow/trustgraph/agent/orchestrator/meta_router.py
@@ -53,7 +53,7 @@ class MetaRouter:
"general": {"name": "general", "description": "General queries", "valid_patterns": ["react"], "framing": ""},
}
- async def identify_task_type(self, question, context):
+ async def identify_task_type(self, question, context, usage=None):
"""
Use the LLM to classify the question into one of the known task types.
@@ -71,7 +71,7 @@ class MetaRouter:
try:
client = context("prompt-request")
- response = await client.prompt(
+ result = await client.prompt(
id="task-type-classify",
variables={
"question": question,
@@ -81,7 +81,9 @@ class MetaRouter:
],
},
)
- selected = response.strip().lower().replace('"', '').replace("'", "")
+ if usage:
+ usage.track(result)
+ selected = result.text.strip().lower().replace('"', '').replace("'", "")
if selected in self.task_types:
framing = self.task_types[selected].get("framing", DEFAULT_FRAMING)
@@ -100,7 +102,7 @@ class MetaRouter:
)
return DEFAULT_TASK_TYPE, framing
- async def select_pattern(self, question, task_type, context):
+ async def select_pattern(self, question, task_type, context, usage=None):
"""
Use the LLM to select the best execution pattern for this task type.
@@ -120,7 +122,7 @@ class MetaRouter:
try:
client = context("prompt-request")
- response = await client.prompt(
+ result = await client.prompt(
id="pattern-select",
variables={
"question": question,
@@ -133,7 +135,9 @@ class MetaRouter:
],
},
)
- selected = response.strip().lower().replace('"', '').replace("'", "")
+ if usage:
+ usage.track(result)
+ selected = result.text.strip().lower().replace('"', '').replace("'", "")
if selected in valid_patterns:
logger.info(f"MetaRouter: selected pattern '{selected}'")
@@ -148,19 +152,20 @@ class MetaRouter:
logger.warning(f"MetaRouter: pattern selection failed: {e}")
return valid_patterns[0] if valid_patterns else DEFAULT_PATTERN
- async def route(self, question, context):
+ async def route(self, question, context, usage=None):
"""
Full routing pipeline: identify task type, then select pattern.
Args:
question: The user's query.
context: UserAwareContext (flow wrapper).
+ usage: Optional UsageTracker for token counting.
Returns:
(pattern, task_type, framing) tuple.
"""
- task_type, framing = await self.identify_task_type(question, context)
- pattern = await self.select_pattern(question, task_type, context)
+ task_type, framing = await self.identify_task_type(question, context, usage=usage)
+ pattern = await self.select_pattern(question, task_type, context, usage=usage)
logger.info(
f"MetaRouter: route result — "
f"pattern={pattern}, task_type={task_type}, framing={framing!r}"
diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py
index c18c5bac..88d4ee72 100644
--- a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py
+++ b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py
@@ -25,6 +25,7 @@ from trustgraph.provenance import (
agent_plan_uri,
agent_step_result_uri,
agent_synthesis_uri,
+ agent_pattern_decision_uri,
agent_session_triples,
agent_iteration_triples,
agent_observation_triples,
@@ -34,6 +35,7 @@ from trustgraph.provenance import (
agent_plan_triples,
agent_step_result_triples,
agent_synthesis_triples,
+ agent_pattern_decision_triples,
set_graph,
GRAPH_RETRIEVAL,
)
@@ -65,6 +67,37 @@ class UserAwareContext:
return client
+class UsageTracker:
+ """Accumulates token usage across multiple prompt calls."""
+
+ def __init__(self):
+ self.total_in = 0
+ self.total_out = 0
+ self.last_model = None
+
+ def track(self, result):
+ """Track usage from a PromptResult."""
+ if result is not None:
+ if getattr(result, "in_token", None) is not None:
+ self.total_in += result.in_token
+ if getattr(result, "out_token", None) is not None:
+ self.total_out += result.out_token
+ if getattr(result, "model", None) is not None:
+ self.last_model = result.model
+
+ @property
+ def in_token(self):
+ return self.total_in if self.total_in > 0 else None
+
+ @property
+ def out_token(self):
+ return self.total_out if self.total_out > 0 else None
+
+ @property
+ def model(self):
+ return self.last_model
+
+
class PatternBase:
"""
Shared infrastructure for all agent patterns.
@@ -151,7 +184,7 @@ class PatternBase:
logger.debug(f"Think: {x} (is_final={is_final})")
if streaming:
r = AgentResponse(
- chunk_type="thought",
+ message_type="thought",
content=x,
end_of_message=is_final,
end_of_dialog=False,
@@ -159,7 +192,7 @@ class PatternBase:
)
else:
r = AgentResponse(
- chunk_type="thought",
+ message_type="thought",
content=x,
end_of_message=True,
end_of_dialog=False,
@@ -174,7 +207,7 @@ class PatternBase:
logger.debug(f"Observe: {x} (is_final={is_final})")
if streaming:
r = AgentResponse(
- chunk_type="observation",
+ message_type="observation",
content=x,
end_of_message=is_final,
end_of_dialog=False,
@@ -182,7 +215,7 @@ class PatternBase:
)
else:
r = AgentResponse(
- chunk_type="observation",
+ message_type="observation",
content=x,
end_of_message=True,
end_of_dialog=False,
@@ -197,7 +230,7 @@ class PatternBase:
logger.debug(f"Answer: {x}")
if streaming:
r = AgentResponse(
- chunk_type="answer",
+ message_type="answer",
content=x,
end_of_message=False,
end_of_dialog=False,
@@ -205,7 +238,7 @@ class PatternBase:
)
else:
r = AgentResponse(
- chunk_type="answer",
+ message_type="answer",
content=x,
end_of_message=True,
end_of_dialog=False,
@@ -239,16 +272,43 @@ class PatternBase:
logger.debug(f"Emitted session triples for {session_uri}")
await respond(AgentResponse(
- chunk_type="explain",
+ message_type="explain",
content="",
explain_id=session_uri,
explain_graph=GRAPH_RETRIEVAL,
explain_triples=triples,
))
+ async def emit_pattern_decision_triples(
+ self, flow, session_id, session_uri, pattern, task_type,
+ user, collection, respond,
+ ):
+ """Emit provenance triples for a meta-router pattern decision."""
+ uri = agent_pattern_decision_uri(session_id)
+ triples = set_graph(
+ agent_pattern_decision_triples(
+ uri, session_uri, pattern, task_type,
+ ),
+ GRAPH_RETRIEVAL,
+ )
+ await flow("explainability").send(Triples(
+ metadata=Metadata(id=uri, user=user, collection=collection),
+ triples=triples,
+ ))
+ await respond(AgentResponse(
+ message_type="explain", content="",
+ explain_id=uri, explain_graph=GRAPH_RETRIEVAL,
+ explain_triples=triples,
+ ))
+ return uri
+
async def emit_iteration_triples(self, flow, session_id, iteration_num,
session_uri, act, request, respond,
- streaming):
+ streaming, tool_candidates=None,
+ step_number=None,
+ llm_duration_ms=None,
+ in_token=None, out_token=None,
+ model=None):
"""Emit provenance triples for an iteration (Analysis+ToolUse)."""
iteration_uri = agent_iteration_uri(session_id, iteration_num)
@@ -288,6 +348,12 @@ class PatternBase:
arguments=act.arguments,
thought_uri=thought_entity_uri if thought_doc_id else None,
thought_document_id=thought_doc_id,
+ tool_candidates=tool_candidates,
+ step_number=step_number,
+ llm_duration_ms=llm_duration_ms,
+ in_token=in_token,
+ out_token=out_token,
+ model=model,
),
GRAPH_RETRIEVAL,
)
@@ -302,7 +368,7 @@ class PatternBase:
logger.debug(f"Emitted iteration triples for {iteration_uri}")
await respond(AgentResponse(
- chunk_type="explain",
+ message_type="explain",
content="",
explain_id=iteration_uri,
explain_graph=GRAPH_RETRIEVAL,
@@ -311,7 +377,9 @@ class PatternBase:
async def emit_observation_triples(self, flow, session_id, iteration_num,
observation_text, request, respond,
- context=None):
+ context=None,
+ tool_duration_ms=None,
+ tool_error=None):
"""Emit provenance triples for a standalone Observation entity."""
iteration_uri = agent_iteration_uri(session_id, iteration_num)
observation_entity_uri = agent_observation_uri(session_id, iteration_num)
@@ -344,6 +412,8 @@ class PatternBase:
observation_entity_uri,
parent_uri,
document_id=observation_doc_id,
+ tool_duration_ms=tool_duration_ms,
+ tool_error=tool_error,
),
GRAPH_RETRIEVAL,
)
@@ -358,7 +428,7 @@ class PatternBase:
logger.debug(f"Emitted observation triples for {observation_entity_uri}")
await respond(AgentResponse(
- chunk_type="explain",
+ message_type="explain",
content="",
explain_id=observation_entity_uri,
explain_graph=GRAPH_RETRIEVAL,
@@ -367,7 +437,7 @@ class PatternBase:
async def emit_final_triples(self, flow, session_id, iteration_num,
session_uri, answer_text, request, respond,
- streaming):
+ streaming, termination_reason=None):
"""Emit provenance triples for the final answer and save to librarian."""
final_uri = agent_final_uri(session_id)
@@ -401,6 +471,7 @@ class PatternBase:
question_uri=final_question_uri,
previous_uri=final_previous_uri,
document_id=answer_doc_id,
+ termination_reason=termination_reason,
),
GRAPH_RETRIEVAL,
)
@@ -415,7 +486,7 @@ class PatternBase:
logger.debug(f"Emitted final triples for {final_uri}")
await respond(AgentResponse(
- chunk_type="explain",
+ message_type="explain",
content="",
explain_id=final_uri,
explain_graph=GRAPH_RETRIEVAL,
@@ -439,7 +510,7 @@ class PatternBase:
triples=triples,
))
await respond(AgentResponse(
- chunk_type="explain", content="",
+ message_type="explain", content="",
explain_id=uri, explain_graph=GRAPH_RETRIEVAL,
explain_triples=triples,
))
@@ -478,7 +549,7 @@ class PatternBase:
triples=triples,
))
await respond(AgentResponse(
- chunk_type="explain", content="",
+ message_type="explain", content="",
explain_id=uri, explain_graph=GRAPH_RETRIEVAL,
explain_triples=triples,
))
@@ -498,7 +569,7 @@ class PatternBase:
triples=triples,
))
await respond(AgentResponse(
- chunk_type="explain", content="",
+ message_type="explain", content="",
explain_id=uri, explain_graph=GRAPH_RETRIEVAL,
explain_triples=triples,
))
@@ -531,14 +602,14 @@ class PatternBase:
triples=triples,
))
await respond(AgentResponse(
- chunk_type="explain", content="",
+ message_type="explain", content="",
explain_id=uri, explain_graph=GRAPH_RETRIEVAL,
explain_triples=triples,
))
async def emit_synthesis_triples(
self, flow, session_id, previous_uris, answer_text, user, collection,
- respond, streaming,
+ respond, streaming, termination_reason=None,
):
"""Emit provenance for a synthesis answer."""
uri = agent_synthesis_uri(session_id)
@@ -555,7 +626,10 @@ class PatternBase:
doc_id = None
triples = set_graph(
- agent_synthesis_triples(uri, previous_uris, doc_id),
+ agent_synthesis_triples(
+ uri, previous_uris, doc_id,
+ termination_reason=termination_reason,
+ ),
GRAPH_RETRIEVAL,
)
await flow("explainability").send(Triples(
@@ -563,7 +637,7 @@ class PatternBase:
triples=triples,
))
await respond(AgentResponse(
- chunk_type="explain", content="",
+ message_type="explain", content="",
explain_id=uri, explain_graph=GRAPH_RETRIEVAL,
explain_triples=triples,
))
@@ -571,7 +645,8 @@ class PatternBase:
# ---- Response helpers ---------------------------------------------------
async def prompt_as_answer(self, client, prompt_id, variables,
- respond, streaming, message_id=""):
+ respond, streaming, message_id="",
+ usage=None):
"""Call a prompt template, forwarding chunks as answer
AgentResponse messages when streaming is enabled.
@@ -584,29 +659,35 @@ class PatternBase:
if text:
accumulated.append(text)
await respond(AgentResponse(
- chunk_type="answer",
+ message_type="answer",
content=text,
end_of_message=False,
end_of_dialog=False,
message_id=message_id,
))
- await client.prompt(
+ result = await client.prompt(
id=prompt_id,
variables=variables,
streaming=True,
chunk_callback=on_chunk,
)
+ if usage:
+ usage.track(result)
return "".join(accumulated)
else:
- return await client.prompt(
+ result = await client.prompt(
id=prompt_id,
variables=variables,
)
+ if usage:
+ usage.track(result)
+ return result.text
async def send_final_response(self, respond, streaming, answer_text,
- already_streamed=False, message_id=""):
+ already_streamed=False, message_id="",
+ usage=None):
"""Send the answer content and end-of-dialog marker.
Args:
@@ -614,33 +695,44 @@ class PatternBase:
via streaming callbacks (e.g. ReactPattern). Only the
end-of-dialog marker is emitted.
message_id: Provenance URI for the answer entity.
+ usage: UsageTracker with accumulated token counts.
"""
+ usage_kwargs = {}
+ if usage:
+ usage_kwargs = {
+ "in_token": usage.in_token,
+ "out_token": usage.out_token,
+ "model": usage.model,
+ }
+
if streaming and not already_streamed:
# Answer wasn't streamed yet — send it as a chunk first
if answer_text:
await respond(AgentResponse(
- chunk_type="answer",
+ message_type="answer",
content=answer_text,
end_of_message=False,
end_of_dialog=False,
message_id=message_id,
))
if streaming:
- # End-of-dialog marker
+ # End-of-dialog marker with usage
await respond(AgentResponse(
- chunk_type="answer",
+ message_type="answer",
content="",
end_of_message=True,
end_of_dialog=True,
message_id=message_id,
+ **usage_kwargs,
))
else:
await respond(AgentResponse(
- chunk_type="answer",
+ message_type="answer",
content=answer_text,
end_of_message=True,
end_of_dialog=True,
message_id=message_id,
+ **usage_kwargs,
))
def build_next_request(self, request, history, session_id, collection,
diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py
index 59d22929..1de31a92 100644
--- a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py
+++ b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py
@@ -18,7 +18,7 @@ from trustgraph.provenance import (
agent_synthesis_uri,
)
-from . pattern_base import PatternBase
+from . pattern_base import PatternBase, UsageTracker
logger = logging.getLogger(__name__)
@@ -35,7 +35,11 @@ class PlanThenExecutePattern(PatternBase):
Subsequent calls execute the next pending plan step via ReACT.
"""
- async def iterate(self, request, respond, next, flow):
+ async def iterate(self, request, respond, next, flow, usage=None,
+ pattern_decision_uri=None):
+
+ if usage is None:
+ usage = UsageTracker()
streaming = getattr(request, 'streaming', False)
session_id = getattr(request, 'session_id', '') or str(uuid.uuid4())
@@ -63,17 +67,19 @@ class PlanThenExecutePattern(PatternBase):
# Determine current phase by checking history for a plan step
plan = self._extract_plan(request.history)
+ derive_from_uri = pattern_decision_uri or session_uri
+
if plan is None:
await self._planning_iteration(
request, respond, next, flow,
- session_id, collection, streaming, session_uri,
- iteration_num,
+ session_id, collection, streaming, derive_from_uri,
+ iteration_num, usage=usage,
)
else:
await self._execution_iteration(
request, respond, next, flow,
- session_id, collection, streaming, session_uri,
- iteration_num, plan,
+ session_id, collection, streaming, derive_from_uri,
+ iteration_num, plan, usage=usage,
)
def _extract_plan(self, history):
@@ -98,7 +104,7 @@ class PlanThenExecutePattern(PatternBase):
async def _planning_iteration(self, request, respond, next, flow,
session_id, collection, streaming,
- session_uri, iteration_num):
+ session_uri, iteration_num, usage=None):
"""Ask the LLM to produce a structured plan."""
think = self.make_think_callback(respond, streaming)
@@ -113,7 +119,7 @@ class PlanThenExecutePattern(PatternBase):
client = context("prompt-request")
# Use the plan-create prompt template
- plan_steps = await client.prompt(
+ result = await client.prompt(
id="plan-create",
variables={
"question": request.question,
@@ -124,7 +130,10 @@ class PlanThenExecutePattern(PatternBase):
],
},
)
+ if usage:
+ usage.track(result)
+ plan_steps = result.objects
# Validate we got a list
if not isinstance(plan_steps, list) or not plan_steps:
logger.warning("plan-create returned invalid result, falling back to single step")
@@ -187,7 +196,8 @@ class PlanThenExecutePattern(PatternBase):
async def _execution_iteration(self, request, respond, next, flow,
session_id, collection, streaming,
- session_uri, iteration_num, plan):
+ session_uri, iteration_num, plan,
+ usage=None):
"""Execute the next pending plan step via single-shot tool call."""
pending_idx = self._find_next_pending_step(plan)
@@ -198,6 +208,7 @@ class PlanThenExecutePattern(PatternBase):
request, respond, next, flow,
session_id, collection, streaming,
session_uri, iteration_num, plan,
+ usage=usage,
)
return
@@ -240,7 +251,7 @@ class PlanThenExecutePattern(PatternBase):
client = context("prompt-request")
# Single-shot: ask LLM which tool + arguments to use for this goal
- tool_call = await client.prompt(
+ result = await client.prompt(
id="plan-step-execute",
variables={
"goal": goal,
@@ -258,7 +269,10 @@ class PlanThenExecutePattern(PatternBase):
],
},
)
+ if usage:
+ usage.track(result)
+ tool_call = result.object
tool_name = tool_call.get("tool", "")
tool_arguments = tool_call.get("arguments", {})
@@ -330,7 +344,8 @@ class PlanThenExecutePattern(PatternBase):
async def _synthesise(self, request, respond, next, flow,
session_id, collection, streaming,
- session_uri, iteration_num, plan):
+ session_uri, iteration_num, plan,
+ usage=None):
"""Synthesise a final answer from all completed plan step results."""
think = self.make_think_callback(respond, streaming)
@@ -365,6 +380,7 @@ class PlanThenExecutePattern(PatternBase):
respond=respond,
streaming=streaming,
message_id=synthesis_msg_id,
+ usage=usage,
)
# Emit synthesis provenance (links back to last step result)
@@ -372,6 +388,7 @@ class PlanThenExecutePattern(PatternBase):
await self.emit_synthesis_triples(
flow, session_id, last_step_uri,
response_text, request.user, collection, respond, streaming,
+ termination_reason="plan-complete",
)
if self.is_subagent(request):
@@ -380,4 +397,5 @@ class PlanThenExecutePattern(PatternBase):
await self.send_final_response(
respond, streaming, response_text, already_streamed=streaming,
message_id=synthesis_msg_id,
+ usage=usage,
)
diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py
index 67ded823..25264c26 100644
--- a/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py
+++ b/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py
@@ -23,7 +23,7 @@ from ..react.agent_manager import AgentManager
from ..react.types import Action, Final
from ..tool_filter import get_next_state
-from . pattern_base import PatternBase
+from . pattern_base import PatternBase, UsageTracker
logger = logging.getLogger(__name__)
@@ -37,7 +37,11 @@ class ReactPattern(PatternBase):
result is appended to history and a next-request is emitted.
"""
- async def iterate(self, request, respond, next, flow):
+ async def iterate(self, request, respond, next, flow, usage=None,
+ pattern_decision_uri=None):
+
+ if usage is None:
+ usage = UsageTracker()
streaming = getattr(request, 'streaming', False)
session_id = getattr(request, 'session_id', '') or str(uuid.uuid4())
@@ -105,11 +109,23 @@ class ReactPattern(PatternBase):
session_id, iteration_num,
)
+ # Tool names available to the LLM for this iteration
+ tool_candidates = [t.name for t in filtered_tools.values()]
+
+ # Use pattern decision as derivation source if available
+ derive_from_uri = pattern_decision_uri or session_uri
+
# Callback: emit Analysis+ToolUse triples before tool executes
async def on_action(act):
await self.emit_iteration_triples(
- flow, session_id, iteration_num, session_uri,
+ flow, session_id, iteration_num, derive_from_uri,
act, request, respond, streaming,
+ tool_candidates=tool_candidates,
+ step_number=iteration_num,
+ llm_duration_ms=getattr(act, 'llm_duration_ms', None),
+ in_token=getattr(act, 'in_token', None),
+ out_token=getattr(act, 'out_token', None),
+ model=getattr(act, 'llm_model', None),
)
act = await temp_agent.react(
@@ -121,6 +137,7 @@ class ReactPattern(PatternBase):
context=context,
streaming=streaming,
on_action=on_action,
+ usage=usage,
)
logger.debug(f"Action: {act}")
@@ -134,8 +151,9 @@ class ReactPattern(PatternBase):
# Emit final provenance
await self.emit_final_triples(
- flow, session_id, iteration_num, session_uri,
+ flow, session_id, iteration_num, derive_from_uri,
f, request, respond, streaming,
+ termination_reason="final-answer",
)
if self.is_subagent(request):
@@ -144,6 +162,7 @@ class ReactPattern(PatternBase):
await self.send_final_response(
respond, streaming, f, already_streamed=streaming,
message_id=answer_msg_id,
+ usage=usage,
)
return
@@ -152,6 +171,8 @@ class ReactPattern(PatternBase):
flow, session_id, iteration_num,
act.observation, request, respond,
context=context,
+ tool_duration_ms=getattr(act, 'tool_duration_ms', None),
+ tool_error=getattr(act, 'tool_error', None),
)
history.append(act)
diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/service.py b/trustgraph-flow/trustgraph/agent/orchestrator/service.py
index 5bf8e2fd..3d08154d 100644
--- a/trustgraph-flow/trustgraph/agent/orchestrator/service.py
+++ b/trustgraph-flow/trustgraph/agent/orchestrator/service.py
@@ -23,6 +23,7 @@ from ... base import Consumer, Producer
from ... base import ConsumerMetrics, ProducerMetrics
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
+from ..orchestrator.pattern_base import UsageTracker, PatternBase
from ... schema import Triples, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue
@@ -493,6 +494,8 @@ class Processor(AgentService):
async def agent_request(self, request, respond, next, flow):
+ usage = UsageTracker()
+
try:
# Intercept subagent completion messages
@@ -516,7 +519,7 @@ class Processor(AgentService):
if self.meta_router:
pattern, task_type, framing = await self.meta_router.route(
- request.question, context,
+ request.question, context, usage=usage,
)
else:
pattern = "react"
@@ -534,19 +537,31 @@ class Processor(AgentService):
)
# Dispatch to the selected pattern
+ selected = self.react_pattern
if pattern == "plan-then-execute":
- await self.plan_pattern.iterate(
- request, respond, next, flow,
- )
+ selected = self.plan_pattern
elif pattern == "supervisor":
- await self.supervisor_pattern.iterate(
- request, respond, next, flow,
- )
- else:
- # Default to react
- await self.react_pattern.iterate(
- request, respond, next, flow,
- )
+ selected = self.supervisor_pattern
+
+ # Emit pattern decision provenance on first iteration
+ pattern_decision_uri = None
+ if not request.history and pattern:
+ session_id = getattr(request, 'session_id', '')
+ if session_id:
+ session_uri = self.provenance_session_uri(session_id)
+ pattern_decision_uri = \
+ await selected.emit_pattern_decision_triples(
+ flow, session_id, session_uri,
+ pattern, getattr(request, 'task_type', ''),
+ request.user,
+ getattr(request, 'collection', 'default'),
+ respond,
+ )
+
+ await selected.iterate(
+ request, respond, next, flow, usage=usage,
+ pattern_decision_uri=pattern_decision_uri,
+ )
except Exception as e:
@@ -562,7 +577,7 @@ class Processor(AgentService):
)
r = AgentResponse(
- chunk_type="error",
+ message_type="error",
content=str(e),
end_of_message=True,
end_of_dialog=True,
diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py
index d5537876..973a9966 100644
--- a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py
+++ b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py
@@ -22,7 +22,7 @@ from trustgraph.provenance import (
agent_synthesis_uri,
)
-from . pattern_base import PatternBase
+from . pattern_base import PatternBase, UsageTracker
logger = logging.getLogger(__name__)
@@ -38,7 +38,11 @@ class SupervisorPattern(PatternBase):
- "synthesise": triggered by aggregator with results in subagent_results
"""
- async def iterate(self, request, respond, next, flow):
+ async def iterate(self, request, respond, next, flow, usage=None,
+ pattern_decision_uri=None):
+
+ if usage is None:
+ usage = UsageTracker()
streaming = getattr(request, 'streaming', False)
session_id = getattr(request, 'session_id', '') or str(uuid.uuid4())
@@ -67,22 +71,26 @@ class SupervisorPattern(PatternBase):
)
)
+ derive_from_uri = pattern_decision_uri or session_uri
+
if has_results:
await self._synthesise(
request, respond, next, flow,
session_id, collection, streaming,
- session_uri, iteration_num,
+ derive_from_uri, iteration_num,
+ usage=usage,
)
else:
await self._decompose_and_fanout(
request, respond, next, flow,
session_id, collection, streaming,
- session_uri, iteration_num,
+ derive_from_uri, iteration_num,
+ usage=usage,
)
async def _decompose_and_fanout(self, request, respond, next, flow,
session_id, collection, streaming,
- session_uri, iteration_num):
+ session_uri, iteration_num, usage=None):
"""Decompose the question into sub-goals and fan out subagents."""
decompose_msg_id = agent_decomposition_uri(session_id)
@@ -100,7 +108,7 @@ class SupervisorPattern(PatternBase):
client = context("prompt-request")
# Use the supervisor-decompose prompt template
- goals = await client.prompt(
+ result = await client.prompt(
id="supervisor-decompose",
variables={
"question": request.question,
@@ -112,7 +120,10 @@ class SupervisorPattern(PatternBase):
],
},
)
+ if usage:
+ usage.track(result)
+ goals = result.objects
# Validate result
if not isinstance(goals, list):
goals = []
@@ -175,7 +186,7 @@ class SupervisorPattern(PatternBase):
async def _synthesise(self, request, respond, next, flow,
session_id, collection, streaming,
- session_uri, iteration_num):
+ session_uri, iteration_num, usage=None):
"""Synthesise final answer from subagent results."""
synthesis_msg_id = agent_synthesis_uri(session_id)
@@ -216,6 +227,7 @@ class SupervisorPattern(PatternBase):
respond=respond,
streaming=streaming,
message_id=synthesis_msg_id,
+ usage=usage,
)
# Emit synthesis provenance (links back to all findings)
@@ -226,9 +238,11 @@ class SupervisorPattern(PatternBase):
await self.emit_synthesis_triples(
flow, session_id, finding_uris,
response_text, request.user, collection, respond, streaming,
+ termination_reason="subagents-complete",
)
await self.send_final_response(
respond, streaming, response_text, already_streamed=streaming,
message_id=synthesis_msg_id,
+ usage=usage,
)
diff --git a/trustgraph-flow/trustgraph/agent/react/agent_manager.py b/trustgraph-flow/trustgraph/agent/react/agent_manager.py
index e86a2d6c..82a8f905 100644
--- a/trustgraph-flow/trustgraph/agent/react/agent_manager.py
+++ b/trustgraph-flow/trustgraph/agent/react/agent_manager.py
@@ -3,6 +3,7 @@ import logging
import json
import re
import asyncio
+import time
from . types import Action, Final
@@ -170,7 +171,7 @@ class AgentManager:
raise ValueError(f"Could not parse response: {text}")
- async def reason(self, question, history, context, streaming=False, think=None, observe=None, answer=None):
+ async def reason(self, question, history, context, streaming=False, think=None, observe=None, answer=None, usage=None):
logger.debug(f"calling reason: {question}")
@@ -255,11 +256,14 @@ class AgentManager:
client = context("prompt-request")
# Get streaming response
- response_text = await client.agent_react(
+ prompt_result = await client.agent_react(
variables=variables,
streaming=True,
chunk_callback=on_chunk
)
+ self._last_prompt_result = prompt_result
+ if usage:
+ usage.track(prompt_result)
# Finalize parser
parser.finalize()
@@ -267,7 +271,13 @@ class AgentManager:
# Get result
result = parser.get_result()
if result is None:
- raise RuntimeError("Parser failed to produce a result")
+ return Action(
+ thought="",
+ name="__parse_error__",
+ arguments={},
+ observation="",
+ tool_error="LLM response could not be parsed (streaming)",
+ )
return result
@@ -275,10 +285,14 @@ class AgentManager:
# Non-streaming path - get complete text and parse
client = context("prompt-request")
- response_text = await client.agent_react(
+ prompt_result = await client.agent_react(
variables=variables,
streaming=False
)
+ self._last_prompt_result = prompt_result
+ if usage:
+ usage.track(prompt_result)
+ response_text = prompt_result.text
logger.debug(f"Response text:\n{response_text}")
@@ -289,11 +303,19 @@ class AgentManager:
except ValueError as e:
logger.error(f"Failed to parse response: {e}")
logger.error(f"Response was: {response_text}")
- raise RuntimeError(f"Failed to parse agent response: {e}")
+ return Action(
+ thought="",
+ name="__parse_error__",
+ arguments={},
+ observation="",
+ tool_error=f"LLM parse error: {e}",
+ )
async def react(self, question, history, think, observe, context,
- streaming=False, answer=None, on_action=None):
+ streaming=False, answer=None, on_action=None,
+ usage=None):
+ t0 = time.monotonic()
act = await self.reason(
question = question,
history = history,
@@ -302,7 +324,14 @@ class AgentManager:
think = think,
observe = observe,
answer = answer,
+ usage = usage,
)
+ act.llm_duration_ms = int((time.monotonic() - t0) * 1000)
+ pr = getattr(self, '_last_prompt_result', None)
+ if pr:
+ act.in_token = pr.in_token
+ act.out_token = pr.out_token
+ act.llm_model = pr.model
if isinstance(act, Final):
@@ -321,24 +350,43 @@ class AgentManager:
logger.debug(f"ACTION: {act.name}")
+ # Notify caller before tool execution (for provenance)
+ if on_action:
+ await on_action(act)
+
+ # Handle parse errors — skip tool execution
+ if act.name == "__parse_error__":
+ resp = f"Error: {act.tool_error}"
+ act.tool_duration_ms = 0
+ await observe(resp, is_final=True)
+ act.observation = resp
+ return act
+
if act.name in self.tools:
action = self.tools[act.name]
else:
raise RuntimeError(f"No action for {act.name}!")
- # Notify caller before tool execution (for provenance)
- if on_action:
- await on_action(act)
+ t0 = time.monotonic()
+ try:
+ resp = await action.implementation(context).invoke(
+ **act.arguments
+ )
- resp = await action.implementation(context).invoke(
- **act.arguments
- )
+ if isinstance(resp, str):
+ resp = resp.strip()
+ else:
+ resp = str(resp)
+ resp = resp.strip()
- if isinstance(resp, str):
- resp = resp.strip()
- else:
- resp = str(resp)
- resp = resp.strip()
+ act.tool_error = None
+
+ except Exception as e:
+ logger.error(f"Tool execution error ({act.name}): {e}")
+ resp = f"Error: {e}"
+ act.tool_error = str(e)
+
+ act.tool_duration_ms = int((time.monotonic() - t0) * 1000)
await observe(resp, is_final=True)
diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py
index 2c7423d8..00432181 100755
--- a/trustgraph-flow/trustgraph/agent/react/service.py
+++ b/trustgraph-flow/trustgraph/agent/react/service.py
@@ -469,7 +469,7 @@ class Processor(AgentService):
# Send explain event for session
await respond(AgentResponse(
- chunk_type="explain",
+ message_type="explain",
content="",
explain_id=session_uri,
explain_graph=GRAPH_RETRIEVAL,
@@ -492,7 +492,7 @@ class Processor(AgentService):
if streaming:
r = AgentResponse(
- chunk_type="thought",
+ message_type="thought",
content=x,
end_of_message=is_final,
end_of_dialog=False,
@@ -500,7 +500,7 @@ class Processor(AgentService):
)
else:
r = AgentResponse(
- chunk_type="thought",
+ message_type="thought",
content=x,
end_of_message=True,
end_of_dialog=False,
@@ -515,7 +515,7 @@ class Processor(AgentService):
if streaming:
r = AgentResponse(
- chunk_type="observation",
+ message_type="observation",
content=x,
end_of_message=is_final,
end_of_dialog=False,
@@ -523,7 +523,7 @@ class Processor(AgentService):
)
else:
r = AgentResponse(
- chunk_type="observation",
+ message_type="observation",
content=x,
end_of_message=True,
end_of_dialog=False,
@@ -540,7 +540,7 @@ class Processor(AgentService):
if streaming:
r = AgentResponse(
- chunk_type="answer",
+ message_type="answer",
content=x,
end_of_message=False,
end_of_dialog=False,
@@ -548,7 +548,7 @@ class Processor(AgentService):
)
else:
r = AgentResponse(
- chunk_type="answer",
+ message_type="answer",
content=x,
end_of_message=True,
end_of_dialog=False,
@@ -637,7 +637,7 @@ class Processor(AgentService):
logger.debug(f"Emitted iteration triples for {iter_uri}")
await respond(AgentResponse(
- chunk_type="explain",
+ message_type="explain",
content="",
explain_id=iter_uri,
explain_graph=GRAPH_RETRIEVAL,
@@ -715,7 +715,7 @@ class Processor(AgentService):
# Send explain event for conclusion
await respond(AgentResponse(
- chunk_type="explain",
+ message_type="explain",
content="",
explain_id=final_uri,
explain_graph=GRAPH_RETRIEVAL,
@@ -725,7 +725,7 @@ class Processor(AgentService):
if streaming:
# End-of-dialog marker — answer chunks already sent via callback
r = AgentResponse(
- chunk_type="answer",
+ message_type="answer",
content="",
end_of_message=True,
end_of_dialog=True,
@@ -733,7 +733,7 @@ class Processor(AgentService):
)
else:
r = AgentResponse(
- chunk_type="answer",
+ message_type="answer",
content=f,
end_of_message=True,
end_of_dialog=True,
@@ -792,7 +792,7 @@ class Processor(AgentService):
# Send explain event for observation
await respond(AgentResponse(
- chunk_type="explain",
+ message_type="explain",
content="",
explain_id=observation_entity_uri,
explain_graph=GRAPH_RETRIEVAL,
@@ -847,7 +847,7 @@ class Processor(AgentService):
streaming = getattr(request, 'streaming', False) if 'request' in locals() else False
r = AgentResponse(
- chunk_type="error",
+ message_type="error",
content=str(e),
end_of_message=True,
end_of_dialog=True,
diff --git a/trustgraph-flow/trustgraph/agent/react/tools.py b/trustgraph-flow/trustgraph/agent/react/tools.py
index 6fd96ade..6674c999 100644
--- a/trustgraph-flow/trustgraph/agent/react/tools.py
+++ b/trustgraph-flow/trustgraph/agent/react/tools.py
@@ -42,7 +42,7 @@ class KnowledgeQueryImpl:
async def explain_callback(explain_id, explain_graph, explain_triples=None):
self.context.last_sub_explain_uri = explain_id
await respond(AgentResponse(
- chunk_type="explain",
+ message_type="explain",
content="",
explain_id=explain_id,
explain_graph=explain_graph,
@@ -78,9 +78,10 @@ class TextCompletionImpl:
async def invoke(self, **arguments):
client = self.context("prompt-request")
logger.debug("Prompt question...")
- return await client.question(
+ result = await client.question(
arguments.get("question")
)
+ return result.text
# This tool implementation knows how to do MCP tool invocation. This uses
# the mcp-tool service.
@@ -227,10 +228,11 @@ class PromptImpl:
async def invoke(self, **arguments):
client = self.context("prompt-request")
logger.debug(f"Prompt template invocation: {self.template_id}...")
- return await client.prompt(
+ result = await client.prompt(
id=self.template_id,
variables=arguments
)
+ return result.text
# This tool implementation invokes a dynamically configured tool service
diff --git a/trustgraph-flow/trustgraph/agent/react/types.py b/trustgraph-flow/trustgraph/agent/react/types.py
index 7180db3e..ee0a677f 100644
--- a/trustgraph-flow/trustgraph/agent/react/types.py
+++ b/trustgraph-flow/trustgraph/agent/react/types.py
@@ -22,9 +22,19 @@ class Action:
name : str
arguments : dict
observation : str
-
+ llm_duration_ms : int = None
+ tool_duration_ms : int = None
+ tool_error : str = None
+ in_token : int = None
+ out_token : int = None
+ llm_model : str = None
+
@dataclasses.dataclass
class Final:
thought : str
final : str
+ llm_duration_ms : int = None
+ in_token : int = None
+ out_token : int = None
+ llm_model : str = None
diff --git a/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py b/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py
index 1a03ac9f..a5fee382 100755
--- a/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py
+++ b/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py
@@ -4,6 +4,7 @@ Embeddings service, applies an embeddings model using fastembed
Input is text, output is embeddings vector.
"""
+import asyncio
import logging
from ... base import EmbeddingsService
@@ -37,7 +38,13 @@ class Processor(EmbeddingsService):
self._load_model(model)
def _load_model(self, model_name):
- """Load a model, caching it for reuse"""
+ """Load a model, caching it for reuse.
+
+ Synchronous — CPU and I/O heavy. Callers that run on the
+ event loop must dispatch via asyncio.to_thread to avoid
+ freezing the loop (which, in processor-group deployments,
+ freezes every sibling processor in the same process).
+ """
if self.cached_model_name != model_name:
logger.info(f"Loading FastEmbed model: {model_name}")
self.embeddings = TextEmbedding(model_name=model_name)
@@ -46,6 +53,11 @@ class Processor(EmbeddingsService):
else:
logger.debug(f"Using cached model: {model_name}")
+ def _run_embed(self, texts):
+ """Synchronous embed call. Runs in a worker thread via
+ asyncio.to_thread from on_embeddings."""
+ return list(self.embeddings.embed(texts))
+
async def on_embeddings(self, texts, model=None):
if not texts:
@@ -53,11 +65,18 @@ class Processor(EmbeddingsService):
use_model = model or self.default_model
- # Reload model if it has changed
- self._load_model(use_model)
+ # Reload model if it has changed. Model loading is sync
+ # and can take seconds; push it to a worker thread so the
+ # event loop (and any sibling processors in group mode)
+ # stay responsive.
+ if self.cached_model_name != use_model:
+ await asyncio.to_thread(self._load_model, use_model)
- # FastEmbed processes the full batch efficiently
- vecs = list(self.embeddings.embed(texts))
+ # FastEmbed inference is synchronous ONNX runtime work.
+ # Dispatch to a worker thread so the event loop stays
+ # responsive for other tasks (important in group mode
+ # where the loop is shared across many processors).
+ vecs = await asyncio.to_thread(self._run_embed, texts)
# Return list of vectors, one per input text
return [v.tolist() for v in vecs]
diff --git a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py
index 2bb88c8a..9b5bbb79 100755
--- a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py
+++ b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py
@@ -117,10 +117,11 @@ class Processor(FlowProcessor):
try:
- defs = await flow("prompt-request").extract_definitions(
+ result = await flow("prompt-request").extract_definitions(
text = chunk
)
+ defs = result.objects
logger.debug(f"Definitions response: {defs}")
if type(defs) != list:
diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py
index 29808cae..bdb0e6e8 100644
--- a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py
+++ b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py
@@ -376,10 +376,11 @@ class Processor(FlowProcessor):
"""
try:
# Call prompt service with simplified format prompt
- extraction_response = await flow("prompt-request").prompt(
+ result = await flow("prompt-request").prompt(
id="extract-with-ontologies",
variables=prompt_variables
)
+ extraction_response = result.object
logger.debug(f"Simplified extraction response: {extraction_response}")
# Parse response into structured format
diff --git a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py
index b557ec32..8068a23d 100755
--- a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py
+++ b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py
@@ -100,10 +100,11 @@ class Processor(FlowProcessor):
try:
- rels = await flow("prompt-request").extract_relationships(
+ result = await flow("prompt-request").extract_relationships(
text = chunk
)
+ rels = result.objects
logger.debug(f"Prompt response: {rels}")
if type(rels) != list:
diff --git a/trustgraph-flow/trustgraph/extract/kg/rows/processor.py b/trustgraph-flow/trustgraph/extract/kg/rows/processor.py
index 8fd494b0..973bb3d7 100644
--- a/trustgraph-flow/trustgraph/extract/kg/rows/processor.py
+++ b/trustgraph-flow/trustgraph/extract/kg/rows/processor.py
@@ -148,11 +148,12 @@ class Processor(FlowProcessor):
schema_dict = row_schema_translator.encode(schema)
# Use prompt client to extract rows based on schema
- objects = await flow("prompt-request").extract_objects(
+ result = await flow("prompt-request").extract_objects(
schema=schema_dict,
text=text
)
-
+
+ objects = result.objects
if not isinstance(objects, list):
return []
diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py b/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py
index 62626046..3a37c4e3 100644
--- a/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py
+++ b/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py
@@ -40,15 +40,14 @@ class CoreExport:
"ge",
{
"m": {
- "i": data["metadata"]["id"],
- "m": data["metadata"]["metadata"],
+ "i": data["metadata"]["id"],
"u": data["metadata"]["user"],
"c": data["metadata"]["collection"],
},
"e": [
{
"e": ent["entity"],
- "v": ent["vectors"],
+ "v": ent["vector"],
}
for ent in data["entities"]
]
@@ -65,8 +64,7 @@ class CoreExport:
"t",
{
"m": {
- "i": data["metadata"]["id"],
- "m": data["metadata"]["metadata"],
+ "i": data["metadata"]["id"],
"u": data["metadata"]["user"],
"c": data["metadata"]["collection"],
},
diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py
index af22a5b0..0ca07319 100644
--- a/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py
+++ b/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py
@@ -48,7 +48,6 @@ class CoreImport:
"triples": {
"metadata": {
"id": id,
- "metadata": msg["m"]["m"],
"user": user,
"collection": "default", # Not used?
},
@@ -57,7 +56,7 @@ class CoreImport:
}
await kr.process(msg)
-
+
elif unpacked[0] == "ge":
msg = unpacked[1]
msg = {
@@ -67,14 +66,13 @@ class CoreImport:
"graph-embeddings": {
"metadata": {
"id": id,
- "metadata": msg["m"]["m"],
"user": user,
"collection": "default", # Not used?
},
"entities": [
{
"entity": ent["e"],
- "vectors": ent["v"],
+ "vector": ent["v"],
}
for ent in msg["e"]
]
diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py
index 6e01a5ca..de0fe52d 100644
--- a/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py
+++ b/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py
@@ -8,7 +8,7 @@ from ... schema import Metadata
from ... schema import EntityContexts, EntityContext
from ... base import Publisher
-from . serialize import to_subgraph, to_value
+from . serialize import to_value
# Module logger
logger = logging.getLogger(__name__)
@@ -48,7 +48,6 @@ class EntityContextsImport:
elt = EntityContexts(
metadata=Metadata(
id=data["metadata"]["id"],
- metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py
index 8abf5e9c..7c7dc915 100644
--- a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py
+++ b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py
@@ -8,7 +8,7 @@ from ... schema import Metadata
from ... schema import GraphEmbeddings, EntityEmbeddings
from ... base import Publisher
-from . serialize import to_subgraph, to_value
+from . serialize import to_value
# Module logger
logger = logging.getLogger(__name__)
@@ -48,14 +48,13 @@ class GraphEmbeddingsImport:
elt = GraphEmbeddings(
metadata=Metadata(
id=data["metadata"]["id"],
- metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
entities=[
EntityEmbeddings(
entity=to_value(ent["entity"]),
- vectors=ent["vectors"],
+ vector=ent["vector"],
)
for ent in data["entities"]
]
diff --git a/trustgraph-flow/trustgraph/gateway/service.py b/trustgraph-flow/trustgraph/gateway/service.py
index 8d1aca9e..4e465bf7 100755
--- a/trustgraph-flow/trustgraph/gateway/service.py
+++ b/trustgraph-flow/trustgraph/gateway/service.py
@@ -9,7 +9,7 @@ from aiohttp import web
import logging
import os
-from trustgraph.base.logging import setup_logging
+from trustgraph.base.logging import setup_logging, add_logging_args
from trustgraph.base.pubsub import get_pubsub, add_pubsub_args
from . auth import Authenticator
@@ -195,12 +195,7 @@ def run():
help=f'Secret API token (default: no auth)',
)
- parser.add_argument(
- '-l', '--log-level',
- default='INFO',
- choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
- help=f'Log level (default: INFO)'
- )
+ add_logging_args(parser)
parser.add_argument(
'--metrics',
diff --git a/trustgraph-flow/trustgraph/metering/counter.py b/trustgraph-flow/trustgraph/metering/counter.py
index 3e0b610c..46460b1f 100644
--- a/trustgraph-flow/trustgraph/metering/counter.py
+++ b/trustgraph-flow/trustgraph/metering/counter.py
@@ -102,10 +102,10 @@ class Processor(FlowProcessor):
__class__.cost_metric.labels(model=modelname, direction="input").inc(cost_in)
__class__.cost_metric.labels(model=modelname, direction="output").inc(cost_out)
- logger.info(f"Model: {modelname}")
- logger.info(f"Input Tokens: {num_in}")
- logger.info(f"Output Tokens: {num_out}")
- logger.info(f"Cost for call: ${cost_per_call}")
+ logger.debug(
+ f"Model: {modelname}, in={num_in}, out={num_out}, "
+ f"cost=${cost_per_call}"
+ )
@staticmethod
def add_args(parser):
diff --git a/trustgraph-flow/trustgraph/prompt/template/service.py b/trustgraph-flow/trustgraph/prompt/template/service.py
index 97298e13..c599ce77 100755
--- a/trustgraph-flow/trustgraph/prompt/template/service.py
+++ b/trustgraph-flow/trustgraph/prompt/template/service.py
@@ -11,7 +11,6 @@ import logging
from ...schema import Definition, Relationship, Triple
from ...schema import Topic
from ...schema import PromptRequest, PromptResponse, Error
-from ...schema import TextCompletionRequest, TextCompletionResponse
from ...base import FlowProcessor
from ...base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec
@@ -124,35 +123,26 @@ class Processor(FlowProcessor):
logger.debug(f"System prompt: {system}")
logger.debug(f"User prompt: {prompt}")
- # Use the text completion client with recipient handler
- client = flow("text-completion-request")
-
async def forward_chunks(resp):
- if resp.error:
- raise RuntimeError(resp.error.message)
-
is_final = getattr(resp, 'end_of_stream', False)
# Always send a message if there's content OR if it's the final message
if resp.response or is_final:
- # Forward each chunk immediately
r = PromptResponse(
text=resp.response if resp.response else "",
object=None,
error=None,
end_of_stream=is_final,
+ in_token=resp.in_token,
+ out_token=resp.out_token,
+ model=resp.model,
)
await flow("response").send(r, properties={"id": id})
- # Return True when end_of_stream
- return is_final
-
- await client.request(
- TextCompletionRequest(
- system=system, prompt=prompt, streaming=True
- ),
- recipient=forward_chunks,
- timeout=600
+ await flow("text-completion-request").text_completion_stream(
+ system=system, prompt=prompt,
+ handler=forward_chunks,
+ timeout=600,
)
# Return empty string since we already sent all chunks
@@ -167,17 +157,21 @@ class Processor(FlowProcessor):
return
# Non-streaming path (original behavior)
+ usage = {}
+
async def llm(system, prompt):
logger.debug(f"System prompt: {system}")
logger.debug(f"User prompt: {prompt}")
- resp = await flow("text-completion-request").text_completion(
- system = system, prompt = prompt, streaming = False,
- )
-
try:
- return resp
+ result = await flow("text-completion-request").text_completion(
+ system = system, prompt = prompt,
+ )
+ usage["in_token"] = result.in_token
+ usage["out_token"] = result.out_token
+ usage["model"] = result.model
+ return result.text
except Exception as e:
logger.error(f"LLM Exception: {e}", exc_info=True)
return None
@@ -199,6 +193,9 @@ class Processor(FlowProcessor):
object=None,
error=None,
end_of_stream=True,
+ in_token=usage.get("in_token", 0),
+ out_token=usage.get("out_token", 0),
+ model=usage.get("model", ""),
)
await flow("response").send(r, properties={"id": id})
@@ -215,6 +212,9 @@ class Processor(FlowProcessor):
object=json.dumps(resp),
error=None,
end_of_stream=True,
+ in_token=usage.get("in_token", 0),
+ out_token=usage.get("out_token", 0),
+ model=usage.get("model", ""),
)
await flow("response").send(r, properties={"id": id})
diff --git a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py
index f928a911..019d5610 100644
--- a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py
+++ b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py
@@ -23,6 +23,7 @@ from .... schema import RowsQueryRequest, RowsQueryResponse, GraphQLError
from .... schema import Error, RowSchema, Field as SchemaField
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
+from .... tables.cassandra_async import async_execute
from ... graphql import GraphQLSchemaBuilder, SortDirection
@@ -263,7 +264,7 @@ class Processor(FlowProcessor):
query += f" LIMIT {limit}"
try:
- rows = self.session.execute(query, params)
+ rows = await async_execute(self.session, query, params)
for row in rows:
# Convert data map to dict with proper field names
row_dict = dict(row.data) if row.data else {}
@@ -301,7 +302,7 @@ class Processor(FlowProcessor):
params = [collection, schema_name, primary_index]
try:
- rows = self.session.execute(query, params)
+ rows = await async_execute(self.session, query, params)
for row in rows:
row_dict = dict(row.data) if row.data else {}
diff --git a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py
index f1f5ba60..905aaaf2 100755
--- a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py
+++ b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py
@@ -4,6 +4,7 @@ Triples query service. Input is a (s, p, o, g) quad pattern, some values may be
null. Output is a list of quads.
"""
+import asyncio
import logging
import json
@@ -200,7 +201,11 @@ class Processor(TriplesQueryService):
try:
- self.ensure_connection(query.user)
+ # ensure_connection may construct a fresh
+ # EntityCentricKnowledgeGraph which does sync schema
+ # setup against Cassandra. Push it to a worker thread
+ # so the event loop doesn't block on first-use per user.
+ await asyncio.to_thread(self.ensure_connection, query.user)
# Extract values from query
s_val = get_term_value(query.s)
@@ -218,14 +223,21 @@ class Processor(TriplesQueryService):
quads = []
+ # All self.tg.get_* calls below are sync wrappers around
+ # cassandra session.execute. Materialise inside a worker
+ # thread so iteration never triggers sync paging back on
+ # the event loop.
+
# Route to appropriate query method based on which fields are specified
if s_val is not None:
if p_val is not None:
if o_val is not None:
# SPO specified - find matching graphs
- resp = self.tg.get_spo(
- query.collection, s_val, p_val, o_val, g=g_val,
- limit=query.limit
+ resp = await asyncio.to_thread(
+ lambda: list(self.tg.get_spo(
+ query.collection, s_val, p_val, o_val,
+ g=g_val, limit=query.limit,
+ ))
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
@@ -233,9 +245,11 @@ class Processor(TriplesQueryService):
quads.append((s_val, p_val, o_val, g, term_type, datatype, language))
else:
# SP specified
- resp = self.tg.get_sp(
- query.collection, s_val, p_val, g=g_val,
- limit=query.limit
+ resp = await asyncio.to_thread(
+ lambda: list(self.tg.get_sp(
+ query.collection, s_val, p_val,
+ g=g_val, limit=query.limit,
+ ))
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
@@ -244,9 +258,11 @@ class Processor(TriplesQueryService):
else:
if o_val is not None:
# SO specified
- resp = self.tg.get_os(
- query.collection, o_val, s_val, g=g_val,
- limit=query.limit
+ resp = await asyncio.to_thread(
+ lambda: list(self.tg.get_os(
+ query.collection, o_val, s_val,
+ g=g_val, limit=query.limit,
+ ))
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
@@ -254,9 +270,11 @@ class Processor(TriplesQueryService):
quads.append((s_val, t.p, o_val, g, term_type, datatype, language))
else:
# S only
- resp = self.tg.get_s(
- query.collection, s_val, g=g_val,
- limit=query.limit
+ resp = await asyncio.to_thread(
+ lambda: list(self.tg.get_s(
+ query.collection, s_val,
+ g=g_val, limit=query.limit,
+ ))
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
@@ -266,9 +284,11 @@ class Processor(TriplesQueryService):
if p_val is not None:
if o_val is not None:
# PO specified
- resp = self.tg.get_po(
- query.collection, p_val, o_val, g=g_val,
- limit=query.limit
+ resp = await asyncio.to_thread(
+ lambda: list(self.tg.get_po(
+ query.collection, p_val, o_val,
+ g=g_val, limit=query.limit,
+ ))
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
@@ -276,9 +296,11 @@ class Processor(TriplesQueryService):
quads.append((t.s, p_val, o_val, g, term_type, datatype, language))
else:
# P only
- resp = self.tg.get_p(
- query.collection, p_val, g=g_val,
- limit=query.limit
+ resp = await asyncio.to_thread(
+ lambda: list(self.tg.get_p(
+ query.collection, p_val,
+ g=g_val, limit=query.limit,
+ ))
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
@@ -287,9 +309,11 @@ class Processor(TriplesQueryService):
else:
if o_val is not None:
# O only
- resp = self.tg.get_o(
- query.collection, o_val, g=g_val,
- limit=query.limit
+ resp = await asyncio.to_thread(
+ lambda: list(self.tg.get_o(
+ query.collection, o_val,
+ g=g_val, limit=query.limit,
+ ))
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
@@ -297,9 +321,10 @@ class Processor(TriplesQueryService):
quads.append((t.s, t.p, o_val, g, term_type, datatype, language))
else:
# Nothing specified - get all
- resp = self.tg.get_all(
- query.collection,
- limit=query.limit
+ resp = await asyncio.to_thread(
+ lambda: list(self.tg.get_all(
+ query.collection, limit=query.limit,
+ ))
)
for t in resp:
# Note: quads_by_collection uses 'd' for graph field
@@ -340,7 +365,7 @@ class Processor(TriplesQueryService):
Uses Cassandra's paging to fetch results incrementally.
"""
try:
- self.ensure_connection(query.user)
+ await asyncio.to_thread(self.ensure_connection, query.user)
batch_size = query.batch_size if query.batch_size > 0 else 20
limit = query.limit if query.limit > 0 else 10000
@@ -374,9 +399,16 @@ class Processor(TriplesQueryService):
yield batch, is_final
return
- # Create statement with fetch_size for true streaming
+ # Materialise in a worker thread. We lose true streaming
+ # paging (the driver fetches all pages eagerly inside the
+ # thread) but the event loop stays responsive, and result
+ # sets at this layer are typically small enough that this
+ # is acceptable. If true async paging is needed later,
+ # revisit using ResponseFuture page callbacks.
statement = SimpleStatement(cql, fetch_size=batch_size)
- result_set = self.tg.session.execute(statement, params)
+ result_set = await asyncio.to_thread(
+ lambda: list(self.tg.session.execute(statement, params))
+ )
batch = []
count = 0
diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py
index 730a7226..625b1386 100644
--- a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py
+++ b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py
@@ -27,24 +27,27 @@ class Query:
def __init__(
self, rag, user, collection, verbose,
- doc_limit=20
+ doc_limit=20, track_usage=None,
):
self.rag = rag
self.user = user
self.collection = collection
self.verbose = verbose
self.doc_limit = doc_limit
+ self.track_usage = track_usage
async def extract_concepts(self, query):
"""Extract key concepts from query for independent embedding."""
- response = await self.rag.prompt_client.prompt(
+ result = await self.rag.prompt_client.prompt(
"extract-concepts",
variables={"query": query}
)
+ if self.track_usage:
+ self.track_usage(result)
concepts = []
- if isinstance(response, str):
- for line in response.strip().split('\n'):
+ if result.text:
+ for line in result.text.strip().split('\n'):
line = line.strip()
if line:
concepts.append(line)
@@ -53,6 +56,8 @@ class Query:
if not concepts:
concepts = [query]
+ self.concepts_usage = result
+
if self.verbose:
logger.debug(f"Extracted concepts: {concepts}")
@@ -167,8 +172,23 @@ class DocumentRag:
save_answer_callback: async def callback(doc_id, answer_text) to save answer to librarian
Returns:
- str: The synthesized answer text
+ tuple: (answer_text, usage) where usage is a dict with
+ in_token, out_token, model
"""
+ total_in = 0
+ total_out = 0
+ last_model = None
+
+ def track_usage(result):
+ nonlocal total_in, total_out, last_model
+ if result is not None:
+ if result.in_token is not None:
+ total_in += result.in_token
+ if result.out_token is not None:
+ total_out += result.out_token
+ if result.model is not None:
+ last_model = result.model
+
if self.verbose:
logger.debug("Constructing prompt...")
@@ -191,7 +211,7 @@ class DocumentRag:
q = Query(
rag=self, user=user, collection=collection, verbose=self.verbose,
- doc_limit=doc_limit
+ doc_limit=doc_limit, track_usage=track_usage,
)
# Extract concepts from query (grounding step)
@@ -199,8 +219,14 @@ class DocumentRag:
# Emit grounding explainability after concept extraction
if explain_callback:
+ cu = getattr(q, 'concepts_usage', None)
gnd_triples = set_graph(
- grounding_triples(gnd_uri, q_uri, concepts),
+ grounding_triples(
+ gnd_uri, q_uri, concepts,
+ in_token=cu.in_token if cu else None,
+ out_token=cu.out_token if cu else None,
+ model=cu.model if cu else None,
+ ),
GRAPH_RETRIEVAL
)
await explain_callback(gnd_triples, gnd_uri)
@@ -228,19 +254,22 @@ class DocumentRag:
accumulated_chunks.append(chunk)
await chunk_callback(chunk, end_of_stream)
- resp = await self.prompt_client.document_prompt(
+ synthesis_result = await self.prompt_client.document_prompt(
query=query,
documents=docs,
streaming=True,
chunk_callback=accumulating_callback
)
+ track_usage(synthesis_result)
# Combine all chunks into full response
resp = "".join(accumulated_chunks)
else:
- resp = await self.prompt_client.document_prompt(
+ synthesis_result = await self.prompt_client.document_prompt(
query=query,
documents=docs
)
+ track_usage(synthesis_result)
+ resp = synthesis_result.text
if self.verbose:
logger.debug("Query processing complete")
@@ -265,6 +294,9 @@ class DocumentRag:
docrag_synthesis_triples(
syn_uri, exp_uri,
document_id=synthesis_doc_id,
+ in_token=synthesis_result.in_token if synthesis_result else None,
+ out_token=synthesis_result.out_token if synthesis_result else None,
+ model=synthesis_result.model if synthesis_result else None,
),
GRAPH_RETRIEVAL
)
@@ -273,5 +305,11 @@ class DocumentRag:
if self.verbose:
logger.debug(f"Emitted explain for session {session_id}")
- return resp
+ usage = {
+ "in_token": total_in if total_in > 0 else None,
+ "out_token": total_out if total_out > 0 else None,
+ "model": last_model,
+ }
+
+ return resp, usage
diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py
index 3b281fe3..dc7296ad 100755
--- a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py
+++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py
@@ -200,7 +200,7 @@ class Processor(FlowProcessor):
# Query with streaming enabled
# All chunks (including final one with end_of_stream=True) are sent via callback
- await self.rag.query(
+ response, usage = await self.rag.query(
v.query,
user=v.user,
collection=v.collection,
@@ -217,12 +217,15 @@ class Processor(FlowProcessor):
response=None,
end_of_session=True,
message_type="end",
+ in_token=usage.get("in_token"),
+ out_token=usage.get("out_token"),
+ model=usage.get("model"),
),
properties={"id": id}
)
else:
- # Non-streaming path (existing behavior)
- response = await self.rag.query(
+ # Non-streaming path - single response with answer and token usage
+ response, usage = await self.rag.query(
v.query,
user=v.user,
collection=v.collection,
@@ -233,11 +236,15 @@ class Processor(FlowProcessor):
await flow("response").send(
DocumentRagResponse(
- response = response,
- end_of_stream = True,
- error = None
+ response=response,
+ end_of_stream=True,
+ end_of_session=True,
+ error=None,
+ in_token=usage.get("in_token"),
+ out_token=usage.get("out_token"),
+ model=usage.get("model"),
),
- properties = {"id": id}
+ properties={"id": id}
)
logger.info("Request processing complete")
diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py
index 5cf7b991..cf9f5c4e 100644
--- a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py
+++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py
@@ -121,7 +121,7 @@ class Query:
def __init__(
self, rag, user, collection, verbose,
entity_limit=50, triple_limit=30, max_subgraph_size=1000,
- max_path_length=2,
+ max_path_length=2, track_usage=None,
):
self.rag = rag
self.user = user
@@ -131,17 +131,20 @@ class Query:
self.triple_limit = triple_limit
self.max_subgraph_size = max_subgraph_size
self.max_path_length = max_path_length
+ self.track_usage = track_usage
async def extract_concepts(self, query):
"""Extract key concepts from query for independent embedding."""
- response = await self.rag.prompt_client.prompt(
+ result = await self.rag.prompt_client.prompt(
"extract-concepts",
variables={"query": query}
)
+ if self.track_usage:
+ self.track_usage(result)
concepts = []
- if isinstance(response, str):
- for line in response.strip().split('\n'):
+ if result.text:
+ for line in result.text.strip().split('\n'):
line = line.strip()
if line:
concepts.append(line)
@@ -149,6 +152,8 @@ class Query:
if self.verbose:
logger.debug(f"Extracted concepts: {concepts}")
+ self.concepts_usage = result
+
# Fall back to raw query if extraction returns nothing
return concepts if concepts else [query]
@@ -609,8 +614,24 @@ class GraphRag:
save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian
Returns:
- str: The synthesized answer text
+ tuple: (answer_text, usage) where usage is a dict with
+ in_token, out_token, model
"""
+ # Accumulate token usage across all prompt calls
+ total_in = 0
+ total_out = 0
+ last_model = None
+
+ def track_usage(result):
+ nonlocal total_in, total_out, last_model
+ if result is not None:
+ if result.in_token is not None:
+ total_in += result.in_token
+ if result.out_token is not None:
+ total_out += result.out_token
+ if result.model is not None:
+ last_model = result.model
+
if self.verbose:
logger.debug("Constructing prompt...")
@@ -641,14 +662,21 @@ class GraphRag:
triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
+ track_usage = track_usage,
)
kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query)
# Emit grounding explain after concept extraction
if explain_callback:
+ cu = getattr(q, 'concepts_usage', None)
gnd_triples = set_graph(
- grounding_triples(gnd_uri, q_uri, concepts),
+ grounding_triples(
+ gnd_uri, q_uri, concepts,
+ in_token=cu.in_token if cu else None,
+ out_token=cu.out_token if cu else None,
+ model=cu.model if cu else None,
+ ),
GRAPH_RETRIEVAL
)
await explain_callback(gnd_triples, gnd_uri)
@@ -751,21 +779,22 @@ class GraphRag:
logger.debug(f"Built edge map with {len(edge_map)} edges")
# Step 1a: Edge Scoring - LLM scores edges for relevance
- scoring_response = await self.prompt_client.prompt(
+ scoring_result = await self.prompt_client.prompt(
"kg-edge-scoring",
variables={
"query": query,
"knowledge": edges_with_ids
}
)
+ track_usage(scoring_result)
if self.verbose:
- logger.debug(f"Edge scoring response: {scoring_response}")
+ logger.debug(f"Edge scoring result: {scoring_result}")
- # Parse scoring response to get edge IDs with scores
+ # Parse scoring response (jsonl) to get edge IDs with scores
scored_edges = []
- def parse_scored_edge(obj):
+ for obj in scoring_result.objects or []:
if isinstance(obj, dict) and "id" in obj and "score" in obj:
try:
score = int(obj["score"])
@@ -773,21 +802,6 @@ class GraphRag:
score = 0
scored_edges.append({"id": obj["id"], "score": score})
- if isinstance(scoring_response, list):
- for obj in scoring_response:
- parse_scored_edge(obj)
- elif isinstance(scoring_response, str):
- for line in scoring_response.strip().split('\n'):
- line = line.strip()
- if not line:
- continue
- try:
- parse_scored_edge(json.loads(line))
- except json.JSONDecodeError:
- logger.warning(
- f"Failed to parse edge scoring line: {line}"
- )
-
# Select top N edges by score
scored_edges.sort(key=lambda x: x["score"], reverse=True)
top_edges = scored_edges[:edge_limit]
@@ -821,25 +835,30 @@ class GraphRag:
]
# Run reasoning and document tracing concurrently
- reasoning_task = self.prompt_client.prompt(
- "kg-edge-reasoning",
- variables={
- "query": query,
- "knowledge": selected_edges_with_ids
- }
- )
+ async def _get_reasoning():
+ result = await self.prompt_client.prompt(
+ "kg-edge-reasoning",
+ variables={
+ "query": query,
+ "knowledge": selected_edges_with_ids
+ }
+ )
+ track_usage(result)
+ return result
+
+ reasoning_task = _get_reasoning()
doc_trace_task = q.trace_source_documents(selected_edge_uris)
- reasoning_response, source_documents = await asyncio.gather(
+ reasoning_result, source_documents = await asyncio.gather(
reasoning_task, doc_trace_task, return_exceptions=True
)
# Handle exceptions from gather
- if isinstance(reasoning_response, Exception):
+ if isinstance(reasoning_result, Exception):
logger.warning(
- f"Edge reasoning failed: {reasoning_response}"
+ f"Edge reasoning failed: {reasoning_result}"
)
- reasoning_response = ""
+ reasoning_result = None
if isinstance(source_documents, Exception):
logger.warning(
f"Document tracing failed: {source_documents}"
@@ -848,29 +867,15 @@ class GraphRag:
if self.verbose:
- logger.debug(f"Edge reasoning response: {reasoning_response}")
+ logger.debug(f"Edge reasoning result: {reasoning_result}")
- # Parse reasoning response and build explainability data
+ # Parse reasoning response (jsonl) and build explainability data
reasoning_map = {}
- def parse_reasoning(obj):
- if isinstance(obj, dict) and "id" in obj:
- reasoning_map[obj["id"]] = obj.get("reasoning", "")
-
- if isinstance(reasoning_response, list):
- for obj in reasoning_response:
- parse_reasoning(obj)
- elif isinstance(reasoning_response, str):
- for line in reasoning_response.strip().split('\n'):
- line = line.strip()
- if not line:
- continue
- try:
- parse_reasoning(json.loads(line))
- except json.JSONDecodeError:
- logger.warning(
- f"Failed to parse edge reasoning line: {line}"
- )
+ if reasoning_result is not None:
+ for obj in reasoning_result.objects or []:
+ if isinstance(obj, dict) and "id" in obj:
+ reasoning_map[obj["id"]] = obj.get("reasoning", "")
selected_edges_with_reasoning = []
for eid in selected_ids:
@@ -886,9 +891,25 @@ class GraphRag:
# Emit focus explain after edge selection completes
if explain_callback:
+ # Sum scoring + reasoning token usage for focus event
+ focus_in = 0
+ focus_out = 0
+ focus_model = None
+ for r in [scoring_result, reasoning_result]:
+ if r is not None:
+ if r.in_token is not None:
+ focus_in += r.in_token
+ if r.out_token is not None:
+ focus_out += r.out_token
+ if r.model is not None:
+ focus_model = r.model
+
foc_triples = set_graph(
focus_triples(
- foc_uri, exp_uri, selected_edges_with_reasoning, session_id
+ foc_uri, exp_uri, selected_edges_with_reasoning, session_id,
+ in_token=focus_in or None,
+ out_token=focus_out or None,
+ model=focus_model,
),
GRAPH_RETRIEVAL
)
@@ -919,19 +940,22 @@ class GraphRag:
accumulated_chunks.append(chunk)
await chunk_callback(chunk, end_of_stream)
- await self.prompt_client.prompt(
+ synthesis_result = await self.prompt_client.prompt(
"kg-synthesis",
variables=synthesis_variables,
streaming=True,
chunk_callback=accumulating_callback
)
+ track_usage(synthesis_result)
# Combine all chunks into full response
resp = "".join(accumulated_chunks)
else:
- resp = await self.prompt_client.prompt(
+ synthesis_result = await self.prompt_client.prompt(
"kg-synthesis",
variables=synthesis_variables,
)
+ track_usage(synthesis_result)
+ resp = synthesis_result.text
if self.verbose:
logger.debug("Query processing complete")
@@ -956,6 +980,9 @@ class GraphRag:
synthesis_triples(
syn_uri, foc_uri,
document_id=synthesis_doc_id,
+ in_token=synthesis_result.in_token if synthesis_result else None,
+ out_token=synthesis_result.out_token if synthesis_result else None,
+ model=synthesis_result.model if synthesis_result else None,
),
GRAPH_RETRIEVAL
)
@@ -964,5 +991,11 @@ class GraphRag:
if self.verbose:
logger.debug(f"Emitted explain for session {session_id}")
- return resp
+ usage = {
+ "in_token": total_in if total_in > 0 else None,
+ "out_token": total_out if total_out > 0 else None,
+ "model": last_model,
+ }
+
+ return resp, usage
diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py
index abf10e90..15c30ba1 100755
--- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py
+++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py
@@ -332,7 +332,7 @@ class Processor(FlowProcessor):
)
# Query with streaming and real-time explain
- response = await rag.query(
+ response, usage = await rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
@@ -348,7 +348,7 @@ class Processor(FlowProcessor):
else:
# Non-streaming path with real-time explain
- response = await rag.query(
+ response, usage = await rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
@@ -360,23 +360,30 @@ class Processor(FlowProcessor):
parent_uri = v.parent_uri,
)
- # Send chunk with response
+ # Send single response with answer and token usage
await flow("response").send(
GraphRagResponse(
message_type="chunk",
response=response,
end_of_stream=True,
- error=None,
+ end_of_session=True,
+ in_token=usage.get("in_token"),
+ out_token=usage.get("out_token"),
+ model=usage.get("model"),
),
properties={"id": id}
)
+ return
- # Send final message to close session
+ # Streaming: send final message to close session with token usage
await flow("response").send(
GraphRagResponse(
message_type="chunk",
response="",
end_of_session=True,
+ in_token=usage.get("in_token"),
+ out_token=usage.get("out_token"),
+ model=usage.get("model"),
),
properties={"id": id}
)
diff --git a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py
index 673cba4d..d0eec2e1 100755
--- a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py
+++ b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py
@@ -13,6 +13,7 @@ Uses a single 'rows' table with the schema:
Each row is written multiple times - once per indexed field defined in the schema.
"""
+import asyncio
import json
import logging
import re
@@ -26,6 +27,7 @@ from .... schema import RowSchema, Field
from .... base import FlowProcessor, ConsumerSpec
from .... base import CollectionConfigHandler
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
+from .... tables.cassandra_async import async_execute
# Module logger
logger = logging.getLogger(__name__)
@@ -361,11 +363,15 @@ class Processor(CollectionConfigHandler, FlowProcessor):
schema_name = obj.schema_name
source = getattr(obj.metadata, 'source', '') or ''
- # Ensure tables exist
- self.ensure_tables(keyspace)
+ # Ensure tables exist (sync DDL — push to a worker thread
+ # so the event loop stays responsive when running in a
+ # processor group sharing the loop with siblings).
+ await asyncio.to_thread(self.ensure_tables, keyspace)
# Register partitions if first time seeing this (collection, schema_name)
- self.register_partitions(keyspace, collection, schema_name)
+ await asyncio.to_thread(
+ self.register_partitions, keyspace, collection, schema_name
+ )
safe_keyspace = self.sanitize_name(keyspace)
@@ -406,9 +412,10 @@ class Processor(CollectionConfigHandler, FlowProcessor):
continue
try:
- self.session.execute(
+ await async_execute(
+ self.session,
insert_cql,
- (collection, schema_name, index_name, index_value, data_map, source)
+ (collection, schema_name, index_name, index_value, data_map, source),
)
rows_written += 1
except Exception as e:
@@ -425,18 +432,18 @@ class Processor(CollectionConfigHandler, FlowProcessor):
async def create_collection(self, user: str, collection: str, metadata: dict):
"""Create/verify collection exists in Cassandra row store"""
- # Connect if not already connected
- self.connect_cassandra()
+ # Connect if not already connected (sync, push to thread)
+ await asyncio.to_thread(self.connect_cassandra)
- # Ensure tables exist
- self.ensure_tables(user)
+ # Ensure tables exist (sync DDL, push to thread)
+ await asyncio.to_thread(self.ensure_tables, user)
logger.info(f"Collection {collection} ready for user {user}")
async def delete_collection(self, user: str, collection: str):
"""Delete all data for a specific collection using partition tracking"""
# Connect if not already connected
- self.connect_cassandra()
+ await asyncio.to_thread(self.connect_cassandra)
safe_keyspace = self.sanitize_name(user)
@@ -446,8 +453,10 @@ class Processor(CollectionConfigHandler, FlowProcessor):
SELECT keyspace_name FROM system_schema.keyspaces
WHERE keyspace_name = %s
"""
- result = self.session.execute(check_keyspace_cql, (safe_keyspace,))
- if not result.one():
+ result = await async_execute(
+ self.session, check_keyspace_cql, (safe_keyspace,)
+ )
+ if not result:
logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete")
return
self.known_keyspaces.add(user)
@@ -459,8 +468,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
"""
try:
- partitions = self.session.execute(select_partitions_cql, (collection,))
- partition_list = list(partitions)
+ partition_list = await async_execute(
+ self.session, select_partitions_cql, (collection,)
+ )
except Exception as e:
logger.error(f"Failed to query partitions for collection {collection}: {e}")
raise
@@ -474,9 +484,10 @@ class Processor(CollectionConfigHandler, FlowProcessor):
partitions_deleted = 0
for partition in partition_list:
try:
- self.session.execute(
+ await async_execute(
+ self.session,
delete_rows_cql,
- (collection, partition.schema_name, partition.index_name)
+ (collection, partition.schema_name, partition.index_name),
)
partitions_deleted += 1
except Exception as e:
@@ -493,7 +504,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
"""
try:
- self.session.execute(delete_partitions_cql, (collection,))
+ await async_execute(
+ self.session, delete_partitions_cql, (collection,)
+ )
except Exception as e:
logger.error(f"Failed to clean up row_partitions for {collection}: {e}")
raise
@@ -512,7 +525,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
async def delete_collection_schema(self, user: str, collection: str, schema_name: str):
"""Delete all data for a specific collection + schema combination"""
# Connect if not already connected
- self.connect_cassandra()
+ await asyncio.to_thread(self.connect_cassandra)
safe_keyspace = self.sanitize_name(user)
@@ -523,8 +536,9 @@ class Processor(CollectionConfigHandler, FlowProcessor):
"""
try:
- partitions = self.session.execute(select_partitions_cql, (collection, schema_name))
- partition_list = list(partitions)
+ partition_list = await async_execute(
+ self.session, select_partitions_cql, (collection, schema_name)
+ )
except Exception as e:
logger.error(
f"Failed to query partitions for {collection}/{schema_name}: {e}"
@@ -540,9 +554,10 @@ class Processor(CollectionConfigHandler, FlowProcessor):
partitions_deleted = 0
for partition in partition_list:
try:
- self.session.execute(
+ await async_execute(
+ self.session,
delete_rows_cql,
- (collection, schema_name, partition.index_name)
+ (collection, schema_name, partition.index_name),
)
partitions_deleted += 1
except Exception as e:
@@ -559,7 +574,11 @@ class Processor(CollectionConfigHandler, FlowProcessor):
"""
try:
- self.session.execute(delete_partitions_cql, (collection, schema_name))
+ await async_execute(
+ self.session,
+ delete_partitions_cql,
+ (collection, schema_name),
+ )
except Exception as e:
logger.error(
f"Failed to clean up row_partitions for {collection}/{schema_name}: {e}"
diff --git a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py
index 2a240f0b..01d95c8b 100755
--- a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py
+++ b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py
@@ -3,6 +3,7 @@
Graph writer. Input is graph edge. Writes edges to Cassandra graph.
"""
+import asyncio
import base64
import os
import argparse
@@ -150,59 +151,71 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
user = message.metadata.user
- if self.table is None or self.table != user:
+ # The cassandra-driver work below — connection, schema
+ # setup, and per-triple inserts — is all synchronous.
+ # Wrap the whole batch in a worker thread so the event
+ # loop stays responsive for sibling processors when
+ # running in a processor group.
- self.tg = None
+ def _do_store():
- # Use factory function to select implementation
- KGClass = EntityCentricKnowledgeGraph
+ if self.table is None or self.table != user:
- try:
- if self.cassandra_username and self.cassandra_password:
- self.tg = KGClass(
- hosts=self.cassandra_host,
- keyspace=message.metadata.user,
- username=self.cassandra_username, password=self.cassandra_password
- )
- else:
- self.tg = KGClass(
- hosts=self.cassandra_host,
- keyspace=message.metadata.user,
- )
- except Exception as e:
- logger.error(f"Exception: {e}", exc_info=True)
- time.sleep(1)
- raise e
+ self.tg = None
- self.table = user
+ # Use factory function to select implementation
+ KGClass = EntityCentricKnowledgeGraph
- for t in message.triples:
- # Extract values from Term objects
- s_val = get_term_value(t.s)
- p_val = get_term_value(t.p)
- o_val = get_term_value(t.o)
- # t.g is None for default graph, or a graph IRI
- g_val = t.g if t.g is not None else DEFAULT_GRAPH
+ try:
+ if self.cassandra_username and self.cassandra_password:
+ self.tg = KGClass(
+ hosts=self.cassandra_host,
+ keyspace=message.metadata.user,
+ username=self.cassandra_username,
+ password=self.cassandra_password,
+ )
+ else:
+ self.tg = KGClass(
+ hosts=self.cassandra_host,
+ keyspace=message.metadata.user,
+ )
+ except Exception as e:
+ logger.error(f"Exception: {e}", exc_info=True)
+ time.sleep(1)
+ raise e
- # Extract object type metadata for entity-centric storage
- otype = get_term_otype(t.o)
- dtype = get_term_dtype(t.o)
- lang = get_term_lang(t.o)
+ self.table = user
- self.tg.insert(
- message.metadata.collection,
- s_val,
- p_val,
- o_val,
- g=g_val,
- otype=otype,
- dtype=dtype,
- lang=lang
- )
+ for t in message.triples:
+ # Extract values from Term objects
+ s_val = get_term_value(t.s)
+ p_val = get_term_value(t.p)
+ o_val = get_term_value(t.o)
+ # t.g is None for default graph, or a graph IRI
+ g_val = t.g if t.g is not None else DEFAULT_GRAPH
+
+ # Extract object type metadata for entity-centric storage
+ otype = get_term_otype(t.o)
+ dtype = get_term_dtype(t.o)
+ lang = get_term_lang(t.o)
+
+ self.tg.insert(
+ message.metadata.collection,
+ s_val,
+ p_val,
+ o_val,
+ g=g_val,
+ otype=otype,
+ dtype=dtype,
+ lang=lang,
+ )
+
+ await asyncio.to_thread(_do_store)
async def create_collection(self, user: str, collection: str, metadata: dict):
"""Create a collection in Cassandra triple store via config push"""
- try:
+
+ def _do_create():
# Create or reuse connection for this user's keyspace
if self.table is None or self.table != user:
self.tg = None
@@ -216,7 +229,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
hosts=self.cassandra_host,
keyspace=user,
username=self.cassandra_username,
- password=self.cassandra_password
+ password=self.cassandra_password,
)
else:
self.tg = KGClass(
@@ -238,13 +251,16 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
self.tg.create_collection(collection)
logger.info(f"Created collection {collection}")
+ try:
+ await asyncio.to_thread(_do_create)
except Exception as e:
logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True)
raise
async def delete_collection(self, user: str, collection: str):
"""Delete all data for a specific collection from the unified triples table"""
- try:
+
+ def _do_delete():
# Create or reuse connection for this user's keyspace
if self.table is None or self.table != user:
self.tg = None
@@ -258,7 +274,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
hosts=self.cassandra_host,
keyspace=user,
username=self.cassandra_username,
- password=self.cassandra_password
+ password=self.cassandra_password,
)
else:
self.tg = KGClass(
@@ -275,6 +291,8 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
self.tg.delete_collection(collection)
logger.info(f"Deleted all triples for collection {collection} from keyspace {user}")
+ try:
+ await asyncio.to_thread(_do_delete)
except Exception as e:
logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True)
raise
diff --git a/trustgraph-flow/trustgraph/tables/cassandra_async.py b/trustgraph-flow/trustgraph/tables/cassandra_async.py
new file mode 100644
index 00000000..2f497748
--- /dev/null
+++ b/trustgraph-flow/trustgraph/tables/cassandra_async.py
@@ -0,0 +1,78 @@
+"""
+Async wrapper for cassandra-driver sessions.
+
+The cassandra driver exposes a callback-based async API via
+session.execute_async, returning a ResponseFuture that fires
+on_result / on_error from the driver's own worker thread.
+This module bridges that into an awaitable interface.
+
+Usage:
+ from ..tables.cassandra_async import async_execute
+
+ rows = await async_execute(self.cassandra, stmt, (param1, param2))
+ for row in rows:
+ ...
+
+Notes:
+ - Rows are materialised into a list inside the driver callback
+ thread before the future is resolved, so subsequent iteration
+ in the caller never triggers a sync page-fetch on the asyncio
+ loop. This is safe for single-page results (the common case
+ in this codebase); if a query needs pagination, handle it
+ explicitly.
+ - Callbacks fire on a driver worker thread; call_soon_threadsafe
+ is used to hand the result back to the asyncio loop.
+ - Errors from the driver are re-raised in the awaiting coroutine.
+"""
+
+import asyncio
+
+
+async def async_execute(session, query, parameters=None):
+ """Execute a CQL statement asynchronously.
+
+ Args:
+ session: cassandra.cluster.Session (self.cassandra)
+ query: statement string or PreparedStatement
+ parameters: tuple/list of bind params, or None
+
+ Returns:
+ A list of rows (materialised from the first result page).
+ """
+
+ loop = asyncio.get_running_loop()
+ fut = loop.create_future()
+
+ def on_result(rows):
+ # Materialise on the driver thread so the loop thread
+ # never touches a lazy iterator that might trigger
+ # further sync I/O.
+ try:
+ materialised = list(rows) if rows is not None else []
+ except Exception as e:
+ loop.call_soon_threadsafe(
+ _set_exception_if_pending, fut, e
+ )
+ return
+ loop.call_soon_threadsafe(
+ _set_result_if_pending, fut, materialised
+ )
+
+ def on_error(exc):
+ loop.call_soon_threadsafe(
+ _set_exception_if_pending, fut, exc
+ )
+
+ rf = session.execute_async(query, parameters)
+ rf.add_callbacks(on_result, on_error)
+ return await fut
+
+
+def _set_result_if_pending(fut, result):
+ if not fut.done():
+ fut.set_result(result)
+
+
+def _set_exception_if_pending(fut, exc):
+ if not fut.done():
+ fut.set_exception(exc)
diff --git a/trustgraph-flow/trustgraph/tables/config.py b/trustgraph-flow/trustgraph/tables/config.py
index fb9ea0a7..d9a8711b 100644
--- a/trustgraph-flow/trustgraph/tables/config.py
+++ b/trustgraph-flow/trustgraph/tables/config.py
@@ -11,6 +11,8 @@ import time
import asyncio
import logging
+from . cassandra_async import async_execute
+
logger = logging.getLogger(__name__)
class ConfigTableStore:
@@ -102,21 +104,20 @@ class ConfigTableStore:
async def inc_version(self):
- self.cassandra.execute("""
+ await async_execute(self.cassandra, """
UPDATE version set version = version + 1
WHERE id = 'version'
""")
async def get_version(self):
- resp = self.cassandra.execute("""
+ rows = await async_execute(self.cassandra, """
SELECT version FROM version
WHERE id = 'version'
""")
- row = resp.one()
-
- if row: return row[0]
+ if rows:
+ return rows[0][0]
return None
@@ -153,150 +154,91 @@ class ConfigTableStore:
""")
async def put_config(self, cls, key, value):
-
- while True:
-
- try:
-
- resp = self.cassandra.execute(
- self.put_config_stmt,
- ( cls, key, value )
- )
-
- break
-
- except Exception as e:
-
- logger.error("Exception occurred", exc_info=True)
- raise e
+ try:
+ await async_execute(
+ self.cassandra,
+ self.put_config_stmt,
+ (cls, key, value),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
async def get_value(self, cls, key):
+ try:
+ rows = await async_execute(
+ self.cassandra,
+ self.get_value_stmt,
+ (cls, key),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
- while True:
-
- try:
-
- resp = self.cassandra.execute(
- self.get_value_stmt,
- ( cls, key )
- )
-
- break
-
- except Exception as e:
-
- logger.error("Exception occurred", exc_info=True)
- raise e
-
- for row in resp:
+ for row in rows:
return row[0]
-
return None
async def get_values(self, cls):
+ try:
+ rows = await async_execute(
+ self.cassandra,
+ self.get_values_stmt,
+ (cls,),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
- while True:
-
- try:
-
- resp = self.cassandra.execute(
- self.get_values_stmt,
- ( cls, )
- )
-
- break
-
- except Exception as e:
-
- logger.error("Exception occurred", exc_info=True)
- raise e
-
- return [
- [row[0], row[1]]
- for row in resp
- ]
+ return [[row[0], row[1]] for row in rows]
async def get_classes(self):
+ try:
+ rows = await async_execute(
+ self.cassandra,
+ self.get_classes_stmt,
+ (),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
- while True:
-
- try:
-
- resp = self.cassandra.execute(
- self.get_classes_stmt,
- ()
- )
-
- break
-
- except Exception as e:
-
- logger.error("Exception occurred", exc_info=True)
- raise e
-
- return [
- row[0] for row in resp
- ]
+ return [row[0] for row in rows]
async def get_all(self):
+ try:
+ rows = await async_execute(
+ self.cassandra,
+ self.get_all_stmt,
+ (),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
- while True:
-
- try:
-
- resp = self.cassandra.execute(
- self.get_all_stmt,
- ()
- )
-
- break
-
- except Exception as e:
-
- logger.error("Exception occurred", exc_info=True)
- raise e
-
- return [
- (row[0], row[1], row[2])
- for row in resp
- ]
+ return [(row[0], row[1], row[2]) for row in rows]
async def get_keys(self, cls):
+ try:
+ rows = await async_execute(
+ self.cassandra,
+ self.get_keys_stmt,
+ (cls,),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
- while True:
-
- try:
-
- resp = self.cassandra.execute(
- self.get_keys_stmt,
- ( cls, )
- )
-
- break
-
- except Exception as e:
-
- logger.error("Exception occurred", exc_info=True)
- raise e
-
- return [
- row[0] for row in resp
- ]
+ return [row[0] for row in rows]
async def delete_key(self, cls, key):
-
- while True:
-
- try:
-
- resp = self.cassandra.execute(
- self.delete_key_stmt,
- (cls, key)
- )
-
- break
-
- except Exception as e:
- logger.error("Exception occurred", exc_info=True)
- raise e
+ try:
+ await async_execute(
+ self.cassandra,
+ self.delete_key_stmt,
+ (cls, key),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
diff --git a/trustgraph-flow/trustgraph/tables/knowledge.py b/trustgraph-flow/trustgraph/tables/knowledge.py
index 430dc3c9..b06f4862 100644
--- a/trustgraph-flow/trustgraph/tables/knowledge.py
+++ b/trustgraph-flow/trustgraph/tables/knowledge.py
@@ -4,6 +4,8 @@ from .. schema import Metadata, Term, IRI, LITERAL, GraphEmbeddings
from cassandra.cluster import Cluster
+from . cassandra_async import async_execute
+
def term_to_tuple(term):
"""Convert Term to (value, is_uri) tuple for database storage."""
@@ -225,25 +227,19 @@ class KnowledgeTableStore:
for v in m.triples
]
- while True:
-
- try:
-
- resp = self.cassandra.execute(
- self.insert_triples_stmt,
- (
- uuid.uuid4(), m.metadata.user,
- m.metadata.root or m.metadata.id, when,
- [], triples,
- )
- )
-
- break
-
- except Exception as e:
-
- logger.error("Exception occurred", exc_info=True)
- raise e
+ try:
+ await async_execute(
+ self.cassandra,
+ self.insert_triples_stmt,
+ (
+ uuid.uuid4(), m.metadata.user,
+ m.metadata.root or m.metadata.id, when,
+ [], triples,
+ ),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
async def add_graph_embeddings(self, m):
@@ -257,25 +253,19 @@ class KnowledgeTableStore:
for v in m.entities
]
- while True:
-
- try:
-
- resp = self.cassandra.execute(
- self.insert_graph_embeddings_stmt,
- (
- uuid.uuid4(), m.metadata.user,
- m.metadata.root or m.metadata.id, when,
- [], entities,
- )
- )
-
- break
-
- except Exception as e:
-
- logger.error("Exception occurred", exc_info=True)
- raise e
+ try:
+ await async_execute(
+ self.cassandra,
+ self.insert_graph_embeddings_stmt,
+ (
+ uuid.uuid4(), m.metadata.user,
+ m.metadata.root or m.metadata.id, when,
+ [], entities,
+ ),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
async def add_document_embeddings(self, m):
@@ -289,50 +279,35 @@ class KnowledgeTableStore:
for v in m.chunks
]
- while True:
-
- try:
-
- resp = self.cassandra.execute(
- self.insert_document_embeddings_stmt,
- (
- uuid.uuid4(), m.metadata.user,
- m.metadata.root or m.metadata.id, when,
- [], chunks,
- )
- )
-
- break
-
- except Exception as e:
-
- logger.error("Exception occurred", exc_info=True)
- raise e
+ try:
+ await async_execute(
+ self.cassandra,
+ self.insert_document_embeddings_stmt,
+ (
+ uuid.uuid4(), m.metadata.user,
+ m.metadata.root or m.metadata.id, when,
+ [], chunks,
+ ),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
async def list_kg_cores(self, user):
logger.debug("List kg cores...")
- while True:
+ try:
+ rows = await async_execute(
+ self.cassandra,
+ self.list_cores_stmt,
+ (user,),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
- try:
-
- resp = self.cassandra.execute(
- self.list_cores_stmt,
- (user,)
- )
-
- break
-
- except Exception as e:
- logger.error("Exception occurred", exc_info=True)
- raise e
-
-
- lst = [
- row[1]
- for row in resp
- ]
+ lst = [row[1] for row in rows]
logger.debug("Done")
@@ -342,56 +317,41 @@ class KnowledgeTableStore:
logger.debug("Delete kg cores...")
- while True:
+ try:
+ await async_execute(
+ self.cassandra,
+ self.delete_triples_stmt,
+ (user, document_id),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
- try:
-
- resp = self.cassandra.execute(
- self.delete_triples_stmt,
- (user, document_id)
- )
-
- break
-
- except Exception as e:
- logger.error("Exception occurred", exc_info=True)
- raise e
-
- while True:
-
- try:
-
- resp = self.cassandra.execute(
- self.delete_graph_embeddings_stmt,
- (user, document_id)
- )
-
- break
-
- except Exception as e:
- logger.error("Exception occurred", exc_info=True)
- raise e
+ try:
+ await async_execute(
+ self.cassandra,
+ self.delete_graph_embeddings_stmt,
+ (user, document_id),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
async def get_triples(self, user, document_id, receiver):
logger.debug("Get triples...")
- while True:
+ try:
+ rows = await async_execute(
+ self.cassandra,
+ self.get_triples_stmt,
+ (user, document_id),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
- try:
-
- resp = self.cassandra.execute(
- self.get_triples_stmt,
- (user, document_id)
- )
-
- break
-
- except Exception as e:
- logger.error("Exception occurred", exc_info=True)
- raise e
-
- for row in resp:
+ for row in rows:
if row[3]:
triples = [
@@ -422,28 +382,23 @@ class KnowledgeTableStore:
logger.debug("Get GE...")
- while True:
+ try:
+ rows = await async_execute(
+ self.cassandra,
+ self.get_graph_embeddings_stmt,
+ (user, document_id),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
- try:
-
- resp = self.cassandra.execute(
- self.get_graph_embeddings_stmt,
- (user, document_id)
- )
-
- break
-
- except Exception as e:
- logger.error("Exception occurred", exc_info=True)
- raise e
-
- for row in resp:
+ for row in rows:
if row[3]:
entities = [
EntityEmbeddings(
entity = tuple_to_term(ent[0][0], ent[0][1]),
- vectors = ent[1]
+ vector = ent[1]
)
for ent in row[3]
]
diff --git a/trustgraph-flow/trustgraph/tables/library.py b/trustgraph-flow/trustgraph/tables/library.py
index 11dd9022..c85ae72a 100644
--- a/trustgraph-flow/trustgraph/tables/library.py
+++ b/trustgraph-flow/trustgraph/tables/library.py
@@ -31,6 +31,8 @@ import time
import asyncio
import logging
+from . cassandra_async import async_execute
+
logger = logging.getLogger(__name__)
class LibraryTableStore:
@@ -321,18 +323,13 @@ class LibraryTableStore:
async def document_exists(self, user, id):
- resp = self.cassandra.execute(
+ rows = await async_execute(
+ self.cassandra,
self.test_document_exists_stmt,
- ( user, id )
+ (user, id),
)
- # If a row exists, document exists. It's a cursor, can't just
- # count the length
-
- for row in resp:
- return True
-
- return False
+ return bool(rows)
async def add_document(self, document, object_id):
@@ -349,26 +346,20 @@ class LibraryTableStore:
parent_id = getattr(document, 'parent_id', '') or ''
document_type = getattr(document, 'document_type', 'source') or 'source'
- while True:
-
- try:
-
- resp = self.cassandra.execute(
- self.insert_document_stmt,
- (
- document.id, document.user, int(document.time * 1000),
- document.kind, document.title, document.comments,
- metadata, document.tags, object_id,
- parent_id, document_type
- )
- )
-
- break
-
- except Exception as e:
-
- logger.error("Exception occurred", exc_info=True)
- raise e
+ try:
+ await async_execute(
+ self.cassandra,
+ self.insert_document_stmt,
+ (
+ document.id, document.user, int(document.time * 1000),
+ document.kind, document.title, document.comments,
+ metadata, document.tags, object_id,
+ parent_id, document_type
+ ),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
logger.debug("Add complete")
@@ -383,25 +374,19 @@ class LibraryTableStore:
for v in document.metadata
]
- while True:
-
- try:
-
- resp = self.cassandra.execute(
- self.update_document_stmt,
- (
- int(document.time * 1000), document.title,
- document.comments, metadata, document.tags,
- document.user, document.id
- )
- )
-
- break
-
- except Exception as e:
-
- logger.error("Exception occurred", exc_info=True)
- raise e
+ try:
+ await async_execute(
+ self.cassandra,
+ self.update_document_stmt,
+ (
+ int(document.time * 1000), document.title,
+ document.comments, metadata, document.tags,
+ document.user, document.id
+ ),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
logger.debug("Update complete")
@@ -409,23 +394,15 @@ class LibraryTableStore:
logger.info(f"Removing document {document_id}")
- while True:
-
- try:
-
- resp = self.cassandra.execute(
- self.delete_document_stmt,
- (
- user, document_id
- )
- )
-
- break
-
- except Exception as e:
-
- logger.error("Exception occurred", exc_info=True)
- raise e
+ try:
+ await async_execute(
+ self.cassandra,
+ self.delete_document_stmt,
+ (user, document_id),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
logger.debug("Delete complete")
@@ -433,21 +410,15 @@ class LibraryTableStore:
logger.debug("List documents...")
- while True:
-
- try:
-
- resp = self.cassandra.execute(
- self.list_document_stmt,
- (user,)
- )
-
- break
-
- except Exception as e:
- logger.error("Exception occurred", exc_info=True)
- raise e
-
+ try:
+ rows = await async_execute(
+ self.cassandra,
+ self.list_document_stmt,
+ (user,),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
lst = [
DocumentMetadata(
@@ -469,7 +440,7 @@ class LibraryTableStore:
parent_id = row[8] if row[8] else "",
document_type = row[9] if row[9] else "source",
)
- for row in resp
+ for row in rows
]
logger.debug("Done")
@@ -481,20 +452,15 @@ class LibraryTableStore:
logger.debug(f"List children for parent {parent_id}")
- while True:
-
- try:
-
- resp = self.cassandra.execute(
- self.list_children_stmt,
- (parent_id,)
- )
-
- break
-
- except Exception as e:
- logger.error("Exception occurred", exc_info=True)
- raise e
+ try:
+ rows = await async_execute(
+ self.cassandra,
+ self.list_children_stmt,
+ (parent_id,),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
lst = [
DocumentMetadata(
@@ -516,7 +482,7 @@ class LibraryTableStore:
parent_id = row[9] if row[9] else "",
document_type = row[10] if row[10] else "source",
)
- for row in resp
+ for row in rows
]
logger.debug("Done")
@@ -527,23 +493,17 @@ class LibraryTableStore:
logger.debug("Get document")
- while True:
+ try:
+ rows = await async_execute(
+ self.cassandra,
+ self.get_document_stmt,
+ (user, id),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
- try:
-
- resp = self.cassandra.execute(
- self.get_document_stmt,
- (user, id)
- )
-
- break
-
- except Exception as e:
- logger.error("Exception occurred", exc_info=True)
- raise e
-
-
- for row in resp:
+ for row in rows:
doc = DocumentMetadata(
id = id,
user = user,
@@ -573,23 +533,17 @@ class LibraryTableStore:
logger.debug("Get document obj ID")
- while True:
+ try:
+ rows = await async_execute(
+ self.cassandra,
+ self.get_document_stmt,
+ (user, id),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
- try:
-
- resp = self.cassandra.execute(
- self.get_document_stmt,
- (user, id)
- )
-
- break
-
- except Exception as e:
- logger.error("Exception occurred", exc_info=True)
- raise e
-
-
- for row in resp:
+ for row in rows:
logger.debug("Done")
return row[6]
@@ -597,43 +551,32 @@ class LibraryTableStore:
async def processing_exists(self, user, id):
- resp = self.cassandra.execute(
+ rows = await async_execute(
+ self.cassandra,
self.test_processing_exists_stmt,
- ( user, id )
+ (user, id),
)
- # If a row exists, document exists. It's a cursor, can't just
- # count the length
-
- for row in resp:
- return True
-
- return False
+ return bool(rows)
async def add_processing(self, processing):
logger.info(f"Adding processing {processing.id}")
- while True:
-
- try:
-
- resp = self.cassandra.execute(
- self.insert_processing_stmt,
- (
- processing.id, processing.document_id,
- int(processing.time * 1000), processing.flow,
- processing.user, processing.collection,
- processing.tags
- )
- )
-
- break
-
- except Exception as e:
-
- logger.error("Exception occurred", exc_info=True)
- raise e
+ try:
+ await async_execute(
+ self.cassandra,
+ self.insert_processing_stmt,
+ (
+ processing.id, processing.document_id,
+ int(processing.time * 1000), processing.flow,
+ processing.user, processing.collection,
+ processing.tags
+ ),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
logger.debug("Add complete")
@@ -641,23 +584,15 @@ class LibraryTableStore:
logger.info(f"Removing processing {processing_id}")
- while True:
-
- try:
-
- resp = self.cassandra.execute(
- self.delete_processing_stmt,
- (
- user, processing_id
- )
- )
-
- break
-
- except Exception as e:
-
- logger.error("Exception occurred", exc_info=True)
- raise e
+ try:
+ await async_execute(
+ self.cassandra,
+ self.delete_processing_stmt,
+ (user, processing_id),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
logger.debug("Delete complete")
@@ -665,21 +600,15 @@ class LibraryTableStore:
logger.debug("List processing objects")
- while True:
-
- try:
-
- resp = self.cassandra.execute(
- self.list_processing_stmt,
- (user,)
- )
-
- break
-
- except Exception as e:
- logger.error("Exception occurred", exc_info=True)
- raise e
-
+ try:
+ rows = await async_execute(
+ self.cassandra,
+ self.list_processing_stmt,
+ (user,),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
lst = [
ProcessingMetadata(
@@ -691,7 +620,7 @@ class LibraryTableStore:
collection = row[4],
tags = row[5] if row[5] else [],
)
- for row in resp
+ for row in rows
]
logger.debug("Done")
@@ -718,20 +647,19 @@ class LibraryTableStore:
now = int(time.time() * 1000)
- while True:
- try:
- self.cassandra.execute(
- self.insert_upload_session_stmt,
- (
- upload_id, user, document_id, document_metadata,
- s3_upload_id, object_id, total_size, chunk_size,
- total_chunks, {}, now, now
- )
- )
- break
- except Exception as e:
- logger.error("Exception occurred", exc_info=True)
- raise e
+ try:
+ await async_execute(
+ self.cassandra,
+ self.insert_upload_session_stmt,
+ (
+ upload_id, user, document_id, document_metadata,
+ s3_upload_id, object_id, total_size, chunk_size,
+ total_chunks, {}, now, now
+ ),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
logger.debug("Upload session created")
@@ -740,18 +668,17 @@ class LibraryTableStore:
logger.debug(f"Get upload session {upload_id}")
- while True:
- try:
- resp = self.cassandra.execute(
- self.get_upload_session_stmt,
- (upload_id,)
- )
- break
- except Exception as e:
- logger.error("Exception occurred", exc_info=True)
- raise e
+ try:
+ rows = await async_execute(
+ self.cassandra,
+ self.get_upload_session_stmt,
+ (upload_id,),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
- for row in resp:
+ for row in rows:
session = {
"upload_id": row[0],
"user": row[1],
@@ -778,20 +705,19 @@ class LibraryTableStore:
now = int(time.time() * 1000)
- while True:
- try:
- self.cassandra.execute(
- self.update_upload_session_chunk_stmt,
- (
- {chunk_index: etag},
- now,
- upload_id
- )
- )
- break
- except Exception as e:
- logger.error("Exception occurred", exc_info=True)
- raise e
+ try:
+ await async_execute(
+ self.cassandra,
+ self.update_upload_session_chunk_stmt,
+ (
+ {chunk_index: etag},
+ now,
+ upload_id
+ ),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
logger.debug("Chunk recorded")
@@ -800,16 +726,15 @@ class LibraryTableStore:
logger.info(f"Deleting upload session {upload_id}")
- while True:
- try:
- self.cassandra.execute(
- self.delete_upload_session_stmt,
- (upload_id,)
- )
- break
- except Exception as e:
- logger.error("Exception occurred", exc_info=True)
- raise e
+ try:
+ await async_execute(
+ self.cassandra,
+ self.delete_upload_session_stmt,
+ (upload_id,),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
logger.debug("Upload session deleted")
@@ -818,19 +743,18 @@ class LibraryTableStore:
logger.debug(f"List upload sessions for {user}")
- while True:
- try:
- resp = self.cassandra.execute(
- self.list_upload_sessions_stmt,
- (user,)
- )
- break
- except Exception as e:
- logger.error("Exception occurred", exc_info=True)
- raise e
+ try:
+ rows = await async_execute(
+ self.cassandra,
+ self.list_upload_sessions_stmt,
+ (user,),
+ )
+ except Exception:
+ logger.error("Exception occurred", exc_info=True)
+ raise
sessions = []
- for row in resp:
+ for row in rows:
chunks_received = row[6] if row[6] else {}
sessions.append({
"upload_id": row[0],
diff --git a/trustgraph-ocr/pyproject.toml b/trustgraph-ocr/pyproject.toml
index deab0d5c..cd1d20a1 100644
--- a/trustgraph-ocr/pyproject.toml
+++ b/trustgraph-ocr/pyproject.toml
@@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
- "trustgraph-base>=2.2,<2.3",
+ "trustgraph-base>=2.3,<2.4",
"pulsar-client",
"prometheus-client",
"boto3",
diff --git a/trustgraph-unstructured/pyproject.toml b/trustgraph-unstructured/pyproject.toml
index 33265edb..d8879329 100644
--- a/trustgraph-unstructured/pyproject.toml
+++ b/trustgraph-unstructured/pyproject.toml
@@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
- "trustgraph-base>=2.2,<2.3",
+ "trustgraph-base>=2.3,<2.4",
"pulsar-client",
"prometheus-client",
"python-magic",
diff --git a/trustgraph-vertexai/pyproject.toml b/trustgraph-vertexai/pyproject.toml
index 9eb75ed1..45958ef3 100644
--- a/trustgraph-vertexai/pyproject.toml
+++ b/trustgraph-vertexai/pyproject.toml
@@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
- "trustgraph-base>=2.2,<2.3",
+ "trustgraph-base>=2.3,<2.4",
"pulsar-client",
"google-genai",
"google-api-core",