diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 1b6dc177..c182578a 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -5,17 +5,17 @@ on: workflow_dispatch: push: tags: - - v0.18.* + - v* permissions: contents: read jobs: - deploy: + python-packages: - name: Build everything - runs-on: ubuntu-latest + name: Release Python packages + runs-on: ubuntu-24.04 permissions: contents: write id-token: write @@ -25,50 +25,73 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 - - - name: Log in to Docker Hub - uses: docker/login-action@f4ef78c080cd8ba55a85445d5b36e214a81df20a - with: - username: ${{ vars.DOCKER_USERNAME }} - password: ${{ secrets.DOCKER_SECRET }} - - - name: Install build dependencies - run: pip3 install jsonnet + uses: actions/checkout@v4 - name: Get version id: version run: echo VERSION=$(git describe --exact-match --tags | sed 's/^v//') >> $GITHUB_OUTPUT - - run: echo ${{ steps.version.outputs.VERSION }} - - name: Build packages run: make packages VERSION=${{ steps.version.outputs.VERSION }} - name: Publish release distributions to PyPI uses: pypa/gh-action-pypi-publish@release/v1 - - name: Extract metadata for container - id: meta - uses: docker/metadata-action@v4 - with: - images: trustgraph/trustgraph-flow - tags: | - type=ref,event=branch - type=ref,event=pr - type=semver,pattern={{version}} - type=semver,pattern={{major}}.{{minor}} - type=sha + deploy-container-image: - - name: Build and push Docker image - id: push - uses: docker/build-push-action@3b5e8027fcad23fda98b2e3ac259d8d67585f671 - with: - context: . - file: ./Containerfile - push: true - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} + name: Release container image + runs-on: ubuntu-24.04 + permissions: + contents: write + id-token: write + environment: + name: release + + steps: + + - name: Checkout + uses: actions/checkout@v4 + + - name: Docker Hub token + run: echo ${{ secrets.DOCKER_SECRET }} > docker-token.txt + + - name: Authenticate with Docker hub + run: make docker-hub-login + + - name: Get version + id: version + run: echo VERSION=$(git describe --exact-match --tags | sed 's/^v//') >> $GITHUB_OUTPUT + + - name: Put version into package manifests + run: make update-package-versions VERSION=${{ steps.version.outputs.VERSION }} + + - name: Build containers + run: make container VERSION=${{ steps.version.outputs.VERSION }} + + - name: Push containers + run: make push VERSION=${{ steps.version.outputs.VERSION }} + + release-bundle: + + name: Upload release bundle + runs-on: ubuntu-24.04 + permissions: + contents: write + id-token: write + environment: + name: release + + steps: + + - name: Checkout + uses: actions/checkout@v4 + + - name: Install build dependencies + run: pip3 install jsonnet + + - name: Get version + id: version + run: echo VERSION=$(git describe --exact-match --tags | sed 's/^v//') >> $GITHUB_OUTPUT - name: Create deploy bundle run: templates/generate-all deploy.zip ${{ steps.version.outputs.VERSION }} diff --git a/.gitignore b/.gitignore index 357ecf1e..4d089211 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ trustgraph-base/trustgraph/base_version.py trustgraph-bedrock/trustgraph/bedrock_version.py trustgraph-embeddings-hf/trustgraph/embeddings_hf_version.py trustgraph-flow/trustgraph/flow_version.py +trustgraph-ocr/trustgraph/ocr_version.py trustgraph-parquet/trustgraph/parquet_version.py trustgraph-vertexai/trustgraph/vertexai_version.py trustgraph-cli/trustgraph/ diff --git a/Containerfile b/Containerfile index 73c9285f..7283a06b 100644 --- a/Containerfile +++ b/Containerfile @@ -11,15 +11,26 @@ ENV PIP_BREAK_SYSTEM_PACKAGES=1 RUN dnf install -y python3 python3-pip python3-wheel python3-aiohttp \ python3-rdflib -RUN pip3 install torch --index-url https://download.pytorch.org/whl/cpu +RUN pip3 install torch==2.5.1+cpu \ + --index-url https://download.pytorch.org/whl/cpu -RUN pip3 install anthropic boto3 cohere openai google-cloud-aiplatform ollama google-generativeai \ - langchain langchain-core langchain-huggingface langchain-text-splitters \ - langchain-community pymilvus sentence-transformers transformers \ - huggingface-hub pulsar-client cassandra-driver pyyaml \ +RUN pip3 install \ + anthropic boto3 cohere mistralai openai google-cloud-aiplatform \ + ollama google-generativeai \ + langchain==0.3.13 langchain-core==0.3.28 langchain-huggingface==0.1.2 \ + langchain-text-splitters==0.3.4 \ + langchain-community==0.3.13 \ + sentence-transformers==3.4.0 transformers==4.47.1 \ + huggingface-hub==0.27.0 \ + pymilvus \ + pulsar-client==3.5.0 cassandra-driver pyyaml \ neo4j tiktoken falkordb && \ pip3 cache purge +# Most commonly used embeddings model, just build it into the container +# image +RUN huggingface-cli download sentence-transformers/all-MiniLM-L6-v2 + # ---------------------------------------------------------------------------- # Build a container which contains the built Python packages. The build # creates a bunch of left-over cruft, a separate phase means this is only @@ -34,6 +45,7 @@ COPY trustgraph-vertexai/ /root/build/trustgraph-vertexai/ COPY trustgraph-bedrock/ /root/build/trustgraph-bedrock/ COPY trustgraph-embeddings-hf/ /root/build/trustgraph-embeddings-hf/ COPY trustgraph-cli/ /root/build/trustgraph-cli/ +COPY trustgraph-ocr/ /root/build/trustgraph-ocr/ WORKDIR /root/build/ @@ -43,6 +55,7 @@ RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-vertexai/ RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-bedrock/ RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-embeddings-hf/ RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-cli/ +RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-ocr/ RUN ls /root/wheels @@ -61,6 +74,7 @@ RUN \ pip3 install /root/wheels/trustgraph_bedrock-* && \ pip3 install /root/wheels/trustgraph_embeddings_hf-* && \ pip3 install /root/wheels/trustgraph_cli-* && \ + pip3 install /root/wheels/trustgraph_ocr-* && \ pip3 cache purge && \ rm -rf /root/wheels diff --git a/Makefile b/Makefile index 67094a90..1fae97f6 100644 --- a/Makefile +++ b/Makefile @@ -16,6 +16,7 @@ wheels: pip3 wheel --no-deps --wheel-dir dist trustgraph-bedrock/ pip3 wheel --no-deps --wheel-dir dist trustgraph-embeddings-hf/ pip3 wheel --no-deps --wheel-dir dist trustgraph-cli/ + pip3 wheel --no-deps --wheel-dir dist trustgraph-ocr/ packages: update-package-versions rm -rf dist/ @@ -26,11 +27,12 @@ packages: update-package-versions cd trustgraph-bedrock && python3 setup.py sdist --dist-dir ../dist/ cd trustgraph-embeddings-hf && python3 setup.py sdist --dist-dir ../dist/ cd trustgraph-cli && python3 setup.py sdist --dist-dir ../dist/ + cd trustgraph-ocr && python3 setup.py sdist --dist-dir ../dist/ pypi-upload: twine upload dist/*-${VERSION}.* -CONTAINER=docker.io/trustgraph/trustgraph-flow +CONTAINER_BASE=docker.io/trustgraph update-package-versions: mkdir -p trustgraph-cli/trustgraph @@ -41,14 +43,34 @@ update-package-versions: echo __version__ = \"${VERSION}\" > trustgraph-bedrock/trustgraph/bedrock_version.py echo __version__ = \"${VERSION}\" > trustgraph-embeddings-hf/trustgraph/embeddings_hf_version.py echo __version__ = \"${VERSION}\" > trustgraph-cli/trustgraph/cli_version.py + echo __version__ = \"${VERSION}\" > trustgraph-ocr/trustgraph/ocr_version.py echo __version__ = \"${VERSION}\" > trustgraph/trustgraph/trustgraph_version.py container: update-package-versions - ${DOCKER} build -f Containerfile -t ${CONTAINER}:${VERSION} \ - --format docker + ${DOCKER} build -f containers/Containerfile.base \ + -t ${CONTAINER_BASE}/trustgraph-base:${VERSION} . + ${DOCKER} build -f containers/Containerfile.flow \ + -t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} . + ${DOCKER} build -f containers/Containerfile.bedrock \ + -t ${CONTAINER_BASE}/trustgraph-bedrock:${VERSION} . + ${DOCKER} build -f containers/Containerfile.vertexai \ + -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} . + ${DOCKER} build -f containers/Containerfile.hf \ + -t ${CONTAINER_BASE}/trustgraph-hf:${VERSION} . + ${DOCKER} build -f containers/Containerfile.ocr \ + -t ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} . + +container.ocr: + ${DOCKER} build -f containers/Containerfile.ocr \ + -t ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} . push: - ${DOCKER} push ${CONTAINER}:${VERSION} + ${DOCKER} push ${CONTAINER_BASE}/trustgraph-base:${VERSION} + ${DOCKER} push ${CONTAINER_BASE}/trustgraph-flow:${VERSION} + ${DOCKER} push ${CONTAINER_BASE}/trustgraph-bedrock:${VERSION} + ${DOCKER} push ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} + ${DOCKER} push ${CONTAINER_BASE}/trustgraph-hf:${VERSION} + ${DOCKER} push ${CONTAINER_BASE}/trustgraph-ocr:${VERSION} clean: rm -rf wheels/ @@ -56,13 +78,13 @@ clean: set-version: echo '"${VERSION}"' > templates/values/version.jsonnet -TEMPLATES=azure bedrock claude cohere mix llamafile ollama openai vertexai \ +TEMPLATES=azure bedrock claude cohere mix llamafile mistral ollama openai vertexai \ openai-neo4j storage DCS=$(foreach template,${TEMPLATES},${template:%=tg-launch-%.yaml}) -MODELS=azure bedrock claude cohere llamafile ollama openai vertexai -GRAPHS=cassandra neo4j falkordb +MODELS=azure bedrock claude cohere llamafile mistral ollama openai vertexai +GRAPHS=cassandra neo4j falkordb memgraph # tg-launch-%.yaml: templates/%.jsonnet templates/components/version.jsonnet # jsonnet -Jtemplates \ @@ -104,5 +126,5 @@ update-dcs: set-version docker-hub-login: cat docker-token.txt | \ - docker login -u trustgraph --password-stdin registry-1.docker.io + ${DOCKER} login -u trustgraph --password-stdin registry-1.docker.io diff --git a/containers/Containerfile.base b/containers/Containerfile.base new file mode 100644 index 00000000..b4f5bbbf --- /dev/null +++ b/containers/Containerfile.base @@ -0,0 +1,48 @@ + +# ---------------------------------------------------------------------------- +# Build an AI container. This does the torch install which is huge, and I +# like to avoid re-doing this. +# ---------------------------------------------------------------------------- + +FROM docker.io/fedora:40 AS base + +ENV PIP_BREAK_SYSTEM_PACKAGES=1 + +RUN dnf install -y python3 python3-pip python3-wheel python3-aiohttp && \ + dnf clean all + +RUN pip3 install --no-cache-dir pulsar-client==3.5.0 + +# ---------------------------------------------------------------------------- +# Build a container which contains the built Python packages. The build +# creates a bunch of left-over cruft, a separate phase means this is only +# needed to support package build +# ---------------------------------------------------------------------------- + +FROM base AS build + +COPY trustgraph-base/ /root/build/trustgraph-base/ +COPY trustgraph-cli/ /root/build/trustgraph-cli/ + +WORKDIR /root/build/ + +RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-base/ +RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-cli/ + +# ---------------------------------------------------------------------------- +# Finally, the target container. Start with base and add the package. +# ---------------------------------------------------------------------------- + +FROM base + +COPY --from=build /root/wheels /root/wheels + +RUN \ + pip3 install --no-cache-dir /root/wheels/trustgraph_base-* && \ + pip3 install --no-cache-dir /root/wheels/trustgraph_cli-* && \ + rm -rf /root/wheels + +WORKDIR / + + + diff --git a/containers/Containerfile.bedrock b/containers/Containerfile.bedrock new file mode 100644 index 00000000..21819973 --- /dev/null +++ b/containers/Containerfile.bedrock @@ -0,0 +1,48 @@ + +# ---------------------------------------------------------------------------- +# Build an AI container. This does the torch install which is huge, and I +# like to avoid re-doing this. +# ---------------------------------------------------------------------------- + +FROM docker.io/fedora:40 AS base + +ENV PIP_BREAK_SYSTEM_PACKAGES=1 + +RUN dnf install -y python3 python3-pip python3-wheel python3-aiohttp \ + python3-rdflib + +RUN pip3 install --no-cache-dir boto3 pulsar-client==3.5.0 + +# ---------------------------------------------------------------------------- +# Build a container which contains the built Python packages. The build +# creates a bunch of left-over cruft, a separate phase means this is only +# needed to support package build +# ---------------------------------------------------------------------------- + +FROM base AS build + +COPY trustgraph-base/ /root/build/trustgraph-base/ +COPY trustgraph-bedrock/ /root/build/trustgraph-bedrock/ + +WORKDIR /root/build/ + +RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-base/ +RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-bedrock/ + +RUN ls /root/wheels + +# ---------------------------------------------------------------------------- +# Finally, the target container. Start with base and add the package. +# ---------------------------------------------------------------------------- + +FROM base + +COPY --from=build /root/wheels /root/wheels + +RUN \ + pip3 install --no-cache-dir /root/wheels/trustgraph_base-* && \ + pip3 install --no-cache-dir /root/wheels/trustgraph_bedrock-* && \ + rm -rf /root/wheels + +WORKDIR / + diff --git a/containers/Containerfile.flow b/containers/Containerfile.flow new file mode 100644 index 00000000..352e5ac5 --- /dev/null +++ b/containers/Containerfile.flow @@ -0,0 +1,60 @@ + +# ---------------------------------------------------------------------------- +# Build an AI container. This does the torch install which is huge, and I +# like to avoid re-doing this. +# ---------------------------------------------------------------------------- + +FROM docker.io/fedora:40 AS base + +ENV PIP_BREAK_SYSTEM_PACKAGES=1 + +RUN dnf install -y python3 python3-pip python3-wheel python3-aiohttp \ + python3-rdflib + +RUN pip3 install --no-cache-dir \ + anthropic cohere mistralai openai google-generativeai \ + ollama \ + langchain==0.3.13 langchain-core==0.3.28 \ + langchain-text-splitters==0.3.4 \ + langchain-community==0.3.13 \ + pymilvus \ + pulsar-client==3.5.0 cassandra-driver pyyaml \ + neo4j tiktoken falkordb && \ + pip3 cache purge + +# ---------------------------------------------------------------------------- +# Build a container which contains the built Python packages. The build +# creates a bunch of left-over cruft, a separate phase means this is only +# needed to support package build +# ---------------------------------------------------------------------------- + +FROM base AS build + +COPY trustgraph-base/ /root/build/trustgraph-base/ +COPY trustgraph-flow/ /root/build/trustgraph-flow/ +COPY trustgraph-cli/ /root/build/trustgraph-cli/ + +WORKDIR /root/build/ + +RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-base/ +RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-flow/ +RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-cli/ + +RUN ls /root/wheels + +# ---------------------------------------------------------------------------- +# Finally, the target container. Start with base and add the package. +# ---------------------------------------------------------------------------- + +FROM base + +COPY --from=build /root/wheels /root/wheels + +RUN \ + pip3 install --no-cache-dir /root/wheels/trustgraph_base-* && \ + pip3 install --no-cache-dir /root/wheels/trustgraph_flow-* && \ + pip3 install --no-cache-dir /root/wheels/trustgraph_cli-* && \ + rm -rf /root/wheels + +WORKDIR / + diff --git a/containers/Containerfile.hf b/containers/Containerfile.hf new file mode 100644 index 00000000..4076db28 --- /dev/null +++ b/containers/Containerfile.hf @@ -0,0 +1,75 @@ + +# ---------------------------------------------------------------------------- +# Build an AI container. This does the torch install which is huge, and I +# like to avoid re-doing this. +# ---------------------------------------------------------------------------- + +FROM docker.io/fedora:40 AS ai + +ENV PIP_BREAK_SYSTEM_PACKAGES=1 + +RUN dnf install -y python3 python3-pip python3-wheel python3-aiohttp \ + python3-rdflib + +RUN pip3 install torch==2.5.1+cpu \ + --index-url https://download.pytorch.org/whl/cpu + +RUN pip3 install --no-cache-dir \ + langchain==0.3.13 langchain-core==0.3.28 langchain-huggingface==0.1.2 \ + langchain-community==0.3.13 \ + sentence-transformers==3.4.0 transformers==4.47.1 \ + huggingface-hub==0.27.0 \ + pulsar-client==3.5.0 + +# Most commonly used embeddings model, just build it into the container +# image +RUN huggingface-cli download sentence-transformers/all-MiniLM-L6-v2 + +# ---------------------------------------------------------------------------- +# Build a container which contains the built Python packages. The build +# creates a bunch of left-over cruft, a separate phase means this is only +# needed to support package build +# ---------------------------------------------------------------------------- + +FROM ai AS build + +COPY trustgraph-base/ /root/build/trustgraph-base/ +COPY trustgraph-flow/ /root/build/trustgraph-flow/ +COPY trustgraph-vertexai/ /root/build/trustgraph-vertexai/ +COPY trustgraph-bedrock/ /root/build/trustgraph-bedrock/ +COPY trustgraph-embeddings-hf/ /root/build/trustgraph-embeddings-hf/ +COPY trustgraph-cli/ /root/build/trustgraph-cli/ + +WORKDIR /root/build/ + +RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-base/ +RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-flow/ +RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-vertexai/ +RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-bedrock/ +RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-embeddings-hf/ +RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-cli/ + +RUN ls /root/wheels + +# ---------------------------------------------------------------------------- +# Finally, the target container. Start with base and add the package. +# ---------------------------------------------------------------------------- + +FROM ai + +COPY --from=build /root/wheels /root/wheels + +RUN \ + pip3 install /root/wheels/trustgraph_base-* && \ + pip3 install /root/wheels/trustgraph_flow-* && \ + pip3 install /root/wheels/trustgraph_vertexai-* && \ + pip3 install /root/wheels/trustgraph_bedrock-* && \ + pip3 install /root/wheels/trustgraph_embeddings_hf-* && \ + pip3 install /root/wheels/trustgraph_cli-* && \ + pip3 cache purge && \ + rm -rf /root/wheels + +WORKDIR / + +CMD sleep 1000000 + diff --git a/containers/Containerfile.ocr b/containers/Containerfile.ocr new file mode 100644 index 00000000..8a454008 --- /dev/null +++ b/containers/Containerfile.ocr @@ -0,0 +1,48 @@ + +# ---------------------------------------------------------------------------- +# Build an AI container. This does the torch install which is huge, and I +# like to avoid re-doing this. +# ---------------------------------------------------------------------------- + +FROM docker.io/fedora:40 AS base + +ENV PIP_BREAK_SYSTEM_PACKAGES=1 + +RUN dnf install -y python3 python3-pip python3-wheel python3-aiohttp \ + python3-rdflib tesseract poppler poppler-utils + +RUN pip3 install --no-cache-dir pytesseract pulsar-client==3.5.0 + +# ---------------------------------------------------------------------------- +# Build a container which contains the built Python packages. The build +# creates a bunch of left-over cruft, a separate phase means this is only +# needed to support package build +# ---------------------------------------------------------------------------- + +FROM base AS build + +COPY trustgraph-base/ /root/build/trustgraph-base/ +COPY trustgraph-ocr/ /root/build/trustgraph-ocr/ + +WORKDIR /root/build/ + +RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-base/ +RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-ocr/ + +RUN ls /root/wheels + +# ---------------------------------------------------------------------------- +# Finally, the target container. Start with base and add the package. +# ---------------------------------------------------------------------------- + +FROM base + +COPY --from=build /root/wheels /root/wheels + +RUN \ + pip3 install --no-cache-dir /root/wheels/trustgraph_base-* && \ + pip3 install --no-cache-dir /root/wheels/trustgraph_ocr-* && \ + rm -rf /root/wheels + +WORKDIR / + diff --git a/containers/Containerfile.vertexai b/containers/Containerfile.vertexai new file mode 100644 index 00000000..72d21bde --- /dev/null +++ b/containers/Containerfile.vertexai @@ -0,0 +1,51 @@ + +# ---------------------------------------------------------------------------- +# Build an AI container. This does the torch install which is huge, and I +# like to avoid re-doing this. +# ---------------------------------------------------------------------------- + +FROM docker.io/fedora:40 AS base + +ENV PIP_BREAK_SYSTEM_PACKAGES=1 + +RUN dnf install -y python3 python3-pip python3-wheel python3-aiohttp \ + python3-rdflib + +RUN pip3 install --no-cache-dir \ + google-cloud-aiplatform pulsar-client==3.5.0 + +# ---------------------------------------------------------------------------- +# Build a container which contains the built Python packages. The build +# creates a bunch of left-over cruft, a separate phase means this is only +# needed to support package build +# ---------------------------------------------------------------------------- + +FROM base AS build + +COPY trustgraph-base/ /root/build/trustgraph-base/ +COPY trustgraph-vertexai/ /root/build/trustgraph-vertexai/ + +WORKDIR /root/build/ + +RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-base/ +RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-vertexai/ + +RUN ls /root/wheels + +# ---------------------------------------------------------------------------- +# Finally, the target container. Start with base and add the package. +# ---------------------------------------------------------------------------- + +FROM base + +COPY --from=build /root/wheels /root/wheels + +RUN \ + pip3 install --no-cache-dir /root/wheels/trustgraph_base-* && \ + pip3 install --no-cache-dir /root/wheels/trustgraph_vertexai-* && \ + rm -rf /root/wheels + +WORKDIR / + + + diff --git a/docs/README.quickstart-docker-compose.md b/docs/README.quickstart-docker-compose.md index 12cb8cf0..76f7e1f5 100644 --- a/docs/README.quickstart-docker-compose.md +++ b/docs/README.quickstart-docker-compose.md @@ -113,8 +113,9 @@ Choose one of the `Docker Compose` files that meets your preferred model and gra ### AWS Bedrock API ``` -export AWS_ID_KEY= -export AWS_SECRET_KEY= +export AWS_ACCESS_KEY_ID= +export AWS_SECRET_ACCESS_KEY= +export AWS_DEFAULT_REGION= docker compose -f tg-launch-bedrock-cassandra.yaml up -d # Using Cassandra as the graph store docker compose -f tg-launch-bedrock-neo4j.yaml up -d # Using Neo4j as the graph store ``` diff --git a/grafana/dashboards/dashboard.json b/grafana/dashboards/dashboard.json index 04561863..c484dffa 100644 --- a/grafana/dashboards/dashboard.json +++ b/grafana/dashboards/dashboard.json @@ -577,7 +577,7 @@ "disableTextWrap": false, "editorMode": "builder", "exemplar": false, - "expr": "increase(processing_count_total{status!=\"success\"}[$__rate_interval])", + "expr": "sum by(job) (increase(rate_limit_count_total[$__rate_interval]))", "format": "time_series", "fullMetaSearch": false, "includeNullMetadata": true, @@ -588,7 +588,7 @@ "useBackend": false } ], - "title": "Errors", + "title": "Rate limit events", "type": "timeseries" }, { diff --git a/prometheus/prometheus.yml b/prometheus/prometheus.yml index 24102a23..0fa70314 100644 --- a/prometheus/prometheus.yml +++ b/prometheus/prometheus.yml @@ -20,6 +20,18 @@ scrape_configs: - targets: - 'pulsar:8080' + - job_name: 'bookie' + scrape_interval: 5s + static_configs: + - targets: + - 'bookie:8000' + + - job_name: 'zookeeper' + scrape_interval: 5s + static_configs: + - targets: + - 'zookeeper:8000' + - job_name: 'pdf-decoder' scrape_interval: 5s static_configs: @@ -32,11 +44,17 @@ scrape_configs: - targets: - 'chunker:8000' - - job_name: 'vectorize' + - job_name: 'document-embeddings' scrape_interval: 5s static_configs: - targets: - - 'vectorize:8000' + - 'document-embeddings:8000' + + - job_name: 'graph-embeddings' + scrape_interval: 5s + static_configs: + - targets: + - 'graph-embeddings:8000' - job_name: 'embeddings' scrape_interval: 5s @@ -74,6 +92,12 @@ scrape_configs: - targets: - 'metering-rag:8000' + - job_name: 'store-doc-embeddings' + scrape_interval: 5s + static_configs: + - targets: + - 'store-doc-embeddings:8000' + - job_name: 'store-graph-embeddings' scrape_interval: 5s static_configs: @@ -104,6 +128,12 @@ scrape_configs: - targets: - 'graph-rag:8000' + - job_name: 'document-rag' + scrape_interval: 5s + static_configs: + - targets: + - 'document-rag:8000' + - job_name: 'prompt' scrape_interval: 5s static_configs: @@ -122,6 +152,12 @@ scrape_configs: - targets: - 'query-graph-embeddings:8000' + - job_name: 'query-doc-embeddings' + scrape_interval: 5s + static_configs: + - targets: + - 'query-doc-embeddings:8000' + - job_name: 'query-triples' scrape_interval: 5s static_configs: @@ -145,3 +181,7 @@ scrape_configs: static_configs: - targets: - 'workbench-ui:8000' + +# Cassandra +# qdrant + diff --git a/templates/all-patterns.jsonnet b/templates/all-patterns.jsonnet index f68f307d..3282be53 100644 --- a/templates/all-patterns.jsonnet +++ b/templates/all-patterns.jsonnet @@ -13,6 +13,7 @@ import "patterns/llm-claude.jsonnet", import "patterns/llm-cohere.jsonnet", import "patterns/llm-llamafile.jsonnet", + import "patterns/llm-mistral.jsonnet", import "patterns/llm-ollama.jsonnet", import "patterns/llm-openai.jsonnet", import "patterns/llm-vertexai.jsonnet", diff --git a/templates/components.jsonnet b/templates/components.jsonnet index b14665d6..d0df569f 100644 --- a/templates/components.jsonnet +++ b/templates/components.jsonnet @@ -1,47 +1,80 @@ { + + // Essentials + "trustgraph-base": import "components/trustgraph.jsonnet", + "pulsar": import "components/pulsar.jsonnet", + + // LLMs "azure": import "components/azure.jsonnet", "azure-openai": import "components/azure-openai.jsonnet", "bedrock": import "components/bedrock.jsonnet", "claude": import "components/claude.jsonnet", "cohere": import "components/cohere.jsonnet", - "document-rag": import "components/document-rag.jsonnet", - "embeddings-hf": import "components/embeddings-hf.jsonnet", - "embeddings-ollama": import "components/embeddings-ollama.jsonnet", "googleaistudio": import "components/googleaistudio.jsonnet", - "grafana": import "components/grafana.jsonnet", + "mistral": import "components/mistral.jsonnet", + "ollama": import "components/ollama.jsonnet", + "openai": import "components/openai.jsonnet", + "vertexai": import "components/vertexai.jsonnet", + + // LLMs for RAG + "azure-rag": import "components/azure-rag.jsonnet", + "azure-openai-rag": import "components/azure-openai-rag.jsonnet", + "bedrock-rag": import "components/bedrock-rag.jsonnet", + "claude-rag": import "components/claude-rag.jsonnet", + "cohere-rag": import "components/cohere-rag.jsonnet", + "googleaistudio-rag": import "components/googleaistudio-rag.jsonnet", + "mistral-rag": import "components/mistral-rag.jsonnet", + "ollama-rag": import "components/ollama-rag.jsonnet", + "openai-rag": import "components/openai-rag.jsonnet", + "vertexai-rag": import "components/vertexai-rag.jsonnet", + + // Embeddings + "embeddings-ollama": import "components/embeddings-ollama.jsonnet", + "embeddings-hf": import "components/embeddings-hf.jsonnet", + "embeddings-fastembed": import "components/embeddings-fastembed.jsonnet", + + // Processing pipelines "graph-rag": import "components/graph-rag.jsonnet", + "document-rag": import "components/document-rag.jsonnet", + + // OCR options + "ocr": import "components/ocr.jsonnet", + "mistral-ocr": import "components/mistral-ocr.jsonnet", + + // Librarian - document management + "librarian": import "components/librarian.jsonnet", + + // Vector stores + "vector-store-milvus": import "components/milvus.jsonnet", + "vector-store-qdrant": import "components/qdrant.jsonnet", + "vector-store-pinecone": import "components/pinecone.jsonnet", + + // Triples stores "triple-store-cassandra": import "components/cassandra.jsonnet", "triple-store-neo4j": import "components/neo4j.jsonnet", "triple-store-falkordb": import "components/falkordb.jsonnet", "triple-store-memgraph": import "components/memgraph.jsonnet", + + // Observability support + "grafana": import "components/grafana.jsonnet", + + // Pulsar manager is a UI for Pulsar. Uses a LOT of memory + "pulsar-manager": import "components/pulsar-manager.jsonnet", + "llamafile": import "components/llamafile.jsonnet", - "ollama": import "components/ollama.jsonnet", - "openai": import "components/openai.jsonnet", "override-recursive-chunker": import "components/chunker-recursive.jsonnet", + // The prompt manager "prompt-template": import "components/prompt-template.jsonnet", "prompt-overrides": import "components/prompt-overrides.jsonnet", - "pulsar": import "components/pulsar.jsonnet", - "pulsar-manager": import "components/pulsar-manager.jsonnet", - "trustgraph-base": import "components/trustgraph.jsonnet", - "vector-store-milvus": import "components/milvus.jsonnet", - "vector-store-qdrant": import "components/qdrant.jsonnet", - "vector-store-pinecone": import "components/pinecone.jsonnet", - "vertexai": import "components/vertexai.jsonnet", - "workbench-ui": import "components/workbench-ui.jsonnet", - "null": {}, - + // ReAct agent "agent-manager-react": import "components/agent-manager-react.jsonnet", - // FIXME: Dupes - "cassandra": import "components/cassandra.jsonnet", - "neo4j": import "components/neo4j.jsonnet", - "memgraph": import "components/memgraph.jsonnet", - "qdrant": import "components/qdrant.jsonnet", - "pinecone": import "components/pinecone.jsonnet", - "milvus": import "components/milvus.jsonnet", - "falkordb": import "components/falkordb.jsonnet", - "trustgraph": import "components/trustgraph.jsonnet", + // Optional UI + "workbench-ui": import "components/workbench-ui.jsonnet", + + // Does nothing. But, can be a hack to overwrite parameters + "null": {}, } diff --git a/templates/components/agent-manager-react.jsonnet b/templates/components/agent-manager-react.jsonnet index a995dba5..672a0439 100644 --- a/templates/components/agent-manager-react.jsonnet +++ b/templates/components/agent-manager-react.jsonnet @@ -14,11 +14,15 @@ local default_prompts = import "prompts/default-prompts.jsonnet"; local container = engine.container("agent-manager") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "agent-manager-react", "-p", url.pulsar, + "--prompt-request-queue", + "non-persistent://tg/request/prompt-rag", + "--prompt-response-queue", + "non-persistent://tg/response/prompt-rag", "--tool-type", ] + [ tool.id + "=" + tool.type diff --git a/templates/components/azure-openai-rag.jsonnet b/templates/components/azure-openai-rag.jsonnet new file mode 100644 index 00000000..33355707 --- /dev/null +++ b/templates/components/azure-openai-rag.jsonnet @@ -0,0 +1,61 @@ +local base = import "base/base.jsonnet"; +local images = import "values/images.jsonnet"; +local url = import "values/url.jsonnet"; +local prompts = import "prompts/mixtral.jsonnet"; + +{ + + with:: function(key, value) + self + { + ["ollama-rag-" + key]:: value, + }, + + "azure-openai-rag-model":: "GPT-3.5-Turbo", + "azure-openai-rag-max-output-tokens":: 4192, + "azure-openai-rag-temperature":: 0.0, + + "text-completion-rag" +: { + + create:: function(engine) + + local envSecrets = engine.envSecrets("azure-openai-credentials") + .with_env_var("AZURE_TOKEN", "azure-token"); + + local containerRag = + engine.container("text-completion-rag") + .with_image(images.trustgraph_flow) + .with_command([ + "text-completion-azure", + "-p", + url.pulsar, + "-x", + std.toString($["azure-openai-rag-max-output-tokens"]), + "-t", + "%0.3f" % $["azure-openai-rag-temperature"], + "-i", + "non-persistent://tg/request/text-completion-rag", + "-o", + "non-persistent://tg/response/text-completion-rag", + ]) + .with_env_var_secrets(envSecrets) + .with_limits("0.5", "128M") + .with_reservations("0.1", "128M"); + + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] + ); + + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8000, 8000, "metrics"); + + engine.resources([ + envSecrets, + containerSetRag, + serviceRag, + ]) + + }, + +} + prompts + diff --git a/templates/components/azure-openai.jsonnet b/templates/components/azure-openai.jsonnet index 8afcaf11..3ecbbdac 100644 --- a/templates/components/azure-openai.jsonnet +++ b/templates/components/azure-openai.jsonnet @@ -5,6 +5,11 @@ local prompts = import "prompts/mixtral.jsonnet"; { + with:: function(key, value) + self + { + ["azure-openai-" + key]:: value, + }, + "azure-openai-model":: "GPT-3.5-Turbo", "azure-openai-max-output-tokens":: 4192, "azure-openai-temperature":: 0.0, @@ -18,7 +23,7 @@ local prompts = import "prompts/mixtral.jsonnet"; local container = engine.container("text-completion") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "text-completion-azure-openai", "-p", @@ -34,48 +39,18 @@ local prompts = import "prompts/mixtral.jsonnet"; .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); - local containerRag = - engine.container("text-completion-rag") - .with_image(images.trustgraph) - .with_command([ - "text-completion-azure", - "-p", - url.pulsar, - "-x", - std.toString($["azure-openai-max-output-tokens"]), - "-t", - "%0.3f" % $["azure-openai-temperature"], - "-i", - "non-persistent://tg/request/text-completion-rag", - "-o", - "non-persistent://tg/response/text-completion-rag", - ]) - .with_env_var_secrets(envSecrets) - .with_limits("0.5", "128M") - .with_reservations("0.1", "128M"); - local containerSet = engine.containers( "text-completion", [ container ] ); - local containerSetRag = engine.containers( - "text-completion-rag", [ containerRag ] - ); - local service = engine.internalService(containerSet) .with_port(8000, 8000, "metrics"); - local serviceRag = - engine.internalService(containerSetRag) - .with_port(8000, 8000, "metrics"); - engine.resources([ envSecrets, containerSet, - containerSetRag, service, - serviceRag, ]) }, diff --git a/templates/components/azure-rag.jsonnet b/templates/components/azure-rag.jsonnet new file mode 100644 index 00000000..20b7306e --- /dev/null +++ b/templates/components/azure-rag.jsonnet @@ -0,0 +1,60 @@ +local images = import "values/images.jsonnet"; +local url = import "values/url.jsonnet"; +local prompts = import "prompts/mixtral.jsonnet"; + +{ + + with:: function(key, value) + self + { + ["azure-rag-" + key]:: value, + }, + + "azure-rag-max-output-tokens":: 4096, + "azure-rag-temperature":: 0.0, + + "text-completion-rag" +: { + + create:: function(engine) + + local envSecrets = engine.envSecrets("azure-credentials") + .with_env_var("AZURE_TOKEN", "azure-token") + .with_env_var("AZURE_ENDPOINT", "azure-endpoint"); + + local containerRag = + engine.container("text-completion-rag") + .with_image(images.trustgraph_flow) + .with_command([ + "text-completion-azure", + "-p", + url.pulsar, + "-x", + std.toString($["azure-rag-max-output-tokens"]), + "-t", + "%0.3f" % $["azure-rag-temperature"], + "-i", + "non-persistent://tg/request/text-completion-rag", + "-o", + "non-persistent://tg/response/text-completion-rag", + ]) + .with_env_var_secrets(envSecrets) + .with_limits("0.5", "128M") + .with_reservations("0.1", "128M"); + + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] + ); + + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8000, 8000, "metrics"); + + engine.resources([ + envSecrets, + containerSetRag, + serviceRag, + ]) + + } + +} + prompts + diff --git a/templates/components/azure.jsonnet b/templates/components/azure.jsonnet index cf10dc66..c7746e23 100644 --- a/templates/components/azure.jsonnet +++ b/templates/components/azure.jsonnet @@ -1,10 +1,14 @@ -local base = import "base/base.jsonnet"; local images = import "values/images.jsonnet"; local url = import "values/url.jsonnet"; local prompts = import "prompts/mixtral.jsonnet"; { + with:: function(key, value) + self + { + ["azure-" + key]:: value, + }, + "azure-max-output-tokens":: 4096, "azure-temperature":: 0.0, @@ -18,7 +22,7 @@ local prompts = import "prompts/mixtral.jsonnet"; local container = engine.container("text-completion") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "text-completion-azure", "-p", @@ -32,48 +36,18 @@ local prompts = import "prompts/mixtral.jsonnet"; .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); - local containerRag = - engine.container("text-completion-rag") - .with_image(images.trustgraph) - .with_command([ - "text-completion-azure", - "-p", - url.pulsar, - "-x", - std.toString($["azure-max-output-tokens"]), - "-t", - "%0.3f" % $["azure-temperature"], - "-i", - "non-persistent://tg/request/text-completion-rag", - "-o", - "non-persistent://tg/response/text-completion-rag", - ]) - .with_env_var_secrets(envSecrets) - .with_limits("0.5", "128M") - .with_reservations("0.1", "128M"); - local containerSet = engine.containers( "text-completion", [ container ] ); - local containerSetRag = engine.containers( - "text-completion-rag", [ containerRag ] - ); - local service = engine.internalService(containerSet) .with_port(8000, 8000, "metrics"); - local serviceRag = - engine.internalService(containerSetRag) - .with_port(8000, 8000, "metrics"); - engine.resources([ envSecrets, containerSet, - containerSetRag, service, - serviceRag, ]) } diff --git a/templates/components/bedrock-rag.jsonnet b/templates/components/bedrock-rag.jsonnet new file mode 100644 index 00000000..b265a9f2 --- /dev/null +++ b/templates/components/bedrock-rag.jsonnet @@ -0,0 +1,66 @@ +local base = import "base/base.jsonnet"; +local images = import "values/images.jsonnet"; +local url = import "values/url.jsonnet"; +local prompts = import "prompts/mixtral.jsonnet"; +local chunker = import "chunker-recursive.jsonnet"; + +{ + + with:: function(key, value) + self + { + ["bedrock-rag-" + key]:: value, + }, + + "bedrock-rag-max-output-tokens":: 4096, + "bedrock-rag-temperature":: 0.0, + "bedrock-rag-model":: "mistral.mixtral-8x7b-instruct-v0:1", + + "text-completion-rag" +: { + + create:: function(engine) + + local envSecrets = engine.envSecrets("bedrock-credentials") + .with_env_var("AWS_ACCESS_KEY_ID", "aws-id-key") + .with_env_var("AWS_SECRET_ACCESS_KEY", "aws-secret") + .with_env_var("AWS_DEFAULT_REGION", "aws-region"); + + local containerRag = + engine.container("text-completion-rag") + .with_image(images.trustgraph_bedrock) + .with_command([ + "text-completion-bedrock", + "-p", + url.pulsar, + "-x", + std.toString($["bedrock-rag-max-output-tokens"]), + "-t", + "%0.3f" % $["bedrock-rag-temperature"], + "-m", + $["bedrock-rag-model"], + "-i", + "non-persistent://tg/request/text-completion-rag", + "-o", + "non-persistent://tg/response/text-completion-rag", + ]) + .with_env_var_secrets(envSecrets) + .with_limits("0.5", "128M") + .with_reservations("0.1", "128M"); + + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] + ); + + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8000, 8000, "metrics"); + + engine.resources([ + envSecrets, + containerSetRag, + serviceRag, + ]) + + }, + +} + prompts + chunker + diff --git a/templates/components/bedrock.jsonnet b/templates/components/bedrock.jsonnet index 6ccaa1c5..6b599057 100644 --- a/templates/components/bedrock.jsonnet +++ b/templates/components/bedrock.jsonnet @@ -6,6 +6,11 @@ local chunker = import "chunker-recursive.jsonnet"; { + with:: function(key, value) + self + { + ["bedrock-" + key]:: value, + }, + "bedrock-max-output-tokens":: 4096, "bedrock-temperature":: 0.0, "bedrock-model":: "mistral.mixtral-8x7b-instruct-v0:1", @@ -15,13 +20,13 @@ local chunker = import "chunker-recursive.jsonnet"; create:: function(engine) local envSecrets = engine.envSecrets("bedrock-credentials") - .with_env_var("AWS_ID_KEY", "aws-id-key") - .with_env_var("AWS_SECRET", "aws-secret") - .with_env_var("AWS_REGION", "aws-region"); + .with_env_var("AWS_ACCESS_KEY_ID", "aws-id-key") + .with_env_var("AWS_SECRET_ACCESS_KEY", "aws-secret") + .with_env_var("AWS_DEFAULT_REGION", "aws-region"); local container = engine.container("text-completion") - .with_image(images.trustgraph) + .with_image(images.trustgraph_bedrock) .with_command([ "text-completion-bedrock", "-p", @@ -37,50 +42,18 @@ local chunker = import "chunker-recursive.jsonnet"; .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); - local containerRag = - engine.container("text-completion-rag") - .with_image(images.trustgraph) - .with_command([ - "text-completion-bedrock", - "-p", - url.pulsar, - "-x", - std.toString($["bedrock-max-output-tokens"]), - "-t", - "%0.3f" % $["bedrock-temperature"], - "-m", - $["bedrock-model"], - "-i", - "non-persistent://tg/request/text-completion-rag", - "-o", - "non-persistent://tg/response/text-completion-rag", - ]) - .with_env_var_secrets(envSecrets) - .with_limits("0.5", "128M") - .with_reservations("0.1", "128M"); - local containerSet = engine.containers( "text-completion", [ container ] ); - local containerSetRag = engine.containers( - "text-completion-rag", [ containerRag ] - ); - local service = engine.internalService(containerSet) .with_port(8000, 8000, "metrics"); - local serviceRag = - engine.internalService(containerSetRag) - .with_port(8000, 8000, "metrics"); - engine.resources([ envSecrets, containerSet, - containerSetRag, service, - serviceRag, ]) }, diff --git a/templates/components/cassandra.jsonnet b/templates/components/cassandra.jsonnet index b52d4b04..92ecf69f 100644 --- a/templates/components/cassandra.jsonnet +++ b/templates/components/cassandra.jsonnet @@ -12,7 +12,7 @@ cassandra + { local container = engine.container("store-triples") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "triples-write-cassandra", "-p", @@ -44,7 +44,7 @@ cassandra + { local container = engine.container("query-triples") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "triples-query-cassandra", "-p", diff --git a/templates/components/chunker-recursive.jsonnet b/templates/components/chunker-recursive.jsonnet index 0b64b712..4a174366 100644 --- a/templates/components/chunker-recursive.jsonnet +++ b/templates/components/chunker-recursive.jsonnet @@ -14,7 +14,7 @@ local prompts = import "prompts/mixtral.jsonnet"; local container = engine.container("chunker") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "chunker-recursive", "-p", diff --git a/templates/components/claude-rag.jsonnet b/templates/components/claude-rag.jsonnet new file mode 100644 index 00000000..06d58db2 --- /dev/null +++ b/templates/components/claude-rag.jsonnet @@ -0,0 +1,63 @@ +local base = import "base/base.jsonnet"; +local images = import "values/images.jsonnet"; +local url = import "values/url.jsonnet"; +local prompts = import "prompts/mixtral.jsonnet"; + +{ + + with:: function(key, value) + self + { + ["claude-rag-" + key]:: value, + }, + + "claude-rag-model":: "claude-3-sonnet-20240229", + "claude-rag-max-output-tokens":: 4096, + "claude-rag-temperature":: 0.0, + + "text-completion-rag" +: { + + create:: function(engine) + + local envSecrets = engine.envSecrets("claude-credentials") + .with_env_var("CLAUDE_KEY", "claude-key"); + + local containerRag = + engine.container("text-completion-rag") + .with_image(images.trustgraph_flow) + .with_command([ + "text-completion-claude", + "-p", + url.pulsar, + "-x", + std.toString($["claude-rag-max-output-tokens"]), + "-m", + $["claude-rag-model"], + "-t", + "%0.3f" % $["claude-rag-temperature"], + "-i", + "non-persistent://tg/request/text-completion-rag", + "-o", + "non-persistent://tg/response/text-completion-rag", + ]) + .with_env_var_secrets(envSecrets) + .with_limits("0.5", "128M") + .with_reservations("0.1", "128M"); + + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] + ); + + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8000, 8000, "metrics"); + + engine.resources([ + envSecrets, + containerSetRag, + serviceRag, + ]) + + }, + +} + prompts + diff --git a/templates/components/claude.jsonnet b/templates/components/claude.jsonnet index 00e4ec79..e43e7504 100644 --- a/templates/components/claude.jsonnet +++ b/templates/components/claude.jsonnet @@ -5,6 +5,12 @@ local prompts = import "prompts/mixtral.jsonnet"; { + with:: function(key, value) + self + { + ["claude-" + key]:: value, + }, + + "claude-model":: "claude-3-sonnet-20240229", "claude-max-output-tokens":: 4096, "claude-temperature":: 0.0, @@ -13,17 +19,19 @@ local prompts = import "prompts/mixtral.jsonnet"; create:: function(engine) local envSecrets = engine.envSecrets("claude-credentials") - .with_env_var("CLAUDE_KEY_TOKEN", "claude-key"); + .with_env_var("CLAUDE_KEY", "claude-key"); local container = engine.container("text-completion") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "text-completion-claude", "-p", url.pulsar, "-x", std.toString($["claude-max-output-tokens"]), + "-m", + $["claude-model"], "-t", "%0.3f" % $["claude-temperature"], ]) @@ -31,48 +39,18 @@ local prompts = import "prompts/mixtral.jsonnet"; .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); - local containerRag = - engine.container("text-completion-rag") - .with_image(images.trustgraph) - .with_command([ - "text-completion-claude", - "-p", - url.pulsar, - "-x", - std.toString($["claude-max-output-tokens"]), - "-t", - "%0.3f" % $["claude-temperature"], - "-i", - "non-persistent://tg/request/text-completion-rag", - "-o", - "non-persistent://tg/response/text-completion-rag", - ]) - .with_env_var_secrets(envSecrets) - .with_limits("0.5", "128M") - .with_reservations("0.1", "128M"); - local containerSet = engine.containers( "text-completion", [ container ] ); - local containerSetRag = engine.containers( - "text-completion-rag", [ containerRag ] - ); - local service = engine.internalService(containerSet) .with_port(8000, 8000, "metrics"); - local serviceRag = - engine.internalService(containerSetRag) - .with_port(8000, 8000, "metrics"); - engine.resources([ envSecrets, containerSet, - containerSetRag, service, - serviceRag, ]) }, diff --git a/templates/components/cohere-rag.jsonnet b/templates/components/cohere-rag.jsonnet new file mode 100644 index 00000000..6a142519 --- /dev/null +++ b/templates/components/cohere-rag.jsonnet @@ -0,0 +1,56 @@ +local base = import "base/base.jsonnet"; +local images = import "values/images.jsonnet"; +local url = import "values/url.jsonnet"; +local prompts = import "prompts/mixtral.jsonnet"; + +{ + + with:: function(key, value) + self + { + ["cohere-rag-" + key]:: value, + }, + + "cohere-rag-temperature":: 0.0, + + "text-completion-rag" +: { + + create:: function(engine) + + local envSecrets = engine.envSecrets("cohere-credentials") + .with_env_var("COHERE_KEY", "cohere-key"); + + local containerRag = + engine.container("text-completion-rag") + .with_image(images.trustgraph_flow) + .with_command([ + "text-completion-cohere", + "-p", + url.pulsar, + "-t", + "%0.3f" % $["cohere-rag-temperature"], + "-i", + "non-persistent://tg/request/text-completion-rag", + "-o", + "non-persistent://tg/response/text-completion-rag", + ]) + .with_limits("0.5", "128M") + .with_reservations("0.1", "128M"); + + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] + ); + + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8000, 8000, "metrics"); + + engine.resources([ + envSecrets, + containerSetRag, + serviceRag, + ]) + + }, + +} + prompts + diff --git a/templates/components/cohere.jsonnet b/templates/components/cohere.jsonnet index 5bc9b39c..093436fd 100644 --- a/templates/components/cohere.jsonnet +++ b/templates/components/cohere.jsonnet @@ -5,9 +5,10 @@ local prompts = import "prompts/mixtral.jsonnet"; { - // Override chunking - "chunk-size":: 150, - "chunk-overlap":: 10, + with:: function(key, value) + self + { + ["cohere-" + key]:: value, + }, "cohere-temperature":: 0.0, @@ -20,7 +21,7 @@ local prompts = import "prompts/mixtral.jsonnet"; local container = engine.container("text-completion") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "text-completion-cohere", "-p", @@ -31,45 +32,18 @@ local prompts = import "prompts/mixtral.jsonnet"; .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); - local containerRag = - engine.container("text-completion-rag") - .with_image(images.trustgraph) - .with_command([ - "text-completion-cohere", - "-p", - url.pulsar, - "-t", - "%0.3f" % $["cohere-temperature"], - "-i", - "non-persistent://tg/request/text-completion-rag", - "-o", - "non-persistent://tg/response/text-completion-rag", - ]) - .with_limits("0.5", "128M") - .with_reservations("0.1", "128M"); - local containerSet = engine.containers( "text-completion", [ container ] ); - local containerSetRag = engine.containers( - "text-completion-rag", [ containerRag ] - ); - local service = engine.internalService(containerSet) .with_port(8000, 8000, "metrics"); - local serviceRag = - engine.internalService(containerSetRag) - .with_port(8000, 8000, "metrics"); - engine.resources([ envSecrets, containerSet, - containerSetRag, service, - serviceRag, ]) }, diff --git a/templates/components/document-rag.jsonnet b/templates/components/document-rag.jsonnet index 0a68dd52..2d9dda3d 100644 --- a/templates/components/document-rag.jsonnet +++ b/templates/components/document-rag.jsonnet @@ -5,17 +5,21 @@ local prompts = import "prompts/mixtral.jsonnet"; { + "document-rag-doc-limit":: 20, + "document-rag" +: { create:: function(engine) local container = engine.container("document-rag") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "document-rag", "-p", url.pulsar, + "--doc-limit", + std.toString($["document-rag-doc-limit"]), "--prompt-request-queue", "non-persistent://tg/request/prompt-rag", "--prompt-response-queue", @@ -39,5 +43,35 @@ local prompts = import "prompts/mixtral.jsonnet"; }, + "document-embeddings" +: { + + create:: function(engine) + + local container = + engine.container("document-embeddings") + .with_image(images.trustgraph_flow) + .with_command([ + "document-embeddings", + "-p", + url.pulsar, + ]) + .with_limits("1.0", "512M") + .with_reservations("0.5", "512M"); + + local containerSet = engine.containers( + "document-embeddings", [ container ] + ); + + local service = + engine.internalService(containerSet) + .with_port(8000, 8000, "metrics"); + + engine.resources([ + containerSet, + service, + ]) + + }, + } diff --git a/templates/components/embeddings-fastembed.jsonnet b/templates/components/embeddings-fastembed.jsonnet new file mode 100644 index 00000000..c1fe35ff --- /dev/null +++ b/templates/components/embeddings-fastembed.jsonnet @@ -0,0 +1,43 @@ +local base = import "base/base.jsonnet"; +local images = import "values/images.jsonnet"; +local url = import "values/url.jsonnet"; +local prompts = import "prompts/mixtral.jsonnet"; + +{ + + "embeddings-model":: "sentence-transformers/all-MiniLM-L6-v2", + + embeddings +: { + + create:: function(engine) + + local container = + engine.container("embeddings") + .with_image(images.trustgraph_flow) + .with_command([ + "embeddings-fastembed", + "-p", + url.pulsar, + "-m", + $["embeddings-model"], + ]) + .with_limits("1.0", "400M") + .with_reservations("0.5", "400M"); + + local containerSet = engine.containers( + "embeddings", [ container ] + ); + + local service = + engine.internalService(containerSet) + .with_port(8000, 8000, "metrics"); + + engine.resources([ + containerSet, + service, + ]) + + }, + +} + diff --git a/templates/components/embeddings-hf.jsonnet b/templates/components/embeddings-hf.jsonnet index b46feac7..29ebbc48 100644 --- a/templates/components/embeddings-hf.jsonnet +++ b/templates/components/embeddings-hf.jsonnet @@ -13,7 +13,7 @@ local prompts = import "prompts/mixtral.jsonnet"; local container = engine.container("embeddings") - .with_image(images.trustgraph) + .with_image(images.trustgraph_hf) .with_command([ "embeddings-hf", "-p", diff --git a/templates/components/embeddings-ollama.jsonnet b/templates/components/embeddings-ollama.jsonnet index 425a1c47..a26ad0ba 100644 --- a/templates/components/embeddings-ollama.jsonnet +++ b/templates/components/embeddings-ollama.jsonnet @@ -13,7 +13,7 @@ local url = import "values/url.jsonnet"; local container = engine.container("embeddings") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "embeddings-ollama", "-p", diff --git a/templates/components/falkordb.jsonnet b/templates/components/falkordb.jsonnet index e238cebe..c08896d3 100644 --- a/templates/components/falkordb.jsonnet +++ b/templates/components/falkordb.jsonnet @@ -13,7 +13,7 @@ falkordb + { local container = engine.container("store-triples") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "triples-write-falkordb", "-p", @@ -45,7 +45,7 @@ falkordb + { local container = engine.container("query-triples") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "triples-query-falkordb", "-p", diff --git a/templates/components/googleaistudio-rag.jsonnet b/templates/components/googleaistudio-rag.jsonnet new file mode 100644 index 00000000..332749e8 --- /dev/null +++ b/templates/components/googleaistudio-rag.jsonnet @@ -0,0 +1,65 @@ +local base = import "base/base.jsonnet"; +local images = import "values/images.jsonnet"; +local url = import "values/url.jsonnet"; +local prompts = import "prompts/mixtral.jsonnet"; + +{ + + with:: function(key, value) + self + { + ["googleaistudio-rag-" + key]:: value, + }, + + "googleaistudio-rag-max-output-tokens":: 4096, + "googleaistudio-rag-temperature":: 0.0, + "googleaistudio-rag-model":: "gemini-1.5-flash-002", + + "text-completion-rag" +: { + + create:: function(engine) + + local envSecrets = engine.envSecrets("googleaistudio-credentials") + .with_env_var("GOOGLE_AI_STUDIO_KEY", "googleaistudio-key"); + + local containerRag = + engine.container("text-completion-rag") + .with_image(images.trustgraph_flow) + .with_command([ + "text-completion-googleaistudio", + "-p", + url.pulsar, + "-x", + std.toString( + $["googleaistudio-rag-max-output-tokens"] + ), + "-t", + "%0.3f" % $["googleaistudio-rag-temperature"], + "-m", + $["googleaistudio-rag-model"], + "-i", + "non-persistent://tg/request/text-completion-rag", + "-o", + "non-persistent://tg/response/text-completion-rag", + ]) + .with_env_var_secrets(envSecrets) + .with_limits("0.5", "128M") + .with_reservations("0.1", "128M"); + + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] + ); + + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8000, 8000, "metrics"); + + engine.resources([ + envSecrets, + containerSetRag, + serviceRag, + ]) + + }, + +} + prompts + diff --git a/templates/components/googleaistudio.jsonnet b/templates/components/googleaistudio.jsonnet index c2a40f2c..58c7807d 100644 --- a/templates/components/googleaistudio.jsonnet +++ b/templates/components/googleaistudio.jsonnet @@ -5,6 +5,11 @@ local prompts = import "prompts/mixtral.jsonnet"; { + with:: function(key, value) + self + { + ["googleaistudio-" + key]:: value, + }, + "googleaistudio-max-output-tokens":: 4096, "googleaistudio-temperature":: 0.0, "googleaistudio-model":: "gemini-1.5-flash-002", @@ -13,12 +18,12 @@ local prompts = import "prompts/mixtral.jsonnet"; create:: function(engine) - local envSecrets = engine.envSecrets("googleaistudio-key") + local envSecrets = engine.envSecrets("googleaistudio-credentials") .with_env_var("GOOGLE_AI_STUDIO_KEY", "googleaistudio-key"); local container = engine.container("text-completion") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "text-completion-googleaistudio", "-p", @@ -34,50 +39,18 @@ local prompts = import "prompts/mixtral.jsonnet"; .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); - local containerRag = - engine.container("text-completion-rag") - .with_image(images.trustgraph) - .with_command([ - "text-completion-googleaistudio", - "-p", - url.pulsar, - "-x", - std.toString($["googleaistudio-max-output-tokens"]), - "-t", - "%0.3f" % $["googleaistudio-temperature"], - "-m", - $["googleaistudio-model"], - "-i", - "non-persistent://tg/request/text-completion-rag", - "-o", - "non-persistent://tg/response/text-completion-rag", - ]) - .with_env_var_secrets(envSecrets) - .with_limits("0.5", "128M") - .with_reservations("0.1", "128M"); - local containerSet = engine.containers( "text-completion", [ container ] ); - local containerSetRag = engine.containers( - "text-completion-rag", [ containerRag ] - ); - local service = engine.internalService(containerSet) .with_port(8000, 8000, "metrics"); - local serviceRag = - engine.internalService(containerSetRag) - .with_port(8000, 8000, "metrics"); - engine.resources([ envSecrets, containerSet, - containerSetRag, service, - serviceRag, ]) }, diff --git a/templates/components/graph-rag.jsonnet b/templates/components/graph-rag.jsonnet index 860152c9..8d3e2e38 100644 --- a/templates/components/graph-rag.jsonnet +++ b/templates/components/graph-rag.jsonnet @@ -6,7 +6,8 @@ local url = import "values/url.jsonnet"; "graph-rag-entity-limit":: 50, "graph-rag-triple-limit":: 30, - "graph-rag-max-subgraph-size":: 3000, + "graph-rag-max-subgraph-size":: 400, + "graph-rag-max-path-length":: 2, "kg-extract-definitions" +: { @@ -14,7 +15,7 @@ local url = import "values/url.jsonnet"; local container = engine.container("kg-extract-definitions") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "kg-extract-definitions", "-p", @@ -44,7 +45,7 @@ local url = import "values/url.jsonnet"; local container = engine.container("kg-extract-relationships") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "kg-extract-relationships", "-p", @@ -74,7 +75,7 @@ local url = import "values/url.jsonnet"; local container = engine.container("kg-extract-topics") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "kg-extract-topics", "-p", @@ -104,7 +105,7 @@ local url = import "values/url.jsonnet"; local container = engine.container("graph-rag") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "graph-rag", "-p", @@ -119,6 +120,8 @@ local url = import "values/url.jsonnet"; std.toString($["graph-rag-triple-limit"]), "--max-subgraph-size", std.toString($["graph-rag-max-subgraph-size"]), + "--max-path-length", + std.toString($["graph-rag-max-path-length"]), ]) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); @@ -138,5 +141,35 @@ local url = import "values/url.jsonnet"; }, + "graph-embeddings" +: { + + create:: function(engine) + + local container = + engine.container("graph-embeddings") + .with_image(images.trustgraph_flow) + .with_command([ + "graph-embeddings", + "-p", + url.pulsar, + ]) + .with_limits("1.0", "512M") + .with_reservations("0.5", "512M"); + + local containerSet = engine.containers( + "graph-embeddings", [ container ] + ); + + local service = + engine.internalService(containerSet) + .with_port(8000, 8000, "metrics"); + + engine.resources([ + containerSet, + service, + ]) + + }, + } diff --git a/templates/components/librarian.jsonnet b/templates/components/librarian.jsonnet new file mode 100644 index 00000000..4df1b692 --- /dev/null +++ b/templates/components/librarian.jsonnet @@ -0,0 +1,43 @@ +local base = import "base/base.jsonnet"; +local images = import "values/images.jsonnet"; +local url = import "values/url.jsonnet"; +local minio = import "stores/minio.jsonnet"; +local cassandra = import "stores/cassandra.jsonnet"; + +{ + + "librarian" +: { + + create:: function(engine) + + local container = + engine.container("librarian") + .with_image(images.trustgraph_flow) + .with_command([ + "librarian", + "-p", + url.pulsar, + ]) + .with_limits("0.5", "256M") + .with_reservations("0.1", "256M"); + + local containerSet = engine.containers( + "librarian", [ container ] + ); + + local service = + engine.internalService(containerSet) + .with_port(8000, 8000, "metrics"); + + engine.resources([ + containerSet, + service, + ]) + + }, + +} + + // Minio and Cassandra are used by the Librarian + + minio + cassandra + diff --git a/templates/components/llamafile-rag.jsonnet b/templates/components/llamafile-rag.jsonnet new file mode 100644 index 00000000..262f586e --- /dev/null +++ b/templates/components/llamafile-rag.jsonnet @@ -0,0 +1,57 @@ +local base = import "base/base.jsonnet"; +local images = import "values/images.jsonnet"; +local url = import "values/url.jsonnet"; +local prompts = import "prompts/slm.jsonnet"; + +{ + + with:: function(key, value) + self + { + ["llamafile-rag-" + key]:: value, + }, + + "llamafile-rag-model":: "LLaMA_CPP", + + "text-completion-rag" +: { + + create:: function(engine) + + local envSecrets = engine.envSecrets("llamafile-credentials") + .with_env_var("LLAMAFILE_URL", "llamafile-url"); + + local containerRag = + engine.container("text-completion-rag") + .with_image(images.trustgraph_flow) + .with_command([ + "text-completion-llamafile", + "-p", + url.pulsar, + "-m", + $["llamafile-rag-model"], + "-i", + "non-persistent://tg/request/text-completion-rag", + "-o", + "non-persistent://tg/response/text-completion-rag", + ]) + .with_env_var_secrets(envSecrets) + .with_limits("0.5", "128M") + .with_reservations("0.1", "128M"); + + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] + ); + + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8080, 8080, "metrics"); + + engine.resources([ + envSecrets, + containerSetRag, + serviceRag, + ]) + + }, + +} + prompts + diff --git a/templates/components/llamafile.jsonnet b/templates/components/llamafile.jsonnet index bc1a011c..f3e1efd3 100644 --- a/templates/components/llamafile.jsonnet +++ b/templates/components/llamafile.jsonnet @@ -5,6 +5,11 @@ local prompts = import "prompts/slm.jsonnet"; { + with:: function(key, value) + self + { + ["llamafile-" + key]:: value, + }, + "llamafile-model":: "LLaMA_CPP", "text-completion" +: { @@ -16,7 +21,7 @@ local prompts = import "prompts/slm.jsonnet"; local container = engine.container("text-completion") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "text-completion-llamafile", "-p", @@ -28,46 +33,18 @@ local prompts = import "prompts/slm.jsonnet"; .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); - local containerRag = - engine.container("text-completion-rag") - .with_image(images.trustgraph) - .with_command([ - "text-completion-llamafile", - "-p", - url.pulsar, - "-m", - $["llamafile-model"], - "-i", - "non-persistent://tg/request/text-completion-rag", - "-o", - "non-persistent://tg/response/text-completion-rag", - ]) - .with_env_var_secrets(envSecrets) - .with_limits("0.5", "128M") - .with_reservations("0.1", "128M"); - local containerSet = engine.containers( "text-completion", [ container ] ); - local containerSetRag = engine.containers( - "text-completion-rag", [ containerRag ] - ); - local service = engine.internalService(containerSet) .with_port(8080, 8080, "metrics"); - local serviceRag = - engine.internalService(containerSetRag) - .with_port(8080, 8080, "metrics"); - engine.resources([ envSecrets, containerSet, - containerSetRag, service, - serviceRag, ]) }, diff --git a/templates/components/memgraph.jsonnet b/templates/components/memgraph.jsonnet index 609da3a2..21684a61 100644 --- a/templates/components/memgraph.jsonnet +++ b/templates/components/memgraph.jsonnet @@ -14,7 +14,7 @@ memgraph + { local container = engine.container("store-triples") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "triples-write-memgraph", "-p", @@ -48,7 +48,7 @@ memgraph + { local container = engine.container("query-triples") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "triples-query-memgraph", "-p", diff --git a/templates/components/milvus.jsonnet b/templates/components/milvus.jsonnet index b3044f98..27e5e316 100644 --- a/templates/components/milvus.jsonnet +++ b/templates/components/milvus.jsonnet @@ -12,7 +12,7 @@ milvus + { local container = engine.container("store-graph-embeddings") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "ge-write-milvus", "-p", @@ -44,7 +44,7 @@ milvus + { local container = engine.container("query-graph-embeddings") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "ge-query-milvus", "-p", @@ -76,7 +76,7 @@ milvus + { local container = engine.container("store-doc-embeddings") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "de-write-milvus", "-p", @@ -108,7 +108,7 @@ milvus + { local container = engine.container("query-doc-embeddings") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "de-query-milvus", "-p", diff --git a/templates/components/mistral-ocr.jsonnet b/templates/components/mistral-ocr.jsonnet new file mode 100644 index 00000000..8049c514 --- /dev/null +++ b/templates/components/mistral-ocr.jsonnet @@ -0,0 +1,47 @@ +local images = import "values/images.jsonnet"; +local url = import "values/url.jsonnet"; + +{ + + with:: function(key, value) + self + { + ["mistral-" + key]:: value, + }, + + "pdf-decoder" +: { + + create:: function(engine) + + local envSecrets = engine.envSecrets("mistral-credentials") + .with_env_var("MISTRAL_TOKEN", "mistral-token"); + + local container = + engine.container("mistral-ocr") + .with_image(images.trustgraph_flow) + .with_command([ + "pdf-ocr-mistral", + "-p", + url.pulsar, + ]) + .with_env_var_secrets(envSecrets) + .with_limits("0.5", "128M") + .with_reservations("0.1", "128M"); + + local containerSet = engine.containers( + "mistral-ocr", [ container ] + ); + + local service = + engine.internalService(containerSet) + .with_port(8080, 8080, "metrics"); + + engine.resources([ + envSecrets, + containerSet, + service, + ]) + + }, + +} + prompts + diff --git a/templates/components/mistral-rag.jsonnet b/templates/components/mistral-rag.jsonnet new file mode 100644 index 00000000..12fbe8a5 --- /dev/null +++ b/templates/components/mistral-rag.jsonnet @@ -0,0 +1,63 @@ +local base = import "base/base.jsonnet"; +local images = import "values/images.jsonnet"; +local url = import "values/url.jsonnet"; +local prompts = import "prompts/mixtral.jsonnet"; + +{ + + with:: function(key, value) + self + { + ["mistral-rag-" + key]:: value, + }, + + "mistral-rag-max-output-tokens":: 4096, + "mistral-rag-temperature":: 0.0, + "mistral-rag-model":: "ministral-8b-latest", + + "text-completion-rag" +: { + + create:: function(engine) + + local envSecrets = engine.envSecrets("mistral-credentials") + .with_env_var("MISTRAL_TOKEN", "mistral-token"); + + local containerRag = + engine.container("text-completion-rag") + .with_image(images.trustgraph_flow) + .with_command([ + "text-completion-mistral", + "-p", + url.pulsar, + "-x", + std.toString($["mistral-rag-max-output-tokens"]), + "-t", + "%0.3f" % $["mistral-rag-temperature"], + "-m", + $["mistral-rag-model"], + "-i", + "non-persistent://tg/request/text-completion-rag", + "-o", + "non-persistent://tg/response/text-completion-rag", + ]) + .with_env_var_secrets(envSecrets) + .with_limits("0.5", "128M") + .with_reservations("0.1", "128M"); + + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] + ); + + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8080, 8080, "metrics"); + + engine.resources([ + envSecrets, + containerSetRag, + serviceRag, + ]) + + }, + +} + prompts + diff --git a/templates/components/mistral.jsonnet b/templates/components/mistral.jsonnet new file mode 100644 index 00000000..4de332c9 --- /dev/null +++ b/templates/components/mistral.jsonnet @@ -0,0 +1,59 @@ +local base = import "base/base.jsonnet"; +local images = import "values/images.jsonnet"; +local url = import "values/url.jsonnet"; +local prompts = import "prompts/mixtral.jsonnet"; + +{ + + with:: function(key, value) + self + { + ["mistral-" + key]:: value, + }, + + "mistral-max-output-tokens":: 4096, + "mistral-temperature":: 0.0, + "mistral-model":: "ministral-8b-latest", + + "text-completion" +: { + + create:: function(engine) + + local envSecrets = engine.envSecrets("mistral-credentials") + .with_env_var("MISTRAL_TOKEN", "mistral-token"); + + local container = + engine.container("text-completion") + .with_image(images.trustgraph_flow) + .with_command([ + "text-completion-mistral", + "-p", + url.pulsar, + "-x", + std.toString($["mistral-max-output-tokens"]), + "-t", + "%0.3f" % $["mistral-temperature"], + "-m", + $["mistral-model"], + ]) + .with_env_var_secrets(envSecrets) + .with_limits("0.5", "128M") + .with_reservations("0.1", "128M"); + + local containerSet = engine.containers( + "text-completion", [ container ] + ); + + local service = + engine.internalService(containerSet) + .with_port(8080, 8080, "metrics"); + + engine.resources([ + envSecrets, + containerSet, + service, + ]) + + }, + +} + prompts + diff --git a/templates/components/neo4j.jsonnet b/templates/components/neo4j.jsonnet index b70562fe..7cebdc71 100644 --- a/templates/components/neo4j.jsonnet +++ b/templates/components/neo4j.jsonnet @@ -13,7 +13,7 @@ neo4j + { local container = engine.container("store-triples") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "triples-write-neo4j", "-p", @@ -45,7 +45,7 @@ neo4j + { local container = engine.container("query-triples") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "triples-query-neo4j", "-p", diff --git a/templates/components/ocr.jsonnet b/templates/components/ocr.jsonnet new file mode 100644 index 00000000..4353b7f9 --- /dev/null +++ b/templates/components/ocr.jsonnet @@ -0,0 +1,38 @@ +local images = import "values/images.jsonnet"; +local url = import "values/url.jsonnet"; + +{ + + "pdf-decoder" +: { + + create:: function(engine) + + local container = + engine.container("pdf-ocr") + .with_image(images.trustgraph_ocr) + .with_command([ + "pdf-ocr", + "-p", + url.pulsar, + ]) + .with_limits("1.0", "512M") + .with_reservations("0.1", "512M"); + + local containerSet = engine.containers( + "pdf-ocr", [ container ] + ); + + local service = + engine.internalService(containerSet) + .with_port(8080, 8080, "metrics"); + + engine.resources([ + envSecrets, + containerSet, + service, + ]) + + }, + +} + prompts + diff --git a/templates/components/ollama-rag.jsonnet b/templates/components/ollama-rag.jsonnet new file mode 100644 index 00000000..680adea5 --- /dev/null +++ b/templates/components/ollama-rag.jsonnet @@ -0,0 +1,57 @@ +local base = import "base/base.jsonnet"; +local images = import "values/images.jsonnet"; +local url = import "values/url.jsonnet"; +local prompts = import "prompts/mixtral.jsonnet"; + +{ + + with:: function(key, value) + self + { + ["ollama-rag-" + key]:: value, + }, + + "ollama-rag-model":: "gemma2:9b", + + "text-completion-rag" +: { + + create:: function(engine) + + local envSecrets = engine.envSecrets("ollama-credentials") + .with_env_var("OLLAMA_HOST", "ollama-host"); + + local containerRag = + engine.container("text-completion-rag") + .with_image(images.trustgraph_flow) + .with_command([ + "text-completion-ollama", + "-p", + url.pulsar, + "-m", + $["ollama-rag-model"], + "-i", + "non-persistent://tg/request/text-completion-rag", + "-o", + "non-persistent://tg/response/text-completion-rag", + ]) + .with_env_var_secrets(envSecrets) + .with_limits("0.5", "128M") + .with_reservations("0.1", "128M"); + + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] + ); + + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8080, 8080, "metrics"); + + engine.resources([ + envSecrets, + containerSetRag, + serviceRag, + ]) + + }, + +} + prompts + diff --git a/templates/components/ollama.jsonnet b/templates/components/ollama.jsonnet index 8da00848..95f1abf0 100644 --- a/templates/components/ollama.jsonnet +++ b/templates/components/ollama.jsonnet @@ -5,6 +5,11 @@ local prompts = import "prompts/mixtral.jsonnet"; { + with:: function(key, value) + self + { + ["ollama-" + key]:: value, + }, + "ollama-model":: "gemma2:9b", "text-completion" +: { @@ -16,7 +21,7 @@ local prompts = import "prompts/mixtral.jsonnet"; local container = engine.container("text-completion") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "text-completion-ollama", "-p", @@ -28,46 +33,18 @@ local prompts = import "prompts/mixtral.jsonnet"; .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); - local containerRag = - engine.container("text-completion-rag") - .with_image(images.trustgraph) - .with_command([ - "text-completion-ollama", - "-p", - url.pulsar, - "-m", - $["ollama-model"], - "-i", - "non-persistent://tg/request/text-completion-rag", - "-o", - "non-persistent://tg/response/text-completion-rag", - ]) - .with_env_var_secrets(envSecrets) - .with_limits("0.5", "128M") - .with_reservations("0.1", "128M"); - local containerSet = engine.containers( "text-completion", [ container ] ); - local containerSetRag = engine.containers( - "text-completion-rag", [ containerRag ] - ); - local service = engine.internalService(containerSet) .with_port(8080, 8080, "metrics"); - local serviceRag = - engine.internalService(containerSetRag) - .with_port(8080, 8080, "metrics"); - engine.resources([ envSecrets, containerSet, - containerSetRag, service, - serviceRag, ]) }, diff --git a/templates/components/openai-rag.jsonnet b/templates/components/openai-rag.jsonnet new file mode 100644 index 00000000..bfb7dd98 --- /dev/null +++ b/templates/components/openai-rag.jsonnet @@ -0,0 +1,63 @@ +local base = import "base/base.jsonnet"; +local images = import "values/images.jsonnet"; +local url = import "values/url.jsonnet"; +local prompts = import "prompts/mixtral.jsonnet"; + +{ + + with:: function(key, value) + self + { + ["openai-rag-" + key]:: value, + }, + + "openai-rag-max-output-tokens":: 4096, + "openai-rag-temperature":: 0.0, + "openai-rag-model":: "GPT-3.5-Turbo", + + "text-completion-rag" +: { + + create:: function(engine) + + local envSecrets = engine.envSecrets("openai-credentials") + .with_env_var("OPENAI_TOKEN", "openai-token"); + + local containerRag = + engine.container("text-completion-rag") + .with_image(images.trustgraph_flow) + .with_command([ + "text-completion-openai", + "-p", + url.pulsar, + "-x", + std.toString($["openai-rag-max-output-tokens"]), + "-t", + "%0.3f" % $["openai-rag-temperature"], + "-m", + $["openai-rag-model"], + "-i", + "non-persistent://tg/request/text-completion-rag", + "-o", + "non-persistent://tg/response/text-completion-rag", + ]) + .with_env_var_secrets(envSecrets) + .with_limits("0.5", "128M") + .with_reservations("0.1", "128M"); + + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] + ); + + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8080, 8080, "metrics"); + + engine.resources([ + envSecrets, + containerSetRag, + serviceRag, + ]) + + }, + +} + prompts + diff --git a/templates/components/openai.jsonnet b/templates/components/openai.jsonnet index 27725cb6..9e0212d2 100644 --- a/templates/components/openai.jsonnet +++ b/templates/components/openai.jsonnet @@ -5,6 +5,11 @@ local prompts = import "prompts/mixtral.jsonnet"; { + with:: function(key, value) + self + { + ["openai-" + key]:: value, + }, + "openai-max-output-tokens":: 4096, "openai-temperature":: 0.0, "openai-model":: "GPT-3.5-Turbo", @@ -18,7 +23,7 @@ local prompts = import "prompts/mixtral.jsonnet"; local container = engine.container("text-completion") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "text-completion-openai", "-p", @@ -34,50 +39,18 @@ local prompts = import "prompts/mixtral.jsonnet"; .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); - local containerRag = - engine.container("text-completion-rag") - .with_image(images.trustgraph) - .with_command([ - "text-completion-openai", - "-p", - url.pulsar, - "-x", - std.toString($["openai-max-output-tokens"]), - "-t", - "%0.3f" % $["openai-temperature"], - "-m", - $["openai-model"], - "-i", - "non-persistent://tg/request/text-completion-rag", - "-o", - "non-persistent://tg/response/text-completion-rag", - ]) - .with_env_var_secrets(envSecrets) - .with_limits("0.5", "128M") - .with_reservations("0.1", "128M"); - local containerSet = engine.containers( "text-completion", [ container ] ); - local containerSetRag = engine.containers( - "text-completion-rag", [ containerRag ] - ); - local service = engine.internalService(containerSet) .with_port(8080, 8080, "metrics"); - local serviceRag = - engine.internalService(containerSetRag) - .with_port(8080, 8080, "metrics"); - engine.resources([ envSecrets, containerSet, - containerSetRag, service, - serviceRag, ]) }, diff --git a/templates/components/pinecone.jsonnet b/templates/components/pinecone.jsonnet index 3422952a..ede383a5 100644 --- a/templates/components/pinecone.jsonnet +++ b/templates/components/pinecone.jsonnet @@ -17,7 +17,7 @@ local cassandra_hosts = "cassandra"; local container = engine.container("store-graph-embeddings") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "ge-write-pinecone", "-p", @@ -52,7 +52,7 @@ local cassandra_hosts = "cassandra"; local container = engine.container("query-graph-embeddings") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "ge-query-pinecone", "-p", @@ -87,7 +87,7 @@ local cassandra_hosts = "cassandra"; local container = engine.container("store-doc-embeddings") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "de-write-pinecone", "-p", @@ -122,7 +122,7 @@ local cassandra_hosts = "cassandra"; local container = engine.container("query-doc-embeddings") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "de-query-pinecone", "-p", diff --git a/templates/components/prompt-overrides.jsonnet b/templates/components/prompt-overrides.jsonnet index 648e5b66..852ec09d 100644 --- a/templates/components/prompt-overrides.jsonnet +++ b/templates/components/prompt-overrides.jsonnet @@ -1,7 +1,3 @@ -local base = import "base/base.jsonnet"; -local images = import "values/images.jsonnet"; -local url = import "values/url.jsonnet"; -local prompts = import "prompts/mixtral.jsonnet"; local default_prompts = import "prompts/default-prompts.jsonnet"; { diff --git a/templates/components/prompt-template.jsonnet b/templates/components/prompt-template.jsonnet index 3dadf337..b3187c9b 100644 --- a/templates/components/prompt-template.jsonnet +++ b/templates/components/prompt-template.jsonnet @@ -44,7 +44,7 @@ local default_prompts = import "prompts/default-prompts.jsonnet"; local container = engine.container("prompt") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "prompt-template", "-p", @@ -84,7 +84,7 @@ local default_prompts = import "prompts/default-prompts.jsonnet"; local container = engine.container("prompt-rag") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "prompt-template", "-p", diff --git a/templates/components/pulsar.jsonnet b/templates/components/pulsar.jsonnet index 0342b4d5..d111f616 100644 --- a/templates/components/pulsar.jsonnet +++ b/templates/components/pulsar.jsonnet @@ -2,32 +2,114 @@ local base = import "base/base.jsonnet"; local images = import "values/images.jsonnet"; local url = import "values/url.jsonnet"; +// This is a Pulsar configuration. Non-standalone mode so we deploy +// individual components: bookkeeper, broker and zookeeper. +// +// This also deploys the TrustGraph 'admin' container which initialises +// TrustGraph-specific namespaces etc. + { "pulsar" +: { create:: function(engine) -// local confVolume = engine.volume("pulsar-conf").with_size("2G"); - local dataVolume = engine.volume("pulsar-data").with_size("20G"); + // Zookeeper volume + local zkVolume = engine.volume("zookeeper").with_size("1G"); - local container = + // Zookeeper container + local zkContainer = + engine.container("zookeeper") + .with_image(images.pulsar) + .with_command([ + "bash", + "-c", + "bin/apply-config-from-env.py conf/zookeeper.conf && bin/generate-zookeeper-config.sh conf/zookeeper.conf && exec bin/pulsar zookeeper" + ]) + .with_limits("1", "400M") + .with_reservations("0.05", "400M") + .with_user("0:1000") + .with_volume_mount(zkVolume, "/pulsar/data/zookeeper") + .with_environment({ + "metadataStoreUrl": "zk:zookeeper:2181", + "PULSAR_MEM": "-Xms256m -Xmx256m -XX:MaxDirectMemorySize=256m", + }) + .with_port(2181, 2181, "zookeeper") + .with_port(2888, 2888, "zookeeper2") + .with_port(3888, 3888, "zookeeper3"); + + // Pulsar cluster init container + local initContainer = + engine.container("pulsar-init") + .with_image(images.pulsar) + .with_command([ + "bash", + "-c", + "sleep 10 && bin/pulsar initialize-cluster-metadata --cluster cluster-a --zookeeper zookeeper:2181 --configuration-store zookeeper:2181 --web-service-url http://pulsar:8080 --broker-service-url pulsar://pulsar:6650", + ]) + .with_limits("1", "512M") + .with_reservations("0.05", "512M") + .with_environment({ + "PULSAR_MEM": "-Xms256m -Xmx256m -XX:MaxDirectMemorySize=256m", + }); + + + // Bookkeeper volume + local bookieVolume = engine.volume("bookie").with_size("20G"); + + // Bookkeeper container + local bookieContainer = + engine.container("bookie") + .with_image(images.pulsar) + .with_command([ + "bash", + "-c", + "bin/apply-config-from-env.py conf/bookkeeper.conf && exec bin/pulsar bookie" + // false ^ causes this to be a 'failure' exit. + ]) + .with_limits("1", "800M") + .with_reservations("0.1", "800M") + .with_user("0:1000") + .with_volume_mount(bookieVolume, "/pulsar/data/bookkeeper") + .with_environment({ + "clusterName": "cluster-a", + "zkServers": "zookeeper:2181", + "bookieId": "bookie", + "metadataStoreUri": "metadata-store:zk:zookeeper:2181", + "advertisedAddress": "bookie", + "BOOKIE_MEM": "-Xms512m -Xmx512m -XX:MaxDirectMemorySize=256m", + }) + .with_port(3181, 3181, "bookie"); + + // Pulsar broker, stateless (uses ZK and Bookkeeper for state) + local brokerContainer = engine.container("pulsar") .with_image(images.pulsar) - .with_command(["bin/pulsar", "standalone"]) + .with_command([ + "bash", + "-c", + "bin/apply-config-from-env.py conf/broker.conf && exec bin/pulsar broker" + ]) + .with_limits("1", "800M") + .with_reservations("0.1", "800M") .with_environment({ - "PULSAR_MEM": "-Xms600M -Xmx600M" + "metadataStoreUrl": "zk:zookeeper:2181", + "zookeeperServers": "zookeeper:2181", + "clusterName": "cluster-a", + "managedLedgerDefaultEnsembleSize": "1", + "managedLedgerDefaultWriteQuorum": "1", + "managedLedgerDefaultAckQuorum": "1", + "advertisedAddress": "pulsar", + "advertisedListeners": "external:pulsar://pulsar:6650,localhost:pulsar://localhost:6650", + "PULSAR_MEM": "-Xms512m -Xmx512m -XX:MaxDirectMemorySize=256m", }) - .with_limits("2.0", "1500M") - .with_reservations("1.0", "1500M") -// .with_volume_mount(confVolume, "/pulsar/conf") - .with_volume_mount(dataVolume, "/pulsar/data") - .with_port(6650, 6650, "bookie") - .with_port(8080, 8080, "http"); + .with_port(6650, 6650, "pulsar") + .with_port(8080, 8080, "admin"); + // Trustgraph Pulsar initialisation local adminContainer = - engine.container("init-pulsar") - .with_image(images.trustgraph) + engine.container("init-trustgraph") + .with_image(images.trustgraph_flow) .with_command([ "tg-init-pulsar", "-p", @@ -36,10 +118,32 @@ local url = import "values/url.jsonnet"; .with_limits("1", "128M") .with_reservations("0.1", "128M"); - local containerSet = engine.containers( + // Container sets + local zkContainerSet = engine.containers( + "zookeeper", + [ + zkContainer, + ] + ); + + local initContainerSet = engine.containers( + "init-pulsar", + [ + initContainer, + ] + ); + + local bookieContainerSet = engine.containers( + "bookie", + [ + bookieContainer, + ] + ); + + local brokerContainerSet = engine.containers( "pulsar", [ - container + brokerContainer, ] ); @@ -50,17 +154,35 @@ local url = import "values/url.jsonnet"; ] ); - local service = - engine.service(containerSet) - .with_port(6650, 6650, "bookie") - .with_port(8080, 8080, "http"); + // Zookeeper service + local zkService = + engine.service(zkContainerSet) + .with_port(2181, 2181, "zookeeper") + .with_port(2888, 2888, "zookeeper2") + .with_port(3888, 3888, "zookeeper3"); + + // Bookkeeper service + local bookieService = + engine.service(bookieContainerSet) + .with_port(3181, 3181, "bookie"); + + // Pulsar broker service + local brokerService = + engine.service(brokerContainerSet) + .with_port(6650, 6650, "pulsar") + .with_port(8080, 8080, "admin"); engine.resources([ -// confVolume, - dataVolume, - containerSet, + zkVolume, + bookieVolume, + zkContainerSet, + initContainerSet, + bookieContainerSet, + brokerContainerSet, adminContainerSet, - service, + zkService, + bookieService, + brokerService, ]) } diff --git a/templates/components/qdrant.jsonnet b/templates/components/qdrant.jsonnet index f923e84f..352cb741 100644 --- a/templates/components/qdrant.jsonnet +++ b/templates/components/qdrant.jsonnet @@ -12,7 +12,7 @@ qdrant + { local container = engine.container("store-graph-embeddings") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "ge-write-qdrant", "-p", @@ -44,7 +44,7 @@ qdrant + { local container = engine.container("query-graph-embeddings") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "ge-query-qdrant", "-p", @@ -76,7 +76,7 @@ qdrant + { local container = engine.container("store-doc-embeddings") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "de-write-qdrant", "-p", @@ -108,7 +108,7 @@ qdrant + { local container = engine.container("query-doc-embeddings") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "de-query-qdrant", "-p", diff --git a/templates/components/trustgraph.jsonnet b/templates/components/trustgraph.jsonnet index 31ae420e..833d932b 100644 --- a/templates/components/trustgraph.jsonnet +++ b/templates/components/trustgraph.jsonnet @@ -1,7 +1,6 @@ local base = import "base/base.jsonnet"; local images = import "values/images.jsonnet"; local url = import "values/url.jsonnet"; -local prompt = import "prompt-template.jsonnet"; { @@ -22,7 +21,7 @@ local prompt = import "prompt-template.jsonnet"; local container = engine.container("api-gateway") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "api-gateway", "-p", @@ -61,7 +60,7 @@ local prompt = import "prompt-template.jsonnet"; local container = engine.container("chunker") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "chunker-token", "-p", @@ -95,7 +94,7 @@ local prompt = import "prompt-template.jsonnet"; local container = engine.container("pdf-decoder") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "pdf-decoder", "-p", @@ -119,43 +118,13 @@ local prompt = import "prompt-template.jsonnet"; }, - "vectorize" +: { - - create:: function(engine) - - local container = - engine.container("vectorize") - .with_image(images.trustgraph) - .with_command([ - "embeddings-vectorize", - "-p", - url.pulsar, - ]) - .with_limits("1.0", "512M") - .with_reservations("0.5", "512M"); - - local containerSet = engine.containers( - "vectorize", [ container ] - ); - - local service = - engine.internalService(containerSet) - .with_port(8000, 8000, "metrics"); - - engine.resources([ - containerSet, - service, - ]) - - }, - "metering" +: { create:: function(engine) local container = engine.container("metering") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "metering", "-p", @@ -185,7 +154,7 @@ local prompt = import "prompt-template.jsonnet"; local container = engine.container("metering-rag") - .with_image(images.trustgraph) + .with_image(images.trustgraph_flow) .with_command([ "metering", "-p", @@ -211,5 +180,5 @@ local prompt = import "prompt-template.jsonnet"; }, -} + prompt +} diff --git a/templates/components/vertexai-rag.jsonnet b/templates/components/vertexai-rag.jsonnet new file mode 100644 index 00000000..0b5cf9a3 --- /dev/null +++ b/templates/components/vertexai-rag.jsonnet @@ -0,0 +1,74 @@ +local base = import "base/base.jsonnet"; +local images = import "values/images.jsonnet"; +local url = import "values/url.jsonnet"; +local prompts = import "prompts/mixtral.jsonnet"; + +{ + + with:: function(key, value) + self + { + ["vertexai-rag-" + key]:: value, + }, + + "vertexai-rag-model":: "gemini-1.0-pro-001", + "vertexai-rag-private-key":: "/vertexai/private.json", + "vertexai-rag-region":: "us-central1", + "vertexai-rag-max-output-tokens":: 4096, + "vertexai-rag-temperature":: 0.0, + + "text-completion-rag" +: { + + create:: function(engine) + + local cfgVol = engine.secretVolume( + "vertexai-creds", + "./vertexai", + { + "private.json": importstr "vertexai/private.json", + } + ); + + local container = + engine.container("text-completion-rag") + .with_image(images.trustgraph_vertexai) + .with_command([ + "text-completion-vertexai", + "-p", + url.pulsar, + "-k", + $["vertexai-rag-private-key"], + "-r", + $["vertexai-rag-region"], + "-x", + std.toString($["vertexai-rag-max-output-tokens"]), + "-t", + "%0.3f" % $["vertexai-rag-temperature"], + "-m", + $["vertexai-rag-model"], + "-i", + "non-persistent://tg/request/text-completion-rag", + "-o", + "non-persistent://tg/response/text-completion-rag", + ]) + .with_limits("0.5", "256M") + .with_reservations("0.1", "256M") + .with_volume_mount(cfgVol, "/vertexai"); + + local containerSet = engine.containers( + "text-completion-rag", [ container ] + ); + + local service = + engine.internalService(containerSet) + .with_port(8000, 8000, "metrics"); + + engine.resources([ + cfgVol, + containerSet, + service, + ]) + + } + +} + prompts + diff --git a/templates/components/vertexai.jsonnet b/templates/components/vertexai.jsonnet index ef193156..0e3550c5 100644 --- a/templates/components/vertexai.jsonnet +++ b/templates/components/vertexai.jsonnet @@ -5,6 +5,11 @@ local prompts = import "prompts/mixtral.jsonnet"; { + with:: function(key, value) + self + { + ["vertexai-" + key]:: value, + }, + "vertexai-model":: "gemini-1.0-pro-001", "vertexai-private-key":: "/vertexai/private.json", "vertexai-region":: "us-central1", @@ -25,7 +30,7 @@ local prompts = import "prompts/mixtral.jsonnet"; local container = engine.container("text-completion") - .with_image(images.trustgraph) + .with_image(images.trustgraph_vertexai) .with_command([ "text-completion-vertexai", "-p", @@ -61,59 +66,5 @@ local prompts = import "prompts/mixtral.jsonnet"; }, - "text-completion-rag" +: { - - create:: function(engine) - - local cfgVol = engine.secretVolume( - "vertexai-creds", - "./vertexai", - { - "private.json": importstr "vertexai/private.json", - } - ); - - local container = - engine.container("text-completion-rag") - .with_image(images.trustgraph) - .with_command([ - "text-completion-vertexai", - "-p", - url.pulsar, - "-k", - $["vertexai-private-key"], - "-r", - $["vertexai-region"], - "-x", - std.toString($["vertexai-max-output-tokens"]), - "-t", - "%0.3f" % $["vertexai-temperature"], - "-m", - $["vertexai-model"], - "-i", - "non-persistent://tg/request/text-completion-rag", - "-o", - "non-persistent://tg/response/text-completion-rag", - ]) - .with_limits("0.5", "256M") - .with_reservations("0.1", "256M") - .with_volume_mount(cfgVol, "/vertexai"); - - local containerSet = engine.containers( - "text-completion-rag", [ container ] - ); - - local service = - engine.internalService(containerSet) - .with_port(8000, 8000, "metrics"); - - engine.resources([ - cfgVol, - containerSet, - service, - ]) - - } - } + prompts diff --git a/templates/config-to-aks-k8s.jsonnet b/templates/config-to-aks-k8s.jsonnet new file mode 100644 index 00000000..c603a0d8 --- /dev/null +++ b/templates/config-to-aks-k8s.jsonnet @@ -0,0 +1,16 @@ + +local engine = import "engine/aks-k8s.jsonnet"; +local decode = import "util/decode-config.jsonnet"; +local components = import "components.jsonnet"; + +// Import config +local config = import "config.json"; + +// Produce patterns from config +local patterns = decode(config); + +// Extract resources usnig the engine +local resourceList = engine.package(patterns); + +resourceList + diff --git a/templates/engine/aks-k8s.jsonnet b/templates/engine/aks-k8s.jsonnet new file mode 100644 index 00000000..69bca03f --- /dev/null +++ b/templates/engine/aks-k8s.jsonnet @@ -0,0 +1,45 @@ + +local k8s = import "k8s.jsonnet"; + +local ns = { + apiVersion: "v1", + kind: "Namespace", + metadata: { + name: "trustgraph", + }, + "spec": { + }, +}; + +local sc = { + apiVersion: "storage.k8s.io/v1", + kind: "StorageClass", + metadata: { + name: "tg", + }, + provisioner: "disk.csi.azure.com", + parameters: { + // Standard disks (spinning magnetic), Locally Redundant Storage + // Cheapest, basically + skuName: "Standard_LRS", + }, + reclaimPolicy: "Delete", + volumeBindingMode: "WaitForFirstConsumer", +}; + +k8s + { + + // Extract resources usnig the engine + package:: function(patterns) + local resources = [sc, ns] + std.flattenArrays([ + p.create(self) for p in std.objectValues(patterns) + ]); + local resourceList = { + apiVersion: "v1", + kind: "List", + items: [ns, sc] + resources, + }; + resourceList + +} + diff --git a/templates/engine/docker-compose.jsonnet b/templates/engine/docker-compose.jsonnet index c37f1df0..0be3c3e3 100644 --- a/templates/engine/docker-compose.jsonnet +++ b/templates/engine/docker-compose.jsonnet @@ -22,6 +22,8 @@ with_image:: function(x) self + { image: x }, + with_user:: function(x) self + { user: x }, + with_command:: function(x) self + { command: x }, with_environment:: function(x) self + { @@ -75,6 +77,10 @@ { command: container.command } else {}) + + (if std.objectHas(container, "user") then + { user: container.user } + else {}) + + (if ! std.isEmpty(container.environment) then { environment: container.environment } else {}) + diff --git a/templates/engine/k8s.jsonnet b/templates/engine/k8s.jsonnet index 2fec0d1f..dfd8d11f 100644 --- a/templates/engine/k8s.jsonnet +++ b/templates/engine/k8s.jsonnet @@ -14,6 +14,8 @@ with_image:: function(x) self + { image: x }, + with_user:: function(x) self + { user: x }, + with_command:: function(x) self + { command: x }, with_environment:: function(x) self + { diff --git a/templates/generate b/templates/generate index e8772483..2640a125 100755 --- a/templates/generate +++ b/templates/generate @@ -7,6 +7,8 @@ import logging import os import sys import zipfile +import pathlib +from io import BytesIO logger = logging.getLogger("generate") logging.basicConfig(level=logging.INFO, format='%(message)s') @@ -15,9 +17,13 @@ private_json = "Put your GCP private.json here" class Generator: - def __init__(self, config, base="./templates/", version="0.0.0"): + def __init__( + self, config, templates="./templates/", resources="./resources", + version="0.0.0", + ): - self.jsonnet_base = base + self.templates = pathlib.Path(templates) + self.resources = pathlib.Path(resources) self.config = config self.version = f"\"{version}\"".encode("utf-8") @@ -34,25 +40,30 @@ class Generator: path = os.path.join(".", dir, filename) return str(path), self.config - if filename == "version.jsonnet" and dir == "./templates/values/": + if filename == "version.jsonnet" and dir == "templates/values/": path = os.path.join(".", dir, filename) return str(path), self.version if dir: candidates = [ - os.path.join(".", dir, filename), - os.path.join(".", filename) + self.templates.joinpath(dir, filename), + self.templates.joinpath(filename), + self.resources.joinpath(dir, filename), + self.resources.joinpath(filename), + pathlib.Path(dir).joinpath(filename), ] else: candidates = [ - os.path.join(".", filename) + self.templates.joinpath(filename), + pathlib.Path(dir).joinpath(filename), + pathlib.Path(filename), ] try: if filename == "vertexai/private.json": - return candidates[0], private_json.encode("utf-8") + return str(candidates[0]), private_json.encode("utf-8") for c in candidates: logger.debug("Try: %s", c) @@ -68,73 +79,157 @@ class Generator: except: - path = os.path.join(self.jsonnet_base, filename) + path = os.path.join(self.templates, filename) logger.debug("Try: %s", path) with open(path, "rb") as f: logger.debug("Loaded: %s", path) return str(path), f.read() +class Packager: + + def __init__(self): + self.templates = pathlib.Path("./templates") + self.resources = pathlib.Path("./") + + def process( + self, config, version="0.0.0", platform="docker-compose", + ): + + config = config.encode("utf-8") + + gen = Generator( + config, templates=self.templates, resources=self.resources, + version=version + ) + + path = self.templates.joinpath( + f"config-to-{platform}.jsonnet" + ) + wrapper = path.read_text() + + processed = gen.process(wrapper) + + return processed + + def generate(self, config, version, platform): + + logger.info(f"Generating for platform={platform} version={version}") + + try: + + if platform in set(["docker-compose", "podman-compose"]): + return self.generate_docker_compose( + "docker-compose", version, config + ) + elif platform in set(["minikube-k8s", "gcp-k8s", "aks-k8s"]): + return self.generate_k8s( + platform, version, config + ) + else: + raise RuntimeError("Bad configuration") + + except Exception as e: + logging.error(f"Exception: {e}") + raise e + + def generate_docker_compose(self, platform, version, config): + + processed = self.process( + config, platform=platform, version=version + ) + + y = yaml.dump(processed) + + mem = BytesIO() + + with zipfile.ZipFile(mem, mode='w') as out: + + def output(name, content): + logger.info(f"Adding {name}...") + out.writestr(name, content) + + fname = "docker-compose.yaml" + + output(fname, y) + + # Grafana config + path = self.resources.joinpath( + "grafana/dashboards/dashboard.json" + ) + res = path.read_text() + output("grafana/dashboards/dashboard.json", res) + + path = self.resources.joinpath( + "grafana/provisioning/dashboard.yml" + ) + res = path.read_text() + output("grafana/provisioning/dashboard.yml", res) + + path = self.resources.joinpath( + "grafana/provisioning/datasource.yml" + ) + res = path.read_text() + output("grafana/provisioning/datasource.yml", res) + + # Prometheus config + path = self.resources.joinpath( + "prometheus/prometheus.yml" + ) + res = path.read_text() + output("prometheus/prometheus.yml", res) + + logger.info("Generation complete.") + + return mem.getvalue() + + def generate_k8s(self, platform, version, config): + + processed = self.process( + config, platform=platform, version=version + ) + + y = yaml.dump(processed) + + mem = BytesIO() + + with zipfile.ZipFile(mem, mode='w') as out: + + def output(name, content): + logger.info(f"Adding {name}...") + out.writestr(name, content) + + fname = "resources.yaml" + + output(fname, y) + + logger.info("Generation complete.") + + return mem.getvalue() + def main(): - if len(sys.argv) != 3: + if len(sys.argv) != 4: print() print("Usage:") - print(" generate < input.json") + print(" generate < input.json") print() sys.exit(1) outfile = sys.argv[1] version = sys.argv[2] + platform = sys.argv[3] cfg = sys.stdin.read() - cfg = json.loads(cfg) logger.info(f"Outputting to {outfile}...") - with zipfile.ZipFile(outfile, mode='w') as out: + p = Packager() + resp = p.generate(cfg, version, platform) - def output(name, content): - logger.info(f"Adding {name}...") - out.writestr(name, content) + with open(outfile, "wb") as f: + f.write(resp) - fname = "tg-launch.yaml" - - platform = "docker-compose" - - with open(f"./templates/config-to-{platform}.jsonnet", "r") as f: - wrapper = f.read() - - gen = Generator(json.dumps(cfg).encode("utf-8"), version=version) - - processed = gen.process(wrapper) - - y = yaml.dump(processed) - - output(fname, y) - - # Placeholder for the private.json file. Won't put actual credentials - # here. - output("docker-compose/vertexai/private.json", private_json) - - # Grafana config - with open("grafana/dashboards/dashboard.json") as f: - output( - "docker-compose/grafana/dashboards/dashboard.json", f.read() - ) - - with open("grafana/provisioning/dashboard.yml") as f: - output( - "docker-compose/grafana/provisioning/dashboard.yml", f.read() - ) - - with open("grafana/provisioning/datasource.yml") as f: - output( - "docker-compose/grafana/provisioning/datasource.yml", f.read() - ) - - # Prometheus config - with open("prometheus/prometheus.yml") as f: - output("docker-compose/prometheus/prometheus.yml", f.read()) + return main() diff --git a/templates/generate-all b/templates/generate-all index 1ae71402..fb1fe917 100755 --- a/templates/generate-all +++ b/templates/generate-all @@ -87,8 +87,18 @@ def full_config_object( ): return config_object([ - graph_store, "pulsar", vector_store, embeddings, - "graph-rag", "grafana", "trustgraph", llm, "workbench-ui", + "triple-store-" + graph_store, + "pulsar", + "vector-store-" + vector_store, + embeddings, + "graph-rag", + "grafana", + "trustgraph-base", + llm, + llm + "-rag", + "workbench-ui", + "prompt-template", + "agent-manager-react", ]) def generate_config( @@ -124,7 +134,7 @@ def generate_all(output, version): ]: for model in [ # "azure", "azure-openai", "bedrock", "claude", "cohere", - # "googleaistudio", "llamafile", + # "googleaistudio", "llamafile", "mistral", "ollama", # "openai", "vertexai", ]: diff --git a/templates/patterns/llm-mistral.jsonnet b/templates/patterns/llm-mistral.jsonnet new file mode 100644 index 00000000..11f6de22 --- /dev/null +++ b/templates/patterns/llm-mistral.jsonnet @@ -0,0 +1,32 @@ +{ + pattern: { + name: "mistral", + icon: "🤖💬", + title: "Add Mistral LLM endpoint for text completion", + description: "This pattern integrates a Mistral LLM service for text completion operations. You need a Mistral subscription and have an API key to be able to use this service.", + requires: ["pulsar", "trustgraph"], + features: ["llm"], + args: [ + { + name: "mistral-max-output-tokens", + label: "Maximum output tokens", + type: "integer", + description: "Limit on number tokens to generate", + default: 4096, + required: true, + }, + { + name: "mistral-temperature", + label: "Temperature", + type: "slider", + description: "Controlling predictability / creativity balance", + min: 0, + max: 1, + step: 0.05, + default: 0.5, + }, + ], + category: [ "llm" ], + }, + module: "components/mistral.jsonnet", +} diff --git a/templates/stores/cassandra.jsonnet b/templates/stores/cassandra.jsonnet index 0c90421e..2a9d6d7a 100644 --- a/templates/stores/cassandra.jsonnet +++ b/templates/stores/cassandra.jsonnet @@ -13,7 +13,7 @@ local images = import "values/images.jsonnet"; engine.container("cassandra") .with_image(images.cassandra) .with_environment({ - JVM_OPTS: "-Xms300M -Xmx300M", + JVM_OPTS: "-Xms300M -Xmx300M -Dcassandra.skip_wait_for_gossip_to_settle=0", }) .with_limits("1.0", "1000M") .with_reservations("0.5", "1000M") diff --git a/templates/stores/memgraph.jsonnet b/templates/stores/memgraph.jsonnet index 75faf5f0..70ad127a 100644 --- a/templates/stores/memgraph.jsonnet +++ b/templates/stores/memgraph.jsonnet @@ -7,6 +7,8 @@ local images = import "values/images.jsonnet"; create:: function(engine) + local vol = engine.volume("memgraph").with_size("20G"); + local container = engine.container("memgraph") .with_image(images.memgraph_mage) @@ -16,7 +18,8 @@ local images = import "values/images.jsonnet"; .with_limits("1.0", "1000M") .with_reservations("0.5", "1000M") .with_port(7474, 7474, "api") - .with_port(7687, 7687, "api2"); + .with_port(7687, 7687, "api2") + .with_volume_mount(vol, "/var/lib/memgraph"); local containerSet = engine.containers( "memgraph", [ container ] @@ -28,6 +31,7 @@ local images = import "values/images.jsonnet"; .with_port(7687, 7687, "api2"); engine.resources([ + vol, containerSet, service, ]) @@ -65,4 +69,3 @@ local images = import "values/images.jsonnet"; }, } - diff --git a/templates/stores/milvus.jsonnet b/templates/stores/milvus.jsonnet index cbeb4268..1c3e3734 100644 --- a/templates/stores/milvus.jsonnet +++ b/templates/stores/milvus.jsonnet @@ -1,7 +1,8 @@ local base = import "base/base.jsonnet"; local images = import "values/images.jsonnet"; +local minio = import "stores/minio.jsonnet"; -{ +minio { etcd +: { @@ -47,47 +48,6 @@ local images = import "values/images.jsonnet"; }, - mino +: { - - create:: function(engine) - - local vol = engine.volume("minio-data").with_size("20G"); - - local container = - engine.container("minio") - .with_image(images.minio) - .with_command([ - "minio", - "server", - "/minio_data", - "--console-address", - ":9001", - ]) - .with_environment({ - MINIO_ROOT_USER: "minioadmin", - MINIO_ROOT_PASSWORD: "minioadmin", - }) - .with_limits("0.5", "128M") - .with_reservations("0.25", "128M") - .with_port(9001, 9001, "api") - .with_volume_mount(vol, "/minio_data"); - - local containerSet = engine.containers( - "etcd", [ container ] - ); - - local service = - engine.service(containerSet) - .with_port(9001, 9001, "api"); - - engine.resources([ - vol, - containerSet, - service, - ]) - - }, - milvus +: { create:: function(engine) diff --git a/templates/stores/minio.jsonnet b/templates/stores/minio.jsonnet new file mode 100644 index 00000000..6ef1d96f --- /dev/null +++ b/templates/stores/minio.jsonnet @@ -0,0 +1,49 @@ +local base = import "base/base.jsonnet"; +local images = import "values/images.jsonnet"; + +{ + + minio +: { + + create:: function(engine) + + local vol = engine.volume("minio-data").with_size("20G"); + + local container = + engine.container("minio") + .with_image(images.minio) + .with_command([ + "minio", + "server", + "/minio_data", + "--console-address", + ":9001", + ]) + .with_environment({ + MINIO_ROOT_USER: "minioadmin", + MINIO_ROOT_PASSWORD: "minioadmin", + }) + .with_limits("0.5", "128M") + .with_reservations("0.25", "128M") + .with_port(9000, 9000, "api") + .with_port(9001, 9001, "console") + .with_volume_mount(vol, "/minio_data"); + + local containerSet = engine.containers( + "etcd", [ container ] + ); + + local service = + engine.service(containerSet) + .with_port(9000, 9000, "api") + .with_port(9001, 9001, "console"); + + engine.resources([ + vol, + containerSet, + service, + ]) + + }, + +} diff --git a/templates/stores/neo4j.jsonnet b/templates/stores/neo4j.jsonnet index 55cccc5f..3a8bb783 100644 --- a/templates/stores/neo4j.jsonnet +++ b/templates/stores/neo4j.jsonnet @@ -14,12 +14,14 @@ local images = import "values/images.jsonnet"; .with_image(images.neo4j) .with_environment({ NEO4J_AUTH: "neo4j/password", + NEO4J_server_memory_pagecache_size: "512m", + NEO4J_server_memory_heap_max__size: "512m", // NEO4J_server_bolt_listen__address: "0.0.0.0:7687", // NEO4J_server_default__listen__address: "0.0.0.0", // NEO4J_server_http_listen__address: "0.0.0.0:7474", }) - .with_limits("1.0", "768M") - .with_reservations("0.5", "768M") + .with_limits("1.0", "1536M") + .with_reservations("0.5", "1536M") .with_port(7474, 7474, "api") .with_port(7687, 7687, "api2") .with_volume_mount(vol, "/data"); diff --git a/templates/values/images.jsonnet b/templates/values/images.jsonnet index 40954289..54dbd016 100644 --- a/templates/values/images.jsonnet +++ b/templates/values/images.jsonnet @@ -1,18 +1,23 @@ local version = import "version.jsonnet"; { cassandra: "docker.io/cassandra:4.1.6", - neo4j: "docker.io/neo4j:5.22.0-community-bullseye", + neo4j: "docker.io/neo4j:5.26.0-community-bullseye", pulsar: "docker.io/apachepulsar/pulsar:3.3.1", pulsar_manager: "docker.io/apachepulsar/pulsar-manager:v0.4.0", etcd: "quay.io/coreos/etcd:v3.5.15", - minio: "docker.io/minio/minio:RELEASE.2024-08-17T01-24-54Z", + minio: "docker.io/minio/minio:RELEASE.2025-02-03T21-03-04Z", milvus: "docker.io/milvusdb/milvus:v2.4.9", prometheus: "docker.io/prom/prometheus:v2.53.2", grafana: "docker.io/grafana/grafana:11.1.4", - trustgraph: "docker.io/trustgraph/trustgraph-flow:" + version, - qdrant: "docker.io/qdrant/qdrant:v1.11.1", + trustgraph_base: "docker.io/trustgraph/trustgraph-base:" + version, + trustgraph_flow: "docker.io/trustgraph/trustgraph-flow:" + version, + trustgraph_ocr: "docker.io/trustgraph/trustgraph-ocr:" + version, + trustgraph_bedrock: "docker.io/trustgraph/trustgraph-bedrock:" + version, + trustgraph_vertexai: "docker.io/trustgraph/trustgraph-vertexai:" + version, + trustgraph_hf: "docker.io/trustgraph/trustgraph-hf:" + version, + qdrant: "docker.io/qdrant/qdrant:v1.13.3", memgraph_mage: "docker.io/memgraph/memgraph-mage:1.22-memgraph-2.22", memgraph_lab: "docker.io/memgraph/lab:2.19.1", falkordb: "docker.io/falkordb/falkordb:latest", - "workbench-ui": "docker.io/trustgraph/workbench-ui:0.1.6", + "workbench-ui": "docker.io/trustgraph/workbench-ui:0.2.4", } diff --git a/trustgraph-base/trustgraph/api/api.py b/trustgraph-base/trustgraph/api/api.py index de96499c..4c72d3ca 100644 --- a/trustgraph-base/trustgraph/api/api.py +++ b/trustgraph-base/trustgraph/api/api.py @@ -102,11 +102,21 @@ class Api: except: raise ProtocolException(f"Response not formatted correctly") - def graph_rag(self, question): + def graph_rag( + self, question, user="trustgraph", collection="default", + entity_limit=50, triple_limit=30, max_subgraph_size=150, + max_path_length=2, + ): # The input consists of a question input = { - "query": question + "query": question, + "user": user, + "collection": collection, + "entity-limit": entity_limit, + "triple-limit": triple_limit, + "max-subgraph-size": max_subgraph_size, + "max-path-length": max_path_length, } url = f"{self.url}graph-rag" @@ -131,6 +141,41 @@ class Api: except: raise ProtocolException(f"Response not formatted correctly") + def document_rag( + self, question, user="trustgraph", collection="default", + doc_limit=10, + ): + + # The input consists of a question + input = { + "query": question, + "user": user, + "collection": collection, + "doc-limit": doc_limit, + } + + url = f"{self.url}document-rag" + + # Invoke the API, input is passed as JSON + resp = requests.post(url, json=input) + + # Should be a 200 status code + if resp.status_code != 200: + raise ProtocolException(f"Status code {resp.status_code}") + + try: + # Parse the response as JSON + object = resp.json() + except: + raise ProtocolException(f"Expected JSON response") + + self.check_error(resp) + + try: + return object["response"] + except: + raise ProtocolException(f"Response not formatted correctly") + def embeddings(self, text): # The input consists of a text block diff --git a/trustgraph-base/trustgraph/base/__init__.py b/trustgraph-base/trustgraph/base/__init__.py index b9dba4fa..3a58d51e 100644 --- a/trustgraph-base/trustgraph/base/__init__.py +++ b/trustgraph-base/trustgraph/base/__init__.py @@ -3,4 +3,6 @@ from . base_processor import BaseProcessor from . consumer import Consumer from . producer import Producer from . consumer_producer import ConsumerProducer +from . publisher import Publisher +from . subscriber import Subscriber diff --git a/trustgraph-base/trustgraph/base/base_processor.py b/trustgraph-base/trustgraph/base/base_processor.py index f258ff1a..a8374538 100644 --- a/trustgraph-base/trustgraph/base/base_processor.py +++ b/trustgraph-base/trustgraph/base/base_processor.py @@ -1,4 +1,5 @@ +import asyncio import os import argparse import pulsar @@ -11,6 +12,7 @@ from .. log_level import LogLevel class BaseProcessor: default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') + default_pulsar_api_key = os.getenv("PULSAR_API_KEY", None) def __init__(self, **params): @@ -28,14 +30,28 @@ class BaseProcessor: }) pulsar_host = params.get("pulsar_host", self.default_pulsar_host) + pulsar_listener = params.get("pulsar_listener", None) + pulsar_api_key = params.get("pulsar_api_key", None) log_level = params.get("log_level", LogLevel.INFO) self.pulsar_host = pulsar_host + self.pulsar_api_key = pulsar_api_key - self.client = pulsar.Client( + if pulsar_api_key: + auth = pulsar.AuthenticationToken(pulsar_api_key) + self.client = pulsar.Client( pulsar_host, + authentication=auth, logger=pulsar.ConsoleLogger(log_level.to_pulsar()) - ) + ) + else: + self.client = pulsar.Client( + pulsar_host, + listener_name=pulsar_listener, + logger=pulsar.ConsoleLogger(log_level.to_pulsar()) + ) + + self.pulsar_listener = pulsar_listener def __del__(self): @@ -51,6 +67,17 @@ class BaseProcessor: default=__class__.default_pulsar_host, help=f'Pulsar host (default: {__class__.default_pulsar_host})', ) + + parser.add_argument( + '--pulsar-api-key', + default=__class__.default_pulsar_api_key, + help=f'Pulsar API key', + ) + + parser.add_argument( + '--pulsar-listener', + help=f'Pulsar listener (default: none)', + ) parser.add_argument( '-l', '--log-level', @@ -74,11 +101,20 @@ class BaseProcessor: help=f'Pulsar host (default: 8000)', ) - def run(self): + async def start(self): + pass + + async def run(self): raise RuntimeError("Something should have implemented the run method") @classmethod - def start(cls, prog, doc): + async def launch_async(cls, args): + p = cls(**args) + await p.start() + await p.run() + + @classmethod + def launch(cls, prog, doc): parser = argparse.ArgumentParser( prog=prog, @@ -99,8 +135,7 @@ class BaseProcessor: try: - p = cls(**args) - p.run() + asyncio.run(cls.launch_async(args)) except KeyboardInterrupt: print("Keyboard interrupt.") @@ -118,3 +153,4 @@ class BaseProcessor: print("Will retry...", flush=True) time.sleep(4) + diff --git a/trustgraph-base/trustgraph/base/consumer.py b/trustgraph-base/trustgraph/base/consumer.py index eeaf83a1..175f1fd7 100644 --- a/trustgraph-base/trustgraph/base/consumer.py +++ b/trustgraph-base/trustgraph/base/consumer.py @@ -1,11 +1,16 @@ +import asyncio from pulsar.schema import JsonSchema +import pulsar from prometheus_client import Histogram, Info, Counter, Enum import time from . base_processor import BaseProcessor from .. exceptions import TooManyRequests +default_rate_limit_retry = 10 +default_rate_limit_timeout = 7200 + class Consumer(BaseProcessor): def __init__(self, **params): @@ -21,11 +26,18 @@ class Consumer(BaseProcessor): super(Consumer, self).__init__(**params) - input_queue = params.get("input_queue") - subscriber = params.get("subscriber") - input_schema = params.get("input_schema") + self.input_queue = params.get("input_queue") + self.subscriber = params.get("subscriber") + self.input_schema = params.get("input_schema") - if input_schema == None: + self.rate_limit_retry = params.get( + "rate_limit_retry", default_rate_limit_retry + ) + self.rate_limit_timeout = params.get( + "rate_limit_timeout", default_rate_limit_timeout + ) + + if self.input_schema == None: raise RuntimeError("input_schema must be specified") if not hasattr(__class__, "request_metric"): @@ -43,18 +55,28 @@ class Consumer(BaseProcessor): 'processing_count', 'Processing count', ["status"] ) + if not hasattr(__class__, "rate_limit_metric"): + __class__.rate_limit_metric = Counter( + 'rate_limit_count', 'Rate limit event count', + ) + __class__.pubsub_metric.info({ - "input_queue": input_queue, - "subscriber": subscriber, - "input_schema": input_schema.__name__, + "input_queue": self.input_queue, + "subscriber": self.subscriber, + "input_schema": self.input_schema.__name__, + "rate_limit_retry": str(self.rate_limit_retry), + "rate_limit_timeout": str(self.rate_limit_timeout), }) self.consumer = self.client.subscribe( - input_queue, subscriber, - schema=JsonSchema(input_schema), + self.input_queue, self.subscriber, + consumer_type=pulsar.ConsumerType.Shared, + schema=JsonSchema(self.input_schema), ) - def run(self): + print("Initialised consumer.", flush=True) + + async def run(self): __class__.state_metric.state('running') @@ -62,31 +84,61 @@ class Consumer(BaseProcessor): msg = self.consumer.receive() - try: + expiry = time.time() + self.rate_limit_timeout - with __class__.request_metric.time(): - self.handle(msg) + # This loop is for retry on rate-limit / resource limits + while True: - # Acknowledge successful processing of the message - self.consumer.acknowledge(msg) + if time.time() > expiry: - __class__.processing_metric.labels(status="success").inc() + print("Gave up waiting for rate-limit retry", flush=True) - except TooManyRequests: - self.consumer.negative_acknowledge(msg) - print("TooManyRequests: will retry") - __class__.processing_metric.labels(status="rate-limit").inc() - time.sleep(5) - continue + # Message failed to be processed, this causes it to + # be retried + self.consumer.negative_acknowledge(msg) + + __class__.processing_metric.labels(status="error").inc() + + # Break out of retry loop, processes next message + break + + try: + + with __class__.request_metric.time(): + await self.handle(msg) + + # Acknowledge successful processing of the message + self.consumer.acknowledge(msg) + + __class__.processing_metric.labels(status="success").inc() + + # Break out of retry loop + break + + except TooManyRequests: + + print("TooManyRequests: will retry...", flush=True) + + __class__.rate_limit_metric.inc() + + # Sleep + time.sleep(self.rate_limit_retry) + + # Contine from retry loop, just causes a reprocessing + continue - except Exception as e: + except Exception as e: - print("Exception:", e, flush=True) + print("Exception:", e, flush=True) - # Message failed to be processed - self.consumer.negative_acknowledge(msg) + # Message failed to be processed, this causes it to + # be retried + self.consumer.negative_acknowledge(msg) - __class__.processing_metric.labels(status="error").inc() + __class__.processing_metric.labels(status="error").inc() + + # Break out of retry loop, processes next message + break @staticmethod def add_args(parser, default_input_queue, default_subscriber): @@ -105,3 +157,17 @@ class Consumer(BaseProcessor): help=f'Queue subscriber name (default: {default_subscriber})' ) + parser.add_argument( + '--rate-limit-retry', + type=int, + default=default_rate_limit_retry, + help=f'Rate limit retry (default: {default_rate_limit_retry})' + ) + + parser.add_argument( + '--rate-limit-timeout', + type=int, + default=default_rate_limit_timeout, + help=f'Rate limit timeout (default: {default_rate_limit_timeout})' + ) + diff --git a/trustgraph-base/trustgraph/base/consumer_producer.py b/trustgraph-base/trustgraph/base/consumer_producer.py index cabb7525..1006f9b5 100644 --- a/trustgraph-base/trustgraph/base/consumer_producer.py +++ b/trustgraph-base/trustgraph/base/consumer_producer.py @@ -1,113 +1,48 @@ from pulsar.schema import JsonSchema +import pulsar from prometheus_client import Histogram, Info, Counter, Enum import time -from . base_processor import BaseProcessor +from . consumer import Consumer from .. exceptions import TooManyRequests -# FIXME: Derive from consumer? And producer? - -class ConsumerProducer(BaseProcessor): +class ConsumerProducer(Consumer): def __init__(self, **params): - if not hasattr(__class__, "state_metric"): - __class__.state_metric = Enum( - 'processor_state', 'Processor state', - states=['starting', 'running', 'stopped'] - ) - __class__.state_metric.state('starting') + super(ConsumerProducer, self).__init__(**params) - __class__.state_metric.state('starting') - - input_queue = params.get("input_queue") - output_queue = params.get("output_queue") - subscriber = params.get("subscriber") - input_schema = params.get("input_schema") - output_schema = params.get("output_schema") - - if not hasattr(__class__, "request_metric"): - __class__.request_metric = Histogram( - 'request_latency', 'Request latency (seconds)' - ) + self.output_queue = params.get("output_queue") + self.output_schema = params.get("output_schema") if not hasattr(__class__, "output_metric"): __class__.output_metric = Counter( 'output_count', 'Output items created' ) - if not hasattr(__class__, "pubsub_metric"): - __class__.pubsub_metric = Info( - 'pubsub', 'Pub/sub configuration' - ) - - if not hasattr(__class__, "processing_metric"): - __class__.processing_metric = Counter( - 'processing_count', 'Processing count', ["status"] - ) - __class__.pubsub_metric.info({ - "input_queue": input_queue, - "output_queue": output_queue, - "subscriber": subscriber, - "input_schema": input_schema.__name__, - "output_schema": output_schema.__name__, + "input_queue": self.input_queue, + "output_queue": self.output_queue, + "subscriber": self.subscriber, + "input_schema": self.input_schema.__name__, + "output_schema": self.output_schema.__name__, + "rate_limit_retry": str(self.rate_limit_retry), + "rate_limit_timeout": str(self.rate_limit_timeout), }) - super(ConsumerProducer, self).__init__(**params) - - if input_schema == None: - raise RuntimeError("input_schema must be specified") - - if output_schema == None: + if self.output_schema == None: raise RuntimeError("output_schema must be specified") self.producer = self.client.create_producer( - topic=output_queue, - schema=JsonSchema(output_schema), + topic=self.output_queue, + schema=JsonSchema(self.output_schema), + chunking_enabled=True, ) - self.consumer = self.client.subscribe( - input_queue, subscriber, - schema=JsonSchema(input_schema), - ) + print("Initialised consumer/producer.") - def run(self): - - __class__.state_metric.state('running') - - while True: - - msg = self.consumer.receive() - - try: - - with __class__.request_metric.time(): - resp = self.handle(msg) - - # Acknowledge successful processing of the message - self.consumer.acknowledge(msg) - - __class__.processing_metric.labels(status="success").inc() - - except TooManyRequests: - self.consumer.negative_acknowledge(msg) - print("TooManyRequests: will retry") - __class__.processing_metric.labels(status="rate-limit").inc() - time.sleep(5) - continue - - except Exception as e: - - print("Exception:", e, flush=True) - - # Message failed to be processed - self.consumer.negative_acknowledge(msg) - - __class__.processing_metric.labels(status="error").inc() - - def send(self, msg, properties={}): + async def send(self, msg, properties={}): self.producer.send(msg, properties) __class__.output_metric.inc() @@ -117,19 +52,7 @@ class ConsumerProducer(BaseProcessor): default_output_queue, ): - BaseProcessor.add_args(parser) - - parser.add_argument( - '-i', '--input-queue', - default=default_input_queue, - help=f'Input queue (default: {default_input_queue})' - ) - - parser.add_argument( - '-s', '--subscriber', - default=default_subscriber, - help=f'Queue subscriber name (default: {default_subscriber})' - ) + Consumer.add_args(parser, default_input_queue, default_subscriber) parser.add_argument( '-o', '--output-queue', diff --git a/trustgraph-base/trustgraph/base/producer.py b/trustgraph-base/trustgraph/base/producer.py index 27d693ee..bc2d7791 100644 --- a/trustgraph-base/trustgraph/base/producer.py +++ b/trustgraph-base/trustgraph/base/producer.py @@ -34,9 +34,10 @@ class Producer(BaseProcessor): self.producer = self.client.create_producer( topic=output_queue, schema=JsonSchema(output_schema), + chunking_enabled=True, ) - def send(self, msg, properties={}): + async def send(self, msg, properties={}): self.producer.send(msg, properties) __class__.output_metric.inc() diff --git a/trustgraph-flow/trustgraph/gateway/publisher.py b/trustgraph-base/trustgraph/base/publisher.py similarity index 61% rename from trustgraph-flow/trustgraph/gateway/publisher.py rename to trustgraph-base/trustgraph/base/publisher.py index 89c612ce..2da63331 100644 --- a/trustgraph-flow/trustgraph/gateway/publisher.py +++ b/trustgraph-base/trustgraph/base/publisher.py @@ -6,37 +6,43 @@ import threading class Publisher: - def __init__(self, pulsar_host, topic, schema=None, max_size=10, - chunking_enabled=False): - self.pulsar_host = pulsar_host + def __init__(self, pulsar_client, topic, schema=None, max_size=10, + chunking_enabled=True): + self.client = pulsar_client self.topic = topic self.schema = schema self.q = queue.Queue(maxsize=max_size) self.chunking_enabled = chunking_enabled + self.running = True def start(self): self.task = threading.Thread(target=self.run) self.task.start() + def stop(self): + self.running = False + + def join(self): + self.stop() + self.task.join() + def run(self): - while True: + while self.running: try: - - client = pulsar.Client( - self.pulsar_host, - ) - - producer = client.create_producer( + producer = self.client.create_producer( topic=self.topic, schema=self.schema, chunking_enabled=self.chunking_enabled, ) - while True: + while self.running: - id, item = self.q.get() + try: + id, item = self.q.get(timeout=0.5) + except queue.Empty: + continue if id: producer.send(item, { "id": id }) @@ -51,3 +57,5 @@ class Publisher: def send(self, id, msg): self.q.put((id, msg)) + + diff --git a/trustgraph-flow/trustgraph/gateway/subscriber.py b/trustgraph-base/trustgraph/base/subscriber.py similarity index 82% rename from trustgraph-flow/trustgraph/gateway/subscriber.py rename to trustgraph-base/trustgraph/base/subscriber.py index cccfc5b4..30ade3ee 100644 --- a/trustgraph-flow/trustgraph/gateway/subscriber.py +++ b/trustgraph-base/trustgraph/base/subscriber.py @@ -6,9 +6,9 @@ import time class Subscriber: - def __init__(self, pulsar_host, topic, subscription, consumer_name, + def __init__(self, pulsar_client, topic, subscription, consumer_name, schema=None, max_size=100): - self.pulsar_host = pulsar_host + self.client = pulsar_client self.topic = topic self.subscription = subscription self.consumer_name = consumer_name @@ -17,29 +17,32 @@ class Subscriber: self.full = {} self.max_size = max_size self.lock = threading.Lock() + self.running = True def start(self): self.task = threading.Thread(target=self.run) self.task.start() + def stop(self): + self.running = False + + def join(self): + self.task.join() + def run(self): - while True: + while self.running: try: - client = pulsar.Client( - self.pulsar_host, - ) - - consumer = client.subscribe( + consumer = self.client.subscribe( topic=self.topic, subscription_name=self.subscription, consumer_name=self.consumer_name, schema=self.schema, ) - while True: + while self.running: msg = consumer.receive() @@ -57,12 +60,14 @@ class Subscriber: if id in self.q: try: + # FIXME: Timeout means data goes missing self.q[id].put(value, timeout=0.5) except: pass for q in self.full.values(): try: + # FIXME: Timeout means data goes missing q.put(value, timeout=0.5) except: pass diff --git a/trustgraph-base/trustgraph/clients/agent_client.py b/trustgraph-base/trustgraph/clients/agent_client.py index 2ef69274..b31b4e36 100644 --- a/trustgraph-base/trustgraph/clients/agent_client.py +++ b/trustgraph-base/trustgraph/clients/agent_client.py @@ -20,6 +20,7 @@ class AgentClient(BaseClient): input_queue=None, output_queue=None, pulsar_host="pulsar://pulsar:6650", + pulsar_api_key=None, ): if input_queue is None: input_queue = agent_request_queue @@ -33,6 +34,7 @@ class AgentClient(BaseClient): pulsar_host=pulsar_host, input_schema=AgentRequest, output_schema=AgentResponse, + pulsar_api_key=pulsar_api_key ) def request( diff --git a/trustgraph-base/trustgraph/clients/base.py b/trustgraph-base/trustgraph/clients/base.py index 78116f41..ac809123 100644 --- a/trustgraph-base/trustgraph/clients/base.py +++ b/trustgraph-base/trustgraph/clients/base.py @@ -27,6 +27,7 @@ class BaseClient: input_schema=None, output_schema=None, pulsar_host="pulsar://pulsar:6650", + pulsar_api_key=None, ): if input_queue == None: raise RuntimeError("Need input_queue") @@ -37,10 +38,18 @@ class BaseClient: if subscriber == None: subscriber = str(uuid.uuid4()) - self.client = pulsar.Client( + if pulsar_api_key: + auth = pulsar.AuthenticationToken(pulsar_api_key) + self.client = pulsar.Client( pulsar_host, logger=pulsar.ConsoleLogger(log_level), - ) + authentication=auth, + ) + else: + self.client = pulsar.Client( + pulsar_host, + logger=pulsar.ConsoleLogger(log_level) + ) self.producer = self.client.create_producer( topic=input_queue, diff --git a/trustgraph-base/trustgraph/clients/document_embeddings_client.py b/trustgraph-base/trustgraph/clients/document_embeddings_client.py index d432991d..14547595 100644 --- a/trustgraph-base/trustgraph/clients/document_embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/document_embeddings_client.py @@ -20,6 +20,7 @@ class DocumentEmbeddingsClient(BaseClient): input_queue=None, output_queue=None, pulsar_host="pulsar://pulsar:6650", + pulsar_api_key=None, ): if input_queue == None: @@ -34,12 +35,17 @@ class DocumentEmbeddingsClient(BaseClient): input_queue=input_queue, output_queue=output_queue, pulsar_host=pulsar_host, + pulsar_api_key=pulsar_api_key, input_schema=DocumentEmbeddingsRequest, output_schema=DocumentEmbeddingsResponse, ) - def request(self, vectors, limit=10, timeout=300): + def request( + self, vectors, user="trustgraph", collection="default", + limit=10, timeout=300 + ): return self.call( + user=user, collection=collection, vectors=vectors, limit=limit, timeout=timeout ).documents diff --git a/trustgraph-base/trustgraph/clients/document_rag_client.py b/trustgraph-base/trustgraph/clients/document_rag_client.py index 103cbb69..6cbafa9b 100644 --- a/trustgraph-base/trustgraph/clients/document_rag_client.py +++ b/trustgraph-base/trustgraph/clients/document_rag_client.py @@ -20,6 +20,7 @@ class DocumentRagClient(BaseClient): input_queue=None, output_queue=None, pulsar_host="pulsar://pulsar:6650", + pulsar_api_key=None, ): if input_queue == None: @@ -34,6 +35,7 @@ class DocumentRagClient(BaseClient): input_queue=input_queue, output_queue=output_queue, pulsar_host=pulsar_host, + pulsar_api_key=pulsar_api_key, input_schema=DocumentRagQuery, output_schema=DocumentRagResponse, ) diff --git a/trustgraph-base/trustgraph/clients/embeddings_client.py b/trustgraph-base/trustgraph/clients/embeddings_client.py index 8d21bdec..811f6ed2 100644 --- a/trustgraph-base/trustgraph/clients/embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/embeddings_client.py @@ -20,6 +20,7 @@ class EmbeddingsClient(BaseClient): output_queue=None, subscriber=None, pulsar_host="pulsar://pulsar:6650", + pulsar_api_key=None, ): if input_queue == None: @@ -34,6 +35,7 @@ class EmbeddingsClient(BaseClient): input_queue=input_queue, output_queue=output_queue, pulsar_host=pulsar_host, + pulsar_api_key=pulsar_api_key, input_schema=EmbeddingsRequest, output_schema=EmbeddingsResponse, ) diff --git a/trustgraph-base/trustgraph/clients/graph_embeddings_client.py b/trustgraph-base/trustgraph/clients/graph_embeddings_client.py index 401266bc..1a7a9512 100644 --- a/trustgraph-base/trustgraph/clients/graph_embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/graph_embeddings_client.py @@ -20,6 +20,7 @@ class GraphEmbeddingsClient(BaseClient): input_queue=None, output_queue=None, pulsar_host="pulsar://pulsar:6650", + pulsar_api_key=None, ): if input_queue == None: @@ -34,6 +35,7 @@ class GraphEmbeddingsClient(BaseClient): input_queue=input_queue, output_queue=output_queue, pulsar_host=pulsar_host, + pulsar_api_key=pulsar_api_key, input_schema=GraphEmbeddingsRequest, output_schema=GraphEmbeddingsResponse, ) diff --git a/trustgraph-base/trustgraph/clients/graph_rag_client.py b/trustgraph-base/trustgraph/clients/graph_rag_client.py index 9f8eff62..77102e36 100644 --- a/trustgraph-base/trustgraph/clients/graph_rag_client.py +++ b/trustgraph-base/trustgraph/clients/graph_rag_client.py @@ -20,6 +20,7 @@ class GraphRagClient(BaseClient): input_queue=None, output_queue=None, pulsar_host="pulsar://pulsar:6650", + pulsar_api_key=None, ): if input_queue == None: @@ -34,6 +35,7 @@ class GraphRagClient(BaseClient): input_queue=input_queue, output_queue=output_queue, pulsar_host=pulsar_host, + pulsar_api_key=pulsar_api_key, input_schema=GraphRagQuery, output_schema=GraphRagResponse, ) diff --git a/trustgraph-base/trustgraph/clients/llm_client.py b/trustgraph-base/trustgraph/clients/llm_client.py index cfb0e606..a8894c8f 100644 --- a/trustgraph-base/trustgraph/clients/llm_client.py +++ b/trustgraph-base/trustgraph/clients/llm_client.py @@ -20,6 +20,7 @@ class LlmClient(BaseClient): input_queue=None, output_queue=None, pulsar_host="pulsar://pulsar:6650", + pulsar_api_key=None, ): if input_queue is None: input_queue = text_completion_request_queue @@ -31,6 +32,7 @@ class LlmClient(BaseClient): input_queue=input_queue, output_queue=output_queue, pulsar_host=pulsar_host, + pulsar_api_key=pulsar_api_key, input_schema=TextCompletionRequest, output_schema=TextCompletionResponse, ) diff --git a/trustgraph-base/trustgraph/clients/prompt_client.py b/trustgraph-base/trustgraph/clients/prompt_client.py index 4b026cf0..91707670 100644 --- a/trustgraph-base/trustgraph/clients/prompt_client.py +++ b/trustgraph-base/trustgraph/clients/prompt_client.py @@ -39,6 +39,7 @@ class PromptClient(BaseClient): input_queue=None, output_queue=None, pulsar_host="pulsar://pulsar:6650", + pulsar_api_key=None, ): if input_queue == None: @@ -53,6 +54,7 @@ class PromptClient(BaseClient): input_queue=input_queue, output_queue=output_queue, pulsar_host=pulsar_host, + pulsar_api_key=pulsar_api_key, input_schema=PromptRequest, output_schema=PromptResponse, ) diff --git a/trustgraph-base/trustgraph/clients/triples_query_client.py b/trustgraph-base/trustgraph/clients/triples_query_client.py index fc1e4b26..8ed2ebb7 100644 --- a/trustgraph-base/trustgraph/clients/triples_query_client.py +++ b/trustgraph-base/trustgraph/clients/triples_query_client.py @@ -21,6 +21,7 @@ class TriplesQueryClient(BaseClient): input_queue=None, output_queue=None, pulsar_host="pulsar://pulsar:6650", + pulsar_api_key=None, ): if input_queue == None: @@ -34,6 +35,7 @@ class TriplesQueryClient(BaseClient): subscriber=subscriber, input_queue=input_queue, output_queue=output_queue, + pulsar_api_key=pulsar_api_key, pulsar_host=pulsar_host, input_schema=TriplesQueryRequest, output_schema=TriplesQueryResponse, diff --git a/trustgraph-base/trustgraph/exceptions.py b/trustgraph-base/trustgraph/exceptions.py index 16f9956c..09f098df 100644 --- a/trustgraph-base/trustgraph/exceptions.py +++ b/trustgraph-base/trustgraph/exceptions.py @@ -8,7 +8,6 @@ class LlmError(Exception): class ParseError(Exception): pass - - - +class RequestError(Exception): + pass diff --git a/trustgraph-base/trustgraph/schema/__init__.py b/trustgraph-base/trustgraph/schema/__init__.py index be41b670..9c44a743 100644 --- a/trustgraph-base/trustgraph/schema/__init__.py +++ b/trustgraph-base/trustgraph/schema/__init__.py @@ -10,5 +10,6 @@ from . retrieval import * from . metadata import * from . agent import * from . lookup import * +from . library import * diff --git a/trustgraph-base/trustgraph/schema/documents.py b/trustgraph-base/trustgraph/schema/documents.py index 2a3d3d0c..fd0049ee 100644 --- a/trustgraph-base/trustgraph/schema/documents.py +++ b/trustgraph-base/trustgraph/schema/documents.py @@ -35,14 +35,18 @@ chunk_ingest_queue = topic('chunk-load') ############################################################################ -# Chunk embeddings are an embeddings associated with a text chunk +# Document embeddings are embeddings associated with a chunk class ChunkEmbeddings(Record): - metadata = Metadata() - vectors = Array(Array(Double())) chunk = Bytes() + vectors = Array(Array(Double())) -chunk_embeddings_ingest_queue = topic('chunk-embeddings-load') +# This is a 'batching' mechanism for the above data +class DocumentEmbeddings(Record): + metadata = Metadata() + chunks = Array(ChunkEmbeddings()) + +document_embeddings_store_queue = topic('document-embeddings-store') ############################################################################ @@ -51,6 +55,8 @@ chunk_embeddings_ingest_queue = topic('chunk-embeddings-load') class DocumentEmbeddingsRequest(Record): vectors = Array(Array(Double())) limit = Integer() + user = String() + collection = String() class DocumentEmbeddingsResponse(Record): error = Error() @@ -62,3 +68,4 @@ document_embeddings_request_queue = topic( document_embeddings_response_queue = topic( 'doc-embeddings', kind='non-persistent', namespace='response', ) + diff --git a/trustgraph-base/trustgraph/schema/graph.py b/trustgraph-base/trustgraph/schema/graph.py index 78c1a99c..7c304e1d 100644 --- a/trustgraph-base/trustgraph/schema/graph.py +++ b/trustgraph-base/trustgraph/schema/graph.py @@ -7,12 +7,31 @@ from . metadata import Metadata ############################################################################ +# Entity context are an entity associated with textual context + +class EntityContext(Record): + entity = Value() + context = String() + +# This is a 'batching' mechanism for the above data +class EntityContexts(Record): + metadata = Metadata() + entities = Array(EntityContext()) + +entity_contexts_ingest_queue = topic('entity-contexts-load') + +############################################################################ + # Graph embeddings are embeddings associated with a graph entity +class EntityEmbeddings(Record): + entity = Value() + vectors = Array(Array(Double())) + +# This is a 'batching' mechanism for the above data class GraphEmbeddings(Record): metadata = Metadata() - vectors = Array(Array(Double())) - entity = Value() + entities = Array(EntityEmbeddings()) graph_embeddings_store_queue = topic('graph-embeddings-store') diff --git a/trustgraph-base/trustgraph/schema/library.py b/trustgraph-base/trustgraph/schema/library.py new file mode 100644 index 00000000..ed52b2ad --- /dev/null +++ b/trustgraph-base/trustgraph/schema/library.py @@ -0,0 +1,74 @@ + +from pulsar.schema import Record, Bytes, String, Array, Long +from . types import Triple +from . topic import topic +from . types import Error +from . metadata import Metadata +from . documents import Document, TextDocument + +# add +# -> (id, document) +# <- () +# <- (error) + +# list +# -> (user, collection?) +# <- (info) +# <- (error) + +# add(Metadata, Bytes) : error? +# copy(id, user, collection) +# move(id, user, collection) +# delete(id) +# get(id) : Bytes +# reindex(id) +# list(user, collection) : id[] +# info(id[]) : DocumentInfo[] +# search([]) : id[] + +class DocumentPackage(Record): + id = String() + document = Bytes() + kind = String() + user = String() + collection = String() + title = String() + comments = String() + time = Long() + metadata = Array(Triple()) + +class DocumentInfo(Record): + id = String() + kind = String() + user = String() + collection = String() + title = String() + comments = String() + time = Long() + metadata = Array(Triple()) + +class Criteria(Record): + key = String() + value = String() + operator = String() + +class LibrarianRequest(Record): + operation = String() + id = String() + document = DocumentPackage() + user = String() + collection = String() + criteria = Array(Criteria()) + +class LibrarianResponse(Record): + error = Error() + document = DocumentPackage() + info = Array(DocumentInfo()) + +librarian_request_queue = topic( + 'librarian', kind='non-persistent', namespace='request' +) +librarian_response_queue = topic( + 'librarian', kind='non-persistent', namespace='response', +) + diff --git a/trustgraph-base/trustgraph/schema/retrieval.py b/trustgraph-base/trustgraph/schema/retrieval.py index 9c4361a1..caeb8e67 100644 --- a/trustgraph-base/trustgraph/schema/retrieval.py +++ b/trustgraph-base/trustgraph/schema/retrieval.py @@ -11,6 +11,10 @@ class GraphRagQuery(Record): query = String() user = String() collection = String() + entity_limit = Integer() + triple_limit = Integer() + max_subgraph_size = Integer() + max_path_length = Integer() class GraphRagResponse(Record): error = Error() @@ -31,6 +35,7 @@ class DocumentRagQuery(Record): query = String() user = String() collection = String() + doc_limit = Integer() class DocumentRagResponse(Record): error = Error() diff --git a/trustgraph-bedrock/setup.py b/trustgraph-bedrock/setup.py index b8dd36bd..8db4520b 100644 --- a/trustgraph-bedrock/setup.py +++ b/trustgraph-bedrock/setup.py @@ -34,7 +34,7 @@ setuptools.setup( python_requires='>=3.8', download_url = "https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v" + version + ".tar.gz", install_requires=[ - "trustgraph-base>=0.18,<0.19", + "trustgraph-base>=0.21,<0.22", "pulsar-client", "prometheus-client", "boto3", diff --git a/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py b/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py index a9c05cc8..75868b56 100755 --- a/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py +++ b/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py @@ -8,6 +8,7 @@ import boto3 import json from prometheus_client import Histogram import os +import enum from .... schema import TextCompletionRequest, TextCompletionResponse, Error from .... schema import text_completion_request_queue @@ -24,32 +25,163 @@ default_subscriber = module default_model = 'mistral.mistral-large-2407-v1:0' default_temperature = 0.0 default_max_output = 2048 -default_aws_id_key = os.getenv("AWS_ID_KEY", None) -default_aws_secret = os.getenv("AWS_SECRET", None) -default_aws_region = os.getenv("AWS_REGION", 'us-west-2') +default_top_p = 0.99 +default_top_k = 40 + +# Actually, these could all just be None, no need to get environment +# variables, as Boto3 would pick all these up if not passed in as args +default_access_key_id = os.getenv("AWS_ACCESS_KEY_ID", None) +default_secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY", None) +default_session_token = os.getenv("AWS_SESSION_TOKEN", None) +default_profile = os.getenv("AWS_PROFILE", None) +default_region = os.getenv("AWS_DEFAULT_REGION", None) + +# Variant API handling depends on the model type + +class ModelHandler: + def __init__(self): + self.temperature = default_temperature + self.max_output = default_max_output + self.top_p = default_top_p + self.top_k = default_top_k + def set_temperature(self, temperature): + self.temperature = temperature + def set_max_output(self, max_output): + self.max_output = max_output + def set_top_p(self, top_p): + self.top_p = top_p + def set_top_k(self, top_k): + self.top_k = top_k + def encode_request(self, system, prompt): + raise RuntimeError("format_request not implemented") + def decode_response(self, response): + raise RuntimeError("format_request not implemented") + +class Mistral(ModelHandler): + def __init__(self): + self.top_p = 0.99 + self.top_k = 40 + def encode_request(self, system, prompt): + return json.dumps({ + "prompt": f"{system}\n\n{prompt}", + "max_tokens": self.max_output, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + }) + def decode_response(self, response): + response_body = json.loads(response.get("body").read()) + return response_body['outputs'][0]['text'] + +# Llama 3 +class Meta(ModelHandler): + def __init__(self): + self.top_p = 0.95 + def encode_request(self, system, prompt): + return json.dumps({ + "prompt": f"{system}\n\n{prompt}", + "max_gen_len": self.max_output, + "temperature": self.temperature, + "top_p": self.top_p, + }) + def decode_response(self, response): + model_response = json.loads(response["body"].read()) + return model_response["generation"] + +class Anthropic(ModelHandler): + def __init__(self): + self.top_p = 0.999 + def encode_request(self, system, prompt): + return json.dumps({ + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": self.max_output, + "temperature": self.temperature, + "top_p": self.top_p, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": f"{system}\n\n{prompt}", + } + ] + } + ] + }) + def decode_response(self, response): + model_response = json.loads(response["body"].read()) + return model_response['content'][0]['text'] + +class Ai21(ModelHandler): + def __init__(self): + self.top_p = 0.9 + def encode_request(self, system, prompt): + return json.dumps({ + "max_tokens": self.max_output, + "temperature": self.temperature, + "top_p": self.top_p, + "messages": [ + { + "role": "user", + "content": f"{system}\n\n{prompt}" + } + ] + }) + def decode_response(self, response): + content = response['body'].read() + content_str = content.decode('utf-8') + content_json = json.loads(content_str) + return content_json['choices'][0]['message']['content'] + +class Cohere(ModelHandler): + def encode_request(self, system, prompt): + return json.dumps({ + "max_tokens": self.max_output, + "temperature": self.temperature, + "message": f"{system}\n\n{prompt}", + }) + def decode_response(self, response): + content = response['body'].read() + content_str = content.decode('utf-8') + content_json = json.loads(content_str) + return content_json['text'] + +Default=Mistral class Processor(ConsumerProducer): def __init__(self, **params): + + print(params) input_queue = params.get("input_queue", default_input_queue) output_queue = params.get("output_queue", default_output_queue) subscriber = params.get("subscriber", default_subscriber) + model = params.get("model", default_model) - aws_id_key = params.get("aws_id_key", default_aws_id_key) - aws_secret = params.get("aws_secret", default_aws_secret) - aws_region = params.get("aws_region", default_aws_region) temperature = params.get("temperature", default_temperature) max_output = params.get("max_output", default_max_output) - if aws_id_key is None: - raise RuntimeError("AWS ID not specified") + aws_access_key_id = params.get( + "aws_access_key_id", default_access_key_id + ) - if aws_secret is None: - raise RuntimeError("AWS secret not specified") + aws_secret_access_key = params.get( + "aws_secret_access_key", default_secret_access_key + ) - if aws_region is None: - raise RuntimeError("AWS region not specified") + aws_session_token = params.get( + "aws_session_token", default_session_token + ) + + aws_region = params.get( + "aws_region", default_region + ) + + aws_profile = params.get( + "aws_profile", default_profile + ) super(Processor, self).__init__( **params | { @@ -81,17 +213,51 @@ class Processor(ConsumerProducer): self.temperature = temperature self.max_output = max_output + self.variant = self.determine_variant(self.model)() + self.variant.set_temperature(temperature) + self.variant.set_max_output(max_output) + self.session = boto3.Session( - aws_access_key_id=aws_id_key, - aws_secret_access_key=aws_secret, - region_name=aws_region + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + profile_name=aws_profile, + region_name=aws_region, ) self.bedrock = self.session.client(service_name='bedrock-runtime') print("Initialised", flush=True) - def handle(self, msg): + def determine_variant(self, model): + + # FIXME: Missing, Amazon models, Deepseek + + # This set of conditions deals with normal bedrock on-demand usage + if self.model.startswith("mistral"): + return Mistral + elif self.model.startswith("meta"): + return Meta + elif self.model.startswith("anthropic"): + return Anthropic + elif self.model.startswith("ai21"): + return Ai21 + elif self.model.startswith("cohere"): + return Cohere + + # The inference profiles + if self.model.startswith("us.meta"): + return Meta + elif self.model.startswith("us.anthropic"): + return Anthropic + elif self.model.startswith("eu.meta"): + return Meta + elif self.model.startswith("eu.anthropic"): + return Anthropic + + return Default + + async def handle(self, msg): v = msg.value() @@ -101,130 +267,27 @@ class Processor(ConsumerProducer): print(f"Handling prompt {id}...", flush=True) - prompt = v.system + "\n\n" + v.prompt - try: - # Mistral Input Format - if self.model.startswith("mistral"): - promptbody = json.dumps({ - "prompt": prompt, - "max_tokens": self.max_output, - "temperature": self.temperature, - "top_p": 0.99, - "top_k": 40 - }) - - # Llama 3.1 Input Format - elif self.model.startswith("meta"): - promptbody = json.dumps({ - "prompt": prompt, - "max_gen_len": self.max_output, - "temperature": self.temperature, - "top_p": 0.95, - }) - - # Anthropic Input Format - elif self.model.startswith("anthropic"): - promptbody = json.dumps({ - "anthropic_version": "bedrock-2023-05-31", - "max_tokens": self.max_output, - "temperature": self.temperature, - "top_p": 0.999, - "messages": [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": prompt - } - ] - } - ] - }) - - # Jamba Input Format - elif self.model.startswith("ai21"): - promptbody = json.dumps({ - "max_tokens": self.max_output, - "temperature": self.temperature, - "top_p": 0.9, - "messages": [ - { - "role": "user", - "content": prompt - } - ] - }) - - # Cohere Input Format - elif self.model.startswith("cohere"): - promptbody = json.dumps({ - "max_tokens": self.max_output, - "temperature": self.temperature, - "message": prompt - }) - - # Use Mistral format as defualt - else: - promptbody = json.dumps({ - "prompt": prompt, - "max_tokens": self.max_output, - "temperature": self.temperature, - "top_p": 0.99, - "top_k": 40 - }) + promptbody = self.variant.encode_request(v.system, v.prompt) accept = 'application/json' contentType = 'application/json' - # FIXME: Consider catching request limits and raise TooManyRequests - # See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html - with __class__.text_completion_metric.time(): response = self.bedrock.invoke_model( - body=promptbody, modelId=self.model, accept=accept, + body=promptbody, + modelId=self.model, + accept=accept, contentType=contentType ) - # Mistral Response Structure - if self.model.startswith("mistral"): - response_body = json.loads(response.get("body").read()) - outputtext = response_body['outputs'][0]['text'] - - # Claude Response Structure - elif self.model.startswith("anthropic"): - model_response = json.loads(response["body"].read()) - outputtext = model_response['content'][0]['text'] - - # Llama 3.1 Response Structure - elif self.model.startswith("meta"): - model_response = json.loads(response["body"].read()) - outputtext = model_response["generation"] - - # Jamba Response Structure - elif self.model.startswith("ai21"): - content = response['body'].read() - content_str = content.decode('utf-8') - content_json = json.loads(content_str) - outputtext = content_json['choices'][0]['message']['content'] - - # Cohere Input Format - elif self.model.startswith("cohere"): - content = response['body'].read() - content_str = content.decode('utf-8') - content_json = json.loads(content_str) - outputtext = content_json['text'] - - # Use Mistral as default - else: - response_body = json.loads(response.get("body").read()) - outputtext = response_body['outputs'][0]['text'] + # Response structure decode + outputtext = self.variant.decode_response(response) metadata = response['ResponseMetadata']['HTTPHeaders'] inputtokens = int(metadata['x-amzn-bedrock-input-token-count']) - outputtokens = int(metadata['x-amzn-bedrock-output-token-count']) + outputtokens = int(metadata['x-amzn-bedrock-output-token-count']) print(outputtext, flush=True) print(f"Input Tokens: {inputtokens}", flush=True) @@ -243,30 +306,18 @@ class Processor(ConsumerProducer): print("Done.", flush=True) + except self.bedrock.exceptions.ThrottlingException as e: - # FIXME: Wrong exception, don't know what Bedrock throws - # for a rate limit - except TooManyRequests: + print("Hit rate limit:", e, flush=True) - print("Send rate limit response...", flush=True) - - r = TextCompletionResponse( - error=Error( - type = "rate-limit", - message = str(e), - ), - response=None, - in_token=None, - out_token=None, - model=None, - ) - - self.producer.send(r, properties={"id": id}) - - self.consumer.acknowledge(msg) + # Leave rate limit retries to the base handler + raise TooManyRequests() except Exception as e: + # Apart from rate limits, treat all exceptions as unrecoverable + + print(type(e)) print(f"Exception: {e}") print("Send error response...", flush=True) @@ -299,21 +350,27 @@ class Processor(ConsumerProducer): ) parser.add_argument( - '-z', '--aws-id-key', - default=default_aws_id_key, - help=f'AWS ID Key' + '-z', '--aws-access-key-id', + default=default_access_key_id, + help=f'AWS access key ID' ) parser.add_argument( - '-k', '--aws-secret', - default=default_aws_secret, - help=f'AWS Secret Key' + '-k', '--aws-secret-access-key', + default=default_secret_access_key, + help=f'AWS secret access key' ) parser.add_argument( '-r', '--aws-region', - default=default_aws_region, - help=f'AWS Region' + default=default_region, + help=f'AWS region' + ) + + parser.add_argument( + '--aws-profile', '--profile', + default=default_profile, + help=f'AWS profile name' ) parser.add_argument( @@ -332,5 +389,5 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-cli/scripts/tg-dump-msgpack b/trustgraph-cli/scripts/tg-dump-msgpack index 18819649..f3b24d73 100755 --- a/trustgraph-cli/scripts/tg-dump-msgpack +++ b/trustgraph-cli/scripts/tg-dump-msgpack @@ -9,6 +9,7 @@ diagnostic utility. import msgpack import sys import argparse +import json def dump(input_file, action): @@ -17,7 +18,7 @@ def dump(input_file, action): unpacker = msgpack.Unpacker(f, raw=False) for unpacked in unpacker: - print(unpacked) + print(json.dumps(unpacked)) def summary(input_file, action): diff --git a/trustgraph-cli/scripts/tg-graph-show b/trustgraph-cli/scripts/tg-graph-show index c09266fb..a3d10283 100755 --- a/trustgraph-cli/scripts/tg-graph-show +++ b/trustgraph-cli/scripts/tg-graph-show @@ -6,23 +6,23 @@ Connects to the graph query service and dumps all graph edges. import argparse import os -from trustgraph.clients.triples_query_client import TriplesQueryClient +from trustgraph.api import Api -default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650') +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_user = 'trustgraph' default_collection = 'default' -def show_graph(pulsar, user, collection): +def show_graph(url, user, collection): - tq = TriplesQueryClient(pulsar_host=pulsar) + api = Api(url) - rows = tq.request( - user=user, collection=collection, - s=None, p=None, o=None, limit=10_000_000 + rows = api.triples_query( +# user=user, collection=collection, + s=None, p=None, o=None, limit=10_000, ) for row in rows: - print(row.s.value, row.p.value, row.o.value) + print(row.s, row.p, row.o) def main(): @@ -32,19 +32,19 @@ def main(): ) parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', + '-u', '--api-url', + default=default_url, + help=f'API URL (default: {default_url})', ) parser.add_argument( - '-u', '--user', + '-U', '--user', default=default_user, help=f'User ID (default: {default_user})' ) parser.add_argument( - '-c', '--collection', + '-C', '--collection', default=default_collection, help=f'Collection ID (default: {default_collection})' ) @@ -54,7 +54,8 @@ def main(): try: show_graph( - pulsar=args.pulsar_host, user=args.user, + url=args.api_url, + user=args.user, collection=args.collection, ) diff --git a/trustgraph-cli/scripts/tg-graph-to-turtle b/trustgraph-cli/scripts/tg-graph-to-turtle index 1d75478e..fc17ddd0 100755 --- a/trustgraph-cli/scripts/tg-graph-to-turtle +++ b/trustgraph-cli/scripts/tg-graph-to-turtle @@ -5,37 +5,45 @@ Connects to the graph query service and dumps all graph edges in Turtle format. """ -import argparse -import os -from trustgraph.clients.triples_query_client import TriplesQueryClient import rdflib import io import sys +import argparse +import os -default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650') +from trustgraph.api import Api, Uri -def show_graph(pulsar): +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_user = 'trustgraph' +default_collection = 'default' - tq = TriplesQueryClient(pulsar_host=pulsar) +def show_graph(url, user, collection): - rows = tq.request(None, None, None, limit=10_000_000) + api = Api(url) + + rows = api.triples_query( + s=None, p=None, o=None, + limit=10_000) +# user=user, collection=collection, g = rdflib.Graph() for row in rows: - sv = rdflib.term.URIRef(row.s.value) - pv = rdflib.term.URIRef(row.p.value) + sv = rdflib.term.URIRef(row.s) + pv = rdflib.term.URIRef(row.p) - if row.o.is_uri: + if isinstance(row.o, Uri): # Skip malformed URLs with spaces in - if " " in row.o.value: + if " " in row.o: continue - ov = rdflib.term.URIRef(row.o.value) + ov = rdflib.term.URIRef(row.o) + else: - ov = rdflib.term.Literal(row.o.value) + + ov = rdflib.term.Literal(row.o) g.add((sv, pv, ov)) @@ -56,16 +64,32 @@ def main(): ) parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', + '-u', '--api-url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-U', '--user', + default=default_user, + help=f'User ID (default: {default_user})' + ) + + parser.add_argument( + '-C', '--collection', + default=default_collection, + help=f'Collection ID (default: {default_collection})' ) args = parser.parse_args() try: - show_graph(args.pulsar_host) + show_graph( + url=args.api_url, + user=args.user, + collection=args.collection + ) except Exception as e: diff --git a/trustgraph-cli/scripts/tg-invoke-agent b/trustgraph-cli/scripts/tg-invoke-agent index 3f05071c..5e213447 100755 --- a/trustgraph-cli/scripts/tg-invoke-agent +++ b/trustgraph-cli/scripts/tg-invoke-agent @@ -1,16 +1,18 @@ #!/usr/bin/env python3 """ -Uses the GraphRAG service to answer a query +Uses the GraphRAG service to answer a question """ import argparse import os import textwrap +import uuid +import asyncio +import json +from websockets.asyncio.client import connect -from trustgraph.clients.agent_client import AgentClient - -default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650') +default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') default_user = 'trustgraph' default_collection = 'default' @@ -27,15 +29,18 @@ def output(text, prefix="> ", width=78): ) print(out) -def query( - pulsar_host, query, user, collection, +async def question( + url, question, user, collection, plan=None, state=None, verbose=False ): - am = AgentClient(pulsar_host=pulsar_host) + if not url.endswith("/"): + url += "/" + + url = url + "api/v1/socket" if verbose: - output(wrap(query), "\U00002753 ") + output(wrap(question), "\U00002753 ") print() def think(x): @@ -48,11 +53,43 @@ def query( output(wrap(x), "\U0001f4a1 ") print() - resp = am.request( - question=query, think=think, observe=observe, - ) + mid = str(uuid.uuid4()) - print(resp) + async with connect(url) as ws: + + req = json.dumps({ + "id": mid, + "service": "agent", + "request": { + "question": question, + } + + }) + + await ws.send(req) + + while True: + + msg = await ws.recv() + + obj = json.loads(msg) + + if obj["id"] != mid: + print("Ignore message") + continue + + if "thought" in obj["response"]: + think(obj["response"]["thought"]) + + if "observation" in obj["response"]: + observe(obj["response"]["observation"]) + + if "answer" in obj["response"]: + print(obj["response"]["answer"]) + + if obj["complete"]: break + + await ws.close() def main(): @@ -62,25 +99,25 @@ def main(): ) parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', + '-u', '--url', + default=default_url, + help=f'API URL (default: {default_url})', ) parser.add_argument( - '-q', '--query', + '-q', '--question', required=True, - help=f'Query to execute', + help=f'Question to answer', ) parser.add_argument( - '-u', '--user', + '-U', '--user', default=default_user, help=f'User ID (default: {default_user})' ) parser.add_argument( - '-c', '--collection', + '-C', '--collection', default=default_collection, help=f'Collection ID (default: {default_collection})' ) @@ -100,19 +137,27 @@ def main(): action="store_true", help=f'Output thinking/observations' ) + + # parser.add_argument( + # '--pulsar-api-key', + # default=default_pulsar_api_key, + # help=f'Pulsar API key', + # ) args = parser.parse_args() try: - query( - pulsar_host=args.pulsar_host, - query=args.query, - user=args.user, - collection=args.collection, - plan=args.plan, - state=args.state, - verbose=args.verbose, + asyncio.run( + question( + url=args.url, + question=args.question, + user=args.user, + collection=args.collection, + plan=args.plan, + state=args.state, + verbose=args.verbose, + ) ) except Exception as e: diff --git a/trustgraph-cli/scripts/tg-invoke-document-rag b/trustgraph-cli/scripts/tg-invoke-document-rag new file mode 100755 index 00000000..759d4200 --- /dev/null +++ b/trustgraph-cli/scripts/tg-invoke-document-rag @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 + +""" +Uses the GraphRAG service to answer a question +""" + +import argparse +import os +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_user = 'trustgraph' +default_collection = 'default' +default_doc_limit = 10 + +def question(url, question, user, collection, doc_limit): + + rag = Api(url) + + resp = rag.document_rag( + question=question, user=user, collection=collection, + doc_limit=doc_limit, + ) + + print(resp) + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-invoke-document-rag', + description=__doc__, + ) + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + # parser.add_argument( + # '--pulsar-api-key', + # default=default_pulsar_api_key, + # help=f'Pulsar API key', + # ) + + parser.add_argument( + '-q', '--question', + required=True, + help=f'Question to answer', + ) + + parser.add_argument( + '-U', '--user', + default=default_user, + help=f'User ID (default: {default_user})' + ) + + parser.add_argument( + '-C', '--collection', + default=default_collection, + help=f'Collection ID (default: {default_collection})' + ) + + parser.add_argument( + '-d', '--doc-limit', + default=default_doc_limit, + help=f'Document limit (default: {default_doc_limit})' + ) + + args = parser.parse_args() + + try: + + question( + url=args.url, + question=args.question, + user=args.user, + collection=args.collection, + doc_limit=args.doc_limit, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + +main() + diff --git a/trustgraph-cli/scripts/tg-invoke-graph-rag b/trustgraph-cli/scripts/tg-invoke-graph-rag new file mode 100755 index 00000000..5bbe5f59 --- /dev/null +++ b/trustgraph-cli/scripts/tg-invoke-graph-rag @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 + +""" +Uses the GraphRAG service to answer a question +""" + +import argparse +import os +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_user = 'trustgraph' +default_collection = 'default' +default_entity_limit = 50 +default_triple_limit = 30 +default_max_subgraph_size = 150 +default_max_path_length = 2 + +def question( + url, question, user, collection, entity_limit, triple_limit, + max_subgraph_size, max_path_length +): + + rag = Api(url) + + resp = rag.graph_rag( + question=question, user=user, collection=collection, + entity_limit=entity_limit, triple_limit=triple_limit, + max_subgraph_size=max_subgraph_size, + max_path_length=max_path_length + ) + + print(resp) + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-invoke-graph-rag', + description=__doc__, + ) + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-q', '--question', + required=True, + help=f'Question to answer', + ) + + parser.add_argument( + '-U', '--user', + default=default_user, + help=f'User ID (default: {default_user})' + ) + + parser.add_argument( + '-C', '--collection', + default=default_collection, + help=f'Collection ID (default: {default_collection})' + ) + + parser.add_argument( + '-e', '--entity-limit', + default=default_entity_limit, + help=f'Entity limit (default: {default_entity_limit})' + ) + + parser.add_argument( + '-t', '--triple-limit', + default=default_triple_limit, + help=f'Triple limit (default: {default_triple_limit})' + ) + + parser.add_argument( + '-s', '--max-subgraph-size', + default=default_max_subgraph_size, + help=f'Max subgraph size (default: {default_max_subgraph_size})' + ) + + parser.add_argument( + '-p', '--max-path-length', + default=default_max_path_length, + help=f'Max path length (default: {default_max_path_length})' + ) + + args = parser.parse_args() + + try: + + question( + url=args.url, + question=args.question, + user=args.user, + collection=args.collection, + entity_limit=args.entity_limit, + triple_limit=args.triple_limit, + max_subgraph_size=args.max_subgraph_size, + max_path_length=args.max_path_length, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + +main() + diff --git a/trustgraph-cli/scripts/tg-invoke-llm b/trustgraph-cli/scripts/tg-invoke-llm index d7289b5f..eb469b6e 100755 --- a/trustgraph-cli/scripts/tg-invoke-llm +++ b/trustgraph-cli/scripts/tg-invoke-llm @@ -8,15 +8,15 @@ and user prompt. Both arguments are required. import argparse import os import json -from trustgraph.clients.llm_client import LlmClient +from trustgraph.api import Api -default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650') +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -def query(pulsar_host, system, prompt): +def query(url, system, prompt): - cli = LlmClient(pulsar_host=pulsar_host) + api = Api(url) - resp = cli.request(system=system, prompt=prompt) + resp = api.text_completion(system=system, prompt=prompt) print(resp) @@ -28,9 +28,9 @@ def main(): ) parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', + '-u', '--url', + default=default_url, + help=f'API URL (default: {default_url})', ) parser.add_argument( @@ -44,13 +44,19 @@ def main(): nargs=1, help='LLM prompt e.g. What is 2 + 2?', ) + + # parser.add_argument( + # '--pulsar-api-key', + # default=default_pulsar_api_key, + # help=f'Pulsar API key', + # ) args = parser.parse_args() try: query( - pulsar_host=args.pulsar_host, + url=args.url, system=args.system[0], prompt=args.prompt[0], ) diff --git a/trustgraph-cli/scripts/tg-invoke-prompt b/trustgraph-cli/scripts/tg-invoke-prompt index 19f30912..426fe1ee 100755 --- a/trustgraph-cli/scripts/tg-invoke-prompt +++ b/trustgraph-cli/scripts/tg-invoke-prompt @@ -12,15 +12,15 @@ using key=value arguments on the command line, and these replace import argparse import os import json -from trustgraph.clients.prompt_client import PromptClient +from trustgraph.api import Api -default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650') +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') -def query(pulsar_host, template_id, variables): +def query(url, template_id, variables): - cli = PromptClient(pulsar_host=pulsar_host) + api = Api(url) - resp = cli.request(id=template_id, variables=variables) + resp = api.prompt(id=template_id, variables=variables) if isinstance(resp, str): print(resp) @@ -35,9 +35,9 @@ def main(): ) parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', + '-u', '--url', + default=default_url, + help=f'API URL (default: {default_url})', ) parser.add_argument( @@ -54,6 +54,12 @@ def main(): help='''Prompt template terms of the form variable=value, can be specified multiple times''', ) + + # parser.add_argument( + # '--pulsar-api-key', + # default=default_pulsar_api_key, + # help=f'Pulsar API key', + # ) args = parser.parse_args() @@ -70,7 +76,7 @@ specified multiple times''', try: query( - pulsar_host=args.pulsar_host, + url=args.url, template_id=args.id[0], variables=variables, ) diff --git a/trustgraph-cli/scripts/tg-load-doc-embeds b/trustgraph-cli/scripts/tg-load-doc-embeds new file mode 100755 index 00000000..d445ec5a --- /dev/null +++ b/trustgraph-cli/scripts/tg-load-doc-embeds @@ -0,0 +1,224 @@ +#!/usr/bin/env python3 + +"""This utility takes a knowledge core and loads it into a running TrustGraph +through the API. The knowledge core should be in msgpack format, which is the +default format produce by tg-save-kg-core. +""" + +import aiohttp +import asyncio +import msgpack +import json +import sys +import argparse +import os +import signal + +class Running: + def __init__(self): self.running = True + def get(self): return self.running + def stop(self): self.running = False + +de_counts = 0 + +async def load_de(running, queue, url): + + global de_counts + + async with aiohttp.ClientSession() as session: + + async with session.ws_connect(f"{url}load/document-embeddings") as ws: + + while running.get(): + + try: + msg = await asyncio.wait_for(queue.get(), 1) + + # End of load + if msg is None: + break + + except: + # Hopefully it's TimeoutError. Annoying to match since + # it changed in 3.11. + continue + + msg = { + "metadata": { + "id": msg["m"]["i"], + "metadata": msg["m"]["m"], + "user": msg["m"]["u"], + "collection": msg["m"]["c"], + }, + "chunks": [ + { + "chunk": chunk["c"], + "vectors": chunk["v"], + } + for chunk in msg["c"] + ], + } + + try: + await ws.send_json(msg) + except Exception as e: + print(e) + + de_counts += 1 + +async def stats(running): + + global de_counts + + while running.get(): + + await asyncio.sleep(2) + + print( + f"Graph embeddings: {de_counts:10d}" + ) + +async def loader(running, de_queue, path, format, user, collection): + + if format == "json": + + raise RuntimeError("Not implemented") + + else: + + with open(path, "rb") as f: + + unpacker = msgpack.Unpacker(f, raw=False) + + while running.get(): + + try: + unpacked = unpacker.unpack() + except: + break + + if user: + unpacked["metadata"]["user"] = user + + if collection: + unpacked["metadata"]["collection"] = collection + + if unpacked[0] == "de": + qtype = de_queue + + while running.get(): + + try: + await asyncio.wait_for(qtype.put(unpacked[1]), 0.5) + + # Successful put message, move on + break + + except: + # Hopefully it's TimeoutError. Annoying to match since + # it changed in 3.11. + continue + + if not running.get(): break + + # Put 'None' on end of queue to finish + while running.get(): + + try: + await asyncio.wait_for(de_queue.put(None), 1) + + # Successful put message, move on + break + + except: + # Hopefully it's TimeoutError. Annoying to match since + # it changed in 3.11. + continue + +async def run(running, **args): + + # Maxsize on queues reduces back-pressure so tg-load-kg-core doesn't + # grow to eat all memory + de_q = asyncio.Queue(maxsize=10) + + load_task = asyncio.create_task( + loader( + running=running, + de_queue=de_q, + path=args["input_file"], format=args["format"], + user=args["user"], collection=args["collection"], + ) + + ) + + de_task = asyncio.create_task( + load_de( + running=running, + queue=de_q, url=args["url"] + "api/v1/" + ) + ) + + stats_task = asyncio.create_task(stats(running)) + + await de_task + + running.stop() + + await load_task + await stats_task + +async def main(running): + + parser = argparse.ArgumentParser( + prog='tg-load-kg-core', + description=__doc__, + ) + + default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/") + default_user = "trustgraph" + collection = "default" + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'TrustGraph API URL (default: {default_url})', + ) + + parser.add_argument( + '-i', '--input-file', + # Make it mandatory, difficult to over-write an existing file + required=True, + help=f'Output file' + ) + + parser.add_argument( + '--format', + default="msgpack", + choices=["msgpack", "json"], + help=f'Output format (default: msgpack)', + ) + + parser.add_argument( + '--user', + help=f'User ID to load as (default: from input)' + ) + + parser.add_argument( + '--collection', + help=f'Collection ID to load as (default: from input)' + ) + + args = parser.parse_args() + + await run(running, **vars(args)) + +running = Running() + +def interrupt(sig, frame): + running.stop() + print('Interrupt') + +signal.signal(signal.SIGINT, interrupt) + +asyncio.run(main(running)) + diff --git a/trustgraph-cli/scripts/tg-load-kg-core b/trustgraph-cli/scripts/tg-load-kg-core index 4e76e525..b79ec992 100755 --- a/trustgraph-cli/scripts/tg-load-kg-core +++ b/trustgraph-cli/scripts/tg-load-kg-core @@ -51,8 +51,13 @@ async def load_ge(running, queue, url): "user": msg["m"]["u"], "collection": msg["m"]["c"], }, - "vectors": msg["v"], - "entity": msg["e"], + "entities": [ + { + "entity": ent["e"], + "vectors": ent["v"], + } + for ent in msg["e"] + ], } try: diff --git a/trustgraph-cli/scripts/tg-load-pdf b/trustgraph-cli/scripts/tg-load-pdf index a0d2b3bc..3e960c67 100755 --- a/trustgraph-cli/scripts/tg-load-pdf +++ b/trustgraph-cli/scripts/tg-load-pdf @@ -6,21 +6,19 @@ Loads a PDF document into TrustGraph processing. import pulsar from pulsar.schema import JsonSchema -import base64 import hashlib import argparse import os import time import uuid -from trustgraph.schema import Document, document_ingest_queue -from trustgraph.schema import Metadata, Triple, Value -from trustgraph.log_level import LogLevel -from trustgraph.knowledge import hash, to_uri, Uri +from trustgraph.api import Api +from trustgraph.knowledge import hash, to_uri from trustgraph.knowledge import PREF_PUBEV, PREF_DOC, PREF_ORG from trustgraph.knowledge import Organization, PublicationEvent from trustgraph.knowledge import DigitalDocument +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_user = 'trustgraph' default_collection = 'default' @@ -28,24 +26,14 @@ class Loader: def __init__( self, - pulsar_host, - output_queue, + url, user, collection, - log_level, metadata, + pulsar_api_key=None, ): - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level.to_pulsar()) - ) - - self.producer = self.client.create_producer( - topic=output_queue, - schema=JsonSchema(Document), - chunking_enabled=True, - ) + self.api = Api(url) self.user = user self.collection = collection @@ -68,49 +56,18 @@ class Loader: id = to_uri(PREF_DOC, id) - triples = [] - - def emit(t): - triples.append(t) - self.metadata.id = id - self.metadata.emit(emit) - r = Document( - metadata=Metadata( - id=id, - metadata=[ - Triple( - s=Value( - value=t["s"], - is_uri=isinstance(t["s"], Uri) - ), - p=Value( - value=t["p"], - is_uri=isinstance(t["p"], Uri) - ), - o=Value( - value=t["o"], - is_uri=isinstance(t["o"], Uri) - ), - ) - for t in triples - ], - user=self.user, - collection=self.collection, - ), - data=base64.b64encode(data), + self.api.load_document( + document=data, id=id, metadata=self.metadata, +# user=self.user, +# collection=self.collection, ) - self.producer.send(r) - print(f"{file}: Loaded successfully.") except Exception as e: print(f"{file}: Failed: {str(e)}", flush=True) - - def __del__(self): - self.client.close() def main(): @@ -119,29 +76,20 @@ def main(): description=__doc__, ) - default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650') - default_output_queue = document_ingest_queue - parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', + '-u', '--url', + default=default_url, + help=f'API URL (default: {default_url})', ) parser.add_argument( - '-o', '--output-queue', - default=default_output_queue, - help=f'Output queue (default: {default_output_queue})' - ) - - parser.add_argument( - '-u', '--user', + '-U', '--user', default=default_user, help=f'User ID (default: {default_user})' ) parser.add_argument( - '-c', '--collection', + '-C', '--collection', default=default_collection, help=f'Collection ID (default: {default_collection})' ) @@ -183,7 +131,7 @@ def main(): ) parser.add_argument( - '--url', help=f'Document URL' + '--document-url', help=f'Document URL' ) parser.add_argument( @@ -194,14 +142,6 @@ def main(): '--identifier', '--id', help=f'Document ID' ) - parser.add_argument( - '-l', '--log-level', - type=LogLevel, - default=LogLevel.ERROR, - choices=list(LogLevel), - help=f'Output queue (default: info)' - ) - parser.add_argument( 'files', nargs='+', help=f'File to load' @@ -221,7 +161,7 @@ def main(): copyright_holder=args.copyright_holder, copyright_year=args.copyright_year, license=args.license, - url=args.url, + url=args.document_url, keywords=args.keyword, ) @@ -239,11 +179,9 @@ def main(): ) p = Loader( - pulsar_host=args.pulsar_host, - output_queue=args.output_queue, + url=args.url, user=args.user, collection=args.collection, - log_level=args.log_level, metadata=document, ) diff --git a/trustgraph-cli/scripts/tg-load-text b/trustgraph-cli/scripts/tg-load-text index 51664a1b..0cc221a5 100755 --- a/trustgraph-cli/scripts/tg-load-text +++ b/trustgraph-cli/scripts/tg-load-text @@ -12,14 +12,13 @@ import os import time import uuid -from trustgraph.schema import TextDocument, text_ingest_queue -from trustgraph.schema import Metadata, Triple, Value -from trustgraph.log_level import LogLevel -from trustgraph.knowledge import hash, to_uri, Literal, Uri +from trustgraph.api import Api +from trustgraph.knowledge import hash, to_uri from trustgraph.knowledge import PREF_PUBEV, PREF_DOC, PREF_ORG from trustgraph.knowledge import Organization, PublicationEvent from trustgraph.knowledge import DigitalDocument +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_user = 'trustgraph' default_collection = 'default' @@ -27,24 +26,13 @@ class Loader: def __init__( self, - pulsar_host, - output_queue, + url, user, collection, - log_level, metadata, ): - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level.to_pulsar()) - ) - - self.producer = self.client.create_producer( - topic=output_queue, - schema=JsonSchema(TextDocument), - chunking_enabled=True, - ) + self.api = Api(url) self.user = user self.collection = collection @@ -67,49 +55,18 @@ class Loader: id = to_uri(PREF_DOC, id) - triples = [] - - def emit(t): - triples.append(t) - self.metadata.id = id - self.metadata.emit(emit) - r = TextDocument( - metadata=Metadata( - id=id, - metadata=[ - Triple( - s=Value( - value=t["s"], - is_uri=isinstance(t["s"], Uri) - ), - p=Value( - value=t["p"], - is_uri=isinstance(t["p"], Uri) - ), - o=Value( - value=t["o"], - is_uri=isinstance(t["o"], Uri) - ), - ) - for t in triples - ], - user=self.user, - collection=self.collection, - ), - text=data, + self.api.load_text( + text=data, id=id, metadata=self.metadata, +# user=self.user, +# collection=self.collection, ) - self.producer.send(r) - print(f"{file}: Loaded successfully.") except Exception as e: print(f"{file}: Failed: {str(e)}", flush=True) - - def __del__(self): - self.client.close() def main(): @@ -118,29 +75,26 @@ def main(): description=__doc__, ) - default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650') - default_output_queue = text_ingest_queue - parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', + '-u', '--url', + default=default_url, + help=f'API URL (default: {default_url})', ) + + # parser.add_argument( + # '--pulsar-api-key', + # default=default_pulsar_api_key, + # help=f'Pulsar API key', + # ) parser.add_argument( - '-o', '--output-queue', - default=default_output_queue, - help=f'Output queue (default: {default_output_queue})' - ) - - parser.add_argument( - '-u', '--user', + '-U', '--user', default=default_user, help=f'User ID (default: {default_user})' ) parser.add_argument( - '-c', '--collection', + '-C', '--collection', default=default_collection, help=f'Collection ID (default: {default_collection})' ) @@ -182,7 +136,7 @@ def main(): ) parser.add_argument( - '--url', help=f'Document URL' + '--document-url', help=f'Document URL' ) parser.add_argument( @@ -193,14 +147,6 @@ def main(): '--identifier', '--id', help=f'Document ID' ) - parser.add_argument( - '-l', '--log-level', - type=LogLevel, - default=LogLevel.ERROR, - choices=list(LogLevel), - help=f'Output queue (default: info)' - ) - parser.add_argument( 'files', nargs='+', help=f'File to load' @@ -220,7 +166,7 @@ def main(): copyright_holder=args.copyright_holder, copyright_year=args.copyright_year, license=args.license, - url=args.url, + url=args.document_url, keywords=args.keyword, ) @@ -238,11 +184,9 @@ def main(): ) p = Loader( - pulsar_host=args.pulsar_host, - output_queue=args.output_queue, + url=args.url, user=args.user, collection=args.collection, - log_level=args.log_level, metadata=document, ) diff --git a/trustgraph-cli/scripts/tg-load-turtle b/trustgraph-cli/scripts/tg-load-turtle index 7c258fcc..3417a87d 100755 --- a/trustgraph-cli/scripts/tg-load-turtle +++ b/trustgraph-cli/scripts/tg-load-turtle @@ -19,6 +19,8 @@ from trustgraph.log_level import LogLevel default_user = 'trustgraph' default_collection = 'default' default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650') +default_pulsar_api_key = os.getenv("PULSAR_API_KEY", None) + default_output_queue = triples_store_queue class Loader: @@ -31,12 +33,21 @@ class Loader: files, user, collection, + pulsar_api_key=None, ): - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level.to_pulsar()) - ) + if pulsar_api_key: + auth = pulsar.AuthenticationToken(pulsar_api_key) + self.client = pulsar.Client( + pulsar_host, + authentication=auth, + logger=pulsar.ConsoleLogger(log_level.to_pulsar()) + ) + else: + self.client = pulsar.Client( + pulsar_host, + logger=pulsar.ConsoleLogger(log_level.to_pulsar()) + ) self.producer = self.client.create_producer( topic=output_queue, @@ -98,6 +109,12 @@ def main(): default=default_pulsar_host, help=f'Pulsar host (default: {default_pulsar_host})', ) + + parser.add_argument( + '--pulsar-api-key', + default=default_pulsar_api_key, + help=f'Pulsar API key', + ) parser.add_argument( '-o', '--output-queue', @@ -137,6 +154,7 @@ def main(): try: p = Loader( pulsar_host=args.pulsar_host, + pulsar_api_key=args.pulsar_api_key, output_queue=args.output_queue, log_level=args.log_level, files=args.files, diff --git a/trustgraph-cli/scripts/tg-query-document-rag b/trustgraph-cli/scripts/tg-query-document-rag deleted file mode 100755 index 8d800629..00000000 --- a/trustgraph-cli/scripts/tg-query-document-rag +++ /dev/null @@ -1,68 +0,0 @@ -#!/usr/bin/env python3 - -""" -Uses the Document RAG service to answer a query -""" - -import argparse -import os -from trustgraph.clients.document_rag_client import DocumentRagClient - -default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650') -default_user = 'trustgraph' -default_collection = 'default' - -def query(pulsar_host, query, user, collection): - - rag = DocumentRagClient(pulsar_host=pulsar) - resp = rag.request(user=user, collection=collection, query=query) - print(resp) - -def main(): - - parser = argparse.ArgumentParser( - prog='tg-query-document-rag', - description=__doc__, - ) - - parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', - ) - - parser.add_argument( - '-q', '--query', - required=True, - help=f'Query to execute', - ) - - parser.add_argument( - '-u', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - - parser.add_argument( - '-c', '--collection', - default=default_collection, - help=f'Collection ID (default: {default_collection})' - ) - - args = parser.parse_args() - - try: - - query( - pulsar_host=args.pulsar_host, - query=args.query, - user=args.user, - collection=args.collection, - ) - - except Exception as e: - - print("Exception:", e, flush=True) - -main() - diff --git a/trustgraph-cli/scripts/tg-query-graph-rag b/trustgraph-cli/scripts/tg-query-graph-rag deleted file mode 100755 index 8a865eea..00000000 --- a/trustgraph-cli/scripts/tg-query-graph-rag +++ /dev/null @@ -1,68 +0,0 @@ -#!/usr/bin/env python3 - -""" -Uses the GraphRAG service to answer a query -""" - -import argparse -import os -from trustgraph.clients.graph_rag_client import GraphRagClient - -default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650') -default_user = 'trustgraph' -default_collection = 'default' - -def query(pulsar_host, query, user, collection): - - rag = GraphRagClient(pulsar_host=pulsar_host) - resp = rag.request(user=user, collection=collection, query=query) - print(resp) - -def main(): - - parser = argparse.ArgumentParser( - prog='tg-graph-query-rag', - description=__doc__, - ) - - parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', - ) - - parser.add_argument( - '-q', '--query', - required=True, - help=f'Query to execute', - ) - - parser.add_argument( - '-u', '--user', - default=default_user, - help=f'User ID (default: {default_user})' - ) - - parser.add_argument( - '-c', '--collection', - default=default_collection, - help=f'Collection ID (default: {default_collection})' - ) - - args = parser.parse_args() - - try: - - query( - pulsar_host=args.pulsar_host, - query=args.query, - user=args.user, - collection=args.collection, - ) - - except Exception as e: - - print("Exception:", e, flush=True) - -main() - diff --git a/trustgraph-cli/scripts/tg-save-doc-embeds b/trustgraph-cli/scripts/tg-save-doc-embeds new file mode 100755 index 00000000..95f8b748 --- /dev/null +++ b/trustgraph-cli/scripts/tg-save-doc-embeds @@ -0,0 +1,198 @@ +#!/usr/bin/env python3 + +""" +This utility connects to a running TrustGraph through the API and creates +a knowledge core from the data streaming through the processing queues. +For completeness of data, tg-save-kg-core should be initiated before data +loading takes place. The default output format, msgpack should be used. +JSON output format is also available - msgpack produces a more compact +representation, which is also more performant to load. +""" + +import aiohttp +import asyncio +import msgpack +import json +import sys +import argparse +import os +import signal + +class Running: + def __init__(self): self.running = True + def get(self): return self.running + def stop(self): self.running = False + +async def fetch_de(running, queue, user, collection, url): + + async with aiohttp.ClientSession() as session: + + de_url = f"{url}stream/document-embeddings" + + async with session.ws_connect(de_url) as ws: + + while running.get(): + + try: + msg = await asyncio.wait_for(ws.receive(), 1) + except: + continue + + if msg.type == aiohttp.WSMsgType.TEXT: + + data = msg.json() + + if user: + if data["metadata"]["user"] != user: + continue + + if collection: + if data["metadata"]["collection"] != collection: + continue + + await queue.put([ + "de", + { + "m": { + "i": data["metadata"]["id"], + "m": data["metadata"]["metadata"], + "u": data["metadata"]["user"], + "c": data["metadata"]["collection"], + }, + "c": [ + { + "c": chunk["chunk"], + "v": chunk["vectors"], + } + for chunk in data["chunks"] + ] + } + ]) + if msg.type == aiohttp.WSMsgType.ERROR: + print("Error") + break + +de_counts = 0 + +async def stats(running): + + global t_counts + global de_counts + + while running.get(): + + await asyncio.sleep(2) + + print( + f"Document embeddings: {de_counts:10d}" + ) + +async def output(running, queue, path, format): + + global t_counts + global de_counts + + with open(path, "wb") as f: + + while running.get(): + + try: + msg = await asyncio.wait_for(queue.get(), 0.5) + except: + # Hopefully it's TimeoutError. Annoying to match since + # it changed in 3.11. + continue + + if format == "msgpack": + f.write(msgpack.packb(msg, use_bin_type=True)) + else: + f.write(json.dumps(msg).encode("utf-8")) + + if msg[0] == "de": + de_counts += 1 + + print("Output file closed") + +async def run(running, **args): + + q = asyncio.Queue() + + de_task = asyncio.create_task( + fetch_de( + running=running, + queue=q, user=args["user"], collection=args["collection"], + url=args["url"] + "api/v1/" + ) + ) + + output_task = asyncio.create_task( + output( + running=running, queue=q, + path=args["output_file"], format=args["format"], + ) + + ) + + stats_task = asyncio.create_task(stats(running)) + + await output_task + await de_task + await stats_task + + print("Exiting") + +async def main(running): + + parser = argparse.ArgumentParser( + prog='tg-save-kg-core', + description=__doc__, + ) + + default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/") + default_user = "trustgraph" + collection = "default" + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'TrustGraph API URL (default: {default_url})', + ) + + parser.add_argument( + '-o', '--output-file', + # Make it mandatory, difficult to over-write an existing file + required=True, + help=f'Output file' + ) + + parser.add_argument( + '--format', + default="msgpack", + choices=["msgpack", "json"], + help=f'Output format (default: msgpack)', + ) + + parser.add_argument( + '--user', + help=f'User ID to filter on (default: no filter)' + ) + + parser.add_argument( + '--collection', + help=f'Collection ID to filter on (default: no filter)' + ) + + args = parser.parse_args() + + await run(running, **vars(args)) + +running = Running() + +def interrupt(sig, frame): + running.stop() + print('Interrupt') + +signal.signal(signal.SIGINT, interrupt) + +asyncio.run(main(running)) + diff --git a/trustgraph-cli/scripts/tg-save-kg-core b/trustgraph-cli/scripts/tg-save-kg-core index e52cd7dc..298f2e84 100755 --- a/trustgraph-cli/scripts/tg-save-kg-core +++ b/trustgraph-cli/scripts/tg-save-kg-core @@ -57,8 +57,13 @@ async def fetch_ge(running, queue, user, collection, url): "u": data["metadata"]["user"], "c": data["metadata"]["collection"], }, - "v": data["vectors"], - "e": data["entity"], + "e": [ + { + "e": ent["entity"], + "v": ent["vectors"], + } + for ent in data["entities"] + ] } ]) if msg.type == aiohttp.WSMsgType.ERROR: diff --git a/trustgraph-cli/setup.py b/trustgraph-cli/setup.py index 8217346f..822ab765 100644 --- a/trustgraph-cli/setup.py +++ b/trustgraph-cli/setup.py @@ -34,29 +34,33 @@ setuptools.setup( python_requires='>=3.8', download_url = "https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v" + version + ".tar.gz", install_requires=[ - "trustgraph-base>=0.18,<0.19", + "trustgraph-base>=0.21,<0.22", "requests", "pulsar-client", + "aiohttp", "rdflib", "tabulate", "msgpack", + "websockets", ], scripts=[ + "scripts/tg-dump-msgpack", "scripts/tg-graph-show", "scripts/tg-graph-to-turtle", + "scripts/tg-init-pulsar", "scripts/tg-init-pulsar-manager", + "scripts/tg-invoke-agent", + "scripts/tg-invoke-document-rag", + "scripts/tg-invoke-graph-rag", + "scripts/tg-invoke-llm", + "scripts/tg-invoke-prompt", + "scripts/tg-load-kg-core", + "scripts/tg-load-doc-embeds", "scripts/tg-load-pdf", "scripts/tg-load-text", "scripts/tg-load-turtle", - "scripts/tg-query-document-rag", - "scripts/tg-query-graph-rag", - "scripts/tg-init-pulsar", "scripts/tg-processor-state", - "scripts/tg-invoke-agent", - "scripts/tg-invoke-prompt", - "scripts/tg-invoke-llm", "scripts/tg-save-kg-core", - "scripts/tg-load-kg-core", - "scripts/tg-dump-msgpack", + "scripts/tg-save-doc-embeds", ] ) diff --git a/trustgraph-embeddings-hf/setup.py b/trustgraph-embeddings-hf/setup.py index 8febd59b..8cf5beb4 100644 --- a/trustgraph-embeddings-hf/setup.py +++ b/trustgraph-embeddings-hf/setup.py @@ -34,8 +34,8 @@ setuptools.setup( python_requires='>=3.8', download_url = "https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v" + version + ".tar.gz", install_requires=[ - "trustgraph-base>=0.18,<0.19", - "trustgraph-flow>=0.18,<0.19", + "trustgraph-base>=0.21,<0.22", + "trustgraph-flow>=0.21,<0.22", "torch", "urllib3", "transformers", diff --git a/trustgraph-embeddings-hf/trustgraph/embeddings/hf/hf.py b/trustgraph-embeddings-hf/trustgraph/embeddings/hf/hf.py index 4b3b39c1..2e44821e 100755 --- a/trustgraph-embeddings-hf/trustgraph/embeddings/hf/hf.py +++ b/trustgraph-embeddings-hf/trustgraph/embeddings/hf/hf.py @@ -40,7 +40,7 @@ class Processor(ConsumerProducer): self.embeddings = HuggingFaceEmbeddings(model_name=model) - def handle(self, msg): + async def handle(self, msg): v = msg.value() @@ -56,7 +56,7 @@ class Processor(ConsumerProducer): print("Send response...", flush=True) r = EmbeddingsResponse(vectors=embeds, error=None) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) print("Done.", flush=True) @@ -75,7 +75,7 @@ class Processor(ConsumerProducer): response=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) self.consumer.acknowledge(msg) @@ -96,5 +96,5 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/scripts/document-embeddings b/trustgraph-flow/scripts/document-embeddings new file mode 100755 index 00000000..26bb85b0 --- /dev/null +++ b/trustgraph-flow/scripts/document-embeddings @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from trustgraph.embeddings.document_embeddings import run + +run() + diff --git a/trustgraph-flow/scripts/embeddings-fastembed b/trustgraph-flow/scripts/embeddings-fastembed new file mode 100755 index 00000000..e1322269 --- /dev/null +++ b/trustgraph-flow/scripts/embeddings-fastembed @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from trustgraph.embeddings.fastembed import run + +run() + diff --git a/trustgraph-flow/scripts/embeddings-vectorize b/trustgraph-flow/scripts/embeddings-vectorize deleted file mode 100755 index 3de1e3a9..00000000 --- a/trustgraph-flow/scripts/embeddings-vectorize +++ /dev/null @@ -1,6 +0,0 @@ -#!/usr/bin/env python3 - -from trustgraph.embeddings.vectorize import run - -run() - diff --git a/trustgraph-flow/scripts/graph-embeddings b/trustgraph-flow/scripts/graph-embeddings new file mode 100755 index 00000000..29b1fbf4 --- /dev/null +++ b/trustgraph-flow/scripts/graph-embeddings @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from trustgraph.embeddings.graph_embeddings import run + +run() + diff --git a/trustgraph-flow/scripts/librarian b/trustgraph-flow/scripts/librarian new file mode 100755 index 00000000..9f6458ab --- /dev/null +++ b/trustgraph-flow/scripts/librarian @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from trustgraph.librarian import run + +run() + diff --git a/trustgraph-flow/scripts/pdf-ocr-mistral b/trustgraph-flow/scripts/pdf-ocr-mistral new file mode 100755 index 00000000..fb086767 --- /dev/null +++ b/trustgraph-flow/scripts/pdf-ocr-mistral @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from trustgraph.decoding.mistral_ocr import run + +run() + diff --git a/trustgraph-flow/scripts/text-completion-lmstudio b/trustgraph-flow/scripts/text-completion-lmstudio new file mode 100755 index 00000000..7b9e259e --- /dev/null +++ b/trustgraph-flow/scripts/text-completion-lmstudio @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from trustgraph.model.text_completion.lmstudio import run + +run() + diff --git a/trustgraph-flow/scripts/text-completion-mistral b/trustgraph-flow/scripts/text-completion-mistral new file mode 100755 index 00000000..91ef2279 --- /dev/null +++ b/trustgraph-flow/scripts/text-completion-mistral @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from trustgraph.model.text_completion.mistral import run + +run() + diff --git a/trustgraph-flow/setup.py b/trustgraph-flow/setup.py index 30ec0170..4b6179b6 100644 --- a/trustgraph-flow/setup.py +++ b/trustgraph-flow/setup.py @@ -34,62 +34,72 @@ setuptools.setup( python_requires='>=3.8', download_url = "https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v" + version + ".tar.gz", install_requires=[ - "trustgraph-base>=0.18,<0.19", - "urllib3", - "rdflib", - "pymilvus", - "langchain", - "langchain-core", - "langchain-text-splitters", - "langchain-community", - "requests", - "cassandra-driver", - "pulsar-client", - "pypdf", - "qdrant-client", - "tabulate", + "trustgraph-base>=0.21,<0.22", + "aiohttp", "anthropic", - "pyyaml", - "prometheus-client", + "cassandra-driver", "cohere", - "openai", - "neo4j", - "tiktoken", + "cryptography", + "falkordb", + "fastembed", "google-generativeai", "ibis", "jsonschema", - "aiohttp", + "langchain", + "langchain-community", + "langchain-core", + "langchain-text-splitters", + "minio", + "mistralai", + "neo4j", + "ollama", + "openai", "pinecone[grpc]", - "falkordb", + "prometheus-client", + "pulsar-client", + "pymilvus", + "pypdf", + "mistralai", + "pyyaml", + "qdrant-client", + "rdflib", + "requests", + "tabulate", + "tiktoken", + "urllib3", ], scripts=[ - "scripts/api-gateway", "scripts/agent-manager-react", + "scripts/api-gateway", "scripts/chunker-recursive", "scripts/chunker-token", "scripts/de-query-milvus", - "scripts/de-query-qdrant", "scripts/de-query-pinecone", + "scripts/de-query-qdrant", "scripts/de-write-milvus", - "scripts/de-write-qdrant", "scripts/de-write-pinecone", + "scripts/de-write-qdrant", + "scripts/document-embeddings", "scripts/document-rag", + "scripts/embeddings-fastembed", "scripts/embeddings-ollama", - "scripts/embeddings-vectorize", "scripts/ge-query-milvus", "scripts/ge-query-pinecone", "scripts/ge-query-qdrant", "scripts/ge-write-milvus", "scripts/ge-write-pinecone", "scripts/ge-write-qdrant", + "scripts/graph-embeddings", "scripts/graph-rag", "scripts/kg-extract-definitions", - "scripts/kg-extract-topics", "scripts/kg-extract-relationships", + "scripts/kg-extract-topics", + "scripts/librarian", "scripts/metering", "scripts/object-extract-row", "scripts/oe-write-milvus", "scripts/pdf-decoder", + "scripts/pdf-ocr-mistral", "scripts/prompt-generic", "scripts/prompt-template", "scripts/rows-write-cassandra", @@ -100,16 +110,18 @@ setuptools.setup( "scripts/text-completion-cohere", "scripts/text-completion-googleaistudio", "scripts/text-completion-llamafile", + "scripts/text-completion-lmstudio", + "scripts/text-completion-mistral", "scripts/text-completion-ollama", "scripts/text-completion-openai", "scripts/triples-query-cassandra", - "scripts/triples-query-neo4j", - "scripts/triples-query-memgraph", "scripts/triples-query-falkordb", + "scripts/triples-query-memgraph", + "scripts/triples-query-neo4j", "scripts/triples-write-cassandra", - "scripts/triples-write-neo4j", - "scripts/triples-write-memgraph", "scripts/triples-write-falkordb", + "scripts/triples-write-memgraph", + "scripts/triples-write-neo4j", "scripts/wikipedia-lookup", ] ) diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py index 8799816b..bc045b71 100755 --- a/trustgraph-flow/trustgraph/agent/react/service.py +++ b/trustgraph-flow/trustgraph/agent/react/service.py @@ -14,8 +14,6 @@ from ... schema import AgentRequest, AgentResponse, AgentStep from ... schema import agent_request_queue, agent_response_queue from ... schema import prompt_request_queue as pr_request_queue from ... schema import prompt_response_queue as pr_response_queue -from ... schema import text_completion_request_queue as tc_request_queue -from ... schema import text_completion_response_queue as tc_response_queue from ... schema import graph_rag_request_queue as gr_request_queue from ... schema import graph_rag_response_queue as gr_response_queue from ... clients.prompt_client import PromptClient @@ -133,12 +131,6 @@ class Processor(ConsumerProducer): prompt_response_queue = params.get( "prompt_response_queue", pr_response_queue ) - text_completion_request_queue = params.get( - "text_completion_request_queue", tc_request_queue - ) - text_completion_response_queue = params.get( - "text_completion_response_queue", tc_response_queue - ) graph_rag_request_queue = params.get( "graph_rag_request_queue", gr_request_queue ) @@ -155,8 +147,6 @@ class Processor(ConsumerProducer): "output_schema": AgentResponse, "prompt_request_queue": prompt_request_queue, "prompt_response_queue": prompt_response_queue, - "text_completion_request_queue": tc_request_queue, - "text_completion_response_queue": tc_response_queue, "graph_rag_request_queue": gr_request_queue, "graph_rag_response_queue": gr_response_queue, } @@ -166,21 +156,16 @@ class Processor(ConsumerProducer): subscriber=subscriber, input_queue=prompt_request_queue, output_queue=prompt_response_queue, - pulsar_host = self.pulsar_host - ) - - self.llm = LlmClient( - subscriber=subscriber, - input_queue=text_completion_request_queue, - output_queue=text_completion_response_queue, - pulsar_host = self.pulsar_host + pulsar_host = self.pulsar_host, + pulsar_api_key=self.pulsar_api_key, ) self.graph_rag = GraphRagClient( subscriber=subscriber, input_queue=graph_rag_request_queue, output_queue=graph_rag_response_queue, - pulsar_host = self.pulsar_host + pulsar_host = self.pulsar_host, + pulsar_api_key=self.pulsar_api_key, ) # Need to be able to feed requests to myself @@ -206,7 +191,7 @@ class Processor(ConsumerProducer): return json.loads(json_str) - def handle(self, msg): + async def handle(self, msg): try: @@ -235,7 +220,7 @@ class Processor(ConsumerProducer): print(f"History: {history}", flush=True) - def think(x): + async def think(x): print(f"Think: {x}", flush=True) @@ -246,9 +231,9 @@ class Processor(ConsumerProducer): observation=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) - def observe(x): + async def observe(x): print(f"Observe: {x}", flush=True) @@ -259,7 +244,7 @@ class Processor(ConsumerProducer): observation=x, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) act = self.agent.react(v.question, history, think, observe) @@ -275,7 +260,7 @@ class Processor(ConsumerProducer): thought=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) print("Done.", flush=True) @@ -318,7 +303,7 @@ class Processor(ConsumerProducer): response=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) @staticmethod def add_args(parser): @@ -340,18 +325,6 @@ class Processor(ConsumerProducer): help=f'Prompt response queue (default: {pr_response_queue})', ) - parser.add_argument( - '--text-completion-request-queue', - default=tc_request_queue, - help=f'Text completion request queue (default: {tc_request_queue})', - ) - - parser.add_argument( - '--text-completion-response-queue', - default=tc_response_queue, - help=f'Text completion response queue (default: {tc_response_queue})', - ) - parser.add_argument( '--graph-rag-request-queue', default=gr_request_queue, @@ -406,5 +379,5 @@ description.''' def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/agent/react/tools.py b/trustgraph-flow/trustgraph/agent/react/tools.py index d9bc846f..941610be 100644 --- a/trustgraph-flow/trustgraph/agent/react/tools.py +++ b/trustgraph-flow/trustgraph/agent/react/tools.py @@ -14,6 +14,6 @@ class TextCompletionImpl: self.context = context def invoke(self, **arguments): return self.context.prompt.request( - "question", { "question": arguments.get("computation") } + "question", { "question": arguments.get("question") } ) diff --git a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py index 694ced70..82f333b5 100755 --- a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py @@ -52,7 +52,7 @@ class Processor(ConsumerProducer): is_separator_regex=False, ) - def handle(self, msg): + async def handle(self, msg): v = msg.value() print(f"Chunking {v.metadata.id}...", flush=True) @@ -70,7 +70,7 @@ class Processor(ConsumerProducer): __class__.chunk_metric.observe(len(chunk.page_content)) - self.send(r) + await self.send(r) print("Done.", flush=True) @@ -98,5 +98,5 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/chunking/token/chunker.py b/trustgraph-flow/trustgraph/chunking/token/chunker.py index dccd9c89..c625b48c 100755 --- a/trustgraph-flow/trustgraph/chunking/token/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/token/chunker.py @@ -51,7 +51,7 @@ class Processor(ConsumerProducer): chunk_overlap=chunk_overlap, ) - def handle(self, msg): + async def handle(self, msg): v = msg.value() print(f"Chunking {v.metadata.id}...", flush=True) @@ -69,7 +69,7 @@ class Processor(ConsumerProducer): __class__.chunk_metric.observe(len(chunk.page_content)) - self.send(r) + await self.send(r) print("Done.", flush=True) @@ -97,5 +97,5 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/decoding/mistral_ocr/__init__.py b/trustgraph-flow/trustgraph/decoding/mistral_ocr/__init__.py new file mode 100644 index 00000000..9d16af90 --- /dev/null +++ b/trustgraph-flow/trustgraph/decoding/mistral_ocr/__init__.py @@ -0,0 +1,3 @@ + +from . processor import * + diff --git a/trustgraph-flow/trustgraph/decoding/mistral_ocr/__main__.py b/trustgraph-flow/trustgraph/decoding/mistral_ocr/__main__.py new file mode 100755 index 00000000..986c0257 --- /dev/null +++ b/trustgraph-flow/trustgraph/decoding/mistral_ocr/__main__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 + +from . processor import run + +if __name__ == '__main__': + run() + diff --git a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py new file mode 100755 index 00000000..f5100244 --- /dev/null +++ b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py @@ -0,0 +1,190 @@ + +""" +Simple decoder, accepts PDF documents on input, outputs pages from the +PDF document as text as separate output objects. +""" + +from pypdf import PdfWriter, PdfReader +from io import BytesIO +import base64 +import uuid +import os + +from mistralai import Mistral +from mistralai import DocumentURLChunk, ImageURLChunk, TextChunk +from mistralai.models import OCRResponse + +from ... schema import Document, TextDocument, Metadata +from ... schema import document_ingest_queue, text_ingest_queue +from ... log_level import LogLevel +from ... base import ConsumerProducer + +module = ".".join(__name__.split(".")[1:-1]) + +default_input_queue = document_ingest_queue +default_output_queue = text_ingest_queue +default_subscriber = module +default_api_key = os.getenv("MISTRAL_TOKEN") + +pages_per_chunk = 5 + +def chunks(lst, n): + "Yield successive n-sized chunks from lst." + for i in range(0, len(lst), n): + yield lst[i:i + n] + +def replace_images_in_markdown(markdown_str: str, images_dict: dict) -> str: + """ + Replace image placeholders in markdown with base64-encoded images. + + Args: + markdown_str: Markdown text containing image placeholders + images_dict: Dictionary mapping image IDs to base64 strings + + Returns: + Markdown text with images replaced by base64 data + """ + for img_name, base64_str in images_dict.items(): + markdown_str = markdown_str.replace( + f"![{img_name}]({img_name})", f"![{img_name}]({base64_str})" + ) + return markdown_str + +def get_combined_markdown(ocr_response: OCRResponse) -> str: + """ + Combine OCR text and images into a single markdown document. + + Args: + ocr_response: Response from OCR processing containing text and images + + Returns: + Combined markdown string with embedded images + """ + markdowns: list[str] = [] + # Extract images from page + for page in ocr_response.pages: + image_data = {} + for img in page.images: + image_data[img.id] = img.image_base64 + # Replace image placeholders with actual images + markdowns.append(replace_images_in_markdown(page.markdown, image_data)) + + return "\n\n".join(markdowns) + +class Processor(ConsumerProducer): + + def __init__(self, **params): + + input_queue = params.get("input_queue", default_input_queue) + output_queue = params.get("output_queue", default_output_queue) + subscriber = params.get("subscriber", default_subscriber) + api_key = params.get("api_key", default_api_key) + + super(Processor, self).__init__( + **params | { + "input_queue": input_queue, + "output_queue": output_queue, + "subscriber": subscriber, + "input_schema": Document, + "output_schema": TextDocument, + } + ) + + if api_key is None: + raise RuntimeError("Mistral API key not specified") + + self.mistral = Mistral(api_key=api_key) + + # Used with Mistral doc upload + self.unique_id = str(uuid.uuid4()) + + print("PDF inited") + + def ocr(self, blob): + + print("Parse PDF...", flush=True) + + pdfbuf = BytesIO(blob) + pdf = PdfReader(pdfbuf) + + for chunk in chunks(pdf.pages, pages_per_chunk): + + print("Get next pages...", flush=True) + + part = PdfWriter() + for page in chunk: + part.add_page(page) + + buf = BytesIO() + part.write_stream(buf) + + print("Upload chunk...", flush=True) + + uploaded_file = self.mistral.files.upload( + file={ + "file_name": self.unique_id, + "content": buf.getvalue(), + }, + purpose="ocr", + ) + + signed_url = self.mistral.files.get_signed_url( + file_id=uploaded_file.id, expiry=1 + ) + + print("OCR...", flush=True) + + processed = self.mistral.ocr.process( + model="mistral-ocr-latest", + include_image_base64=True, + document={ + "type": "document_url", + "document_url": signed_url.url, + } + ) + + print("Extract markdown...", flush=True) + + markdown = get_combined_markdown(processed) + + print("OCR complete.", flush=True) + + return markdown + + async def handle(self, msg): + + print("PDF message received") + + v = msg.value() + + print(f"Decoding {v.metadata.id}...", flush=True) + + markdown = self.ocr(base64.b64decode(v.data)) + + r = TextDocument( + metadata=v.metadata, + text=markdown.encode("utf-8"), + ) + + await self.send(r) + + print("Done.", flush=True) + + @staticmethod + def add_args(parser): + + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) + + parser.add_argument( + '-k', '--api-key', + default=default_api_key, + help=f'Mistral API Key' + ) + +def run(): + + Processor.launch(module, __doc__) + diff --git a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py index 38ac9257..5e5e3612 100755 --- a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py +++ b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py @@ -39,7 +39,7 @@ class Processor(ConsumerProducer): print("PDF inited") - def handle(self, msg): + async def handle(self, msg): print("PDF message received") @@ -64,7 +64,7 @@ class Processor(ConsumerProducer): text=page.page_content.encode("utf-8"), ) - self.send(r) + await self.send(r) print("Done.", flush=True) @@ -78,5 +78,5 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/direct/cassandra.py b/trustgraph-flow/trustgraph/direct/cassandra.py index 568411a9..73f1f33a 100644 --- a/trustgraph-flow/trustgraph/direct/cassandra.py +++ b/trustgraph-flow/trustgraph/direct/cassandra.py @@ -1,12 +1,13 @@ from cassandra.cluster import Cluster from cassandra.auth import PlainTextAuthProvider +from ssl import SSLContext, PROTOCOL_TLSv1_2 class TrustGraph: def __init__( self, hosts=None, - keyspace="trustgraph", table="default", + keyspace="trustgraph", table="default", username=None, password=None ): if hosts is None: @@ -14,8 +15,14 @@ class TrustGraph: self.keyspace = keyspace self.table = table + self.username = username - self.cluster = Cluster(hosts) + if username and password: + ssl_context = SSLContext(PROTOCOL_TLSv1_2) + auth_provider = PlainTextAuthProvider(username=username, password=password) + self.cluster = Cluster(hosts, auth_provider=auth_provider, ssl_context=ssl_context) + else: + self.cluster = Cluster(hosts) self.session = self.cluster.connect() self.init() diff --git a/trustgraph-flow/trustgraph/document_rag.py b/trustgraph-flow/trustgraph/document_rag.py index f3c8b158..4fc4850a 100644 --- a/trustgraph-flow/trustgraph/document_rag.py +++ b/trustgraph-flow/trustgraph/document_rag.py @@ -16,11 +16,54 @@ from . schema import document_embeddings_response_queue LABEL="http://www.w3.org/2000/01/rdf-schema#label" DEFINITION="http://www.w3.org/2004/02/skos/core#definition" +class Query: + + def __init__( + self, rag, user, collection, verbose, + doc_limit=20 + ): + self.rag = rag + self.user = user + self.collection = collection + self.verbose = verbose + self.doc_limit = doc_limit + + def get_vector(self, query): + + if self.verbose: + print("Compute embeddings...", flush=True) + + qembeds = self.rag.embeddings.request(query) + + if self.verbose: + print("Done.", flush=True) + + return qembeds + + def get_docs(self, query): + + vectors = self.get_vector(query) + + if self.verbose: + print("Get entities...", flush=True) + + docs = self.rag.de_client.request( + vectors, limit=self.doc_limit + ) + + if self.verbose: + print("Docs:", flush=True) + for doc in docs: + print(doc, flush=True) + + return docs + class DocumentRag: def __init__( self, pulsar_host="pulsar://pulsar:6650", + pulsar_api_key=None, pr_request_queue=None, pr_response_queue=None, emb_request_queue=None, @@ -54,14 +97,12 @@ class DocumentRag: if self.verbose: print("Initialising...", flush=True) - # FIXME: Configurable - self.entity_limit = 20 - self.de_client = DocumentEmbeddingsClient( pulsar_host=pulsar_host, subscriber=module + "-de", input_queue=de_request_queue, output_queue=de_response_queue, + pulsar_api_key=pulsar_api_key, ) self.embeddings = EmbeddingsClient( @@ -69,6 +110,7 @@ class DocumentRag: input_queue=emb_request_queue, output_queue=emb_response_queue, subscriber=module + "-emb", + pulsar_api_key=pulsar_api_key, ) self.lang = PromptClient( @@ -76,47 +118,26 @@ class DocumentRag: input_queue=pr_request_queue, output_queue=pr_response_queue, subscriber=module + "-de-prompt", + pulsar_api_key=pulsar_api_key, ) if self.verbose: print("Initialised", flush=True) - def get_vector(self, query): - - if self.verbose: - print("Compute embeddings...", flush=True) - - qembeds = self.embeddings.request(query) - - if self.verbose: - print("Done.", flush=True) - - return qembeds - - def get_docs(self, query): - - vectors = self.get_vector(query) - - if self.verbose: - print("Get entities...", flush=True) - - docs = self.de_client.request( - vectors, self.entity_limit - ) - - if self.verbose: - print("Docs:", flush=True) - for doc in docs: - print(doc, flush=True) - - return docs - - def query(self, query): + def query( + self, query, user="trustgraph", collection="default", + doc_limit=20, + ): if self.verbose: print("Construct prompt...", flush=True) - docs = self.get_docs(query) + q = Query( + rag=self, user=user, collection=collection, verbose=self.verbose, + doc_limit=doc_limit + ) + + docs = q.get_docs(query) if self.verbose: print("Invoke LLM...", flush=True) diff --git a/trustgraph-flow/trustgraph/embeddings/document_embeddings/__init__.py b/trustgraph-flow/trustgraph/embeddings/document_embeddings/__init__.py new file mode 100644 index 00000000..40d505a5 --- /dev/null +++ b/trustgraph-flow/trustgraph/embeddings/document_embeddings/__init__.py @@ -0,0 +1,3 @@ + +from . embeddings import * + diff --git a/trustgraph-flow/trustgraph/embeddings/vectorize/__main__.py b/trustgraph-flow/trustgraph/embeddings/document_embeddings/__main__.py similarity index 57% rename from trustgraph-flow/trustgraph/embeddings/vectorize/__main__.py rename to trustgraph-flow/trustgraph/embeddings/document_embeddings/__main__.py index a578de8a..a48cc4d0 100755 --- a/trustgraph-flow/trustgraph/embeddings/vectorize/__main__.py +++ b/trustgraph-flow/trustgraph/embeddings/document_embeddings/__main__.py @@ -1,5 +1,5 @@ -from . vectorize import run +from . embeddings import run if __name__ == '__main__': run() diff --git a/trustgraph-flow/trustgraph/embeddings/vectorize/vectorize.py b/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py similarity index 72% rename from trustgraph-flow/trustgraph/embeddings/vectorize/vectorize.py rename to trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py index 4cf2af05..70f53e07 100755 --- a/trustgraph-flow/trustgraph/embeddings/vectorize/vectorize.py +++ b/trustgraph-flow/trustgraph/embeddings/document_embeddings/embeddings.py @@ -1,11 +1,13 @@ """ -Vectorizer, calls the embeddings service to get embeddings for a chunk. -Input is text chunk, output is chunk and vectors. +Document embeddings, calls the embeddings service to get embeddings for a +chunk of text. Input is chunk of text plus metadata. +Output is chunk plus embedding. """ -from ... schema import Chunk, ChunkEmbeddings -from ... schema import chunk_ingest_queue, chunk_embeddings_ingest_queue +from ... schema import Chunk, ChunkEmbeddings, DocumentEmbeddings +from ... schema import chunk_ingest_queue +from ... schema import document_embeddings_store_queue from ... schema import embeddings_request_queue, embeddings_response_queue from ... clients.embeddings_client import EmbeddingsClient from ... log_level import LogLevel @@ -14,7 +16,7 @@ from ... base import ConsumerProducer module = ".".join(__name__.split(".")[1:-1]) default_input_queue = chunk_ingest_queue -default_output_queue = chunk_embeddings_ingest_queue +default_output_queue = document_embeddings_store_queue default_subscriber = module class Processor(ConsumerProducer): @@ -39,42 +41,47 @@ class Processor(ConsumerProducer): "embeddings_response_queue": emb_response_queue, "subscriber": subscriber, "input_schema": Chunk, - "output_schema": ChunkEmbeddings, + "output_schema": DocumentEmbeddings, } ) self.embeddings = EmbeddingsClient( pulsar_host=self.pulsar_host, + pulsar_api_key=self.pulsar_api_key, input_queue=emb_request_queue, output_queue=emb_response_queue, subscriber=module + "-emb", ) - def emit(self, metadata, chunk, vectors): - - r = ChunkEmbeddings(metadata=metadata, chunk=chunk, vectors=vectors) - self.producer.send(r) - - def handle(self, msg): + async def handle(self, msg): v = msg.value() print(f"Indexing {v.metadata.id}...", flush=True) - chunk = v.chunk.decode("utf-8") - try: - vectors = self.embeddings.request(chunk) + vectors = self.embeddings.request(v.chunk) - self.emit( + embeds = [ + ChunkEmbeddings( + chunk=v.chunk, + vectors=vectors, + ) + ] + + r = DocumentEmbeddings( metadata=v.metadata, - chunk=chunk.encode("utf-8"), - vectors=vectors + chunks=embeds, ) + await self.send(r) + except Exception as e: print("Exception:", e, flush=True) + # Retry + raise e + print("Done.", flush=True) @staticmethod @@ -99,5 +106,5 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/embeddings/fastembed/__init__.py b/trustgraph-flow/trustgraph/embeddings/fastembed/__init__.py new file mode 100644 index 00000000..9d16af90 --- /dev/null +++ b/trustgraph-flow/trustgraph/embeddings/fastembed/__init__.py @@ -0,0 +1,3 @@ + +from . processor import * + diff --git a/trustgraph-flow/trustgraph/embeddings/fastembed/__main__.py b/trustgraph-flow/trustgraph/embeddings/fastembed/__main__.py new file mode 100755 index 00000000..986c0257 --- /dev/null +++ b/trustgraph-flow/trustgraph/embeddings/fastembed/__main__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 + +from . processor import run + +if __name__ == '__main__': + run() + diff --git a/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py b/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py new file mode 100755 index 00000000..bc164fa0 --- /dev/null +++ b/trustgraph-flow/trustgraph/embeddings/fastembed/processor.py @@ -0,0 +1,89 @@ + +""" +Embeddings service, applies an embeddings model selected from HuggingFace. +Input is text, output is embeddings vector. +""" + +from ... schema import EmbeddingsRequest, EmbeddingsResponse +from ... schema import embeddings_request_queue, embeddings_response_queue +from ... log_level import LogLevel +from ... base import ConsumerProducer +from fastembed import TextEmbedding +import os + +module = ".".join(__name__.split(".")[1:-1]) + +default_input_queue = embeddings_request_queue +default_output_queue = embeddings_response_queue +default_subscriber = module +default_model="sentence-transformers/all-MiniLM-L6-v2" + +class Processor(ConsumerProducer): + + def __init__(self, **params): + + input_queue = params.get("input_queue", default_input_queue) + output_queue = params.get("output_queue", default_output_queue) + subscriber = params.get("subscriber", default_subscriber) + + model = params.get("model", default_model) + + super(Processor, self).__init__( + **params | { + "input_queue": input_queue, + "output_queue": output_queue, + "subscriber": subscriber, + "input_schema": EmbeddingsRequest, + "output_schema": EmbeddingsResponse, + "model": model, + } + ) + + self.embeddings = TextEmbedding(model_name = model) + + async def handle(self, msg): + + v = msg.value() + + # Sender-produced ID + + id = msg.properties()["id"] + + print(f"Handling input {id}...", flush=True) + + text = v.text + vecs = self.embeddings.embed([text]) + + vecs = [ + v.tolist() + for v in vecs + ] + + print("Send response...", flush=True) + r = EmbeddingsResponse( + vectors=list(vecs), + error=None, + ) + + await self.send(r, properties={"id": id}) + + print("Done.", flush=True) + + @staticmethod + def add_args(parser): + + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) + + parser.add_argument( + '-m', '--model', + default=default_model, + help=f'Embeddings model (default: {default_model})' + ) + +def run(): + + Processor.launch(module, __doc__) + diff --git a/trustgraph-flow/trustgraph/embeddings/graph_embeddings/__init__.py b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/__init__.py new file mode 100644 index 00000000..40d505a5 --- /dev/null +++ b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/__init__.py @@ -0,0 +1,3 @@ + +from . embeddings import * + diff --git a/trustgraph-flow/trustgraph/embeddings/graph_embeddings/__main__.py b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/__main__.py new file mode 100755 index 00000000..a48cc4d0 --- /dev/null +++ b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/__main__.py @@ -0,0 +1,6 @@ + +from . embeddings import run + +if __name__ == '__main__': + run() + diff --git a/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py new file mode 100755 index 00000000..2cbe9907 --- /dev/null +++ b/trustgraph-flow/trustgraph/embeddings/graph_embeddings/embeddings.py @@ -0,0 +1,113 @@ + +""" +Graph embeddings, calls the embeddings service to get embeddings for a +set of entity contexts. Input is entity plus textual context. +Output is entity plus embedding. +""" + +from ... schema import EntityContexts, EntityEmbeddings, GraphEmbeddings +from ... schema import entity_contexts_ingest_queue +from ... schema import graph_embeddings_store_queue +from ... schema import embeddings_request_queue, embeddings_response_queue +from ... clients.embeddings_client import EmbeddingsClient +from ... log_level import LogLevel +from ... base import ConsumerProducer + +module = ".".join(__name__.split(".")[1:-1]) + +default_input_queue = entity_contexts_ingest_queue +default_output_queue = graph_embeddings_store_queue +default_subscriber = module + +class Processor(ConsumerProducer): + + def __init__(self, **params): + + input_queue = params.get("input_queue", default_input_queue) + output_queue = params.get("output_queue", default_output_queue) + subscriber = params.get("subscriber", default_subscriber) + emb_request_queue = params.get( + "embeddings_request_queue", embeddings_request_queue + ) + emb_response_queue = params.get( + "embeddings_response_queue", embeddings_response_queue + ) + + super(Processor, self).__init__( + **params | { + "input_queue": input_queue, + "output_queue": output_queue, + "embeddings_request_queue": emb_request_queue, + "embeddings_response_queue": emb_response_queue, + "subscriber": subscriber, + "input_schema": EntityContexts, + "output_schema": GraphEmbeddings, + } + ) + + self.embeddings = EmbeddingsClient( + pulsar_host=self.pulsar_host, + input_queue=emb_request_queue, + output_queue=emb_response_queue, + subscriber=module + "-emb", + ) + + async def handle(self, msg): + + v = msg.value() + print(f"Indexing {v.metadata.id}...", flush=True) + + entities = [] + + try: + + for entity in v.entities: + + vectors = self.embeddings.request(entity.context) + + entities.append( + EntityEmbeddings( + entity=entity.entity, + vectors=vectors + ) + ) + + r = GraphEmbeddings( + metadata=v.metadata, + entities=entities, + ) + + await self.send(r) + + except Exception as e: + print("Exception:", e, flush=True) + + # Retry + raise e + + print("Done.", flush=True) + + @staticmethod + def add_args(parser): + + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) + + parser.add_argument( + '--embeddings-request-queue', + default=embeddings_request_queue, + help=f'Embeddings request queue (default: {embeddings_request_queue})', + ) + + parser.add_argument( + '--embeddings-response-queue', + default=embeddings_response_queue, + help=f'Embeddings request queue (default: {embeddings_response_queue})', + ) + +def run(): + + Processor.launch(module, __doc__) + diff --git a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py index 6682a79f..c441b9c6 100755 --- a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py +++ b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py @@ -1,14 +1,15 @@ """ -Embeddings service, applies an embeddings model selected from HuggingFace. +Embeddings service, applies an embeddings model hosted on a local Ollama. Input is text, output is embeddings vector. """ -from langchain_community.embeddings import OllamaEmbeddings from ... schema import EmbeddingsRequest, EmbeddingsResponse from ... schema import embeddings_request_queue, embeddings_response_queue from ... log_level import LogLevel from ... base import ConsumerProducer +from ollama import Client +import os module = ".".join(__name__.split(".")[1:-1]) @@ -16,7 +17,7 @@ default_input_queue = embeddings_request_queue default_output_queue = embeddings_response_queue default_subscriber = module default_model="mxbai-embed-large" -default_ollama = 'http://localhost:11434' +default_ollama = os.getenv("OLLAMA_HOST", 'http://localhost:11434') class Processor(ConsumerProducer): @@ -26,6 +27,9 @@ class Processor(ConsumerProducer): output_queue = params.get("output_queue", default_output_queue) subscriber = params.get("subscriber", default_subscriber) + ollama = params.get("ollama", default_ollama) + model = params.get("model", default_model) + super(Processor, self).__init__( **params | { "input_queue": input_queue, @@ -33,12 +37,15 @@ class Processor(ConsumerProducer): "subscriber": subscriber, "input_schema": EmbeddingsRequest, "output_schema": EmbeddingsResponse, + "ollama": ollama, + "model": model, } ) - self.embeddings = OllamaEmbeddings(base_url=ollama, model=model) + self.client = Client(host=ollama) + self.model = model - def handle(self, msg): + async def handle(self, msg): v = msg.value() @@ -49,12 +56,18 @@ class Processor(ConsumerProducer): print(f"Handling input {id}...", flush=True) text = v.text - embeds = self.embeddings.embed_query([text]) + embeds = self.client.embed( + model = self.model, + input = text + ) print("Send response...", flush=True) - r = EmbeddingsResponse(vectors=[embeds]) + r = EmbeddingsResponse( + vectors=embeds.embeddings, + error=None, + ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) print("Done.", flush=True) @@ -80,5 +93,5 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/embeddings/vectorize/__init__.py b/trustgraph-flow/trustgraph/embeddings/vectorize/__init__.py deleted file mode 100644 index 31596b8c..00000000 --- a/trustgraph-flow/trustgraph/embeddings/vectorize/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ - -from . vectorize import * - diff --git a/trustgraph-flow/trustgraph/external/wikipedia/service.py b/trustgraph-flow/trustgraph/external/wikipedia/service.py index 932e1213..cc002765 100644 --- a/trustgraph-flow/trustgraph/external/wikipedia/service.py +++ b/trustgraph-flow/trustgraph/external/wikipedia/service.py @@ -39,7 +39,7 @@ class Processor(ConsumerProducer): self.url = url - def handle(self, msg): + async def handle(self, msg): v = msg.value() @@ -60,7 +60,7 @@ class Processor(ConsumerProducer): text=resp ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) self.consumer.acknowledge(msg) @@ -75,7 +75,7 @@ class Processor(ConsumerProducer): ), text=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) self.consumer.acknowledge(msg) @@ -98,5 +98,5 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py index eed34574..47c99802 100755 --- a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py @@ -1,14 +1,17 @@ """ -Simple decoder, accepts embeddings+text chunks input, applies entity analysis to -get entity definitions which are output as graph edges. +Simple decoder, accepts text chunks input, applies entity analysis to +get entity definitions which are output as graph edges along with +entity/context definitions for embedding. """ import urllib.parse -import json +from pulsar.schema import JsonSchema -from .... schema import ChunkEmbeddings, Triple, Triples, Metadata, Value -from .... schema import chunk_embeddings_ingest_queue, triples_store_queue +from .... schema import Chunk, Triple, Triples, Metadata, Value +from .... schema import EntityContext, EntityContexts +from .... schema import chunk_ingest_queue, triples_store_queue +from .... schema import entity_contexts_ingest_queue from .... schema import prompt_request_queue from .... schema import prompt_response_queue from .... log_level import LogLevel @@ -22,8 +25,9 @@ SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True) module = ".".join(__name__.split(".")[1:-1]) -default_input_queue = chunk_embeddings_ingest_queue +default_input_queue = chunk_ingest_queue default_output_queue = triples_store_queue +default_entity_context_queue = entity_contexts_ingest_queue default_subscriber = module class Processor(ConsumerProducer): @@ -32,6 +36,10 @@ class Processor(ConsumerProducer): input_queue = params.get("input_queue", default_input_queue) output_queue = params.get("output_queue", default_output_queue) + ec_queue = params.get( + "entity_context_queue", + default_entity_context_queue + ) subscriber = params.get("subscriber", default_subscriber) pr_request_queue = params.get( "prompt_request_queue", prompt_request_queue @@ -45,15 +53,33 @@ class Processor(ConsumerProducer): "input_queue": input_queue, "output_queue": output_queue, "subscriber": subscriber, - "input_schema": ChunkEmbeddings, + "input_schema": Chunk, "output_schema": Triples, "prompt_request_queue": pr_request_queue, "prompt_response_queue": pr_response_queue, } ) + self.ec_prod = self.client.create_producer( + topic=ec_queue, + schema=JsonSchema(EntityContexts), + ) + + __class__.pubsub_metric.info({ + "input_queue": input_queue, + "output_queue": output_queue, + "entity_context_queue": ec_queue, + "prompt_request_queue": pr_request_queue, + "prompt_response_queue": pr_response_queue, + "subscriber": subscriber, + "input_schema": Chunk.__name__, + "output_schema": Triples.__name__, + "vector_schema": EntityContexts.__name__, + }) + self.prompt = PromptClient( pulsar_host=self.pulsar_host, + pulsar_api_key=self.pulsar_api_key, input_queue=pr_request_queue, output_queue=pr_response_queue, subscriber = module + "-prompt", @@ -71,15 +97,23 @@ class Processor(ConsumerProducer): return self.prompt.request_definitions(chunk) - def emit_edges(self, metadata, triples): + async def emit_edges(self, metadata, triples): t = Triples( metadata=metadata, triples=triples, ) - self.producer.send(t) + await self.send(t) - def handle(self, msg): + async def emit_ecs(self, metadata, entities): + + t = EntityContexts( + metadata=metadata, + entities=entities, + ) + self.ec_prod.send(t) + + async def handle(self, msg): v = msg.value() print(f"Indexing {v.metadata.id}...", flush=True) @@ -91,6 +125,7 @@ class Processor(ConsumerProducer): defs = self.get_definitions(chunk) triples = [] + entities = [] # FIXME: Putting metadata into triples store is duplicated in # relationships extractor too @@ -129,7 +164,15 @@ class Processor(ConsumerProducer): o=Value(value=v.metadata.id, is_uri=True) )) - self.emit_edges( + ec = EntityContext( + entity=s_value, + context=defn.definition, + ) + + entities.append(ec) + + + await self.emit_edges( Metadata( id=v.metadata.id, metadata=[], @@ -139,6 +182,16 @@ class Processor(ConsumerProducer): triples ) + await self.emit_ecs( + Metadata( + id=v.metadata.id, + metadata=[], + user=v.metadata.user, + collection=v.metadata.collection, + ), + entities + ) + except Exception as e: print("Exception: ", e, flush=True) @@ -152,6 +205,12 @@ class Processor(ConsumerProducer): default_output_queue, ) + parser.add_argument( + '-e', '--entity-context-queue', + default=default_entity_context_queue, + help=f'Entity context queue (default: {default_entity_context_queue})' + ) + parser.add_argument( '--prompt-request-queue', default=prompt_request_queue, @@ -166,5 +225,5 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py index d2dea062..2f293527 100755 --- a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py @@ -1,18 +1,15 @@ """ -Simple decoder, accepts vector+text chunks input, applies entity +Simple decoder, accepts text chunks input, applies entity relationship analysis to get entity relationship edges which are output as graph edges. """ import urllib.parse -import os -from pulsar.schema import JsonSchema -from .... schema import ChunkEmbeddings, Triple, Triples, GraphEmbeddings +from .... schema import Chunk, Triple, Triples from .... schema import Metadata, Value -from .... schema import chunk_embeddings_ingest_queue, triples_store_queue -from .... schema import graph_embeddings_store_queue +from .... schema import chunk_ingest_queue, triples_store_queue from .... schema import prompt_request_queue from .... schema import prompt_response_queue from .... log_level import LogLevel @@ -25,9 +22,8 @@ SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True) module = ".".join(__name__.split(".")[1:-1]) -default_input_queue = chunk_embeddings_ingest_queue +default_input_queue = chunk_ingest_queue default_output_queue = triples_store_queue -default_vector_queue = graph_embeddings_store_queue default_subscriber = module class Processor(ConsumerProducer): @@ -36,7 +32,6 @@ class Processor(ConsumerProducer): input_queue = params.get("input_queue", default_input_queue) output_queue = params.get("output_queue", default_output_queue) - vector_queue = params.get("vector_queue", default_vector_queue) subscriber = params.get("subscriber", default_subscriber) pr_request_queue = params.get( "prompt_request_queue", prompt_request_queue @@ -50,32 +45,16 @@ class Processor(ConsumerProducer): "input_queue": input_queue, "output_queue": output_queue, "subscriber": subscriber, - "input_schema": ChunkEmbeddings, + "input_schema": Chunk, "output_schema": Triples, "prompt_request_queue": pr_request_queue, "prompt_response_queue": pr_response_queue, } ) - self.vec_prod = self.client.create_producer( - topic=vector_queue, - schema=JsonSchema(GraphEmbeddings), - ) - - __class__.pubsub_metric.info({ - "input_queue": input_queue, - "output_queue": output_queue, - "vector_queue": vector_queue, - "prompt_request_queue": pr_request_queue, - "prompt_response_queue": pr_response_queue, - "subscriber": subscriber, - "input_schema": ChunkEmbeddings.__name__, - "output_schema": Triples.__name__, - "vector_schema": GraphEmbeddings.__name__, - }) - self.prompt = PromptClient( pulsar_host=self.pulsar_host, + pulsar_api_key=self.pulsar_api_key, input_queue=pr_request_queue, output_queue=pr_response_queue, subscriber = module + "-prompt", @@ -93,20 +72,15 @@ class Processor(ConsumerProducer): return self.prompt.request_relationships(chunk) - def emit_edges(self, metadata, triples): + async def emit_edges(self, metadata, triples): t = Triples( metadata=metadata, triples=triples, ) - self.producer.send(t) + await self.send(t) - def emit_vec(self, metadata, ent, vec): - - r = GraphEmbeddings(metadata=metadata, entity=ent, vectors=vec) - self.vec_prod.send(r) - - def handle(self, msg): + async def handle(self, msg): v = msg.value() print(f"Indexing {v.metadata.id}...", flush=True) @@ -193,13 +167,7 @@ class Processor(ConsumerProducer): o=Value(value=v.metadata.id, is_uri=True) )) - self.emit_vec(v.metadata, s_value, v.vectors) - self.emit_vec(v.metadata, p_value, v.vectors) - - if rel.o_entity: - self.emit_vec(v.metadata, o_value, v.vectors) - - self.emit_edges( + await self.emit_edges( Metadata( id=v.metadata.id, metadata=[], @@ -222,12 +190,6 @@ class Processor(ConsumerProducer): default_output_queue, ) - parser.add_argument( - '-c', '--vector-queue', - default=default_vector_queue, - help=f'Vector output queue (default: {default_vector_queue})' - ) - parser.add_argument( '--prompt-request-queue', default=prompt_request_queue, @@ -242,5 +204,5 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/extract/kg/topics/extract.py b/trustgraph-flow/trustgraph/extract/kg/topics/extract.py index 8dfc3e6e..7424abe2 100755 --- a/trustgraph-flow/trustgraph/extract/kg/topics/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/topics/extract.py @@ -1,14 +1,14 @@ """ -Simple decoder, accepts embeddings+text chunks input, applies entity analysis to -get entity definitions which are output as graph edges. +Simple decoder, accepts text chunks input, applies entity analysis to +get topics which are output as graph edges. """ import urllib.parse import json -from .... schema import ChunkEmbeddings, Triple, Triples, Metadata, Value -from .... schema import chunk_embeddings_ingest_queue, triples_store_queue +from .... schema import Chunk, Triple, Triples, Metadata, Value +from .... schema import chunk_ingest_queue, triples_store_queue from .... schema import prompt_request_queue from .... schema import prompt_response_queue from .... log_level import LogLevel @@ -20,7 +20,7 @@ DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True) module = ".".join(__name__.split(".")[1:-1]) -default_input_queue = chunk_embeddings_ingest_queue +default_input_queue = chunk_ingest_queue default_output_queue = triples_store_queue default_subscriber = module @@ -43,7 +43,7 @@ class Processor(ConsumerProducer): "input_queue": input_queue, "output_queue": output_queue, "subscriber": subscriber, - "input_schema": ChunkEmbeddings, + "input_schema": Chunk, "output_schema": Triples, "prompt_request_queue": pr_request_queue, "prompt_response_queue": pr_response_queue, @@ -52,6 +52,7 @@ class Processor(ConsumerProducer): self.prompt = PromptClient( pulsar_host=self.pulsar_host, + pulsar_api_key=self.pulsar_api_key, input_queue=pr_request_queue, output_queue=pr_response_queue, subscriber = module + "-prompt", @@ -69,15 +70,15 @@ class Processor(ConsumerProducer): return self.prompt.request_topics(chunk) - def emit_edge(self, metadata, s, p, o): + async def emit_edge(self, metadata, s, p, o): t = Triples( metadata=metadata, triples=[Triple(s=s, p=p, o=o)], ) - self.producer.send(t) + await self.send(t) - def handle(self, msg): + async def handle(self, msg): v = msg.value() print(f"Indexing {v.metadata.id}...", flush=True) @@ -104,7 +105,9 @@ class Processor(ConsumerProducer): s_value = Value(value=str(s_uri), is_uri=True) o_value = Value(value=str(o), is_uri=False) - self.emit_edge(v. metadata, s_value, DEFINITION_VALUE, o_value) + await self.emit_edge( + v.metadata, s_value, DEFINITION_VALUE, o_value + ) except Exception as e: print("Exception: ", e, flush=True) @@ -133,5 +136,5 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/extract/object/row/extract.py b/trustgraph-flow/trustgraph/extract/object/row/extract.py index 185a59c3..9ccf3370 100755 --- a/trustgraph-flow/trustgraph/extract/object/row/extract.py +++ b/trustgraph-flow/trustgraph/extract/object/row/extract.py @@ -112,6 +112,7 @@ class Processor(ConsumerProducer): self.prompt = PromptClient( pulsar_host=self.pulsar_host, + pulsar_api_key=self.pulsar_api_key, input_queue=pr_request_queue, output_queue=pr_response_queue, subscriber = module + "-prompt", @@ -129,7 +130,7 @@ class Processor(ConsumerProducer): t = Rows( metadata=metadata, row_schema=self.row_schema, rows=rows ) - self.producer.send(t) + await self.send(t) def emit_vec(self, metadata, name, vec, key_name, key): @@ -138,7 +139,7 @@ class Processor(ConsumerProducer): ) self.vec_prod.send(r) - def handle(self, msg): + async def handle(self, msg): v = msg.value() print(f"Indexing {v.metadata.id}...", flush=True) @@ -216,5 +217,5 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/gateway/agent.py b/trustgraph-flow/trustgraph/gateway/agent.py index c7af947b..150b970e 100644 --- a/trustgraph-flow/trustgraph/gateway/agent.py +++ b/trustgraph-flow/trustgraph/gateway/agent.py @@ -7,10 +7,10 @@ from . endpoint import ServiceEndpoint from . requestor import ServiceRequestor class AgentRequestor(ServiceRequestor): - def __init__(self, pulsar_host, timeout, auth): + def __init__(self, pulsar_client, timeout, auth): super(AgentRequestor, self).__init__( - pulsar_host=pulsar_host, + pulsar_client=pulsar_client, request_queue=agent_request_queue, response_queue=agent_response_queue, request_schema=AgentRequest, diff --git a/trustgraph-flow/trustgraph/gateway/dbpedia.py b/trustgraph-flow/trustgraph/gateway/dbpedia.py index 8ae4f695..4c8f9346 100644 --- a/trustgraph-flow/trustgraph/gateway/dbpedia.py +++ b/trustgraph-flow/trustgraph/gateway/dbpedia.py @@ -7,10 +7,10 @@ from . endpoint import ServiceEndpoint from . requestor import ServiceRequestor class DbpediaRequestor(ServiceRequestor): - def __init__(self, pulsar_host, timeout, auth): + def __init__(self, pulsar_client, timeout, auth): super(DbpediaRequestor, self).__init__( - pulsar_host=pulsar_host, + pulsar_client=pulsar_client, request_queue=dbpedia_lookup_request_queue, response_queue=dbpedia_lookup_response_queue, request_schema=LookupRequest, diff --git a/trustgraph-flow/trustgraph/gateway/document_embeddings_load.py b/trustgraph-flow/trustgraph/gateway/document_embeddings_load.py new file mode 100644 index 00000000..6b4b4838 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/document_embeddings_load.py @@ -0,0 +1,64 @@ + +import asyncio +from pulsar.schema import JsonSchema +import uuid +from aiohttp import WSMsgType + +from .. schema import Metadata +from .. schema import DocumentEmbeddings, ChunkEmbeddings +from .. schema import document_embeddings_store_queue +from .. base import Publisher + +from . socket import SocketEndpoint +from . serialize import to_subgraph + +class DocumentEmbeddingsLoadEndpoint(SocketEndpoint): + + def __init__( + self, pulsar_client, auth, path="/api/v1/load/document-embeddings", + ): + + super(DocumentEmbeddingsLoadEndpoint, self).__init__( + endpoint_path=path, auth=auth, + ) + + self.pulsar_client=pulsar_client + + self.publisher = Publisher( + self.pulsar_client, document_embeddings_store_queue, + schema=JsonSchema(DocumentEmbeddings) + ) + + async def start(self): + + self.publisher.start() + + async def listener(self, ws, running): + + async for msg in ws: + # On error, finish + if msg.type == WSMsgType.ERROR: + break + else: + + data = msg.json() + + elt = DocumentEmbeddings( + metadata=Metadata( + id=data["metadata"]["id"], + metadata=to_subgraph(data["metadata"]["metadata"]), + user=data["metadata"]["user"], + collection=data["metadata"]["collection"], + ), + chunks=[ + ChunkEmbeddings( + chunk=de["chunk"].encode("utf-8"), + vectors=de["vectors"], + ) + for de in data["chunks"] + ], + ) + + self.publisher.send(None, elt) + + running.stop() diff --git a/trustgraph-flow/trustgraph/gateway/document_embeddings_stream.py b/trustgraph-flow/trustgraph/gateway/document_embeddings_stream.py new file mode 100644 index 00000000..6d7db576 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/document_embeddings_stream.py @@ -0,0 +1,73 @@ + +import asyncio +import queue +from pulsar.schema import JsonSchema +import uuid + +from .. schema import DocumentEmbeddings +from .. schema import document_embeddings_store_queue +from .. base import Subscriber + +from . socket import SocketEndpoint +from . serialize import serialize_document_embeddings + +class DocumentEmbeddingsStreamEndpoint(SocketEndpoint): + + def __init__( + self, pulsar_client, auth, + path="/api/v1/stream/document-embeddings" + ): + + super(DocumentEmbeddingsStreamEndpoint, self).__init__( + endpoint_path=path, auth=auth, + ) + + self.pulsar_client=pulsar_client + + self.subscriber = Subscriber( + self.pulsar_client, document_embeddings_store_queue, + "api-gateway", "api-gateway", + schema=JsonSchema(DocumentEmbeddings), + ) + + async def listener(self, ws, running): + + worker = asyncio.create_task( + self.async_thread(ws, running) + ) + + await super(DocumentEmbeddingsStreamEndpoint, self).listener( + ws, running + ) + + await worker + + async def start(self): + + self.subscriber.start() + + async def async_thread(self, ws, running): + + id = str(uuid.uuid4()) + + q = self.subscriber.subscribe_all(id) + + while running.get(): + try: + resp = await asyncio.to_thread(q.get, timeout=0.5) + await ws.send_json(serialize_document_embeddings(resp)) + + except TimeoutError: + continue + + except queue.Empty: + continue + + except Exception as e: + print(f"Exception: {str(e)}", flush=True) + break + + self.subscriber.unsubscribe_all(id) + + running.stop() + diff --git a/trustgraph-flow/trustgraph/gateway/document_load.py b/trustgraph-flow/trustgraph/gateway/document_load.py index 0fd9a0df..78cd7930 100644 --- a/trustgraph-flow/trustgraph/gateway/document_load.py +++ b/trustgraph-flow/trustgraph/gateway/document_load.py @@ -1,42 +1,41 @@ import base64 -from .. schema import Document +from .. schema import Document, Metadata from .. schema import document_ingest_queue from . sender import ServiceSender from . serialize import to_subgraph class DocumentLoadSender(ServiceSender): - def __init__(self, pulsar_host): + def __init__(self, pulsar_client): super(DocumentLoadSender, self).__init__( - pulsar_host=pulsar_host, + pulsar_client=pulsar_client, request_queue=document_ingest_queue, request_schema=Document, ) def to_request(self, body): - if "metadata" in data: - metadata = to_subgraph(data["metadata"]) + if "metadata" in body: + metadata = to_subgraph(body["metadata"]) else: metadata = [] # Doing a base64 decoe/encode here to make sure the # content is valid base64 - doc = base64.b64decode(data["data"]) + doc = base64.b64decode(body["data"]) print("Document received") return Document( metadata=Metadata( - id=data.get("id"), + id=body.get("id"), metadata=metadata, - user=data.get("user", "trustgraph"), - collection=data.get("collection", "default"), + user=body.get("user", "trustgraph"), + collection=body.get("collection", "default"), ), data=base64.b64encode(doc).decode("utf-8") ) - diff --git a/trustgraph-flow/trustgraph/gateway/document_rag.py b/trustgraph-flow/trustgraph/gateway/document_rag.py new file mode 100644 index 00000000..94d8f788 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/document_rag.py @@ -0,0 +1,31 @@ + +from .. schema import DocumentRagQuery, DocumentRagResponse +from .. schema import document_rag_request_queue +from .. schema import document_rag_response_queue + +from . endpoint import ServiceEndpoint +from . requestor import ServiceRequestor + +class DocumentRagRequestor(ServiceRequestor): + def __init__(self, pulsar_client, timeout, auth): + + super(DocumentRagRequestor, self).__init__( + pulsar_client=pulsar_client, + request_queue=document_rag_request_queue, + response_queue=document_rag_response_queue, + request_schema=DocumentRagQuery, + response_schema=DocumentRagResponse, + timeout=timeout, + ) + + def to_request(self, body): + return DocumentRagQuery( + query=body["query"], + user=body.get("user", "trustgraph"), + collection=body.get("collection", "default"), + doc_limit=int(body.get("doc-limit", 20)), + ) + + def from_response(self, message): + return { "response": message.response }, True + diff --git a/trustgraph-flow/trustgraph/gateway/embeddings.py b/trustgraph-flow/trustgraph/gateway/embeddings.py index 1efafa76..42ed91a1 100644 --- a/trustgraph-flow/trustgraph/gateway/embeddings.py +++ b/trustgraph-flow/trustgraph/gateway/embeddings.py @@ -7,10 +7,10 @@ from . endpoint import ServiceEndpoint from . requestor import ServiceRequestor class EmbeddingsRequestor(ServiceRequestor): - def __init__(self, pulsar_host, timeout, auth): + def __init__(self, pulsar_client, timeout, auth): super(EmbeddingsRequestor, self).__init__( - pulsar_host=pulsar_host, + pulsar_client=pulsar_client, request_queue=embeddings_request_queue, response_queue=embeddings_response_queue, request_schema=EmbeddingsRequest, diff --git a/trustgraph-flow/trustgraph/gateway/encyclopedia.py b/trustgraph-flow/trustgraph/gateway/encyclopedia.py index 3f4dad79..49c1dfcd 100644 --- a/trustgraph-flow/trustgraph/gateway/encyclopedia.py +++ b/trustgraph-flow/trustgraph/gateway/encyclopedia.py @@ -7,10 +7,10 @@ from . endpoint import ServiceEndpoint from . requestor import ServiceRequestor class EncyclopediaRequestor(ServiceRequestor): - def __init__(self, pulsar_host, timeout, auth): + def __init__(self, pulsar_client, timeout, auth): super(EncyclopediaRequestor, self).__init__( - pulsar_host=pulsar_host, + pulsar_client=pulsar_client, request_queue=encyclopedia_lookup_request_queue, response_queue=encyclopedia_lookup_response_queue, request_schema=LookupRequest, diff --git a/trustgraph-flow/trustgraph/gateway/endpoint.py b/trustgraph-flow/trustgraph/gateway/endpoint.py index 1f38c489..5005463c 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint.py @@ -5,8 +5,8 @@ from aiohttp import web import uuid import logging -from . publisher import Publisher -from . subscriber import Subscriber +from .. base import Publisher +from .. base import Subscriber logger = logging.getLogger("endpoint") logger.setLevel(logging.INFO) diff --git a/trustgraph-flow/trustgraph/gateway/graph_embeddings_load.py b/trustgraph-flow/trustgraph/gateway/graph_embeddings_load.py index 18a2e6fe..c1354ce5 100644 --- a/trustgraph-flow/trustgraph/gateway/graph_embeddings_load.py +++ b/trustgraph-flow/trustgraph/gateway/graph_embeddings_load.py @@ -5,27 +5,27 @@ import uuid from aiohttp import WSMsgType from .. schema import Metadata -from .. schema import GraphEmbeddings +from .. schema import GraphEmbeddings, EntityEmbeddings from .. schema import graph_embeddings_store_queue +from .. base import Publisher -from . publisher import Publisher from . socket import SocketEndpoint from . serialize import to_subgraph, to_value class GraphEmbeddingsLoadEndpoint(SocketEndpoint): def __init__( - self, pulsar_host, auth, path="/api/v1/load/graph-embeddings", + self, pulsar_client, auth, path="/api/v1/load/graph-embeddings", ): super(GraphEmbeddingsLoadEndpoint, self).__init__( endpoint_path=path, auth=auth, ) - self.pulsar_host=pulsar_host + self.pulsar_client=pulsar_client self.publisher = Publisher( - self.pulsar_host, graph_embeddings_store_queue, + self.pulsar_client, graph_embeddings_store_queue, schema=JsonSchema(GraphEmbeddings) ) @@ -36,6 +36,7 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint): async def listener(self, ws, running): async for msg in ws: + # On error, finish if msg.type == WSMsgType.ERROR: break @@ -50,11 +51,15 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint): user=data["metadata"]["user"], collection=data["metadata"]["collection"], ), - entity=to_value(data["entity"]), - vectors=data["vectors"], + entities=[ + EntityEmbeddings( + entity=to_value(ent["entity"]), + vectors=ent["vectors"], + ) + for ent in data["entities"] + ] ) self.publisher.send(None, elt) - running.stop() diff --git a/trustgraph-flow/trustgraph/gateway/graph_embeddings_query.py b/trustgraph-flow/trustgraph/gateway/graph_embeddings_query.py index 5e3c0ce9..8df38e97 100644 --- a/trustgraph-flow/trustgraph/gateway/graph_embeddings_query.py +++ b/trustgraph-flow/trustgraph/gateway/graph_embeddings_query.py @@ -8,10 +8,10 @@ from . requestor import ServiceRequestor from . serialize import serialize_value class GraphEmbeddingsQueryRequestor(ServiceRequestor): - def __init__(self, pulsar_host, timeout, auth): + def __init__(self, pulsar_client, timeout, auth): super(GraphEmbeddingsQueryRequestor, self).__init__( - pulsar_host=pulsar_host, + pulsar_client=pulsar_client, request_queue=graph_embeddings_request_queue, response_queue=graph_embeddings_response_queue, request_schema=GraphEmbeddingsRequest, diff --git a/trustgraph-flow/trustgraph/gateway/graph_embeddings_stream.py b/trustgraph-flow/trustgraph/gateway/graph_embeddings_stream.py index f0b4dd86..385eb9f4 100644 --- a/trustgraph-flow/trustgraph/gateway/graph_embeddings_stream.py +++ b/trustgraph-flow/trustgraph/gateway/graph_embeddings_stream.py @@ -6,29 +6,39 @@ import uuid from .. schema import GraphEmbeddings from .. schema import graph_embeddings_store_queue +from .. base import Subscriber -from . subscriber import Subscriber from . socket import SocketEndpoint from . serialize import serialize_graph_embeddings class GraphEmbeddingsStreamEndpoint(SocketEndpoint): def __init__( - self, pulsar_host, auth, path="/api/v1/stream/graph-embeddings" + self, pulsar_client, auth, path="/api/v1/stream/graph-embeddings" ): super(GraphEmbeddingsStreamEndpoint, self).__init__( endpoint_path=path, auth=auth, ) - self.pulsar_host=pulsar_host + self.pulsar_client=pulsar_client self.subscriber = Subscriber( - self.pulsar_host, graph_embeddings_store_queue, + self.pulsar_client, graph_embeddings_store_queue, "api-gateway", "api-gateway", schema=JsonSchema(GraphEmbeddings) ) + async def listener(self, ws, running): + + worker = asyncio.create_task( + self.async_thread(ws, running) + ) + + await super(GraphEmbeddingsStreamEndpoint, self).listener(ws, running) + + await worker + async def start(self): self.subscriber.start() @@ -44,6 +54,9 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint): resp = await asyncio.to_thread(q.get, timeout=0.5) await ws.send_json(serialize_graph_embeddings(resp)) + except TimeoutError: + continue + except queue.Empty: continue diff --git a/trustgraph-flow/trustgraph/gateway/graph_rag.py b/trustgraph-flow/trustgraph/gateway/graph_rag.py index 55fd5d2f..b2b69758 100644 --- a/trustgraph-flow/trustgraph/gateway/graph_rag.py +++ b/trustgraph-flow/trustgraph/gateway/graph_rag.py @@ -7,10 +7,10 @@ from . endpoint import ServiceEndpoint from . requestor import ServiceRequestor class GraphRagRequestor(ServiceRequestor): - def __init__(self, pulsar_host, timeout, auth): + def __init__(self, pulsar_client, timeout, auth): super(GraphRagRequestor, self).__init__( - pulsar_host=pulsar_host, + pulsar_client=pulsar_client, request_queue=graph_rag_request_queue, response_queue=graph_rag_response_queue, request_schema=GraphRagQuery, @@ -23,6 +23,10 @@ class GraphRagRequestor(ServiceRequestor): query=body["query"], user=body.get("user", "trustgraph"), collection=body.get("collection", "default"), + entity_limit=int(body.get("entity-limit", 50)), + triple_limit=int(body.get("triple-limit", 30)), + max_subgraph_size=int(body.get("max-subgraph-size", 1000)), + max_path_length=int(body.get("max-path-length", 2)), ) def from_response(self, message): diff --git a/trustgraph-flow/trustgraph/gateway/internet_search.py b/trustgraph-flow/trustgraph/gateway/internet_search.py index 127cd5d1..598a75cf 100644 --- a/trustgraph-flow/trustgraph/gateway/internet_search.py +++ b/trustgraph-flow/trustgraph/gateway/internet_search.py @@ -7,10 +7,10 @@ from . endpoint import ServiceEndpoint from . requestor import ServiceRequestor class InternetSearchRequestor(ServiceRequestor): - def __init__(self, pulsar_host, timeout, auth): + def __init__(self, pulsar_client, timeout, auth): super(InternetSearchRequestor, self).__init__( - pulsar_host=pulsar_host, + pulsar_client=pulsar_client, request_queue=internet_search_request_queue, response_queue=internet_search_response_queue, request_schema=LookupRequest, diff --git a/trustgraph-flow/trustgraph/gateway/librarian.py b/trustgraph-flow/trustgraph/gateway/librarian.py new file mode 100644 index 00000000..e6ff7ce3 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/librarian.py @@ -0,0 +1,62 @@ + +from .. schema import LibrarianRequest, LibrarianResponse, Triples +from .. schema import librarian_request_queue +from .. schema import librarian_response_queue + +from . endpoint import ServiceEndpoint +from . requestor import ServiceRequestor +from . serialize import serialize_document_package, serialize_document_info +from . serialize import to_document_package, to_document_info, to_criteria + +class LibrarianRequestor(ServiceRequestor): + def __init__(self, pulsar_client, timeout, auth): + + super(LibrarianRequestor, self).__init__( + pulsar_client=pulsar_client, + request_queue=librarian_request_queue, + response_queue=librarian_response_queue, + request_schema=LibrarianRequest, + response_schema=LibrarianResponse, + timeout=timeout, + ) + + def to_request(self, body): + + print("TRR") + if "document" in body: + dp = to_document_package(body["document"]) + else: + dp = None + + print("GOT") + if "criteria" in body: + criteria = to_criteria(body["criteria"]) + else: + criteria = None + + print("ASLDKJ") + + return LibrarianRequest( + operation = body.get("operation", None), + id = body.get("id", None), + document = dp, + user = body.get("user", None), + collection = body.get("collection", None), + criteria = criteria, + ) + + def from_response(self, message): + + response = {} + + if message.document: + response["document"] = serialize_document_package(message.document) + + if message.info: + response["info"] = [ + serialize_document_info(v) + for v in message.info + ] + + return response, True + diff --git a/trustgraph-flow/trustgraph/gateway/metrics.py b/trustgraph-flow/trustgraph/gateway/metrics.py new file mode 100644 index 00000000..33c1fe3a --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/metrics.py @@ -0,0 +1,73 @@ + +# +# This provides a Prometheus endpoint on the api-gateway. It proxies +# HTTP GET requests to Prometheus. +# + +import aiohttp +from aiohttp import web +import asyncio +from pulsar.schema import JsonSchema +import uuid +import logging + +logger = logging.getLogger("endpoint") +logger.setLevel(logging.INFO) + +class MetricsEndpoint: + + def __init__(self, prometheus_url, endpoint_path, auth): + + self.prometheus_url = prometheus_url + self.path = endpoint_path + self.auth = auth + self.operation = "service" + + async def start(self): + pass + + def add_routes(self, app): + + app.add_routes([ + web.get(self.path + "/{path:.*}", self.handle), + ]) + + async def handle(self, request): + + print(request.path, "...") + + try: + ht = request.headers["Authorization"] + tokens = ht.split(" ", 2) + if tokens[0] != "Bearer": + return web.HTTPUnauthorized() + token = tokens[1] + except: + token = "" + + if not self.auth.permitted(token, self.operation): + return web.HTTPUnauthorized() + + try: + + path = request.match_info["path"] + + async with aiohttp.ClientSession() as session: + + url = ( + self.prometheus_url + "/api/v1/" + path + "?" + + request.query_string + ) + + async with session.get(url) as resp: + return web.Response( + status=resp.status, + text=await resp.text() + ) + + except Exception as e: + + logging.error(f"Exception: {e}") + + raise web.HTTPInternalServerError() + diff --git a/trustgraph-flow/trustgraph/gateway/mux.py b/trustgraph-flow/trustgraph/gateway/mux.py index ae699ae6..23b693ab 100644 --- a/trustgraph-flow/trustgraph/gateway/mux.py +++ b/trustgraph-flow/trustgraph/gateway/mux.py @@ -18,7 +18,7 @@ MAX_QUEUE_SIZE = 10 class MuxEndpoint(SocketEndpoint): def __init__( - self, pulsar_host, auth, + self, pulsar_client, auth, services, path="/api/v1/socket", ): diff --git a/trustgraph-flow/trustgraph/gateway/prompt.py b/trustgraph-flow/trustgraph/gateway/prompt.py index 080d5618..eb50ac73 100644 --- a/trustgraph-flow/trustgraph/gateway/prompt.py +++ b/trustgraph-flow/trustgraph/gateway/prompt.py @@ -9,10 +9,10 @@ from . endpoint import ServiceEndpoint from . requestor import ServiceRequestor class PromptRequestor(ServiceRequestor): - def __init__(self, pulsar_host, timeout, auth): + def __init__(self, pulsar_client, timeout, auth): super(PromptRequestor, self).__init__( - pulsar_host=pulsar_host, + pulsar_client=pulsar_client, request_queue=prompt_request_queue, response_queue=prompt_response_queue, request_schema=PromptRequest, diff --git a/trustgraph-flow/trustgraph/gateway/requestor.py b/trustgraph-flow/trustgraph/gateway/requestor.py index 5f6e2692..dc74667d 100644 --- a/trustgraph-flow/trustgraph/gateway/requestor.py +++ b/trustgraph-flow/trustgraph/gateway/requestor.py @@ -4,8 +4,8 @@ from pulsar.schema import JsonSchema import uuid import logging -from . publisher import Publisher -from . subscriber import Subscriber +from .. base import Publisher +from .. base import Subscriber logger = logging.getLogger("requestor") logger.setLevel(logging.INFO) @@ -14,7 +14,7 @@ class ServiceRequestor: def __init__( self, - pulsar_host, + pulsar_client, request_queue, request_schema, response_queue, response_schema, subscription="api-gateway", consumer_name="api-gateway", @@ -22,12 +22,12 @@ class ServiceRequestor: ): self.pub = Publisher( - pulsar_host, request_queue, - schema=JsonSchema(request_schema) + pulsar_client, request_queue, + schema=JsonSchema(request_schema), ) self.sub = Subscriber( - pulsar_host, response_queue, + pulsar_client, response_queue, subscription, consumer_name, JsonSchema(response_schema) ) @@ -60,12 +60,21 @@ class ServiceRequestor: while True: try: - resp = await asyncio.to_thread(q.get, timeout=self.timeout) + resp = await asyncio.to_thread( + q.get, + timeout=self.timeout + ) except Exception as e: raise RuntimeError("Timeout") if resp.error: - return { "error": resp.error.message } + err = { "error": { + "type": resp.error.type, + "message": resp.error.message, + } } + if responder: + await responder(err, True) + return err resp, fin = self.from_response(resp) @@ -81,7 +90,13 @@ class ServiceRequestor: logging.error(f"Exception: {e}") - return { "error": str(e) } + err = { "error": { + "type": "gateway-error", + "message": str(e), + } } + if responder: + await responder(err, True) + return err finally: self.sub.unsubscribe(id) diff --git a/trustgraph-flow/trustgraph/gateway/sender.py b/trustgraph-flow/trustgraph/gateway/sender.py index 93f1164c..32c586b1 100644 --- a/trustgraph-flow/trustgraph/gateway/sender.py +++ b/trustgraph-flow/trustgraph/gateway/sender.py @@ -6,7 +6,7 @@ from pulsar.schema import JsonSchema import uuid import logging -from . publisher import Publisher +from .. base import Publisher logger = logging.getLogger("sender") logger.setLevel(logging.INFO) @@ -15,13 +15,13 @@ class ServiceSender: def __init__( self, - pulsar_host, + pulsar_client, request_queue, request_schema, ): self.pub = Publisher( - pulsar_host, request_queue, - schema=JsonSchema(request_schema) + pulsar_client, request_queue, + schema=JsonSchema(request_schema), ) async def start(self): @@ -46,5 +46,10 @@ class ServiceSender: logging.error(f"Exception: {e}") - return { "error": str(e) } + err = { "error": str(e) } + + if responder: + await responder(err, True) + + return err diff --git a/trustgraph-flow/trustgraph/gateway/serialize.py b/trustgraph-flow/trustgraph/gateway/serialize.py index 35932382..5cc90a78 100644 --- a/trustgraph-flow/trustgraph/gateway/serialize.py +++ b/trustgraph-flow/trustgraph/gateway/serialize.py @@ -1,4 +1,7 @@ -from .. schema import Value, Triple + +import base64 + +from .. schema import Value, Triple, DocumentPackage, DocumentInfo def to_value(x): return Value(value=x["v"], is_uri=x["e"]) @@ -51,7 +54,118 @@ def serialize_graph_embeddings(message): "user": message.metadata.user, "collection": message.metadata.collection, }, - "vectors": message.vectors, - "entity": serialize_value(message.entity), + "entities": [ + { + "vectors": entity.vectors, + "entity": serialize_value(entity.entity), + } + for entity in message.entities + ], } +def serialize_document_embeddings(message): + return { + "metadata": { + "id": message.metadata.id, + "metadata": serialize_subgraph(message.metadata.metadata), + "user": message.metadata.user, + "collection": message.metadata.collection, + }, + "chunks": [ + { + "vectors": chunk.vectors, + "chunk": chunk.chunk.decode("utf-8"), + } + for chunk in message.chunks + ], + } + +def serialize_document_package(message): + + ret = {} + + if message.id: + ret["id"] = message.id + + if message.metadata: + ret["metadata"] = serialize_subgraph(message.metdata) + + if message.document: + blob = base64.b64encode( + message.document.encode("utf-8") + ).decode("utf-8") + ret["document"] = blob + + if message.kind: + ret["kind"] = message.kind + + if message.user: + ret["user"] = message.user + + if message.collection: + ret["collection"] = message.collection + + return ret + +def serialize_document_info(message): + + ret = {} + + if message.id: + ret["id"] = message.id + + if message.kind: + ret["kind"] = message.kind + + if message.user: + ret["user"] = message.user + + if message.collection: + ret["collection"] = message.collection + + if message.title: + ret["title"] = message.title + + if message.comments: + ret["comments"] = message.comments + + if message.time: + ret["time"] = message.time + + if message.metadata: + ret["metadata"] = serialize_subgraph(message.metadata) + + return ret + +def to_document_package(x): + + return DocumentPackage( + id = x.get("id", None), + kind = x.get("kind", None), + user = x.get("user", None), + collection = x.get("collection", None), + title = x.get("title", None), + comments = x.get("comments", None), + time = x.get("time", None), + document = x.get("document", None), + metadata = to_subgraph(x["metadata"]), + ) + +def to_document_info(x): + + return DocumentInfo( + id = x.get("id", None), + kind = x.get("kind", None), + user = x.get("user", None), + collection = x.get("collection", None), + title = x.get("title", None), + comments = x.get("comments", None), + time = x.get("time", None), + metadata = to_subgraph(x["metadata"]), + ) + +def to_criteria(x): + return [ + Critera(v["key"], v["value"], v["operator"]) + for v in x + ] diff --git a/trustgraph-flow/trustgraph/gateway/service.py b/trustgraph-flow/trustgraph/gateway/service.py index a260b631..d3122e3b 100755 --- a/trustgraph-flow/trustgraph/gateway/service.py +++ b/trustgraph-flow/trustgraph/gateway/service.py @@ -26,11 +26,10 @@ from .. log_level import LogLevel from . serialize import to_subgraph from . running import Running -from . publisher import Publisher -from . subscriber import Subscriber from . text_completion import TextCompletionRequestor from . prompt import PromptRequestor from . graph_rag import GraphRagRequestor +from . document_rag import DocumentRagRequestor from . triples_query import TriplesQueryRequestor from . graph_embeddings_query import GraphEmbeddingsQueryRequestor from . embeddings import EmbeddingsRequestor @@ -38,13 +37,17 @@ from . encyclopedia import EncyclopediaRequestor from . agent import AgentRequestor from . dbpedia import DbpediaRequestor from . internet_search import InternetSearchRequestor +from . librarian import LibrarianRequestor from . triples_stream import TriplesStreamEndpoint from . graph_embeddings_stream import GraphEmbeddingsStreamEndpoint +from . document_embeddings_stream import DocumentEmbeddingsStreamEndpoint from . triples_load import TriplesLoadEndpoint from . graph_embeddings_load import GraphEmbeddingsLoadEndpoint +from . document_embeddings_load import DocumentEmbeddingsLoadEndpoint from . mux import MuxEndpoint from . document_load import DocumentLoadSender from . text_load import TextLoadSender +from . metrics import MetricsEndpoint from . endpoint import ServiceEndpoint from . auth import Authenticator @@ -53,6 +56,8 @@ logger = logging.getLogger("api") logger.setLevel(logging.INFO) default_pulsar_host = os.getenv("PULSAR_HOST", "pulsar://pulsar:6650") +default_prometheus_url = os.getenv("PROMETHEUS_URL", "http://prometheus:9090") +default_pulsar_api_key = os.getenv("PULSAR_API_KEY", None) default_timeout = 600 default_port = 8088 default_api_token = os.getenv("GATEWAY_SECRET", "") @@ -69,6 +74,27 @@ class Api: self.port = int(config.get("port", default_port)) self.timeout = int(config.get("timeout", default_timeout)) self.pulsar_host = config.get("pulsar_host", default_pulsar_host) + self.pulsar_api_key = config.get( + "pulsar_api_key", default_pulsar_api_key + ) + self.pulsar_listener = config.get("pulsar_listener", None) + + if self.pulsar_api_key: + self.pulsar_client = pulsar.Client( + self.pulsar_host, listener_name=self.pulsar_listener, + authentication=pulsar.AuthenticationToken(self.pulsar_api_key) + ) + else: + self.pulsar_client = pulsar.Client( + self.pulsar_host, listener_name=self.pulsar_listener, + ) + + self.prometheus_url = config.get( + "prometheus_url", default_prometheus_url, + ) + + if not self.prometheus_url.endswith("/"): + self.prometheus_url += "/" api_token = config.get("api_token", default_api_token) @@ -80,50 +106,58 @@ class Api: self.services = { "text-completion": TextCompletionRequestor( - pulsar_host=self.pulsar_host, timeout=self.timeout, + pulsar_client=self.pulsar_client, timeout=self.timeout, auth = self.auth, ), "prompt": PromptRequestor( - pulsar_host=self.pulsar_host, timeout=self.timeout, + pulsar_client=self.pulsar_client, timeout=self.timeout, auth = self.auth, ), "graph-rag": GraphRagRequestor( - pulsar_host=self.pulsar_host, timeout=self.timeout, + pulsar_client=self.pulsar_client, timeout=self.timeout, + auth = self.auth, + ), + "document-rag": DocumentRagRequestor( + pulsar_client=self.pulsar_client, timeout=self.timeout, auth = self.auth, ), "triples-query": TriplesQueryRequestor( - pulsar_host=self.pulsar_host, timeout=self.timeout, + pulsar_client=self.pulsar_client, timeout=self.timeout, auth = self.auth, ), "graph-embeddings-query": GraphEmbeddingsQueryRequestor( - pulsar_host=self.pulsar_host, timeout=self.timeout, + pulsar_client=self.pulsar_client, timeout=self.timeout, auth = self.auth, ), "embeddings": EmbeddingsRequestor( - pulsar_host=self.pulsar_host, timeout=self.timeout, + pulsar_client=self.pulsar_client, timeout=self.timeout, auth = self.auth, ), "agent": AgentRequestor( - pulsar_host=self.pulsar_host, timeout=self.timeout, + pulsar_client=self.pulsar_client, timeout=self.timeout, + auth = self.auth, + ), + "librarian": LibrarianRequestor( + pulsar_client=self.pulsar_client, timeout=self.timeout, auth = self.auth, ), "encyclopedia": EncyclopediaRequestor( - pulsar_host=self.pulsar_host, timeout=self.timeout, + pulsar_client=self.pulsar_client, timeout=self.timeout, auth = self.auth, ), "dbpedia": DbpediaRequestor( - pulsar_host=self.pulsar_host, timeout=self.timeout, + pulsar_client=self.pulsar_client, timeout=self.timeout, auth = self.auth, ), "internet-search": InternetSearchRequestor( - pulsar_host=self.pulsar_host, timeout=self.timeout, + pulsar_client=self.pulsar_client, timeout=self.timeout, auth = self.auth, ), "document-load": DocumentLoadSender( - pulsar_host=self.pulsar_host, + pulsar_client=self.pulsar_client, ), "text-load": TextLoadSender( - pulsar_host=self.pulsar_host, + pulsar_client=self.pulsar_client, ), } @@ -140,6 +174,10 @@ class Api: endpoint_path = "/api/v1/graph-rag", auth=self.auth, requestor = self.services["graph-rag"], ), + ServiceEndpoint( + endpoint_path = "/api/v1/document-rag", auth=self.auth, + requestor = self.services["document-rag"], + ), ServiceEndpoint( endpoint_path = "/api/v1/triples-query", auth=self.auth, requestor = self.services["triples-query"], @@ -157,6 +195,10 @@ class Api: endpoint_path = "/api/v1/agent", auth=self.auth, requestor = self.services["agent"], ), + ServiceEndpoint( + endpoint_path = "/api/v1/librarian", auth=self.auth, + requestor = self.services["librarian"], + ), ServiceEndpoint( endpoint_path = "/api/v1/encyclopedia", auth=self.auth, requestor = self.services["encyclopedia"], @@ -178,26 +220,39 @@ class Api: requestor = self.services["text-load"], ), TriplesStreamEndpoint( - pulsar_host=self.pulsar_host, + pulsar_client=self.pulsar_client, auth = self.auth, ), GraphEmbeddingsStreamEndpoint( - pulsar_host=self.pulsar_host, + pulsar_client=self.pulsar_client, + auth = self.auth, + ), + DocumentEmbeddingsStreamEndpoint( + pulsar_client=self.pulsar_client, auth = self.auth, ), TriplesLoadEndpoint( - pulsar_host=self.pulsar_host, + pulsar_client=self.pulsar_client, auth = self.auth, ), GraphEmbeddingsLoadEndpoint( - pulsar_host=self.pulsar_host, + pulsar_client=self.pulsar_client, + auth = self.auth, + ), + DocumentEmbeddingsLoadEndpoint( + pulsar_client=self.pulsar_client, auth = self.auth, ), MuxEndpoint( - pulsar_host=self.pulsar_host, + pulsar_client=self.pulsar_client, auth = self.auth, services = self.services, ), + MetricsEndpoint( + endpoint_path = "/api/v1/metrics", + prometheus_url = self.prometheus_url, + auth = self.auth, + ), ] for ep in self.endpoints: @@ -225,6 +280,23 @@ def run(): default=default_pulsar_host, help=f'Pulsar host (default: {default_pulsar_host})', ) + + parser.add_argument( + '--pulsar-api-key', + default=default_pulsar_api_key, + help=f'Pulsar API key', + ) + + parser.add_argument( + '--pulsar-listener', + help=f'Pulsar listener (default: none)', + ) + + parser.add_argument( + '-m', '--prometheus-url', + default=default_prometheus_url, + help=f'Prometheus URL (default: {default_prometheus_url})', + ) parser.add_argument( '--port', diff --git a/trustgraph-flow/trustgraph/gateway/socket.py b/trustgraph-flow/trustgraph/gateway/socket.py index fd408d7b..c32a28af 100644 --- a/trustgraph-flow/trustgraph/gateway/socket.py +++ b/trustgraph-flow/trustgraph/gateway/socket.py @@ -19,7 +19,7 @@ class SocketEndpoint: self.operation = "socket" async def listener(self, ws, running): - + async for msg in ws: # On error, finish if msg.type == WSMsgType.TEXT: @@ -44,13 +44,16 @@ class SocketEndpoint: return web.HTTPUnauthorized() running = Running() - ws = web.WebSocketResponse() + + # 50MB max message size + ws = web.WebSocketResponse(max_msg_size=52428800) + await ws.prepare(request) try: await self.listener(ws, running) except Exception as e: - print(e, flush=True) + print("Socket exception:", e, flush=True) running.stop() diff --git a/trustgraph-flow/trustgraph/gateway/text_completion.py b/trustgraph-flow/trustgraph/gateway/text_completion.py index 7291fc88..ec84e5d6 100644 --- a/trustgraph-flow/trustgraph/gateway/text_completion.py +++ b/trustgraph-flow/trustgraph/gateway/text_completion.py @@ -7,10 +7,10 @@ from . endpoint import ServiceEndpoint from . requestor import ServiceRequestor class TextCompletionRequestor(ServiceRequestor): - def __init__(self, pulsar_host, timeout, auth): + def __init__(self, pulsar_client, timeout, auth): super(TextCompletionRequestor, self).__init__( - pulsar_host=pulsar_host, + pulsar_client=pulsar_client, request_queue=text_completion_request_queue, response_queue=text_completion_response_queue, request_schema=TextCompletionRequest, diff --git a/trustgraph-flow/trustgraph/gateway/text_load.py b/trustgraph-flow/trustgraph/gateway/text_load.py index ade6b1c1..cc432698 100644 --- a/trustgraph-flow/trustgraph/gateway/text_load.py +++ b/trustgraph-flow/trustgraph/gateway/text_load.py @@ -8,10 +8,10 @@ from . sender import ServiceSender from . serialize import to_subgraph class TextLoadSender(ServiceSender): - def __init__(self, pulsar_host): + def __init__(self, pulsar_client): super(TextLoadSender, self).__init__( - pulsar_host=pulsar_host, + pulsar_client=pulsar_client, request_queue=text_ingest_queue, request_schema=TextDocument, ) @@ -36,7 +36,7 @@ class TextLoadSender(ServiceSender): return TextDocument( metadata=Metadata( id=body.get("id"), - metabody=metadata, + metadata=metadata, user=body.get("user", "trustgraph"), collection=body.get("collection", "default"), ), diff --git a/trustgraph-flow/trustgraph/gateway/triples_load.py b/trustgraph-flow/trustgraph/gateway/triples_load.py index 2689f3ad..bc69975e 100644 --- a/trustgraph-flow/trustgraph/gateway/triples_load.py +++ b/trustgraph-flow/trustgraph/gateway/triples_load.py @@ -7,23 +7,23 @@ from aiohttp import WSMsgType from .. schema import Metadata from .. schema import Triples from .. schema import triples_store_queue +from .. base import Publisher -from . publisher import Publisher from . socket import SocketEndpoint from . serialize import to_subgraph class TriplesLoadEndpoint(SocketEndpoint): - def __init__(self, pulsar_host, auth, path="/api/v1/load/triples"): + def __init__(self, pulsar_client, auth, path="/api/v1/load/triples"): super(TriplesLoadEndpoint, self).__init__( endpoint_path=path, auth=auth, ) - self.pulsar_host=pulsar_host + self.pulsar_client=pulsar_client self.publisher = Publisher( - self.pulsar_host, triples_store_queue, + self.pulsar_client, triples_store_queue, schema=JsonSchema(Triples) ) diff --git a/trustgraph-flow/trustgraph/gateway/triples_query.py b/trustgraph-flow/trustgraph/gateway/triples_query.py index 0ea7cd8d..061bd4d8 100644 --- a/trustgraph-flow/trustgraph/gateway/triples_query.py +++ b/trustgraph-flow/trustgraph/gateway/triples_query.py @@ -8,10 +8,10 @@ from . requestor import ServiceRequestor from . serialize import to_value, serialize_subgraph class TriplesQueryRequestor(ServiceRequestor): - def __init__(self, pulsar_host, timeout, auth): + def __init__(self, pulsar_client, timeout, auth): super(TriplesQueryRequestor, self).__init__( - pulsar_host=pulsar_host, + pulsar_client=pulsar_client, request_queue=triples_request_queue, response_queue=triples_response_queue, request_schema=TriplesQueryRequest, diff --git a/trustgraph-flow/trustgraph/gateway/triples_stream.py b/trustgraph-flow/trustgraph/gateway/triples_stream.py index 92ada132..a5d5ad0a 100644 --- a/trustgraph-flow/trustgraph/gateway/triples_stream.py +++ b/trustgraph-flow/trustgraph/gateway/triples_stream.py @@ -6,27 +6,37 @@ import uuid from .. schema import Triples from .. schema import triples_store_queue +from .. base import Subscriber -from . subscriber import Subscriber from . socket import SocketEndpoint from . serialize import serialize_triples class TriplesStreamEndpoint(SocketEndpoint): - def __init__(self, pulsar_host, auth, path="/api/v1/stream/triples"): + def __init__(self, pulsar_client, auth, path="/api/v1/stream/triples"): super(TriplesStreamEndpoint, self).__init__( endpoint_path=path, auth=auth, ) - self.pulsar_host=pulsar_host + self.pulsar_client=pulsar_client self.subscriber = Subscriber( - self.pulsar_host, triples_store_queue, + self.pulsar_client, triples_store_queue, "api-gateway", "api-gateway", schema=JsonSchema(Triples) ) + async def listener(self, ws, running): + + worker = asyncio.create_task( + self.async_thread(ws, running) + ) + + await super(TriplesStreamEndpoint, self).listener(ws, running) + + await worker + async def start(self): self.subscriber.start() @@ -42,6 +52,9 @@ class TriplesStreamEndpoint(SocketEndpoint): resp = await asyncio.to_thread(q.get, timeout=0.5) await ws.send_json(serialize_triples(resp)) + except TimeoutError: + continue + except queue.Empty: continue diff --git a/trustgraph-flow/trustgraph/graph_rag.py b/trustgraph-flow/trustgraph/graph_rag.py index f69ebeb7..6a4e11c5 100644 --- a/trustgraph-flow/trustgraph/graph_rag.py +++ b/trustgraph-flow/trustgraph/graph_rag.py @@ -20,11 +20,19 @@ DEFINITION="http://www.w3.org/2004/02/skos/core#definition" class Query: - def __init__(self, rag, user, collection, verbose): + def __init__( + self, rag, user, collection, verbose, + entity_limit=50, triple_limit=30, max_subgraph_size=1000, + max_path_length=2, + ): self.rag = rag self.user = user self.collection = collection self.verbose = verbose + self.entity_limit = entity_limit + self.triple_limit = triple_limit + self.max_subgraph_size = max_subgraph_size + self.max_path_length = max_path_length def get_vector(self, query): @@ -47,7 +55,7 @@ class Query: entities = self.rag.ge_client.request( user=self.user, collection=self.collection, - vectors=vectors, limit=self.rag.entity_limit, + vectors=vectors, limit=self.entity_limit, ) entities = [ @@ -79,62 +87,67 @@ class Query: self.rag.label_cache[e] = res[0].o.value return self.rag.label_cache[e] + def follow_edges(self, ent, subgraph, path_length): + + # Not needed? + if path_length <= 0: + return + + # Stop spanning around if the subgraph is already maxed out + if len(subgraph) >= self.max_subgraph_size: + return + + res = self.rag.triples_client.request( + user=self.user, collection=self.collection, + s=ent, p=None, o=None, + limit=self.triple_limit + ) + + for triple in res: + subgraph.add( + (triple.s.value, triple.p.value, triple.o.value) + ) + if path_length > 1: + self.follow_edges(triple.o.value, subgraph, path_length-1) + + res = self.rag.triples_client.request( + user=self.user, collection=self.collection, + s=None, p=ent, o=None, + limit=self.triple_limit + ) + + for triple in res: + subgraph.add( + (triple.s.value, triple.p.value, triple.o.value) + ) + + res = self.rag.triples_client.request( + user=self.user, collection=self.collection, + s=None, p=None, o=ent, + limit=self.triple_limit, + ) + + for triple in res: + subgraph.add( + (triple.s.value, triple.p.value, triple.o.value) + ) + if path_length > 1: + self.follow_edges(triple.s.value, subgraph, path_length-1) + def get_subgraph(self, query): entities = self.get_entities(query) - subgraph = set() - if self.verbose: print("Get subgraph...", flush=True) - for e in entities: + subgraph = set() - res = self.rag.triples_client.request( - user=self.user, collection=self.collection, - s=e, p=None, o=None, - limit=self.rag.query_limit - ) - - for triple in res: - subgraph.add( - (triple.s.value, triple.p.value, triple.o.value) - ) - - res = self.rag.triples_client.request( - user=self.user, collection=self.collection, - s=None, p=e, o=None, - limit=self.rag.query_limit - ) - - for triple in res: - subgraph.add( - (triple.s.value, triple.p.value, triple.o.value) - ) - - res = self.rag.triples_client.request( - user=self.user, collection=self.collection, - s=None, p=None, o=e, - limit=self.rag.query_limit, - ) - - for triple in res: - subgraph.add( - (triple.s.value, triple.p.value, triple.o.value) - ) + for ent in entities: + self.follow_edges(ent, subgraph, self.max_path_length) subgraph = list(subgraph) - subgraph = subgraph[0:self.rag.max_subgraph_size] - - if self.verbose: - print("Subgraph:", flush=True) - for edge in subgraph: - print(" ", str(edge), flush=True) - - if self.verbose: - print("Done.", flush=True) - return subgraph def get_labelgraph(self, query): @@ -154,6 +167,16 @@ class Query: sg2.append((s, p, o)) + sg2 = sg2[0:self.max_subgraph_size] + + if self.verbose: + print("Subgraph:", flush=True) + for edge in sg2: + print(" ", str(edge), flush=True) + + if self.verbose: + print("Done.", flush=True) + return sg2 class GraphRag: @@ -161,6 +184,7 @@ class GraphRag: def __init__( self, pulsar_host="pulsar://pulsar:6650", + pulsar_api_key=None, pr_request_queue=None, pr_response_queue=None, emb_request_queue=None, @@ -170,9 +194,6 @@ class GraphRag: tpl_request_queue=None, tpl_response_queue=None, verbose=False, - entity_limit=50, - triple_limit=30, - max_subgraph_size=3000, module="test", ): @@ -207,6 +228,7 @@ class GraphRag: self.ge_client = GraphEmbeddingsClient( pulsar_host=pulsar_host, + pulsar_api_key=pulsar_api_key, subscriber=module + "-ge", input_queue=ge_request_queue, output_queue=ge_response_queue, @@ -214,6 +236,7 @@ class GraphRag: self.triples_client = TriplesQueryClient( pulsar_host=pulsar_host, + pulsar_api_key=pulsar_api_key, subscriber=module + "-tpl", input_queue=tpl_request_queue, output_queue=tpl_response_queue @@ -221,19 +244,17 @@ class GraphRag: self.embeddings = EmbeddingsClient( pulsar_host=pulsar_host, + pulsar_api_key=pulsar_api_key, input_queue=emb_request_queue, output_queue=emb_response_queue, subscriber=module + "-emb", ) - self.entity_limit=entity_limit - self.query_limit=triple_limit - self.max_subgraph_size=max_subgraph_size - self.label_cache = {} self.prompt = PromptClient( pulsar_host=pulsar_host, + pulsar_api_key=pulsar_api_key, input_queue=pr_request_queue, output_queue=pr_response_queue, subscriber=module + "-prompt", @@ -242,13 +263,20 @@ class GraphRag: if self.verbose: print("Initialised", flush=True) - def query(self, query, user="trustgraph", collection="default"): + def query( + self, query, user="trustgraph", collection="default", + entity_limit=50, triple_limit=30, max_subgraph_size=1000, + max_path_length=2, + ): if self.verbose: print("Construct prompt...", flush=True) q = Query( - rag=self, user=user, collection=collection, verbose=self.verbose + rag=self, user=user, collection=collection, verbose=self.verbose, + entity_limit=entity_limit, triple_limit=triple_limit, + max_subgraph_size=max_subgraph_size, + max_path_length=max_path_length, ) kg = q.get_labelgraph(query) diff --git a/trustgraph-flow/trustgraph/librarian/__init__.py b/trustgraph-flow/trustgraph/librarian/__init__.py new file mode 100644 index 00000000..ba844705 --- /dev/null +++ b/trustgraph-flow/trustgraph/librarian/__init__.py @@ -0,0 +1,3 @@ + +from . service import * + diff --git a/trustgraph-flow/trustgraph/librarian/__main__.py b/trustgraph-flow/trustgraph/librarian/__main__.py new file mode 100755 index 00000000..e9136855 --- /dev/null +++ b/trustgraph-flow/trustgraph/librarian/__main__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 + +from . service import run + +if __name__ == '__main__': + run() + diff --git a/trustgraph-flow/trustgraph/librarian/blob_store.py b/trustgraph-flow/trustgraph/librarian/blob_store.py new file mode 100644 index 00000000..5cffef18 --- /dev/null +++ b/trustgraph-flow/trustgraph/librarian/blob_store.py @@ -0,0 +1,51 @@ +from .. schema import LibrarianRequest, LibrarianResponse, Error +from .. knowledge import hash +from .. exceptions import RequestError + +from minio import Minio +import time +import io + +class BlobStore: + + def __init__( + self, + minio_host, minio_access_key, minio_secret_key, bucket_name, + ): + + + self.minio = Minio( + minio_host, + access_key = minio_access_key, + secret_key = minio_secret_key, + secure = False, + ) + + self.bucket_name = bucket_name + + print("Connected to minio", flush=True) + + self.ensure_bucket() + + def ensure_bucket(self): + + # Make the bucket if it doesn't exist. + found = self.minio.bucket_exists(self.bucket_name) + if not found: + self.minio.make_bucket(self.bucket_name) + print("Created bucket", self.bucket_name, flush=True) + else: + print("Bucket", self.bucket_name, "already exists", flush=True) + + def add(self, object_id, blob, kind): + + # FIXME: Loop retry + self.minio.put_object( + bucket_name = self.bucket_name, + object_name = "doc/" + str(object_id), + length = len(blob), + data = io.BytesIO(blob), + content_type = kind, + ) + + print("Add blob complete", flush=True) diff --git a/trustgraph-flow/trustgraph/librarian/librarian.py b/trustgraph-flow/trustgraph/librarian/librarian.py new file mode 100644 index 00000000..9bccc37a --- /dev/null +++ b/trustgraph-flow/trustgraph/librarian/librarian.py @@ -0,0 +1,88 @@ +from .. schema import LibrarianRequest, LibrarianResponse, Error, Triple +from .. knowledge import hash +from .. exceptions import RequestError +from . table_store import TableStore +from . blob_store import BlobStore + +import uuid + +class Librarian: + + def __init__( + self, + cassandra_host, cassandra_user, cassandra_password, + minio_host, minio_access_key, minio_secret_key, + bucket_name, keyspace, load_document, load_text, + ): + + self.blob_store = BlobStore( + minio_host, minio_access_key, minio_secret_key, bucket_name + ) + + self.table_store = TableStore( + cassandra_host, cassandra_user, cassandra_password, keyspace + ) + + self.load_document = load_document + self.load_text = load_text + + async def add(self, document): + + if document.kind not in ( + "text/plain", "application/pdf" + ): + raise RequestError("Invalid document kind: " + document.kind) + + # Create object ID as a hash of the document + object_id = uuid.UUID(hash(document.document)) + + self.blob_store.add(object_id, document.document, document.kind) + + self.table_store.add(object_id, document) + + if document.kind == "application/pdf": + await self.load_document(document) + elif document.kind == "text/plain": + await self.load_text(document) + + print("Add complete", flush=True) + + return LibrarianResponse( + error = None, + document = None, + info = None, + ) + + async def list(self, user, collection): + + print("list") + + info = self.table_store.list(user, collection) + + print(">>", info) + + return LibrarianResponse( + error = None, + document = None, + info = info, + ) + + def handle_triples(self, m): + self.table_store.add_triples(m) + + def handle_graph_embeddings(self, m): + self.table_store.add_graph_embeddings(m) + + def handle_document_embeddings(self, m): + self.table_store.add_document_embeddings(m) + + + def handle_triples(self, m): + self.table_store.add_triples(m) + + def handle_graph_embeddings(self, m): + self.table_store.add_graph_embeddings(m) + + def handle_document_embeddings(self, m): + self.table_store.add_document_embeddings(m) + diff --git a/trustgraph-flow/trustgraph/librarian/service.py b/trustgraph-flow/trustgraph/librarian/service.py new file mode 100755 index 00000000..b42123a5 --- /dev/null +++ b/trustgraph-flow/trustgraph/librarian/service.py @@ -0,0 +1,424 @@ + +""" +Librarian service, manages documents in collections +""" + +from functools import partial +import asyncio +import threading +import queue +import base64 + +from pulsar.schema import JsonSchema + +from .. schema import LibrarianRequest, LibrarianResponse, Error +from .. schema import librarian_request_queue, librarian_response_queue + +from .. schema import GraphEmbeddings +from .. schema import graph_embeddings_store_queue +from .. schema import Triples +from .. schema import triples_store_queue +from .. schema import DocumentEmbeddings +from .. schema import document_embeddings_store_queue + +from .. schema import Document, Metadata +from .. schema import document_ingest_queue +from .. schema import TextDocument, Metadata +from .. schema import text_ingest_queue + +from .. base import Publisher +from .. base import Subscriber + +from .. log_level import LogLevel +from .. base import ConsumerProducer +from .. exceptions import RequestError + +from . librarian import Librarian + +module = ".".join(__name__.split(".")[1:-1]) + +default_input_queue = librarian_request_queue +default_output_queue = librarian_response_queue +default_subscriber = module +default_minio_host = "minio:9000" +default_minio_access_key = "minioadmin" +default_minio_secret_key = "minioadmin" +default_cassandra_host = "cassandra" + +bucket_name = "library" + +# FIXME: How to ensure this doesn't conflict with other usage? +keyspace = "librarian" + +class Processor(ConsumerProducer): + + def __init__(self, **params): + + self.running = True + + input_queue = params.get("input_queue", default_input_queue) + output_queue = params.get("output_queue", default_output_queue) + subscriber = params.get("subscriber", default_subscriber) + + minio_host = params.get("minio_host", default_minio_host) + minio_access_key = params.get( + "minio_access_key", + default_minio_access_key + ) + minio_secret_key = params.get( + "minio_secret_key", + default_minio_secret_key + ) + + cassandra_host = params.get("cassandra_host", default_cassandra_host) + cassandra_user = params.get("cassandra_user") + cassandra_password = params.get("cassandra_password") + + triples_queue = params.get("triples_queue") + graph_embeddings_queue = params.get("graph_embeddings_queue") + document_embeddings_queue = params.get("document_embeddings_queue") + document_load_queue = params.get("document_load_queue") + text_load_queue = params.get("text_load_queue") + + super(Processor, self).__init__( + **params | { + "input_queue": input_queue, + "output_queue": output_queue, + "subscriber": subscriber, + "input_schema": LibrarianRequest, + "output_schema": LibrarianResponse, + "minio_host": minio_host, + "minio_access_key": minio_access_key, + "cassandra_host": cassandra_host, + "cassandra_user": cassandra_user, + } + ) + + self.document_load = Publisher( + self.client, document_load_queue, JsonSchema(Document), + ) + + self.text_load = Publisher( + self.client, text_load_queue, JsonSchema(TextDocument), + ) + + self.triples_brk = Subscriber( + self.client, triples_store_queue, + "librarian", "librarian", + schema=JsonSchema(Triples), + ) + self.graph_embeddings_brk = Subscriber( + self.client, graph_embeddings_store_queue, + "librarian", "librarian", + schema=JsonSchema(GraphEmbeddings), + ) + self.document_embeddings_brk = Subscriber( + self.client, document_embeddings_store_queue, + "librarian", "librarian", + schema=JsonSchema(DocumentEmbeddings), + ) + + self.triples_reader = threading.Thread( + target=self.receive_triples + ) + self.graph_embeddings_reader = threading.Thread( + target=self.receive_graph_embeddings + ) + self.document_embeddings_reader = threading.Thread( + target=self.receive_document_embeddings + ) + + self.librarian = Librarian( + cassandra_host = cassandra_host.split(","), + cassandra_user = cassandra_user, + cassandra_password = cassandra_password, + minio_host = minio_host, + minio_access_key = minio_access_key, + minio_secret_key = minio_secret_key, + bucket_name = bucket_name, + keyspace = keyspace, + load_document = self.load_document, + load_text = self.load_text, + ) + + print("Initialised.", flush=True) + + async def start(self): + + self.document_load.start() + self.text_load.start() + + self.triples_brk.start() + self.graph_embeddings_brk.start() + self.document_embeddings_brk.start() + + self.triples_sub = self.triples_brk.subscribe_all("x") + self.graph_embeddings_sub = self.graph_embeddings_brk.subscribe_all("x") + self.document_embeddings_sub = self.document_embeddings_brk.subscribe_all("x") + + self.triples_reader.start() + self.graph_embeddings_reader.start() + self.document_embeddings_reader.start() + + def __del__(self): + + self.running = False + + if hasattr(self, "document_load"): + self.document_load.stop() + self.document_load.join() + + if hasattr(self, "text_load"): + self.text_load.stop() + self.text_load.join() + + if hasattr(self, "triples_sub"): + self.triples_sub.unsubscribe_all("x") + + if hasattr(self, "graph_embeddings_sub"): + self.graph_embeddings_sub.unsubscribe_all("x") + + if hasattr(self, "document_embeddings_sub"): + self.document_embeddings_sub.unsubscribe_all("x") + + if hasattr(self, "triples_brk"): + self.triples_brk.stop() + self.triples_brk.join() + + if hasattr(self, "graph_embeddings_brk"): + self.graph_embeddings_brk.stop() + self.graph_embeddings_brk.join() + + if hasattr(self, "document_embeddings_brk"): + self.document_embeddings_brk.stop() + self.document_embeddings_brk.join() + + def receive_triples(self): + + while self.running: + try: + msg = self.triples_sub.get(timeout=1) + except queue.Empty: + continue + + self.librarian.handle_triples(msg) + + def receive_graph_embeddings(self): + + while self.running: + try: + msg = self.graph_embeddings_sub.get(timeout=1) + except queue.Empty: + continue + + self.librarian.handle_graph_embeddings(msg) + + def receive_document_embeddings(self): + + while self.running: + try: + msg = self.document_embeddings_sub.get(timeout=1) + except queue.Empty: + continue + + self.librarian.handle_document_embeddings(msg) + + async def load_document(self, document): + + doc = Document( + metadata = Metadata( + id = document.id, + metadata = document.metadata, + user = document.user, + collection = document.collection + ), + data = document.document + ) + + self.document_load.send(None, doc) + + async def load_text(self, document): + + text = base64.b64decode(document.document) + text = text.decode("utf-8") + + doc = TextDocument( + metadata = Metadata( + id = document.id, + metadata = document.metadata, + user = document.user, + collection = document.collection + ), + text = text, + ) + + self.text_load.send(None, doc) + + def parse_request(self, v): + + if v.operation is None: + raise RequestError("Null operation") + + print("op", v.operation) + + if v.operation == "add": + if ( + v.document and v.document.id and v.document.metadata and + v.document.document and v.document.kind + ): + return partial( + self.librarian.add, + document = v.document, + ) + else: + raise RequestError("Invalid call") + + if v.operation == "list": + print("list", v) + print(v.user) + if v.user: + return partial( + self.librarian.list, + user = v.user, + collection = v.collection, + ) + else: + print("BROK") + raise RequestError("Invalid call") + + raise RequestError("Invalid operation: " + v.operation) + + async def handle(self, msg): + + v = msg.value() + + # Sender-produced ID + + id = msg.properties()["id"] + + print(f"Handling input {id}...", flush=True) + + try: + func = self.parse_request(v) + except RequestError as e: + resp = LibrarianResponse( + error = Error( + type = "request-error", + message = str(e), + ) + ) + await self.send(resp, properties={"id": id}) + return + + try: + resp = await func() + print("->", resp) + except RequestError as e: + resp = LibrarianResponse( + error = Error( + type = "request-error", + message = str(e), + ) + ) + await self.send(resp, properties={"id": id}) + return + except Exception as e: + print("Exception:", e, flush=True) + resp = LibrarianResponse( + error = Error( + type = "processing-error", + message = "Unhandled error: " + str(e), + ) + ) + await self.send(resp, properties={"id": id}) + return + + print("Send response..!.", flush=True) + + await self.send(resp, properties={"id": id}) + + print("Done.", flush=True) + + @staticmethod + def add_args(parser): + + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) + + parser.add_argument( + '--minio-host', + default=default_minio_host, + help=f'Minio hostname (default: {default_minio_host})', + ) + + parser.add_argument( + '--minio-access-key', + default='minioadmin', + help='Minio access key / username ' + f'(default: {default_minio_access_key})', + ) + + parser.add_argument( + '--minio-secret-key', + default='minioadmin', + help='Minio secret key / password ' + f'(default: {default_minio_access_key})', + ) + + parser.add_argument( + '--cassandra-host', + default="cassandra", + help=f'Graph host (default: cassandra)' + ) + + parser.add_argument( + '--cassandra-user', + default=None, + help=f'Cassandra user' + ) + + parser.add_argument( + '--cassandra-password', + default=None, + help=f'Cassandra password' + ) + + parser.add_argument( + '--triples-queue', + default=triples_store_queue, + help=f'Triples queue (default: {triples_store_queue})' + ) + + parser.add_argument( + '--graph-embeddings-queue', + default=graph_embeddings_store_queue, + help=f'Graph embeddings queue (default: {triples_store_queue})' + ) + + parser.add_argument( + '--document-embeddings-queue', + default=document_embeddings_store_queue, + help='Document embeddings queue ' + f'(default: {document_embeddings_store_queue})' + ) + + parser.add_argument( + '--document-load-queue', + default=document_ingest_queue, + help='Document load queue ' + f'(default: {document_ingest_queue})' + ) + + parser.add_argument( + '--text-load-queue', + default=text_ingest_queue, + help='Text ingest queue ' + f'(default: {text_ingest_queue})' + ) + +def run(): + + Processor.launch(module, __doc__) + diff --git a/trustgraph-flow/trustgraph/librarian/table_store.py b/trustgraph-flow/trustgraph/librarian/table_store.py new file mode 100644 index 00000000..1fe47fcf --- /dev/null +++ b/trustgraph-flow/trustgraph/librarian/table_store.py @@ -0,0 +1,448 @@ +from .. schema import LibrarianRequest, LibrarianResponse +from .. schema import DocumentInfo, Error, Triple, Value +from .. knowledge import hash +from .. exceptions import RequestError + +from cassandra.cluster import Cluster +from cassandra.auth import PlainTextAuthProvider +from cassandra.query import BatchStatement +from ssl import SSLContext, PROTOCOL_TLSv1_2 +import uuid +import time + +class TableStore: + + def __init__( + self, + cassandra_host, cassandra_user, cassandra_password, keyspace, + ): + + self.keyspace = keyspace + + print("Connecting to Cassandra...", flush=True) + + if cassandra_user and cassandra_password: + ssl_context = SSLContext(PROTOCOL_TLSv1_2) + auth_provider = PlainTextAuthProvider( + username=cassandra_user, password=cassandra_password + ) + self.cluster = Cluster( + cassandra_host, + auth_provider=auth_provider, + ssl_context=ssl_context + ) + else: + self.cluster = Cluster(cassandra_host) + + self.cassandra = self.cluster.connect() + + print("Connected.", flush=True) + + self.ensure_cassandra_schema() + + self.prepare_statements() + + def ensure_cassandra_schema(self): + + print("Ensure Cassandra schema...", flush=True) + + print("Keyspace...", flush=True) + + # FIXME: Replication factor should be configurable + self.cassandra.execute(f""" + create keyspace if not exists {self.keyspace} + with replication = {{ + 'class' : 'SimpleStrategy', + 'replication_factor' : 1 + }}; + """); + + self.cassandra.set_keyspace(self.keyspace) + + print("document table...", flush=True) + + self.cassandra.execute(""" + CREATE TABLE IF NOT EXISTS document ( + user text, + collection text, + id text, + time timestamp, + title text, + comments text, + kind text, + object_id uuid, + metadata list>, + PRIMARY KEY (user, collection, id) + ); + """); + + print("object index...", flush=True) + + self.cassandra.execute(""" + CREATE INDEX IF NOT EXISTS document_object + ON document (object_id) + """); + + print("triples table...", flush=True) + + self.cassandra.execute(""" + CREATE TABLE IF NOT EXISTS triples ( + user text, + collection text, + document_id text, + id uuid, + time timestamp, + metadata list>, + triples list>, + PRIMARY KEY (user, collection, document_id, id) + ); + """); + + print("graph_embeddings table...", flush=True) + + self.cassandra.execute(""" + create table if not exists graph_embeddings ( + user text, + collection text, + document_id text, + id uuid, + time timestamp, + metadata list>, + entity_embeddings list< + tuple< + tuple, + list> + > + >, + PRIMARY KEY (user, collection, document_id, id) + ); + """); + + print("document_embeddings table...", flush=True) + + self.cassandra.execute(""" + create table if not exists document_embeddings ( + user text, + collection text, + document_id text, + id uuid, + time timestamp, + metadata list>, + chunks list< + tuple< + blob, + list> + > + >, + PRIMARY KEY (user, collection, document_id, id) + ); + """); + + print("Cassandra schema OK.", flush=True) + + def prepare_statements(self): + + self.insert_document_stmt = self.cassandra.prepare(""" + INSERT INTO document + ( + id, user, collection, kind, object_id, time, title, comments, + metadata + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """) + + self.list_document_stmt = self.cassandra.prepare(""" + SELECT + id, kind, user, collection, title, comments, time, metadata + FROM document + WHERE user = ? + """) + + self.list_document_by_collection_stmt = self.cassandra.prepare(""" + SELECT + id, kind, user, collection, title, comments, time, metadata + FROM document + WHERE user = ? AND collection = ? + """) + + self.insert_triples_stmt = self.cassandra.prepare(""" + INSERT INTO triples + ( + id, user, collection, document_id, time, + metadata, triples + ) + VALUES (?, ?, ?, ?, ?, ?, ?) + """) + + self.insert_graph_embeddings_stmt = self.cassandra.prepare(""" + INSERT INTO graph_embeddings + ( + id, user, collection, document_id, time, + metadata, entity_embeddings + ) + VALUES (?, ?, ?, ?, ?, ?, ?) + """) + + self.insert_document_embeddings_stmt = self.cassandra.prepare(""" + INSERT INTO document_embeddings + ( + id, user, collection, document_id, time, + metadata, chunks + ) + VALUES (?, ?, ?, ?, ?, ?, ?) + """) + + def add(self, object_id, document): + + if document.kind not in ( + "text/plain", "application/pdf" + ): + raise RequestError("Invalid document kind: " + document.kind) + + # Create random doc ID + when = int(time.time() * 1000) + + print("Adding", document.id, object_id) + + metadata = [ + ( + v.s.value, v.s.is_uri, v.p.value, v.p.is_uri, + v.o.value, v.o.is_uri + ) + for v in document.metadata + ] + + while True: + + try: + + resp = self.cassandra.execute( + self.insert_document_stmt, + ( + document.id, document.user, document.collection, + document.kind, object_id, when, + document.title, document.comments, + metadata + ) + ) + + break + + except Exception as e: + + print("Exception:", type(e)) + print(f"{e}, retry...", flush=True) + time.sleep(1) + + print("Add complete", flush=True) + + def add_triples(self, m): + + when = int(time.time() * 1000) + + if m.metadata.metadata: + metadata = [ + ( + v.s.value, v.s.is_uri, v.p.value, v.p.is_uri, + v.o.value, v.o.is_uri + ) + for v in m.metadata.metadata + ] + else: + metadata = [] + + triples = [ + ( + v.s.value, v.s.is_uri, v.p.value, v.p.is_uri, + v.o.value, v.o.is_uri + ) + for v in m.triples + ] + + while True: + + try: + + resp = self.cassandra.execute( + self.insert_triples_stmt, + ( + uuid.uuid4(), m.metadata.user, + m.metadata.collection, m.metadata.id, when, + metadata, triples, + ) + ) + + break + + except Exception as e: + + print("Exception:", type(e)) + print(f"{e}, retry...", flush=True) + time.sleep(1) + + def list(self, user, collection=None): + + print("LIST") + while True: + + print("TRY") + + print(self.list_document_stmt) + try: + + if collection: + resp = self.cassandra.execute( + self.list_document_by_collection_stmt, + (user, collection) + ) + else: + resp = self.cassandra.execute( + self.list_document_stmt, + (user,) + ) + break + + print("OK") + + except Exception as e: + print("Exception:", type(e)) + print(f"{e}, retry...", flush=True) + time.sleep(1) + + print("OK2") + + info = [ + DocumentInfo( + id = row[0], + kind = row[1], + user = row[2], + collection = row[3], + title = row[4], + comments = row[5], + time = int(1000 * row[6].timestamp()), + metadata = [ + Triple( + s=Value(value=m[0], is_uri=m[1]), + p=Value(value=m[2], is_uri=m[3]), + o=Value(value=m[4], is_uri=m[5]) + ) + for m in row[7] + ], + ) + for row in resp + ] + + print("OK3") + + print(info[0]) + + print(info[0].user) + print(info[0].time) + print(info[0].kind) + print(info[0].collection) + print(info[0].title) + print(info[0].comments) + print(info[0].metadata) + print(info[0].metadata) + + return info + + def add_graph_embeddings(self, m): + + when = int(time.time() * 1000) + + if m.metadata.metadata: + metadata = [ + ( + v.s.value, v.s.is_uri, v.p.value, v.p.is_uri, + v.o.value, v.o.is_uri + ) + for v in m.metadata.metadata + ] + else: + metadata = [] + + entities = [ + ( + (v.entity.value, v.entity.is_uri), + v.vectors + ) + for v in m.entities + ] + + while True: + + try: + + resp = self.cassandra.execute( + self.insert_graph_embeddings_stmt, + ( + uuid.uuid4(), m.metadata.user, + m.metadata.collection, m.metadata.id, when, + metadata, entities, + ) + ) + + break + + except Exception as e: + + print("Exception:", type(e)) + print(f"{e}, retry...", flush=True) + time.sleep(1) + + def add_document_embeddings(self, m): + + when = int(time.time() * 1000) + + if m.metadata.metadata: + metadata = [ + ( + v.s.value, v.s.is_uri, v.p.value, v.p.is_uri, + v.o.value, v.o.is_uri + ) + for v in m.metadata.metadata + ] + else: + metadata = [] + + chunks = [ + ( + v.chunk, + v.vectors, + ) + for v in m.chunks + ] + + while True: + + try: + + resp = self.cassandra.execute( + self.insert_document_embeddings_stmt, + ( + uuid.uuid4(), m.metadata.user, + m.metadata.collection, m.metadata.id, when, + metadata, chunks, + ) + ) + + break + + except Exception as e: + + print("Exception:", type(e)) + print(f"{e}, retry...", flush=True) + time.sleep(1) + + diff --git a/trustgraph-flow/trustgraph/metering/counter.py b/trustgraph-flow/trustgraph/metering/counter.py index 6e6b829b..68ddf441 100644 --- a/trustgraph-flow/trustgraph/metering/counter.py +++ b/trustgraph-flow/trustgraph/metering/counter.py @@ -57,7 +57,7 @@ class Processor(Consumer): return model["input_price"], model["output_price"] return None, None # Return None if model is not found - def handle(self, msg): + async def handle(self, msg): v = msg.value() modelname = v.model @@ -98,4 +98,4 @@ class Processor(Consumer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/model/prompt/generic/service.py b/trustgraph-flow/trustgraph/model/prompt/generic/service.py index 96c9be57..b143b759 100755 --- a/trustgraph-flow/trustgraph/model/prompt/generic/service.py +++ b/trustgraph-flow/trustgraph/model/prompt/generic/service.py @@ -63,7 +63,8 @@ class Processor(ConsumerProducer): subscriber=subscriber, input_queue=tc_request_queue, output_queue=tc_response_queue, - pulsar_host = self.pulsar_host + pulsar_host = self.pulsar_host, + pulsar_api_key=self.pulsar_api_key, ) def parse_json(self, text): @@ -77,7 +78,7 @@ class Processor(ConsumerProducer): return json.loads(json_str) - def handle(self, msg): + async def handle(self, msg): v = msg.value() @@ -91,32 +92,32 @@ class Processor(ConsumerProducer): if kind == "extract-definitions": - self.handle_extract_definitions(id, v) + await self.handle_extract_definitions(id, v) return elif kind == "extract-topics": - self.handle_extract_topics(id, v) + await self.handle_extract_topics(id, v) return elif kind == "extract-relationships": - self.handle_extract_relationships(id, v) + await self.handle_extract_relationships(id, v) return elif kind == "extract-rows": - self.handle_extract_rows(id, v) + await self.handle_extract_rows(id, v) return elif kind == "kg-prompt": - self.handle_kg_prompt(id, v) + await self.handle_kg_prompt(id, v) return elif kind == "document-prompt": - self.handle_document_prompt(id, v) + await self.handle_document_prompt(id, v) return else: @@ -124,7 +125,7 @@ class Processor(ConsumerProducer): print("Invalid kind.", flush=True) return - def handle_extract_definitions(self, id, v): + async def handle_extract_definitions(self, id, v): try: @@ -163,7 +164,7 @@ class Processor(ConsumerProducer): print("Send response...", flush=True) r = PromptResponse(definitions=output, error=None) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) print("Done.", flush=True) @@ -181,9 +182,9 @@ class Processor(ConsumerProducer): response=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) - def handle_extract_topics(self, id, v): + async def handle_extract_topics(self, id, v): try: @@ -222,7 +223,7 @@ class Processor(ConsumerProducer): print("Send response...", flush=True) r = PromptResponse(topics=output, error=None) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) print("Done.", flush=True) @@ -240,9 +241,9 @@ class Processor(ConsumerProducer): response=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) - def handle_extract_relationships(self, id, v): + async def handle_extract_relationships(self, id, v): try: @@ -294,7 +295,7 @@ class Processor(ConsumerProducer): print("Send response...", flush=True) r = PromptResponse(relationships=output, error=None) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) print("Done.", flush=True) @@ -312,9 +313,9 @@ class Processor(ConsumerProducer): response=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) - def handle_extract_rows(self, id, v): + async def handle_extract_rows(self, id, v): try: @@ -365,7 +366,7 @@ class Processor(ConsumerProducer): print("Send response...", flush=True) r = PromptResponse(rows=output, error=None) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) print("Done.", flush=True) @@ -383,9 +384,9 @@ class Processor(ConsumerProducer): response=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) - def handle_kg_prompt(self, id, v): + async def handle_kg_prompt(self, id, v): try: @@ -399,7 +400,7 @@ class Processor(ConsumerProducer): print("Send response...", flush=True) r = PromptResponse(answer=ans, error=None) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) print("Done.", flush=True) @@ -417,9 +418,9 @@ class Processor(ConsumerProducer): response=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) - def handle_document_prompt(self, id, v): + async def handle_document_prompt(self, id, v): try: @@ -436,7 +437,7 @@ class Processor(ConsumerProducer): print("Send response...", flush=True) r = PromptResponse(answer=ans, error=None) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) print("Done.", flush=True) @@ -454,7 +455,7 @@ class Processor(ConsumerProducer): response=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) @staticmethod def add_args(parser): @@ -480,5 +481,5 @@ def run(): raise RuntimeError("NOT IMPLEMENTED") - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/model/prompt/template/service.py b/trustgraph-flow/trustgraph/model/prompt/template/service.py index 2e5416f4..58657d7d 100755 --- a/trustgraph-flow/trustgraph/model/prompt/template/service.py +++ b/trustgraph-flow/trustgraph/model/prompt/template/service.py @@ -136,7 +136,8 @@ class Processor(ConsumerProducer): subscriber=subscriber, input_queue=tc_request_queue, output_queue=tc_response_queue, - pulsar_host = self.pulsar_host + pulsar_host = self.pulsar_host, + pulsar_api_key=self.pulsar_api_key, ) # System prompt hack @@ -155,7 +156,7 @@ class Processor(ConsumerProducer): config = prompt_configuration, ) - def handle(self, msg): + async def handle(self, msg): v = msg.value() @@ -190,7 +191,7 @@ class Processor(ConsumerProducer): error=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) return @@ -205,7 +206,7 @@ class Processor(ConsumerProducer): error=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) return @@ -223,7 +224,7 @@ class Processor(ConsumerProducer): response=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) except Exception as e: @@ -239,7 +240,7 @@ class Processor(ConsumerProducer): response=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) @staticmethod def add_args(parser): @@ -293,5 +294,5 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py b/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py index 4db7dbf1..33840378 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py @@ -123,7 +123,7 @@ class Processor(ConsumerProducer): return result - def handle(self, msg): + async def handle(self, msg): v = msg.value() @@ -154,29 +154,19 @@ class Processor(ConsumerProducer): print("Send response...", flush=True) r = TextCompletionResponse(response=resp, error=None, in_token=inputtokens, out_token=outputtokens, model=self.model) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) except TooManyRequests: - print("Send rate limit response...", flush=True) + print("Rate limit...") - r = TextCompletionResponse( - error=Error( - type = "rate-limit", - message = str(e), - ), - response=None, - in_token=None, - out_token=None, - model=None, - ) - - self.producer.send(r, properties={"id": id}) - - self.consumer.acknowledge(msg) + # Leave rate limit retries to the base handler + raise TooManyRequests() except Exception as e: + # Apart from rate limits, treat all exceptions as unrecoverable + print(f"Exception: {e}") print("Send error response...", flush=True) @@ -192,7 +182,7 @@ class Processor(ConsumerProducer): model=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) self.consumer.acknowledge(msg) @@ -234,4 +224,4 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py index a3edb859..252d58ad 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py @@ -4,10 +4,9 @@ Simple LLM service, performs text prompt completion using the Azure OpenAI endpoit service. Input is prompt, output is response. """ -import requests import json from prometheus_client import Histogram -from openai import AzureOpenAI +from openai import AzureOpenAI, RateLimitError import os from .... schema import TextCompletionRequest, TextCompletionResponse, Error @@ -24,9 +23,10 @@ default_output_queue = text_completion_response_queue default_subscriber = module default_temperature = 0.0 default_max_output = 4192 -default_api = "2024-02-15-preview" -default_endpoint = os.getenv("AZURE_ENDPOINT") -default_token = os.getenv("AZURE_TOKEN") +default_api = "2024-12-01-preview" +default_endpoint = os.getenv("AZURE_ENDPOINT", None) +default_token = os.getenv("AZURE_TOKEN", None) +default_model = os.getenv("AZURE_MODEL", None) class Processor(ConsumerProducer): @@ -35,12 +35,13 @@ class Processor(ConsumerProducer): input_queue = params.get("input_queue", default_input_queue) output_queue = params.get("output_queue", default_output_queue) subscriber = params.get("subscriber", default_subscriber) - endpoint = params.get("endpoint", default_endpoint) - token = params.get("token", default_token) temperature = params.get("temperature", default_temperature) max_output = params.get("max_output", default_max_output) - model = params.get("model") + api = params.get("api_version", default_api) + endpoint = params.get("endpoint", default_endpoint) + token = params.get("token", default_token) + model = params.get("model", default_model) if endpoint is None: raise RuntimeError("Azure endpoint not specified") @@ -85,7 +86,7 @@ class Processor(ConsumerProducer): azure_endpoint = endpoint, ) - def handle(self, msg): + async def handle(self, msg): v = msg.value() @@ -126,30 +127,27 @@ class Processor(ConsumerProducer): print(f"Output Tokens: {outputtokens}", flush=True) print("Send response...", flush=True) - r = TextCompletionResponse(response=resp.choices[0].message.content, error=None, in_token=inputtokens, out_token=outputtokens, model=self.model) - self.producer.send(r, properties={"id": id}) + r = TextCompletionResponse( + response=resp.choices[0].message.content, + error=None, + in_token=inputtokens, + out_token=outputtokens, + model=self.model + ) - except TooManyRequests: + await self.send(r, properties={"id": id}) + + except RateLimitError: print("Send rate limit response...", flush=True) - r = TextCompletionResponse( - error=Error( - type = "rate-limit", - message = str(e), - ), - response=None, - in_token=None, - out_token=None, - model=None, - ) - - self.producer.send(r, properties={"id": id}) - - self.consumer.acknowledge(msg) + # Leave rate limit retries to the base handler + raise TooManyRequests() except Exception as e: + # Apart from rate limits, treat all exceptions as unrecoverable + print(f"Exception: {e}") print("Send error response...", flush=True) @@ -165,7 +163,7 @@ class Processor(ConsumerProducer): model=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) self.consumer.acknowledge(msg) @@ -181,6 +179,7 @@ class Processor(ConsumerProducer): parser.add_argument( '-e', '--endpoint', + default=default_endpoint, help=f'LLM model endpoint' ) @@ -192,11 +191,13 @@ class Processor(ConsumerProducer): parser.add_argument( '-k', '--token', + default=default_token, help=f'LLM model token' ) parser.add_argument( '-m', '--model', + default=default_model, help=f'LLM model' ) @@ -216,4 +217,4 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/model/text_completion/claude/llm.py b/trustgraph-flow/trustgraph/model/text_completion/claude/llm.py index 01ce837d..195a39e4 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/claude/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/claude/llm.py @@ -73,7 +73,7 @@ class Processor(ConsumerProducer): print("Initialised", flush=True) - def handle(self, msg): + async def handle(self, msg): v = msg.value() @@ -87,8 +87,6 @@ class Processor(ConsumerProducer): try: - # FIXME: Rate limits? - with __class__.text_completion_metric.time(): response = message = self.claude.messages.create( @@ -117,34 +115,26 @@ class Processor(ConsumerProducer): print(f"Output Tokens: {outputtokens}", flush=True) print("Send response...", flush=True) - r = TextCompletionResponse(response=resp, error=None, in_token=inputtokens, out_token=outputtokens, model=self.model) + r = TextCompletionResponse( + response=resp, + error=None, + in_token=inputtokens, + out_token=outputtokens, + model=self.model + ) self.send(r, properties={"id": id}) print("Done.", flush=True) - # FIXME: Wrong exception, don't know what this LLM throws - # for a rate limit - except TooManyRequests: + except anthropic.RateLimitError: - print("Send rate limit response...", flush=True) - - r = TextCompletionResponse( - error=Error( - type = "rate-limit", - message = str(e), - ), - response=None, - in_token=None, - out_token=None, - model=None, - ) - - self.producer.send(r, properties={"id": id}) - - self.consumer.acknowledge(msg) + # Leave rate limit retries to the base handler + raise TooManyRequests() except Exception as e: + # Apart from rate limits, treat all exceptions as unrecoverable + print(f"Exception: {e}") print("Send error response...", flush=True) @@ -160,7 +150,7 @@ class Processor(ConsumerProducer): model=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) self.consumer.acknowledge(msg) @@ -200,6 +190,6 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py b/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py index d03e1554..d5dab142 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py @@ -69,7 +69,7 @@ class Processor(ConsumerProducer): print("Initialised", flush=True) - def handle(self, msg): + async def handle(self, msg): v = msg.value() @@ -106,33 +106,21 @@ class Processor(ConsumerProducer): print("Send response...", flush=True) r = TextCompletionResponse(response=resp, error=None, in_token=inputtokens, out_token=outputtokens, model=self.model) - self.send(r, properties={"id": id}) + self.await send(r, properties={"id": id}) print("Done.", flush=True) # FIXME: Wrong exception, don't know what this LLM throws # for a rate limit - except TooManyRequests: + except cohere.TooManyRequestsError: - print("Send rate limit response...", flush=True) - - r = TextCompletionResponse( - error=Error( - type = "rate-limit", - message = str(e), - ), - response=None, - in_token=None, - out_token=None, - model=None, - ) - - self.producer.send(r, properties={"id": id}) - - self.consumer.acknowledge(msg) + # Leave rate limit retries to the base handler + raise TooManyRequests() except Exception as e: + # Apart from rate limits, treat all exceptions as unrecoverable + print(f"Exception: {e}") print("Send error response...", flush=True) @@ -148,7 +136,7 @@ class Processor(ConsumerProducer): model=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) self.consumer.acknowledge(msg) @@ -181,6 +169,6 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py b/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py index a249998d..98ecaf0e 100644 --- a/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py @@ -88,7 +88,8 @@ class Processor(ConsumerProducer): HarmCategory.HARM_CATEGORY_HARASSMENT: block_level, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: block_level, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: block_level, - # There is a documentation conflict on whether or not CIVIC_INTEGRITY is a valid category + # There is a documentation conflict on whether or not + # CIVIC_INTEGRITY is a valid category # HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY: block_level, } @@ -101,7 +102,7 @@ class Processor(ConsumerProducer): print("Initialised", flush=True) - def handle(self, msg): + async def handle(self, msg): v = msg.value() @@ -122,8 +123,6 @@ class Processor(ConsumerProducer): try: - # FIXME: Rate limits? - with __class__.text_completion_metric.time(): chat_session = self.llm.start_chat( @@ -140,35 +139,30 @@ class Processor(ConsumerProducer): print(f"Output Tokens: {outputtokens}", flush=True) print("Send response...", flush=True) - r = TextCompletionResponse(response=resp, error=None, in_token=inputtokens, out_token=outputtokens, model=self.model) - self.send(r, properties={"id": id}) + r = TextCompletionResponse( + response=resp, + error=None, + in_token=inputtokens, + out_token=outputtokens, + model=self.model + ) + await self.send(r, properties={"id": id}) print("Done.", flush=True) - # FIXME: Wrong exception, don't know what this LLM throws - # for a rate limit except ResourceExhausted as e: - print("Send rate limit response...", flush=True) + print("Hit rate limit:", e, flush=True) - r = TextCompletionResponse( - error=Error( - type = "rate-limit", - message = str(e), - ), - response=None, - in_token=None, - out_token=None, - model=None, - ) - - self.producer.send(r, properties={"id": id}) - - self.consumer.acknowledge(msg) + # Leave rate limit retries to the default handler + raise TooManyRequests() except Exception as e: - print(f"Exception: {e}") + # Apart from rate limits, treat all exceptions as unrecoverable + + print(type(e), flush=True) + print(f"Exception: {e}", flush=True) print("Send error response...", flush=True) @@ -183,7 +177,7 @@ class Processor(ConsumerProducer): model=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) self.consumer.acknowledge(msg) @@ -223,6 +217,6 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py b/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py index 274948a8..483412a2 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py @@ -74,7 +74,7 @@ class Processor(ConsumerProducer): print("Initialised", flush=True) - def handle(self, msg): + async def handle(self, msg): v = msg.value() @@ -122,30 +122,11 @@ class Processor(ConsumerProducer): out_token=outputtokens, model="llama.cpp" ) - self.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) print("Done.", flush=True) - # FIXME: Wrong exception, don't know what this LLM throws - # for a rate limit - except TooManyRequests: - - print("Send rate limit response...", flush=True) - - r = TextCompletionResponse( - error=Error( - type = "rate-limit", - message = str(e), - ), - response=None, - in_token=None, - out_token=None, - model=None, - ) - - self.producer.send(r, properties={"id": id}) - - self.consumer.acknowledge(msg) + # SLM, presumably there aren't rate limits except Exception as e: @@ -164,7 +145,7 @@ class Processor(ConsumerProducer): model=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) self.consumer.acknowledge(msg) @@ -204,6 +185,6 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/model/text_completion/lmstudio/__init__.py b/trustgraph-flow/trustgraph/model/text_completion/lmstudio/__init__.py new file mode 100644 index 00000000..f2017af8 --- /dev/null +++ b/trustgraph-flow/trustgraph/model/text_completion/lmstudio/__init__.py @@ -0,0 +1,3 @@ + +from . llm import * + diff --git a/trustgraph-flow/trustgraph/model/text_completion/lmstudio/__main__.py b/trustgraph-flow/trustgraph/model/text_completion/lmstudio/__main__.py new file mode 100755 index 00000000..91342d2d --- /dev/null +++ b/trustgraph-flow/trustgraph/model/text_completion/lmstudio/__main__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 + +from . llm import run + +if __name__ == '__main__': + run() + diff --git a/trustgraph-flow/trustgraph/model/text_completion/lmstudio/llm.py b/trustgraph-flow/trustgraph/model/text_completion/lmstudio/llm.py new file mode 100755 index 00000000..16ff2df4 --- /dev/null +++ b/trustgraph-flow/trustgraph/model/text_completion/lmstudio/llm.py @@ -0,0 +1,193 @@ + +""" +Simple LLM service, performs text prompt completion using OpenAI. +Input is prompt, output is response. +""" + +from openai import OpenAI +from prometheus_client import Histogram +import os + +from .... schema import TextCompletionRequest, TextCompletionResponse, Error +from .... schema import text_completion_request_queue +from .... schema import text_completion_response_queue +from .... log_level import LogLevel +from .... base import ConsumerProducer +from .... exceptions import TooManyRequests + +module = ".".join(__name__.split(".")[1:-1]) + +default_input_queue = text_completion_request_queue +default_output_queue = text_completion_response_queue +default_subscriber = module +default_model = 'gemma3:9b' +default_url = os.getenv("LMSTUDIO_URL", "http://localhost:1234/") +default_temperature = 0.0 +default_max_output = 4096 + +class Processor(ConsumerProducer): + + def __init__(self, **params): + + input_queue = params.get("input_queue", default_input_queue) + output_queue = params.get("output_queue", default_output_queue) + subscriber = params.get("subscriber", default_subscriber) + model = params.get("model", default_model) + url = params.get("url", default_url) + temperature = params.get("temperature", default_temperature) + max_output = params.get("max_output", default_max_output) + + super(Processor, self).__init__( + **params | { + "input_queue": input_queue, + "output_queue": output_queue, + "subscriber": subscriber, + "input_schema": TextCompletionRequest, + "output_schema": TextCompletionResponse, + "model": model, + "temperature": temperature, + "max_output": max_output, + "url" : url, + } + ) + + if not hasattr(__class__, "text_completion_metric"): + __class__.text_completion_metric = Histogram( + 'text_completion_duration', + 'Text completion duration (seconds)', + buckets=[ + 0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, + 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, + 30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0, + 120.0 + ] + ) + + self.model = model + self.url = url + "v1/" + self.temperature = temperature + self.max_output = max_output + self.openai = OpenAI( + base_url=self.url, + api_key = "sk-no-key-required", + ) + + print("Initialised", flush=True) + + async def handle(self, msg): + + v = msg.value() + + # Sender-produced ID + + id = msg.properties()["id"] + + print(f"Handling prompt {id}...", flush=True) + + prompt = v.system + "\n\n" + v.prompt + + try: + + # FIXME: Rate limits + + with __class__.text_completion_metric.time(): + + print(prompt) + + resp = self.openai.chat.completions.create( + model=self.model, + messages=[ + {"role": "user", "content": prompt} + ] + #temperature=self.temperature, + #max_tokens=self.max_output, + #top_p=1, + #frequency_penalty=0, + #presence_penalty=0, + #response_format={ + # "type": "text" + #} + ) + + print(resp) + + inputtokens = resp.usage.prompt_tokens + outputtokens = resp.usage.completion_tokens + + print(resp.choices[0].message.content, flush=True) + print(f"Input Tokens: {inputtokens}", flush=True) + print(f"Output Tokens: {outputtokens}", flush=True) + + print("Send response...", flush=True) + r = TextCompletionResponse( + response=resp.choices[0].message.content, + error=None, + in_token=inputtokens, + out_token=outputtokens, + model=self.model, + ) + await self.send(r, properties={"id": id}) + + print("Done.", flush=True) + + # SLM, presumably there aren't rate limits + + except Exception as e: + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + r = TextCompletionResponse( + error=Error( + type = "llm-error", + message = str(e), + ), + response=None, + in_token=None, + out_token=None, + model=None, + ) + + await self.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) + + @staticmethod + def add_args(parser): + + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) + + parser.add_argument( + '-m', '--model', + default=default_model, + help=f'LLM model (default: gemma3:9b)' + ) + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'LMStudio URL (default: {default_url})' + ) + + parser.add_argument( + '-t', '--temperature', + type=float, + default=default_temperature, + help=f'LLM temperature parameter (default: {default_temperature})' + ) + + parser.add_argument( + '-x', '--max-output', + type=int, + default=default_max_output, + help=f'LLM max output tokens (default: {default_max_output})' + ) + +def run(): + Processor.launch(module, __doc__) + diff --git a/trustgraph-flow/trustgraph/model/text_completion/mistral/__init__.py b/trustgraph-flow/trustgraph/model/text_completion/mistral/__init__.py new file mode 100644 index 00000000..f2017af8 --- /dev/null +++ b/trustgraph-flow/trustgraph/model/text_completion/mistral/__init__.py @@ -0,0 +1,3 @@ + +from . llm import * + diff --git a/trustgraph-flow/trustgraph/model/text_completion/mistral/__main__.py b/trustgraph-flow/trustgraph/model/text_completion/mistral/__main__.py new file mode 100755 index 00000000..91342d2d --- /dev/null +++ b/trustgraph-flow/trustgraph/model/text_completion/mistral/__main__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 + +from . llm import run + +if __name__ == '__main__': + run() + diff --git a/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py b/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py new file mode 100755 index 00000000..8130cf8a --- /dev/null +++ b/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py @@ -0,0 +1,201 @@ + +""" +Simple LLM service, performs text prompt completion using Mistral. +Input is prompt, output is response. +""" + +from mistralai import Mistral, RateLimitError +from prometheus_client import Histogram +import os + +from .... schema import TextCompletionRequest, TextCompletionResponse, Error +from .... schema import text_completion_request_queue +from .... schema import text_completion_response_queue +from .... log_level import LogLevel +from .... base import ConsumerProducer +from .... exceptions import TooManyRequests + +module = ".".join(__name__.split(".")[1:-1]) + +default_input_queue = text_completion_request_queue +default_output_queue = text_completion_response_queue +default_subscriber = module +default_model = 'ministral-8b-latest' +default_temperature = 0.0 +default_max_output = 4096 +default_api_key = os.getenv("MISTRAL_TOKEN") + +class Processor(ConsumerProducer): + + def __init__(self, **params): + + input_queue = params.get("input_queue", default_input_queue) + output_queue = params.get("output_queue", default_output_queue) + subscriber = params.get("subscriber", default_subscriber) + model = params.get("model", default_model) + api_key = params.get("api_key", default_api_key) + temperature = params.get("temperature", default_temperature) + max_output = params.get("max_output", default_max_output) + + if api_key is None: + raise RuntimeError("Mistral API key not specified") + + super(Processor, self).__init__( + **params | { + "input_queue": input_queue, + "output_queue": output_queue, + "subscriber": subscriber, + "input_schema": TextCompletionRequest, + "output_schema": TextCompletionResponse, + "model": model, + "temperature": temperature, + "max_output": max_output, + } + ) + + if not hasattr(__class__, "text_completion_metric"): + __class__.text_completion_metric = Histogram( + 'text_completion_duration', + 'Text completion duration (seconds)', + buckets=[ + 0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, + 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, + 30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0, + 120.0 + ] + ) + + self.model = model + self.temperature = temperature + self.max_output = max_output + self.mistral = Mistral(api_key=api_key) + + print("Initialised", flush=True) + + async def handle(self, msg): + + v = msg.value() + + # Sender-produced ID + + id = msg.properties()["id"] + + print(f"Handling prompt {id}...", flush=True) + + prompt = v.system + "\n\n" + v.prompt + + try: + + with __class__.text_completion_metric.time(): + + resp = self.mistral.chat.complete( + model=self.model, + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt + } + ] + } + ], + temperature=self.temperature, + max_tokens=self.max_output, + top_p=1, + frequency_penalty=0, + presence_penalty=0, + response_format={ + "type": "text" + } + ) + + inputtokens = resp.usage.prompt_tokens + outputtokens = resp.usage.completion_tokens + print(resp.choices[0].message.content, flush=True) + print(f"Input Tokens: {inputtokens}", flush=True) + print(f"Output Tokens: {outputtokens}", flush=True) + + print("Send response...", flush=True) + r = TextCompletionResponse( + response=resp.choices[0].message.content, + error=None, + in_token=inputtokens, + out_token=outputtokens, + model=self.model + ) + await self.send(r, properties={"id": id}) + + print("Done.", flush=True) + + # FIXME: Wrong exception, don't know what this LLM throws + # for a rate limit + except Mistral.RateLimitError: + + # Leave rate limit retries to the base handler + raise TooManyRequests() + + except Exception as e: + + # Apart from rate limits, treat all exceptions as unrecoverable + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + r = TextCompletionResponse( + error=Error( + type = "llm-error", + message = str(e), + ), + response=None, + in_token=None, + out_token=None, + model=None, + ) + + await self.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) + + @staticmethod + def add_args(parser): + + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) + + parser.add_argument( + '-m', '--model', + default=default_model, + help=f'LLM model (default: ministral-8b-latest)' + ) + + parser.add_argument( + '-k', '--api-key', + default=default_api_key, + help=f'Mistral API Key' + ) + + parser.add_argument( + '-t', '--temperature', + type=float, + default=default_temperature, + help=f'LLM temperature parameter (default: {default_temperature})' + ) + + parser.add_argument( + '-x', '--max-output', + type=int, + default=default_max_output, + help=f'LLM max output tokens (default: {default_max_output})' + ) + +def run(): + + Processor.launch(module, __doc__) + + diff --git a/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py b/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py index 00d44f6d..6d825bac 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py @@ -71,7 +71,7 @@ class Processor(ConsumerProducer): self.model = model self.llm = Client(host=ollama) - def handle(self, msg): + async def handle(self, msg): v = msg.value() @@ -96,30 +96,11 @@ class Processor(ConsumerProducer): r = TextCompletionResponse(response=response_text, error=None, in_token=inputtokens, out_token=outputtokens, model="ollama") - self.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) print("Done.", flush=True) - # FIXME: Wrong exception, don't know what this LLM throws - # for a rate limit - except TooManyRequests: - - print("Send rate limit response...", flush=True) - - r = TextCompletionResponse( - error=Error( - type = "rate-limit", - message = str(e), - ), - response=None, - in_token=None, - out_token=None, - model=None, - ) - - self.producer.send(r, properties={"id": id}) - - self.consumer.acknowledge(msg) + # SLM, presumably no rate limits except Exception as e: @@ -138,7 +119,7 @@ class Processor(ConsumerProducer): model=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) self.consumer.acknowledge(msg) @@ -164,6 +145,6 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py index c874943e..ebfae9ed 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py @@ -4,7 +4,7 @@ Simple LLM service, performs text prompt completion using OpenAI. Input is prompt, output is response. """ -from openai import OpenAI +from openai import OpenAI, RateLimitError from prometheus_client import Histogram import os @@ -73,7 +73,7 @@ class Processor(ConsumerProducer): print("Initialised", flush=True) - def handle(self, msg): + async def handle(self, msg): v = msg.value() @@ -87,8 +87,6 @@ class Processor(ConsumerProducer): try: - # FIXME: Rate limits - with __class__.text_completion_metric.time(): resp = self.openai.chat.completions.create( @@ -128,33 +126,21 @@ class Processor(ConsumerProducer): out_token=outputtokens, model=self.model ) - self.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) print("Done.", flush=True) # FIXME: Wrong exception, don't know what this LLM throws # for a rate limit - except TooManyRequests: + except openai.RateLimitError: - print("Send rate limit response...", flush=True) - - r = TextCompletionResponse( - error=Error( - type = "rate-limit", - message = str(e), - ), - response=None, - in_token=None, - out_token=None, - model=None, - ) - - self.producer.send(r, properties={"id": id}) - - self.consumer.acknowledge(msg) + # Leave rate limit retries to the base handler + raise TooManyRequests() except Exception as e: + # Apart from rate limits, treat all exceptions as unrecoverable + print(f"Exception: {e}") print("Send error response...", flush=True) @@ -170,7 +156,7 @@ class Processor(ConsumerProducer): model=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) self.consumer.acknowledge(msg) @@ -210,6 +196,6 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/processing/processing.py b/trustgraph-flow/trustgraph/processing/processing.py index 5e4c7c8a..5352776a 100644 --- a/trustgraph-flow/trustgraph/processing/processing.py +++ b/trustgraph-flow/trustgraph/processing/processing.py @@ -49,11 +49,12 @@ class Processing: pulsar_host, log_level, file, + pulsar_api_key=None, ): self.pulsar_host = pulsar_host self.log_level = log_level self.file = file - + self.pulsar_api_key = pulsar_api_key self.defs = load(open(file, "r"), Loader=Loader) def run(self): @@ -68,6 +69,7 @@ class Processing: params = { "pulsar_host": self.pulsar_host, + "pulsar_api_key": self.pulsar_api_key, "log_level": str(self.log_level), } @@ -125,12 +127,19 @@ def run(): ) default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') + default_pulsar_api_key = os.getenv("PULSAR_API_KEY", None) parser.add_argument( '-p', '--pulsar-host', default=default_pulsar_host, help=f'Pulsar host (default: {default_pulsar_host})', ) + + parser.add_argument( + '--pulsar-api-key', + default=default_pulsar_api_key, + help=f'Pulsar API key', + ) parser.add_argument( '-l', '--log-level', @@ -153,6 +162,7 @@ def run(): try: p = Processing( pulsar_host=args.pulsar_host, + pulsar_api_key=args.pulsar_api_key, file=args.file, log_level=args.log_level, ) diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py index 8e106e6f..b16399e9 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py @@ -40,7 +40,7 @@ class Processor(ConsumerProducer): self.vecstore = DocVectors(store_uri) - def handle(self, msg): + async def handle(self, msg): try: @@ -64,7 +64,7 @@ class Processor(ConsumerProducer): print("Send response...", flush=True) r = DocumentEmbeddingsResponse(documents=chunks, error=None) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) print("Done.", flush=True) @@ -82,7 +82,7 @@ class Processor(ConsumerProducer): documents=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) self.consumer.acknowledge(msg) @@ -102,5 +102,5 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py index b8502143..6a88671c 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/pinecone/service.py @@ -56,7 +56,7 @@ class Processor(ConsumerProducer): } ) - def handle(self, msg): + async def handle(self, msg): try: @@ -100,7 +100,7 @@ class Processor(ConsumerProducer): print("Send response...", flush=True) r = DocumentEmbeddingsResponse(documents=chunks, error=None) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) print("Done.", flush=True) @@ -118,7 +118,7 @@ class Processor(ConsumerProducer): documents=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) self.consumer.acknowledge(msg) @@ -143,5 +143,5 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py index 7bb5133a..128203ad 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py @@ -30,6 +30,8 @@ class Processor(ConsumerProducer): output_queue = params.get("output_queue", default_output_queue) subscriber = params.get("subscriber", default_subscriber) store_uri = params.get("store_uri", default_store_uri) + #optional api key + api_key = params.get("api_key", None) super(Processor, self).__init__( **params | { @@ -39,12 +41,13 @@ class Processor(ConsumerProducer): "input_schema": DocumentEmbeddingsRequest, "output_schema": DocumentEmbeddingsResponse, "store_uri": store_uri, + "api_key": api_key, } ) - self.client = QdrantClient(url=store_uri) + self.client = QdrantClient(url=store_uri, api_key=api_key) - def handle(self, msg): + async def handle(self, msg): try: @@ -78,7 +81,7 @@ class Processor(ConsumerProducer): print("Send response...", flush=True) r = DocumentEmbeddingsResponse(documents=chunks, error=None) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) print("Done.", flush=True) @@ -96,7 +99,7 @@ class Processor(ConsumerProducer): documents=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) self.consumer.acknowledge(msg) @@ -111,10 +114,16 @@ class Processor(ConsumerProducer): parser.add_argument( '-t', '--store-uri', default=default_store_uri, - help=f'Milvus store URI (default: {default_store_uri})' + help=f'Qdrant store URI (default: {default_store_uri})' + ) + + parser.add_argument( + '-k', '--api-key', + default=None, + help=f'API key for qdrant (default: None)' ) def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py index b5f9ae5b..8dd8d04d 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py @@ -46,7 +46,7 @@ class Processor(ConsumerProducer): else: return Value(value=ent, is_uri=False) - def handle(self, msg): + async def handle(self, msg): try: @@ -79,7 +79,7 @@ class Processor(ConsumerProducer): print("Send response...", flush=True) r = GraphEmbeddingsResponse(entities=entities, error=None) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) print("Done.", flush=True) @@ -97,7 +97,7 @@ class Processor(ConsumerProducer): entities=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) self.consumer.acknowledge(msg) @@ -117,5 +117,5 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py index 2534d278..90cfc6de 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/pinecone/service.py @@ -62,7 +62,7 @@ class Processor(ConsumerProducer): else: return Value(value=ent, is_uri=False) - def handle(self, msg): + async def handle(self, msg): try: @@ -120,7 +120,7 @@ class Processor(ConsumerProducer): print("Send response...", flush=True) r = GraphEmbeddingsResponse(entities=entities, error=None) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) print("Done.", flush=True) @@ -138,7 +138,7 @@ class Processor(ConsumerProducer): entities=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) self.consumer.acknowledge(msg) @@ -163,5 +163,5 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py index c2dcaa4c..dc3e28f3 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py @@ -30,6 +30,7 @@ class Processor(ConsumerProducer): output_queue = params.get("output_queue", default_output_queue) subscriber = params.get("subscriber", default_subscriber) store_uri = params.get("store_uri", default_store_uri) + api_key = params.get("api_key", None) super(Processor, self).__init__( **params | { @@ -39,10 +40,11 @@ class Processor(ConsumerProducer): "input_schema": GraphEmbeddingsRequest, "output_schema": GraphEmbeddingsResponse, "store_uri": store_uri, + "api_key": api_key, } ) - self.client = QdrantClient(url=store_uri) + self.client = QdrantClient(url=store_uri, api_key=api_key) def create_value(self, ent): if ent.startswith("http://") or ent.startswith("https://"): @@ -50,7 +52,7 @@ class Processor(ConsumerProducer): else: return Value(value=ent, is_uri=False) - def handle(self, msg): + async def handle(self, msg): try: @@ -104,7 +106,7 @@ class Processor(ConsumerProducer): print("Send response...", flush=True) r = GraphEmbeddingsResponse(entities=entities, error=None) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) print("Done.", flush=True) @@ -122,7 +124,7 @@ class Processor(ConsumerProducer): entities=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) self.consumer.acknowledge(msg) @@ -137,10 +139,16 @@ class Processor(ConsumerProducer): parser.add_argument( '-t', '--store-uri', default=default_store_uri, - help=f'Milvus store URI (default: {default_store_uri})' + help=f'Qdrant store URI (default: {default_store_uri})' + ) + + parser.add_argument( + '-k', '--api-key', + default=None, + help=f'API key for qdrant (default: None)' ) def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py index 4245784d..e3687756 100755 --- a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py @@ -26,6 +26,8 @@ class Processor(ConsumerProducer): output_queue = params.get("output_queue", default_output_queue) subscriber = params.get("subscriber", default_subscriber) graph_host = params.get("graph_host", default_graph_host) + graph_username = params.get("graph_username", None) + graph_password = params.get("graph_password", None) super(Processor, self).__init__( **params | { @@ -35,10 +37,14 @@ class Processor(ConsumerProducer): "input_schema": TriplesQueryRequest, "output_schema": TriplesQueryResponse, "graph_host": graph_host, + "graph_username": graph_username, + "graph_password": graph_password, } ) self.graph_host = [graph_host] + self.username = graph_username + self.password = graph_password self.table = None def create_value(self, ent): @@ -47,7 +53,7 @@ class Processor(ConsumerProducer): else: return Value(value=ent, is_uri=False) - def handle(self, msg): + async def handle(self, msg): try: @@ -56,10 +62,17 @@ class Processor(ConsumerProducer): table = (v.user, v.collection) if table != self.table: - self.tg = TrustGraph( - hosts=self.graph_host, - keyspace=v.user, table=v.collection, - ) + if self.username and self.password: + self.tg = TrustGraph( + hosts=self.graph_host, + keyspace=v.user, table=v.collection, + username=self.username, password=self.password + ) + else: + self.tg = TrustGraph( + hosts=self.graph_host, + keyspace=v.user, table=v.collection, + ) self.table = table # Sender-produced ID @@ -141,7 +154,7 @@ class Processor(ConsumerProducer): print("Send response...", flush=True) r = TriplesQueryResponse(triples=triples, error=None) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) print("Done.", flush=True) @@ -159,7 +172,7 @@ class Processor(ConsumerProducer): response=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) self.consumer.acknowledge(msg) @@ -176,8 +189,21 @@ class Processor(ConsumerProducer): default="localhost", help=f'Graph host (default: localhost)' ) + + parser.add_argument( + '--graph-username', + default=None, + help=f'Cassandra username' + ) + + parser.add_argument( + '--graph-password', + default=None, + help=f'Cassandra password' + ) + def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/query/triples/falkordb/service.py b/trustgraph-flow/trustgraph/query/triples/falkordb/service.py index 1d77bb15..56fed6d3 100755 --- a/trustgraph-flow/trustgraph/query/triples/falkordb/service.py +++ b/trustgraph-flow/trustgraph/query/triples/falkordb/service.py @@ -54,7 +54,7 @@ class Processor(ConsumerProducer): else: return Value(value=ent, is_uri=False) - def handle(self, msg): + async def handle(self, msg): try: @@ -301,7 +301,7 @@ class Processor(ConsumerProducer): print("Send response...", flush=True) r = TriplesQueryResponse(triples=triples, error=None) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) print("Done.", flush=True) @@ -319,7 +319,7 @@ class Processor(ConsumerProducer): response=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) self.consumer.acknowledge(msg) @@ -345,5 +345,5 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/query/triples/memgraph/service.py b/trustgraph-flow/trustgraph/query/triples/memgraph/service.py index 46dd19e3..f442c4ef 100755 --- a/trustgraph-flow/trustgraph/query/triples/memgraph/service.py +++ b/trustgraph-flow/trustgraph/query/triples/memgraph/service.py @@ -58,7 +58,7 @@ class Processor(ConsumerProducer): else: return Value(value=ent, is_uri=False) - def handle(self, msg): + async def handle(self, msg): try: @@ -313,7 +313,7 @@ class Processor(ConsumerProducer): print("Send response...", flush=True) r = TriplesQueryResponse(triples=triples, error=None) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) print("Done.", flush=True) @@ -331,7 +331,7 @@ class Processor(ConsumerProducer): response=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) self.consumer.acknowledge(msg) @@ -369,5 +369,5 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/query/triples/neo4j/service.py b/trustgraph-flow/trustgraph/query/triples/neo4j/service.py index d60bc4f4..49ba0345 100755 --- a/trustgraph-flow/trustgraph/query/triples/neo4j/service.py +++ b/trustgraph-flow/trustgraph/query/triples/neo4j/service.py @@ -58,7 +58,7 @@ class Processor(ConsumerProducer): else: return Value(value=ent, is_uri=False) - def handle(self, msg): + async def handle(self, msg): try: @@ -297,7 +297,7 @@ class Processor(ConsumerProducer): print("Send response...", flush=True) r = TriplesQueryResponse(triples=triples, error=None) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) print("Done.", flush=True) @@ -315,7 +315,7 @@ class Processor(ConsumerProducer): response=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) self.consumer.acknowledge(msg) @@ -353,5 +353,5 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py index 4310cdbd..bb8b008e 100755 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py @@ -50,6 +50,8 @@ class Processor(ConsumerProducer): document_embeddings_response_queue ) + doc_limit = params.get("doc_limit", 10) + super(Processor, self).__init__( **params | { "input_queue": input_queue, @@ -68,6 +70,7 @@ class Processor(ConsumerProducer): self.rag = DocumentRag( pulsar_host=self.pulsar_host, + pulsar_api_key=self.pulsar_api_key, pr_request_queue=pr_request_queue, pr_response_queue=pr_response_queue, emb_request_queue=emb_request_queue, @@ -78,7 +81,9 @@ class Processor(ConsumerProducer): module=module, ) - def handle(self, msg): + self.doc_limit = doc_limit + + async def handle(self, msg): try: @@ -89,11 +94,16 @@ class Processor(ConsumerProducer): print(f"Handling input {id}...", flush=True) - response = self.rag.query(v.query) + if v.doc_limit: + doc_limit = v.doc_limit + else: + doc_limit = self.doc_limit + + response = self.rag.query(v.query, doc_limit=doc_limit) print("Send response...", flush=True) r = DocumentRagResponse(response = response, error=None) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) print("Done.", flush=True) @@ -111,7 +121,7 @@ class Processor(ConsumerProducer): response=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) self.consumer.acknowledge(msg) @@ -123,6 +133,13 @@ class Processor(ConsumerProducer): default_output_queue, ) + parser.add_argument( + '-d', '--doc-limit', + type=int, + default=20, + help=f'Default document fetch limit (default: 10)' + ) + parser.add_argument( '--prompt-request-queue', default=prompt_request_queue, @@ -161,5 +178,5 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py index 1219050e..2c45ecd4 100755 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py @@ -31,9 +31,7 @@ class Processor(ConsumerProducer): input_queue = params.get("input_queue", default_input_queue) output_queue = params.get("output_queue", default_output_queue) subscriber = params.get("subscriber", default_subscriber) - entity_limit = params.get("entity_limit", 50) - triple_limit = params.get("triple_limit", 30) - max_subgraph_size = params.get("max_subgraph_size", 3000) + pr_request_queue = params.get( "prompt_request_queue", prompt_request_queue ) @@ -59,6 +57,11 @@ class Processor(ConsumerProducer): "triples_response_queue", triples_response_queue ) + entity_limit = params.get("entity_limit", 50) + triple_limit = params.get("triple_limit", 30) + max_subgraph_size = params.get("max_subgraph_size", 150) + max_path_length = params.get("max_path_length", 2) + super(Processor, self).__init__( **params | { "input_queue": input_queue, @@ -82,6 +85,7 @@ class Processor(ConsumerProducer): self.rag = GraphRag( pulsar_host=self.pulsar_host, + pulsar_api_key=self.pulsar_api_key, pr_request_queue=pr_request_queue, pr_response_queue=pr_response_queue, emb_request_queue=emb_request_queue, @@ -91,13 +95,15 @@ class Processor(ConsumerProducer): tpl_request_queue=triples_request_queue, tpl_response_queue=triples_response_queue, verbose=True, - entity_limit=entity_limit, - triple_limit=triple_limit, - max_subgraph_size=max_subgraph_size, module=module, ) - def handle(self, msg): + self.default_entity_limit = entity_limit + self.default_triple_limit = triple_limit + self.default_max_subgraph_size = max_subgraph_size + self.default_max_path_length = max_path_length + + async def handle(self, msg): try: @@ -105,16 +111,39 @@ class Processor(ConsumerProducer): # Sender-produced ID id = msg.properties()["id"] - + print(f"Handling input {id}...", flush=True) + if v.entity_limit: + entity_limit = v.entity_limit + else: + entity_limit = self.default_entity_limit + + if v.triple_limit: + triple_limit = v.triple_limit + else: + triple_limit = self.default_triple_limit + + if v.max_subgraph_size: + max_subgraph_size = v.max_subgraph_size + else: + max_subgraph_size = self.default_max_subgraph_size + + if v.max_path_length: + max_path_length = v.max_path_length + else: + max_path_length = self.default_max_path_length + response = self.rag.query( - query=v.query, user=v.user, collection=v.collection + query=v.query, user=v.user, collection=v.collection, + entity_limit=entity_limit, triple_limit=triple_limit, + max_subgraph_size=max_subgraph_size, + max_path_length=max_path_length, ) print("Send response...", flush=True) - r = GraphRagResponse(response = response, error=None) - self.producer.send(r, properties={"id": id}) + r = GraphRagResponse(response=response, error=None) + await self.send(r, properties={"id": id}) print("Done.", flush=True) @@ -132,7 +161,7 @@ class Processor(ConsumerProducer): response=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) self.consumer.acknowledge(msg) @@ -148,21 +177,28 @@ class Processor(ConsumerProducer): '-e', '--entity-limit', type=int, default=50, - help=f'Entity vector fetch limit (default: 50)' + help=f'Default entity vector fetch limit (default: 50)' ) parser.add_argument( '-t', '--triple-limit', type=int, default=30, - help=f'Triple query limit, per query (default: 30)' + help=f'Default triple query limit, per query (default: 30)' ) parser.add_argument( '-u', '--max-subgraph-size', type=int, - default=3000, - help=f'Max subgraph size (default: 3000)' + default=150, + help=f'Default max subgraph size (default: 150)' + ) + + parser.add_argument( + '-a', '--max-path-length', + type=int, + default=2, + help=f'Default max path length (default: 2)' ) parser.add_argument( @@ -215,5 +251,5 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py index 00f9d5b5..b4dbc486 100755 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/milvus/write.py @@ -3,15 +3,16 @@ Accepts entity/vector pairs and writes them to a Milvus store. """ -from .... schema import ChunkEmbeddings -from .... schema import chunk_embeddings_ingest_queue -from .... log_level import LogLevel from .... direct.milvus_doc_embeddings import DocVectors + +from .... schema import DocumentEmbeddings +from .... schema import document_embeddings_store_queue +from .... log_level import LogLevel from .... base import Consumer module = ".".join(__name__.split(".")[1:-1]) -default_input_queue = chunk_embeddings_ingest_queue +default_input_queue = document_embeddings_store_queue default_subscriber = module default_store_uri = 'http://localhost:19530' @@ -27,22 +28,27 @@ class Processor(Consumer): **params | { "input_queue": input_queue, "subscriber": subscriber, - "input_schema": ChunkEmbeddings, + "input_schema": DocumentEmbeddings, "store_uri": store_uri, } ) self.vecstore = DocVectors(store_uri) - def handle(self, msg): + async def handle(self, msg): v = msg.value() - chunk = v.chunk.decode("utf-8") + for emb in v.chunks: - if v.chunk != "" and v.chunk is not None: - for vec in v.vectors: - self.vecstore.insert(vec, chunk) + chunk = emb.chunk.decode("utf-8") + if chunk == "" or chunk is None: continue + + for vec in emb.vectors: + + if chunk != "" and v.chunk is not None: + for vec in v.vectors: + self.vecstore.insert(vec, chunk) @staticmethod def add_args(parser): @@ -59,5 +65,5 @@ class Processor(Consumer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py index 24cfcb78..9e91db9a 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/pinecone/write.py @@ -11,14 +11,14 @@ import time import uuid import os -from .... schema import ChunkEmbeddings -from .... schema import chunk_embeddings_ingest_queue +from .... schema import DocumentEmbeddings +from .... schema import document_embeddings_store_queue from .... log_level import LogLevel from .... base import Consumer module = ".".join(__name__.split(".")[1:-1]) -default_input_queue = chunk_embeddings_ingest_queue +default_input_queue = document_embeddings_store_queue default_subscriber = module default_api_key = os.getenv("PINECONE_API_KEY", "not-specified") default_cloud = "aws" @@ -54,82 +54,85 @@ class Processor(Consumer): **params | { "input_queue": input_queue, "subscriber": subscriber, - "input_schema": ChunkEmbeddings, + "input_schema": DocumentEmbeddings, "url": self.url, } ) self.last_index_name = None - def handle(self, msg): + async def handle(self, msg): v = msg.value() - chunk = v.chunk.decode("utf-8") + for emb in v.chunks: - if chunk == "": return + chunk = emb.chunk.decode("utf-8") + if chunk == "" or chunk is None: continue - for vec in v.vectors: + for vec in emb.vectors: - dim = len(vec) - collection = ( - "d-" + v.metadata.user + "-" + str(dim) - ) + for vec in v.vectors: - if index_name != self.last_index_name: + dim = len(vec) + collection = ( + "d-" + v.metadata.user + "-" + str(dim) + ) - if not self.pinecone.has_index(index_name): + if index_name != self.last_index_name: - try: + if not self.pinecone.has_index(index_name): - self.pinecone.create_index( - name = index_name, - dimension = dim, - metric = "cosine", - spec = ServerlessSpec( - cloud = self.cloud, - region = self.region, - ) - ) + try: - for i in range(0, 1000): + self.pinecone.create_index( + name = index_name, + dimension = dim, + metric = "cosine", + spec = ServerlessSpec( + cloud = self.cloud, + region = self.region, + ) + ) - if self.pinecone.describe_index( - index_name - ).status["ready"]: - break + for i in range(0, 1000): - time.sleep(1) + if self.pinecone.describe_index( + index_name + ).status["ready"]: + break - if not self.pinecone.describe_index( - index_name - ).status["ready"]: - raise RuntimeError( - "Gave up waiting for index creation" - ) + time.sleep(1) - except Exception as e: - print("Pinecone index creation failed") - raise e + if not self.pinecone.describe_index( + index_name + ).status["ready"]: + raise RuntimeError( + "Gave up waiting for index creation" + ) - print(f"Index {index_name} created", flush=True) + except Exception as e: + print("Pinecone index creation failed") + raise e - self.last_index_name = index_name + print(f"Index {index_name} created", flush=True) - index = self.pinecone.Index(index_name) + self.last_index_name = index_name - records = [ - { - "id": id, - "values": vec, - "metadata": { "doc": chunk }, - } - ] + index = self.pinecone.Index(index_name) - index.upsert( - vectors = records, - namespace = v.metadata.collection, - ) + records = [ + { + "id": id, + "values": vec, + "metadata": { "doc": chunk }, + } + ] + + index.upsert( + vectors = records, + namespace = v.metadata.collection, + ) @staticmethod def add_args(parser): @@ -163,5 +166,5 @@ class Processor(Consumer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py index 813c4f29..810c1931 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py @@ -8,14 +8,14 @@ from qdrant_client.models import PointStruct from qdrant_client.models import Distance, VectorParams import uuid -from .... schema import ChunkEmbeddings -from .... schema import chunk_embeddings_ingest_queue +from .... schema import DocumentEmbeddings +from .... schema import document_embeddings_store_queue from .... log_level import LogLevel from .... base import Consumer module = ".".join(__name__.split(".")[1:-1]) -default_input_queue = chunk_embeddings_ingest_queue +default_input_queue = document_embeddings_store_queue default_subscriber = module default_store_uri = 'http://localhost:6333' @@ -26,13 +26,15 @@ class Processor(Consumer): input_queue = params.get("input_queue", default_input_queue) subscriber = params.get("subscriber", default_subscriber) store_uri = params.get("store_uri", default_store_uri) + api_key = params.get("api_key", None) super(Processor, self).__init__( **params | { "input_queue": input_queue, "subscriber": subscriber, - "input_schema": ChunkEmbeddings, + "input_schema": DocumentEmbeddings, "store_uri": store_uri, + "api_key": api_key, } ) @@ -40,51 +42,52 @@ class Processor(Consumer): self.client = QdrantClient(url=store_uri) - def handle(self, msg): + async def handle(self, msg): v = msg.value() - chunk = v.chunk.decode("utf-8") + for emb in v.chunks: - if chunk == "": return + chunk = emb.chunk.decode("utf-8") + if chunk == "": return - for vec in v.vectors: + for vec in emb.vectors: - dim = len(vec) - collection = ( - "d_" + v.metadata.user + "_" + v.metadata.collection + "_" + - str(dim) - ) + dim = len(vec) + collection = ( + "d_" + v.metadata.user + "_" + v.metadata.collection + "_" + + str(dim) + ) - if collection != self.last_collection: + if collection != self.last_collection: - if not self.client.collection_exists(collection): + if not self.client.collection_exists(collection): - try: - self.client.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, distance=Distance.DOT - ), + try: + self.client.create_collection( + collection_name=collection, + vectors_config=VectorParams( + size=dim, distance=Distance.COSINE + ), + ) + except Exception as e: + print("Qdrant collection creation failed") + raise e + + self.last_collection = collection + + self.client.upsert( + collection_name=collection, + points=[ + PointStruct( + id=str(uuid.uuid4()), + vector=vec, + payload={ + "doc": chunk, + } ) - except Exception as e: - print("Qdrant collection creation failed") - raise e - - self.last_collection = collection - - self.client.upsert( - collection_name=collection, - points=[ - PointStruct( - id=str(uuid.uuid4()), - vector=vec, - payload={ - "doc": chunk, - } - ) - ] - ) + ] + ) @staticmethod def add_args(parser): @@ -96,10 +99,16 @@ class Processor(Consumer): parser.add_argument( '-t', '--store-uri', default=default_store_uri, - help=f'Qdrant store URI (default: {default_store_uri})' + help=f'Qdrant URI (default: {default_store_uri})' + ) + + parser.add_argument( + '-k', '--api-key', + default=None, + help=f'Qdrant API key (default: None)' ) def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py index 98fe7915..b2d40306 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/milvus/write.py @@ -34,13 +34,15 @@ class Processor(Consumer): self.vecstore = EntityVectors(store_uri) - def handle(self, msg): + async def handle(self, msg): v = msg.value() - if v.entity.value != "": - for vec in v.vectors: - self.vecstore.insert(vec, v.entity.value) + for entity in v.entities: + + if entity.entity.value != "" and entity.entity.value is not None: + for vec in entity.vectors: + self.vecstore.insert(vec, entity.entity.value) @staticmethod def add_args(parser): @@ -57,5 +59,5 @@ class Processor(Consumer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py index b918c10b..83861b54 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/pinecone/write.py @@ -60,76 +60,83 @@ class Processor(Consumer): self.last_index_name = None - def handle(self, msg): + def create_index(self, index_name, dim): + + self.pinecone.create_index( + name = index_name, + dimension = dim, + metric = "cosine", + spec = ServerlessSpec( + cloud = self.cloud, + region = self.region, + ) + ) + + for i in range(0, 1000): + + if self.pinecone.describe_index( + index_name + ).status["ready"]: + break + + time.sleep(1) + + if not self.pinecone.describe_index( + index_name + ).status["ready"]: + raise RuntimeError( + "Gave up waiting for index creation" + ) + + async def handle(self, msg): v = msg.value() id = str(uuid.uuid4()) - if v.entity.value == "" or v.entity.value is None: return + for entity in v.entities: - for vec in v.vectors: + if entity.entity.value == "" or entity.entity.value is None: + continue - dim = len(vec) + for vec in entity.vectors: - index_name = ( - "t-" + v.metadata.user + "-" + str(dim) - ) + dim = len(vec) - if index_name != self.last_index_name: + index_name = ( + "t-" + v.metadata.user + "-" + str(dim) + ) - if not self.pinecone.has_index(index_name): + if index_name != self.last_index_name: - try: + if not self.pinecone.has_index(index_name): - self.pinecone.create_index( - name = index_name, - dimension = dim, - metric = "cosine", - spec = ServerlessSpec( - cloud = self.cloud, - region = self.region, - ) - ) + try: - for i in range(0, 1000): + self.create_index(index_name, dim) - if self.pinecone.describe_index( - index_name - ).status["ready"]: - break + except Exception as e: + print("Pinecone index creation failed") + raise e - time.sleep(1) + print(f"Index {index_name} created", flush=True) - if not self.pinecone.describe_index( - index_name - ).status["ready"]: - raise RuntimeError( - "Gave up waiting for index creation" - ) + self.last_index_name = index_name - except Exception as e: - print("Pinecone index creation failed") - raise e + index = self.pinecone.Index(index_name) - print(f"Index {index_name} created", flush=True) + records = [ + { + "id": id, + "values": vec, + "metadata": { "entity": entity.entity.value }, + } + ] - self.last_index_name = index_name - - index = self.pinecone.Index(index_name) - - records = [ - { - "id": id, - "values": vec, - "metadata": { "entity": v.entity.value }, - } - ] - - index.upsert( - vectors = records, - namespace = v.metadata.collection, - ) + index.upsert( + vectors = records, + namespace = v.metadata.collection, + ) @staticmethod def add_args(parser): @@ -163,5 +170,5 @@ class Processor(Consumer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py index 47b53979..6b0d7371 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py @@ -26,6 +26,7 @@ class Processor(Consumer): input_queue = params.get("input_queue", default_input_queue) subscriber = params.get("subscriber", default_subscriber) store_uri = params.get("store_uri", default_store_uri) + api_key = params.get("api_key", None) super(Processor, self).__init__( **params | { @@ -33,56 +34,67 @@ class Processor(Consumer): "subscriber": subscriber, "input_schema": GraphEmbeddings, "store_uri": store_uri, + "api_key": api_key, } ) self.last_collection = None - self.client = QdrantClient(url=store_uri) + self.client = QdrantClient(url=store_uri, api_key=api_key) - def handle(self, msg): + def get_collection(self, dim, user, collection): + + cname = ( + "t_" + user + "_" + collection + "_" + str(dim) + ) + + if cname != self.last_collection: + + if not self.client.collection_exists(cname): + + try: + self.client.create_collection( + collection_name=cname, + vectors_config=VectorParams( + size=dim, distance=Distance.COSINE + ), + ) + except Exception as e: + print("Qdrant collection creation failed") + raise e + + self.last_collection = cname + + return cname + + async def handle(self, msg): v = msg.value() - if v.entity.value == "" or v.entity.value is None: return + for entity in v.entities: - for vec in v.vectors: + if entity.entity.value == "" or entity.entity.value is None: return - dim = len(vec) - collection = ( - "t_" + v.metadata.user + "_" + v.metadata.collection + "_" + - str(dim) - ) + for vec in entity.vectors: - if collection != self.last_collection: + dim = len(vec) - if not self.client.collection_exists(collection): + collection = self.get_collection( + dim, v.metadata.user, v.metadata.collection + ) - try: - self.client.create_collection( - collection_name=collection, - vectors_config=VectorParams( - size=dim, distance=Distance.COSINE - ), + self.client.upsert( + collection_name=collection, + points=[ + PointStruct( + id=str(uuid.uuid4()), + vector=vec, + payload={ + "entity": entity.entity.value, + } ) - except Exception as e: - print("Qdrant collection creation failed") - raise e - - self.last_collection = collection - - self.client.upsert( - collection_name=collection, - points=[ - PointStruct( - id=str(uuid.uuid4()), - vector=vec, - payload={ - "entity": v.entity.value, - } - ) - ] - ) + ] + ) @staticmethod def add_args(parser): @@ -96,8 +108,14 @@ class Processor(Consumer): default=default_store_uri, help=f'Qdrant store URI (default: {default_store_uri})' ) + + parser.add_argument( + '-k', '--api-key', + default=None, + help=f'Qdrant API key' + ) def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/write.py b/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/write.py index 468b357a..5490af97 100755 --- a/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/write.py +++ b/trustgraph-flow/trustgraph/storage/object_embeddings/milvus/write.py @@ -34,7 +34,7 @@ class Processor(Consumer): self.vecstore = ObjectVectors(store_uri) - def handle(self, msg): + async def handle(self, msg): v = msg.value() @@ -57,5 +57,5 @@ class Processor(Consumer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py index d44864fe..e6536e6c 100755 --- a/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/rows/cassandra/write.py @@ -10,6 +10,7 @@ import argparse import time from cassandra.cluster import Cluster from cassandra.auth import PlainTextAuthProvider +from ssl import SSLContext, PROTOCOL_TLSv1_2 from .... schema import Rows from .... schema import rows_store_queue @@ -17,6 +18,7 @@ from .... log_level import LogLevel from .... base import Consumer module = ".".join(__name__.split(".")[1:-1]) +ssl_context = SSLContext(PROTOCOL_TLSv1_2) default_input_queue = rows_store_queue default_subscriber = module @@ -29,6 +31,8 @@ class Processor(Consumer): input_queue = params.get("input_queue", default_input_queue) subscriber = params.get("subscriber", default_subscriber) graph_host = params.get("graph_host", default_graph_host) + graph_username = params.get("graph_username", None) + graph_password = params.get("graph_password", None) super(Processor, self).__init__( **params | { @@ -36,10 +40,16 @@ class Processor(Consumer): "subscriber": subscriber, "input_schema": Rows, "graph_host": graph_host, + "graph_username": graph_username, + "graph_password": graph_password, } ) - - self.cluster = Cluster(graph_host.split(",")) + + if graph_username and graph_password: + auth_provider = PlainTextAuthProvider(username=graph_username, password=graph_password) + self.cluster = Cluster(graph_host.split(","), auth_provider=auth_provider, ssl_context=ssl_context) + else: + self.cluster = Cluster(graph_host.split(",")) self.session = self.cluster.connect() self.tables = set() @@ -54,7 +64,7 @@ class Processor(Consumer): self.session.execute("use trustgraph"); - def handle(self, msg): + async def handle(self, msg): try: @@ -120,8 +130,20 @@ class Processor(Consumer): default="localhost", help=f'Graph host (default: localhost)' ) + + parser.add_argument( + '--graph-username', + default=None, + help=f'Cassandra username' + ) + + parser.add_argument( + '--graph-password', + default=None, + help=f'Cassandra password' + ) def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py index e7078e08..17b5ae9a 100755 --- a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py @@ -28,6 +28,8 @@ class Processor(Consumer): input_queue = params.get("input_queue", default_input_queue) subscriber = params.get("subscriber", default_subscriber) graph_host = params.get("graph_host", default_graph_host) + graph_username = params.get("graph_username", None) + graph_password = params.get("graph_password", None) super(Processor, self).__init__( **params | { @@ -35,13 +37,17 @@ class Processor(Consumer): "subscriber": subscriber, "input_schema": Triples, "graph_host": graph_host, + "graph_username": graph_username, + "graph_password": graph_password, } ) - + self.graph_host = [graph_host] + self.username = graph_username + self.password = graph_password self.table = None - def handle(self, msg): + async def handle(self, msg): v = msg.value() @@ -52,10 +58,17 @@ class Processor(Consumer): self.tg = None try: - self.tg = TrustGraph( - hosts=self.graph_host, - keyspace=v.metadata.user, table=v.metadata.collection, - ) + if self.username and self.password: + self.tg = TrustGraph( + hosts=self.graph_host, + keyspace=v.metadata.user, table=v.metadata.collection, + username=self.username, password=self.password + ) + else: + self.tg = TrustGraph( + hosts=self.graph_host, + keyspace=v.metadata.user, table=v.metadata.collection, + ) except Exception as e: print("Exception", e, flush=True) time.sleep(1) @@ -82,8 +95,20 @@ class Processor(Consumer): default="localhost", help=f'Graph host (default: localhost)' ) + + parser.add_argument( + '--graph-username', + default=None, + help=f'Cassandra username' + ) + + parser.add_argument( + '--graph-password', + default=None, + help=f'Cassandra password' + ) def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py b/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py index 3c7d1660..2d0ae38a 100755 --- a/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/falkordb/write.py @@ -118,7 +118,7 @@ class Processor(Consumer): time=res.run_time_ms )) - def handle(self, msg): + async def handle(self, msg): v = msg.value() @@ -154,5 +154,5 @@ class Processor(Consumer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py index f106170a..620e669e 100755 --- a/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/memgraph/write.py @@ -55,6 +55,14 @@ class Processor(Consumer): def create_indexes(self, session): + # Race condition, index creation failure is ignored. Right thing + # to do if the index already exists. Wrong thing to do if it's + # because the store is not up yet + + # In real-world cases, Memgraph will start up quicker than Pulsar + # and this process will restart several times until Pulsar arrives, + # so should be safe + print("Create indexes...", flush=True) try: @@ -197,7 +205,7 @@ class Processor(Consumer): src=t.s.value, dest=t.o.value, uri=t.p.value, ) - def handle(self, msg): + async def handle(self, msg): v = msg.value() @@ -248,5 +256,5 @@ class Processor(Consumer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py index 1aa25aa8..3323f912 100755 --- a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py @@ -50,6 +50,50 @@ class Processor(Consumer): self.io = GraphDatabase.driver(graph_host, auth=(username, password)) + with self.io.session(database=self.db) as session: + self.create_indexes(session) + + def create_indexes(self, session): + + # Race condition, index creation failure is ignored. Right thing + # to do if the index already exists. Wrong thing to do if it's + # because the store is not up yet + + # In real-world cases, Neo4j will start up quicker than Pulsar + # and this process will restart several times until Pulsar arrives, + # so should be safe + + print("Create indexes...", flush=True) + + try: + session.run( + "CREATE INDEX Node_uri FOR (n:Node) ON (n.uri)", + ) + except Exception as e: + print(e, flush=True) + # Maybe index already exists + print("Index create failure ignored", flush=True) + + try: + session.run( + "CREATE INDEX Literal_value FOR (n:Literal) ON (n.value)", + ) + except Exception as e: + print(e, flush=True) + # Maybe index already exists + print("Index create failure ignored", flush=True) + + try: + session.run( + "CREATE INDEX Rel_uri FOR ()-[r:Rel]-() ON (r.uri)", + ) + except Exception as e: + print(e, flush=True) + # Maybe index already exists + print("Index create failure ignored", flush=True) + + print("Index creation done", flush=True) + def create_node(self, uri): print("Create node", uri) @@ -114,7 +158,7 @@ class Processor(Consumer): time=summary.result_available_after )) - def handle(self, msg): + async def handle(self, msg): v = msg.value() @@ -162,5 +206,5 @@ class Processor(Consumer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph-ocr/README.md b/trustgraph-ocr/README.md new file mode 100644 index 00000000..7a2ce130 --- /dev/null +++ b/trustgraph-ocr/README.md @@ -0,0 +1 @@ +See https://trustgraph.ai/ diff --git a/trustgraph-ocr/scripts/pdf-ocr b/trustgraph-ocr/scripts/pdf-ocr new file mode 100755 index 00000000..1417351f --- /dev/null +++ b/trustgraph-ocr/scripts/pdf-ocr @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from trustgraph.decoding.ocr import run + +run() + diff --git a/trustgraph-ocr/setup.py b/trustgraph-ocr/setup.py new file mode 100644 index 00000000..43e15061 --- /dev/null +++ b/trustgraph-ocr/setup.py @@ -0,0 +1,47 @@ +import setuptools +import os +import importlib + +with open("README.md", "r") as fh: + long_description = fh.read() + +# Load a version number module +spec = importlib.util.spec_from_file_location( + 'version', 'trustgraph/ocr_version.py' +) +version_module = importlib.util.module_from_spec(spec) +spec.loader.exec_module(version_module) + +version = version_module.__version__ + +setuptools.setup( + name="trustgraph-ocr", + version=version, + author="trustgraph.ai", + author_email="security@trustgraph.ai", + description="TrustGraph provides a means to run a pipeline of flexible AI processing components in a flexible means to achieve a processing pipeline.", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/trustgraph-ai/trustgraph", + packages=setuptools.find_namespace_packages( + where='./', + ), + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)", + "Operating System :: OS Independent", + ], + python_requires='>=3.8', + download_url = "https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v" + version + ".tar.gz", + install_requires=[ + "trustgraph-base>=0.21,<0.22", + "pulsar-client", + "prometheus-client", + "boto3", + "pdf2image", + "pytesseract", + ], + scripts=[ + "scripts/pdf-ocr", + ] +) diff --git a/trustgraph-ocr/trustgraph/decoding/ocr/__init__.py b/trustgraph-ocr/trustgraph/decoding/ocr/__init__.py new file mode 100644 index 00000000..0d8d9c78 --- /dev/null +++ b/trustgraph-ocr/trustgraph/decoding/ocr/__init__.py @@ -0,0 +1,3 @@ + +from . pdf_decoder import * + diff --git a/trustgraph-ocr/trustgraph/decoding/ocr/__main__.py b/trustgraph-ocr/trustgraph/decoding/ocr/__main__.py new file mode 100755 index 00000000..44dd026d --- /dev/null +++ b/trustgraph-ocr/trustgraph/decoding/ocr/__main__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 + +from . pdf_decoder import run + +if __name__ == '__main__': + run() + diff --git a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py new file mode 100755 index 00000000..f8926589 --- /dev/null +++ b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py @@ -0,0 +1,83 @@ + +""" +Simple decoder, accepts PDF documents on input, outputs pages from the +PDF document as text as separate output objects. +""" + +import tempfile +import base64 +import pytesseract +from pdf2image import convert_from_bytes + +from ... schema import Document, TextDocument, Metadata +from ... schema import document_ingest_queue, text_ingest_queue +from ... log_level import LogLevel +from ... base import ConsumerProducer + +module = ".".join(__name__.split(".")[1:-1]) + +default_input_queue = document_ingest_queue +default_output_queue = text_ingest_queue +default_subscriber = module + +class Processor(ConsumerProducer): + + def __init__(self, **params): + + input_queue = params.get("input_queue", default_input_queue) + output_queue = params.get("output_queue", default_output_queue) + subscriber = params.get("subscriber", default_subscriber) + + super(Processor, self).__init__( + **params | { + "input_queue": input_queue, + "output_queue": output_queue, + "subscriber": subscriber, + "input_schema": Document, + "output_schema": TextDocument, + } + ) + + print("PDF OCR inited") + + async def handle(self, msg): + + print("PDF message received") + + v = msg.value() + + print(f"Decoding {v.metadata.id}...", flush=True) + + blob = base64.b64decode(v.data) + + pages = convert_from_bytes(blob) + + for ix, page in enumerate(pages): + + try: + text = pytesseract.image_to_string(page, lang='eng') + except Exception as e: + print(f"Page did not OCR: {e}") + continue + + r = TextDocument( + metadata=v.metadata, + text=text.encode("utf-8"), + ) + + await self.send(r) + + print("Done.", flush=True) + + @staticmethod + def add_args(parser): + + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) + +def run(): + + Processor.launch(module, __doc__) + diff --git a/trustgraph-vertexai/setup.py b/trustgraph-vertexai/setup.py index 7f9c2923..1258fea9 100644 --- a/trustgraph-vertexai/setup.py +++ b/trustgraph-vertexai/setup.py @@ -34,7 +34,7 @@ setuptools.setup( python_requires='>=3.8', download_url = "https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v" + version + ".tar.gz", install_requires=[ - "trustgraph-base>=0.18,<0.19", + "trustgraph-base>=0.21,<0.22", "pulsar-client", "google-cloud-aiplatform", "prometheus-client", diff --git a/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py b/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py index cb817836..4d38c8c0 100755 --- a/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py +++ b/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py @@ -131,7 +131,7 @@ class Processor(ConsumerProducer): print("Initialisation complete", flush=True) - def handle(self, msg): + async def handle(self, msg): try: @@ -169,7 +169,7 @@ class Processor(ConsumerProducer): model=self.model ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) print("Done.", flush=True) @@ -178,25 +178,15 @@ class Processor(ConsumerProducer): except google.api_core.exceptions.ResourceExhausted as e: - print("Send rate limit response...", flush=True) + print("Hit rate limit:", e, flush=True) - r = TextCompletionResponse( - error=Error( - type = "rate-limit", - message = str(e), - ), - response=None, - in_token=None, - out_token=None, - model=None, - ) - - self.producer.send(r, properties={"id": id}) - - self.consumer.acknowledge(msg) + # Leave rate limit retries to the base handler + raise TooManyRequests() except Exception as e: + # Apart from rate limits, treat all exceptions as unrecoverable + print(f"Exception: {e}") print("Send error response...", flush=True) @@ -212,7 +202,7 @@ class Processor(ConsumerProducer): model=None, ) - self.producer.send(r, properties={"id": id}) + await self.send(r, properties={"id": id}) self.consumer.acknowledge(msg) @@ -258,5 +248,5 @@ class Processor(ConsumerProducer): def run(): - Processor.start(module, __doc__) + Processor.launch(module, __doc__) diff --git a/trustgraph/setup.py b/trustgraph/setup.py index a964ff06..d7185e66 100644 --- a/trustgraph/setup.py +++ b/trustgraph/setup.py @@ -34,12 +34,12 @@ setuptools.setup( python_requires='>=3.8', download_url = "https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v" + version + ".tar.gz", install_requires=[ - "trustgraph-base>=0.18,<0.19", - "trustgraph-bedrock>=0.18,<0.19", - "trustgraph-cli>=0.18,<0.19", - "trustgraph-embeddings-hf>=0.18,<0.19", - "trustgraph-flow>=0.18,<0.19", - "trustgraph-vertexai>=0.18,<0.19", + "trustgraph-base>=0.21,<0.22", + "trustgraph-bedrock>=0.21,<0.22", + "trustgraph-cli>=0.21,<0.22", + "trustgraph-embeddings-hf>=0.21,<0.22", + "trustgraph-flow>=0.21,<0.22", + "trustgraph-vertexai>=0.21,<0.22", ], scripts=[ ]