diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml new file mode 100644 index 00000000..6080b661 --- /dev/null +++ b/.github/workflows/pull-request.yaml @@ -0,0 +1,20 @@ + +name: Test pull request + +on: + pull_request: + +permissions: + contents: read + +jobs: + + container-push: + + name: Do nothing + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v3 + diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml new file mode 100644 index 00000000..5ca3b735 --- /dev/null +++ b/.github/workflows/release.yaml @@ -0,0 +1,86 @@ + +name: Build + +on: + workflow_dispatch: + push: + tags: + - v0.15.* + +permissions: + contents: read + +jobs: + + deploy: + + name: Build everything + runs-on: ubuntu-latest + permissions: + contents: write + id-token: write + environment: + name: release + + steps: + + - name: Checkout + uses: actions/checkout@v3 + + - name: Log in to Docker Hub + uses: docker/login-action@f4ef78c080cd8ba55a85445d5b36e214a81df20a + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_SECRET }} + + - name: Install build dependencies + run: pip3 install jsonnet + + - name: Get version + id: version + run: echo VERSION=$(git describe --exact-match --tags | sed 's/^v//') >> $GITHUB_OUTPUT + + - run: echo ${{ steps.version.outputs.VERSION }} + + - name: Build packages + run: make packages VERSION=${{ steps.version.outputs.VERSION }} + + - name: Publish release distributions to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + + - name: Create deploy bundle + run: templates/generate-all deploy.zip ${{ steps.version.outputs.VERSION }} + + - uses: ncipollo/release-action@v1 + with: + artifacts: deploy.zip + generateReleaseNotes: true + makeLatest: false + prerelease: true + skipIfReleaseExists: true + + - name: Build container + run: make container VERSION=${{ steps.version.outputs.VERSION }} + + - name: Extract metadata for container + id: meta + uses: docker/metadata-action@v4 + with: + images: trustgraph/trustgraph-flow + tags: | + type=ref,event=branch + type=ref,event=pr + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=sha + + - name: Build and push Docker image + id: push + uses: docker/build-push-action@3b5e8027fcad23fda98b2e3ac259d8d67585f671 + with: + context: . + file: ./Containerfile + push: true + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + diff --git a/Containerfile b/Containerfile index e8daa861..0d6d357b 100644 --- a/Containerfile +++ b/Containerfile @@ -13,7 +13,7 @@ RUN dnf install -y python3 python3-pip python3-wheel python3-aiohttp \ RUN pip3 install torch --index-url https://download.pytorch.org/whl/cpu -RUN pip3 install anthropic boto3 cohere openai google-cloud-aiplatform ollama \ +RUN pip3 install anthropic boto3 cohere openai google-cloud-aiplatform ollama google-generativeai \ langchain langchain-core langchain-huggingface langchain-text-splitters \ langchain-community pymilvus sentence-transformers transformers \ huggingface-hub pulsar-client cassandra-driver pyarrow pyyaml \ diff --git a/Makefile b/Makefile index 6e78df46..0fb4b175 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,7 @@ # VERSION=$(shell git describe | sed 's/^v//') -VERSION=0.11.20 + +VERSION=0.0.0 DOCKER=podman @@ -35,6 +36,7 @@ CONTAINER=docker.io/trustgraph/trustgraph-flow update-package-versions: mkdir -p trustgraph-cli/trustgraph + mkdir -p trustgraph/trustgraph echo __version__ = \"${VERSION}\" > trustgraph-base/trustgraph/base_version.py echo __version__ = \"${VERSION}\" > trustgraph-flow/trustgraph/flow_version.py echo __version__ = \"${VERSION}\" > trustgraph-vertexai/trustgraph/vertexai_version.py diff --git a/docs/README.agent-demo b/docs/README.agent-demo new file mode 100644 index 00000000..491755c3 --- /dev/null +++ b/docs/README.agent-demo @@ -0,0 +1,18 @@ +podman-compose -f docker-compose.yaml up -d + + +tg-processor-state + +tg-load-text --keyword cats animals home-life --name "Mark's cats" --description "This document describes Mark's cats" --copyright-notice 'Public domain' --publication-organization 'trustgraph.ai' --publication-date 2024-10-23 --copyright-holder 'trustgraph.ai' --copyright-year 2024 --publication-description 'Uploading to Github' --url https://example.com --id TG-000001 ../trustgraph/README.cats + +tg-load-text --keyword nasa challenger space-shuttle shuttle orbiter --name 'Challenger Report Volume 1' --description 'The findings of the Presidential Commission regarding the circumstances surrounding the Challenger accident are reported and recommendations for corrective action are outlined' --copyright-notice 'Work of the US Gov. Public Use Permitted' --publication-organization 'NASA' --publication-date 1986-06-06 --copyright-holder 'US Government' --copyright-year 1986 --publication-description 'The findings of the Commission regarding the circumstances surrounding the Challenger accident are reported' --url https://ntrs.nasa.gov/citations/19860015255 --id AD-A171402 ../trustgraph/README.challenger + + +tg-graph-show + +tg-query-graph-rag -q 'Tell me cat facts' + + +tg-invoke-agent -v -q "How many cats does Mark have? Calculate that number raised to 0.4 power. Is that number lower than the numeric part of the mission identifier of the Space Shuttle Challenger on its last mission? If so, give me an apple pie recipe, otherwise return a poem about cheese." + + diff --git a/docs/README.cats b/docs/README.cats new file mode 100644 index 00000000..9d4f4e91 --- /dev/null +++ b/docs/README.cats @@ -0,0 +1,35 @@ + +My name is Mark. + +I have 2 cats: +- Fred is a big, fat, orange, stripy cat. He is 12 years old and has 4 legs. +- Hope is a small, black cat. She is 7 years old and also has 4 legs. + +Fred has 4 legs. + +Hope has 4 legs. + +Fred and Hope are nice animals, but occasionally they fight. + +Fred is lazy and sleeps a lot. Hope is energetic, runs around a lot and +climbs trees. + +Both cats have tails and whiskers like all cats do. + +Cats have the species name Felis catus. + +The cat (Felis catus), also referred to as domestic cat or house cat, is a +small domesticated carnivorous mammal. It is the only domesticated species of +the family Felidae. Advances in archaeology and genetics have shown that the +domestication of the cat occurred in the Near East around 7500 BC. It is +commonly kept as a pet and farm cat, but also ranges freely as a feral cat +avoiding human contact. Valued by humans for companionship and its ability to +kill vermin, the cat's retractable claws are adapted to killing small prey +like mice and rats. It has a strong, flexible body, quick reflexes, and sharp +teeth, and its night vision and sense of smell are well developed. It is a +social species, but a solitary hunter and a crepuscular predator. Cat +communication includes vocalizations—including meowing, purring, trilling, +hissing, growling, and grunting–as well as body language. It can hear sounds +too faint or too high in frequency for human ears, such as those made by small +mammals. It secretes and perceives pheromones. + diff --git a/docs/README.challenger b/docs/README.challenger new file mode 100644 index 00000000..2e594ff2 --- /dev/null +++ b/docs/README.challenger @@ -0,0 +1,54 @@ + +On January 28, 1986, the Space Shuttle Challenger broke apart 73 seconds into +its flight, killing all seven crew members aboard. The spacecraft +disintegrated 46,000 feet (14 km) above the Atlantic Ocean, off the coast of +Cape Canaveral, Florida, at 11:39 a.m. EST (16:39 UTC). It was the first fatal +accident involving an American spacecraft while in flight. + +The mission, designated STS-51-L, was the 10th flight for the orbiter and the +25th flight of the Space Shuttle fleet. The crew was scheduled to deploy a +communications satellite and study Halley's Comet while they were in orbit, in +addition to taking schoolteacher Christa McAuliffe into space under the +Teacher In Space program. The latter task resulted in a higher-than-usual +media interest in and coverage of the mission; the launch and subsequent +disaster were seen live in many schools across the United States. + +The cause of the disaster was the failure of the primary and secondary O-ring +seals in a joint in the shuttle's right solid rocket booster (SRB). The +record-low temperatures on the morning of the launch had stiffened the rubber +O-rings, reducing their ability to seal the joints. Shortly after liftoff, the +seals were breached, and hot pressurized gas from within the SRB leaked +through the joint and burned through the aft attachment strut connecting it to +the external propellant tank (ET), then into the tank itself. The collapse of +the ET's internal structures and the rotation of the SRB that followed threw +the shuttle stack, traveling at a speed of Mach 1.92, into a direction that +allowed aerodynamic forces to tear the orbiter apart. Both SRBs detached from +the now-destroyed ET and continued to fly uncontrollably until the range +safety officer destroyed them. + +The crew compartment, human remains, and many other fragments from the shuttle +were recovered from the ocean floor after a three-month search-and-recovery +operation. The exact timing of the deaths of the crew is unknown, but several +crew members are thought to have survived the initial breakup of the +spacecraft. The orbiter had no escape system, and the impact of the crew +compartment at terminal velocity with the ocean surface was too violent to be +survivable. + +The disaster resulted in a 32-month hiatus in the Space Shuttle +program. President Ronald Reagan created the Rogers Commission to investigate +the accident. The commission criticized NASA's organizational culture and +decision-making processes that had contributed to the accident. Test data +since 1977 demonstrated a potentially catastrophic flaw in the SRBs' O-rings, +but neither NASA nor SRB manufacturer Morton Thiokol had addressed this known +defect. NASA managers also disregarded engineers' warnings about the dangers +of launching in cold temperatures and did not report these technical concerns +to their superiors. + +As a result of this disaster, NASA established the Office of Safety, +Reliability, and Quality Assurance, and arranged for deployment of commercial +satellites from expendable launch vehicles rather than from a crewed +orbiter. To replace Challenger, the construction of a new Space Shuttle +orbiter, Endeavour, was approved in 1987, and the new orbiter first flew in +1992. Subsequent missions were launched with redesigned SRBs and their crews +wore pressurized suits during ascent and reentry. + diff --git a/docs/README.quickstart-docker-compose.md b/docs/README.quickstart-docker-compose.md index 5e2fa237..a81da9bc 100644 --- a/docs/README.quickstart-docker-compose.md +++ b/docs/README.quickstart-docker-compose.md @@ -186,7 +186,7 @@ To change the `Ollama` model, first make sure the desired model has been pulled ### OpenAI API ``` -export OPENAI_KEY= +export OPENAI_TOKEN= docker compose -f tg-launch-openai-cassandra.yaml up -d # Using Cassandra as the graph store docker compose -f tg-launch-openai-neo4j.yaml up -d # Using Neo4j as the graph store ``` @@ -458,4 +458,4 @@ docker compose -f tg-launch--.yaml down -v > To confirm all Docker volumes have been removed, check that the following list is empty: > ``` > docker volume ls -> ``` \ No newline at end of file +> ``` diff --git a/schema.ttl b/schema.ttl new file mode 100644 index 00000000..91e1a394 --- /dev/null +++ b/schema.ttl @@ -0,0 +1,32 @@ +@prefix ns1: . +@prefix rdf: . +@prefix rdfs: . +@prefix schema: . +@prefix skos: . + +schema:subjectOf rdfs:label "subject of" . +skos:definition rdfs:label "definition" . + +rdf:type rdfs:label "type" . + +schema:DigitalDocument rdfs:label "digital document" . +schema:Organization rdfs:label "organization" . +schema:PublicationEvent rdfs:label "publication event" . + +schema:copyrightNotice rdfs:label "copyright notice" . +schema:copyrightHolder rdfs:label "copyright holder" . +schema:copyrightYear rdfs:label "copyright year" . +schema:license rdfs:label "license" . +schema:publication rdfs:label "publication" . +schema:startDate rdfs:label "start date" . +schema:endDate rdfs:label "end date" . +schema:publishedBy rdfs:label "published by" . +schema:datePublished rdfs:label "date published" . +schema:publication rdfs:label "publication" . +schema:datePublished rdfs:label "date published" . +schema:url rdfs:label "url" . +schema:identifier rdfs:label "identifier" . +schema:keywords rdfs:label "keyword" . + +skos:definition rdfs:label "definition" . + diff --git a/templates/all-patterns.jsonnet b/templates/all-patterns.jsonnet index 29384d1e..47622939 100644 --- a/templates/all-patterns.jsonnet +++ b/templates/all-patterns.jsonnet @@ -7,6 +7,7 @@ import "patterns/triple-store-neo4j.jsonnet", import "patterns/graph-rag.jsonnet", import "patterns/llm-azure.jsonnet", + import "patterns/llm-azure-openai.jsonnet", import "patterns/llm-bedrock.jsonnet", import "patterns/llm-claude.jsonnet", import "patterns/llm-cohere.jsonnet", diff --git a/templates/components.jsonnet b/templates/components.jsonnet index 8ed2da0e..ec7f862b 100644 --- a/templates/components.jsonnet +++ b/templates/components.jsonnet @@ -1,11 +1,13 @@ { "azure": import "components/azure.jsonnet", + "azure-openai": import "components/azure-openai.jsonnet", "bedrock": import "components/bedrock.jsonnet", "claude": import "components/claude.jsonnet", "cohere": import "components/cohere.jsonnet", "document-rag": import "components/document-rag.jsonnet", "embeddings-hf": import "components/embeddings-hf.jsonnet", "embeddings-ollama": import "components/embeddings-ollama.jsonnet", + "googleaistudio": import "components/googleaistudio.jsonnet", "grafana": import "components/grafana.jsonnet", "graph-rag": import "components/graph-rag.jsonnet", "triple-store-cassandra": import "components/cassandra.jsonnet", @@ -14,11 +16,10 @@ "ollama": import "components/ollama.jsonnet", "openai": import "components/openai.jsonnet", "override-recursive-chunker": import "components/chunker-recursive.jsonnet", - "prompt-template-definitions": import "components/null.jsonnet", - "prompt-template-document-query": import "components/null.jsonnet", - "prompt-template-kq-query": import "components/null.jsonnet", - "prompt-template-relationships": import "components/null.jsonnet", - "prompt-template-rows-template": import "components/null.jsonnet", + + "prompt-template": import "components/prompt-template.jsonnet", + "prompt-overrides": import "components/prompt-overrides.jsonnet", + "pulsar": import "components/pulsar.jsonnet", "pulsar-manager": import "components/pulsar-manager.jsonnet", "trustgraph-base": import "components/trustgraph.jsonnet", @@ -27,6 +28,8 @@ "vertexai": import "components/vertexai.jsonnet", "null": {}, + "agent-manager-react": import "components/agent-manager-react.jsonnet", + // FIXME: Dupes "cassandra": import "components/cassandra.jsonnet", "neo4j": import "components/neo4j.jsonnet", diff --git a/templates/components/agent-manager-react.jsonnet b/templates/components/agent-manager-react.jsonnet new file mode 100644 index 00000000..a995dba5 --- /dev/null +++ b/templates/components/agent-manager-react.jsonnet @@ -0,0 +1,60 @@ +local base = import "base/base.jsonnet"; +local images = import "values/images.jsonnet"; +local url = import "values/url.jsonnet"; +local prompts = import "prompts/mixtral.jsonnet"; +local default_prompts = import "prompts/default-prompts.jsonnet"; + +{ + + tools:: [], + + "agent-manager" +: { + + create:: function(engine) + + local container = + engine.container("agent-manager") + .with_image(images.trustgraph) + .with_command([ + "agent-manager-react", + "-p", + url.pulsar, + "--tool-type", + ] + [ + tool.id + "=" + tool.type + for tool in $.tools + ] + [ + "--tool-description" + ] + [ + tool.id + "=" + tool.description + for tool in $.tools + ] + [ + "--tool-argument" + ] + [ + "%s=%s:%s:%s" % [ + tool.id, arg.name, arg.type, arg.description + ] + for tool in $.tools + for arg in tool.arguments + ] + ) + .with_limits("0.5", "128M") + .with_reservations("0.1", "128M"); + + local containerSet = engine.containers( + "agent-manager", [ container ] + ); + + local service = + engine.internalService(containerSet) + .with_port(8000, 8000, "metrics"); + + engine.resources([ + containerSet, + service, + ]) + + }, + +} + default_prompts + diff --git a/templates/components/azure-openai.jsonnet b/templates/components/azure-openai.jsonnet new file mode 100644 index 00000000..cc3847c0 --- /dev/null +++ b/templates/components/azure-openai.jsonnet @@ -0,0 +1,84 @@ +local base = import "base/base.jsonnet"; +local images = import "values/images.jsonnet"; +local url = import "values/url.jsonnet"; +local prompts = import "prompts/mixtral.jsonnet"; + +{ + + "azure-openai-model":: "GPT-3.5-Turbo", + "azure-openai-max-output-tokens":: 4192, + "azure-openai-temperature":: 0.0, + + "text-completion" +: { + + create:: function(engine) + + local envSecrets = engine.envSecrets("azure-openai-credentials") + .with_env_var("AZURE_TOKEN", "azure-token"); + + local container = + engine.container("text-completion") + .with_image(images.trustgraph) + .with_command([ + "text-completion-azure-openai", + "-p", + url.pulsar, + "-m", + $["azure-openai-model"], + "-x", + std.toString($["azure-openai-max-output-tokens"]), + "-t", + "%0.3f" % $["azure-openai-temperature"], + ]) + .with_env_var_secrets(envSecrets) + .with_limits("0.5", "128M") + .with_reservations("0.1", "128M"); + + local containerRag = + engine.container("text-completion-rag") + .with_image(images.trustgraph) + .with_command([ + "text-completion-azure", + "-p", + url.pulsar, + "-x", + std.toString($["azure-openai-max-output-tokens"]), + "-t", + "%0.3f" % $["azure-openai-temperature"], + "-i", + "non-persistent://tg/request/text-completion-rag", + "-o", + "non-persistent://tg/response/text-completion-rag-response", + ]) + .with_env_var_secrets(envSecrets) + .with_limits("0.5", "128M") + .with_reservations("0.1", "128M"); + + local containerSet = engine.containers( + "text-completion", [ container ] + ); + + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] + ); + + local service = + engine.internalService(containerSet) + .with_port(8000, 8000, "metrics"); + + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8000, 8000, "metrics"); + + engine.resources([ + envSecrets, + containerSet, + containerSetRag, + service, + serviceRag, + ]) + + }, + +} + prompts + diff --git a/templates/components/azure.jsonnet b/templates/components/azure.jsonnet index 3ee819ee..82b79133 100644 --- a/templates/components/azure.jsonnet +++ b/templates/components/azure.jsonnet @@ -5,8 +5,6 @@ local prompts = import "prompts/mixtral.jsonnet"; { - "azure-token":: "${AZURE_TOKEN}", - "azure-endpoint":: "${AZURE_ENDPOINT}", "azure-max-output-tokens":: 4096, "azure-temperature":: 0.0, @@ -14,6 +12,10 @@ local prompts = import "prompts/mixtral.jsonnet"; create:: function(engine) + local envSecrets = engine.envSecrets("azure-credentials") + .with_env_var("AZURE_TOKEN", "azure-token") + .with_env_var("AZURE_ENDPOINT", "azure-endpoint"); + local container = engine.container("text-completion") .with_image(images.trustgraph) @@ -21,15 +23,32 @@ local prompts = import "prompts/mixtral.jsonnet"; "text-completion-azure", "-p", url.pulsar, - "-k", - $["azure-token"], - "-e", - $["azure-endpoint"], "-x", std.toString($["azure-max-output-tokens"]), "-t", - std.toString($["azure-temperature"]), + "%0.3f" % $["azure-temperature"], ]) + .with_env_var_secrets(envSecrets) + .with_limits("0.5", "128M") + .with_reservations("0.1", "128M"); + + local containerRag = + engine.container("text-completion-rag") + .with_image(images.trustgraph) + .with_command([ + "text-completion-azure", + "-p", + url.pulsar, + "-x", + std.toString($["azure-max-output-tokens"]), + "-t", + "%0.3f" % $["azure-temperature"], + "-i", + "non-persistent://tg/request/text-completion-rag", + "-o", + "non-persistent://tg/response/text-completion-rag-response", + ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); @@ -37,57 +56,25 @@ local prompts = import "prompts/mixtral.jsonnet"; "text-completion", [ container ] ); - local service = - engine.internalService(containerSet) - .with_port(8000, 8000, "metrics"); - - engine.resources([ - containerSet, - service, - ]) - - }, - - "text-completion-rag" +: { - - create:: function(engine) - - local container = - engine.container("text-completion-rag") - .with_image(images.trustgraph) - .with_command([ - "text-completion-azure", - "-p", - url.pulsar, - "-k", - $["azure-token"], - "-e", - $["azure-endpoint"], - "-x", - std.toString($["azure-max-output-tokens"]), - "-t", - std.toString($["azure-temperature"]), - "-i", - "non-persistent://tg/request/text-completion-rag", - "-o", - "non-persistent://tg/response/text-completion-rag-response", - ]) - .with_limits("0.5", "128M") - .with_reservations("0.1", "128M"); - - local containerSet = engine.containers( - "text-completion-rag", [ container ] + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] ); local service = engine.internalService(containerSet) .with_port(8000, 8000, "metrics"); - engine.resources([ - containerSet, - service, - ]) + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8000, 8000, "metrics"); + engine.resources([ + envSecrets, + containerSet, + containerSetRag, + service, + serviceRag, + ]) } diff --git a/templates/components/bedrock.jsonnet b/templates/components/bedrock.jsonnet index 1c375621..93978a59 100644 --- a/templates/components/bedrock.jsonnet +++ b/templates/components/bedrock.jsonnet @@ -6,9 +6,6 @@ local chunker = import "chunker-recursive.jsonnet"; { - "aws-id-key":: "${AWS_ID_KEY}", - "aws-secret-key":: "${AWS_SECRET_KEY}", - "aws-region":: "us-west-2", "bedrock-max-output-tokens":: 4096, "bedrock-temperature":: 0.0, "bedrock-model":: "mistral.mixtral-8x7b-instruct-v0:1", @@ -17,6 +14,11 @@ local chunker = import "chunker-recursive.jsonnet"; create:: function(engine) + local envSecrets = engine.envSecrets("bedrock-credentials") + .with_env_var("AWS_ID_KEY", "aws-id-key") + .with_env_var("AWS_SECRET", "aws-secret") + .with_env_var("AWS_REGION", "aws-region"); + local container = engine.container("text-completion") .with_image(images.trustgraph) @@ -24,58 +26,28 @@ local chunker = import "chunker-recursive.jsonnet"; "text-completion-bedrock", "-p", url.pulsar, - "-z", - $["aws-id-key"], - "-k", - $["aws-secret-key"], - "-r", - $["aws-region"], "-x", std.toString($["bedrock-max-output-tokens"]), "-t", - std.toString($["bedrock-temperature"]), + "%0.3f" % $["bedrock-temperature"], "-m", $["bedrock-model"], ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); - local containerSet = engine.containers( - "text-completion", [ container ] - ); - - local service = - engine.internalService(containerSet) - .with_port(8000, 8000, "metrics"); - - engine.resources([ - containerSet, - service, - ]) - - }, - - "text-completion-rag" +: { - - create:: function(engine) - - local container = + local containerRag = engine.container("text-completion-rag") .with_image(images.trustgraph) .with_command([ "text-completion-bedrock", "-p", url.pulsar, - "-z", - $["aws-id-key"], - "-k", - $["aws-secret-key"], - "-r", - $["aws-region"], "-x", std.toString($["bedrock-max-output-tokens"]), "-t", - std.toString($["bedrock-temperature"]), + "%0.3f" % $["bedrock-temperature"], "-m", $["bedrock-model"], "-i", @@ -83,24 +55,35 @@ local chunker = import "chunker-recursive.jsonnet"; "-o", "non-persistent://tg/response/text-completion-rag-response", ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); local containerSet = engine.containers( - "text-completion-rag", [ container ] + "text-completion", [ container ] + ); + + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] ); local service = engine.internalService(containerSet) .with_port(8000, 8000, "metrics"); + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8000, 8000, "metrics"); + engine.resources([ + envSecrets, containerSet, + containerSetRag, service, + serviceRag, ]) - - } + }, } + prompts + chunker diff --git a/templates/components/cassandra.jsonnet b/templates/components/cassandra.jsonnet index b9345fed..b52d4b04 100644 --- a/templates/components/cassandra.jsonnet +++ b/templates/components/cassandra.jsonnet @@ -24,7 +24,7 @@ cassandra + { .with_reservations("0.1", "128M"); local containerSet = engine.containers( - "stop-triples", [ container ] + "store-triples", [ container ] ); local service = diff --git a/templates/components/claude.jsonnet b/templates/components/claude.jsonnet index 0cd190d4..c6c94e21 100644 --- a/templates/components/claude.jsonnet +++ b/templates/components/claude.jsonnet @@ -5,7 +5,6 @@ local prompts = import "prompts/mixtral.jsonnet"; { - "claude-key":: "${CLAUDE_KEY}", "claude-max-output-tokens":: 4096, "claude-temperature":: 0.0, @@ -13,6 +12,9 @@ local prompts = import "prompts/mixtral.jsonnet"; create:: function(engine) + local envSecrets = engine.envSecrets("claude-credentials") + .with_env_var("CLAUDE_KEY_TOKEN", "claude-key"); + local container = engine.container("text-completion") .with_image(images.trustgraph) @@ -20,13 +22,32 @@ local prompts = import "prompts/mixtral.jsonnet"; "text-completion-claude", "-p", url.pulsar, - "-k", - $["claude-key"], "-x", std.toString($["claude-max-output-tokens"]), "-t", - std.toString($["claude-temperature"]), + "%0.3f" % $["claude-temperature"], ]) + .with_env_var_secrets(envSecrets) + .with_limits("0.5", "128M") + .with_reservations("0.1", "128M"); + + local containerRag = + engine.container("text-completion-rag") + .with_image(images.trustgraph) + .with_command([ + "text-completion-claude", + "-p", + url.pulsar, + "-x", + std.toString($["claude-max-output-tokens"]), + "-t", + "%0.3f" % $["claude-temperature"], + "-i", + "non-persistent://tg/request/text-completion-rag", + "-o", + "non-persistent://tg/response/text-completion-rag-response", + ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); @@ -34,57 +55,27 @@ local prompts = import "prompts/mixtral.jsonnet"; "text-completion", [ container ] ); - local service = - engine.internalService(containerSet) - .with_port(8000, 8000, "metrics"); - - engine.resources([ - containerSet, - service, - ]) - - }, - - "text-completion-rag" +: { - - create:: function(engine) - - local container = - engine.container("text-completion-rag") - .with_image(images.trustgraph) - .with_command([ - "text-completion-claude", - "-p", - url.pulsar, - "-k", - $["claude-key"], - "-x", - std.toString($["claude-max-output-tokens"]), - "-t", - std.toString($["claude-temperature"]), - "-i", - "non-persistent://tg/request/text-completion-rag", - "-o", - "non-persistent://tg/response/text-completion-rag-response", - ]) - .with_limits("0.5", "128M") - .with_reservations("0.1", "128M"); - - local containerSet = engine.containers( - "text-completion-rag", [ container ] + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] ); local service = engine.internalService(containerSet) .with_port(8000, 8000, "metrics"); + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8000, 8000, "metrics"); + engine.resources([ + envSecrets, containerSet, + containerSetRag, service, + serviceRag, ]) - - } + }, } + prompts diff --git a/templates/components/cohere.jsonnet b/templates/components/cohere.jsonnet index f05cb635..11c30fbd 100644 --- a/templates/components/cohere.jsonnet +++ b/templates/components/cohere.jsonnet @@ -9,13 +9,15 @@ local prompts = import "prompts/mixtral.jsonnet"; "chunk-size":: 150, "chunk-overlap":: 10, - "cohere-key":: "${COHERE_KEY}", "cohere-temperature":: 0.0, "text-completion" +: { create:: function(engine) + local envSecrets = engine.envSecrets("cohere-credentials") + .with_env_var("COHERE_KEY", "cohere-key"); + local container = engine.container("text-completion") .with_image(images.trustgraph) @@ -23,44 +25,21 @@ local prompts = import "prompts/mixtral.jsonnet"; "text-completion-cohere", "-p", url.pulsar, - "-k", - $["cohere-key"], "-t", - std.toString($["cohere-temperature"]), + "%0.3f" % $["cohere-temperature"], ]) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); - local containerSet = engine.containers( - "text-completion", [ container ] - ); - - local service = - engine.internalService(containerSet) - .with_port(8000, 8000, "metrics"); - - engine.resources([ - containerSet, - service, - ]) - - }, - - "text-completion-rag" +: { - - create:: function(engine) - - local container = + local containerRag = engine.container("text-completion-rag") .with_image(images.trustgraph) .with_command([ "text-completion-cohere", "-p", url.pulsar, - "-k", - $["cohere-key"], "-t", - std.toString($["cohere-temperature"]), + "%0.3f" % $["cohere-temperature"], "-i", "non-persistent://tg/request/text-completion-rag", "-o", @@ -70,20 +49,30 @@ local prompts = import "prompts/mixtral.jsonnet"; .with_reservations("0.1", "128M"); local containerSet = engine.containers( - "text-completion-rag", [ container ] + "text-completion", [ container ] + ); + + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] ); local service = engine.internalService(containerSet) .with_port(8000, 8000, "metrics"); + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8000, 8000, "metrics"); + engine.resources([ + envSecrets, containerSet, + containerSetRag, service, + serviceRag, ]) - - } + }, } + prompts diff --git a/templates/components/googleaistudio.jsonnet b/templates/components/googleaistudio.jsonnet new file mode 100644 index 00000000..b6ee1d85 --- /dev/null +++ b/templates/components/googleaistudio.jsonnet @@ -0,0 +1,86 @@ +local base = import "base/base.jsonnet"; +local images = import "values/images.jsonnet"; +local url = import "values/url.jsonnet"; +local prompts = import "prompts/mixtral.jsonnet"; + +{ + + "googleaistudio-max-output-tokens":: 4096, + "googleaistudio-temperature":: 0.0, + "googleaistudio-model":: "gemini-1.5-flash-002", + + "text-completion" +: { + + create:: function(engine) + + local envSecrets = engine.envSecrets("bedrock-credentials") + .with_env_var("GOOGLE_AI_STUDIO_KEY", "googleaistudio-key"); + + local container = + engine.container("text-completion") + .with_image(images.trustgraph) + .with_command([ + "text-completion-googleaistudio", + "-p", + url.pulsar, + "-x", + std.toString($["googleaistudio-max-output-tokens"]), + "-t", + "%0.3f" % $["googleaistudio-temperature"], + "-m", + $["googleaistudio-model"], + ]) + .with_env_var_secrets(envSecrets) + .with_limits("0.5", "128M") + .with_reservations("0.1", "128M"); + + local containerRag = + engine.container("text-completion-rag") + .with_image(images.trustgraph) + .with_command([ + "text-completion-googleaistudio", + "-p", + url.pulsar, + "-x", + std.toString($["googleaistudio-max-output-tokens"]), + "-t", + "%0.3f" % $["googleaistudio-temperature"], + "-m", + $["googleaistudio-model"], + "-i", + "non-persistent://tg/request/text-completion-rag", + "-o", + "non-persistent://tg/response/text-completion-rag-response", + ]) + .with_env_var_secrets(envSecrets) + .with_limits("0.5", "128M") + .with_reservations("0.1", "128M"); + + local containerSet = engine.containers( + "text-completion", [ container ] + ); + + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] + ); + + local service = + engine.internalService(containerSet) + .with_port(8000, 8000, "metrics"); + + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8000, 8000, "metrics"); + + engine.resources([ + envSecrets, + containerSet, + containerSetRag, + service, + serviceRag, + ]) + + }, + +} + prompts + diff --git a/templates/components/llamafile.jsonnet b/templates/components/llamafile.jsonnet index 93163a14..d51cda61 100644 --- a/templates/components/llamafile.jsonnet +++ b/templates/components/llamafile.jsonnet @@ -6,12 +6,14 @@ local prompts = import "prompts/slm.jsonnet"; { "llamafile-model":: "LLaMA_CPP", - "llamafile-url":: "${LLAMAFILE_URL}", "text-completion" +: { create:: function(engine) + local envSecrets = engine.envSecrets("llamafile-credentials") + .with_env_var("LLAMAFILE_URL", "llamafile-url"); + local container = engine.container("text-completion") .with_image(images.trustgraph) @@ -21,27 +23,12 @@ local prompts = import "prompts/slm.jsonnet"; url.pulsar, "-m", $["llamafile-model"], - "-r", - $["llamafile-url"], ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); - local containerSet = engine.containers( - "text-completion", [ container ] - ); - - engine.resources([ - containerSet, - ]) - - }, - - "text-completion-rag" +: { - - create:: function(engine) - - local container = + local containerRag = engine.container("text-completion-rag") .with_image(images.trustgraph) .with_command([ @@ -50,26 +37,40 @@ local prompts = import "prompts/slm.jsonnet"; url.pulsar, "-m", $["llamafile-model"], - "-r", - $["llamafile-url"], "-i", "non-persistent://tg/request/text-completion-rag", "-o", "non-persistent://tg/response/text-completion-rag-response", ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); local containerSet = engine.containers( - "text-completion-rag", [ container ] + "text-completion", [ container ] ); + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] + ); + + local service = + engine.internalService(containerSet) + .with_port(8080, 8080, "metrics"); + + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8080, 8080, "metrics"); + engine.resources([ + envSecrets, containerSet, + containerSetRag, + service, + serviceRag, ]) - - } + }, } + prompts diff --git a/templates/components/ollama.jsonnet b/templates/components/ollama.jsonnet index b0507cef..2ae696b4 100644 --- a/templates/components/ollama.jsonnet +++ b/templates/components/ollama.jsonnet @@ -1,17 +1,19 @@ local base = import "base/base.jsonnet"; local images = import "values/images.jsonnet"; local url = import "values/url.jsonnet"; -local prompts = import "prompts/slm.jsonnet"; +local prompts = import "prompts/mixtral.jsonnet"; { "ollama-model":: "gemma2:9b", - "ollama-url":: "${OLLAMA_HOST}", "text-completion" +: { create:: function(engine) + local envSecrets = engine.envSecrets("ollama-credentials") + .with_env_var("OLLAMA_HOST", "ollama-host"); + local container = engine.container("text-completion") .with_image(images.trustgraph) @@ -21,32 +23,12 @@ local prompts = import "prompts/slm.jsonnet"; url.pulsar, "-m", $["ollama-model"], - "-r", - $["ollama-url"], ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); - local containerSet = engine.containers( - "text-completion", [ container ] - ); - - local service = - engine.internalService(containerSet) - .with_port(8080, 8080, "metrics"); - - engine.resources([ - containerSet, - service, - ]) - - }, - - "text-completion-rag" +: { - - create:: function(engine) - - local container = + local containerRag = engine.container("text-completion-rag") .with_image(images.trustgraph) .with_command([ @@ -55,31 +37,40 @@ local prompts = import "prompts/slm.jsonnet"; url.pulsar, "-m", $["ollama-model"], - "-r", - $["ollama-url"], "-i", "non-persistent://tg/request/text-completion-rag", "-o", "non-persistent://tg/response/text-completion-rag-response", ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); local containerSet = engine.containers( - "text-completion-rag", [ container ] + "text-completion", [ container ] + ); + + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] ); local service = engine.internalService(containerSet) .with_port(8080, 8080, "metrics"); + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8080, 8080, "metrics"); + engine.resources([ + envSecrets, containerSet, + containerSetRag, service, + serviceRag, ]) - - } + }, } + prompts diff --git a/templates/components/openai.jsonnet b/templates/components/openai.jsonnet index 3d1a2b73..83cbd406 100644 --- a/templates/components/openai.jsonnet +++ b/templates/components/openai.jsonnet @@ -5,7 +5,6 @@ local prompts = import "prompts/mixtral.jsonnet"; { - "openai-key":: "${OPENAI_KEY}", "openai-max-output-tokens":: 4096, "openai-temperature":: 0.0, "openai-model":: "GPT-3.5-Turbo", @@ -14,6 +13,9 @@ local prompts = import "prompts/mixtral.jsonnet"; create:: function(engine) + local envSecrets = engine.envSecrets("openai-credentials") + .with_env_var("OPENAI_TOKEN", "openai-token"); + local container = engine.container("text-completion") .with_image(images.trustgraph) @@ -21,50 +23,28 @@ local prompts = import "prompts/mixtral.jsonnet"; "text-completion-openai", "-p", url.pulsar, - "-k", - $["openai-key"], "-x", std.toString($["openai-max-output-tokens"]), "-t", - std.toString($["openai-temperature"]), + "%0.3f" % $["openai-temperature"], "-m", $["openai-model"], ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); - local containerSet = engine.containers( - "text-completion", [ container ] - ); - - local service = - engine.internalService(containerSet) - .with_port(8080, 8080, "metrics"); - - engine.resources([ - containerSet, - service, - ]) - - }, - - "text-completion-rag" +: { - - create:: function(engine) - - local container = + local containerRag = engine.container("text-completion-rag") .with_image(images.trustgraph) .with_command([ "text-completion-openai", "-p", url.pulsar, - "-k", - $["openai-key"], "-x", std.toString($["openai-max-output-tokens"]), "-t", - std.toString($["openai-temperature"]), + "%0.3f" % $["openai-temperature"], "-m", $["openai-model"], "-i", @@ -72,24 +52,35 @@ local prompts = import "prompts/mixtral.jsonnet"; "-o", "non-persistent://tg/response/text-completion-rag-response", ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); local containerSet = engine.containers( - "text-completion-rag", [ container ] + "text-completion", [ container ] + ); + + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] ); local service = engine.internalService(containerSet) .with_port(8080, 8080, "metrics"); + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8080, 8080, "metrics"); + engine.resources([ + envSecrets, containerSet, + containerSetRag, service, + serviceRag, ]) - - } + }, } + prompts diff --git a/templates/components/prompt-generic.jsonnet b/templates/components/prompt-generic.jsonnet deleted file mode 100644 index 5d6d7c54..00000000 --- a/templates/components/prompt-generic.jsonnet +++ /dev/null @@ -1,81 +0,0 @@ -local base = import "base/base.jsonnet"; -local images = import "values/images.jsonnet"; -local url = import "values/url.jsonnet"; -local prompts = import "prompts/mixtral.jsonnet"; - -{ - - "prompt" +: { - - create:: function(engine) - - local container = - engine.container("prompt") - .with_image(images.trustgraph) - .with_command([ - "prompt-generic", - "-p", - url.pulsar, - "--text-completion-request-queue", - "non-persistent://tg/request/text-completion", - "--text-completion-response-queue", - "non-persistent://tg/response/text-completion-response", - ]) - .with_limits("0.5", "128M") - .with_reservations("0.1", "128M"); - - local containerSet = engine.containers( - "prompt", [ container ] - ); - - local service = - engine.internalService(containerSet) - .with_port(8080, 8080, "metrics"); - - engine.resources([ - containerSet, - service, - ]) - - }, - - "prompt-rag" +: { - - create:: function(engine) - - local container = - engine.container("prompt-rag") - .with_image(images.trustgraph) - .with_command([ - "prompt-generic", - "-p", - url.pulsar, - "-i", - "non-persistent://tg/request/prompt-rag", - "-o", - "non-persistent://tg/response/prompt-rag-response", - "--text-completion-request-queue", - "non-persistent://tg/request/text-completion-rag", - "--text-completion-response-queue", - "non-persistent://tg/response/text-completion-rag-response", - ]) - .with_limits("0.5", "128M") - .with_reservations("0.1", "128M"); - - local containerSet = engine.containers( - "prompt-rag", [ container ] - ); - - local service = - engine.internalService(containerSet) - .with_port(8080, 8080, "metrics"); - - engine.resources([ - containerSet, - service, - ]) - - }, - -} - diff --git a/templates/components/prompt-overrides.jsonnet b/templates/components/prompt-overrides.jsonnet new file mode 100644 index 00000000..648e5b66 --- /dev/null +++ b/templates/components/prompt-overrides.jsonnet @@ -0,0 +1,28 @@ +local base = import "base/base.jsonnet"; +local images = import "values/images.jsonnet"; +local url = import "values/url.jsonnet"; +local prompts = import "prompts/mixtral.jsonnet"; +local default_prompts = import "prompts/default-prompts.jsonnet"; + +{ + + with:: function(key, value) + if (key == "system-template") then + self + { + prompts +:: { + "system-template": value, + } + } + else + self + { + prompts +:: { + templates +:: { + [key] +:: { + prompt: value + } + } + } + }, + +} + default_prompts + diff --git a/templates/components/prompt-template.jsonnet b/templates/components/prompt-template.jsonnet index 8ba0d17f..ac820df6 100644 --- a/templates/components/prompt-template.jsonnet +++ b/templates/components/prompt-template.jsonnet @@ -6,6 +6,38 @@ local default_prompts = import "prompts/default-prompts.jsonnet"; { + prompts:: default_prompts, + + local prompt_template_args = [ "--prompt" ] + [ + p.key + "=" + p.value.prompt, + for p in std.objectKeysValuesAll($.prompts.templates) + ], + + local prompt_response_type_args = [ "--prompt-response-type" ] + [ + p.key + "=" + p.value["response-type"], + for p in std.objectKeysValuesAll($.prompts.templates) + if std.objectHas(p.value, "response-type") + ], + + local prompt_schema_args = [ "--prompt-schema" ] + [ + ( + p.key + "=" + + std.manifestJsonMinified(p.value["schema"]) + ) + for p in std.objectKeysValuesAll($.prompts.templates) + if std.objectHas(p.value, "schema") + ], + + local prompt_term_args = [ "--prompt-term" ] + [ + p.key + "=" + t.key + ":" + t.value + for p in std.objectKeysValuesAll($.prompts.templates) + if std.objectHas(p.value, "terms") + for t in std.objectKeysValuesAll(p.value.terms) + ], + + local prompt_args = prompt_template_args + prompt_response_type_args + + prompt_schema_args + prompt_term_args, + "prompt" +: { create:: function(engine) @@ -17,23 +49,17 @@ local default_prompts = import "prompts/default-prompts.jsonnet"; "prompt-template", "-p", url.pulsar, + "--text-completion-request-queue", "non-persistent://tg/request/text-completion", "--text-completion-response-queue", "non-persistent://tg/response/text-completion-response", - "--definition-template", - $["prompt-definition-template"], - "--relationship-template", - $["prompt-relationship-template"], - "--topic-template", - $["prompt-topic-template"], - "--knowledge-query-template", - $["prompt-knowledge-query-template"], - "--document-query-template", - $["prompt-document-query-template"], - "--rows-template", - $["prompt-rows-template"], - ]) + + "--system-prompt", + $["prompts"]["system-template"], + + ] + prompt_args + ) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); @@ -71,19 +97,12 @@ local default_prompts = import "prompts/default-prompts.jsonnet"; "non-persistent://tg/request/text-completion-rag", "--text-completion-response-queue", "non-persistent://tg/response/text-completion-rag-response", - "--definition-template", - $["prompt-definition-template"], - "--relationship-template", - $["prompt-relationship-template"], - "--topic-template", - $["prompt-topic-template"], - "--knowledge-query-template", - $["prompt-knowledge-query-template"], - "--document-query-template", - $["prompt-document-query-template"], - "--rows-template", - $["prompt-rows-template"], - ]) + + "--system-prompt", + $["prompts"]["system-template"], + + ] + prompt_args + ) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); diff --git a/templates/components/vertexai.jsonnet b/templates/components/vertexai.jsonnet index 2bc97799..44fe27c6 100644 --- a/templates/components/vertexai.jsonnet +++ b/templates/components/vertexai.jsonnet @@ -37,7 +37,7 @@ local prompts = import "prompts/mixtral.jsonnet"; "-x", std.toString($["vertexai-max-output-tokens"]), "-t", - std.toString($["vertexai-temperature"]), + "%0.3f" % $["vertexai-temperature"], "-m", $["vertexai-model"], ]) @@ -87,7 +87,7 @@ local prompts = import "prompts/mixtral.jsonnet"; "-x", std.toString($["vertexai-max-output-tokens"]), "-t", - std.toString($["vertexai-temperature"]), + "%0.3f" % $["vertexai-temperature"], "-m", $["vertexai-model"], "-i", diff --git a/templates/engine/docker-compose.jsonnet b/templates/engine/docker-compose.jsonnet index 4f837ff2..c37f1df0 100644 --- a/templates/engine/docker-compose.jsonnet +++ b/templates/engine/docker-compose.jsonnet @@ -18,12 +18,15 @@ reservations: {}, ports: [], volumes: [], + environment: {}, with_image:: function(x) self + { image: x }, with_command:: function(x) self + { command: x }, - with_environment:: function(x) self + { environment: x }, + with_environment:: function(x) self + { + environment: super.environment + x, + }, with_limits:: function(c, m) self + { limits: { cpus: c, memory: m } }, @@ -45,6 +48,16 @@ ] }, + with_env_var_secrets:: + function(vars) + std.foldl( + function(obj, x) obj.with_environment( + { [x]: "${" + x + "}" } + ), + vars.variables, + self + ), + add:: function() { services +: { [container.name]: { @@ -62,7 +75,7 @@ { command: container.command } else {}) + - (if std.objectHas(container, "environment") then + (if ! std.isEmpty(container.environment) then { environment: container.environment } else {}) + @@ -170,6 +183,27 @@ }, + envSecrets:: function(name) + { + + local volume = self, + + name: name, + + volid:: name, + + variables:: [], + + with_env_var:: + function(name, key) self + { + variables: super.variables + [name], + }, + + add:: function() { + } + + }, + containers:: function(name, containers) { diff --git a/templates/engine/k8s.jsonnet b/templates/engine/k8s.jsonnet index 69aabfd7..2fec0d1f 100644 --- a/templates/engine/k8s.jsonnet +++ b/templates/engine/k8s.jsonnet @@ -10,12 +10,20 @@ reservations: {}, ports: [], volumes: [], + environment: [], with_image:: function(x) self + { image: x }, with_command:: function(x) self + { command: x }, - with_environment:: function(x) self + { environment: x }, + with_environment:: function(x) self + { + environment: super.environment + [ + { + name: v.key, value: v.value + } + for v in std.objectKeysValues(x) + ], + }, with_limits:: function(c, m) self + { limits: { cpu: c, memory: m } }, @@ -37,6 +45,24 @@ ] }, + with_env_var_secrets:: + function(vars) + std.foldl( + function(obj, x) obj + { + environment: super.environment + [{ + name: x, + valueFrom: { + secretKeyRef: { + name: vars.name, + key: vars.keyMap[x], + } + } + }] + }, + vars.variables, + self + ), + add:: function() [ { @@ -97,16 +123,11 @@ (if std.objectHas(container, "command") then { command: container.command } else {}) + - (if std.objectHas(container, "environment") then - { env: [ { - name: e.key, value: e.value - } - for e in - std.objectKeysValues( - container.environment - ) - ] - } + + (if ! std.isEmpty(container.environment) then + { + env: container.environment, + } else {}) + (if std.length(container.volumes) > 0 then @@ -283,6 +304,34 @@ }, + envSecrets:: function(name) + { + + local volume = self, + + name: name, + + variables: [], + keyMap: {}, + + with_size:: function(size) self + { size: size }, + + add:: function() [ + ], + + volRef:: function() { + name: volume.name, + secret: { secretName: volume.name }, + }, + + with_env_var:: + function(name, key) self + { + variables: super.variables + [name], + keyMap: super.keyMap + { [name]: key }, + }, + + }, + containers:: function(name, containers) { diff --git a/templates/generate b/templates/generate new file mode 100755 index 00000000..e4925cec --- /dev/null +++ b/templates/generate @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 + +import _jsonnet as j +import json +import yaml +import logging +import os +import sys +import zipfile + +logger = logging.getLogger("generate") +logging.basicConfig(level=logging.INFO, format='%(message)s') + +private_json = "Put your GCP private.json here" + +class Generator: + + def __init__(self, config, base="./templates/", version="0.0.0"): + + self.jsonnet_base = base + self.config = config + self.version = f"\"{version}\"".encode("utf-8") + + def process(self, config): + + res = j.evaluate_snippet("config", config, import_callback=self.load) + return json.loads(res) + + def load(self, dir, filename): + + logger.debug("Request jsonnet: %s %s", dir, filename) + + if filename == "config.json" and dir == "": + path = os.path.join(".", dir, filename) + return str(path), self.config + + if filename == "version.jsonnet" and dir == "./templates/values/": + path = os.path.join(".", dir, filename) + return str(path), self.version + + if dir: + candidates = [ + os.path.join(".", dir, filename), + os.path.join(".", filename) + ] + else: + candidates = [ + os.path.join(".", filename) + ] + + try: + + if filename == "vertexai/private.json": + + return candidates[0], private_json.encode("utf-8") + + for c in candidates: + logger.debug("Try: %s", c) + + if os.path.isfile(c): + with open(c, "rb") as f: + logger.debug("Loading: %s", c) + return str(c), f.read() + + raise RuntimeError( + f"Could not load file={filename} dir={dir}" + ) + + except: + + path = os.path.join(self.jsonnet_base, filename) + logger.debug("Try: %s", path) + with open(path, "rb") as f: + logger.debug("Loaded: %s", path) + return str(path), f.read() + +def main(): + + if len(sys.argv) != 3: + print() + print("Usage:") + print(" generate < input.json") + print() + raise RuntimeError("Arg error") + + outfile = sys.argv[1] + version = sys.argv[2] + + cfg = sys.stdin.read() + cfg = json.loads(cfg) + + logger.info(f"Outputting to {outfile}...") + + with zipfile.ZipFile(outfile, mode='w') as out: + + def output(name, content): + logger.info(f"Adding {name}...") + out.writestr(name, content) + + fname = "tg-launch.yaml" + + platform = "docker-compose" + + with open(f"./templates/config-to-{platform}.jsonnet", "r") as f: + wrapper = f.read() + + gen = Generator(json.dumps(cfg).encode("utf-8"), version=version) + + processed = gen.process(wrapper) + + y = yaml.dump(processed) + + output(fname, y) + + # Placeholder for the private.json file. Won't put actual credentials + # here. + output("docker-compose/vertexai/private.json", private_json) + + # Grafana config + with open("grafana/dashboards/dashboard.json") as f: + output( + "docker-compose/grafana/dashboards/dashboard.json", f.read() + ) + + with open("grafana/provisioning/dashboard.yml") as f: + output( + "docker-compose/grafana/provisioning/dashboard.yml", f.read() + ) + + with open("grafana/provisioning/datasource.yml") as f: + output( + "docker-compose/grafana/provisioning/datasource.yml", f.read() + ) + + # Prometheus config + with open("prometheus/prometheus.yml") as f: + output("docker-compose/prometheus/prometheus.yml", f.read()) + +main() + diff --git a/templates/generate-all b/templates/generate-all index 948b811f..0b403620 100755 --- a/templates/generate-all +++ b/templates/generate-all @@ -122,8 +122,8 @@ def generate_all(output, version): "docker-compose", "minikube-k8s", "gcp-k8s" ]: for model in [ - "azure", "bedrock", "claude", "cohere", "llamafile", "ollama", - "openai", "vertexai" + "azure", "azure-openai", "bedrock", "claude", "cohere", + "googleaistudio", "llamafile", "ollama", "openai", "vertexai", ]: for graph in [ "cassandra", "neo4j" ]: diff --git a/templates/patterns/llm-azure-openai.jsonnet b/templates/patterns/llm-azure-openai.jsonnet new file mode 100644 index 00000000..06e1a3f5 --- /dev/null +++ b/templates/patterns/llm-azure-openai.jsonnet @@ -0,0 +1,32 @@ +{ + pattern: { + name: "azure-openai", + icon: "🤖💬", + title: "Add Azure OpenAI LLM endpoint for text completion", + description: "This pattern integrates an Azure OpenAI LLM endpoint hosted in the Azure cloud for text completion operations. You need an Azure subscription to be able to use this service.", + requires: ["pulsar", "trustgraph"], + features: ["llm"], + args: [ + { + name: "azure-openai-max-output-tokens", + label: "Maximum output tokens", + type: "integer", + description: "Limit on number tokens to generate", + default: 4096, + required: true, + }, + { + name: "azure-openai-temperature", + label: "Temperature", + type: "slider", + description: "Controlling predictability / creativity balance", + min: 0, + max: 1, + step: 0.05, + default: 0.5, + }, + ], + category: [ "llm" ], + }, + module: "components/azure.jsonnet", +} diff --git a/templates/patterns/llm-googleaistudio.jsonnet b/templates/patterns/llm-googleaistudio.jsonnet new file mode 100644 index 00000000..aa56d347 --- /dev/null +++ b/templates/patterns/llm-googleaistudio.jsonnet @@ -0,0 +1,32 @@ +{ + pattern: { + name: "googleaistudio", + icon: "🤖💬", + title: "Add GoogleAIStudio for text completion", + description: "This pattern integrates a GoogleAIStudio LLM service for text completion operations. You need a GoogleAISTudio API key to be able to use this service.", + requires: ["pulsar", "trustgraph"], + features: ["llm"], + args: [ + { + name: "googleaistudio-max-output-tokens", + label: "Maximum output tokens", + type: "integer", + description: "Limit on number tokens to generate", + default: 4096, + required: true, + }, + { + name: "googleaistudio-temperature", + label: "Temperature", + type: "slider", + description: "Controlling predictability / creativity balance", + min: 0, + max: 1, + step: 0.05, + default: 0.5, + }, + ], + category: [ "llm" ], + }, + module: "components/googleaistudio.jsonnet", +} diff --git a/templates/prompts/cohere.jsonnet b/templates/prompts/cohere.jsonnet index 2335084b..9541e4c2 100644 --- a/templates/prompts/cohere.jsonnet +++ b/templates/prompts/cohere.jsonnet @@ -1,18 +1,42 @@ - // For Cohere. Not currently overriding prompts -{ +local prompts = import "default-prompts.jsonnet"; -// "prompt-definition-template": "PROMPT GOES HERE", +prompts + { -// "prompt-relationship-template":: "PROMPT GOES HERE", + // "system-template":: "PROMPT GOES HERE.", -// "prompt-topic-template":: "PROMPT GOES HERE", + "templates" +:: { -// "prompt-knowledge-query-template":: "PROMPT GOES HERE", + "question" +:: { + // "prompt": "PROMPT GOES HERE", + }, -// "prompt-document-query-template":: "PROMPT GOES HERE", + "extract-definitions" +:: { + // "prompt": "PROMPT GOES HERE", + }, -// "prompt-rows-template":: "PROMPT GOES HERE", + "extract-relationships" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + "extract-topics" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + "extract-rows" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + "kg-prompt" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + "document-prompt" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + } } + diff --git a/templates/prompts/default-prompts.jsonnet b/templates/prompts/default-prompts.jsonnet index 6f8c7b7b..be05b992 100644 --- a/templates/prompts/default-prompts.jsonnet +++ b/templates/prompts/default-prompts.jsonnet @@ -4,16 +4,111 @@ { - "prompt-definition-template":: "\nStudy the following text and derive definitions for any discovered entities.\nDo not provide definitions for entities whose definitions are incomplete\nor unknown.\nOutput relationships in JSON format as an arary of objects with fields:\n- entity: the name of the entity\n- definition: English text which defines the entity\n\n\n\n{text}\n\n\n\nYou will respond only with raw JSON format data. Do not provide\nexplanations. Do not use special characters in the abstract text. The\nabstract will be written as plain text. Do not add markdown formatting\nor headers or prefixes. Do not include null or unknown definitions.\n", + "system-template":: "You are a helpful assistant.", - "prompt-relationship-template":: "\nStudy the following text and derive entity relationships. For each\nrelationship, derive the subject, predicate and object of the relationship.\nOutput relationships in JSON format as an arary of objects with fields:\n- subject: the subject of the relationship\n- predicate: the predicate\n- object: the object of the relationship\n- object-entity: false if the object is a simple data type: name, value or date. true if it is an entity.\n\n\n\n{text}\n\n\n\nYou will respond only with raw JSON format data. Do not provide\nexplanations. Do not use special characters in the abstract text. The\nabstract must be written as plain text. Do not add markdown formatting\nor headers or prefixes.\n", + "templates":: { - "prompt-topic-template":: "You are a helpful assistant that performs information extraction tasks for a provided text.\nRead the provided text. You will identify topics and their definitions in JSON.\n\nReading Instructions:\n- Ignore document formatting in the provided text.\n- Study the provided text carefully.\n\nHere is the text:\n{text}\n\nResponse Instructions: \n- Do not respond with special characters.\n- Return only topics that are concepts and unique to the provided text.\n- Respond only with well-formed JSON.\n- The JSON response shall be an array of objects with keys \"topic\" and \"definition\". \n- The JSON response shall use the following structure:\n\n```json\n[{{\"topic\": string, \"definition\": string}}]\n```\n\n- Do not write any additional text or explanations.", + "question":: { + "prompt": "{{question}}", + }, - "prompt-knowledge-query-template":: "Study the following set of knowledge statements. The statements are written in Cypher format that has been extracted from a knowledge graph. Use only the provided set of knowledge statements in your response. Do not speculate if the answer is not found in the provided set of knowledge statements.\n\nHere's the knowledge statements:\n{graph}\n\nUse only the provided knowledge statements to respond to the following:\n{query}\n", + "extract-definitions":: { + "prompt": "\nStudy the following text and derive definitions for any discovered entities.\nDo not provide definitions for entities whose definitions are incomplete\nor unknown.\nOutput relationships in JSON format as an arary of objects with fields:\n- entity: the name of the entity\n- definition: English text which defines the entity\n\n\n\n{{text}}\n\n\n\nYou will respond only with raw JSON format data. Do not provide\nexplanations. Do not use special characters in the abstract text. The\nabstract will be written as plain text. Do not add markdown formatting\nor headers or prefixes. Do not include null or unknown definitions.\n", + "response-type": "json", + "schema": { + "type": "array", + "items": { + "type": "object", + "properties": { + "entity": { + "type": "string" + }, + "definition": { + "type": "string" + } + }, + "required": [ + "entity", + "definition" + ] + } + } + }, - "prompt-document-query-template":: "Study the following context. Use only the information provided in the context in your response. Do not speculate if the answer is not found in the provided set of knowledge statements.\n\nHere is the context:\n{documents}\n\nUse only the provided knowledge statements to respond to the following:\n{query}\n", + "extract-relationships":: { + "prompt": "\nStudy the following text and derive entity relationships. For each\nrelationship, derive the subject, predicate and object of the relationship.\nOutput relationships in JSON format as an arary of objects with fields:\n- subject: the subject of the relationship\n- predicate: the predicate\n- object: the object of the relationship\n- object-entity: false if the object is a simple data type: name, value or date. true if it is an entity.\n\n\n\n{{text}}\n\n\n\nYou will respond only with raw JSON format data. Do not provide\nexplanations. Do not use special characters in the abstract text. The\nabstract must be written as plain text. Do not add markdown formatting\nor headers or prefixes.\n", + "response-type": "json", + "schema": { + "type": "array", + "items": { + "type": "object", + "properties": { + "subject": { + "type": "string" + }, + "predicate": { + "type": "string" + }, + "object": { + "type": "string" + }, + "object-entity": { + "type": "boolean" + }, + }, + "required": [ + "subject", + "predicate", + "object", + "object-entity" + ] + } + } + }, - "prompt-rows-template":: "\nStudy the following text and derive objects which match the schema provided.\n\nYou must output an array of JSON objects for each object you discover\nwhich matches the schema. For each object, output a JSON object whose fields\ncarry the name field specified in the schema.\n\n\n\n{schema}\n\n\n\n{text}\n\n\n\nYou will respond only with raw JSON format data. Do not provide\nexplanations. Do not add markdown formatting or headers or prefixes.\n", + "extract-topics":: { + "prompt": "You are a helpful assistant that performs information extraction tasks for a provided text.\nRead the provided text. You will identify topics and their definitions in JSON.\n\nReading Instructions:\n- Ignore document formatting in the provided text.\n- Study the provided text carefully.\n\nHere is the text:\n{{text}}\n\nResponse Instructions: \n- Do not respond with special characters.\n- Return only topics that are concepts and unique to the provided text.\n- Respond only with well-formed JSON.\n- The JSON response shall be an array of objects with keys \"topic\" and \"definition\". \n- The JSON response shall use the following structure:\n\n```json\n[{\"topic\": string, \"definition\": string}]\n```\n\n- Do not write any additional text or explanations.", + "response-type": "json", + "schema": { + "type": "array", + "items": { + "type": "object", + "properties": { + "topic": { + "type": "string" + }, + "definition": { + "type": "string" + } + }, + "required": [ + "topic", + "definition" + ] + } + } + }, + + "extract-rows":: { + "prompt": "\nStudy the following text and derive objects which match the schema provided.\n\nYou must output an array of JSON objects for each object you discover\nwhich matches the schema. For each object, output a JSON object whose fields\ncarry the name field specified in the schema.\n\n\n\n{{schema}}\n\n\n\n{{text}}\n\n\n\nYou will respond only with raw JSON format data. Do not provide\nexplanations. Do not add markdown formatting or headers or prefixes.\n", + "response-type": "json", + }, + + "kg-prompt":: { + "prompt": "Study the following set of knowledge statements. The statements are written in Cypher format that has been extracted from a knowledge graph. Use only the provided set of knowledge statements in your response. Do not speculate if the answer is not found in the provided set of knowledge statements.\n\nHere's the knowledge statements:\n{% for edge in knowledge %}({{edge.s}})-[{{edge.p}}]->({{edge.o}})\n{%endfor%}\n\nUse only the provided knowledge statements to respond to the following:\n{{query}}\n", + "response-type": "text", + }, + + "document-prompt":: { + "prompt": "Study the following context. Use only the information provided in the context in your response. Do not speculate if the answer is not found in the provided set of knowledge statements.\n\nHere is the context:\n{{documents}}\n\nUse only the provided knowledge statements to respond to the following:\n{{query}}\n", + "response-type": "text", + }, + + "agent-react":: { + "prompt": "Answer the following questions as best you can. You have\naccess to the following functions:\n\n{% for tool in tools %}{\n \"function\": \"{{ tool.name }}\",\n \"description\": \"{{ tool.description }}\",\n \"arguments\": [\n{% for arg in tool.arguments %} {\n \"name\": \"{{ arg.name }}\",\n \"type\": \"{{ arg.type }}\",\n \"description\": \"{{ arg.description }}\",\n }\n{% endfor %}\n ]\n}\n{% endfor %}\n\nYou can either choose to call a function to get more information, or\nreturn a final answer.\n \nTo call a function, respond with a JSON object of the following format:\n\n{\n \"thought\": \"your thought about what to do\",\n \"action\": \"the action to take, should be one of [{{tool_names}}]\",\n \"arguments\": {\n \"argument1\": \"argument_value\",\n \"argument2\": \"argument_value\"\n }\n}\n\nTo provide a final answer, response a JSON object of the following format:\n\n{\n \"thought\": \"I now know the final answer\",\n \"final-answer\": \"the final answer to the original input question\"\n}\n\nPrevious steps are included in the input. Each step has the following\nformat in your output:\n\n{\n \"thought\": \"your thought about what to do\",\n \"action\": \"the action taken\",\n \"arguments\": {\n \"argument1\": action argument,\n \"argument2\": action argument2\n },\n \"observation\": \"the result of the action\",\n}\n\nRespond by describing either one single thought/action/arguments or\nthe final-answer. Pause after providing one action or final-answer.\n\n{% if context %}Additional context has been provided:\n{{context}}{% endif %}\n\nQuestion: {{question}}\n\nInput:\n \n{% for h in history %}\n{\n \"action\": \"{{h.action}}\",\n \"arguments\": [\n{% for k, v in h.arguments.items() %} {\n \"{{k}}\": \"{{v}}\",\n{%endfor%} }\n ],\n \"observation\": \"{{h.observation}}\"\n}\n{% endfor %}", + "response-type": "json" + } + } + +} -} \ No newline at end of file diff --git a/templates/prompts/gemini.jsonnet b/templates/prompts/gemini.jsonnet index 12905c7a..b9a1e0c0 100644 --- a/templates/prompts/gemini.jsonnet +++ b/templates/prompts/gemini.jsonnet @@ -1,17 +1,42 @@ - // For VertexAI Gemini. Not currently overriding prompts -{ -// "prompt-definition-template": "PROMPT GOES HERE", +local prompts = import "default-prompts.jsonnet"; -// "prompt-relationship-template":: "PROMPT GOES HERE", +prompts + { -// "prompt-topic-template":: "PROMPT GOES HERE", + // "system-template":: "PROMPT GOES HERE.", -// "prompt-knowledge-query-template":: "PROMPT GOES HERE", + "templates" +:: { -// "prompt-document-query-template":: "PROMPT GOES HERE", + "question" +:: { + // "prompt": "PROMPT GOES HERE", + }, -// "prompt-rows-template":: "PROMPT GOES HERE", + "extract-definitions" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + "extract-relationships" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + "extract-topics" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + "extract-rows" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + "kg-prompt" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + "document-prompt" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + } } + diff --git a/templates/prompts/mixtral.jsonnet b/templates/prompts/mixtral.jsonnet index c5e70477..cd56e7ef 100644 --- a/templates/prompts/mixtral.jsonnet +++ b/templates/prompts/mixtral.jsonnet @@ -1,18 +1,42 @@ - // For Mixtral. Not currently overriding prompts -{ +local prompts = import "default-prompts.jsonnet"; -// "prompt-definition-template": "PROMPT GOES HERE", +prompts + { -// "prompt-relationship-template":: "PROMPT GOES HERE", + // "system-template":: "PROMPT GOES HERE.", -// "prompt-topic-template":: "PROMPT GOES HERE", + "templates" +:: { -// "prompt-knowledge-query-template":: "PROMPT GOES HERE", + "question" +:: { + // "prompt": "PROMPT GOES HERE", + }, -// "prompt-document-query-template":: "PROMPT GOES HERE", + "extract-definitions" +:: { + // "prompt": "PROMPT GOES HERE", + }, -// "prompt-rows-template":: "PROMPT GOES HERE", + "extract-relationships" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + "extract-topics" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + "extract-rows" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + "kg-prompt" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + "document-prompt" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + } } + diff --git a/templates/prompts/openai.jsonnet b/templates/prompts/openai.jsonnet index 0715525e..5d232337 100644 --- a/templates/prompts/openai.jsonnet +++ b/templates/prompts/openai.jsonnet @@ -1,23 +1,42 @@ +// For OpenAI LLMs. Not currently overriding prompts -// For OpenAI LLMs +local prompts = import "default-prompts.jsonnet"; -local base = import "base/base.jsonnet"; -local images = import "values/images.jsonnet"; -local url = import "values/url.jsonnet"; +prompts + { -{ + // "system-template":: "PROMPT GOES HERE.", -// "prompt-definition-template": "PROMPT GOES HERE", + "templates" +:: { -// "prompt-relationship-template":: "PROMPT GOES HERE", + "question" +:: { + // "prompt": "PROMPT GOES HERE", + }, -// "prompt-topic-template":: "PROMPT GOES HERE", + "extract-definitions" +:: { + // "prompt": "PROMPT GOES HERE", + }, -// "prompt-knowledge-query-template":: "PROMPT GOES HERE", + "extract-relationships" +:: { + // "prompt": "PROMPT GOES HERE", + }, -// "prompt-document-query-template":: "PROMPT GOES HERE", + "extract-topics" +:: { + // "prompt": "PROMPT GOES HERE", + }, -// "prompt-rows-template":: "PROMPT GOES HERE", + "extract-rows" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + "kg-prompt" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + "document-prompt" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + } } diff --git a/templates/prompts/slm.jsonnet b/templates/prompts/slm.jsonnet index bd0cbff3..48eb96d0 100644 --- a/templates/prompts/slm.jsonnet +++ b/templates/prompts/slm.jsonnet @@ -1,7 +1,44 @@ +// For SLM. Not currently overriding prompts -// For basic SLMs, use prompt-generic +local prompts = import "default-prompts.jsonnet"; + +prompts + { + + // "system-template":: "PROMPT GOES HERE.", + + "templates" +:: { + + "question" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + "extract-definitions" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + "extract-relationships" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + "extract-topics" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + "extract-rows" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + "kg-prompt" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + "document-prompt" +:: { + // "prompt": "PROMPT GOES HERE", + }, + + } + +} -local prompts = import "components/prompt-generic.jsonnet"; -prompts diff --git a/templates/stores/cassandra.jsonnet b/templates/stores/cassandra.jsonnet index c501e1f9..0c90421e 100644 --- a/templates/stores/cassandra.jsonnet +++ b/templates/stores/cassandra.jsonnet @@ -13,10 +13,10 @@ local images = import "values/images.jsonnet"; engine.container("cassandra") .with_image(images.cassandra) .with_environment({ - JVM_OPTS: "-Xms256M -Xmx256M", + JVM_OPTS: "-Xms300M -Xmx300M", }) - .with_limits("1.0", "800M") - .with_reservations("0.5", "800M") + .with_limits("1.0", "1000M") + .with_reservations("0.5", "1000M") .with_port(9042, 9042, "cassandra") .with_volume_mount(vol, "/var/lib/cassandra"); diff --git a/templates/stores/milvus.jsonnet b/templates/stores/milvus.jsonnet index 888a83a9..cbeb4268 100644 --- a/templates/stores/milvus.jsonnet +++ b/templates/stores/milvus.jsonnet @@ -37,7 +37,7 @@ local images = import "values/images.jsonnet"; local service = engine.service(containerSet) - .with_port(2379, 2379, 30379, "api"); + .with_port(2379, 2379, "api"); engine.resources([ vol, @@ -117,7 +117,7 @@ local images = import "values/images.jsonnet"; local service = engine.service(containerSet) .with_port(9091, 9091, "api") - .with_port(19530, 19530, "api2); + .with_port(19530, 19530, "api2"); engine.resources([ vol, diff --git a/templates/stores/qdrant.jsonnet b/templates/stores/qdrant.jsonnet index e8443b73..9e807632 100644 --- a/templates/stores/qdrant.jsonnet +++ b/templates/stores/qdrant.jsonnet @@ -12,8 +12,8 @@ local images = import "values/images.jsonnet"; local container = engine.container("qdrant") .with_image(images.qdrant) - .with_limits("1.0", "256M") - .with_reservations("0.5", "256M") + .with_limits("1.0", "1024M") + .with_reservations("0.5", "1024M") .with_port(6333, 6333, "api") .with_port(6334, 6334, "api2") .with_volume_mount(vol, "/qdrant/storage"); diff --git a/templates/util/decode-config.jsonnet b/templates/util/decode-config.jsonnet index a1bd146e..503b5b6b 100644 --- a/templates/util/decode-config.jsonnet +++ b/templates/util/decode-config.jsonnet @@ -3,9 +3,7 @@ local components = import "components.jsonnet"; local apply = function(p, components) - local component = components[p.name]; - - (component + { + local base = { with:: function(k, v) self + { [k]:: v @@ -18,7 +16,11 @@ local apply = function(p, components) self ), - }).with_params(p.parameters); + }; + + local component = base + components[p.name]; + + component.with_params(p.parameters); local decode = function(config) local add = function(state, c) state + apply(c, components); diff --git a/tests/README.prompts b/tests/README.prompts new file mode 100644 index 00000000..7a17affe --- /dev/null +++ b/tests/README.prompts @@ -0,0 +1,27 @@ + +test-prompt-... is tested with this prompt set... + +prompt-template \ + -p pulsar://localhost:6650 \ + --system-prompt 'You are a {{attitude}}, you are called {{name}}' \ + --global-term \ + 'name=Craig' \ + 'attitude=LOUD, SHOUTY ANNOYING BOT' \ + --prompt \ + 'question={{question}}' \ + 'french-question={{question}}' \ + "analyze=Find the name and age in this text, and output a JSON structure containing just the name and age fields: {{description}}. Don't add markup, just output the raw JSON object." \ + "graph-query=Study the following knowledge graph, and then answer the question.\\n\nGraph:\\n{% for edge in knowledge %}({{edge.0}})-[{{edge.1}}]->({{edge.2}})\\n{%endfor%}\\nQuestion:\\n{{question}}" \ + "extract-definition=Analyse the text provided, and then return a list of terms and definitions. The output should be a JSON array, each item in the array is an object with fields 'term' and 'definition'.Don't add markup, just output the raw JSON object. Here is the text:\\n{{text}}" \ + --prompt-response-type \ + 'question=text' \ + 'analyze=json' \ + 'graph-query=text' \ + 'extract-definition=json' \ + --prompt-term \ + 'question=name:Bonny' \ + 'french-question=attitude:French-speaking bot' \ + --prompt-schema \ + 'analyze={ "type" : "object", "properties" : { "age": { "type" : "number" }, "name": { "type" : "string" } } }' \ + 'extract-definition={ "type": "array", "items": { "type": "object", "properties": { "term": { "type": "string" }, "definition": { "type": "string" } }, "required": [ "term", "definition" ] } }' + diff --git a/tests/test-agent b/tests/test-agent new file mode 100755 index 00000000..4782bbae --- /dev/null +++ b/tests/test-agent @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 + +import json +import textwrap +from trustgraph.clients.agent_client import AgentClient + +def wrap(text, width=75): + + if text is None: text = "n/a" + + out = textwrap.wrap( + text, width=width + ) + return "\n".join(out) + +def output(text, prefix="> ", width=78): + + out = textwrap.indent( + text, prefix=prefix + ) + print(out) + +p = AgentClient(pulsar_host="pulsar://localhost:6650") + +q = "How many cats does Mark have? Calculate that number raised to 0.4 power. Is that number lower than the numeric part of the mission identifier of the Space Shuttle Challenger on its last mission? If so, give me an apple pie recipe, otherwise return a poem about cheese." + +output(wrap(q), "\U00002753 ") +print() + +def think(x): + output(wrap(x), "\U0001f914 ") + print() + +def observe(x): + output(wrap(x), "\U0001f4a1 ") + print() + +resp = p.request( + question=q, think=think, observe=observe, +) + +output(resp, "\U0001f4ac ") +print() + diff --git a/tests/test-lang-definition b/tests/test-lang-definition index c6e593fd..67342779 100755 --- a/tests/test-lang-definition +++ b/tests/test-lang-definition @@ -7,7 +7,13 @@ p = PromptClient(pulsar_host="pulsar://localhost:6650") chunk = """I noticed a cat in my garden. It is a four-legged animal which is a mammal and can be tame or wild. I wonder if it will be friends -with me. I think the cat's name is Fred and it has 4 legs""" +with me. I think the cat's name is Fred and it has 4 legs. + +A cat is a small mammal. + +A grapefruit is a citrus fruit. + +""" resp = p.request_definitions( chunk=chunk, diff --git a/tests/test-lang-topics b/tests/test-lang-topics new file mode 100755 index 00000000..2b668524 --- /dev/null +++ b/tests/test-lang-topics @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 + +import pulsar +from trustgraph.clients.prompt_client import PromptClient + +p = PromptClient(pulsar_host="pulsar://localhost:6650") + +chunk = """I noticed a cat in my garden. It is a four-legged animal +which is a mammal and can be tame or wild. I wonder if it will be friends +with me. I think the cat's name is Fred and it has 4 legs""" + +resp = p.request_topics( + chunk=chunk, +) + +for d in resp: + print(d.topic) + print(" ", d.definition) + diff --git a/tests/test-llm b/tests/test-llm index 7e2c271d..4e86387a 100755 --- a/tests/test-llm +++ b/tests/test-llm @@ -5,9 +5,10 @@ from trustgraph.clients.llm_client import LlmClient llm = LlmClient(pulsar_host="pulsar://localhost:6650") +system = "You are a lovely assistant." prompt="Write a funny limerick about a llama" -resp = llm.request(prompt) +resp = llm.request(system, prompt) print(resp) diff --git a/tests/test-prompt-analyze b/tests/test-prompt-analyze new file mode 100755 index 00000000..53c1d76f --- /dev/null +++ b/tests/test-prompt-analyze @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 + +import json +from trustgraph.clients.prompt_client import PromptClient + +p = PromptClient(pulsar_host="pulsar://localhost:6650") + +description = """Fred is a 4-legged cat who is 12 years old""" + +resp = p.request( + id="analyze", + terms = { + "description": description, + } +) + +print(json.dumps(resp, indent=4)) + diff --git a/tests/test-prompt-extraction b/tests/test-prompt-extraction new file mode 100755 index 00000000..c73bd2e2 --- /dev/null +++ b/tests/test-prompt-extraction @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 + +import json +from trustgraph.clients.prompt_client import PromptClient + +p = PromptClient(pulsar_host="pulsar://localhost:6650") + +chunk=""" + The Space Shuttle was a reusable spacecraft that transported astronauts and cargo to and from Earth's orbit. It was designed to launch like a rocket, maneuver in orbit like a spacecraft, and land like an airplane. The Space Shuttle was NASA's space transportation system and was used for many purposes, including: + + Carrying astronauts + The Space Shuttle could carry up to seven astronauts at a time. + +Launching, recovering, and repairing satellites +The Space Shuttle could launch satellites into orbit, recover them, and repair them. +Building the International Space Station +The Space Shuttle carried large parts into space to build the International Space Station. +Conducting research +Astronauts conducted experiments in the Space Shuttle, which was like a science lab in space. + +The Space Shuttle was retired in 2011 after the Columbia accident in 2003. The Columbia Accident Investigation Board report found that the Space Shuttle was unsafe and expensive to make safe. +Here are some other facts about the Space Shuttle: + + The Space Shuttle was 184 ft tall and had a diameter of 29 ft. + +The Space Shuttle had a mass of 4,480,000 lb. +The Space Shuttle's first flight was on April 12, 1981. +The Space Shuttle's last mission was in 2011. +""" + +q = "Tell me some facts in the knowledge graph" + +resp = p.request( + id="extract-definition", + terms = { + "text": chunk, + } +) + +print(resp) + +for fact in resp: + print(fact["term"], "::") + print(fact["definition"]) + print() + diff --git a/tests/test-prompt-french-question b/tests/test-prompt-french-question new file mode 100755 index 00000000..4417cf41 --- /dev/null +++ b/tests/test-prompt-french-question @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 + +import pulsar +from trustgraph.clients.prompt_client import PromptClient + +p = PromptClient(pulsar_host="pulsar://localhost:6650") + +question = """What is the square root of 16?""" + +resp = p.request( + id="french-question", + terms = { + "question": question + } +) + +print(resp) + diff --git a/tests/test-prompt-knowledge b/tests/test-prompt-knowledge new file mode 100755 index 00000000..b1b94983 --- /dev/null +++ b/tests/test-prompt-knowledge @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 + +import json +from trustgraph.clients.prompt_client import PromptClient + +p = PromptClient(pulsar_host="pulsar://localhost:6650") + +knowledge = [ + ("accident", "evoked", "a wide range of deeply felt public responses"), + ("Space Shuttle concept", "had", "genesis"), + ("Commission", "had", "a mandate to develop recommendations for corrective or other action based upon the Commission's findings and determinations"), + ("Commission", "established", "teams of persons"), + ("Space Shuttle Challenger", "http://www.w3.org/2004/02/skos/core#definition", "A space shuttle that was destroyed in an accident during mission 51-L."), + ("The mid fuselage", "contains", "the payload bay"), + ("Volume I", "contains", "Chapter IX"), + ("accident", "resulted in", "firm national resolve that those men and women be forever enshrined in the annals of American heroes"), + ("Volume I", "contains", "Chapter VII"), + ("Volume I", "contains", "Chapter II"), + ("Volume I", "contains", "Chapter V"), + ("Commission", "believes", "its investigation and report have been responsive to the request of the President and hopes that they will serve the best interests of the nation in restoring the United States space program to its preeminent position in the world"), + ("Commission", "construe", "mandate"), + ("accident", "became", "a milestone on the way to achieving the full potential that space offers to mankind"), + ("Volume I", "contains", "The Commission"), + ("Commission", "http://www.w3.org/2004/02/skos/core#definition", "A group established to investigate the space shuttle accident"), + ("Volume I", "contains", "Appendix D"), + ("Commission", "had", "a mandate to review the circumstances surrounding the accident to establish the probable cause or causes of the accident"), + ("Volume I", "contains", "Recommendations") +] + +q = "Tell me some facts in the knowledge graph" + +resp = p.request( + id="graph-query", + terms = { + "name": "Jayney", + "knowledge": knowledge, + "question": q + } +) + +print(resp) + + + diff --git a/tests/test-prompt-question b/tests/test-prompt-question new file mode 100755 index 00000000..50660965 --- /dev/null +++ b/tests/test-prompt-question @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 + +import pulsar +from trustgraph.clients.prompt_client import PromptClient + +p = PromptClient(pulsar_host="pulsar://localhost:6650") + +question = """What is the square root of 16?""" + +resp = p.request( + id="question", + terms = { + "question": question + } +) + +print(resp) + diff --git a/tests/test-prompt-spanish-question b/tests/test-prompt-spanish-question new file mode 100755 index 00000000..e55a174b --- /dev/null +++ b/tests/test-prompt-spanish-question @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 + +import pulsar +from trustgraph.clients.prompt_client import PromptClient + +p = PromptClient(pulsar_host="pulsar://localhost:6650") + +question = """What is the square root of 16?""" + +resp = p.request( + id="question", + terms = { + "question": question, + "attitude": "Spanish-speaking bot" + } +) + +print(resp) + diff --git a/trustgraph-base/trustgraph/base/base_processor.py b/trustgraph-base/trustgraph/base/base_processor.py index 9b2e29ac..f258ff1a 100644 --- a/trustgraph-base/trustgraph/base/base_processor.py +++ b/trustgraph-base/trustgraph/base/base_processor.py @@ -39,8 +39,9 @@ class BaseProcessor: def __del__(self): - if self.client: - self.client.close() + if hasattr(self, "client"): + if self.client: + self.client.close() @staticmethod def add_args(parser): diff --git a/trustgraph-base/trustgraph/clients/agent_client.py b/trustgraph-base/trustgraph/clients/agent_client.py new file mode 100644 index 00000000..2ef69274 --- /dev/null +++ b/trustgraph-base/trustgraph/clients/agent_client.py @@ -0,0 +1,64 @@ + +import _pulsar + +from .. schema import AgentRequest, AgentResponse +from .. schema import agent_request_queue +from .. schema import agent_response_queue +from . base import BaseClient + +# Ugly +ERROR=_pulsar.LoggerLevel.Error +WARN=_pulsar.LoggerLevel.Warn +INFO=_pulsar.LoggerLevel.Info +DEBUG=_pulsar.LoggerLevel.Debug + +class AgentClient(BaseClient): + + def __init__( + self, log_level=ERROR, + subscriber=None, + input_queue=None, + output_queue=None, + pulsar_host="pulsar://pulsar:6650", + ): + + if input_queue is None: input_queue = agent_request_queue + if output_queue is None: output_queue = agent_response_queue + + super(AgentClient, self).__init__( + log_level=log_level, + subscriber=subscriber, + input_queue=input_queue, + output_queue=output_queue, + pulsar_host=pulsar_host, + input_schema=AgentRequest, + output_schema=AgentResponse, + ) + + def request( + self, + question, + think=None, + observe=None, + timeout=300 + ): + + def inspect(x): + + if x.thought and think: + think(x.thought) + return + + if x.observation and observe: + observe(x.observation) + return + + if x.answer: + return True + + return False + + return self.call( + question=question, inspect=inspect, timeout=timeout + ).answer + diff --git a/trustgraph-base/trustgraph/clients/base.py b/trustgraph-base/trustgraph/clients/base.py index 726b57df..78116f41 100644 --- a/trustgraph-base/trustgraph/clients/base.py +++ b/trustgraph-base/trustgraph/clients/base.py @@ -59,10 +59,14 @@ class BaseClient: def call(self, **args): timeout = args.get("timeout", DEFAULT_TIMEOUT) + inspect = args.get("inspect", lambda x: True) if "timeout" in args: del args["timeout"] + if "inspect" in args: + del args["inspect"] + id = str(uuid.uuid4()) r = self.input_schema(**args) @@ -103,6 +107,10 @@ class BaseClient: f"{value.error.type}: {value.error.message}" ) + complete = inspect(value) + + if not complete: continue + resp = msg.value() self.consumer.acknowledge(msg) return resp diff --git a/trustgraph-base/trustgraph/clients/document_rag_client.py b/trustgraph-base/trustgraph/clients/document_rag_client.py index 1e8a706d..103cbb69 100644 --- a/trustgraph-base/trustgraph/clients/document_rag_client.py +++ b/trustgraph-base/trustgraph/clients/document_rag_client.py @@ -38,7 +38,7 @@ class DocumentRagClient(BaseClient): output_schema=DocumentRagResponse, ) - def request(self, query, timeout=500): + def request(self, query, timeout=300): return self.call( query=query, timeout=timeout diff --git a/trustgraph-base/trustgraph/clients/graph_embeddings_client.py b/trustgraph-base/trustgraph/clients/graph_embeddings_client.py index bb1358fc..401266bc 100644 --- a/trustgraph-base/trustgraph/clients/graph_embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/graph_embeddings_client.py @@ -38,8 +38,12 @@ class GraphEmbeddingsClient(BaseClient): output_schema=GraphEmbeddingsResponse, ) - def request(self, vectors, limit=10, timeout=300): + def request( + self, vectors, user="trustgraph", collection="default", + limit=10, timeout=300 + ): return self.call( + user=user, collection=collection, vectors=vectors, limit=limit, timeout=timeout ).entities diff --git a/trustgraph-base/trustgraph/clients/graph_rag_client.py b/trustgraph-base/trustgraph/clients/graph_rag_client.py index e546aaed..9f8eff62 100644 --- a/trustgraph-base/trustgraph/clients/graph_rag_client.py +++ b/trustgraph-base/trustgraph/clients/graph_rag_client.py @@ -38,9 +38,12 @@ class GraphRagClient(BaseClient): output_schema=GraphRagResponse, ) - def request(self, query, timeout=500): + def request( + self, query, user="trustgraph", collection="default", + timeout=500 + ): return self.call( - query=query, timeout=timeout + user=user, collection=collection, query=query, timeout=timeout ).response diff --git a/trustgraph-base/trustgraph/clients/llm_client.py b/trustgraph-base/trustgraph/clients/llm_client.py index 3b52cb16..cfb0e606 100644 --- a/trustgraph-base/trustgraph/clients/llm_client.py +++ b/trustgraph-base/trustgraph/clients/llm_client.py @@ -35,6 +35,8 @@ class LlmClient(BaseClient): output_schema=TextCompletionResponse, ) - def request(self, prompt, timeout=300): - return self.call(prompt=prompt, timeout=timeout).response + def request(self, system, prompt, timeout=300): + return self.call( + system=system, prompt=prompt, timeout=timeout + ).response diff --git a/trustgraph-base/trustgraph/clients/prompt_client.py b/trustgraph-base/trustgraph/clients/prompt_client.py index f7f5a3ef..4b026cf0 100644 --- a/trustgraph-base/trustgraph/clients/prompt_client.py +++ b/trustgraph-base/trustgraph/clients/prompt_client.py @@ -1,7 +1,9 @@ import _pulsar +import json +import dataclasses -from .. schema import PromptRequest, PromptResponse, Fact, RowSchema, Field +from .. schema import PromptRequest, PromptResponse from .. schema import prompt_request_queue from .. schema import prompt_response_queue from . base import BaseClient @@ -12,6 +14,23 @@ WARN=_pulsar.LoggerLevel.Warn INFO=_pulsar.LoggerLevel.Info DEBUG=_pulsar.LoggerLevel.Debug +@dataclasses.dataclass +class Definition: + name: str + definition: str + +@dataclasses.dataclass +class Relationship: + s: str + p: str + o: str + o_entity: str + +@dataclasses.dataclass +class Topic: + name: str + definition: str + class PromptClient(BaseClient): def __init__( @@ -38,63 +57,116 @@ class PromptClient(BaseClient): output_schema=PromptResponse, ) + def request(self, id, variables, timeout=300): + + resp = self.call( + id=id, + terms={ + k: json.dumps(v) + for k, v in variables.items() + }, + timeout=timeout + ) + + if resp.text: return resp.text + + return json.loads(resp.object) + def request_definitions(self, chunk, timeout=300): - return self.call( - kind="extract-definitions", chunk=chunk, + defs = self.request( + id="extract-definitions", + variables={ + "text": chunk + }, timeout=timeout - ).definitions - - def request_topics(self, chunk, timeout=300): + ) - return self.call( - kind="extract-topics", chunk=chunk, - timeout=timeout - ).topics + return [ + Definition(name=d["entity"], definition=d["definition"]) + for d in defs + ] def request_relationships(self, chunk, timeout=300): - return self.call( - kind="extract-relationships", chunk=chunk, + rels = self.request( + id="extract-relationships", + variables={ + "text": chunk + }, timeout=timeout - ).relationships + ) + + return [ + Relationship( + s=d["subject"], + p=d["predicate"], + o=d["object"], + o_entity=d["object-entity"] + ) + for d in rels + ] + + def request_topics(self, chunk, timeout=300): + + topics = self.request( + id="extract-topics", + variables={ + "text": chunk + }, + timeout=timeout + ) + + return [ + Topic(name=d["topic"], definition=d["definition"]) + for d in topics + ] def request_rows(self, schema, chunk, timeout=300): - return self.call( - kind="extract-rows", chunk=chunk, - row_schema=RowSchema( - name=schema.name, - description=schema.description, - fields=[ - Field( - name=f.name, type=str(f.type), size=f.size, - primary=f.primary, description=f.description, - ) - for f in schema.fields - ] - ), + return self.request( + id="extract-rows", + variables={ + "chunk": chunk, + "row-schema": { + "name": schema.name, + "description": schema.description, + "fields": [ + { + "name": f.name, "type": str(f.type), + "size": f.size, "primary": f.primary, + "description": f.description, + } + for f in schema.fields + ] + } + }, timeout=timeout - ).rows + ) def request_kg_prompt(self, query, kg, timeout=300): - return self.call( - kind="kg-prompt", - query=query, - kg=[ - Fact(s=v[0], p=v[1], o=v[2]) - for v in kg - ], + return self.request( + id="kg-prompt", + variables={ + "query": query, + "knowledge": [ + { "s": v[0], "p": v[1], "o": v[2] } + for v in kg + ] + }, timeout=timeout - ).answer + ) def request_document_prompt(self, query, documents, timeout=300): - return self.call( - kind="document-prompt", - query=query, - documents=documents, + return self.request( + id="document-prompt", + variables={ + "query": query, + "documents": documents, + }, timeout=timeout - ).answer + ) + diff --git a/trustgraph-base/trustgraph/clients/triples_query_client.py b/trustgraph-base/trustgraph/clients/triples_query_client.py index 14b75151..fc1e4b26 100644 --- a/trustgraph-base/trustgraph/clients/triples_query_client.py +++ b/trustgraph-base/trustgraph/clients/triples_query_client.py @@ -48,11 +48,18 @@ class TriplesQueryClient(BaseClient): return Value(value=ent, is_uri=False) - def request(self, s, p, o, limit=10, timeout=60): + def request( + self, + s, p, o, + user="trustgraph", collection="default", + limit=10, timeout=120, + ): return self.call( s=self.create_value(s), p=self.create_value(p), o=self.create_value(o), + user=user, + collection=collection, limit=limit, timeout=timeout, ).triples diff --git a/trustgraph-base/trustgraph/knowledge/__init__.py b/trustgraph-base/trustgraph/knowledge/__init__.py new file mode 100644 index 00000000..0ab6b5db --- /dev/null +++ b/trustgraph-base/trustgraph/knowledge/__init__.py @@ -0,0 +1,6 @@ + +from . identifier import * +from . publication import * +from . document import * +from . organization import * + diff --git a/trustgraph-base/trustgraph/knowledge/defs.py b/trustgraph-base/trustgraph/knowledge/defs.py new file mode 100644 index 00000000..b95863c6 --- /dev/null +++ b/trustgraph-base/trustgraph/knowledge/defs.py @@ -0,0 +1,25 @@ + +IS_A = 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type' +LABEL = 'http://www.w3.org/2000/01/rdf-schema#label' + +DIGITAL_DOCUMENT = 'https://schema.org/DigitalDocument' +PUBLICATION_EVENT = 'https://schema.org/PublicationEvent' +ORGANIZATION = 'https://schema.org/Organization' + +NAME = 'https://schema.org/name' +DESCRIPTION = 'https://schema.org/description' +COPYRIGHT_NOTICE = 'https://schema.org/copyrightNotice' +COPYRIGHT_HOLDER = 'https://schema.org/copyrightHolder' +COPYRIGHT_YEAR = 'https://schema.org/copyrightYear' +LICENSE = 'https://schema.org/license' +PUBLICATION = 'https://schema.org/publication' +START_DATE = 'https://schema.org/startDate' +END_DATE = 'https://schema.org/endDate' +PUBLISHED_BY = 'https://schema.org/publishedBy' +DATE_PUBLISHED = 'https://schema.org/datePublished' +PUBLICATION = 'https://schema.org/publication' +DATE_PUBLISHED = 'https://schema.org/datePublished' +URL = 'https://schema.org/url' +IDENTIFIER = 'https://schema.org/identifier' +KEYWORD = 'https://schema.org/keywords' + diff --git a/trustgraph-base/trustgraph/knowledge/document.py b/trustgraph-base/trustgraph/knowledge/document.py new file mode 100644 index 00000000..dc2f43e3 --- /dev/null +++ b/trustgraph-base/trustgraph/knowledge/document.py @@ -0,0 +1,120 @@ + +from . defs import * +from .. schema import Triple, Value + +class DigitalDocument: + + def __init__( + self, id, name=None, description=None, copyright_notice=None, + copyright_holder=None, copyright_year=None, license=None, + identifier=None, + publication=None, url=None, keywords=[] + ): + + self.id = id + self.name = name + self.description = description + self.copyright_notice = copyright_notice + self.copyright_holder = copyright_holder + self.copyright_year = copyright_year + self.license = license + self.publication = publication + self.url = url + self.identifier = identifier + self.keywords = keywords + + def emit(self, emit): + + emit(Triple( + s=Value(value=self.id, is_uri=True), + p=Value(value=IS_A, is_uri=True), + o=Value(value=DIGITAL_DOCUMENT, is_uri=True) + )) + + if self.name: + + emit(Triple( + s=Value(value=self.id, is_uri=True), + p=Value(value=LABEL, is_uri=True), + o=Value(value=self.name, is_uri=False) + )) + + emit(Triple( + s=Value(value=self.id, is_uri=True), + p=Value(value=NAME, is_uri=True), + o=Value(value=self.name, is_uri=False) + )) + + if self.identifier: + + emit(Triple( + s=Value(value=id, is_uri=True), + p=Value(value=IDENTIFIER, is_uri=True), + o=Value(value=self.identifier, is_uri=False) + )) + + if self.description: + + emit(Triple( + s=Value(value=self.id, is_uri=True), + p=Value(value=DESCRIPTION, is_uri=True), + o=Value(value=self.description, is_uri=False) + )) + + if self.copyright_notice: + + emit(Triple( + s=Value(value=self.id, is_uri=True), + p=Value(value=COPYRIGHT_NOTICE, is_uri=True), + o=Value(value=self.copyright_notice, is_uri=False) + )) + + if self.copyright_holder: + + emit(Triple( + s=Value(value=self.id, is_uri=True), + p=Value(value=COPYRIGHT_HOLDER, is_uri=True), + o=Value(value=self.copyright_holder, is_uri=False) + )) + + if self.copyright_year: + + emit(Triple( + s=Value(value=self.id, is_uri=True), + p=Value(value=COPYRIGHT_YEAR, is_uri=True), + o=Value(value=self.copyright_year, is_uri=False) + )) + + if self.license: + + emit(Triple( + s=Value(value=self.id, is_uri=True), + p=Value(value=LICENSE, is_uri=True), + o=Value(value=self.license, is_uri=False) + )) + + if self.keywords: + for k in self.keywords: + emit(Triple( + s=Value(value=self.id, is_uri=True), + p=Value(value=KEYWORD, is_uri=True), + o=Value(value=k, is_uri=False) + )) + + if self.publication: + + emit(Triple( + s=Value(value=self.id, is_uri=True), + p=Value(value=PUBLICATION, is_uri=True), + o=Value(value=self.publication.id, is_uri=True) + )) + + self.publication.emit(emit) + + if self.url: + + emit(Triple( + s=Value(value=self.id, is_uri=True), + p=Value(value=URL, is_uri=True), + o=Value(value=self.url, is_uri=True) + )) diff --git a/trustgraph-base/trustgraph/knowledge/identifier.py b/trustgraph-base/trustgraph/knowledge/identifier.py new file mode 100644 index 00000000..e0052fce --- /dev/null +++ b/trustgraph-base/trustgraph/knowledge/identifier.py @@ -0,0 +1,23 @@ + +import uuid +import hashlib + +def hash(data): + + if isinstance(data, str): + data = data.encode("utf-8") + + # Create a SHA256 hash from the data + id = hashlib.sha256(data).hexdigest() + + # Convert into a UUID, 64-byte hash becomes 32-byte UUID + id = str(uuid.UUID(id[::2])) + + return id + +def to_uri(pref, id): + return f"https://trustgraph.ai/{pref}/{id}" + +PREF_PUBEV = "pubev" +PREF_ORG = "org" +PREF_DOC = "doc" diff --git a/trustgraph-base/trustgraph/knowledge/organization.py b/trustgraph-base/trustgraph/knowledge/organization.py new file mode 100644 index 00000000..1129dd6c --- /dev/null +++ b/trustgraph-base/trustgraph/knowledge/organization.py @@ -0,0 +1,40 @@ + +from . defs import * +from .. schema import Triple, Value + +class Organization: + def __init__(self, id, name=None, description=None): + self.id = id + self.name = name + self.description = description + + def emit(self, emit): + + emit(Triple( + s=Value(value=self.id, is_uri=True), + p=Value(value=IS_A, is_uri=True), + o=Value(value=ORGANIZATION, is_uri=True) + )) + + if self.name: + + emit(Triple( + s=Value(value=self.id, is_uri=True), + p=Value(value=LABEL, is_uri=True), + o=Value(value=self.name, is_uri=False) + )) + + emit(Triple( + s=Value(value=self.id, is_uri=True), + p=Value(value=NAME, is_uri=True), + o=Value(value=self.name, is_uri=False) + )) + + if self.description: + + emit(Triple( + s=Value(value=self.id, is_uri=True), + p=Value(value=DESCRIPTION, is_uri=True), + o=Value(value=self.description, is_uri=False) + )) + diff --git a/trustgraph-base/trustgraph/knowledge/publication.py b/trustgraph-base/trustgraph/knowledge/publication.py new file mode 100644 index 00000000..3c9d41c8 --- /dev/null +++ b/trustgraph-base/trustgraph/knowledge/publication.py @@ -0,0 +1,69 @@ + +from . defs import * +from .. schema import Triple, Value + +class PublicationEvent: + def __init__( + self, id, organization=None, name=None, description=None, + start_date=None, end_date=None, + ): + self.id = id + self.organization = organization + self.name = name + self.description = description + self.start_date = start_date + self.end_date = end_date + + def emit(self, emit): + + emit(Triple( + s=Value(value=self.id, is_uri=True), + p=Value(value=IS_A, is_uri=True), + o=Value(value=PUBLICATION_EVENT, is_uri=True))) + + if self.name: + + emit(Triple( + s=Value(value=self.id, is_uri=True), + p=Value(value=LABEL, is_uri=True), + o=Value(value=self.name, is_uri=False) + )) + + emit(Triple( + s=Value(value=self.id, is_uri=True), + p=Value(value=NAME, is_uri=True), + o=Value(value=self.name, is_uri=False) + )) + + if self.description: + + emit(Triple( + s=Value(value=self.id, is_uri=True), + p=Value(value=DESCRIPTION, is_uri=True), + o=Value(value=self.description, is_uri=False) + )) + + if self.organization: + + emit(Triple( + s=Value(value=self.id, is_uri=True), + p=Value(value=PUBLISHED_BY, is_uri=True), + o=Value(value=self.organization.id, is_uri=True) + )) + + self.organization.emit(emit) + + if self.start_date: + + emit(Triple( + s=Value(value=self.id, is_uri=True), + p=Value(value=START_DATE, is_uri=True), + o=Value(value=self.start_date, is_uri=False) + )) + + if self.end_date: + + emit(Triple( + s=Value(value=self.id, is_uri=True), + p=Value(value=END_DATE, is_uri=True), + o=Value(value=self.end_date, is_uri=False))) diff --git a/trustgraph-base/trustgraph/rdf.py b/trustgraph-base/trustgraph/rdf.py index b65d9c29..ef1da183 100644 --- a/trustgraph-base/trustgraph/rdf.py +++ b/trustgraph-base/trustgraph/rdf.py @@ -1,6 +1,7 @@ RDF_LABEL = "http://www.w3.org/2000/01/rdf-schema#label" DEFINITION = "http://www.w3.org/2004/02/skos/core#definition" +SUBJECT_OF = "https://schema.org/subjectOf" TRUSTGRAPH_ENTITIES = "http://trustgraph.ai/e/" diff --git a/trustgraph-base/trustgraph/schema/__init__.py b/trustgraph-base/trustgraph/schema/__init__.py index 0cd5a370..3196691b 100644 --- a/trustgraph-base/trustgraph/schema/__init__.py +++ b/trustgraph-base/trustgraph/schema/__init__.py @@ -7,6 +7,6 @@ from . object import * from . topic import * from . graph import * from . retrieval import * - - +from . metadata import * +from . agent import * diff --git a/trustgraph-base/trustgraph/schema/agent.py b/trustgraph-base/trustgraph/schema/agent.py new file mode 100644 index 00000000..9bcdde51 --- /dev/null +++ b/trustgraph-base/trustgraph/schema/agent.py @@ -0,0 +1,37 @@ + +from pulsar.schema import Record, String, Array, Map + +from . topic import topic +from . types import Error + +############################################################################ + +# Prompt services, abstract the prompt generation + +class AgentStep(Record): + thought = String() + action = String() + arguments = Map(String()) + observation = String() + +class AgentRequest(Record): + question = String() + plan = String() + state = String() + history = Array(AgentStep()) + +class AgentResponse(Record): + answer = String() + error = Error() + thought = String() + observation = String() + +agent_request_queue = topic( + 'agent', kind='non-persistent', namespace='request' +) +agent_response_queue = topic( + 'agent', kind='non-persistent', namespace='response' +) + +############################################################################ + diff --git a/trustgraph-base/trustgraph/schema/documents.py b/trustgraph-base/trustgraph/schema/documents.py index d80ff38f..59aba287 100644 --- a/trustgraph-base/trustgraph/schema/documents.py +++ b/trustgraph-base/trustgraph/schema/documents.py @@ -2,17 +2,13 @@ from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array, Double from . topic import topic from . types import Error - -class Source(Record): - source = String() - id = String() - title = String() +from . metadata import Metadata ############################################################################ # PDF docs etc. class Document(Record): - source = Source() + metadata = Metadata() data = Bytes() document_ingest_queue = topic('document-load') @@ -22,7 +18,7 @@ document_ingest_queue = topic('document-load') # Text documents / text from PDF class TextDocument(Record): - source = Source() + metadata = Metadata() text = Bytes() text_ingest_queue = topic('text-document-load') @@ -32,7 +28,7 @@ text_ingest_queue = topic('text-document-load') # Chunks of text class Chunk(Record): - source = Source() + metadata = Metadata() chunk = Bytes() chunk_ingest_queue = topic('chunk-load') @@ -42,7 +38,7 @@ chunk_ingest_queue = topic('chunk-load') # Chunk embeddings are an embeddings associated with a text chunk class ChunkEmbeddings(Record): - source = Source() + metadata = Metadata() vectors = Array(Array(Double())) chunk = Bytes() diff --git a/trustgraph-base/trustgraph/schema/graph.py b/trustgraph-base/trustgraph/schema/graph.py index 234a0bed..2d108a30 100644 --- a/trustgraph-base/trustgraph/schema/graph.py +++ b/trustgraph-base/trustgraph/schema/graph.py @@ -1,16 +1,16 @@ from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array, Double -from . documents import Source -from . types import Error, Value +from . types import Error, Value, Triple from . topic import topic +from . metadata import Metadata ############################################################################ # Graph embeddings are embeddings associated with a graph entity class GraphEmbeddings(Record): - source = Source() + metadata = Metadata() vectors = Array(Array(Double())) entity = Value() @@ -23,6 +23,8 @@ graph_embeddings_store_queue = topic('graph-embeddings-store') class GraphEmbeddingsRequest(Record): vectors = Array(Array(Double())) limit = Integer() + user = String() + collection = String() class GraphEmbeddingsResponse(Record): error = Error() @@ -39,11 +41,9 @@ graph_embeddings_response_queue = topic( # Graph triples -class Triple(Record): - source = Source() - s = Value() - p = Value() - o = Value() +class Triples(Record): + metadata = Metadata() + triples = Array(Triple()) triples_store_queue = topic('triples-store') @@ -56,6 +56,8 @@ class TriplesQueryRequest(Record): p = Value() o = Value() limit = Integer() + user = String() + collection = String() class TriplesQueryResponse(Record): error = Error() diff --git a/trustgraph-base/trustgraph/schema/metadata.py b/trustgraph-base/trustgraph/schema/metadata.py new file mode 100644 index 00000000..5922db26 --- /dev/null +++ b/trustgraph-base/trustgraph/schema/metadata.py @@ -0,0 +1,16 @@ + +from pulsar.schema import Record, String, Array +from . types import Triple + +class Metadata(Record): + + # Source identifier + id = String() + + # Subgraph + metadata = Array(Triple()) + + # Collection management + user = String() + collection = String() + diff --git a/trustgraph-base/trustgraph/schema/models.py b/trustgraph-base/trustgraph/schema/models.py index 2196a3d2..70cb2c8f 100644 --- a/trustgraph-base/trustgraph/schema/models.py +++ b/trustgraph-base/trustgraph/schema/models.py @@ -9,6 +9,7 @@ from . types import Error # LLM text completion class TextCompletionRequest(Record): + system = String() prompt = String() class TextCompletionResponse(Record): diff --git a/trustgraph-base/trustgraph/schema/object.py b/trustgraph-base/trustgraph/schema/object.py index 3377e2df..60c2bdc3 100644 --- a/trustgraph-base/trustgraph/schema/object.py +++ b/trustgraph-base/trustgraph/schema/object.py @@ -2,7 +2,7 @@ from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array from pulsar.schema import Double, Map -from . documents import Source +from . metadata import Metadata from . types import Value, RowSchema from . topic import topic @@ -12,7 +12,7 @@ from . topic import topic # object class ObjectEmbeddings(Record): - source = Source() + metadata = Metadata() vectors = Array(Array(Double())) name = String() key_name = String() @@ -25,7 +25,7 @@ object_embeddings_store_queue = topic('object-embeddings-store') # Stores rows of information class Rows(Record): - source = Source() + metadata = Metadata() row_schema = RowSchema() rows = Array(Map(String())) diff --git a/trustgraph-base/trustgraph/schema/prompt.py b/trustgraph-base/trustgraph/schema/prompt.py index c7dbfd43..9bcdf117 100644 --- a/trustgraph-base/trustgraph/schema/prompt.py +++ b/trustgraph-base/trustgraph/schema/prompt.py @@ -39,20 +39,21 @@ class Fact(Record): # schema, chunk -> rows class PromptRequest(Record): - kind = String() - chunk = String() - query = String() - kg = Array(Fact()) - documents = Array(Bytes()) - row_schema = RowSchema() + id = String() + + # JSON encoded values + terms = Map(String()) class PromptResponse(Record): + + # Error case error = Error() - answer = String() - definitions = Array(Definition()) - topics = Array(Topic()) - relationships = Array(Relationship()) - rows = Array(Map(String())) + + # Just plain text + text = String() + + # JSON encoded + object = String() prompt_request_queue = topic( 'prompt', kind='non-persistent', namespace='request' diff --git a/trustgraph-base/trustgraph/schema/retrieval.py b/trustgraph-base/trustgraph/schema/retrieval.py index fa0288fc..ad860c3c 100644 --- a/trustgraph-base/trustgraph/schema/retrieval.py +++ b/trustgraph-base/trustgraph/schema/retrieval.py @@ -9,6 +9,8 @@ from . types import Error, Value class GraphRagQuery(Record): query = String() + user = String() + collection = String() class GraphRagResponse(Record): error = Error() @@ -27,6 +29,8 @@ graph_rag_response_queue = topic( class DocumentRagQuery(Record): query = String() + user = String() + collection = String() class DocumentRagResponse(Record): error = Error() diff --git a/trustgraph-base/trustgraph/schema/types.py b/trustgraph-base/trustgraph/schema/types.py index 4cad70ac..b75a0884 100644 --- a/trustgraph-base/trustgraph/schema/types.py +++ b/trustgraph-base/trustgraph/schema/types.py @@ -10,6 +10,11 @@ class Value(Record): is_uri = Boolean() type = String() +class Triple(Record): + s = Value() + p = Value() + o = Value() + class Field(Record): name = String() # int, string, long, bool, float, double diff --git a/trustgraph-bedrock/setup.py b/trustgraph-bedrock/setup.py index 787cbad4..80cee09c 100644 --- a/trustgraph-bedrock/setup.py +++ b/trustgraph-bedrock/setup.py @@ -34,7 +34,7 @@ setuptools.setup( python_requires='>=3.8', download_url = "https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v" + version + ".tar.gz", install_requires=[ - "trustgraph-base<0.12", + "trustgraph-base>=0.15,<0.16", "pulsar-client", "prometheus-client", "boto3", diff --git a/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py b/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py index 0d050261..a9c05cc8 100755 --- a/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py +++ b/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py @@ -7,6 +7,7 @@ Input is prompt, output is response. Mistral is default. import boto3 import json from prometheus_client import Histogram +import os from .... schema import TextCompletionRequest, TextCompletionResponse, Error from .... schema import text_completion_request_queue @@ -21,10 +22,11 @@ default_input_queue = text_completion_request_queue default_output_queue = text_completion_response_queue default_subscriber = module default_model = 'mistral.mistral-large-2407-v1:0' -default_region = 'us-west-2' default_temperature = 0.0 default_max_output = 2048 - +default_aws_id_key = os.getenv("AWS_ID_KEY", None) +default_aws_secret = os.getenv("AWS_SECRET", None) +default_aws_region = os.getenv("AWS_REGION", 'us-west-2') class Processor(ConsumerProducer): @@ -34,12 +36,21 @@ class Processor(ConsumerProducer): output_queue = params.get("output_queue", default_output_queue) subscriber = params.get("subscriber", default_subscriber) model = params.get("model", default_model) - aws_id = params.get("aws_id_key") - aws_secret = params.get("aws_secret") - aws_region = params.get("aws_region", default_region) + aws_id_key = params.get("aws_id_key", default_aws_id_key) + aws_secret = params.get("aws_secret", default_aws_secret) + aws_region = params.get("aws_region", default_aws_region) temperature = params.get("temperature", default_temperature) max_output = params.get("max_output", default_max_output) + if aws_id_key is None: + raise RuntimeError("AWS ID not specified") + + if aws_secret is None: + raise RuntimeError("AWS secret not specified") + + if aws_region is None: + raise RuntimeError("AWS region not specified") + super(Processor, self).__init__( **params | { "input_queue": input_queue, @@ -71,7 +82,7 @@ class Processor(ConsumerProducer): self.max_output = max_output self.session = boto3.Session( - aws_access_key_id=aws_id, + aws_access_key_id=aws_id_key, aws_secret_access_key=aws_secret, region_name=aws_region ) @@ -90,7 +101,7 @@ class Processor(ConsumerProducer): print(f"Handling prompt {id}...", flush=True) - prompt = v.prompt + prompt = v.system + "\n\n" + v.prompt try: @@ -289,17 +300,20 @@ class Processor(ConsumerProducer): parser.add_argument( '-z', '--aws-id-key', + default=default_aws_id_key, help=f'AWS ID Key' ) parser.add_argument( '-k', '--aws-secret', + default=default_aws_secret, help=f'AWS Secret Key' ) parser.add_argument( '-r', '--aws-region', - help=f'AWS Region (default: us-west-2)' + default=default_aws_region, + help=f'AWS Region' ) parser.add_argument( @@ -320,4 +334,3 @@ def run(): Processor.start(module, __doc__) - diff --git a/trustgraph-cli/scripts/tg-graph-show b/trustgraph-cli/scripts/tg-graph-show index a737c97b..c09266fb 100755 --- a/trustgraph-cli/scripts/tg-graph-show +++ b/trustgraph-cli/scripts/tg-graph-show @@ -9,12 +9,17 @@ import os from trustgraph.clients.triples_query_client import TriplesQueryClient default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650') +default_user = 'trustgraph' +default_collection = 'default' -def show_graph(pulsar): +def show_graph(pulsar, user, collection): tq = TriplesQueryClient(pulsar_host=pulsar) - rows = tq.request(None, None, None, limit=10_000_000) + rows = tq.request( + user=user, collection=collection, + s=None, p=None, o=None, limit=10_000_000 + ) for row in rows: print(row.s.value, row.p.value, row.o.value) @@ -22,7 +27,7 @@ def show_graph(pulsar): def main(): parser = argparse.ArgumentParser( - prog='graph-show', + prog='tg-graph-show', description=__doc__, ) @@ -32,11 +37,26 @@ def main(): help=f'Pulsar host (default: {default_pulsar_host})', ) + parser.add_argument( + '-u', '--user', + default=default_user, + help=f'User ID (default: {default_user})' + ) + + parser.add_argument( + '-c', '--collection', + default=default_collection, + help=f'Collection ID (default: {default_collection})' + ) + args = parser.parse_args() try: - show_graph(args.pulsar_host) + show_graph( + pulsar=args.pulsar_host, user=args.user, + collection=args.collection, + ) except Exception as e: diff --git a/trustgraph-cli/scripts/tg-graph-to-turtle b/trustgraph-cli/scripts/tg-graph-to-turtle index bff03fc6..1d75478e 100755 --- a/trustgraph-cli/scripts/tg-graph-to-turtle +++ b/trustgraph-cli/scripts/tg-graph-to-turtle @@ -1,7 +1,8 @@ #!/usr/bin/env python3 """ -Connects to the graph query service and dumps all graph edges. +Connects to the graph query service and dumps all graph edges in Turtle +format. """ import argparse @@ -50,7 +51,7 @@ def show_graph(pulsar): def main(): parser = argparse.ArgumentParser( - prog='graph-show', + prog='tg-graph-to-turtle', description=__doc__, ) diff --git a/trustgraph-cli/scripts/tg-init-pulsar b/trustgraph-cli/scripts/tg-init-pulsar index 0113a7f0..07fd31eb 100755 --- a/trustgraph-cli/scripts/tg-init-pulsar +++ b/trustgraph-cli/scripts/tg-init-pulsar @@ -1,7 +1,7 @@ #!/usr/bin/env python3 """ -Initialises Pulsar with Trustgraph tenant / namespaces & policy +Initialises Pulsar with Trustgraph tenant / namespaces & policy. """ import requests diff --git a/trustgraph-cli/scripts/tg-invoke-agent b/trustgraph-cli/scripts/tg-invoke-agent new file mode 100755 index 00000000..3f05071c --- /dev/null +++ b/trustgraph-cli/scripts/tg-invoke-agent @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 + +""" +Uses the GraphRAG service to answer a query +""" + +import argparse +import os +import textwrap + +from trustgraph.clients.agent_client import AgentClient + +default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650') +default_user = 'trustgraph' +default_collection = 'default' + +def wrap(text, width=75): + if text is None: text = "n/a" + out = textwrap.wrap( + text, width=width + ) + return "\n".join(out) + +def output(text, prefix="> ", width=78): + out = textwrap.indent( + text, prefix=prefix + ) + print(out) + +def query( + pulsar_host, query, user, collection, + plan=None, state=None, verbose=False +): + + am = AgentClient(pulsar_host=pulsar_host) + + if verbose: + output(wrap(query), "\U00002753 ") + print() + + def think(x): + if verbose: + output(wrap(x), "\U0001f914 ") + print() + + def observe(x): + if verbose: + output(wrap(x), "\U0001f4a1 ") + print() + + resp = am.request( + question=query, think=think, observe=observe, + ) + + print(resp) + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-invoke-agent', + description=__doc__, + ) + + parser.add_argument( + '-p', '--pulsar-host', + default=default_pulsar_host, + help=f'Pulsar host (default: {default_pulsar_host})', + ) + + parser.add_argument( + '-q', '--query', + required=True, + help=f'Query to execute', + ) + + parser.add_argument( + '-u', '--user', + default=default_user, + help=f'User ID (default: {default_user})' + ) + + parser.add_argument( + '-c', '--collection', + default=default_collection, + help=f'Collection ID (default: {default_collection})' + ) + + parser.add_argument( + '-l', '--plan', + help=f'Agent plan (default: unspecified)' + ) + + parser.add_argument( + '-s', '--state', + help=f'Agent initial state (default: unspecified)' + ) + + parser.add_argument( + '-v', '--verbose', + action="store_true", + help=f'Output thinking/observations' + ) + + args = parser.parse_args() + + try: + + query( + pulsar_host=args.pulsar_host, + query=args.query, + user=args.user, + collection=args.collection, + plan=args.plan, + state=args.state, + verbose=args.verbose, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + +main() + diff --git a/trustgraph-cli/scripts/tg-invoke-llm b/trustgraph-cli/scripts/tg-invoke-llm new file mode 100755 index 00000000..d7289b5f --- /dev/null +++ b/trustgraph-cli/scripts/tg-invoke-llm @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 + +""" +Invokes the text completion service by specifying an LLM system prompt +and user prompt. Both arguments are required. +""" + +import argparse +import os +import json +from trustgraph.clients.llm_client import LlmClient + +default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650') + +def query(pulsar_host, system, prompt): + + cli = LlmClient(pulsar_host=pulsar_host) + + resp = cli.request(system=system, prompt=prompt) + + print(resp) + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-invoke-llm', + description=__doc__, + ) + + parser.add_argument( + '-p', '--pulsar-host', + default=default_pulsar_host, + help=f'Pulsar host (default: {default_pulsar_host})', + ) + + parser.add_argument( + 'system', + nargs=1, + help='LLM system prompt e.g. You are a helpful assistant', + ) + + parser.add_argument( + 'prompt', + nargs=1, + help='LLM prompt e.g. What is 2 + 2?', + ) + + args = parser.parse_args() + + try: + + query( + pulsar_host=args.pulsar_host, + system=args.system[0], + prompt=args.prompt[0], + ) + + except Exception as e: + + print("Exception:", e, flush=True) + +main() + diff --git a/trustgraph-cli/scripts/tg-invoke-prompt b/trustgraph-cli/scripts/tg-invoke-prompt new file mode 100755 index 00000000..19f30912 --- /dev/null +++ b/trustgraph-cli/scripts/tg-invoke-prompt @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 + +""" +Invokes the LLM prompt service by specifying the prompt template to use +and values for the variables in the prompt template. The +prompt template is identified by its template identifier e.g. +question, extract-definitions. Template variable values are specified +using key=value arguments on the command line, and these replace +{{key}} placeholders in the template. +""" + +import argparse +import os +import json +from trustgraph.clients.prompt_client import PromptClient + +default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650') + +def query(pulsar_host, template_id, variables): + + cli = PromptClient(pulsar_host=pulsar_host) + + resp = cli.request(id=template_id, variables=variables) + + if isinstance(resp, str): + print(resp) + else: + print(json.dumps(resp, indent=4)) + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-invoke-prompt', + description=__doc__, + ) + + parser.add_argument( + '-p', '--pulsar-host', + default=default_pulsar_host, + help=f'Pulsar host (default: {default_pulsar_host})', + ) + + parser.add_argument( + 'id', + metavar='template-id', + nargs=1, + help=f'Prompt identifier e.g. question, extract-definitions', + ) + + parser.add_argument( + 'variable', + nargs='*', + metavar="variable=value", + help='''Prompt template terms of the form variable=value, can be +specified multiple times''', + ) + + args = parser.parse_args() + + variables = {} + + for variable in args.variable: + + toks = variable.split("=", 1) + if len(toks) != 2: + raise RuntimeError(f"Malformed variable: {variable}") + + variables[toks[0]] = toks[1] + + try: + + query( + pulsar_host=args.pulsar_host, + template_id=args.id[0], + variables=variables, + ) + + except Exception as e: + + print("Exception:", e, flush=True) + +main() + diff --git a/trustgraph-cli/scripts/tg-load-pdf b/trustgraph-cli/scripts/tg-load-pdf index 5d54da93..18ac57cb 100755 --- a/trustgraph-cli/scripts/tg-load-pdf +++ b/trustgraph-cli/scripts/tg-load-pdf @@ -6,14 +6,23 @@ Loads a PDF document into TrustGraph processing. import pulsar from pulsar.schema import JsonSchema -from trustgraph.schema import Document, Source, document_ingest_queue import base64 import hashlib import argparse import os import time +import uuid +from trustgraph.schema import Document, document_ingest_queue +from trustgraph.schema import Metadata from trustgraph.log_level import LogLevel +from trustgraph.knowledge import hash, to_uri +from trustgraph.knowledge import PREF_PUBEV, PREF_DOC, PREF_ORG +from trustgraph.knowledge import Organization, PublicationEvent +from trustgraph.knowledge import DigitalDocument + +default_user = 'trustgraph' +default_collection = 'default' class Loader: @@ -21,7 +30,10 @@ class Loader: self, pulsar_host, output_queue, + user, + collection, log_level, + metadata, ): self.client = pulsar.Client( @@ -35,6 +47,10 @@ class Loader: chunking_enabled=True, ) + self.user = user + self.collection = collection + self.metadata = metadata + def load(self, files): for file in files: @@ -47,13 +63,25 @@ class Loader: path = file data = open(path, "rb").read() - id = hashlib.sha256(path.encode("utf-8")).hexdigest()[0:8] + # Create a SHA256 hash from the data + id = hash(data) + + id = to_uri(PREF_DOC, id) + + triples = [] + + def emit(t): + triples.append(t) + + self.metadata.id = id + self.metadata.emit(emit) r = Document( - source=Source( - source=path, - title=path, + metadata=Metadata( id=id, + metadata=triples, + user=self.user, + collection=self.collection, ), data=base64.b64encode(data), ) @@ -71,7 +99,7 @@ class Loader: def main(): parser = argparse.ArgumentParser( - prog='loader', + prog='tg-load-pdf', description=__doc__, ) @@ -90,6 +118,66 @@ def main(): help=f'Output queue (default: {default_output_queue})' ) + parser.add_argument( + '-u', '--user', + default=default_user, + help=f'User ID (default: {default_user})' + ) + + parser.add_argument( + '-c', '--collection', + default=default_collection, + help=f'Collection ID (default: {default_collection})' + ) + + parser.add_argument( + '--name', help=f'Document name' + ) + + parser.add_argument( + '--description', help=f'Document description' + ) + + parser.add_argument( + '--copyright-notice', help=f'Copyright notice' + ) + + parser.add_argument( + '--copyright-holder', help=f'Copyright holder' + ) + + parser.add_argument( + '--copyright-year', help=f'Copyright year' + ) + + parser.add_argument( + '--license', help=f'Copyright license' + ) + + parser.add_argument( + '--publication-organization', help=f'Publication organization' + ) + + parser.add_argument( + '--publication-description', help=f'Publication description' + ) + + parser.add_argument( + '--publication-date', help=f'Publication date' + ) + + parser.add_argument( + '--url', help=f'Document URL' + ) + + parser.add_argument( + '--keyword', nargs='+', help=f'Keyword' + ) + + parser.add_argument( + '--identifier', '--id', help=f'Document ID' + ) + parser.add_argument( '-l', '--log-level', type=LogLevel, @@ -109,10 +197,38 @@ def main(): try: + document = DigitalDocument( + id, + name=args.name, + description=args.description, + copyright_notice=args.copyright_notice, + copyright_holder=args.copyright_holder, + copyright_year=args.copyright_year, + license=args.license, + url=args.url, + keywords=args.keyword, + ) + + if args.publication_organization: + org = Organization( + id=to_uri(PREF_ORG, hash(args.publication_organization)), + name=args.publication_organization, + ) + document.publication = PublicationEvent( + id = to_uri(PREF_PUBEV, str(uuid.uuid4())), + organization=org, + description=args.publication_description, + start_date=args.publication_date, + end_date=args.publication_date, + ) + p = Loader( pulsar_host=args.pulsar_host, output_queue=args.output_queue, + user=args.user, + collection=args.collection, log_level=args.log_level, + metadata=document, ) p.load(args.files) diff --git a/trustgraph-cli/scripts/tg-load-text b/trustgraph-cli/scripts/tg-load-text index 8137006c..88dc8e17 100755 --- a/trustgraph-cli/scripts/tg-load-text +++ b/trustgraph-cli/scripts/tg-load-text @@ -6,14 +6,23 @@ Loads a text document into TrustGraph processing. import pulsar from pulsar.schema import JsonSchema -from trustgraph.schema import TextDocument, Source, text_ingest_queue import base64 import hashlib import argparse import os import time +import uuid +from trustgraph.schema import TextDocument, text_ingest_queue +from trustgraph.schema import Metadata from trustgraph.log_level import LogLevel +from trustgraph.knowledge import hash, to_uri +from trustgraph.knowledge import PREF_PUBEV, PREF_DOC, PREF_ORG +from trustgraph.knowledge import Organization, PublicationEvent +from trustgraph.knowledge import DigitalDocument + +default_user = 'trustgraph' +default_collection = 'default' class Loader: @@ -21,7 +30,10 @@ class Loader: self, pulsar_host, output_queue, + user, + collection, log_level, + metadata, ): self.client = pulsar.Client( @@ -35,6 +47,10 @@ class Loader: chunking_enabled=True, ) + self.user = user + self.collection = collection + self.metadata = metadata + def load(self, files): for file in files: @@ -47,13 +63,25 @@ class Loader: path = file data = open(path, "rb").read() - id = hashlib.sha256(path.encode("utf-8")).hexdigest()[0:8] + # Create a SHA256 hash from the data + id = hash(data) + + id = to_uri(PREF_DOC, id) + + triples = [] + + def emit(t): + triples.append(t) + + self.metadata.id = id + self.metadata.emit(emit) r = TextDocument( - source=Source( - source=path, - title=path, + metadata=Metadata( id=id, + metadata=triples, + user=self.user, + collection=self.collection, ), text=data, ) @@ -71,7 +99,7 @@ class Loader: def main(): parser = argparse.ArgumentParser( - prog='loader', + prog='tg-load-text', description=__doc__, ) @@ -90,6 +118,66 @@ def main(): help=f'Output queue (default: {default_output_queue})' ) + parser.add_argument( + '-u', '--user', + default=default_user, + help=f'User ID (default: {default_user})' + ) + + parser.add_argument( + '-c', '--collection', + default=default_collection, + help=f'Collection ID (default: {default_collection})' + ) + + parser.add_argument( + '--name', help=f'Document name' + ) + + parser.add_argument( + '--description', help=f'Document description' + ) + + parser.add_argument( + '--copyright-notice', help=f'Copyright notice' + ) + + parser.add_argument( + '--copyright-holder', help=f'Copyright holder' + ) + + parser.add_argument( + '--copyright-year', help=f'Copyright year' + ) + + parser.add_argument( + '--license', help=f'Copyright license' + ) + + parser.add_argument( + '--publication-organization', help=f'Publication organization' + ) + + parser.add_argument( + '--publication-description', help=f'Publication description' + ) + + parser.add_argument( + '--publication-date', help=f'Publication date' + ) + + parser.add_argument( + '--url', help=f'Document URL' + ) + + parser.add_argument( + '--keyword', nargs='+', help=f'Keyword' + ) + + parser.add_argument( + '--identifier', '--id', help=f'Document ID' + ) + parser.add_argument( '-l', '--log-level', type=LogLevel, @@ -109,10 +197,38 @@ def main(): try: + document = DigitalDocument( + id, + name=args.name, + description=args.description, + copyright_notice=args.copyright_notice, + copyright_holder=args.copyright_holder, + copyright_year=args.copyright_year, + license=args.license, + url=args.url, + keywords=args.keyword, + ) + + if args.publication_organization: + org = Organization( + id=to_uri(PREF_ORG, hash(args.publication_organization)), + name=args.publication_organization, + ) + document.publication = PublicationEvent( + id = to_uri(PREF_PUBEV, str(uuid.uuid4())), + organization=org, + description=args.publication_description, + start_date=args.publication_date, + end_date=args.publication_date, + ) + p = Loader( pulsar_host=args.pulsar_host, output_queue=args.output_queue, + user=args.user, + collection=args.collection, log_level=args.log_level, + metadata=document, ) p.load(args.files) diff --git a/trustgraph-cli/scripts/tg-load-turtle b/trustgraph-cli/scripts/tg-load-turtle new file mode 100755 index 00000000..7c258fcc --- /dev/null +++ b/trustgraph-cli/scripts/tg-load-turtle @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 + +""" +Loads Graph embeddings into TrustGraph processing. +""" + +import pulsar +from pulsar.schema import JsonSchema +from trustgraph.schema import Triples, Triple, Value, Metadata +from trustgraph.schema import triples_store_queue +import argparse +import os +import time +import pyarrow as pa +import rdflib + +from trustgraph.log_level import LogLevel + +default_user = 'trustgraph' +default_collection = 'default' +default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650') +default_output_queue = triples_store_queue + +class Loader: + + def __init__( + self, + pulsar_host, + output_queue, + log_level, + files, + user, + collection, + ): + + self.client = pulsar.Client( + pulsar_host, + logger=pulsar.ConsoleLogger(log_level.to_pulsar()) + ) + + self.producer = self.client.create_producer( + topic=output_queue, + schema=JsonSchema(Triples), + chunking_enabled=True, + ) + + self.files = files + self.user = user + self.collection = collection + + def run(self): + + try: + + for file in self.files: + self.load_file(file) + + except Exception as e: + print(e, flush=True) + + def load_file(self, file): + + g = rdflib.Graph() + g.parse(file, format="turtle") + + for e in g: + s = Value(value=str(e[0]), is_uri=True) + p = Value(value=str(e[1]), is_uri=True) + if type(e[2]) == rdflib.term.URIRef: + o = Value(value=str(e[2]), is_uri=True) + else: + o = Value(value=str(e[2]), is_uri=False) + + r = Triples( + metadata=Metadata( + id=None, + metadata=[], + user=self.user, + collection=self.collection, + ), + triples=[ Triple(s=s, p=p, o=o) ] + ) + + self.producer.send(r) + + def __del__(self): + self.client.close() + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-load-turtle', + description=__doc__, + ) + + parser.add_argument( + '-p', '--pulsar-host', + default=default_pulsar_host, + help=f'Pulsar host (default: {default_pulsar_host})', + ) + + parser.add_argument( + '-o', '--output-queue', + default=default_output_queue, + help=f'Output queue (default: {default_output_queue})' + ) + + parser.add_argument( + '-u', '--user', + default=default_user, + help=f'User ID (default: {default_user})' + ) + + parser.add_argument( + '-c', '--collection', + default=default_collection, + help=f'Collection ID (default: {default_collection})' + ) + + parser.add_argument( + '-l', '--log-level', + type=LogLevel, + default=LogLevel.ERROR, + choices=list(LogLevel), + help=f'Output queue (default: info)' + ) + + parser.add_argument( + 'files', nargs='+', + help=f'Turtle files to load' + ) + + args = parser.parse_args() + + while True: + + try: + p = Loader( + pulsar_host=args.pulsar_host, + output_queue=args.output_queue, + log_level=args.log_level, + files=args.files, + user=args.user, + collection=args.collection, + ) + + p.run() + + print("File loaded.") + break + + except Exception as e: + + print("Exception:", e, flush=True) + print("Will retry...", flush=True) + + time.sleep(10) + +main() + diff --git a/trustgraph-cli/scripts/tg-query-document-rag b/trustgraph-cli/scripts/tg-query-document-rag index 948dcd2f..8d800629 100755 --- a/trustgraph-cli/scripts/tg-query-document-rag +++ b/trustgraph-cli/scripts/tg-query-document-rag @@ -9,17 +9,19 @@ import os from trustgraph.clients.document_rag_client import DocumentRagClient default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650') +default_user = 'trustgraph' +default_collection = 'default' -def query(pulsar, query): +def query(pulsar_host, query, user, collection): rag = DocumentRagClient(pulsar_host=pulsar) - resp = rag.request(query) + resp = rag.request(user=user, collection=collection, query=query) print(resp) def main(): parser = argparse.ArgumentParser( - prog='graph-show', + prog='tg-query-document-rag', description=__doc__, ) @@ -35,11 +37,28 @@ def main(): help=f'Query to execute', ) + parser.add_argument( + '-u', '--user', + default=default_user, + help=f'User ID (default: {default_user})' + ) + + parser.add_argument( + '-c', '--collection', + default=default_collection, + help=f'Collection ID (default: {default_collection})' + ) + args = parser.parse_args() try: - query(args.pulsar_host, args.query) + query( + pulsar_host=args.pulsar_host, + query=args.query, + user=args.user, + collection=args.collection, + ) except Exception as e: diff --git a/trustgraph-cli/scripts/tg-query-graph-rag b/trustgraph-cli/scripts/tg-query-graph-rag index 5250bf15..8a865eea 100755 --- a/trustgraph-cli/scripts/tg-query-graph-rag +++ b/trustgraph-cli/scripts/tg-query-graph-rag @@ -9,17 +9,19 @@ import os from trustgraph.clients.graph_rag_client import GraphRagClient default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650') +default_user = 'trustgraph' +default_collection = 'default' -def query(pulsar, query): +def query(pulsar_host, query, user, collection): - rag = GraphRagClient(pulsar_host=pulsar) - resp = rag.request(query) + rag = GraphRagClient(pulsar_host=pulsar_host) + resp = rag.request(user=user, collection=collection, query=query) print(resp) def main(): parser = argparse.ArgumentParser( - prog='graph-show', + prog='tg-graph-query-rag', description=__doc__, ) @@ -35,11 +37,28 @@ def main(): help=f'Query to execute', ) + parser.add_argument( + '-u', '--user', + default=default_user, + help=f'User ID (default: {default_user})' + ) + + parser.add_argument( + '-c', '--collection', + default=default_collection, + help=f'Collection ID (default: {default_collection})' + ) + args = parser.parse_args() try: - query(args.pulsar_host, args.query) + query( + pulsar_host=args.pulsar_host, + query=args.query, + user=args.user, + collection=args.collection, + ) except Exception as e: diff --git a/trustgraph-cli/setup.py b/trustgraph-cli/setup.py index 061234a6..651fdc27 100644 --- a/trustgraph-cli/setup.py +++ b/trustgraph-cli/setup.py @@ -34,7 +34,7 @@ setuptools.setup( python_requires='>=3.8', download_url = "https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v" + version + ".tar.gz", install_requires=[ - "trustgraph-base<0.12", + "trustgraph-base>=0.15,<0.16", "requests", "pulsar-client", "rdflib", @@ -46,9 +46,13 @@ setuptools.setup( "scripts/tg-init-pulsar-manager", "scripts/tg-load-pdf", "scripts/tg-load-text", + "scripts/tg-load-turtle", "scripts/tg-query-document-rag", "scripts/tg-query-graph-rag", "scripts/tg-init-pulsar", "scripts/tg-processor-state", + "scripts/tg-invoke-agent", + "scripts/tg-invoke-prompt", + "scripts/tg-invoke-llm", ] ) diff --git a/trustgraph-embeddings-hf/setup.py b/trustgraph-embeddings-hf/setup.py index 7f1aafa4..ad01667f 100644 --- a/trustgraph-embeddings-hf/setup.py +++ b/trustgraph-embeddings-hf/setup.py @@ -34,8 +34,8 @@ setuptools.setup( python_requires='>=3.8', download_url = "https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v" + version + ".tar.gz", install_requires=[ - "trustgraph-base<0.12", - "trustgraph-flow<0.12", + "trustgraph-base>=0.15,<0.16", + "trustgraph-flow>=0.15,<0.16", "torch", "urllib3", "transformers", diff --git a/trustgraph-flow/scripts/agent-manager-react b/trustgraph-flow/scripts/agent-manager-react new file mode 100644 index 00000000..b5e060c7 --- /dev/null +++ b/trustgraph-flow/scripts/agent-manager-react @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from trustgraph.agent.react import run + +run() + diff --git a/trustgraph-flow/scripts/text-completion-azure-openai b/trustgraph-flow/scripts/text-completion-azure-openai new file mode 100755 index 00000000..f989d4b7 --- /dev/null +++ b/trustgraph-flow/scripts/text-completion-azure-openai @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from trustgraph.model.text_completion.azure_openai import run + +run() + diff --git a/trustgraph-flow/scripts/text-completion-googleaistudio b/trustgraph-flow/scripts/text-completion-googleaistudio new file mode 100755 index 00000000..4d2b0784 --- /dev/null +++ b/trustgraph-flow/scripts/text-completion-googleaistudio @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from trustgraph.model.text_completion.googleaistudio import run + +run() + diff --git a/trustgraph-flow/setup.py b/trustgraph-flow/setup.py index 1c86ed77..8b46b2d2 100644 --- a/trustgraph-flow/setup.py +++ b/trustgraph-flow/setup.py @@ -34,7 +34,7 @@ setuptools.setup( python_requires='>=3.8', download_url = "https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v" + version + ".tar.gz", install_requires=[ - "trustgraph-base<0.12", + "trustgraph-base>=0.15,<0.16", "urllib3", "rdflib", "pymilvus", @@ -55,8 +55,12 @@ setuptools.setup( "openai", "neo4j", "tiktoken", + "google-generativeai", + "ibis", + "jsonschema", ], scripts=[ + "scripts/agent-manager-react", "scripts/chunker-recursive", "scripts/chunker-token", "scripts/de-query-milvus", @@ -83,8 +87,10 @@ setuptools.setup( "scripts/rows-write-cassandra", "scripts/run-processing", "scripts/text-completion-azure", + "scripts/text-completion-azure-openai", "scripts/text-completion-claude", "scripts/text-completion-cohere", + "scripts/text-completion-googleaistudio", "scripts/text-completion-llamafile", "scripts/text-completion-ollama", "scripts/text-completion-openai", diff --git a/trustgraph-flow/trustgraph/agent/__init__.py b/trustgraph-flow/trustgraph/agent/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/trustgraph-flow/trustgraph/agent/react/README.md b/trustgraph-flow/trustgraph/agent/react/README.md new file mode 100644 index 00000000..dd5cbea5 --- /dev/null +++ b/trustgraph-flow/trustgraph/agent/react/README.md @@ -0,0 +1,19 @@ + +agent-manager-react \ + -p pulsar://localhost:6650 \ + --tool-type \ + shuttle=knowledge-query:query \ + cats=knowledge-query:query \ + compute=text-completion:computation \ + --tool-description \ + shuttle="Query a knowledge base with information about the space shuttle. The query should be a simple natural language question" \ + cats="Query a knowledge base with information about Mark's cats. The query should be a simple natural language question" \ + compute="A computation engine which can answer questions about maths and computation" \ + --tool-argument \ + cats="query:string:The search query string" \ + shuttle="query:string:The search query string" \ + compute="computation:string:The computation to solve" + + + --context 'The space shuttle challenger final mission was 58-L' + diff --git a/trustgraph-flow/trustgraph/agent/react/__init__.py b/trustgraph-flow/trustgraph/agent/react/__init__.py new file mode 100644 index 00000000..ba844705 --- /dev/null +++ b/trustgraph-flow/trustgraph/agent/react/__init__.py @@ -0,0 +1,3 @@ + +from . service import * + diff --git a/trustgraph-flow/trustgraph/agent/react/__main__.py b/trustgraph-flow/trustgraph/agent/react/__main__.py new file mode 100755 index 00000000..e9136855 --- /dev/null +++ b/trustgraph-flow/trustgraph/agent/react/__main__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 + +from . service import run + +if __name__ == '__main__': + run() + diff --git a/trustgraph-flow/trustgraph/agent/react/agent_manager.py b/trustgraph-flow/trustgraph/agent/react/agent_manager.py new file mode 100644 index 00000000..5d071e30 --- /dev/null +++ b/trustgraph-flow/trustgraph/agent/react/agent_manager.py @@ -0,0 +1,120 @@ + +import logging +import json + +from . types import Action, Final + +logger = logging.getLogger(__name__) + +class AgentManager: + + def __init__(self, context, tools, additional_context=None): + self.context = context + self.tools = tools + self.additional_context = additional_context + + def reason(self, question, history): + + tools = self.tools + + tool_names = ",".join([ + t for t in self.tools.keys() + ]) + + variables = { + "question": question, + "tools": [ + { + "name": tool.name, + "description": tool.description, + "arguments": [ + { + "name": arg.name, + "type": arg.type, + "description": arg.description + } + for arg in tool.arguments.values() + ] + } + for tool in self.tools.values() + ], + "context": self.additional_context, + "question": question, + "tool_names": tool_names, + "history": [ + { + "thought": h.thought, + "action": h.name, + "arguments": h.arguments, + "observation": h.observation, + } + for h in history + ] + } + + print(json.dumps(variables, indent=4), flush=True) + + logger.info(f"prompt: {variables}") + + obj = self.context.prompt.request( + "agent-react", + variables + ) + + print(json.dumps(obj, indent=4), flush=True) + + logger.info(f"response: {obj}") + + if obj.get("final-answer"): + + a = Final( + thought = obj.get("thought"), + final = obj.get("final-answer"), + ) + + return a + + else: + + a = Action( + thought = obj.get("thought"), + name = obj.get("action"), + arguments = obj.get("arguments"), + observation = "" + ) + + return a + + def react(self, question, history, think, observe): + + act = self.reason(question, history) + logger.info(f"act: {act}") + + if isinstance(act, Final): + + think(act.thought) + return act + + else: + + think(act.thought) + + if act.name in self.tools: + action = self.tools[act.name] + else: + raise RuntimeError(f"No action for {act.name}!") + + resp = action.implementation.invoke(**act.arguments) + + resp = resp.strip() + + logger.info(f"resp: {resp}") + + observe(resp) + + act.observation = resp + + logger.info(f"iter: {act}") + + return act + diff --git a/trustgraph-flow/trustgraph/agent/react/service.py b/trustgraph-flow/trustgraph/agent/react/service.py new file mode 100755 index 00000000..8799816b --- /dev/null +++ b/trustgraph-flow/trustgraph/agent/react/service.py @@ -0,0 +1,410 @@ +""" +Simple agent infrastructure broadly implements the ReAct flow. +""" + +import json +import re +import sys + +from pulsar.schema import JsonSchema + +from ... base import ConsumerProducer +from ... schema import Error +from ... schema import AgentRequest, AgentResponse, AgentStep +from ... schema import agent_request_queue, agent_response_queue +from ... schema import prompt_request_queue as pr_request_queue +from ... schema import prompt_response_queue as pr_response_queue +from ... schema import text_completion_request_queue as tc_request_queue +from ... schema import text_completion_response_queue as tc_response_queue +from ... schema import graph_rag_request_queue as gr_request_queue +from ... schema import graph_rag_response_queue as gr_response_queue +from ... clients.prompt_client import PromptClient +from ... clients.llm_client import LlmClient +from ... clients.graph_rag_client import GraphRagClient + +from . tools import KnowledgeQueryImpl, TextCompletionImpl +from . agent_manager import AgentManager + +from . types import Final, Action, Tool, Argument + +module = ".".join(__name__.split(".")[1:-1]) + +default_input_queue = agent_request_queue +default_output_queue = agent_response_queue +default_subscriber = module +default_max_iterations = 15 + +class Processor(ConsumerProducer): + + def __init__(self, **params): + + additional = params.get("context", None) + + self.max_iterations = int(params.get("max_iterations", default_max_iterations)) + + tools = {} + + # Parsing the prompt information to the prompt configuration + # structure + tool_type_arg = params.get("tool_type", []) + if tool_type_arg: + for t in tool_type_arg: + toks = t.split("=", 1) + if len(toks) < 2: + raise RuntimeError( + f"Tool-type string not well-formed: {t}" + ) + ttoks = toks[1].split(":", 1) + if len(ttoks) < 1: + raise RuntimeError( + f"Tool-type string not well-formed: {t}" + ) + + if ttoks[0] == "knowledge-query": + impl = KnowledgeQueryImpl(self) + elif ttoks[0] == "text-completion": + impl = TextCompletionImpl(self) + else: + raise RuntimeError( + f"Tool-kind {ttoks[0]} not known" + ) + + if len(ttoks) == 1: + + tools[toks[0]] = Tool( + name = toks[0], + description = "", + implementation = impl, + config = { "input": "query" }, + arguments = {}, + ) + else: + tools[toks[0]] = Tool( + name = toks[0], + description = "", + implementation = impl, + config = { "input": ttoks[1] }, + arguments = {}, + ) + + # parsing the prompt information to the prompt configuration + # structure + tool_desc_arg = params.get("tool_description", []) + if tool_desc_arg: + for t in tool_desc_arg: + toks = t.split("=", 1) + if len(toks) < 2: + raise runtimeerror( + f"tool-type string not well-formed: {t}" + ) + if toks[0] not in tools: + raise runtimeerror(f"description, tool {toks[0]} not known") + tools[toks[0]].description = toks[1] + + # Parsing the prompt information to the prompt configuration + # structure + tool_arg_arg = params.get("tool_argument", []) + if tool_arg_arg: + for t in tool_arg_arg: + toks = t.split("=", 1) + if len(toks) < 2: + raise RuntimeError( + f"Tool-type string not well-formed: {t}" + ) + ttoks = toks[1].split(":", 2) + if len(ttoks) != 3: + raise RuntimeError( + f"Tool argument string not well-formed: {t}" + ) + if toks[0] not in tools: + raise RuntimeError(f"Description, tool {toks[0]} not known") + tools[toks[0]].arguments[ttoks[0]] = Argument( + name = ttoks[0], + type = ttoks[1], + description = ttoks[2] + ) + + input_queue = params.get("input_queue", default_input_queue) + output_queue = params.get("output_queue", default_output_queue) + subscriber = params.get("subscriber", default_subscriber) + prompt_request_queue = params.get( + "prompt_request_queue", pr_request_queue + ) + prompt_response_queue = params.get( + "prompt_response_queue", pr_response_queue + ) + text_completion_request_queue = params.get( + "text_completion_request_queue", tc_request_queue + ) + text_completion_response_queue = params.get( + "text_completion_response_queue", tc_response_queue + ) + graph_rag_request_queue = params.get( + "graph_rag_request_queue", gr_request_queue + ) + graph_rag_response_queue = params.get( + "graph_rag_response_queue", gr_response_queue + ) + + super(Processor, self).__init__( + **params | { + "input_queue": input_queue, + "output_queue": output_queue, + "subscriber": subscriber, + "input_schema": AgentRequest, + "output_schema": AgentResponse, + "prompt_request_queue": prompt_request_queue, + "prompt_response_queue": prompt_response_queue, + "text_completion_request_queue": tc_request_queue, + "text_completion_response_queue": tc_response_queue, + "graph_rag_request_queue": gr_request_queue, + "graph_rag_response_queue": gr_response_queue, + } + ) + + self.prompt = PromptClient( + subscriber=subscriber, + input_queue=prompt_request_queue, + output_queue=prompt_response_queue, + pulsar_host = self.pulsar_host + ) + + self.llm = LlmClient( + subscriber=subscriber, + input_queue=text_completion_request_queue, + output_queue=text_completion_response_queue, + pulsar_host = self.pulsar_host + ) + + self.graph_rag = GraphRagClient( + subscriber=subscriber, + input_queue=graph_rag_request_queue, + output_queue=graph_rag_response_queue, + pulsar_host = self.pulsar_host + ) + + # Need to be able to feed requests to myself + self.recursive_input = self.client.create_producer( + topic=input_queue, + schema=JsonSchema(AgentRequest), + ) + + self.agent = AgentManager( + context=self, + tools=tools, + additional_context=additional + ) + + def parse_json(self, text): + json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL) + + if json_match: + json_str = json_match.group(1).strip() + else: + # If no delimiters, assume the entire output is JSON + json_str = text.strip() + + return json.loads(json_str) + + def handle(self, msg): + + try: + + v = msg.value() + + # Sender-produced ID + id = msg.properties()["id"] + + if v.history: + history = [ + Action( + thought=h.thought, + name=h.action, + arguments=h.arguments, + observation=h.observation + ) + for h in v.history + ] + else: + history = [] + + print(f"Question: {v.question}", flush=True) + + if len(history) >= self.max_iterations: + raise RuntimeError("Too many agent iterations") + + print(f"History: {history}", flush=True) + + def think(x): + + print(f"Think: {x}", flush=True) + + r = AgentResponse( + answer=None, + error=None, + thought=x, + observation=None, + ) + + self.producer.send(r, properties={"id": id}) + + def observe(x): + + print(f"Observe: {x}", flush=True) + + r = AgentResponse( + answer=None, + error=None, + thought=None, + observation=x, + ) + + self.producer.send(r, properties={"id": id}) + + act = self.agent.react(v.question, history, think, observe) + + print(f"Action: {act}", flush=True) + + print("Send response...", flush=True) + + if type(act) == Final: + + r = AgentResponse( + answer=act.final, + error=None, + thought=None, + ) + + self.producer.send(r, properties={"id": id}) + + print("Done.", flush=True) + + return + + history.append(act) + + r = AgentRequest( + question=v.question, + plan=v.plan, + state=v.state, + history=[ + AgentStep( + thought=h.thought, + action=h.name, + arguments=h.arguments, + observation=h.observation + ) + for h in history + ] + ) + + self.recursive_input.send(r, properties={"id": id}) + + print("Done.", flush=True) + + return + + except Exception as e: + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + r = AgentResponse( + error=Error( + type = "agent-error", + message = str(e), + ), + response=None, + ) + + self.producer.send(r, properties={"id": id}) + + @staticmethod + def add_args(parser): + + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) + + parser.add_argument( + '--prompt-request-queue', + default=pr_request_queue, + help=f'Prompt request queue (default: {pr_request_queue})', + ) + + parser.add_argument( + '--prompt-response-queue', + default=pr_response_queue, + help=f'Prompt response queue (default: {pr_response_queue})', + ) + + parser.add_argument( + '--text-completion-request-queue', + default=tc_request_queue, + help=f'Text completion request queue (default: {tc_request_queue})', + ) + + parser.add_argument( + '--text-completion-response-queue', + default=tc_response_queue, + help=f'Text completion response queue (default: {tc_response_queue})', + ) + + parser.add_argument( + '--graph-rag-request-queue', + default=gr_request_queue, + help=f'Graph RAG request queue (default: {gr_request_queue})', + ) + + parser.add_argument( + '--graph-rag-response-queue', + default=gr_response_queue, + help=f'Graph RAG response queue (default: {gr_response_queue})', + ) + + parser.add_argument( + '--tool-type', nargs='*', + help=f'''Specifies the type of an agent tool. Takes the form +=. is the name of the tool. is one of +knowledge-query, text-completion. Additional parameters are specified +for different tools which are tool-specific. e.g. knowledge-query: +which specifies the name of the arg whose content is fed into the knowledge +query as a question. text-completion: specifies the name of the arg +whose content is fed into the text-completion service as a prompt''' + ) + + parser.add_argument( + '--tool-description', nargs='*', + help=f'''Specifies the textual description of a tool. Takes +the form =. The description is important, it teaches the +LLM how to use the tool. It should describe what it does and how to +use the arguments. This is specified in natural language.''' + ) + + parser.add_argument( + '--tool-argument', nargs='*', + help=f'''Specifies argument usage for a tool. Takes +the form =::. The description is important, +it is read by the LLM and used to determine how to use the argument. + can be specified multiple times to give a tool multiple arguments. + is one of string, number. is a natural language +description.''' + ) + + parser.add_argument( + '--context', + help=f'Optional, specifies additional context text for the LLM.' + ) + + parser.add_argument( + '--max-iterations', + default=default_max_iterations, + help=f'Maximum number of react iterations (default: {default_max_iterations})', + ) + +def run(): + + Processor.start(module, __doc__) + diff --git a/trustgraph-flow/trustgraph/agent/react/tools.py b/trustgraph-flow/trustgraph/agent/react/tools.py new file mode 100644 index 00000000..d9bc846f --- /dev/null +++ b/trustgraph-flow/trustgraph/agent/react/tools.py @@ -0,0 +1,19 @@ + +# This tool implementation knows how to put a question to the graph RAG +# service +class KnowledgeQueryImpl: + def __init__(self, context): + self.context = context + def invoke(self, **arguments): + return self.context.graph_rag.request(arguments.get("query")) + +# This tool implementation knows how to do text completion. This uses +# the prompt service, rather than talking to TextCompletion directly. +class TextCompletionImpl: + def __init__(self, context): + self.context = context + def invoke(self, **arguments): + return self.context.prompt.request( + "question", { "question": arguments.get("computation") } + ) + diff --git a/trustgraph-flow/trustgraph/agent/react/types.py b/trustgraph-flow/trustgraph/agent/react/types.py new file mode 100644 index 00000000..7180db3e --- /dev/null +++ b/trustgraph-flow/trustgraph/agent/react/types.py @@ -0,0 +1,30 @@ + +import dataclasses +from typing import Any, Dict + +@dataclasses.dataclass +class Argument: + name : str + type : str + description : str + +@dataclasses.dataclass +class Tool: + name : str + description : str + arguments : list[Argument] + implementation : Any + config : Dict[str, str] + +@dataclasses.dataclass +class Action: + thought : str + name : str + arguments : dict + observation : str + +@dataclasses.dataclass +class Final: + thought : str + final : str + diff --git a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py index fe1a0cee..694ced70 100755 --- a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py @@ -7,7 +7,7 @@ as text as separate output objects. from langchain_text_splitters import RecursiveCharacterTextSplitter from prometheus_client import Histogram -from ... schema import TextDocument, Chunk, Source +from ... schema import TextDocument, Chunk, Metadata from ... schema import text_ingest_queue, chunk_ingest_queue from ... log_level import LogLevel from ... base import ConsumerProducer @@ -55,7 +55,7 @@ class Processor(ConsumerProducer): def handle(self, msg): v = msg.value() - print(f"Chunking {v.source.id}...", flush=True) + print(f"Chunking {v.metadata.id}...", flush=True) texts = self.text_splitter.create_documents( [v.text.decode("utf-8")] @@ -63,14 +63,8 @@ class Processor(ConsumerProducer): for ix, chunk in enumerate(texts): - id = v.source.id + "-c" + str(ix) - r = Chunk( - source=Source( - source=v.source.source, - id=id, - title=v.source.title - ), + metadata=v.metadata, chunk=chunk.page_content.encode("utf-8"), ) diff --git a/trustgraph-flow/trustgraph/chunking/token/chunker.py b/trustgraph-flow/trustgraph/chunking/token/chunker.py index c152b0fd..dccd9c89 100755 --- a/trustgraph-flow/trustgraph/chunking/token/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/token/chunker.py @@ -7,7 +7,7 @@ as text as separate output objects. from langchain_text_splitters import TokenTextSplitter from prometheus_client import Histogram -from ... schema import TextDocument, Chunk, Source +from ... schema import TextDocument, Chunk, Metadata from ... schema import text_ingest_queue, chunk_ingest_queue from ... log_level import LogLevel from ... base import ConsumerProducer @@ -54,7 +54,7 @@ class Processor(ConsumerProducer): def handle(self, msg): v = msg.value() - print(f"Chunking {v.source.id}...", flush=True) + print(f"Chunking {v.metadata.id}...", flush=True) texts = self.text_splitter.create_documents( [v.text.decode("utf-8")] @@ -62,14 +62,8 @@ class Processor(ConsumerProducer): for ix, chunk in enumerate(texts): - id = v.source.id + "-c" + str(ix) - r = Chunk( - source=Source( - source=v.source.source, - id=id, - title=v.source.title - ), + metadata=v.metadata, chunk=chunk.page_content.encode("utf-8"), ) diff --git a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py index fffcaee0..38ac9257 100755 --- a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py +++ b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py @@ -8,7 +8,7 @@ import tempfile import base64 from langchain_community.document_loaders import PyPDFLoader -from ... schema import Document, TextDocument, Source +from ... schema import Document, TextDocument, Metadata from ... schema import document_ingest_queue, text_ingest_queue from ... log_level import LogLevel from ... base import ConsumerProducer @@ -45,7 +45,7 @@ class Processor(ConsumerProducer): v = msg.value() - print(f"Decoding {v.source.id}...", flush=True) + print(f"Decoding {v.metadata.id}...", flush=True) with tempfile.NamedTemporaryFile(delete_on_close=False) as fp: @@ -59,13 +59,8 @@ class Processor(ConsumerProducer): for ix, page in enumerate(pages): - id = v.source.id + "-p" + str(ix) r = TextDocument( - source=Source( - source=v.source.source, - title=v.source.title, - id=id, - ), + metadata=v.metadata, text=page.page_content.encode("utf-8"), ) diff --git a/trustgraph-flow/trustgraph/direct/cassandra.py b/trustgraph-flow/trustgraph/direct/cassandra.py index 1754e090..2b577df1 100644 --- a/trustgraph-flow/trustgraph/direct/cassandra.py +++ b/trustgraph-flow/trustgraph/direct/cassandra.py @@ -4,10 +4,16 @@ from cassandra.auth import PlainTextAuthProvider class TrustGraph: - def __init__(self, hosts=None): + def __init__( + self, hosts=None, + keyspace="trustgraph", table="default", + ): if hosts is None: hosts = ["localhost"] + + self.keyspace = keyspace + self.table = table self.cluster = Cluster(hosts) self.session = self.cluster.connect() @@ -16,26 +22,26 @@ class TrustGraph: def clear(self): - self.session.execute(""" - drop keyspace if exists trustgraph; + self.session.execute(f""" + drop keyspace if exists {self.keyspace}; """); self.init() def init(self): - self.session.execute(""" - create keyspace if not exists trustgraph - with replication = { + self.session.execute(f""" + create keyspace if not exists {self.keyspace} + with replication = {{ 'class' : 'SimpleStrategy', 'replication_factor' : 1 - }; + }}; """); - self.session.set_keyspace('trustgraph') + self.session.set_keyspace(self.keyspace) - self.session.execute(""" - create table if not exists triples ( + self.session.execute(f""" + create table if not exists {self.table} ( s text, p text, o text, @@ -43,66 +49,66 @@ class TrustGraph: ); """); - self.session.execute(""" - create index if not exists triples_p - ON triples (p); + self.session.execute(f""" + create index if not exists {self.table}_p + ON {self.table} (p); """); - self.session.execute(""" - create index if not exists triples_o - ON triples (o); + self.session.execute(f""" + create index if not exists {self.table}_o + ON {self.table} (o); """); def insert(self, s, p, o): self.session.execute( - "insert into triples (s, p, o) values (%s, %s, %s)", + f"insert into {self.table} (s, p, o) values (%s, %s, %s)", (s, p, o) ) def get_all(self, limit=50): return self.session.execute( - f"select s, p, o from triples limit {limit}" + f"select s, p, o from {self.table} limit {limit}" ) def get_s(self, s, limit=10): return self.session.execute( - f"select p, o from triples where s = %s limit {limit}", + f"select p, o from {self.table} where s = %s limit {limit}", (s,) ) def get_p(self, p, limit=10): return self.session.execute( - f"select s, o from triples where p = %s limit {limit}", + f"select s, o from {self.table} where p = %s limit {limit}", (p,) ) def get_o(self, o, limit=10): return self.session.execute( - f"select s, p from triples where o = %s limit {limit}", + f"select s, p from {self.table} where o = %s limit {limit}", (o,) ) def get_sp(self, s, p, limit=10): return self.session.execute( - f"select o from triples where s = %s and p = %s limit {limit}", + f"select o from {self.table} where s = %s and p = %s limit {limit}", (s, p) ) def get_po(self, p, o, limit=10): return self.session.execute( - f"select s from triples where p = %s and o = %s allow filtering limit {limit}", + f"select s from {self.table} where p = %s and o = %s allow filtering limit {limit}", (p, o) ) def get_os(self, o, s, limit=10): return self.session.execute( - f"select p from triples where o = %s and s = %s limit {limit}", + f"select p from {self.table} where o = %s and s = %s limit {limit}", (o, s) ) def get_spo(self, s, p, o, limit=10): return self.session.execute( - f"""select s as x from triples where s = %s and p = %s and o = %s limit {limit}""", + f"""select s as x from {self.table} where s = %s and p = %s and o = %s limit {limit}""", (s, p, o) ) diff --git a/trustgraph-flow/trustgraph/embeddings/vectorize/vectorize.py b/trustgraph-flow/trustgraph/embeddings/vectorize/vectorize.py index 3770fee2..4cf2af05 100755 --- a/trustgraph-flow/trustgraph/embeddings/vectorize/vectorize.py +++ b/trustgraph-flow/trustgraph/embeddings/vectorize/vectorize.py @@ -50,15 +50,15 @@ class Processor(ConsumerProducer): subscriber=module + "-emb", ) - def emit(self, source, chunk, vectors): + def emit(self, metadata, chunk, vectors): - r = ChunkEmbeddings(source=source, chunk=chunk, vectors=vectors) + r = ChunkEmbeddings(metadata=metadata, chunk=chunk, vectors=vectors) self.producer.send(r) def handle(self, msg): v = msg.value() - print(f"Indexing {v.source.id}...", flush=True) + print(f"Indexing {v.metadata.id}...", flush=True) chunk = v.chunk.decode("utf-8") @@ -67,7 +67,7 @@ class Processor(ConsumerProducer): vectors = self.embeddings.request(chunk) self.emit( - source=v.source, + metadata=v.metadata, chunk=chunk.encode("utf-8"), vectors=vectors ) diff --git a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py index 06ba8e68..eed34574 100755 --- a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py @@ -7,16 +7,18 @@ get entity definitions which are output as graph edges. import urllib.parse import json -from .... schema import ChunkEmbeddings, Triple, Source, Value +from .... schema import ChunkEmbeddings, Triple, Triples, Metadata, Value from .... schema import chunk_embeddings_ingest_queue, triples_store_queue from .... schema import prompt_request_queue from .... schema import prompt_response_queue from .... log_level import LogLevel from .... clients.prompt_client import PromptClient -from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION +from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF from .... base import ConsumerProducer DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True) +RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True) +SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True) module = ".".join(__name__.split(".")[1:-1]) @@ -44,7 +46,7 @@ class Processor(ConsumerProducer): "output_queue": output_queue, "subscriber": subscriber, "input_schema": ChunkEmbeddings, - "output_schema": Triple, + "output_schema": Triples, "prompt_request_queue": pr_request_queue, "prompt_response_queue": pr_response_queue, } @@ -69,15 +71,18 @@ class Processor(ConsumerProducer): return self.prompt.request_definitions(chunk) - def emit_edge(self, s, p, o): + def emit_edges(self, metadata, triples): - t = Triple(s=s, p=p, o=o) + t = Triples( + metadata=metadata, + triples=triples, + ) self.producer.send(t) def handle(self, msg): v = msg.value() - print(f"Indexing {v.source.id}...", flush=True) + print(f"Indexing {v.metadata.id}...", flush=True) chunk = v.chunk.decode("utf-8") @@ -85,6 +90,13 @@ class Processor(ConsumerProducer): defs = self.get_definitions(chunk) + triples = [] + + # FIXME: Putting metadata into triples store is duplicated in + # relationships extractor too + for t in v.metadata.metadata: + triples.append(t) + for defn in defs: s = defn.name @@ -101,7 +113,31 @@ class Processor(ConsumerProducer): s_value = Value(value=str(s_uri), is_uri=True) o_value = Value(value=str(o), is_uri=False) - self.emit_edge(s_value, DEFINITION_VALUE, o_value) + triples.append(Triple( + s=s_value, + p=RDF_LABEL_VALUE, + o=Value(value=s, is_uri=False), + )) + + triples.append(Triple( + s=s_value, p=DEFINITION_VALUE, o=o_value + )) + + triples.append(Triple( + s=s_value, + p=SUBJECT_OF_VALUE, + o=Value(value=v.metadata.id, is_uri=True) + )) + + self.emit_edges( + Metadata( + id=v.metadata.id, + metadata=[], + user=v.metadata.user, + collection=v.metadata.collection, + ), + triples + ) except Exception as e: print("Exception: ", e, flush=True) diff --git a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py index 49ef9072..d2dea062 100755 --- a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py @@ -9,17 +9,19 @@ import urllib.parse import os from pulsar.schema import JsonSchema -from .... schema import ChunkEmbeddings, Triple, GraphEmbeddings, Source, Value +from .... schema import ChunkEmbeddings, Triple, Triples, GraphEmbeddings +from .... schema import Metadata, Value from .... schema import chunk_embeddings_ingest_queue, triples_store_queue from .... schema import graph_embeddings_store_queue from .... schema import prompt_request_queue from .... schema import prompt_response_queue from .... log_level import LogLevel from .... clients.prompt_client import PromptClient -from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES +from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES, SUBJECT_OF from .... base import ConsumerProducer RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True) +SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True) module = ".".join(__name__.split(".")[1:-1]) @@ -49,7 +51,7 @@ class Processor(ConsumerProducer): "output_queue": output_queue, "subscriber": subscriber, "input_schema": ChunkEmbeddings, - "output_schema": Triple, + "output_schema": Triples, "prompt_request_queue": pr_request_queue, "prompt_response_queue": pr_response_queue, } @@ -68,7 +70,7 @@ class Processor(ConsumerProducer): "prompt_response_queue": pr_response_queue, "subscriber": subscriber, "input_schema": ChunkEmbeddings.__name__, - "output_schema": Triple.__name__, + "output_schema": Triples.__name__, "vector_schema": GraphEmbeddings.__name__, }) @@ -91,20 +93,23 @@ class Processor(ConsumerProducer): return self.prompt.request_relationships(chunk) - def emit_edge(self, s, p, o): + def emit_edges(self, metadata, triples): - t = Triple(s=s, p=p, o=o) + t = Triples( + metadata=metadata, + triples=triples, + ) self.producer.send(t) - def emit_vec(self, ent, vec): + def emit_vec(self, metadata, ent, vec): - r = GraphEmbeddings(entity=ent, vectors=vec) + r = GraphEmbeddings(metadata=metadata, entity=ent, vectors=vec) self.vec_prod.send(r) def handle(self, msg): v = msg.value() - print(f"Indexing {v.source.id}...", flush=True) + print(f"Indexing {v.metadata.id}...", flush=True) chunk = v.chunk.decode("utf-8") @@ -112,6 +117,13 @@ class Processor(ConsumerProducer): rels = self.get_relationships(chunk) + triples = [] + + # FIXME: Putting metadata into triples store is duplicated in + # relationships extractor too + for t in v.metadata.metadata: + triples.append(t) + for rel in rels: s = rel.s @@ -138,38 +150,64 @@ class Processor(ConsumerProducer): else: o_value = Value(value=str(o), is_uri=False) - self.emit_edge( - s_value, - p_value, - o_value - ) + triples.append(Triple( + s=s_value, + p=p_value, + o=o_value + )) # Label for s - self.emit_edge( - s_value, - RDF_LABEL_VALUE, - Value(value=str(s), is_uri=False) - ) + triples.append(Triple( + s=s_value, + p=RDF_LABEL_VALUE, + o=Value(value=str(s), is_uri=False) + )) # Label for p - self.emit_edge( - p_value, - RDF_LABEL_VALUE, - Value(value=str(p), is_uri=False) - ) + triples.append(Triple( + s=p_value, + p=RDF_LABEL_VALUE, + o=Value(value=str(p), is_uri=False) + )) if rel.o_entity: # Label for o - self.emit_edge( - o_value, - RDF_LABEL_VALUE, - Value(value=str(o), is_uri=False) - ) + triples.append(Triple( + s=o_value, + p=RDF_LABEL_VALUE, + o=Value(value=str(o), is_uri=False) + )) + + # 'Subject of' for s + triples.append(Triple( + s=s_value, + p=SUBJECT_OF_VALUE, + o=Value(value=v.metadata.id, is_uri=True) + )) - self.emit_vec(s_value, v.vectors) - self.emit_vec(p_value, v.vectors) if rel.o_entity: - self.emit_vec(o_value, v.vectors) + # 'Subject of' for o + triples.append(Triple( + s=o_value, + p=SUBJECT_OF_VALUE, + o=Value(value=v.metadata.id, is_uri=True) + )) + + self.emit_vec(v.metadata, s_value, v.vectors) + self.emit_vec(v.metadata, p_value, v.vectors) + + if rel.o_entity: + self.emit_vec(v.metadata, o_value, v.vectors) + + self.emit_edges( + Metadata( + id=v.metadata.id, + metadata=[], + user=v.metadata.user, + collection=v.metadata.collection, + ), + triples + ) except Exception as e: print("Exception: ", e, flush=True) diff --git a/trustgraph-flow/trustgraph/extract/kg/topics/extract.py b/trustgraph-flow/trustgraph/extract/kg/topics/extract.py index e2ebe5b0..8dfc3e6e 100755 --- a/trustgraph-flow/trustgraph/extract/kg/topics/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/topics/extract.py @@ -7,7 +7,7 @@ get entity definitions which are output as graph edges. import urllib.parse import json -from .... schema import ChunkEmbeddings, Triple, Source, Value +from .... schema import ChunkEmbeddings, Triple, Triples, Metadata, Value from .... schema import chunk_embeddings_ingest_queue, triples_store_queue from .... schema import prompt_request_queue from .... schema import prompt_response_queue @@ -44,7 +44,7 @@ class Processor(ConsumerProducer): "output_queue": output_queue, "subscriber": subscriber, "input_schema": ChunkEmbeddings, - "output_schema": Triple, + "output_schema": Triples, "prompt_request_queue": pr_request_queue, "prompt_response_queue": pr_response_queue, } @@ -69,15 +69,18 @@ class Processor(ConsumerProducer): return self.prompt.request_topics(chunk) - def emit_edge(self, s, p, o): + def emit_edge(self, metadata, s, p, o): - t = Triple(s=s, p=p, o=o) + t = Triples( + metadata=metadata, + triples=[Triple(s=s, p=p, o=o)], + ) self.producer.send(t) def handle(self, msg): v = msg.value() - print(f"Indexing {v.source.id}...", flush=True) + print(f"Indexing {v.metadata.id}...", flush=True) chunk = v.chunk.decode("utf-8") @@ -101,7 +104,7 @@ class Processor(ConsumerProducer): s_value = Value(value=str(s_uri), is_uri=True) o_value = Value(value=str(o), is_uri=False) - self.emit_edge(s_value, DEFINITION_VALUE, o_value) + self.emit_edge(v. metadata, s_value, DEFINITION_VALUE, o_value) except Exception as e: print("Exception: ", e, flush=True) diff --git a/trustgraph-flow/trustgraph/extract/object/row/extract.py b/trustgraph-flow/trustgraph/extract/object/row/extract.py index aa53f2a6..185a59c3 100755 --- a/trustgraph-flow/trustgraph/extract/object/row/extract.py +++ b/trustgraph-flow/trustgraph/extract/object/row/extract.py @@ -8,7 +8,7 @@ import urllib.parse import os from pulsar.schema import JsonSchema -from .... schema import ChunkEmbeddings, Rows, ObjectEmbeddings, Source +from .... schema import ChunkEmbeddings, Rows, ObjectEmbeddings, Metadata from .... schema import RowSchema, Field from .... schema import chunk_embeddings_ingest_queue, rows_store_queue from .... schema import object_embeddings_store_queue @@ -124,24 +124,24 @@ class Processor(ConsumerProducer): def get_rows(self, chunk): return self.prompt.request_rows(self.schema, chunk) - def emit_rows(self, source, rows): + def emit_rows(self, metadata, rows): t = Rows( - source=source, row_schema=self.row_schema, rows=rows + metadata=metadata, row_schema=self.row_schema, rows=rows ) self.producer.send(t) - def emit_vec(self, source, name, vec, key_name, key): + def emit_vec(self, metadata, name, vec, key_name, key): r = ObjectEmbeddings( - source=source, vectors=vec, name=name, key_name=key_name, id=key + metadata=metadata, vectors=vec, name=name, key_name=key_name, id=key ) self.vec_prod.send(r) def handle(self, msg): v = msg.value() - print(f"Indexing {v.source.id}...", flush=True) + print(f"Indexing {v.metadata.id}...", flush=True) chunk = v.chunk.decode("utf-8") @@ -150,13 +150,13 @@ class Processor(ConsumerProducer): rows = self.get_rows(chunk) self.emit_rows( - source=v.source, + metadata=v.metadata, rows=rows ) for row in rows: self.emit_vec( - source=v.source, vec=v.vectors, + metadata=v.metadata, vec=v.vectors, name=self.schema.name, key_name=self.primary.name, key=row[self.primary.name] ) diff --git a/trustgraph-flow/trustgraph/graph_rag.py b/trustgraph-flow/trustgraph/graph_rag.py index 15acb609..f69ebeb7 100644 --- a/trustgraph-flow/trustgraph/graph_rag.py +++ b/trustgraph-flow/trustgraph/graph_rag.py @@ -18,6 +18,144 @@ from . schema import triples_response_queue LABEL="http://www.w3.org/2000/01/rdf-schema#label" DEFINITION="http://www.w3.org/2004/02/skos/core#definition" +class Query: + + def __init__(self, rag, user, collection, verbose): + self.rag = rag + self.user = user + self.collection = collection + self.verbose = verbose + + def get_vector(self, query): + + if self.verbose: + print("Compute embeddings...", flush=True) + + qembeds = self.rag.embeddings.request(query) + + if self.verbose: + print("Done.", flush=True) + + return qembeds + + def get_entities(self, query): + + vectors = self.get_vector(query) + + if self.verbose: + print("Get entities...", flush=True) + + entities = self.rag.ge_client.request( + user=self.user, collection=self.collection, + vectors=vectors, limit=self.rag.entity_limit, + ) + + entities = [ + e.value + for e in entities + ] + + if self.verbose: + print("Entities:", flush=True) + for ent in entities: + print(" ", ent, flush=True) + + return entities + + def maybe_label(self, e): + + if e in self.rag.label_cache: + return self.rag.label_cache[e] + + res = self.rag.triples_client.request( + user=self.user, collection=self.collection, + s=e, p=LABEL, o=None, limit=1, + ) + + if len(res) == 0: + self.rag.label_cache[e] = e + return e + + self.rag.label_cache[e] = res[0].o.value + return self.rag.label_cache[e] + + def get_subgraph(self, query): + + entities = self.get_entities(query) + + subgraph = set() + + if self.verbose: + print("Get subgraph...", flush=True) + + for e in entities: + + res = self.rag.triples_client.request( + user=self.user, collection=self.collection, + s=e, p=None, o=None, + limit=self.rag.query_limit + ) + + for triple in res: + subgraph.add( + (triple.s.value, triple.p.value, triple.o.value) + ) + + res = self.rag.triples_client.request( + user=self.user, collection=self.collection, + s=None, p=e, o=None, + limit=self.rag.query_limit + ) + + for triple in res: + subgraph.add( + (triple.s.value, triple.p.value, triple.o.value) + ) + + res = self.rag.triples_client.request( + user=self.user, collection=self.collection, + s=None, p=None, o=e, + limit=self.rag.query_limit, + ) + + for triple in res: + subgraph.add( + (triple.s.value, triple.p.value, triple.o.value) + ) + + subgraph = list(subgraph) + + subgraph = subgraph[0:self.rag.max_subgraph_size] + + if self.verbose: + print("Subgraph:", flush=True) + for edge in subgraph: + print(" ", str(edge), flush=True) + + if self.verbose: + print("Done.", flush=True) + + return subgraph + + def get_labelgraph(self, query): + + subgraph = self.get_subgraph(query) + + sg2 = [] + + for edge in subgraph: + + if edge[1] == LABEL: + continue + + s = self.maybe_label(edge[0]) + p = self.maybe_label(edge[1]) + o = self.maybe_label(edge[2]) + + sg2.append((s, p, o)) + + return sg2 + class GraphRag: def __init__( @@ -94,7 +232,7 @@ class GraphRag: self.label_cache = {} - self.lang = PromptClient( + self.prompt = PromptClient( pulsar_host=pulsar_host, input_queue=pr_request_queue, output_queue=pr_response_queue, @@ -104,144 +242,23 @@ class GraphRag: if self.verbose: print("Initialised", flush=True) - def get_vector(self, query): - - if self.verbose: - print("Compute embeddings...", flush=True) - - qembeds = self.embeddings.request(query) - - if self.verbose: - print("Done.", flush=True) - - return qembeds - - def get_entities(self, query): - - vectors = self.get_vector(query) - - if self.verbose: - print("Get entities...", flush=True) - - entities = self.ge_client.request( - vectors, self.entity_limit - ) - - entities = [ - e.value - for e in entities - ] - - if self.verbose: - print("Entities:", flush=True) - for ent in entities: - print(" ", ent, flush=True) - - return entities - - def maybe_label(self, e): - - if e in self.label_cache: - return self.label_cache[e] - - res = self.triples_client.request( - e, LABEL, None, limit=1 - ) - - if len(res) == 0: - self.label_cache[e] = e - return e - - self.label_cache[e] = res[0].o.value - return self.label_cache[e] - - def get_subgraph(self, query): - - entities = self.get_entities(query) - - subgraph = set() - - if self.verbose: - print("Get subgraph...", flush=True) - - for e in entities: - - res = self.triples_client.request( - e, None, None, - limit=self.query_limit - ) - - for triple in res: - subgraph.add( - (triple.s.value, triple.p.value, triple.o.value) - ) - - res = self.triples_client.request( - None, e, None, - limit=self.query_limit - ) - - for triple in res: - subgraph.add( - (triple.s.value, triple.p.value, triple.o.value) - ) - - res = self.triples_client.request( - None, None, e, - limit=self.query_limit - ) - - for triple in res: - subgraph.add( - (triple.s.value, triple.p.value, triple.o.value) - ) - - subgraph = list(subgraph) - - subgraph = subgraph[0:self.max_subgraph_size] - - if self.verbose: - print("Subgraph:", flush=True) - for edge in subgraph: - print(" ", str(edge), flush=True) - - if self.verbose: - print("Done.", flush=True) - - return subgraph - - def get_labelgraph(self, query): - - subgraph = self.get_subgraph(query) - - sg2 = [] - - for edge in subgraph: - - if edge[1] == LABEL: - continue - - s = self.maybe_label(edge[0]) - p = self.maybe_label(edge[1]) - o = self.maybe_label(edge[2]) - - sg2.append((s, p, o)) - - return sg2 - - def query(self, query): + def query(self, query, user="trustgraph", collection="default"): if self.verbose: print("Construct prompt...", flush=True) - kg = self.get_labelgraph(query) + q = Query( + rag=self, user=user, collection=collection, verbose=self.verbose + ) + + kg = q.get_labelgraph(query) if self.verbose: print("Invoke LLM...", flush=True) print(kg) print(query) - resp = self.lang.request_kg_prompt(query, kg) + resp = self.prompt.request_kg_prompt(query, kg) if self.verbose: print("Done", flush=True) diff --git a/trustgraph-flow/trustgraph/model/prompt/generic/service.py b/trustgraph-flow/trustgraph/model/prompt/generic/service.py index 16986980..96c9be57 100755 --- a/trustgraph-flow/trustgraph/model/prompt/generic/service.py +++ b/trustgraph-flow/trustgraph/model/prompt/generic/service.py @@ -2,6 +2,15 @@ Language service abstracts prompt engineering from LLM. """ +# +# FIXME: This module is broken, it doesn't conform to the prompt API change +# made in 0.14, nor the prompt template support. +# +# It could be made to conform by using prompt-template as a starting +# point, and hard-coding all the information. +# + + import json import re @@ -469,5 +478,7 @@ class Processor(ConsumerProducer): def run(): + raise RuntimeError("NOT IMPLEMENTED") + Processor.start(module, __doc__) diff --git a/trustgraph-flow/trustgraph/model/prompt/template/README.md b/trustgraph-flow/trustgraph/model/prompt/template/README.md new file mode 100644 index 00000000..0b98e906 --- /dev/null +++ b/trustgraph-flow/trustgraph/model/prompt/template/README.md @@ -0,0 +1,25 @@ + +prompt-template \ + -p pulsar://localhost:6650 \ + --system-prompt 'You are a {{attitude}}, you are called {{name}}' \ + --global-term \ + 'name=Craig' \ + 'attitude=LOUD, SHOUTY ANNOYING BOT' \ + --prompt \ + 'question={{question}}' \ + 'french-question={{question}}' \ + "analyze=Find the name and age in this text, and output a JSON structure containing just the name and age fields: {{description}}. Don't add markup, just output the raw JSON object." \ + "graph-query=Study the following knowledge graph, and then answer the question.\\n\nGraph:\\n{% for edge in knowledge %}({{edge.0}})-[{{edge.1}}]->({{edge.2}})\\n{%endfor%}\\nQuestion:\\n{{question}}" \ + "extract-definition=Analyse the text provided, and then return a list of terms and definitions. The output should be a JSON array, each item in the array is an object with fields 'term' and 'definition'.Don't add markup, just output the raw JSON object. Here is the text:\\n{{text}}" \ + --prompt-response-type \ + 'question=text' \ + 'analyze=json' \ + 'graph-query=text' \ + 'extract-definition=json' \ + --prompt-term \ + 'question=name:Bonny' \ + 'french-question=attitude:French-speaking bot' \ + --prompt-schema \ + 'analyze={ "type" : "object", "properties" : { "age": { "type" : "number" }, "name": { "type" : "string" } } }' \ + 'extract-definition={ "type": "array", "items": { "type": "object", "properties": { "term": { "type": "string" }, "definition": { "type": "string" } }, "required": [ "term", "definition" ] } }' + diff --git a/trustgraph-flow/trustgraph/model/prompt/template/prompt_manager.py b/trustgraph-flow/trustgraph/model/prompt/template/prompt_manager.py new file mode 100644 index 00000000..d8a032ca --- /dev/null +++ b/trustgraph-flow/trustgraph/model/prompt/template/prompt_manager.py @@ -0,0 +1,95 @@ + +import ibis +import json +from jsonschema import validate +import re + +from trustgraph.clients.llm_client import LlmClient + +class PromptConfiguration: + def __init__(self, system_template, global_terms={}, prompts={}): + self.system_template = system_template + self.global_terms = global_terms + self.prompts = prompts + +class Prompt: + def __init__(self, template, response_type = "text", terms=None, schema=None): + self.template = template + self.response_type = response_type + self.terms = terms + self.schema = schema + +class PromptManager: + + def __init__(self, llm, config): + self.llm = llm + self.config = config + self.terms = config.global_terms + + self.prompts = config.prompts + + try: + self.system_template = ibis.Template(config.system_template) + except: + raise RuntimeError("Error in system template") + + self.templates = {} + for k, v in self.prompts.items(): + try: + self.templates[k] = ibis.Template(v.template) + except: + raise RuntimeError(f"Error in template: {k}") + + if v.terms is None: + v.terms = {} + + def parse_json(self, text): + json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL) + + if json_match: + json_str = json_match.group(1).strip() + else: + # If no delimiters, assume the entire output is JSON + json_str = text.strip() + + return json.loads(json_str) + + def invoke(self, id, input): + + if id not in self.prompts: + raise RuntimeError("ID invalid") + + terms = self.terms | self.prompts[id].terms | input + + resp_type = self.prompts[id].response_type + + prompt = { + "system": self.system_template.render(terms), + "prompt": self.templates[id].render(terms) + } + + resp = self.llm.request(**prompt) + + print(resp, flush=True) + + if resp_type == "text": + return resp + + if resp_type != "json": + raise RuntimeError(f"Response type {resp_type} not known") + + try: + obj = self.parse_json(resp) + except: + raise RuntimeError("JSON parse fail") + + print(obj, flush=True) + if self.prompts[id].schema: + try: + print(self.prompts[id].schema) + validate(instance=obj, schema=self.prompts[id].schema) + except Exception as e: + raise RuntimeError(f"Schema validation fail: {e}") + + return obj + diff --git a/trustgraph-flow/trustgraph/model/prompt/template/prompts.py b/trustgraph-flow/trustgraph/model/prompt/template/prompts.py deleted file mode 100644 index e3148157..00000000 --- a/trustgraph-flow/trustgraph/model/prompt/template/prompts.py +++ /dev/null @@ -1,47 +0,0 @@ - -def to_relationships(template, text): - return template.format(text=text) - -def to_definitions(template, text): - return template.format(text=text) - -def to_topics(template, text): - return template.format(text=text) - -def to_rows(template, schema, text): - - field_schema = [ - f"- Name: {f.name}\n Type: {f.type}\n Definition: {f.description}" - for f in schema.fields - ] - - field_schema = "\n".join(field_schema) - - return template.format(schema=schema, text=text) - - schema = f"""Object name: {schema.name} -Description: {schema.description} - -Fields: -{schema}""" - - prompt = f"""""" - - return prompt - -def get_cypher(kg): - sg2 = [] - for f in kg: - sg2.append(f"({f.s})-[{f.p}]->({f.o})") - kg = "\n".join(sg2) - kg = kg.replace("\\", "-") - return kg - -def to_kg_query(template, query, kg): - cypher = get_cypher(kg) - return template.format(query=query, graph=cypher) - -def to_document_query(template, query, docs): - docs = "\n\n".join(docs) - return template.format(query=query, documents=docs) - diff --git a/trustgraph-flow/trustgraph/model/prompt/template/service.py b/trustgraph-flow/trustgraph/model/prompt/template/service.py index 14b65d5a..2e5416f4 100755 --- a/trustgraph-flow/trustgraph/model/prompt/template/service.py +++ b/trustgraph-flow/trustgraph/model/prompt/template/service.py @@ -16,8 +16,7 @@ from .... schema import prompt_request_queue, prompt_response_queue from .... base import ConsumerProducer from .... clients.llm_client import LlmClient -from . prompts import to_definitions, to_relationships, to_rows -from . prompts import to_kg_query, to_document_query, to_topics +from . prompt_manager import PromptConfiguration, Prompt, PromptManager module = ".".join(__name__.split(".")[1:-1]) @@ -29,6 +28,82 @@ class Processor(ConsumerProducer): def __init__(self, **params): + prompt_base = {} + + # Parsing the prompt information to the prompt configuration + # structure + prompt_arg = params.get("prompt", []) + if prompt_arg: + for p in prompt_arg: + toks = p.split("=", 1) + if len(toks) < 2: + raise RuntimeError(f"Prompt string not well-formed: {p}") + prompt_base[toks[0]] = { + "template": toks[1] + } + + prompt_response_type_arg = params.get("prompt_response_type", []) + if prompt_response_type_arg: + for p in prompt_response_type_arg: + toks = p.split("=", 1) + if len(toks) < 2: + raise RuntimeError(f"Response type not well-formed: {p}") + if toks[0] not in prompt_base: + raise RuntimeError(f"Response-type, {toks[0]} not known") + prompt_base[toks[0]]["response_type"] = toks[1] + + prompt_schema_arg = params.get("prompt_schema", []) + if prompt_schema_arg: + for p in prompt_schema_arg: + toks = p.split("=", 1) + if len(toks) < 2: + raise RuntimeError(f"Schema arg not well-formed: {p}") + if toks[0] not in prompt_base: + raise RuntimeError(f"Schema, {toks[0]} not known") + try: + prompt_base[toks[0]]["schema"] = json.loads(toks[1]) + except: + raise RuntimeError(f"Failed to parse JSON schema: {p}") + + prompt_term_arg = params.get("prompt_term", []) + if prompt_term_arg: + for p in prompt_term_arg: + toks = p.split("=", 1) + if len(toks) < 2: + raise RuntimeError(f"Term arg not well-formed: {p}") + if toks[0] not in prompt_base: + raise RuntimeError(f"Term, {toks[0]} not known") + kvtoks = toks[1].split(":", 1) + if len(kvtoks) < 2: + raise RuntimeError(f"Term not well-formed: {toks[1]}") + k, v = kvtoks + if "terms" not in prompt_base[toks[0]]: + prompt_base[toks[0]]["terms"] = {} + prompt_base[toks[0]]["terms"][k] = v + + global_terms = {} + + global_term_arg = params.get("global_term", []) + if global_term_arg: + for t in global_term_arg: + toks = t.split("=", 1) + if len(toks) < 2: + raise RuntimeError(f"Global term arg not well-formed: {t}") + global_terms[toks[0]] = toks[1] + + print(global_terms) + + prompts = { + k: Prompt(**v) + for k, v in prompt_base.items() + } + + prompt_configuration = PromptConfiguration( + system_template = params.get("system_prompt", ""), + global_terms = global_terms, + prompts = prompts + ) + input_queue = params.get("input_queue", default_input_queue) output_queue = params.get("output_queue", default_output_queue) subscriber = params.get("subscriber", default_subscriber) @@ -64,23 +139,21 @@ class Processor(ConsumerProducer): pulsar_host = self.pulsar_host ) - self.definition_template = definition_template - self.topic_template = topic_template - self.relationship_template = relationship_template - self.rows_template = rows_template - self.knowledge_query_template = knowledge_query_template - self.document_query_template = document_query_template + # System prompt hack + class Llm: + def __init__(self, llm): + self.llm = llm + def request(self, system, prompt): + print(system) + print(prompt, flush=True) + return self.llm.request(system, prompt) - def parse_json(self, text): - json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL) - - if json_match: - json_str = json_match.group(1).strip() - else: - # If no delimiters, assume the entire output is JSON - json_str = text.strip() + self.llm = Llm(self.llm) - return json.loads(json_str) + self.manager = PromptManager( + llm = self.llm, + config = prompt_configuration, + ) def handle(self, msg): @@ -90,88 +163,52 @@ class Processor(ConsumerProducer): id = msg.properties()["id"] - kind = v.kind - - print(f"Handling kind {kind}...", flush=True) - - if kind == "extract-definitions": - - self.handle_extract_definitions(id, v) - return - - elif kind == "extract-topics": - - self.handle_extract_topics(id, v) - return - - elif kind == "extract-relationships": - - self.handle_extract_relationships(id, v) - return - - elif kind == "extract-rows": - - self.handle_extract_rows(id, v) - return - - elif kind == "kg-prompt": - - self.handle_kg_prompt(id, v) - return - - elif kind == "document-prompt": - - self.handle_document_prompt(id, v) - return - - else: - - print("Invalid kind.", flush=True) - return - - def handle_extract_definitions(self, id, v): + kind = v.id try: - prompt = to_definitions(self.definition_template, v.chunk) + print(v.terms) - ans = self.llm.request(prompt) + input = { + k: json.loads(v) + for k, v in v.terms.items() + } + + print(f"Handling kind {kind}...", flush=True) + print(input, flush=True) - # Silently ignore JSON parse error - try: - defs = self.parse_json(ans) - except: - print("JSON parse error, ignored", flush=True) - defs = [] + resp = self.manager.invoke(kind, input) - output = [] + if isinstance(resp, str): - for defn in defs: + print("Send text response...", flush=True) + print(resp, flush=True) - try: - e = defn["entity"] - d = defn["definition"] + r = PromptResponse( + text=resp, + object=None, + error=None, + ) - if e == "": continue - if e is None: continue - if d == "": continue - if d is None: continue + self.producer.send(r, properties={"id": id}) - output.append( - Definition( - name=e, definition=d - ) - ) + return - except: - print("definition fields missing, ignored", flush=True) + else: - print("Send response...", flush=True) - r = PromptResponse(definitions=output, error=None) - self.producer.send(r, properties={"id": id}) + print("Send object response...", flush=True) + print(json.dumps(resp, indent=4), flush=True) - print("Done.", flush=True) - + r = PromptResponse( + text=None, + object=json.dumps(resp), + error=None, + ) + + self.producer.send(r, properties={"id": id}) + + return + except Exception as e: print(f"Exception: {e}") @@ -188,122 +225,6 @@ class Processor(ConsumerProducer): self.producer.send(r, properties={"id": id}) - def handle_extract_topics(self, id, v): - - try: - - prompt = to_topics(self.topic_template, v.chunk) - - ans = self.llm.request(prompt) - - # Silently ignore JSON parse error - try: - defs = self.parse_json(ans) - except: - print("JSON parse error, ignored", flush=True) - defs = [] - - output = [] - - for defn in defs: - - try: - e = defn["topic"] - d = defn["definition"] - - if e == "": continue - if e is None: continue - if d == "": continue - if d is None: continue - - output.append( - Topic( - name=e, definition=d - ) - ) - - except: - print("definition fields missing, ignored", flush=True) - - print("Send response...", flush=True) - r = PromptResponse(topics=output, error=None) - self.producer.send(r, properties={"id": id}) - - print("Done.", flush=True) - - except Exception as e: - - print(f"Exception: {e}") - - print("Send error response...", flush=True) - - r = PromptResponse( - error=Error( - type = "llm-error", - message = str(e), - ), - response=None, - ) - - self.producer.send(r, properties={"id": id}) - - - def handle_extract_relationships(self, id, v): - - try: - - prompt = to_relationships(self.relationship_template, v.chunk) - - ans = self.llm.request(prompt) - - # Silently ignore JSON parse error - try: - defs = self.parse_json(ans) - except: - print("JSON parse error, ignored", flush=True) - defs = [] - - output = [] - - for defn in defs: - - try: - - s = defn["subject"] - p = defn["predicate"] - o = defn["object"] - o_entity = defn["object-entity"] - - if s == "": continue - if s is None: continue - - if p == "": continue - if p is None: continue - - if o == "": continue - if o is None: continue - - if o_entity == "" or o_entity is None: - o_entity = False - - output.append( - Relationship( - s = s, - p = p, - o = o, - o_entity = o_entity, - ) - ) - - except Exception as e: - print("relationship fields missing, ignored", flush=True) - - print("Send response...", flush=True) - r = PromptResponse(relationships=output, error=None) - self.producer.send(r, properties={"id": id}) - - print("Done.", flush=True) - except Exception as e: print(f"Exception: {e}") @@ -320,147 +241,6 @@ class Processor(ConsumerProducer): self.producer.send(r, properties={"id": id}) - def handle_extract_rows(self, id, v): - - try: - - fields = v.row_schema.fields - - prompt = to_rows(self.rows_template, v.row_schema, v.chunk) - - print(prompt) - - ans = self.llm.request(prompt) - - print(ans) - - # Silently ignore JSON parse error - try: - objs = self.parse_json(ans) - except: - print("JSON parse error, ignored", flush=True) - objs = [] - - output = [] - - for obj in objs: - - try: - - row = {} - - for f in fields: - - if f.name not in obj: - print(f"Object ignored, missing field {f.name}") - row = {} - break - - row[f.name] = obj[f.name] - - if row == {}: - continue - - output.append(row) - - except Exception as e: - print("row fields missing, ignored", flush=True) - - for row in output: - print(row) - - print("Send response...", flush=True) - r = PromptResponse(rows=output, error=None) - self.producer.send(r, properties={"id": id}) - - print("Done.", flush=True) - - except Exception as e: - - print(f"Exception: {e}") - - print("Send error response...", flush=True) - - r = PromptResponse( - error=Error( - type = "llm-error", - message = str(e), - ), - response=None, - ) - - self.producer.send(r, properties={"id": id}) - - def handle_kg_prompt(self, id, v): - - try: - - prompt = to_kg_query(self.knowledge_query_template, v.query, v.kg) - - print(prompt) - - ans = self.llm.request(prompt) - - print(ans) - - print("Send response...", flush=True) - r = PromptResponse(answer=ans, error=None) - self.producer.send(r, properties={"id": id}) - - print("Done.", flush=True) - - except Exception as e: - - print(f"Exception: {e}") - - print("Send error response...", flush=True) - - r = PromptResponse( - error=Error( - type = "llm-error", - message = str(e), - ), - response=None, - ) - - self.producer.send(r, properties={"id": id}) - - def handle_document_prompt(self, id, v): - - try: - - prompt = to_document_query( - self.document_query_template, v.query, v.documents - ) - - print(prompt) - - ans = self.llm.request(prompt) - - print(ans) - - print("Send response...", flush=True) - r = PromptResponse(answer=ans, error=None) - self.producer.send(r, properties={"id": id}) - - print("Done.", flush=True) - - except Exception as e: - - print(f"Exception: {e}") - - print("Send error response...", flush=True) - - r = PromptResponse( - error=Error( - type = "llm-error", - message = str(e), - ), - response=None, - ) - - self.producer.send(r, properties={"id": id}) - @staticmethod def add_args(parser): @@ -482,39 +262,33 @@ class Processor(ConsumerProducer): ) parser.add_argument( - '--definition-template', - required=True, - help=f'Definition extraction template', + '--prompt', nargs='*', + help=f'Prompt template form id=template', ) parser.add_argument( - '--topic-template', - required=True, - help=f'Topic extraction template', + '--prompt-response-type', nargs='*', + help=f'Prompt response type, form id=json|text', ) parser.add_argument( - '--rows-template', - required=True, - help=f'Rows extraction template', + '--prompt-term', nargs='*', + help=f'Prompt response type, form id=key:value', ) parser.add_argument( - '--relationship-template', - required=True, - help=f'Relationship extraction template', + '--prompt-schema', nargs='*', + help=f'Prompt response schema, form id=schema', ) parser.add_argument( - '--knowledge-query-template', - required=True, - help=f'Knowledge query template', + '--system-prompt', + help=f'System prompt template', ) parser.add_argument( - '--document-query-template', - required=True, - help=f'Document query template', + '--global-term', nargs='+', + help=f'Global term, form key:value' ) def run(): diff --git a/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py b/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py index ff97f644..4db7dbf1 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py @@ -7,6 +7,7 @@ serverless endpoint service. Input is prompt, output is response. import requests import json from prometheus_client import Histogram +import os from .... schema import TextCompletionRequest, TextCompletionResponse, Error from .... schema import text_completion_request_queue @@ -23,6 +24,8 @@ default_subscriber = module default_temperature = 0.0 default_max_output = 4192 default_model = "AzureAI" +default_endpoint = os.getenv("AZURE_ENDPOINT") +default_token = os.getenv("AZURE_TOKEN") class Processor(ConsumerProducer): @@ -31,12 +34,18 @@ class Processor(ConsumerProducer): input_queue = params.get("input_queue", default_input_queue) output_queue = params.get("output_queue", default_output_queue) subscriber = params.get("subscriber", default_subscriber) - endpoint = params.get("endpoint") - token = params.get("token") + endpoint = params.get("endpoint", default_endpoint) + token = params.get("token", default_token) temperature = params.get("temperature", default_temperature) max_output = params.get("max_output", default_max_output) model = default_model + if endpoint is None: + raise RuntimeError("Azure endpoint not specified") + + if token is None: + raise RuntimeError("Azure token not specified") + super(Processor, self).__init__( **params | { "input_queue": input_queue, @@ -127,7 +136,7 @@ class Processor(ConsumerProducer): try: prompt = self.build_prompt( - "You are a helpful chatbot", + v.system, v.prompt ) @@ -199,11 +208,13 @@ class Processor(ConsumerProducer): parser.add_argument( '-e', '--endpoint', + default=default_endpoint, help=f'LLM model endpoint' ) parser.add_argument( '-k', '--token', + default=default_token, help=f'LLM model token' ) diff --git a/trustgraph-flow/trustgraph/model/text_completion/azure_openai/__init__.py b/trustgraph-flow/trustgraph/model/text_completion/azure_openai/__init__.py new file mode 100644 index 00000000..f2017af8 --- /dev/null +++ b/trustgraph-flow/trustgraph/model/text_completion/azure_openai/__init__.py @@ -0,0 +1,3 @@ + +from . llm import * + diff --git a/trustgraph-flow/trustgraph/model/text_completion/azure_openai/__main__.py b/trustgraph-flow/trustgraph/model/text_completion/azure_openai/__main__.py new file mode 100755 index 00000000..91342d2d --- /dev/null +++ b/trustgraph-flow/trustgraph/model/text_completion/azure_openai/__main__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 + +from . llm import run + +if __name__ == '__main__': + run() + diff --git a/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py new file mode 100755 index 00000000..a3edb859 --- /dev/null +++ b/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py @@ -0,0 +1,219 @@ + +""" +Simple LLM service, performs text prompt completion using the Azure +OpenAI endpoit service. Input is prompt, output is response. +""" + +import requests +import json +from prometheus_client import Histogram +from openai import AzureOpenAI +import os + +from .... schema import TextCompletionRequest, TextCompletionResponse, Error +from .... schema import text_completion_request_queue +from .... schema import text_completion_response_queue +from .... log_level import LogLevel +from .... base import ConsumerProducer +from .... exceptions import TooManyRequests + +module = ".".join(__name__.split(".")[1:-1]) + +default_input_queue = text_completion_request_queue +default_output_queue = text_completion_response_queue +default_subscriber = module +default_temperature = 0.0 +default_max_output = 4192 +default_api = "2024-02-15-preview" +default_endpoint = os.getenv("AZURE_ENDPOINT") +default_token = os.getenv("AZURE_TOKEN") + +class Processor(ConsumerProducer): + + def __init__(self, **params): + + input_queue = params.get("input_queue", default_input_queue) + output_queue = params.get("output_queue", default_output_queue) + subscriber = params.get("subscriber", default_subscriber) + endpoint = params.get("endpoint", default_endpoint) + token = params.get("token", default_token) + temperature = params.get("temperature", default_temperature) + max_output = params.get("max_output", default_max_output) + model = params.get("model") + api = params.get("api_version", default_api) + + if endpoint is None: + raise RuntimeError("Azure endpoint not specified") + + if token is None: + raise RuntimeError("Azure token not specified") + + super(Processor, self).__init__( + **params | { + "input_queue": input_queue, + "output_queue": output_queue, + "subscriber": subscriber, + "input_schema": TextCompletionRequest, + "output_schema": TextCompletionResponse, + "temperature": temperature, + "max_output": max_output, + "model": model, + "api": api, + } + ) + + if not hasattr(__class__, "text_completion_metric"): + __class__.text_completion_metric = Histogram( + 'text_completion_duration', + 'Text completion duration (seconds)', + buckets=[ + 0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, + 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, + 30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0, + 120.0 + ] + ) + + self.temperature = temperature + self.max_output = max_output + self.model = model + + self.openai = AzureOpenAI( + api_key=token, + api_version=api, + azure_endpoint = endpoint, + ) + + def handle(self, msg): + + v = msg.value() + + # Sender-produced ID + + id = msg.properties()["id"] + + print(f"Handling prompt {id}...", flush=True) + + prompt = v.system + "\n\n" + v.prompt + + + try: + + with __class__.text_completion_metric.time(): + resp = self.openai.chat.completions.create( + model=self.model, + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt + } + ] + } + ], + temperature=self.temperature, + max_tokens=self.max_output, + top_p=1, + ) + + inputtokens = resp.usage.prompt_tokens + outputtokens = resp.usage.completion_tokens + print(resp.choices[0].message.content, flush=True) + print(f"Input Tokens: {inputtokens}", flush=True) + print(f"Output Tokens: {outputtokens}", flush=True) + print("Send response...", flush=True) + + r = TextCompletionResponse(response=resp.choices[0].message.content, error=None, in_token=inputtokens, out_token=outputtokens, model=self.model) + self.producer.send(r, properties={"id": id}) + + except TooManyRequests: + + print("Send rate limit response...", flush=True) + + r = TextCompletionResponse( + error=Error( + type = "rate-limit", + message = str(e), + ), + response=None, + in_token=None, + out_token=None, + model=None, + ) + + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) + + except Exception as e: + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + r = TextCompletionResponse( + error=Error( + type = "llm-error", + message = str(e), + ), + response=None, + in_token=None, + out_token=None, + model=None, + ) + + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) + + print("Done.", flush=True) + + @staticmethod + def add_args(parser): + + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) + + parser.add_argument( + '-e', '--endpoint', + help=f'LLM model endpoint' + ) + + parser.add_argument( + '-a', '--api-version', + default=default_api, + help=f'API version (default: {default_api})' + ) + + parser.add_argument( + '-k', '--token', + help=f'LLM model token' + ) + + parser.add_argument( + '-m', '--model', + help=f'LLM model' + ) + + parser.add_argument( + '-t', '--temperature', + type=float, + default=default_temperature, + help=f'LLM temperature parameter (default: {default_temperature})' + ) + + parser.add_argument( + '-x', '--max-output', + type=int, + default=default_max_output, + help=f'LLM max output tokens (default: {default_max_output})' + ) + +def run(): + + Processor.start(module, __doc__) diff --git a/trustgraph-flow/trustgraph/model/text_completion/claude/llm.py b/trustgraph-flow/trustgraph/model/text_completion/claude/llm.py index ad949b02..01ce837d 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/claude/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/claude/llm.py @@ -6,6 +6,7 @@ Input is prompt, output is response. import anthropic from prometheus_client import Histogram +import os from .... schema import TextCompletionRequest, TextCompletionResponse, Error from .... schema import text_completion_request_queue @@ -22,6 +23,7 @@ default_subscriber = module default_model = 'claude-3-5-sonnet-20240620' default_temperature = 0.0 default_max_output = 8192 +default_api_key = os.getenv("CLAUDE_KEY") class Processor(ConsumerProducer): @@ -31,10 +33,13 @@ class Processor(ConsumerProducer): output_queue = params.get("output_queue", default_output_queue) subscriber = params.get("subscriber", default_subscriber) model = params.get("model", default_model) - api_key = params.get("api_key") + api_key = params.get("api_key", default_api_key) temperature = params.get("temperature", default_temperature) max_output = params.get("max_output", default_max_output) + if api_key is None: + raise RuntimeError("Claude API key not specified") + super(Processor, self).__init__( **params | { "input_queue": input_queue, @@ -90,7 +95,7 @@ class Processor(ConsumerProducer): model=self.model, max_tokens=self.max_output, temperature=self.temperature, - system = "You are a helpful chatbot.", + system = v.system, messages=[ { "role": "user", @@ -175,6 +180,7 @@ class Processor(ConsumerProducer): parser.add_argument( '-k', '--api-key', + default=default_api_key, help=f'Claude API key' ) diff --git a/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py b/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py index 4c64e8b6..d03e1554 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py @@ -6,6 +6,7 @@ Input is prompt, output is response. import cohere from prometheus_client import Histogram +import os from .... schema import TextCompletionRequest, TextCompletionResponse, Error from .... schema import text_completion_request_queue @@ -21,6 +22,7 @@ default_output_queue = text_completion_response_queue default_subscriber = module default_model = 'c4ai-aya-23-8b' default_temperature = 0.0 +default_api_key = os.getenv("COHERE_KEY") class Processor(ConsumerProducer): @@ -30,9 +32,12 @@ class Processor(ConsumerProducer): output_queue = params.get("output_queue", default_output_queue) subscriber = params.get("subscriber", default_subscriber) model = params.get("model", default_model) - api_key = params.get("api_key") + api_key = params.get("api_key", default_api_key) temperature = params.get("temperature", default_temperature) + if api_key is None: + raise RuntimeError("Cohere API key not specified") + super(Processor, self).__init__( **params | { "input_queue": input_queue, @@ -74,6 +79,7 @@ class Processor(ConsumerProducer): print(f"Handling prompt {id}...", flush=True) + system = v.system prompt = v.prompt try: @@ -83,7 +89,7 @@ class Processor(ConsumerProducer): output = self.cohere.chat( model=self.model, message=prompt, - preamble = "You are a helpful AI-assistant.", + preamble = system, temperature=self.temperature, chat_history=[], prompt_truncation='auto', @@ -162,6 +168,7 @@ class Processor(ConsumerProducer): parser.add_argument( '-k', '--api-key', + default=default_api_key, help=f'Cohere API key' ) diff --git a/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/__init__.py b/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/__init__.py new file mode 100644 index 00000000..f2017af8 --- /dev/null +++ b/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/__init__.py @@ -0,0 +1,3 @@ + +from . llm import * + diff --git a/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/__main__.py b/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/__main__.py new file mode 100755 index 00000000..91342d2d --- /dev/null +++ b/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/__main__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 + +from . llm import run + +if __name__ == '__main__': + run() + diff --git a/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py b/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py new file mode 100644 index 00000000..a249998d --- /dev/null +++ b/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py @@ -0,0 +1,228 @@ + +""" +Simple LLM service, performs text prompt completion using GoogleAIStudio. +Input is prompt, output is response. +""" + +import google.generativeai as genai +from google.generativeai.types import HarmCategory, HarmBlockThreshold +from google.api_core.exceptions import ResourceExhausted +from prometheus_client import Histogram +import os + +from .... schema import TextCompletionRequest, TextCompletionResponse, Error +from .... schema import text_completion_request_queue +from .... schema import text_completion_response_queue +from .... log_level import LogLevel +from .... base import ConsumerProducer +from .... exceptions import TooManyRequests + +module = ".".join(__name__.split(".")[1:-1]) + +default_input_queue = text_completion_request_queue +default_output_queue = text_completion_response_queue +default_subscriber = module +default_model = 'gemini-1.5-flash-002' +default_temperature = 0.0 +default_max_output = 8192 +default_api_key = os.getenv("GOOGLE_AI_STUDIO_KEY") + +class Processor(ConsumerProducer): + + def __init__(self, **params): + + input_queue = params.get("input_queue", default_input_queue) + output_queue = params.get("output_queue", default_output_queue) + subscriber = params.get("subscriber", default_subscriber) + model = params.get("model", default_model) + api_key = params.get("api_key", default_api_key) + temperature = params.get("temperature", default_temperature) + max_output = params.get("max_output", default_max_output) + + if api_key is None: + raise RuntimeError("Google AI Studio API key not specified") + + super(Processor, self).__init__( + **params | { + "input_queue": input_queue, + "output_queue": output_queue, + "subscriber": subscriber, + "input_schema": TextCompletionRequest, + "output_schema": TextCompletionResponse, + "model": model, + "temperature": temperature, + "max_output": max_output, + } + ) + + if not hasattr(__class__, "text_completion_metric"): + __class__.text_completion_metric = Histogram( + 'text_completion_duration', + 'Text completion duration (seconds)', + buckets=[ + 0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, + 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, + 30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0, + 120.0 + ] + ) + + genai.configure(api_key=api_key) + self.model = model + self.temperature = temperature + self.max_output = max_output + + self.generation_config = { + "temperature": temperature, + "top_p": 1, + "top_k": 40, + "max_output_tokens": max_output, + "response_mime_type": "text/plain", + } + + block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH + + self.safety_settings={ + HarmCategory.HARM_CATEGORY_HATE_SPEECH: block_level, + HarmCategory.HARM_CATEGORY_HARASSMENT: block_level, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: block_level, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: block_level, + # There is a documentation conflict on whether or not CIVIC_INTEGRITY is a valid category + # HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY: block_level, + } + + self.llm = genai.GenerativeModel( + model_name=model, + generation_config=self.generation_config, + safety_settings=self.safety_settings, + system_instruction="You are a helpful AI assistant.", + ) + + print("Initialised", flush=True) + + def handle(self, msg): + + v = msg.value() + + # Sender-produced ID + + id = msg.properties()["id"] + + print(f"Handling prompt {id}...", flush=True) + + # FIXME: There's a system prompt above. Maybe if system changes, + # then reset self.llm? It shouldn't do, because system prompt + # is set system wide? + + # Or... could keep different LLM structures for different system + # prompts? + + prompt = v.system + "\n\n" + v.prompt + + try: + + # FIXME: Rate limits? + + with __class__.text_completion_metric.time(): + + chat_session = self.llm.start_chat( + history=[ + ] + ) + response = chat_session.send_message(prompt) + + resp = response.text + inputtokens = int(response.usage_metadata.prompt_token_count) + outputtokens = int(response.usage_metadata.candidates_token_count) + print(resp, flush=True) + print(f"Input Tokens: {inputtokens}", flush=True) + print(f"Output Tokens: {outputtokens}", flush=True) + + print("Send response...", flush=True) + r = TextCompletionResponse(response=resp, error=None, in_token=inputtokens, out_token=outputtokens, model=self.model) + self.send(r, properties={"id": id}) + + print("Done.", flush=True) + + # FIXME: Wrong exception, don't know what this LLM throws + # for a rate limit + except ResourceExhausted as e: + + print("Send rate limit response...", flush=True) + + r = TextCompletionResponse( + error=Error( + type = "rate-limit", + message = str(e), + ), + response=None, + in_token=None, + out_token=None, + model=None, + ) + + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) + + except Exception as e: + + print(f"Exception: {e}") + + print("Send error response...", flush=True) + + r = TextCompletionResponse( + error=Error( + type = "llm-error", + message = str(e), + ), + response=None, + in_token=None, + out_token=None, + model=None, + ) + + self.producer.send(r, properties={"id": id}) + + self.consumer.acknowledge(msg) + + @staticmethod + def add_args(parser): + + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) + + parser.add_argument( + '-m', '--model', + default=default_model, + help=f'LLM model (default: {default_model})' + ) + + parser.add_argument( + '-k', '--api-key', + default=default_api_key, + help=f'GoogleAIStudio API key' + ) + + parser.add_argument( + '-t', '--temperature', + type=float, + default=default_temperature, + help=f'LLM temperature parameter (default: {default_temperature})' + ) + + parser.add_argument( + '-x', '--max-output', + type=int, + default=default_max_output, + help=f'LLM max output tokens (default: {default_max_output})' + ) + +def run(): + + Processor.start(module, __doc__) + + diff --git a/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py b/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py index 86427167..274948a8 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py @@ -20,7 +20,7 @@ default_input_queue = text_completion_request_queue default_output_queue = text_completion_response_queue default_subscriber = module default_model = 'LLaMA_CPP' -default_llamafile = 'http://localhost:8080/v1' +default_llamafile = os.getenv("LLAMAFILE_URL", "http://localhost:8080/v1") default_temperature = 0.0 default_max_output = 4096 @@ -84,7 +84,7 @@ class Processor(ConsumerProducer): print(f"Handling prompt {id}...", flush=True) - prompt = v.prompt + prompt = v.system + "\n\n" + v.prompt try: diff --git a/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py b/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py index b506b3cd..00d44f6d 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py @@ -6,6 +6,7 @@ Input is prompt, output is response. from ollama import Client from prometheus_client import Histogram, Info +import os from .... schema import TextCompletionRequest, TextCompletionResponse, Error from .... schema import text_completion_request_queue @@ -19,8 +20,8 @@ module = ".".join(__name__.split(".")[1:-1]) default_input_queue = text_completion_request_queue default_output_queue = text_completion_response_queue default_subscriber = module -default_model = 'gemma2' -default_ollama = 'http://localhost:11434' +default_model = 'gemma2:9b' +default_ollama = os.getenv("OLLAMA_HOST", 'http://localhost:11434') class Processor(ConsumerProducer): @@ -79,7 +80,7 @@ class Processor(ConsumerProducer): print(f"Handling prompt {id}...", flush=True) - prompt = v.prompt + prompt = v.system + "\n\n" + v.prompt try: @@ -152,7 +153,7 @@ class Processor(ConsumerProducer): parser.add_argument( '-m', '--model', default="gemma2", - help=f'LLM model (default: gemma2)' + help=f'LLM model (default: {default_model})' ) parser.add_argument( diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py index 5d259e7e..c874943e 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py @@ -6,6 +6,7 @@ Input is prompt, output is response. from openai import OpenAI from prometheus_client import Histogram +import os from .... schema import TextCompletionRequest, TextCompletionResponse, Error from .... schema import text_completion_request_queue @@ -22,6 +23,7 @@ default_subscriber = module default_model = 'gpt-3.5-turbo' default_temperature = 0.0 default_max_output = 4096 +default_api_key = os.getenv("OPENAI_TOKEN") class Processor(ConsumerProducer): @@ -31,10 +33,13 @@ class Processor(ConsumerProducer): output_queue = params.get("output_queue", default_output_queue) subscriber = params.get("subscriber", default_subscriber) model = params.get("model", default_model) - api_key = params.get("api_key") + api_key = params.get("api_key", default_api_key) temperature = params.get("temperature", default_temperature) max_output = params.get("max_output", default_max_output) + if api_key is None: + raise RuntimeError("OpenAI API key not specified") + super(Processor, self).__init__( **params | { "input_queue": input_queue, @@ -78,7 +83,7 @@ class Processor(ConsumerProducer): print(f"Handling prompt {id}...", flush=True) - prompt = v.prompt + prompt = v.system + "\n\n" + v.prompt try: @@ -185,6 +190,7 @@ class Processor(ConsumerProducer): parser.add_argument( '-k', '--api-key', + default=default_api_key, help=f'OpenAI API key' ) diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py index 5cc41437..7bb5133a 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py @@ -60,7 +60,10 @@ class Processor(ConsumerProducer): for vec in v.vectors: dim = len(vec) - collection = "doc_" + str(dim) + collection = ( + "d_" + v.user + "_" + v.collection + "_" + + str(dim) + ) search_result = self.client.query_points( collection_name=collection, diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py index e61a00a7..8991f9ea 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py @@ -66,7 +66,10 @@ class Processor(ConsumerProducer): for vec in v.vectors: dim = len(vec) - collection = "triples_" + str(dim) + collection = ( + "t_" + v.user + "_" + v.collection + "_" + + str(dim) + ) search_result = self.client.query_points( collection_name=collection, diff --git a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py index 5e1e0e3e..4245784d 100755 --- a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py @@ -38,7 +38,8 @@ class Processor(ConsumerProducer): } ) - self.tg = TrustGraph([graph_host]) + self.graph_host = [graph_host] + self.table = None def create_value(self, ent): if ent.startswith("http://") or ent.startswith("https://"): @@ -52,6 +53,15 @@ class Processor(ConsumerProducer): v = msg.value() + table = (v.user, v.collection) + + if table != self.table: + self.tg = TrustGraph( + hosts=self.graph_host, + keyspace=v.user, table=v.collection, + ) + self.table = table + # Sender-produced ID id = msg.properties()["id"] diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py index 87f8d24f..1219050e 100755 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py @@ -108,7 +108,9 @@ class Processor(ConsumerProducer): print(f"Handling input {id}...", flush=True) - response = self.rag.query(v.query) + response = self.rag.query( + query=v.query, user=v.user, collection=v.collection + ) print("Send response...", flush=True) r = GraphRagResponse(response = response, error=None) diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py index f22ae74a..813c4f29 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py @@ -37,7 +37,6 @@ class Processor(Consumer): ) self.last_collection = None - self.last_dim = None self.client = QdrantClient(url=store_uri) @@ -52,9 +51,12 @@ class Processor(Consumer): for vec in v.vectors: dim = len(vec) - collection = "doc_" + str(dim) + collection = ( + "d_" + v.metadata.user + "_" + v.metadata.collection + "_" + + str(dim) + ) - if dim != self.last_dim: + if collection != self.last_collection: if not self.client.collection_exists(collection): @@ -70,7 +72,6 @@ class Processor(Consumer): raise e self.last_collection = collection - self.last_dim = dim self.client.upsert( collection_name=collection, diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py index 95448750..e27c2516 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py @@ -37,7 +37,6 @@ class Processor(Consumer): ) self.last_collection = None - self.last_dim = None self.client = QdrantClient(url=store_uri) @@ -50,9 +49,12 @@ class Processor(Consumer): for vec in v.vectors: dim = len(vec) - collection = "triples_" + str(dim) + collection = ( + "t_" + v.metadata.user + "_" + v.metadata.collection + "_" + + str(dim) + ) - if dim != self.last_dim: + if collection != self.last_collection: if not self.client.collection_exists(collection): @@ -68,7 +70,6 @@ class Processor(Consumer): raise e self.last_collection = collection - self.last_dim = dim self.client.upsert( collection_name=collection, diff --git a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py index 84c002ff..e7078e08 100755 --- a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py @@ -10,7 +10,7 @@ import argparse import time from .... direct.cassandra import TrustGraph -from .... schema import Triple +from .... schema import Triples from .... schema import triples_store_queue from .... log_level import LogLevel from .... base import Consumer @@ -33,22 +33,42 @@ class Processor(Consumer): **params | { "input_queue": input_queue, "subscriber": subscriber, - "input_schema": Triple, + "input_schema": Triples, "graph_host": graph_host, } ) - self.tg = TrustGraph([graph_host]) + self.graph_host = [graph_host] + self.table = None def handle(self, msg): v = msg.value() - self.tg.insert( - v.s.value, - v.p.value, - v.o.value - ) + table = (v.metadata.user, v.metadata.collection) + + if self.table is None or self.table != table: + + self.tg = None + + try: + self.tg = TrustGraph( + hosts=self.graph_host, + keyspace=v.metadata.user, table=v.metadata.collection, + ) + except Exception as e: + print("Exception", e, flush=True) + time.sleep(1) + raise e + + self.table = table + + for t in v.triples: + self.tg.insert( + t.s.value, + t.p.value, + t.o.value + ) @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py index d20ecdd2..82302e96 100755 --- a/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/neo4j/write.py @@ -11,7 +11,7 @@ import time from neo4j import GraphDatabase -from .... schema import Triple +from .... schema import Triples from .... schema import triples_store_queue from .... log_level import LogLevel from .... base import Consumer @@ -39,7 +39,7 @@ class Processor(Consumer): **params | { "input_queue": input_queue, "subscriber": subscriber, - "input_schema": Triple, + "input_schema": Triples, "graph_host": graph_host, } ) @@ -116,14 +116,16 @@ class Processor(Consumer): v = msg.value() - self.create_node(v.s.value) + for t in v.triples: - if v.o.is_uri: - self.create_node(v.o.value) - self.relate_node(v.s.value, v.p.value, v.o.value) - else: - self.create_literal(v.o.value) - self.relate_literal(v.s.value, v.p.value, v.o.value) + self.create_node(t.s.value) + + if t.o.is_uri: + self.create_node(t.o.value) + self.relate_node(t.s.value, t.p.value, t.o.value) + else: + self.create_literal(t.o.value) + self.relate_literal(t.s.value, t.p.value, t.o.value) @staticmethod def add_args(parser): diff --git a/trustgraph-parquet/scripts/load-graph-embeddings b/trustgraph-parquet/scripts/load-graph-embeddings index 2dc3c06f..0e6ecf93 100755 --- a/trustgraph-parquet/scripts/load-graph-embeddings +++ b/trustgraph-parquet/scripts/load-graph-embeddings @@ -6,7 +6,7 @@ Loads Graph embeddings into TrustGraph processing. import pulsar from pulsar.schema import JsonSchema -from trustgraph.schema import GraphEmbeddings, Value +from trustgraph.schema import GraphEmbeddings, Value, Metadata from trustgraph.schema import graph_embeddings_store_queue import argparse import os @@ -24,6 +24,8 @@ class Loader: output_queue, log_level, file, + user, + collection, ): self.client = pulsar.Client( @@ -38,6 +40,8 @@ class Loader: ) self.file = file + self.user = user + self.collection = collection def run(self): @@ -66,11 +70,16 @@ class Loader: n = ent.as_py() r = GraphEmbeddings( + metadata=Metadata( + metadata=[], + user=self.user, + collection=self.collection, + ), vectors=b, entity=Value( value=n, is_uri=n.startswith("https:") - ) + ), ) self.producer.send(r) @@ -90,6 +99,8 @@ def main(): default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650') default_output_queue = graph_embeddings_store_queue + default_user = 'trustgraph' + default_collection = 'default' parser.add_argument( '-p', '--pulsar-host', @@ -103,6 +114,18 @@ def main(): help=f'Output queue (default: {default_output_queue})' ) + parser.add_argument( + '-u', '--user', + default=default_user, + help=f'User ID (default: {default_user})' + ) + + parser.add_argument( + '-c', '--collection', + default=default_collection, + help=f'Collection ID (default: {default_collection})' + ) + parser.add_argument( '-l', '--log-level', type=LogLevel, @@ -127,6 +150,8 @@ def main(): output_queue=args.output_queue, log_level=args.log_level, file=args.file, + user=args.user, + collection=args.collection, ) p.run() diff --git a/trustgraph-parquet/scripts/load-triples b/trustgraph-parquet/scripts/load-triples index e03c065b..e6bb0ff7 100755 --- a/trustgraph-parquet/scripts/load-triples +++ b/trustgraph-parquet/scripts/load-triples @@ -6,7 +6,7 @@ Loads Graph embeddings into TrustGraph processing. import pulsar from pulsar.schema import JsonSchema -from trustgraph.schema import Triple, Value +from trustgraph.schema import Triples, Triple, Value, Metadata from trustgraph.schema import triples_store_queue import argparse import os @@ -24,6 +24,8 @@ class Loader: output_queue, log_level, file, + user, + collection, ): self.client = pulsar.Client( @@ -33,11 +35,13 @@ class Loader: self.producer = self.client.create_producer( topic=output_queue, - schema=JsonSchema(Triple), + schema=JsonSchema(Triples), chunking_enabled=True, ) self.file = file + self.user = user + self.collection = collection def run(self): @@ -66,10 +70,26 @@ class Loader: for s, p, o in zip(sc, pc, oc): - r = Triple( - s=Value(value=s.as_py(), is_uri=True), - p=Value(value=p.as_py(), is_uri=True), - o=Value(value=o.as_py(), is_uri=o.as_py().startswith("https:")) + r = Triples( + metadata=Metadata( + metadata=[], + user=self.user, + collection=self.collection, + ), + triples=[ + Triple( + s=Value( + value=s.as_py(), is_uri=True + ), + p=Value( + value=p.as_py(), is_uri=True + ), + o=Value( + value=o.as_py(), + is_uri=o.as_py().startswith("https:") + ) + ) + ] ) self.producer.send(r) @@ -89,6 +109,8 @@ def main(): default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650') default_output_queue = triples_store_queue + default_user = 'trustgraph' + default_collection = 'default' parser.add_argument( '-p', '--pulsar-host', @@ -102,6 +124,18 @@ def main(): help=f'Output queue (default: {default_output_queue})' ) + parser.add_argument( + '-u', '--user', + default=default_user, + help=f'User ID (default: {default_user})' + ) + + parser.add_argument( + '-c', '--collection', + default=default_collection, + help=f'Collection ID (default: {default_collection})' + ) + parser.add_argument( '-l', '--log-level', type=LogLevel, @@ -126,6 +160,8 @@ def main(): output_queue=args.output_queue, log_level=args.log_level, file=args.file, + user=args.user, + collection=args.collection, ) p.run() diff --git a/trustgraph-parquet/setup.py b/trustgraph-parquet/setup.py index ee0c7ce4..668cde1c 100644 --- a/trustgraph-parquet/setup.py +++ b/trustgraph-parquet/setup.py @@ -34,15 +34,18 @@ setuptools.setup( python_requires='>=3.8', download_url = "https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v" + version + ".tar.gz", install_requires=[ - "trustgraph-base<0.12", + "trustgraph-base>=0.15,<0.16", "pulsar-client", "prometheus-client", "pyarrow", + "pandas", ], scripts=[ "scripts/concat-parquet", "scripts/dump-parquet", "scripts/ge-dump-parquet", "scripts/triples-dump-parquet", + "scripts/load-graph-embeddings", + "scripts/load-triples", ] ) diff --git a/trustgraph-parquet/trustgraph/dump/triples/parquet/processor.py b/trustgraph-parquet/trustgraph/dump/triples/parquet/processor.py index 553a62b8..dc15d8a9 100755 --- a/trustgraph-parquet/trustgraph/dump/triples/parquet/processor.py +++ b/trustgraph-parquet/trustgraph/dump/triples/parquet/processor.py @@ -9,7 +9,7 @@ import os import argparse import time -from .... schema import Triple +from .... schema import Triples from .... schema import triples_store_queue from .... base import Consumer @@ -38,7 +38,7 @@ class Processor(Consumer): **params | { "input_queue": input_queue, "subscriber": subscriber, - "input_schema": Triple, + "input_schema": Triples, } ) @@ -51,7 +51,9 @@ class Processor(Consumer): def handle(self, msg): v = msg.value() - self.writer.write(v.s.value, v.p.value, v.o.value) + + for t in v.triples: + self.writer.write(t.s.value, t.p.value, t.o.value) @staticmethod def add_args(parser): diff --git a/trustgraph-vertexai/setup.py b/trustgraph-vertexai/setup.py index 3ef59da9..0cdc3a97 100644 --- a/trustgraph-vertexai/setup.py +++ b/trustgraph-vertexai/setup.py @@ -34,7 +34,7 @@ setuptools.setup( python_requires='>=3.8', download_url = "https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v" + version + ".tar.gz", install_requires=[ - "trustgraph-base<0.12", + "trustgraph-base>=0.15,<0.16", "pulsar-client", "google-cloud-aiplatform", "prometheus-client", diff --git a/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py b/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py index c57b9fb0..cb817836 100755 --- a/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py +++ b/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py @@ -7,6 +7,7 @@ Google Cloud. Input is prompt, output is response. import vertexai import time from prometheus_client import Histogram +import os from google.oauth2 import service_account import google @@ -38,6 +39,7 @@ default_model = 'gemini-1.0-pro-001' default_region = 'us-central1' default_temperature = 0.0 default_max_output = 8192 +default_private_key = "private.json" class Processor(ConsumerProducer): @@ -48,10 +50,13 @@ class Processor(ConsumerProducer): subscriber = params.get("subscriber", default_subscriber) region = params.get("region", default_region) model = params.get("model", default_model) - private_key = params.get("private_key") + private_key = params.get("private_key", default_private_key) temperature = params.get("temperature", default_temperature) max_output = params.get("max_output", default_max_output) + if private_key is None: + raise RuntimeError("Private key file not specified") + super(Processor, self).__init__( **params | { "input_queue": input_queue, @@ -138,7 +143,7 @@ class Processor(ConsumerProducer): print(f"Handling prompt {id}...", flush=True) - prompt = v.prompt + prompt = v.system + "\n\n" + v.prompt with __class__.text_completion_metric.time(): diff --git a/trustgraph/README.md b/trustgraph/README.md new file mode 100644 index 00000000..7a2ce130 --- /dev/null +++ b/trustgraph/README.md @@ -0,0 +1 @@ +See https://trustgraph.ai/ diff --git a/trustgraph/setup.py b/trustgraph/setup.py index f840943e..8e50aed5 100644 --- a/trustgraph/setup.py +++ b/trustgraph/setup.py @@ -34,13 +34,13 @@ setuptools.setup( python_requires='>=3.8', download_url = "https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v" + version + ".tar.gz", install_requires=[ - "trustgraph-base<0.12", - "trustgraph-bedrock<0.12", - "trustgraph-cli<0.12", - "trustgraph-embeddings-hf<0.12", - "trustgraph-flow<0.12", - "trustgraph-parquet<0.12", - "trustgraph-vertexai<0.12", + "trustgraph-base>=0.15,<0.16", + "trustgraph-bedrock>=0.15,<0.16", + "trustgraph-cli>=0.15,<0.16", + "trustgraph-embeddings-hf>=0.15,<0.16", + "trustgraph-flow>=0.15,<0.16", + "trustgraph-parquet>=0.15,<0.16", + "trustgraph-vertexai>=0.15,<0.16", ], scripts=[ ]