diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml index 2243da10..b1ae8611 100644 --- a/.github/workflows/pull-request.yaml +++ b/.github/workflows/pull-request.yaml @@ -22,7 +22,7 @@ jobs: uses: actions/checkout@v3 - name: Setup packages - run: make update-package-versions VERSION=2.2.999 + run: make update-package-versions VERSION=2.3.999 - name: Setup environment run: python3 -m venv env diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index dc2fc89b..07af8db9 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -40,10 +40,9 @@ jobs: - name: Publish release distributions to PyPI uses: pypa/gh-action-pypi-publish@release/v1 - deploy-container-image: + build-platform-image: - name: Release container images - runs-on: ubuntu-24.04 + name: Build ${{ matrix.container }} (${{ matrix.platform }}) permissions: contents: write id-token: write @@ -52,14 +51,24 @@ jobs: strategy: matrix: container: - - trustgraph-base - - trustgraph-flow - - trustgraph-bedrock - - trustgraph-vertexai - - trustgraph-hf - - trustgraph-ocr - - trustgraph-unstructured - - trustgraph-mcp + - base + - flow + - bedrock + - vertexai + - hf + - ocr + - unstructured + - mcp + platform: + - amd64 + - arm64 + include: + - platform: amd64 + runner: ubuntu-24.04 + - platform: arm64 + runner: ubuntu-24.04-arm + + runs-on: ${{ matrix.runner }} steps: @@ -76,12 +85,48 @@ jobs: id: version run: echo VERSION=$(git describe --exact-match --tags | sed 's/^v//') >> $GITHUB_OUTPUT - - name: Put version into package manifests - run: make update-package-versions VERSION=${{ steps.version.outputs.VERSION }} + - name: Build container + run: make platform-${{ matrix.container }}-${{ matrix.platform }} VERSION=${{ steps.version.outputs.VERSION }} - - name: Build container - ${{ matrix.container }} - run: make container-${{ matrix.container }} VERSION=${{ steps.version.outputs.VERSION }} + - name: Push container + run: make push-platform-${{ matrix.container }}-${{ matrix.platform }} VERSION=${{ steps.version.outputs.VERSION }} - - name: Push container - ${{ matrix.container }} - run: make push-${{ matrix.container }} VERSION=${{ steps.version.outputs.VERSION }} + combine-manifests: + name: Combine manifest ${{ matrix.container }} + runs-on: ubuntu-24.04 + needs: build-platform-image + permissions: + contents: write + id-token: write + environment: + name: release + strategy: + matrix: + container: + - base + - flow + - bedrock + - vertexai + - hf + - ocr + - unstructured + - mcp + + steps: + + - name: Checkout + uses: actions/checkout@v4 + + - name: Docker Hub token + run: echo ${{ secrets.DOCKER_SECRET }} > docker-token.txt + + - name: Authenticate with Docker hub + run: make docker-hub-login + + - name: Get version + id: version + run: echo VERSION=$(git describe --exact-match --tags | sed 's/^v//') >> $GITHUB_OUTPUT + + - name: Combine and push manifest + run: make combine-manifest-${{ matrix.container }} VERSION=${{ steps.version.outputs.VERSION }} diff --git a/Makefile b/Makefile index 197a6c63..85f10fdd 100644 --- a/Makefile +++ b/Makefile @@ -52,51 +52,12 @@ update-package-versions: echo __version__ = \"${VERSION}\" > trustgraph/trustgraph/trustgraph_version.py echo __version__ = \"${VERSION}\" > trustgraph-mcp/trustgraph/mcp_version.py -FORCE: +containers: container-base container-flow \ +container-bedrock container-vertexai \ +container-hf container-ocr \ +container-unstructured container-mcp -containers: FORCE - ${DOCKER} build -f containers/Containerfile.base \ - -t ${CONTAINER_BASE}/trustgraph-base:${VERSION} . - ${DOCKER} build -f containers/Containerfile.flow \ - -t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} . - ${DOCKER} build -f containers/Containerfile.bedrock \ - -t ${CONTAINER_BASE}/trustgraph-bedrock:${VERSION} . - ${DOCKER} build -f containers/Containerfile.vertexai \ - -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} . - ${DOCKER} build -f containers/Containerfile.hf \ - -t ${CONTAINER_BASE}/trustgraph-hf:${VERSION} . - ${DOCKER} build -f containers/Containerfile.ocr \ - -t ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} . - ${DOCKER} build -f containers/Containerfile.unstructured \ - -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} . - ${DOCKER} build -f containers/Containerfile.mcp \ - -t ${CONTAINER_BASE}/trustgraph-mcp:${VERSION} . - -some-containers: - ${DOCKER} build -f containers/Containerfile.base \ - -t ${CONTAINER_BASE}/trustgraph-base:${VERSION} . - ${DOCKER} build -f containers/Containerfile.flow \ - -t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} . -# ${DOCKER} build -f containers/Containerfile.unstructured \ -# -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} . -# ${DOCKER} build -f containers/Containerfile.vertexai \ -# -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} . -# ${DOCKER} build -f containers/Containerfile.mcp \ -# -t ${CONTAINER_BASE}/trustgraph-mcp:${VERSION} . -# ${DOCKER} build -f containers/Containerfile.vertexai \ -# -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} . -# ${DOCKER} build -f containers/Containerfile.bedrock \ -# -t ${CONTAINER_BASE}/trustgraph-bedrock:${VERSION} . - -basic-containers: update-package-versions - ${DOCKER} build -f containers/Containerfile.base \ - -t ${CONTAINER_BASE}/trustgraph-base:${VERSION} . - ${DOCKER} build -f containers/Containerfile.flow \ - -t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} . - -container.ocr: - ${DOCKER} build -f containers/Containerfile.ocr \ - -t ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} . +some-containers: container-base container-flow push: ${DOCKER} push ${CONTAINER_BASE}/trustgraph-base:${VERSION} @@ -109,54 +70,60 @@ push: ${DOCKER} push ${CONTAINER_BASE}/trustgraph-mcp:${VERSION} # Individual container build targets -container-trustgraph-base: update-package-versions - ${DOCKER} build -f containers/Containerfile.base -t ${CONTAINER_BASE}/trustgraph-base:${VERSION} . +container-%: update-package-versions + ${DOCKER} build \ + -f containers/Containerfile.${@:container-%=%} \ + -t ${CONTAINER_BASE}/trustgraph-${@:container-%=%}:${VERSION} . -container-trustgraph-flow: update-package-versions - ${DOCKER} build -f containers/Containerfile.flow -t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} . +# Multi-arch: build both platforms sequentially into one manifest (local use) +manifest-%: update-package-versions + -@${DOCKER} manifest rm \ + ${CONTAINER_BASE}/trustgraph-${@:manifest-%=%}:${VERSION} + ${DOCKER} build --platform linux/amd64,linux/arm64 \ + -f containers/Containerfile.${@:manifest-%=%} \ + --manifest \ + ${CONTAINER_BASE}/trustgraph-${@:manifest-%=%}:${VERSION} . -container-trustgraph-bedrock: update-package-versions - ${DOCKER} build -f containers/Containerfile.bedrock -t ${CONTAINER_BASE}/trustgraph-bedrock:${VERSION} . +# Multi-arch: build a single platform image (for parallel CI) +platform-%-amd64: update-package-versions + ${DOCKER} build --platform linux/amd64 \ + -f containers/Containerfile.${@:platform-%-amd64=%} \ + -t ${CONTAINER_BASE}/trustgraph-${@:platform-%-amd64=%}:${VERSION}-amd64 . -container-trustgraph-vertexai: update-package-versions - ${DOCKER} build -f containers/Containerfile.vertexai -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} . +platform-%-arm64: update-package-versions + ${DOCKER} build --platform linux/arm64 \ + -f containers/Containerfile.${@:platform-%-arm64=%} \ + -t ${CONTAINER_BASE}/trustgraph-${@:platform-%-arm64=%}:${VERSION}-arm64 . -container-trustgraph-hf: update-package-versions - ${DOCKER} build -f containers/Containerfile.hf -t ${CONTAINER_BASE}/trustgraph-hf:${VERSION} . +# Push a single platform image +push-platform-%-amd64: + ${DOCKER} push \ + ${CONTAINER_BASE}/trustgraph-${@:push-platform-%-amd64=%}:${VERSION}-amd64 -container-trustgraph-ocr: update-package-versions - ${DOCKER} build -f containers/Containerfile.ocr -t ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} . +push-platform-%-arm64: + ${DOCKER} push \ + ${CONTAINER_BASE}/trustgraph-${@:push-platform-%-arm64=%}:${VERSION}-arm64 -container-trustgraph-unstructured: update-package-versions - ${DOCKER} build -f containers/Containerfile.unstructured -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} . +# Combine per-platform images into a multi-arch manifest +combine-manifest-%: + -@${DOCKER} manifest rm \ + ${CONTAINER_BASE}/trustgraph-${@:combine-manifest-%=%}:${VERSION} + ${DOCKER} manifest create \ + ${CONTAINER_BASE}/trustgraph-${@:combine-manifest-%=%}:${VERSION} \ + docker://${CONTAINER_BASE}/trustgraph-${@:combine-manifest-%=%}:${VERSION}-amd64 \ + docker://${CONTAINER_BASE}/trustgraph-${@:combine-manifest-%=%}:${VERSION}-arm64 + ${DOCKER} manifest push \ + ${CONTAINER_BASE}/trustgraph-${@:combine-manifest-%=%}:${VERSION} -container-trustgraph-mcp: update-package-versions - ${DOCKER} build -f containers/Containerfile.mcp -t ${CONTAINER_BASE}/trustgraph-mcp:${VERSION} . +# Push a container +push-container-%: + ${DOCKER} push \ + ${CONTAINER_BASE}/trustgraph-${@:push-container-%=%}:${VERSION} -# Individual container push targets -push-trustgraph-base: - ${DOCKER} push ${CONTAINER_BASE}/trustgraph-base:${VERSION} - -push-trustgraph-flow: - ${DOCKER} push ${CONTAINER_BASE}/trustgraph-flow:${VERSION} - -push-trustgraph-bedrock: - ${DOCKER} push ${CONTAINER_BASE}/trustgraph-bedrock:${VERSION} - -push-trustgraph-vertexai: - ${DOCKER} push ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} - -push-trustgraph-hf: - ${DOCKER} push ${CONTAINER_BASE}/trustgraph-hf:${VERSION} - -push-trustgraph-ocr: - ${DOCKER} push ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} - -push-trustgraph-unstructured: - ${DOCKER} push ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} - -push-trustgraph-mcp: - ${DOCKER} push ${CONTAINER_BASE}/trustgraph-mcp:${VERSION} +# Push a manifest (from local multi-arch build) +push-manifest-%: + ${DOCKER} manifest push \ + ${CONTAINER_BASE}/trustgraph-${@:push-manifest-%=%}:${VERSION} clean: rm -rf wheels/ @@ -164,52 +131,6 @@ clean: set-version: echo '"${VERSION}"' > templates/values/version.jsonnet -TEMPLATES=azure bedrock claude cohere mix llamafile mistral ollama openai vertexai \ - openai-neo4j storage - -DCS=$(foreach template,${TEMPLATES},${template:%=tg-launch-%.yaml}) - -MODELS=azure bedrock claude cohere llamafile mistral ollama openai vertexai -GRAPHS=cassandra neo4j falkordb memgraph - -# tg-launch-%.yaml: templates/%.jsonnet templates/components/version.jsonnet -# jsonnet -Jtemplates \ -# -S ${@:tg-launch-%.yaml=templates/%.jsonnet} > $@ - -# VECTORDB=milvus -VECTORDB=qdrant - -JSONNET_FLAGS=-J templates -J . - -# Temporarily going back to how templates were built in 0.9 because this -# is going away in 0.11. - -update-templates: update-dcs - -JSON_TO_YAML=python -c 'import sys, yaml, json; j=json.loads(sys.stdin.read()); print(yaml.safe_dump(j))' - -update-dcs: set-version - for graph in ${GRAPHS}; do \ - cm=$${graph},pulsar,${VECTORDB},grafana; \ - input=templates/opts-to-docker-compose.jsonnet; \ - output=tg-storage-$${graph}.yaml; \ - echo $${graph} '->' $${output}; \ - jsonnet ${JSONNET_FLAGS} \ - --ext-str options=$${cm} $${input} | \ - ${JSON_TO_YAML} > $${output}; \ - done - for model in ${MODELS}; do \ - for graph in ${GRAPHS}; do \ - cm=$${graph},pulsar,${VECTORDB},embeddings-hf,graph-rag,grafana,trustgraph,$${model}; \ - input=templates/opts-to-docker-compose.jsonnet; \ - output=tg-launch-$${model}-$${graph}.yaml; \ - echo $${model} + $${graph} '->' $${output}; \ - jsonnet ${JSONNET_FLAGS} \ - --ext-str options=$${cm} $${input} | \ - ${JSON_TO_YAML} > $${output}; \ - done; \ - done - docker-hub-login: cat docker-token.txt | \ ${DOCKER} login -u trustgraph --password-stdin registry-1.docker.io diff --git a/containers/Containerfile.hf b/containers/Containerfile.hf index 351300ae..a1ec5346 100644 --- a/containers/Containerfile.hf +++ b/containers/Containerfile.hf @@ -1,22 +1,22 @@ -# ---------------------------------------------------------------------------- -# Build an AI container. This does the torch install which is huge, and I -# like to avoid re-doing this. -# ---------------------------------------------------------------------------- +# Torch is stable and compiles for ARM64 and AMD64 on Python 3.12 FROM docker.io/fedora:42 AS ai ENV PIP_BREAK_SYSTEM_PACKAGES=1 -RUN dnf install -y python3.13 && \ - alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \ +RUN dnf install -y python3.12 && \ + alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \ python -m ensurepip --upgrade && \ pip3 install --no-cache-dir build wheel aiohttp && \ pip3 install --no-cache-dir pulsar-client==3.7.0 && \ dnf clean all -RUN pip3 install torch==2.5.1+cpu \ - --index-url https://download.pytorch.org/whl/cpu +# This won't work on ARM +#RUN pip3 install torch==2.5.1+cpu \ +# --index-url https://download.pytorch.org/whl/cpu + +RUN pip3 install torch RUN pip3 install --no-cache-dir \ langchain==0.3.25 langchain-core==0.3.60 langchain-huggingface==0.2.0 \ diff --git a/dev-tools/proc-group/README.md b/dev-tools/proc-group/README.md new file mode 100644 index 00000000..1874ea36 --- /dev/null +++ b/dev-tools/proc-group/README.md @@ -0,0 +1,117 @@ +# proc-group — run TrustGraph as a single process + +A dev-focused alternative to the per-container deployment. Instead of 30+ +containers each running a single processor, `processor-group` runs all the +processors as asyncio tasks inside one Python process, sharing the event +loop, Prometheus registry, and (importantly) resources on your laptop. + +This is **not** for production. Scale deployments should keep using +per-processor containers — one failure bringing down the whole process, +no horizontal scaling, and a single giant log are fine for dev and a +bad idea in prod. + +## What this directory contains + +- `group.yaml` — the group runner config. One entry per processor, each + with the dotted class path and a params dict. Defaults (pubsub backend, + rabbitmq host, log level) are pulled in per-entry with a YAML anchor. +- `README.md` — this file. + +## Prerequisites + +Install the TrustGraph packages into a venv: + +``` +pip install trustgraph-base trustgraph-flow trustgraph-unstructured +``` + +`trustgraph-base` provides the `processor-group` endpoint. The others +provide the processor classes that `group.yaml` imports at runtime. +`trustgraph-unstructured` is only needed if you want `document-decoder` +(the `universal-decoder` processor). + +## Running it + +Start infrastructure (cassandra, qdrant, rabbitmq, garage, observability +stack) with a working compose file. These aren't packable into the group - +they're third-party services. You may be able to run these as standalone +services. + +To get Cassandra to be accessible from the host, you need to +set a couple of environment variables: +``` + CASSANDRA_BROADCAST_ADDRESS: 127.0.0.1 + CASSANDRA_LISTEN_ADDRESS: 127.0.0.1 +``` +and also set `network: host`. Then start services: + +``` +podman-compose up -d cassandra qdrant rabbitmq +podman-compose up -d garage garage-init +podman-compose up -d loki prometheus grafana +podman-compose up -d init-trustgraph +``` + +`init-trustgraph` is a one-shot that seeds config and the default flow +into cassandra/rabbitmq. Don't leave too long a delay between starting +`init-trustgraph` and running the processor-group, because it needs to +talk to the config service. + +Run the api-gateway separately — it's an aiohttp HTTP server, not an +`AsyncProcessor`, so the group runner doesn't host it: + +Raise the file descriptor limit — 30+ processors sharing one process +open far more sockets than the default 1024 allows: + +``` +ulimit -n 65536 +``` + +Then start the group from a terminal: + +``` +processor-group -c group.yaml --no-loki-enabled +``` + +You'll see every processor's startup messages interleaved in one log. +Each processor has a supervisor that restarts it independently on +failure, so a transient crash (or a dependency that isn't ready yet) +only affects that one processor — siblings keep running and the failing +one self-heals on the next retry. + +Finally when everything is running you can start the API gateway from +its own terminal: + +``` +api-gateway \ + --pubsub-backend rabbitmq --rabbitmq-host localhost \ + --loki-url http://localhost:3100/loki/api/v1/push \ + --no-metrics +``` + + + +## When things go wrong + +- **"Too many open files"** — raise `ulimit -n` further. 65536 is + usually plenty but some workflows need more. +- **One processor failing repeatedly** — look for its id in the log. The + supervisor will log each failure before restarting. Fix the cause + (missing env var, unreachable dependency, bad params) and the + processor self-heals on the next 4-second retry without restarting + the whole group. +- **Ctrl-C leaves the process hung** — the pika and cassandra drivers + spawn non-cooperative threads that asyncio can't cancel. Use Ctrl-\ + (SIGQUIT) to force-kill. Not a bug in the group runner, just a + limitation of those libraries. + +## Environment variables + +Processors that talk to external LLMs or APIs read their credentials +from env vars, same as in the per-container deployment: + +- `OPENAI_TOKEN`, `OPENAI_BASE_URL` — for `text-completion` / + `text-completion-rag` + +Export whatever your particular `group.yaml` needs before running. + diff --git a/dev-tools/proc-group/group.yaml b/dev-tools/proc-group/group.yaml new file mode 100644 index 00000000..98ef5016 --- /dev/null +++ b/dev-tools/proc-group/group.yaml @@ -0,0 +1,257 @@ +# Multi-processor group config, derived from docker-compose.yaml. +# +# Covers every AsyncProcessor-based service from the compose file. +# Out of scope: +# - api-gateway (aiohttp, not AsyncProcessor) +# - init-trustgraph (one-shot init, not a processor) +# - document-decoder (universal-decoder, trustgraph-unstructured package — +# packable but lives in a separate image/package) +# - mcp-server (trustgraph-mcp package, separate image) +# - ddg-mcp-server (third-party image) +# - infrastructure (cassandra, rabbitmq, qdrant, garage, grafana, +# prometheus, loki, workbench-ui) +# +# Run with: +# processor-group -c group.yaml + +_defaults: &defaults + pubsub_backend: rabbitmq + rabbitmq_host: localhost + log_level: INFO + +processors: + + - class: trustgraph.agent.orchestrator.Processor + params: + <<: *defaults + id: agent-manager + + - class: trustgraph.chunking.recursive.Processor + params: + <<: *defaults + id: chunker + chunk_size: 2000 + chunk_overlap: 50 + + - class: trustgraph.config.service.Processor + params: + <<: *defaults + id: config-svc + cassandra_host: localhost + + - class: trustgraph.decoding.universal.Processor + params: + <<: *defaults + id: document-decoder + + - class: trustgraph.embeddings.document_embeddings.Processor + params: + <<: *defaults + id: document-embeddings + + - class: trustgraph.retrieval.document_rag.Processor + params: + <<: *defaults + id: document-rag + doc_limit: 20 + + - class: trustgraph.embeddings.fastembed.Processor + params: + <<: *defaults + id: embeddings + concurrency: 1 + + - class: trustgraph.embeddings.graph_embeddings.Processor + params: + <<: *defaults + id: graph-embeddings + + - class: trustgraph.retrieval.graph_rag.Processor + params: + <<: *defaults + id: graph-rag + concurrency: 1 + entity_limit: 50 + triple_limit: 30 + edge_limit: 30 + edge_score_limit: 10 + max_subgraph_size: 100 + max_path_length: 2 + + - class: trustgraph.extract.kg.agent.Processor + params: + <<: *defaults + id: kg-extract-agent + concurrency: 1 + + - class: trustgraph.extract.kg.definitions.Processor + params: + <<: *defaults + id: kg-extract-definitions + concurrency: 1 + + - class: trustgraph.extract.kg.ontology.Processor + params: + <<: *defaults + id: kg-extract-ontology + concurrency: 1 + + - class: trustgraph.extract.kg.relationships.Processor + params: + <<: *defaults + id: kg-extract-relationships + concurrency: 1 + + - class: trustgraph.extract.kg.rows.Processor + params: + <<: *defaults + id: kg-extract-rows + concurrency: 1 + + - class: trustgraph.cores.service.Processor + params: + <<: *defaults + id: knowledge + cassandra_host: localhost + + - class: trustgraph.storage.knowledge.store.Processor + params: + <<: *defaults + id: kg-store + cassandra_host: localhost + + - class: trustgraph.librarian.Processor + params: + <<: *defaults + id: librarian + cassandra_host: localhost + object_store_endpoint: localhost:3900 + object_store_access_key: GK000000000000000000000001 + object_store_secret_key: b171f00be9be4c32c734f4c05fe64c527a8ab5eb823b376cfa8c2531f70fc427 + object_store_region: garage + + - class: trustgraph.agent.mcp_tool.Service + params: + <<: *defaults + id: mcp-tool + + - class: trustgraph.metering.Processor + params: + <<: *defaults + id: metering + + - class: trustgraph.metering.Processor + params: + <<: *defaults + id: metering-rag + + - class: trustgraph.retrieval.nlp_query.Processor + params: + <<: *defaults + id: nlp-query + + - class: trustgraph.prompt.template.Processor + params: + <<: *defaults + id: prompt + concurrency: 1 + + - class: trustgraph.prompt.template.Processor + params: + <<: *defaults + id: prompt-rag + concurrency: 1 + + - class: trustgraph.query.doc_embeddings.qdrant.Processor + params: + <<: *defaults + id: doc-embeddings-query + store_uri: http://localhost:6333 + + - class: trustgraph.query.graph_embeddings.qdrant.Processor + params: + <<: *defaults + id: graph-embeddings-query + store_uri: http://localhost:6333 + + - class: trustgraph.query.row_embeddings.qdrant.Processor + params: + <<: *defaults + id: row-embeddings-query + store_uri: http://localhost:6333 + + - class: trustgraph.query.rows.cassandra.Processor + params: + <<: *defaults + id: rows-query + cassandra_host: localhost + + - class: trustgraph.query.triples.cassandra.Processor + params: + <<: *defaults + id: triples-query + cassandra_host: localhost + + - class: trustgraph.embeddings.row_embeddings.Processor + params: + <<: *defaults + id: row-embeddings + + - class: trustgraph.query.sparql.Processor + params: + <<: *defaults + id: sparql-query + + - class: trustgraph.storage.doc_embeddings.qdrant.Processor + params: + <<: *defaults + id: doc-embeddings-write + store_uri: http://localhost:6333 + + - class: trustgraph.storage.graph_embeddings.qdrant.Processor + params: + <<: *defaults + id: graph-embeddings-write + store_uri: http://localhost:6333 + + - class: trustgraph.storage.row_embeddings.qdrant.Processor + params: + <<: *defaults + id: row-embeddings-write + store_uri: http://localhost:6333 + + - class: trustgraph.storage.rows.cassandra.Processor + params: + <<: *defaults + id: rows-write + cassandra_host: localhost + + - class: trustgraph.storage.triples.cassandra.Processor + params: + <<: *defaults + id: triples-write + cassandra_host: localhost + + - class: trustgraph.retrieval.structured_diag.Processor + params: + <<: *defaults + id: structured-diag + + - class: trustgraph.retrieval.structured_query.Processor + params: + <<: *defaults + id: structured-query + + - class: trustgraph.model.text_completion.openai.Processor + params: + <<: *defaults + id: text-completion + max_output: 8192 + temperature: 0.0 + + - class: trustgraph.model.text_completion.openai.Processor + params: + <<: *defaults + id: text-completion-rag + max_output: 8192 + temperature: 0.0 diff --git a/dev-tools/proc-group/groups/control.yaml b/dev-tools/proc-group/groups/control.yaml new file mode 100644 index 00000000..b9ee9bfa --- /dev/null +++ b/dev-tools/proc-group/groups/control.yaml @@ -0,0 +1,47 @@ +# Control plane. Stateful "always on" services that every flow depends on. +# Cassandra-heavy, low traffic. + +_defaults: &defaults + pubsub_backend: rabbitmq + rabbitmq_host: localhost + log_level: INFO + +processors: + + - class: trustgraph.config.service.Processor + params: + <<: *defaults + id: config-svc + cassandra_host: localhost + + - class: trustgraph.librarian.Processor + params: + <<: *defaults + id: librarian + cassandra_host: localhost + object_store_endpoint: localhost:3900 + object_store_access_key: GK000000000000000000000001 + object_store_secret_key: b171f00be9be4c32c734f4c05fe64c527a8ab5eb823b376cfa8c2531f70fc427 + object_store_region: garage + + - class: trustgraph.cores.service.Processor + params: + <<: *defaults + id: knowledge + cassandra_host: localhost + + - class: trustgraph.storage.knowledge.store.Processor + params: + <<: *defaults + id: kg-store + cassandra_host: localhost + + - class: trustgraph.metering.Processor + params: + <<: *defaults + id: metering + + - class: trustgraph.metering.Processor + params: + <<: *defaults + id: metering-rag diff --git a/dev-tools/proc-group/groups/embeddings-store.yaml b/dev-tools/proc-group/groups/embeddings-store.yaml new file mode 100644 index 00000000..b5d4a6c8 --- /dev/null +++ b/dev-tools/proc-group/groups/embeddings-store.yaml @@ -0,0 +1,45 @@ +# Embeddings store. All Qdrant-backed vector query/write processors. +# One process owns the Qdrant driver pool for the whole group. + +_defaults: &defaults + pubsub_backend: rabbitmq + rabbitmq_host: localhost + log_level: INFO + +processors: + + - class: trustgraph.query.doc_embeddings.qdrant.Processor + params: + <<: *defaults + id: doc-embeddings-query + store_uri: http://localhost:6333 + + - class: trustgraph.storage.doc_embeddings.qdrant.Processor + params: + <<: *defaults + id: doc-embeddings-write + store_uri: http://localhost:6333 + + - class: trustgraph.query.graph_embeddings.qdrant.Processor + params: + <<: *defaults + id: graph-embeddings-query + store_uri: http://localhost:6333 + + - class: trustgraph.storage.graph_embeddings.qdrant.Processor + params: + <<: *defaults + id: graph-embeddings-write + store_uri: http://localhost:6333 + + - class: trustgraph.query.row_embeddings.qdrant.Processor + params: + <<: *defaults + id: row-embeddings-query + store_uri: http://localhost:6333 + + - class: trustgraph.storage.row_embeddings.qdrant.Processor + params: + <<: *defaults + id: row-embeddings-write + store_uri: http://localhost:6333 diff --git a/dev-tools/proc-group/groups/embeddings.yaml b/dev-tools/proc-group/groups/embeddings.yaml new file mode 100644 index 00000000..a4e0298b --- /dev/null +++ b/dev-tools/proc-group/groups/embeddings.yaml @@ -0,0 +1,31 @@ +# Embeddings. Memory-hungry — fastembed loads an ML model at startup. +# Keep isolated from other groups so its memory footprint and restart +# latency don't affect siblings. + +_defaults: &defaults + pubsub_backend: rabbitmq + rabbitmq_host: localhost + log_level: INFO + +processors: + + - class: trustgraph.embeddings.fastembed.Processor + params: + <<: *defaults + id: embeddings + concurrency: 1 + + - class: trustgraph.embeddings.document_embeddings.Processor + params: + <<: *defaults + id: document-embeddings + + - class: trustgraph.embeddings.graph_embeddings.Processor + params: + <<: *defaults + id: graph-embeddings + + - class: trustgraph.embeddings.row_embeddings.Processor + params: + <<: *defaults + id: row-embeddings diff --git a/dev-tools/proc-group/groups/ingest.yaml b/dev-tools/proc-group/groups/ingest.yaml new file mode 100644 index 00000000..146a6339 --- /dev/null +++ b/dev-tools/proc-group/groups/ingest.yaml @@ -0,0 +1,52 @@ +# Ingest pipeline. Document-processing hot path. Bursty, correlated +# failures — if chunker dies the extractors have nothing to do anyway. + +_defaults: &defaults + pubsub_backend: rabbitmq + rabbitmq_host: localhost + log_level: INFO + +processors: + + - class: trustgraph.chunking.recursive.Processor + params: + <<: *defaults + id: chunker + chunk_size: 2000 + chunk_overlap: 50 + + - class: trustgraph.extract.kg.agent.Processor + params: + <<: *defaults + id: kg-extract-agent + concurrency: 1 + + - class: trustgraph.extract.kg.definitions.Processor + params: + <<: *defaults + id: kg-extract-definitions + concurrency: 1 + + - class: trustgraph.extract.kg.ontology.Processor + params: + <<: *defaults + id: kg-extract-ontology + concurrency: 1 + + - class: trustgraph.extract.kg.relationships.Processor + params: + <<: *defaults + id: kg-extract-relationships + concurrency: 1 + + - class: trustgraph.extract.kg.rows.Processor + params: + <<: *defaults + id: kg-extract-rows + concurrency: 1 + + - class: trustgraph.prompt.template.Processor + params: + <<: *defaults + id: prompt + concurrency: 1 diff --git a/dev-tools/proc-group/groups/llm.yaml b/dev-tools/proc-group/groups/llm.yaml new file mode 100644 index 00000000..35930dbf --- /dev/null +++ b/dev-tools/proc-group/groups/llm.yaml @@ -0,0 +1,24 @@ +# LLM. Outbound text-completion calls. Isolated because the upstream +# LLM API is often the bottleneck and the most likely thing to need +# restart (provider changes, model changes, API flakiness). + +_defaults: &defaults + pubsub_backend: rabbitmq + rabbitmq_host: localhost + log_level: INFO + +processors: + + - class: trustgraph.model.text_completion.openai.Processor + params: + <<: *defaults + id: text-completion + max_output: 8192 + temperature: 0.0 + + - class: trustgraph.model.text_completion.openai.Processor + params: + <<: *defaults + id: text-completion-rag + max_output: 8192 + temperature: 0.0 diff --git a/dev-tools/proc-group/groups/rag.yaml b/dev-tools/proc-group/groups/rag.yaml new file mode 100644 index 00000000..be27086b --- /dev/null +++ b/dev-tools/proc-group/groups/rag.yaml @@ -0,0 +1,64 @@ +# RAG / retrieval / agent. Query-time serving path. Drives outbound +# LLM calls via prompt-rag. sparql-query lives here because it's a +# read-side serving endpoint, not a backend writer. + +_defaults: &defaults + pubsub_backend: rabbitmq + rabbitmq_host: localhost + log_level: INFO + +processors: + + - class: trustgraph.agent.orchestrator.Processor + params: + <<: *defaults + id: agent-manager + + - class: trustgraph.retrieval.graph_rag.Processor + params: + <<: *defaults + id: graph-rag + concurrency: 1 + entity_limit: 50 + triple_limit: 30 + edge_limit: 30 + edge_score_limit: 10 + max_subgraph_size: 100 + max_path_length: 2 + + - class: trustgraph.retrieval.document_rag.Processor + params: + <<: *defaults + id: document-rag + doc_limit: 20 + + - class: trustgraph.retrieval.nlp_query.Processor + params: + <<: *defaults + id: nlp-query + + - class: trustgraph.retrieval.structured_query.Processor + params: + <<: *defaults + id: structured-query + + - class: trustgraph.retrieval.structured_diag.Processor + params: + <<: *defaults + id: structured-diag + + - class: trustgraph.query.sparql.Processor + params: + <<: *defaults + id: sparql-query + + - class: trustgraph.prompt.template.Processor + params: + <<: *defaults + id: prompt-rag + concurrency: 1 + + - class: trustgraph.agent.mcp_tool.Service + params: + <<: *defaults + id: mcp-tool diff --git a/dev-tools/proc-group/groups/rows-store.yaml b/dev-tools/proc-group/groups/rows-store.yaml new file mode 100644 index 00000000..ed52556d --- /dev/null +++ b/dev-tools/proc-group/groups/rows-store.yaml @@ -0,0 +1,20 @@ +# Rows store. Cassandra-backed structured row query/write. + +_defaults: &defaults + pubsub_backend: rabbitmq + rabbitmq_host: localhost + log_level: INFO + +processors: + + - class: trustgraph.query.rows.cassandra.Processor + params: + <<: *defaults + id: rows-query + cassandra_host: localhost + + - class: trustgraph.storage.rows.cassandra.Processor + params: + <<: *defaults + id: rows-write + cassandra_host: localhost diff --git a/dev-tools/proc-group/groups/triples-store.yaml b/dev-tools/proc-group/groups/triples-store.yaml new file mode 100644 index 00000000..4e32bfbd --- /dev/null +++ b/dev-tools/proc-group/groups/triples-store.yaml @@ -0,0 +1,20 @@ +# Triples store. Cassandra-backed RDF triple query/write. + +_defaults: &defaults + pubsub_backend: rabbitmq + rabbitmq_host: localhost + log_level: INFO + +processors: + + - class: trustgraph.query.triples.cassandra.Processor + params: + <<: *defaults + id: triples-query + cassandra_host: localhost + + - class: trustgraph.storage.triples.cassandra.Processor + params: + <<: *defaults + id: triples-write + cassandra_host: localhost diff --git a/dev-tools/tests/agent_dag/analyse_trace.py b/dev-tools/tests/agent_dag/analyse_trace.py index b71cdebe..42cca118 100644 --- a/dev-tools/tests/agent_dag/analyse_trace.py +++ b/dev-tools/tests/agent_dag/analyse_trace.py @@ -131,21 +131,21 @@ async def analyse(path, url, flow, user, collection): for i, msg in enumerate(messages): resp = msg.get("response", {}) - chunk_type = resp.get("chunk_type", "?") + message_type = resp.get("message_type", "?") - if chunk_type == "explain": + if message_type == "explain": explain_id = resp.get("explain_id", "") explain_ids.append(explain_id) - print(f" {i:3d} {chunk_type} {explain_id}") + print(f" {i:3d} {message_type} {explain_id}") else: - print(f" {i:3d} {chunk_type}") + print(f" {i:3d} {message_type}") # Rule 7: message_id on content chunks - if chunk_type in ("thought", "observation", "answer"): + if message_type in ("thought", "observation", "answer"): mid = resp.get("message_id", "") if not mid: errors.append( - f"[msg {i}] {chunk_type} chunk missing message_id" + f"[msg {i}] {message_type} chunk missing message_id" ) print() diff --git a/docs/tech-specs/agent-explainability.md b/docs/tech-specs/agent-explainability.md index e054f023..add39fd7 100644 --- a/docs/tech-specs/agent-explainability.md +++ b/docs/tech-specs/agent-explainability.md @@ -6,273 +6,212 @@ parent: "Tech Specs" # Agent Explainability: Provenance Recording +## Status + +Implemented + ## Overview -Add provenance recording to the React agent loop so agent sessions can be traced and debugged using the same explainability infrastructure as GraphRAG. +Agent sessions are traced and debugged using the same explainability infrastructure as GraphRAG and Document RAG. Provenance is written to `urn:graph:retrieval` and delivered inline on the explain stream. -**Design Decisions:** -- Write to `urn:graph:retrieval` (generic explainability graph) -- Linear dependency chain for now (analysis N → wasDerivedFrom → analysis N-1) -- Tools are opaque black boxes (record input/output only) -- DAG support deferred to future iteration +The canonical vocabulary for all predicates and types is published as an OWL ontology at `specs/ontology/trustgraph.ttl`. ## Entity Types -Both GraphRAG and Agent use PROV-O as the base ontology with TrustGraph-specific subtypes: +All services use PROV-O as the base ontology with TrustGraph-specific subtypes. ### GraphRAG Types -| Entity | PROV-O Type | TG Types | Description | -|--------|-------------|----------|-------------| -| Question | `prov:Activity` | `tg:Question`, `tg:GraphRagQuestion` | The user's query | -| Exploration | `prov:Entity` | `tg:Exploration` | Edges retrieved from knowledge graph | -| Focus | `prov:Entity` | `tg:Focus` | Selected edges with reasoning | -| Synthesis | `prov:Entity` | `tg:Synthesis` | Final answer | - -### Agent Types -| Entity | PROV-O Type | TG Types | Description | -|--------|-------------|----------|-------------| -| Question | `prov:Activity` | `tg:Question`, `tg:AgentQuestion` | The user's query | -| Analysis | `prov:Entity` | `tg:Analysis` | Each think/act/observe cycle | -| Conclusion | `prov:Entity` | `tg:Conclusion` | Final answer | +| Entity | TG Types | Description | +|--------|----------|-------------| +| Question | `tg:Question`, `tg:GraphRagQuestion` | The user's query | +| Grounding | `tg:Grounding` | Concept extraction from query | +| Exploration | `tg:Exploration` | Edges retrieved from knowledge graph | +| Focus | `tg:Focus` | Selected edges with reasoning | +| Synthesis | `tg:Synthesis`, `tg:Answer` | Final answer | ### Document RAG Types -| Entity | PROV-O Type | TG Types | Description | -|--------|-------------|----------|-------------| -| Question | `prov:Activity` | `tg:Question`, `tg:DocRagQuestion` | The user's query | -| Exploration | `prov:Entity` | `tg:Exploration` | Chunks retrieved from document store | -| Synthesis | `prov:Entity` | `tg:Synthesis` | Final answer | +| Entity | TG Types | Description | +|--------|----------|-------------| +| Question | `tg:Question`, `tg:DocRagQuestion` | The user's query | +| Grounding | `tg:Grounding` | Concept extraction from query | +| Exploration | `tg:Exploration` | Chunks retrieved from document store | +| Synthesis | `tg:Synthesis`, `tg:Answer` | Final answer | -**Note:** Document RAG uses a subset of GraphRAG's types (no Focus step since there's no edge selection/reasoning phase). +### Agent Types (React) +| Entity | TG Types | Description | +|--------|----------|-------------| +| Question | `tg:Question`, `tg:AgentQuestion` | The user's query (session start) | +| PatternDecision | `tg:PatternDecision` | Meta-router routing decision | +| Analysis | `tg:Analysis`, `tg:ToolUse` | One think/act cycle | +| Thought | `tg:Reflection`, `tg:Thought` | Agent reasoning (sub-entity of Analysis) | +| Observation | `tg:Reflection`, `tg:Observation` | Tool result (standalone entity) | +| Conclusion | `tg:Conclusion`, `tg:Answer` | Final answer | + +### Agent Types (Orchestrator — Plan) +| Entity | TG Types | Description | +|--------|----------|-------------| +| Plan | `tg:Plan` | Structured plan of steps | +| StepResult | `tg:StepResult`, `tg:Answer` | Result from executing one plan step | +| Synthesis | `tg:Synthesis`, `tg:Answer` | Final synthesised answer | + +### Agent Types (Orchestrator — Supervisor) +| Entity | TG Types | Description | +|--------|----------|-------------| +| Decomposition | `tg:Decomposition` | Question decomposed into sub-goals | +| Finding | `tg:Finding`, `tg:Answer` | Result from a sub-agent | +| Synthesis | `tg:Synthesis`, `tg:Answer` | Final synthesised answer | + +### Mixin Types +| Type | Description | +|------|-------------| +| `tg:Answer` | Unifying type for terminal answers (Synthesis, Conclusion, Finding, StepResult) | +| `tg:Reflection` | Unifying type for intermediate commentary (Thought, Observation) | +| `tg:ToolUse` | Applied to Analysis when a tool is invoked | +| `tg:Error` | Applied to Observation events where a failure occurred (tool error or LLM parse error) | ### Question Subtypes -All Question entities share `tg:Question` as a base type but have a specific subtype to identify the retrieval mechanism: - | Subtype | URI Pattern | Mechanism | |---------|-------------|-----------| | `tg:GraphRagQuestion` | `urn:trustgraph:question:{uuid}` | Knowledge graph RAG | | `tg:DocRagQuestion` | `urn:trustgraph:docrag:{uuid}` | Document/chunk RAG | -| `tg:AgentQuestion` | `urn:trustgraph:agent:{uuid}` | ReAct agent | +| `tg:AgentQuestion` | `urn:trustgraph:agent:session:{uuid}` | Agent orchestrator | -This allows querying all questions via `tg:Question` while filtering by specific mechanism via the subtype. +## Provenance Chains -## Provenance Model +All chains use `prov:wasDerivedFrom` links. Each entity is a `prov:Entity`. + +### GraphRAG ``` -Question (urn:trustgraph:agent:{uuid}) - │ - │ tg:query = "User's question" - │ prov:startedAtTime = timestamp - │ rdf:type = prov:Activity, tg:Question - │ - ↓ prov:wasDerivedFrom - │ -Analysis1 (urn:trustgraph:agent:{uuid}/i1) - │ - │ tg:thought = "I need to query the knowledge base..." - │ tg:action = "knowledge-query" - │ tg:arguments = {"question": "..."} - │ tg:observation = "Result from tool..." - │ rdf:type = prov:Entity, tg:Analysis - │ - ↓ prov:wasDerivedFrom - │ -Analysis2 (urn:trustgraph:agent:{uuid}/i2) - │ ... - ↓ prov:wasDerivedFrom - │ -Conclusion (urn:trustgraph:agent:{uuid}/final) - │ - │ tg:answer = "The final response..." - │ rdf:type = prov:Entity, tg:Conclusion +Question → Grounding → Exploration → Focus → Synthesis ``` -### Document RAG Provenance Model +### Document RAG ``` -Question (urn:trustgraph:docrag:{uuid}) - │ - │ tg:query = "User's question" - │ prov:startedAtTime = timestamp - │ rdf:type = prov:Activity, tg:Question - │ - ↓ prov:wasGeneratedBy - │ -Exploration (urn:trustgraph:docrag:{uuid}/exploration) - │ - │ tg:chunkCount = 5 - │ tg:selectedChunk = "chunk-id-1" - │ tg:selectedChunk = "chunk-id-2" - │ ... - │ rdf:type = prov:Entity, tg:Exploration - │ - ↓ prov:wasDerivedFrom - │ -Synthesis (urn:trustgraph:docrag:{uuid}/synthesis) - │ - │ tg:content = "The synthesized answer..." - │ rdf:type = prov:Entity, tg:Synthesis +Question → Grounding → Exploration → Synthesis ``` -## Changes Required +### Agent React -### 1. Schema Changes - -**File:** `trustgraph-base/trustgraph/schema/services/agent.py` - -Add `session_id` and `collection` fields to `AgentRequest`: -```python -@dataclass -class AgentRequest: - question: str = "" - state: str = "" - group: list[str] | None = None - history: list[AgentStep] = field(default_factory=list) - user: str = "" - collection: str = "default" # NEW: Collection for provenance traces - streaming: bool = False - session_id: str = "" # NEW: For provenance tracking across iterations +``` +Question → PatternDecision → Analysis(1) → Observation(1) → Analysis(2) → ... → Conclusion ``` -**File:** `trustgraph-base/trustgraph/messaging/translators/agent.py` +The PatternDecision entity records which execution pattern the meta-router selected. It is only emitted on the first iteration when routing occurs. -Update translator to handle `session_id` and `collection` in both `to_pulsar()` and `from_pulsar()`. +Thought sub-entities derive from their parent Analysis. Observation entities derive from their parent Analysis (or from a sub-trace entity if the tool produced its own explainability, e.g. a GraphRAG query). -### 2. Add Explainability Producer to Agent Service +### Agent Plan-then-Execute -**File:** `trustgraph-flow/trustgraph/agent/react/service.py` - -Register an "explainability" producer (same pattern as GraphRAG): -```python -from ... base import ProducerSpec -from ... schema import Triples - -# In __init__: -self.register_specification( - ProducerSpec( - name = "explainability", - schema = Triples, - ) -) +``` +Question → PatternDecision → Plan → StepResult(0) → StepResult(1) → ... → Synthesis ``` -### 3. Provenance Triple Generation +### Agent Supervisor -**File:** `trustgraph-base/trustgraph/provenance/agent.py` - -Create helper functions (similar to GraphRAG's `question_triples`, `exploration_triples`, etc.): -```python -def agent_session_triples(session_uri, query, timestamp): - """Generate triples for agent Question.""" - return [ - Triple(s=session_uri, p=RDF_TYPE, o=PROV_ACTIVITY), - Triple(s=session_uri, p=RDF_TYPE, o=TG_QUESTION), - Triple(s=session_uri, p=TG_QUERY, o=query), - Triple(s=session_uri, p=PROV_STARTED_AT_TIME, o=timestamp), - ] - -def agent_iteration_triples(iteration_uri, parent_uri, thought, action, arguments, observation): - """Generate triples for one Analysis step.""" - return [ - Triple(s=iteration_uri, p=RDF_TYPE, o=PROV_ENTITY), - Triple(s=iteration_uri, p=RDF_TYPE, o=TG_ANALYSIS), - Triple(s=iteration_uri, p=TG_THOUGHT, o=thought), - Triple(s=iteration_uri, p=TG_ACTION, o=action), - Triple(s=iteration_uri, p=TG_ARGUMENTS, o=json.dumps(arguments)), - Triple(s=iteration_uri, p=TG_OBSERVATION, o=observation), - Triple(s=iteration_uri, p=PROV_WAS_DERIVED_FROM, o=parent_uri), - ] - -def agent_final_triples(final_uri, parent_uri, answer): - """Generate triples for Conclusion.""" - return [ - Triple(s=final_uri, p=RDF_TYPE, o=PROV_ENTITY), - Triple(s=final_uri, p=RDF_TYPE, o=TG_CONCLUSION), - Triple(s=final_uri, p=TG_ANSWER, o=answer), - Triple(s=final_uri, p=PROV_WAS_DERIVED_FROM, o=parent_uri), - ] +``` +Question → PatternDecision → Decomposition → [fan-out sub-agents] + → Finding(0) → Finding(1) → ... → Synthesis ``` -### 4. Type Definitions +Each sub-agent runs its own session with `wasDerivedFrom` linking back to the parent's Decomposition. Findings derive from their sub-agent's Conclusion. -**File:** `trustgraph-base/trustgraph/provenance/namespaces.py` +## Predicates -Add explainability entity types and agent predicates: -```python -# Explainability entity types (used by both GraphRAG and Agent) -TG_QUESTION = TG + "Question" -TG_EXPLORATION = TG + "Exploration" -TG_FOCUS = TG + "Focus" -TG_SYNTHESIS = TG + "Synthesis" -TG_ANALYSIS = TG + "Analysis" -TG_CONCLUSION = TG + "Conclusion" +### Session / Question +| Predicate | Type | Description | +|-----------|------|-------------| +| `tg:query` | string | The user's query text | +| `prov:startedAtTime` | string | ISO timestamp | -# Agent predicates -TG_THOUGHT = TG + "thought" -TG_ACTION = TG + "action" -TG_ARGUMENTS = TG + "arguments" -TG_OBSERVATION = TG + "observation" -TG_ANSWER = TG + "answer" -``` +### Pattern Decision +| Predicate | Type | Description | +|-----------|------|-------------| +| `tg:pattern` | string | Selected pattern (react, plan-then-execute, supervisor) | +| `tg:taskType` | string | Identified task type (general, research, etc.) | -## Files Modified +### Analysis (Iteration) +| Predicate | Type | Description | +|-----------|------|-------------| +| `tg:action` | string | Tool name selected by the agent | +| `tg:arguments` | string | JSON-encoded arguments | +| `tg:thought` | IRI | Link to Thought sub-entity | +| `tg:toolCandidate` | string | Tool name available to the LLM (one per candidate) | +| `tg:stepNumber` | integer | 1-based iteration counter | +| `tg:llmDurationMs` | integer | LLM call duration in milliseconds | +| `tg:inToken` | integer | Input token count | +| `tg:outToken` | integer | Output token count | +| `tg:llmModel` | string | Model identifier | -| File | Change | -|------|--------| -| `trustgraph-base/trustgraph/schema/services/agent.py` | Add session_id and collection to AgentRequest | -| `trustgraph-base/trustgraph/messaging/translators/agent.py` | Update translator for new fields | -| `trustgraph-base/trustgraph/provenance/namespaces.py` | Add entity types, agent predicates, and Document RAG predicates | -| `trustgraph-base/trustgraph/provenance/triples.py` | Add TG types to GraphRAG triple builders, add Document RAG triple builders | -| `trustgraph-base/trustgraph/provenance/uris.py` | Add Document RAG URI generators | -| `trustgraph-base/trustgraph/provenance/__init__.py` | Export new types, predicates, and Document RAG functions | -| `trustgraph-base/trustgraph/schema/services/retrieval.py` | Add explain_id, explain_graph, and explain_triples to DocumentRagResponse | -| `trustgraph-base/trustgraph/messaging/translators/retrieval.py` | Update DocumentRagResponseTranslator for explainability fields including inline triples | -| `trustgraph-flow/trustgraph/agent/react/service.py` | Add explainability producer + recording logic | -| `trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py` | Add explainability callback and emit provenance triples | -| `trustgraph-flow/trustgraph/retrieval/document_rag/rag.py` | Add explainability producer and wire up callback | -| `trustgraph-cli/trustgraph/cli/show_explain_trace.py` | Handle agent trace types | -| `trustgraph-cli/trustgraph/cli/list_explain_traces.py` | List agent sessions alongside GraphRAG | +### Observation +| Predicate | Type | Description | +|-----------|------|-------------| +| `tg:document` | IRI | Librarian document reference | +| `tg:toolDurationMs` | integer | Tool execution time in milliseconds | +| `tg:toolError` | string | Error message (tool failure or LLM parse error) | -## Files Created +When `tg:toolError` is present, the Observation also carries the `tg:Error` mixin type. -| File | Purpose | -|------|---------| -| `trustgraph-base/trustgraph/provenance/agent.py` | Agent-specific triple generators | +### Conclusion / Synthesis +| Predicate | Type | Description | +|-----------|------|-------------| +| `tg:document` | IRI | Librarian document reference | +| `tg:terminationReason` | string | Why the loop stopped | +| `tg:inToken` | integer | Input token count (synthesis LLM call) | +| `tg:outToken` | integer | Output token count | +| `tg:llmModel` | string | Model identifier | -## CLI Updates +Termination reason values: +- `final-answer` -- LLM produced a confident answer (react) +- `plan-complete` -- all plan steps executed (plan-then-execute) +- `subagents-complete` -- all sub-agents reported back (supervisor) -**Detection:** Both GraphRAG and Agent Questions have `tg:Question` type. Distinguished by: -1. URI pattern: `urn:trustgraph:agent:` vs `urn:trustgraph:question:` -2. Derived entities: `tg:Analysis` (agent) vs `tg:Exploration` (GraphRAG) +### Decomposition +| Predicate | Type | Description | +|-----------|------|-------------| +| `tg:subagentGoal` | string | Goal assigned to a sub-agent (one per goal) | +| `tg:inToken` | integer | Input token count | +| `tg:outToken` | integer | Output token count | -**`list_explain_traces.py`:** -- Shows Type column (Agent vs GraphRAG) +### Plan +| Predicate | Type | Description | +|-----------|------|-------------| +| `tg:planStep` | string | Goal for a plan step (one per step) | +| `tg:inToken` | integer | Input token count | +| `tg:outToken` | integer | Output token count | -**`show_explain_trace.py`:** -- Auto-detects trace type -- Agent rendering shows: Question → Analysis step(s) → Conclusion +### Token Counts on RAG Events -## Backwards Compatibility +Grounding, Focus, and Synthesis events on GraphRAG and Document RAG also carry `tg:inToken`, `tg:outToken`, and `tg:llmModel` for the LLM calls associated with that step. -- `session_id` defaults to `""` - old requests work, just won't have provenance -- `collection` defaults to `"default"` - reasonable fallback -- CLI gracefully handles both trace types +## Error Handling + +Tool execution errors and LLM parse errors are captured as Observation events rather than crashing the agent: + +- The error message is recorded on `tg:toolError` +- The Observation carries the `tg:Error` mixin type +- The error text becomes the observation content, visible to the LLM on the next iteration +- The provenance chain is preserved (Observation derives from Analysis) +- The agent gets another iteration to retry or choose a different approach + +## Vocabulary Reference + +The full OWL ontology covering all classes and predicates is at `specs/ontology/trustgraph.ttl`. ## Verification ```bash -# Run an agent query -tg-invoke-agent -q "What is the capital of France?" +# Run an agent query with explainability +tg-invoke-agent -q "What is quantum computing?" -x -# List traces (should show agent sessions with Type column) -tg-list-explain-traces -U trustgraph -C default +# Run with token usage +tg-invoke-agent -q "What is quantum computing?" --show-usage -# Show agent trace -tg-show-explain-trace "urn:trustgraph:agent:xxx" +# GraphRAG with explainability +tg-invoke-graph-rag -q "Tell me about AI" -x + +# Document RAG with explainability +tg-invoke-document-rag -q "Summarize the findings" -x ``` - -## Future Work (Not This PR) - -- DAG dependencies (when analysis N uses results from multiple prior analyses) -- Tool-specific provenance linking (KnowledgeQuery → its GraphRAG trace) -- Streaming provenance emission (emit as we go, not batch at end) diff --git a/docs/tech-specs/agent-orchestration.md b/docs/tech-specs/agent-orchestration.md index 19261af0..ab1569ed 100644 --- a/docs/tech-specs/agent-orchestration.md +++ b/docs/tech-specs/agent-orchestration.md @@ -868,7 +868,7 @@ independently. Response chunk fields: message_id UUID for this message (groups chunks) session_id Which agent session produced this chunk - chunk_type "thought" | "observation" | "answer" | ... + message_type "thought" | "observation" | "answer" | ... content The chunk text end_of_message True on the final chunk of this message end_of_dialog True on the final message of the entire execution diff --git a/docs/tech-specs/extraction-provenance-subgraph.md b/docs/tech-specs/extraction-provenance-subgraph.md index 197691f2..62d3a701 100644 --- a/docs/tech-specs/extraction-provenance-subgraph.md +++ b/docs/tech-specs/extraction-provenance-subgraph.md @@ -209,3 +209,7 @@ def subgraph_provenance_triples( This is a breaking change to the provenance model. Provenance has not been released, so no migration is needed. The old `tg:reifies` / `statement_uri` code can be removed outright. + +## Vocabulary Reference + +The full OWL ontology covering all extraction and query-time classes and predicates is at `specs/ontology/trustgraph.ttl`. diff --git a/docs/tech-specs/extraction-time-provenance.md b/docs/tech-specs/extraction-time-provenance.md index e6197f8a..6c4b4513 100644 --- a/docs/tech-specs/extraction-time-provenance.md +++ b/docs/tech-specs/extraction-time-provenance.md @@ -618,8 +618,13 @@ Link embedding entity IDs to chunk. | Triple store | Use reification for triple → chunk provenance | | Embedding provenance | Link entity ID → chunk ID | +## Vocabulary Reference + +The full OWL ontology covering all extraction and query-time classes and predicates is at `specs/ontology/trustgraph.ttl`. + ## References -- Query-time provenance: `docs/tech-specs/query-time-provenance.md` +- Query-time provenance: `docs/tech-specs/query-time-explainability.md` +- Agent explainability: `docs/tech-specs/agent-explainability.md` - PROV-O standard for provenance modeling - Existing source metadata in knowledge graph (needs audit) diff --git a/docs/tech-specs/python-api-refactor.md b/docs/tech-specs/python-api-refactor.md index 97ebe2f7..dd0022df 100644 --- a/docs/tech-specs/python-api-refactor.md +++ b/docs/tech-specs/python-api-refactor.md @@ -710,17 +710,17 @@ class StreamingChunk: @dataclasses.dataclass class AgentThought(StreamingChunk): """Agent reasoning chunk""" - chunk_type: str = "thought" + message_type: str = "thought" @dataclasses.dataclass class AgentObservation(StreamingChunk): """Agent tool observation chunk""" - chunk_type: str = "observation" + message_type: str = "observation" @dataclasses.dataclass class AgentAnswer(StreamingChunk): """Agent final answer chunk""" - chunk_type: str = "final-answer" + message_type: str = "final-answer" end_of_dialog: bool = False @dataclasses.dataclass diff --git a/docs/tech-specs/query-time-explainability.md b/docs/tech-specs/query-time-explainability.md index 0e1d18f6..02598563 100644 --- a/docs/tech-specs/query-time-explainability.md +++ b/docs/tech-specs/query-time-explainability.md @@ -144,6 +144,25 @@ Defined in `trustgraph-base/trustgraph/provenance/namespaces.py`: | `TG_REASONING` | `https://trustgraph.ai/ns/reasoning` | | `TG_CONTENT` | `https://trustgraph.ai/ns/content` | | `TG_DOCUMENT` | `https://trustgraph.ai/ns/document` | +| `TG_IN_TOKEN` | `https://trustgraph.ai/ns/inToken` | +| `TG_OUT_TOKEN` | `https://trustgraph.ai/ns/outToken` | +| `TG_LLM_MODEL` | `https://trustgraph.ai/ns/llmModel` | + +### Token Usage on Events + +Grounding, Focus, and Synthesis events carry per-event LLM token counts: + +| Predicate | Type | Present on | +|-----------|------|------------| +| `tg:inToken` | integer | Grounding, Focus, Synthesis | +| `tg:outToken` | integer | Grounding, Focus, Synthesis | +| `tg:llmModel` | string | Grounding, Focus, Synthesis | + +- **Grounding**: tokens from the extract-concepts LLM call +- **Focus**: summed tokens from edge-scoring + edge-reasoning LLM calls +- **Synthesis**: tokens from the synthesis LLM call + +Values are absent (not zero) when token counts are unavailable. ## GraphRagResponse Schema @@ -267,8 +286,13 @@ Based on the provided knowledge statements... | `trustgraph-flow/trustgraph/query/triples/cassandra/service.py` | Quoted triple query support | | `trustgraph-cli/trustgraph/cli/invoke_graph_rag.py` | CLI with explainability display | +## Vocabulary Reference + +The full OWL ontology covering all classes and predicates is at `specs/ontology/trustgraph.ttl`. + ## References - PROV-O (W3C Provenance Ontology): https://www.w3.org/TR/prov-o/ - RDF-star: https://w3c.github.io/rdf-star/ - Extraction-time provenance: `docs/tech-specs/extraction-time-provenance.md` +- Agent explainability: `docs/tech-specs/agent-explainability.md` diff --git a/docs/tech-specs/streaming-llm-responses.md b/docs/tech-specs/streaming-llm-responses.md index 7ecea5a5..5f6d9877 100644 --- a/docs/tech-specs/streaming-llm-responses.md +++ b/docs/tech-specs/streaming-llm-responses.md @@ -399,28 +399,28 @@ The agent produces multiple types of output during its reasoning cycle: - Answer (final response) - Errors -Since `chunk_type` identifies what kind of content is being sent, the separate +Since `message_type` identifies what kind of content is being sent, the separate `answer`, `error`, `thought`, and `observation` fields can be collapsed into a single `content` field: ```python class AgentResponse(Record): - chunk_type = String() # "thought", "action", "observation", "answer", "error" - content = String() # The actual content (interpretation depends on chunk_type) + message_type = String() # "thought", "action", "observation", "answer", "error" + content = String() # The actual content (interpretation depends on message_type) end_of_message = Boolean() # Current thought/action/observation/answer is complete end_of_dialog = Boolean() # Entire agent dialog is complete ``` **Field Semantics:** -- `chunk_type`: Indicates what type of content is in the `content` field +- `message_type`: Indicates what type of content is in the `content` field - `"thought"`: Agent reasoning/thinking - `"action"`: Tool/action being invoked - `"observation"`: Result from tool execution - `"answer"`: Final answer to the user's question - `"error"`: Error message -- `content`: The actual streamed content, interpreted based on `chunk_type` +- `content`: The actual streamed content, interpreted based on `message_type` - `end_of_message`: When `true`, the current chunk type is complete - Example: All tokens for the current thought have been sent @@ -434,27 +434,27 @@ class AgentResponse(Record): When `streaming=true`: 1. **Thought streaming**: - - Multiple chunks with `chunk_type="thought"`, `end_of_message=false` + - Multiple chunks with `message_type="thought"`, `end_of_message=false` - Final thought chunk has `end_of_message=true` 2. **Action notification**: - - Single chunk with `chunk_type="action"`, `end_of_message=true` + - Single chunk with `message_type="action"`, `end_of_message=true` 3. **Observation**: - - Chunk(s) with `chunk_type="observation"`, final has `end_of_message=true` + - Chunk(s) with `message_type="observation"`, final has `end_of_message=true` 4. **Repeat** steps 1-3 as the agent reasons 5. **Final answer**: - - `chunk_type="answer"` with the final response in `content` + - `message_type="answer"` with the final response in `content` - Last chunk has `end_of_message=true`, `end_of_dialog=true` **Example Stream Sequence:** ``` -{chunk_type: "thought", content: "I need to", end_of_message: false, end_of_dialog: false} -{chunk_type: "thought", content: " search for...", end_of_message: true, end_of_dialog: false} -{chunk_type: "action", content: "search", end_of_message: true, end_of_dialog: false} -{chunk_type: "observation", content: "Found: ...", end_of_message: true, end_of_dialog: false} -{chunk_type: "thought", content: "Based on this", end_of_message: false, end_of_dialog: false} -{chunk_type: "thought", content: " I can answer...", end_of_message: true, end_of_dialog: false} -{chunk_type: "answer", content: "The answer is...", end_of_message: true, end_of_dialog: true} +{message_type: "thought", content: "I need to", end_of_message: false, end_of_dialog: false} +{message_type: "thought", content: " search for...", end_of_message: true, end_of_dialog: false} +{message_type: "action", content: "search", end_of_message: true, end_of_dialog: false} +{message_type: "observation", content: "Found: ...", end_of_message: true, end_of_dialog: false} +{message_type: "thought", content: "Based on this", end_of_message: false, end_of_dialog: false} +{message_type: "thought", content: " I can answer...", end_of_message: true, end_of_dialog: false} +{message_type: "answer", content: "The answer is...", end_of_message: true, end_of_dialog: true} ``` When `streaming=false`: @@ -547,7 +547,7 @@ The following questions were resolved during specification: populated and no other fields are needed. An error is always the final communication - no subsequent messages are permitted or expected after an error. For LLM/Prompt streams, `end_of_stream=true`. For Agent streams, - `chunk_type="error"` with `end_of_dialog=true`. + `message_type="error"` with `end_of_dialog=true`. 3. **Partial Response Recovery**: The messaging protocol (Pulsar) is resilient, so message-level retry is not needed. If a client loses track of the stream diff --git a/specs/ontology/trustgraph.ttl b/specs/ontology/trustgraph.ttl new file mode 100644 index 00000000..4c7de612 --- /dev/null +++ b/specs/ontology/trustgraph.ttl @@ -0,0 +1,415 @@ +@prefix tg: . +@prefix owl: . +@prefix rdf: . +@prefix rdfs: . +@prefix xsd: . +@prefix prov: . + +# ============================================================================= +# Ontology declaration +# ============================================================================= + + + a owl:Ontology ; + rdfs:label "TrustGraph Ontology" ; + rdfs:comment "Vocabulary for TrustGraph provenance, extraction metadata, and explainability." ; + owl:versionInfo "2.3" . + +# ============================================================================= +# Classes — Extraction provenance +# ============================================================================= + +tg:Document a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Document" ; + rdfs:comment "A loaded document (PDF, text, etc.)." . + +tg:Page a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Page" ; + rdfs:comment "A page within a document." . + +tg:Section a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Section" ; + rdfs:comment "A structural section within a document." . + +tg:Chunk a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Chunk" ; + rdfs:comment "A text chunk produced by the chunker." . + +tg:Image a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Image" ; + rdfs:comment "An image extracted from a document." . + +tg:Subgraph a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Subgraph" ; + rdfs:comment "A set of triples extracted from a chunk." . + +# ============================================================================= +# Classes — Query-time explainability (shared) +# ============================================================================= + +tg:Question a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Question" ; + rdfs:comment "Root entity for a query session." . + +tg:GraphRagQuestion a owl:Class ; + rdfs:subClassOf tg:Question ; + rdfs:label "Graph RAG Question" ; + rdfs:comment "A question answered via graph-based RAG." . + +tg:DocRagQuestion a owl:Class ; + rdfs:subClassOf tg:Question ; + rdfs:label "Document RAG Question" ; + rdfs:comment "A question answered via document-based RAG." . + +tg:AgentQuestion a owl:Class ; + rdfs:subClassOf tg:Question ; + rdfs:label "Agent Question" ; + rdfs:comment "A question answered via the agent orchestrator." . + +tg:Grounding a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Grounding" ; + rdfs:comment "Concept extraction step (query decomposition into search terms)." . + +tg:Exploration a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Exploration" ; + rdfs:comment "Entity/chunk retrieval step." . + +tg:Focus a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Focus" ; + rdfs:comment "Edge selection and scoring step (GraphRAG)." . + +tg:Synthesis a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Synthesis" ; + rdfs:comment "Final answer synthesis from retrieved context." . + +# ============================================================================= +# Classes — Agent provenance +# ============================================================================= + +tg:Analysis a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Analysis" ; + rdfs:comment "One agent iteration: reasoning followed by tool selection." . + +tg:ToolUse a owl:Class ; + rdfs:label "ToolUse" ; + rdfs:comment "Mixin type applied to Analysis when a tool is invoked." . + +tg:Error a owl:Class ; + rdfs:label "Error" ; + rdfs:comment "Mixin type applied to events where a failure occurred (tool error, parse error)." . + +tg:Conclusion a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Conclusion" ; + rdfs:comment "Agent final answer (ReAct pattern)." . + +tg:PatternDecision a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Pattern Decision" ; + rdfs:comment "Meta-router decision recording which execution pattern was selected." . + +# --- Unifying types --- + +tg:Answer a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Answer" ; + rdfs:comment "Unifying type for any terminal answer (Synthesis, Conclusion, Finding, StepResult)." . + +tg:Reflection a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Reflection" ; + rdfs:comment "Unifying type for intermediate commentary (Thought, Observation)." . + +tg:Thought a owl:Class ; + rdfs:subClassOf tg:Reflection ; + rdfs:label "Thought" ; + rdfs:comment "Agent reasoning text within an iteration." . + +tg:Observation a owl:Class ; + rdfs:subClassOf tg:Reflection ; + rdfs:label "Observation" ; + rdfs:comment "Tool execution result." . + +# --- Orchestrator types --- + +tg:Decomposition a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Decomposition" ; + rdfs:comment "Supervisor pattern: question decomposed into sub-goals." . + +tg:Finding a owl:Class ; + rdfs:subClassOf tg:Answer ; + rdfs:label "Finding" ; + rdfs:comment "Result from a sub-agent execution." . + +tg:Plan a owl:Class ; + rdfs:subClassOf prov:Entity ; + rdfs:label "Plan" ; + rdfs:comment "Plan-then-execute pattern: structured plan of steps." . + +tg:StepResult a owl:Class ; + rdfs:subClassOf tg:Answer ; + rdfs:label "Step Result" ; + rdfs:comment "Result from executing one plan step." . + +# ============================================================================= +# Properties — Extraction metadata +# ============================================================================= + +tg:contains a owl:ObjectProperty ; + rdfs:label "contains" ; + rdfs:comment "Links a parent entity to a child (e.g. Document contains Page, Subgraph contains triple)." . + +tg:pageCount a owl:DatatypeProperty ; + rdfs:label "page count" ; + rdfs:range xsd:integer ; + rdfs:domain tg:Document . + +tg:mimeType a owl:DatatypeProperty ; + rdfs:label "MIME type" ; + rdfs:range xsd:string ; + rdfs:domain tg:Document . + +tg:pageNumber a owl:DatatypeProperty ; + rdfs:label "page number" ; + rdfs:range xsd:integer ; + rdfs:domain tg:Page . + +tg:chunkIndex a owl:DatatypeProperty ; + rdfs:label "chunk index" ; + rdfs:range xsd:integer ; + rdfs:domain tg:Chunk . + +tg:charOffset a owl:DatatypeProperty ; + rdfs:label "character offset" ; + rdfs:range xsd:integer . + +tg:charLength a owl:DatatypeProperty ; + rdfs:label "character length" ; + rdfs:range xsd:integer . + +tg:chunkSize a owl:DatatypeProperty ; + rdfs:label "chunk size" ; + rdfs:range xsd:integer . + +tg:chunkOverlap a owl:DatatypeProperty ; + rdfs:label "chunk overlap" ; + rdfs:range xsd:integer . + +tg:componentVersion a owl:DatatypeProperty ; + rdfs:label "component version" ; + rdfs:range xsd:string . + +tg:llmModel a owl:DatatypeProperty ; + rdfs:label "LLM model" ; + rdfs:range xsd:string . + +tg:ontology a owl:DatatypeProperty ; + rdfs:label "ontology" ; + rdfs:range xsd:string . + +tg:embeddingModel a owl:DatatypeProperty ; + rdfs:label "embedding model" ; + rdfs:range xsd:string . + +tg:sourceText a owl:DatatypeProperty ; + rdfs:label "source text" ; + rdfs:range xsd:string . + +tg:sourceCharOffset a owl:DatatypeProperty ; + rdfs:label "source character offset" ; + rdfs:range xsd:integer . + +tg:sourceCharLength a owl:DatatypeProperty ; + rdfs:label "source character length" ; + rdfs:range xsd:integer . + +tg:elementTypes a owl:DatatypeProperty ; + rdfs:label "element types" ; + rdfs:range xsd:string . + +tg:tableCount a owl:DatatypeProperty ; + rdfs:label "table count" ; + rdfs:range xsd:integer . + +tg:imageCount a owl:DatatypeProperty ; + rdfs:label "image count" ; + rdfs:range xsd:integer . + +# ============================================================================= +# Properties — Query-time provenance (GraphRAG / DocumentRAG) +# ============================================================================= + +tg:query a owl:DatatypeProperty ; + rdfs:label "query" ; + rdfs:comment "The user's query text." ; + rdfs:range xsd:string ; + rdfs:domain tg:Question . + +tg:concept a owl:DatatypeProperty ; + rdfs:label "concept" ; + rdfs:comment "An extracted concept from the query." ; + rdfs:range xsd:string ; + rdfs:domain tg:Grounding . + +tg:entity a owl:ObjectProperty ; + rdfs:label "entity" ; + rdfs:comment "A seed entity retrieved during exploration." ; + rdfs:domain tg:Exploration . + +tg:edgeCount a owl:DatatypeProperty ; + rdfs:label "edge count" ; + rdfs:comment "Number of edges explored." ; + rdfs:range xsd:integer ; + rdfs:domain tg:Exploration . + +tg:selectedEdge a owl:ObjectProperty ; + rdfs:label "selected edge" ; + rdfs:comment "Link to an edge selection entity within a Focus event." ; + rdfs:domain tg:Focus . + +tg:edge a owl:ObjectProperty ; + rdfs:label "edge" ; + rdfs:comment "A quoted triple representing a knowledge graph edge." ; + rdfs:domain tg:Focus . + +tg:reasoning a owl:DatatypeProperty ; + rdfs:label "reasoning" ; + rdfs:comment "LLM-generated reasoning for an edge selection." ; + rdfs:range xsd:string . + +tg:document a owl:ObjectProperty ; + rdfs:label "document" ; + rdfs:comment "Reference to a document stored in the librarian." . + +tg:chunkCount a owl:DatatypeProperty ; + rdfs:label "chunk count" ; + rdfs:comment "Number of document chunks retrieved (DocumentRAG)." ; + rdfs:range xsd:integer ; + rdfs:domain tg:Exploration . + +tg:selectedChunk a owl:DatatypeProperty ; + rdfs:label "selected chunk" ; + rdfs:comment "A selected chunk ID (DocumentRAG)." ; + rdfs:range xsd:string ; + rdfs:domain tg:Exploration . + +# ============================================================================= +# Properties — Agent provenance +# ============================================================================= + +tg:thought a owl:ObjectProperty ; + rdfs:label "thought" ; + rdfs:comment "Links an Analysis iteration to its Thought sub-entity." ; + rdfs:domain tg:Analysis ; + rdfs:range tg:Thought . + +tg:action a owl:DatatypeProperty ; + rdfs:label "action" ; + rdfs:comment "The tool/action name selected by the agent." ; + rdfs:range xsd:string ; + rdfs:domain tg:Analysis . + +tg:arguments a owl:DatatypeProperty ; + rdfs:label "arguments" ; + rdfs:comment "JSON-encoded arguments passed to the tool." ; + rdfs:range xsd:string ; + rdfs:domain tg:Analysis . + +tg:observation a owl:ObjectProperty ; + rdfs:label "observation" ; + rdfs:comment "Links an Analysis iteration to its Observation sub-entity." ; + rdfs:domain tg:Analysis ; + rdfs:range tg:Observation . + +tg:toolCandidate a owl:DatatypeProperty ; + rdfs:label "tool candidate" ; + rdfs:comment "Name of a tool available to the LLM for this iteration. One triple per candidate." ; + rdfs:range xsd:string ; + rdfs:domain tg:Analysis . + +tg:stepNumber a owl:DatatypeProperty ; + rdfs:label "step number" ; + rdfs:comment "Explicit 1-based step counter for iteration events." ; + rdfs:range xsd:integer ; + rdfs:domain tg:Analysis . + +tg:terminationReason a owl:DatatypeProperty ; + rdfs:label "termination reason" ; + rdfs:comment "Why the agent loop stopped: final-answer, plan-complete, subagents-complete, max-iterations, error." ; + rdfs:range xsd:string . + +tg:pattern a owl:DatatypeProperty ; + rdfs:label "pattern" ; + rdfs:comment "Selected execution pattern (react, plan-then-execute, supervisor)." ; + rdfs:range xsd:string ; + rdfs:domain tg:PatternDecision . + +tg:taskType a owl:DatatypeProperty ; + rdfs:label "task type" ; + rdfs:comment "Identified task type from the meta-router (general, research, etc.)." ; + rdfs:range xsd:string ; + rdfs:domain tg:PatternDecision . + +tg:llmDurationMs a owl:DatatypeProperty ; + rdfs:label "LLM duration (ms)" ; + rdfs:comment "Time spent in the LLM prompt call, in milliseconds." ; + rdfs:range xsd:integer ; + rdfs:domain tg:Analysis . + +tg:toolDurationMs a owl:DatatypeProperty ; + rdfs:label "tool duration (ms)" ; + rdfs:comment "Time spent executing the tool, in milliseconds." ; + rdfs:range xsd:integer ; + rdfs:domain tg:Observation . + +tg:toolError a owl:DatatypeProperty ; + rdfs:label "tool error" ; + rdfs:comment "Error message from a failed tool execution." ; + rdfs:range xsd:string ; + rdfs:domain tg:Observation . + +# --- Token usage predicates (on any event that involves an LLM call) --- + +tg:inToken a owl:DatatypeProperty ; + rdfs:label "input tokens" ; + rdfs:comment "Input token count for the LLM call associated with this event." ; + rdfs:range xsd:integer . + +tg:outToken a owl:DatatypeProperty ; + rdfs:label "output tokens" ; + rdfs:comment "Output token count for the LLM call associated with this event." ; + rdfs:range xsd:integer . + +# --- Orchestrator predicates --- + +tg:subagentGoal a owl:DatatypeProperty ; + rdfs:label "sub-agent goal" ; + rdfs:comment "Goal string assigned to a sub-agent (Decomposition, Finding)." ; + rdfs:range xsd:string . + +tg:planStep a owl:DatatypeProperty ; + rdfs:label "plan step" ; + rdfs:comment "Goal string for a plan step (Plan, StepResult)." ; + rdfs:range xsd:string . + +# ============================================================================= +# Named graphs +# ============================================================================= +# These are not OWL classes but documented here for reference: +# +# (default graph) — Core knowledge facts (extracted triples) +# urn:graph:source — Extraction provenance (document → chunk → triple) +# urn:graph:retrieval — Query-time explainability (question → exploration → synthesis) diff --git a/tests/contract/conftest.py b/tests/contract/conftest.py index 15082437..4fdfe83b 100644 --- a/tests/contract/conftest.py +++ b/tests/contract/conftest.py @@ -87,7 +87,7 @@ def sample_message_data(): "history": [] }, "AgentResponse": { - "chunk_type": "answer", + "message_type": "answer", "content": "Machine learning is a subset of AI.", "end_of_message": True, "end_of_dialog": True, diff --git a/tests/contract/test_message_contracts.py b/tests/contract/test_message_contracts.py index bc5bece1..6b7f82e7 100644 --- a/tests/contract/test_message_contracts.py +++ b/tests/contract/test_message_contracts.py @@ -212,7 +212,7 @@ class TestAgentMessageContracts: # Test required fields response = AgentResponse(**response_data) - assert hasattr(response, 'chunk_type') + assert hasattr(response, 'message_type') assert hasattr(response, 'content') assert hasattr(response, 'end_of_message') assert hasattr(response, 'end_of_dialog') diff --git a/tests/contract/test_schema_field_contracts.py b/tests/contract/test_schema_field_contracts.py new file mode 100644 index 00000000..4b7c3da5 --- /dev/null +++ b/tests/contract/test_schema_field_contracts.py @@ -0,0 +1,73 @@ +""" +Contract tests for schema dataclass field sets. + +These pin the *field names* of small, widely-constructed schema dataclasses +so that any rename, removal, or accidental addition fails CI loudly instead +of waiting for a runtime TypeError on the next websocket message. + +Background: in v2.2 the `Metadata` dataclass dropped a `metadata: list[Triple]` +field but several call sites kept passing `Metadata(metadata=...)`. The bug +was only discovered when a websocket import dispatcher received its first +real message in production. A trivial structural assertion of the kind +below would have caught it at unit-test time. + +Add to this file whenever a schema rename burns you. The cost of a frozen +field set is a one-line update when you intentionally evolve the schema; the +benefit is that every call site is forced to come along for the ride. +""" + +import dataclasses +import pytest + +from trustgraph.schema import ( + Metadata, + EntityContext, + EntityEmbeddings, + ChunkEmbeddings, +) + + +def _field_names(dc): + return {f.name for f in dataclasses.fields(dc)} + + +@pytest.mark.contract +class TestSchemaFieldContracts: + """Pin the field set of dataclasses that get constructed all over the + codebase. If you intentionally change one of these, update the + expected set in the same commit — that diff will surface every call + site that needs to come along.""" + + def test_metadata_fields(self): + # NOTE: there is no `metadata` field. A previous regression + # constructed Metadata(metadata=...) and crashed at runtime. + assert _field_names(Metadata) == { + "id", + "root", + "user", + "collection", + } + + def test_entity_embeddings_fields(self): + # NOTE: the embedding field is `vector` (singular, list[float]). + # There is no `vectors` field. Several call sites historically + # passed `vectors=` and crashed at runtime. + assert _field_names(EntityEmbeddings) == { + "entity", + "vector", + "chunk_id", + } + + def test_chunk_embeddings_fields(self): + # Same `vector` (singular) convention as EntityEmbeddings. + assert _field_names(ChunkEmbeddings) == { + "chunk_id", + "vector", + } + + def test_entity_context_fields(self): + assert _field_names(EntityContext) == { + "entity", + "context", + "chunk_id", + } diff --git a/tests/contract/test_translator_completion_flags.py b/tests/contract/test_translator_completion_flags.py index 91ce1b77..606061f9 100644 --- a/tests/contract/test_translator_completion_flags.py +++ b/tests/contract/test_translator_completion_flags.py @@ -188,7 +188,7 @@ class TestAgentTranslatorCompletionFlags: # Arrange translator = TranslatorRegistry.get_response_translator("agent") response = AgentResponse( - chunk_type="answer", + message_type="answer", content="4", end_of_message=True, end_of_dialog=True, @@ -210,7 +210,7 @@ class TestAgentTranslatorCompletionFlags: # Arrange translator = TranslatorRegistry.get_response_translator("agent") response = AgentResponse( - chunk_type="thought", + message_type="thought", content="I need to solve this.", end_of_message=True, end_of_dialog=False, @@ -233,7 +233,7 @@ class TestAgentTranslatorCompletionFlags: # Test thought message thought_response = AgentResponse( - chunk_type="thought", + message_type="thought", content="Processing...", end_of_message=True, end_of_dialog=False, @@ -247,7 +247,7 @@ class TestAgentTranslatorCompletionFlags: # Test observation message observation_response = AgentResponse( - chunk_type="observation", + message_type="observation", content="Result found", end_of_message=True, end_of_dialog=False, @@ -268,7 +268,7 @@ class TestAgentTranslatorCompletionFlags: # Streaming format with end_of_dialog=True response = AgentResponse( - chunk_type="answer", + message_type="answer", content="", end_of_message=True, end_of_dialog=True, diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 7e18f0de..44a9f127 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -418,55 +418,55 @@ def sample_streaming_agent_response(): """Sample streaming agent response chunks""" return [ { - "chunk_type": "thought", + "message_type": "thought", "content": "I need to search", "end_of_message": False, "end_of_dialog": False }, { - "chunk_type": "thought", + "message_type": "thought", "content": " for information", "end_of_message": False, "end_of_dialog": False }, { - "chunk_type": "thought", + "message_type": "thought", "content": " about machine learning.", "end_of_message": True, "end_of_dialog": False }, { - "chunk_type": "action", + "message_type": "action", "content": "knowledge_query", "end_of_message": True, "end_of_dialog": False }, { - "chunk_type": "observation", + "message_type": "observation", "content": "Machine learning is", "end_of_message": False, "end_of_dialog": False }, { - "chunk_type": "observation", + "message_type": "observation", "content": " a subset of AI.", "end_of_message": True, "end_of_dialog": False }, { - "chunk_type": "final-answer", + "message_type": "final-answer", "content": "Machine learning", "end_of_message": False, "end_of_dialog": False }, { - "chunk_type": "final-answer", + "message_type": "final-answer", "content": " is a subset", "end_of_message": False, "end_of_dialog": False }, { - "chunk_type": "final-answer", + "message_type": "final-answer", "content": " of artificial intelligence.", "end_of_message": True, "end_of_dialog": True @@ -494,10 +494,10 @@ def streaming_chunk_collector(): """Concatenate all chunk content""" return "".join(self.chunks) - def get_chunk_types(self): + def get_message_types(self): """Get list of chunk types if chunks are dicts""" if self.chunks and isinstance(self.chunks[0], dict): - return [c.get("chunk_type") for c in self.chunks] + return [c.get("message_type") for c in self.chunks] return [] def verify_streaming_protocol(self): diff --git a/tests/integration/test_agent_manager_integration.py b/tests/integration/test_agent_manager_integration.py index 652894a2..743ab4d2 100644 --- a/tests/integration/test_agent_manager_integration.py +++ b/tests/integration/test_agent_manager_integration.py @@ -15,6 +15,7 @@ from trustgraph.agent.react.agent_manager import AgentManager from trustgraph.agent.react.tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl from trustgraph.agent.react.types import Action, Final, Tool, Argument from trustgraph.schema import AgentRequest, AgentResponse, AgentStep, Error +from trustgraph.base import PromptResult @pytest.mark.integration @@ -28,19 +29,25 @@ class TestAgentManagerIntegration: # Mock prompt client prompt_client = AsyncMock() - prompt_client.agent_react.return_value = """Thought: I need to search for information about machine learning + prompt_client.agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I need to search for information about machine learning Action: knowledge_query Args: { "question": "What is machine learning?" }""" - + ) + # Mock graph RAG client graph_rag_client = AsyncMock() graph_rag_client.rag.return_value = "Machine learning is a subset of AI that enables computers to learn from data." - + # Mock text completion client text_completion_client = AsyncMock() - text_completion_client.question.return_value = "Machine learning involves algorithms that improve through experience." + text_completion_client.question.return_value = PromptResult( + response_type="text", + text="Machine learning involves algorithms that improve through experience." + ) # Mock MCP tool client mcp_tool_client = AsyncMock() @@ -147,8 +154,11 @@ Args: { async def test_agent_manager_final_answer(self, agent_manager, mock_flow_context): """Test agent manager returning final answer""" # Arrange - mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have enough information to answer the question + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I have enough information to answer the question Final Answer: Machine learning is a field of AI that enables computers to learn from data.""" + ) question = "What is machine learning?" history = [] @@ -193,8 +203,11 @@ Final Answer: Machine learning is a field of AI that enables computers to learn async def test_agent_manager_react_with_final_answer(self, agent_manager, mock_flow_context): """Test ReAct cycle ending with final answer""" # Arrange - mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide a direct answer + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I can provide a direct answer Final Answer: Machine learning is a branch of artificial intelligence.""" + ) question = "What is machine learning?" history = [] @@ -254,11 +267,14 @@ Final Answer: Machine learning is a branch of artificial intelligence.""" for tool_name, expected_service in tool_scenarios: # Arrange - mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: I need to use {tool_name} + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text=f"""Thought: I need to use {tool_name} Action: {tool_name} Args: {{ "question": "test question" }}""" + ) think_callback = AsyncMock() observe_callback = AsyncMock() @@ -284,11 +300,14 @@ Args: {{ async def test_agent_manager_unknown_tool_error(self, agent_manager, mock_flow_context): """Test agent manager error handling for unknown tool""" # Arrange - mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to use an unknown tool + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I need to use an unknown tool Action: unknown_tool Args: { "param": "value" }""" + ) think_callback = AsyncMock() observe_callback = AsyncMock() @@ -308,11 +327,13 @@ Args: { think_callback = AsyncMock() observe_callback = AsyncMock() - # Act & Assert - with pytest.raises(Exception) as exc_info: - await agent_manager.react("test question", [], think_callback, observe_callback, mock_flow_context) - - assert "Tool execution failed" in str(exc_info.value) + # Act - tool errors are now caught and returned as observations + result = await agent_manager.react("test question", [], think_callback, observe_callback, mock_flow_context) + + # Assert - error captured on the action, not raised + assert result.tool_error is not None + assert "Tool execution failed" in result.tool_error + assert "Error:" in result.observation @pytest.mark.asyncio async def test_agent_manager_multiple_tools_coordination(self, agent_manager, mock_flow_context): @@ -321,11 +342,14 @@ Args: { question = "Find information about AI and summarize it" # Mock multi-step reasoning - mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search for AI information first + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I need to search for AI information first Action: knowledge_query Args: { "question": "What is artificial intelligence?" }""" + ) # Act action = await agent_manager.reason(question, [], mock_flow_context) @@ -372,9 +396,12 @@ Args: { # Format arguments as JSON import json args_json = json.dumps(test_case['arguments'], indent=4) - mock_flow_context("prompt-request").agent_react.return_value = f"""Thought: Using {test_case['action']} + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text=f"""Thought: Using {test_case['action']} Action: {test_case['action']} Args: {args_json}""" + ) think_callback = AsyncMock() observe_callback = AsyncMock() @@ -507,15 +534,17 @@ Args: { ] for test_case in test_cases: - mock_flow_context("prompt-request").agent_react.return_value = test_case["response"] + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text=test_case["response"] + ) if test_case["error_contains"]: - # Should raise an error - with pytest.raises(RuntimeError) as exc_info: - await agent_manager.reason("test question", [], mock_flow_context) - - assert "Failed to parse agent response" in str(exc_info.value) - assert test_case["error_contains"] in str(exc_info.value) + # Parse errors now return an Action with tool_error + result = await agent_manager.reason("test question", [], mock_flow_context) + assert isinstance(result, Action) + assert result.name == "__parse_error__" + assert result.tool_error is not None else: # Should succeed action = await agent_manager.reason("test question", [], mock_flow_context) @@ -527,13 +556,16 @@ Args: { async def test_agent_manager_text_parsing_edge_cases(self, agent_manager, mock_flow_context): """Test edge cases in text parsing""" # Test response with markdown code blocks - mock_flow_context("prompt-request").agent_react.return_value = """``` + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text="""``` Thought: I need to search for information Action: knowledge_query Args: { "question": "What is AI?" } ```""" + ) action = await agent_manager.reason("test", [], mock_flow_context) assert isinstance(action, Action) @@ -541,15 +573,18 @@ Args: { assert action.name == "knowledge_query" # Test response with extra whitespace - mock_flow_context("prompt-request").agent_react.return_value = """ + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text=""" -Thought: I need to think about this -Action: knowledge_query +Thought: I need to think about this +Action: knowledge_query Args: { "question": "test" } """ + ) action = await agent_manager.reason("test", [], mock_flow_context) assert isinstance(action, Action) @@ -560,7 +595,9 @@ Args: { async def test_agent_manager_multiline_content(self, agent_manager, mock_flow_context): """Test handling of multi-line thoughts and final answers""" # Multi-line thought - mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to consider multiple factors: + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I need to consider multiple factors: 1. The user's question is complex 2. I should search for comprehensive information 3. This requires using the knowledge query tool @@ -568,6 +605,7 @@ Action: knowledge_query Args: { "question": "complex query" }""" + ) action = await agent_manager.reason("test", [], mock_flow_context) assert isinstance(action, Action) @@ -575,13 +613,16 @@ Args: { assert "knowledge query tool" in action.thought # Multi-line final answer - mock_flow_context("prompt-request").agent_react.return_value = """Thought: I have gathered enough information + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I have gathered enough information Final Answer: Here is a comprehensive answer: 1. First point about the topic 2. Second point with details 3. Final conclusion This covers all aspects of the question.""" + ) action = await agent_manager.reason("test", [], mock_flow_context) assert isinstance(action, Final) @@ -593,13 +634,16 @@ This covers all aspects of the question.""" async def test_agent_manager_json_args_special_characters(self, agent_manager, mock_flow_context): """Test JSON arguments with special characters and edge cases""" # Test with special characters in JSON (properly escaped) - mock_flow_context("prompt-request").agent_react.return_value = """Thought: Processing special characters + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: Processing special characters Action: knowledge_query Args: { "question": "What about \\"quotes\\" and 'apostrophes'?", "context": "Line 1\\nLine 2\\tTabbed", "special": "Symbols: @#$%^&*()_+-=[]{}|;':,.<>?" }""" + ) action = await agent_manager.reason("test", [], mock_flow_context) assert isinstance(action, Action) @@ -608,7 +652,9 @@ Args: { assert "@#$%^&*" in action.arguments["special"] # Test with nested JSON - mock_flow_context("prompt-request").agent_react.return_value = """Thought: Complex arguments + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: Complex arguments Action: web_search Args: { "query": "test", @@ -621,6 +667,7 @@ Args: { } } }""" + ) action = await agent_manager.reason("test", [], mock_flow_context) assert isinstance(action, Action) @@ -632,7 +679,9 @@ Args: { async def test_agent_manager_final_answer_json_format(self, agent_manager, mock_flow_context): """Test final answers that contain JSON-like content""" # Final answer with JSON content - mock_flow_context("prompt-request").agent_react.return_value = """Thought: I can provide the data in JSON format + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I can provide the data in JSON format Final Answer: { "result": "success", "data": { @@ -642,6 +691,7 @@ Final Answer: { }, "confidence": 0.95 }""" + ) action = await agent_manager.reason("test", [], mock_flow_context) assert isinstance(action, Final) @@ -792,11 +842,14 @@ Final Answer: { agent = AgentManager(tools=custom_tools, additional_context="") # Mock response for custom collection query - mock_flow_context("prompt-request").agent_react.return_value = """Thought: I need to search in the research papers + mock_flow_context("prompt-request").agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I need to search in the research papers Action: knowledge_query_custom Args: { "question": "Latest AI research?" }""" + ) think_callback = AsyncMock() observe_callback = AsyncMock() diff --git a/tests/integration/test_agent_streaming_integration.py b/tests/integration/test_agent_streaming_integration.py index d6004c21..de7372f1 100644 --- a/tests/integration/test_agent_streaming_integration.py +++ b/tests/integration/test_agent_streaming_integration.py @@ -10,11 +10,12 @@ from unittest.mock import AsyncMock, MagicMock from trustgraph.agent.react.agent_manager import AgentManager from trustgraph.agent.react.tools import KnowledgeQueryImpl from trustgraph.agent.react.types import Tool, Argument +from trustgraph.base import PromptResult from tests.utils.streaming_assertions import ( assert_agent_streaming_chunks, assert_streaming_chunks_valid, assert_callback_invoked, - assert_chunk_types_valid, + assert_message_types_valid, ) @@ -51,10 +52,10 @@ Args: { is_final = (i == len(chunks) - 1) await chunk_callback(chunk, is_final) - return full_text + return PromptResult(response_type="text", text=full_text) else: # Non-streaming response - same text - return full_text + return PromptResult(response_type="text", text=full_text) client.agent_react.side_effect = agent_react_streaming return client @@ -317,8 +318,8 @@ Final Answer: AI is the simulation of human intelligence in machines.""" for i, chunk in enumerate(chunks): is_final = (i == len(chunks) - 1) await chunk_callback(chunk + " ", is_final) - return response - return response + return PromptResult(response_type="text", text=response) + return PromptResult(response_type="text", text=response) mock_prompt_client_streaming.agent_react.side_effect = multi_step_agent_react diff --git a/tests/integration/test_agent_structured_query_integration.py b/tests/integration/test_agent_structured_query_integration.py index 0fedd2b5..2442bf10 100644 --- a/tests/integration/test_agent_structured_query_integration.py +++ b/tests/integration/test_agent_structured_query_integration.py @@ -16,6 +16,7 @@ from trustgraph.schema import ( Error ) from trustgraph.agent.react.service import Processor +from trustgraph.base import PromptResult @pytest.mark.integration @@ -95,11 +96,14 @@ class TestAgentStructuredQueryIntegration: # Mock the prompt client that agent calls for reasoning mock_prompt_client = AsyncMock() - mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from New York using structured query + mock_prompt_client.agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I need to find customers from New York using structured query Action: structured-query Args: { "question": "Find all customers from New York" }""" + ) # Set up flow context routing def flow_context(service_name): @@ -173,11 +177,14 @@ Args: { # Mock the prompt client that agent calls for reasoning mock_prompt_client = AsyncMock() - mock_prompt_client.agent_react.return_value = """Thought: I need to query for a table that might not exist + mock_prompt_client.agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I need to query for a table that might not exist Action: structured-query Args: { "question": "Find data from a table that doesn't exist" }""" + ) # Set up flow context routing def flow_context(service_name): @@ -250,11 +257,14 @@ Args: { # Mock the prompt client that agent calls for reasoning mock_prompt_client = AsyncMock() - mock_prompt_client.agent_react.return_value = """Thought: I need to find customers from California first + mock_prompt_client.agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I need to find customers from California first Action: structured-query Args: { "question": "Find all customers from California" }""" + ) # Set up flow context routing def flow_context(service_name): @@ -339,11 +349,14 @@ Args: { # Mock the prompt client that agent calls for reasoning mock_prompt_client = AsyncMock() - mock_prompt_client.agent_react.return_value = """Thought: I need to query the sales data + mock_prompt_client.agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I need to query the sales data Action: structured-query Args: { "question": "Query the sales data for recent transactions" }""" + ) # Set up flow context routing def flow_context(service_name): @@ -447,11 +460,14 @@ Args: { # Mock the prompt client that agent calls for reasoning mock_prompt_client = AsyncMock() - mock_prompt_client.agent_react.return_value = """Thought: I need to get customer information + mock_prompt_client.agent_react.return_value = PromptResult( + response_type="text", + text="""Thought: I need to get customer information Action: structured-query Args: { "question": "Get customer information and format it nicely" }""" + ) # Set up flow context routing def flow_context(service_name): diff --git a/tests/integration/test_document_rag_integration.py b/tests/integration/test_document_rag_integration.py index e9df05cf..8c165385 100644 --- a/tests/integration/test_document_rag_integration.py +++ b/tests/integration/test_document_rag_integration.py @@ -10,6 +10,7 @@ import pytest from unittest.mock import AsyncMock, MagicMock from trustgraph.retrieval.document_rag.document_rag import DocumentRag from trustgraph.schema import ChunkMatch +from trustgraph.base import PromptResult # Sample chunk content for testing - maps chunk_id to content @@ -61,11 +62,16 @@ class TestDocumentRagIntegration: def mock_prompt_client(self): """Mock prompt client that generates realistic responses""" client = AsyncMock() - client.document_prompt.return_value = ( - "Machine learning is a field of artificial intelligence that enables computers to learn " - "and improve from experience without being explicitly programmed. It uses algorithms " - "to find patterns in data and make predictions or decisions." + client.document_prompt.return_value = PromptResult( + response_type="text", + text=( + "Machine learning is a field of artificial intelligence that enables computers to learn " + "and improve from experience without being explicitly programmed. It uses algorithms " + "to find patterns in data and make predictions or decisions." + ) ) + # Mock prompt() for extract-concepts call in DocumentRag + client.prompt.return_value = PromptResult(response_type="text", text="") return client @pytest.fixture @@ -119,6 +125,7 @@ class TestDocumentRagIntegration: ) # Verify final response + result, usage = result assert result is not None assert isinstance(result, str) assert "machine learning" in result.lower() @@ -131,7 +138,11 @@ class TestDocumentRagIntegration: """Test DocumentRAG behavior when no documents are retrieved""" # Arrange mock_doc_embeddings_client.query.return_value = [] # No chunk_ids found - mock_prompt_client.document_prompt.return_value = "I couldn't find any relevant documents for your query." + mock_prompt_client.document_prompt.return_value = PromptResult( + response_type="text", + text="I couldn't find any relevant documents for your query." + ) + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="") document_rag = DocumentRag( embeddings_client=mock_embeddings_client, @@ -152,7 +163,8 @@ class TestDocumentRagIntegration: documents=[] ) - assert result == "I couldn't find any relevant documents for your query." + result_text, usage = result + assert result_text == "I couldn't find any relevant documents for your query." @pytest.mark.asyncio async def test_document_rag_embeddings_service_failure(self, mock_embeddings_client, diff --git a/tests/integration/test_document_rag_streaming_integration.py b/tests/integration/test_document_rag_streaming_integration.py index dad30a8f..e2c032ad 100644 --- a/tests/integration/test_document_rag_streaming_integration.py +++ b/tests/integration/test_document_rag_streaming_integration.py @@ -9,6 +9,7 @@ import pytest from unittest.mock import AsyncMock from trustgraph.retrieval.document_rag.document_rag import DocumentRag from trustgraph.schema import ChunkMatch +from trustgraph.base import PromptResult from tests.utils.streaming_assertions import ( assert_streaming_chunks_valid, assert_callback_invoked, @@ -74,12 +75,14 @@ class TestDocumentRagStreaming: is_final = (i == len(chunks) - 1) await chunk_callback(chunk, is_final) - return full_text + return PromptResult(response_type="text", text=full_text) else: # Non-streaming response - same text - return full_text + return PromptResult(response_type="text", text=full_text) client.document_prompt.side_effect = document_prompt_side_effect + # Mock prompt() for extract-concepts call in DocumentRag + client.prompt.return_value = PromptResult(response_type="text", text="") return client @pytest.fixture @@ -119,11 +122,12 @@ class TestDocumentRagStreaming: collector.verify_streaming_protocol() # Verify full response matches concatenated chunks + result_text, usage = result full_from_chunks = collector.get_full_text() - assert result == full_from_chunks + assert result_text == full_from_chunks # Verify content is reasonable - assert len(result) > 0 + assert len(result_text) > 0 @pytest.mark.asyncio async def test_document_rag_streaming_vs_non_streaming(self, document_rag_streaming): @@ -159,9 +163,11 @@ class TestDocumentRagStreaming: ) # Assert - Results should be equivalent - assert streaming_result == non_streaming_result + non_streaming_text, _ = non_streaming_result + streaming_text, _ = streaming_result + assert streaming_text == non_streaming_text assert len(streaming_chunks) > 0 - assert "".join(streaming_chunks) == streaming_result + assert "".join(streaming_chunks) == streaming_text @pytest.mark.asyncio async def test_document_rag_streaming_callback_invocation(self, document_rag_streaming): @@ -180,8 +186,9 @@ class TestDocumentRagStreaming: ) # Assert + result_text, usage = result assert callback.call_count > 0 - assert result is not None + assert result_text is not None # Verify all callback invocations had string arguments for call in callback.call_args_list: @@ -202,7 +209,8 @@ class TestDocumentRagStreaming: # Assert - Should complete without error assert result is not None - assert isinstance(result, str) + result_text, usage = result + assert isinstance(result_text, str) @pytest.mark.asyncio async def test_document_rag_streaming_with_no_documents(self, document_rag_streaming, @@ -223,7 +231,8 @@ class TestDocumentRagStreaming: ) # Assert - Should still produce streamed response - assert result is not None + result_text, usage = result + assert result_text is not None assert callback.call_count > 0 @pytest.mark.asyncio @@ -271,7 +280,8 @@ class TestDocumentRagStreaming: ) # Assert - assert result is not None + result_text, usage = result + assert result_text is not None assert callback.call_count > 0 # Verify doc_limit was passed correctly diff --git a/tests/integration/test_graph_rag_integration.py b/tests/integration/test_graph_rag_integration.py index 5e3279e3..9c3cdf45 100644 --- a/tests/integration/test_graph_rag_integration.py +++ b/tests/integration/test_graph_rag_integration.py @@ -12,6 +12,7 @@ import pytest from unittest.mock import AsyncMock, MagicMock from trustgraph.retrieval.graph_rag.graph_rag import GraphRag from trustgraph.schema import EntityMatch, Term, IRI +from trustgraph.base import PromptResult @pytest.mark.integration @@ -93,18 +94,21 @@ class TestGraphRagIntegration: # 4. kg-synthesis returns the final answer async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None): if prompt_name == "extract-concepts": - return "" # Falls back to raw query + return PromptResult(response_type="text", text="") elif prompt_name == "kg-edge-scoring": - return "" # No edges scored + return PromptResult(response_type="text", text="") elif prompt_name == "kg-edge-reasoning": - return "" # No reasoning + return PromptResult(response_type="text", text="") elif prompt_name == "kg-synthesis": - return ( - "Machine learning is a subset of artificial intelligence that enables computers " - "to learn from data without being explicitly programmed. It uses algorithms " - "and statistical models to find patterns in data." + return PromptResult( + response_type="text", + text=( + "Machine learning is a subset of artificial intelligence that enables computers " + "to learn from data without being explicitly programmed. It uses algorithms " + "and statistical models to find patterns in data." + ) ) - return "" + return PromptResult(response_type="text", text="") client.prompt.side_effect = mock_prompt return client @@ -169,6 +173,7 @@ class TestGraphRagIntegration: assert mock_prompt_client.prompt.call_count == 4 # Verify final response + response, usage = response assert response is not None assert isinstance(response, str) assert "machine learning" in response.lower() diff --git a/tests/integration/test_graph_rag_streaming_integration.py b/tests/integration/test_graph_rag_streaming_integration.py index b66c5289..95c494bb 100644 --- a/tests/integration/test_graph_rag_streaming_integration.py +++ b/tests/integration/test_graph_rag_streaming_integration.py @@ -9,6 +9,7 @@ import pytest from unittest.mock import AsyncMock, MagicMock from trustgraph.retrieval.graph_rag.graph_rag import GraphRag from trustgraph.schema import EntityMatch, Term, IRI +from trustgraph.base import PromptResult from tests.utils.streaming_assertions import ( assert_streaming_chunks_valid, assert_rag_streaming_chunks, @@ -61,12 +62,12 @@ class TestGraphRagStreaming: async def prompt_side_effect(prompt_id, variables, streaming=False, chunk_callback=None, **kwargs): if prompt_id == "extract-concepts": - return "" # Falls back to raw query + return PromptResult(response_type="text", text="") elif prompt_id == "kg-edge-scoring": # Edge scoring returns JSONL with IDs and scores - return '{"id": "abc12345", "score": 0.9}\n' + return PromptResult(response_type="text", text='{"id": "abc12345", "score": 0.9}\n') elif prompt_id == "kg-edge-reasoning": - return '{"id": "abc12345", "reasoning": "Relevant to query"}\n' + return PromptResult(response_type="text", text='{"id": "abc12345", "reasoning": "Relevant to query"}\n') elif prompt_id == "kg-synthesis": if streaming and chunk_callback: # Simulate streaming chunks with end_of_stream flags @@ -79,10 +80,10 @@ class TestGraphRagStreaming: is_final = (i == len(chunks) - 1) await chunk_callback(chunk, is_final) - return full_text + return PromptResult(response_type="text", text=full_text) else: - return full_text - return "" + return PromptResult(response_type="text", text=full_text) + return PromptResult(response_type="text", text="") client.prompt.side_effect = prompt_side_effect return client @@ -123,6 +124,7 @@ class TestGraphRagStreaming: ) # Assert + response, usage = response assert_streaming_chunks_valid(collector.chunks, min_chunks=1) assert_callback_invoked(AsyncMock(call_count=len(collector.chunks)), min_calls=1) @@ -172,9 +174,11 @@ class TestGraphRagStreaming: ) # Assert - Results should be equivalent - assert streaming_response == non_streaming_response + non_streaming_text, _ = non_streaming_response + streaming_text, _ = streaming_response + assert streaming_text == non_streaming_text assert len(streaming_chunks) > 0 - assert "".join(streaming_chunks) == streaming_response + assert "".join(streaming_chunks) == streaming_text @pytest.mark.asyncio async def test_graph_rag_streaming_callback_invocation(self, graph_rag_streaming): @@ -213,7 +217,8 @@ class TestGraphRagStreaming: # Assert - Should complete without error assert response is not None - assert isinstance(response, str) + response_text, usage = response + assert isinstance(response_text, str) @pytest.mark.asyncio async def test_graph_rag_streaming_with_empty_kg(self, graph_rag_streaming, diff --git a/tests/integration/test_kg_extract_store_integration.py b/tests/integration/test_kg_extract_store_integration.py index 4d8b60ad..84c0905d 100644 --- a/tests/integration/test_kg_extract_store_integration.py +++ b/tests/integration/test_kg_extract_store_integration.py @@ -18,6 +18,7 @@ from trustgraph.storage.knowledge.store import Processor as KnowledgeStoreProces from trustgraph.schema import Chunk, Triple, Triples, Metadata, Term, Error, IRI, LITERAL from trustgraph.schema import EntityContext, EntityContexts, GraphEmbeddings, EntityEmbeddings from trustgraph.rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL +from trustgraph.base import PromptResult @pytest.mark.integration @@ -31,32 +32,38 @@ class TestKnowledgeGraphPipelineIntegration: # Mock prompt client for definitions extraction prompt_client = AsyncMock() - prompt_client.extract_definitions.return_value = [ - { - "entity": "Machine Learning", - "definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming." - }, - { - "entity": "Neural Networks", - "definition": "Computing systems inspired by biological neural networks that process information." - } - ] - + prompt_client.extract_definitions.return_value = PromptResult( + response_type="jsonl", + objects=[ + { + "entity": "Machine Learning", + "definition": "A subset of artificial intelligence that enables computers to learn from data without explicit programming." + }, + { + "entity": "Neural Networks", + "definition": "Computing systems inspired by biological neural networks that process information." + } + ] + ) + # Mock prompt client for relationships extraction - prompt_client.extract_relationships.return_value = [ - { - "subject": "Machine Learning", - "predicate": "is_subset_of", - "object": "Artificial Intelligence", - "object-entity": True - }, - { - "subject": "Neural Networks", - "predicate": "is_used_in", - "object": "Machine Learning", - "object-entity": True - } - ] + prompt_client.extract_relationships.return_value = PromptResult( + response_type="jsonl", + objects=[ + { + "subject": "Machine Learning", + "predicate": "is_subset_of", + "object": "Artificial Intelligence", + "object-entity": True + }, + { + "subject": "Neural Networks", + "predicate": "is_used_in", + "object": "Machine Learning", + "object-entity": True + } + ] + ) # Mock producers for output streams triples_producer = AsyncMock() @@ -489,7 +496,10 @@ class TestKnowledgeGraphPipelineIntegration: async def test_empty_extraction_results_handling(self, definitions_processor, mock_flow_context, sample_chunk): """Test handling of empty extraction results""" # Arrange - mock_flow_context("prompt-request").extract_definitions.return_value = [] + mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult( + response_type="jsonl", + objects=[] + ) mock_msg = MagicMock() mock_msg.value.return_value = sample_chunk @@ -510,7 +520,10 @@ class TestKnowledgeGraphPipelineIntegration: async def test_invalid_extraction_format_handling(self, definitions_processor, mock_flow_context, sample_chunk): """Test handling of invalid extraction response format""" # Arrange - mock_flow_context("prompt-request").extract_definitions.return_value = "invalid format" # Should be list + mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult( + response_type="text", + text="invalid format" + ) # Should be jsonl with objects list mock_msg = MagicMock() mock_msg.value.return_value = sample_chunk @@ -528,13 +541,16 @@ class TestKnowledgeGraphPipelineIntegration: async def test_entity_filtering_and_validation(self, definitions_processor, mock_flow_context): """Test entity filtering and validation in extraction""" # Arrange - mock_flow_context("prompt-request").extract_definitions.return_value = [ - {"entity": "Valid Entity", "definition": "Valid definition"}, - {"entity": "", "definition": "Empty entity"}, # Should be filtered - {"entity": "Valid Entity 2", "definition": ""}, # Should be filtered - {"entity": None, "definition": "None entity"}, # Should be filtered - {"entity": "Valid Entity 3", "definition": None}, # Should be filtered - ] + mock_flow_context("prompt-request").extract_definitions.return_value = PromptResult( + response_type="jsonl", + objects=[ + {"entity": "Valid Entity", "definition": "Valid definition"}, + {"entity": "", "definition": "Empty entity"}, # Should be filtered + {"entity": "Valid Entity 2", "definition": ""}, # Should be filtered + {"entity": None, "definition": "None entity"}, # Should be filtered + {"entity": "Valid Entity 3", "definition": None}, # Should be filtered + ] + ) sample_chunk = Chunk( metadata=Metadata(id="test", user="user", collection="collection"), diff --git a/tests/integration/test_object_extraction_integration.py b/tests/integration/test_object_extraction_integration.py index faa63381..22ba9a3f 100644 --- a/tests/integration/test_object_extraction_integration.py +++ b/tests/integration/test_object_extraction_integration.py @@ -16,6 +16,7 @@ from trustgraph.schema import ( Chunk, ExtractedObject, Metadata, RowSchema, Field, PromptRequest, PromptResponse ) +from trustgraph.base import PromptResult @pytest.mark.integration @@ -114,49 +115,61 @@ class TestObjectExtractionServiceIntegration: schema_name = schema.get("name") if isinstance(schema, dict) else schema.name if schema_name == "customer_records": if "john" in text.lower(): - return [ - { - "customer_id": "CUST001", - "name": "John Smith", - "email": "john.smith@email.com", - "phone": "555-0123" - } - ] + return PromptResult( + response_type="jsonl", + objects=[ + { + "customer_id": "CUST001", + "name": "John Smith", + "email": "john.smith@email.com", + "phone": "555-0123" + } + ] + ) elif "jane" in text.lower(): - return [ - { - "customer_id": "CUST002", - "name": "Jane Doe", - "email": "jane.doe@email.com", - "phone": "" - } - ] + return PromptResult( + response_type="jsonl", + objects=[ + { + "customer_id": "CUST002", + "name": "Jane Doe", + "email": "jane.doe@email.com", + "phone": "" + } + ] + ) else: - return [] - + return PromptResult(response_type="jsonl", objects=[]) + elif schema_name == "product_catalog": if "laptop" in text.lower(): - return [ - { - "product_id": "PROD001", - "name": "Gaming Laptop", - "price": "1299.99", - "category": "electronics" - } - ] + return PromptResult( + response_type="jsonl", + objects=[ + { + "product_id": "PROD001", + "name": "Gaming Laptop", + "price": "1299.99", + "category": "electronics" + } + ] + ) elif "book" in text.lower(): - return [ - { - "product_id": "PROD002", - "name": "Python Programming Guide", - "price": "49.99", - "category": "books" - } - ] + return PromptResult( + response_type="jsonl", + objects=[ + { + "product_id": "PROD002", + "name": "Python Programming Guide", + "price": "49.99", + "category": "books" + } + ] + ) else: - return [] - - return [] + return PromptResult(response_type="jsonl", objects=[]) + + return PromptResult(response_type="jsonl", objects=[]) prompt_client.extract_objects.side_effect = mock_extract_objects diff --git a/tests/integration/test_prompt_streaming_integration.py b/tests/integration/test_prompt_streaming_integration.py index 9b1a06b6..a1414e2d 100644 --- a/tests/integration/test_prompt_streaming_integration.py +++ b/tests/integration/test_prompt_streaming_integration.py @@ -9,6 +9,7 @@ import pytest from unittest.mock import AsyncMock, MagicMock from trustgraph.prompt.template.service import Processor from trustgraph.schema import PromptRequest, PromptResponse, TextCompletionResponse +from trustgraph.base.text_completion_client import TextCompletionResult from tests.utils.streaming_assertions import ( assert_streaming_chunks_valid, assert_callback_invoked, @@ -27,34 +28,52 @@ class TestPromptStreaming: # Mock text completion client with streaming text_completion_client = AsyncMock() - async def streaming_request(request, recipient=None, timeout=600): - """Simulate streaming text completion""" - if request.streaming and recipient: - # Simulate streaming chunks - chunks = [ - "Machine", " learning", " is", " a", " field", - " of", " artificial", " intelligence", "." - ] + # Streaming chunks to send + chunks = [ + "Machine", " learning", " is", " a", " field", + " of", " artificial", " intelligence", "." + ] - for i, chunk_text in enumerate(chunks): - is_final = (i == len(chunks) - 1) - response = TextCompletionResponse( - response=chunk_text, - error=None, - end_of_stream=is_final - ) - final = await recipient(response) - if final: - break - - # Final empty chunk - await recipient(TextCompletionResponse( - response="", + async def streaming_text_completion_stream(system, prompt, handler, timeout=600): + """Simulate streaming text completion via text_completion_stream""" + for i, chunk_text in enumerate(chunks): + response = TextCompletionResponse( + response=chunk_text, error=None, - end_of_stream=True - )) + end_of_stream=False + ) + await handler(response) - text_completion_client.request = streaming_request + # Send final empty chunk with end_of_stream + await handler(TextCompletionResponse( + response="", + error=None, + end_of_stream=True + )) + + return TextCompletionResult( + text=None, + in_token=10, + out_token=9, + model="test-model", + ) + + async def non_streaming_text_completion(system, prompt, timeout=600): + """Simulate non-streaming text completion""" + full_text = "Machine learning is a field of artificial intelligence." + return TextCompletionResult( + text=full_text, + in_token=10, + out_token=9, + model="test-model", + ) + + text_completion_client.text_completion_stream = AsyncMock( + side_effect=streaming_text_completion_stream + ) + text_completion_client.text_completion = AsyncMock( + side_effect=non_streaming_text_completion + ) # Mock response producer response_producer = AsyncMock() @@ -156,14 +175,6 @@ class TestPromptStreaming: consumer = MagicMock() - # Mock non-streaming text completion - text_completion_client = mock_flow_context_streaming("text-completion-request") - - async def non_streaming_text_completion(system, prompt, streaming=False): - return "AI is the simulation of human intelligence in machines." - - text_completion_client.text_completion = non_streaming_text_completion - # Act await prompt_processor_streaming.on_request( message, consumer, mock_flow_context_streaming @@ -218,17 +229,12 @@ class TestPromptStreaming: # Mock text completion client that raises an error text_completion_client = AsyncMock() - async def failing_request(request, recipient=None, timeout=600): - if recipient: - # Send error response with proper Error schema - error_response = TextCompletionResponse( - response="", - error=Error(message="Text completion error", type="processing_error"), - end_of_stream=True - ) - await recipient(error_response) + async def failing_stream(system, prompt, handler, timeout=600): + raise RuntimeError("Text completion error") - text_completion_client.request = failing_request + text_completion_client.text_completion_stream = AsyncMock( + side_effect=failing_stream + ) # Mock response producer to capture error response response_producer = AsyncMock() @@ -255,22 +261,15 @@ class TestPromptStreaming: consumer = MagicMock() - # Act - The service catches errors and sends error responses, doesn't raise + # Act - The service catches errors and sends an error PromptResponse await prompt_processor_streaming.on_request(message, consumer, context) - # Assert - Verify error response was sent - assert response_producer.send.call_count > 0 - - # Check that at least one response contains an error - error_sent = False - for call in response_producer.send.call_args_list: - response = call.args[0] - if hasattr(response, 'error') and response.error: - error_sent = True - assert "Text completion error" in response.error.message - break - - assert error_sent, "Expected error response to be sent" + # Assert - error response was sent + calls = response_producer.send.call_args_list + assert len(calls) > 0 + error_response = calls[-1].args[0] + assert error_response.error is not None + assert "Text completion error" in error_response.error.message @pytest.mark.asyncio async def test_prompt_streaming_preserves_message_id(self, prompt_processor_streaming, @@ -315,21 +314,22 @@ class TestPromptStreaming: # Mock text completion that sends empty chunks text_completion_client = AsyncMock() - async def empty_streaming_request(request, recipient=None, timeout=600): - if request.streaming and recipient: - # Send empty chunk followed by final marker - await recipient(TextCompletionResponse( - response="", - error=None, - end_of_stream=False - )) - await recipient(TextCompletionResponse( - response="", - error=None, - end_of_stream=True - )) + async def empty_streaming(system, prompt, handler, timeout=600): + # Send empty chunk followed by final marker + await handler(TextCompletionResponse( + response="", + error=None, + end_of_stream=False + )) + await handler(TextCompletionResponse( + response="", + error=None, + end_of_stream=True + )) - text_completion_client.request = empty_streaming_request + text_completion_client.text_completion_stream = AsyncMock( + side_effect=empty_streaming + ) response_producer = AsyncMock() def context_router(service_name): @@ -401,4 +401,4 @@ class TestPromptStreaming: # Verify chunks concatenate to expected result full_text = "".join(chunk_texts) - assert full_text == "Machine learning is a field of artificial intelligence" + assert full_text == "Machine learning is a field of artificial intelligence." diff --git a/tests/integration/test_rag_streaming_protocol.py b/tests/integration/test_rag_streaming_protocol.py index f5fe14b5..83a90412 100644 --- a/tests/integration/test_rag_streaming_protocol.py +++ b/tests/integration/test_rag_streaming_protocol.py @@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, call from trustgraph.retrieval.graph_rag.graph_rag import GraphRag from trustgraph.retrieval.document_rag.document_rag import DocumentRag from trustgraph.schema import EntityMatch, ChunkMatch, Term, IRI +from trustgraph.base import PromptResult class TestGraphRagStreamingProtocol: @@ -46,8 +47,7 @@ class TestGraphRagStreamingProtocol: async def prompt_side_effect(prompt_name, variables=None, streaming=False, chunk_callback=None): if prompt_name == "kg-edge-selection": - # Edge selection returns empty (no edges selected) - return "" + return PromptResult(response_type="text", text="") elif prompt_name == "kg-synthesis": if streaming and chunk_callback: # Simulate realistic streaming: chunks with end_of_stream=False, then final with end_of_stream=True @@ -55,10 +55,10 @@ class TestGraphRagStreamingProtocol: await chunk_callback(" answer", False) await chunk_callback(" is here.", False) await chunk_callback("", True) # Empty final chunk with end_of_stream=True - return "" # Return value not used since callback handles everything + return PromptResult(response_type="text", text="") else: - return "The answer is here." - return "" + return PromptResult(response_type="text", text="The answer is here.") + return PromptResult(response_type="text", text="") client.prompt.side_effect = prompt_side_effect return client @@ -237,11 +237,13 @@ class TestDocumentRagStreamingProtocol: await chunk_callback("Document", False) await chunk_callback(" summary", False) await chunk_callback(".", True) # Non-empty final chunk - return "" + return PromptResult(response_type="text", text="") else: - return "Document summary." + return PromptResult(response_type="text", text="Document summary.") client.document_prompt.side_effect = document_prompt_side_effect + # Mock prompt() for extract-concepts call in DocumentRag + client.prompt.return_value = PromptResult(response_type="text", text="") return client @pytest.fixture @@ -334,17 +336,17 @@ class TestStreamingProtocolEdgeCases: async def prompt_with_empties(prompt_name, variables=None, streaming=False, chunk_callback=None): if prompt_name == "kg-edge-selection": - return "" + return PromptResult(response_type="text", text="") elif prompt_name == "kg-synthesis": if streaming and chunk_callback: await chunk_callback("text", False) await chunk_callback("", False) # Empty but not final await chunk_callback("more", False) await chunk_callback("", True) # Empty and final - return "" + return PromptResult(response_type="text", text="") else: - return "textmore" - return "" + return PromptResult(response_type="text", text="textmore") + return PromptResult(response_type="text", text="") client.prompt.side_effect = prompt_with_empties diff --git a/tests/pytest.ini b/tests/pytest.ini index b032a9d4..8541bd8f 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -4,14 +4,11 @@ python_paths = . python_files = test_*.py python_classes = Test* python_functions = test_* -addopts = +addopts = -v --tb=short --strict-markers --disable-warnings - --cov=trustgraph - --cov-report=html - --cov-report=term-missing # --cov-fail-under=80 asyncio_mode = auto markers = diff --git a/tests/unit/test_agent/test_agent_service_non_streaming.py b/tests/unit/test_agent/test_agent_service_non_streaming.py index 0b9b283a..bb58e5ee 100644 --- a/tests/unit/test_agent/test_agent_service_non_streaming.py +++ b/tests/unit/test_agent/test_agent_service_non_streaming.py @@ -78,10 +78,10 @@ class TestAgentServiceNonStreaming: # Filter out explain events — those are always sent now content_responses = [ - r for r in sent_responses if r.chunk_type != "explain" + r for r in sent_responses if r.message_type != "explain" ] explain_responses = [ - r for r in sent_responses if r.chunk_type == "explain" + r for r in sent_responses if r.message_type == "explain" ] # Should have explain events for session, iteration, observation, and final @@ -93,7 +93,7 @@ class TestAgentServiceNonStreaming: # Check thought message thought_response = content_responses[0] assert isinstance(thought_response, AgentResponse) - assert thought_response.chunk_type == "thought" + assert thought_response.message_type == "thought" assert thought_response.content == "I need to solve this." assert thought_response.end_of_message is True, "Thought message must have end_of_message=True" assert thought_response.end_of_dialog is False, "Thought message must have end_of_dialog=False" @@ -101,7 +101,7 @@ class TestAgentServiceNonStreaming: # Check observation message observation_response = content_responses[1] assert isinstance(observation_response, AgentResponse) - assert observation_response.chunk_type == "observation" + assert observation_response.message_type == "observation" assert observation_response.content == "The answer is 4." assert observation_response.end_of_message is True, "Observation message must have end_of_message=True" assert observation_response.end_of_dialog is False, "Observation message must have end_of_dialog=False" @@ -168,10 +168,10 @@ class TestAgentServiceNonStreaming: # Filter out explain events — those are always sent now content_responses = [ - r for r in sent_responses if r.chunk_type != "explain" + r for r in sent_responses if r.message_type != "explain" ] explain_responses = [ - r for r in sent_responses if r.chunk_type == "explain" + r for r in sent_responses if r.message_type == "explain" ] # Should have explain events for session and final @@ -183,7 +183,7 @@ class TestAgentServiceNonStreaming: # Check final answer message answer_response = content_responses[0] assert isinstance(answer_response, AgentResponse) - assert answer_response.chunk_type == "answer" + assert answer_response.message_type == "answer" assert answer_response.content == "4" assert answer_response.end_of_message is True, "Final answer must have end_of_message=True" assert answer_response.end_of_dialog is True, "Final answer must have end_of_dialog=True" diff --git a/tests/unit/test_agent/test_callback_message_id.py b/tests/unit/test_agent/test_callback_message_id.py index 7cb0ee54..2c4964a5 100644 --- a/tests/unit/test_agent/test_callback_message_id.py +++ b/tests/unit/test_agent/test_callback_message_id.py @@ -29,7 +29,7 @@ class TestThinkCallbackMessageId: assert len(responses) == 1 assert responses[0].message_id == msg_id - assert responses[0].chunk_type == "thought" + assert responses[0].message_type == "thought" @pytest.mark.asyncio async def test_non_streaming_think_has_message_id(self, pattern): @@ -58,7 +58,7 @@ class TestObserveCallbackMessageId: await observe("result", is_final=True) assert responses[0].message_id == msg_id - assert responses[0].chunk_type == "observation" + assert responses[0].message_type == "observation" class TestAnswerCallbackMessageId: @@ -74,7 +74,7 @@ class TestAnswerCallbackMessageId: await answer("the answer") assert responses[0].message_id == msg_id - assert responses[0].chunk_type == "answer" + assert responses[0].message_type == "answer" @pytest.mark.asyncio async def test_no_message_id_default(self, pattern): diff --git a/tests/unit/test_agent/test_meta_router.py b/tests/unit/test_agent/test_meta_router.py index da0c634c..da8c6c79 100644 --- a/tests/unit/test_agent/test_meta_router.py +++ b/tests/unit/test_agent/test_meta_router.py @@ -9,6 +9,7 @@ from unittest.mock import AsyncMock, MagicMock from trustgraph.agent.orchestrator.meta_router import ( MetaRouter, DEFAULT_PATTERN, DEFAULT_TASK_TYPE, ) +from trustgraph.base import PromptResult def _make_config(patterns=None, task_types=None): @@ -28,7 +29,9 @@ def _make_config(patterns=None, task_types=None): def _make_context(prompt_response): """Build a mock context that returns a mock prompt client.""" client = AsyncMock() - client.prompt = AsyncMock(return_value=prompt_response) + client.prompt = AsyncMock( + return_value=PromptResult(response_type="text", text=prompt_response) + ) def context(service_name): return client @@ -274,8 +277,8 @@ class TestRoute: nonlocal call_count call_count += 1 if call_count == 1: - return "research" # task type - return "plan-then-execute" # pattern + return PromptResult(response_type="text", text="research") + return PromptResult(response_type="text", text="plan-then-execute") client.prompt = mock_prompt context = lambda name: client diff --git a/tests/unit/test_agent/test_orchestrator_provenance_integration.py b/tests/unit/test_agent/test_orchestrator_provenance_integration.py index 96d41259..63d87ba1 100644 --- a/tests/unit/test_agent/test_orchestrator_provenance_integration.py +++ b/tests/unit/test_agent/test_orchestrator_provenance_integration.py @@ -18,6 +18,7 @@ from dataclasses import dataclass, field from trustgraph.schema import ( AgentRequest, AgentResponse, AgentStep, PlanStep, ) +from trustgraph.base import PromptResult from trustgraph.provenance.namespaces import ( RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM, @@ -68,7 +69,7 @@ def collect_explain_events(respond_mock): events = [] for call in respond_mock.call_args_list: resp = call[0][0] - if isinstance(resp, AgentResponse) and resp.chunk_type == "explain": + if isinstance(resp, AgentResponse) and resp.message_type == "explain": events.append({ "explain_id": resp.explain_id, "explain_graph": resp.explain_graph, @@ -183,7 +184,7 @@ class TestReactPatternProvenance: ) async def mock_react(question, history, think, observe, answer, - context, streaming, on_action): + context, streaming, on_action, **kwargs): # Simulate the on_action callback before returning Final if on_action: await on_action(Action( @@ -267,7 +268,7 @@ class TestReactPatternProvenance: MockAM.return_value = mock_am async def mock_react(question, history, think, observe, answer, - context, streaming, on_action): + context, streaming, on_action, **kwargs): if on_action: await on_action(action) return action @@ -309,7 +310,7 @@ class TestReactPatternProvenance: MockAM.return_value = mock_am async def mock_react(question, history, think, observe, answer, - context, streaming, on_action): + context, streaming, on_action, **kwargs): if on_action: await on_action(Action( thought="done", name="final", @@ -355,10 +356,13 @@ class TestPlanPatternProvenance: # Mock prompt client for plan creation mock_prompt_client = AsyncMock() - mock_prompt_client.prompt.return_value = [ - {"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []}, - {"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]}, - ] + mock_prompt_client.prompt.return_value = PromptResult( + response_type="jsonl", + objects=[ + {"goal": "Find information", "tool_hint": "knowledge-query", "depends_on": []}, + {"goal": "Summarise findings", "tool_hint": "", "depends_on": [0]}, + ], + ) def flow_factory(name): if name == "prompt-request": @@ -418,10 +422,13 @@ class TestPlanPatternProvenance: # Mock prompt for step execution mock_prompt_client = AsyncMock() - mock_prompt_client.prompt.return_value = { - "tool": "knowledge-query", - "arguments": {"question": "quantum computing"}, - } + mock_prompt_client.prompt.return_value = PromptResult( + response_type="json", + object={ + "tool": "knowledge-query", + "arguments": {"question": "quantum computing"}, + }, + ) def flow_factory(name): if name == "prompt-request": @@ -475,7 +482,7 @@ class TestPlanPatternProvenance: # Mock prompt for synthesis mock_prompt_client = AsyncMock() - mock_prompt_client.prompt.return_value = "The synthesised answer." + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The synthesised answer.") def flow_factory(name): if name == "prompt-request": @@ -542,10 +549,13 @@ class TestSupervisorPatternProvenance: # Mock prompt for decomposition mock_prompt_client = AsyncMock() - mock_prompt_client.prompt.return_value = [ - "What is quantum computing?", - "What are qubits?", - ] + mock_prompt_client.prompt.return_value = PromptResult( + response_type="jsonl", + objects=[ + "What is quantum computing?", + "What are qubits?", + ], + ) def flow_factory(name): if name == "prompt-request": @@ -590,7 +600,7 @@ class TestSupervisorPatternProvenance: # Mock prompt for synthesis mock_prompt_client = AsyncMock() - mock_prompt_client.prompt.return_value = "The combined answer." + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="The combined answer.") def flow_factory(name): if name == "prompt-request": @@ -639,7 +649,10 @@ class TestSupervisorPatternProvenance: flow = make_mock_flow() mock_prompt_client = AsyncMock() - mock_prompt_client.prompt.return_value = ["Goal A", "Goal B", "Goal C"] + mock_prompt_client.prompt.return_value = PromptResult( + response_type="jsonl", + objects=["Goal A", "Goal B", "Goal C"], + ) def flow_factory(name): if name == "prompt-request": diff --git a/tests/unit/test_agent/test_parse_chunk_message_id.py b/tests/unit/test_agent/test_parse_chunk_message_id.py index 38942f1e..36d2220e 100644 --- a/tests/unit/test_agent/test_parse_chunk_message_id.py +++ b/tests/unit/test_agent/test_parse_chunk_message_id.py @@ -20,7 +20,7 @@ class TestParseChunkMessageId: def test_thought_message_id(self, client): resp = { - "chunk_type": "thought", + "message_type": "thought", "content": "thinking...", "end_of_message": False, "message_id": "urn:trustgraph:agent:sess/i1/thought", @@ -31,7 +31,7 @@ class TestParseChunkMessageId: def test_observation_message_id(self, client): resp = { - "chunk_type": "observation", + "message_type": "observation", "content": "result", "end_of_message": True, "message_id": "urn:trustgraph:agent:sess/i1/observation", @@ -42,7 +42,7 @@ class TestParseChunkMessageId: def test_answer_message_id(self, client): resp = { - "chunk_type": "answer", + "message_type": "answer", "content": "the answer", "end_of_message": False, "end_of_dialog": False, @@ -54,7 +54,7 @@ class TestParseChunkMessageId: def test_thought_missing_message_id(self, client): resp = { - "chunk_type": "thought", + "message_type": "thought", "content": "thinking...", "end_of_message": False, } @@ -64,7 +64,7 @@ class TestParseChunkMessageId: def test_answer_missing_message_id(self, client): resp = { - "chunk_type": "answer", + "message_type": "answer", "content": "answer", "end_of_message": True, "end_of_dialog": True, diff --git a/tests/unit/test_base/test_metrics.py b/tests/unit/test_base/test_metrics.py index b5a792a1..0496db20 100644 --- a/tests/unit/test_base/test_metrics.py +++ b/tests/unit/test_base/test_metrics.py @@ -7,6 +7,14 @@ from trustgraph.base import metrics @pytest.fixture(autouse=True) def reset_metric_singletons(): + """Temporarily remove metric singletons so each test can inject mocks. + + Saves any existing class-level metrics and restores them after the test + so that later tests in the same process still find the hasattr() guard + intact — deleting without restoring causes every subsequent Processor() + construction to re-register the same Prometheus metric name, which raises + ValueError: Duplicated timeseries. + """ classes_and_attrs = { metrics.ConsumerMetrics: [ "state_metric", @@ -23,18 +31,24 @@ def reset_metric_singletons(): ], } + saved = {} for cls, attrs in classes_and_attrs.items(): for attr in attrs: if hasattr(cls, attr): + saved[(cls, attr)] = getattr(cls, attr) delattr(cls, attr) yield + # Remove anything the test may have set, then restore originals for cls, attrs in classes_and_attrs.items(): for attr in attrs: if hasattr(cls, attr): delattr(cls, attr) + for (cls, attr), value in saved.items(): + setattr(cls, attr, value) + def test_consumer_metrics_reuses_singletons_and_records_events(monkeypatch): enum_factory = MagicMock() diff --git a/tests/unit/test_base/test_subscriber_graceful_shutdown.py b/tests/unit/test_base/test_subscriber_graceful_shutdown.py index ec14f66b..cbd4d535 100644 --- a/tests/unit/test_base/test_subscriber_graceful_shutdown.py +++ b/tests/unit/test_base/test_subscriber_graceful_shutdown.py @@ -236,6 +236,10 @@ async def test_subscriber_graceful_shutdown(): with patch.object(subscriber, 'run') as mock_run: # Mock run that simulates graceful shutdown async def mock_run_graceful(): + # Honor the readiness contract: real run() signals _ready + # after binding the consumer, so start() can unblock. Mocks + # of run() must do the same or start() hangs forever. + subscriber._ready.set_result(None) # Process messages while running, then drain while subscriber.running or subscriber.draining: if subscriber.draining: @@ -337,6 +341,8 @@ async def test_subscriber_pending_acks_cleanup(): with patch.object(subscriber, 'run') as mock_run: # Mock run that simulates cleanup of pending acks async def mock_run_cleanup(): + # Honor the readiness contract — see test_subscriber_graceful_shutdown. + subscriber._ready.set_result(None) while subscriber.running or subscriber.draining: await asyncio.sleep(0.05) if subscriber.draining: @@ -406,4 +412,4 @@ async def test_subscriber_multiple_subscribers(): msg1 = await queue1.get() msg_all = await queue_all.get() assert msg1 == {"data": "broadcast"} - assert msg_all == {"data": "broadcast"} \ No newline at end of file + assert msg_all == {"data": "broadcast"} diff --git a/tests/unit/test_base/test_subscriber_readiness.py b/tests/unit/test_base/test_subscriber_readiness.py new file mode 100644 index 00000000..e1ef47de --- /dev/null +++ b/tests/unit/test_base/test_subscriber_readiness.py @@ -0,0 +1,189 @@ +""" +Regression tests for Subscriber.start() readiness barrier. + +Background: prior to the eager-connect fix, Subscriber.start() created +the run() task and returned immediately. The underlying backend consumer +was lazily connected on its first receive() call, which left a setup +race for request/response clients using ephemeral per-subscriber response +queues (RabbitMQ auto-delete exclusive queues): the request would be +published before the response queue was bound, and the broker would +silently drop the reply. fetch_config(), document-embeddings, and +api-gateway all hit this with "Failed to fetch config on notify" / +"Request timeout exception" symptoms. + +These tests pin the readiness contract: + + await subscriber.start() + # at this point, consumer.ensure_connected() MUST have run + +so that any future change which removes the eager bind, or moves it +back to lazy initialisation, fails CI loudly. +""" + +import asyncio + +import pytest +from unittest.mock import MagicMock + +from trustgraph.base.subscriber import Subscriber + + +def _make_backend(ensure_connected_side_effect=None, + receive_side_effect=None): + """Build a fake backend whose consumer records ensure_connected / + receive calls. ensure_connected_side_effect lets a test inject a + delay or exception.""" + backend = MagicMock() + consumer = MagicMock() + + consumer.ensure_connected = MagicMock( + side_effect=ensure_connected_side_effect, + ) + + # By default receive raises a timeout-style exception that the + # subscriber loop is supposed to swallow as a "no message yet" — this + # keeps the subscriber idling cleanly while the test inspects state. + if receive_side_effect is None: + receive_side_effect = TimeoutError("No message received within timeout") + consumer.receive = MagicMock(side_effect=receive_side_effect) + + consumer.acknowledge = MagicMock() + consumer.negative_acknowledge = MagicMock() + consumer.pause_message_listener = MagicMock() + consumer.unsubscribe = MagicMock() + consumer.close = MagicMock() + + backend.create_consumer.return_value = consumer + return backend, consumer + + +def _make_subscriber(backend): + return Subscriber( + backend=backend, + topic="response:tg:config", + subscription="test-sub", + consumer_name="test-consumer", + schema=dict, + max_size=10, + drain_timeout=1.0, + backpressure_strategy="block", + ) + + +class TestSubscriberReadiness: + + @pytest.mark.asyncio + async def test_start_calls_ensure_connected_before_returning(self): + """The barrier: ensure_connected must have been invoked at least + once by the time start() returns.""" + backend, consumer = _make_backend() + subscriber = _make_subscriber(backend) + + await subscriber.start() + + try: + consumer.ensure_connected.assert_called_once() + finally: + await subscriber.stop() + + @pytest.mark.asyncio + async def test_start_blocks_until_ensure_connected_completes(self): + """If ensure_connected is slow, start() must wait for it. This is + the actual race-condition guard — it would have failed against + the buggy version where start() returned before run() had even + scheduled the consumer creation.""" + connect_started = asyncio.Event() + release_connect = asyncio.Event() + + # ensure_connected runs in the executor thread, so we need a + # threading-safe gate. Use a simple busy-wait on a flag set by + # the asyncio side via call_soon_threadsafe — but the simpler + # path is to give it a sleep and observe ordering. + import threading + gate = threading.Event() + + def slow_connect(): + connect_started.set() # safe: only mutates the Event flag + gate.wait(timeout=2.0) + + backend, consumer = _make_backend( + ensure_connected_side_effect=slow_connect, + ) + subscriber = _make_subscriber(backend) + + start_task = asyncio.create_task(subscriber.start()) + + # Wait until ensure_connected has begun executing. + await asyncio.wait_for(connect_started.wait(), timeout=2.0) + + # ensure_connected is in flight — start() must NOT have returned. + assert not start_task.done(), ( + "start() returned before ensure_connected() completed — " + "the readiness barrier is broken and the request/response " + "race condition is back." + ) + + # Release the gate; start() should now complete promptly. + gate.set() + await asyncio.wait_for(start_task, timeout=2.0) + + consumer.ensure_connected.assert_called_once() + + await subscriber.stop() + + @pytest.mark.asyncio + async def test_start_propagates_consumer_creation_failure(self): + """If create_consumer() raises, start() must surface the error + rather than hang on the readiness future. The old code path + retried indefinitely inside run() and never let start() unblock.""" + backend = MagicMock() + backend.create_consumer.side_effect = RuntimeError("broker down") + + subscriber = _make_subscriber(backend) + + with pytest.raises(RuntimeError, match="broker down"): + await asyncio.wait_for(subscriber.start(), timeout=2.0) + + @pytest.mark.asyncio + async def test_start_propagates_ensure_connected_failure(self): + """Same contract for an ensure_connected() that raises (e.g. the + broker is up but the queue declare/bind fails).""" + backend, consumer = _make_backend( + ensure_connected_side_effect=RuntimeError("queue declare failed"), + ) + subscriber = _make_subscriber(backend) + + with pytest.raises(RuntimeError, match="queue declare failed"): + await asyncio.wait_for(subscriber.start(), timeout=2.0) + + @pytest.mark.asyncio + async def test_ensure_connected_runs_before_subscriber_running_log(self): + """Subtle ordering: ensure_connected MUST happen before the + receive loop, so that any reply is captured. We assert this by + checking ensure_connected was called before any receive call.""" + call_order = [] + + def record_ensure(): + call_order.append("ensure_connected") + + def record_receive(*args, **kwargs): + call_order.append("receive") + raise TimeoutError("No message received within timeout") + + backend, consumer = _make_backend( + ensure_connected_side_effect=record_ensure, + receive_side_effect=record_receive, + ) + subscriber = _make_subscriber(backend) + + await subscriber.start() + + # Give the receive loop a tick to run at least once. + await asyncio.sleep(0.05) + + await subscriber.stop() + + # ensure_connected must come first; receive may not have happened + # yet on a fast machine, but if it did, it must come after. + assert call_order, "neither ensure_connected nor receive was called" + assert call_order[0] == "ensure_connected" diff --git a/tests/unit/test_chunking/test_recursive_chunker.py b/tests/unit/test_chunking/test_recursive_chunker.py index a5ec59c8..d1a5d247 100644 --- a/tests/unit/test_chunking/test_recursive_chunker.py +++ b/tests/unit/test_chunking/test_recursive_chunker.py @@ -70,11 +70,12 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): # Mock message and flow mock_message = MagicMock() mock_consumer = MagicMock() + # Flow exposes parameter lookup via __call__: flow("chunk-size") mock_flow = MagicMock() - mock_flow.parameters.get.side_effect = lambda param: { + mock_flow.side_effect = lambda key: { "chunk-size": 2000, # Override chunk size "chunk-overlap": None # Use default chunk overlap - }.get(param) + }.get(key) # Act chunk_size, chunk_overlap = await processor.chunk_document( @@ -105,10 +106,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_consumer = MagicMock() mock_flow = MagicMock() - mock_flow.parameters.get.side_effect = lambda param: { + mock_flow.side_effect = lambda key: { "chunk-size": None, # Use default chunk size "chunk-overlap": 200 # Override chunk overlap - }.get(param) + }.get(key) # Act chunk_size, chunk_overlap = await processor.chunk_document( @@ -139,10 +140,10 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_consumer = MagicMock() mock_flow = MagicMock() - mock_flow.parameters.get.side_effect = lambda param: { + mock_flow.side_effect = lambda key: { "chunk-size": 1500, # Override chunk size "chunk-overlap": 150 # Override chunk overlap - }.get(param) + }.get(key) # Act chunk_size, chunk_overlap = await processor.chunk_document( @@ -195,15 +196,15 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): mock_consumer = MagicMock() mock_producer = AsyncMock() mock_triples_producer = AsyncMock() + # Flow.__call__ resolves parameters and producers/consumers from the + # same dict — merge both kinds here. mock_flow = MagicMock() - mock_flow.parameters.get.side_effect = lambda param: { + mock_flow.side_effect = lambda key: { "chunk-size": 1500, "chunk-overlap": 150, - }.get(param) - mock_flow.side_effect = lambda name: { "output": mock_producer, "triples": mock_triples_producer, - }.get(name) + }.get(key) # Act await processor.on_message(mock_message, mock_consumer, mock_flow) @@ -241,7 +242,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_consumer = MagicMock() mock_flow = MagicMock() - mock_flow.parameters.get.return_value = None # No overrides + mock_flow.side_effect = lambda key: None # No overrides # Act chunk_size, chunk_overlap = await processor.chunk_document( diff --git a/tests/unit/test_chunking/test_token_chunker.py b/tests/unit/test_chunking/test_token_chunker.py index f3f83904..dba4ca94 100644 --- a/tests/unit/test_chunking/test_token_chunker.py +++ b/tests/unit/test_chunking/test_token_chunker.py @@ -70,11 +70,12 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): # Mock message and flow mock_message = MagicMock() mock_consumer = MagicMock() + # Flow exposes parameter lookup via __call__: flow("chunk-size") mock_flow = MagicMock() - mock_flow.parameters.get.side_effect = lambda param: { + mock_flow.side_effect = lambda key: { "chunk-size": 400, # Override chunk size "chunk-overlap": None # Use default chunk overlap - }.get(param) + }.get(key) # Act chunk_size, chunk_overlap = await processor.chunk_document( @@ -105,10 +106,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_consumer = MagicMock() mock_flow = MagicMock() - mock_flow.parameters.get.side_effect = lambda param: { + mock_flow.side_effect = lambda key: { "chunk-size": None, # Use default chunk size "chunk-overlap": 25 # Override chunk overlap - }.get(param) + }.get(key) # Act chunk_size, chunk_overlap = await processor.chunk_document( @@ -139,10 +140,10 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_consumer = MagicMock() mock_flow = MagicMock() - mock_flow.parameters.get.side_effect = lambda param: { + mock_flow.side_effect = lambda key: { "chunk-size": 350, # Override chunk size "chunk-overlap": 30 # Override chunk overlap - }.get(param) + }.get(key) # Act chunk_size, chunk_overlap = await processor.chunk_document( @@ -195,15 +196,15 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): mock_consumer = MagicMock() mock_producer = AsyncMock() mock_triples_producer = AsyncMock() + # Flow.__call__ resolves parameters and producers/consumers from the + # same dict — merge both kinds here. mock_flow = MagicMock() - mock_flow.parameters.get.side_effect = lambda param: { + mock_flow.side_effect = lambda key: { "chunk-size": 400, "chunk-overlap": 40, - }.get(param) - mock_flow.side_effect = lambda name: { "output": mock_producer, "triples": mock_triples_producer, - }.get(name) + }.get(key) # Act await processor.on_message(mock_message, mock_consumer, mock_flow) @@ -245,7 +246,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_consumer = MagicMock() mock_flow = MagicMock() - mock_flow.parameters.get.return_value = None # No overrides + mock_flow.side_effect = lambda key: None # No overrides # Act chunk_size, chunk_overlap = await processor.chunk_document( diff --git a/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py b/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py index b651b59e..cbc9a05a 100644 --- a/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py +++ b/tests/unit/test_extract/test_streaming_triples/test_definitions_batching.py @@ -12,6 +12,7 @@ from unittest.mock import AsyncMock, MagicMock from trustgraph.extract.kg.definitions.extract import ( Processor, default_triples_batch_size, default_entity_batch_size, ) +from trustgraph.base import PromptResult from trustgraph.schema import ( Chunk, Triples, EntityContexts, Triple, Metadata, Term, IRI, LITERAL, ) @@ -51,8 +52,12 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"): mock_triples_pub = AsyncMock() mock_ecs_pub = AsyncMock() mock_prompt_client = AsyncMock() + if isinstance(prompt_result, list): + wrapped = PromptResult(response_type="jsonl", objects=prompt_result) + else: + wrapped = PromptResult(response_type="text", text=prompt_result) mock_prompt_client.extract_definitions = AsyncMock( - return_value=prompt_result + return_value=wrapped ) def flow(name): diff --git a/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py b/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py index cf3b1fb0..d9861cf3 100644 --- a/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py +++ b/tests/unit/test_extract/test_streaming_triples/test_relationships_batching.py @@ -14,6 +14,7 @@ from trustgraph.extract.kg.relationships.extract import ( from trustgraph.schema import ( Chunk, Triples, Triple, Metadata, Term, IRI, LITERAL, ) +from trustgraph.base import PromptResult # --------------------------------------------------------------------------- @@ -58,7 +59,10 @@ def _make_flow(prompt_result, llm_model="test-llm", ontology_uri="test-onto"): mock_triples_pub = AsyncMock() mock_prompt_client = AsyncMock() mock_prompt_client.extract_relationships = AsyncMock( - return_value=prompt_result + return_value=PromptResult( + response_type="jsonl", + objects=prompt_result, + ) ) def flow(name): diff --git a/tests/unit/test_gateway/test_core_import_export_roundtrip.py b/tests/unit/test_gateway/test_core_import_export_roundtrip.py new file mode 100644 index 00000000..843a2b7b --- /dev/null +++ b/tests/unit/test_gateway/test_core_import_export_roundtrip.py @@ -0,0 +1,418 @@ +""" +Round-trip unit tests for the core msgpack import/export gateway endpoints. + +The kg-core export endpoint receives KnowledgeResponse-shaped dicts from +the responder callback and packs them into msgpack tuples. The kg-core +import endpoint takes msgpack tuples back off the wire and rebuilds +KnowledgeRequest-shaped dicts which it then hands to KnowledgeRequestor +(whose translator decodes them into real dataclasses). + +Regression coverage: the previous wire format used `"vectors"` (plural) +in the entity blobs and embedded a stale `"m"` field that referenced the +removed `Metadata.metadata` triples-list field. The export side hit a +KeyError on first message; the import side built dicts that the +KnowledgeRequestTranslator (separately fixed) couldn't decode. These +tests pin both halves of the wire protocol. +""" + +import msgpack +import pytest +from unittest.mock import AsyncMock, Mock, patch + +from trustgraph.gateway.dispatch.core_export import CoreExport +from trustgraph.gateway.dispatch.core_import import CoreImport + + +# --------------------------------------------------------------------------- +# Helpers — sample translator-shaped dicts (as KnowledgeResponseTranslator +# would emit). The vector wire key is *singular* on purpose; the export +# side previously read the wrong key and crashed. +# --------------------------------------------------------------------------- + + +def _ge_response_dict(): + return { + "graph-embeddings": { + "metadata": { + "id": "doc-1", + "root": "", + "user": "alice", + "collection": "testcoll", + }, + "entities": [ + { + "entity": {"t": "i", "i": "http://example.org/alice"}, + "vector": [0.1, 0.2, 0.3], + }, + { + "entity": {"t": "i", "i": "http://example.org/bob"}, + "vector": [0.4, 0.5, 0.6], + }, + ], + } + } + + +def _triples_response_dict(): + return { + "triples": { + "metadata": { + "id": "doc-1", + "root": "", + "user": "alice", + "collection": "testcoll", + }, + "triples": [ + { + "s": {"t": "i", "i": "http://example.org/alice"}, + "p": {"t": "i", "i": "http://example.org/knows"}, + "o": {"t": "i", "i": "http://example.org/bob"}, + }, + ], + } + } + + +def _make_request(id_="doc-1", user="alice"): + request = Mock() + request.query = {"id": id_, "user": user} + return request + + +def _make_data_reader(payload: bytes): + """Mock the aiohttp StreamReader: returns payload once, then EOF.""" + chunks = [payload, b""] + + data = Mock() + + async def fake_read(n): + return chunks.pop(0) if chunks else b"" + + data.read = fake_read + return data + + +# --------------------------------------------------------------------------- +# Export side: translator-shaped dict -> msgpack bytes +# --------------------------------------------------------------------------- + + +class TestCoreExportWireFormat: + + @pytest.mark.asyncio + @patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor") + async def test_export_packs_graph_embeddings_with_singular_vector( + self, mock_kr_class, + ): + """The export side must read `ent["vector"]` and emit `v`. The + previous bug was reading `ent["vectors"]` which KeyErrored against + the translator output.""" + captured = [] + + async def fake_kr_process(req_dict, responder): + await responder(_ge_response_dict(), True) + + mock_kr = AsyncMock() + mock_kr.start = AsyncMock() + mock_kr.stop = AsyncMock() + mock_kr.process = fake_kr_process + mock_kr_class.return_value = mock_kr + + response = AsyncMock() + + async def fake_write(b): + captured.append(b) + + response.write = fake_write + response.write_eof = AsyncMock() + + ok = AsyncMock(return_value=response) + error = AsyncMock() + + exporter = CoreExport(backend=Mock()) + await exporter.process( + data=Mock(), + error=error, + ok=ok, + request=_make_request(), + ) + + # Did not raise, did not call error() + error.assert_not_called() + assert len(captured) == 1 + + unpacker = msgpack.Unpacker() + unpacker.feed(captured[0]) + items = list(unpacker) + + assert len(items) == 1 + msg_type, payload = items[0] + assert msg_type == "ge" + + # Metadata envelope: only id/user/collection — no stale `m["m"]`. + assert payload["m"] == { + "i": "doc-1", + "u": "alice", + "c": "testcoll", + } + + # Entities: each carries the *singular* `v` and the term envelope + assert len(payload["e"]) == 2 + assert payload["e"][0]["v"] == [0.1, 0.2, 0.3] + assert payload["e"][1]["v"] == [0.4, 0.5, 0.6] + assert payload["e"][0]["e"]["i"] == "http://example.org/alice" + + @pytest.mark.asyncio + @patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor") + async def test_export_packs_triples(self, mock_kr_class): + captured = [] + + async def fake_kr_process(req_dict, responder): + await responder(_triples_response_dict(), True) + + mock_kr = AsyncMock() + mock_kr.start = AsyncMock() + mock_kr.stop = AsyncMock() + mock_kr.process = fake_kr_process + mock_kr_class.return_value = mock_kr + + response = AsyncMock() + + async def fake_write(b): + captured.append(b) + + response.write = fake_write + response.write_eof = AsyncMock() + + ok = AsyncMock(return_value=response) + error = AsyncMock() + + exporter = CoreExport(backend=Mock()) + await exporter.process( + data=Mock(), error=error, ok=ok, request=_make_request(), + ) + + error.assert_not_called() + assert len(captured) == 1 + + unpacker = msgpack.Unpacker() + unpacker.feed(captured[0]) + items = list(unpacker) + assert len(items) == 1 + + msg_type, payload = items[0] + assert msg_type == "t" + assert payload["m"] == { + "i": "doc-1", + "u": "alice", + "c": "testcoll", + } + assert len(payload["t"]) == 1 + + +# --------------------------------------------------------------------------- +# Import side: msgpack bytes -> translator-shaped dict +# --------------------------------------------------------------------------- + + +class TestCoreImportWireFormat: + + @pytest.mark.asyncio + @patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor") + async def test_import_unpacks_graph_embeddings_to_singular_vector( + self, mock_kr_class, + ): + """The import side must build dicts whose entity blobs have the + singular `vector` key — that's what the KnowledgeRequestTranslator + decode side reads. Previous bug emitted `vectors`.""" + captured = [] + + async def fake_kr_process(req_dict): + captured.append(req_dict) + + mock_kr = AsyncMock() + mock_kr.start = AsyncMock() + mock_kr.stop = AsyncMock() + mock_kr.process = fake_kr_process + mock_kr_class.return_value = mock_kr + + # Build a msgpack tuple matching the new wire format + payload = msgpack.packb(( + "ge", + { + "m": {"i": "doc-1", "u": "alice", "c": "testcoll"}, + "e": [ + { + "e": {"t": "i", "i": "http://example.org/alice"}, + "v": [0.1, 0.2, 0.3], + }, + ], + }, + )) + + ok = AsyncMock(return_value=AsyncMock(write_eof=AsyncMock())) + error = AsyncMock() + + importer = CoreImport(backend=Mock()) + await importer.process( + data=_make_data_reader(payload), + error=error, + ok=ok, + request=_make_request(), + ) + + error.assert_not_called() + assert len(captured) == 1 + + req = captured[0] + assert req["operation"] == "put-kg-core" + assert req["user"] == "alice" + assert req["id"] == "doc-1" + + ge = req["graph-embeddings"] + # Metadata envelope must NOT contain a stale `metadata` key + # referencing the removed Metadata.metadata field. + assert "metadata" not in ge["metadata"] + assert ge["metadata"] == { + "id": "doc-1", + "user": "alice", + "collection": "default", + } + + # Entity blob carries the singular `vector` key + assert len(ge["entities"]) == 1 + ent = ge["entities"][0] + assert ent["vector"] == [0.1, 0.2, 0.3] + assert "vectors" not in ent + + @pytest.mark.asyncio + @patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor") + async def test_import_unpacks_triples(self, mock_kr_class): + captured = [] + + async def fake_kr_process(req_dict): + captured.append(req_dict) + + mock_kr = AsyncMock() + mock_kr.start = AsyncMock() + mock_kr.stop = AsyncMock() + mock_kr.process = fake_kr_process + mock_kr_class.return_value = mock_kr + + payload = msgpack.packb(( + "t", + { + "m": {"i": "doc-1", "u": "alice", "c": "testcoll"}, + "t": [ + { + "s": {"t": "i", "i": "http://example.org/alice"}, + "p": {"t": "i", "i": "http://example.org/knows"}, + "o": {"t": "i", "i": "http://example.org/bob"}, + }, + ], + }, + )) + + ok = AsyncMock(return_value=AsyncMock(write_eof=AsyncMock())) + error = AsyncMock() + + importer = CoreImport(backend=Mock()) + await importer.process( + data=_make_data_reader(payload), + error=error, + ok=ok, + request=_make_request(), + ) + + error.assert_not_called() + assert len(captured) == 1 + + req = captured[0] + triples = req["triples"] + assert "metadata" not in triples["metadata"] # no stale field + assert len(triples["triples"]) == 1 + + +# --------------------------------------------------------------------------- +# Full round-trip: export bytes feed directly into import +# --------------------------------------------------------------------------- + + +class TestCoreImportExportRoundTrip: + """End-to-end: produce bytes via core_export, consume them via + core_import, and verify the dict that lands at the import-side + translator is structurally equivalent to what went in. This is the + test that catches asymmetries between the two halves.""" + + @pytest.mark.asyncio + @patch("trustgraph.gateway.dispatch.core_import.KnowledgeRequestor") + @patch("trustgraph.gateway.dispatch.core_export.KnowledgeRequestor") + async def test_graph_embeddings_round_trip( + self, mock_export_kr_class, mock_import_kr_class, + ): + # ----- export side: capture bytes ----- + export_bytes = [] + + async def fake_export_process(req_dict, responder): + await responder(_ge_response_dict(), True) + + export_kr = AsyncMock() + export_kr.start = AsyncMock() + export_kr.stop = AsyncMock() + export_kr.process = fake_export_process + mock_export_kr_class.return_value = export_kr + + response = AsyncMock() + + async def fake_write(b): + export_bytes.append(b) + + response.write = fake_write + response.write_eof = AsyncMock() + + exporter = CoreExport(backend=Mock()) + await exporter.process( + data=Mock(), + error=AsyncMock(), + ok=AsyncMock(return_value=response), + request=_make_request(), + ) + + assert len(export_bytes) == 1 + + # ----- import side: feed those bytes back in ----- + import_captured = [] + + async def fake_import_process(req_dict): + import_captured.append(req_dict) + + import_kr = AsyncMock() + import_kr.start = AsyncMock() + import_kr.stop = AsyncMock() + import_kr.process = fake_import_process + mock_import_kr_class.return_value = import_kr + + importer = CoreImport(backend=Mock()) + await importer.process( + data=_make_data_reader(export_bytes[0]), + error=AsyncMock(), + ok=AsyncMock(return_value=AsyncMock(write_eof=AsyncMock())), + request=_make_request(), + ) + + # ----- verify the dict the importer would hand to the translator ----- + assert len(import_captured) == 1 + req = import_captured[0] + + original = _ge_response_dict()["graph-embeddings"] + + ge = req["graph-embeddings"] + # The import side overrides id/user from the URL query (intentional), + # so we only round-trip the entity payload itself. + assert ge["metadata"]["id"] == original["metadata"]["id"] + assert ge["metadata"]["user"] == original["metadata"]["user"] + + assert len(ge["entities"]) == len(original["entities"]) + for got, want in zip(ge["entities"], original["entities"]): + assert got["vector"] == want["vector"] + assert got["entity"] == want["entity"] diff --git a/tests/unit/test_gateway/test_entity_contexts_import_dispatcher.py b/tests/unit/test_gateway/test_entity_contexts_import_dispatcher.py new file mode 100644 index 00000000..8eddeba9 --- /dev/null +++ b/tests/unit/test_gateway/test_entity_contexts_import_dispatcher.py @@ -0,0 +1,242 @@ +""" +Unit tests for entity contexts import dispatcher. + +Tests the business logic of EntityContextsImport while mocking the +Publisher and websocket components. + +Regression coverage: a previous version constructed Metadata(metadata=...) +which raised TypeError at runtime as soon as a message was received. These +tests exercise receive() end-to-end so any future schema/kwarg drift in +the Metadata or EntityContexts construction is caught immediately. +""" + +import pytest +from unittest.mock import Mock, AsyncMock, patch + +from trustgraph.gateway.dispatch.entity_contexts_import import EntityContextsImport +from trustgraph.schema import EntityContexts, EntityContext, Metadata + + +@pytest.fixture +def mock_backend(): + return Mock() + + +@pytest.fixture +def mock_running(): + running = Mock() + running.get.return_value = True + running.stop = Mock() + return running + + +@pytest.fixture +def mock_websocket(): + ws = Mock() + ws.close = AsyncMock() + return ws + + +@pytest.fixture +def sample_message(): + """Sample entity-contexts websocket message.""" + return { + "metadata": { + "id": "doc-123", + "user": "testuser", + "collection": "testcollection", + }, + "entities": [ + { + "entity": {"v": "http://example.org/alice", "e": True}, + "context": "Alice is a person.", + }, + { + "entity": {"v": "http://example.org/bob", "e": True}, + "context": "Bob is a person.", + }, + ], + } + + +@pytest.fixture +def empty_entities_message(): + return { + "metadata": { + "id": "doc-empty", + "user": "u", + "collection": "c", + }, + "entities": [], + } + + +class TestEntityContextsImportInitialization: + + @patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher') + def test_init_creates_publisher_with_correct_params( + self, mock_publisher_class, mock_backend, mock_websocket, mock_running + ): + instance = Mock() + mock_publisher_class.return_value = instance + + dispatcher = EntityContextsImport( + ws=mock_websocket, + running=mock_running, + backend=mock_backend, + queue="ec-queue", + ) + + mock_publisher_class.assert_called_once_with( + mock_backend, + topic="ec-queue", + schema=EntityContexts, + ) + assert dispatcher.ws is mock_websocket + assert dispatcher.running is mock_running + assert dispatcher.publisher is instance + + +class TestEntityContextsImportLifecycle: + + @patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher') + @pytest.mark.asyncio + async def test_start_calls_publisher_start( + self, mock_publisher_class, mock_backend, mock_websocket, mock_running + ): + instance = Mock() + instance.start = AsyncMock() + mock_publisher_class.return_value = instance + + dispatcher = EntityContextsImport( + ws=mock_websocket, running=mock_running, + backend=mock_backend, queue="q", + ) + await dispatcher.start() + instance.start.assert_called_once() + + @patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher') + @pytest.mark.asyncio + async def test_destroy_stops_and_closes_properly( + self, mock_publisher_class, mock_backend, mock_websocket, mock_running + ): + instance = Mock() + instance.stop = AsyncMock() + mock_publisher_class.return_value = instance + + dispatcher = EntityContextsImport( + ws=mock_websocket, running=mock_running, + backend=mock_backend, queue="q", + ) + await dispatcher.destroy() + + mock_running.stop.assert_called_once() + instance.stop.assert_called_once() + mock_websocket.close.assert_called_once() + + @patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher') + @pytest.mark.asyncio + async def test_destroy_handles_none_websocket( + self, mock_publisher_class, mock_backend, mock_running + ): + instance = Mock() + instance.stop = AsyncMock() + mock_publisher_class.return_value = instance + + dispatcher = EntityContextsImport( + ws=None, running=mock_running, + backend=mock_backend, queue="q", + ) + await dispatcher.destroy() + + mock_running.stop.assert_called_once() + instance.stop.assert_called_once() + + +class TestEntityContextsImportMessageProcessing: + """Regression coverage for receive(): catches Metadata/schema drift.""" + + @patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher') + @pytest.mark.asyncio + async def test_receive_constructs_entity_contexts_correctly( + self, mock_publisher_class, mock_backend, mock_websocket, + mock_running, sample_message, + ): + instance = Mock() + instance.send = AsyncMock() + mock_publisher_class.return_value = instance + + dispatcher = EntityContextsImport( + ws=mock_websocket, running=mock_running, + backend=mock_backend, queue="q", + ) + + mock_msg = Mock() + mock_msg.json.return_value = sample_message + + # If Metadata or EntityContexts gain/lose kwargs, this raises + # TypeError — exactly the regression we want to catch. + await dispatcher.receive(mock_msg) + + instance.send.assert_called_once() + call_args = instance.send.call_args + assert call_args[0][0] is None + + sent = call_args[0][1] + assert isinstance(sent, EntityContexts) + assert isinstance(sent.metadata, Metadata) + assert sent.metadata.id == "doc-123" + assert sent.metadata.user == "testuser" + assert sent.metadata.collection == "testcollection" + + assert len(sent.entities) == 2 + assert all(isinstance(e, EntityContext) for e in sent.entities) + assert sent.entities[0].context == "Alice is a person." + assert sent.entities[1].context == "Bob is a person." + + @patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher') + @pytest.mark.asyncio + async def test_receive_handles_empty_entities( + self, mock_publisher_class, mock_backend, mock_websocket, + mock_running, empty_entities_message, + ): + instance = Mock() + instance.send = AsyncMock() + mock_publisher_class.return_value = instance + + dispatcher = EntityContextsImport( + ws=mock_websocket, running=mock_running, + backend=mock_backend, queue="q", + ) + + mock_msg = Mock() + mock_msg.json.return_value = empty_entities_message + + await dispatcher.receive(mock_msg) + + instance.send.assert_called_once() + sent = instance.send.call_args[0][1] + assert isinstance(sent, EntityContexts) + assert sent.entities == [] + assert sent.metadata.id == "doc-empty" + + @patch('trustgraph.gateway.dispatch.entity_contexts_import.Publisher') + @pytest.mark.asyncio + async def test_receive_propagates_publisher_errors( + self, mock_publisher_class, mock_backend, mock_websocket, + mock_running, sample_message, + ): + instance = Mock() + instance.send = AsyncMock(side_effect=RuntimeError("publish failed")) + mock_publisher_class.return_value = instance + + dispatcher = EntityContextsImport( + ws=mock_websocket, running=mock_running, + backend=mock_backend, queue="q", + ) + + mock_msg = Mock() + mock_msg.json.return_value = sample_message + + with pytest.raises(RuntimeError, match="publish failed"): + await dispatcher.receive(mock_msg) diff --git a/tests/unit/test_gateway/test_explain_triples.py b/tests/unit/test_gateway/test_explain_triples.py index 24e77410..42a2f4c5 100644 --- a/tests/unit/test_gateway/test_explain_triples.py +++ b/tests/unit/test_gateway/test_explain_triples.py @@ -158,7 +158,7 @@ class TestAgentExplainTriples: translator = AgentResponseTranslator() response = AgentResponse( - chunk_type="explain", + message_type="explain", content="", explain_id="urn:trustgraph:agent:session:abc123", explain_graph="urn:graph:retrieval", @@ -179,7 +179,7 @@ class TestAgentExplainTriples: translator = AgentResponseTranslator() response = AgentResponse( - chunk_type="thought", + message_type="thought", content="I need to think...", ) @@ -190,7 +190,7 @@ class TestAgentExplainTriples: translator = AgentResponseTranslator() response = AgentResponse( - chunk_type="explain", + message_type="explain", explain_id="urn:trustgraph:agent:session:abc123", explain_triples=sample_triples(), end_of_dialog=False, @@ -203,7 +203,7 @@ class TestAgentExplainTriples: translator = AgentResponseTranslator() response = AgentResponse( - chunk_type="answer", + message_type="answer", content="The answer is...", end_of_dialog=True, ) diff --git a/tests/unit/test_gateway/test_graph_embeddings_import_dispatcher.py b/tests/unit/test_gateway/test_graph_embeddings_import_dispatcher.py new file mode 100644 index 00000000..fa277178 --- /dev/null +++ b/tests/unit/test_gateway/test_graph_embeddings_import_dispatcher.py @@ -0,0 +1,247 @@ +""" +Unit tests for graph embeddings import dispatcher. + +Tests the business logic of GraphEmbeddingsImport while mocking the +Publisher and websocket components. + +Regression coverage: a previous version of EntityContextsImport +constructed Metadata(metadata=...) which raised TypeError at runtime as +soon as a message was received. The same shape of bug can occur here, so +these tests exercise receive() end-to-end to catch any future schema or +kwarg drift in Metadata / GraphEmbeddings / EntityEmbeddings construction. +""" + +import pytest +from unittest.mock import Mock, AsyncMock, patch + +from trustgraph.gateway.dispatch.graph_embeddings_import import GraphEmbeddingsImport +from trustgraph.schema import GraphEmbeddings, EntityEmbeddings, Metadata + + +@pytest.fixture +def mock_backend(): + return Mock() + + +@pytest.fixture +def mock_running(): + running = Mock() + running.get.return_value = True + running.stop = Mock() + return running + + +@pytest.fixture +def mock_websocket(): + ws = Mock() + ws.close = AsyncMock() + return ws + + +@pytest.fixture +def sample_message(): + """Sample graph-embeddings websocket message.""" + return { + "metadata": { + "id": "doc-123", + "user": "testuser", + "collection": "testcollection", + }, + "entities": [ + { + "entity": {"v": "http://example.org/alice", "e": True}, + "vector": [0.1, 0.2, 0.3], + }, + { + "entity": {"v": "http://example.org/bob", "e": True}, + "vector": [0.4, 0.5, 0.6], + }, + ], + } + + +@pytest.fixture +def empty_entities_message(): + return { + "metadata": { + "id": "doc-empty", + "user": "u", + "collection": "c", + }, + "entities": [], + } + + +class TestGraphEmbeddingsImportInitialization: + + @patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher') + def test_init_creates_publisher_with_correct_params( + self, mock_publisher_class, mock_backend, mock_websocket, mock_running + ): + instance = Mock() + mock_publisher_class.return_value = instance + + dispatcher = GraphEmbeddingsImport( + ws=mock_websocket, + running=mock_running, + backend=mock_backend, + queue="ge-queue", + ) + + mock_publisher_class.assert_called_once_with( + mock_backend, + topic="ge-queue", + schema=GraphEmbeddings, + ) + assert dispatcher.ws is mock_websocket + assert dispatcher.running is mock_running + assert dispatcher.publisher is instance + + +class TestGraphEmbeddingsImportLifecycle: + + @patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher') + @pytest.mark.asyncio + async def test_start_calls_publisher_start( + self, mock_publisher_class, mock_backend, mock_websocket, mock_running + ): + instance = Mock() + instance.start = AsyncMock() + mock_publisher_class.return_value = instance + + dispatcher = GraphEmbeddingsImport( + ws=mock_websocket, running=mock_running, + backend=mock_backend, queue="q", + ) + await dispatcher.start() + instance.start.assert_called_once() + + @patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher') + @pytest.mark.asyncio + async def test_destroy_stops_and_closes_properly( + self, mock_publisher_class, mock_backend, mock_websocket, mock_running + ): + instance = Mock() + instance.stop = AsyncMock() + mock_publisher_class.return_value = instance + + dispatcher = GraphEmbeddingsImport( + ws=mock_websocket, running=mock_running, + backend=mock_backend, queue="q", + ) + await dispatcher.destroy() + + mock_running.stop.assert_called_once() + instance.stop.assert_called_once() + mock_websocket.close.assert_called_once() + + @patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher') + @pytest.mark.asyncio + async def test_destroy_handles_none_websocket( + self, mock_publisher_class, mock_backend, mock_running + ): + instance = Mock() + instance.stop = AsyncMock() + mock_publisher_class.return_value = instance + + dispatcher = GraphEmbeddingsImport( + ws=None, running=mock_running, + backend=mock_backend, queue="q", + ) + await dispatcher.destroy() + + mock_running.stop.assert_called_once() + instance.stop.assert_called_once() + + +class TestGraphEmbeddingsImportMessageProcessing: + """Regression coverage for receive(): catches Metadata/schema drift.""" + + @patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher') + @pytest.mark.asyncio + async def test_receive_constructs_graph_embeddings_correctly( + self, mock_publisher_class, mock_backend, mock_websocket, + mock_running, sample_message, + ): + instance = Mock() + instance.send = AsyncMock() + mock_publisher_class.return_value = instance + + dispatcher = GraphEmbeddingsImport( + ws=mock_websocket, running=mock_running, + backend=mock_backend, queue="q", + ) + + mock_msg = Mock() + mock_msg.json.return_value = sample_message + + # If Metadata, GraphEmbeddings, or EntityEmbeddings gain/lose + # kwargs, this raises TypeError — exactly the regression we want + # to catch. + await dispatcher.receive(mock_msg) + + instance.send.assert_called_once() + call_args = instance.send.call_args + assert call_args[0][0] is None + + sent = call_args[0][1] + assert isinstance(sent, GraphEmbeddings) + assert isinstance(sent.metadata, Metadata) + assert sent.metadata.id == "doc-123" + assert sent.metadata.user == "testuser" + assert sent.metadata.collection == "testcollection" + + assert len(sent.entities) == 2 + assert all(isinstance(e, EntityEmbeddings) for e in sent.entities) + # Lock in the wire format: incoming "vector" key (singular, + # list[float]) maps to EntityEmbeddings.vector. This mirrors + # serialize_graph_embeddings() on the export side. + assert sent.entities[0].vector == [0.1, 0.2, 0.3] + assert sent.entities[1].vector == [0.4, 0.5, 0.6] + + @patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher') + @pytest.mark.asyncio + async def test_receive_handles_empty_entities( + self, mock_publisher_class, mock_backend, mock_websocket, + mock_running, empty_entities_message, + ): + instance = Mock() + instance.send = AsyncMock() + mock_publisher_class.return_value = instance + + dispatcher = GraphEmbeddingsImport( + ws=mock_websocket, running=mock_running, + backend=mock_backend, queue="q", + ) + + mock_msg = Mock() + mock_msg.json.return_value = empty_entities_message + + await dispatcher.receive(mock_msg) + + instance.send.assert_called_once() + sent = instance.send.call_args[0][1] + assert isinstance(sent, GraphEmbeddings) + assert sent.entities == [] + assert sent.metadata.id == "doc-empty" + + @patch('trustgraph.gateway.dispatch.graph_embeddings_import.Publisher') + @pytest.mark.asyncio + async def test_receive_propagates_publisher_errors( + self, mock_publisher_class, mock_backend, mock_websocket, + mock_running, sample_message, + ): + instance = Mock() + instance.send = AsyncMock(side_effect=RuntimeError("publish failed")) + mock_publisher_class.return_value = instance + + dispatcher = GraphEmbeddingsImport( + ws=mock_websocket, running=mock_running, + backend=mock_backend, queue="q", + ) + + mock_msg = Mock() + mock_msg.json.return_value = sample_message + + with pytest.raises(RuntimeError, match="publish failed"): + await dispatcher.receive(mock_msg) diff --git a/tests/unit/test_gateway/test_service.py b/tests/unit/test_gateway/test_service.py index 22d9ab04..71428db4 100644 --- a/tests/unit/test_gateway/test_service.py +++ b/tests/unit/test_gateway/test_service.py @@ -171,6 +171,14 @@ class TestApi: patch('aiohttp.web.run_app') as mock_run_app: mock_get_pubsub.return_value = Mock() + # Api.run() passes self.app_factory() — a coroutine — to + # web.run_app, which would normally consume it inside its own + # event loop. Since we mock run_app, close the coroutine here + # so it doesn't leak as an "unawaited coroutine" RuntimeWarning. + def _consume_coro(coro, **kwargs): + coro.close() + mock_run_app.side_effect = _consume_coro + api = Api(port=8080) api.run() diff --git a/tests/unit/test_provenance/test_dag_structure.py b/tests/unit/test_provenance/test_dag_structure.py new file mode 100644 index 00000000..184560f0 --- /dev/null +++ b/tests/unit/test_provenance/test_dag_structure.py @@ -0,0 +1,592 @@ +""" +DAG structure tests for provenance chains. + +Verifies that the wasDerivedFrom chain has the expected shape for each +service. These tests catch structural regressions when new entities are +inserted into the chain (e.g. PatternDecision between session and first +iteration). + +Expected chains: + + GraphRAG: question → grounding → exploration → focus → synthesis + DocumentRAG: question → grounding → exploration → synthesis + Agent React: session → pattern-decision → iteration → (observation → iteration)* → final + Agent Plan: session → pattern-decision → plan → step-result(s) → synthesis + Agent Super: session → pattern-decision → decomposition → (fan-out) → finding(s) → synthesis +""" + +import json +import uuid +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from trustgraph.schema import ( + AgentRequest, AgentResponse, AgentStep, PlanStep, + Triple, Term, IRI, LITERAL, +) +from trustgraph.base import PromptResult + +from trustgraph.provenance.namespaces import ( + RDF_TYPE, PROV_WAS_DERIVED_FROM, GRAPH_RETRIEVAL, + TG_AGENT_QUESTION, TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, + TG_GROUNDING, TG_EXPLORATION, TG_FOCUS, TG_SYNTHESIS, + TG_ANALYSIS, TG_CONCLUSION, TG_PATTERN_DECISION, + TG_PLAN_TYPE, TG_STEP_RESULT, TG_DECOMPOSITION, + TG_OBSERVATION_TYPE, + TG_PATTERN, TG_TASK_TYPE, +) + +from trustgraph.retrieval.graph_rag.graph_rag import GraphRag +from trustgraph.retrieval.document_rag.document_rag import DocumentRag + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _collect_events(events): + """Build a dict of explain_id → {types, derived_from, triples}.""" + result = {} + for ev in events: + eid = ev["explain_id"] + triples = ev["triples"] + types = { + t.o.iri for t in triples + if t.s.iri == eid and t.p.iri == RDF_TYPE + } + parents = [ + t.o.iri for t in triples + if t.s.iri == eid and t.p.iri == PROV_WAS_DERIVED_FROM + ] + result[eid] = { + "types": types, + "derived_from": parents[0] if parents else None, + "triples": triples, + } + return result + + +def _find_by_type(dag, rdf_type): + """Find all event IDs that have the given rdf:type.""" + return [eid for eid, info in dag.items() if rdf_type in info["types"]] + + +def _assert_chain(dag, chain_types): + """Assert that a linear wasDerivedFrom chain exists through the given types.""" + for i in range(1, len(chain_types)): + parent_type = chain_types[i - 1] + child_type = chain_types[i] + parents = _find_by_type(dag, parent_type) + children = _find_by_type(dag, child_type) + assert parents, f"No entity with type {parent_type}" + assert children, f"No entity with type {child_type}" + # At least one child must derive from at least one parent + linked = False + for child_id in children: + derived = dag[child_id]["derived_from"] + if derived in parents: + linked = True + break + assert linked, ( + f"No {child_type} derives from {parent_type}. " + f"Children derive from: " + f"{[dag[c]['derived_from'] for c in children]}" + ) + + +# --------------------------------------------------------------------------- +# GraphRAG DAG structure +# --------------------------------------------------------------------------- + +class TestGraphRagDagStructure: + """Verify: question → grounding → exploration → focus → synthesis""" + + @pytest.fixture + def mock_clients(self): + prompt_client = AsyncMock() + embeddings_client = AsyncMock() + graph_embeddings_client = AsyncMock() + triples_client = AsyncMock() + + embeddings_client.embed.return_value = [[0.1, 0.2]] + graph_embeddings_client.query.return_value = [ + MagicMock(entity=Term(type=IRI, iri="http://example.com/e1")), + ] + triples_client.query_stream.return_value = [ + Triple( + s=Term(type=IRI, iri="http://example.com/e1"), + p=Term(type=IRI, iri="http://example.com/p"), + o=Term(type=LITERAL, value="value"), + ) + ] + triples_client.query.return_value = [] + + async def mock_prompt(template_id, variables=None, **kwargs): + if template_id == "extract-concepts": + return PromptResult(response_type="text", text="concept") + elif template_id == "kg-edge-scoring": + edges = variables.get("knowledge", []) + return PromptResult( + response_type="jsonl", + objects=[{"id": e["id"], "score": 10} for e in edges], + ) + elif template_id == "kg-edge-reasoning": + edges = variables.get("knowledge", []) + return PromptResult( + response_type="jsonl", + objects=[{"id": e["id"], "reasoning": "relevant"} for e in edges], + ) + elif template_id == "kg-synthesis": + return PromptResult(response_type="text", text="Answer.") + return PromptResult(response_type="text", text="") + + prompt_client.prompt.side_effect = mock_prompt + return prompt_client, embeddings_client, graph_embeddings_client, triples_client + + @pytest.mark.asyncio + async def test_dag_chain(self, mock_clients): + rag = GraphRag(*mock_clients) + events = [] + + async def explain_cb(triples, explain_id): + events.append({"explain_id": explain_id, "triples": triples}) + + await rag.query( + query="test", explain_callback=explain_cb, edge_score_limit=0, + ) + + dag = _collect_events(events) + assert len(dag) == 5, f"Expected 5 events, got {len(dag)}" + + _assert_chain(dag, [ + TG_GRAPH_RAG_QUESTION, + TG_GROUNDING, + TG_EXPLORATION, + TG_FOCUS, + TG_SYNTHESIS, + ]) + + +# --------------------------------------------------------------------------- +# DocumentRAG DAG structure +# --------------------------------------------------------------------------- + +class TestDocumentRagDagStructure: + """Verify: question → grounding → exploration → synthesis""" + + @pytest.fixture + def mock_clients(self): + from trustgraph.schema import ChunkMatch + + prompt_client = AsyncMock() + embeddings_client = AsyncMock() + doc_embeddings_client = AsyncMock() + fetch_chunk = AsyncMock(return_value="Chunk content.") + + embeddings_client.embed.return_value = [[0.1, 0.2]] + doc_embeddings_client.query.return_value = [ + ChunkMatch(chunk_id="doc/c1", score=0.9), + ] + + async def mock_prompt(template_id, variables=None, **kwargs): + if template_id == "extract-concepts": + return PromptResult(response_type="text", text="concept") + return PromptResult(response_type="text", text="") + + prompt_client.prompt.side_effect = mock_prompt + prompt_client.document_prompt.return_value = PromptResult( + response_type="text", text="Answer.", + ) + + return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk + + @pytest.mark.asyncio + async def test_dag_chain(self, mock_clients): + rag = DocumentRag(*mock_clients) + events = [] + + async def explain_cb(triples, explain_id): + events.append({"explain_id": explain_id, "triples": triples}) + + await rag.query( + query="test", explain_callback=explain_cb, + ) + + dag = _collect_events(events) + assert len(dag) == 4, f"Expected 4 events, got {len(dag)}" + + _assert_chain(dag, [ + TG_DOC_RAG_QUESTION, + TG_GROUNDING, + TG_EXPLORATION, + TG_SYNTHESIS, + ]) + + +# --------------------------------------------------------------------------- +# Agent DAG structure — tested via service.agent_request() +# --------------------------------------------------------------------------- + +def _make_processor(tools=None): + processor = MagicMock() + processor.max_iterations = 10 + processor.save_answer_content = AsyncMock() + + def mock_session_uri(sid): + return f"urn:trustgraph:agent:session:{sid}" + processor.provenance_session_uri.side_effect = mock_session_uri + + agent = MagicMock() + agent.tools = tools or {} + agent.additional_context = "" + processor.agent = agent + processor.aggregator = MagicMock() + + return processor + + +def _make_flow(): + producers = {} + + def factory(name): + if name not in producers: + producers[name] = AsyncMock() + return producers[name] + + flow = MagicMock(side_effect=factory) + return flow + + +def _collect_agent_events(respond_mock): + events = [] + for call in respond_mock.call_args_list: + resp = call[0][0] + if isinstance(resp, AgentResponse) and resp.message_type == "explain": + events.append({ + "explain_id": resp.explain_id, + "triples": resp.explain_triples, + }) + return events + + +class TestAgentReactDagStructure: + """ + Via service.agent_request(), full two-iteration react chain: + session → pattern-decision → iteration(1) → observation(1) → final + + Iteration 1: tool call → observation + Iteration 2: final answer + """ + + def _make_service(self): + from trustgraph.agent.orchestrator.service import Processor + from trustgraph.agent.orchestrator.react_pattern import ReactPattern + from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern + from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern + + mock_tool = MagicMock() + mock_tool.name = "lookup" + mock_tool.description = "Look things up" + mock_tool.arguments = [] + mock_tool.groups = [] + mock_tool.states = {} + mock_tool_impl = AsyncMock(return_value="42") + mock_tool.implementation = MagicMock(return_value=mock_tool_impl) + + processor = _make_processor(tools={"lookup": mock_tool}) + + service = Processor.__new__(Processor) + service.max_iterations = 10 + service.save_answer_content = AsyncMock() + service.provenance_session_uri = processor.provenance_session_uri + service.agent = processor.agent + service.aggregator = processor.aggregator + + service.react_pattern = ReactPattern(service) + service.plan_pattern = PlanThenExecutePattern(service) + service.supervisor_pattern = SupervisorPattern(service) + service.meta_router = None + + return service + + @pytest.mark.asyncio + async def test_dag_chain(self): + from trustgraph.agent.react.types import Action, Final + + service = self._make_service() + + respond = AsyncMock() + next_fn = AsyncMock() + flow = _make_flow() + session_id = str(uuid.uuid4()) + + # Iteration 1: tool call → returns Action, triggers on_action + tool exec + action = Action( + thought="I need to look this up", + name="lookup", + arguments={"question": "6x7"}, + observation="", + ) + + with patch( + "trustgraph.agent.orchestrator.react_pattern.AgentManager" + ) as MockAM: + mock_am = AsyncMock() + MockAM.return_value = mock_am + + async def mock_react_iter1(on_action=None, **kwargs): + if on_action: + await on_action(action) + action.observation = "42" + return action + + mock_am.react.side_effect = mock_react_iter1 + + request1 = AgentRequest( + question="What is 6x7?", + user="testuser", + collection="default", + streaming=False, + session_id=session_id, + pattern="react", + history=[], + ) + + await service.agent_request(request1, respond, next_fn, flow) + + # next_fn should have been called with updated history + assert next_fn.called + + # Iteration 2: final answer + final = Final(thought="The answer is 42", final="42") + next_request = next_fn.call_args[0][0] + + with patch( + "trustgraph.agent.orchestrator.react_pattern.AgentManager" + ) as MockAM: + mock_am = AsyncMock() + MockAM.return_value = mock_am + + async def mock_react_iter2(**kwargs): + return final + + mock_am.react.side_effect = mock_react_iter2 + + await service.agent_request(next_request, respond, next_fn, flow) + + # Collect and verify DAG + events = _collect_agent_events(respond) + dag = _collect_events(events) + + session_ids = _find_by_type(dag, TG_AGENT_QUESTION) + pd_ids = _find_by_type(dag, TG_PATTERN_DECISION) + analysis_ids = _find_by_type(dag, TG_ANALYSIS) + observation_ids = _find_by_type(dag, TG_OBSERVATION_TYPE) + final_ids = _find_by_type(dag, TG_CONCLUSION) + + assert len(session_ids) == 1, f"Expected 1 session, got {len(session_ids)}" + assert len(pd_ids) == 1, f"Expected 1 pattern-decision, got {len(pd_ids)}" + assert len(analysis_ids) >= 1, f"Expected >=1 analysis, got {len(analysis_ids)}" + assert len(observation_ids) >= 1, f"Expected >=1 observation, got {len(observation_ids)}" + assert len(final_ids) == 1, f"Expected 1 final, got {len(final_ids)}" + + # Full chain: + # session → pattern-decision + assert dag[pd_ids[0]]["derived_from"] == session_ids[0] + + # pattern-decision → iteration(1) + assert dag[analysis_ids[0]]["derived_from"] == pd_ids[0] + + # iteration(1) → observation(1) + assert dag[observation_ids[0]]["derived_from"] == analysis_ids[0] + + # observation(1) → final + assert dag[final_ids[0]]["derived_from"] == observation_ids[0] + + +class TestAgentPlanDagStructure: + """ + Via service.agent_request(): + session → pattern-decision → plan → step-result → synthesis + """ + + @pytest.mark.asyncio + async def test_dag_chain(self): + from trustgraph.agent.orchestrator.service import Processor + from trustgraph.agent.orchestrator.react_pattern import ReactPattern + from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern + from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern + + # Mock tool + mock_tool = MagicMock() + mock_tool.name = "knowledge-query" + mock_tool.description = "Query KB" + mock_tool.arguments = [] + mock_tool.groups = [] + mock_tool.states = {} + mock_tool_impl = AsyncMock(return_value="Found it") + mock_tool.implementation = MagicMock(return_value=mock_tool_impl) + + processor = _make_processor(tools={"knowledge-query": mock_tool}) + + service = Processor.__new__(Processor) + service.max_iterations = 10 + service.save_answer_content = AsyncMock() + service.provenance_session_uri = processor.provenance_session_uri + service.agent = processor.agent + service.aggregator = processor.aggregator + + service.react_pattern = ReactPattern(service) + service.plan_pattern = PlanThenExecutePattern(service) + service.supervisor_pattern = SupervisorPattern(service) + service.meta_router = None + + respond = AsyncMock() + next_fn = AsyncMock() + flow = _make_flow() + + # Mock prompt client + mock_prompt_client = AsyncMock() + + call_count = 0 + + async def mock_prompt(id, variables=None, **kwargs): + nonlocal call_count + call_count += 1 + if id == "plan-create": + return PromptResult( + response_type="jsonl", + objects=[{"goal": "Find info", "tool_hint": "knowledge-query", "depends_on": []}], + ) + elif id == "plan-step-execute": + return PromptResult( + response_type="json", + object={"tool": "knowledge-query", "arguments": {"question": "test"}}, + ) + elif id == "plan-synthesise": + return PromptResult(response_type="text", text="Final answer.") + return PromptResult(response_type="text", text="") + + mock_prompt_client.prompt.side_effect = mock_prompt + + def flow_factory(name): + if name == "prompt-request": + return mock_prompt_client + return AsyncMock() + flow.side_effect = flow_factory + + session_id = str(uuid.uuid4()) + + # Iteration 1: planning + request1 = AgentRequest( + question="Test?", + user="testuser", + collection="default", + streaming=False, + session_id=session_id, + pattern="plan-then-execute", + history=[], + ) + await service.agent_request(request1, respond, next_fn, flow) + + # Iteration 2: execute step (next_fn was called with updated request) + assert next_fn.called + next_request = next_fn.call_args[0][0] + + # Iteration 3: all steps done → synthesis + # Simulate completed step in history + next_request.history[-1].plan[0].status = "completed" + next_request.history[-1].plan[0].result = "Found it" + + await service.agent_request(next_request, respond, next_fn, flow) + + events = _collect_agent_events(respond) + dag = _collect_events(events) + + session_ids = _find_by_type(dag, TG_AGENT_QUESTION) + pd_ids = _find_by_type(dag, TG_PATTERN_DECISION) + plan_ids = _find_by_type(dag, TG_PLAN_TYPE) + synthesis_ids = _find_by_type(dag, TG_SYNTHESIS) + + assert len(session_ids) == 1 + assert len(pd_ids) == 1 + assert len(plan_ids) == 1 + assert len(synthesis_ids) == 1 + + # Chain: session → pattern-decision → plan → ... → synthesis + assert dag[pd_ids[0]]["derived_from"] == session_ids[0] + assert dag[plan_ids[0]]["derived_from"] == pd_ids[0] + + +class TestAgentSupervisorDagStructure: + """ + Via service.agent_request(): + session → pattern-decision → decomposition → (fan-out) + """ + + @pytest.mark.asyncio + async def test_dag_chain(self): + from trustgraph.agent.orchestrator.service import Processor + from trustgraph.agent.orchestrator.react_pattern import ReactPattern + from trustgraph.agent.orchestrator.plan_pattern import PlanThenExecutePattern + from trustgraph.agent.orchestrator.supervisor_pattern import SupervisorPattern + + processor = _make_processor() + + service = Processor.__new__(Processor) + service.max_iterations = 10 + service.save_answer_content = AsyncMock() + service.provenance_session_uri = processor.provenance_session_uri + service.agent = processor.agent + service.aggregator = processor.aggregator + + service.react_pattern = ReactPattern(service) + service.plan_pattern = PlanThenExecutePattern(service) + service.supervisor_pattern = SupervisorPattern(service) + service.meta_router = None + + respond = AsyncMock() + next_fn = AsyncMock() + flow = _make_flow() + + mock_prompt_client = AsyncMock() + mock_prompt_client.prompt.return_value = PromptResult( + response_type="jsonl", + objects=["Goal A", "Goal B"], + ) + + def flow_factory(name): + if name == "prompt-request": + return mock_prompt_client + return AsyncMock() + flow.side_effect = flow_factory + + request = AgentRequest( + question="Research quantum computing", + user="testuser", + collection="default", + streaming=False, + session_id=str(uuid.uuid4()), + pattern="supervisor", + history=[], + ) + + await service.agent_request(request, respond, next_fn, flow) + + events = _collect_agent_events(respond) + dag = _collect_events(events) + + session_ids = _find_by_type(dag, TG_AGENT_QUESTION) + pd_ids = _find_by_type(dag, TG_PATTERN_DECISION) + decomp_ids = _find_by_type(dag, TG_DECOMPOSITION) + + assert len(session_ids) == 1 + assert len(pd_ids) == 1 + assert len(decomp_ids) == 1 + + # Chain: session → pattern-decision → decomposition + assert dag[pd_ids[0]]["derived_from"] == session_ids[0] + assert dag[decomp_ids[0]]["derived_from"] == pd_ids[0] + + # Fan-out should have been called + assert next_fn.call_count == 2 # One per goal diff --git a/tests/unit/test_provenance/test_triples.py b/tests/unit/test_provenance/test_triples.py index 792db028..f906a00d 100644 --- a/tests/unit/test_provenance/test_triples.py +++ b/tests/unit/test_provenance/test_triples.py @@ -223,7 +223,7 @@ class TestDerivedEntityTriples: assert has_type(triples, self.ENTITY_URI, PROV_ENTITY) assert has_type(triples, self.ENTITY_URI, TG_PAGE_TYPE) - def test_chunk_entity_has_chunk_type(self): + def test_chunk_entity_has_message_type(self): triples = derived_entity_triples( self.ENTITY_URI, self.PARENT_URI, "chunker", "1.0", diff --git a/tests/unit/test_python_api_client.py b/tests/unit/test_python_api_client.py index 80443a0c..0b6709fb 100644 --- a/tests/unit/test_python_api_client.py +++ b/tests/unit/test_python_api_client.py @@ -304,14 +304,14 @@ class TestStreamingTypes: assert chunk.content == "thinking..." assert chunk.end_of_message is False - assert chunk.chunk_type == "thought" + assert chunk.message_type == "thought" def test_agent_observation_creation(self): """Test creating AgentObservation chunk""" chunk = AgentObservation(content="observing...", end_of_message=False) assert chunk.content == "observing..." - assert chunk.chunk_type == "observation" + assert chunk.message_type == "observation" def test_agent_answer_creation(self): """Test creating AgentAnswer chunk""" @@ -324,7 +324,7 @@ class TestStreamingTypes: assert chunk.content == "answer" assert chunk.end_of_message is True assert chunk.end_of_dialog is True - assert chunk.chunk_type == "final-answer" + assert chunk.message_type == "final-answer" def test_rag_chunk_creation(self): """Test creating RAGChunk""" diff --git a/tests/unit/test_retrieval/test_document_rag.py b/tests/unit/test_retrieval/test_document_rag.py index 27508ba4..1ff85f5a 100644 --- a/tests/unit/test_retrieval/test_document_rag.py +++ b/tests/unit/test_retrieval/test_document_rag.py @@ -6,6 +6,7 @@ import pytest from unittest.mock import MagicMock, AsyncMock from trustgraph.retrieval.document_rag.document_rag import DocumentRag, Query +from trustgraph.base import PromptResult # Sample chunk content mapping for tests @@ -132,7 +133,7 @@ class TestQuery: mock_rag.prompt_client = mock_prompt_client # Mock the prompt response with concept lines - mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence\ndata patterns" + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nartificial intelligence\ndata patterns") query = Query( rag=mock_rag, @@ -157,7 +158,7 @@ class TestQuery: mock_rag.prompt_client = mock_prompt_client # Mock empty response - mock_prompt_client.prompt.return_value = "" + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="") query = Query( rag=mock_rag, @@ -258,7 +259,7 @@ class TestQuery: mock_doc_embeddings_client = AsyncMock() # Mock concept extraction - mock_prompt_client.prompt.return_value = "test concept" + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="test concept") # Mock embeddings - one vector per concept test_vectors = [[0.1, 0.2, 0.3]] @@ -273,7 +274,7 @@ class TestQuery: expected_response = "This is the document RAG response" mock_doc_embeddings_client.query.return_value = [mock_match1, mock_match2] - mock_prompt_client.document_prompt.return_value = expected_response + mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text=expected_response) document_rag = DocumentRag( prompt_client=mock_prompt_client, @@ -315,7 +316,8 @@ class TestQuery: assert "Relevant document content" in docs assert "Another document" in docs - assert result == expected_response + result_text, usage = result + assert result_text == expected_response @pytest.mark.asyncio async def test_document_rag_query_with_defaults(self, mock_fetch_chunk): @@ -325,7 +327,7 @@ class TestQuery: mock_doc_embeddings_client = AsyncMock() # Mock concept extraction fallback (empty → raw query) - mock_prompt_client.prompt.return_value = "" + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="") # Mock responses mock_embeddings_client.embed.return_value = [[[0.1, 0.2]]] @@ -333,7 +335,7 @@ class TestQuery: mock_match.chunk_id = "doc/c5" mock_match.score = 0.9 mock_doc_embeddings_client.query.return_value = [mock_match] - mock_prompt_client.document_prompt.return_value = "Default response" + mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="Default response") document_rag = DocumentRag( prompt_client=mock_prompt_client, @@ -352,7 +354,8 @@ class TestQuery: collection="default" # Default collection ) - assert result == "Default response" + result_text, usage = result + assert result_text == "Default response" @pytest.mark.asyncio async def test_get_docs_with_verbose_output(self): @@ -401,7 +404,7 @@ class TestQuery: mock_doc_embeddings_client = AsyncMock() # Mock concept extraction - mock_prompt_client.prompt.return_value = "verbose query test" + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="verbose query test") # Mock responses mock_embeddings_client.embed.return_value = [[[0.3, 0.4]]] @@ -409,7 +412,7 @@ class TestQuery: mock_match.chunk_id = "doc/c7" mock_match.score = 0.92 mock_doc_embeddings_client.query.return_value = [mock_match] - mock_prompt_client.document_prompt.return_value = "Verbose RAG response" + mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="Verbose RAG response") document_rag = DocumentRag( prompt_client=mock_prompt_client, @@ -428,7 +431,8 @@ class TestQuery: assert call_args.kwargs["query"] == "verbose query test" assert "Verbose doc content" in call_args.kwargs["documents"] - assert result == "Verbose RAG response" + result_text, usage = result + assert result_text == "Verbose RAG response" @pytest.mark.asyncio async def test_get_docs_with_empty_results(self): @@ -469,11 +473,11 @@ class TestQuery: mock_doc_embeddings_client = AsyncMock() # Mock concept extraction - mock_prompt_client.prompt.return_value = "query with no matching docs" + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="query with no matching docs") mock_embeddings_client.embed.return_value = [[[0.5, 0.6]]] mock_doc_embeddings_client.query.return_value = [] - mock_prompt_client.document_prompt.return_value = "No documents found response" + mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text="No documents found response") document_rag = DocumentRag( prompt_client=mock_prompt_client, @@ -490,7 +494,8 @@ class TestQuery: documents=[] ) - assert result == "No documents found response" + result_text, usage = result + assert result_text == "No documents found response" @pytest.mark.asyncio async def test_get_vectors_with_verbose(self): @@ -525,7 +530,7 @@ class TestQuery: final_response = "Machine learning is a field of AI that enables computers to learn and improve from experience without being explicitly programmed." # Mock concept extraction - mock_prompt_client.prompt.return_value = "machine learning\nartificial intelligence" + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nartificial intelligence") # Mock embeddings - one vector per concept query_vectors = [[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]] @@ -541,7 +546,7 @@ class TestQuery: MagicMock(chunk_id="doc/ml3", score=0.82), ] mock_doc_embeddings_client.query.side_effect = [mock_matches_1, mock_matches_2] - mock_prompt_client.document_prompt.return_value = final_response + mock_prompt_client.document_prompt.return_value = PromptResult(response_type="text", text=final_response) document_rag = DocumentRag( prompt_client=mock_prompt_client, @@ -584,7 +589,8 @@ class TestQuery: assert "Common ML techniques include supervised and unsupervised learning..." in docs assert len(docs) == 3 # doc/ml2 deduplicated - assert result == final_response + result_text, usage = result + assert result_text == final_response @pytest.mark.asyncio async def test_get_docs_deduplicates_across_concepts(self): diff --git a/tests/unit/test_retrieval/test_document_rag_provenance_integration.py b/tests/unit/test_retrieval/test_document_rag_provenance_integration.py index 74157285..8fa10642 100644 --- a/tests/unit/test_retrieval/test_document_rag_provenance_integration.py +++ b/tests/unit/test_retrieval/test_document_rag_provenance_integration.py @@ -12,6 +12,7 @@ from unittest.mock import AsyncMock from dataclasses import dataclass from trustgraph.retrieval.document_rag.document_rag import DocumentRag +from trustgraph.base import PromptResult from trustgraph.provenance.namespaces import ( RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM, @@ -89,8 +90,8 @@ def build_mock_clients(): # 1. Concept extraction async def mock_prompt(template_id, variables=None, **kwargs): if template_id == "extract-concepts": - return "return policy\nrefund" - return "" + return PromptResult(response_type="text", text="return policy\nrefund") + return PromptResult(response_type="text", text="") prompt_client.prompt.side_effect = mock_prompt @@ -113,8 +114,9 @@ def build_mock_clients(): fetch_chunk.side_effect = mock_fetch # 5. Synthesis - prompt_client.document_prompt.return_value = ( - "Items can be returned within 30 days for a full refund." + prompt_client.document_prompt.return_value = PromptResult( + response_type="text", + text="Items can be returned within 30 days for a full refund.", ) return prompt_client, embeddings_client, doc_embeddings_client, fetch_chunk @@ -340,12 +342,12 @@ class TestDocumentRagQueryProvenance: clients = build_mock_clients() rag = DocumentRag(*clients) - result = await rag.query( + result_text, usage = await rag.query( query="What is the return policy?", explain_callback=AsyncMock(), ) - assert result == "Items can be returned within 30 days for a full refund." + assert result_text == "Items can be returned within 30 days for a full refund." @pytest.mark.asyncio async def test_no_explain_callback_still_works(self): @@ -353,8 +355,8 @@ class TestDocumentRagQueryProvenance: clients = build_mock_clients() rag = DocumentRag(*clients) - result = await rag.query(query="What is the return policy?") - assert result == "Items can be returned within 30 days for a full refund." + result_text, usage = await rag.query(query="What is the return policy?") + assert result_text == "Items can be returned within 30 days for a full refund." @pytest.mark.asyncio async def test_all_triples_in_retrieval_graph(self): diff --git a/tests/unit/test_retrieval/test_document_rag_service.py b/tests/unit/test_retrieval/test_document_rag_service.py index 05e1bb60..a5d42f3a 100644 --- a/tests/unit/test_retrieval/test_document_rag_service.py +++ b/tests/unit/test_retrieval/test_document_rag_service.py @@ -34,7 +34,7 @@ class TestDocumentRagService: # Setup mock DocumentRag instance mock_rag_instance = AsyncMock() mock_document_rag_class.return_value = mock_rag_instance - mock_rag_instance.query.return_value = "test response" + mock_rag_instance.query.return_value = ("test response", {"in_token": None, "out_token": None, "model": None}) # Setup message with custom user/collection msg = MagicMock() @@ -97,7 +97,7 @@ class TestDocumentRagService: # Setup mock DocumentRag instance mock_rag_instance = AsyncMock() mock_document_rag_class.return_value = mock_rag_instance - mock_rag_instance.query.return_value = "A document about cats." + mock_rag_instance.query.return_value = ("A document about cats.", {"in_token": None, "out_token": None, "model": None}) # Setup message with non-streaming request msg = MagicMock() @@ -130,4 +130,5 @@ class TestDocumentRagService: assert isinstance(sent_response, DocumentRagResponse) assert sent_response.response == "A document about cats." assert sent_response.end_of_stream is True, "Non-streaming response must have end_of_stream=True" + assert sent_response.end_of_session is True assert sent_response.error is None \ No newline at end of file diff --git a/tests/unit/test_retrieval/test_graph_rag.py b/tests/unit/test_retrieval/test_graph_rag.py index 00d8b72a..00a9551f 100644 --- a/tests/unit/test_retrieval/test_graph_rag.py +++ b/tests/unit/test_retrieval/test_graph_rag.py @@ -7,6 +7,7 @@ import unittest.mock from unittest.mock import MagicMock, AsyncMock from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, Query +from trustgraph.base import PromptResult class TestGraphRag: @@ -172,7 +173,7 @@ class TestQuery: mock_prompt_client = AsyncMock() mock_rag.prompt_client = mock_prompt_client - mock_prompt_client.prompt.return_value = "machine learning\nneural networks\n" + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="machine learning\nneural networks\n") query = Query( rag=mock_rag, @@ -196,7 +197,7 @@ class TestQuery: mock_prompt_client = AsyncMock() mock_rag.prompt_client = mock_prompt_client - mock_prompt_client.prompt.return_value = "" + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="") query = Query( rag=mock_rag, @@ -220,7 +221,7 @@ class TestQuery: mock_rag.graph_embeddings_client = mock_graph_embeddings_client # extract_concepts returns empty -> falls back to [query] - mock_prompt_client.prompt.return_value = "" + mock_prompt_client.prompt.return_value = PromptResult(response_type="text", text="") # embed returns one vector set for the single concept test_vectors = [[0.1, 0.2, 0.3]] @@ -565,14 +566,14 @@ class TestQuery: # Mock prompt responses for the multi-step process async def mock_prompt(prompt_name, variables=None, streaming=False, chunk_callback=None): if prompt_name == "extract-concepts": - return "" # Falls back to raw query + return PromptResult(response_type="text", text="") elif prompt_name == "kg-edge-scoring": - return json.dumps({"id": test_edge_id, "score": 0.9}) + return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "score": 0.9}]) elif prompt_name == "kg-edge-reasoning": - return json.dumps({"id": test_edge_id, "reasoning": "relevant"}) + return PromptResult(response_type="jsonl", objects=[{"id": test_edge_id, "reasoning": "relevant"}]) elif prompt_name == "kg-synthesis": - return expected_response - return "" + return PromptResult(response_type="text", text=expected_response) + return PromptResult(response_type="text", text="") mock_prompt_client.prompt = mock_prompt @@ -607,7 +608,8 @@ class TestQuery: explain_callback=collect_provenance ) - assert response == expected_response + response_text, usage = response + assert response_text == expected_response # 5 events: question, grounding, exploration, focus, synthesis assert len(provenance_events) == 5 diff --git a/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py b/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py index 36536f7d..1eb0dd72 100644 --- a/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py +++ b/tests/unit/test_retrieval/test_graph_rag_provenance_integration.py @@ -13,6 +13,7 @@ from dataclasses import dataclass from trustgraph.retrieval.graph_rag.graph_rag import GraphRag, edge_id from trustgraph.schema import Triple as SchemaTriple, Term, IRI, LITERAL +from trustgraph.base import PromptResult from trustgraph.provenance.namespaces import ( RDF_TYPE, PROV_ENTITY, PROV_WAS_DERIVED_FROM, @@ -136,24 +137,36 @@ def build_mock_clients(): async def mock_prompt(template_id, variables=None, **kwargs): if template_id == "extract-concepts": - return prompt_responses["extract-concepts"] + return PromptResult( + response_type="text", + text=prompt_responses["extract-concepts"], + ) elif template_id == "kg-edge-scoring": # Score all edges highly, using the IDs that GraphRag computed edges = variables.get("knowledge", []) - return [ - {"id": e["id"], "score": 10 - i} - for i, e in enumerate(edges) - ] + return PromptResult( + response_type="jsonl", + objects=[ + {"id": e["id"], "score": 10 - i} + for i, e in enumerate(edges) + ], + ) elif template_id == "kg-edge-reasoning": # Provide reasoning for each edge edges = variables.get("knowledge", []) - return [ - {"id": e["id"], "reasoning": f"Relevant edge {i}"} - for i, e in enumerate(edges) - ] + return PromptResult( + response_type="jsonl", + objects=[ + {"id": e["id"], "reasoning": f"Relevant edge {i}"} + for i, e in enumerate(edges) + ], + ) elif template_id == "kg-synthesis": - return synthesis_answer - return "" + return PromptResult( + response_type="text", + text=synthesis_answer, + ) + return PromptResult(response_type="text", text="") prompt_client.prompt.side_effect = mock_prompt @@ -413,13 +426,13 @@ class TestGraphRagQueryProvenance: async def explain_callback(triples, explain_id): events.append({"triples": triples, "explain_id": explain_id}) - result = await rag.query( + result_text, usage = await rag.query( query="What is quantum computing?", explain_callback=explain_callback, edge_score_limit=0, ) - assert result == "Quantum computing applies physics principles to computation." + assert result_text == "Quantum computing applies physics principles to computation." @pytest.mark.asyncio async def test_parent_uri_links_question_to_parent(self): @@ -450,12 +463,12 @@ class TestGraphRagQueryProvenance: clients = build_mock_clients() rag = GraphRag(*clients) - result = await rag.query( + result_text, usage = await rag.query( query="What is quantum computing?", edge_score_limit=0, ) - assert result == "Quantum computing applies physics principles to computation." + assert result_text == "Quantum computing applies physics principles to computation." @pytest.mark.asyncio async def test_all_triples_in_retrieval_graph(self): diff --git a/tests/unit/test_retrieval/test_graph_rag_service.py b/tests/unit/test_retrieval/test_graph_rag_service.py index 2cd62286..606aa7fe 100644 --- a/tests/unit/test_retrieval/test_graph_rag_service.py +++ b/tests/unit/test_retrieval/test_graph_rag_service.py @@ -44,7 +44,7 @@ class TestGraphRagService: await explain_callback([], "urn:trustgraph:prov:retrieval:test") await explain_callback([], "urn:trustgraph:prov:selection:test") await explain_callback([], "urn:trustgraph:prov:answer:test") - return "A small domesticated mammal." + return "A small domesticated mammal.", {"in_token": None, "out_token": None, "model": None} mock_rag_instance.query.side_effect = mock_query @@ -79,8 +79,8 @@ class TestGraphRagService: # Execute await processor.on_request(msg, consumer, flow) - # Verify: 6 messages sent (4 provenance + 1 chunk + 1 end_of_session) - assert mock_response_producer.send.call_count == 6 + # Verify: 5 messages sent (4 provenance + 1 combined chunk with end_of_session) + assert mock_response_producer.send.call_count == 5 # First 4 messages are explain (emitted in real-time during query) for i in range(4): @@ -88,17 +88,12 @@ class TestGraphRagService: assert prov_msg.message_type == "explain" assert prov_msg.explain_id is not None - # 5th message is chunk with response + # 5th message is chunk with response and end_of_session chunk_msg = mock_response_producer.send.call_args_list[4][0][0] assert chunk_msg.message_type == "chunk" assert chunk_msg.response == "A small domesticated mammal." assert chunk_msg.end_of_stream is True - - # 6th message is empty chunk with end_of_session=True - close_msg = mock_response_producer.send.call_args_list[5][0][0] - assert close_msg.message_type == "chunk" - assert close_msg.response == "" - assert close_msg.end_of_session is True + assert chunk_msg.end_of_session is True # Verify provenance triples were sent to provenance queue assert mock_provenance_producer.send.call_count == 4 @@ -187,7 +182,7 @@ class TestGraphRagService: async def mock_query(**kwargs): # Don't call explain_callback - return "Response text" + return "Response text", {"in_token": None, "out_token": None, "model": None} mock_rag_instance.query.side_effect = mock_query @@ -218,17 +213,12 @@ class TestGraphRagService: # Execute await processor.on_request(msg, consumer, flow) - # Verify: 2 messages (chunk + empty chunk to close) - assert mock_response_producer.send.call_count == 2 + # Verify: 1 combined message (chunk with end_of_session) + assert mock_response_producer.send.call_count == 1 - # First is the response chunk + # Single message has response and end_of_session chunk_msg = mock_response_producer.send.call_args_list[0][0][0] assert chunk_msg.message_type == "chunk" assert chunk_msg.response == "Response text" assert chunk_msg.end_of_stream is True - - # Second is empty chunk to close session - close_msg = mock_response_producer.send.call_args_list[1][0][0] - assert close_msg.message_type == "chunk" - assert close_msg.response == "" - assert close_msg.end_of_session is True + assert chunk_msg.end_of_session is True diff --git a/tests/unit/test_tables/__init__.py b/tests/unit/test_tables/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/test_tables/test_knowledge_table_store.py b/tests/unit/test_tables/test_knowledge_table_store.py new file mode 100644 index 00000000..5129b01e --- /dev/null +++ b/tests/unit/test_tables/test_knowledge_table_store.py @@ -0,0 +1,197 @@ +""" +Unit tests for KnowledgeTableStore row deserialization. + +Regression coverage: a previous version of get_graph_embeddings constructed +EntityEmbeddings(vectors=ent[1]) — the schema field is `vector` (singular), +so any real Cassandra row would crash on read. These tests bypass the live +Cassandra connection entirely and exercise the row -> schema conversion +with hand-built fake rows. +""" + +import pytest +from unittest.mock import Mock + +from trustgraph.tables.knowledge import KnowledgeTableStore +from trustgraph.schema import ( + EntityEmbeddings, + GraphEmbeddings, + Triples, + Triple, + Metadata, + IRI, + LITERAL, +) + + +def _make_store(): + """ + Build a KnowledgeTableStore without invoking __init__ (which connects + to Cassandra). Tests inject only the attributes the method under test + actually touches. + """ + return KnowledgeTableStore.__new__(KnowledgeTableStore) + + +class TestGetGraphEmbeddings: + + @pytest.mark.asyncio + async def test_row_converts_to_entity_embeddings_with_singular_vector(self): + """ + Cassandra rows return entities as a list of [entity_tuple, vector] + pairs in row[3]. The deserializer must construct EntityEmbeddings + with `vector=` (singular) — the schema field name. A previous + version used `vectors=` and TypeError'd at runtime. + """ + # Arrange — fake row matching the get_triples_stmt result shape: + # row[0..2] are unused by the method, row[3] is the entities blob + fake_row = ( + None, None, None, + [ + # ((value, is_uri), vector) + (("http://example.org/alice", True), [0.1, 0.2, 0.3]), + (("http://example.org/bob", True), [0.4, 0.5, 0.6]), + (("a literal entity", False), [0.7, 0.8, 0.9]), + ], + ) + + store = _make_store() + store.cassandra = Mock() + store.cassandra.execute = Mock(return_value=[fake_row]) + store.get_graph_embeddings_stmt = Mock() + + received = [] + + async def receiver(msg): + received.append(msg) + + # Act + await store.get_graph_embeddings( + user="alice", + document_id="doc-1", + receiver=receiver, + ) + + # Assert + store.cassandra.execute.assert_called_once_with( + store.get_graph_embeddings_stmt, + ("alice", "doc-1"), + ) + + assert len(received) == 1 + ge = received[0] + assert isinstance(ge, GraphEmbeddings) + assert isinstance(ge.metadata, Metadata) + assert ge.metadata.id == "doc-1" + assert ge.metadata.user == "alice" + + assert len(ge.entities) == 3 + assert all(isinstance(e, EntityEmbeddings) for e in ge.entities) + + # Vectors land in the singular `vector` field — this is the + # explicit regression assertion for the original bug. + assert ge.entities[0].vector == [0.1, 0.2, 0.3] + assert ge.entities[1].vector == [0.4, 0.5, 0.6] + assert ge.entities[2].vector == [0.7, 0.8, 0.9] + + # Term type round-trips through tuple_to_term + assert ge.entities[0].entity.type == IRI + assert ge.entities[0].entity.iri == "http://example.org/alice" + assert ge.entities[1].entity.type == IRI + assert ge.entities[1].entity.iri == "http://example.org/bob" + assert ge.entities[2].entity.type == LITERAL + assert ge.entities[2].entity.value == "a literal entity" + + @pytest.mark.asyncio + async def test_empty_entities_blob_yields_empty_list(self): + """row[3] being None / empty must produce a GraphEmbeddings with + no entities, not raise.""" + fake_row = (None, None, None, None) + + store = _make_store() + store.cassandra = Mock() + store.cassandra.execute = Mock(return_value=[fake_row]) + store.get_graph_embeddings_stmt = Mock() + + received = [] + + async def receiver(msg): + received.append(msg) + + await store.get_graph_embeddings("u", "d", receiver) + + assert len(received) == 1 + assert received[0].entities == [] + + @pytest.mark.asyncio + async def test_multiple_rows_each_emit_one_message(self): + fake_rows = [ + (None, None, None, [ + (("http://example.org/a", True), [1.0]), + ]), + (None, None, None, [ + (("http://example.org/b", True), [2.0]), + ]), + ] + + store = _make_store() + store.cassandra = Mock() + store.cassandra.execute = Mock(return_value=fake_rows) + store.get_graph_embeddings_stmt = Mock() + + received = [] + + async def receiver(msg): + received.append(msg) + + await store.get_graph_embeddings("u", "d", receiver) + + assert len(received) == 2 + assert received[0].entities[0].entity.iri == "http://example.org/a" + assert received[0].entities[0].vector == [1.0] + assert received[1].entities[0].entity.iri == "http://example.org/b" + assert received[1].entities[0].vector == [2.0] + + +class TestGetTriples: + """Bonus: the sibling get_triples path uses the same row[3] shape and + the same Metadata construction. Cover it for parity.""" + + @pytest.mark.asyncio + async def test_row_converts_to_triples(self): + # row[3] is a list of (s_val, s_uri, p_val, p_uri, o_val, o_uri) + fake_row = ( + None, None, None, + [ + ( + "http://example.org/alice", True, + "http://example.org/knows", True, + "http://example.org/bob", True, + ), + ], + ) + + store = _make_store() + store.cassandra = Mock() + store.cassandra.execute = Mock(return_value=[fake_row]) + store.get_triples_stmt = Mock() + + received = [] + + async def receiver(msg): + received.append(msg) + + await store.get_triples("alice", "doc-1", receiver) + + assert len(received) == 1 + triples_msg = received[0] + assert isinstance(triples_msg, Triples) + assert isinstance(triples_msg.metadata, Metadata) + assert triples_msg.metadata.id == "doc-1" + assert triples_msg.metadata.user == "alice" + + assert len(triples_msg.triples) == 1 + t = triples_msg.triples[0] + assert isinstance(t, Triple) + assert t.s.iri == "http://example.org/alice" + assert t.p.iri == "http://example.org/knows" + assert t.o.iri == "http://example.org/bob" diff --git a/tests/unit/test_translators/__init__.py b/tests/unit/test_translators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/test_translators/test_document_embeddings_translator_roundtrip.py b/tests/unit/test_translators/test_document_embeddings_translator_roundtrip.py new file mode 100644 index 00000000..72f4796b --- /dev/null +++ b/tests/unit/test_translators/test_document_embeddings_translator_roundtrip.py @@ -0,0 +1,66 @@ +""" +Round-trip unit tests for DocumentEmbeddingsTranslator. + +Regression coverage: a previous version of the decode side constructed +ChunkEmbeddings(vectors=...) — the schema field is `vector` (singular), +so any real DocumentEmbeddings message would crash on decode. The encode +side already wrote `"vector"`, so encode→decode was asymmetric. +""" + +import pytest + +from trustgraph.messaging.translators.document_loading import ( + DocumentEmbeddingsTranslator, +) +from trustgraph.schema import ( + DocumentEmbeddings, + ChunkEmbeddings, + Metadata, +) + + +@pytest.fixture +def translator(): + return DocumentEmbeddingsTranslator() + + +@pytest.fixture +def sample(): + return DocumentEmbeddings( + metadata=Metadata( + id="doc-1", + root="", + user="alice", + collection="testcoll", + ), + chunks=[ + ChunkEmbeddings(chunk_id="c1", vector=[0.1, 0.2, 0.3]), + ChunkEmbeddings(chunk_id="c2", vector=[0.4, 0.5, 0.6]), + ], + ) + + +class TestDocumentEmbeddingsTranslator: + + def test_encode_uses_singular_vector_key(self, translator, sample): + encoded = translator.encode(sample) + chunks = encoded["chunks"] + assert all("vector" in c for c in chunks) + assert all("vectors" not in c for c in chunks) + assert chunks[0]["vector"] == [0.1, 0.2, 0.3] + + def test_roundtrip_preserves_document_embeddings(self, translator, sample): + encoded = translator.encode(sample) + decoded = translator.decode(encoded) + + assert isinstance(decoded, DocumentEmbeddings) + assert isinstance(decoded.metadata, Metadata) + assert decoded.metadata.id == "doc-1" + assert decoded.metadata.user == "alice" + assert decoded.metadata.collection == "testcoll" + + assert len(decoded.chunks) == 2 + assert decoded.chunks[0].chunk_id == "c1" + assert decoded.chunks[0].vector == [0.1, 0.2, 0.3] + assert decoded.chunks[1].chunk_id == "c2" + assert decoded.chunks[1].vector == [0.4, 0.5, 0.6] diff --git a/tests/unit/test_translators/test_knowledge_translator_roundtrip.py b/tests/unit/test_translators/test_knowledge_translator_roundtrip.py new file mode 100644 index 00000000..57e7ae17 --- /dev/null +++ b/tests/unit/test_translators/test_knowledge_translator_roundtrip.py @@ -0,0 +1,153 @@ +""" +Round-trip unit tests for KnowledgeRequestTranslator. + +Regression coverage: a previous version of the decode side constructed +EntityEmbeddings(vectors=...) — the schema field is `vector` (singular), +so any real graph-embeddings KnowledgeRequest would crash on first +message. The encode side already wrote `"vector"`, so encode→decode was +asymmetric. + +These tests build a real KnowledgeRequest with graph-embeddings, encode +it, decode the result, and assert the round-trip is lossless. They also +exercise the triples path so any future schema drift in Metadata or +Triples breaks the test. +""" + +import pytest + +from trustgraph.messaging.translators.knowledge import KnowledgeRequestTranslator +from trustgraph.schema import ( + KnowledgeRequest, + GraphEmbeddings, + EntityEmbeddings, + Triples, + Triple, + Metadata, + Term, + IRI, +) + + +def _term_iri(uri): + return Term(type=IRI, iri=uri) + + +@pytest.fixture +def translator(): + return KnowledgeRequestTranslator() + + +@pytest.fixture +def graph_embeddings_request(): + return KnowledgeRequest( + operation="put-kg-core", + user="alice", + id="doc-1", + flow="default", + collection="testcoll", + graph_embeddings=GraphEmbeddings( + metadata=Metadata( + id="doc-1", + root="", + user="alice", + collection="testcoll", + ), + entities=[ + EntityEmbeddings( + entity=_term_iri("http://example.org/alice"), + vector=[0.1, 0.2, 0.3], + ), + EntityEmbeddings( + entity=_term_iri("http://example.org/bob"), + vector=[0.4, 0.5, 0.6], + ), + ], + ), + ) + + +@pytest.fixture +def triples_request(): + return KnowledgeRequest( + operation="put-kg-core", + user="alice", + id="doc-1", + flow="default", + collection="testcoll", + triples=Triples( + metadata=Metadata( + id="doc-1", + root="", + user="alice", + collection="testcoll", + ), + triples=[ + Triple( + s=_term_iri("http://example.org/alice"), + p=_term_iri("http://example.org/knows"), + o=_term_iri("http://example.org/bob"), + ), + ], + ), + ) + + +class TestKnowledgeRequestTranslatorGraphEmbeddings: + + def test_encode_produces_singular_vector_key( + self, translator, graph_embeddings_request, + ): + """The wire key must be `vector`, never `vectors`.""" + encoded = translator.encode(graph_embeddings_request) + entities = encoded["graph-embeddings"]["entities"] + assert all("vector" in e for e in entities) + assert all("vectors" not in e for e in entities) + assert entities[0]["vector"] == [0.1, 0.2, 0.3] + + def test_roundtrip_preserves_graph_embeddings( + self, translator, graph_embeddings_request, + ): + """encode -> decode must be lossless for the GE branch.""" + encoded = translator.encode(graph_embeddings_request) + decoded = translator.decode(encoded) + + assert isinstance(decoded, KnowledgeRequest) + assert decoded.operation == "put-kg-core" + assert decoded.user == "alice" + assert decoded.id == "doc-1" + assert decoded.flow == "default" + assert decoded.collection == "testcoll" + + assert decoded.graph_embeddings is not None + ge = decoded.graph_embeddings + assert isinstance(ge, GraphEmbeddings) + assert isinstance(ge.metadata, Metadata) + assert ge.metadata.id == "doc-1" + assert ge.metadata.user == "alice" + assert ge.metadata.collection == "testcoll" + + assert len(ge.entities) == 2 + assert ge.entities[0].vector == [0.1, 0.2, 0.3] + assert ge.entities[1].vector == [0.4, 0.5, 0.6] + assert ge.entities[0].entity.iri == "http://example.org/alice" + assert ge.entities[1].entity.iri == "http://example.org/bob" + + +class TestKnowledgeRequestTranslatorTriples: + + def test_roundtrip_preserves_triples(self, translator, triples_request): + encoded = translator.encode(triples_request) + decoded = translator.decode(encoded) + + assert isinstance(decoded, KnowledgeRequest) + assert decoded.triples is not None + assert isinstance(decoded.triples.metadata, Metadata) + assert decoded.triples.metadata.id == "doc-1" + assert decoded.triples.metadata.user == "alice" + assert decoded.triples.metadata.collection == "testcoll" + + assert len(decoded.triples.triples) == 1 + t = decoded.triples.triples[0] + assert t.s.iri == "http://example.org/alice" + assert t.p.iri == "http://example.org/knows" + assert t.o.iri == "http://example.org/bob" diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index 985bcbf1..c8c676c9 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -9,7 +9,7 @@ from .streaming_assertions import ( assert_streaming_content_matches, assert_no_empty_chunks, assert_streaming_error_handled, - assert_chunk_types_valid, + assert_message_types_valid, assert_streaming_latency_acceptable, assert_callback_invoked, ) @@ -23,7 +23,7 @@ __all__ = [ "assert_streaming_content_matches", "assert_no_empty_chunks", "assert_streaming_error_handled", - "assert_chunk_types_valid", + "assert_message_types_valid", "assert_streaming_latency_acceptable", "assert_callback_invoked", ] diff --git a/tests/utils/streaming_assertions.py b/tests/utils/streaming_assertions.py index cc9164ed..945bb031 100644 --- a/tests/utils/streaming_assertions.py +++ b/tests/utils/streaming_assertions.py @@ -20,14 +20,14 @@ def assert_streaming_chunks_valid(chunks: List[Any], min_chunks: int = 1): assert all(chunk is not None for chunk in chunks), "All chunks should be non-None" -def assert_streaming_sequence(chunks: List[Dict[str, Any]], expected_sequence: List[str], key: str = "chunk_type"): +def assert_streaming_sequence(chunks: List[Dict[str, Any]], expected_sequence: List[str], key: str = "message_type"): """ Assert that streaming chunks follow an expected sequence. Args: chunks: List of chunk dictionaries expected_sequence: Expected sequence of chunk types/values - key: Dictionary key to check (default: "chunk_type") + key: Dictionary key to check (default: "message_type") """ actual_sequence = [chunk.get(key) for chunk in chunks if key in chunk] assert actual_sequence == expected_sequence, \ @@ -39,7 +39,7 @@ def assert_agent_streaming_chunks(chunks: List[Dict[str, Any]]): Assert that agent streaming chunks have valid structure. Validates: - - All chunks have chunk_type field + - All chunks have message_type field - All chunks have content field - All chunks have end_of_message field - All chunks have end_of_dialog field @@ -51,15 +51,15 @@ def assert_agent_streaming_chunks(chunks: List[Dict[str, Any]]): assert len(chunks) > 0, "Expected at least one chunk" for i, chunk in enumerate(chunks): - assert "chunk_type" in chunk, f"Chunk {i} missing chunk_type" + assert "message_type" in chunk, f"Chunk {i} missing message_type" assert "content" in chunk, f"Chunk {i} missing content" assert "end_of_message" in chunk, f"Chunk {i} missing end_of_message" assert "end_of_dialog" in chunk, f"Chunk {i} missing end_of_dialog" - # Validate chunk_type values + # Validate message_type values valid_types = ["thought", "action", "observation", "final-answer"] - assert chunk["chunk_type"] in valid_types, \ - f"Invalid chunk_type '{chunk['chunk_type']}' at index {i}" + assert chunk["message_type"] in valid_types, \ + f"Invalid message_type '{chunk['message_type']}' at index {i}" # Last chunk should signal end of dialog assert chunks[-1]["end_of_dialog"] is True, \ @@ -175,7 +175,7 @@ def assert_streaming_error_handled(chunks: List[Dict[str, Any]], error_flag: str "Error chunk should have completion flag set to True" -def assert_chunk_types_valid(chunks: List[Dict[str, Any]], valid_types: List[str], type_key: str = "chunk_type"): +def assert_message_types_valid(chunks: List[Dict[str, Any]], valid_types: List[str], type_key: str = "message_type"): """ Assert that all chunk types are from a valid set. @@ -185,9 +185,9 @@ def assert_chunk_types_valid(chunks: List[Dict[str, Any]], valid_types: List[str type_key: Dictionary key for chunk type """ for i, chunk in enumerate(chunks): - chunk_type = chunk.get(type_key) - assert chunk_type in valid_types, \ - f"Chunk {i} has invalid type '{chunk_type}', expected one of {valid_types}" + message_type = chunk.get(type_key) + assert message_type in valid_types, \ + f"Chunk {i} has invalid type '{message_type}', expected one of {valid_types}" def assert_streaming_latency_acceptable(chunk_timestamps: List[float], max_gap_seconds: float = 5.0): diff --git a/trustgraph-base/pyproject.toml b/trustgraph-base/pyproject.toml index 216ccbd6..4f1bce76 100644 --- a/trustgraph-base/pyproject.toml +++ b/trustgraph-base/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "requests", "python-logging-loki", "pika", + "pyyaml", ] classifiers = [ "Programming Language :: Python :: 3", @@ -24,6 +25,9 @@ classifiers = [ [project.urls] Homepage = "https://github.com/trustgraph-ai/trustgraph" +[project.scripts] +processor-group = "trustgraph.base.processor_group:run" + [tool.setuptools.packages.find] include = ["trustgraph*"] @@ -31,4 +35,4 @@ include = ["trustgraph*"] "trustgraph.i18n.packs" = ["*.json"] [tool.setuptools.dynamic] -version = {attr = "trustgraph.base_version.__version__"} \ No newline at end of file +version = {attr = "trustgraph.base_version.__version__"} diff --git a/trustgraph-base/trustgraph/api/__init__.py b/trustgraph-base/trustgraph/api/__init__.py index 8b703dc7..2f44aad0 100644 --- a/trustgraph-base/trustgraph/api/__init__.py +++ b/trustgraph-base/trustgraph/api/__init__.py @@ -107,6 +107,7 @@ from .types import ( AgentObservation, AgentAnswer, RAGChunk, + TextCompletionResult, ProvenanceEvent, ) @@ -185,6 +186,7 @@ __all__ = [ "AgentObservation", "AgentAnswer", "RAGChunk", + "TextCompletionResult", "ProvenanceEvent", # Exceptions diff --git a/trustgraph-base/trustgraph/api/async_flow.py b/trustgraph-base/trustgraph/api/async_flow.py index 2ff37307..68899341 100644 --- a/trustgraph-base/trustgraph/api/async_flow.py +++ b/trustgraph-base/trustgraph/api/async_flow.py @@ -14,6 +14,8 @@ import aiohttp import json from typing import Optional, Dict, Any, List +from . types import TextCompletionResult + from . exceptions import ProtocolException, ApplicationException @@ -434,12 +436,11 @@ class AsyncFlowInstance: return await self.request("agent", request_data) - async def text_completion(self, system: str, prompt: str, **kwargs: Any) -> str: + async def text_completion(self, system: str, prompt: str, **kwargs: Any) -> TextCompletionResult: """ Generate text completion (non-streaming). Generates a text response from an LLM given a system prompt and user prompt. - Returns the complete response text. Note: This method does not support streaming. For streaming text generation, use AsyncSocketFlowInstance.text_completion() instead. @@ -450,19 +451,19 @@ class AsyncFlowInstance: **kwargs: Additional service-specific parameters Returns: - str: Complete generated text response + TextCompletionResult: Result with text, in_token, out_token, model Example: ```python async_flow = await api.async_flow() flow = async_flow.id("default") - # Generate text - response = await flow.text_completion( + result = await flow.text_completion( system="You are a helpful assistant.", prompt="Explain quantum computing in simple terms." ) - print(response) + print(result.text) + print(f"Tokens: {result.in_token} in, {result.out_token} out") ``` """ request_data = { @@ -473,7 +474,12 @@ class AsyncFlowInstance: request_data.update(kwargs) result = await self.request("text-completion", request_data) - return result.get("response", "") + return TextCompletionResult( + text=result.get("response", ""), + in_token=result.get("in_token"), + out_token=result.get("out_token"), + model=result.get("model"), + ) async def graph_rag(self, query: str, user: str, collection: str, max_subgraph_size: int = 1000, max_subgraph_count: int = 5, diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py index 7a239b07..6e5064ab 100644 --- a/trustgraph-base/trustgraph/api/async_socket_client.py +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -4,7 +4,7 @@ import asyncio import websockets from typing import Optional, Dict, Any, AsyncIterator, Union -from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk +from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, TextCompletionResult from . exceptions import ProtocolException, ApplicationException @@ -178,30 +178,32 @@ class AsyncSocketClient: def _parse_chunk(self, resp: Dict[str, Any]): """Parse response chunk into appropriate type. Returns None for non-content messages.""" - chunk_type = resp.get("chunk_type") message_type = resp.get("message_type") # Handle new GraphRAG message format with message_type if message_type == "provenance": return None - if chunk_type == "thought": + if message_type == "thought": return AgentThought( content=resp.get("content", ""), end_of_message=resp.get("end_of_message", False) ) - elif chunk_type == "observation": + elif message_type == "observation": return AgentObservation( content=resp.get("content", ""), end_of_message=resp.get("end_of_message", False) ) - elif chunk_type == "answer" or chunk_type == "final-answer": + elif message_type == "answer" or message_type == "final-answer": return AgentAnswer( content=resp.get("content", ""), end_of_message=resp.get("end_of_message", False), - end_of_dialog=resp.get("end_of_dialog", False) + end_of_dialog=resp.get("end_of_dialog", False), + in_token=resp.get("in_token"), + out_token=resp.get("out_token"), + model=resp.get("model"), ) - elif chunk_type == "action": + elif message_type == "action": return AgentThought( content=resp.get("content", ""), end_of_message=resp.get("end_of_message", False) @@ -211,7 +213,10 @@ class AsyncSocketClient: return RAGChunk( content=content, end_of_stream=resp.get("end_of_stream", False), - error=None + error=None, + in_token=resp.get("in_token"), + out_token=resp.get("out_token"), + model=resp.get("model"), ) async def aclose(self): @@ -269,7 +274,11 @@ class AsyncSocketFlowInstance: return await self.client._send_request("agent", self.flow_id, request) async def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs): - """Text completion with optional streaming""" + """Text completion with optional streaming. + + Non-streaming: returns a TextCompletionResult with text and token counts. + Streaming: returns an async iterator of RAGChunk (with token counts on the final chunk). + """ request = { "system": system, "prompt": prompt, @@ -281,13 +290,18 @@ class AsyncSocketFlowInstance: return self._text_completion_streaming(request) else: result = await self.client._send_request("text-completion", self.flow_id, request) - return result.get("response", "") + return TextCompletionResult( + text=result.get("response", ""), + in_token=result.get("in_token"), + out_token=result.get("out_token"), + model=result.get("model"), + ) async def _text_completion_streaming(self, request): - """Helper for streaming text completion""" + """Helper for streaming text completion. Yields RAGChunk objects.""" async for chunk in self.client._send_request_streaming("text-completion", self.flow_id, request): - if hasattr(chunk, 'content'): - yield chunk.content + if isinstance(chunk, RAGChunk): + yield chunk async def graph_rag(self, query: str, user: str, collection: str, max_subgraph_size: int = 1000, max_subgraph_count: int = 5, diff --git a/trustgraph-base/trustgraph/api/flow.py b/trustgraph-base/trustgraph/api/flow.py index 0aa55347..7ee32dad 100644 --- a/trustgraph-base/trustgraph/api/flow.py +++ b/trustgraph-base/trustgraph/api/flow.py @@ -11,7 +11,7 @@ import base64 from .. knowledge import hash, Uri, Literal, QuotedTriple from .. schema import IRI, LITERAL, TRIPLE -from . types import Triple +from . types import Triple, TextCompletionResult from . exceptions import ProtocolException @@ -360,16 +360,17 @@ class FlowInstance: prompt: User prompt/question Returns: - str: Generated response text + TextCompletionResult: Result with text, in_token, out_token, model Example: ```python flow = api.flow().id("default") - response = flow.text_completion( + result = flow.text_completion( system="You are a helpful assistant", prompt="What is quantum computing?" ) - print(response) + print(result.text) + print(f"Tokens: {result.in_token} in, {result.out_token} out") ``` """ @@ -379,10 +380,17 @@ class FlowInstance: "prompt": prompt } - return self.request( + result = self.request( "service/text-completion", input - )["response"] + ) + + return TextCompletionResult( + text=result.get("response", ""), + in_token=result.get("in_token"), + out_token=result.get("out_token"), + model=result.get("model"), + ) def agent(self, question, user="trustgraph", state=None, group=None, history=None): """ @@ -498,10 +506,17 @@ class FlowInstance: "edge-limit": edge_limit, } - return self.request( + result = self.request( "service/graph-rag", input - )["response"] + ) + + return TextCompletionResult( + text=result.get("response", ""), + in_token=result.get("in_token"), + out_token=result.get("out_token"), + model=result.get("model"), + ) def document_rag( self, query, user="trustgraph", collection="default", @@ -543,10 +558,17 @@ class FlowInstance: "doc-limit": doc_limit, } - return self.request( + result = self.request( "service/document-rag", input - )["response"] + ) + + return TextCompletionResult( + text=result.get("response", ""), + in_token=result.get("in_token"), + out_token=result.get("out_token"), + model=result.get("model"), + ) def embeddings(self, texts): """ diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index b6ceba00..c590c9b4 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -14,7 +14,7 @@ import websockets from typing import Optional, Dict, Any, Iterator, Union, List from threading import Lock -from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk, ProvenanceEvent +from . types import AgentThought, AgentObservation, AgentAnswer, RAGChunk, StreamingChunk, ProvenanceEvent, TextCompletionResult from . exceptions import ProtocolException, raise_from_error_dict @@ -360,41 +360,36 @@ class SocketClient: def _parse_chunk(self, resp: Dict[str, Any], include_provenance: bool = False) -> Optional[StreamingChunk]: """Parse response chunk into appropriate type. Returns None for non-content messages.""" - chunk_type = resp.get("chunk_type") message_type = resp.get("message_type") - # Handle GraphRAG/DocRAG message format with message_type if message_type == "explain": if include_provenance: return self._build_provenance_event(resp) return None - # Handle Agent message format with chunk_type="explain" - if chunk_type == "explain": - if include_provenance: - return self._build_provenance_event(resp) - return None - - if chunk_type == "thought": + if message_type == "thought": return AgentThought( content=resp.get("content", ""), end_of_message=resp.get("end_of_message", False), message_id=resp.get("message_id", ""), ) - elif chunk_type == "observation": + elif message_type == "observation": return AgentObservation( content=resp.get("content", ""), end_of_message=resp.get("end_of_message", False), message_id=resp.get("message_id", ""), ) - elif chunk_type == "answer" or chunk_type == "final-answer": + elif message_type == "answer" or message_type == "final-answer": return AgentAnswer( content=resp.get("content", ""), end_of_message=resp.get("end_of_message", False), end_of_dialog=resp.get("end_of_dialog", False), message_id=resp.get("message_id", ""), + in_token=resp.get("in_token"), + out_token=resp.get("out_token"), + model=resp.get("model"), ) - elif chunk_type == "action": + elif message_type == "action": return AgentThought( content=resp.get("content", ""), end_of_message=resp.get("end_of_message", False) @@ -404,7 +399,10 @@ class SocketClient: return RAGChunk( content=content, end_of_stream=resp.get("end_of_stream", False), - error=None + error=None, + in_token=resp.get("in_token"), + out_token=resp.get("out_token"), + model=resp.get("model"), ) def _build_provenance_event(self, resp: Dict[str, Any]) -> ProvenanceEvent: @@ -543,8 +541,12 @@ class SocketFlowInstance: streaming=True, include_provenance=True ) - def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[str, Iterator[str]]: - """Execute text completion with optional streaming.""" + def text_completion(self, system: str, prompt: str, streaming: bool = False, **kwargs) -> Union[TextCompletionResult, Iterator[RAGChunk]]: + """Execute text completion with optional streaming. + + Non-streaming: returns a TextCompletionResult with text and token counts. + Streaming: returns an iterator of RAGChunk (with token counts on the final chunk). + """ request = { "system": system, "prompt": prompt, @@ -557,12 +559,17 @@ class SocketFlowInstance: if streaming: return self._text_completion_generator(result) else: - return result.get("response", "") + return TextCompletionResult( + text=result.get("response", ""), + in_token=result.get("in_token"), + out_token=result.get("out_token"), + model=result.get("model"), + ) - def _text_completion_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]: + def _text_completion_generator(self, result: Iterator[StreamingChunk]) -> Iterator[RAGChunk]: for chunk in result: - if hasattr(chunk, 'content'): - yield chunk.content + if isinstance(chunk, RAGChunk): + yield chunk def graph_rag( self, @@ -577,8 +584,12 @@ class SocketFlowInstance: edge_limit: int = 25, streaming: bool = False, **kwargs: Any - ) -> Union[str, Iterator[str]]: - """Execute graph-based RAG query with optional streaming.""" + ) -> Union[TextCompletionResult, Iterator[RAGChunk]]: + """Execute graph-based RAG query with optional streaming. + + Non-streaming: returns a TextCompletionResult with text and token counts. + Streaming: returns an iterator of RAGChunk (with token counts on the final chunk). + """ request = { "query": query, "user": user, @@ -598,7 +609,12 @@ class SocketFlowInstance: if streaming: return self._rag_generator(result) else: - return result.get("response", "") + return TextCompletionResult( + text=result.get("response", ""), + in_token=result.get("in_token"), + out_token=result.get("out_token"), + model=result.get("model"), + ) def graph_rag_explain( self, @@ -642,8 +658,12 @@ class SocketFlowInstance: doc_limit: int = 10, streaming: bool = False, **kwargs: Any - ) -> Union[str, Iterator[str]]: - """Execute document-based RAG query with optional streaming.""" + ) -> Union[TextCompletionResult, Iterator[RAGChunk]]: + """Execute document-based RAG query with optional streaming. + + Non-streaming: returns a TextCompletionResult with text and token counts. + Streaming: returns an iterator of RAGChunk (with token counts on the final chunk). + """ request = { "query": query, "user": user, @@ -658,7 +678,12 @@ class SocketFlowInstance: if streaming: return self._rag_generator(result) else: - return result.get("response", "") + return TextCompletionResult( + text=result.get("response", ""), + in_token=result.get("in_token"), + out_token=result.get("out_token"), + model=result.get("model"), + ) def document_rag_explain( self, @@ -684,10 +709,10 @@ class SocketFlowInstance: streaming=True, include_provenance=True ) - def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[str]: + def _rag_generator(self, result: Iterator[StreamingChunk]) -> Iterator[RAGChunk]: for chunk in result: - if hasattr(chunk, 'content'): - yield chunk.content + if isinstance(chunk, RAGChunk): + yield chunk def prompt( self, @@ -695,8 +720,12 @@ class SocketFlowInstance: variables: Dict[str, str], streaming: bool = False, **kwargs: Any - ) -> Union[str, Iterator[str]]: - """Execute a prompt template with optional streaming.""" + ) -> Union[TextCompletionResult, Iterator[RAGChunk]]: + """Execute a prompt template with optional streaming. + + Non-streaming: returns a TextCompletionResult with text and token counts. + Streaming: returns an iterator of RAGChunk (with token counts on the final chunk). + """ request = { "id": id, "variables": variables, @@ -709,7 +738,12 @@ class SocketFlowInstance: if streaming: return self._rag_generator(result) else: - return result.get("response", "") + return TextCompletionResult( + text=result.get("text", result.get("response", "")), + in_token=result.get("in_token"), + out_token=result.get("out_token"), + model=result.get("model"), + ) def graph_embeddings_query( self, diff --git a/trustgraph-base/trustgraph/api/types.py b/trustgraph-base/trustgraph/api/types.py index 55635584..f5987b0e 100644 --- a/trustgraph-base/trustgraph/api/types.py +++ b/trustgraph-base/trustgraph/api/types.py @@ -149,10 +149,10 @@ class AgentThought(StreamingChunk): Attributes: content: Agent's thought text end_of_message: True if this completes the current thought - chunk_type: Always "thought" + message_type: Always "thought" message_id: Provenance URI of the entity being built """ - chunk_type: str = "thought" + message_type: str = "thought" message_id: str = "" @dataclasses.dataclass @@ -166,10 +166,10 @@ class AgentObservation(StreamingChunk): Attributes: content: Observation text describing tool results end_of_message: True if this completes the current observation - chunk_type: Always "observation" + message_type: Always "observation" message_id: Provenance URI of the entity being built """ - chunk_type: str = "observation" + message_type: str = "observation" message_id: str = "" @dataclasses.dataclass @@ -184,11 +184,14 @@ class AgentAnswer(StreamingChunk): content: Answer text end_of_message: True if this completes the current answer segment end_of_dialog: True if this completes the entire agent interaction - chunk_type: Always "final-answer" + message_type: Always "final-answer" """ - chunk_type: str = "final-answer" + message_type: str = "final-answer" end_of_dialog: bool = False message_id: str = "" + in_token: Optional[int] = None + out_token: Optional[int] = None + model: Optional[str] = None @dataclasses.dataclass class RAGChunk(StreamingChunk): @@ -202,11 +205,37 @@ class RAGChunk(StreamingChunk): content: Generated text content end_of_stream: True if this is the final chunk of the stream error: Optional error information if an error occurred - chunk_type: Always "rag" + in_token: Input token count (populated on the final chunk, 0 otherwise) + out_token: Output token count (populated on the final chunk, 0 otherwise) + model: Model identifier (populated on the final chunk, empty otherwise) + message_type: Always "rag" """ - chunk_type: str = "rag" + message_type: str = "rag" end_of_stream: bool = False error: Optional[Dict[str, str]] = None + in_token: Optional[int] = None + out_token: Optional[int] = None + model: Optional[str] = None + +@dataclasses.dataclass +class TextCompletionResult: + """ + Result from a text completion request. + + Returned by text_completion() in both streaming and non-streaming modes. + In streaming mode, text is None (chunks are delivered via the iterator). + In non-streaming mode, text contains the complete response. + + Attributes: + text: Complete response text (None in streaming mode) + in_token: Input token count (None if not available) + out_token: Output token count (None if not available) + model: Model identifier (None if not available) + """ + text: Optional[str] + in_token: Optional[int] = None + out_token: Optional[int] = None + model: Optional[str] = None @dataclasses.dataclass class ProvenanceEvent: diff --git a/trustgraph-base/trustgraph/base/__init__.py b/trustgraph-base/trustgraph/base/__init__.py index 24b6c1f0..ce17a585 100644 --- a/trustgraph-base/trustgraph/base/__init__.py +++ b/trustgraph-base/trustgraph/base/__init__.py @@ -18,8 +18,10 @@ from . librarian_client import LibrarianClient from . chunking_service import ChunkingService from . embeddings_service import EmbeddingsService from . embeddings_client import EmbeddingsClientSpec -from . text_completion_client import TextCompletionClientSpec -from . prompt_client import PromptClientSpec +from . text_completion_client import ( + TextCompletionClientSpec, TextCompletionClient, TextCompletionResult, +) +from . prompt_client import PromptClientSpec, PromptClient, PromptResult from . triples_store_service import TriplesStoreService from . graph_embeddings_store_service import GraphEmbeddingsStoreService from . document_embeddings_store_service import DocumentEmbeddingsStoreService diff --git a/trustgraph-base/trustgraph/base/agent_client.py b/trustgraph-base/trustgraph/base/agent_client.py index d73d03b9..393864fa 100644 --- a/trustgraph-base/trustgraph/base/agent_client.py +++ b/trustgraph-base/trustgraph/base/agent_client.py @@ -30,19 +30,19 @@ class AgentClient(RequestResponse): raise RuntimeError(resp.error.message) # Handle thought chunks - if resp.chunk_type == 'thought': + if resp.message_type == 'thought': if think: await think(resp.content, resp.end_of_message) return False # Continue receiving # Handle observation chunks - if resp.chunk_type == 'observation': + if resp.message_type == 'observation': if observe: await observe(resp.content, resp.end_of_message) return False # Continue receiving # Handle answer chunks - if resp.chunk_type == 'answer': + if resp.message_type == 'answer': if resp.content: accumulated_answer.append(resp.content) if answer_callback: diff --git a/trustgraph-base/trustgraph/base/backend.py b/trustgraph-base/trustgraph/base/backend.py index 9b9a42af..f0d6b397 100644 --- a/trustgraph-base/trustgraph/base/backend.py +++ b/trustgraph-base/trustgraph/base/backend.py @@ -58,6 +58,18 @@ class BackendProducer(Protocol): class BackendConsumer(Protocol): """Protocol for backend-specific consumer.""" + def ensure_connected(self) -> None: + """ + Eagerly establish the underlying connection and bind the queue. + + Backends that lazily connect on first receive() must implement this + so that callers can guarantee the consumer is fully bound — and + therefore able to receive responses — before any related request is + published. Backends that connect at construction time may make this + a no-op. + """ + ... + def receive(self, timeout_millis: int = 2000) -> Message: """ Receive a message from the topic. diff --git a/trustgraph-base/trustgraph/base/chunking_service.py b/trustgraph-base/trustgraph/base/chunking_service.py index d4bf4cd4..4bd78428 100644 --- a/trustgraph-base/trustgraph/base/chunking_service.py +++ b/trustgraph-base/trustgraph/base/chunking_service.py @@ -88,14 +88,14 @@ class ChunkingService(FlowProcessor): chunk_overlap = default_chunk_overlap try: - cs = flow.parameters.get("chunk-size") + cs = flow("chunk-size") if cs is not None: chunk_size = int(cs) except Exception as e: logger.warning(f"Could not parse chunk-size parameter: {e}") try: - co = flow.parameters.get("chunk-overlap") + co = flow("chunk-overlap") if co is not None: chunk_overlap = int(co) except Exception as e: diff --git a/trustgraph-base/trustgraph/base/logging.py b/trustgraph-base/trustgraph/base/logging.py index 7bab6091..93cd8fa5 100644 --- a/trustgraph-base/trustgraph/base/logging.py +++ b/trustgraph-base/trustgraph/base/logging.py @@ -8,12 +8,51 @@ ensuring consistent log formats, levels, and command-line arguments. Supports dual output to console and Loki for centralized log aggregation. """ +import contextvars import logging import logging.handlers from queue import Queue import os +# The current processor id for this task context. Read by +# _ProcessorIdFilter to stamp every LogRecord with its owning +# processor, and read by logging_loki's emitter via record.tags +# to label log lines in Loki. ContextVar so asyncio subtasks +# inherit their parent supervisor's processor id automatically. +current_processor_id = contextvars.ContextVar( + "current_processor_id", default="unknown" +) + + +def set_processor_id(pid): + """Set the processor id for the current task context. + + All subsequent log records emitted from this task — and any + asyncio tasks spawned from it — will be tagged with this id + in the console format and in Loki labels. + """ + current_processor_id.set(pid) + + +class _ProcessorIdFilter(logging.Filter): + """Stamps every LogRecord with processor_id from the contextvar. + + Attaches two fields to each record: + record.processor_id — used by the console format string + record.tags — merged into Loki labels by logging_loki's + emitter (it reads record.tags and combines + with the handler's static tags) + """ + + def filter(self, record): + pid = current_processor_id.get() + record.processor_id = pid + existing = getattr(record, "tags", None) or {} + record.tags = {**existing, "processor": pid} + return True + + def add_logging_args(parser): """ Add standard logging arguments to an argument parser. @@ -87,12 +126,15 @@ def setup_logging(args): loki_url = args.get('loki_url', 'http://loki:3100/loki/api/v1/push') loki_username = args.get('loki_username') loki_password = args.get('loki_password') - processor_id = args.get('id') # Processor identity (e.g., "config-svc", "text-completion") try: from logging_loki import LokiHandler - # Create Loki handler with optional authentication and processor label + # Create Loki handler with optional authentication. The + # processor label is NOT baked in here — it's stamped onto + # each record by _ProcessorIdFilter reading the task-local + # contextvar, and logging_loki's emitter reads record.tags + # to build per-record Loki labels. loki_handler_kwargs = { 'url': loki_url, 'version': "1", @@ -101,10 +143,6 @@ def setup_logging(args): if loki_username and loki_password: loki_handler_kwargs['auth'] = (loki_username, loki_password) - # Add processor label if available (for consistency with Prometheus metrics) - if processor_id: - loki_handler_kwargs['tags'] = {'processor': processor_id} - loki_handler = LokiHandler(**loki_handler_kwargs) # Wrap in QueueHandler for non-blocking operation @@ -133,23 +171,44 @@ def setup_logging(args): print(f"WARNING: Failed to setup Loki logging: {e}") print("Continuing with console-only logging") - # Get processor ID for log formatting (use 'unknown' if not available) - processor_id = args.get('id', 'unknown') - - # Configure logging with all handlers - # Use processor ID as the primary identifier in logs + # Configure logging with all handlers. The processor id comes + # from _ProcessorIdFilter (via contextvar) and is injected into + # each record as record.processor_id. The format string reads + # that attribute on every emit. logging.basicConfig( level=getattr(logging, log_level.upper()), - format=f'%(asctime)s - {processor_id} - %(levelname)s - %(message)s', + format='%(asctime)s - %(processor_id)s - %(levelname)s - %(message)s', handlers=handlers, force=True # Force reconfiguration if already configured ) - # Prevent recursive logging from Loki's HTTP client - if loki_enabled and queue_listener: - # Disable urllib3 logging to prevent infinite loop - logging.getLogger('urllib3').setLevel(logging.WARNING) - logging.getLogger('urllib3.connectionpool').setLevel(logging.WARNING) + # Attach the processor-id filter to every handler so all records + # passing through any sink get stamped (console, queue→loki, + # future handlers). Filters on handlers run regardless of which + # logger originated the record, so logs from pika, cassandra, + # processor code, etc. all pass through it. + processor_filter = _ProcessorIdFilter() + for h in handlers: + h.addFilter(processor_filter) + + # Seed the contextvar from --id if one was supplied. In group + # mode --id isn't present; the processor_group supervisor sets + # it per task. In standalone mode AsyncProcessor.launch provides + # it via argparse default. + if args.get('id'): + set_processor_id(args['id']) + + # Silence noisy third-party library loggers. These emit INFO-level + # chatter (connection churn, channel open/close, driver warnings) that + # drowns the useful signal and can't be attributed to a specific + # processor anyway. WARNING and above still propagate. + for noisy in ( + 'pika', + 'cassandra', + 'urllib3', + 'urllib3.connectionpool', + ): + logging.getLogger(noisy).setLevel(logging.WARNING) logger = logging.getLogger(__name__) logger.info(f"Logging configured with level: {log_level}") diff --git a/trustgraph-base/trustgraph/base/processor_group.py b/trustgraph-base/trustgraph/base/processor_group.py new file mode 100644 index 00000000..d27b82c4 --- /dev/null +++ b/trustgraph-base/trustgraph/base/processor_group.py @@ -0,0 +1,204 @@ + +# Multi-processor group runner. Runs multiple AsyncProcessor descendants +# as concurrent tasks inside a single process, sharing one event loop, +# one Prometheus HTTP server, and one pub/sub backend pool. +# +# Intended for dev and resource-constrained deployments. Scale deployments +# should continue to use per-processor endpoints. +# +# Group config is a YAML or JSON file with shape: +# +# processors: +# - class: trustgraph.extract.kg.definitions.extract.Processor +# params: +# id: kg-extract-definitions +# triples_batch_size: 1000 +# - class: trustgraph.chunking.recursive.Processor +# params: +# id: chunker-recursive +# +# Each entry's params are passed directly to the class constructor alongside +# the shared taskgroup. Defaults live inside each processor class. + +import argparse +import asyncio +import importlib +import json +import logging +import time + +from prometheus_client import start_http_server + +from . logging import add_logging_args, setup_logging, set_processor_id + +logger = logging.getLogger(__name__) + + +def _load_config(path): + with open(path) as f: + text = f.read() + if path.endswith((".yaml", ".yml")): + import yaml + return yaml.safe_load(text) + return json.loads(text) + + +def _resolve_class(dotted): + module_path, _, class_name = dotted.rpartition(".") + if not module_path: + raise ValueError( + f"Processor class must be a dotted path, got {dotted!r}" + ) + module = importlib.import_module(module_path) + return getattr(module, class_name) + + +RESTART_DELAY_SECONDS = 4 + + +async def _supervise(entry): + """Run one processor with its own nested TaskGroup, restarting on any + failure. Each processor is isolated from its siblings — a crash here + does not propagate to the outer group.""" + + pid = entry["params"]["id"] + class_path = entry["class"] + + # Stamp the contextvar for this supervisor task. Every log + # record emitted from this task — and from any inner TaskGroup + # child created by the processor — inherits this id via + # contextvar propagation. Siblings in the outer group set + # their own id in their own task context and do not interfere. + set_processor_id(pid) + + while True: + + try: + + async with asyncio.TaskGroup() as inner_tg: + + cls = _resolve_class(class_path) + params = dict(entry.get("params", {})) + params["taskgroup"] = inner_tg + + logger.info(f"Starting {class_path} as {pid}") + + p = cls(**params) + await p.start() + inner_tg.create_task(p.run()) + + # Clean exit — processor's run() returned without raising. + # Treat as a transient shutdown and restart, matching the + # behaviour of per-container `restart: on-failure`. + logger.warning( + f"Processor {pid} exited cleanly, will restart" + ) + + except asyncio.CancelledError: + logger.info(f"Processor {pid} cancelled") + raise + + except BaseExceptionGroup as eg: + for e in eg.exceptions: + logger.error( + f"Processor {pid} failure: {type(e).__name__}: {e}", + exc_info=e, + ) + + except Exception as e: + logger.error( + f"Processor {pid} failure: {type(e).__name__}: {e}", + exc_info=True, + ) + + logger.info( + f"Restarting {pid} in {RESTART_DELAY_SECONDS}s..." + ) + await asyncio.sleep(RESTART_DELAY_SECONDS) + + +async def run_group(config): + + entries = config.get("processors", []) + if not entries: + raise RuntimeError("Group config has no processors") + + seen_ids = set() + for entry in entries: + pid = entry.get("params", {}).get("id") + if pid is None: + raise RuntimeError( + f"Entry {entry.get('class')!r} missing params.id — " + f"required for metrics labelling" + ) + if pid in seen_ids: + raise RuntimeError(f"Duplicate processor id {pid!r} in group") + seen_ids.add(pid) + + async with asyncio.TaskGroup() as outer_tg: + for entry in entries: + outer_tg.create_task(_supervise(entry)) + + +def run(): + + parser = argparse.ArgumentParser( + prog="processor-group", + description="Run multiple processors as tasks in one process", + ) + + parser.add_argument( + "-c", "--config", + required=True, + help="Path to group config file (JSON or YAML)", + ) + + parser.add_argument( + "--metrics", + action=argparse.BooleanOptionalAction, + default=True, + help="Metrics enabled (default: true)", + ) + + parser.add_argument( + "-P", "--metrics-port", + type=int, + default=8000, + help="Prometheus metrics port (default: 8000)", + ) + + add_logging_args(parser) + + args = vars(parser.parse_args()) + + setup_logging(args) + + config = _load_config(args["config"]) + + if args["metrics"]: + start_http_server(args["metrics_port"]) + + while True: + + logger.info("Starting group...") + + try: + asyncio.run(run_group(config)) + + except KeyboardInterrupt: + logger.info("Keyboard interrupt.") + return + + except ExceptionGroup as e: + logger.error("Exception group:") + for se in e.exceptions: + logger.error(f" Type: {type(se)}") + logger.error(f" Exception: {se}", exc_info=se) + + except Exception as e: + logger.error(f"Type: {type(e)}") + logger.error(f"Exception: {e}", exc_info=True) + + logger.warning("Will retry...") + time.sleep(4) + logger.info("Retrying...") diff --git a/trustgraph-base/trustgraph/base/prompt_client.py b/trustgraph-base/trustgraph/base/prompt_client.py index 6859a9f0..853e7e66 100644 --- a/trustgraph-base/trustgraph/base/prompt_client.py +++ b/trustgraph-base/trustgraph/base/prompt_client.py @@ -1,10 +1,22 @@ import json import asyncio +from dataclasses import dataclass +from typing import Optional, Any from . request_response_spec import RequestResponse, RequestResponseSpec from .. schema import PromptRequest, PromptResponse +@dataclass +class PromptResult: + response_type: str # "text", "json", or "jsonl" + text: Optional[str] = None # populated for "text" + object: Any = None # populated for "json" + objects: Optional[list] = None # populated for "jsonl" + in_token: Optional[int] = None + out_token: Optional[int] = None + model: Optional[str] = None + class PromptClient(RequestResponse): async def prompt(self, id, variables, timeout=600, streaming=False, chunk_callback=None): @@ -26,17 +38,40 @@ class PromptClient(RequestResponse): if resp.error: raise RuntimeError(resp.error.message) - if resp.text: return resp.text + if resp.text: + return PromptResult( + response_type="text", + text=resp.text, + in_token=resp.in_token, + out_token=resp.out_token, + model=resp.model, + ) - return json.loads(resp.object) + parsed = json.loads(resp.object) + + if isinstance(parsed, list): + return PromptResult( + response_type="jsonl", + objects=parsed, + in_token=resp.in_token, + out_token=resp.out_token, + model=resp.model, + ) + + return PromptResult( + response_type="json", + object=parsed, + in_token=resp.in_token, + out_token=resp.out_token, + model=resp.model, + ) else: - last_text = "" - last_object = None + last_resp = None async def forward_chunks(resp): - nonlocal last_text, last_object + nonlocal last_resp if resp.error: raise RuntimeError(resp.error.message) @@ -44,14 +79,13 @@ class PromptClient(RequestResponse): end_stream = getattr(resp, 'end_of_stream', False) if resp.text is not None: - last_text = resp.text if chunk_callback: if asyncio.iscoroutinefunction(chunk_callback): await chunk_callback(resp.text, end_stream) else: chunk_callback(resp.text, end_stream) - elif resp.object: - last_object = resp.object + + last_resp = resp return end_stream @@ -70,10 +104,36 @@ class PromptClient(RequestResponse): timeout=timeout ) - if last_text: - return last_text + if last_resp is None: + return PromptResult(response_type="text") - return json.loads(last_object) if last_object else None + if last_resp.object: + parsed = json.loads(last_resp.object) + + if isinstance(parsed, list): + return PromptResult( + response_type="jsonl", + objects=parsed, + in_token=last_resp.in_token, + out_token=last_resp.out_token, + model=last_resp.model, + ) + + return PromptResult( + response_type="json", + object=parsed, + in_token=last_resp.in_token, + out_token=last_resp.out_token, + model=last_resp.model, + ) + + return PromptResult( + response_type="text", + text=last_resp.text, + in_token=last_resp.in_token, + out_token=last_resp.out_token, + model=last_resp.model, + ) async def extract_definitions(self, text, timeout=600): return await self.prompt( @@ -152,4 +212,3 @@ class PromptClientSpec(RequestResponseSpec): response_schema = PromptResponse, impl = PromptClient, ) - diff --git a/trustgraph-base/trustgraph/base/pulsar_backend.py b/trustgraph-base/trustgraph/base/pulsar_backend.py index a567191e..6f125399 100644 --- a/trustgraph-base/trustgraph/base/pulsar_backend.py +++ b/trustgraph-base/trustgraph/base/pulsar_backend.py @@ -72,6 +72,16 @@ class PulsarBackendConsumer: self._consumer = pulsar_consumer self._schema_cls = schema_cls + def ensure_connected(self) -> None: + """No-op for Pulsar. + + PulsarBackend.create_consumer() calls client.subscribe() which is + synchronous and returns a fully-subscribed consumer, so the + consumer is already ready by the time this object is constructed. + Defined for parity with the BackendConsumer protocol used by + Subscriber.start()'s readiness barrier.""" + pass + def receive(self, timeout_millis: int = 2000) -> Message: """Receive a message. Raises TimeoutError if no message available.""" try: diff --git a/trustgraph-base/trustgraph/base/rabbitmq_backend.py b/trustgraph-base/trustgraph/base/rabbitmq_backend.py index 3fafcead..7de51a0a 100644 --- a/trustgraph-base/trustgraph/base/rabbitmq_backend.py +++ b/trustgraph-base/trustgraph/base/rabbitmq_backend.py @@ -214,16 +214,43 @@ class RabbitMQBackendConsumer: and self._channel.is_open ) + def ensure_connected(self) -> None: + """Eagerly declare and bind the queue. + + Without this, the queue is only declared lazily on the first + receive() call. For request/response with ephemeral per-subscriber + response queues that is a race: a request published before the + response queue is bound will have its reply silently dropped by + the broker. Subscriber.start() calls this so callers get a hard + readiness barrier.""" + if not self._is_alive(): + self._connect() + def receive(self, timeout_millis: int = 2000) -> Message: - """Receive a message. Raises TimeoutError if none available.""" + """Receive a message. Raises TimeoutError if none available. + + Loop ordering matters: check _incoming at the TOP of each + iteration, not as the loop condition. process_data_events + may dispatch a message via the _on_message callback during + the pump; we must re-check _incoming on the next iteration + before giving up on the deadline. The previous control + flow (`while deadline: check; pump`) could lose a wakeup if + the pump consumed the remainder of the window — the + `while` check would fail before `_incoming` was re-read, + leaving a just-dispatched message stranded until the next + receive() call one full poll cycle later. + """ if not self._is_alive(): self._connect() timeout_seconds = timeout_millis / 1000.0 deadline = time.monotonic() + timeout_seconds - while time.monotonic() < deadline: - # Check if a message was already delivered + while True: + # Check if a message has been dispatched to our queue. + # This catches both (a) messages dispatched before this + # receive() was called and (b) messages dispatched + # during the previous iteration's process_data_events. try: method, properties, body = self._incoming.get_nowait() return RabbitMQMessage( @@ -232,14 +259,16 @@ class RabbitMQBackendConsumer: except queue.Empty: pass - # Drive pika's I/O — delivers messages and processes heartbeats remaining = deadline - time.monotonic() - if remaining > 0: - self._connection.process_data_events( - time_limit=min(0.1, remaining), - ) + if remaining <= 0: + raise TimeoutError("No message received within timeout") - raise TimeoutError("No message received within timeout") + # Drive pika's I/O. Any messages delivered during this + # call land in _incoming via _on_message; the next + # iteration of this loop catches them at the top. + self._connection.process_data_events( + time_limit=min(0.1, remaining), + ) def acknowledge(self, message: Message) -> None: if isinstance(message, RabbitMQMessage) and message._method: diff --git a/trustgraph-base/trustgraph/base/subscriber.py b/trustgraph-base/trustgraph/base/subscriber.py index 8c68e51c..82ffe8e2 100644 --- a/trustgraph-base/trustgraph/base/subscriber.py +++ b/trustgraph-base/trustgraph/base/subscriber.py @@ -41,14 +41,55 @@ class Subscriber: self.consumer = None self.executor = None + # Readiness barrier — completed by run() once the underlying + # backend consumer is fully connected and bound. start() awaits + # this so callers know any subsequently published request will + # have a queue ready to receive its response. Without this, + # ephemeral per-subscriber response queues (RabbitMQ auto-delete + # exclusive queues) would race the request and lose the reply. + # A Future is used (rather than an Event) so that a first-attempt + # connection failure can be propagated to start() as an exception. + self._ready = None # created in start() so we have a running loop + def __del__(self): self.running = False async def start(self): + self._ready = asyncio.get_event_loop().create_future() self.task = asyncio.create_task(self.run()) + # Block until run() signals readiness OR exits. The future + # carries the outcome of the first connect attempt: a value on + # success, an exception on first-attempt failure. If run() exits + # without ever signalling (e.g. cancelled, or a code path bug), + # we surface that as a clear RuntimeError rather than hanging + # forever waiting on the future. + ready_wait = asyncio.ensure_future( + asyncio.shield(self._ready) + ) + try: + await asyncio.wait( + {self.task, ready_wait}, + return_when=asyncio.FIRST_COMPLETED, + ) + finally: + ready_wait.cancel() + + if self._ready.done(): + # Re-raise first-attempt connect failure if any. + self._ready.result() + return + + # run() exited before _ready was settled. Propagate its exception + # if it had one, otherwise raise a generic readiness error. + if self.task.done() and self.task.exception() is not None: + raise self.task.exception() + raise RuntimeError( + "Subscriber.run() exited before signalling readiness" + ) + async def stop(self): """Initiate graceful shutdown with draining""" self.running = False @@ -66,6 +107,7 @@ class Subscriber: async def run(self): """Enhanced run method with integrated draining logic""" + first_attempt = True while self.running or self.draining: if self.metrics: @@ -87,10 +129,27 @@ class Subscriber: ), ) + # Eagerly bind the queue. For backends that connect + # lazily on first receive (RabbitMQ), this is what + # closes the request/response setup race — without + # it the response queue is not bound until later and + # any reply published in the meantime is dropped. + await loop.run_in_executor( + self.executor, + lambda: self.consumer.ensure_connected(), + ) + if self.metrics: self.metrics.state("running") logger.info("Subscriber running...") + + # Signal start() that the consumer is ready. This must + # happen AFTER ensure_connected() above so callers can + # safely publish requests immediately after start() returns. + if first_attempt and not self._ready.done(): + self._ready.set_result(None) + first_attempt = False drain_end_time = None while self.running or self.draining: @@ -162,6 +221,16 @@ class Subscriber: except Exception as e: logger.error(f"Subscriber exception: {e}", exc_info=True) + # First-attempt connection failure: propagate to start() + # so the caller can decide what to do (retry, give up). + # Subsequent failures use the existing retry-with-backoff + # path so a long-lived subscriber survives broker blips. + if first_attempt and not self._ready.done(): + self._ready.set_exception(e) + first_attempt = False + # Falls through into finally for cleanup, then the + # outer return below ends run() so start() unblocks. + finally: # Negative acknowledge any pending messages for msg in self.pending_acks.values(): @@ -193,6 +262,11 @@ class Subscriber: if not self.running and not self.draining: return + # If start() has already returned with an exception there is + # nothing more to do — exit run() rather than busy-retry. + if self._ready.done() and self._ready.exception() is not None: + return + # Sleep before retry await asyncio.sleep(1) diff --git a/trustgraph-base/trustgraph/base/text_completion_client.py b/trustgraph-base/trustgraph/base/text_completion_client.py index ae93e22e..876d71df 100644 --- a/trustgraph-base/trustgraph/base/text_completion_client.py +++ b/trustgraph-base/trustgraph/base/text_completion_client.py @@ -1,47 +1,71 @@ +from dataclasses import dataclass +from typing import Optional + from . request_response_spec import RequestResponse, RequestResponseSpec from .. schema import TextCompletionRequest, TextCompletionResponse +@dataclass +class TextCompletionResult: + text: Optional[str] + in_token: Optional[int] = None + out_token: Optional[int] = None + model: Optional[str] = None + class TextCompletionClient(RequestResponse): - async def text_completion(self, system, prompt, streaming=False, timeout=600): - # If not streaming, use original behavior - if not streaming: - resp = await self.request( - TextCompletionRequest( - system = system, prompt = prompt, streaming = False - ), - timeout=timeout - ) - if resp.error: - raise RuntimeError(resp.error.message) + async def text_completion(self, system, prompt, timeout=600): - return resp.response - - # For streaming: collect all chunks and return complete response - full_response = "" - - async def collect_chunks(resp): - nonlocal full_response - - if resp.error: - raise RuntimeError(resp.error.message) - - if resp.response: - full_response += resp.response - - # Return True when end_of_stream is reached - return getattr(resp, 'end_of_stream', False) - - await self.request( + resp = await self.request( TextCompletionRequest( - system = system, prompt = prompt, streaming = True + system = system, prompt = prompt, streaming = False ), - recipient=collect_chunks, timeout=timeout ) - return full_response + if resp.error: + raise RuntimeError(resp.error.message) + + return TextCompletionResult( + text = resp.response, + in_token = resp.in_token, + out_token = resp.out_token, + model = resp.model, + ) + + async def text_completion_stream( + self, system, prompt, handler, timeout=600, + ): + """ + Streaming text completion. `handler` is an async callable invoked + once per chunk with the chunk's TextCompletionResponse. Returns a + TextCompletionResult with text=None and token counts / model taken + from the end_of_stream message. + """ + + async def on_chunk(resp): + + if resp.error: + raise RuntimeError(resp.error.message) + + await handler(resp) + + return getattr(resp, "end_of_stream", False) + + final = await self.request( + TextCompletionRequest( + system = system, prompt = prompt, streaming = True + ), + recipient=on_chunk, + timeout=timeout, + ) + + return TextCompletionResult( + text = None, + in_token = final.in_token, + out_token = final.out_token, + model = final.model, + ) class TextCompletionClientSpec(RequestResponseSpec): def __init__( @@ -54,4 +78,3 @@ class TextCompletionClientSpec(RequestResponseSpec): response_schema = TextCompletionResponse, impl = TextCompletionClient, ) - diff --git a/trustgraph-base/trustgraph/clients/agent_client.py b/trustgraph-base/trustgraph/clients/agent_client.py index 1cadbdd5..d17ea37a 100644 --- a/trustgraph-base/trustgraph/clients/agent_client.py +++ b/trustgraph-base/trustgraph/clients/agent_client.py @@ -58,23 +58,23 @@ class AgentClient(BaseClient): def inspect(x): # Handle errors - if x.chunk_type == 'error' or x.error: + if x.message_type == 'error' or x.error: if error_callback: error_callback(x.content or (x.error.message if x.error else "")) # Continue to check end_of_dialog # Handle thought chunks - elif x.chunk_type == 'thought': + elif x.message_type == 'thought': if think: think(x.content, x.end_of_message) # Handle observation chunks - elif x.chunk_type == 'observation': + elif x.message_type == 'observation': if observe: observe(x.content, x.end_of_message) # Handle answer chunks - elif x.chunk_type == 'answer': + elif x.message_type == 'answer': if x.content: accumulated_answer.append(x.content) if answer_callback: diff --git a/trustgraph-base/trustgraph/messaging/translators/agent.py b/trustgraph-base/trustgraph/messaging/translators/agent.py index 8cf525f5..7df59907 100644 --- a/trustgraph-base/trustgraph/messaging/translators/agent.py +++ b/trustgraph-base/trustgraph/messaging/translators/agent.py @@ -60,8 +60,8 @@ class AgentResponseTranslator(MessageTranslator): def encode(self, obj: AgentResponse) -> Dict[str, Any]: result = {} - if obj.chunk_type: - result["chunk_type"] = obj.chunk_type + if obj.message_type: + result["message_type"] = obj.message_type if obj.content: result["content"] = obj.content result["end_of_message"] = getattr(obj, "end_of_message", False) @@ -90,6 +90,13 @@ class AgentResponseTranslator(MessageTranslator): if hasattr(obj, 'error') and obj.error and obj.error.message: result["error"] = {"message": obj.error.message, "code": obj.error.code} + if obj.in_token is not None: + result["in_token"] = obj.in_token + if obj.out_token is not None: + result["out_token"] = obj.out_token + if obj.model is not None: + result["model"] = obj.model + return result def encode_with_completion(self, obj: AgentResponse) -> Tuple[Dict[str, Any], bool]: diff --git a/trustgraph-base/trustgraph/messaging/translators/document_loading.py b/trustgraph-base/trustgraph/messaging/translators/document_loading.py index 3e7062e2..df2aa3ba 100644 --- a/trustgraph-base/trustgraph/messaging/translators/document_loading.py +++ b/trustgraph-base/trustgraph/messaging/translators/document_loading.py @@ -151,7 +151,7 @@ class DocumentEmbeddingsTranslator(SendTranslator): chunks = [ ChunkEmbeddings( chunk_id=chunk["chunk_id"], - vectors=chunk["vectors"] + vector=chunk["vector"] ) for chunk in data.get("chunks", []) ] diff --git a/trustgraph-base/trustgraph/messaging/translators/knowledge.py b/trustgraph-base/trustgraph/messaging/translators/knowledge.py index 2f11d75a..f819dc9c 100644 --- a/trustgraph-base/trustgraph/messaging/translators/knowledge.py +++ b/trustgraph-base/trustgraph/messaging/translators/knowledge.py @@ -39,7 +39,7 @@ class KnowledgeRequestTranslator(MessageTranslator): entities=[ EntityEmbeddings( entity=self.value_translator.decode(ent["entity"]), - vectors=ent["vectors"], + vector=ent["vector"], ) for ent in data["graph-embeddings"]["entities"] ] diff --git a/trustgraph-base/trustgraph/messaging/translators/prompt.py b/trustgraph-base/trustgraph/messaging/translators/prompt.py index 4345e6fd..7f76bf4a 100644 --- a/trustgraph-base/trustgraph/messaging/translators/prompt.py +++ b/trustgraph-base/trustgraph/messaging/translators/prompt.py @@ -53,6 +53,13 @@ class PromptResponseTranslator(MessageTranslator): # Always include end_of_stream flag for streaming support result["end_of_stream"] = getattr(obj, "end_of_stream", False) + if obj.in_token is not None: + result["in_token"] = obj.in_token + if obj.out_token is not None: + result["out_token"] = obj.out_token + if obj.model is not None: + result["model"] = obj.model + return result def encode_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]: diff --git a/trustgraph-base/trustgraph/messaging/translators/retrieval.py b/trustgraph-base/trustgraph/messaging/translators/retrieval.py index 849bee94..e37b76e1 100644 --- a/trustgraph-base/trustgraph/messaging/translators/retrieval.py +++ b/trustgraph-base/trustgraph/messaging/translators/retrieval.py @@ -74,6 +74,13 @@ class DocumentRagResponseTranslator(MessageTranslator): if hasattr(obj, 'error') and obj.error and obj.error.message: result["error"] = {"message": obj.error.message, "type": obj.error.type} + if obj.in_token is not None: + result["in_token"] = obj.in_token + if obj.out_token is not None: + result["out_token"] = obj.out_token + if obj.model is not None: + result["model"] = obj.model + return result def encode_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]: @@ -163,6 +170,13 @@ class GraphRagResponseTranslator(MessageTranslator): if hasattr(obj, 'error') and obj.error and obj.error.message: result["error"] = {"message": obj.error.message, "type": obj.error.type} + if obj.in_token is not None: + result["in_token"] = obj.in_token + if obj.out_token is not None: + result["out_token"] = obj.out_token + if obj.model is not None: + result["model"] = obj.model + return result def encode_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]: diff --git a/trustgraph-base/trustgraph/messaging/translators/text_completion.py b/trustgraph-base/trustgraph/messaging/translators/text_completion.py index 596ff744..62cc4afb 100644 --- a/trustgraph-base/trustgraph/messaging/translators/text_completion.py +++ b/trustgraph-base/trustgraph/messaging/translators/text_completion.py @@ -29,11 +29,11 @@ class TextCompletionResponseTranslator(MessageTranslator): def encode(self, obj: TextCompletionResponse) -> Dict[str, Any]: result = {"response": obj.response} - if obj.in_token: + if obj.in_token is not None: result["in_token"] = obj.in_token - if obj.out_token: + if obj.out_token is not None: result["out_token"] = obj.out_token - if obj.model: + if obj.model is not None: result["model"] = obj.model # Always include end_of_stream flag for streaming support diff --git a/trustgraph-base/trustgraph/provenance/__init__.py b/trustgraph-base/trustgraph/provenance/__init__.py index e6ce0a9e..051efc66 100644 --- a/trustgraph-base/trustgraph/provenance/__init__.py +++ b/trustgraph-base/trustgraph/provenance/__init__.py @@ -59,6 +59,7 @@ from . uris import ( agent_plan_uri, agent_step_result_uri, agent_synthesis_uri, + agent_pattern_decision_uri, # Document RAG provenance URIs docrag_question_uri, docrag_grounding_uri, @@ -102,6 +103,11 @@ from . namespaces import ( # Agent provenance predicates TG_THOUGHT, TG_ACTION, TG_ARGUMENTS, TG_OBSERVATION, TG_SUBAGENT_GOAL, TG_PLAN_STEP, + TG_TOOL_CANDIDATE, TG_TERMINATION_REASON, + TG_STEP_NUMBER, TG_PATTERN_DECISION, TG_PATTERN, TG_TASK_TYPE, + TG_LLM_DURATION_MS, TG_TOOL_DURATION_MS, TG_TOOL_ERROR, + TG_IN_TOKEN, TG_OUT_TOKEN, + TG_ERROR_TYPE, # Orchestrator entity types TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT, # Document reference predicate @@ -141,6 +147,7 @@ from . agent import ( agent_plan_triples, agent_step_result_triples, agent_synthesis_triples, + agent_pattern_decision_triples, ) # Vocabulary bootstrap @@ -182,6 +189,7 @@ __all__ = [ "agent_plan_uri", "agent_step_result_uri", "agent_synthesis_uri", + "agent_pattern_decision_uri", # Document RAG provenance URIs "docrag_question_uri", "docrag_grounding_uri", @@ -218,6 +226,11 @@ __all__ = [ # Agent provenance predicates "TG_THOUGHT", "TG_ACTION", "TG_ARGUMENTS", "TG_OBSERVATION", "TG_SUBAGENT_GOAL", "TG_PLAN_STEP", + "TG_TOOL_CANDIDATE", "TG_TERMINATION_REASON", + "TG_STEP_NUMBER", "TG_PATTERN_DECISION", "TG_PATTERN", "TG_TASK_TYPE", + "TG_LLM_DURATION_MS", "TG_TOOL_DURATION_MS", "TG_TOOL_ERROR", + "TG_IN_TOKEN", "TG_OUT_TOKEN", + "TG_ERROR_TYPE", # Orchestrator entity types "TG_DECOMPOSITION", "TG_FINDING", "TG_PLAN_TYPE", "TG_STEP_RESULT", # Document reference predicate @@ -249,6 +262,7 @@ __all__ = [ "agent_plan_triples", "agent_step_result_triples", "agent_synthesis_triples", + "agent_pattern_decision_triples", # Utility "set_graph", # Vocabulary diff --git a/trustgraph-base/trustgraph/provenance/agent.py b/trustgraph-base/trustgraph/provenance/agent.py index 7203174e..5c4f0b2e 100644 --- a/trustgraph-base/trustgraph/provenance/agent.py +++ b/trustgraph-base/trustgraph/provenance/agent.py @@ -29,6 +29,11 @@ from . namespaces import ( TG_AGENT_QUESTION, TG_DECOMPOSITION, TG_FINDING, TG_PLAN_TYPE, TG_STEP_RESULT, TG_SYNTHESIS, TG_SUBAGENT_GOAL, TG_PLAN_STEP, + TG_TOOL_CANDIDATE, TG_TERMINATION_REASON, + TG_STEP_NUMBER, TG_PATTERN_DECISION, TG_PATTERN, TG_TASK_TYPE, + TG_LLM_DURATION_MS, TG_TOOL_DURATION_MS, TG_TOOL_ERROR, + TG_ERROR_TYPE, + TG_IN_TOKEN, TG_OUT_TOKEN, TG_LLM_MODEL, ) @@ -47,6 +52,17 @@ def _triple(s: str, p: str, o_term: Term) -> Triple: return Triple(s=_iri(s), p=_iri(p), o=o_term) +def _append_token_triples(triples, uri, in_token=None, out_token=None, + model=None): + """Append in_token/out_token/model triples when values are present.""" + if in_token is not None: + triples.append(_triple(uri, TG_IN_TOKEN, _literal(str(in_token)))) + if out_token is not None: + triples.append(_triple(uri, TG_OUT_TOKEN, _literal(str(out_token)))) + if model is not None: + triples.append(_triple(uri, TG_LLM_MODEL, _literal(model))) + + def agent_session_triples( session_uri: str, query: str, @@ -90,6 +106,43 @@ def agent_session_triples( return triples +def agent_pattern_decision_triples( + uri: str, + session_uri: str, + pattern: str, + task_type: str = "", +) -> List[Triple]: + """ + Build triples for a meta-router pattern decision. + + Creates: + - Entity declaration with tg:PatternDecision type + - wasDerivedFrom link to session + - Pattern and task type predicates + + Args: + uri: URI of this decision (from agent_pattern_decision_uri) + session_uri: URI of the parent session + pattern: Selected execution pattern (e.g. "react", "plan-then-execute") + task_type: Identified task type (e.g. "general", "research") + + Returns: + List of Triple objects + """ + triples = [ + _triple(uri, RDF_TYPE, _iri(PROV_ENTITY)), + _triple(uri, RDF_TYPE, _iri(TG_PATTERN_DECISION)), + _triple(uri, RDFS_LABEL, _literal(f"Pattern: {pattern}")), + _triple(uri, TG_PATTERN, _literal(pattern)), + _triple(uri, PROV_WAS_DERIVED_FROM, _iri(session_uri)), + ] + + if task_type: + triples.append(_triple(uri, TG_TASK_TYPE, _literal(task_type))) + + return triples + + def agent_iteration_triples( iteration_uri: str, question_uri: Optional[str] = None, @@ -98,6 +151,12 @@ def agent_iteration_triples( arguments: Dict[str, Any] = None, thought_uri: Optional[str] = None, thought_document_id: Optional[str] = None, + tool_candidates: Optional[List[str]] = None, + step_number: Optional[int] = None, + llm_duration_ms: Optional[int] = None, + in_token: Optional[int] = None, + out_token: Optional[int] = None, + model: Optional[str] = None, ) -> List[Triple]: """ Build triples for one agent iteration (Analysis+ToolUse). @@ -106,6 +165,7 @@ def agent_iteration_triples( - Entity declaration with tg:Analysis and tg:ToolUse types - wasDerivedFrom link to question (if first iteration) or previous - Action and arguments metadata + - Tool candidates (names of tools visible to the LLM) - Thought sub-entity (tg:Reflection, tg:Thought) with librarian document Args: @@ -116,6 +176,7 @@ def agent_iteration_triples( arguments: Arguments passed to the tool (will be JSON-encoded) thought_uri: URI for the thought sub-entity thought_document_id: Document URI for thought in librarian + tool_candidates: List of tool names available to the LLM Returns: List of Triple objects @@ -132,6 +193,23 @@ def agent_iteration_triples( _triple(iteration_uri, TG_ARGUMENTS, _literal(json.dumps(arguments))), ] + if tool_candidates: + for name in tool_candidates: + triples.append( + _triple(iteration_uri, TG_TOOL_CANDIDATE, _literal(name)) + ) + + if step_number is not None: + triples.append( + _triple(iteration_uri, TG_STEP_NUMBER, _literal(str(step_number))) + ) + + if llm_duration_ms is not None: + triples.append( + _triple(iteration_uri, TG_LLM_DURATION_MS, + _literal(str(llm_duration_ms))) + ) + if question_uri: triples.append( _triple(iteration_uri, PROV_WAS_DERIVED_FROM, _iri(question_uri)) @@ -155,6 +233,8 @@ def agent_iteration_triples( _triple(thought_uri, TG_DOCUMENT, _iri(thought_document_id)) ) + _append_token_triples(triples, iteration_uri, in_token, out_token, model) + return triples @@ -162,6 +242,8 @@ def agent_observation_triples( observation_uri: str, iteration_uri: str, document_id: Optional[str] = None, + tool_duration_ms: Optional[int] = None, + tool_error: Optional[str] = None, ) -> List[Triple]: """ Build triples for an agent observation (standalone entity). @@ -170,11 +252,15 @@ def agent_observation_triples( - Entity declaration with prov:Entity and tg:Observation types - wasDerivedFrom link to the iteration (Analysis+ToolUse) - Document reference to librarian (if provided) + - Tool execution duration (if provided) + - Tool error message (if the tool failed) Args: observation_uri: URI of the observation entity iteration_uri: URI of the iteration this observation derives from document_id: Librarian document ID for the observation content + tool_duration_ms: Tool execution time in milliseconds + tool_error: Error message if the tool failed Returns: List of Triple objects @@ -191,6 +277,20 @@ def agent_observation_triples( _triple(observation_uri, TG_DOCUMENT, _iri(document_id)) ) + if tool_duration_ms is not None: + triples.append( + _triple(observation_uri, TG_TOOL_DURATION_MS, + _literal(str(tool_duration_ms))) + ) + + if tool_error: + triples.append( + _triple(observation_uri, TG_TOOL_ERROR, _literal(tool_error)) + ) + triples.append( + _triple(observation_uri, RDF_TYPE, _iri(TG_ERROR_TYPE)) + ) + return triples @@ -199,6 +299,10 @@ def agent_final_triples( question_uri: Optional[str] = None, previous_uri: Optional[str] = None, document_id: Optional[str] = None, + termination_reason: Optional[str] = None, + in_token: Optional[int] = None, + out_token: Optional[int] = None, + model: Optional[str] = None, ) -> List[Triple]: """ Build triples for an agent final answer (Conclusion). @@ -208,12 +312,15 @@ def agent_final_triples( - wasGeneratedBy link to question (if no iterations) - wasDerivedFrom link to last iteration (if iterations exist) - Document reference to librarian + - Termination reason (why the agent loop stopped) Args: final_uri: URI of the final answer (from agent_final_uri) question_uri: URI of the question activity (if no iterations) previous_uri: URI of the last iteration (if iterations exist) document_id: Librarian document ID for the answer content + termination_reason: Why the loop stopped, e.g. "final-answer", + "max-iterations", "error" Returns: List of Triple objects @@ -237,6 +344,14 @@ def agent_final_triples( if document_id: triples.append(_triple(final_uri, TG_DOCUMENT, _iri(document_id))) + if termination_reason: + triples.append( + _triple(final_uri, TG_TERMINATION_REASON, + _literal(termination_reason)) + ) + + _append_token_triples(triples, final_uri, in_token, out_token, model) + return triples @@ -244,6 +359,9 @@ def agent_decomposition_triples( uri: str, session_uri: str, goals: List[str], + in_token: Optional[int] = None, + out_token: Optional[int] = None, + model: Optional[str] = None, ) -> List[Triple]: """Build triples for a supervisor decomposition step.""" triples = [ @@ -255,6 +373,7 @@ def agent_decomposition_triples( ] for goal in goals: triples.append(_triple(uri, TG_SUBAGENT_GOAL, _literal(goal))) + _append_token_triples(triples, uri, in_token, out_token, model) return triples @@ -282,6 +401,9 @@ def agent_plan_triples( uri: str, session_uri: str, steps: List[str], + in_token: Optional[int] = None, + out_token: Optional[int] = None, + model: Optional[str] = None, ) -> List[Triple]: """Build triples for a plan-then-execute plan.""" triples = [ @@ -293,6 +415,7 @@ def agent_plan_triples( ] for step in steps: triples.append(_triple(uri, TG_PLAN_STEP, _literal(step))) + _append_token_triples(triples, uri, in_token, out_token, model) return triples @@ -301,6 +424,9 @@ def agent_step_result_triples( plan_uri: str, goal: str, document_id: Optional[str] = None, + in_token: Optional[int] = None, + out_token: Optional[int] = None, + model: Optional[str] = None, ) -> List[Triple]: """Build triples for a plan step result.""" triples = [ @@ -313,6 +439,7 @@ def agent_step_result_triples( ] if document_id: triples.append(_triple(uri, TG_DOCUMENT, _iri(document_id))) + _append_token_triples(triples, uri, in_token, out_token, model) return triples @@ -320,6 +447,10 @@ def agent_synthesis_triples( uri: str, previous_uris, document_id: Optional[str] = None, + termination_reason: Optional[str] = None, + in_token: Optional[int] = None, + out_token: Optional[int] = None, + model: Optional[str] = None, ) -> List[Triple]: """Build triples for a synthesis answer. @@ -327,6 +458,8 @@ def agent_synthesis_triples( uri: URI of the synthesis entity previous_uris: Single URI string or list of URIs to derive from document_id: Librarian document ID for the answer content + termination_reason: Why the agent loop stopped + in_token/out_token/model: Token usage for the synthesis LLM call """ triples = [ _triple(uri, RDF_TYPE, _iri(PROV_ENTITY)), @@ -342,4 +475,12 @@ def agent_synthesis_triples( if document_id: triples.append(_triple(uri, TG_DOCUMENT, _iri(document_id))) + + if termination_reason: + triples.append( + _triple(uri, TG_TERMINATION_REASON, _literal(termination_reason)) + ) + + _append_token_triples(triples, uri, in_token, out_token, model) + return triples diff --git a/trustgraph-base/trustgraph/provenance/namespaces.py b/trustgraph-base/trustgraph/provenance/namespaces.py index 9e7fbb2d..0b14f1b9 100644 --- a/trustgraph-base/trustgraph/provenance/namespaces.py +++ b/trustgraph-base/trustgraph/provenance/namespaces.py @@ -119,6 +119,18 @@ TG_ARGUMENTS = TG + "arguments" TG_OBSERVATION = TG + "observation" # Links iteration to observation sub-entity TG_SUBAGENT_GOAL = TG + "subagentGoal" # Goal string on Decomposition/Finding TG_PLAN_STEP = TG + "planStep" # Step goal string on Plan/StepResult +TG_TOOL_CANDIDATE = TG + "toolCandidate" # Tool name on Analysis events +TG_TERMINATION_REASON = TG + "terminationReason" # Why the agent loop stopped +TG_STEP_NUMBER = TG + "stepNumber" # Explicit step counter on iteration events +TG_PATTERN_DECISION = TG + "PatternDecision" # Meta-router routing decision entity type +TG_PATTERN = TG + "pattern" # Selected execution pattern +TG_TASK_TYPE = TG + "taskType" # Identified task type +TG_LLM_DURATION_MS = TG + "llmDurationMs" # LLM call duration in milliseconds +TG_TOOL_DURATION_MS = TG + "toolDurationMs" # Tool execution duration in milliseconds +TG_TOOL_ERROR = TG + "toolError" # Error message from a failed tool execution +TG_ERROR_TYPE = TG + "Error" # Mixin type for failure events +TG_IN_TOKEN = TG + "inToken" # Input token count for an LLM call +TG_OUT_TOKEN = TG + "outToken" # Output token count for an LLM call # Named graph URIs for RDF datasets # These separate different types of data while keeping them in the same collection diff --git a/trustgraph-base/trustgraph/provenance/triples.py b/trustgraph-base/trustgraph/provenance/triples.py index 920a3482..8bdfc2cb 100644 --- a/trustgraph-base/trustgraph/provenance/triples.py +++ b/trustgraph-base/trustgraph/provenance/triples.py @@ -34,6 +34,8 @@ from . namespaces import ( TG_ANSWER_TYPE, # Question subtypes TG_GRAPH_RAG_QUESTION, TG_DOC_RAG_QUESTION, + # Token usage + TG_IN_TOKEN, TG_OUT_TOKEN, ) from . uris import activity_uri, agent_uri, subgraph_uri, edge_selection_uri @@ -74,6 +76,17 @@ def _triple(s: str, p: str, o_term: Term) -> Triple: return Triple(s=_iri(s), p=_iri(p), o=o_term) +def _append_token_triples(triples, uri, in_token=None, out_token=None, + model=None): + """Append in_token/out_token/model triples when values are present.""" + if in_token is not None: + triples.append(_triple(uri, TG_IN_TOKEN, _literal(str(in_token)))) + if out_token is not None: + triples.append(_triple(uri, TG_OUT_TOKEN, _literal(str(out_token)))) + if model is not None: + triples.append(_triple(uri, TG_LLM_MODEL, _literal(model))) + + def document_triples( doc_uri: str, title: Optional[str] = None, @@ -396,6 +409,9 @@ def grounding_triples( grounding_uri: str, question_uri: str, concepts: List[str], + in_token: Optional[int] = None, + out_token: Optional[int] = None, + model: Optional[str] = None, ) -> List[Triple]: """ Build triples for a grounding entity (concept decomposition of query). @@ -423,6 +439,8 @@ def grounding_triples( for concept in concepts: triples.append(_triple(grounding_uri, TG_CONCEPT, _literal(concept))) + _append_token_triples(triples, grounding_uri, in_token, out_token, model) + return triples @@ -485,6 +503,9 @@ def focus_triples( exploration_uri: str, selected_edges_with_reasoning: List[dict], session_id: str = "", + in_token: Optional[int] = None, + out_token: Optional[int] = None, + model: Optional[str] = None, ) -> List[Triple]: """ Build triples for a focus entity (selected edges with reasoning). @@ -543,6 +564,8 @@ def focus_triples( _triple(edge_sel_uri, TG_REASONING, _literal(reasoning)) ) + _append_token_triples(triples, focus_uri, in_token, out_token, model) + return triples @@ -550,6 +573,9 @@ def synthesis_triples( synthesis_uri: str, focus_uri: str, document_id: Optional[str] = None, + in_token: Optional[int] = None, + out_token: Optional[int] = None, + model: Optional[str] = None, ) -> List[Triple]: """ Build triples for a synthesis entity (final answer). @@ -578,6 +604,8 @@ def synthesis_triples( if document_id: triples.append(_triple(synthesis_uri, TG_DOCUMENT, _iri(document_id))) + _append_token_triples(triples, synthesis_uri, in_token, out_token, model) + return triples @@ -674,6 +702,9 @@ def docrag_synthesis_triples( synthesis_uri: str, exploration_uri: str, document_id: Optional[str] = None, + in_token: Optional[int] = None, + out_token: Optional[int] = None, + model: Optional[str] = None, ) -> List[Triple]: """ Build triples for a document RAG synthesis entity (final answer). @@ -702,4 +733,6 @@ def docrag_synthesis_triples( if document_id: triples.append(_triple(synthesis_uri, TG_DOCUMENT, _iri(document_id))) + _append_token_triples(triples, synthesis_uri, in_token, out_token, model) + return triples diff --git a/trustgraph-base/trustgraph/provenance/uris.py b/trustgraph-base/trustgraph/provenance/uris.py index a3aadef6..a26ac867 100644 --- a/trustgraph-base/trustgraph/provenance/uris.py +++ b/trustgraph-base/trustgraph/provenance/uris.py @@ -259,6 +259,11 @@ def agent_synthesis_uri(session_id: str) -> str: return f"urn:trustgraph:agent:{session_id}/synthesis" +def agent_pattern_decision_uri(session_id: str) -> str: + """Generate URI for a meta-router pattern decision.""" + return f"urn:trustgraph:agent:{session_id}/pattern-decision" + + # Document RAG provenance URIs # These URIs use the urn:trustgraph:docrag: namespace to distinguish # document RAG provenance from graph RAG provenance diff --git a/trustgraph-base/trustgraph/schema/services/agent.py b/trustgraph-base/trustgraph/schema/services/agent.py index fbc0101c..cd4a2b45 100644 --- a/trustgraph-base/trustgraph/schema/services/agent.py +++ b/trustgraph-base/trustgraph/schema/services/agent.py @@ -51,8 +51,8 @@ class AgentRequest: @dataclass class AgentResponse: # Streaming-first design - chunk_type: str = "" # "thought", "action", "observation", "answer", "explain", "error" - content: str = "" # The actual content (interpretation depends on chunk_type) + message_type: str = "" # "thought", "action", "observation", "answer", "explain", "error" + content: str = "" # The actual content (interpretation depends on message_type) end_of_message: bool = False # Current chunk type (thought/action/etc.) is complete end_of_dialog: bool = False # Entire agent dialog is complete @@ -66,5 +66,10 @@ class AgentResponse: error: Error | None = None + # Token usage (populated on end_of_dialog message) + in_token: int | None = None + out_token: int | None = None + model: str | None = None + ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/llm.py b/trustgraph-base/trustgraph/schema/services/llm.py index 0fd6ab90..89c0cd54 100644 --- a/trustgraph-base/trustgraph/schema/services/llm.py +++ b/trustgraph-base/trustgraph/schema/services/llm.py @@ -17,9 +17,9 @@ class TextCompletionRequest: class TextCompletionResponse: error: Error | None = None response: str = "" - in_token: int = 0 - out_token: int = 0 - model: str = "" + in_token: int | None = None + out_token: int | None = None + model: str | None = None end_of_stream: bool = False # Indicates final message in stream ############################################################################ diff --git a/trustgraph-base/trustgraph/schema/services/prompt.py b/trustgraph-base/trustgraph/schema/services/prompt.py index f7388102..1696790b 100644 --- a/trustgraph-base/trustgraph/schema/services/prompt.py +++ b/trustgraph-base/trustgraph/schema/services/prompt.py @@ -41,4 +41,9 @@ class PromptResponse: # Indicates final message in stream end_of_stream: bool = False + # Token usage from the underlying text completion + in_token: int | None = None + out_token: int | None = None + model: str | None = None + ############################################################################ \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/services/retrieval.py b/trustgraph-base/trustgraph/schema/services/retrieval.py index 4b17733d..a1af9170 100644 --- a/trustgraph-base/trustgraph/schema/services/retrieval.py +++ b/trustgraph-base/trustgraph/schema/services/retrieval.py @@ -29,6 +29,9 @@ class GraphRagResponse: explain_triples: list[Triple] = field(default_factory=list) # Provenance triples for this step message_type: str = "" # "chunk" or "explain" end_of_session: bool = False # Entire session complete + in_token: int | None = None + out_token: int | None = None + model: str | None = None ############################################################################ @@ -52,3 +55,6 @@ class DocumentRagResponse: explain_triples: list[Triple] = field(default_factory=list) # Provenance triples for this step message_type: str = "" # "chunk" or "explain" end_of_session: bool = False # Entire session complete + in_token: int | None = None + out_token: int | None = None + model: str | None = None diff --git a/trustgraph-bedrock/pyproject.toml b/trustgraph-bedrock/pyproject.toml index 6dd017c7..2d65461b 100644 --- a/trustgraph-bedrock/pyproject.toml +++ b/trustgraph-bedrock/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.2,<2.3", + "trustgraph-base>=2.3,<2.4", "pulsar-client", "prometheus-client", "boto3", diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index 2b111cae..0151fef4 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.2,<2.3", + "trustgraph-base>=2.3,<2.4", "requests", "pulsar-client", "aiohttp", diff --git a/trustgraph-cli/trustgraph/cli/invoke_agent.py b/trustgraph-cli/trustgraph/cli/invoke_agent.py index 026286d0..b379c2df 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_agent.py +++ b/trustgraph-cli/trustgraph/cli/invoke_agent.py @@ -126,7 +126,7 @@ def question_explainable( try: # Track last chunk type for formatting - last_chunk_type = None + last_message_type = None current_outputter = None # Stream agent with explainability - process events as they arrive @@ -138,7 +138,7 @@ def question_explainable( group=group, ): if isinstance(item, AgentThought): - if last_chunk_type != "thought": + if last_message_type != "thought": if current_outputter: current_outputter.__exit__(None, None, None) current_outputter = None @@ -146,7 +146,7 @@ def question_explainable( if verbose: current_outputter = Outputter(width=78, prefix="\U0001f914 ") current_outputter.__enter__() - last_chunk_type = "thought" + last_message_type = "thought" if current_outputter: current_outputter.output(item.content) if current_outputter.word_buffer: @@ -155,7 +155,7 @@ def question_explainable( current_outputter.word_buffer = "" elif isinstance(item, AgentObservation): - if last_chunk_type != "observation": + if last_message_type != "observation": if current_outputter: current_outputter.__exit__(None, None, None) current_outputter = None @@ -163,7 +163,7 @@ def question_explainable( if verbose: current_outputter = Outputter(width=78, prefix="\U0001f4a1 ") current_outputter.__enter__() - last_chunk_type = "observation" + last_message_type = "observation" if current_outputter: current_outputter.output(item.content) if current_outputter.word_buffer: @@ -172,12 +172,12 @@ def question_explainable( current_outputter.word_buffer = "" elif isinstance(item, AgentAnswer): - if last_chunk_type != "answer": + if last_message_type != "answer": if current_outputter: current_outputter.__exit__(None, None, None) current_outputter = None print() - last_chunk_type = "answer" + last_message_type = "answer" # Print answer content directly print(item.content, end="", flush=True) @@ -261,7 +261,7 @@ def question_explainable( current_outputter = None # Final newline if we ended with answer - if last_chunk_type == "answer": + if last_message_type == "answer": print() finally: @@ -272,7 +272,8 @@ def question( url, question, flow_id, user, collection, plan=None, state=None, group=None, pattern=None, verbose=False, streaming=True, - token=None, explainable=False, debug=False + token=None, explainable=False, debug=False, + show_usage=False ): # Explainable mode uses the API to capture and process provenance events if explainable: @@ -321,15 +322,16 @@ def question( # Handle streaming response if streaming: # Track last chunk type and current outputter for streaming - last_chunk_type = None + last_message_type = None current_outputter = None + last_answer_chunk = None for chunk in response: - chunk_type = chunk.chunk_type + message_type = chunk.message_type content = chunk.content # Check if we're switching to a new message type - if last_chunk_type != chunk_type: + if last_message_type != message_type: # Close previous outputter if exists if current_outputter: current_outputter.__exit__(None, None, None) @@ -337,15 +339,15 @@ def question( print() # Blank line between message types # Create new outputter for new message type - if chunk_type == "thought" and verbose: + if message_type == "thought" and verbose: current_outputter = Outputter(width=78, prefix="\U0001f914 ") current_outputter.__enter__() - elif chunk_type == "observation" and verbose: + elif message_type == "observation" and verbose: current_outputter = Outputter(width=78, prefix="\U0001f4a1 ") current_outputter.__enter__() # For answer, don't use Outputter - just print as-is - last_chunk_type = chunk_type + last_message_type = message_type # Output the chunk if current_outputter: @@ -355,33 +357,42 @@ def question( print(current_outputter.word_buffer, end="", flush=True) current_outputter.column += len(current_outputter.word_buffer) current_outputter.word_buffer = "" - elif chunk_type == "final-answer": + elif message_type == "final-answer": print(content, end="", flush=True) + last_answer_chunk = chunk # Close any remaining outputter if current_outputter: current_outputter.__exit__(None, None, None) current_outputter = None # Add final newline if we were outputting answer - elif last_chunk_type == "final-answer": + elif last_message_type == "final-answer": print() + if show_usage and last_answer_chunk: + print( + f"Input tokens: {last_answer_chunk.in_token} " + f"Output tokens: {last_answer_chunk.out_token} " + f"Model: {last_answer_chunk.model}", + file=sys.stderr, + ) + else: # Non-streaming response - but agents use multipart messaging # so we iterate through the chunks (which are complete messages, not text chunks) for chunk in response: # Display thoughts if verbose - if chunk.chunk_type == "thought" and verbose: + if chunk.message_type == "thought" and verbose: output(wrap(chunk.content), "\U0001f914 ") print() # Display observations if verbose - elif chunk.chunk_type == "observation" and verbose: + elif chunk.message_type == "observation" and verbose: output(wrap(chunk.content), "\U0001f4a1 ") print() # Display answer - elif chunk.chunk_type == "final-answer" or chunk.chunk_type == "answer": + elif chunk.message_type == "final-answer" or chunk.message_type == "answer": print(chunk.content) finally: @@ -477,6 +488,12 @@ def main(): help='Show debug output for troubleshooting' ) + parser.add_argument( + '--show-usage', + action='store_true', + help='Show token usage and model on stderr' + ) + args = parser.parse_args() try: @@ -496,6 +513,7 @@ def main(): token = args.token, explainable = args.explainable, debug = args.debug, + show_usage = args.show_usage, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py index 066b92f4..d566f51d 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_document_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_document_rag.py @@ -99,7 +99,8 @@ def question_explainable( def question( url, flow_id, question_text, user, collection, doc_limit, - streaming=True, token=None, explainable=False, debug=False + streaming=True, token=None, explainable=False, debug=False, + show_usage=False ): # Explainable mode uses the API to capture and process provenance events if explainable: @@ -133,22 +134,40 @@ def question( ) # Stream output + last_chunk = None for chunk in response: - print(chunk, end="", flush=True) + print(chunk.content, end="", flush=True) + last_chunk = chunk print() # Final newline + if show_usage and last_chunk: + print( + f"Input tokens: {last_chunk.in_token} " + f"Output tokens: {last_chunk.out_token} " + f"Model: {last_chunk.model}", + file=sys.stderr, + ) + finally: socket.close() else: # Use REST API for non-streaming flow = api.flow().id(flow_id) - resp = flow.document_rag( + result = flow.document_rag( query=question_text, user=user, collection=collection, doc_limit=doc_limit, ) - print(resp) + print(result.text) + + if show_usage: + print( + f"Input tokens: {result.in_token} " + f"Output tokens: {result.out_token} " + f"Model: {result.model}", + file=sys.stderr, + ) def main(): @@ -219,6 +238,12 @@ def main(): help='Show debug output for troubleshooting' ) + parser.add_argument( + '--show-usage', + action='store_true', + help='Show token usage and model on stderr' + ) + args = parser.parse_args() try: @@ -234,6 +259,7 @@ def main(): token=args.token, explainable=args.explainable, debug=args.debug, + show_usage=args.show_usage, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py index 230cc54b..c9efe54d 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py +++ b/trustgraph-cli/trustgraph/cli/invoke_graph_rag.py @@ -753,7 +753,7 @@ def question( url, flow_id, question, user, collection, entity_limit, triple_limit, max_subgraph_size, max_path_length, edge_score_limit=50, edge_limit=25, streaming=True, token=None, - explainable=False, debug=False + explainable=False, debug=False, show_usage=False ): # Explainable mode uses the API to capture and process provenance events @@ -798,16 +798,26 @@ def question( ) # Stream output + last_chunk = None for chunk in response: - print(chunk, end="", flush=True) + print(chunk.content, end="", flush=True) + last_chunk = chunk print() # Final newline + if show_usage and last_chunk: + print( + f"Input tokens: {last_chunk.in_token} " + f"Output tokens: {last_chunk.out_token} " + f"Model: {last_chunk.model}", + file=sys.stderr, + ) + finally: socket.close() else: # Use REST API for non-streaming flow = api.flow().id(flow_id) - resp = flow.graph_rag( + result = flow.graph_rag( query=question, user=user, collection=collection, @@ -818,7 +828,15 @@ def question( edge_score_limit=edge_score_limit, edge_limit=edge_limit, ) - print(resp) + print(result.text) + + if show_usage: + print( + f"Input tokens: {result.in_token} " + f"Output tokens: {result.out_token} " + f"Model: {result.model}", + file=sys.stderr, + ) def main(): @@ -923,6 +941,12 @@ def main(): help='Show debug output for troubleshooting' ) + parser.add_argument( + '--show-usage', + action='store_true', + help='Show token usage and model on stderr' + ) + args = parser.parse_args() try: @@ -943,6 +967,7 @@ def main(): token=args.token, explainable=args.explainable, debug=args.debug, + show_usage=args.show_usage, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_llm.py b/trustgraph-cli/trustgraph/cli/invoke_llm.py index a1611625..3bf521f6 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_llm.py +++ b/trustgraph-cli/trustgraph/cli/invoke_llm.py @@ -10,7 +10,8 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def query(url, flow_id, system, prompt, streaming=True, token=None): +def query(url, flow_id, system, prompt, streaming=True, token=None, + show_usage=False): # Create API client api = Api(url=url, token=token) @@ -26,14 +27,29 @@ def query(url, flow_id, system, prompt, streaming=True, token=None): ) if streaming: - # Stream output to stdout without newline + last_chunk = None for chunk in response: - print(chunk, end="", flush=True) - # Add final newline after streaming + print(chunk.content, end="", flush=True) + last_chunk = chunk print() + + if show_usage and last_chunk: + print( + f"Input tokens: {last_chunk.in_token} " + f"Output tokens: {last_chunk.out_token} " + f"Model: {last_chunk.model}", + file=__import__('sys').stderr, + ) else: - # Non-streaming: print complete response - print(response) + print(response.text) + + if show_usage: + print( + f"Input tokens: {response.in_token} " + f"Output tokens: {response.out_token} " + f"Model: {response.model}", + file=__import__('sys').stderr, + ) finally: # Clean up socket connection @@ -82,6 +98,12 @@ def main(): help='Disable streaming (default: streaming enabled)' ) + parser.add_argument( + '--show-usage', + action='store_true', + help='Show token usage and model on stderr' + ) + args = parser.parse_args() try: @@ -93,6 +115,7 @@ def main(): prompt=args.prompt[0], streaming=not args.no_streaming, token=args.token, + show_usage=args.show_usage, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_prompt.py b/trustgraph-cli/trustgraph/cli/invoke_prompt.py index 09cc9043..86f7a024 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_prompt.py +++ b/trustgraph-cli/trustgraph/cli/invoke_prompt.py @@ -15,7 +15,8 @@ from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def query(url, flow_id, template_id, variables, streaming=True, token=None): +def query(url, flow_id, template_id, variables, streaming=True, token=None, + show_usage=False): # Create API client api = Api(url=url, token=token) @@ -31,16 +32,30 @@ def query(url, flow_id, template_id, variables, streaming=True, token=None): ) if streaming: - # Stream output (prompt yields strings directly) + last_chunk = None for chunk in response: - if chunk: - print(chunk, end="", flush=True) - # Add final newline after streaming + if chunk.content: + print(chunk.content, end="", flush=True) + last_chunk = chunk print() + if show_usage and last_chunk: + print( + f"Input tokens: {last_chunk.in_token} " + f"Output tokens: {last_chunk.out_token} " + f"Model: {last_chunk.model}", + file=__import__('sys').stderr, + ) else: - # Non-streaming: print complete response - print(response) + print(response.text) + + if show_usage: + print( + f"Input tokens: {response.in_token} " + f"Output tokens: {response.out_token} " + f"Model: {response.model}", + file=__import__('sys').stderr, + ) finally: # Clean up socket connection @@ -92,6 +107,12 @@ specified multiple times''', help='Disable streaming (default: streaming enabled for text responses)' ) + parser.add_argument( + '--show-usage', + action='store_true', + help='Show token usage and model on stderr' + ) + args = parser.parse_args() variables = {} @@ -113,6 +134,7 @@ specified multiple times''', variables=variables, streaming=not args.no_streaming, token=args.token, + show_usage=args.show_usage, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py b/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py index 82e48456..7b1ae9a6 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py +++ b/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py @@ -62,6 +62,11 @@ def sparql_query(url, token, flow_id, query, user, collection, limit, limit=limit, batch_size=batch_size, ): + if "error" in response: + err = response["error"] + msg = err.get("message", err) if isinstance(err, dict) else err + raise RuntimeError(msg) + query_type = response.get("query-type", "select") # ASK queries - just print and return diff --git a/trustgraph-cli/trustgraph/cli/verify_system_status.py b/trustgraph-cli/trustgraph/cli/verify_system_status.py index d7aa1d93..9491deaa 100644 --- a/trustgraph-cli/trustgraph/cli/verify_system_status.py +++ b/trustgraph-cli/trustgraph/cli/verify_system_status.py @@ -403,15 +403,8 @@ def main(): # Phase 1: Infrastructure print(tr.t("cli.verify_system_status.phase_1")) print("-" * 60) - if not checker.run_check( - tr.t("cli.verify_system_status.check_name.pulsar"), - check_pulsar, - args.pulsar_url, - args.check_timeout, - tr, - ): - print(f"\n⚠️ {tr.t('cli.verify_system_status.pulsar_not_responding')}") - print() + # Pulsar check is skipped — not all deployments use Pulsar. + # The API Gateway check covers broker connectivity indirectly. checker.run_check( tr.t("cli.verify_system_status.check_name.api_gateway"), diff --git a/trustgraph-embeddings-hf/pyproject.toml b/trustgraph-embeddings-hf/pyproject.toml index 1bc7bcb4..459f6123 100644 --- a/trustgraph-embeddings-hf/pyproject.toml +++ b/trustgraph-embeddings-hf/pyproject.toml @@ -10,8 +10,8 @@ description = "HuggingFace embeddings support for TrustGraph." readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.2,<2.3", - "trustgraph-flow>=2.2,<2.3", + "trustgraph-base>=2.3,<2.4", + "trustgraph-flow>=2.3,<2.4", "torch", "urllib3", "transformers", diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index b2df4a4c..14e919c0 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.2,<2.3", + "trustgraph-base>=2.3,<2.4", "aiohttp", "anthropic", "scylla-driver", diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/meta_router.py b/trustgraph-flow/trustgraph/agent/orchestrator/meta_router.py index c3b1afa6..97b87134 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/meta_router.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/meta_router.py @@ -53,7 +53,7 @@ class MetaRouter: "general": {"name": "general", "description": "General queries", "valid_patterns": ["react"], "framing": ""}, } - async def identify_task_type(self, question, context): + async def identify_task_type(self, question, context, usage=None): """ Use the LLM to classify the question into one of the known task types. @@ -71,7 +71,7 @@ class MetaRouter: try: client = context("prompt-request") - response = await client.prompt( + result = await client.prompt( id="task-type-classify", variables={ "question": question, @@ -81,7 +81,9 @@ class MetaRouter: ], }, ) - selected = response.strip().lower().replace('"', '').replace("'", "") + if usage: + usage.track(result) + selected = result.text.strip().lower().replace('"', '').replace("'", "") if selected in self.task_types: framing = self.task_types[selected].get("framing", DEFAULT_FRAMING) @@ -100,7 +102,7 @@ class MetaRouter: ) return DEFAULT_TASK_TYPE, framing - async def select_pattern(self, question, task_type, context): + async def select_pattern(self, question, task_type, context, usage=None): """ Use the LLM to select the best execution pattern for this task type. @@ -120,7 +122,7 @@ class MetaRouter: try: client = context("prompt-request") - response = await client.prompt( + result = await client.prompt( id="pattern-select", variables={ "question": question, @@ -133,7 +135,9 @@ class MetaRouter: ], }, ) - selected = response.strip().lower().replace('"', '').replace("'", "") + if usage: + usage.track(result) + selected = result.text.strip().lower().replace('"', '').replace("'", "") if selected in valid_patterns: logger.info(f"MetaRouter: selected pattern '{selected}'") @@ -148,19 +152,20 @@ class MetaRouter: logger.warning(f"MetaRouter: pattern selection failed: {e}") return valid_patterns[0] if valid_patterns else DEFAULT_PATTERN - async def route(self, question, context): + async def route(self, question, context, usage=None): """ Full routing pipeline: identify task type, then select pattern. Args: question: The user's query. context: UserAwareContext (flow wrapper). + usage: Optional UsageTracker for token counting. Returns: (pattern, task_type, framing) tuple. """ - task_type, framing = await self.identify_task_type(question, context) - pattern = await self.select_pattern(question, task_type, context) + task_type, framing = await self.identify_task_type(question, context, usage=usage) + pattern = await self.select_pattern(question, task_type, context, usage=usage) logger.info( f"MetaRouter: route result — " f"pattern={pattern}, task_type={task_type}, framing={framing!r}" diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py index c18c5bac..88d4ee72 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/pattern_base.py @@ -25,6 +25,7 @@ from trustgraph.provenance import ( agent_plan_uri, agent_step_result_uri, agent_synthesis_uri, + agent_pattern_decision_uri, agent_session_triples, agent_iteration_triples, agent_observation_triples, @@ -34,6 +35,7 @@ from trustgraph.provenance import ( agent_plan_triples, agent_step_result_triples, agent_synthesis_triples, + agent_pattern_decision_triples, set_graph, GRAPH_RETRIEVAL, ) @@ -65,6 +67,37 @@ class UserAwareContext: return client +class UsageTracker: + """Accumulates token usage across multiple prompt calls.""" + + def __init__(self): + self.total_in = 0 + self.total_out = 0 + self.last_model = None + + def track(self, result): + """Track usage from a PromptResult.""" + if result is not None: + if getattr(result, "in_token", None) is not None: + self.total_in += result.in_token + if getattr(result, "out_token", None) is not None: + self.total_out += result.out_token + if getattr(result, "model", None) is not None: + self.last_model = result.model + + @property + def in_token(self): + return self.total_in if self.total_in > 0 else None + + @property + def out_token(self): + return self.total_out if self.total_out > 0 else None + + @property + def model(self): + return self.last_model + + class PatternBase: """ Shared infrastructure for all agent patterns. @@ -151,7 +184,7 @@ class PatternBase: logger.debug(f"Think: {x} (is_final={is_final})") if streaming: r = AgentResponse( - chunk_type="thought", + message_type="thought", content=x, end_of_message=is_final, end_of_dialog=False, @@ -159,7 +192,7 @@ class PatternBase: ) else: r = AgentResponse( - chunk_type="thought", + message_type="thought", content=x, end_of_message=True, end_of_dialog=False, @@ -174,7 +207,7 @@ class PatternBase: logger.debug(f"Observe: {x} (is_final={is_final})") if streaming: r = AgentResponse( - chunk_type="observation", + message_type="observation", content=x, end_of_message=is_final, end_of_dialog=False, @@ -182,7 +215,7 @@ class PatternBase: ) else: r = AgentResponse( - chunk_type="observation", + message_type="observation", content=x, end_of_message=True, end_of_dialog=False, @@ -197,7 +230,7 @@ class PatternBase: logger.debug(f"Answer: {x}") if streaming: r = AgentResponse( - chunk_type="answer", + message_type="answer", content=x, end_of_message=False, end_of_dialog=False, @@ -205,7 +238,7 @@ class PatternBase: ) else: r = AgentResponse( - chunk_type="answer", + message_type="answer", content=x, end_of_message=True, end_of_dialog=False, @@ -239,16 +272,43 @@ class PatternBase: logger.debug(f"Emitted session triples for {session_uri}") await respond(AgentResponse( - chunk_type="explain", + message_type="explain", content="", explain_id=session_uri, explain_graph=GRAPH_RETRIEVAL, explain_triples=triples, )) + async def emit_pattern_decision_triples( + self, flow, session_id, session_uri, pattern, task_type, + user, collection, respond, + ): + """Emit provenance triples for a meta-router pattern decision.""" + uri = agent_pattern_decision_uri(session_id) + triples = set_graph( + agent_pattern_decision_triples( + uri, session_uri, pattern, task_type, + ), + GRAPH_RETRIEVAL, + ) + await flow("explainability").send(Triples( + metadata=Metadata(id=uri, user=user, collection=collection), + triples=triples, + )) + await respond(AgentResponse( + message_type="explain", content="", + explain_id=uri, explain_graph=GRAPH_RETRIEVAL, + explain_triples=triples, + )) + return uri + async def emit_iteration_triples(self, flow, session_id, iteration_num, session_uri, act, request, respond, - streaming): + streaming, tool_candidates=None, + step_number=None, + llm_duration_ms=None, + in_token=None, out_token=None, + model=None): """Emit provenance triples for an iteration (Analysis+ToolUse).""" iteration_uri = agent_iteration_uri(session_id, iteration_num) @@ -288,6 +348,12 @@ class PatternBase: arguments=act.arguments, thought_uri=thought_entity_uri if thought_doc_id else None, thought_document_id=thought_doc_id, + tool_candidates=tool_candidates, + step_number=step_number, + llm_duration_ms=llm_duration_ms, + in_token=in_token, + out_token=out_token, + model=model, ), GRAPH_RETRIEVAL, ) @@ -302,7 +368,7 @@ class PatternBase: logger.debug(f"Emitted iteration triples for {iteration_uri}") await respond(AgentResponse( - chunk_type="explain", + message_type="explain", content="", explain_id=iteration_uri, explain_graph=GRAPH_RETRIEVAL, @@ -311,7 +377,9 @@ class PatternBase: async def emit_observation_triples(self, flow, session_id, iteration_num, observation_text, request, respond, - context=None): + context=None, + tool_duration_ms=None, + tool_error=None): """Emit provenance triples for a standalone Observation entity.""" iteration_uri = agent_iteration_uri(session_id, iteration_num) observation_entity_uri = agent_observation_uri(session_id, iteration_num) @@ -344,6 +412,8 @@ class PatternBase: observation_entity_uri, parent_uri, document_id=observation_doc_id, + tool_duration_ms=tool_duration_ms, + tool_error=tool_error, ), GRAPH_RETRIEVAL, ) @@ -358,7 +428,7 @@ class PatternBase: logger.debug(f"Emitted observation triples for {observation_entity_uri}") await respond(AgentResponse( - chunk_type="explain", + message_type="explain", content="", explain_id=observation_entity_uri, explain_graph=GRAPH_RETRIEVAL, @@ -367,7 +437,7 @@ class PatternBase: async def emit_final_triples(self, flow, session_id, iteration_num, session_uri, answer_text, request, respond, - streaming): + streaming, termination_reason=None): """Emit provenance triples for the final answer and save to librarian.""" final_uri = agent_final_uri(session_id) @@ -401,6 +471,7 @@ class PatternBase: question_uri=final_question_uri, previous_uri=final_previous_uri, document_id=answer_doc_id, + termination_reason=termination_reason, ), GRAPH_RETRIEVAL, ) @@ -415,7 +486,7 @@ class PatternBase: logger.debug(f"Emitted final triples for {final_uri}") await respond(AgentResponse( - chunk_type="explain", + message_type="explain", content="", explain_id=final_uri, explain_graph=GRAPH_RETRIEVAL, @@ -439,7 +510,7 @@ class PatternBase: triples=triples, )) await respond(AgentResponse( - chunk_type="explain", content="", + message_type="explain", content="", explain_id=uri, explain_graph=GRAPH_RETRIEVAL, explain_triples=triples, )) @@ -478,7 +549,7 @@ class PatternBase: triples=triples, )) await respond(AgentResponse( - chunk_type="explain", content="", + message_type="explain", content="", explain_id=uri, explain_graph=GRAPH_RETRIEVAL, explain_triples=triples, )) @@ -498,7 +569,7 @@ class PatternBase: triples=triples, )) await respond(AgentResponse( - chunk_type="explain", content="", + message_type="explain", content="", explain_id=uri, explain_graph=GRAPH_RETRIEVAL, explain_triples=triples, )) @@ -531,14 +602,14 @@ class PatternBase: triples=triples, )) await respond(AgentResponse( - chunk_type="explain", content="", + message_type="explain", content="", explain_id=uri, explain_graph=GRAPH_RETRIEVAL, explain_triples=triples, )) async def emit_synthesis_triples( self, flow, session_id, previous_uris, answer_text, user, collection, - respond, streaming, + respond, streaming, termination_reason=None, ): """Emit provenance for a synthesis answer.""" uri = agent_synthesis_uri(session_id) @@ -555,7 +626,10 @@ class PatternBase: doc_id = None triples = set_graph( - agent_synthesis_triples(uri, previous_uris, doc_id), + agent_synthesis_triples( + uri, previous_uris, doc_id, + termination_reason=termination_reason, + ), GRAPH_RETRIEVAL, ) await flow("explainability").send(Triples( @@ -563,7 +637,7 @@ class PatternBase: triples=triples, )) await respond(AgentResponse( - chunk_type="explain", content="", + message_type="explain", content="", explain_id=uri, explain_graph=GRAPH_RETRIEVAL, explain_triples=triples, )) @@ -571,7 +645,8 @@ class PatternBase: # ---- Response helpers --------------------------------------------------- async def prompt_as_answer(self, client, prompt_id, variables, - respond, streaming, message_id=""): + respond, streaming, message_id="", + usage=None): """Call a prompt template, forwarding chunks as answer AgentResponse messages when streaming is enabled. @@ -584,29 +659,35 @@ class PatternBase: if text: accumulated.append(text) await respond(AgentResponse( - chunk_type="answer", + message_type="answer", content=text, end_of_message=False, end_of_dialog=False, message_id=message_id, )) - await client.prompt( + result = await client.prompt( id=prompt_id, variables=variables, streaming=True, chunk_callback=on_chunk, ) + if usage: + usage.track(result) return "".join(accumulated) else: - return await client.prompt( + result = await client.prompt( id=prompt_id, variables=variables, ) + if usage: + usage.track(result) + return result.text async def send_final_response(self, respond, streaming, answer_text, - already_streamed=False, message_id=""): + already_streamed=False, message_id="", + usage=None): """Send the answer content and end-of-dialog marker. Args: @@ -614,33 +695,44 @@ class PatternBase: via streaming callbacks (e.g. ReactPattern). Only the end-of-dialog marker is emitted. message_id: Provenance URI for the answer entity. + usage: UsageTracker with accumulated token counts. """ + usage_kwargs = {} + if usage: + usage_kwargs = { + "in_token": usage.in_token, + "out_token": usage.out_token, + "model": usage.model, + } + if streaming and not already_streamed: # Answer wasn't streamed yet — send it as a chunk first if answer_text: await respond(AgentResponse( - chunk_type="answer", + message_type="answer", content=answer_text, end_of_message=False, end_of_dialog=False, message_id=message_id, )) if streaming: - # End-of-dialog marker + # End-of-dialog marker with usage await respond(AgentResponse( - chunk_type="answer", + message_type="answer", content="", end_of_message=True, end_of_dialog=True, message_id=message_id, + **usage_kwargs, )) else: await respond(AgentResponse( - chunk_type="answer", + message_type="answer", content=answer_text, end_of_message=True, end_of_dialog=True, message_id=message_id, + **usage_kwargs, )) def build_next_request(self, request, history, session_id, collection, diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py index 59d22929..1de31a92 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/plan_pattern.py @@ -18,7 +18,7 @@ from trustgraph.provenance import ( agent_synthesis_uri, ) -from . pattern_base import PatternBase +from . pattern_base import PatternBase, UsageTracker logger = logging.getLogger(__name__) @@ -35,7 +35,11 @@ class PlanThenExecutePattern(PatternBase): Subsequent calls execute the next pending plan step via ReACT. """ - async def iterate(self, request, respond, next, flow): + async def iterate(self, request, respond, next, flow, usage=None, + pattern_decision_uri=None): + + if usage is None: + usage = UsageTracker() streaming = getattr(request, 'streaming', False) session_id = getattr(request, 'session_id', '') or str(uuid.uuid4()) @@ -63,17 +67,19 @@ class PlanThenExecutePattern(PatternBase): # Determine current phase by checking history for a plan step plan = self._extract_plan(request.history) + derive_from_uri = pattern_decision_uri or session_uri + if plan is None: await self._planning_iteration( request, respond, next, flow, - session_id, collection, streaming, session_uri, - iteration_num, + session_id, collection, streaming, derive_from_uri, + iteration_num, usage=usage, ) else: await self._execution_iteration( request, respond, next, flow, - session_id, collection, streaming, session_uri, - iteration_num, plan, + session_id, collection, streaming, derive_from_uri, + iteration_num, plan, usage=usage, ) def _extract_plan(self, history): @@ -98,7 +104,7 @@ class PlanThenExecutePattern(PatternBase): async def _planning_iteration(self, request, respond, next, flow, session_id, collection, streaming, - session_uri, iteration_num): + session_uri, iteration_num, usage=None): """Ask the LLM to produce a structured plan.""" think = self.make_think_callback(respond, streaming) @@ -113,7 +119,7 @@ class PlanThenExecutePattern(PatternBase): client = context("prompt-request") # Use the plan-create prompt template - plan_steps = await client.prompt( + result = await client.prompt( id="plan-create", variables={ "question": request.question, @@ -124,7 +130,10 @@ class PlanThenExecutePattern(PatternBase): ], }, ) + if usage: + usage.track(result) + plan_steps = result.objects # Validate we got a list if not isinstance(plan_steps, list) or not plan_steps: logger.warning("plan-create returned invalid result, falling back to single step") @@ -187,7 +196,8 @@ class PlanThenExecutePattern(PatternBase): async def _execution_iteration(self, request, respond, next, flow, session_id, collection, streaming, - session_uri, iteration_num, plan): + session_uri, iteration_num, plan, + usage=None): """Execute the next pending plan step via single-shot tool call.""" pending_idx = self._find_next_pending_step(plan) @@ -198,6 +208,7 @@ class PlanThenExecutePattern(PatternBase): request, respond, next, flow, session_id, collection, streaming, session_uri, iteration_num, plan, + usage=usage, ) return @@ -240,7 +251,7 @@ class PlanThenExecutePattern(PatternBase): client = context("prompt-request") # Single-shot: ask LLM which tool + arguments to use for this goal - tool_call = await client.prompt( + result = await client.prompt( id="plan-step-execute", variables={ "goal": goal, @@ -258,7 +269,10 @@ class PlanThenExecutePattern(PatternBase): ], }, ) + if usage: + usage.track(result) + tool_call = result.object tool_name = tool_call.get("tool", "") tool_arguments = tool_call.get("arguments", {}) @@ -330,7 +344,8 @@ class PlanThenExecutePattern(PatternBase): async def _synthesise(self, request, respond, next, flow, session_id, collection, streaming, - session_uri, iteration_num, plan): + session_uri, iteration_num, plan, + usage=None): """Synthesise a final answer from all completed plan step results.""" think = self.make_think_callback(respond, streaming) @@ -365,6 +380,7 @@ class PlanThenExecutePattern(PatternBase): respond=respond, streaming=streaming, message_id=synthesis_msg_id, + usage=usage, ) # Emit synthesis provenance (links back to last step result) @@ -372,6 +388,7 @@ class PlanThenExecutePattern(PatternBase): await self.emit_synthesis_triples( flow, session_id, last_step_uri, response_text, request.user, collection, respond, streaming, + termination_reason="plan-complete", ) if self.is_subagent(request): @@ -380,4 +397,5 @@ class PlanThenExecutePattern(PatternBase): await self.send_final_response( respond, streaming, response_text, already_streamed=streaming, message_id=synthesis_msg_id, + usage=usage, ) diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py index 67ded823..25264c26 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/react_pattern.py @@ -23,7 +23,7 @@ from ..react.agent_manager import AgentManager from ..react.types import Action, Final from ..tool_filter import get_next_state -from . pattern_base import PatternBase +from . pattern_base import PatternBase, UsageTracker logger = logging.getLogger(__name__) @@ -37,7 +37,11 @@ class ReactPattern(PatternBase): result is appended to history and a next-request is emitted. """ - async def iterate(self, request, respond, next, flow): + async def iterate(self, request, respond, next, flow, usage=None, + pattern_decision_uri=None): + + if usage is None: + usage = UsageTracker() streaming = getattr(request, 'streaming', False) session_id = getattr(request, 'session_id', '') or str(uuid.uuid4()) @@ -105,11 +109,23 @@ class ReactPattern(PatternBase): session_id, iteration_num, ) + # Tool names available to the LLM for this iteration + tool_candidates = [t.name for t in filtered_tools.values()] + + # Use pattern decision as derivation source if available + derive_from_uri = pattern_decision_uri or session_uri + # Callback: emit Analysis+ToolUse triples before tool executes async def on_action(act): await self.emit_iteration_triples( - flow, session_id, iteration_num, session_uri, + flow, session_id, iteration_num, derive_from_uri, act, request, respond, streaming, + tool_candidates=tool_candidates, + step_number=iteration_num, + llm_duration_ms=getattr(act, 'llm_duration_ms', None), + in_token=getattr(act, 'in_token', None), + out_token=getattr(act, 'out_token', None), + model=getattr(act, 'llm_model', None), ) act = await temp_agent.react( @@ -121,6 +137,7 @@ class ReactPattern(PatternBase): context=context, streaming=streaming, on_action=on_action, + usage=usage, ) logger.debug(f"Action: {act}") @@ -134,8 +151,9 @@ class ReactPattern(PatternBase): # Emit final provenance await self.emit_final_triples( - flow, session_id, iteration_num, session_uri, + flow, session_id, iteration_num, derive_from_uri, f, request, respond, streaming, + termination_reason="final-answer", ) if self.is_subagent(request): @@ -144,6 +162,7 @@ class ReactPattern(PatternBase): await self.send_final_response( respond, streaming, f, already_streamed=streaming, message_id=answer_msg_id, + usage=usage, ) return @@ -152,6 +171,8 @@ class ReactPattern(PatternBase): flow, session_id, iteration_num, act.observation, request, respond, context=context, + tool_duration_ms=getattr(act, 'tool_duration_ms', None), + tool_error=getattr(act, 'tool_error', None), ) history.append(act) diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/service.py b/trustgraph-flow/trustgraph/agent/orchestrator/service.py index 5bf8e2fd..3d08154d 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/service.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/service.py @@ -23,6 +23,7 @@ from ... base import Consumer, Producer from ... base import ConsumerMetrics, ProducerMetrics from ... schema import AgentRequest, AgentResponse, AgentStep, Error +from ..orchestrator.pattern_base import UsageTracker, PatternBase from ... schema import Triples, Metadata from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata from ... schema import librarian_request_queue, librarian_response_queue @@ -493,6 +494,8 @@ class Processor(AgentService): async def agent_request(self, request, respond, next, flow): + usage = UsageTracker() + try: # Intercept subagent completion messages @@ -516,7 +519,7 @@ class Processor(AgentService): if self.meta_router: pattern, task_type, framing = await self.meta_router.route( - request.question, context, + request.question, context, usage=usage, ) else: pattern = "react" @@ -534,19 +537,31 @@ class Processor(AgentService): ) # Dispatch to the selected pattern + selected = self.react_pattern if pattern == "plan-then-execute": - await self.plan_pattern.iterate( - request, respond, next, flow, - ) + selected = self.plan_pattern elif pattern == "supervisor": - await self.supervisor_pattern.iterate( - request, respond, next, flow, - ) - else: - # Default to react - await self.react_pattern.iterate( - request, respond, next, flow, - ) + selected = self.supervisor_pattern + + # Emit pattern decision provenance on first iteration + pattern_decision_uri = None + if not request.history and pattern: + session_id = getattr(request, 'session_id', '') + if session_id: + session_uri = self.provenance_session_uri(session_id) + pattern_decision_uri = \ + await selected.emit_pattern_decision_triples( + flow, session_id, session_uri, + pattern, getattr(request, 'task_type', ''), + request.user, + getattr(request, 'collection', 'default'), + respond, + ) + + await selected.iterate( + request, respond, next, flow, usage=usage, + pattern_decision_uri=pattern_decision_uri, + ) except Exception as e: @@ -562,7 +577,7 @@ class Processor(AgentService): ) r = AgentResponse( - chunk_type="error", + message_type="error", content=str(e), end_of_message=True, end_of_dialog=True, diff --git a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py index d5537876..973a9966 100644 --- a/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py +++ b/trustgraph-flow/trustgraph/agent/orchestrator/supervisor_pattern.py @@ -22,7 +22,7 @@ from trustgraph.provenance import ( agent_synthesis_uri, ) -from . pattern_base import PatternBase +from . pattern_base import PatternBase, UsageTracker logger = logging.getLogger(__name__) @@ -38,7 +38,11 @@ class SupervisorPattern(PatternBase): - "synthesise": triggered by aggregator with results in subagent_results """ - async def iterate(self, request, respond, next, flow): + async def iterate(self, request, respond, next, flow, usage=None, + pattern_decision_uri=None): + + if usage is None: + usage = UsageTracker() streaming = getattr(request, 'streaming', False) session_id = getattr(request, 'session_id', '') or str(uuid.uuid4()) @@ -67,22 +71,26 @@ class SupervisorPattern(PatternBase): ) ) + derive_from_uri = pattern_decision_uri or session_uri + if has_results: await self._synthesise( request, respond, next, flow, session_id, collection, streaming, - session_uri, iteration_num, + derive_from_uri, iteration_num, + usage=usage, ) else: await self._decompose_and_fanout( request, respond, next, flow, session_id, collection, streaming, - session_uri, iteration_num, + derive_from_uri, iteration_num, + usage=usage, ) async def _decompose_and_fanout(self, request, respond, next, flow, session_id, collection, streaming, - session_uri, iteration_num): + session_uri, iteration_num, usage=None): """Decompose the question into sub-goals and fan out subagents.""" decompose_msg_id = agent_decomposition_uri(session_id) @@ -100,7 +108,7 @@ class SupervisorPattern(PatternBase): client = context("prompt-request") # Use the supervisor-decompose prompt template - goals = await client.prompt( + result = await client.prompt( id="supervisor-decompose", variables={ "question": request.question, @@ -112,7 +120,10 @@ class SupervisorPattern(PatternBase): ], }, ) + if usage: + usage.track(result) + goals = result.objects # Validate result if not isinstance(goals, list): goals = [] @@ -175,7 +186,7 @@ class SupervisorPattern(PatternBase): async def _synthesise(self, request, respond, next, flow, session_id, collection, streaming, - session_uri, iteration_num): + session_uri, iteration_num, usage=None): """Synthesise final answer from subagent results.""" synthesis_msg_id = agent_synthesis_uri(session_id) @@ -216,6 +227,7 @@ class SupervisorPattern(PatternBase): respond=respond, streaming=streaming, message_id=synthesis_msg_id, + usage=usage, ) # Emit synthesis provenance (links back to all findings) @@ -226,9 +238,11 @@ class SupervisorPattern(PatternBase): await self.emit_synthesis_triples( flow, session_id, finding_uris, response_text, request.user, collection, respond, streaming, + termination_reason="subagents-complete", ) await self.send_final_response( respond, streaming, response_text, already_streamed=streaming, message_id=synthesis_msg_id, + usage=usage, ) diff --git a/trustgraph-flow/trustgraph/agent/react/agent_manager.py b/trustgraph-flow/trustgraph/agent/react/agent_manager.py index e86a2d6c..82a8f905 100644 --- a/trustgraph-flow/trustgraph/agent/react/agent_manager.py +++ b/trustgraph-flow/trustgraph/agent/react/agent_manager.py @@ -3,6 +3,7 @@ import logging import json import re import asyncio +import time from . types import Action, Final @@ -170,7 +171,7 @@ class AgentManager: raise ValueError(f"Could not parse response: {text}") - async def reason(self, question, history, context, streaming=False, think=None, observe=None, answer=None): + async def reason(self, question, history, context, streaming=False, think=None, observe=None, answer=None, usage=None): logger.debug(f"calling reason: {question}") @@ -255,11 +256,14 @@ class AgentManager: client = context("prompt-request") # Get streaming response - response_text = await client.agent_react( + prompt_result = await client.agent_react( variables=variables, streaming=True, chunk_callback=on_chunk ) + self._last_prompt_result = prompt_result + if usage: + usage.track(prompt_result) # Finalize parser parser.finalize() @@ -267,7 +271,13 @@ class AgentManager: # Get result result = parser.get_result() if result is None: - raise RuntimeError("Parser failed to produce a result") + return Action( + thought="", + name="__parse_error__", + arguments={}, + observation="", + tool_error="LLM response could not be parsed (streaming)", + ) return result @@ -275,10 +285,14 @@ class AgentManager: # Non-streaming path - get complete text and parse client = context("prompt-request") - response_text = await client.agent_react( + prompt_result = await client.agent_react( variables=variables, streaming=False ) + self._last_prompt_result = prompt_result + if usage: + usage.track(prompt_result) + response_text = prompt_result.text logger.debug(f"Response text:\n{response_text}") @@ -289,11 +303,19 @@ class AgentManager: except ValueError as e: logger.error(f"Failed to parse response: {e}") logger.error(f"Response was: {response_text}") - raise RuntimeError(f"Failed to parse agent response: {e}") + return Action( + thought="", + name="__parse_error__", + arguments={}, + observation="", + tool_error=f"LLM parse error: {e}", + ) async def react(self, question, history, think, observe, context, - streaming=False, answer=None, on_action=None): + streaming=False, answer=None, on_action=None, + usage=None): + t0 = time.monotonic() act = await self.reason( question = question, history = history, @@ -302,7 +324,14 @@ class AgentManager: think = think, observe = observe, answer = answer, + usage = usage, ) + act.llm_duration_ms = int((time.monotonic() - t0) * 1000) + pr = getattr(self, '_last_prompt_result', None) + if pr: + act.in_token = pr.in_token + act.out_token = pr.out_token + act.llm_model = pr.model if isinstance(act, Final): @@ -321,24 +350,43 @@ class AgentManager: logger.debug(f"ACTION: {act.name}") + # Notify caller before tool execution (for provenance) + if on_action: + await on_action(act) + + # Handle parse errors — skip tool execution + if act.name == "__parse_error__": + resp = f"Error: {act.tool_error}" + act.tool_duration_ms = 0 + await observe(resp, is_final=True) + act.observation = resp + return act + if act.name in self.tools: action = self.tools[act.name] else: raise RuntimeError(f"No action for {act.name}!") - # Notify caller before tool execution (for provenance) - if on_action: - await on_action(act) + t0 = time.monotonic() + try: + resp = await action.implementation(context).invoke( + **act.arguments + ) - resp = await action.implementation(context).invoke( - **act.arguments - ) + if isinstance(resp, str): + resp = resp.strip() + else: + resp = str(resp) + resp = resp.strip() - if isinstance(resp, str): - resp = resp.strip() - else: - resp = str(resp) - resp = resp.strip() + act.tool_error = None + + except Exception as e: + logger.error(f"Tool execution error ({act.name}): {e}") + resp = f"Error: {e}" + act.tool_error = str(e) + + act.tool_duration_ms = int((time.monotonic() - t0) * 1000) await observe(resp, is_final=True) diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py index 2c7423d8..00432181 100755 --- a/trustgraph-flow/trustgraph/agent/react/service.py +++ b/trustgraph-flow/trustgraph/agent/react/service.py @@ -469,7 +469,7 @@ class Processor(AgentService): # Send explain event for session await respond(AgentResponse( - chunk_type="explain", + message_type="explain", content="", explain_id=session_uri, explain_graph=GRAPH_RETRIEVAL, @@ -492,7 +492,7 @@ class Processor(AgentService): if streaming: r = AgentResponse( - chunk_type="thought", + message_type="thought", content=x, end_of_message=is_final, end_of_dialog=False, @@ -500,7 +500,7 @@ class Processor(AgentService): ) else: r = AgentResponse( - chunk_type="thought", + message_type="thought", content=x, end_of_message=True, end_of_dialog=False, @@ -515,7 +515,7 @@ class Processor(AgentService): if streaming: r = AgentResponse( - chunk_type="observation", + message_type="observation", content=x, end_of_message=is_final, end_of_dialog=False, @@ -523,7 +523,7 @@ class Processor(AgentService): ) else: r = AgentResponse( - chunk_type="observation", + message_type="observation", content=x, end_of_message=True, end_of_dialog=False, @@ -540,7 +540,7 @@ class Processor(AgentService): if streaming: r = AgentResponse( - chunk_type="answer", + message_type="answer", content=x, end_of_message=False, end_of_dialog=False, @@ -548,7 +548,7 @@ class Processor(AgentService): ) else: r = AgentResponse( - chunk_type="answer", + message_type="answer", content=x, end_of_message=True, end_of_dialog=False, @@ -637,7 +637,7 @@ class Processor(AgentService): logger.debug(f"Emitted iteration triples for {iter_uri}") await respond(AgentResponse( - chunk_type="explain", + message_type="explain", content="", explain_id=iter_uri, explain_graph=GRAPH_RETRIEVAL, @@ -715,7 +715,7 @@ class Processor(AgentService): # Send explain event for conclusion await respond(AgentResponse( - chunk_type="explain", + message_type="explain", content="", explain_id=final_uri, explain_graph=GRAPH_RETRIEVAL, @@ -725,7 +725,7 @@ class Processor(AgentService): if streaming: # End-of-dialog marker — answer chunks already sent via callback r = AgentResponse( - chunk_type="answer", + message_type="answer", content="", end_of_message=True, end_of_dialog=True, @@ -733,7 +733,7 @@ class Processor(AgentService): ) else: r = AgentResponse( - chunk_type="answer", + message_type="answer", content=f, end_of_message=True, end_of_dialog=True, @@ -792,7 +792,7 @@ class Processor(AgentService): # Send explain event for observation await respond(AgentResponse( - chunk_type="explain", + message_type="explain", content="", explain_id=observation_entity_uri, explain_graph=GRAPH_RETRIEVAL, @@ -847,7 +847,7 @@ class Processor(AgentService): streaming = getattr(request, 'streaming', False) if 'request' in locals() else False r = AgentResponse( - chunk_type="error", + message_type="error", content=str(e), end_of_message=True, end_of_dialog=True, diff --git a/trustgraph-flow/trustgraph/agent/react/tools.py b/trustgraph-flow/trustgraph/agent/react/tools.py index 6fd96ade..6674c999 100644 --- a/trustgraph-flow/trustgraph/agent/react/tools.py +++ b/trustgraph-flow/trustgraph/agent/react/tools.py @@ -42,7 +42,7 @@ class KnowledgeQueryImpl: async def explain_callback(explain_id, explain_graph, explain_triples=None): self.context.last_sub_explain_uri = explain_id await respond(AgentResponse( - chunk_type="explain", + message_type="explain", content="", explain_id=explain_id, explain_graph=explain_graph, @@ -78,9 +78,10 @@ class TextCompletionImpl: async def invoke(self, **arguments): client = self.context("prompt-request") logger.debug("Prompt question...") - return await client.question( + result = await client.question( arguments.get("question") ) + return result.text # This tool implementation knows how to do MCP tool invocation. This uses # the mcp-tool service. @@ -227,10 +228,11 @@ class PromptImpl: async def invoke(self, **arguments): client = self.context("prompt-request") logger.debug(f"Prompt template invocation: {self.template_id}...") - return await client.prompt( + result = await client.prompt( id=self.template_id, variables=arguments ) + return result.text # This tool implementation invokes a dynamically configured tool service diff --git a/trustgraph-flow/trustgraph/agent/react/types.py b/trustgraph-flow/trustgraph/agent/react/types.py index 7180db3e..ee0a677f 100644 --- a/trustgraph-flow/trustgraph/agent/react/types.py +++ b/trustgraph-flow/trustgraph/agent/react/types.py @@ -22,9 +22,19 @@ class Action: name : str arguments : dict observation : str - + llm_duration_ms : int = None + tool_duration_ms : int = None + tool_error : str = None + in_token : int = None + out_token : int = None + llm_model : str = None + @dataclasses.dataclass class Final: thought : str final : str + llm_duration_ms : int = None + in_token : int = None + out_token : int = None + llm_model : str = None diff --git a/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py b/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py index 1a03ac9f..a5fee382 100755 --- a/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py +++ b/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py @@ -4,6 +4,7 @@ Embeddings service, applies an embeddings model using fastembed Input is text, output is embeddings vector. """ +import asyncio import logging from ... base import EmbeddingsService @@ -37,7 +38,13 @@ class Processor(EmbeddingsService): self._load_model(model) def _load_model(self, model_name): - """Load a model, caching it for reuse""" + """Load a model, caching it for reuse. + + Synchronous — CPU and I/O heavy. Callers that run on the + event loop must dispatch via asyncio.to_thread to avoid + freezing the loop (which, in processor-group deployments, + freezes every sibling processor in the same process). + """ if self.cached_model_name != model_name: logger.info(f"Loading FastEmbed model: {model_name}") self.embeddings = TextEmbedding(model_name=model_name) @@ -46,6 +53,11 @@ class Processor(EmbeddingsService): else: logger.debug(f"Using cached model: {model_name}") + def _run_embed(self, texts): + """Synchronous embed call. Runs in a worker thread via + asyncio.to_thread from on_embeddings.""" + return list(self.embeddings.embed(texts)) + async def on_embeddings(self, texts, model=None): if not texts: @@ -53,11 +65,18 @@ class Processor(EmbeddingsService): use_model = model or self.default_model - # Reload model if it has changed - self._load_model(use_model) + # Reload model if it has changed. Model loading is sync + # and can take seconds; push it to a worker thread so the + # event loop (and any sibling processors in group mode) + # stay responsive. + if self.cached_model_name != use_model: + await asyncio.to_thread(self._load_model, use_model) - # FastEmbed processes the full batch efficiently - vecs = list(self.embeddings.embed(texts)) + # FastEmbed inference is synchronous ONNX runtime work. + # Dispatch to a worker thread so the event loop stays + # responsive for other tasks (important in group mode + # where the loop is shared across many processors). + vecs = await asyncio.to_thread(self._run_embed, texts) # Return list of vectors, one per input text return [v.tolist() for v in vecs] diff --git a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py index 2bb88c8a..9b5bbb79 100755 --- a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py @@ -117,10 +117,11 @@ class Processor(FlowProcessor): try: - defs = await flow("prompt-request").extract_definitions( + result = await flow("prompt-request").extract_definitions( text = chunk ) + defs = result.objects logger.debug(f"Definitions response: {defs}") if type(defs) != list: diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py index 29808cae..bdb0e6e8 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py @@ -376,10 +376,11 @@ class Processor(FlowProcessor): """ try: # Call prompt service with simplified format prompt - extraction_response = await flow("prompt-request").prompt( + result = await flow("prompt-request").prompt( id="extract-with-ontologies", variables=prompt_variables ) + extraction_response = result.object logger.debug(f"Simplified extraction response: {extraction_response}") # Parse response into structured format diff --git a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py index b557ec32..8068a23d 100755 --- a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py @@ -100,10 +100,11 @@ class Processor(FlowProcessor): try: - rels = await flow("prompt-request").extract_relationships( + result = await flow("prompt-request").extract_relationships( text = chunk ) + rels = result.objects logger.debug(f"Prompt response: {rels}") if type(rels) != list: diff --git a/trustgraph-flow/trustgraph/extract/kg/rows/processor.py b/trustgraph-flow/trustgraph/extract/kg/rows/processor.py index 8fd494b0..973bb3d7 100644 --- a/trustgraph-flow/trustgraph/extract/kg/rows/processor.py +++ b/trustgraph-flow/trustgraph/extract/kg/rows/processor.py @@ -148,11 +148,12 @@ class Processor(FlowProcessor): schema_dict = row_schema_translator.encode(schema) # Use prompt client to extract rows based on schema - objects = await flow("prompt-request").extract_objects( + result = await flow("prompt-request").extract_objects( schema=schema_dict, text=text ) - + + objects = result.objects if not isinstance(objects, list): return [] diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py b/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py index 62626046..3a37c4e3 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py @@ -40,15 +40,14 @@ class CoreExport: "ge", { "m": { - "i": data["metadata"]["id"], - "m": data["metadata"]["metadata"], + "i": data["metadata"]["id"], "u": data["metadata"]["user"], "c": data["metadata"]["collection"], }, "e": [ { "e": ent["entity"], - "v": ent["vectors"], + "v": ent["vector"], } for ent in data["entities"] ] @@ -65,8 +64,7 @@ class CoreExport: "t", { "m": { - "i": data["metadata"]["id"], - "m": data["metadata"]["metadata"], + "i": data["metadata"]["id"], "u": data["metadata"]["user"], "c": data["metadata"]["collection"], }, diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py index af22a5b0..0ca07319 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py @@ -48,7 +48,6 @@ class CoreImport: "triples": { "metadata": { "id": id, - "metadata": msg["m"]["m"], "user": user, "collection": "default", # Not used? }, @@ -57,7 +56,7 @@ class CoreImport: } await kr.process(msg) - + elif unpacked[0] == "ge": msg = unpacked[1] msg = { @@ -67,14 +66,13 @@ class CoreImport: "graph-embeddings": { "metadata": { "id": id, - "metadata": msg["m"]["m"], "user": user, "collection": "default", # Not used? }, "entities": [ { "entity": ent["e"], - "vectors": ent["v"], + "vector": ent["v"], } for ent in msg["e"] ] diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py index 6e01a5ca..de0fe52d 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py @@ -8,7 +8,7 @@ from ... schema import Metadata from ... schema import EntityContexts, EntityContext from ... base import Publisher -from . serialize import to_subgraph, to_value +from . serialize import to_value # Module logger logger = logging.getLogger(__name__) @@ -48,7 +48,6 @@ class EntityContextsImport: elt = EntityContexts( metadata=Metadata( id=data["metadata"]["id"], - metadata=to_subgraph(data["metadata"]["metadata"]), user=data["metadata"]["user"], collection=data["metadata"]["collection"], ), diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py index 8abf5e9c..7c7dc915 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py @@ -8,7 +8,7 @@ from ... schema import Metadata from ... schema import GraphEmbeddings, EntityEmbeddings from ... base import Publisher -from . serialize import to_subgraph, to_value +from . serialize import to_value # Module logger logger = logging.getLogger(__name__) @@ -48,14 +48,13 @@ class GraphEmbeddingsImport: elt = GraphEmbeddings( metadata=Metadata( id=data["metadata"]["id"], - metadata=to_subgraph(data["metadata"]["metadata"]), user=data["metadata"]["user"], collection=data["metadata"]["collection"], ), entities=[ EntityEmbeddings( entity=to_value(ent["entity"]), - vectors=ent["vectors"], + vector=ent["vector"], ) for ent in data["entities"] ] diff --git a/trustgraph-flow/trustgraph/gateway/service.py b/trustgraph-flow/trustgraph/gateway/service.py index 8d1aca9e..4e465bf7 100755 --- a/trustgraph-flow/trustgraph/gateway/service.py +++ b/trustgraph-flow/trustgraph/gateway/service.py @@ -9,7 +9,7 @@ from aiohttp import web import logging import os -from trustgraph.base.logging import setup_logging +from trustgraph.base.logging import setup_logging, add_logging_args from trustgraph.base.pubsub import get_pubsub, add_pubsub_args from . auth import Authenticator @@ -195,12 +195,7 @@ def run(): help=f'Secret API token (default: no auth)', ) - parser.add_argument( - '-l', '--log-level', - default='INFO', - choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], - help=f'Log level (default: INFO)' - ) + add_logging_args(parser) parser.add_argument( '--metrics', diff --git a/trustgraph-flow/trustgraph/metering/counter.py b/trustgraph-flow/trustgraph/metering/counter.py index 3e0b610c..46460b1f 100644 --- a/trustgraph-flow/trustgraph/metering/counter.py +++ b/trustgraph-flow/trustgraph/metering/counter.py @@ -102,10 +102,10 @@ class Processor(FlowProcessor): __class__.cost_metric.labels(model=modelname, direction="input").inc(cost_in) __class__.cost_metric.labels(model=modelname, direction="output").inc(cost_out) - logger.info(f"Model: {modelname}") - logger.info(f"Input Tokens: {num_in}") - logger.info(f"Output Tokens: {num_out}") - logger.info(f"Cost for call: ${cost_per_call}") + logger.debug( + f"Model: {modelname}, in={num_in}, out={num_out}, " + f"cost=${cost_per_call}" + ) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/prompt/template/service.py b/trustgraph-flow/trustgraph/prompt/template/service.py index 97298e13..c599ce77 100755 --- a/trustgraph-flow/trustgraph/prompt/template/service.py +++ b/trustgraph-flow/trustgraph/prompt/template/service.py @@ -11,7 +11,6 @@ import logging from ...schema import Definition, Relationship, Triple from ...schema import Topic from ...schema import PromptRequest, PromptResponse, Error -from ...schema import TextCompletionRequest, TextCompletionResponse from ...base import FlowProcessor from ...base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec @@ -124,35 +123,26 @@ class Processor(FlowProcessor): logger.debug(f"System prompt: {system}") logger.debug(f"User prompt: {prompt}") - # Use the text completion client with recipient handler - client = flow("text-completion-request") - async def forward_chunks(resp): - if resp.error: - raise RuntimeError(resp.error.message) - is_final = getattr(resp, 'end_of_stream', False) # Always send a message if there's content OR if it's the final message if resp.response or is_final: - # Forward each chunk immediately r = PromptResponse( text=resp.response if resp.response else "", object=None, error=None, end_of_stream=is_final, + in_token=resp.in_token, + out_token=resp.out_token, + model=resp.model, ) await flow("response").send(r, properties={"id": id}) - # Return True when end_of_stream - return is_final - - await client.request( - TextCompletionRequest( - system=system, prompt=prompt, streaming=True - ), - recipient=forward_chunks, - timeout=600 + await flow("text-completion-request").text_completion_stream( + system=system, prompt=prompt, + handler=forward_chunks, + timeout=600, ) # Return empty string since we already sent all chunks @@ -167,17 +157,21 @@ class Processor(FlowProcessor): return # Non-streaming path (original behavior) + usage = {} + async def llm(system, prompt): logger.debug(f"System prompt: {system}") logger.debug(f"User prompt: {prompt}") - resp = await flow("text-completion-request").text_completion( - system = system, prompt = prompt, streaming = False, - ) - try: - return resp + result = await flow("text-completion-request").text_completion( + system = system, prompt = prompt, + ) + usage["in_token"] = result.in_token + usage["out_token"] = result.out_token + usage["model"] = result.model + return result.text except Exception as e: logger.error(f"LLM Exception: {e}", exc_info=True) return None @@ -199,6 +193,9 @@ class Processor(FlowProcessor): object=None, error=None, end_of_stream=True, + in_token=usage.get("in_token", 0), + out_token=usage.get("out_token", 0), + model=usage.get("model", ""), ) await flow("response").send(r, properties={"id": id}) @@ -215,6 +212,9 @@ class Processor(FlowProcessor): object=json.dumps(resp), error=None, end_of_stream=True, + in_token=usage.get("in_token", 0), + out_token=usage.get("out_token", 0), + model=usage.get("model", ""), ) await flow("response").send(r, properties={"id": id}) diff --git a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py index f928a911..019d5610 100644 --- a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py @@ -23,6 +23,7 @@ from .... schema import RowsQueryRequest, RowsQueryResponse, GraphQLError from .... schema import Error, RowSchema, Field as SchemaField from .... base import FlowProcessor, ConsumerSpec, ProducerSpec from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config +from .... tables.cassandra_async import async_execute from ... graphql import GraphQLSchemaBuilder, SortDirection @@ -263,7 +264,7 @@ class Processor(FlowProcessor): query += f" LIMIT {limit}" try: - rows = self.session.execute(query, params) + rows = await async_execute(self.session, query, params) for row in rows: # Convert data map to dict with proper field names row_dict = dict(row.data) if row.data else {} @@ -301,7 +302,7 @@ class Processor(FlowProcessor): params = [collection, schema_name, primary_index] try: - rows = self.session.execute(query, params) + rows = await async_execute(self.session, query, params) for row in rows: row_dict = dict(row.data) if row.data else {} diff --git a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py index f1f5ba60..905aaaf2 100755 --- a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py @@ -4,6 +4,7 @@ Triples query service. Input is a (s, p, o, g) quad pattern, some values may be null. Output is a list of quads. """ +import asyncio import logging import json @@ -200,7 +201,11 @@ class Processor(TriplesQueryService): try: - self.ensure_connection(query.user) + # ensure_connection may construct a fresh + # EntityCentricKnowledgeGraph which does sync schema + # setup against Cassandra. Push it to a worker thread + # so the event loop doesn't block on first-use per user. + await asyncio.to_thread(self.ensure_connection, query.user) # Extract values from query s_val = get_term_value(query.s) @@ -218,14 +223,21 @@ class Processor(TriplesQueryService): quads = [] + # All self.tg.get_* calls below are sync wrappers around + # cassandra session.execute. Materialise inside a worker + # thread so iteration never triggers sync paging back on + # the event loop. + # Route to appropriate query method based on which fields are specified if s_val is not None: if p_val is not None: if o_val is not None: # SPO specified - find matching graphs - resp = self.tg.get_spo( - query.collection, s_val, p_val, o_val, g=g_val, - limit=query.limit + resp = await asyncio.to_thread( + lambda: list(self.tg.get_spo( + query.collection, s_val, p_val, o_val, + g=g_val, limit=query.limit, + )) ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH @@ -233,9 +245,11 @@ class Processor(TriplesQueryService): quads.append((s_val, p_val, o_val, g, term_type, datatype, language)) else: # SP specified - resp = self.tg.get_sp( - query.collection, s_val, p_val, g=g_val, - limit=query.limit + resp = await asyncio.to_thread( + lambda: list(self.tg.get_sp( + query.collection, s_val, p_val, + g=g_val, limit=query.limit, + )) ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH @@ -244,9 +258,11 @@ class Processor(TriplesQueryService): else: if o_val is not None: # SO specified - resp = self.tg.get_os( - query.collection, o_val, s_val, g=g_val, - limit=query.limit + resp = await asyncio.to_thread( + lambda: list(self.tg.get_os( + query.collection, o_val, s_val, + g=g_val, limit=query.limit, + )) ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH @@ -254,9 +270,11 @@ class Processor(TriplesQueryService): quads.append((s_val, t.p, o_val, g, term_type, datatype, language)) else: # S only - resp = self.tg.get_s( - query.collection, s_val, g=g_val, - limit=query.limit + resp = await asyncio.to_thread( + lambda: list(self.tg.get_s( + query.collection, s_val, + g=g_val, limit=query.limit, + )) ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH @@ -266,9 +284,11 @@ class Processor(TriplesQueryService): if p_val is not None: if o_val is not None: # PO specified - resp = self.tg.get_po( - query.collection, p_val, o_val, g=g_val, - limit=query.limit + resp = await asyncio.to_thread( + lambda: list(self.tg.get_po( + query.collection, p_val, o_val, + g=g_val, limit=query.limit, + )) ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH @@ -276,9 +296,11 @@ class Processor(TriplesQueryService): quads.append((t.s, p_val, o_val, g, term_type, datatype, language)) else: # P only - resp = self.tg.get_p( - query.collection, p_val, g=g_val, - limit=query.limit + resp = await asyncio.to_thread( + lambda: list(self.tg.get_p( + query.collection, p_val, + g=g_val, limit=query.limit, + )) ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH @@ -287,9 +309,11 @@ class Processor(TriplesQueryService): else: if o_val is not None: # O only - resp = self.tg.get_o( - query.collection, o_val, g=g_val, - limit=query.limit + resp = await asyncio.to_thread( + lambda: list(self.tg.get_o( + query.collection, o_val, + g=g_val, limit=query.limit, + )) ) for t in resp: g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH @@ -297,9 +321,10 @@ class Processor(TriplesQueryService): quads.append((t.s, t.p, o_val, g, term_type, datatype, language)) else: # Nothing specified - get all - resp = self.tg.get_all( - query.collection, - limit=query.limit + resp = await asyncio.to_thread( + lambda: list(self.tg.get_all( + query.collection, limit=query.limit, + )) ) for t in resp: # Note: quads_by_collection uses 'd' for graph field @@ -340,7 +365,7 @@ class Processor(TriplesQueryService): Uses Cassandra's paging to fetch results incrementally. """ try: - self.ensure_connection(query.user) + await asyncio.to_thread(self.ensure_connection, query.user) batch_size = query.batch_size if query.batch_size > 0 else 20 limit = query.limit if query.limit > 0 else 10000 @@ -374,9 +399,16 @@ class Processor(TriplesQueryService): yield batch, is_final return - # Create statement with fetch_size for true streaming + # Materialise in a worker thread. We lose true streaming + # paging (the driver fetches all pages eagerly inside the + # thread) but the event loop stays responsive, and result + # sets at this layer are typically small enough that this + # is acceptable. If true async paging is needed later, + # revisit using ResponseFuture page callbacks. statement = SimpleStatement(cql, fetch_size=batch_size) - result_set = self.tg.session.execute(statement, params) + result_set = await asyncio.to_thread( + lambda: list(self.tg.session.execute(statement, params)) + ) batch = [] count = 0 diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py index 730a7226..625b1386 100644 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/document_rag.py @@ -27,24 +27,27 @@ class Query: def __init__( self, rag, user, collection, verbose, - doc_limit=20 + doc_limit=20, track_usage=None, ): self.rag = rag self.user = user self.collection = collection self.verbose = verbose self.doc_limit = doc_limit + self.track_usage = track_usage async def extract_concepts(self, query): """Extract key concepts from query for independent embedding.""" - response = await self.rag.prompt_client.prompt( + result = await self.rag.prompt_client.prompt( "extract-concepts", variables={"query": query} ) + if self.track_usage: + self.track_usage(result) concepts = [] - if isinstance(response, str): - for line in response.strip().split('\n'): + if result.text: + for line in result.text.strip().split('\n'): line = line.strip() if line: concepts.append(line) @@ -53,6 +56,8 @@ class Query: if not concepts: concepts = [query] + self.concepts_usage = result + if self.verbose: logger.debug(f"Extracted concepts: {concepts}") @@ -167,8 +172,23 @@ class DocumentRag: save_answer_callback: async def callback(doc_id, answer_text) to save answer to librarian Returns: - str: The synthesized answer text + tuple: (answer_text, usage) where usage is a dict with + in_token, out_token, model """ + total_in = 0 + total_out = 0 + last_model = None + + def track_usage(result): + nonlocal total_in, total_out, last_model + if result is not None: + if result.in_token is not None: + total_in += result.in_token + if result.out_token is not None: + total_out += result.out_token + if result.model is not None: + last_model = result.model + if self.verbose: logger.debug("Constructing prompt...") @@ -191,7 +211,7 @@ class DocumentRag: q = Query( rag=self, user=user, collection=collection, verbose=self.verbose, - doc_limit=doc_limit + doc_limit=doc_limit, track_usage=track_usage, ) # Extract concepts from query (grounding step) @@ -199,8 +219,14 @@ class DocumentRag: # Emit grounding explainability after concept extraction if explain_callback: + cu = getattr(q, 'concepts_usage', None) gnd_triples = set_graph( - grounding_triples(gnd_uri, q_uri, concepts), + grounding_triples( + gnd_uri, q_uri, concepts, + in_token=cu.in_token if cu else None, + out_token=cu.out_token if cu else None, + model=cu.model if cu else None, + ), GRAPH_RETRIEVAL ) await explain_callback(gnd_triples, gnd_uri) @@ -228,19 +254,22 @@ class DocumentRag: accumulated_chunks.append(chunk) await chunk_callback(chunk, end_of_stream) - resp = await self.prompt_client.document_prompt( + synthesis_result = await self.prompt_client.document_prompt( query=query, documents=docs, streaming=True, chunk_callback=accumulating_callback ) + track_usage(synthesis_result) # Combine all chunks into full response resp = "".join(accumulated_chunks) else: - resp = await self.prompt_client.document_prompt( + synthesis_result = await self.prompt_client.document_prompt( query=query, documents=docs ) + track_usage(synthesis_result) + resp = synthesis_result.text if self.verbose: logger.debug("Query processing complete") @@ -265,6 +294,9 @@ class DocumentRag: docrag_synthesis_triples( syn_uri, exp_uri, document_id=synthesis_doc_id, + in_token=synthesis_result.in_token if synthesis_result else None, + out_token=synthesis_result.out_token if synthesis_result else None, + model=synthesis_result.model if synthesis_result else None, ), GRAPH_RETRIEVAL ) @@ -273,5 +305,11 @@ class DocumentRag: if self.verbose: logger.debug(f"Emitted explain for session {session_id}") - return resp + usage = { + "in_token": total_in if total_in > 0 else None, + "out_token": total_out if total_out > 0 else None, + "model": last_model, + } + + return resp, usage diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py index 3b281fe3..dc7296ad 100755 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py @@ -200,7 +200,7 @@ class Processor(FlowProcessor): # Query with streaming enabled # All chunks (including final one with end_of_stream=True) are sent via callback - await self.rag.query( + response, usage = await self.rag.query( v.query, user=v.user, collection=v.collection, @@ -217,12 +217,15 @@ class Processor(FlowProcessor): response=None, end_of_session=True, message_type="end", + in_token=usage.get("in_token"), + out_token=usage.get("out_token"), + model=usage.get("model"), ), properties={"id": id} ) else: - # Non-streaming path (existing behavior) - response = await self.rag.query( + # Non-streaming path - single response with answer and token usage + response, usage = await self.rag.query( v.query, user=v.user, collection=v.collection, @@ -233,11 +236,15 @@ class Processor(FlowProcessor): await flow("response").send( DocumentRagResponse( - response = response, - end_of_stream = True, - error = None + response=response, + end_of_stream=True, + end_of_session=True, + error=None, + in_token=usage.get("in_token"), + out_token=usage.get("out_token"), + model=usage.get("model"), ), - properties = {"id": id} + properties={"id": id} ) logger.info("Request processing complete") diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py index 5cf7b991..cf9f5c4e 100644 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py @@ -121,7 +121,7 @@ class Query: def __init__( self, rag, user, collection, verbose, entity_limit=50, triple_limit=30, max_subgraph_size=1000, - max_path_length=2, + max_path_length=2, track_usage=None, ): self.rag = rag self.user = user @@ -131,17 +131,20 @@ class Query: self.triple_limit = triple_limit self.max_subgraph_size = max_subgraph_size self.max_path_length = max_path_length + self.track_usage = track_usage async def extract_concepts(self, query): """Extract key concepts from query for independent embedding.""" - response = await self.rag.prompt_client.prompt( + result = await self.rag.prompt_client.prompt( "extract-concepts", variables={"query": query} ) + if self.track_usage: + self.track_usage(result) concepts = [] - if isinstance(response, str): - for line in response.strip().split('\n'): + if result.text: + for line in result.text.strip().split('\n'): line = line.strip() if line: concepts.append(line) @@ -149,6 +152,8 @@ class Query: if self.verbose: logger.debug(f"Extracted concepts: {concepts}") + self.concepts_usage = result + # Fall back to raw query if extraction returns nothing return concepts if concepts else [query] @@ -609,8 +614,24 @@ class GraphRag: save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian Returns: - str: The synthesized answer text + tuple: (answer_text, usage) where usage is a dict with + in_token, out_token, model """ + # Accumulate token usage across all prompt calls + total_in = 0 + total_out = 0 + last_model = None + + def track_usage(result): + nonlocal total_in, total_out, last_model + if result is not None: + if result.in_token is not None: + total_in += result.in_token + if result.out_token is not None: + total_out += result.out_token + if result.model is not None: + last_model = result.model + if self.verbose: logger.debug("Constructing prompt...") @@ -641,14 +662,21 @@ class GraphRag: triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, max_path_length = max_path_length, + track_usage = track_usage, ) kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query) # Emit grounding explain after concept extraction if explain_callback: + cu = getattr(q, 'concepts_usage', None) gnd_triples = set_graph( - grounding_triples(gnd_uri, q_uri, concepts), + grounding_triples( + gnd_uri, q_uri, concepts, + in_token=cu.in_token if cu else None, + out_token=cu.out_token if cu else None, + model=cu.model if cu else None, + ), GRAPH_RETRIEVAL ) await explain_callback(gnd_triples, gnd_uri) @@ -751,21 +779,22 @@ class GraphRag: logger.debug(f"Built edge map with {len(edge_map)} edges") # Step 1a: Edge Scoring - LLM scores edges for relevance - scoring_response = await self.prompt_client.prompt( + scoring_result = await self.prompt_client.prompt( "kg-edge-scoring", variables={ "query": query, "knowledge": edges_with_ids } ) + track_usage(scoring_result) if self.verbose: - logger.debug(f"Edge scoring response: {scoring_response}") + logger.debug(f"Edge scoring result: {scoring_result}") - # Parse scoring response to get edge IDs with scores + # Parse scoring response (jsonl) to get edge IDs with scores scored_edges = [] - def parse_scored_edge(obj): + for obj in scoring_result.objects or []: if isinstance(obj, dict) and "id" in obj and "score" in obj: try: score = int(obj["score"]) @@ -773,21 +802,6 @@ class GraphRag: score = 0 scored_edges.append({"id": obj["id"], "score": score}) - if isinstance(scoring_response, list): - for obj in scoring_response: - parse_scored_edge(obj) - elif isinstance(scoring_response, str): - for line in scoring_response.strip().split('\n'): - line = line.strip() - if not line: - continue - try: - parse_scored_edge(json.loads(line)) - except json.JSONDecodeError: - logger.warning( - f"Failed to parse edge scoring line: {line}" - ) - # Select top N edges by score scored_edges.sort(key=lambda x: x["score"], reverse=True) top_edges = scored_edges[:edge_limit] @@ -821,25 +835,30 @@ class GraphRag: ] # Run reasoning and document tracing concurrently - reasoning_task = self.prompt_client.prompt( - "kg-edge-reasoning", - variables={ - "query": query, - "knowledge": selected_edges_with_ids - } - ) + async def _get_reasoning(): + result = await self.prompt_client.prompt( + "kg-edge-reasoning", + variables={ + "query": query, + "knowledge": selected_edges_with_ids + } + ) + track_usage(result) + return result + + reasoning_task = _get_reasoning() doc_trace_task = q.trace_source_documents(selected_edge_uris) - reasoning_response, source_documents = await asyncio.gather( + reasoning_result, source_documents = await asyncio.gather( reasoning_task, doc_trace_task, return_exceptions=True ) # Handle exceptions from gather - if isinstance(reasoning_response, Exception): + if isinstance(reasoning_result, Exception): logger.warning( - f"Edge reasoning failed: {reasoning_response}" + f"Edge reasoning failed: {reasoning_result}" ) - reasoning_response = "" + reasoning_result = None if isinstance(source_documents, Exception): logger.warning( f"Document tracing failed: {source_documents}" @@ -848,29 +867,15 @@ class GraphRag: if self.verbose: - logger.debug(f"Edge reasoning response: {reasoning_response}") + logger.debug(f"Edge reasoning result: {reasoning_result}") - # Parse reasoning response and build explainability data + # Parse reasoning response (jsonl) and build explainability data reasoning_map = {} - def parse_reasoning(obj): - if isinstance(obj, dict) and "id" in obj: - reasoning_map[obj["id"]] = obj.get("reasoning", "") - - if isinstance(reasoning_response, list): - for obj in reasoning_response: - parse_reasoning(obj) - elif isinstance(reasoning_response, str): - for line in reasoning_response.strip().split('\n'): - line = line.strip() - if not line: - continue - try: - parse_reasoning(json.loads(line)) - except json.JSONDecodeError: - logger.warning( - f"Failed to parse edge reasoning line: {line}" - ) + if reasoning_result is not None: + for obj in reasoning_result.objects or []: + if isinstance(obj, dict) and "id" in obj: + reasoning_map[obj["id"]] = obj.get("reasoning", "") selected_edges_with_reasoning = [] for eid in selected_ids: @@ -886,9 +891,25 @@ class GraphRag: # Emit focus explain after edge selection completes if explain_callback: + # Sum scoring + reasoning token usage for focus event + focus_in = 0 + focus_out = 0 + focus_model = None + for r in [scoring_result, reasoning_result]: + if r is not None: + if r.in_token is not None: + focus_in += r.in_token + if r.out_token is not None: + focus_out += r.out_token + if r.model is not None: + focus_model = r.model + foc_triples = set_graph( focus_triples( - foc_uri, exp_uri, selected_edges_with_reasoning, session_id + foc_uri, exp_uri, selected_edges_with_reasoning, session_id, + in_token=focus_in or None, + out_token=focus_out or None, + model=focus_model, ), GRAPH_RETRIEVAL ) @@ -919,19 +940,22 @@ class GraphRag: accumulated_chunks.append(chunk) await chunk_callback(chunk, end_of_stream) - await self.prompt_client.prompt( + synthesis_result = await self.prompt_client.prompt( "kg-synthesis", variables=synthesis_variables, streaming=True, chunk_callback=accumulating_callback ) + track_usage(synthesis_result) # Combine all chunks into full response resp = "".join(accumulated_chunks) else: - resp = await self.prompt_client.prompt( + synthesis_result = await self.prompt_client.prompt( "kg-synthesis", variables=synthesis_variables, ) + track_usage(synthesis_result) + resp = synthesis_result.text if self.verbose: logger.debug("Query processing complete") @@ -956,6 +980,9 @@ class GraphRag: synthesis_triples( syn_uri, foc_uri, document_id=synthesis_doc_id, + in_token=synthesis_result.in_token if synthesis_result else None, + out_token=synthesis_result.out_token if synthesis_result else None, + model=synthesis_result.model if synthesis_result else None, ), GRAPH_RETRIEVAL ) @@ -964,5 +991,11 @@ class GraphRag: if self.verbose: logger.debug(f"Emitted explain for session {session_id}") - return resp + usage = { + "in_token": total_in if total_in > 0 else None, + "out_token": total_out if total_out > 0 else None, + "model": last_model, + } + + return resp, usage diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py index abf10e90..15c30ba1 100755 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py @@ -332,7 +332,7 @@ class Processor(FlowProcessor): ) # Query with streaming and real-time explain - response = await rag.query( + response, usage = await rag.query( query = v.query, user = v.user, collection = v.collection, entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, @@ -348,7 +348,7 @@ class Processor(FlowProcessor): else: # Non-streaming path with real-time explain - response = await rag.query( + response, usage = await rag.query( query = v.query, user = v.user, collection = v.collection, entity_limit = entity_limit, triple_limit = triple_limit, max_subgraph_size = max_subgraph_size, @@ -360,23 +360,30 @@ class Processor(FlowProcessor): parent_uri = v.parent_uri, ) - # Send chunk with response + # Send single response with answer and token usage await flow("response").send( GraphRagResponse( message_type="chunk", response=response, end_of_stream=True, - error=None, + end_of_session=True, + in_token=usage.get("in_token"), + out_token=usage.get("out_token"), + model=usage.get("model"), ), properties={"id": id} ) + return - # Send final message to close session + # Streaming: send final message to close session with token usage await flow("response").send( GraphRagResponse( message_type="chunk", response="", end_of_session=True, + in_token=usage.get("in_token"), + out_token=usage.get("out_token"), + model=usage.get("model"), ), properties={"id": id} ) diff --git a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py index 673cba4d..d0eec2e1 100755 --- a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py @@ -13,6 +13,7 @@ Uses a single 'rows' table with the schema: Each row is written multiple times - once per indexed field defined in the schema. """ +import asyncio import json import logging import re @@ -26,6 +27,7 @@ from .... schema import RowSchema, Field from .... base import FlowProcessor, ConsumerSpec from .... base import CollectionConfigHandler from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config +from .... tables.cassandra_async import async_execute # Module logger logger = logging.getLogger(__name__) @@ -361,11 +363,15 @@ class Processor(CollectionConfigHandler, FlowProcessor): schema_name = obj.schema_name source = getattr(obj.metadata, 'source', '') or '' - # Ensure tables exist - self.ensure_tables(keyspace) + # Ensure tables exist (sync DDL — push to a worker thread + # so the event loop stays responsive when running in a + # processor group sharing the loop with siblings). + await asyncio.to_thread(self.ensure_tables, keyspace) # Register partitions if first time seeing this (collection, schema_name) - self.register_partitions(keyspace, collection, schema_name) + await asyncio.to_thread( + self.register_partitions, keyspace, collection, schema_name + ) safe_keyspace = self.sanitize_name(keyspace) @@ -406,9 +412,10 @@ class Processor(CollectionConfigHandler, FlowProcessor): continue try: - self.session.execute( + await async_execute( + self.session, insert_cql, - (collection, schema_name, index_name, index_value, data_map, source) + (collection, schema_name, index_name, index_value, data_map, source), ) rows_written += 1 except Exception as e: @@ -425,18 +432,18 @@ class Processor(CollectionConfigHandler, FlowProcessor): async def create_collection(self, user: str, collection: str, metadata: dict): """Create/verify collection exists in Cassandra row store""" - # Connect if not already connected - self.connect_cassandra() + # Connect if not already connected (sync, push to thread) + await asyncio.to_thread(self.connect_cassandra) - # Ensure tables exist - self.ensure_tables(user) + # Ensure tables exist (sync DDL, push to thread) + await asyncio.to_thread(self.ensure_tables, user) logger.info(f"Collection {collection} ready for user {user}") async def delete_collection(self, user: str, collection: str): """Delete all data for a specific collection using partition tracking""" # Connect if not already connected - self.connect_cassandra() + await asyncio.to_thread(self.connect_cassandra) safe_keyspace = self.sanitize_name(user) @@ -446,8 +453,10 @@ class Processor(CollectionConfigHandler, FlowProcessor): SELECT keyspace_name FROM system_schema.keyspaces WHERE keyspace_name = %s """ - result = self.session.execute(check_keyspace_cql, (safe_keyspace,)) - if not result.one(): + result = await async_execute( + self.session, check_keyspace_cql, (safe_keyspace,) + ) + if not result: logger.info(f"Keyspace {safe_keyspace} does not exist, nothing to delete") return self.known_keyspaces.add(user) @@ -459,8 +468,9 @@ class Processor(CollectionConfigHandler, FlowProcessor): """ try: - partitions = self.session.execute(select_partitions_cql, (collection,)) - partition_list = list(partitions) + partition_list = await async_execute( + self.session, select_partitions_cql, (collection,) + ) except Exception as e: logger.error(f"Failed to query partitions for collection {collection}: {e}") raise @@ -474,9 +484,10 @@ class Processor(CollectionConfigHandler, FlowProcessor): partitions_deleted = 0 for partition in partition_list: try: - self.session.execute( + await async_execute( + self.session, delete_rows_cql, - (collection, partition.schema_name, partition.index_name) + (collection, partition.schema_name, partition.index_name), ) partitions_deleted += 1 except Exception as e: @@ -493,7 +504,9 @@ class Processor(CollectionConfigHandler, FlowProcessor): """ try: - self.session.execute(delete_partitions_cql, (collection,)) + await async_execute( + self.session, delete_partitions_cql, (collection,) + ) except Exception as e: logger.error(f"Failed to clean up row_partitions for {collection}: {e}") raise @@ -512,7 +525,7 @@ class Processor(CollectionConfigHandler, FlowProcessor): async def delete_collection_schema(self, user: str, collection: str, schema_name: str): """Delete all data for a specific collection + schema combination""" # Connect if not already connected - self.connect_cassandra() + await asyncio.to_thread(self.connect_cassandra) safe_keyspace = self.sanitize_name(user) @@ -523,8 +536,9 @@ class Processor(CollectionConfigHandler, FlowProcessor): """ try: - partitions = self.session.execute(select_partitions_cql, (collection, schema_name)) - partition_list = list(partitions) + partition_list = await async_execute( + self.session, select_partitions_cql, (collection, schema_name) + ) except Exception as e: logger.error( f"Failed to query partitions for {collection}/{schema_name}: {e}" @@ -540,9 +554,10 @@ class Processor(CollectionConfigHandler, FlowProcessor): partitions_deleted = 0 for partition in partition_list: try: - self.session.execute( + await async_execute( + self.session, delete_rows_cql, - (collection, schema_name, partition.index_name) + (collection, schema_name, partition.index_name), ) partitions_deleted += 1 except Exception as e: @@ -559,7 +574,11 @@ class Processor(CollectionConfigHandler, FlowProcessor): """ try: - self.session.execute(delete_partitions_cql, (collection, schema_name)) + await async_execute( + self.session, + delete_partitions_cql, + (collection, schema_name), + ) except Exception as e: logger.error( f"Failed to clean up row_partitions for {collection}/{schema_name}: {e}" diff --git a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py index 2a240f0b..01d95c8b 100755 --- a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py @@ -3,6 +3,7 @@ Graph writer. Input is graph edge. Writes edges to Cassandra graph. """ +import asyncio import base64 import os import argparse @@ -150,59 +151,71 @@ class Processor(CollectionConfigHandler, TriplesStoreService): user = message.metadata.user - if self.table is None or self.table != user: + # The cassandra-driver work below — connection, schema + # setup, and per-triple inserts — is all synchronous. + # Wrap the whole batch in a worker thread so the event + # loop stays responsive for sibling processors when + # running in a processor group. - self.tg = None + def _do_store(): - # Use factory function to select implementation - KGClass = EntityCentricKnowledgeGraph + if self.table is None or self.table != user: - try: - if self.cassandra_username and self.cassandra_password: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=message.metadata.user, - username=self.cassandra_username, password=self.cassandra_password - ) - else: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=message.metadata.user, - ) - except Exception as e: - logger.error(f"Exception: {e}", exc_info=True) - time.sleep(1) - raise e + self.tg = None - self.table = user + # Use factory function to select implementation + KGClass = EntityCentricKnowledgeGraph - for t in message.triples: - # Extract values from Term objects - s_val = get_term_value(t.s) - p_val = get_term_value(t.p) - o_val = get_term_value(t.o) - # t.g is None for default graph, or a graph IRI - g_val = t.g if t.g is not None else DEFAULT_GRAPH + try: + if self.cassandra_username and self.cassandra_password: + self.tg = KGClass( + hosts=self.cassandra_host, + keyspace=message.metadata.user, + username=self.cassandra_username, + password=self.cassandra_password, + ) + else: + self.tg = KGClass( + hosts=self.cassandra_host, + keyspace=message.metadata.user, + ) + except Exception as e: + logger.error(f"Exception: {e}", exc_info=True) + time.sleep(1) + raise e - # Extract object type metadata for entity-centric storage - otype = get_term_otype(t.o) - dtype = get_term_dtype(t.o) - lang = get_term_lang(t.o) + self.table = user - self.tg.insert( - message.metadata.collection, - s_val, - p_val, - o_val, - g=g_val, - otype=otype, - dtype=dtype, - lang=lang - ) + for t in message.triples: + # Extract values from Term objects + s_val = get_term_value(t.s) + p_val = get_term_value(t.p) + o_val = get_term_value(t.o) + # t.g is None for default graph, or a graph IRI + g_val = t.g if t.g is not None else DEFAULT_GRAPH + + # Extract object type metadata for entity-centric storage + otype = get_term_otype(t.o) + dtype = get_term_dtype(t.o) + lang = get_term_lang(t.o) + + self.tg.insert( + message.metadata.collection, + s_val, + p_val, + o_val, + g=g_val, + otype=otype, + dtype=dtype, + lang=lang, + ) + + await asyncio.to_thread(_do_store) async def create_collection(self, user: str, collection: str, metadata: dict): """Create a collection in Cassandra triple store via config push""" - try: + + def _do_create(): # Create or reuse connection for this user's keyspace if self.table is None or self.table != user: self.tg = None @@ -216,7 +229,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService): hosts=self.cassandra_host, keyspace=user, username=self.cassandra_username, - password=self.cassandra_password + password=self.cassandra_password, ) else: self.tg = KGClass( @@ -238,13 +251,16 @@ class Processor(CollectionConfigHandler, TriplesStoreService): self.tg.create_collection(collection) logger.info(f"Created collection {collection}") + try: + await asyncio.to_thread(_do_create) except Exception as e: logger.error(f"Failed to create collection {user}/{collection}: {e}", exc_info=True) raise async def delete_collection(self, user: str, collection: str): """Delete all data for a specific collection from the unified triples table""" - try: + + def _do_delete(): # Create or reuse connection for this user's keyspace if self.table is None or self.table != user: self.tg = None @@ -258,7 +274,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService): hosts=self.cassandra_host, keyspace=user, username=self.cassandra_username, - password=self.cassandra_password + password=self.cassandra_password, ) else: self.tg = KGClass( @@ -275,6 +291,8 @@ class Processor(CollectionConfigHandler, TriplesStoreService): self.tg.delete_collection(collection) logger.info(f"Deleted all triples for collection {collection} from keyspace {user}") + try: + await asyncio.to_thread(_do_delete) except Exception as e: logger.error(f"Failed to delete collection {user}/{collection}: {e}", exc_info=True) raise diff --git a/trustgraph-flow/trustgraph/tables/cassandra_async.py b/trustgraph-flow/trustgraph/tables/cassandra_async.py new file mode 100644 index 00000000..2f497748 --- /dev/null +++ b/trustgraph-flow/trustgraph/tables/cassandra_async.py @@ -0,0 +1,78 @@ +""" +Async wrapper for cassandra-driver sessions. + +The cassandra driver exposes a callback-based async API via +session.execute_async, returning a ResponseFuture that fires +on_result / on_error from the driver's own worker thread. +This module bridges that into an awaitable interface. + +Usage: + from ..tables.cassandra_async import async_execute + + rows = await async_execute(self.cassandra, stmt, (param1, param2)) + for row in rows: + ... + +Notes: + - Rows are materialised into a list inside the driver callback + thread before the future is resolved, so subsequent iteration + in the caller never triggers a sync page-fetch on the asyncio + loop. This is safe for single-page results (the common case + in this codebase); if a query needs pagination, handle it + explicitly. + - Callbacks fire on a driver worker thread; call_soon_threadsafe + is used to hand the result back to the asyncio loop. + - Errors from the driver are re-raised in the awaiting coroutine. +""" + +import asyncio + + +async def async_execute(session, query, parameters=None): + """Execute a CQL statement asynchronously. + + Args: + session: cassandra.cluster.Session (self.cassandra) + query: statement string or PreparedStatement + parameters: tuple/list of bind params, or None + + Returns: + A list of rows (materialised from the first result page). + """ + + loop = asyncio.get_running_loop() + fut = loop.create_future() + + def on_result(rows): + # Materialise on the driver thread so the loop thread + # never touches a lazy iterator that might trigger + # further sync I/O. + try: + materialised = list(rows) if rows is not None else [] + except Exception as e: + loop.call_soon_threadsafe( + _set_exception_if_pending, fut, e + ) + return + loop.call_soon_threadsafe( + _set_result_if_pending, fut, materialised + ) + + def on_error(exc): + loop.call_soon_threadsafe( + _set_exception_if_pending, fut, exc + ) + + rf = session.execute_async(query, parameters) + rf.add_callbacks(on_result, on_error) + return await fut + + +def _set_result_if_pending(fut, result): + if not fut.done(): + fut.set_result(result) + + +def _set_exception_if_pending(fut, exc): + if not fut.done(): + fut.set_exception(exc) diff --git a/trustgraph-flow/trustgraph/tables/config.py b/trustgraph-flow/trustgraph/tables/config.py index fb9ea0a7..d9a8711b 100644 --- a/trustgraph-flow/trustgraph/tables/config.py +++ b/trustgraph-flow/trustgraph/tables/config.py @@ -11,6 +11,8 @@ import time import asyncio import logging +from . cassandra_async import async_execute + logger = logging.getLogger(__name__) class ConfigTableStore: @@ -102,21 +104,20 @@ class ConfigTableStore: async def inc_version(self): - self.cassandra.execute(""" + await async_execute(self.cassandra, """ UPDATE version set version = version + 1 WHERE id = 'version' """) async def get_version(self): - resp = self.cassandra.execute(""" + rows = await async_execute(self.cassandra, """ SELECT version FROM version WHERE id = 'version' """) - row = resp.one() - - if row: return row[0] + if rows: + return rows[0][0] return None @@ -153,150 +154,91 @@ class ConfigTableStore: """) async def put_config(self, cls, key, value): - - while True: - - try: - - resp = self.cassandra.execute( - self.put_config_stmt, - ( cls, key, value ) - ) - - break - - except Exception as e: - - logger.error("Exception occurred", exc_info=True) - raise e + try: + await async_execute( + self.cassandra, + self.put_config_stmt, + (cls, key, value), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise async def get_value(self, cls, key): + try: + rows = await async_execute( + self.cassandra, + self.get_value_stmt, + (cls, key), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise - while True: - - try: - - resp = self.cassandra.execute( - self.get_value_stmt, - ( cls, key ) - ) - - break - - except Exception as e: - - logger.error("Exception occurred", exc_info=True) - raise e - - for row in resp: + for row in rows: return row[0] - return None async def get_values(self, cls): + try: + rows = await async_execute( + self.cassandra, + self.get_values_stmt, + (cls,), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise - while True: - - try: - - resp = self.cassandra.execute( - self.get_values_stmt, - ( cls, ) - ) - - break - - except Exception as e: - - logger.error("Exception occurred", exc_info=True) - raise e - - return [ - [row[0], row[1]] - for row in resp - ] + return [[row[0], row[1]] for row in rows] async def get_classes(self): + try: + rows = await async_execute( + self.cassandra, + self.get_classes_stmt, + (), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise - while True: - - try: - - resp = self.cassandra.execute( - self.get_classes_stmt, - () - ) - - break - - except Exception as e: - - logger.error("Exception occurred", exc_info=True) - raise e - - return [ - row[0] for row in resp - ] + return [row[0] for row in rows] async def get_all(self): + try: + rows = await async_execute( + self.cassandra, + self.get_all_stmt, + (), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise - while True: - - try: - - resp = self.cassandra.execute( - self.get_all_stmt, - () - ) - - break - - except Exception as e: - - logger.error("Exception occurred", exc_info=True) - raise e - - return [ - (row[0], row[1], row[2]) - for row in resp - ] + return [(row[0], row[1], row[2]) for row in rows] async def get_keys(self, cls): + try: + rows = await async_execute( + self.cassandra, + self.get_keys_stmt, + (cls,), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise - while True: - - try: - - resp = self.cassandra.execute( - self.get_keys_stmt, - ( cls, ) - ) - - break - - except Exception as e: - - logger.error("Exception occurred", exc_info=True) - raise e - - return [ - row[0] for row in resp - ] + return [row[0] for row in rows] async def delete_key(self, cls, key): - - while True: - - try: - - resp = self.cassandra.execute( - self.delete_key_stmt, - (cls, key) - ) - - break - - except Exception as e: - logger.error("Exception occurred", exc_info=True) - raise e + try: + await async_execute( + self.cassandra, + self.delete_key_stmt, + (cls, key), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise diff --git a/trustgraph-flow/trustgraph/tables/knowledge.py b/trustgraph-flow/trustgraph/tables/knowledge.py index 430dc3c9..b06f4862 100644 --- a/trustgraph-flow/trustgraph/tables/knowledge.py +++ b/trustgraph-flow/trustgraph/tables/knowledge.py @@ -4,6 +4,8 @@ from .. schema import Metadata, Term, IRI, LITERAL, GraphEmbeddings from cassandra.cluster import Cluster +from . cassandra_async import async_execute + def term_to_tuple(term): """Convert Term to (value, is_uri) tuple for database storage.""" @@ -225,25 +227,19 @@ class KnowledgeTableStore: for v in m.triples ] - while True: - - try: - - resp = self.cassandra.execute( - self.insert_triples_stmt, - ( - uuid.uuid4(), m.metadata.user, - m.metadata.root or m.metadata.id, when, - [], triples, - ) - ) - - break - - except Exception as e: - - logger.error("Exception occurred", exc_info=True) - raise e + try: + await async_execute( + self.cassandra, + self.insert_triples_stmt, + ( + uuid.uuid4(), m.metadata.user, + m.metadata.root or m.metadata.id, when, + [], triples, + ), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise async def add_graph_embeddings(self, m): @@ -257,25 +253,19 @@ class KnowledgeTableStore: for v in m.entities ] - while True: - - try: - - resp = self.cassandra.execute( - self.insert_graph_embeddings_stmt, - ( - uuid.uuid4(), m.metadata.user, - m.metadata.root or m.metadata.id, when, - [], entities, - ) - ) - - break - - except Exception as e: - - logger.error("Exception occurred", exc_info=True) - raise e + try: + await async_execute( + self.cassandra, + self.insert_graph_embeddings_stmt, + ( + uuid.uuid4(), m.metadata.user, + m.metadata.root or m.metadata.id, when, + [], entities, + ), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise async def add_document_embeddings(self, m): @@ -289,50 +279,35 @@ class KnowledgeTableStore: for v in m.chunks ] - while True: - - try: - - resp = self.cassandra.execute( - self.insert_document_embeddings_stmt, - ( - uuid.uuid4(), m.metadata.user, - m.metadata.root or m.metadata.id, when, - [], chunks, - ) - ) - - break - - except Exception as e: - - logger.error("Exception occurred", exc_info=True) - raise e + try: + await async_execute( + self.cassandra, + self.insert_document_embeddings_stmt, + ( + uuid.uuid4(), m.metadata.user, + m.metadata.root or m.metadata.id, when, + [], chunks, + ), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise async def list_kg_cores(self, user): logger.debug("List kg cores...") - while True: + try: + rows = await async_execute( + self.cassandra, + self.list_cores_stmt, + (user,), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise - try: - - resp = self.cassandra.execute( - self.list_cores_stmt, - (user,) - ) - - break - - except Exception as e: - logger.error("Exception occurred", exc_info=True) - raise e - - - lst = [ - row[1] - for row in resp - ] + lst = [row[1] for row in rows] logger.debug("Done") @@ -342,56 +317,41 @@ class KnowledgeTableStore: logger.debug("Delete kg cores...") - while True: + try: + await async_execute( + self.cassandra, + self.delete_triples_stmt, + (user, document_id), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise - try: - - resp = self.cassandra.execute( - self.delete_triples_stmt, - (user, document_id) - ) - - break - - except Exception as e: - logger.error("Exception occurred", exc_info=True) - raise e - - while True: - - try: - - resp = self.cassandra.execute( - self.delete_graph_embeddings_stmt, - (user, document_id) - ) - - break - - except Exception as e: - logger.error("Exception occurred", exc_info=True) - raise e + try: + await async_execute( + self.cassandra, + self.delete_graph_embeddings_stmt, + (user, document_id), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise async def get_triples(self, user, document_id, receiver): logger.debug("Get triples...") - while True: + try: + rows = await async_execute( + self.cassandra, + self.get_triples_stmt, + (user, document_id), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise - try: - - resp = self.cassandra.execute( - self.get_triples_stmt, - (user, document_id) - ) - - break - - except Exception as e: - logger.error("Exception occurred", exc_info=True) - raise e - - for row in resp: + for row in rows: if row[3]: triples = [ @@ -422,28 +382,23 @@ class KnowledgeTableStore: logger.debug("Get GE...") - while True: + try: + rows = await async_execute( + self.cassandra, + self.get_graph_embeddings_stmt, + (user, document_id), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise - try: - - resp = self.cassandra.execute( - self.get_graph_embeddings_stmt, - (user, document_id) - ) - - break - - except Exception as e: - logger.error("Exception occurred", exc_info=True) - raise e - - for row in resp: + for row in rows: if row[3]: entities = [ EntityEmbeddings( entity = tuple_to_term(ent[0][0], ent[0][1]), - vectors = ent[1] + vector = ent[1] ) for ent in row[3] ] diff --git a/trustgraph-flow/trustgraph/tables/library.py b/trustgraph-flow/trustgraph/tables/library.py index 11dd9022..c85ae72a 100644 --- a/trustgraph-flow/trustgraph/tables/library.py +++ b/trustgraph-flow/trustgraph/tables/library.py @@ -31,6 +31,8 @@ import time import asyncio import logging +from . cassandra_async import async_execute + logger = logging.getLogger(__name__) class LibraryTableStore: @@ -321,18 +323,13 @@ class LibraryTableStore: async def document_exists(self, user, id): - resp = self.cassandra.execute( + rows = await async_execute( + self.cassandra, self.test_document_exists_stmt, - ( user, id ) + (user, id), ) - # If a row exists, document exists. It's a cursor, can't just - # count the length - - for row in resp: - return True - - return False + return bool(rows) async def add_document(self, document, object_id): @@ -349,26 +346,20 @@ class LibraryTableStore: parent_id = getattr(document, 'parent_id', '') or '' document_type = getattr(document, 'document_type', 'source') or 'source' - while True: - - try: - - resp = self.cassandra.execute( - self.insert_document_stmt, - ( - document.id, document.user, int(document.time * 1000), - document.kind, document.title, document.comments, - metadata, document.tags, object_id, - parent_id, document_type - ) - ) - - break - - except Exception as e: - - logger.error("Exception occurred", exc_info=True) - raise e + try: + await async_execute( + self.cassandra, + self.insert_document_stmt, + ( + document.id, document.user, int(document.time * 1000), + document.kind, document.title, document.comments, + metadata, document.tags, object_id, + parent_id, document_type + ), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise logger.debug("Add complete") @@ -383,25 +374,19 @@ class LibraryTableStore: for v in document.metadata ] - while True: - - try: - - resp = self.cassandra.execute( - self.update_document_stmt, - ( - int(document.time * 1000), document.title, - document.comments, metadata, document.tags, - document.user, document.id - ) - ) - - break - - except Exception as e: - - logger.error("Exception occurred", exc_info=True) - raise e + try: + await async_execute( + self.cassandra, + self.update_document_stmt, + ( + int(document.time * 1000), document.title, + document.comments, metadata, document.tags, + document.user, document.id + ), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise logger.debug("Update complete") @@ -409,23 +394,15 @@ class LibraryTableStore: logger.info(f"Removing document {document_id}") - while True: - - try: - - resp = self.cassandra.execute( - self.delete_document_stmt, - ( - user, document_id - ) - ) - - break - - except Exception as e: - - logger.error("Exception occurred", exc_info=True) - raise e + try: + await async_execute( + self.cassandra, + self.delete_document_stmt, + (user, document_id), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise logger.debug("Delete complete") @@ -433,21 +410,15 @@ class LibraryTableStore: logger.debug("List documents...") - while True: - - try: - - resp = self.cassandra.execute( - self.list_document_stmt, - (user,) - ) - - break - - except Exception as e: - logger.error("Exception occurred", exc_info=True) - raise e - + try: + rows = await async_execute( + self.cassandra, + self.list_document_stmt, + (user,), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise lst = [ DocumentMetadata( @@ -469,7 +440,7 @@ class LibraryTableStore: parent_id = row[8] if row[8] else "", document_type = row[9] if row[9] else "source", ) - for row in resp + for row in rows ] logger.debug("Done") @@ -481,20 +452,15 @@ class LibraryTableStore: logger.debug(f"List children for parent {parent_id}") - while True: - - try: - - resp = self.cassandra.execute( - self.list_children_stmt, - (parent_id,) - ) - - break - - except Exception as e: - logger.error("Exception occurred", exc_info=True) - raise e + try: + rows = await async_execute( + self.cassandra, + self.list_children_stmt, + (parent_id,), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise lst = [ DocumentMetadata( @@ -516,7 +482,7 @@ class LibraryTableStore: parent_id = row[9] if row[9] else "", document_type = row[10] if row[10] else "source", ) - for row in resp + for row in rows ] logger.debug("Done") @@ -527,23 +493,17 @@ class LibraryTableStore: logger.debug("Get document") - while True: + try: + rows = await async_execute( + self.cassandra, + self.get_document_stmt, + (user, id), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise - try: - - resp = self.cassandra.execute( - self.get_document_stmt, - (user, id) - ) - - break - - except Exception as e: - logger.error("Exception occurred", exc_info=True) - raise e - - - for row in resp: + for row in rows: doc = DocumentMetadata( id = id, user = user, @@ -573,23 +533,17 @@ class LibraryTableStore: logger.debug("Get document obj ID") - while True: + try: + rows = await async_execute( + self.cassandra, + self.get_document_stmt, + (user, id), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise - try: - - resp = self.cassandra.execute( - self.get_document_stmt, - (user, id) - ) - - break - - except Exception as e: - logger.error("Exception occurred", exc_info=True) - raise e - - - for row in resp: + for row in rows: logger.debug("Done") return row[6] @@ -597,43 +551,32 @@ class LibraryTableStore: async def processing_exists(self, user, id): - resp = self.cassandra.execute( + rows = await async_execute( + self.cassandra, self.test_processing_exists_stmt, - ( user, id ) + (user, id), ) - # If a row exists, document exists. It's a cursor, can't just - # count the length - - for row in resp: - return True - - return False + return bool(rows) async def add_processing(self, processing): logger.info(f"Adding processing {processing.id}") - while True: - - try: - - resp = self.cassandra.execute( - self.insert_processing_stmt, - ( - processing.id, processing.document_id, - int(processing.time * 1000), processing.flow, - processing.user, processing.collection, - processing.tags - ) - ) - - break - - except Exception as e: - - logger.error("Exception occurred", exc_info=True) - raise e + try: + await async_execute( + self.cassandra, + self.insert_processing_stmt, + ( + processing.id, processing.document_id, + int(processing.time * 1000), processing.flow, + processing.user, processing.collection, + processing.tags + ), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise logger.debug("Add complete") @@ -641,23 +584,15 @@ class LibraryTableStore: logger.info(f"Removing processing {processing_id}") - while True: - - try: - - resp = self.cassandra.execute( - self.delete_processing_stmt, - ( - user, processing_id - ) - ) - - break - - except Exception as e: - - logger.error("Exception occurred", exc_info=True) - raise e + try: + await async_execute( + self.cassandra, + self.delete_processing_stmt, + (user, processing_id), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise logger.debug("Delete complete") @@ -665,21 +600,15 @@ class LibraryTableStore: logger.debug("List processing objects") - while True: - - try: - - resp = self.cassandra.execute( - self.list_processing_stmt, - (user,) - ) - - break - - except Exception as e: - logger.error("Exception occurred", exc_info=True) - raise e - + try: + rows = await async_execute( + self.cassandra, + self.list_processing_stmt, + (user,), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise lst = [ ProcessingMetadata( @@ -691,7 +620,7 @@ class LibraryTableStore: collection = row[4], tags = row[5] if row[5] else [], ) - for row in resp + for row in rows ] logger.debug("Done") @@ -718,20 +647,19 @@ class LibraryTableStore: now = int(time.time() * 1000) - while True: - try: - self.cassandra.execute( - self.insert_upload_session_stmt, - ( - upload_id, user, document_id, document_metadata, - s3_upload_id, object_id, total_size, chunk_size, - total_chunks, {}, now, now - ) - ) - break - except Exception as e: - logger.error("Exception occurred", exc_info=True) - raise e + try: + await async_execute( + self.cassandra, + self.insert_upload_session_stmt, + ( + upload_id, user, document_id, document_metadata, + s3_upload_id, object_id, total_size, chunk_size, + total_chunks, {}, now, now + ), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise logger.debug("Upload session created") @@ -740,18 +668,17 @@ class LibraryTableStore: logger.debug(f"Get upload session {upload_id}") - while True: - try: - resp = self.cassandra.execute( - self.get_upload_session_stmt, - (upload_id,) - ) - break - except Exception as e: - logger.error("Exception occurred", exc_info=True) - raise e + try: + rows = await async_execute( + self.cassandra, + self.get_upload_session_stmt, + (upload_id,), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise - for row in resp: + for row in rows: session = { "upload_id": row[0], "user": row[1], @@ -778,20 +705,19 @@ class LibraryTableStore: now = int(time.time() * 1000) - while True: - try: - self.cassandra.execute( - self.update_upload_session_chunk_stmt, - ( - {chunk_index: etag}, - now, - upload_id - ) - ) - break - except Exception as e: - logger.error("Exception occurred", exc_info=True) - raise e + try: + await async_execute( + self.cassandra, + self.update_upload_session_chunk_stmt, + ( + {chunk_index: etag}, + now, + upload_id + ), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise logger.debug("Chunk recorded") @@ -800,16 +726,15 @@ class LibraryTableStore: logger.info(f"Deleting upload session {upload_id}") - while True: - try: - self.cassandra.execute( - self.delete_upload_session_stmt, - (upload_id,) - ) - break - except Exception as e: - logger.error("Exception occurred", exc_info=True) - raise e + try: + await async_execute( + self.cassandra, + self.delete_upload_session_stmt, + (upload_id,), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise logger.debug("Upload session deleted") @@ -818,19 +743,18 @@ class LibraryTableStore: logger.debug(f"List upload sessions for {user}") - while True: - try: - resp = self.cassandra.execute( - self.list_upload_sessions_stmt, - (user,) - ) - break - except Exception as e: - logger.error("Exception occurred", exc_info=True) - raise e + try: + rows = await async_execute( + self.cassandra, + self.list_upload_sessions_stmt, + (user,), + ) + except Exception: + logger.error("Exception occurred", exc_info=True) + raise sessions = [] - for row in resp: + for row in rows: chunks_received = row[6] if row[6] else {} sessions.append({ "upload_id": row[0], diff --git a/trustgraph-ocr/pyproject.toml b/trustgraph-ocr/pyproject.toml index deab0d5c..cd1d20a1 100644 --- a/trustgraph-ocr/pyproject.toml +++ b/trustgraph-ocr/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.2,<2.3", + "trustgraph-base>=2.3,<2.4", "pulsar-client", "prometheus-client", "boto3", diff --git a/trustgraph-unstructured/pyproject.toml b/trustgraph-unstructured/pyproject.toml index 33265edb..d8879329 100644 --- a/trustgraph-unstructured/pyproject.toml +++ b/trustgraph-unstructured/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.2,<2.3", + "trustgraph-base>=2.3,<2.4", "pulsar-client", "prometheus-client", "python-magic", diff --git a/trustgraph-vertexai/pyproject.toml b/trustgraph-vertexai/pyproject.toml index 9eb75ed1..45958ef3 100644 --- a/trustgraph-vertexai/pyproject.toml +++ b/trustgraph-vertexai/pyproject.toml @@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc readme = "README.md" requires-python = ">=3.8" dependencies = [ - "trustgraph-base>=2.2,<2.3", + "trustgraph-base>=2.3,<2.4", "pulsar-client", "google-genai", "google-api-core",