mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-18 03:45:12 +02:00
Merge branch 'release/v2.3'
This commit is contained in:
commit
1f30a3bcea
155 changed files with 6526 additions and 1885 deletions
2
.github/workflows/pull-request.yaml
vendored
2
.github/workflows/pull-request.yaml
vendored
|
|
@ -22,7 +22,7 @@ jobs:
|
|||
uses: actions/checkout@v3
|
||||
|
||||
- name: Setup packages
|
||||
run: make update-package-versions VERSION=2.2.999
|
||||
run: make update-package-versions VERSION=2.3.999
|
||||
|
||||
- name: Setup environment
|
||||
run: python3 -m venv env
|
||||
|
|
|
|||
79
.github/workflows/release.yaml
vendored
79
.github/workflows/release.yaml
vendored
|
|
@ -40,10 +40,9 @@ jobs:
|
|||
- name: Publish release distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
|
||||
deploy-container-image:
|
||||
build-platform-image:
|
||||
|
||||
name: Release container images
|
||||
runs-on: ubuntu-24.04
|
||||
name: Build ${{ matrix.container }} (${{ matrix.platform }})
|
||||
permissions:
|
||||
contents: write
|
||||
id-token: write
|
||||
|
|
@ -52,14 +51,24 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
container:
|
||||
- trustgraph-base
|
||||
- trustgraph-flow
|
||||
- trustgraph-bedrock
|
||||
- trustgraph-vertexai
|
||||
- trustgraph-hf
|
||||
- trustgraph-ocr
|
||||
- trustgraph-unstructured
|
||||
- trustgraph-mcp
|
||||
- base
|
||||
- flow
|
||||
- bedrock
|
||||
- vertexai
|
||||
- hf
|
||||
- ocr
|
||||
- unstructured
|
||||
- mcp
|
||||
platform:
|
||||
- amd64
|
||||
- arm64
|
||||
include:
|
||||
- platform: amd64
|
||||
runner: ubuntu-24.04
|
||||
- platform: arm64
|
||||
runner: ubuntu-24.04-arm
|
||||
|
||||
runs-on: ${{ matrix.runner }}
|
||||
|
||||
steps:
|
||||
|
||||
|
|
@ -76,12 +85,48 @@ jobs:
|
|||
id: version
|
||||
run: echo VERSION=$(git describe --exact-match --tags | sed 's/^v//') >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Put version into package manifests
|
||||
run: make update-package-versions VERSION=${{ steps.version.outputs.VERSION }}
|
||||
- name: Build container
|
||||
run: make platform-${{ matrix.container }}-${{ matrix.platform }} VERSION=${{ steps.version.outputs.VERSION }}
|
||||
|
||||
- name: Build container - ${{ matrix.container }}
|
||||
run: make container-${{ matrix.container }} VERSION=${{ steps.version.outputs.VERSION }}
|
||||
- name: Push container
|
||||
run: make push-platform-${{ matrix.container }}-${{ matrix.platform }} VERSION=${{ steps.version.outputs.VERSION }}
|
||||
|
||||
- name: Push container - ${{ matrix.container }}
|
||||
run: make push-${{ matrix.container }} VERSION=${{ steps.version.outputs.VERSION }}
|
||||
combine-manifests:
|
||||
|
||||
name: Combine manifest ${{ matrix.container }}
|
||||
runs-on: ubuntu-24.04
|
||||
needs: build-platform-image
|
||||
permissions:
|
||||
contents: write
|
||||
id-token: write
|
||||
environment:
|
||||
name: release
|
||||
strategy:
|
||||
matrix:
|
||||
container:
|
||||
- base
|
||||
- flow
|
||||
- bedrock
|
||||
- vertexai
|
||||
- hf
|
||||
- ocr
|
||||
- unstructured
|
||||
- mcp
|
||||
|
||||
steps:
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Docker Hub token
|
||||
run: echo ${{ secrets.DOCKER_SECRET }} > docker-token.txt
|
||||
|
||||
- name: Authenticate with Docker hub
|
||||
run: make docker-hub-login
|
||||
|
||||
- name: Get version
|
||||
id: version
|
||||
run: echo VERSION=$(git describe --exact-match --tags | sed 's/^v//') >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Combine and push manifest
|
||||
run: make combine-manifest-${{ matrix.container }} VERSION=${{ steps.version.outputs.VERSION }}
|
||||
|
|
|
|||
181
Makefile
181
Makefile
|
|
@ -52,51 +52,12 @@ update-package-versions:
|
|||
echo __version__ = \"${VERSION}\" > trustgraph/trustgraph/trustgraph_version.py
|
||||
echo __version__ = \"${VERSION}\" > trustgraph-mcp/trustgraph/mcp_version.py
|
||||
|
||||
FORCE:
|
||||
containers: container-base container-flow \
|
||||
container-bedrock container-vertexai \
|
||||
container-hf container-ocr \
|
||||
container-unstructured container-mcp
|
||||
|
||||
containers: FORCE
|
||||
${DOCKER} build -f containers/Containerfile.base \
|
||||
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
|
||||
${DOCKER} build -f containers/Containerfile.flow \
|
||||
-t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
|
||||
${DOCKER} build -f containers/Containerfile.bedrock \
|
||||
-t ${CONTAINER_BASE}/trustgraph-bedrock:${VERSION} .
|
||||
${DOCKER} build -f containers/Containerfile.vertexai \
|
||||
-t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
|
||||
${DOCKER} build -f containers/Containerfile.hf \
|
||||
-t ${CONTAINER_BASE}/trustgraph-hf:${VERSION} .
|
||||
${DOCKER} build -f containers/Containerfile.ocr \
|
||||
-t ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} .
|
||||
${DOCKER} build -f containers/Containerfile.unstructured \
|
||||
-t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} .
|
||||
${DOCKER} build -f containers/Containerfile.mcp \
|
||||
-t ${CONTAINER_BASE}/trustgraph-mcp:${VERSION} .
|
||||
|
||||
some-containers:
|
||||
${DOCKER} build -f containers/Containerfile.base \
|
||||
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
|
||||
${DOCKER} build -f containers/Containerfile.flow \
|
||||
-t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
|
||||
# ${DOCKER} build -f containers/Containerfile.unstructured \
|
||||
# -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} .
|
||||
# ${DOCKER} build -f containers/Containerfile.vertexai \
|
||||
# -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
|
||||
# ${DOCKER} build -f containers/Containerfile.mcp \
|
||||
# -t ${CONTAINER_BASE}/trustgraph-mcp:${VERSION} .
|
||||
# ${DOCKER} build -f containers/Containerfile.vertexai \
|
||||
# -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
|
||||
# ${DOCKER} build -f containers/Containerfile.bedrock \
|
||||
# -t ${CONTAINER_BASE}/trustgraph-bedrock:${VERSION} .
|
||||
|
||||
basic-containers: update-package-versions
|
||||
${DOCKER} build -f containers/Containerfile.base \
|
||||
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
|
||||
${DOCKER} build -f containers/Containerfile.flow \
|
||||
-t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
|
||||
|
||||
container.ocr:
|
||||
${DOCKER} build -f containers/Containerfile.ocr \
|
||||
-t ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} .
|
||||
some-containers: container-base container-flow
|
||||
|
||||
push:
|
||||
${DOCKER} push ${CONTAINER_BASE}/trustgraph-base:${VERSION}
|
||||
|
|
@ -109,54 +70,60 @@ push:
|
|||
${DOCKER} push ${CONTAINER_BASE}/trustgraph-mcp:${VERSION}
|
||||
|
||||
# Individual container build targets
|
||||
container-trustgraph-base: update-package-versions
|
||||
${DOCKER} build -f containers/Containerfile.base -t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
|
||||
container-%: update-package-versions
|
||||
${DOCKER} build \
|
||||
-f containers/Containerfile.${@:container-%=%} \
|
||||
-t ${CONTAINER_BASE}/trustgraph-${@:container-%=%}:${VERSION} .
|
||||
|
||||
container-trustgraph-flow: update-package-versions
|
||||
${DOCKER} build -f containers/Containerfile.flow -t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
|
||||
# Multi-arch: build both platforms sequentially into one manifest (local use)
|
||||
manifest-%: update-package-versions
|
||||
-@${DOCKER} manifest rm \
|
||||
${CONTAINER_BASE}/trustgraph-${@:manifest-%=%}:${VERSION}
|
||||
${DOCKER} build --platform linux/amd64,linux/arm64 \
|
||||
-f containers/Containerfile.${@:manifest-%=%} \
|
||||
--manifest \
|
||||
${CONTAINER_BASE}/trustgraph-${@:manifest-%=%}:${VERSION} .
|
||||
|
||||
container-trustgraph-bedrock: update-package-versions
|
||||
${DOCKER} build -f containers/Containerfile.bedrock -t ${CONTAINER_BASE}/trustgraph-bedrock:${VERSION} .
|
||||
# Multi-arch: build a single platform image (for parallel CI)
|
||||
platform-%-amd64: update-package-versions
|
||||
${DOCKER} build --platform linux/amd64 \
|
||||
-f containers/Containerfile.${@:platform-%-amd64=%} \
|
||||
-t ${CONTAINER_BASE}/trustgraph-${@:platform-%-amd64=%}:${VERSION}-amd64 .
|
||||
|
||||
container-trustgraph-vertexai: update-package-versions
|
||||
${DOCKER} build -f containers/Containerfile.vertexai -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
|
||||
platform-%-arm64: update-package-versions
|
||||
${DOCKER} build --platform linux/arm64 \
|
||||
-f containers/Containerfile.${@:platform-%-arm64=%} \
|
||||
-t ${CONTAINER_BASE}/trustgraph-${@:platform-%-arm64=%}:${VERSION}-arm64 .
|
||||
|
||||
container-trustgraph-hf: update-package-versions
|
||||
${DOCKER} build -f containers/Containerfile.hf -t ${CONTAINER_BASE}/trustgraph-hf:${VERSION} .
|
||||
# Push a single platform image
|
||||
push-platform-%-amd64:
|
||||
${DOCKER} push \
|
||||
${CONTAINER_BASE}/trustgraph-${@:push-platform-%-amd64=%}:${VERSION}-amd64
|
||||
|
||||
container-trustgraph-ocr: update-package-versions
|
||||
${DOCKER} build -f containers/Containerfile.ocr -t ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} .
|
||||
push-platform-%-arm64:
|
||||
${DOCKER} push \
|
||||
${CONTAINER_BASE}/trustgraph-${@:push-platform-%-arm64=%}:${VERSION}-arm64
|
||||
|
||||
container-trustgraph-unstructured: update-package-versions
|
||||
${DOCKER} build -f containers/Containerfile.unstructured -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} .
|
||||
# Combine per-platform images into a multi-arch manifest
|
||||
combine-manifest-%:
|
||||
-@${DOCKER} manifest rm \
|
||||
${CONTAINER_BASE}/trustgraph-${@:combine-manifest-%=%}:${VERSION}
|
||||
${DOCKER} manifest create \
|
||||
${CONTAINER_BASE}/trustgraph-${@:combine-manifest-%=%}:${VERSION} \
|
||||
docker://${CONTAINER_BASE}/trustgraph-${@:combine-manifest-%=%}:${VERSION}-amd64 \
|
||||
docker://${CONTAINER_BASE}/trustgraph-${@:combine-manifest-%=%}:${VERSION}-arm64
|
||||
${DOCKER} manifest push \
|
||||
${CONTAINER_BASE}/trustgraph-${@:combine-manifest-%=%}:${VERSION}
|
||||
|
||||
container-trustgraph-mcp: update-package-versions
|
||||
${DOCKER} build -f containers/Containerfile.mcp -t ${CONTAINER_BASE}/trustgraph-mcp:${VERSION} .
|
||||
# Push a container
|
||||
push-container-%:
|
||||
${DOCKER} push \
|
||||
${CONTAINER_BASE}/trustgraph-${@:push-container-%=%}:${VERSION}
|
||||
|
||||
# Individual container push targets
|
||||
push-trustgraph-base:
|
||||
${DOCKER} push ${CONTAINER_BASE}/trustgraph-base:${VERSION}
|
||||
|
||||
push-trustgraph-flow:
|
||||
${DOCKER} push ${CONTAINER_BASE}/trustgraph-flow:${VERSION}
|
||||
|
||||
push-trustgraph-bedrock:
|
||||
${DOCKER} push ${CONTAINER_BASE}/trustgraph-bedrock:${VERSION}
|
||||
|
||||
push-trustgraph-vertexai:
|
||||
${DOCKER} push ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION}
|
||||
|
||||
push-trustgraph-hf:
|
||||
${DOCKER} push ${CONTAINER_BASE}/trustgraph-hf:${VERSION}
|
||||
|
||||
push-trustgraph-ocr:
|
||||
${DOCKER} push ${CONTAINER_BASE}/trustgraph-ocr:${VERSION}
|
||||
|
||||
push-trustgraph-unstructured:
|
||||
${DOCKER} push ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION}
|
||||
|
||||
push-trustgraph-mcp:
|
||||
${DOCKER} push ${CONTAINER_BASE}/trustgraph-mcp:${VERSION}
|
||||
# Push a manifest (from local multi-arch build)
|
||||
push-manifest-%:
|
||||
${DOCKER} manifest push \
|
||||
${CONTAINER_BASE}/trustgraph-${@:push-manifest-%=%}:${VERSION}
|
||||
|
||||
clean:
|
||||
rm -rf wheels/
|
||||
|
|
@ -164,52 +131,6 @@ clean:
|
|||
set-version:
|
||||
echo '"${VERSION}"' > templates/values/version.jsonnet
|
||||
|
||||
TEMPLATES=azure bedrock claude cohere mix llamafile mistral ollama openai vertexai \
|
||||
openai-neo4j storage
|
||||
|
||||
DCS=$(foreach template,${TEMPLATES},${template:%=tg-launch-%.yaml})
|
||||
|
||||
MODELS=azure bedrock claude cohere llamafile mistral ollama openai vertexai
|
||||
GRAPHS=cassandra neo4j falkordb memgraph
|
||||
|
||||
# tg-launch-%.yaml: templates/%.jsonnet templates/components/version.jsonnet
|
||||
# jsonnet -Jtemplates \
|
||||
# -S ${@:tg-launch-%.yaml=templates/%.jsonnet} > $@
|
||||
|
||||
# VECTORDB=milvus
|
||||
VECTORDB=qdrant
|
||||
|
||||
JSONNET_FLAGS=-J templates -J .
|
||||
|
||||
# Temporarily going back to how templates were built in 0.9 because this
|
||||
# is going away in 0.11.
|
||||
|
||||
update-templates: update-dcs
|
||||
|
||||
JSON_TO_YAML=python -c 'import sys, yaml, json; j=json.loads(sys.stdin.read()); print(yaml.safe_dump(j))'
|
||||
|
||||
update-dcs: set-version
|
||||
for graph in ${GRAPHS}; do \
|
||||
cm=$${graph},pulsar,${VECTORDB},grafana; \
|
||||
input=templates/opts-to-docker-compose.jsonnet; \
|
||||
output=tg-storage-$${graph}.yaml; \
|
||||
echo $${graph} '->' $${output}; \
|
||||
jsonnet ${JSONNET_FLAGS} \
|
||||
--ext-str options=$${cm} $${input} | \
|
||||
${JSON_TO_YAML} > $${output}; \
|
||||
done
|
||||
for model in ${MODELS}; do \
|
||||
for graph in ${GRAPHS}; do \
|
||||
cm=$${graph},pulsar,${VECTORDB},embeddings-hf,graph-rag,grafana,trustgraph,$${model}; \
|
||||
input=templates/opts-to-docker-compose.jsonnet; \
|
||||
output=tg-launch-$${model}-$${graph}.yaml; \
|
||||
echo $${model} + $${graph} '->' $${output}; \
|
||||
jsonnet ${JSONNET_FLAGS} \
|
||||
--ext-str options=$${cm} $${input} | \
|
||||
${JSON_TO_YAML} > $${output}; \
|
||||
done; \
|
||||
done
|
||||
|
||||
docker-hub-login:
|
||||
cat docker-token.txt | \
|
||||
${DOCKER} login -u trustgraph --password-stdin registry-1.docker.io
|
||||
|
|
|
|||
|
|
@ -1,22 +1,22 @@
|
|||
|
||||
# ----------------------------------------------------------------------------
|
||||
# Build an AI container. This does the torch install which is huge, and I
|
||||
# like to avoid re-doing this.
|
||||
# ----------------------------------------------------------------------------
|
||||
# Torch is stable and compiles for ARM64 and AMD64 on Python 3.12
|
||||
|
||||
FROM docker.io/fedora:42 AS ai
|
||||
|
||||
ENV PIP_BREAK_SYSTEM_PACKAGES=1
|
||||
|
||||
RUN dnf install -y python3.13 && \
|
||||
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
|
||||
RUN dnf install -y python3.12 && \
|
||||
alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \
|
||||
python -m ensurepip --upgrade && \
|
||||
pip3 install --no-cache-dir build wheel aiohttp && \
|
||||
pip3 install --no-cache-dir pulsar-client==3.7.0 && \
|
||||
dnf clean all
|
||||
|
||||
RUN pip3 install torch==2.5.1+cpu \
|
||||
--index-url https://download.pytorch.org/whl/cpu
|
||||
# This won't work on ARM
|
||||
#RUN pip3 install torch==2.5.1+cpu \
|
||||
# --index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
RUN pip3 install torch
|
||||
|
||||
RUN pip3 install --no-cache-dir \
|
||||
langchain==0.3.25 langchain-core==0.3.60 langchain-huggingface==0.2.0 \
|
||||
|
|
|
|||
117
dev-tools/proc-group/README.md
Normal file
117
dev-tools/proc-group/README.md
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
# proc-group — run TrustGraph as a single process
|
||||
|
||||
A dev-focused alternative to the per-container deployment. Instead of 30+
|
||||
containers each running a single processor, `processor-group` runs all the
|
||||
processors as asyncio tasks inside one Python process, sharing the event
|
||||
loop, Prometheus registry, and (importantly) resources on your laptop.
|
||||
|
||||
This is **not** for production. Scale deployments should keep using
|
||||
per-processor containers — one failure bringing down the whole process,
|
||||
no horizontal scaling, and a single giant log are fine for dev and a
|
||||
bad idea in prod.
|
||||
|
||||
## What this directory contains
|
||||
|
||||
- `group.yaml` — the group runner config. One entry per processor, each
|
||||
with the dotted class path and a params dict. Defaults (pubsub backend,
|
||||
rabbitmq host, log level) are pulled in per-entry with a YAML anchor.
|
||||
- `README.md` — this file.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Install the TrustGraph packages into a venv:
|
||||
|
||||
```
|
||||
pip install trustgraph-base trustgraph-flow trustgraph-unstructured
|
||||
```
|
||||
|
||||
`trustgraph-base` provides the `processor-group` endpoint. The others
|
||||
provide the processor classes that `group.yaml` imports at runtime.
|
||||
`trustgraph-unstructured` is only needed if you want `document-decoder`
|
||||
(the `universal-decoder` processor).
|
||||
|
||||
## Running it
|
||||
|
||||
Start infrastructure (cassandra, qdrant, rabbitmq, garage, observability
|
||||
stack) with a working compose file. These aren't packable into the group -
|
||||
they're third-party services. You may be able to run these as standalone
|
||||
services.
|
||||
|
||||
To get Cassandra to be accessible from the host, you need to
|
||||
set a couple of environment variables:
|
||||
```
|
||||
CASSANDRA_BROADCAST_ADDRESS: 127.0.0.1
|
||||
CASSANDRA_LISTEN_ADDRESS: 127.0.0.1
|
||||
```
|
||||
and also set `network: host`. Then start services:
|
||||
|
||||
```
|
||||
podman-compose up -d cassandra qdrant rabbitmq
|
||||
podman-compose up -d garage garage-init
|
||||
podman-compose up -d loki prometheus grafana
|
||||
podman-compose up -d init-trustgraph
|
||||
```
|
||||
|
||||
`init-trustgraph` is a one-shot that seeds config and the default flow
|
||||
into cassandra/rabbitmq. Don't leave too long a delay between starting
|
||||
`init-trustgraph` and running the processor-group, because it needs to
|
||||
talk to the config service.
|
||||
|
||||
Run the api-gateway separately — it's an aiohttp HTTP server, not an
|
||||
`AsyncProcessor`, so the group runner doesn't host it:
|
||||
|
||||
Raise the file descriptor limit — 30+ processors sharing one process
|
||||
open far more sockets than the default 1024 allows:
|
||||
|
||||
```
|
||||
ulimit -n 65536
|
||||
```
|
||||
|
||||
Then start the group from a terminal:
|
||||
|
||||
```
|
||||
processor-group -c group.yaml --no-loki-enabled
|
||||
```
|
||||
|
||||
You'll see every processor's startup messages interleaved in one log.
|
||||
Each processor has a supervisor that restarts it independently on
|
||||
failure, so a transient crash (or a dependency that isn't ready yet)
|
||||
only affects that one processor — siblings keep running and the failing
|
||||
one self-heals on the next retry.
|
||||
|
||||
Finally when everything is running you can start the API gateway from
|
||||
its own terminal:
|
||||
|
||||
```
|
||||
api-gateway \
|
||||
--pubsub-backend rabbitmq --rabbitmq-host localhost \
|
||||
--loki-url http://localhost:3100/loki/api/v1/push \
|
||||
--no-metrics
|
||||
```
|
||||
|
||||
|
||||
|
||||
## When things go wrong
|
||||
|
||||
- **"Too many open files"** — raise `ulimit -n` further. 65536 is
|
||||
usually plenty but some workflows need more.
|
||||
- **One processor failing repeatedly** — look for its id in the log. The
|
||||
supervisor will log each failure before restarting. Fix the cause
|
||||
(missing env var, unreachable dependency, bad params) and the
|
||||
processor self-heals on the next 4-second retry without restarting
|
||||
the whole group.
|
||||
- **Ctrl-C leaves the process hung** — the pika and cassandra drivers
|
||||
spawn non-cooperative threads that asyncio can't cancel. Use Ctrl-\
|
||||
(SIGQUIT) to force-kill. Not a bug in the group runner, just a
|
||||
limitation of those libraries.
|
||||
|
||||
## Environment variables
|
||||
|
||||
Processors that talk to external LLMs or APIs read their credentials
|
||||
from env vars, same as in the per-container deployment:
|
||||
|
||||
- `OPENAI_TOKEN`, `OPENAI_BASE_URL` — for `text-completion` /
|
||||
`text-completion-rag`
|
||||
|
||||
Export whatever your particular `group.yaml` needs before running.
|
||||
|
||||
257
dev-tools/proc-group/group.yaml
Normal file
257
dev-tools/proc-group/group.yaml
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
# Multi-processor group config, derived from docker-compose.yaml.
|
||||
#
|
||||
# Covers every AsyncProcessor-based service from the compose file.
|
||||
# Out of scope:
|
||||
# - api-gateway (aiohttp, not AsyncProcessor)
|
||||
# - init-trustgraph (one-shot init, not a processor)
|
||||
# - document-decoder (universal-decoder, trustgraph-unstructured package —
|
||||
# packable but lives in a separate image/package)
|
||||
# - mcp-server (trustgraph-mcp package, separate image)
|
||||
# - ddg-mcp-server (third-party image)
|
||||
# - infrastructure (cassandra, rabbitmq, qdrant, garage, grafana,
|
||||
# prometheus, loki, workbench-ui)
|
||||
#
|
||||
# Run with:
|
||||
# processor-group -c group.yaml
|
||||
|
||||
_defaults: &defaults
|
||||
pubsub_backend: rabbitmq
|
||||
rabbitmq_host: localhost
|
||||
log_level: INFO
|
||||
|
||||
processors:
|
||||
|
||||
- class: trustgraph.agent.orchestrator.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: agent-manager
|
||||
|
||||
- class: trustgraph.chunking.recursive.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: chunker
|
||||
chunk_size: 2000
|
||||
chunk_overlap: 50
|
||||
|
||||
- class: trustgraph.config.service.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: config-svc
|
||||
cassandra_host: localhost
|
||||
|
||||
- class: trustgraph.decoding.universal.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: document-decoder
|
||||
|
||||
- class: trustgraph.embeddings.document_embeddings.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: document-embeddings
|
||||
|
||||
- class: trustgraph.retrieval.document_rag.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: document-rag
|
||||
doc_limit: 20
|
||||
|
||||
- class: trustgraph.embeddings.fastembed.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: embeddings
|
||||
concurrency: 1
|
||||
|
||||
- class: trustgraph.embeddings.graph_embeddings.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: graph-embeddings
|
||||
|
||||
- class: trustgraph.retrieval.graph_rag.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: graph-rag
|
||||
concurrency: 1
|
||||
entity_limit: 50
|
||||
triple_limit: 30
|
||||
edge_limit: 30
|
||||
edge_score_limit: 10
|
||||
max_subgraph_size: 100
|
||||
max_path_length: 2
|
||||
|
||||
- class: trustgraph.extract.kg.agent.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: kg-extract-agent
|
||||
concurrency: 1
|
||||
|
||||
- class: trustgraph.extract.kg.definitions.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: kg-extract-definitions
|
||||
concurrency: 1
|
||||
|
||||
- class: trustgraph.extract.kg.ontology.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: kg-extract-ontology
|
||||
concurrency: 1
|
||||
|
||||
- class: trustgraph.extract.kg.relationships.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: kg-extract-relationships
|
||||
concurrency: 1
|
||||
|
||||
- class: trustgraph.extract.kg.rows.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: kg-extract-rows
|
||||
concurrency: 1
|
||||
|
||||
- class: trustgraph.cores.service.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: knowledge
|
||||
cassandra_host: localhost
|
||||
|
||||
- class: trustgraph.storage.knowledge.store.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: kg-store
|
||||
cassandra_host: localhost
|
||||
|
||||
- class: trustgraph.librarian.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: librarian
|
||||
cassandra_host: localhost
|
||||
object_store_endpoint: localhost:3900
|
||||
object_store_access_key: GK000000000000000000000001
|
||||
object_store_secret_key: b171f00be9be4c32c734f4c05fe64c527a8ab5eb823b376cfa8c2531f70fc427
|
||||
object_store_region: garage
|
||||
|
||||
- class: trustgraph.agent.mcp_tool.Service
|
||||
params:
|
||||
<<: *defaults
|
||||
id: mcp-tool
|
||||
|
||||
- class: trustgraph.metering.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: metering
|
||||
|
||||
- class: trustgraph.metering.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: metering-rag
|
||||
|
||||
- class: trustgraph.retrieval.nlp_query.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: nlp-query
|
||||
|
||||
- class: trustgraph.prompt.template.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: prompt
|
||||
concurrency: 1
|
||||
|
||||
- class: trustgraph.prompt.template.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: prompt-rag
|
||||
concurrency: 1
|
||||
|
||||
- class: trustgraph.query.doc_embeddings.qdrant.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: doc-embeddings-query
|
||||
store_uri: http://localhost:6333
|
||||
|
||||
- class: trustgraph.query.graph_embeddings.qdrant.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: graph-embeddings-query
|
||||
store_uri: http://localhost:6333
|
||||
|
||||
- class: trustgraph.query.row_embeddings.qdrant.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: row-embeddings-query
|
||||
store_uri: http://localhost:6333
|
||||
|
||||
- class: trustgraph.query.rows.cassandra.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: rows-query
|
||||
cassandra_host: localhost
|
||||
|
||||
- class: trustgraph.query.triples.cassandra.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: triples-query
|
||||
cassandra_host: localhost
|
||||
|
||||
- class: trustgraph.embeddings.row_embeddings.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: row-embeddings
|
||||
|
||||
- class: trustgraph.query.sparql.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: sparql-query
|
||||
|
||||
- class: trustgraph.storage.doc_embeddings.qdrant.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: doc-embeddings-write
|
||||
store_uri: http://localhost:6333
|
||||
|
||||
- class: trustgraph.storage.graph_embeddings.qdrant.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: graph-embeddings-write
|
||||
store_uri: http://localhost:6333
|
||||
|
||||
- class: trustgraph.storage.row_embeddings.qdrant.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: row-embeddings-write
|
||||
store_uri: http://localhost:6333
|
||||
|
||||
- class: trustgraph.storage.rows.cassandra.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: rows-write
|
||||
cassandra_host: localhost
|
||||
|
||||
- class: trustgraph.storage.triples.cassandra.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: triples-write
|
||||
cassandra_host: localhost
|
||||
|
||||
- class: trustgraph.retrieval.structured_diag.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: structured-diag
|
||||
|
||||
- class: trustgraph.retrieval.structured_query.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: structured-query
|
||||
|
||||
- class: trustgraph.model.text_completion.openai.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: text-completion
|
||||
max_output: 8192
|
||||
temperature: 0.0
|
||||
|
||||
- class: trustgraph.model.text_completion.openai.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: text-completion-rag
|
||||
max_output: 8192
|
||||
temperature: 0.0
|
||||
47
dev-tools/proc-group/groups/control.yaml
Normal file
47
dev-tools/proc-group/groups/control.yaml
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
# Control plane. Stateful "always on" services that every flow depends on.
|
||||
# Cassandra-heavy, low traffic.
|
||||
|
||||
_defaults: &defaults
|
||||
pubsub_backend: rabbitmq
|
||||
rabbitmq_host: localhost
|
||||
log_level: INFO
|
||||
|
||||
processors:
|
||||
|
||||
- class: trustgraph.config.service.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: config-svc
|
||||
cassandra_host: localhost
|
||||
|
||||
- class: trustgraph.librarian.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: librarian
|
||||
cassandra_host: localhost
|
||||
object_store_endpoint: localhost:3900
|
||||
object_store_access_key: GK000000000000000000000001
|
||||
object_store_secret_key: b171f00be9be4c32c734f4c05fe64c527a8ab5eb823b376cfa8c2531f70fc427
|
||||
object_store_region: garage
|
||||
|
||||
- class: trustgraph.cores.service.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: knowledge
|
||||
cassandra_host: localhost
|
||||
|
||||
- class: trustgraph.storage.knowledge.store.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: kg-store
|
||||
cassandra_host: localhost
|
||||
|
||||
- class: trustgraph.metering.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: metering
|
||||
|
||||
- class: trustgraph.metering.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: metering-rag
|
||||
45
dev-tools/proc-group/groups/embeddings-store.yaml
Normal file
45
dev-tools/proc-group/groups/embeddings-store.yaml
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
# Embeddings store. All Qdrant-backed vector query/write processors.
|
||||
# One process owns the Qdrant driver pool for the whole group.
|
||||
|
||||
_defaults: &defaults
|
||||
pubsub_backend: rabbitmq
|
||||
rabbitmq_host: localhost
|
||||
log_level: INFO
|
||||
|
||||
processors:
|
||||
|
||||
- class: trustgraph.query.doc_embeddings.qdrant.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: doc-embeddings-query
|
||||
store_uri: http://localhost:6333
|
||||
|
||||
- class: trustgraph.storage.doc_embeddings.qdrant.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: doc-embeddings-write
|
||||
store_uri: http://localhost:6333
|
||||
|
||||
- class: trustgraph.query.graph_embeddings.qdrant.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: graph-embeddings-query
|
||||
store_uri: http://localhost:6333
|
||||
|
||||
- class: trustgraph.storage.graph_embeddings.qdrant.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: graph-embeddings-write
|
||||
store_uri: http://localhost:6333
|
||||
|
||||
- class: trustgraph.query.row_embeddings.qdrant.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: row-embeddings-query
|
||||
store_uri: http://localhost:6333
|
||||
|
||||
- class: trustgraph.storage.row_embeddings.qdrant.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: row-embeddings-write
|
||||
store_uri: http://localhost:6333
|
||||
31
dev-tools/proc-group/groups/embeddings.yaml
Normal file
31
dev-tools/proc-group/groups/embeddings.yaml
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
# Embeddings. Memory-hungry — fastembed loads an ML model at startup.
|
||||
# Keep isolated from other groups so its memory footprint and restart
|
||||
# latency don't affect siblings.
|
||||
|
||||
_defaults: &defaults
|
||||
pubsub_backend: rabbitmq
|
||||
rabbitmq_host: localhost
|
||||
log_level: INFO
|
||||
|
||||
processors:
|
||||
|
||||
- class: trustgraph.embeddings.fastembed.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: embeddings
|
||||
concurrency: 1
|
||||
|
||||
- class: trustgraph.embeddings.document_embeddings.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: document-embeddings
|
||||
|
||||
- class: trustgraph.embeddings.graph_embeddings.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: graph-embeddings
|
||||
|
||||
- class: trustgraph.embeddings.row_embeddings.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: row-embeddings
|
||||
52
dev-tools/proc-group/groups/ingest.yaml
Normal file
52
dev-tools/proc-group/groups/ingest.yaml
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
# Ingest pipeline. Document-processing hot path. Bursty, correlated
|
||||
# failures — if chunker dies the extractors have nothing to do anyway.
|
||||
|
||||
_defaults: &defaults
|
||||
pubsub_backend: rabbitmq
|
||||
rabbitmq_host: localhost
|
||||
log_level: INFO
|
||||
|
||||
processors:
|
||||
|
||||
- class: trustgraph.chunking.recursive.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: chunker
|
||||
chunk_size: 2000
|
||||
chunk_overlap: 50
|
||||
|
||||
- class: trustgraph.extract.kg.agent.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: kg-extract-agent
|
||||
concurrency: 1
|
||||
|
||||
- class: trustgraph.extract.kg.definitions.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: kg-extract-definitions
|
||||
concurrency: 1
|
||||
|
||||
- class: trustgraph.extract.kg.ontology.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: kg-extract-ontology
|
||||
concurrency: 1
|
||||
|
||||
- class: trustgraph.extract.kg.relationships.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: kg-extract-relationships
|
||||
concurrency: 1
|
||||
|
||||
- class: trustgraph.extract.kg.rows.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: kg-extract-rows
|
||||
concurrency: 1
|
||||
|
||||
- class: trustgraph.prompt.template.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: prompt
|
||||
concurrency: 1
|
||||
24
dev-tools/proc-group/groups/llm.yaml
Normal file
24
dev-tools/proc-group/groups/llm.yaml
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
# LLM. Outbound text-completion calls. Isolated because the upstream
|
||||
# LLM API is often the bottleneck and the most likely thing to need
|
||||
# restart (provider changes, model changes, API flakiness).
|
||||
|
||||
_defaults: &defaults
|
||||
pubsub_backend: rabbitmq
|
||||
rabbitmq_host: localhost
|
||||
log_level: INFO
|
||||
|
||||
processors:
|
||||
|
||||
- class: trustgraph.model.text_completion.openai.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: text-completion
|
||||
max_output: 8192
|
||||
temperature: 0.0
|
||||
|
||||
- class: trustgraph.model.text_completion.openai.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: text-completion-rag
|
||||
max_output: 8192
|
||||
temperature: 0.0
|
||||
64
dev-tools/proc-group/groups/rag.yaml
Normal file
64
dev-tools/proc-group/groups/rag.yaml
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
# RAG / retrieval / agent. Query-time serving path. Drives outbound
|
||||
# LLM calls via prompt-rag. sparql-query lives here because it's a
|
||||
# read-side serving endpoint, not a backend writer.
|
||||
|
||||
_defaults: &defaults
|
||||
pubsub_backend: rabbitmq
|
||||
rabbitmq_host: localhost
|
||||
log_level: INFO
|
||||
|
||||
processors:
|
||||
|
||||
- class: trustgraph.agent.orchestrator.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: agent-manager
|
||||
|
||||
- class: trustgraph.retrieval.graph_rag.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: graph-rag
|
||||
concurrency: 1
|
||||
entity_limit: 50
|
||||
triple_limit: 30
|
||||
edge_limit: 30
|
||||
edge_score_limit: 10
|
||||
max_subgraph_size: 100
|
||||
max_path_length: 2
|
||||
|
||||
- class: trustgraph.retrieval.document_rag.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: document-rag
|
||||
doc_limit: 20
|
||||
|
||||
- class: trustgraph.retrieval.nlp_query.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: nlp-query
|
||||
|
||||
- class: trustgraph.retrieval.structured_query.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: structured-query
|
||||
|
||||
- class: trustgraph.retrieval.structured_diag.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: structured-diag
|
||||
|
||||
- class: trustgraph.query.sparql.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: sparql-query
|
||||
|
||||
- class: trustgraph.prompt.template.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: prompt-rag
|
||||
concurrency: 1
|
||||
|
||||
- class: trustgraph.agent.mcp_tool.Service
|
||||
params:
|
||||
<<: *defaults
|
||||
id: mcp-tool
|
||||
20
dev-tools/proc-group/groups/rows-store.yaml
Normal file
20
dev-tools/proc-group/groups/rows-store.yaml
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
# Rows store. Cassandra-backed structured row query/write.
|
||||
|
||||
_defaults: &defaults
|
||||
pubsub_backend: rabbitmq
|
||||
rabbitmq_host: localhost
|
||||
log_level: INFO
|
||||
|
||||
processors:
|
||||
|
||||
- class: trustgraph.query.rows.cassandra.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: rows-query
|
||||
cassandra_host: localhost
|
||||
|
||||
- class: trustgraph.storage.rows.cassandra.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: rows-write
|
||||
cassandra_host: localhost
|
||||
20
dev-tools/proc-group/groups/triples-store.yaml
Normal file
20
dev-tools/proc-group/groups/triples-store.yaml
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
# Triples store. Cassandra-backed RDF triple query/write.
|
||||
|
||||
_defaults: &defaults
|
||||
pubsub_backend: rabbitmq
|
||||
rabbitmq_host: localhost
|
||||
log_level: INFO
|
||||
|
||||
processors:
|
||||
|
||||
- class: trustgraph.query.triples.cassandra.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: triples-query
|
||||
cassandra_host: localhost
|
||||
|
||||
- class: trustgraph.storage.triples.cassandra.Processor
|
||||
params:
|
||||
<<: *defaults
|
||||
id: triples-write
|
||||
cassandra_host: localhost
|
||||
|
|
@ -131,21 +131,21 @@ async def analyse(path, url, flow, user, collection):
|
|||
|
||||
for i, msg in enumerate(messages):
|
||||
resp = msg.get("response", {})
|
||||
chunk_type = resp.get("chunk_type", "?")
|
||||
message_type = resp.get("message_type", "?")
|
||||
|
||||
if chunk_type == "explain":
|
||||
if message_type == "explain":
|
||||
explain_id = resp.get("explain_id", "")
|
||||
explain_ids.append(explain_id)
|
||||
print(f" {i:3d} {chunk_type} {explain_id}")
|
||||
print(f" {i:3d} {message_type} {explain_id}")
|
||||
else:
|
||||
print(f" {i:3d} {chunk_type}")
|
||||
print(f" {i:3d} {message_type}")
|
||||
|
||||
# Rule 7: message_id on content chunks
|
||||
if chunk_type in ("thought", "observation", "answer"):
|
||||
if message_type in ("thought", "observation", "answer"):
|
||||
mid = resp.get("message_id", "")
|
||||
if not mid:
|
||||
errors.append(
|
||||
f"[msg {i}] {chunk_type} chunk missing message_id"
|
||||
f"[msg {i}] {message_type} chunk missing message_id"
|
||||
)
|
||||
|
||||
print()
|
||||
|
|
|
|||
|
|
@ -6,273 +6,212 @@ parent: "Tech Specs"
|
|||
|
||||
# Agent Explainability: Provenance Recording
|
||||
|
||||
## Status
|
||||
|
||||
Implemented
|
||||
|
||||
## Overview
|
||||
|
||||
Add provenance recording to the React agent loop so agent sessions can be traced and debugged using the same explainability infrastructure as GraphRAG.
|
||||
Agent sessions are traced and debugged using the same explainability infrastructure as GraphRAG and Document RAG. Provenance is written to `urn:graph:retrieval` and delivered inline on the explain stream.
|
||||
|
||||
**Design Decisions:**
|
||||
- Write to `urn:graph:retrieval` (generic explainability graph)
|
||||
- Linear dependency chain for now (analysis N → wasDerivedFrom → analysis N-1)
|
||||
- Tools are opaque black boxes (record input/output only)
|
||||
- DAG support deferred to future iteration
|
||||
The canonical vocabulary for all predicates and types is published as an OWL ontology at `specs/ontology/trustgraph.ttl`.
|
||||
|
||||
## Entity Types
|
||||
|
||||
Both GraphRAG and Agent use PROV-O as the base ontology with TrustGraph-specific subtypes:
|
||||
All services use PROV-O as the base ontology with TrustGraph-specific subtypes.
|
||||
|
||||
### GraphRAG Types
|
||||
| Entity | PROV-O Type | TG Types | Description |
|
||||
|--------|-------------|----------|-------------|
|
||||
| Question | `prov:Activity` | `tg:Question`, `tg:GraphRagQuestion` | The user's query |
|
||||
| Exploration | `prov:Entity` | `tg:Exploration` | Edges retrieved from knowledge graph |
|
||||
| Focus | `prov:Entity` | `tg:Focus` | Selected edges with reasoning |
|
||||
| Synthesis | `prov:Entity` | `tg:Synthesis` | Final answer |
|
||||
|
||||
### Agent Types
|
||||
| Entity | PROV-O Type | TG Types | Description |
|
||||
|--------|-------------|----------|-------------|
|
||||
| Question | `prov:Activity` | `tg:Question`, `tg:AgentQuestion` | The user's query |
|
||||
| Analysis | `prov:Entity` | `tg:Analysis` | Each think/act/observe cycle |
|
||||
| Conclusion | `prov:Entity` | `tg:Conclusion` | Final answer |
|
||||
| Entity | TG Types | Description |
|
||||
|--------|----------|-------------|
|
||||
| Question | `tg:Question`, `tg:GraphRagQuestion` | The user's query |
|
||||
| Grounding | `tg:Grounding` | Concept extraction from query |
|
||||
| Exploration | `tg:Exploration` | Edges retrieved from knowledge graph |
|
||||
| Focus | `tg:Focus` | Selected edges with reasoning |
|
||||
| Synthesis | `tg:Synthesis`, `tg:Answer` | Final answer |
|
||||
|
||||
### Document RAG Types
|
||||
| Entity | PROV-O Type | TG Types | Description |
|
||||
|--------|-------------|----------|-------------|
|
||||
| Question | `prov:Activity` | `tg:Question`, `tg:DocRagQuestion` | The user's query |
|
||||
| Exploration | `prov:Entity` | `tg:Exploration` | Chunks retrieved from document store |
|
||||
| Synthesis | `prov:Entity` | `tg:Synthesis` | Final answer |
|
||||
| Entity | TG Types | Description |
|
||||
|--------|----------|-------------|
|
||||
| Question | `tg:Question`, `tg:DocRagQuestion` | The user's query |
|
||||
| Grounding | `tg:Grounding` | Concept extraction from query |
|
||||
| Exploration | `tg:Exploration` | Chunks retrieved from document store |
|
||||
| Synthesis | `tg:Synthesis`, `tg:Answer` | Final answer |
|
||||
|
||||
**Note:** Document RAG uses a subset of GraphRAG's types (no Focus step since there's no edge selection/reasoning phase).
|
||||
### Agent Types (React)
|
||||
| Entity | TG Types | Description |
|
||||
|--------|----------|-------------|
|
||||
| Question | `tg:Question`, `tg:AgentQuestion` | The user's query (session start) |
|
||||
| PatternDecision | `tg:PatternDecision` | Meta-router routing decision |
|
||||
| Analysis | `tg:Analysis`, `tg:ToolUse` | One think/act cycle |
|
||||
| Thought | `tg:Reflection`, `tg:Thought` | Agent reasoning (sub-entity of Analysis) |
|
||||
| Observation | `tg:Reflection`, `tg:Observation` | Tool result (standalone entity) |
|
||||
| Conclusion | `tg:Conclusion`, `tg:Answer` | Final answer |
|
||||
|
||||
### Agent Types (Orchestrator — Plan)
|
||||
| Entity | TG Types | Description |
|
||||
|--------|----------|-------------|
|
||||
| Plan | `tg:Plan` | Structured plan of steps |
|
||||
| StepResult | `tg:StepResult`, `tg:Answer` | Result from executing one plan step |
|
||||
| Synthesis | `tg:Synthesis`, `tg:Answer` | Final synthesised answer |
|
||||
|
||||
### Agent Types (Orchestrator — Supervisor)
|
||||
| Entity | TG Types | Description |
|
||||
|--------|----------|-------------|
|
||||
| Decomposition | `tg:Decomposition` | Question decomposed into sub-goals |
|
||||
| Finding | `tg:Finding`, `tg:Answer` | Result from a sub-agent |
|
||||
| Synthesis | `tg:Synthesis`, `tg:Answer` | Final synthesised answer |
|
||||
|
||||
### Mixin Types
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `tg:Answer` | Unifying type for terminal answers (Synthesis, Conclusion, Finding, StepResult) |
|
||||
| `tg:Reflection` | Unifying type for intermediate commentary (Thought, Observation) |
|
||||
| `tg:ToolUse` | Applied to Analysis when a tool is invoked |
|
||||
| `tg:Error` | Applied to Observation events where a failure occurred (tool error or LLM parse error) |
|
||||
|
||||
### Question Subtypes
|
||||
|
||||
All Question entities share `tg:Question` as a base type but have a specific subtype to identify the retrieval mechanism:
|
||||
|
||||
| Subtype | URI Pattern | Mechanism |
|
||||
|---------|-------------|-----------|
|
||||
| `tg:GraphRagQuestion` | `urn:trustgraph:question:{uuid}` | Knowledge graph RAG |
|
||||
| `tg:DocRagQuestion` | `urn:trustgraph:docrag:{uuid}` | Document/chunk RAG |
|
||||
| `tg:AgentQuestion` | `urn:trustgraph:agent:{uuid}` | ReAct agent |
|
||||
| `tg:AgentQuestion` | `urn:trustgraph:agent:session:{uuid}` | Agent orchestrator |
|
||||
|
||||
This allows querying all questions via `tg:Question` while filtering by specific mechanism via the subtype.
|
||||
## Provenance Chains
|
||||
|
||||
## Provenance Model
|
||||
All chains use `prov:wasDerivedFrom` links. Each entity is a `prov:Entity`.
|
||||
|
||||
### GraphRAG
|
||||
|
||||
```
|
||||
Question (urn:trustgraph:agent:{uuid})
|
||||
│
|
||||
│ tg:query = "User's question"
|
||||
│ prov:startedAtTime = timestamp
|
||||
│ rdf:type = prov:Activity, tg:Question
|
||||
│
|
||||
↓ prov:wasDerivedFrom
|
||||
│
|
||||
Analysis1 (urn:trustgraph:agent:{uuid}/i1)
|
||||
│
|
||||
│ tg:thought = "I need to query the knowledge base..."
|
||||
│ tg:action = "knowledge-query"
|
||||
│ tg:arguments = {"question": "..."}
|
||||
│ tg:observation = "Result from tool..."
|
||||
│ rdf:type = prov:Entity, tg:Analysis
|
||||
│
|
||||
↓ prov:wasDerivedFrom
|
||||
│
|
||||
Analysis2 (urn:trustgraph:agent:{uuid}/i2)
|
||||
│ ...
|
||||
↓ prov:wasDerivedFrom
|
||||
│
|
||||
Conclusion (urn:trustgraph:agent:{uuid}/final)
|
||||
│
|
||||
│ tg:answer = "The final response..."
|
||||
│ rdf:type = prov:Entity, tg:Conclusion
|
||||
Question → Grounding → Exploration → Focus → Synthesis
|
||||
```
|
||||
|
||||
### Document RAG Provenance Model
|
||||
### Document RAG
|
||||
|
||||
```
|
||||
Question (urn:trustgraph:docrag:{uuid})
|
||||
│
|
||||
│ tg:query = "User's question"
|
||||
│ prov:startedAtTime = timestamp
|
||||
│ rdf:type = prov:Activity, tg:Question
|
||||
│
|
||||
↓ prov:wasGeneratedBy
|
||||
│
|
||||
Exploration (urn:trustgraph:docrag:{uuid}/exploration)
|
||||
│
|
||||
│ tg:chunkCount = 5
|
||||
│ tg:selectedChunk = "chunk-id-1"
|
||||
│ tg:selectedChunk = "chunk-id-2"
|
||||
│ ...
|
||||
│ rdf:type = prov:Entity, tg:Exploration
|
||||
│
|
||||
↓ prov:wasDerivedFrom
|
||||
│
|
||||
Synthesis (urn:trustgraph:docrag:{uuid}/synthesis)
|
||||
│
|
||||
│ tg:content = "The synthesized answer..."
|
||||
│ rdf:type = prov:Entity, tg:Synthesis
|
||||
Question → Grounding → Exploration → Synthesis
|
||||
```
|
||||
|
||||
## Changes Required
|
||||
### Agent React
|
||||
|
||||
### 1. Schema Changes
|
||||
|
||||
**File:** `trustgraph-base/trustgraph/schema/services/agent.py`
|
||||
|
||||
Add `session_id` and `collection` fields to `AgentRequest`:
|
||||
```python
|
||||
@dataclass
|
||||
class AgentRequest:
|
||||
question: str = ""
|
||||
state: str = ""
|
||||
group: list[str] | None = None
|
||||
history: list[AgentStep] = field(default_factory=list)
|
||||
user: str = ""
|
||||
collection: str = "default" # NEW: Collection for provenance traces
|
||||
streaming: bool = False
|
||||
session_id: str = "" # NEW: For provenance tracking across iterations
|
||||
```
|
||||
Question → PatternDecision → Analysis(1) → Observation(1) → Analysis(2) → ... → Conclusion
|
||||
```
|
||||
|
||||
**File:** `trustgraph-base/trustgraph/messaging/translators/agent.py`
|
||||
The PatternDecision entity records which execution pattern the meta-router selected. It is only emitted on the first iteration when routing occurs.
|
||||
|
||||
Update translator to handle `session_id` and `collection` in both `to_pulsar()` and `from_pulsar()`.
|
||||
Thought sub-entities derive from their parent Analysis. Observation entities derive from their parent Analysis (or from a sub-trace entity if the tool produced its own explainability, e.g. a GraphRAG query).
|
||||
|
||||
### 2. Add Explainability Producer to Agent Service
|
||||
### Agent Plan-then-Execute
|
||||
|
||||
**File:** `trustgraph-flow/trustgraph/agent/react/service.py`
|
||||
|
||||
Register an "explainability" producer (same pattern as GraphRAG):
|
||||
```python
|
||||
from ... base import ProducerSpec
|
||||
from ... schema import Triples
|
||||
|
||||
# In __init__:
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "explainability",
|
||||
schema = Triples,
|
||||
)
|
||||
)
|
||||
```
|
||||
Question → PatternDecision → Plan → StepResult(0) → StepResult(1) → ... → Synthesis
|
||||
```
|
||||
|
||||
### 3. Provenance Triple Generation
|
||||
### Agent Supervisor
|
||||
|
||||
**File:** `trustgraph-base/trustgraph/provenance/agent.py`
|
||||
|
||||
Create helper functions (similar to GraphRAG's `question_triples`, `exploration_triples`, etc.):
|
||||
```python
|
||||
def agent_session_triples(session_uri, query, timestamp):
|
||||
"""Generate triples for agent Question."""
|
||||
return [
|
||||
Triple(s=session_uri, p=RDF_TYPE, o=PROV_ACTIVITY),
|
||||
Triple(s=session_uri, p=RDF_TYPE, o=TG_QUESTION),
|
||||
Triple(s=session_uri, p=TG_QUERY, o=query),
|
||||
Triple(s=session_uri, p=PROV_STARTED_AT_TIME, o=timestamp),
|
||||
]
|
||||
|
||||
def agent_iteration_triples(iteration_uri, parent_uri, thought, action, arguments, observation):
|
||||
"""Generate triples for one Analysis step."""
|
||||
return [
|
||||
Triple(s=iteration_uri, p=RDF_TYPE, o=PROV_ENTITY),
|
||||
Triple(s=iteration_uri, p=RDF_TYPE, o=TG_ANALYSIS),
|
||||
Triple(s=iteration_uri, p=TG_THOUGHT, o=thought),
|
||||
Triple(s=iteration_uri, p=TG_ACTION, o=action),
|
||||
Triple(s=iteration_uri, p=TG_ARGUMENTS, o=json.dumps(arguments)),
|
||||
Triple(s=iteration_uri, p=TG_OBSERVATION, o=observation),
|
||||
Triple(s=iteration_uri, p=PROV_WAS_DERIVED_FROM, o=parent_uri),
|
||||
]
|
||||
|
||||
def agent_final_triples(final_uri, parent_uri, answer):
|
||||
"""Generate triples for Conclusion."""
|
||||
return [
|
||||
Triple(s=final_uri, p=RDF_TYPE, o=PROV_ENTITY),
|
||||
Triple(s=final_uri, p=RDF_TYPE, o=TG_CONCLUSION),
|
||||
Triple(s=final_uri, p=TG_ANSWER, o=answer),
|
||||
Triple(s=final_uri, p=PROV_WAS_DERIVED_FROM, o=parent_uri),
|
||||
]
|
||||
```
|
||||
Question → PatternDecision → Decomposition → [fan-out sub-agents]
|
||||
→ Finding(0) → Finding(1) → ... → Synthesis
|
||||
```
|
||||
|
||||
### 4. Type Definitions
|
||||
Each sub-agent runs its own session with `wasDerivedFrom` linking back to the parent's Decomposition. Findings derive from their sub-agent's Conclusion.
|
||||
|
||||
**File:** `trustgraph-base/trustgraph/provenance/namespaces.py`
|
||||
## Predicates
|
||||
|
||||
Add explainability entity types and agent predicates:
|
||||
```python
|
||||
# Explainability entity types (used by both GraphRAG and Agent)
|
||||
TG_QUESTION = TG + "Question"
|
||||
TG_EXPLORATION = TG + "Exploration"
|
||||
TG_FOCUS = TG + "Focus"
|
||||
TG_SYNTHESIS = TG + "Synthesis"
|
||||
TG_ANALYSIS = TG + "Analysis"
|
||||
TG_CONCLUSION = TG + "Conclusion"
|
||||
### Session / Question
|
||||
| Predicate | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `tg:query` | string | The user's query text |
|
||||
| `prov:startedAtTime` | string | ISO timestamp |
|
||||
|
||||
# Agent predicates
|
||||
TG_THOUGHT = TG + "thought"
|
||||
TG_ACTION = TG + "action"
|
||||
TG_ARGUMENTS = TG + "arguments"
|
||||
TG_OBSERVATION = TG + "observation"
|
||||
TG_ANSWER = TG + "answer"
|
||||
```
|
||||
### Pattern Decision
|
||||
| Predicate | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `tg:pattern` | string | Selected pattern (react, plan-then-execute, supervisor) |
|
||||
| `tg:taskType` | string | Identified task type (general, research, etc.) |
|
||||
|
||||
## Files Modified
|
||||
### Analysis (Iteration)
|
||||
| Predicate | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `tg:action` | string | Tool name selected by the agent |
|
||||
| `tg:arguments` | string | JSON-encoded arguments |
|
||||
| `tg:thought` | IRI | Link to Thought sub-entity |
|
||||
| `tg:toolCandidate` | string | Tool name available to the LLM (one per candidate) |
|
||||
| `tg:stepNumber` | integer | 1-based iteration counter |
|
||||
| `tg:llmDurationMs` | integer | LLM call duration in milliseconds |
|
||||
| `tg:inToken` | integer | Input token count |
|
||||
| `tg:outToken` | integer | Output token count |
|
||||
| `tg:llmModel` | string | Model identifier |
|
||||
|
||||
| File | Change |
|
||||
|------|--------|
|
||||
| `trustgraph-base/trustgraph/schema/services/agent.py` | Add session_id and collection to AgentRequest |
|
||||
| `trustgraph-base/trustgraph/messaging/translators/agent.py` | Update translator for new fields |
|
||||
| `trustgraph-base/trustgraph/provenance/namespaces.py` | Add entity types, agent predicates, and Document RAG predicates |
|
||||
| `trustgraph-base/trustgraph/provenance/triples.py` | Add TG types to GraphRAG triple builders, add Document RAG triple builders |
|
||||
| `trustgraph-base/trustgraph/provenance/uris.py` | Add Document RAG URI generators |
|
||||
| `trustgraph-base/trustgraph/provenance/__init__.py` | Export new types, predicates, and Document RAG functions |
|
||||
| `trustgraph-base/trustgraph/schema/services/retrieval.py` | Add explain_id, explain_graph, and explain_triples to DocumentRagResponse |
|
||||
| `trustgraph-base/trustgraph/messaging/translators/retrieval.py` | Update DocumentRagResponseTranslator for explainability fields including inline triples |
|
||||
| `trustgraph-flow/trustgraph/agent/react/service.py` | Add explainability producer + recording logic |
|
||||
| `trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py` | Add explainability callback and emit provenance triples |
|
||||
| `trustgraph-flow/trustgraph/retrieval/document_rag/rag.py` | Add explainability producer and wire up callback |
|
||||
| `trustgraph-cli/trustgraph/cli/show_explain_trace.py` | Handle agent trace types |
|
||||
| `trustgraph-cli/trustgraph/cli/list_explain_traces.py` | List agent sessions alongside GraphRAG |
|
||||
### Observation
|
||||
| Predicate | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `tg:document` | IRI | Librarian document reference |
|
||||
| `tg:toolDurationMs` | integer | Tool execution time in milliseconds |
|
||||
| `tg:toolError` | string | Error message (tool failure or LLM parse error) |
|
||||
|
||||
## Files Created
|
||||
When `tg:toolError` is present, the Observation also carries the `tg:Error` mixin type.
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `trustgraph-base/trustgraph/provenance/agent.py` | Agent-specific triple generators |
|
||||
### Conclusion / Synthesis
|
||||
| Predicate | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `tg:document` | IRI | Librarian document reference |
|
||||
| `tg:terminationReason` | string | Why the loop stopped |
|
||||
| `tg:inToken` | integer | Input token count (synthesis LLM call) |
|
||||
| `tg:outToken` | integer | Output token count |
|
||||
| `tg:llmModel` | string | Model identifier |
|
||||
|
||||
## CLI Updates
|
||||
Termination reason values:
|
||||
- `final-answer` -- LLM produced a confident answer (react)
|
||||
- `plan-complete` -- all plan steps executed (plan-then-execute)
|
||||
- `subagents-complete` -- all sub-agents reported back (supervisor)
|
||||
|
||||
**Detection:** Both GraphRAG and Agent Questions have `tg:Question` type. Distinguished by:
|
||||
1. URI pattern: `urn:trustgraph:agent:` vs `urn:trustgraph:question:`
|
||||
2. Derived entities: `tg:Analysis` (agent) vs `tg:Exploration` (GraphRAG)
|
||||
### Decomposition
|
||||
| Predicate | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `tg:subagentGoal` | string | Goal assigned to a sub-agent (one per goal) |
|
||||
| `tg:inToken` | integer | Input token count |
|
||||
| `tg:outToken` | integer | Output token count |
|
||||
|
||||
**`list_explain_traces.py`:**
|
||||
- Shows Type column (Agent vs GraphRAG)
|
||||
### Plan
|
||||
| Predicate | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `tg:planStep` | string | Goal for a plan step (one per step) |
|
||||
| `tg:inToken` | integer | Input token count |
|
||||
| `tg:outToken` | integer | Output token count |
|
||||
|
||||
**`show_explain_trace.py`:**
|
||||
- Auto-detects trace type
|
||||
- Agent rendering shows: Question → Analysis step(s) → Conclusion
|
||||
### Token Counts on RAG Events
|
||||
|
||||
## Backwards Compatibility
|
||||
Grounding, Focus, and Synthesis events on GraphRAG and Document RAG also carry `tg:inToken`, `tg:outToken`, and `tg:llmModel` for the LLM calls associated with that step.
|
||||
|
||||
- `session_id` defaults to `""` - old requests work, just won't have provenance
|
||||
- `collection` defaults to `"default"` - reasonable fallback
|
||||
- CLI gracefully handles both trace types
|
||||
## Error Handling
|
||||
|
||||
Tool execution errors and LLM parse errors are captured as Observation events rather than crashing the agent:
|
||||
|
||||
- The error message is recorded on `tg:toolError`
|
||||
- The Observation carries the `tg:Error` mixin type
|
||||
- The error text becomes the observation content, visible to the LLM on the next iteration
|
||||
- The provenance chain is preserved (Observation derives from Analysis)
|
||||
- The agent gets another iteration to retry or choose a different approach
|
||||
|
||||
## Vocabulary Reference
|
||||
|
||||
The full OWL ontology covering all classes and predicates is at `specs/ontology/trustgraph.ttl`.
|
||||
|
||||
## Verification
|
||||
|
||||
```bash
|
||||
# Run an agent query
|
||||
tg-invoke-agent -q "What is the capital of France?"
|
||||
# Run an agent query with explainability
|
||||
tg-invoke-agent -q "What is quantum computing?" -x
|
||||
|
||||
# List traces (should show agent sessions with Type column)
|
||||
tg-list-explain-traces -U trustgraph -C default
|
||||
# Run with token usage
|
||||
tg-invoke-agent -q "What is quantum computing?" --show-usage
|
||||
|
||||
# Show agent trace
|
||||
tg-show-explain-trace "urn:trustgraph:agent:xxx"
|
||||
# GraphRAG with explainability
|
||||
tg-invoke-graph-rag -q "Tell me about AI" -x
|
||||
|
||||
# Document RAG with explainability
|
||||
tg-invoke-document-rag -q "Summarize the findings" -x
|
||||
```
|
||||
|
||||
## Future Work (Not This PR)
|
||||
|
||||
- DAG dependencies (when analysis N uses results from multiple prior analyses)
|
||||
- Tool-specific provenance linking (KnowledgeQuery → its GraphRAG trace)
|
||||
- Streaming provenance emission (emit as we go, not batch at end)
|
||||
|
|
|
|||
|
|
@ -868,7 +868,7 @@ independently.
|
|||
Response chunk fields:
|
||||
message_id UUID for this message (groups chunks)
|
||||
session_id Which agent session produced this chunk
|
||||
chunk_type "thought" | "observation" | "answer" | ...
|
||||
message_type "thought" | "observation" | "answer" | ...
|
||||
content The chunk text
|
||||
end_of_message True on the final chunk of this message
|
||||
end_of_dialog True on the final message of the entire execution
|
||||
|
|
|
|||
|
|
@ -209,3 +209,7 @@ def subgraph_provenance_triples(
|
|||
This is a breaking change to the provenance model. Provenance has not
|
||||
been released, so no migration is needed. The old `tg:reifies` /
|
||||
`statement_uri` code can be removed outright.
|
||||
|
||||
## Vocabulary Reference
|
||||
|
||||
The full OWL ontology covering all extraction and query-time classes and predicates is at `specs/ontology/trustgraph.ttl`.
|
||||
|
|
|
|||
|
|
@ -618,8 +618,13 @@ Link embedding entity IDs to chunk.
|
|||
| Triple store | Use reification for triple → chunk provenance |
|
||||
| Embedding provenance | Link entity ID → chunk ID |
|
||||
|
||||
## Vocabulary Reference
|
||||
|
||||
The full OWL ontology covering all extraction and query-time classes and predicates is at `specs/ontology/trustgraph.ttl`.
|
||||
|
||||
## References
|
||||
|
||||
- Query-time provenance: `docs/tech-specs/query-time-provenance.md`
|
||||
- Query-time provenance: `docs/tech-specs/query-time-explainability.md`
|
||||
- Agent explainability: `docs/tech-specs/agent-explainability.md`
|
||||
- PROV-O standard for provenance modeling
|
||||
- Existing source metadata in knowledge graph (needs audit)
|
||||
|
|
|
|||
|
|
@ -710,17 +710,17 @@ class StreamingChunk:
|
|||
@dataclasses.dataclass
|
||||
class AgentThought(StreamingChunk):
|
||||
"""Agent reasoning chunk"""
|
||||
chunk_type: str = "thought"
|
||||
message_type: str = "thought"
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AgentObservation(StreamingChunk):
|
||||
"""Agent tool observation chunk"""
|
||||
chunk_type: str = "observation"
|
||||
message_type: str = "observation"
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AgentAnswer(StreamingChunk):
|
||||
"""Agent final answer chunk"""
|
||||
chunk_type: str = "final-answer"
|
||||
message_type: str = "final-answer"
|
||||
end_of_dialog: bool = False
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
|
|
|||
|
|
@ -144,6 +144,25 @@ Defined in `trustgraph-base/trustgraph/provenance/namespaces.py`:
|
|||
| `TG_REASONING` | `https://trustgraph.ai/ns/reasoning` |
|
||||
| `TG_CONTENT` | `https://trustgraph.ai/ns/content` |
|
||||
| `TG_DOCUMENT` | `https://trustgraph.ai/ns/document` |
|
||||
| `TG_IN_TOKEN` | `https://trustgraph.ai/ns/inToken` |
|
||||
| `TG_OUT_TOKEN` | `https://trustgraph.ai/ns/outToken` |
|
||||
| `TG_LLM_MODEL` | `https://trustgraph.ai/ns/llmModel` |
|
||||
|
||||
### Token Usage on Events
|
||||
|
||||
Grounding, Focus, and Synthesis events carry per-event LLM token counts:
|
||||
|
||||
| Predicate | Type | Present on |
|
||||
|-----------|------|------------|
|
||||
| `tg:inToken` | integer | Grounding, Focus, Synthesis |
|
||||
| `tg:outToken` | integer | Grounding, Focus, Synthesis |
|
||||
| `tg:llmModel` | string | Grounding, Focus, Synthesis |
|
||||
|
||||
- **Grounding**: tokens from the extract-concepts LLM call
|
||||
- **Focus**: summed tokens from edge-scoring + edge-reasoning LLM calls
|
||||
- **Synthesis**: tokens from the synthesis LLM call
|
||||
|
||||
Values are absent (not zero) when token counts are unavailable.
|
||||
|
||||
## GraphRagResponse Schema
|
||||
|
||||
|
|
@ -267,8 +286,13 @@ Based on the provided knowledge statements...
|
|||
| `trustgraph-flow/trustgraph/query/triples/cassandra/service.py` | Quoted triple query support |
|
||||
| `trustgraph-cli/trustgraph/cli/invoke_graph_rag.py` | CLI with explainability display |
|
||||
|
||||
## Vocabulary Reference
|
||||
|
||||
The full OWL ontology covering all classes and predicates is at `specs/ontology/trustgraph.ttl`.
|
||||
|
||||
## References
|
||||
|
||||
- PROV-O (W3C Provenance Ontology): https://www.w3.org/TR/prov-o/
|
||||
- RDF-star: https://w3c.github.io/rdf-star/
|
||||
- Extraction-time provenance: `docs/tech-specs/extraction-time-provenance.md`
|
||||
- Agent explainability: `docs/tech-specs/agent-explainability.md`
|
||||
|
|
|
|||
|
|
@ -399,28 +399,28 @@ The agent produces multiple types of output during its reasoning cycle:
|
|||
- Answer (final response)
|
||||
- Errors
|
||||
|
||||
Since `chunk_type` identifies what kind of content is being sent, the separate
|
||||
Since `message_type` identifies what kind of content is being sent, the separate
|
||||
`answer`, `error`, `thought`, and `observation` fields can be collapsed into
|
||||
a single `content` field:
|
||||
|
||||
```python
|
||||
class AgentResponse(Record):
|
||||
chunk_type = String() # "thought", "action", "observation", "answer", "error"
|
||||
content = String() # The actual content (interpretation depends on chunk_type)
|
||||
message_type = String() # "thought", "action", "observation", "answer", "error"
|
||||
content = String() # The actual content (interpretation depends on message_type)
|
||||
end_of_message = Boolean() # Current thought/action/observation/answer is complete
|
||||
end_of_dialog = Boolean() # Entire agent dialog is complete
|
||||
```
|
||||
|
||||
**Field Semantics:**
|
||||
|
||||
- `chunk_type`: Indicates what type of content is in the `content` field
|
||||
- `message_type`: Indicates what type of content is in the `content` field
|
||||
- `"thought"`: Agent reasoning/thinking
|
||||
- `"action"`: Tool/action being invoked
|
||||
- `"observation"`: Result from tool execution
|
||||
- `"answer"`: Final answer to the user's question
|
||||
- `"error"`: Error message
|
||||
|
||||
- `content`: The actual streamed content, interpreted based on `chunk_type`
|
||||
- `content`: The actual streamed content, interpreted based on `message_type`
|
||||
|
||||
- `end_of_message`: When `true`, the current chunk type is complete
|
||||
- Example: All tokens for the current thought have been sent
|
||||
|
|
@ -434,27 +434,27 @@ class AgentResponse(Record):
|
|||
When `streaming=true`:
|
||||
|
||||
1. **Thought streaming**:
|
||||
- Multiple chunks with `chunk_type="thought"`, `end_of_message=false`
|
||||
- Multiple chunks with `message_type="thought"`, `end_of_message=false`
|
||||
- Final thought chunk has `end_of_message=true`
|
||||
2. **Action notification**:
|
||||
- Single chunk with `chunk_type="action"`, `end_of_message=true`
|
||||
- Single chunk with `message_type="action"`, `end_of_message=true`
|
||||
3. **Observation**:
|
||||
- Chunk(s) with `chunk_type="observation"`, final has `end_of_message=true`
|
||||
- Chunk(s) with `message_type="observation"`, final has `end_of_message=true`
|
||||
4. **Repeat** steps 1-3 as the agent reasons
|
||||
5. **Final answer**:
|
||||
- `chunk_type="answer"` with the final response in `content`
|
||||
- `message_type="answer"` with the final response in `content`
|
||||
- Last chunk has `end_of_message=true`, `end_of_dialog=true`
|
||||
|
||||
**Example Stream Sequence:**
|
||||
|
||||
```
|
||||
{chunk_type: "thought", content: "I need to", end_of_message: false, end_of_dialog: false}
|
||||
{chunk_type: "thought", content: " search for...", end_of_message: true, end_of_dialog: false}
|
||||
{chunk_type: "action", content: "search", end_of_message: true, end_of_dialog: false}
|
||||
{chunk_type: "observation", content: "Found: ...", end_of_message: true, end_of_dialog: false}
|
||||
{chunk_type: "thought", content: "Based on this", end_of_message: false, end_of_dialog: false}
|
||||
{chunk_type: "thought", content: " I can answer...", end_of_message: true, end_of_dialog: false}
|
||||
{chunk_type: "answer", content: "The answer is...", end_of_message: true, end_of_dialog: true}
|
||||
{message_type: "thought", content: "I need to", end_of_message: false, end_of_dialog: false}
|
||||
{message_type: "thought", content: " search for...", end_of_message: true, end_of_dialog: false}
|
||||
{message_type: "action", content: "search", end_of_message: true, end_of_dialog: false}
|
||||
{message_type: "observation", content: "Found: ...", end_of_message: true, end_of_dialog: false}
|
||||
{message_type: "thought", content: "Based on this", end_of_message: false, end_of_dialog: false}
|
||||
{message_type: "thought", content: " I can answer...", end_of_message: true, end_of_dialog: false}
|
||||
{message_type: "answer", content: "The answer is...", end_of_message: true, end_of_dialog: true}
|
||||
```
|
||||
|
||||
When `streaming=false`:
|
||||
|
|
@ -547,7 +547,7 @@ The following questions were resolved during specification:
|
|||
populated and no other fields are needed. An error is always the final
|
||||
communication - no subsequent messages are permitted or expected after
|
||||
an error. For LLM/Prompt streams, `end_of_stream=true`. For Agent streams,
|
||||
`chunk_type="error"` with `end_of_dialog=true`.
|
||||
`message_type="error"` with `end_of_dialog=true`.
|
||||
|
||||
3. **Partial Response Recovery**: The messaging protocol (Pulsar) is resilient,
|
||||
so message-level retry is not needed. If a client loses track of the stream
|
||||
|
|
|
|||
415
specs/ontology/trustgraph.ttl
Normal file
415
specs/ontology/trustgraph.ttl
Normal file
|
|
@ -0,0 +1,415 @@
|
|||
@prefix tg: <https://trustgraph.ai/ns/> .
|
||||
@prefix owl: <http://www.w3.org/2002/07/owl#> .
|
||||
@prefix rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#> .
|
||||
@prefix rdfs: <http://www.w3.org/2000/01/rdf-schema#> .
|
||||
@prefix xsd: <http://www.w3.org/2001/XMLSchema#> .
|
||||
@prefix prov: <http://www.w3.org/ns/prov#> .
|
||||
|
||||
# =============================================================================
|
||||
# Ontology declaration
|
||||
# =============================================================================
|
||||
|
||||
<https://trustgraph.ai/ns/>
|
||||
a owl:Ontology ;
|
||||
rdfs:label "TrustGraph Ontology" ;
|
||||
rdfs:comment "Vocabulary for TrustGraph provenance, extraction metadata, and explainability." ;
|
||||
owl:versionInfo "2.3" .
|
||||
|
||||
# =============================================================================
|
||||
# Classes — Extraction provenance
|
||||
# =============================================================================
|
||||
|
||||
tg:Document a owl:Class ;
|
||||
rdfs:subClassOf prov:Entity ;
|
||||
rdfs:label "Document" ;
|
||||
rdfs:comment "A loaded document (PDF, text, etc.)." .
|
||||
|
||||
tg:Page a owl:Class ;
|
||||
rdfs:subClassOf prov:Entity ;
|
||||
rdfs:label "Page" ;
|
||||
rdfs:comment "A page within a document." .
|
||||
|
||||
tg:Section a owl:Class ;
|
||||
rdfs:subClassOf prov:Entity ;
|
||||
rdfs:label "Section" ;
|
||||
rdfs:comment "A structural section within a document." .
|
||||
|
||||
tg:Chunk a owl:Class ;
|
||||
rdfs:subClassOf prov:Entity ;
|
||||
rdfs:label "Chunk" ;
|
||||
rdfs:comment "A text chunk produced by the chunker." .
|
||||
|
||||
tg:Image a owl:Class ;
|
||||
rdfs:subClassOf prov:Entity ;
|
||||
rdfs:label "Image" ;
|
||||
rdfs:comment "An image extracted from a document." .
|
||||
|
||||
tg:Subgraph a owl:Class ;
|
||||
rdfs:subClassOf prov:Entity ;
|
||||
rdfs:label "Subgraph" ;
|
||||
rdfs:comment "A set of triples extracted from a chunk." .
|
||||
|
||||
# =============================================================================
|
||||
# Classes — Query-time explainability (shared)
|
||||
# =============================================================================
|
||||
|
||||
tg:Question a owl:Class ;
|
||||
rdfs:subClassOf prov:Entity ;
|
||||
rdfs:label "Question" ;
|
||||
rdfs:comment "Root entity for a query session." .
|
||||
|
||||
tg:GraphRagQuestion a owl:Class ;
|
||||
rdfs:subClassOf tg:Question ;
|
||||
rdfs:label "Graph RAG Question" ;
|
||||
rdfs:comment "A question answered via graph-based RAG." .
|
||||
|
||||
tg:DocRagQuestion a owl:Class ;
|
||||
rdfs:subClassOf tg:Question ;
|
||||
rdfs:label "Document RAG Question" ;
|
||||
rdfs:comment "A question answered via document-based RAG." .
|
||||
|
||||
tg:AgentQuestion a owl:Class ;
|
||||
rdfs:subClassOf tg:Question ;
|
||||
rdfs:label "Agent Question" ;
|
||||
rdfs:comment "A question answered via the agent orchestrator." .
|
||||
|
||||
tg:Grounding a owl:Class ;
|
||||
rdfs:subClassOf prov:Entity ;
|
||||
rdfs:label "Grounding" ;
|
||||
rdfs:comment "Concept extraction step (query decomposition into search terms)." .
|
||||
|
||||
tg:Exploration a owl:Class ;
|
||||
rdfs:subClassOf prov:Entity ;
|
||||
rdfs:label "Exploration" ;
|
||||
rdfs:comment "Entity/chunk retrieval step." .
|
||||
|
||||
tg:Focus a owl:Class ;
|
||||
rdfs:subClassOf prov:Entity ;
|
||||
rdfs:label "Focus" ;
|
||||
rdfs:comment "Edge selection and scoring step (GraphRAG)." .
|
||||
|
||||
tg:Synthesis a owl:Class ;
|
||||
rdfs:subClassOf prov:Entity ;
|
||||
rdfs:label "Synthesis" ;
|
||||
rdfs:comment "Final answer synthesis from retrieved context." .
|
||||
|
||||
# =============================================================================
|
||||
# Classes — Agent provenance
|
||||
# =============================================================================
|
||||
|
||||
tg:Analysis a owl:Class ;
|
||||
rdfs:subClassOf prov:Entity ;
|
||||
rdfs:label "Analysis" ;
|
||||
rdfs:comment "One agent iteration: reasoning followed by tool selection." .
|
||||
|
||||
tg:ToolUse a owl:Class ;
|
||||
rdfs:label "ToolUse" ;
|
||||
rdfs:comment "Mixin type applied to Analysis when a tool is invoked." .
|
||||
|
||||
tg:Error a owl:Class ;
|
||||
rdfs:label "Error" ;
|
||||
rdfs:comment "Mixin type applied to events where a failure occurred (tool error, parse error)." .
|
||||
|
||||
tg:Conclusion a owl:Class ;
|
||||
rdfs:subClassOf prov:Entity ;
|
||||
rdfs:label "Conclusion" ;
|
||||
rdfs:comment "Agent final answer (ReAct pattern)." .
|
||||
|
||||
tg:PatternDecision a owl:Class ;
|
||||
rdfs:subClassOf prov:Entity ;
|
||||
rdfs:label "Pattern Decision" ;
|
||||
rdfs:comment "Meta-router decision recording which execution pattern was selected." .
|
||||
|
||||
# --- Unifying types ---
|
||||
|
||||
tg:Answer a owl:Class ;
|
||||
rdfs:subClassOf prov:Entity ;
|
||||
rdfs:label "Answer" ;
|
||||
rdfs:comment "Unifying type for any terminal answer (Synthesis, Conclusion, Finding, StepResult)." .
|
||||
|
||||
tg:Reflection a owl:Class ;
|
||||
rdfs:subClassOf prov:Entity ;
|
||||
rdfs:label "Reflection" ;
|
||||
rdfs:comment "Unifying type for intermediate commentary (Thought, Observation)." .
|
||||
|
||||
tg:Thought a owl:Class ;
|
||||
rdfs:subClassOf tg:Reflection ;
|
||||
rdfs:label "Thought" ;
|
||||
rdfs:comment "Agent reasoning text within an iteration." .
|
||||
|
||||
tg:Observation a owl:Class ;
|
||||
rdfs:subClassOf tg:Reflection ;
|
||||
rdfs:label "Observation" ;
|
||||
rdfs:comment "Tool execution result." .
|
||||
|
||||
# --- Orchestrator types ---
|
||||
|
||||
tg:Decomposition a owl:Class ;
|
||||
rdfs:subClassOf prov:Entity ;
|
||||
rdfs:label "Decomposition" ;
|
||||
rdfs:comment "Supervisor pattern: question decomposed into sub-goals." .
|
||||
|
||||
tg:Finding a owl:Class ;
|
||||
rdfs:subClassOf tg:Answer ;
|
||||
rdfs:label "Finding" ;
|
||||
rdfs:comment "Result from a sub-agent execution." .
|
||||
|
||||
tg:Plan a owl:Class ;
|
||||
rdfs:subClassOf prov:Entity ;
|
||||
rdfs:label "Plan" ;
|
||||
rdfs:comment "Plan-then-execute pattern: structured plan of steps." .
|
||||
|
||||
tg:StepResult a owl:Class ;
|
||||
rdfs:subClassOf tg:Answer ;
|
||||
rdfs:label "Step Result" ;
|
||||
rdfs:comment "Result from executing one plan step." .
|
||||
|
||||
# =============================================================================
|
||||
# Properties — Extraction metadata
|
||||
# =============================================================================
|
||||
|
||||
tg:contains a owl:ObjectProperty ;
|
||||
rdfs:label "contains" ;
|
||||
rdfs:comment "Links a parent entity to a child (e.g. Document contains Page, Subgraph contains triple)." .
|
||||
|
||||
tg:pageCount a owl:DatatypeProperty ;
|
||||
rdfs:label "page count" ;
|
||||
rdfs:range xsd:integer ;
|
||||
rdfs:domain tg:Document .
|
||||
|
||||
tg:mimeType a owl:DatatypeProperty ;
|
||||
rdfs:label "MIME type" ;
|
||||
rdfs:range xsd:string ;
|
||||
rdfs:domain tg:Document .
|
||||
|
||||
tg:pageNumber a owl:DatatypeProperty ;
|
||||
rdfs:label "page number" ;
|
||||
rdfs:range xsd:integer ;
|
||||
rdfs:domain tg:Page .
|
||||
|
||||
tg:chunkIndex a owl:DatatypeProperty ;
|
||||
rdfs:label "chunk index" ;
|
||||
rdfs:range xsd:integer ;
|
||||
rdfs:domain tg:Chunk .
|
||||
|
||||
tg:charOffset a owl:DatatypeProperty ;
|
||||
rdfs:label "character offset" ;
|
||||
rdfs:range xsd:integer .
|
||||
|
||||
tg:charLength a owl:DatatypeProperty ;
|
||||
rdfs:label "character length" ;
|
||||
rdfs:range xsd:integer .
|
||||
|
||||
tg:chunkSize a owl:DatatypeProperty ;
|
||||
rdfs:label "chunk size" ;
|
||||
rdfs:range xsd:integer .
|
||||
|
||||
tg:chunkOverlap a owl:DatatypeProperty ;
|
||||
rdfs:label "chunk overlap" ;
|
||||
rdfs:range xsd:integer .
|
||||
|
||||
tg:componentVersion a owl:DatatypeProperty ;
|
||||
rdfs:label "component version" ;
|
||||
rdfs:range xsd:string .
|
||||
|
||||
tg:llmModel a owl:DatatypeProperty ;
|
||||
rdfs:label "LLM model" ;
|
||||
rdfs:range xsd:string .
|
||||
|
||||
tg:ontology a owl:DatatypeProperty ;
|
||||
rdfs:label "ontology" ;
|
||||
rdfs:range xsd:string .
|
||||
|
||||
tg:embeddingModel a owl:DatatypeProperty ;
|
||||
rdfs:label "embedding model" ;
|
||||
rdfs:range xsd:string .
|
||||
|
||||
tg:sourceText a owl:DatatypeProperty ;
|
||||
rdfs:label "source text" ;
|
||||
rdfs:range xsd:string .
|
||||
|
||||
tg:sourceCharOffset a owl:DatatypeProperty ;
|
||||
rdfs:label "source character offset" ;
|
||||
rdfs:range xsd:integer .
|
||||
|
||||
tg:sourceCharLength a owl:DatatypeProperty ;
|
||||
rdfs:label "source character length" ;
|
||||
rdfs:range xsd:integer .
|
||||
|
||||
tg:elementTypes a owl:DatatypeProperty ;
|
||||
rdfs:label "element types" ;
|
||||
rdfs:range xsd:string .
|
||||
|
||||
tg:tableCount a owl:DatatypeProperty ;
|
||||
rdfs:label "table count" ;
|
||||
rdfs:range xsd:integer .
|
||||
|
||||
tg:imageCount a owl:DatatypeProperty ;
|
||||
rdfs:label "image count" ;
|
||||
rdfs:range xsd:integer .
|
||||
|
||||
# =============================================================================
|
||||
# Properties — Query-time provenance (GraphRAG / DocumentRAG)
|
||||
# =============================================================================
|
||||
|
||||
tg:query a owl:DatatypeProperty ;
|
||||
rdfs:label "query" ;
|
||||
rdfs:comment "The user's query text." ;
|
||||
rdfs:range xsd:string ;
|
||||
rdfs:domain tg:Question .
|
||||
|
||||
tg:concept a owl:DatatypeProperty ;
|
||||
rdfs:label "concept" ;
|
||||
rdfs:comment "An extracted concept from the query." ;
|
||||
rdfs:range xsd:string ;
|
||||
rdfs:domain tg:Grounding .
|
||||
|
||||
tg:entity a owl:ObjectProperty ;
|
||||
rdfs:label "entity" ;
|
||||
rdfs:comment "A seed entity retrieved during exploration." ;
|
||||
rdfs:domain tg:Exploration .
|
||||
|
||||
tg:edgeCount a owl:DatatypeProperty ;
|
||||
rdfs:label "edge count" ;
|
||||
rdfs:comment "Number of edges explored." ;
|
||||
rdfs:range xsd:integer ;
|
||||
rdfs:domain tg:Exploration .
|
||||
|
||||
tg:selectedEdge a owl:ObjectProperty ;
|
||||
rdfs:label "selected edge" ;
|
||||
rdfs:comment "Link to an edge selection entity within a Focus event." ;
|
||||
rdfs:domain tg:Focus .
|
||||
|
||||
tg:edge a owl:ObjectProperty ;
|
||||
rdfs:label "edge" ;
|
||||
rdfs:comment "A quoted triple representing a knowledge graph edge." ;
|
||||
rdfs:domain tg:Focus .
|
||||
|
||||
tg:reasoning a owl:DatatypeProperty ;
|
||||
rdfs:label "reasoning" ;
|
||||
rdfs:comment "LLM-generated reasoning for an edge selection." ;
|
||||
rdfs:range xsd:string .
|
||||
|
||||
tg:document a owl:ObjectProperty ;
|
||||
rdfs:label "document" ;
|
||||
rdfs:comment "Reference to a document stored in the librarian." .
|
||||
|
||||
tg:chunkCount a owl:DatatypeProperty ;
|
||||
rdfs:label "chunk count" ;
|
||||
rdfs:comment "Number of document chunks retrieved (DocumentRAG)." ;
|
||||
rdfs:range xsd:integer ;
|
||||
rdfs:domain tg:Exploration .
|
||||
|
||||
tg:selectedChunk a owl:DatatypeProperty ;
|
||||
rdfs:label "selected chunk" ;
|
||||
rdfs:comment "A selected chunk ID (DocumentRAG)." ;
|
||||
rdfs:range xsd:string ;
|
||||
rdfs:domain tg:Exploration .
|
||||
|
||||
# =============================================================================
|
||||
# Properties — Agent provenance
|
||||
# =============================================================================
|
||||
|
||||
tg:thought a owl:ObjectProperty ;
|
||||
rdfs:label "thought" ;
|
||||
rdfs:comment "Links an Analysis iteration to its Thought sub-entity." ;
|
||||
rdfs:domain tg:Analysis ;
|
||||
rdfs:range tg:Thought .
|
||||
|
||||
tg:action a owl:DatatypeProperty ;
|
||||
rdfs:label "action" ;
|
||||
rdfs:comment "The tool/action name selected by the agent." ;
|
||||
rdfs:range xsd:string ;
|
||||
rdfs:domain tg:Analysis .
|
||||
|
||||
tg:arguments a owl:DatatypeProperty ;
|
||||
rdfs:label "arguments" ;
|
||||
rdfs:comment "JSON-encoded arguments passed to the tool." ;
|
||||
rdfs:range xsd:string ;
|
||||
rdfs:domain tg:Analysis .
|
||||
|
||||
tg:observation a owl:ObjectProperty ;
|
||||
rdfs:label "observation" ;
|
||||
rdfs:comment "Links an Analysis iteration to its Observation sub-entity." ;
|
||||
rdfs:domain tg:Analysis ;
|
||||
rdfs:range tg:Observation .
|
||||
|
||||
tg:toolCandidate a owl:DatatypeProperty ;
|
||||
rdfs:label "tool candidate" ;
|
||||
rdfs:comment "Name of a tool available to the LLM for this iteration. One triple per candidate." ;
|
||||
rdfs:range xsd:string ;
|
||||
rdfs:domain tg:Analysis .
|
||||
|
||||
tg:stepNumber a owl:DatatypeProperty ;
|
||||
rdfs:label "step number" ;
|
||||
rdfs:comment "Explicit 1-based step counter for iteration events." ;
|
||||
rdfs:range xsd:integer ;
|
||||
rdfs:domain tg:Analysis .
|
||||
|
||||
tg:terminationReason a owl:DatatypeProperty ;
|
||||
rdfs:label "termination reason" ;
|
||||
rdfs:comment "Why the agent loop stopped: final-answer, plan-complete, subagents-complete, max-iterations, error." ;
|
||||
rdfs:range xsd:string .
|
||||
|
||||
tg:pattern a owl:DatatypeProperty ;
|
||||
rdfs:label "pattern" ;
|
||||
rdfs:comment "Selected execution pattern (react, plan-then-execute, supervisor)." ;
|
||||
rdfs:range xsd:string ;
|
||||
rdfs:domain tg:PatternDecision .
|
||||
|
||||
tg:taskType a owl:DatatypeProperty ;
|
||||
rdfs:label "task type" ;
|
||||
rdfs:comment "Identified task type from the meta-router (general, research, etc.)." ;
|
||||
rdfs:range xsd:string ;
|
||||
rdfs:domain tg:PatternDecision .
|
||||
|
||||
tg:llmDurationMs a owl:DatatypeProperty ;
|
||||
rdfs:label "LLM duration (ms)" ;
|
||||
rdfs:comment "Time spent in the LLM prompt call, in milliseconds." ;
|
||||
rdfs:range xsd:integer ;
|
||||
rdfs:domain tg:Analysis .
|
||||
|
||||
tg:toolDurationMs a owl:DatatypeProperty ;
|
||||
rdfs:label "tool duration (ms)" ;
|
||||
rdfs:comment "Time spent executing the tool, in milliseconds." ;
|
||||
rdfs:range xsd:integer ;
|
||||
rdfs:domain tg:Observation .
|
||||
|
||||
tg:toolError a owl:DatatypeProperty ;
|
||||
rdfs:label "tool error" ;
|
||||
rdfs:comment "Error message from a failed tool execution." ;
|
||||
rdfs:range xsd:string ;
|
||||
rdfs:domain tg:Observation .
|
||||
|
||||
# --- Token usage predicates (on any event that involves an LLM call) ---
|
||||
|
||||
tg:inToken a owl:DatatypeProperty ;
|
||||
rdfs:label "input tokens" ;
|
||||
rdfs:comment "Input token count for the LLM call associated with this event." ;
|
||||
rdfs:range xsd:integer .
|
||||
|
||||
tg:outToken a owl:DatatypeProperty ;
|
||||
rdfs:label "output tokens" ;
|
||||
rdfs:comment "Output token count for the LLM call associated with this event." ;
|
||||
rdfs:range xsd:integer .
|
||||
|
||||
# --- Orchestrator predicates ---
|
||||
|
||||
tg:subagentGoal a owl:DatatypeProperty ;
|
||||
rdfs:label "sub-agent goal" ;
|
||||
rdfs:comment "Goal string assigned to a sub-agent (Decomposition, Finding)." ;
|
||||
rdfs:range xsd:string .
|
||||
|
||||
tg:planStep a owl:DatatypeProperty ;
|
||||
rdfs:label "plan step" ;
|
||||
rdfs:comment "Goal string for a plan step (Plan, StepResult)." ;
|
||||
rdfs:range xsd:string .
|
||||
|
||||
# =============================================================================
|
||||
# Named graphs
|
||||
# =============================================================================
|
||||
# These are not OWL classes but documented here for reference:
|
||||
#
|
||||
# (default graph) — Core knowledge facts (extracted triples)
|
||||
# urn:graph:source — Extraction provenance (document → chunk → triple)
|
||||
# urn:graph:retrieval — Query-time explainability (question → exploration → synthesis)
|
||||
|
|
@ -87,7 +87,7 @@ def sample_message_data():
|
|||
"history": []
|
||||
},
|
||||
"AgentResponse": {
|
||||
"chunk_type": "answer",
|
||||
"message_type": "answer",
|
||||
"content": "Machine learning is a subset of AI.",
|
||||
"end_of_message": True,
|
||||
"end_of_dialog": True,
|
||||
|
|
|
|||
|
|
@ -212,7 +212,7 @@ class TestAgentMessageContracts:
|
|||
|
||||
# Test required fields
|
||||
response = AgentResponse(**response_data)
|
||||
assert hasattr(response, 'chunk_type')
|
||||
assert hasattr(response, 'message_type')
|
||||
assert hasattr(response, 'content')
|
||||
assert hasattr(response, 'end_of_message')
|
||||
assert hasattr(response, 'end_of_dialog')
|
||||
|
|
|
|||
73
tests/contract/test_schema_field_contracts.py
Normal file
73
tests/contract/test_schema_field_contracts.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
"""
|
||||
Contract tests for schema dataclass field sets.
|
||||
|
||||
These pin the *field names* of small, widely-constructed schema dataclasses
|
||||
so that any rename, removal, or accidental addition fails CI loudly instead
|
||||
of waiting for a runtime TypeError on the next websocket message.
|
||||
|
||||
Background: in v2.2 the `Metadata` dataclass dropped a `metadata: list[Triple]`
|
||||
field but several call sites kept passing `Metadata(metadata=...)`. The bug
|
||||
was only discovered when a websocket import dispatcher received its first
|
||||
real message in production. A trivial structural assertion of the kind
|
||||
below would have caught it at unit-test time.
|
||||
|
||||
Add to this file whenever a schema rename burns you. The cost of a frozen
|
||||
field set is a one-line update when you intentionally evolve the schema; the
|
||||
benefit is that every call site is forced to come along for the ride.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import pytest
|
||||
|
||||
from trustgraph.schema import (
|
||||
Metadata,
|
||||
EntityContext,
|
||||
EntityEmbeddings,
|
||||
ChunkEmbeddings,
|
||||
)
|
||||
|
||||
|
||||
def _field_names(dc):
|
||||
return {f.name for f in dataclasses.fields(dc)}
|
||||
|
||||
|
||||
@pytest.mark.contract
|
||||
class TestSchemaFieldContracts:
|
||||
"""Pin the field set of dataclasses that get constructed all over the
|
||||
codebase. If you intentionally change one of these, update the
|
||||
expected set in the same commit — that diff will surface every call
|
||||
site that needs to come along."""
|
||||
|
||||
def test_metadata_fields(self):
|
||||
# NOTE: there is no `metadata` field. A previous regression
|
||||
# constructed Metadata(metadata=...) and crashed at runtime.
|
||||
assert _field_names(Metadata) == {
|
||||
"id",
|
||||
"root",
|
||||
"user",
|
||||
"collection",
|
||||
}
|
||||
|
||||
def test_entity_embeddings_fields(self):
|
||||
# NOTE: the embedding field is `vector` (singular, list[float]).
|
||||
# There is no `vectors` field. Several call sites historically
|
||||
# passed `vectors=` and crashed at runtime.
|
||||
assert _field_names(EntityEmbeddings) == {
|
||||
"entity",
|
||||
"vector",
|
||||
"chunk_id",
|
||||
}
|
||||
|
||||
def test_chunk_embeddings_fields(self):
|
||||
# Same `vector` (singular) convention as EntityEmbeddings.
|
||||
assert _field_names(ChunkEmbeddings) == {
|
||||
"chunk_id",
|
||||
"vector",
|
||||
}
|
||||
|
||||
def test_entity_context_fields(self):
|
||||
assert _field_names(EntityContext) == {
|
||||
"entity",
|
||||
"context",
|
||||
"chunk_id",
|
||||
}
|
||||
|
|
@ -188,7 +188,7 @@ class TestAgentTranslatorCompletionFlags:
|
|||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("agent")
|
||||
response = AgentResponse(
|
||||
chunk_type="answer",
|
||||
message_type="answer",
|
||||
content="4",
|
||||
end_of_message=True,
|
||||
end_of_dialog=True,
|
||||
|
|
@ -210,7 +210,7 @@ class TestAgentTranslatorCompletionFlags:
|
|||
# Arrange
|
||||
translator = TranslatorRegistry.get_response_translator("agent")
|
||||
response = AgentResponse(
|
||||
chunk_type="thought",
|
||||
message_type="thought",
|
||||
content="I need to solve this.",
|
||||
end_of_message=True,
|
||||
end_of_dialog=False,
|
||||
|
|
@ -233,7 +233,7 @@ class TestAgentTranslatorCompletionFlags:
|
|||
|
||||
# Test thought message
|
||||
thought_response = AgentResponse(
|
||||
chunk_type="thought",
|
||||
message_type="thought",
|
||||
content="Processing...",
|
||||
end_of_message=True,
|
||||
end_of_dialog=False,
|
||||
|
|
@ -247,7 +247,7 @@ class TestAgentTranslatorCompletionFlags:
|
|||
|
||||
# Test observation message
|
||||
observation_response = AgentResponse(
|
||||
chunk_type="observation",
|
||||
message_type="observation",
|
||||
content="Result found",
|
||||
end_of_message=True,
|
||||
end_of_dialog=False,
|
||||
|
|
@ -268,7 +268,7 @@ class TestAgentTranslatorCompletionFlags:
|
|||
|
||||
# Streaming format with end_of_dialog=True
|
||||
response = AgentResponse(
|
||||
chunk_type="answer",
|
||||
message_type="answer",
|
||||
content="",
|
||||
end_of_message=True,
|
||||
end_of_dialog=True,
|
||||
|
|
|
|||
|
|
@ -418,55 +418,55 @@ def sample_streaming_agent_response():
|
|||
"""Sample streaming agent response chunks"""
|
||||
return [
|
||||
{
|
||||
"chunk_type": "thought",
|
||||
"message_type": "thought",
|
||||
"content": "I need to search",
|
||||
"end_of_message": False,
|
||||
"end_of_dialog": False
|
||||
},
|
||||
{
|
||||
"chunk_type": "thought",
|
||||
"message_type": "thought",
|
||||
"content": " for information",
|
||||
"end_of_message": False,
|
||||
"end_of_dialog": False
|
||||
},
|
||||
{
|
||||
"chunk_type": "thought",
|
||||
"message_type": "thought",
|
||||
"content": " about machine learning.",
|
||||
"end_of_message": True,
|
||||
"end_of_dialog": False
|
||||
},
|
||||
{
|
||||
"chunk_type": "action",
|
||||
"message_type": "action",
|
||||
"content": "knowledge_query",
|
||||
"end_of_message": True,
|
||||
"end_of_dialog": False
|
||||
},
|
||||
{
|
||||
"chunk_type": "observation",
|
||||
"message_type": "observation",
|
||||
"content": "Machine learning is",
|
||||
"end_of_message": False,
|
||||
"end_of_dialog": False
|
||||
},
|
||||
{
|
||||
"chunk_type": "observation",
|
||||
"message_type": "observation",
|
||||
"content": " a subset of AI.",
|
||||
"end_of_message": True,
|
||||
"end_of_dialog": False
|
||||
},
|
||||
{
|
||||
"chunk_type": "final-answer",
|
||||
"message_type": "final-answer",
|
||||
"content": "Machine learning",
|
||||
"end_of_message": False,
|
||||
"end_of_dialog": False
|
||||
},
|
||||
{
|
||||
"chunk_type": "final-answer",
|
||||
"message_type": "final-answer",
|
||||
"content": " is a subset",
|
||||
"end_of_message": False,
|
||||
"end_of_dialog": False
|
||||
},
|
||||
{
|
||||
"chunk_type": "final-answer",
|
||||
"message_type": "final-answer",
|
||||
"content": " of artificial intelligence.",
|
||||
"end_of_message": True,
|
||||
"end_of_dialog": True
|
||||
|
|
@ -494,10 +494,10 @@ def streaming_chunk_collector():
|
|||
"""Concatenate all chunk content"""
|
||||
return "".join(self.chunks)
|
||||
|
||||
def get_chunk_types(self):
|
||||
def get_message_types(self):
|
||||
"""Get list of chunk types if chunks are dicts"""
|
||||
if self.chunks and isinstance(self.chunks[0], dict):
|
||||
return [c.get("chunk_type") for c in self.chunks]
|
||||
return [c.get("message_type") for c in self.chunks]
|
||||
return []
|
||||
|
||||
def verify_streaming_protocol(self):
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from trustgraph.agent.react.agent_manager import AgentManager
|
|||
from trustgraph.agent.react.tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl
|
||||
from trustgraph.agent.react.types import Action, Final, Tool, Argument
|
||||
from trustgraph.schema import AgentRequest, AgentResponse, AgentStep, Error
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -28,19 +29,25 @@ class TestAgentManagerIntegration:
|
|||
|
||||
# Mock prompt client
|
||||
prompt_client = AsyncMock()
|
||||
prompt_client.agent_react.return_value = """Thought: I need to search for information about machine learning
|
||||
prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to search for information about machine learning
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "What is machine learning?"
|
||||
}"""
|
||||
|
||||
)
|
||||
|
||||
# Mock graph RAG client
|
||||
graph_rag_client = AsyncMock()
|
||||
graph_rag_client.rag.return_value = "Machine learning is a subset of AI that enables computers to learn from data."
|
||||
|
||||
|
||||
# Mock text completion client
|
||||
text_completion_client = AsyncMock()
|
||||
text_completion_client.question.return_value = "Machine learning involves algorithms that improve through experience."
|
||||
text_completion_client.question.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="Machine learning involves algorithms that improve through experience."
|
||||
)
|
||||
|
||||
# Mock MCP tool client
|
||||
mcp_tool_client = AsyncMock()
|
||||
|
|
@ -147,8 +154,11 @@ Args: {
|
|||
async def test_agent_manager_final_answer(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager returning final answer"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have enough information to answer the question
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I have enough information to answer the question
|
||||
Final Answer: Machine learning is a field of AI that enables computers to learn from data."""
|
||||
)
|
||||
|
||||
question = "What is machine learning?"
|
||||
history = []
|
||||
|
|
@ -193,8 +203,11 @@ Final Answer: Machine learning is a field of AI that enables computers to learn
|
|||
async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context):
|
||||
"""Test ReAct cycle ending with final answer"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide a direct answer
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I can provide a direct answer
|
||||
Final Answer: Machine learning is a branch of artificial intelligence."""
|
||||
)
|
||||
|
||||
question = "What is machine learning?"
|
||||
history = []
|
||||
|
|
@ -254,11 +267,14 @@ Final Answer: Machine learning is a branch of artificial intelligence."""
|
|||
|
||||
for tool_name, expected_service in tool_scenarios:
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: I need to use {tool_name}
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text=f"""Thought: I need to use {tool_name}
|
||||
Action: {tool_name}
|
||||
Args: {{
|
||||
"question": "test question"
|
||||
}}"""
|
||||
)
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
|
@ -284,11 +300,14 @@ Args: {{
|
|||
async def test_agent_manager_unknown_tool_error(self, agent_manager, mock_flow_context):
|
||||
"""Test agent manager error handling for unknown tool"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to use an unknown tool
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to use an unknown tool
|
||||
Action: unknown_tool
|
||||
Args: {
|
||||
"param": "value"
|
||||
}"""
|
||||
)
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
|
@ -308,11 +327,13 @@ Args: {
|
|||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await agent_manager.react("test question", [], think_callback, observe_callback, mock_flow_context)
|
||||
|
||||
assert "Tool execution failed" in str(exc_info.value)
|
||||
# Act - tool errors are now caught and returned as observations
|
||||
result = await agent_manager.react("test question", [], think_callback, observe_callback, mock_flow_context)
|
||||
|
||||
# Assert - error captured on the action, not raised
|
||||
assert result.tool_error is not None
|
||||
assert "Tool execution failed" in result.tool_error
|
||||
assert "Error:" in result.observation
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_manager_multiple_tools_coordination(self, agent_manager, mock_flow_context):
|
||||
|
|
@ -321,11 +342,14 @@ Args: {
|
|||
question = "Find information about AI and summarize it"
|
||||
|
||||
# Mock multi-step reasoning
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search for AI information first
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to search for AI information first
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "What is artificial intelligence?"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Act
|
||||
action = await agent_manager.reason(question, [], mock_flow_context)
|
||||
|
|
@ -372,9 +396,12 @@ Args: {
|
|||
# Format arguments as JSON
|
||||
import json
|
||||
args_json = json.dumps(test_case['arguments'], indent=4)
|
||||
mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: Using {test_case['action']}
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text=f"""Thought: Using {test_case['action']}
|
||||
Action: {test_case['action']}
|
||||
Args: {args_json}"""
|
||||
)
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
|
@ -507,15 +534,17 @@ Args: {
|
|||
]
|
||||
|
||||
for test_case in test_cases:
|
||||
mock_flow_context("prompt-request").agent_react.return_value = test_case["response"]
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text=test_case["response"]
|
||||
)
|
||||
|
||||
if test_case["error_contains"]:
|
||||
# Should raise an error
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await agent_manager.reason("test question", [], mock_flow_context)
|
||||
|
||||
assert "Failed to parse agent response" in str(exc_info.value)
|
||||
assert test_case["error_contains"] in str(exc_info.value)
|
||||
# Parse errors now return an Action with tool_error
|
||||
result = await agent_manager.reason("test question", [], mock_flow_context)
|
||||
assert isinstance(result, Action)
|
||||
assert result.name == "__parse_error__"
|
||||
assert result.tool_error is not None
|
||||
else:
|
||||
# Should succeed
|
||||
action = await agent_manager.reason("test question", [], mock_flow_context)
|
||||
|
|
@ -527,13 +556,16 @@ Args: {
|
|||
async def test_agent_manager_text_parsing_edge_cases(self, agent_manager, mock_flow_context):
|
||||
"""Test edge cases in text parsing"""
|
||||
# Test response with markdown code blocks
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """```
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""```
|
||||
Thought: I need to search for information
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "What is AI?"
|
||||
}
|
||||
```"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -541,15 +573,18 @@ Args: {
|
|||
assert action.name == "knowledge_query"
|
||||
|
||||
# Test response with extra whitespace
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""
|
||||
|
||||
Thought: I need to think about this
|
||||
Action: knowledge_query
|
||||
Thought: I need to think about this
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "test"
|
||||
}
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -560,7 +595,9 @@ Args: {
|
|||
async def test_agent_manager_multiline_content(self, agent_manager, mock_flow_context):
|
||||
"""Test handling of multi-line thoughts and final answers"""
|
||||
# Multi-line thought
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to consider multiple factors:
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to consider multiple factors:
|
||||
1. The user's question is complex
|
||||
2. I should search for comprehensive information
|
||||
3. This requires using the knowledge query tool
|
||||
|
|
@ -568,6 +605,7 @@ Action: knowledge_query
|
|||
Args: {
|
||||
"question": "complex query"
|
||||
}"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -575,13 +613,16 @@ Args: {
|
|||
assert "knowledge query tool" in action.thought
|
||||
|
||||
# Multi-line final answer
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have gathered enough information
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I have gathered enough information
|
||||
Final Answer: Here is a comprehensive answer:
|
||||
1. First point about the topic
|
||||
2. Second point with details
|
||||
3. Final conclusion
|
||||
|
||||
This covers all aspects of the question."""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Final)
|
||||
|
|
@ -593,13 +634,16 @@ This covers all aspects of the question."""
|
|||
async def test_agent_manager_json_args_special_characters(self, agent_manager, mock_flow_context):
|
||||
"""Test JSON arguments with special characters and edge cases"""
|
||||
# Test with special characters in JSON (properly escaped)
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: Processing special characters
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: Processing special characters
|
||||
Action: knowledge_query
|
||||
Args: {
|
||||
"question": "What about \\"quotes\\" and 'apostrophes'?",
|
||||
"context": "Line 1\\nLine 2\\tTabbed",
|
||||
"special": "Symbols: @#$%^&*()_+-=[]{}|;':,.<>?"
|
||||
}"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -608,7 +652,9 @@ Args: {
|
|||
assert "@#$%^&*" in action.arguments["special"]
|
||||
|
||||
# Test with nested JSON
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: Complex arguments
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: Complex arguments
|
||||
Action: web_search
|
||||
Args: {
|
||||
"query": "test",
|
||||
|
|
@ -621,6 +667,7 @@ Args: {
|
|||
}
|
||||
}
|
||||
}"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Action)
|
||||
|
|
@ -632,7 +679,9 @@ Args: {
|
|||
async def test_agent_manager_final_answer_json_format(self, agent_manager, mock_flow_context):
|
||||
"""Test final answers that contain JSON-like content"""
|
||||
# Final answer with JSON content
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide the data in JSON format
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I can provide the data in JSON format
|
||||
Final Answer: {
|
||||
"result": "success",
|
||||
"data": {
|
||||
|
|
@ -642,6 +691,7 @@ Final Answer: {
|
|||
},
|
||||
"confidence": 0.95
|
||||
}"""
|
||||
)
|
||||
|
||||
action = await agent_manager.reason("test", [], mock_flow_context)
|
||||
assert isinstance(action, Final)
|
||||
|
|
@ -792,11 +842,14 @@ Final Answer: {
|
|||
agent = AgentManager(tools=custom_tools, additional_context="")
|
||||
|
||||
# Mock response for custom collection query
|
||||
mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search in the research papers
|
||||
mock_flow_context("prompt-request").agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to search in the research papers
|
||||
Action: knowledge_query_custom
|
||||
Args: {
|
||||
"question": "Latest AI research?"
|
||||
}"""
|
||||
)
|
||||
|
||||
think_callback = AsyncMock()
|
||||
observe_callback = AsyncMock()
|
||||
|
|
|
|||
|
|
@ -10,11 +10,12 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
from trustgraph.agent.react.agent_manager import AgentManager
|
||||
from trustgraph.agent.react.tools import KnowledgeQueryImpl
|
||||
from trustgraph.agent.react.types import Tool, Argument
|
||||
from trustgraph.base import PromptResult
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_agent_streaming_chunks,
|
||||
assert_streaming_chunks_valid,
|
||||
assert_callback_invoked,
|
||||
assert_chunk_types_valid,
|
||||
assert_message_types_valid,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -51,10 +52,10 @@ Args: {
|
|||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk, is_final)
|
||||
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
else:
|
||||
# Non-streaming response - same text
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
|
||||
client.agent_react.side_effect = agent_react_streaming
|
||||
return client
|
||||
|
|
@ -317,8 +318,8 @@ Final Answer: AI is the simulation of human intelligence in machines."""
|
|||
for i, chunk in enumerate(chunks):
|
||||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk + " ", is_final)
|
||||
return response
|
||||
return response
|
||||
return PromptResult(response_type="text", text=response)
|
||||
return PromptResult(response_type="text", text=response)
|
||||
|
||||
mock_prompt_client_streaming.agent_react.side_effect = multi_step_agent_react
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from trustgraph.schema import (
|
|||
Error
|
||||
)
|
||||
from trustgraph.agent.react.service import Processor
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -95,11 +96,14 @@ class TestAgentStructuredQueryIntegration:
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from New York using structured query
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to find customers from New York using structured query
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Find all customers from New York"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
@ -173,11 +177,14 @@ Args: {
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to query for a table that might not exist
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to query for a table that might not exist
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Find data from a table that doesn't exist"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
@ -250,11 +257,14 @@ Args: {
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from California first
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to find customers from California first
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Find all customers from California"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
@ -339,11 +349,14 @@ Args: {
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to query the sales data
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to query the sales data
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Query the sales data for recent transactions"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
@ -447,11 +460,14 @@ Args: {
|
|||
|
||||
# Mock the prompt client that agent calls for reasoning
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.agent_react.return_value = """Thought: I need to get customer information
|
||||
mock_prompt_client.agent_react.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="""Thought: I need to get customer information
|
||||
Action: structured-query
|
||||
Args: {
|
||||
"question": "Get customer information and format it nicely"
|
||||
}"""
|
||||
)
|
||||
|
||||
# Set up flow context routing
|
||||
def flow_context(service_name):
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.schema import ChunkMatch
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
# Sample chunk content for testing - maps chunk_id to content
|
||||
|
|
@ -61,11 +62,16 @@ class TestDocumentRagIntegration:
|
|||
def mock_prompt_client(self):
|
||||
"""Mock prompt client that generates realistic responses"""
|
||||
client = AsyncMock()
|
||||
client.document_prompt.return_value = (
|
||||
"Machine learning is a field of artificial intelligence that enables computers to learn "
|
||||
"and improve from experience without being explicitly programmed. It uses algorithms "
|
||||
"to find patterns in data and make predictions or decisions."
|
||||
client.document_prompt.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text=(
|
||||
"Machine learning is a field of artificial intelligence that enables computers to learn "
|
||||
"and improve from experience without being explicitly programmed. It uses algorithms "
|
||||
"to find patterns in data and make predictions or decisions."
|
||||
)
|
||||
)
|
||||
# Mock prompt() for extract-concepts call in DocumentRag
|
||||
client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -119,6 +125,7 @@ class TestDocumentRagIntegration:
|
|||
)
|
||||
|
||||
# Verify final response
|
||||
result, usage = result
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert "machine learning" in result.lower()
|
||||
|
|
@ -131,7 +138,11 @@ class TestDocumentRagIntegration:
|
|||
"""Test DocumentRAG behavior when no documents are retrieved"""
|
||||
# Arrange
|
||||
mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found
|
||||
mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query."
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="I couldn't find any relevant documents for your query."
|
||||
)
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
|
||||
document_rag = DocumentRag(
|
||||
embeddings_client=mock_embeddings_client,
|
||||
|
|
@ -152,7 +163,8 @@ class TestDocumentRagIntegration:
|
|||
documents=[]
|
||||
)
|
||||
|
||||
assert result == "I couldn't find any relevant documents for your query."
|
||||
result_text, usage = result
|
||||
assert result_text == "I couldn't find any relevant documents for your query."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.schema import ChunkMatch
|
||||
from trustgraph.base import PromptResult
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_streaming_chunks_valid,
|
||||
assert_callback_invoked,
|
||||
|
|
@ -74,12 +75,14 @@ class TestDocumentRagStreaming:
|
|||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk, is_final)
|
||||
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
else:
|
||||
# Non-streaming response - same text
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
|
||||
client.document_prompt.side_effect = document_prompt_side_effect
|
||||
# Mock prompt() for extract-concepts call in DocumentRag
|
||||
client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -119,11 +122,12 @@ class TestDocumentRagStreaming:
|
|||
collector.verify_streaming_protocol()
|
||||
|
||||
# Verify full response matches concatenated chunks
|
||||
result_text, usage = result
|
||||
full_from_chunks = collector.get_full_text()
|
||||
assert result == full_from_chunks
|
||||
assert result_text == full_from_chunks
|
||||
|
||||
# Verify content is reasonable
|
||||
assert len(result) > 0
|
||||
assert len(result_text) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_streaming_vs_non_streaming(self, document_rag_streaming):
|
||||
|
|
@ -159,9 +163,11 @@ class TestDocumentRagStreaming:
|
|||
)
|
||||
|
||||
# Assert - Results should be equivalent
|
||||
assert streaming_result == non_streaming_result
|
||||
non_streaming_text, _ = non_streaming_result
|
||||
streaming_text, _ = streaming_result
|
||||
assert streaming_text == non_streaming_text
|
||||
assert len(streaming_chunks) > 0
|
||||
assert "".join(streaming_chunks) == streaming_result
|
||||
assert "".join(streaming_chunks) == streaming_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_streaming_callback_invocation(self, document_rag_streaming):
|
||||
|
|
@ -180,8 +186,9 @@ class TestDocumentRagStreaming:
|
|||
)
|
||||
|
||||
# Assert
|
||||
result_text, usage = result
|
||||
assert callback.call_count > 0
|
||||
assert result is not None
|
||||
assert result_text is not None
|
||||
|
||||
# Verify all callback invocations had string arguments
|
||||
for call in callback.call_args_list:
|
||||
|
|
@ -202,7 +209,8 @@ class TestDocumentRagStreaming:
|
|||
|
||||
# Assert - Should complete without error
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
result_text, usage = result
|
||||
assert isinstance(result_text, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_streaming_with_no_documents(self, document_rag_streaming,
|
||||
|
|
@ -223,7 +231,8 @@ class TestDocumentRagStreaming:
|
|||
)
|
||||
|
||||
# Assert - Should still produce streamed response
|
||||
assert result is not None
|
||||
result_text, usage = result
|
||||
assert result_text is not None
|
||||
assert callback.call_count > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -271,7 +280,8 @@ class TestDocumentRagStreaming:
|
|||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
result_text, usage = result
|
||||
assert result_text is not None
|
||||
assert callback.call_count > 0
|
||||
|
||||
# Verify doc_limit was passed correctly
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.schema import EntityMatch, Term, IRI
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -93,18 +94,21 @@ class TestGraphRagIntegration:
|
|||
# 4. kg-synthesis returns the final answer
|
||||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "extract-concepts":
|
||||
return "" # Falls back to raw query
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-scoring":
|
||||
return "" # No edges scored
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-reasoning":
|
||||
return "" # No reasoning
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return (
|
||||
"Machine learning is a subset of artificial intelligence that enables computers "
|
||||
"to learn from data without being explicitly programmed. It uses algorithms "
|
||||
"and statistical models to find patterns in data."
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
text=(
|
||||
"Machine learning is a subset of artificial intelligence that enables computers "
|
||||
"to learn from data without being explicitly programmed. It uses algorithms "
|
||||
"and statistical models to find patterns in data."
|
||||
)
|
||||
)
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
client.prompt.side_effect = mock_prompt
|
||||
return client
|
||||
|
|
@ -169,6 +173,7 @@ class TestGraphRagIntegration:
|
|||
assert mock_prompt_client.prompt.call_count == 4
|
||||
|
||||
# Verify final response
|
||||
response, usage = response
|
||||
assert response is not None
|
||||
assert isinstance(response, str)
|
||||
assert "machine learning" in response.lower()
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.schema import EntityMatch, Term, IRI
|
||||
from trustgraph.base import PromptResult
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_streaming_chunks_valid,
|
||||
assert_rag_streaming_chunks,
|
||||
|
|
@ -61,12 +62,12 @@ class TestGraphRagStreaming:
|
|||
|
||||
async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs):
|
||||
if prompt_id == "extract-concepts":
|
||||
return "" # Falls back to raw query
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_id == "kg-edge-scoring":
|
||||
# Edge scoring returns JSONL with IDs and scores
|
||||
return '{"id": "abc12345", "score": 0.9}\n'
|
||||
return PromptResult(response_type="text", text='{"id": "abc12345", "score": 0.9}\n')
|
||||
elif prompt_id == "kg-edge-reasoning":
|
||||
return '{"id": "abc12345", "reasoning": "Relevant to query"}\n'
|
||||
return PromptResult(response_type="text", text='{"id": "abc12345", "reasoning": "Relevant to query"}\n')
|
||||
elif prompt_id == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
# Simulate streaming chunks with end_of_stream flags
|
||||
|
|
@ -79,10 +80,10 @@ class TestGraphRagStreaming:
|
|||
is_final = (i == len(chunks) - 1)
|
||||
await chunk_callback(chunk, is_final)
|
||||
|
||||
return full_text
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
else:
|
||||
return full_text
|
||||
return ""
|
||||
return PromptResult(response_type="text", text=full_text)
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
client.prompt.side_effect = prompt_side_effect
|
||||
return client
|
||||
|
|
@ -123,6 +124,7 @@ class TestGraphRagStreaming:
|
|||
)
|
||||
|
||||
# Assert
|
||||
response, usage = response
|
||||
assert_streaming_chunks_valid(collector.chunks, min_chunks=1)
|
||||
assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1)
|
||||
|
||||
|
|
@ -172,9 +174,11 @@ class TestGraphRagStreaming:
|
|||
)
|
||||
|
||||
# Assert - Results should be equivalent
|
||||
assert streaming_response == non_streaming_response
|
||||
non_streaming_text, _ = non_streaming_response
|
||||
streaming_text, _ = streaming_response
|
||||
assert streaming_text == non_streaming_text
|
||||
assert len(streaming_chunks) > 0
|
||||
assert "".join(streaming_chunks) == streaming_response
|
||||
assert "".join(streaming_chunks) == streaming_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_streaming_callback_invocation(self, graph_rag_streaming):
|
||||
|
|
@ -213,7 +217,8 @@ class TestGraphRagStreaming:
|
|||
|
||||
# Assert - Should complete without error
|
||||
assert response is not None
|
||||
assert isinstance(response, str)
|
||||
response_text, usage = response
|
||||
assert isinstance(response_text, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_rag_streaming_with_empty_kg(self, graph_rag_streaming,
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from trustgraph.storage.knowledge.store import Processor as KnowledgeStoreProces
|
|||
from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL
|
||||
from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings, EntityEmbeddings
|
||||
from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -31,32 +32,38 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
|
||||
# Mock prompt client for definitions extraction
|
||||
prompt_client = AsyncMock()
|
||||
prompt_client.extract_definitions.return_value = [
|
||||
{
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
|
||||
},
|
||||
{
|
||||
"entity": "Neural Networks",
|
||||
"definition": "Computing systems inspired by biological neural networks that process information."
|
||||
}
|
||||
]
|
||||
|
||||
prompt_client.extract_definitions.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"entity": "Machine Learning",
|
||||
"definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming."
|
||||
},
|
||||
{
|
||||
"entity": "Neural Networks",
|
||||
"definition": "Computing systems inspired by biological neural networks that process information."
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Mock prompt client for relationships extraction
|
||||
prompt_client.extract_relationships.return_value = [
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Neural Networks",
|
||||
"predicate": "is_used_in",
|
||||
"object": "Machine Learning",
|
||||
"object-entity": True
|
||||
}
|
||||
]
|
||||
prompt_client.extract_relationships.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"subject": "Machine Learning",
|
||||
"predicate": "is_subset_of",
|
||||
"object": "Artificial Intelligence",
|
||||
"object-entity": True
|
||||
},
|
||||
{
|
||||
"subject": "Neural Networks",
|
||||
"predicate": "is_used_in",
|
||||
"object": "Machine Learning",
|
||||
"object-entity": True
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Mock producers for output streams
|
||||
triples_producer = AsyncMock()
|
||||
|
|
@ -489,7 +496,10 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
async def test_empty_extraction_results_handling(self, definitions_processor, mock_flow_context, sample_chunk):
|
||||
"""Test handling of empty extraction results"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = []
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[]
|
||||
)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
|
|
@ -510,7 +520,10 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
async def test_invalid_extraction_format_handling(self, definitions_processor, mock_flow_context, sample_chunk):
|
||||
"""Test handling of invalid extraction response format"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = "invalid format" # Should be list
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="invalid format"
|
||||
) # Should be jsonl with objects list
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.value.return_value = sample_chunk
|
||||
|
|
@ -528,13 +541,16 @@ class TestKnowledgeGraphPipelineIntegration:
|
|||
async def test_entity_filtering_and_validation(self, definitions_processor, mock_flow_context):
|
||||
"""Test entity filtering and validation in extraction"""
|
||||
# Arrange
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = [
|
||||
{"entity": "Valid Entity", "definition": "Valid definition"},
|
||||
{"entity": "", "definition": "Empty entity"}, # Should be filtered
|
||||
{"entity": "Valid Entity 2", "definition": ""}, # Should be filtered
|
||||
{"entity": None, "definition": "None entity"}, # Should be filtered
|
||||
{"entity": "Valid Entity 3", "definition": None}, # Should be filtered
|
||||
]
|
||||
mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"entity": "Valid Entity", "definition": "Valid definition"},
|
||||
{"entity": "", "definition": "Empty entity"}, # Should be filtered
|
||||
{"entity": "Valid Entity 2", "definition": ""}, # Should be filtered
|
||||
{"entity": None, "definition": "None entity"}, # Should be filtered
|
||||
{"entity": "Valid Entity 3", "definition": None}, # Should be filtered
|
||||
]
|
||||
)
|
||||
|
||||
sample_chunk = Chunk(
|
||||
metadata=Metadata(id="test", user="user", collection="collection"),
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from trustgraph.schema import (
|
|||
Chunk, ExtractedObject, Metadata, RowSchema, Field,
|
||||
PromptRequest, PromptResponse
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -114,49 +115,61 @@ class TestObjectExtractionServiceIntegration:
|
|||
schema_name = schema.get("name") if isinstance(schema, dict) else schema.name
|
||||
if schema_name == "customer_records":
|
||||
if "john" in text.lower():
|
||||
return [
|
||||
{
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Smith",
|
||||
"email": "john.smith@email.com",
|
||||
"phone": "555-0123"
|
||||
}
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"customer_id": "CUST001",
|
||||
"name": "John Smith",
|
||||
"email": "john.smith@email.com",
|
||||
"phone": "555-0123"
|
||||
}
|
||||
]
|
||||
)
|
||||
elif "jane" in text.lower():
|
||||
return [
|
||||
{
|
||||
"customer_id": "CUST002",
|
||||
"name": "Jane Doe",
|
||||
"email": "jane.doe@email.com",
|
||||
"phone": ""
|
||||
}
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"customer_id": "CUST002",
|
||||
"name": "Jane Doe",
|
||||
"email": "jane.doe@email.com",
|
||||
"phone": ""
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
return []
|
||||
|
||||
return PromptResult(response_type="jsonl", objects=[])
|
||||
|
||||
elif schema_name == "product_catalog":
|
||||
if "laptop" in text.lower():
|
||||
return [
|
||||
{
|
||||
"product_id": "PROD001",
|
||||
"name": "Gaming Laptop",
|
||||
"price": "1299.99",
|
||||
"category": "electronics"
|
||||
}
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"product_id": "PROD001",
|
||||
"name": "Gaming Laptop",
|
||||
"price": "1299.99",
|
||||
"category": "electronics"
|
||||
}
|
||||
]
|
||||
)
|
||||
elif "book" in text.lower():
|
||||
return [
|
||||
{
|
||||
"product_id": "PROD002",
|
||||
"name": "Python Programming Guide",
|
||||
"price": "49.99",
|
||||
"category": "books"
|
||||
}
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{
|
||||
"product_id": "PROD002",
|
||||
"name": "Python Programming Guide",
|
||||
"price": "49.99",
|
||||
"category": "books"
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
return []
|
||||
|
||||
return []
|
||||
return PromptResult(response_type="jsonl", objects=[])
|
||||
|
||||
return PromptResult(response_type="jsonl", objects=[])
|
||||
|
||||
prompt_client.extract_objects.side_effect = mock_extract_objects
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
from trustgraph.prompt.template.service import Processor
|
||||
from trustgraph.schema import PromptRequest, PromptResponse, TextCompletionResponse
|
||||
from trustgraph.base.text_completion_client import TextCompletionResult
|
||||
from tests.utils.streaming_assertions import (
|
||||
assert_streaming_chunks_valid,
|
||||
assert_callback_invoked,
|
||||
|
|
@ -27,34 +28,52 @@ class TestPromptStreaming:
|
|||
# Mock text completion client with streaming
|
||||
text_completion_client = AsyncMock()
|
||||
|
||||
async def streaming_request(request, recipient=None, timeout=600):
|
||||
"""Simulate streaming text completion"""
|
||||
if request.streaming and recipient:
|
||||
# Simulate streaming chunks
|
||||
chunks = [
|
||||
"Machine", " learning", " is", " a", " field",
|
||||
" of", " artificial", " intelligence", "."
|
||||
]
|
||||
# Streaming chunks to send
|
||||
chunks = [
|
||||
"Machine", " learning", " is", " a", " field",
|
||||
" of", " artificial", " intelligence", "."
|
||||
]
|
||||
|
||||
for i, chunk_text in enumerate(chunks):
|
||||
is_final = (i == len(chunks) - 1)
|
||||
response = TextCompletionResponse(
|
||||
response=chunk_text,
|
||||
error=None,
|
||||
end_of_stream=is_final
|
||||
)
|
||||
final = await recipient(response)
|
||||
if final:
|
||||
break
|
||||
|
||||
# Final empty chunk
|
||||
await recipient(TextCompletionResponse(
|
||||
response="",
|
||||
async def streaming_text_completion_stream(system, prompt, handler, timeout=600):
|
||||
"""Simulate streaming text completion via text_completion_stream"""
|
||||
for i, chunk_text in enumerate(chunks):
|
||||
response = TextCompletionResponse(
|
||||
response=chunk_text,
|
||||
error=None,
|
||||
end_of_stream=True
|
||||
))
|
||||
end_of_stream=False
|
||||
)
|
||||
await handler(response)
|
||||
|
||||
text_completion_client.request = streaming_request
|
||||
# Send final empty chunk with end_of_stream
|
||||
await handler(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=True
|
||||
))
|
||||
|
||||
return TextCompletionResult(
|
||||
text=None,
|
||||
in_token=10,
|
||||
out_token=9,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
async def non_streaming_text_completion(system, prompt, timeout=600):
|
||||
"""Simulate non-streaming text completion"""
|
||||
full_text = "Machine learning is a field of artificial intelligence."
|
||||
return TextCompletionResult(
|
||||
text=full_text,
|
||||
in_token=10,
|
||||
out_token=9,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
text_completion_client.text_completion_stream = AsyncMock(
|
||||
side_effect=streaming_text_completion_stream
|
||||
)
|
||||
text_completion_client.text_completion = AsyncMock(
|
||||
side_effect=non_streaming_text_completion
|
||||
)
|
||||
|
||||
# Mock response producer
|
||||
response_producer = AsyncMock()
|
||||
|
|
@ -156,14 +175,6 @@ class TestPromptStreaming:
|
|||
|
||||
consumer = MagicMock()
|
||||
|
||||
# Mock non-streaming text completion
|
||||
text_completion_client = mock_flow_context_streaming("text-completion-request")
|
||||
|
||||
async def non_streaming_text_completion(system, prompt, streaming=False):
|
||||
return "AI is the simulation of human intelligence in machines."
|
||||
|
||||
text_completion_client.text_completion = non_streaming_text_completion
|
||||
|
||||
# Act
|
||||
await prompt_processor_streaming.on_request(
|
||||
message, consumer, mock_flow_context_streaming
|
||||
|
|
@ -218,17 +229,12 @@ class TestPromptStreaming:
|
|||
# Mock text completion client that raises an error
|
||||
text_completion_client = AsyncMock()
|
||||
|
||||
async def failing_request(request, recipient=None, timeout=600):
|
||||
if recipient:
|
||||
# Send error response with proper Error schema
|
||||
error_response = TextCompletionResponse(
|
||||
response="",
|
||||
error=Error(message="Text completion error", type="processing_error"),
|
||||
end_of_stream=True
|
||||
)
|
||||
await recipient(error_response)
|
||||
async def failing_stream(system, prompt, handler, timeout=600):
|
||||
raise RuntimeError("Text completion error")
|
||||
|
||||
text_completion_client.request = failing_request
|
||||
text_completion_client.text_completion_stream = AsyncMock(
|
||||
side_effect=failing_stream
|
||||
)
|
||||
|
||||
# Mock response producer to capture error response
|
||||
response_producer = AsyncMock()
|
||||
|
|
@ -255,22 +261,15 @@ class TestPromptStreaming:
|
|||
|
||||
consumer = MagicMock()
|
||||
|
||||
# Act - The service catches errors and sends error responses, doesn't raise
|
||||
# Act - The service catches errors and sends an error PromptResponse
|
||||
await prompt_processor_streaming.on_request(message, consumer, context)
|
||||
|
||||
# Assert - Verify error response was sent
|
||||
assert response_producer.send.call_count > 0
|
||||
|
||||
# Check that at least one response contains an error
|
||||
error_sent = False
|
||||
for call in response_producer.send.call_args_list:
|
||||
response = call.args[0]
|
||||
if hasattr(response, 'error') and response.error:
|
||||
error_sent = True
|
||||
assert "Text completion error" in response.error.message
|
||||
break
|
||||
|
||||
assert error_sent, "Expected error response to be sent"
|
||||
# Assert - error response was sent
|
||||
calls = response_producer.send.call_args_list
|
||||
assert len(calls) > 0
|
||||
error_response = calls[-1].args[0]
|
||||
assert error_response.error is not None
|
||||
assert "Text completion error" in error_response.error.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_streaming_preserves_message_id(self, prompt_processor_streaming,
|
||||
|
|
@ -315,21 +314,22 @@ class TestPromptStreaming:
|
|||
# Mock text completion that sends empty chunks
|
||||
text_completion_client = AsyncMock()
|
||||
|
||||
async def empty_streaming_request(request, recipient=None, timeout=600):
|
||||
if request.streaming and recipient:
|
||||
# Send empty chunk followed by final marker
|
||||
await recipient(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=False
|
||||
))
|
||||
await recipient(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=True
|
||||
))
|
||||
async def empty_streaming(system, prompt, handler, timeout=600):
|
||||
# Send empty chunk followed by final marker
|
||||
await handler(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=False
|
||||
))
|
||||
await handler(TextCompletionResponse(
|
||||
response="",
|
||||
error=None,
|
||||
end_of_stream=True
|
||||
))
|
||||
|
||||
text_completion_client.request = empty_streaming_request
|
||||
text_completion_client.text_completion_stream = AsyncMock(
|
||||
side_effect=empty_streaming
|
||||
)
|
||||
response_producer = AsyncMock()
|
||||
|
||||
def context_router(service_name):
|
||||
|
|
@ -401,4 +401,4 @@ class TestPromptStreaming:
|
|||
|
||||
# Verify chunks concatenate to expected result
|
||||
full_text = "".join(chunk_texts)
|
||||
assert full_text == "Machine learning is a field of artificial intelligence"
|
||||
assert full_text == "Machine learning is a field of artificial intelligence."
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, call
|
|||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
class TestGraphRagStreamingProtocol:
|
||||
|
|
@ -46,8 +47,7 @@ class TestGraphRagStreamingProtocol:
|
|||
|
||||
async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
# Edge selection returns empty (no edges selected)
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
# Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True
|
||||
|
|
@ -55,10 +55,10 @@ class TestGraphRagStreamingProtocol:
|
|||
await chunk_callback(" answer", False)
|
||||
await chunk_callback(" is here.", False)
|
||||
await chunk_callback("", True) # Empty final chunk with end_of_stream=True
|
||||
return "" # Return value not used since callback handles everything
|
||||
return PromptResult(response_type="text", text="")
|
||||
else:
|
||||
return "The answer is here."
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="The answer is here.")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
client.prompt.side_effect = prompt_side_effect
|
||||
return client
|
||||
|
|
@ -237,11 +237,13 @@ class TestDocumentRagStreamingProtocol:
|
|||
await chunk_callback("Document", False)
|
||||
await chunk_callback(" summary", False)
|
||||
await chunk_callback(".", True) # Non-empty final chunk
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
else:
|
||||
return "Document summary."
|
||||
return PromptResult(response_type="text", text="Document summary.")
|
||||
|
||||
client.document_prompt.side_effect = document_prompt_side_effect
|
||||
# Mock prompt() for extract-concepts call in DocumentRag
|
||||
client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -334,17 +336,17 @@ class TestStreamingProtocolEdgeCases:
|
|||
|
||||
async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "kg-edge-selection":
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-synthesis":
|
||||
if streaming and chunk_callback:
|
||||
await chunk_callback("text", False)
|
||||
await chunk_callback("", False) # Empty but not final
|
||||
await chunk_callback("more", False)
|
||||
await chunk_callback("", True) # Empty and final
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="")
|
||||
else:
|
||||
return "textmore"
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="textmore")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
client.prompt.side_effect = prompt_with_empties
|
||||
|
||||
|
|
|
|||
|
|
@ -4,14 +4,11 @@ python_paths = .
|
|||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
addopts =
|
||||
addopts =
|
||||
-v
|
||||
--tb=short
|
||||
--strict-markers
|
||||
--disable-warnings
|
||||
--cov=trustgraph
|
||||
--cov-report=html
|
||||
--cov-report=term-missing
|
||||
# --cov-fail-under=80
|
||||
asyncio_mode = auto
|
||||
markers =
|
||||
|
|
|
|||
|
|
@ -78,10 +78,10 @@ class TestAgentServiceNonStreaming:
|
|||
|
||||
# Filter out explain events — those are always sent now
|
||||
content_responses = [
|
||||
r for r in sent_responses if r.chunk_type != "explain"
|
||||
r for r in sent_responses if r.message_type != "explain"
|
||||
]
|
||||
explain_responses = [
|
||||
r for r in sent_responses if r.chunk_type == "explain"
|
||||
r for r in sent_responses if r.message_type == "explain"
|
||||
]
|
||||
|
||||
# Should have explain events for session, iteration, observation, and final
|
||||
|
|
@ -93,7 +93,7 @@ class TestAgentServiceNonStreaming:
|
|||
# Check thought message
|
||||
thought_response = content_responses[0]
|
||||
assert isinstance(thought_response, AgentResponse)
|
||||
assert thought_response.chunk_type == "thought"
|
||||
assert thought_response.message_type == "thought"
|
||||
assert thought_response.content == "I need to solve this."
|
||||
assert thought_response.end_of_message is True, "Thought message must have end_of_message=True"
|
||||
assert thought_response.end_of_dialog is False, "Thought message must have end_of_dialog=False"
|
||||
|
|
@ -101,7 +101,7 @@ class TestAgentServiceNonStreaming:
|
|||
# Check observation message
|
||||
observation_response = content_responses[1]
|
||||
assert isinstance(observation_response, AgentResponse)
|
||||
assert observation_response.chunk_type == "observation"
|
||||
assert observation_response.message_type == "observation"
|
||||
assert observation_response.content == "The answer is 4."
|
||||
assert observation_response.end_of_message is True, "Observation message must have end_of_message=True"
|
||||
assert observation_response.end_of_dialog is False, "Observation message must have end_of_dialog=False"
|
||||
|
|
@ -168,10 +168,10 @@ class TestAgentServiceNonStreaming:
|
|||
|
||||
# Filter out explain events — those are always sent now
|
||||
content_responses = [
|
||||
r for r in sent_responses if r.chunk_type != "explain"
|
||||
r for r in sent_responses if r.message_type != "explain"
|
||||
]
|
||||
explain_responses = [
|
||||
r for r in sent_responses if r.chunk_type == "explain"
|
||||
r for r in sent_responses if r.message_type == "explain"
|
||||
]
|
||||
|
||||
# Should have explain events for session and final
|
||||
|
|
@ -183,7 +183,7 @@ class TestAgentServiceNonStreaming:
|
|||
# Check final answer message
|
||||
answer_response = content_responses[0]
|
||||
assert isinstance(answer_response, AgentResponse)
|
||||
assert answer_response.chunk_type == "answer"
|
||||
assert answer_response.message_type == "answer"
|
||||
assert answer_response.content == "4"
|
||||
assert answer_response.end_of_message is True, "Final answer must have end_of_message=True"
|
||||
assert answer_response.end_of_dialog is True, "Final answer must have end_of_dialog=True"
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class TestThinkCallbackMessageId:
|
|||
|
||||
assert len(responses) == 1
|
||||
assert responses[0].message_id == msg_id
|
||||
assert responses[0].chunk_type == "thought"
|
||||
assert responses[0].message_type == "thought"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_think_has_message_id(self, pattern):
|
||||
|
|
@ -58,7 +58,7 @@ class TestObserveCallbackMessageId:
|
|||
await observe("result", is_final=True)
|
||||
|
||||
assert responses[0].message_id == msg_id
|
||||
assert responses[0].chunk_type == "observation"
|
||||
assert responses[0].message_type == "observation"
|
||||
|
||||
|
||||
class TestAnswerCallbackMessageId:
|
||||
|
|
@ -74,7 +74,7 @@ class TestAnswerCallbackMessageId:
|
|||
await answer("the answer")
|
||||
|
||||
assert responses[0].message_id == msg_id
|
||||
assert responses[0].chunk_type == "answer"
|
||||
assert responses[0].message_type == "answer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_message_id_default(self, pattern):
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
from trustgraph.agent.orchestrator.meta_router import (
|
||||
MetaRouter, DEFAULT_PATTERN, DEFAULT_TASK_TYPE,
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
def _make_config(patterns=None, task_types=None):
|
||||
|
|
@ -28,7 +29,9 @@ def _make_config(patterns=None, task_types=None):
|
|||
def _make_context(prompt_response):
|
||||
"""Build a mock context that returns a mock prompt client."""
|
||||
client = AsyncMock()
|
||||
client.prompt = AsyncMock(return_value=prompt_response)
|
||||
client.prompt = AsyncMock(
|
||||
return_value=PromptResult(response_type="text", text=prompt_response)
|
||||
)
|
||||
|
||||
def context(service_name):
|
||||
return client
|
||||
|
|
@ -274,8 +277,8 @@ class TestRoute:
|
|||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return "research" # task type
|
||||
return "plan-then-execute" # pattern
|
||||
return PromptResult(response_type="text", text="research")
|
||||
return PromptResult(response_type="text", text="plan-then-execute")
|
||||
|
||||
client.prompt = mock_prompt
|
||||
context = lambda name: client
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from dataclasses import dataclass, field
|
|||
from trustgraph.schema import (
|
||||
AgentRequest, AgentResponse, AgentStep, PlanStep,
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
|
|
@ -68,7 +69,7 @@ def collect_explain_events(respond_mock):
|
|||
events = []
|
||||
for call in respond_mock.call_args_list:
|
||||
resp = call[0][0]
|
||||
if isinstance(resp, AgentResponse) and resp.chunk_type == "explain":
|
||||
if isinstance(resp, AgentResponse) and resp.message_type == "explain":
|
||||
events.append({
|
||||
"explain_id": resp.explain_id,
|
||||
"explain_graph": resp.explain_graph,
|
||||
|
|
@ -183,7 +184,7 @@ class TestReactPatternProvenance:
|
|||
)
|
||||
|
||||
async def mock_react(question, history, think, observe, answer,
|
||||
context, streaming, on_action):
|
||||
context, streaming, on_action, **kwargs):
|
||||
# Simulate the on_action callback before returning Final
|
||||
if on_action:
|
||||
await on_action(Action(
|
||||
|
|
@ -267,7 +268,7 @@ class TestReactPatternProvenance:
|
|||
MockAM.return_value = mock_am
|
||||
|
||||
async def mock_react(question, history, think, observe, answer,
|
||||
context, streaming, on_action):
|
||||
context, streaming, on_action, **kwargs):
|
||||
if on_action:
|
||||
await on_action(action)
|
||||
return action
|
||||
|
|
@ -309,7 +310,7 @@ class TestReactPatternProvenance:
|
|||
MockAM.return_value = mock_am
|
||||
|
||||
async def mock_react(question, history, think, observe, answer,
|
||||
context, streaming, on_action):
|
||||
context, streaming, on_action, **kwargs):
|
||||
if on_action:
|
||||
await on_action(Action(
|
||||
thought="done", name="final",
|
||||
|
|
@ -355,10 +356,13 @@ class TestPlanPatternProvenance:
|
|||
|
||||
# Mock prompt client for plan creation
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = [
|
||||
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
|
||||
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
|
||||
]
|
||||
mock_prompt_client.prompt.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []},
|
||||
{"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]},
|
||||
],
|
||||
)
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -418,10 +422,13 @@ class TestPlanPatternProvenance:
|
|||
|
||||
# Mock prompt for step execution
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = {
|
||||
"tool": "knowledge-query",
|
||||
"arguments": {"question": "quantum computing"},
|
||||
}
|
||||
mock_prompt_client.prompt.return_value = PromptResult(
|
||||
response_type="json",
|
||||
object={
|
||||
"tool": "knowledge-query",
|
||||
"arguments": {"question": "quantum computing"},
|
||||
},
|
||||
)
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -475,7 +482,7 @@ class TestPlanPatternProvenance:
|
|||
|
||||
# Mock prompt for synthesis
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = "The synthesised answer."
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The synthesised answer.")
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -542,10 +549,13 @@ class TestSupervisorPatternProvenance:
|
|||
|
||||
# Mock prompt for decomposition
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = [
|
||||
"What is quantum computing?",
|
||||
"What are qubits?",
|
||||
]
|
||||
mock_prompt_client.prompt.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
"What is quantum computing?",
|
||||
"What are qubits?",
|
||||
],
|
||||
)
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -590,7 +600,7 @@ class TestSupervisorPatternProvenance:
|
|||
|
||||
# Mock prompt for synthesis
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = "The combined answer."
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The combined answer.")
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
@ -639,7 +649,10 @@ class TestSupervisorPatternProvenance:
|
|||
flow = make_mock_flow()
|
||||
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = ["Goal A", "Goal B", "Goal C"]
|
||||
mock_prompt_client.prompt.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=["Goal A", "Goal B", "Goal C"],
|
||||
)
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class TestParseChunkMessageId:
|
|||
|
||||
def test_thought_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "thought",
|
||||
"message_type": "thought",
|
||||
"content": "thinking...",
|
||||
"end_of_message": False,
|
||||
"message_id": "urn:trustgraph:agent:sess/i1/thought",
|
||||
|
|
@ -31,7 +31,7 @@ class TestParseChunkMessageId:
|
|||
|
||||
def test_observation_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "observation",
|
||||
"message_type": "observation",
|
||||
"content": "result",
|
||||
"end_of_message": True,
|
||||
"message_id": "urn:trustgraph:agent:sess/i1/observation",
|
||||
|
|
@ -42,7 +42,7 @@ class TestParseChunkMessageId:
|
|||
|
||||
def test_answer_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "answer",
|
||||
"message_type": "answer",
|
||||
"content": "the answer",
|
||||
"end_of_message": False,
|
||||
"end_of_dialog": False,
|
||||
|
|
@ -54,7 +54,7 @@ class TestParseChunkMessageId:
|
|||
|
||||
def test_thought_missing_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "thought",
|
||||
"message_type": "thought",
|
||||
"content": "thinking...",
|
||||
"end_of_message": False,
|
||||
}
|
||||
|
|
@ -64,7 +64,7 @@ class TestParseChunkMessageId:
|
|||
|
||||
def test_answer_missing_message_id(self, client):
|
||||
resp = {
|
||||
"chunk_type": "answer",
|
||||
"message_type": "answer",
|
||||
"content": "answer",
|
||||
"end_of_message": True,
|
||||
"end_of_dialog": True,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,14 @@ from trustgraph.base import metrics
|
|||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_metric_singletons():
|
||||
"""Temporarily remove metric singletons so each test can inject mocks.
|
||||
|
||||
Saves any existing class-level metrics and restores them after the test
|
||||
so that later tests in the same process still find the hasattr() guard
|
||||
intact — deleting without restoring causes every subsequent Processor()
|
||||
construction to re-register the same Prometheus metric name, which raises
|
||||
ValueError: Duplicated timeseries.
|
||||
"""
|
||||
classes_and_attrs = {
|
||||
metrics.ConsumerMetrics: [
|
||||
"state_metric",
|
||||
|
|
@ -23,18 +31,24 @@ def reset_metric_singletons():
|
|||
],
|
||||
}
|
||||
|
||||
saved = {}
|
||||
for cls, attrs in classes_and_attrs.items():
|
||||
for attr in attrs:
|
||||
if hasattr(cls, attr):
|
||||
saved[(cls, attr)] = getattr(cls, attr)
|
||||
delattr(cls, attr)
|
||||
|
||||
yield
|
||||
|
||||
# Remove anything the test may have set, then restore originals
|
||||
for cls, attrs in classes_and_attrs.items():
|
||||
for attr in attrs:
|
||||
if hasattr(cls, attr):
|
||||
delattr(cls, attr)
|
||||
|
||||
for (cls, attr), value in saved.items():
|
||||
setattr(cls, attr, value)
|
||||
|
||||
|
||||
def test_consumer_metrics_reuses_singletons_and_records_events(monkeypatch):
|
||||
enum_factory = MagicMock()
|
||||
|
|
|
|||
|
|
@ -236,6 +236,10 @@ async def test_subscriber_graceful_shutdown():
|
|||
with patch.object(subscriber, 'run') as mock_run:
|
||||
# Mock run that simulates graceful shutdown
|
||||
async def mock_run_graceful():
|
||||
# Honor the readiness contract: real run() signals _ready
|
||||
# after binding the consumer, so start() can unblock. Mocks
|
||||
# of run() must do the same or start() hangs forever.
|
||||
subscriber._ready.set_result(None)
|
||||
# Process messages while running, then drain
|
||||
while subscriber.running or subscriber.draining:
|
||||
if subscriber.draining:
|
||||
|
|
@ -337,6 +341,8 @@ async def test_subscriber_pending_acks_cleanup():
|
|||
with patch.object(subscriber, 'run') as mock_run:
|
||||
# Mock run that simulates cleanup of pending acks
|
||||
async def mock_run_cleanup():
|
||||
# Honor the readiness contract — see test_subscriber_graceful_shutdown.
|
||||
subscriber._ready.set_result(None)
|
||||
while subscriber.running or subscriber.draining:
|
||||
await asyncio.sleep(0.05)
|
||||
if subscriber.draining:
|
||||
|
|
@ -406,4 +412,4 @@ async def test_subscriber_multiple_subscribers():
|
|||
msg1 = await queue1.get()
|
||||
msg_all = await queue_all.get()
|
||||
assert msg1 == {"data": "broadcast"}
|
||||
assert msg_all == {"data": "broadcast"}
|
||||
assert msg_all == {"data": "broadcast"}
|
||||
|
|
|
|||
189
tests/unit/test_base/test_subscriber_readiness.py
Normal file
189
tests/unit/test_base/test_subscriber_readiness.py
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
"""
|
||||
Regression tests for Subscriber.start() readiness barrier.
|
||||
|
||||
Background: prior to the eager-connect fix, Subscriber.start() created
|
||||
the run() task and returned immediately. The underlying backend consumer
|
||||
was lazily connected on its first receive() call, which left a setup
|
||||
race for request/response clients using ephemeral per-subscriber response
|
||||
queues (RabbitMQ auto-delete exclusive queues): the request would be
|
||||
published before the response queue was bound, and the broker would
|
||||
silently drop the reply. fetch_config(), document-embeddings, and
|
||||
api-gateway all hit this with "Failed to fetch config on notify" /
|
||||
"Request timeout exception" symptoms.
|
||||
|
||||
These tests pin the readiness contract:
|
||||
|
||||
await subscriber.start()
|
||||
# at this point, consumer.ensure_connected() MUST have run
|
||||
|
||||
so that any future change which removes the eager bind, or moves it
|
||||
back to lazy initialisation, fails CI loudly.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from trustgraph.base.subscriber import Subscriber
|
||||
|
||||
|
||||
def _make_backend(ensure_connected_side_effect=None,
|
||||
receive_side_effect=None):
|
||||
"""Build a fake backend whose consumer records ensure_connected /
|
||||
receive calls. ensure_connected_side_effect lets a test inject a
|
||||
delay or exception."""
|
||||
backend = MagicMock()
|
||||
consumer = MagicMock()
|
||||
|
||||
consumer.ensure_connected = MagicMock(
|
||||
side_effect=ensure_connected_side_effect,
|
||||
)
|
||||
|
||||
# By default receive raises a timeout-style exception that the
|
||||
# subscriber loop is supposed to swallow as a "no message yet" — this
|
||||
# keeps the subscriber idling cleanly while the test inspects state.
|
||||
if receive_side_effect is None:
|
||||
receive_side_effect = TimeoutError("No message received within timeout")
|
||||
consumer.receive = MagicMock(side_effect=receive_side_effect)
|
||||
|
||||
consumer.acknowledge = MagicMock()
|
||||
consumer.negative_acknowledge = MagicMock()
|
||||
consumer.pause_message_listener = MagicMock()
|
||||
consumer.unsubscribe = MagicMock()
|
||||
consumer.close = MagicMock()
|
||||
|
||||
backend.create_consumer.return_value = consumer
|
||||
return backend, consumer
|
||||
|
||||
|
||||
def _make_subscriber(backend):
|
||||
return Subscriber(
|
||||
backend=backend,
|
||||
topic="response:tg:config",
|
||||
subscription="test-sub",
|
||||
consumer_name="test-consumer",
|
||||
schema=dict,
|
||||
max_size=10,
|
||||
drain_timeout=1.0,
|
||||
backpressure_strategy="block",
|
||||
)
|
||||
|
||||
|
||||
class TestSubscriberReadiness:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_calls_ensure_connected_before_returning(self):
|
||||
"""The barrier: ensure_connected must have been invoked at least
|
||||
once by the time start() returns."""
|
||||
backend, consumer = _make_backend()
|
||||
subscriber = _make_subscriber(backend)
|
||||
|
||||
await subscriber.start()
|
||||
|
||||
try:
|
||||
consumer.ensure_connected.assert_called_once()
|
||||
finally:
|
||||
await subscriber.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_blocks_until_ensure_connected_completes(self):
|
||||
"""If ensure_connected is slow, start() must wait for it. This is
|
||||
the actual race-condition guard — it would have failed against
|
||||
the buggy version where start() returned before run() had even
|
||||
scheduled the consumer creation."""
|
||||
connect_started = asyncio.Event()
|
||||
release_connect = asyncio.Event()
|
||||
|
||||
# ensure_connected runs in the executor thread, so we need a
|
||||
# threading-safe gate. Use a simple busy-wait on a flag set by
|
||||
# the asyncio side via call_soon_threadsafe — but the simpler
|
||||
# path is to give it a sleep and observe ordering.
|
||||
import threading
|
||||
gate = threading.Event()
|
||||
|
||||
def slow_connect():
|
||||
connect_started.set() # safe: only mutates the Event flag
|
||||
gate.wait(timeout=2.0)
|
||||
|
||||
backend, consumer = _make_backend(
|
||||
ensure_connected_side_effect=slow_connect,
|
||||
)
|
||||
subscriber = _make_subscriber(backend)
|
||||
|
||||
start_task = asyncio.create_task(subscriber.start())
|
||||
|
||||
# Wait until ensure_connected has begun executing.
|
||||
await asyncio.wait_for(connect_started.wait(), timeout=2.0)
|
||||
|
||||
# ensure_connected is in flight — start() must NOT have returned.
|
||||
assert not start_task.done(), (
|
||||
"start() returned before ensure_connected() completed — "
|
||||
"the readiness barrier is broken and the request/response "
|
||||
"race condition is back."
|
||||
)
|
||||
|
||||
# Release the gate; start() should now complete promptly.
|
||||
gate.set()
|
||||
await asyncio.wait_for(start_task, timeout=2.0)
|
||||
|
||||
consumer.ensure_connected.assert_called_once()
|
||||
|
||||
await subscriber.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_propagates_consumer_creation_failure(self):
|
||||
"""If create_consumer() raises, start() must surface the error
|
||||
rather than hang on the readiness future. The old code path
|
||||
retried indefinitely inside run() and never let start() unblock."""
|
||||
backend = MagicMock()
|
||||
backend.create_consumer.side_effect = RuntimeError("broker down")
|
||||
|
||||
subscriber = _make_subscriber(backend)
|
||||
|
||||
with pytest.raises(RuntimeError, match="broker down"):
|
||||
await asyncio.wait_for(subscriber.start(), timeout=2.0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_propagates_ensure_connected_failure(self):
|
||||
"""Same contract for an ensure_connected() that raises (e.g. the
|
||||
broker is up but the queue declare/bind fails)."""
|
||||
backend, consumer = _make_backend(
|
||||
ensure_connected_side_effect=RuntimeError("queue declare failed"),
|
||||
)
|
||||
subscriber = _make_subscriber(backend)
|
||||
|
||||
with pytest.raises(RuntimeError, match="queue declare failed"):
|
||||
await asyncio.wait_for(subscriber.start(), timeout=2.0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_connected_runs_before_subscriber_running_log(self):
|
||||
"""Subtle ordering: ensure_connected MUST happen before the
|
||||
receive loop, so that any reply is captured. We assert this by
|
||||
checking ensure_connected was called before any receive call."""
|
||||
call_order = []
|
||||
|
||||
def record_ensure():
|
||||
call_order.append("ensure_connected")
|
||||
|
||||
def record_receive(*args, **kwargs):
|
||||
call_order.append("receive")
|
||||
raise TimeoutError("No message received within timeout")
|
||||
|
||||
backend, consumer = _make_backend(
|
||||
ensure_connected_side_effect=record_ensure,
|
||||
receive_side_effect=record_receive,
|
||||
)
|
||||
subscriber = _make_subscriber(backend)
|
||||
|
||||
await subscriber.start()
|
||||
|
||||
# Give the receive loop a tick to run at least once.
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
await subscriber.stop()
|
||||
|
||||
# ensure_connected must come first; receive may not have happened
|
||||
# yet on a fast machine, but if it did, it must come after.
|
||||
assert call_order, "neither ensure_connected nor receive was called"
|
||||
assert call_order[0] == "ensure_connected"
|
||||
|
|
@ -70,11 +70,12 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
# Mock message and flow
|
||||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
# Flow exposes parameter lookup via __call__: flow("chunk-size")
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
mock_flow.side_effect = lambda key: {
|
||||
"chunk-size": 2000, # Override chunk size
|
||||
"chunk-overlap": None # Use default chunk overlap
|
||||
}.get(param)
|
||||
}.get(key)
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
|
|
@ -105,10 +106,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
mock_flow.side_effect = lambda key: {
|
||||
"chunk-size": None, # Use default chunk size
|
||||
"chunk-overlap": 200 # Override chunk overlap
|
||||
}.get(param)
|
||||
}.get(key)
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
|
|
@ -139,10 +140,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
mock_flow.side_effect = lambda key: {
|
||||
"chunk-size": 1500, # Override chunk size
|
||||
"chunk-overlap": 150 # Override chunk overlap
|
||||
}.get(param)
|
||||
}.get(key)
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
|
|
@ -195,15 +196,15 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_consumer = MagicMock()
|
||||
mock_producer = AsyncMock()
|
||||
mock_triples_producer = AsyncMock()
|
||||
# Flow.__call__ resolves parameters and producers/consumers from the
|
||||
# same dict — merge both kinds here.
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
mock_flow.side_effect = lambda key: {
|
||||
"chunk-size": 1500,
|
||||
"chunk-overlap": 150,
|
||||
}.get(param)
|
||||
mock_flow.side_effect = lambda name: {
|
||||
"output": mock_producer,
|
||||
"triples": mock_triples_producer,
|
||||
}.get(name)
|
||||
}.get(key)
|
||||
|
||||
# Act
|
||||
await processor.on_message(mock_message, mock_consumer, mock_flow)
|
||||
|
|
@ -241,7 +242,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.parameters.get.return_value = None # No overrides
|
||||
mock_flow.side_effect = lambda key: None # No overrides
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
|
|
|
|||
|
|
@ -70,11 +70,12 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
# Mock message and flow
|
||||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
# Flow exposes parameter lookup via __call__: flow("chunk-size")
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
mock_flow.side_effect = lambda key: {
|
||||
"chunk-size": 400, # Override chunk size
|
||||
"chunk-overlap": None # Use default chunk overlap
|
||||
}.get(param)
|
||||
}.get(key)
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
|
|
@ -105,10 +106,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
mock_flow.side_effect = lambda key: {
|
||||
"chunk-size": None, # Use default chunk size
|
||||
"chunk-overlap": 25 # Override chunk overlap
|
||||
}.get(param)
|
||||
}.get(key)
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
|
|
@ -139,10 +140,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
mock_flow.side_effect = lambda key: {
|
||||
"chunk-size": 350, # Override chunk size
|
||||
"chunk-overlap": 30 # Override chunk overlap
|
||||
}.get(param)
|
||||
}.get(key)
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
|
|
@ -195,15 +196,15 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_consumer = MagicMock()
|
||||
mock_producer = AsyncMock()
|
||||
mock_triples_producer = AsyncMock()
|
||||
# Flow.__call__ resolves parameters and producers/consumers from the
|
||||
# same dict — merge both kinds here.
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.parameters.get.side_effect = lambda param: {
|
||||
mock_flow.side_effect = lambda key: {
|
||||
"chunk-size": 400,
|
||||
"chunk-overlap": 40,
|
||||
}.get(param)
|
||||
mock_flow.side_effect = lambda name: {
|
||||
"output": mock_producer,
|
||||
"triples": mock_triples_producer,
|
||||
}.get(name)
|
||||
}.get(key)
|
||||
|
||||
# Act
|
||||
await processor.on_message(mock_message, mock_consumer, mock_flow)
|
||||
|
|
@ -245,7 +246,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
mock_message = MagicMock()
|
||||
mock_consumer = MagicMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.parameters.get.return_value = None # No overrides
|
||||
mock_flow.side_effect = lambda key: None # No overrides
|
||||
|
||||
# Act
|
||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from unittest.mock import AsyncMock, MagicMock
|
|||
from trustgraph.extract.kg.definitions.extract import (
|
||||
Processor, default_triples_batch_size, default_entity_batch_size,
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
from trustgraph.schema import (
|
||||
Chunk, Triples, EntityContexts, Triple, Metadata, Term, IRI, LITERAL,
|
||||
)
|
||||
|
|
@ -51,8 +52,12 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
|
|||
mock_triples_pub = AsyncMock()
|
||||
mock_ecs_pub = AsyncMock()
|
||||
mock_prompt_client = AsyncMock()
|
||||
if isinstance(prompt_result, list):
|
||||
wrapped = PromptResult(response_type="jsonl", objects=prompt_result)
|
||||
else:
|
||||
wrapped = PromptResult(response_type="text", text=prompt_result)
|
||||
mock_prompt_client.extract_definitions = AsyncMock(
|
||||
return_value=prompt_result
|
||||
return_value=wrapped
|
||||
)
|
||||
|
||||
def flow(name):
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from trustgraph.extract.kg.relationships.extract import (
|
|||
from trustgraph.schema import (
|
||||
Chunk, Triples, Triple, Metadata, Term, IRI, LITERAL,
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -58,7 +59,10 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"):
|
|||
mock_triples_pub = AsyncMock()
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.extract_relationships = AsyncMock(
|
||||
return_value=prompt_result
|
||||
return_value=PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=prompt_result,
|
||||
)
|
||||
)
|
||||
|
||||
def flow(name):
|
||||
|
|
|
|||
418
tests/unit/test_gateway/test_core_import_export_roundtrip.py
Normal file
418
tests/unit/test_gateway/test_core_import_export_roundtrip.py
Normal file
|
|
@ -0,0 +1,418 @@
|
|||
"""
|
||||
Round-trip unit tests for the core msgpack import/export gateway endpoints.
|
||||
|
||||
The kg-core export endpoint receives KnowledgeResponse-shaped dicts from
|
||||
the responder callback and packs them into msgpack tuples. The kg-core
|
||||
import endpoint takes msgpack tuples back off the wire and rebuilds
|
||||
KnowledgeRequest-shaped dicts which it then hands to KnowledgeRequestor
|
||||
(whose translator decodes them into real dataclasses).
|
||||
|
||||
Regression coverage: the previous wire format used `"vectors"` (plural)
|
||||
in the entity blobs and embedded a stale `"m"` field that referenced the
|
||||
removed `Metadata.metadata` triples-list field. The export side hit a
|
||||
KeyError on first message; the import side built dicts that the
|
||||
KnowledgeRequestTranslator (separately fixed) couldn't decode. These
|
||||
tests pin both halves of the wire protocol.
|
||||
"""
|
||||
|
||||
import msgpack
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from trustgraph.gateway.dispatch.core_export import CoreExport
|
||||
from trustgraph.gateway.dispatch.core_import import CoreImport
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers — sample translator-shaped dicts (as KnowledgeResponseTranslator
|
||||
# would emit). The vector wire key is *singular* on purpose; the export
|
||||
# side previously read the wrong key and crashed.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ge_response_dict():
|
||||
return {
|
||||
"graph-embeddings": {
|
||||
"metadata": {
|
||||
"id": "doc-1",
|
||||
"root": "",
|
||||
"user": "alice",
|
||||
"collection": "testcoll",
|
||||
},
|
||||
"entities": [
|
||||
{
|
||||
"entity": {"t": "i", "i": "http://example.org/alice"},
|
||||
"vector": [0.1, 0.2, 0.3],
|
||||
},
|
||||
{
|
||||
"entity": {"t": "i", "i": "http://example.org/bob"},
|
||||
"vector": [0.4, 0.5, 0.6],
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _triples_response_dict():
|
||||
return {
|
||||
"triples": {
|
||||
"metadata": {
|
||||
"id": "doc-1",
|
||||
"root": "",
|
||||
"user": "alice",
|
||||
"collection": "testcoll",
|
||||
},
|
||||
"triples": [
|
||||
{
|
||||
"s": {"t": "i", "i": "http://example.org/alice"},
|
||||
"p": {"t": "i", "i": "http://example.org/knows"},
|
||||
"o": {"t": "i", "i": "http://example.org/bob"},
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _make_request(id_="doc-1", user="alice"):
|
||||
request = Mock()
|
||||
request.query = {"id": id_, "user": user}
|
||||
return request
|
||||
|
||||
|
||||
def _make_data_reader(payload: bytes):
|
||||
"""Mock the aiohttp StreamReader: returns payload once, then EOF."""
|
||||
chunks = [payload, b""]
|
||||
|
||||
data = Mock()
|
||||
|
||||
async def fake_read(n):
|
||||
return chunks.pop(0) if chunks else b""
|
||||
|
||||
data.read = fake_read
|
||||
return data
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Export side: translator-shaped dict -> msgpack bytes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCoreExportWireFormat:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
|
||||
async def test_export_packs_graph_embeddings_with_singular_vector(
|
||||
self, mock_kr_class,
|
||||
):
|
||||
"""The export side must read `ent["vector"]` and emit `v`. The
|
||||
previous bug was reading `ent["vectors"]` which KeyErrored against
|
||||
the translator output."""
|
||||
captured = []
|
||||
|
||||
async def fake_kr_process(req_dict, responder):
|
||||
await responder(_ge_response_dict(), True)
|
||||
|
||||
mock_kr = AsyncMock()
|
||||
mock_kr.start = AsyncMock()
|
||||
mock_kr.stop = AsyncMock()
|
||||
mock_kr.process = fake_kr_process
|
||||
mock_kr_class.return_value = mock_kr
|
||||
|
||||
response = AsyncMock()
|
||||
|
||||
async def fake_write(b):
|
||||
captured.append(b)
|
||||
|
||||
response.write = fake_write
|
||||
response.write_eof = AsyncMock()
|
||||
|
||||
ok = AsyncMock(return_value=response)
|
||||
error = AsyncMock()
|
||||
|
||||
exporter = CoreExport(backend=Mock())
|
||||
await exporter.process(
|
||||
data=Mock(),
|
||||
error=error,
|
||||
ok=ok,
|
||||
request=_make_request(),
|
||||
)
|
||||
|
||||
# Did not raise, did not call error()
|
||||
error.assert_not_called()
|
||||
assert len(captured) == 1
|
||||
|
||||
unpacker = msgpack.Unpacker()
|
||||
unpacker.feed(captured[0])
|
||||
items = list(unpacker)
|
||||
|
||||
assert len(items) == 1
|
||||
msg_type, payload = items[0]
|
||||
assert msg_type == "ge"
|
||||
|
||||
# Metadata envelope: only id/user/collection — no stale `m["m"]`.
|
||||
assert payload["m"] == {
|
||||
"i": "doc-1",
|
||||
"u": "alice",
|
||||
"c": "testcoll",
|
||||
}
|
||||
|
||||
# Entities: each carries the *singular* `v` and the term envelope
|
||||
assert len(payload["e"]) == 2
|
||||
assert payload["e"][0]["v"] == [0.1, 0.2, 0.3]
|
||||
assert payload["e"][1]["v"] == [0.4, 0.5, 0.6]
|
||||
assert payload["e"][0]["e"]["i"] == "http://example.org/alice"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
|
||||
async def test_export_packs_triples(self, mock_kr_class):
|
||||
captured = []
|
||||
|
||||
async def fake_kr_process(req_dict, responder):
|
||||
await responder(_triples_response_dict(), True)
|
||||
|
||||
mock_kr = AsyncMock()
|
||||
mock_kr.start = AsyncMock()
|
||||
mock_kr.stop = AsyncMock()
|
||||
mock_kr.process = fake_kr_process
|
||||
mock_kr_class.return_value = mock_kr
|
||||
|
||||
response = AsyncMock()
|
||||
|
||||
async def fake_write(b):
|
||||
captured.append(b)
|
||||
|
||||
response.write = fake_write
|
||||
response.write_eof = AsyncMock()
|
||||
|
||||
ok = AsyncMock(return_value=response)
|
||||
error = AsyncMock()
|
||||
|
||||
exporter = CoreExport(backend=Mock())
|
||||
await exporter.process(
|
||||
data=Mock(), error=error, ok=ok, request=_make_request(),
|
||||
)
|
||||
|
||||
error.assert_not_called()
|
||||
assert len(captured) == 1
|
||||
|
||||
unpacker = msgpack.Unpacker()
|
||||
unpacker.feed(captured[0])
|
||||
items = list(unpacker)
|
||||
assert len(items) == 1
|
||||
|
||||
msg_type, payload = items[0]
|
||||
assert msg_type == "t"
|
||||
assert payload["m"] == {
|
||||
"i": "doc-1",
|
||||
"u": "alice",
|
||||
"c": "testcoll",
|
||||
}
|
||||
assert len(payload["t"]) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import side: msgpack bytes -> translator-shaped dict
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCoreImportWireFormat:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
|
||||
async def test_import_unpacks_graph_embeddings_to_singular_vector(
|
||||
self, mock_kr_class,
|
||||
):
|
||||
"""The import side must build dicts whose entity blobs have the
|
||||
singular `vector` key — that's what the KnowledgeRequestTranslator
|
||||
decode side reads. Previous bug emitted `vectors`."""
|
||||
captured = []
|
||||
|
||||
async def fake_kr_process(req_dict):
|
||||
captured.append(req_dict)
|
||||
|
||||
mock_kr = AsyncMock()
|
||||
mock_kr.start = AsyncMock()
|
||||
mock_kr.stop = AsyncMock()
|
||||
mock_kr.process = fake_kr_process
|
||||
mock_kr_class.return_value = mock_kr
|
||||
|
||||
# Build a msgpack tuple matching the new wire format
|
||||
payload = msgpack.packb((
|
||||
"ge",
|
||||
{
|
||||
"m": {"i": "doc-1", "u": "alice", "c": "testcoll"},
|
||||
"e": [
|
||||
{
|
||||
"e": {"t": "i", "i": "http://example.org/alice"},
|
||||
"v": [0.1, 0.2, 0.3],
|
||||
},
|
||||
],
|
||||
},
|
||||
))
|
||||
|
||||
ok = AsyncMock(return_value=AsyncMock(write_eof=AsyncMock()))
|
||||
error = AsyncMock()
|
||||
|
||||
importer = CoreImport(backend=Mock())
|
||||
await importer.process(
|
||||
data=_make_data_reader(payload),
|
||||
error=error,
|
||||
ok=ok,
|
||||
request=_make_request(),
|
||||
)
|
||||
|
||||
error.assert_not_called()
|
||||
assert len(captured) == 1
|
||||
|
||||
req = captured[0]
|
||||
assert req["operation"] == "put-kg-core"
|
||||
assert req["user"] == "alice"
|
||||
assert req["id"] == "doc-1"
|
||||
|
||||
ge = req["graph-embeddings"]
|
||||
# Metadata envelope must NOT contain a stale `metadata` key
|
||||
# referencing the removed Metadata.metadata field.
|
||||
assert "metadata" not in ge["metadata"]
|
||||
assert ge["metadata"] == {
|
||||
"id": "doc-1",
|
||||
"user": "alice",
|
||||
"collection": "default",
|
||||
}
|
||||
|
||||
# Entity blob carries the singular `vector` key
|
||||
assert len(ge["entities"]) == 1
|
||||
ent = ge["entities"][0]
|
||||
assert ent["vector"] == [0.1, 0.2, 0.3]
|
||||
assert "vectors" not in ent
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
|
||||
async def test_import_unpacks_triples(self, mock_kr_class):
|
||||
captured = []
|
||||
|
||||
async def fake_kr_process(req_dict):
|
||||
captured.append(req_dict)
|
||||
|
||||
mock_kr = AsyncMock()
|
||||
mock_kr.start = AsyncMock()
|
||||
mock_kr.stop = AsyncMock()
|
||||
mock_kr.process = fake_kr_process
|
||||
mock_kr_class.return_value = mock_kr
|
||||
|
||||
payload = msgpack.packb((
|
||||
"t",
|
||||
{
|
||||
"m": {"i": "doc-1", "u": "alice", "c": "testcoll"},
|
||||
"t": [
|
||||
{
|
||||
"s": {"t": "i", "i": "http://example.org/alice"},
|
||||
"p": {"t": "i", "i": "http://example.org/knows"},
|
||||
"o": {"t": "i", "i": "http://example.org/bob"},
|
||||
},
|
||||
],
|
||||
},
|
||||
))
|
||||
|
||||
ok = AsyncMock(return_value=AsyncMock(write_eof=AsyncMock()))
|
||||
error = AsyncMock()
|
||||
|
||||
importer = CoreImport(backend=Mock())
|
||||
await importer.process(
|
||||
data=_make_data_reader(payload),
|
||||
error=error,
|
||||
ok=ok,
|
||||
request=_make_request(),
|
||||
)
|
||||
|
||||
error.assert_not_called()
|
||||
assert len(captured) == 1
|
||||
|
||||
req = captured[0]
|
||||
triples = req["triples"]
|
||||
assert "metadata" not in triples["metadata"] # no stale field
|
||||
assert len(triples["triples"]) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Full round-trip: export bytes feed directly into import
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCoreImportExportRoundTrip:
|
||||
"""End-to-end: produce bytes via core_export, consume them via
|
||||
core_import, and verify the dict that lands at the import-side
|
||||
translator is structurally equivalent to what went in. This is the
|
||||
test that catches asymmetries between the two halves."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor")
|
||||
@patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor")
|
||||
async def test_graph_embeddings_round_trip(
|
||||
self, mock_export_kr_class, mock_import_kr_class,
|
||||
):
|
||||
# ----- export side: capture bytes -----
|
||||
export_bytes = []
|
||||
|
||||
async def fake_export_process(req_dict, responder):
|
||||
await responder(_ge_response_dict(), True)
|
||||
|
||||
export_kr = AsyncMock()
|
||||
export_kr.start = AsyncMock()
|
||||
export_kr.stop = AsyncMock()
|
||||
export_kr.process = fake_export_process
|
||||
mock_export_kr_class.return_value = export_kr
|
||||
|
||||
response = AsyncMock()
|
||||
|
||||
async def fake_write(b):
|
||||
export_bytes.append(b)
|
||||
|
||||
response.write = fake_write
|
||||
response.write_eof = AsyncMock()
|
||||
|
||||
exporter = CoreExport(backend=Mock())
|
||||
await exporter.process(
|
||||
data=Mock(),
|
||||
error=AsyncMock(),
|
||||
ok=AsyncMock(return_value=response),
|
||||
request=_make_request(),
|
||||
)
|
||||
|
||||
assert len(export_bytes) == 1
|
||||
|
||||
# ----- import side: feed those bytes back in -----
|
||||
import_captured = []
|
||||
|
||||
async def fake_import_process(req_dict):
|
||||
import_captured.append(req_dict)
|
||||
|
||||
import_kr = AsyncMock()
|
||||
import_kr.start = AsyncMock()
|
||||
import_kr.stop = AsyncMock()
|
||||
import_kr.process = fake_import_process
|
||||
mock_import_kr_class.return_value = import_kr
|
||||
|
||||
importer = CoreImport(backend=Mock())
|
||||
await importer.process(
|
||||
data=_make_data_reader(export_bytes[0]),
|
||||
error=AsyncMock(),
|
||||
ok=AsyncMock(return_value=AsyncMock(write_eof=AsyncMock())),
|
||||
request=_make_request(),
|
||||
)
|
||||
|
||||
# ----- verify the dict the importer would hand to the translator -----
|
||||
assert len(import_captured) == 1
|
||||
req = import_captured[0]
|
||||
|
||||
original = _ge_response_dict()["graph-embeddings"]
|
||||
|
||||
ge = req["graph-embeddings"]
|
||||
# The import side overrides id/user from the URL query (intentional),
|
||||
# so we only round-trip the entity payload itself.
|
||||
assert ge["metadata"]["id"] == original["metadata"]["id"]
|
||||
assert ge["metadata"]["user"] == original["metadata"]["user"]
|
||||
|
||||
assert len(ge["entities"]) == len(original["entities"])
|
||||
for got, want in zip(ge["entities"], original["entities"]):
|
||||
assert got["vector"] == want["vector"]
|
||||
assert got["entity"] == want["entity"]
|
||||
|
|
@ -0,0 +1,242 @@
|
|||
"""
|
||||
Unit tests for entity contexts import dispatcher.
|
||||
|
||||
Tests the business logic of EntityContextsImport while mocking the
|
||||
Publisher and websocket components.
|
||||
|
||||
Regression coverage: a previous version constructed Metadata(metadata=...)
|
||||
which raised TypeError at runtime as soon as a message was received. These
|
||||
tests exercise receive() end-to-end so any future schema/kwarg drift in
|
||||
the Metadata or EntityContexts construction is caught immediately.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
|
||||
from trustgraph.gateway.dispatch.entity_contexts_import import EntityContextsImport
|
||||
from trustgraph.schema import EntityContexts, EntityContext, Metadata
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_backend():
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_running():
|
||||
running = Mock()
|
||||
running.get.return_value = True
|
||||
running.stop = Mock()
|
||||
return running
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket():
|
||||
ws = Mock()
|
||||
ws.close = AsyncMock()
|
||||
return ws
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_message():
|
||||
"""Sample entity-contexts websocket message."""
|
||||
return {
|
||||
"metadata": {
|
||||
"id": "doc-123",
|
||||
"user": "testuser",
|
||||
"collection": "testcollection",
|
||||
},
|
||||
"entities": [
|
||||
{
|
||||
"entity": {"v": "http://example.org/alice", "e": True},
|
||||
"context": "Alice is a person.",
|
||||
},
|
||||
{
|
||||
"entity": {"v": "http://example.org/bob", "e": True},
|
||||
"context": "Bob is a person.",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_entities_message():
|
||||
return {
|
||||
"metadata": {
|
||||
"id": "doc-empty",
|
||||
"user": "u",
|
||||
"collection": "c",
|
||||
},
|
||||
"entities": [],
|
||||
}
|
||||
|
||||
|
||||
class TestEntityContextsImportInitialization:
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
def test_init_creates_publisher_with_correct_params(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
queue="ec-queue",
|
||||
)
|
||||
|
||||
mock_publisher_class.assert_called_once_with(
|
||||
mock_backend,
|
||||
topic="ec-queue",
|
||||
schema=EntityContexts,
|
||||
)
|
||||
assert dispatcher.ws is mock_websocket
|
||||
assert dispatcher.running is mock_running
|
||||
assert dispatcher.publisher is instance
|
||||
|
||||
|
||||
class TestEntityContextsImportLifecycle:
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_calls_publisher_start(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
instance.start = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
await dispatcher.start()
|
||||
instance.start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_stops_and_closes_properly(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
instance.stop = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
await dispatcher.destroy()
|
||||
|
||||
mock_running.stop.assert_called_once()
|
||||
instance.stop.assert_called_once()
|
||||
mock_websocket.close.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_handles_none_websocket(
|
||||
self, mock_publisher_class, mock_backend, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
instance.stop = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=None, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
await dispatcher.destroy()
|
||||
|
||||
mock_running.stop.assert_called_once()
|
||||
instance.stop.assert_called_once()
|
||||
|
||||
|
||||
class TestEntityContextsImportMessageProcessing:
|
||||
"""Regression coverage for receive(): catches Metadata/schema drift."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_constructs_entity_contexts_correctly(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket,
|
||||
mock_running, sample_message,
|
||||
):
|
||||
instance = Mock()
|
||||
instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = sample_message
|
||||
|
||||
# If Metadata or EntityContexts gain/lose kwargs, this raises
|
||||
# TypeError — exactly the regression we want to catch.
|
||||
await dispatcher.receive(mock_msg)
|
||||
|
||||
instance.send.assert_called_once()
|
||||
call_args = instance.send.call_args
|
||||
assert call_args[0][0] is None
|
||||
|
||||
sent = call_args[0][1]
|
||||
assert isinstance(sent, EntityContexts)
|
||||
assert isinstance(sent.metadata, Metadata)
|
||||
assert sent.metadata.id == "doc-123"
|
||||
assert sent.metadata.user == "testuser"
|
||||
assert sent.metadata.collection == "testcollection"
|
||||
|
||||
assert len(sent.entities) == 2
|
||||
assert all(isinstance(e, EntityContext) for e in sent.entities)
|
||||
assert sent.entities[0].context == "Alice is a person."
|
||||
assert sent.entities[1].context == "Bob is a person."
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_empty_entities(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket,
|
||||
mock_running, empty_entities_message,
|
||||
):
|
||||
instance = Mock()
|
||||
instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = empty_entities_message
|
||||
|
||||
await dispatcher.receive(mock_msg)
|
||||
|
||||
instance.send.assert_called_once()
|
||||
sent = instance.send.call_args[0][1]
|
||||
assert isinstance(sent, EntityContexts)
|
||||
assert sent.entities == []
|
||||
assert sent.metadata.id == "doc-empty"
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_propagates_publisher_errors(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket,
|
||||
mock_running, sample_message,
|
||||
):
|
||||
instance = Mock()
|
||||
instance.send = AsyncMock(side_effect=RuntimeError("publish failed"))
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = EntityContextsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = sample_message
|
||||
|
||||
with pytest.raises(RuntimeError, match="publish failed"):
|
||||
await dispatcher.receive(mock_msg)
|
||||
|
|
@ -158,7 +158,7 @@ class TestAgentExplainTriples:
|
|||
translator = AgentResponseTranslator()
|
||||
|
||||
response = AgentResponse(
|
||||
chunk_type="explain",
|
||||
message_type="explain",
|
||||
content="",
|
||||
explain_id="urn:trustgraph:agent:session:abc123",
|
||||
explain_graph="urn:graph:retrieval",
|
||||
|
|
@ -179,7 +179,7 @@ class TestAgentExplainTriples:
|
|||
translator = AgentResponseTranslator()
|
||||
|
||||
response = AgentResponse(
|
||||
chunk_type="thought",
|
||||
message_type="thought",
|
||||
content="I need to think...",
|
||||
)
|
||||
|
||||
|
|
@ -190,7 +190,7 @@ class TestAgentExplainTriples:
|
|||
translator = AgentResponseTranslator()
|
||||
|
||||
response = AgentResponse(
|
||||
chunk_type="explain",
|
||||
message_type="explain",
|
||||
explain_id="urn:trustgraph:agent:session:abc123",
|
||||
explain_triples=sample_triples(),
|
||||
end_of_dialog=False,
|
||||
|
|
@ -203,7 +203,7 @@ class TestAgentExplainTriples:
|
|||
translator = AgentResponseTranslator()
|
||||
|
||||
response = AgentResponse(
|
||||
chunk_type="answer",
|
||||
message_type="answer",
|
||||
content="The answer is...",
|
||||
end_of_dialog=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,247 @@
|
|||
"""
|
||||
Unit tests for graph embeddings import dispatcher.
|
||||
|
||||
Tests the business logic of GraphEmbeddingsImport while mocking the
|
||||
Publisher and websocket components.
|
||||
|
||||
Regression coverage: a previous version of EntityContextsImport
|
||||
constructed Metadata(metadata=...) which raised TypeError at runtime as
|
||||
soon as a message was received. The same shape of bug can occur here, so
|
||||
these tests exercise receive() end-to-end to catch any future schema or
|
||||
kwarg drift in Metadata / GraphEmbeddings / EntityEmbeddings construction.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
|
||||
from trustgraph.gateway.dispatch.graph_embeddings_import import GraphEmbeddingsImport
|
||||
from trustgraph.schema import GraphEmbeddings, EntityEmbeddings, Metadata
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_backend():
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_running():
|
||||
running = Mock()
|
||||
running.get.return_value = True
|
||||
running.stop = Mock()
|
||||
return running
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket():
|
||||
ws = Mock()
|
||||
ws.close = AsyncMock()
|
||||
return ws
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_message():
|
||||
"""Sample graph-embeddings websocket message."""
|
||||
return {
|
||||
"metadata": {
|
||||
"id": "doc-123",
|
||||
"user": "testuser",
|
||||
"collection": "testcollection",
|
||||
},
|
||||
"entities": [
|
||||
{
|
||||
"entity": {"v": "http://example.org/alice", "e": True},
|
||||
"vector": [0.1, 0.2, 0.3],
|
||||
},
|
||||
{
|
||||
"entity": {"v": "http://example.org/bob", "e": True},
|
||||
"vector": [0.4, 0.5, 0.6],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_entities_message():
|
||||
return {
|
||||
"metadata": {
|
||||
"id": "doc-empty",
|
||||
"user": "u",
|
||||
"collection": "c",
|
||||
},
|
||||
"entities": [],
|
||||
}
|
||||
|
||||
|
||||
class TestGraphEmbeddingsImportInitialization:
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
def test_init_creates_publisher_with_correct_params(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=mock_websocket,
|
||||
running=mock_running,
|
||||
backend=mock_backend,
|
||||
queue="ge-queue",
|
||||
)
|
||||
|
||||
mock_publisher_class.assert_called_once_with(
|
||||
mock_backend,
|
||||
topic="ge-queue",
|
||||
schema=GraphEmbeddings,
|
||||
)
|
||||
assert dispatcher.ws is mock_websocket
|
||||
assert dispatcher.running is mock_running
|
||||
assert dispatcher.publisher is instance
|
||||
|
||||
|
||||
class TestGraphEmbeddingsImportLifecycle:
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_calls_publisher_start(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
instance.start = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
await dispatcher.start()
|
||||
instance.start.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_stops_and_closes_properly(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
instance.stop = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
await dispatcher.destroy()
|
||||
|
||||
mock_running.stop.assert_called_once()
|
||||
instance.stop.assert_called_once()
|
||||
mock_websocket.close.assert_called_once()
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_destroy_handles_none_websocket(
|
||||
self, mock_publisher_class, mock_backend, mock_running
|
||||
):
|
||||
instance = Mock()
|
||||
instance.stop = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=None, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
await dispatcher.destroy()
|
||||
|
||||
mock_running.stop.assert_called_once()
|
||||
instance.stop.assert_called_once()
|
||||
|
||||
|
||||
class TestGraphEmbeddingsImportMessageProcessing:
|
||||
"""Regression coverage for receive(): catches Metadata/schema drift."""
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_constructs_graph_embeddings_correctly(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket,
|
||||
mock_running, sample_message,
|
||||
):
|
||||
instance = Mock()
|
||||
instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = sample_message
|
||||
|
||||
# If Metadata, GraphEmbeddings, or EntityEmbeddings gain/lose
|
||||
# kwargs, this raises TypeError — exactly the regression we want
|
||||
# to catch.
|
||||
await dispatcher.receive(mock_msg)
|
||||
|
||||
instance.send.assert_called_once()
|
||||
call_args = instance.send.call_args
|
||||
assert call_args[0][0] is None
|
||||
|
||||
sent = call_args[0][1]
|
||||
assert isinstance(sent, GraphEmbeddings)
|
||||
assert isinstance(sent.metadata, Metadata)
|
||||
assert sent.metadata.id == "doc-123"
|
||||
assert sent.metadata.user == "testuser"
|
||||
assert sent.metadata.collection == "testcollection"
|
||||
|
||||
assert len(sent.entities) == 2
|
||||
assert all(isinstance(e, EntityEmbeddings) for e in sent.entities)
|
||||
# Lock in the wire format: incoming "vector" key (singular,
|
||||
# list[float]) maps to EntityEmbeddings.vector. This mirrors
|
||||
# serialize_graph_embeddings() on the export side.
|
||||
assert sent.entities[0].vector == [0.1, 0.2, 0.3]
|
||||
assert sent.entities[1].vector == [0.4, 0.5, 0.6]
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_handles_empty_entities(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket,
|
||||
mock_running, empty_entities_message,
|
||||
):
|
||||
instance = Mock()
|
||||
instance.send = AsyncMock()
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = empty_entities_message
|
||||
|
||||
await dispatcher.receive(mock_msg)
|
||||
|
||||
instance.send.assert_called_once()
|
||||
sent = instance.send.call_args[0][1]
|
||||
assert isinstance(sent, GraphEmbeddings)
|
||||
assert sent.entities == []
|
||||
assert sent.metadata.id == "doc-empty"
|
||||
|
||||
@patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher')
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_propagates_publisher_errors(
|
||||
self, mock_publisher_class, mock_backend, mock_websocket,
|
||||
mock_running, sample_message,
|
||||
):
|
||||
instance = Mock()
|
||||
instance.send = AsyncMock(side_effect=RuntimeError("publish failed"))
|
||||
mock_publisher_class.return_value = instance
|
||||
|
||||
dispatcher = GraphEmbeddingsImport(
|
||||
ws=mock_websocket, running=mock_running,
|
||||
backend=mock_backend, queue="q",
|
||||
)
|
||||
|
||||
mock_msg = Mock()
|
||||
mock_msg.json.return_value = sample_message
|
||||
|
||||
with pytest.raises(RuntimeError, match="publish failed"):
|
||||
await dispatcher.receive(mock_msg)
|
||||
|
|
@ -171,6 +171,14 @@ class TestApi:
|
|||
patch('aiohttp.web.run_app') as mock_run_app:
|
||||
mock_get_pubsub.return_value = Mock()
|
||||
|
||||
# Api.run() passes self.app_factory() — a coroutine — to
|
||||
# web.run_app, which would normally consume it inside its own
|
||||
# event loop. Since we mock run_app, close the coroutine here
|
||||
# so it doesn't leak as an "unawaited coroutine" RuntimeWarning.
|
||||
def _consume_coro(coro, **kwargs):
|
||||
coro.close()
|
||||
mock_run_app.side_effect = _consume_coro
|
||||
|
||||
api = Api(port=8080)
|
||||
api.run()
|
||||
|
||||
|
|
|
|||
592
tests/unit/test_provenance/test_dag_structure.py
Normal file
592
tests/unit/test_provenance/test_dag_structure.py
Normal file
|
|
@ -0,0 +1,592 @@
|
|||
"""
|
||||
DAG structure tests for provenance chains.
|
||||
|
||||
Verifies that the wasDerivedFrom chain has the expected shape for each
|
||||
service. These tests catch structural regressions when new entities are
|
||||
inserted into the chain (e.g. PatternDecision between session and first
|
||||
iteration).
|
||||
|
||||
Expected chains:
|
||||
|
||||
GraphRAG: question → grounding → exploration → focus → synthesis
|
||||
DocumentRAG: question → grounding → exploration → synthesis
|
||||
Agent React: session → pattern-decision → iteration → (observation → iteration)* → final
|
||||
Agent Plan: session → pattern-decision → plan → step-result(s) → synthesis
|
||||
Agent Super: session → pattern-decision → decomposition → (fan-out) → finding(s) → synthesis
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from trustgraph.schema import (
|
||||
AgentRequest, AgentResponse, AgentStep, PlanStep,
|
||||
Triple, Term, IRI, LITERAL,
|
||||
)
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_WAS_DERIVED_FROM, GRAPH_RETRIEVAL,
|
||||
TG_AGENT_QUESTION, TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION,
|
||||
TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS,
|
||||
TG_ANALYSIS, TG_CONCLUSION, TG_PATTERN_DECISION,
|
||||
TG_PLAN_TYPE, TG_STEP_RESULT, TG_DECOMPOSITION,
|
||||
TG_OBSERVATION_TYPE,
|
||||
TG_PATTERN, TG_TASK_TYPE,
|
||||
)
|
||||
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _collect_events(events):
|
||||
"""Build a dict of explain_id → {types, derived_from, triples}."""
|
||||
result = {}
|
||||
for ev in events:
|
||||
eid = ev["explain_id"]
|
||||
triples = ev["triples"]
|
||||
types = {
|
||||
t.o.iri for t in triples
|
||||
if t.s.iri == eid and t.p.iri == RDF_TYPE
|
||||
}
|
||||
parents = [
|
||||
t.o.iri for t in triples
|
||||
if t.s.iri == eid and t.p.iri == PROV_WAS_DERIVED_FROM
|
||||
]
|
||||
result[eid] = {
|
||||
"types": types,
|
||||
"derived_from": parents[0] if parents else None,
|
||||
"triples": triples,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def _find_by_type(dag, rdf_type):
|
||||
"""Find all event IDs that have the given rdf:type."""
|
||||
return [eid for eid, info in dag.items() if rdf_type in info["types"]]
|
||||
|
||||
|
||||
def _assert_chain(dag, chain_types):
|
||||
"""Assert that a linear wasDerivedFrom chain exists through the given types."""
|
||||
for i in range(1, len(chain_types)):
|
||||
parent_type = chain_types[i - 1]
|
||||
child_type = chain_types[i]
|
||||
parents = _find_by_type(dag, parent_type)
|
||||
children = _find_by_type(dag, child_type)
|
||||
assert parents, f"No entity with type {parent_type}"
|
||||
assert children, f"No entity with type {child_type}"
|
||||
# At least one child must derive from at least one parent
|
||||
linked = False
|
||||
for child_id in children:
|
||||
derived = dag[child_id]["derived_from"]
|
||||
if derived in parents:
|
||||
linked = True
|
||||
break
|
||||
assert linked, (
|
||||
f"No {child_type} derives from {parent_type}. "
|
||||
f"Children derive from: "
|
||||
f"{[dag[c]['derived_from'] for c in children]}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GraphRAG DAG structure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGraphRagDagStructure:
|
||||
"""Verify: question → grounding → exploration → focus → synthesis"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_clients(self):
|
||||
prompt_client = AsyncMock()
|
||||
embeddings_client = AsyncMock()
|
||||
graph_embeddings_client = AsyncMock()
|
||||
triples_client = AsyncMock()
|
||||
|
||||
embeddings_client.embed.return_value = [[0.1, 0.2]]
|
||||
graph_embeddings_client.query.return_value = [
|
||||
MagicMock(entity=Term(type=IRI, iri="http://example.com/e1")),
|
||||
]
|
||||
triples_client.query_stream.return_value = [
|
||||
Triple(
|
||||
s=Term(type=IRI, iri="http://example.com/e1"),
|
||||
p=Term(type=IRI, iri="http://example.com/p"),
|
||||
o=Term(type=LITERAL, value="value"),
|
||||
)
|
||||
]
|
||||
triples_client.query.return_value = []
|
||||
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="concept")
|
||||
elif template_id == "kg-edge-scoring":
|
||||
edges = variables.get("knowledge", [])
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[{"id": e["id"], "score": 10} for e in edges],
|
||||
)
|
||||
elif template_id == "kg-edge-reasoning":
|
||||
edges = variables.get("knowledge", [])
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[{"id": e["id"], "reasoning": "relevant"} for e in edges],
|
||||
)
|
||||
elif template_id == "kg-synthesis":
|
||||
return PromptResult(response_type="text", text="Answer.")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
return prompt_client, embeddings_client, graph_embeddings_client, triples_client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_chain(self, mock_clients):
|
||||
rag = GraphRag(*mock_clients)
|
||||
events = []
|
||||
|
||||
async def explain_cb(triples, explain_id):
|
||||
events.append({"explain_id": explain_id, "triples": triples})
|
||||
|
||||
await rag.query(
|
||||
query="test", explain_callback=explain_cb, edge_score_limit=0,
|
||||
)
|
||||
|
||||
dag = _collect_events(events)
|
||||
assert len(dag) == 5, f"Expected 5 events, got {len(dag)}"
|
||||
|
||||
_assert_chain(dag, [
|
||||
TG_GRAPH_RAG_QUESTION,
|
||||
TG_GROUNDING,
|
||||
TG_EXPLORATION,
|
||||
TG_FOCUS,
|
||||
TG_SYNTHESIS,
|
||||
])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DocumentRAG DAG structure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDocumentRagDagStructure:
|
||||
"""Verify: question → grounding → exploration → synthesis"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_clients(self):
|
||||
from trustgraph.schema import ChunkMatch
|
||||
|
||||
prompt_client = AsyncMock()
|
||||
embeddings_client = AsyncMock()
|
||||
doc_embeddings_client = AsyncMock()
|
||||
fetch_chunk = AsyncMock(return_value="Chunk content.")
|
||||
|
||||
embeddings_client.embed.return_value = [[0.1, 0.2]]
|
||||
doc_embeddings_client.query.return_value = [
|
||||
ChunkMatch(chunk_id="doc/c1", score=0.9),
|
||||
]
|
||||
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return PromptResult(response_type="text", text="concept")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
prompt_client.document_prompt.return_value = PromptResult(
|
||||
response_type="text", text="Answer.",
|
||||
)
|
||||
|
||||
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_chain(self, mock_clients):
|
||||
rag = DocumentRag(*mock_clients)
|
||||
events = []
|
||||
|
||||
async def explain_cb(triples, explain_id):
|
||||
events.append({"explain_id": explain_id, "triples": triples})
|
||||
|
||||
await rag.query(
|
||||
query="test", explain_callback=explain_cb,
|
||||
)
|
||||
|
||||
dag = _collect_events(events)
|
||||
assert len(dag) == 4, f"Expected 4 events, got {len(dag)}"
|
||||
|
||||
_assert_chain(dag, [
|
||||
TG_DOC_RAG_QUESTION,
|
||||
TG_GROUNDING,
|
||||
TG_EXPLORATION,
|
||||
TG_SYNTHESIS,
|
||||
])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent DAG structure — tested via service.agent_request()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_processor(tools=None):
|
||||
processor = MagicMock()
|
||||
processor.max_iterations = 10
|
||||
processor.save_answer_content = AsyncMock()
|
||||
|
||||
def mock_session_uri(sid):
|
||||
return f"urn:trustgraph:agent:session:{sid}"
|
||||
processor.provenance_session_uri.side_effect = mock_session_uri
|
||||
|
||||
agent = MagicMock()
|
||||
agent.tools = tools or {}
|
||||
agent.additional_context = ""
|
||||
processor.agent = agent
|
||||
processor.aggregator = MagicMock()
|
||||
|
||||
return processor
|
||||
|
||||
|
||||
def _make_flow():
|
||||
producers = {}
|
||||
|
||||
def factory(name):
|
||||
if name not in producers:
|
||||
producers[name] = AsyncMock()
|
||||
return producers[name]
|
||||
|
||||
flow = MagicMock(side_effect=factory)
|
||||
return flow
|
||||
|
||||
|
||||
def _collect_agent_events(respond_mock):
|
||||
events = []
|
||||
for call in respond_mock.call_args_list:
|
||||
resp = call[0][0]
|
||||
if isinstance(resp, AgentResponse) and resp.message_type == "explain":
|
||||
events.append({
|
||||
"explain_id": resp.explain_id,
|
||||
"triples": resp.explain_triples,
|
||||
})
|
||||
return events
|
||||
|
||||
|
||||
class TestAgentReactDagStructure:
|
||||
"""
|
||||
Via service.agent_request(), full two-iteration react chain:
|
||||
session → pattern-decision → iteration(1) → observation(1) → final
|
||||
|
||||
Iteration 1: tool call → observation
|
||||
Iteration 2: final answer
|
||||
"""
|
||||
|
||||
def _make_service(self):
|
||||
from trustgraph.agent.orchestrator.service import Processor
|
||||
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
|
||||
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
|
||||
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
|
||||
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "lookup"
|
||||
mock_tool.description = "Look things up"
|
||||
mock_tool.arguments = []
|
||||
mock_tool.groups = []
|
||||
mock_tool.states = {}
|
||||
mock_tool_impl = AsyncMock(return_value="42")
|
||||
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
|
||||
|
||||
processor = _make_processor(tools={"lookup": mock_tool})
|
||||
|
||||
service = Processor.__new__(Processor)
|
||||
service.max_iterations = 10
|
||||
service.save_answer_content = AsyncMock()
|
||||
service.provenance_session_uri = processor.provenance_session_uri
|
||||
service.agent = processor.agent
|
||||
service.aggregator = processor.aggregator
|
||||
|
||||
service.react_pattern = ReactPattern(service)
|
||||
service.plan_pattern = PlanThenExecutePattern(service)
|
||||
service.supervisor_pattern = SupervisorPattern(service)
|
||||
service.meta_router = None
|
||||
|
||||
return service
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_chain(self):
|
||||
from trustgraph.agent.react.types import Action, Final
|
||||
|
||||
service = self._make_service()
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = _make_flow()
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Iteration 1: tool call → returns Action, triggers on_action + tool exec
|
||||
action = Action(
|
||||
thought="I need to look this up",
|
||||
name="lookup",
|
||||
arguments={"question": "6x7"},
|
||||
observation="",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"trustgraph.agent.orchestrator.react_pattern.AgentManager"
|
||||
) as MockAM:
|
||||
mock_am = AsyncMock()
|
||||
MockAM.return_value = mock_am
|
||||
|
||||
async def mock_react_iter1(on_action=None, **kwargs):
|
||||
if on_action:
|
||||
await on_action(action)
|
||||
action.observation = "42"
|
||||
return action
|
||||
|
||||
mock_am.react.side_effect = mock_react_iter1
|
||||
|
||||
request1 = AgentRequest(
|
||||
question="What is 6x7?",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
session_id=session_id,
|
||||
pattern="react",
|
||||
history=[],
|
||||
)
|
||||
|
||||
await service.agent_request(request1, respond, next_fn, flow)
|
||||
|
||||
# next_fn should have been called with updated history
|
||||
assert next_fn.called
|
||||
|
||||
# Iteration 2: final answer
|
||||
final = Final(thought="The answer is 42", final="42")
|
||||
next_request = next_fn.call_args[0][0]
|
||||
|
||||
with patch(
|
||||
"trustgraph.agent.orchestrator.react_pattern.AgentManager"
|
||||
) as MockAM:
|
||||
mock_am = AsyncMock()
|
||||
MockAM.return_value = mock_am
|
||||
|
||||
async def mock_react_iter2(**kwargs):
|
||||
return final
|
||||
|
||||
mock_am.react.side_effect = mock_react_iter2
|
||||
|
||||
await service.agent_request(next_request, respond, next_fn, flow)
|
||||
|
||||
# Collect and verify DAG
|
||||
events = _collect_agent_events(respond)
|
||||
dag = _collect_events(events)
|
||||
|
||||
session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
|
||||
pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
|
||||
analysis_ids = _find_by_type(dag, TG_ANALYSIS)
|
||||
observation_ids = _find_by_type(dag, TG_OBSERVATION_TYPE)
|
||||
final_ids = _find_by_type(dag, TG_CONCLUSION)
|
||||
|
||||
assert len(session_ids) == 1, f"Expected 1 session, got {len(session_ids)}"
|
||||
assert len(pd_ids) == 1, f"Expected 1 pattern-decision, got {len(pd_ids)}"
|
||||
assert len(analysis_ids) >= 1, f"Expected >=1 analysis, got {len(analysis_ids)}"
|
||||
assert len(observation_ids) >= 1, f"Expected >=1 observation, got {len(observation_ids)}"
|
||||
assert len(final_ids) == 1, f"Expected 1 final, got {len(final_ids)}"
|
||||
|
||||
# Full chain:
|
||||
# session → pattern-decision
|
||||
assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
|
||||
|
||||
# pattern-decision → iteration(1)
|
||||
assert dag[analysis_ids[0]]["derived_from"] == pd_ids[0]
|
||||
|
||||
# iteration(1) → observation(1)
|
||||
assert dag[observation_ids[0]]["derived_from"] == analysis_ids[0]
|
||||
|
||||
# observation(1) → final
|
||||
assert dag[final_ids[0]]["derived_from"] == observation_ids[0]
|
||||
|
||||
|
||||
class TestAgentPlanDagStructure:
|
||||
"""
|
||||
Via service.agent_request():
|
||||
session → pattern-decision → plan → step-result → synthesis
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_chain(self):
|
||||
from trustgraph.agent.orchestrator.service import Processor
|
||||
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
|
||||
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
|
||||
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
|
||||
|
||||
# Mock tool
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "knowledge-query"
|
||||
mock_tool.description = "Query KB"
|
||||
mock_tool.arguments = []
|
||||
mock_tool.groups = []
|
||||
mock_tool.states = {}
|
||||
mock_tool_impl = AsyncMock(return_value="Found it")
|
||||
mock_tool.implementation = MagicMock(return_value=mock_tool_impl)
|
||||
|
||||
processor = _make_processor(tools={"knowledge-query": mock_tool})
|
||||
|
||||
service = Processor.__new__(Processor)
|
||||
service.max_iterations = 10
|
||||
service.save_answer_content = AsyncMock()
|
||||
service.provenance_session_uri = processor.provenance_session_uri
|
||||
service.agent = processor.agent
|
||||
service.aggregator = processor.aggregator
|
||||
|
||||
service.react_pattern = ReactPattern(service)
|
||||
service.plan_pattern = PlanThenExecutePattern(service)
|
||||
service.supervisor_pattern = SupervisorPattern(service)
|
||||
service.meta_router = None
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = _make_flow()
|
||||
|
||||
# Mock prompt client
|
||||
mock_prompt_client = AsyncMock()
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_prompt(id, variables=None, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if id == "plan-create":
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[{"goal": "Find info", "tool_hint": "knowledge-query", "depends_on": []}],
|
||||
)
|
||||
elif id == "plan-step-execute":
|
||||
return PromptResult(
|
||||
response_type="json",
|
||||
object={"tool": "knowledge-query", "arguments": {"question": "test"}},
|
||||
)
|
||||
elif id == "plan-synthesise":
|
||||
return PromptResult(response_type="text", text="Final answer.")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
mock_prompt_client.prompt.side_effect = mock_prompt
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_factory
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Iteration 1: planning
|
||||
request1 = AgentRequest(
|
||||
question="Test?",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
session_id=session_id,
|
||||
pattern="plan-then-execute",
|
||||
history=[],
|
||||
)
|
||||
await service.agent_request(request1, respond, next_fn, flow)
|
||||
|
||||
# Iteration 2: execute step (next_fn was called with updated request)
|
||||
assert next_fn.called
|
||||
next_request = next_fn.call_args[0][0]
|
||||
|
||||
# Iteration 3: all steps done → synthesis
|
||||
# Simulate completed step in history
|
||||
next_request.history[-1].plan[0].status = "completed"
|
||||
next_request.history[-1].plan[0].result = "Found it"
|
||||
|
||||
await service.agent_request(next_request, respond, next_fn, flow)
|
||||
|
||||
events = _collect_agent_events(respond)
|
||||
dag = _collect_events(events)
|
||||
|
||||
session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
|
||||
pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
|
||||
plan_ids = _find_by_type(dag, TG_PLAN_TYPE)
|
||||
synthesis_ids = _find_by_type(dag, TG_SYNTHESIS)
|
||||
|
||||
assert len(session_ids) == 1
|
||||
assert len(pd_ids) == 1
|
||||
assert len(plan_ids) == 1
|
||||
assert len(synthesis_ids) == 1
|
||||
|
||||
# Chain: session → pattern-decision → plan → ... → synthesis
|
||||
assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
|
||||
assert dag[plan_ids[0]]["derived_from"] == pd_ids[0]
|
||||
|
||||
|
||||
class TestAgentSupervisorDagStructure:
|
||||
"""
|
||||
Via service.agent_request():
|
||||
session → pattern-decision → decomposition → (fan-out)
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_chain(self):
|
||||
from trustgraph.agent.orchestrator.service import Processor
|
||||
from trustgraph.agent.orchestrator.react_pattern import ReactPattern
|
||||
from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern
|
||||
from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern
|
||||
|
||||
processor = _make_processor()
|
||||
|
||||
service = Processor.__new__(Processor)
|
||||
service.max_iterations = 10
|
||||
service.save_answer_content = AsyncMock()
|
||||
service.provenance_session_uri = processor.provenance_session_uri
|
||||
service.agent = processor.agent
|
||||
service.aggregator = processor.aggregator
|
||||
|
||||
service.react_pattern = ReactPattern(service)
|
||||
service.plan_pattern = PlanThenExecutePattern(service)
|
||||
service.supervisor_pattern = SupervisorPattern(service)
|
||||
service.meta_router = None
|
||||
|
||||
respond = AsyncMock()
|
||||
next_fn = AsyncMock()
|
||||
flow = _make_flow()
|
||||
|
||||
mock_prompt_client = AsyncMock()
|
||||
mock_prompt_client.prompt.return_value = PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=["Goal A", "Goal B"],
|
||||
)
|
||||
|
||||
def flow_factory(name):
|
||||
if name == "prompt-request":
|
||||
return mock_prompt_client
|
||||
return AsyncMock()
|
||||
flow.side_effect = flow_factory
|
||||
|
||||
request = AgentRequest(
|
||||
question="Research quantum computing",
|
||||
user="testuser",
|
||||
collection="default",
|
||||
streaming=False,
|
||||
session_id=str(uuid.uuid4()),
|
||||
pattern="supervisor",
|
||||
history=[],
|
||||
)
|
||||
|
||||
await service.agent_request(request, respond, next_fn, flow)
|
||||
|
||||
events = _collect_agent_events(respond)
|
||||
dag = _collect_events(events)
|
||||
|
||||
session_ids = _find_by_type(dag, TG_AGENT_QUESTION)
|
||||
pd_ids = _find_by_type(dag, TG_PATTERN_DECISION)
|
||||
decomp_ids = _find_by_type(dag, TG_DECOMPOSITION)
|
||||
|
||||
assert len(session_ids) == 1
|
||||
assert len(pd_ids) == 1
|
||||
assert len(decomp_ids) == 1
|
||||
|
||||
# Chain: session → pattern-decision → decomposition
|
||||
assert dag[pd_ids[0]]["derived_from"] == session_ids[0]
|
||||
assert dag[decomp_ids[0]]["derived_from"] == pd_ids[0]
|
||||
|
||||
# Fan-out should have been called
|
||||
assert next_fn.call_count == 2 # One per goal
|
||||
|
|
@ -223,7 +223,7 @@ class TestDerivedEntityTriples:
|
|||
assert has_type(triples, self.ENTITY_URI, PROV_ENTITY)
|
||||
assert has_type(triples, self.ENTITY_URI, TG_PAGE_TYPE)
|
||||
|
||||
def test_chunk_entity_has_chunk_type(self):
|
||||
def test_chunk_entity_has_message_type(self):
|
||||
triples = derived_entity_triples(
|
||||
self.ENTITY_URI, self.PARENT_URI,
|
||||
"chunker", "1.0",
|
||||
|
|
|
|||
|
|
@ -304,14 +304,14 @@ class TestStreamingTypes:
|
|||
|
||||
assert chunk.content == "thinking..."
|
||||
assert chunk.end_of_message is False
|
||||
assert chunk.chunk_type == "thought"
|
||||
assert chunk.message_type == "thought"
|
||||
|
||||
def test_agent_observation_creation(self):
|
||||
"""Test creating AgentObservation chunk"""
|
||||
chunk = AgentObservation(content="observing...", end_of_message=False)
|
||||
|
||||
assert chunk.content == "observing..."
|
||||
assert chunk.chunk_type == "observation"
|
||||
assert chunk.message_type == "observation"
|
||||
|
||||
def test_agent_answer_creation(self):
|
||||
"""Test creating AgentAnswer chunk"""
|
||||
|
|
@ -324,7 +324,7 @@ class TestStreamingTypes:
|
|||
assert chunk.content == "answer"
|
||||
assert chunk.end_of_message is True
|
||||
assert chunk.end_of_dialog is True
|
||||
assert chunk.chunk_type == "final-answer"
|
||||
assert chunk.message_type == "final-answer"
|
||||
|
||||
def test_rag_chunk_creation(self):
|
||||
"""Test creating RAGChunk"""
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag, Query
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
# Sample chunk content mapping for tests
|
||||
|
|
@ -132,7 +133,7 @@ class TestQuery:
|
|||
mock_rag.prompt_client = mock_prompt_client
|
||||
|
||||
# Mock the prompt response with concept lines
|
||||
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence\ndata patterns"
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nartificial intelligence\ndata patterns")
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -157,7 +158,7 @@ class TestQuery:
|
|||
mock_rag.prompt_client = mock_prompt_client
|
||||
|
||||
# Mock empty response
|
||||
mock_prompt_client.prompt.return_value = ""
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -258,7 +259,7 @@ class TestQuery:
|
|||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock concept extraction
|
||||
mock_prompt_client.prompt.return_value = "test concept"
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="test concept")
|
||||
|
||||
# Mock embeddings - one vector per concept
|
||||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
|
|
@ -273,7 +274,7 @@ class TestQuery:
|
|||
expected_response = "This is the document RAG response"
|
||||
|
||||
mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2]
|
||||
mock_prompt_client.document_prompt.return_value = expected_response
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text=expected_response)
|
||||
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -315,7 +316,8 @@ class TestQuery:
|
|||
assert "Relevant document content" in docs
|
||||
assert "Another document" in docs
|
||||
|
||||
assert result == expected_response
|
||||
result_text, usage = result
|
||||
assert result_text == expected_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_rag_query_with_defaults(self, mock_fetch_chunk):
|
||||
|
|
@ -325,7 +327,7 @@ class TestQuery:
|
|||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock concept extraction fallback (empty → raw query)
|
||||
mock_prompt_client.prompt.return_value = ""
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
|
||||
# Mock responses
|
||||
mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]]
|
||||
|
|
@ -333,7 +335,7 @@ class TestQuery:
|
|||
mock_match.chunk_id = "doc/c5"
|
||||
mock_match.score = 0.9
|
||||
mock_doc_embeddings_client.query.return_value = [mock_match]
|
||||
mock_prompt_client.document_prompt.return_value = "Default response"
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="Default response")
|
||||
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -352,7 +354,8 @@ class TestQuery:
|
|||
collection="default" # Default collection
|
||||
)
|
||||
|
||||
assert result == "Default response"
|
||||
result_text, usage = result
|
||||
assert result_text == "Default response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_docs_with_verbose_output(self):
|
||||
|
|
@ -401,7 +404,7 @@ class TestQuery:
|
|||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock concept extraction
|
||||
mock_prompt_client.prompt.return_value = "verbose query test"
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="verbose query test")
|
||||
|
||||
# Mock responses
|
||||
mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]]
|
||||
|
|
@ -409,7 +412,7 @@ class TestQuery:
|
|||
mock_match.chunk_id = "doc/c7"
|
||||
mock_match.score = 0.92
|
||||
mock_doc_embeddings_client.query.return_value = [mock_match]
|
||||
mock_prompt_client.document_prompt.return_value = "Verbose RAG response"
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="Verbose RAG response")
|
||||
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -428,7 +431,8 @@ class TestQuery:
|
|||
assert call_args.kwargs["query"] == "verbose query test"
|
||||
assert "Verbose doc content" in call_args.kwargs["documents"]
|
||||
|
||||
assert result == "Verbose RAG response"
|
||||
result_text, usage = result
|
||||
assert result_text == "Verbose RAG response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_docs_with_empty_results(self):
|
||||
|
|
@ -469,11 +473,11 @@ class TestQuery:
|
|||
mock_doc_embeddings_client = AsyncMock()
|
||||
|
||||
# Mock concept extraction
|
||||
mock_prompt_client.prompt.return_value = "query with no matching docs"
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="query with no matching docs")
|
||||
|
||||
mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]]
|
||||
mock_doc_embeddings_client.query.return_value = []
|
||||
mock_prompt_client.document_prompt.return_value = "No documents found response"
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="No documents found response")
|
||||
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -490,7 +494,8 @@ class TestQuery:
|
|||
documents=[]
|
||||
)
|
||||
|
||||
assert result == "No documents found response"
|
||||
result_text, usage = result
|
||||
assert result_text == "No documents found response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_vectors_with_verbose(self):
|
||||
|
|
@ -525,7 +530,7 @@ class TestQuery:
|
|||
final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed."
|
||||
|
||||
# Mock concept extraction
|
||||
mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence"
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nartificial intelligence")
|
||||
|
||||
# Mock embeddings - one vector per concept
|
||||
query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]]
|
||||
|
|
@ -541,7 +546,7 @@ class TestQuery:
|
|||
MagicMock(chunk_id="doc/ml3", score=0.82),
|
||||
]
|
||||
mock_doc_embeddings_client.query.side_effect = [mock_matches_1, mock_matches_2]
|
||||
mock_prompt_client.document_prompt.return_value = final_response
|
||||
mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text=final_response)
|
||||
|
||||
document_rag = DocumentRag(
|
||||
prompt_client=mock_prompt_client,
|
||||
|
|
@ -584,7 +589,8 @@ class TestQuery:
|
|||
assert "Common ML techniques include supervised and unsupervised learning..." in docs
|
||||
assert len(docs) == 3 # doc/ml2 deduplicated
|
||||
|
||||
assert result == final_response
|
||||
result_text, usage = result
|
||||
assert result_text == final_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_docs_deduplicates_across_concepts(self):
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from unittest.mock import AsyncMock
|
|||
from dataclasses import dataclass
|
||||
|
||||
from trustgraph.retrieval.document_rag.document_rag import DocumentRag
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
|
|
@ -89,8 +90,8 @@ def build_mock_clients():
|
|||
# 1. Concept extraction
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return "return policy\nrefund"
|
||||
return ""
|
||||
return PromptResult(response_type="text", text="return policy\nrefund")
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
|
||||
|
|
@ -113,8 +114,9 @@ def build_mock_clients():
|
|||
fetch_chunk.side_effect = mock_fetch
|
||||
|
||||
# 5. Synthesis
|
||||
prompt_client.document_prompt.return_value = (
|
||||
"Items can be returned within 30 days for a full refund."
|
||||
prompt_client.document_prompt.return_value = PromptResult(
|
||||
response_type="text",
|
||||
text="Items can be returned within 30 days for a full refund.",
|
||||
)
|
||||
|
||||
return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk
|
||||
|
|
@ -340,12 +342,12 @@ class TestDocumentRagQueryProvenance:
|
|||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
result = await rag.query(
|
||||
result_text, usage = await rag.query(
|
||||
query="What is the return policy?",
|
||||
explain_callback=AsyncMock(),
|
||||
)
|
||||
|
||||
assert result == "Items can be returned within 30 days for a full refund."
|
||||
assert result_text == "Items can be returned within 30 days for a full refund."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_explain_callback_still_works(self):
|
||||
|
|
@ -353,8 +355,8 @@ class TestDocumentRagQueryProvenance:
|
|||
clients = build_mock_clients()
|
||||
rag = DocumentRag(*clients)
|
||||
|
||||
result = await rag.query(query="What is the return policy?")
|
||||
assert result == "Items can be returned within 30 days for a full refund."
|
||||
result_text, usage = await rag.query(query="What is the return policy?")
|
||||
assert result_text == "Items can be returned within 30 days for a full refund."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_triples_in_retrieval_graph(self):
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class TestDocumentRagService:
|
|||
# Setup mock DocumentRag instance
|
||||
mock_rag_instance = AsyncMock()
|
||||
mock_document_rag_class.return_value = mock_rag_instance
|
||||
mock_rag_instance.query.return_value = "test response"
|
||||
mock_rag_instance.query.return_value = ("test response", {"in_token": None, "out_token": None, "model": None})
|
||||
|
||||
# Setup message with custom user/collection
|
||||
msg = MagicMock()
|
||||
|
|
@ -97,7 +97,7 @@ class TestDocumentRagService:
|
|||
# Setup mock DocumentRag instance
|
||||
mock_rag_instance = AsyncMock()
|
||||
mock_document_rag_class.return_value = mock_rag_instance
|
||||
mock_rag_instance.query.return_value = "A document about cats."
|
||||
mock_rag_instance.query.return_value = ("A document about cats.", {"in_token": None, "out_token": None, "model": None})
|
||||
|
||||
# Setup message with non-streaming request
|
||||
msg = MagicMock()
|
||||
|
|
@ -130,4 +130,5 @@ class TestDocumentRagService:
|
|||
assert isinstance(sent_response, DocumentRagResponse)
|
||||
assert sent_response.response == "A document about cats."
|
||||
assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True"
|
||||
assert sent_response.end_of_session is True
|
||||
assert sent_response.error is None
|
||||
|
|
@ -7,6 +7,7 @@ import unittest.mock
|
|||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, Query
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
|
||||
class TestGraphRag:
|
||||
|
|
@ -172,7 +173,7 @@ class TestQuery:
|
|||
mock_prompt_client = AsyncMock()
|
||||
mock_rag.prompt_client = mock_prompt_client
|
||||
|
||||
mock_prompt_client.prompt.return_value = "machine learning\nneural networks\n"
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nneural networks\n")
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -196,7 +197,7 @@ class TestQuery:
|
|||
mock_prompt_client = AsyncMock()
|
||||
mock_rag.prompt_client = mock_prompt_client
|
||||
|
||||
mock_prompt_client.prompt.return_value = ""
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
|
||||
query = Query(
|
||||
rag=mock_rag,
|
||||
|
|
@ -220,7 +221,7 @@ class TestQuery:
|
|||
mock_rag.graph_embeddings_client = mock_graph_embeddings_client
|
||||
|
||||
# extract_concepts returns empty -> falls back to [query]
|
||||
mock_prompt_client.prompt.return_value = ""
|
||||
mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="")
|
||||
|
||||
# embed returns one vector set for the single concept
|
||||
test_vectors = [[0.1, 0.2, 0.3]]
|
||||
|
|
@ -565,14 +566,14 @@ class TestQuery:
|
|||
# Mock prompt responses for the multi-step process
|
||||
async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None):
|
||||
if prompt_name == "extract-concepts":
|
||||
return "" # Falls back to raw query
|
||||
return PromptResult(response_type="text", text="")
|
||||
elif prompt_name == "kg-edge-scoring":
|
||||
return json.dumps({"id": test_edge_id, "score": 0.9})
|
||||
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "score": 0.9}])
|
||||
elif prompt_name == "kg-edge-reasoning":
|
||||
return json.dumps({"id": test_edge_id, "reasoning": "relevant"})
|
||||
return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "reasoning": "relevant"}])
|
||||
elif prompt_name == "kg-synthesis":
|
||||
return expected_response
|
||||
return ""
|
||||
return PromptResult(response_type="text", text=expected_response)
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
mock_prompt_client.prompt = mock_prompt
|
||||
|
||||
|
|
@ -607,7 +608,8 @@ class TestQuery:
|
|||
explain_callback=collect_provenance
|
||||
)
|
||||
|
||||
assert response == expected_response
|
||||
response_text, usage = response
|
||||
assert response_text == expected_response
|
||||
|
||||
# 5 events: question, grounding, exploration, focus, synthesis
|
||||
assert len(provenance_events) == 5
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from dataclasses import dataclass
|
|||
|
||||
from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, edge_id
|
||||
from trustgraph.schema import Triple as SchemaTriple, Term, IRI, LITERAL
|
||||
from trustgraph.base import PromptResult
|
||||
|
||||
from trustgraph.provenance.namespaces import (
|
||||
RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM,
|
||||
|
|
@ -136,24 +137,36 @@ def build_mock_clients():
|
|||
|
||||
async def mock_prompt(template_id, variables=None, **kwargs):
|
||||
if template_id == "extract-concepts":
|
||||
return prompt_responses["extract-concepts"]
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
text=prompt_responses["extract-concepts"],
|
||||
)
|
||||
elif template_id == "kg-edge-scoring":
|
||||
# Score all edges highly, using the IDs that GraphRag computed
|
||||
edges = variables.get("knowledge", [])
|
||||
return [
|
||||
{"id": e["id"], "score": 10 - i}
|
||||
for i, e in enumerate(edges)
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"id": e["id"], "score": 10 - i}
|
||||
for i, e in enumerate(edges)
|
||||
],
|
||||
)
|
||||
elif template_id == "kg-edge-reasoning":
|
||||
# Provide reasoning for each edge
|
||||
edges = variables.get("knowledge", [])
|
||||
return [
|
||||
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
|
||||
for i, e in enumerate(edges)
|
||||
]
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=[
|
||||
{"id": e["id"], "reasoning": f"Relevant edge {i}"}
|
||||
for i, e in enumerate(edges)
|
||||
],
|
||||
)
|
||||
elif template_id == "kg-synthesis":
|
||||
return synthesis_answer
|
||||
return ""
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
text=synthesis_answer,
|
||||
)
|
||||
return PromptResult(response_type="text", text="")
|
||||
|
||||
prompt_client.prompt.side_effect = mock_prompt
|
||||
|
||||
|
|
@ -413,13 +426,13 @@ class TestGraphRagQueryProvenance:
|
|||
async def explain_callback(triples, explain_id):
|
||||
events.append({"triples": triples, "explain_id": explain_id})
|
||||
|
||||
result = await rag.query(
|
||||
result_text, usage = await rag.query(
|
||||
query="What is quantum computing?",
|
||||
explain_callback=explain_callback,
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
assert result == "Quantum computing applies physics principles to computation."
|
||||
assert result_text == "Quantum computing applies physics principles to computation."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parent_uri_links_question_to_parent(self):
|
||||
|
|
@ -450,12 +463,12 @@ class TestGraphRagQueryProvenance:
|
|||
clients = build_mock_clients()
|
||||
rag = GraphRag(*clients)
|
||||
|
||||
result = await rag.query(
|
||||
result_text, usage = await rag.query(
|
||||
query="What is quantum computing?",
|
||||
edge_score_limit=0,
|
||||
)
|
||||
|
||||
assert result == "Quantum computing applies physics principles to computation."
|
||||
assert result_text == "Quantum computing applies physics principles to computation."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_triples_in_retrieval_graph(self):
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class TestGraphRagService:
|
|||
await explain_callback([], "urn:trustgraph:prov:retrieval:test")
|
||||
await explain_callback([], "urn:trustgraph:prov:selection:test")
|
||||
await explain_callback([], "urn:trustgraph:prov:answer:test")
|
||||
return "A small domesticated mammal."
|
||||
return "A small domesticated mammal.", {"in_token": None, "out_token": None, "model": None}
|
||||
|
||||
mock_rag_instance.query.side_effect = mock_query
|
||||
|
||||
|
|
@ -79,8 +79,8 @@ class TestGraphRagService:
|
|||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: 6 messages sent (4 provenance + 1 chunk + 1 end_of_session)
|
||||
assert mock_response_producer.send.call_count == 6
|
||||
# Verify: 5 messages sent (4 provenance + 1 combined chunk with end_of_session)
|
||||
assert mock_response_producer.send.call_count == 5
|
||||
|
||||
# First 4 messages are explain (emitted in real-time during query)
|
||||
for i in range(4):
|
||||
|
|
@ -88,17 +88,12 @@ class TestGraphRagService:
|
|||
assert prov_msg.message_type == "explain"
|
||||
assert prov_msg.explain_id is not None
|
||||
|
||||
# 5th message is chunk with response
|
||||
# 5th message is chunk with response and end_of_session
|
||||
chunk_msg = mock_response_producer.send.call_args_list[4][0][0]
|
||||
assert chunk_msg.message_type == "chunk"
|
||||
assert chunk_msg.response == "A small domesticated mammal."
|
||||
assert chunk_msg.end_of_stream is True
|
||||
|
||||
# 6th message is empty chunk with end_of_session=True
|
||||
close_msg = mock_response_producer.send.call_args_list[5][0][0]
|
||||
assert close_msg.message_type == "chunk"
|
||||
assert close_msg.response == ""
|
||||
assert close_msg.end_of_session is True
|
||||
assert chunk_msg.end_of_session is True
|
||||
|
||||
# Verify provenance triples were sent to provenance queue
|
||||
assert mock_provenance_producer.send.call_count == 4
|
||||
|
|
@ -187,7 +182,7 @@ class TestGraphRagService:
|
|||
|
||||
async def mock_query(**kwargs):
|
||||
# Don't call explain_callback
|
||||
return "Response text"
|
||||
return "Response text", {"in_token": None, "out_token": None, "model": None}
|
||||
|
||||
mock_rag_instance.query.side_effect = mock_query
|
||||
|
||||
|
|
@ -218,17 +213,12 @@ class TestGraphRagService:
|
|||
# Execute
|
||||
await processor.on_request(msg, consumer, flow)
|
||||
|
||||
# Verify: 2 messages (chunk + empty chunk to close)
|
||||
assert mock_response_producer.send.call_count == 2
|
||||
# Verify: 1 combined message (chunk with end_of_session)
|
||||
assert mock_response_producer.send.call_count == 1
|
||||
|
||||
# First is the response chunk
|
||||
# Single message has response and end_of_session
|
||||
chunk_msg = mock_response_producer.send.call_args_list[0][0][0]
|
||||
assert chunk_msg.message_type == "chunk"
|
||||
assert chunk_msg.response == "Response text"
|
||||
assert chunk_msg.end_of_stream is True
|
||||
|
||||
# Second is empty chunk to close session
|
||||
close_msg = mock_response_producer.send.call_args_list[1][0][0]
|
||||
assert close_msg.message_type == "chunk"
|
||||
assert close_msg.response == ""
|
||||
assert close_msg.end_of_session is True
|
||||
assert chunk_msg.end_of_session is True
|
||||
|
|
|
|||
0
tests/unit/test_tables/__init__.py
Normal file
0
tests/unit/test_tables/__init__.py
Normal file
197
tests/unit/test_tables/test_knowledge_table_store.py
Normal file
197
tests/unit/test_tables/test_knowledge_table_store.py
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
"""
|
||||
Unit tests for KnowledgeTableStore row deserialization.
|
||||
|
||||
Regression coverage: a previous version of get_graph_embeddings constructed
|
||||
EntityEmbeddings(vectors=ent[1]) — the schema field is `vector` (singular),
|
||||
so any real Cassandra row would crash on read. These tests bypass the live
|
||||
Cassandra connection entirely and exercise the row -> schema conversion
|
||||
with hand-built fake rows.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
|
||||
from trustgraph.tables.knowledge import KnowledgeTableStore
|
||||
from trustgraph.schema import (
|
||||
EntityEmbeddings,
|
||||
GraphEmbeddings,
|
||||
Triples,
|
||||
Triple,
|
||||
Metadata,
|
||||
IRI,
|
||||
LITERAL,
|
||||
)
|
||||
|
||||
|
||||
def _make_store():
|
||||
"""
|
||||
Build a KnowledgeTableStore without invoking __init__ (which connects
|
||||
to Cassandra). Tests inject only the attributes the method under test
|
||||
actually touches.
|
||||
"""
|
||||
return KnowledgeTableStore.__new__(KnowledgeTableStore)
|
||||
|
||||
|
||||
class TestGetGraphEmbeddings:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_row_converts_to_entity_embeddings_with_singular_vector(self):
|
||||
"""
|
||||
Cassandra rows return entities as a list of [entity_tuple, vector]
|
||||
pairs in row[3]. The deserializer must construct EntityEmbeddings
|
||||
with `vector=` (singular) — the schema field name. A previous
|
||||
version used `vectors=` and TypeError'd at runtime.
|
||||
"""
|
||||
# Arrange — fake row matching the get_triples_stmt result shape:
|
||||
# row[0..2] are unused by the method, row[3] is the entities blob
|
||||
fake_row = (
|
||||
None, None, None,
|
||||
[
|
||||
# ((value, is_uri), vector)
|
||||
(("http://example.org/alice", True), [0.1, 0.2, 0.3]),
|
||||
(("http://example.org/bob", True), [0.4, 0.5, 0.6]),
|
||||
(("a literal entity", False), [0.7, 0.8, 0.9]),
|
||||
],
|
||||
)
|
||||
|
||||
store = _make_store()
|
||||
store.cassandra = Mock()
|
||||
store.cassandra.execute = Mock(return_value=[fake_row])
|
||||
store.get_graph_embeddings_stmt = Mock()
|
||||
|
||||
received = []
|
||||
|
||||
async def receiver(msg):
|
||||
received.append(msg)
|
||||
|
||||
# Act
|
||||
await store.get_graph_embeddings(
|
||||
user="alice",
|
||||
document_id="doc-1",
|
||||
receiver=receiver,
|
||||
)
|
||||
|
||||
# Assert
|
||||
store.cassandra.execute.assert_called_once_with(
|
||||
store.get_graph_embeddings_stmt,
|
||||
("alice", "doc-1"),
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
ge = received[0]
|
||||
assert isinstance(ge, GraphEmbeddings)
|
||||
assert isinstance(ge.metadata, Metadata)
|
||||
assert ge.metadata.id == "doc-1"
|
||||
assert ge.metadata.user == "alice"
|
||||
|
||||
assert len(ge.entities) == 3
|
||||
assert all(isinstance(e, EntityEmbeddings) for e in ge.entities)
|
||||
|
||||
# Vectors land in the singular `vector` field — this is the
|
||||
# explicit regression assertion for the original bug.
|
||||
assert ge.entities[0].vector == [0.1, 0.2, 0.3]
|
||||
assert ge.entities[1].vector == [0.4, 0.5, 0.6]
|
||||
assert ge.entities[2].vector == [0.7, 0.8, 0.9]
|
||||
|
||||
# Term type round-trips through tuple_to_term
|
||||
assert ge.entities[0].entity.type == IRI
|
||||
assert ge.entities[0].entity.iri == "http://example.org/alice"
|
||||
assert ge.entities[1].entity.type == IRI
|
||||
assert ge.entities[1].entity.iri == "http://example.org/bob"
|
||||
assert ge.entities[2].entity.type == LITERAL
|
||||
assert ge.entities[2].entity.value == "a literal entity"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_entities_blob_yields_empty_list(self):
|
||||
"""row[3] being None / empty must produce a GraphEmbeddings with
|
||||
no entities, not raise."""
|
||||
fake_row = (None, None, None, None)
|
||||
|
||||
store = _make_store()
|
||||
store.cassandra = Mock()
|
||||
store.cassandra.execute = Mock(return_value=[fake_row])
|
||||
store.get_graph_embeddings_stmt = Mock()
|
||||
|
||||
received = []
|
||||
|
||||
async def receiver(msg):
|
||||
received.append(msg)
|
||||
|
||||
await store.get_graph_embeddings("u", "d", receiver)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].entities == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_rows_each_emit_one_message(self):
|
||||
fake_rows = [
|
||||
(None, None, None, [
|
||||
(("http://example.org/a", True), [1.0]),
|
||||
]),
|
||||
(None, None, None, [
|
||||
(("http://example.org/b", True), [2.0]),
|
||||
]),
|
||||
]
|
||||
|
||||
store = _make_store()
|
||||
store.cassandra = Mock()
|
||||
store.cassandra.execute = Mock(return_value=fake_rows)
|
||||
store.get_graph_embeddings_stmt = Mock()
|
||||
|
||||
received = []
|
||||
|
||||
async def receiver(msg):
|
||||
received.append(msg)
|
||||
|
||||
await store.get_graph_embeddings("u", "d", receiver)
|
||||
|
||||
assert len(received) == 2
|
||||
assert received[0].entities[0].entity.iri == "http://example.org/a"
|
||||
assert received[0].entities[0].vector == [1.0]
|
||||
assert received[1].entities[0].entity.iri == "http://example.org/b"
|
||||
assert received[1].entities[0].vector == [2.0]
|
||||
|
||||
|
||||
class TestGetTriples:
|
||||
"""Bonus: the sibling get_triples path uses the same row[3] shape and
|
||||
the same Metadata construction. Cover it for parity."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_row_converts_to_triples(self):
|
||||
# row[3] is a list of (s_val, s_uri, p_val, p_uri, o_val, o_uri)
|
||||
fake_row = (
|
||||
None, None, None,
|
||||
[
|
||||
(
|
||||
"http://example.org/alice", True,
|
||||
"http://example.org/knows", True,
|
||||
"http://example.org/bob", True,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
store = _make_store()
|
||||
store.cassandra = Mock()
|
||||
store.cassandra.execute = Mock(return_value=[fake_row])
|
||||
store.get_triples_stmt = Mock()
|
||||
|
||||
received = []
|
||||
|
||||
async def receiver(msg):
|
||||
received.append(msg)
|
||||
|
||||
await store.get_triples("alice", "doc-1", receiver)
|
||||
|
||||
assert len(received) == 1
|
||||
triples_msg = received[0]
|
||||
assert isinstance(triples_msg, Triples)
|
||||
assert isinstance(triples_msg.metadata, Metadata)
|
||||
assert triples_msg.metadata.id == "doc-1"
|
||||
assert triples_msg.metadata.user == "alice"
|
||||
|
||||
assert len(triples_msg.triples) == 1
|
||||
t = triples_msg.triples[0]
|
||||
assert isinstance(t, Triple)
|
||||
assert t.s.iri == "http://example.org/alice"
|
||||
assert t.p.iri == "http://example.org/knows"
|
||||
assert t.o.iri == "http://example.org/bob"
|
||||
0
tests/unit/test_translators/__init__.py
Normal file
0
tests/unit/test_translators/__init__.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
"""
|
||||
Round-trip unit tests for DocumentEmbeddingsTranslator.
|
||||
|
||||
Regression coverage: a previous version of the decode side constructed
|
||||
ChunkEmbeddings(vectors=...) — the schema field is `vector` (singular),
|
||||
so any real DocumentEmbeddings message would crash on decode. The encode
|
||||
side already wrote `"vector"`, so encode→decode was asymmetric.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.messaging.translators.document_loading import (
|
||||
DocumentEmbeddingsTranslator,
|
||||
)
|
||||
from trustgraph.schema import (
|
||||
DocumentEmbeddings,
|
||||
ChunkEmbeddings,
|
||||
Metadata,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def translator():
|
||||
return DocumentEmbeddingsTranslator()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample():
|
||||
return DocumentEmbeddings(
|
||||
metadata=Metadata(
|
||||
id="doc-1",
|
||||
root="",
|
||||
user="alice",
|
||||
collection="testcoll",
|
||||
),
|
||||
chunks=[
|
||||
ChunkEmbeddings(chunk_id="c1", vector=[0.1, 0.2, 0.3]),
|
||||
ChunkEmbeddings(chunk_id="c2", vector=[0.4, 0.5, 0.6]),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestDocumentEmbeddingsTranslator:
|
||||
|
||||
def test_encode_uses_singular_vector_key(self, translator, sample):
|
||||
encoded = translator.encode(sample)
|
||||
chunks = encoded["chunks"]
|
||||
assert all("vector" in c for c in chunks)
|
||||
assert all("vectors" not in c for c in chunks)
|
||||
assert chunks[0]["vector"] == [0.1, 0.2, 0.3]
|
||||
|
||||
def test_roundtrip_preserves_document_embeddings(self, translator, sample):
|
||||
encoded = translator.encode(sample)
|
||||
decoded = translator.decode(encoded)
|
||||
|
||||
assert isinstance(decoded, DocumentEmbeddings)
|
||||
assert isinstance(decoded.metadata, Metadata)
|
||||
assert decoded.metadata.id == "doc-1"
|
||||
assert decoded.metadata.user == "alice"
|
||||
assert decoded.metadata.collection == "testcoll"
|
||||
|
||||
assert len(decoded.chunks) == 2
|
||||
assert decoded.chunks[0].chunk_id == "c1"
|
||||
assert decoded.chunks[0].vector == [0.1, 0.2, 0.3]
|
||||
assert decoded.chunks[1].chunk_id == "c2"
|
||||
assert decoded.chunks[1].vector == [0.4, 0.5, 0.6]
|
||||
|
|
@ -0,0 +1,153 @@
|
|||
"""
|
||||
Round-trip unit tests for KnowledgeRequestTranslator.
|
||||
|
||||
Regression coverage: a previous version of the decode side constructed
|
||||
EntityEmbeddings(vectors=...) — the schema field is `vector` (singular),
|
||||
so any real graph-embeddings KnowledgeRequest would crash on first
|
||||
message. The encode side already wrote `"vector"`, so encode→decode was
|
||||
asymmetric.
|
||||
|
||||
These tests build a real KnowledgeRequest with graph-embeddings, encode
|
||||
it, decode the result, and assert the round-trip is lossless. They also
|
||||
exercise the triples path so any future schema drift in Metadata or
|
||||
Triples breaks the test.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from trustgraph.messaging.translators.knowledge import KnowledgeRequestTranslator
|
||||
from trustgraph.schema import (
|
||||
KnowledgeRequest,
|
||||
GraphEmbeddings,
|
||||
EntityEmbeddings,
|
||||
Triples,
|
||||
Triple,
|
||||
Metadata,
|
||||
Term,
|
||||
IRI,
|
||||
)
|
||||
|
||||
|
||||
def _term_iri(uri):
|
||||
return Term(type=IRI, iri=uri)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def translator():
|
||||
return KnowledgeRequestTranslator()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graph_embeddings_request():
|
||||
return KnowledgeRequest(
|
||||
operation="put-kg-core",
|
||||
user="alice",
|
||||
id="doc-1",
|
||||
flow="default",
|
||||
collection="testcoll",
|
||||
graph_embeddings=GraphEmbeddings(
|
||||
metadata=Metadata(
|
||||
id="doc-1",
|
||||
root="",
|
||||
user="alice",
|
||||
collection="testcoll",
|
||||
),
|
||||
entities=[
|
||||
EntityEmbeddings(
|
||||
entity=_term_iri("http://example.org/alice"),
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
),
|
||||
EntityEmbeddings(
|
||||
entity=_term_iri("http://example.org/bob"),
|
||||
vector=[0.4, 0.5, 0.6],
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def triples_request():
|
||||
return KnowledgeRequest(
|
||||
operation="put-kg-core",
|
||||
user="alice",
|
||||
id="doc-1",
|
||||
flow="default",
|
||||
collection="testcoll",
|
||||
triples=Triples(
|
||||
metadata=Metadata(
|
||||
id="doc-1",
|
||||
root="",
|
||||
user="alice",
|
||||
collection="testcoll",
|
||||
),
|
||||
triples=[
|
||||
Triple(
|
||||
s=_term_iri("http://example.org/alice"),
|
||||
p=_term_iri("http://example.org/knows"),
|
||||
o=_term_iri("http://example.org/bob"),
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestKnowledgeRequestTranslatorGraphEmbeddings:
|
||||
|
||||
def test_encode_produces_singular_vector_key(
|
||||
self, translator, graph_embeddings_request,
|
||||
):
|
||||
"""The wire key must be `vector`, never `vectors`."""
|
||||
encoded = translator.encode(graph_embeddings_request)
|
||||
entities = encoded["graph-embeddings"]["entities"]
|
||||
assert all("vector" in e for e in entities)
|
||||
assert all("vectors" not in e for e in entities)
|
||||
assert entities[0]["vector"] == [0.1, 0.2, 0.3]
|
||||
|
||||
def test_roundtrip_preserves_graph_embeddings(
|
||||
self, translator, graph_embeddings_request,
|
||||
):
|
||||
"""encode -> decode must be lossless for the GE branch."""
|
||||
encoded = translator.encode(graph_embeddings_request)
|
||||
decoded = translator.decode(encoded)
|
||||
|
||||
assert isinstance(decoded, KnowledgeRequest)
|
||||
assert decoded.operation == "put-kg-core"
|
||||
assert decoded.user == "alice"
|
||||
assert decoded.id == "doc-1"
|
||||
assert decoded.flow == "default"
|
||||
assert decoded.collection == "testcoll"
|
||||
|
||||
assert decoded.graph_embeddings is not None
|
||||
ge = decoded.graph_embeddings
|
||||
assert isinstance(ge, GraphEmbeddings)
|
||||
assert isinstance(ge.metadata, Metadata)
|
||||
assert ge.metadata.id == "doc-1"
|
||||
assert ge.metadata.user == "alice"
|
||||
assert ge.metadata.collection == "testcoll"
|
||||
|
||||
assert len(ge.entities) == 2
|
||||
assert ge.entities[0].vector == [0.1, 0.2, 0.3]
|
||||
assert ge.entities[1].vector == [0.4, 0.5, 0.6]
|
||||
assert ge.entities[0].entity.iri == "http://example.org/alice"
|
||||
assert ge.entities[1].entity.iri == "http://example.org/bob"
|
||||
|
||||
|
||||
class TestKnowledgeRequestTranslatorTriples:
|
||||
|
||||
def test_roundtrip_preserves_triples(self, translator, triples_request):
|
||||
encoded = translator.encode(triples_request)
|
||||
decoded = translator.decode(encoded)
|
||||
|
||||
assert isinstance(decoded, KnowledgeRequest)
|
||||
assert decoded.triples is not None
|
||||
assert isinstance(decoded.triples.metadata, Metadata)
|
||||
assert decoded.triples.metadata.id == "doc-1"
|
||||
assert decoded.triples.metadata.user == "alice"
|
||||
assert decoded.triples.metadata.collection == "testcoll"
|
||||
|
||||
assert len(decoded.triples.triples) == 1
|
||||
t = decoded.triples.triples[0]
|
||||
assert t.s.iri == "http://example.org/alice"
|
||||
assert t.p.iri == "http://example.org/knows"
|
||||
assert t.o.iri == "http://example.org/bob"
|
||||
|
|
@ -9,7 +9,7 @@ from .streaming_assertions import (
|
|||
assert_streaming_content_matches,
|
||||
assert_no_empty_chunks,
|
||||
assert_streaming_error_handled,
|
||||
assert_chunk_types_valid,
|
||||
assert_message_types_valid,
|
||||
assert_streaming_latency_acceptable,
|
||||
assert_callback_invoked,
|
||||
)
|
||||
|
|
@ -23,7 +23,7 @@ __all__ = [
|
|||
"assert_streaming_content_matches",
|
||||
"assert_no_empty_chunks",
|
||||
"assert_streaming_error_handled",
|
||||
"assert_chunk_types_valid",
|
||||
"assert_message_types_valid",
|
||||
"assert_streaming_latency_acceptable",
|
||||
"assert_callback_invoked",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -20,14 +20,14 @@ def assert_streaming_chunks_valid(chunks: List[Any], min_chunks: int = 1):
|
|||
assert all(chunk is not None for chunk in chunks), "All chunks should be non-None"
|
||||
|
||||
|
||||
def assert_streaming_sequence(chunks: List[Dict[str, Any]], expected_sequence: List[str], key: str = "chunk_type"):
|
||||
def assert_streaming_sequence(chunks: List[Dict[str, Any]], expected_sequence: List[str], key: str = "message_type"):
|
||||
"""
|
||||
Assert that streaming chunks follow an expected sequence.
|
||||
|
||||
Args:
|
||||
chunks: List of chunk dictionaries
|
||||
expected_sequence: Expected sequence of chunk types/values
|
||||
key: Dictionary key to check (default: "chunk_type")
|
||||
key: Dictionary key to check (default: "message_type")
|
||||
"""
|
||||
actual_sequence = [chunk.get(key) for chunk in chunks if key in chunk]
|
||||
assert actual_sequence == expected_sequence, \
|
||||
|
|
@ -39,7 +39,7 @@ def assert_agent_streaming_chunks(chunks: List[Dict[str, Any]]):
|
|||
Assert that agent streaming chunks have valid structure.
|
||||
|
||||
Validates:
|
||||
- All chunks have chunk_type field
|
||||
- All chunks have message_type field
|
||||
- All chunks have content field
|
||||
- All chunks have end_of_message field
|
||||
- All chunks have end_of_dialog field
|
||||
|
|
@ -51,15 +51,15 @@ def assert_agent_streaming_chunks(chunks: List[Dict[str, Any]]):
|
|||
assert len(chunks) > 0, "Expected at least one chunk"
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
assert "chunk_type" in chunk, f"Chunk {i} missing chunk_type"
|
||||
assert "message_type" in chunk, f"Chunk {i} missing message_type"
|
||||
assert "content" in chunk, f"Chunk {i} missing content"
|
||||
assert "end_of_message" in chunk, f"Chunk {i} missing end_of_message"
|
||||
assert "end_of_dialog" in chunk, f"Chunk {i} missing end_of_dialog"
|
||||
|
||||
# Validate chunk_type values
|
||||
# Validate message_type values
|
||||
valid_types = ["thought", "action", "observation", "final-answer"]
|
||||
assert chunk["chunk_type"] in valid_types, \
|
||||
f"Invalid chunk_type '{chunk['chunk_type']}' at index {i}"
|
||||
assert chunk["message_type"] in valid_types, \
|
||||
f"Invalid message_type '{chunk['message_type']}' at index {i}"
|
||||
|
||||
# Last chunk should signal end of dialog
|
||||
assert chunks[-1]["end_of_dialog"] is True, \
|
||||
|
|
@ -175,7 +175,7 @@ def assert_streaming_error_handled(chunks: List[Dict[str, Any]], error_flag: str
|
|||
"Error chunk should have completion flag set to True"
|
||||
|
||||
|
||||
def assert_chunk_types_valid(chunks: List[Dict[str, Any]], valid_types: List[str], type_key: str = "chunk_type"):
|
||||
def assert_message_types_valid(chunks: List[Dict[str, Any]], valid_types: List[str], type_key: str = "message_type"):
|
||||
"""
|
||||
Assert that all chunk types are from a valid set.
|
||||
|
||||
|
|
@ -185,9 +185,9 @@ def assert_chunk_types_valid(chunks: List[Dict[str, Any]], valid_types: List[str
|
|||
type_key: Dictionary key for chunk type
|
||||
"""
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk_type = chunk.get(type_key)
|
||||
assert chunk_type in valid_types, \
|
||||
f"Chunk {i} has invalid type '{chunk_type}', expected one of {valid_types}"
|
||||
message_type = chunk.get(type_key)
|
||||
assert message_type in valid_types, \
|
||||
f"Chunk {i} has invalid type '{message_type}', expected one of {valid_types}"
|
||||
|
||||
|
||||
def assert_streaming_latency_acceptable(chunk_timestamps: List[float], max_gap_seconds: float = 5.0):
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ dependencies = [
|
|||
"requests",
|
||||
"python-logging-loki",
|
||||
"pika",
|
||||
"pyyaml",
|
||||
]
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
|
|
@ -24,6 +25,9 @@ classifiers = [
|
|||
[project.urls]
|
||||
Homepage = "https://github.com/trustgraph-ai/trustgraph"
|
||||
|
||||
[project.scripts]
|
||||
processor-group = "trustgraph.base.processor_group:run"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["trustgraph*"]
|
||||
|
||||
|
|
@ -31,4 +35,4 @@ include = ["trustgraph*"]
|
|||
"trustgraph.i18n.packs" = ["*.json"]
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
version = {attr = "trustgraph.base_version.__version__"}
|
||||
version = {attr = "trustgraph.base_version.__version__"}
|
||||
|
|
|
|||
|
|
@ -107,6 +107,7 @@ from .types import (
|
|||
AgentObservation,
|
||||
AgentAnswer,
|
||||
RAGChunk,
|
||||
TextCompletionResult,
|
||||
ProvenanceEvent,
|
||||
)
|
||||
|
||||
|
|
@ -185,6 +186,7 @@ __all__ = [
|
|||
"AgentObservation",
|
||||
"AgentAnswer",
|
||||
"RAGChunk",
|
||||
"TextCompletionResult",
|
||||
"ProvenanceEvent",
|
||||
|
||||
# Exceptions
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ import aiohttp
|
|||
import json
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
from . types import TextCompletionResult
|
||||
|
||||
from . exceptions import ProtocolException, ApplicationException
|
||||
|
||||
|
||||
|
|
@ -434,12 +436,11 @@ class AsyncFlowInstance:
|
|||
|
||||
return await self.request("agent", request_data)
|
||||
|
||||
async def text_completion(self, system: str, prompt: str, **kwargs: Any) -> str:
|
||||
async def text_completion(self, system: str, prompt: str, **kwargs: Any) -> TextCompletionResult:
|
||||
"""
|
||||
Generate text completion (non-streaming).
|
||||
|
||||
Generates a text response from an LLM given a system prompt and user prompt.
|
||||
Returns the complete response text.
|
||||
|
||||
Note: This method does not support streaming. For streaming text generation,
|
||||
use AsyncSocketFlowInstance.text_completion() instead.
|
||||
|
|
@ -450,19 +451,19 @@ class AsyncFlowInstance:
|
|||
**kwargs: Additional service-specific parameters
|
||||
|
||||
Returns:
|
||||
str: Complete generated text response
|
||||
TextCompletionResult: Result with text, in_token, out_token, model
|
||||
|
||||
Example:
|
||||
```python
|
||||
async_flow = await api.async_flow()
|
||||
flow = async_flow.id("default")
|
||||
|
||||
# Generate text
|
||||
response = await flow.text_completion(
|
||||
result = await flow.text_completion(
|
||||
system="You are a helpful assistant.",
|
||||
prompt="Explain quantum computing in simple terms."
|
||||
)
|
||||
print(response)
|
||||
print(result.text)
|
||||
print(f"Tokens: {result.in_token} in, {result.out_token} out")
|
||||
```
|
||||
"""
|
||||
request_data = {
|
||||
|
|
@ -473,7 +474,12 @@ class AsyncFlowInstance:
|
|||
request_data.update(kwargs)
|
||||
|
||||
result = await self.request("text-completion", request_data)
|
||||
return result.get("response", "")
|
||||
return TextCompletionResult(
|
||||
text=result.get("response", ""),
|
||||
in_token=result.get("in_token"),
|
||||
out_token=result.get("out_token"),
|
||||
model=result.get("model"),
|
||||
)
|
||||
|
||||
async def graph_rag(self, query: str, user: str, collection: str,
|
||||
max_subgraph_size: int = 1000, max_subgraph_count: int = 5,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import asyncio
|
|||
import websockets
|
||||
from typing import Optional, Dict, Any, AsyncIterator, Union
|
||||
|
||||
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk
|
||||
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, TextCompletionResult
|
||||
from . exceptions import ProtocolException, ApplicationException
|
||||
|
||||
|
||||
|
|
@ -178,30 +178,32 @@ class AsyncSocketClient:
|
|||
|
||||
def _parse_chunk(self, resp: Dict[str, Any]):
|
||||
"""Parse response chunk into appropriate type. Returns None for non-content messages."""
|
||||
chunk_type = resp.get("chunk_type")
|
||||
message_type = resp.get("message_type")
|
||||
|
||||
# Handle new GraphRAG message format with message_type
|
||||
if message_type == "provenance":
|
||||
return None
|
||||
|
||||
if chunk_type == "thought":
|
||||
if message_type == "thought":
|
||||
return AgentThought(
|
||||
content=resp.get("content", ""),
|
||||
end_of_message=resp.get("end_of_message", False)
|
||||
)
|
||||
elif chunk_type == "observation":
|
||||
elif message_type == "observation":
|
||||
return AgentObservation(
|
||||
content=resp.get("content", ""),
|
||||
end_of_message=resp.get("end_of_message", False)
|
||||
)
|
||||
elif chunk_type == "answer" or chunk_type == "final-answer":
|
||||
elif message_type == "answer" or message_type == "final-answer":
|
||||
return AgentAnswer(
|
||||
content=resp.get("content", ""),
|
||||
end_of_message=resp.get("end_of_message", False),
|
||||
end_of_dialog=resp.get("end_of_dialog", False)
|
||||
end_of_dialog=resp.get("end_of_dialog", False),
|
||||
in_token=resp.get("in_token"),
|
||||
out_token=resp.get("out_token"),
|
||||
model=resp.get("model"),
|
||||
)
|
||||
elif chunk_type == "action":
|
||||
elif message_type == "action":
|
||||
return AgentThought(
|
||||
content=resp.get("content", ""),
|
||||
end_of_message=resp.get("end_of_message", False)
|
||||
|
|
@ -211,7 +213,10 @@ class AsyncSocketClient:
|
|||
return RAGChunk(
|
||||
content=content,
|
||||
end_of_stream=resp.get("end_of_stream", False),
|
||||
error=None
|
||||
error=None,
|
||||
in_token=resp.get("in_token"),
|
||||
out_token=resp.get("out_token"),
|
||||
model=resp.get("model"),
|
||||
)
|
||||
|
||||
async def aclose(self):
|
||||
|
|
@ -269,7 +274,11 @@ class AsyncSocketFlowInstance:
|
|||
return await self.client._send_request("agent", self.flow_id, request)
|
||||
|
||||
async def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs):
|
||||
"""Text completion with optional streaming"""
|
||||
"""Text completion with optional streaming.
|
||||
|
||||
Non-streaming: returns a TextCompletionResult with text and token counts.
|
||||
Streaming: returns an async iterator of RAGChunk (with token counts on the final chunk).
|
||||
"""
|
||||
request = {
|
||||
"system": system,
|
||||
"prompt": prompt,
|
||||
|
|
@ -281,13 +290,18 @@ class AsyncSocketFlowInstance:
|
|||
return self._text_completion_streaming(request)
|
||||
else:
|
||||
result = await self.client._send_request("text-completion", self.flow_id, request)
|
||||
return result.get("response", "")
|
||||
return TextCompletionResult(
|
||||
text=result.get("response", ""),
|
||||
in_token=result.get("in_token"),
|
||||
out_token=result.get("out_token"),
|
||||
model=result.get("model"),
|
||||
)
|
||||
|
||||
async def _text_completion_streaming(self, request):
|
||||
"""Helper for streaming text completion"""
|
||||
"""Helper for streaming text completion. Yields RAGChunk objects."""
|
||||
async for chunk in self.client._send_request_streaming("text-completion", self.flow_id, request):
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
if isinstance(chunk, RAGChunk):
|
||||
yield chunk
|
||||
|
||||
async def graph_rag(self, query: str, user: str, collection: str,
|
||||
max_subgraph_size: int = 1000, max_subgraph_count: int = 5,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import base64
|
|||
|
||||
from .. knowledge import hash, Uri, Literal, QuotedTriple
|
||||
from .. schema import IRI, LITERAL, TRIPLE
|
||||
from . types import Triple
|
||||
from . types import Triple, TextCompletionResult
|
||||
from . exceptions import ProtocolException
|
||||
|
||||
|
||||
|
|
@ -360,16 +360,17 @@ class FlowInstance:
|
|||
prompt: User prompt/question
|
||||
|
||||
Returns:
|
||||
str: Generated response text
|
||||
TextCompletionResult: Result with text, in_token, out_token, model
|
||||
|
||||
Example:
|
||||
```python
|
||||
flow = api.flow().id("default")
|
||||
response = flow.text_completion(
|
||||
result = flow.text_completion(
|
||||
system="You are a helpful assistant",
|
||||
prompt="What is quantum computing?"
|
||||
)
|
||||
print(response)
|
||||
print(result.text)
|
||||
print(f"Tokens: {result.in_token} in, {result.out_token} out")
|
||||
```
|
||||
"""
|
||||
|
||||
|
|
@ -379,10 +380,17 @@ class FlowInstance:
|
|||
"prompt": prompt
|
||||
}
|
||||
|
||||
return self.request(
|
||||
result = self.request(
|
||||
"service/text-completion",
|
||||
input
|
||||
)["response"]
|
||||
)
|
||||
|
||||
return TextCompletionResult(
|
||||
text=result.get("response", ""),
|
||||
in_token=result.get("in_token"),
|
||||
out_token=result.get("out_token"),
|
||||
model=result.get("model"),
|
||||
)
|
||||
|
||||
def agent(self, question, user="trustgraph", state=None, group=None, history=None):
|
||||
"""
|
||||
|
|
@ -498,10 +506,17 @@ class FlowInstance:
|
|||
"edge-limit": edge_limit,
|
||||
}
|
||||
|
||||
return self.request(
|
||||
result = self.request(
|
||||
"service/graph-rag",
|
||||
input
|
||||
)["response"]
|
||||
)
|
||||
|
||||
return TextCompletionResult(
|
||||
text=result.get("response", ""),
|
||||
in_token=result.get("in_token"),
|
||||
out_token=result.get("out_token"),
|
||||
model=result.get("model"),
|
||||
)
|
||||
|
||||
def document_rag(
|
||||
self, query, user="trustgraph", collection="default",
|
||||
|
|
@ -543,10 +558,17 @@ class FlowInstance:
|
|||
"doc-limit": doc_limit,
|
||||
}
|
||||
|
||||
return self.request(
|
||||
result = self.request(
|
||||
"service/document-rag",
|
||||
input
|
||||
)["response"]
|
||||
)
|
||||
|
||||
return TextCompletionResult(
|
||||
text=result.get("response", ""),
|
||||
in_token=result.get("in_token"),
|
||||
out_token=result.get("out_token"),
|
||||
model=result.get("model"),
|
||||
)
|
||||
|
||||
def embeddings(self, texts):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ import websockets
|
|||
from typing import Optional, Dict, Any, Iterator, Union, List
|
||||
from threading import Lock
|
||||
|
||||
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk, ProvenanceEvent
|
||||
from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk, ProvenanceEvent, TextCompletionResult
|
||||
from . exceptions import ProtocolException, raise_from_error_dict
|
||||
|
||||
|
||||
|
|
@ -360,41 +360,36 @@ class SocketClient:
|
|||
|
||||
def _parse_chunk(self, resp: Dict[str, Any], include_provenance: bool = False) -> Optional[StreamingChunk]:
|
||||
"""Parse response chunk into appropriate type. Returns None for non-content messages."""
|
||||
chunk_type = resp.get("chunk_type")
|
||||
message_type = resp.get("message_type")
|
||||
|
||||
# Handle GraphRAG/DocRAG message format with message_type
|
||||
if message_type == "explain":
|
||||
if include_provenance:
|
||||
return self._build_provenance_event(resp)
|
||||
return None
|
||||
|
||||
# Handle Agent message format with chunk_type="explain"
|
||||
if chunk_type == "explain":
|
||||
if include_provenance:
|
||||
return self._build_provenance_event(resp)
|
||||
return None
|
||||
|
||||
if chunk_type == "thought":
|
||||
if message_type == "thought":
|
||||
return AgentThought(
|
||||
content=resp.get("content", ""),
|
||||
end_of_message=resp.get("end_of_message", False),
|
||||
message_id=resp.get("message_id", ""),
|
||||
)
|
||||
elif chunk_type == "observation":
|
||||
elif message_type == "observation":
|
||||
return AgentObservation(
|
||||
content=resp.get("content", ""),
|
||||
end_of_message=resp.get("end_of_message", False),
|
||||
message_id=resp.get("message_id", ""),
|
||||
)
|
||||
elif chunk_type == "answer" or chunk_type == "final-answer":
|
||||
elif message_type == "answer" or message_type == "final-answer":
|
||||
return AgentAnswer(
|
||||
content=resp.get("content", ""),
|
||||
end_of_message=resp.get("end_of_message", False),
|
||||
end_of_dialog=resp.get("end_of_dialog", False),
|
||||
message_id=resp.get("message_id", ""),
|
||||
in_token=resp.get("in_token"),
|
||||
out_token=resp.get("out_token"),
|
||||
model=resp.get("model"),
|
||||
)
|
||||
elif chunk_type == "action":
|
||||
elif message_type == "action":
|
||||
return AgentThought(
|
||||
content=resp.get("content", ""),
|
||||
end_of_message=resp.get("end_of_message", False)
|
||||
|
|
@ -404,7 +399,10 @@ class SocketClient:
|
|||
return RAGChunk(
|
||||
content=content,
|
||||
end_of_stream=resp.get("end_of_stream", False),
|
||||
error=None
|
||||
error=None,
|
||||
in_token=resp.get("in_token"),
|
||||
out_token=resp.get("out_token"),
|
||||
model=resp.get("model"),
|
||||
)
|
||||
|
||||
def _build_provenance_event(self, resp: Dict[str, Any]) -> ProvenanceEvent:
|
||||
|
|
@ -543,8 +541,12 @@ class SocketFlowInstance:
|
|||
streaming=True, include_provenance=True
|
||||
)
|
||||
|
||||
def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[str, Iterator[str]]:
|
||||
"""Execute text completion with optional streaming."""
|
||||
def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
|
||||
"""Execute text completion with optional streaming.
|
||||
|
||||
Non-streaming: returns a TextCompletionResult with text and token counts.
|
||||
Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
|
||||
"""
|
||||
request = {
|
||||
"system": system,
|
||||
"prompt": prompt,
|
||||
|
|
@ -557,12 +559,17 @@ class SocketFlowInstance:
|
|||
if streaming:
|
||||
return self._text_completion_generator(result)
|
||||
else:
|
||||
return result.get("response", "")
|
||||
return TextCompletionResult(
|
||||
text=result.get("response", ""),
|
||||
in_token=result.get("in_token"),
|
||||
out_token=result.get("out_token"),
|
||||
model=result.get("model"),
|
||||
)
|
||||
|
||||
def _text_completion_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]:
|
||||
def _text_completion_generator(self, result: Iterator[StreamingChunk]) -> Iterator[RAGChunk]:
|
||||
for chunk in result:
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
if isinstance(chunk, RAGChunk):
|
||||
yield chunk
|
||||
|
||||
def graph_rag(
|
||||
self,
|
||||
|
|
@ -577,8 +584,12 @@ class SocketFlowInstance:
|
|||
edge_limit: int = 25,
|
||||
streaming: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Union[str, Iterator[str]]:
|
||||
"""Execute graph-based RAG query with optional streaming."""
|
||||
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
|
||||
"""Execute graph-based RAG query with optional streaming.
|
||||
|
||||
Non-streaming: returns a TextCompletionResult with text and token counts.
|
||||
Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
|
||||
"""
|
||||
request = {
|
||||
"query": query,
|
||||
"user": user,
|
||||
|
|
@ -598,7 +609,12 @@ class SocketFlowInstance:
|
|||
if streaming:
|
||||
return self._rag_generator(result)
|
||||
else:
|
||||
return result.get("response", "")
|
||||
return TextCompletionResult(
|
||||
text=result.get("response", ""),
|
||||
in_token=result.get("in_token"),
|
||||
out_token=result.get("out_token"),
|
||||
model=result.get("model"),
|
||||
)
|
||||
|
||||
def graph_rag_explain(
|
||||
self,
|
||||
|
|
@ -642,8 +658,12 @@ class SocketFlowInstance:
|
|||
doc_limit: int = 10,
|
||||
streaming: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Union[str, Iterator[str]]:
|
||||
"""Execute document-based RAG query with optional streaming."""
|
||||
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
|
||||
"""Execute document-based RAG query with optional streaming.
|
||||
|
||||
Non-streaming: returns a TextCompletionResult with text and token counts.
|
||||
Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
|
||||
"""
|
||||
request = {
|
||||
"query": query,
|
||||
"user": user,
|
||||
|
|
@ -658,7 +678,12 @@ class SocketFlowInstance:
|
|||
if streaming:
|
||||
return self._rag_generator(result)
|
||||
else:
|
||||
return result.get("response", "")
|
||||
return TextCompletionResult(
|
||||
text=result.get("response", ""),
|
||||
in_token=result.get("in_token"),
|
||||
out_token=result.get("out_token"),
|
||||
model=result.get("model"),
|
||||
)
|
||||
|
||||
def document_rag_explain(
|
||||
self,
|
||||
|
|
@ -684,10 +709,10 @@ class SocketFlowInstance:
|
|||
streaming=True, include_provenance=True
|
||||
)
|
||||
|
||||
def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]:
|
||||
def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[RAGChunk]:
|
||||
for chunk in result:
|
||||
if hasattr(chunk, 'content'):
|
||||
yield chunk.content
|
||||
if isinstance(chunk, RAGChunk):
|
||||
yield chunk
|
||||
|
||||
def prompt(
|
||||
self,
|
||||
|
|
@ -695,8 +720,12 @@ class SocketFlowInstance:
|
|||
variables: Dict[str, str],
|
||||
streaming: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Union[str, Iterator[str]]:
|
||||
"""Execute a prompt template with optional streaming."""
|
||||
) -> Union[TextCompletionResult, Iterator[RAGChunk]]:
|
||||
"""Execute a prompt template with optional streaming.
|
||||
|
||||
Non-streaming: returns a TextCompletionResult with text and token counts.
|
||||
Streaming: returns an iterator of RAGChunk (with token counts on the final chunk).
|
||||
"""
|
||||
request = {
|
||||
"id": id,
|
||||
"variables": variables,
|
||||
|
|
@ -709,7 +738,12 @@ class SocketFlowInstance:
|
|||
if streaming:
|
||||
return self._rag_generator(result)
|
||||
else:
|
||||
return result.get("response", "")
|
||||
return TextCompletionResult(
|
||||
text=result.get("text", result.get("response", "")),
|
||||
in_token=result.get("in_token"),
|
||||
out_token=result.get("out_token"),
|
||||
model=result.get("model"),
|
||||
)
|
||||
|
||||
def graph_embeddings_query(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -149,10 +149,10 @@ class AgentThought(StreamingChunk):
|
|||
Attributes:
|
||||
content: Agent's thought text
|
||||
end_of_message: True if this completes the current thought
|
||||
chunk_type: Always "thought"
|
||||
message_type: Always "thought"
|
||||
message_id: Provenance URI of the entity being built
|
||||
"""
|
||||
chunk_type: str = "thought"
|
||||
message_type: str = "thought"
|
||||
message_id: str = ""
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
|
@ -166,10 +166,10 @@ class AgentObservation(StreamingChunk):
|
|||
Attributes:
|
||||
content: Observation text describing tool results
|
||||
end_of_message: True if this completes the current observation
|
||||
chunk_type: Always "observation"
|
||||
message_type: Always "observation"
|
||||
message_id: Provenance URI of the entity being built
|
||||
"""
|
||||
chunk_type: str = "observation"
|
||||
message_type: str = "observation"
|
||||
message_id: str = ""
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
|
@ -184,11 +184,14 @@ class AgentAnswer(StreamingChunk):
|
|||
content: Answer text
|
||||
end_of_message: True if this completes the current answer segment
|
||||
end_of_dialog: True if this completes the entire agent interaction
|
||||
chunk_type: Always "final-answer"
|
||||
message_type: Always "final-answer"
|
||||
"""
|
||||
chunk_type: str = "final-answer"
|
||||
message_type: str = "final-answer"
|
||||
end_of_dialog: bool = False
|
||||
message_id: str = ""
|
||||
in_token: Optional[int] = None
|
||||
out_token: Optional[int] = None
|
||||
model: Optional[str] = None
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RAGChunk(StreamingChunk):
|
||||
|
|
@ -202,11 +205,37 @@ class RAGChunk(StreamingChunk):
|
|||
content: Generated text content
|
||||
end_of_stream: True if this is the final chunk of the stream
|
||||
error: Optional error information if an error occurred
|
||||
chunk_type: Always "rag"
|
||||
in_token: Input token count (populated on the final chunk, 0 otherwise)
|
||||
out_token: Output token count (populated on the final chunk, 0 otherwise)
|
||||
model: Model identifier (populated on the final chunk, empty otherwise)
|
||||
message_type: Always "rag"
|
||||
"""
|
||||
chunk_type: str = "rag"
|
||||
message_type: str = "rag"
|
||||
end_of_stream: bool = False
|
||||
error: Optional[Dict[str, str]] = None
|
||||
in_token: Optional[int] = None
|
||||
out_token: Optional[int] = None
|
||||
model: Optional[str] = None
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TextCompletionResult:
|
||||
"""
|
||||
Result from a text completion request.
|
||||
|
||||
Returned by text_completion() in both streaming and non-streaming modes.
|
||||
In streaming mode, text is None (chunks are delivered via the iterator).
|
||||
In non-streaming mode, text contains the complete response.
|
||||
|
||||
Attributes:
|
||||
text: Complete response text (None in streaming mode)
|
||||
in_token: Input token count (None if not available)
|
||||
out_token: Output token count (None if not available)
|
||||
model: Model identifier (None if not available)
|
||||
"""
|
||||
text: Optional[str]
|
||||
in_token: Optional[int] = None
|
||||
out_token: Optional[int] = None
|
||||
model: Optional[str] = None
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ProvenanceEvent:
|
||||
|
|
|
|||
|
|
@ -18,8 +18,10 @@ from . librarian_client import LibrarianClient
|
|||
from . chunking_service import ChunkingService
|
||||
from . embeddings_service import EmbeddingsService
|
||||
from . embeddings_client import EmbeddingsClientSpec
|
||||
from . text_completion_client import TextCompletionClientSpec
|
||||
from . prompt_client import PromptClientSpec
|
||||
from . text_completion_client import (
|
||||
TextCompletionClientSpec, TextCompletionClient, TextCompletionResult,
|
||||
)
|
||||
from . prompt_client import PromptClientSpec, PromptClient, PromptResult
|
||||
from . triples_store_service import TriplesStoreService
|
||||
from . graph_embeddings_store_service import GraphEmbeddingsStoreService
|
||||
from . document_embeddings_store_service import DocumentEmbeddingsStoreService
|
||||
|
|
|
|||
|
|
@ -30,19 +30,19 @@ class AgentClient(RequestResponse):
|
|||
raise RuntimeError(resp.error.message)
|
||||
|
||||
# Handle thought chunks
|
||||
if resp.chunk_type == 'thought':
|
||||
if resp.message_type == 'thought':
|
||||
if think:
|
||||
await think(resp.content, resp.end_of_message)
|
||||
return False # Continue receiving
|
||||
|
||||
# Handle observation chunks
|
||||
if resp.chunk_type == 'observation':
|
||||
if resp.message_type == 'observation':
|
||||
if observe:
|
||||
await observe(resp.content, resp.end_of_message)
|
||||
return False # Continue receiving
|
||||
|
||||
# Handle answer chunks
|
||||
if resp.chunk_type == 'answer':
|
||||
if resp.message_type == 'answer':
|
||||
if resp.content:
|
||||
accumulated_answer.append(resp.content)
|
||||
if answer_callback:
|
||||
|
|
|
|||
|
|
@ -58,6 +58,18 @@ class BackendProducer(Protocol):
|
|||
class BackendConsumer(Protocol):
|
||||
"""Protocol for backend-specific consumer."""
|
||||
|
||||
def ensure_connected(self) -> None:
|
||||
"""
|
||||
Eagerly establish the underlying connection and bind the queue.
|
||||
|
||||
Backends that lazily connect on first receive() must implement this
|
||||
so that callers can guarantee the consumer is fully bound — and
|
||||
therefore able to receive responses — before any related request is
|
||||
published. Backends that connect at construction time may make this
|
||||
a no-op.
|
||||
"""
|
||||
...
|
||||
|
||||
def receive(self, timeout_millis: int = 2000) -> Message:
|
||||
"""
|
||||
Receive a message from the topic.
|
||||
|
|
|
|||
|
|
@ -88,14 +88,14 @@ class ChunkingService(FlowProcessor):
|
|||
chunk_overlap = default_chunk_overlap
|
||||
|
||||
try:
|
||||
cs = flow.parameters.get("chunk-size")
|
||||
cs = flow("chunk-size")
|
||||
if cs is not None:
|
||||
chunk_size = int(cs)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not parse chunk-size parameter: {e}")
|
||||
|
||||
try:
|
||||
co = flow.parameters.get("chunk-overlap")
|
||||
co = flow("chunk-overlap")
|
||||
if co is not None:
|
||||
chunk_overlap = int(co)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -8,12 +8,51 @@ ensuring consistent log formats, levels, and command-line arguments.
|
|||
Supports dual output to console and Loki for centralized log aggregation.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
import logging.handlers
|
||||
from queue import Queue
|
||||
import os
|
||||
|
||||
|
||||
# The current processor id for this task context. Read by
|
||||
# _ProcessorIdFilter to stamp every LogRecord with its owning
|
||||
# processor, and read by logging_loki's emitter via record.tags
|
||||
# to label log lines in Loki. ContextVar so asyncio subtasks
|
||||
# inherit their parent supervisor's processor id automatically.
|
||||
current_processor_id = contextvars.ContextVar(
|
||||
"current_processor_id", default="unknown"
|
||||
)
|
||||
|
||||
|
||||
def set_processor_id(pid):
|
||||
"""Set the processor id for the current task context.
|
||||
|
||||
All subsequent log records emitted from this task — and any
|
||||
asyncio tasks spawned from it — will be tagged with this id
|
||||
in the console format and in Loki labels.
|
||||
"""
|
||||
current_processor_id.set(pid)
|
||||
|
||||
|
||||
class _ProcessorIdFilter(logging.Filter):
|
||||
"""Stamps every LogRecord with processor_id from the contextvar.
|
||||
|
||||
Attaches two fields to each record:
|
||||
record.processor_id — used by the console format string
|
||||
record.tags — merged into Loki labels by logging_loki's
|
||||
emitter (it reads record.tags and combines
|
||||
with the handler's static tags)
|
||||
"""
|
||||
|
||||
def filter(self, record):
|
||||
pid = current_processor_id.get()
|
||||
record.processor_id = pid
|
||||
existing = getattr(record, "tags", None) or {}
|
||||
record.tags = {**existing, "processor": pid}
|
||||
return True
|
||||
|
||||
|
||||
def add_logging_args(parser):
|
||||
"""
|
||||
Add standard logging arguments to an argument parser.
|
||||
|
|
@ -87,12 +126,15 @@ def setup_logging(args):
|
|||
loki_url = args.get('loki_url', 'http://loki:3100/loki/api/v1/push')
|
||||
loki_username = args.get('loki_username')
|
||||
loki_password = args.get('loki_password')
|
||||
processor_id = args.get('id') # Processor identity (e.g., "config-svc", "text-completion")
|
||||
|
||||
try:
|
||||
from logging_loki import LokiHandler
|
||||
|
||||
# Create Loki handler with optional authentication and processor label
|
||||
# Create Loki handler with optional authentication. The
|
||||
# processor label is NOT baked in here — it's stamped onto
|
||||
# each record by _ProcessorIdFilter reading the task-local
|
||||
# contextvar, and logging_loki's emitter reads record.tags
|
||||
# to build per-record Loki labels.
|
||||
loki_handler_kwargs = {
|
||||
'url': loki_url,
|
||||
'version': "1",
|
||||
|
|
@ -101,10 +143,6 @@ def setup_logging(args):
|
|||
if loki_username and loki_password:
|
||||
loki_handler_kwargs['auth'] = (loki_username, loki_password)
|
||||
|
||||
# Add processor label if available (for consistency with Prometheus metrics)
|
||||
if processor_id:
|
||||
loki_handler_kwargs['tags'] = {'processor': processor_id}
|
||||
|
||||
loki_handler = LokiHandler(**loki_handler_kwargs)
|
||||
|
||||
# Wrap in QueueHandler for non-blocking operation
|
||||
|
|
@ -133,23 +171,44 @@ def setup_logging(args):
|
|||
print(f"WARNING: Failed to setup Loki logging: {e}")
|
||||
print("Continuing with console-only logging")
|
||||
|
||||
# Get processor ID for log formatting (use 'unknown' if not available)
|
||||
processor_id = args.get('id', 'unknown')
|
||||
|
||||
# Configure logging with all handlers
|
||||
# Use processor ID as the primary identifier in logs
|
||||
# Configure logging with all handlers. The processor id comes
|
||||
# from _ProcessorIdFilter (via contextvar) and is injected into
|
||||
# each record as record.processor_id. The format string reads
|
||||
# that attribute on every emit.
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, log_level.upper()),
|
||||
format=f'%(asctime)s - {processor_id} - %(levelname)s - %(message)s',
|
||||
format='%(asctime)s - %(processor_id)s - %(levelname)s - %(message)s',
|
||||
handlers=handlers,
|
||||
force=True # Force reconfiguration if already configured
|
||||
)
|
||||
|
||||
# Prevent recursive logging from Loki's HTTP client
|
||||
if loki_enabled and queue_listener:
|
||||
# Disable urllib3 logging to prevent infinite loop
|
||||
logging.getLogger('urllib3').setLevel(logging.WARNING)
|
||||
logging.getLogger('urllib3.connectionpool').setLevel(logging.WARNING)
|
||||
# Attach the processor-id filter to every handler so all records
|
||||
# passing through any sink get stamped (console, queue→loki,
|
||||
# future handlers). Filters on handlers run regardless of which
|
||||
# logger originated the record, so logs from pika, cassandra,
|
||||
# processor code, etc. all pass through it.
|
||||
processor_filter = _ProcessorIdFilter()
|
||||
for h in handlers:
|
||||
h.addFilter(processor_filter)
|
||||
|
||||
# Seed the contextvar from --id if one was supplied. In group
|
||||
# mode --id isn't present; the processor_group supervisor sets
|
||||
# it per task. In standalone mode AsyncProcessor.launch provides
|
||||
# it via argparse default.
|
||||
if args.get('id'):
|
||||
set_processor_id(args['id'])
|
||||
|
||||
# Silence noisy third-party library loggers. These emit INFO-level
|
||||
# chatter (connection churn, channel open/close, driver warnings) that
|
||||
# drowns the useful signal and can't be attributed to a specific
|
||||
# processor anyway. WARNING and above still propagate.
|
||||
for noisy in (
|
||||
'pika',
|
||||
'cassandra',
|
||||
'urllib3',
|
||||
'urllib3.connectionpool',
|
||||
):
|
||||
logging.getLogger(noisy).setLevel(logging.WARNING)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"Logging configured with level: {log_level}")
|
||||
|
|
|
|||
204
trustgraph-base/trustgraph/base/processor_group.py
Normal file
204
trustgraph-base/trustgraph/base/processor_group.py
Normal file
|
|
@ -0,0 +1,204 @@
|
|||
|
||||
# Multi-processor group runner. Runs multiple AsyncProcessor descendants
|
||||
# as concurrent tasks inside a single process, sharing one event loop,
|
||||
# one Prometheus HTTP server, and one pub/sub backend pool.
|
||||
#
|
||||
# Intended for dev and resource-constrained deployments. Scale deployments
|
||||
# should continue to use per-processor endpoints.
|
||||
#
|
||||
# Group config is a YAML or JSON file with shape:
|
||||
#
|
||||
# processors:
|
||||
# - class: trustgraph.extract.kg.definitions.extract.Processor
|
||||
# params:
|
||||
# id: kg-extract-definitions
|
||||
# triples_batch_size: 1000
|
||||
# - class: trustgraph.chunking.recursive.Processor
|
||||
# params:
|
||||
# id: chunker-recursive
|
||||
#
|
||||
# Each entry's params are passed directly to the class constructor alongside
|
||||
# the shared taskgroup. Defaults live inside each processor class.
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
from prometheus_client import start_http_server
|
||||
|
||||
from . logging import add_logging_args, setup_logging, set_processor_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _load_config(path):
|
||||
with open(path) as f:
|
||||
text = f.read()
|
||||
if path.endswith((".yaml", ".yml")):
|
||||
import yaml
|
||||
return yaml.safe_load(text)
|
||||
return json.loads(text)
|
||||
|
||||
|
||||
def _resolve_class(dotted):
|
||||
module_path, _, class_name = dotted.rpartition(".")
|
||||
if not module_path:
|
||||
raise ValueError(
|
||||
f"Processor class must be a dotted path, got {dotted!r}"
|
||||
)
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
RESTART_DELAY_SECONDS = 4
|
||||
|
||||
|
||||
async def _supervise(entry):
|
||||
"""Run one processor with its own nested TaskGroup, restarting on any
|
||||
failure. Each processor is isolated from its siblings — a crash here
|
||||
does not propagate to the outer group."""
|
||||
|
||||
pid = entry["params"]["id"]
|
||||
class_path = entry["class"]
|
||||
|
||||
# Stamp the contextvar for this supervisor task. Every log
|
||||
# record emitted from this task — and from any inner TaskGroup
|
||||
# child created by the processor — inherits this id via
|
||||
# contextvar propagation. Siblings in the outer group set
|
||||
# their own id in their own task context and do not interfere.
|
||||
set_processor_id(pid)
|
||||
|
||||
while True:
|
||||
|
||||
try:
|
||||
|
||||
async with asyncio.TaskGroup() as inner_tg:
|
||||
|
||||
cls = _resolve_class(class_path)
|
||||
params = dict(entry.get("params", {}))
|
||||
params["taskgroup"] = inner_tg
|
||||
|
||||
logger.info(f"Starting {class_path} as {pid}")
|
||||
|
||||
p = cls(**params)
|
||||
await p.start()
|
||||
inner_tg.create_task(p.run())
|
||||
|
||||
# Clean exit — processor's run() returned without raising.
|
||||
# Treat as a transient shutdown and restart, matching the
|
||||
# behaviour of per-container `restart: on-failure`.
|
||||
logger.warning(
|
||||
f"Processor {pid} exited cleanly, will restart"
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Processor {pid} cancelled")
|
||||
raise
|
||||
|
||||
except BaseExceptionGroup as eg:
|
||||
for e in eg.exceptions:
|
||||
logger.error(
|
||||
f"Processor {pid} failure: {type(e).__name__}: {e}",
|
||||
exc_info=e,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Processor {pid} failure: {type(e).__name__}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Restarting {pid} in {RESTART_DELAY_SECONDS}s..."
|
||||
)
|
||||
await asyncio.sleep(RESTART_DELAY_SECONDS)
|
||||
|
||||
|
||||
async def run_group(config):
|
||||
|
||||
entries = config.get("processors", [])
|
||||
if not entries:
|
||||
raise RuntimeError("Group config has no processors")
|
||||
|
||||
seen_ids = set()
|
||||
for entry in entries:
|
||||
pid = entry.get("params", {}).get("id")
|
||||
if pid is None:
|
||||
raise RuntimeError(
|
||||
f"Entry {entry.get('class')!r} missing params.id — "
|
||||
f"required for metrics labelling"
|
||||
)
|
||||
if pid in seen_ids:
|
||||
raise RuntimeError(f"Duplicate processor id {pid!r} in group")
|
||||
seen_ids.add(pid)
|
||||
|
||||
async with asyncio.TaskGroup() as outer_tg:
|
||||
for entry in entries:
|
||||
outer_tg.create_task(_supervise(entry))
|
||||
|
||||
|
||||
def run():
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="processor-group",
|
||||
description="Run multiple processors as tasks in one process",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-c", "--config",
|
||||
required=True,
|
||||
help="Path to group config file (JSON or YAML)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--metrics",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help="Metrics enabled (default: true)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-P", "--metrics-port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Prometheus metrics port (default: 8000)",
|
||||
)
|
||||
|
||||
add_logging_args(parser)
|
||||
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
setup_logging(args)
|
||||
|
||||
config = _load_config(args["config"])
|
||||
|
||||
if args["metrics"]:
|
||||
start_http_server(args["metrics_port"])
|
||||
|
||||
while True:
|
||||
|
||||
logger.info("Starting group...")
|
||||
|
||||
try:
|
||||
asyncio.run(run_group(config))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Keyboard interrupt.")
|
||||
return
|
||||
|
||||
except ExceptionGroup as e:
|
||||
logger.error("Exception group:")
|
||||
for se in e.exceptions:
|
||||
logger.error(f" Type: {type(se)}")
|
||||
logger.error(f" Exception: {se}", exc_info=se)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Type: {type(e)}")
|
||||
logger.error(f"Exception: {e}", exc_info=True)
|
||||
|
||||
logger.warning("Will retry...")
|
||||
time.sleep(4)
|
||||
logger.info("Retrying...")
|
||||
|
|
@ -1,10 +1,22 @@
|
|||
|
||||
import json
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any
|
||||
|
||||
from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||
from .. schema import PromptRequest, PromptResponse
|
||||
|
||||
@dataclass
|
||||
class PromptResult:
|
||||
response_type: str # "text", "json", or "jsonl"
|
||||
text: Optional[str] = None # populated for "text"
|
||||
object: Any = None # populated for "json"
|
||||
objects: Optional[list] = None # populated for "jsonl"
|
||||
in_token: Optional[int] = None
|
||||
out_token: Optional[int] = None
|
||||
model: Optional[str] = None
|
||||
|
||||
class PromptClient(RequestResponse):
|
||||
|
||||
async def prompt(self, id, variables, timeout=600, streaming=False, chunk_callback=None):
|
||||
|
|
@ -26,17 +38,40 @@ class PromptClient(RequestResponse):
|
|||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
if resp.text: return resp.text
|
||||
if resp.text:
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
text=resp.text,
|
||||
in_token=resp.in_token,
|
||||
out_token=resp.out_token,
|
||||
model=resp.model,
|
||||
)
|
||||
|
||||
return json.loads(resp.object)
|
||||
parsed = json.loads(resp.object)
|
||||
|
||||
if isinstance(parsed, list):
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=parsed,
|
||||
in_token=resp.in_token,
|
||||
out_token=resp.out_token,
|
||||
model=resp.model,
|
||||
)
|
||||
|
||||
return PromptResult(
|
||||
response_type="json",
|
||||
object=parsed,
|
||||
in_token=resp.in_token,
|
||||
out_token=resp.out_token,
|
||||
model=resp.model,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
last_text = ""
|
||||
last_object = None
|
||||
last_resp = None
|
||||
|
||||
async def forward_chunks(resp):
|
||||
nonlocal last_text, last_object
|
||||
nonlocal last_resp
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
|
@ -44,14 +79,13 @@ class PromptClient(RequestResponse):
|
|||
end_stream = getattr(resp, 'end_of_stream', False)
|
||||
|
||||
if resp.text is not None:
|
||||
last_text = resp.text
|
||||
if chunk_callback:
|
||||
if asyncio.iscoroutinefunction(chunk_callback):
|
||||
await chunk_callback(resp.text, end_stream)
|
||||
else:
|
||||
chunk_callback(resp.text, end_stream)
|
||||
elif resp.object:
|
||||
last_object = resp.object
|
||||
|
||||
last_resp = resp
|
||||
|
||||
return end_stream
|
||||
|
||||
|
|
@ -70,10 +104,36 @@ class PromptClient(RequestResponse):
|
|||
timeout=timeout
|
||||
)
|
||||
|
||||
if last_text:
|
||||
return last_text
|
||||
if last_resp is None:
|
||||
return PromptResult(response_type="text")
|
||||
|
||||
return json.loads(last_object) if last_object else None
|
||||
if last_resp.object:
|
||||
parsed = json.loads(last_resp.object)
|
||||
|
||||
if isinstance(parsed, list):
|
||||
return PromptResult(
|
||||
response_type="jsonl",
|
||||
objects=parsed,
|
||||
in_token=last_resp.in_token,
|
||||
out_token=last_resp.out_token,
|
||||
model=last_resp.model,
|
||||
)
|
||||
|
||||
return PromptResult(
|
||||
response_type="json",
|
||||
object=parsed,
|
||||
in_token=last_resp.in_token,
|
||||
out_token=last_resp.out_token,
|
||||
model=last_resp.model,
|
||||
)
|
||||
|
||||
return PromptResult(
|
||||
response_type="text",
|
||||
text=last_resp.text,
|
||||
in_token=last_resp.in_token,
|
||||
out_token=last_resp.out_token,
|
||||
model=last_resp.model,
|
||||
)
|
||||
|
||||
async def extract_definitions(self, text, timeout=600):
|
||||
return await self.prompt(
|
||||
|
|
@ -152,4 +212,3 @@ class PromptClientSpec(RequestResponseSpec):
|
|||
response_schema = PromptResponse,
|
||||
impl = PromptClient,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -72,6 +72,16 @@ class PulsarBackendConsumer:
|
|||
self._consumer = pulsar_consumer
|
||||
self._schema_cls = schema_cls
|
||||
|
||||
def ensure_connected(self) -> None:
|
||||
"""No-op for Pulsar.
|
||||
|
||||
PulsarBackend.create_consumer() calls client.subscribe() which is
|
||||
synchronous and returns a fully-subscribed consumer, so the
|
||||
consumer is already ready by the time this object is constructed.
|
||||
Defined for parity with the BackendConsumer protocol used by
|
||||
Subscriber.start()'s readiness barrier."""
|
||||
pass
|
||||
|
||||
def receive(self, timeout_millis: int = 2000) -> Message:
|
||||
"""Receive a message. Raises TimeoutError if no message available."""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -214,16 +214,43 @@ class RabbitMQBackendConsumer:
|
|||
and self._channel.is_open
|
||||
)
|
||||
|
||||
def ensure_connected(self) -> None:
|
||||
"""Eagerly declare and bind the queue.
|
||||
|
||||
Without this, the queue is only declared lazily on the first
|
||||
receive() call. For request/response with ephemeral per-subscriber
|
||||
response queues that is a race: a request published before the
|
||||
response queue is bound will have its reply silently dropped by
|
||||
the broker. Subscriber.start() calls this so callers get a hard
|
||||
readiness barrier."""
|
||||
if not self._is_alive():
|
||||
self._connect()
|
||||
|
||||
def receive(self, timeout_millis: int = 2000) -> Message:
|
||||
"""Receive a message. Raises TimeoutError if none available."""
|
||||
"""Receive a message. Raises TimeoutError if none available.
|
||||
|
||||
Loop ordering matters: check _incoming at the TOP of each
|
||||
iteration, not as the loop condition. process_data_events
|
||||
may dispatch a message via the _on_message callback during
|
||||
the pump; we must re-check _incoming on the next iteration
|
||||
before giving up on the deadline. The previous control
|
||||
flow (`while deadline: check; pump`) could lose a wakeup if
|
||||
the pump consumed the remainder of the window — the
|
||||
`while` check would fail before `_incoming` was re-read,
|
||||
leaving a just-dispatched message stranded until the next
|
||||
receive() call one full poll cycle later.
|
||||
"""
|
||||
if not self._is_alive():
|
||||
self._connect()
|
||||
|
||||
timeout_seconds = timeout_millis / 1000.0
|
||||
deadline = time.monotonic() + timeout_seconds
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
# Check if a message was already delivered
|
||||
while True:
|
||||
# Check if a message has been dispatched to our queue.
|
||||
# This catches both (a) messages dispatched before this
|
||||
# receive() was called and (b) messages dispatched
|
||||
# during the previous iteration's process_data_events.
|
||||
try:
|
||||
method, properties, body = self._incoming.get_nowait()
|
||||
return RabbitMQMessage(
|
||||
|
|
@ -232,14 +259,16 @@ class RabbitMQBackendConsumer:
|
|||
except queue.Empty:
|
||||
pass
|
||||
|
||||
# Drive pika's I/O — delivers messages and processes heartbeats
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining > 0:
|
||||
self._connection.process_data_events(
|
||||
time_limit=min(0.1, remaining),
|
||||
)
|
||||
if remaining <= 0:
|
||||
raise TimeoutError("No message received within timeout")
|
||||
|
||||
raise TimeoutError("No message received within timeout")
|
||||
# Drive pika's I/O. Any messages delivered during this
|
||||
# call land in _incoming via _on_message; the next
|
||||
# iteration of this loop catches them at the top.
|
||||
self._connection.process_data_events(
|
||||
time_limit=min(0.1, remaining),
|
||||
)
|
||||
|
||||
def acknowledge(self, message: Message) -> None:
|
||||
if isinstance(message, RabbitMQMessage) and message._method:
|
||||
|
|
|
|||
|
|
@ -41,14 +41,55 @@ class Subscriber:
|
|||
self.consumer = None
|
||||
self.executor = None
|
||||
|
||||
# Readiness barrier — completed by run() once the underlying
|
||||
# backend consumer is fully connected and bound. start() awaits
|
||||
# this so callers know any subsequently published request will
|
||||
# have a queue ready to receive its response. Without this,
|
||||
# ephemeral per-subscriber response queues (RabbitMQ auto-delete
|
||||
# exclusive queues) would race the request and lose the reply.
|
||||
# A Future is used (rather than an Event) so that a first-attempt
|
||||
# connection failure can be propagated to start() as an exception.
|
||||
self._ready = None # created in start() so we have a running loop
|
||||
|
||||
def __del__(self):
|
||||
|
||||
self.running = False
|
||||
|
||||
async def start(self):
|
||||
|
||||
self._ready = asyncio.get_event_loop().create_future()
|
||||
self.task = asyncio.create_task(self.run())
|
||||
|
||||
# Block until run() signals readiness OR exits. The future
|
||||
# carries the outcome of the first connect attempt: a value on
|
||||
# success, an exception on first-attempt failure. If run() exits
|
||||
# without ever signalling (e.g. cancelled, or a code path bug),
|
||||
# we surface that as a clear RuntimeError rather than hanging
|
||||
# forever waiting on the future.
|
||||
ready_wait = asyncio.ensure_future(
|
||||
asyncio.shield(self._ready)
|
||||
)
|
||||
try:
|
||||
await asyncio.wait(
|
||||
{self.task, ready_wait},
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
finally:
|
||||
ready_wait.cancel()
|
||||
|
||||
if self._ready.done():
|
||||
# Re-raise first-attempt connect failure if any.
|
||||
self._ready.result()
|
||||
return
|
||||
|
||||
# run() exited before _ready was settled. Propagate its exception
|
||||
# if it had one, otherwise raise a generic readiness error.
|
||||
if self.task.done() and self.task.exception() is not None:
|
||||
raise self.task.exception()
|
||||
raise RuntimeError(
|
||||
"Subscriber.run() exited before signalling readiness"
|
||||
)
|
||||
|
||||
async def stop(self):
|
||||
"""Initiate graceful shutdown with draining"""
|
||||
self.running = False
|
||||
|
|
@ -66,6 +107,7 @@ class Subscriber:
|
|||
|
||||
async def run(self):
|
||||
"""Enhanced run method with integrated draining logic"""
|
||||
first_attempt = True
|
||||
while self.running or self.draining:
|
||||
|
||||
if self.metrics:
|
||||
|
|
@ -87,10 +129,27 @@ class Subscriber:
|
|||
),
|
||||
)
|
||||
|
||||
# Eagerly bind the queue. For backends that connect
|
||||
# lazily on first receive (RabbitMQ), this is what
|
||||
# closes the request/response setup race — without
|
||||
# it the response queue is not bound until later and
|
||||
# any reply published in the meantime is dropped.
|
||||
await loop.run_in_executor(
|
||||
self.executor,
|
||||
lambda: self.consumer.ensure_connected(),
|
||||
)
|
||||
|
||||
if self.metrics:
|
||||
self.metrics.state("running")
|
||||
|
||||
logger.info("Subscriber running...")
|
||||
|
||||
# Signal start() that the consumer is ready. This must
|
||||
# happen AFTER ensure_connected() above so callers can
|
||||
# safely publish requests immediately after start() returns.
|
||||
if first_attempt and not self._ready.done():
|
||||
self._ready.set_result(None)
|
||||
first_attempt = False
|
||||
drain_end_time = None
|
||||
|
||||
while self.running or self.draining:
|
||||
|
|
@ -162,6 +221,16 @@ class Subscriber:
|
|||
except Exception as e:
|
||||
logger.error(f"Subscriber exception: {e}", exc_info=True)
|
||||
|
||||
# First-attempt connection failure: propagate to start()
|
||||
# so the caller can decide what to do (retry, give up).
|
||||
# Subsequent failures use the existing retry-with-backoff
|
||||
# path so a long-lived subscriber survives broker blips.
|
||||
if first_attempt and not self._ready.done():
|
||||
self._ready.set_exception(e)
|
||||
first_attempt = False
|
||||
# Falls through into finally for cleanup, then the
|
||||
# outer return below ends run() so start() unblocks.
|
||||
|
||||
finally:
|
||||
# Negative acknowledge any pending messages
|
||||
for msg in self.pending_acks.values():
|
||||
|
|
@ -193,6 +262,11 @@ class Subscriber:
|
|||
if not self.running and not self.draining:
|
||||
return
|
||||
|
||||
# If start() has already returned with an exception there is
|
||||
# nothing more to do — exit run() rather than busy-retry.
|
||||
if self._ready.done() and self._ready.exception() is not None:
|
||||
return
|
||||
|
||||
# Sleep before retry
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,47 +1,71 @@
|
|||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from . request_response_spec import RequestResponse, RequestResponseSpec
|
||||
from .. schema import TextCompletionRequest, TextCompletionResponse
|
||||
|
||||
@dataclass
|
||||
class TextCompletionResult:
|
||||
text: Optional[str]
|
||||
in_token: Optional[int] = None
|
||||
out_token: Optional[int] = None
|
||||
model: Optional[str] = None
|
||||
|
||||
class TextCompletionClient(RequestResponse):
|
||||
async def text_completion(self, system, prompt, streaming=False, timeout=600):
|
||||
# If not streaming, use original behavior
|
||||
if not streaming:
|
||||
resp = await self.request(
|
||||
TextCompletionRequest(
|
||||
system = system, prompt = prompt, streaming = False
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
async def text_completion(self, system, prompt, timeout=600):
|
||||
|
||||
return resp.response
|
||||
|
||||
# For streaming: collect all chunks and return complete response
|
||||
full_response = ""
|
||||
|
||||
async def collect_chunks(resp):
|
||||
nonlocal full_response
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
if resp.response:
|
||||
full_response += resp.response
|
||||
|
||||
# Return True when end_of_stream is reached
|
||||
return getattr(resp, 'end_of_stream', False)
|
||||
|
||||
await self.request(
|
||||
resp = await self.request(
|
||||
TextCompletionRequest(
|
||||
system = system, prompt = prompt, streaming = True
|
||||
system = system, prompt = prompt, streaming = False
|
||||
),
|
||||
recipient=collect_chunks,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
return full_response
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
return TextCompletionResult(
|
||||
text = resp.response,
|
||||
in_token = resp.in_token,
|
||||
out_token = resp.out_token,
|
||||
model = resp.model,
|
||||
)
|
||||
|
||||
async def text_completion_stream(
|
||||
self, system, prompt, handler, timeout=600,
|
||||
):
|
||||
"""
|
||||
Streaming text completion. `handler` is an async callable invoked
|
||||
once per chunk with the chunk's TextCompletionResponse. Returns a
|
||||
TextCompletionResult with text=None and token counts / model taken
|
||||
from the end_of_stream message.
|
||||
"""
|
||||
|
||||
async def on_chunk(resp):
|
||||
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.message)
|
||||
|
||||
await handler(resp)
|
||||
|
||||
return getattr(resp, "end_of_stream", False)
|
||||
|
||||
final = await self.request(
|
||||
TextCompletionRequest(
|
||||
system = system, prompt = prompt, streaming = True
|
||||
),
|
||||
recipient=on_chunk,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
return TextCompletionResult(
|
||||
text = None,
|
||||
in_token = final.in_token,
|
||||
out_token = final.out_token,
|
||||
model = final.model,
|
||||
)
|
||||
|
||||
class TextCompletionClientSpec(RequestResponseSpec):
|
||||
def __init__(
|
||||
|
|
@ -54,4 +78,3 @@ class TextCompletionClientSpec(RequestResponseSpec):
|
|||
response_schema = TextCompletionResponse,
|
||||
impl = TextCompletionClient,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -58,23 +58,23 @@ class AgentClient(BaseClient):
|
|||
|
||||
def inspect(x):
|
||||
# Handle errors
|
||||
if x.chunk_type == 'error' or x.error:
|
||||
if x.message_type == 'error' or x.error:
|
||||
if error_callback:
|
||||
error_callback(x.content or (x.error.message if x.error else ""))
|
||||
# Continue to check end_of_dialog
|
||||
|
||||
# Handle thought chunks
|
||||
elif x.chunk_type == 'thought':
|
||||
elif x.message_type == 'thought':
|
||||
if think:
|
||||
think(x.content, x.end_of_message)
|
||||
|
||||
# Handle observation chunks
|
||||
elif x.chunk_type == 'observation':
|
||||
elif x.message_type == 'observation':
|
||||
if observe:
|
||||
observe(x.content, x.end_of_message)
|
||||
|
||||
# Handle answer chunks
|
||||
elif x.chunk_type == 'answer':
|
||||
elif x.message_type == 'answer':
|
||||
if x.content:
|
||||
accumulated_answer.append(x.content)
|
||||
if answer_callback:
|
||||
|
|
|
|||
|
|
@ -60,8 +60,8 @@ class AgentResponseTranslator(MessageTranslator):
|
|||
def encode(self, obj: AgentResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.chunk_type:
|
||||
result["chunk_type"] = obj.chunk_type
|
||||
if obj.message_type:
|
||||
result["message_type"] = obj.message_type
|
||||
if obj.content:
|
||||
result["content"] = obj.content
|
||||
result["end_of_message"] = getattr(obj, "end_of_message", False)
|
||||
|
|
@ -90,6 +90,13 @@ class AgentResponseTranslator(MessageTranslator):
|
|||
if hasattr(obj, 'error') and obj.error and obj.error.message:
|
||||
result["error"] = {"message": obj.error.message, "code": obj.error.code}
|
||||
|
||||
if obj.in_token is not None:
|
||||
result["in_token"] = obj.in_token
|
||||
if obj.out_token is not None:
|
||||
result["out_token"] = obj.out_token
|
||||
if obj.model is not None:
|
||||
result["model"] = obj.model
|
||||
|
||||
return result
|
||||
|
||||
def encode_with_completion(self, obj: AgentResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
|
|
|
|||
|
|
@ -151,7 +151,7 @@ class DocumentEmbeddingsTranslator(SendTranslator):
|
|||
chunks = [
|
||||
ChunkEmbeddings(
|
||||
chunk_id=chunk["chunk_id"],
|
||||
vectors=chunk["vectors"]
|
||||
vector=chunk["vector"]
|
||||
)
|
||||
for chunk in data.get("chunks", [])
|
||||
]
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class KnowledgeRequestTranslator(MessageTranslator):
|
|||
entities=[
|
||||
EntityEmbeddings(
|
||||
entity=self.value_translator.decode(ent["entity"]),
|
||||
vectors=ent["vectors"],
|
||||
vector=ent["vector"],
|
||||
)
|
||||
for ent in data["graph-embeddings"]["entities"]
|
||||
]
|
||||
|
|
|
|||
|
|
@ -53,6 +53,13 @@ class PromptResponseTranslator(MessageTranslator):
|
|||
# Always include end_of_stream flag for streaming support
|
||||
result["end_of_stream"] = getattr(obj, "end_of_stream", False)
|
||||
|
||||
if obj.in_token is not None:
|
||||
result["in_token"] = obj.in_token
|
||||
if obj.out_token is not None:
|
||||
result["out_token"] = obj.out_token
|
||||
if obj.model is not None:
|
||||
result["model"] = obj.model
|
||||
|
||||
return result
|
||||
|
||||
def encode_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
|
|
|
|||
|
|
@ -74,6 +74,13 @@ class DocumentRagResponseTranslator(MessageTranslator):
|
|||
if hasattr(obj, 'error') and obj.error and obj.error.message:
|
||||
result["error"] = {"message": obj.error.message, "type": obj.error.type}
|
||||
|
||||
if obj.in_token is not None:
|
||||
result["in_token"] = obj.in_token
|
||||
if obj.out_token is not None:
|
||||
result["out_token"] = obj.out_token
|
||||
if obj.model is not None:
|
||||
result["model"] = obj.model
|
||||
|
||||
return result
|
||||
|
||||
def encode_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
|
|
@ -163,6 +170,13 @@ class GraphRagResponseTranslator(MessageTranslator):
|
|||
if hasattr(obj, 'error') and obj.error and obj.error.message:
|
||||
result["error"] = {"message": obj.error.message, "type": obj.error.type}
|
||||
|
||||
if obj.in_token is not None:
|
||||
result["in_token"] = obj.in_token
|
||||
if obj.out_token is not None:
|
||||
result["out_token"] = obj.out_token
|
||||
if obj.model is not None:
|
||||
result["model"] = obj.model
|
||||
|
||||
return result
|
||||
|
||||
def encode_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
|
|
|
|||
|
|
@ -29,11 +29,11 @@ class TextCompletionResponseTranslator(MessageTranslator):
|
|||
def encode(self, obj: TextCompletionResponse) -> Dict[str, Any]:
|
||||
result = {"response": obj.response}
|
||||
|
||||
if obj.in_token:
|
||||
if obj.in_token is not None:
|
||||
result["in_token"] = obj.in_token
|
||||
if obj.out_token:
|
||||
if obj.out_token is not None:
|
||||
result["out_token"] = obj.out_token
|
||||
if obj.model:
|
||||
if obj.model is not None:
|
||||
result["model"] = obj.model
|
||||
|
||||
# Always include end_of_stream flag for streaming support
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ from . uris import (
|
|||
agent_plan_uri,
|
||||
agent_step_result_uri,
|
||||
agent_synthesis_uri,
|
||||
agent_pattern_decision_uri,
|
||||
# Document RAG provenance URIs
|
||||
docrag_question_uri,
|
||||
docrag_grounding_uri,
|
||||
|
|
@ -102,6 +103,11 @@ from . namespaces import (
|
|||
# Agent provenance predicates
|
||||
TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION,
|
||||
TG_SUBAGENT_GOAL, TG_PLAN_STEP,
|
||||
TG_TOOL_CANDIDATE, TG_TERMINATION_REASON,
|
||||
TG_STEP_NUMBER, TG_PATTERN_DECISION, TG_PATTERN, TG_TASK_TYPE,
|
||||
TG_LLM_DURATION_MS, TG_TOOL_DURATION_MS, TG_TOOL_ERROR,
|
||||
TG_IN_TOKEN, TG_OUT_TOKEN,
|
||||
TG_ERROR_TYPE,
|
||||
# Orchestrator entity types
|
||||
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
|
||||
# Document reference predicate
|
||||
|
|
@ -141,6 +147,7 @@ from . agent import (
|
|||
agent_plan_triples,
|
||||
agent_step_result_triples,
|
||||
agent_synthesis_triples,
|
||||
agent_pattern_decision_triples,
|
||||
)
|
||||
|
||||
# Vocabulary bootstrap
|
||||
|
|
@ -182,6 +189,7 @@ __all__ = [
|
|||
"agent_plan_uri",
|
||||
"agent_step_result_uri",
|
||||
"agent_synthesis_uri",
|
||||
"agent_pattern_decision_uri",
|
||||
# Document RAG provenance URIs
|
||||
"docrag_question_uri",
|
||||
"docrag_grounding_uri",
|
||||
|
|
@ -218,6 +226,11 @@ __all__ = [
|
|||
# Agent provenance predicates
|
||||
"TG_THOUGHT", "TG_ACTION", "TG_ARGUMENTS", "TG_OBSERVATION",
|
||||
"TG_SUBAGENT_GOAL", "TG_PLAN_STEP",
|
||||
"TG_TOOL_CANDIDATE", "TG_TERMINATION_REASON",
|
||||
"TG_STEP_NUMBER", "TG_PATTERN_DECISION", "TG_PATTERN", "TG_TASK_TYPE",
|
||||
"TG_LLM_DURATION_MS", "TG_TOOL_DURATION_MS", "TG_TOOL_ERROR",
|
||||
"TG_IN_TOKEN", "TG_OUT_TOKEN",
|
||||
"TG_ERROR_TYPE",
|
||||
# Orchestrator entity types
|
||||
"TG_DECOMPOSITION", "TG_FINDING", "TG_PLAN_TYPE", "TG_STEP_RESULT",
|
||||
# Document reference predicate
|
||||
|
|
@ -249,6 +262,7 @@ __all__ = [
|
|||
"agent_plan_triples",
|
||||
"agent_step_result_triples",
|
||||
"agent_synthesis_triples",
|
||||
"agent_pattern_decision_triples",
|
||||
# Utility
|
||||
"set_graph",
|
||||
# Vocabulary
|
||||
|
|
|
|||
|
|
@ -29,6 +29,11 @@ from . namespaces import (
|
|||
TG_AGENT_QUESTION,
|
||||
TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT,
|
||||
TG_SYNTHESIS, TG_SUBAGENT_GOAL, TG_PLAN_STEP,
|
||||
TG_TOOL_CANDIDATE, TG_TERMINATION_REASON,
|
||||
TG_STEP_NUMBER, TG_PATTERN_DECISION, TG_PATTERN, TG_TASK_TYPE,
|
||||
TG_LLM_DURATION_MS, TG_TOOL_DURATION_MS, TG_TOOL_ERROR,
|
||||
TG_ERROR_TYPE,
|
||||
TG_IN_TOKEN, TG_OUT_TOKEN, TG_LLM_MODEL,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -47,6 +52,17 @@ def _triple(s: str, p: str, o_term: Term) -> Triple:
|
|||
return Triple(s=_iri(s), p=_iri(p), o=o_term)
|
||||
|
||||
|
||||
def _append_token_triples(triples, uri, in_token=None, out_token=None,
|
||||
model=None):
|
||||
"""Append in_token/out_token/model triples when values are present."""
|
||||
if in_token is not None:
|
||||
triples.append(_triple(uri, TG_IN_TOKEN, _literal(str(in_token))))
|
||||
if out_token is not None:
|
||||
triples.append(_triple(uri, TG_OUT_TOKEN, _literal(str(out_token))))
|
||||
if model is not None:
|
||||
triples.append(_triple(uri, TG_LLM_MODEL, _literal(model)))
|
||||
|
||||
|
||||
def agent_session_triples(
|
||||
session_uri: str,
|
||||
query: str,
|
||||
|
|
@ -90,6 +106,43 @@ def agent_session_triples(
|
|||
return triples
|
||||
|
||||
|
||||
def agent_pattern_decision_triples(
|
||||
uri: str,
|
||||
session_uri: str,
|
||||
pattern: str,
|
||||
task_type: str = "",
|
||||
) -> List[Triple]:
|
||||
"""
|
||||
Build triples for a meta-router pattern decision.
|
||||
|
||||
Creates:
|
||||
- Entity declaration with tg:PatternDecision type
|
||||
- wasDerivedFrom link to session
|
||||
- Pattern and task type predicates
|
||||
|
||||
Args:
|
||||
uri: URI of this decision (from agent_pattern_decision_uri)
|
||||
session_uri: URI of the parent session
|
||||
pattern: Selected execution pattern (e.g. "react", "plan-then-execute")
|
||||
task_type: Identified task type (e.g. "general", "research")
|
||||
|
||||
Returns:
|
||||
List of Triple objects
|
||||
"""
|
||||
triples = [
|
||||
_triple(uri, RDF_TYPE, _iri(PROV_ENTITY)),
|
||||
_triple(uri, RDF_TYPE, _iri(TG_PATTERN_DECISION)),
|
||||
_triple(uri, RDFS_LABEL, _literal(f"Pattern: {pattern}")),
|
||||
_triple(uri, TG_PATTERN, _literal(pattern)),
|
||||
_triple(uri, PROV_WAS_DERIVED_FROM, _iri(session_uri)),
|
||||
]
|
||||
|
||||
if task_type:
|
||||
triples.append(_triple(uri, TG_TASK_TYPE, _literal(task_type)))
|
||||
|
||||
return triples
|
||||
|
||||
|
||||
def agent_iteration_triples(
|
||||
iteration_uri: str,
|
||||
question_uri: Optional[str] = None,
|
||||
|
|
@ -98,6 +151,12 @@ def agent_iteration_triples(
|
|||
arguments: Dict[str, Any] = None,
|
||||
thought_uri: Optional[str] = None,
|
||||
thought_document_id: Optional[str] = None,
|
||||
tool_candidates: Optional[List[str]] = None,
|
||||
step_number: Optional[int] = None,
|
||||
llm_duration_ms: Optional[int] = None,
|
||||
in_token: Optional[int] = None,
|
||||
out_token: Optional[int] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> List[Triple]:
|
||||
"""
|
||||
Build triples for one agent iteration (Analysis+ToolUse).
|
||||
|
|
@ -106,6 +165,7 @@ def agent_iteration_triples(
|
|||
- Entity declaration with tg:Analysis and tg:ToolUse types
|
||||
- wasDerivedFrom link to question (if first iteration) or previous
|
||||
- Action and arguments metadata
|
||||
- Tool candidates (names of tools visible to the LLM)
|
||||
- Thought sub-entity (tg:Reflection, tg:Thought) with librarian document
|
||||
|
||||
Args:
|
||||
|
|
@ -116,6 +176,7 @@ def agent_iteration_triples(
|
|||
arguments: Arguments passed to the tool (will be JSON-encoded)
|
||||
thought_uri: URI for the thought sub-entity
|
||||
thought_document_id: Document URI for thought in librarian
|
||||
tool_candidates: List of tool names available to the LLM
|
||||
|
||||
Returns:
|
||||
List of Triple objects
|
||||
|
|
@ -132,6 +193,23 @@ def agent_iteration_triples(
|
|||
_triple(iteration_uri, TG_ARGUMENTS, _literal(json.dumps(arguments))),
|
||||
]
|
||||
|
||||
if tool_candidates:
|
||||
for name in tool_candidates:
|
||||
triples.append(
|
||||
_triple(iteration_uri, TG_TOOL_CANDIDATE, _literal(name))
|
||||
)
|
||||
|
||||
if step_number is not None:
|
||||
triples.append(
|
||||
_triple(iteration_uri, TG_STEP_NUMBER, _literal(str(step_number)))
|
||||
)
|
||||
|
||||
if llm_duration_ms is not None:
|
||||
triples.append(
|
||||
_triple(iteration_uri, TG_LLM_DURATION_MS,
|
||||
_literal(str(llm_duration_ms)))
|
||||
)
|
||||
|
||||
if question_uri:
|
||||
triples.append(
|
||||
_triple(iteration_uri, PROV_WAS_DERIVED_FROM, _iri(question_uri))
|
||||
|
|
@ -155,6 +233,8 @@ def agent_iteration_triples(
|
|||
_triple(thought_uri, TG_DOCUMENT, _iri(thought_document_id))
|
||||
)
|
||||
|
||||
_append_token_triples(triples, iteration_uri, in_token, out_token, model)
|
||||
|
||||
return triples
|
||||
|
||||
|
||||
|
|
@ -162,6 +242,8 @@ def agent_observation_triples(
|
|||
observation_uri: str,
|
||||
iteration_uri: str,
|
||||
document_id: Optional[str] = None,
|
||||
tool_duration_ms: Optional[int] = None,
|
||||
tool_error: Optional[str] = None,
|
||||
) -> List[Triple]:
|
||||
"""
|
||||
Build triples for an agent observation (standalone entity).
|
||||
|
|
@ -170,11 +252,15 @@ def agent_observation_triples(
|
|||
- Entity declaration with prov:Entity and tg:Observation types
|
||||
- wasDerivedFrom link to the iteration (Analysis+ToolUse)
|
||||
- Document reference to librarian (if provided)
|
||||
- Tool execution duration (if provided)
|
||||
- Tool error message (if the tool failed)
|
||||
|
||||
Args:
|
||||
observation_uri: URI of the observation entity
|
||||
iteration_uri: URI of the iteration this observation derives from
|
||||
document_id: Librarian document ID for the observation content
|
||||
tool_duration_ms: Tool execution time in milliseconds
|
||||
tool_error: Error message if the tool failed
|
||||
|
||||
Returns:
|
||||
List of Triple objects
|
||||
|
|
@ -191,6 +277,20 @@ def agent_observation_triples(
|
|||
_triple(observation_uri, TG_DOCUMENT, _iri(document_id))
|
||||
)
|
||||
|
||||
if tool_duration_ms is not None:
|
||||
triples.append(
|
||||
_triple(observation_uri, TG_TOOL_DURATION_MS,
|
||||
_literal(str(tool_duration_ms)))
|
||||
)
|
||||
|
||||
if tool_error:
|
||||
triples.append(
|
||||
_triple(observation_uri, TG_TOOL_ERROR, _literal(tool_error))
|
||||
)
|
||||
triples.append(
|
||||
_triple(observation_uri, RDF_TYPE, _iri(TG_ERROR_TYPE))
|
||||
)
|
||||
|
||||
return triples
|
||||
|
||||
|
||||
|
|
@ -199,6 +299,10 @@ def agent_final_triples(
|
|||
question_uri: Optional[str] = None,
|
||||
previous_uri: Optional[str] = None,
|
||||
document_id: Optional[str] = None,
|
||||
termination_reason: Optional[str] = None,
|
||||
in_token: Optional[int] = None,
|
||||
out_token: Optional[int] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> List[Triple]:
|
||||
"""
|
||||
Build triples for an agent final answer (Conclusion).
|
||||
|
|
@ -208,12 +312,15 @@ def agent_final_triples(
|
|||
- wasGeneratedBy link to question (if no iterations)
|
||||
- wasDerivedFrom link to last iteration (if iterations exist)
|
||||
- Document reference to librarian
|
||||
- Termination reason (why the agent loop stopped)
|
||||
|
||||
Args:
|
||||
final_uri: URI of the final answer (from agent_final_uri)
|
||||
question_uri: URI of the question activity (if no iterations)
|
||||
previous_uri: URI of the last iteration (if iterations exist)
|
||||
document_id: Librarian document ID for the answer content
|
||||
termination_reason: Why the loop stopped, e.g. "final-answer",
|
||||
"max-iterations", "error"
|
||||
|
||||
Returns:
|
||||
List of Triple objects
|
||||
|
|
@ -237,6 +344,14 @@ def agent_final_triples(
|
|||
if document_id:
|
||||
triples.append(_triple(final_uri, TG_DOCUMENT, _iri(document_id)))
|
||||
|
||||
if termination_reason:
|
||||
triples.append(
|
||||
_triple(final_uri, TG_TERMINATION_REASON,
|
||||
_literal(termination_reason))
|
||||
)
|
||||
|
||||
_append_token_triples(triples, final_uri, in_token, out_token, model)
|
||||
|
||||
return triples
|
||||
|
||||
|
||||
|
|
@ -244,6 +359,9 @@ def agent_decomposition_triples(
|
|||
uri: str,
|
||||
session_uri: str,
|
||||
goals: List[str],
|
||||
in_token: Optional[int] = None,
|
||||
out_token: Optional[int] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> List[Triple]:
|
||||
"""Build triples for a supervisor decomposition step."""
|
||||
triples = [
|
||||
|
|
@ -255,6 +373,7 @@ def agent_decomposition_triples(
|
|||
]
|
||||
for goal in goals:
|
||||
triples.append(_triple(uri, TG_SUBAGENT_GOAL, _literal(goal)))
|
||||
_append_token_triples(triples, uri, in_token, out_token, model)
|
||||
return triples
|
||||
|
||||
|
||||
|
|
@ -282,6 +401,9 @@ def agent_plan_triples(
|
|||
uri: str,
|
||||
session_uri: str,
|
||||
steps: List[str],
|
||||
in_token: Optional[int] = None,
|
||||
out_token: Optional[int] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> List[Triple]:
|
||||
"""Build triples for a plan-then-execute plan."""
|
||||
triples = [
|
||||
|
|
@ -293,6 +415,7 @@ def agent_plan_triples(
|
|||
]
|
||||
for step in steps:
|
||||
triples.append(_triple(uri, TG_PLAN_STEP, _literal(step)))
|
||||
_append_token_triples(triples, uri, in_token, out_token, model)
|
||||
return triples
|
||||
|
||||
|
||||
|
|
@ -301,6 +424,9 @@ def agent_step_result_triples(
|
|||
plan_uri: str,
|
||||
goal: str,
|
||||
document_id: Optional[str] = None,
|
||||
in_token: Optional[int] = None,
|
||||
out_token: Optional[int] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> List[Triple]:
|
||||
"""Build triples for a plan step result."""
|
||||
triples = [
|
||||
|
|
@ -313,6 +439,7 @@ def agent_step_result_triples(
|
|||
]
|
||||
if document_id:
|
||||
triples.append(_triple(uri, TG_DOCUMENT, _iri(document_id)))
|
||||
_append_token_triples(triples, uri, in_token, out_token, model)
|
||||
return triples
|
||||
|
||||
|
||||
|
|
@ -320,6 +447,10 @@ def agent_synthesis_triples(
|
|||
uri: str,
|
||||
previous_uris,
|
||||
document_id: Optional[str] = None,
|
||||
termination_reason: Optional[str] = None,
|
||||
in_token: Optional[int] = None,
|
||||
out_token: Optional[int] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> List[Triple]:
|
||||
"""Build triples for a synthesis answer.
|
||||
|
||||
|
|
@ -327,6 +458,8 @@ def agent_synthesis_triples(
|
|||
uri: URI of the synthesis entity
|
||||
previous_uris: Single URI string or list of URIs to derive from
|
||||
document_id: Librarian document ID for the answer content
|
||||
termination_reason: Why the agent loop stopped
|
||||
in_token/out_token/model: Token usage for the synthesis LLM call
|
||||
"""
|
||||
triples = [
|
||||
_triple(uri, RDF_TYPE, _iri(PROV_ENTITY)),
|
||||
|
|
@ -342,4 +475,12 @@ def agent_synthesis_triples(
|
|||
|
||||
if document_id:
|
||||
triples.append(_triple(uri, TG_DOCUMENT, _iri(document_id)))
|
||||
|
||||
if termination_reason:
|
||||
triples.append(
|
||||
_triple(uri, TG_TERMINATION_REASON, _literal(termination_reason))
|
||||
)
|
||||
|
||||
_append_token_triples(triples, uri, in_token, out_token, model)
|
||||
|
||||
return triples
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue