Merge branch 'release/v0.17'

This commit is contained in:
Cyber MacGeddon 2024-12-10 22:37:08 +00:00
commit 6c8a8d7932
132 changed files with 4848 additions and 929 deletions

View file

@ -5,7 +5,7 @@ on:
workflow_dispatch: workflow_dispatch:
push: push:
tags: tags:
- v0.15.* - v0.17.*
permissions: permissions:
contents: read contents: read
@ -48,20 +48,6 @@ jobs:
- name: Publish release distributions to PyPI - name: Publish release distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1 uses: pypa/gh-action-pypi-publish@release/v1
- name: Create deploy bundle
run: templates/generate-all deploy.zip ${{ steps.version.outputs.VERSION }}
- uses: ncipollo/release-action@v1
with:
artifacts: deploy.zip
generateReleaseNotes: true
makeLatest: false
prerelease: true
skipIfReleaseExists: true
- name: Build container
run: make container VERSION=${{ steps.version.outputs.VERSION }}
- name: Extract metadata for container - name: Extract metadata for container
id: meta id: meta
uses: docker/metadata-action@v4 uses: docker/metadata-action@v4
@ -84,3 +70,13 @@ jobs:
tags: ${{ steps.meta.outputs.tags }} tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}
- name: Create deploy bundle
run: templates/generate-all deploy.zip ${{ steps.version.outputs.VERSION }}
- uses: ncipollo/release-action@v1
with:
artifacts: deploy.zip
generateReleaseNotes: true
makeLatest: false
prerelease: true
skipIfReleaseExists: true

View file

@ -16,7 +16,7 @@ RUN pip3 install torch --index-url https://download.pytorch.org/whl/cpu
RUN pip3 install anthropic boto3 cohere openai google-cloud-aiplatform ollama google-generativeai \ RUN pip3 install anthropic boto3 cohere openai google-cloud-aiplatform ollama google-generativeai \
langchain langchain-core langchain-huggingface langchain-text-splitters \ langchain langchain-core langchain-huggingface langchain-text-splitters \
langchain-community pymilvus sentence-transformers transformers \ langchain-community pymilvus sentence-transformers transformers \
huggingface-hub pulsar-client cassandra-driver pyarrow pyyaml \ huggingface-hub pulsar-client cassandra-driver pyyaml \
neo4j tiktoken && \ neo4j tiktoken && \
pip3 cache purge pip3 cache purge
@ -32,7 +32,6 @@ COPY trustgraph-base/ /root/build/trustgraph-base/
COPY trustgraph-flow/ /root/build/trustgraph-flow/ COPY trustgraph-flow/ /root/build/trustgraph-flow/
COPY trustgraph-vertexai/ /root/build/trustgraph-vertexai/ COPY trustgraph-vertexai/ /root/build/trustgraph-vertexai/
COPY trustgraph-bedrock/ /root/build/trustgraph-bedrock/ COPY trustgraph-bedrock/ /root/build/trustgraph-bedrock/
COPY trustgraph-parquet/ /root/build/trustgraph-parquet/
COPY trustgraph-embeddings-hf/ /root/build/trustgraph-embeddings-hf/ COPY trustgraph-embeddings-hf/ /root/build/trustgraph-embeddings-hf/
COPY trustgraph-cli/ /root/build/trustgraph-cli/ COPY trustgraph-cli/ /root/build/trustgraph-cli/
@ -42,7 +41,6 @@ RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-base/
RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-flow/ RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-flow/
RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-vertexai/ RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-vertexai/
RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-bedrock/ RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-bedrock/
RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-parquet/
RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-embeddings-hf/ RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-embeddings-hf/
RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-cli/ RUN pip3 wheel -w /root/wheels/ --no-deps ./trustgraph-cli/
@ -61,7 +59,6 @@ RUN \
pip3 install /root/wheels/trustgraph_flow-* && \ pip3 install /root/wheels/trustgraph_flow-* && \
pip3 install /root/wheels/trustgraph_vertexai-* && \ pip3 install /root/wheels/trustgraph_vertexai-* && \
pip3 install /root/wheels/trustgraph_bedrock-* && \ pip3 install /root/wheels/trustgraph_bedrock-* && \
pip3 install /root/wheels/trustgraph_parquet-* && \
pip3 install /root/wheels/trustgraph_embeddings_hf-* && \ pip3 install /root/wheels/trustgraph_embeddings_hf-* && \
pip3 install /root/wheels/trustgraph_cli-* && \ pip3 install /root/wheels/trustgraph_cli-* && \
pip3 cache purge && \ pip3 cache purge && \

View file

@ -14,7 +14,6 @@ wheels:
pip3 wheel --no-deps --wheel-dir dist trustgraph-flow/ pip3 wheel --no-deps --wheel-dir dist trustgraph-flow/
pip3 wheel --no-deps --wheel-dir dist trustgraph-vertexai/ pip3 wheel --no-deps --wheel-dir dist trustgraph-vertexai/
pip3 wheel --no-deps --wheel-dir dist trustgraph-bedrock/ pip3 wheel --no-deps --wheel-dir dist trustgraph-bedrock/
pip3 wheel --no-deps --wheel-dir dist trustgraph-parquet/
pip3 wheel --no-deps --wheel-dir dist trustgraph-embeddings-hf/ pip3 wheel --no-deps --wheel-dir dist trustgraph-embeddings-hf/
pip3 wheel --no-deps --wheel-dir dist trustgraph-cli/ pip3 wheel --no-deps --wheel-dir dist trustgraph-cli/
@ -25,7 +24,6 @@ packages: update-package-versions
cd trustgraph-flow && python3 setup.py sdist --dist-dir ../dist/ cd trustgraph-flow && python3 setup.py sdist --dist-dir ../dist/
cd trustgraph-vertexai && python3 setup.py sdist --dist-dir ../dist/ cd trustgraph-vertexai && python3 setup.py sdist --dist-dir ../dist/
cd trustgraph-bedrock && python3 setup.py sdist --dist-dir ../dist/ cd trustgraph-bedrock && python3 setup.py sdist --dist-dir ../dist/
cd trustgraph-parquet && python3 setup.py sdist --dist-dir ../dist/
cd trustgraph-embeddings-hf && python3 setup.py sdist --dist-dir ../dist/ cd trustgraph-embeddings-hf && python3 setup.py sdist --dist-dir ../dist/
cd trustgraph-cli && python3 setup.py sdist --dist-dir ../dist/ cd trustgraph-cli && python3 setup.py sdist --dist-dir ../dist/
@ -41,7 +39,6 @@ update-package-versions:
echo __version__ = \"${VERSION}\" > trustgraph-flow/trustgraph/flow_version.py echo __version__ = \"${VERSION}\" > trustgraph-flow/trustgraph/flow_version.py
echo __version__ = \"${VERSION}\" > trustgraph-vertexai/trustgraph/vertexai_version.py echo __version__ = \"${VERSION}\" > trustgraph-vertexai/trustgraph/vertexai_version.py
echo __version__ = \"${VERSION}\" > trustgraph-bedrock/trustgraph/bedrock_version.py echo __version__ = \"${VERSION}\" > trustgraph-bedrock/trustgraph/bedrock_version.py
echo __version__ = \"${VERSION}\" > trustgraph-parquet/trustgraph/parquet_version.py
echo __version__ = \"${VERSION}\" > trustgraph-embeddings-hf/trustgraph/embeddings_hf_version.py echo __version__ = \"${VERSION}\" > trustgraph-embeddings-hf/trustgraph/embeddings_hf_version.py
echo __version__ = \"${VERSION}\" > trustgraph-cli/trustgraph/cli_version.py echo __version__ = \"${VERSION}\" > trustgraph-cli/trustgraph/cli_version.py
echo __version__ = \"${VERSION}\" > trustgraph/trustgraph/trustgraph_version.py echo __version__ = \"${VERSION}\" > trustgraph/trustgraph/trustgraph_version.py

View file

@ -12,6 +12,7 @@
"graph-rag": import "components/graph-rag.jsonnet", "graph-rag": import "components/graph-rag.jsonnet",
"triple-store-cassandra": import "components/cassandra.jsonnet", "triple-store-cassandra": import "components/cassandra.jsonnet",
"triple-store-neo4j": import "components/neo4j.jsonnet", "triple-store-neo4j": import "components/neo4j.jsonnet",
"triple-store-memgraph": import "components/memgraph.jsonnet",
"llamafile": import "components/llamafile.jsonnet", "llamafile": import "components/llamafile.jsonnet",
"ollama": import "components/ollama.jsonnet", "ollama": import "components/ollama.jsonnet",
"openai": import "components/openai.jsonnet", "openai": import "components/openai.jsonnet",
@ -25,6 +26,7 @@
"trustgraph-base": import "components/trustgraph.jsonnet", "trustgraph-base": import "components/trustgraph.jsonnet",
"vector-store-milvus": import "components/milvus.jsonnet", "vector-store-milvus": import "components/milvus.jsonnet",
"vector-store-qdrant": import "components/qdrant.jsonnet", "vector-store-qdrant": import "components/qdrant.jsonnet",
"vector-store-pinecone": import "components/pinecone.jsonnet",
"vertexai": import "components/vertexai.jsonnet", "vertexai": import "components/vertexai.jsonnet",
"null": {}, "null": {},
@ -33,7 +35,9 @@
// FIXME: Dupes // FIXME: Dupes
"cassandra": import "components/cassandra.jsonnet", "cassandra": import "components/cassandra.jsonnet",
"neo4j": import "components/neo4j.jsonnet", "neo4j": import "components/neo4j.jsonnet",
"memgraph": import "components/memgraph.jsonnet",
"qdrant": import "components/qdrant.jsonnet", "qdrant": import "components/qdrant.jsonnet",
"pinecone": import "components/pinecone.jsonnet",
"milvus": import "components/milvus.jsonnet", "milvus": import "components/milvus.jsonnet",
"trustgraph": import "components/trustgraph.jsonnet", "trustgraph": import "components/trustgraph.jsonnet",

View file

@ -48,7 +48,7 @@ local prompts = import "prompts/mixtral.jsonnet";
"-i", "-i",
"non-persistent://tg/request/text-completion-rag", "non-persistent://tg/request/text-completion-rag",
"-o", "-o",
"non-persistent://tg/response/text-completion-rag-response", "non-persistent://tg/response/text-completion-rag",
]) ])
.with_env_var_secrets(envSecrets) .with_env_var_secrets(envSecrets)
.with_limits("0.5", "128M") .with_limits("0.5", "128M")

View file

@ -46,7 +46,7 @@ local prompts = import "prompts/mixtral.jsonnet";
"-i", "-i",
"non-persistent://tg/request/text-completion-rag", "non-persistent://tg/request/text-completion-rag",
"-o", "-o",
"non-persistent://tg/response/text-completion-rag-response", "non-persistent://tg/response/text-completion-rag",
]) ])
.with_env_var_secrets(envSecrets) .with_env_var_secrets(envSecrets)
.with_limits("0.5", "128M") .with_limits("0.5", "128M")

View file

@ -53,7 +53,7 @@ local chunker = import "chunker-recursive.jsonnet";
"-i", "-i",
"non-persistent://tg/request/text-completion-rag", "non-persistent://tg/request/text-completion-rag",
"-o", "-o",
"non-persistent://tg/response/text-completion-rag-response", "non-persistent://tg/response/text-completion-rag",
]) ])
.with_env_var_secrets(envSecrets) .with_env_var_secrets(envSecrets)
.with_limits("0.5", "128M") .with_limits("0.5", "128M")

View file

@ -45,7 +45,7 @@ local prompts = import "prompts/mixtral.jsonnet";
"-i", "-i",
"non-persistent://tg/request/text-completion-rag", "non-persistent://tg/request/text-completion-rag",
"-o", "-o",
"non-persistent://tg/response/text-completion-rag-response", "non-persistent://tg/response/text-completion-rag",
]) ])
.with_env_var_secrets(envSecrets) .with_env_var_secrets(envSecrets)
.with_limits("0.5", "128M") .with_limits("0.5", "128M")

View file

@ -43,7 +43,7 @@ local prompts = import "prompts/mixtral.jsonnet";
"-i", "-i",
"non-persistent://tg/request/text-completion-rag", "non-persistent://tg/request/text-completion-rag",
"-o", "-o",
"non-persistent://tg/response/text-completion-rag-response", "non-persistent://tg/response/text-completion-rag",
]) ])
.with_limits("0.5", "128M") .with_limits("0.5", "128M")
.with_reservations("0.1", "128M"); .with_reservations("0.1", "128M");

View file

@ -19,7 +19,7 @@ local prompts = import "prompts/mixtral.jsonnet";
"--prompt-request-queue", "--prompt-request-queue",
"non-persistent://tg/request/prompt-rag", "non-persistent://tg/request/prompt-rag",
"--prompt-response-queue", "--prompt-response-queue",
"non-persistent://tg/response/prompt-rag-response", "non-persistent://tg/response/prompt-rag",
]) ])
.with_limits("0.5", "128M") .with_limits("0.5", "128M")
.with_reservations("0.1", "128M"); .with_reservations("0.1", "128M");

View file

@ -13,7 +13,7 @@ local prompts = import "prompts/mixtral.jsonnet";
create:: function(engine) create:: function(engine)
local envSecrets = engine.envSecrets("bedrock-credentials") local envSecrets = engine.envSecrets("googleaistudio-key")
.with_env_var("GOOGLE_AI_STUDIO_KEY", "googleaistudio-key"); .with_env_var("GOOGLE_AI_STUDIO_KEY", "googleaistudio-key");
local container = local container =
@ -50,7 +50,7 @@ local prompts = import "prompts/mixtral.jsonnet";
"-i", "-i",
"non-persistent://tg/request/text-completion-rag", "non-persistent://tg/request/text-completion-rag",
"-o", "-o",
"non-persistent://tg/response/text-completion-rag-response", "non-persistent://tg/response/text-completion-rag",
]) ])
.with_env_var_secrets(envSecrets) .with_env_var_secrets(envSecrets)
.with_limits("0.5", "128M") .with_limits("0.5", "128M")

View file

@ -112,7 +112,7 @@ local url = import "values/url.jsonnet";
"--prompt-request-queue", "--prompt-request-queue",
"non-persistent://tg/request/prompt-rag", "non-persistent://tg/request/prompt-rag",
"--prompt-response-queue", "--prompt-response-queue",
"non-persistent://tg/response/prompt-rag-response", "non-persistent://tg/response/prompt-rag",
"--entity-limit", "--entity-limit",
std.toString($["graph-rag-entity-limit"]), std.toString($["graph-rag-entity-limit"]),
"--triple-limit", "--triple-limit",

View file

@ -40,7 +40,7 @@ local prompts = import "prompts/slm.jsonnet";
"-i", "-i",
"non-persistent://tg/request/text-completion-rag", "non-persistent://tg/request/text-completion-rag",
"-o", "-o",
"non-persistent://tg/response/text-completion-rag-response", "non-persistent://tg/response/text-completion-rag",
]) ])
.with_env_var_secrets(envSecrets) .with_env_var_secrets(envSecrets)
.with_limits("0.5", "128M") .with_limits("0.5", "128M")

View file

@ -0,0 +1,81 @@
local base = import "base/base.jsonnet";
local images = import "values/images.jsonnet";
local url = import "values/url.jsonnet";
local memgraph = import "stores/memgraph.jsonnet";
memgraph + {
"memgraph-url":: "bolt://memgraph:7687",
"memgraph-database":: "memgraph",
"store-triples" +: {
create:: function(engine)
local container =
engine.container("store-triples")
.with_image(images.trustgraph)
.with_command([
"triples-write-memgraph",
"-p",
url.pulsar,
"-g",
$["memgraph-url"],
"--database",
$["memgraph-database"],
])
.with_limits("0.5", "128M")
.with_reservations("0.1", "128M");
local containerSet = engine.containers(
"store-triples", [ container ]
);
local service =
engine.internalService(containerSet)
.with_port(8080, 8080, "metrics");
engine.resources([
containerSet,
service,
])
},
"query-triples" +: {
create:: function(engine)
local container =
engine.container("query-triples")
.with_image(images.trustgraph)
.with_command([
"triples-query-memgraph",
"-p",
url.pulsar,
"-g",
$["memgraph-url"],
"--database",
$["memgraph-database"],
])
.with_limits("0.5", "128M")
.with_reservations("0.1", "128M");
local containerSet = engine.containers(
"query-triples", [ container ]
);
local service =
engine.internalService(containerSet)
.with_port(8080, 8080, "metrics");
engine.resources([
containerSet,
service,
])
}
}

View file

@ -40,7 +40,7 @@ local prompts = import "prompts/mixtral.jsonnet";
"-i", "-i",
"non-persistent://tg/request/text-completion-rag", "non-persistent://tg/request/text-completion-rag",
"-o", "-o",
"non-persistent://tg/response/text-completion-rag-response", "non-persistent://tg/response/text-completion-rag",
]) ])
.with_env_var_secrets(envSecrets) .with_env_var_secrets(envSecrets)
.with_limits("0.5", "128M") .with_limits("0.5", "128M")

View file

@ -50,7 +50,7 @@ local prompts = import "prompts/mixtral.jsonnet";
"-i", "-i",
"non-persistent://tg/request/text-completion-rag", "non-persistent://tg/request/text-completion-rag",
"-o", "-o",
"non-persistent://tg/response/text-completion-rag-response", "non-persistent://tg/response/text-completion-rag",
]) ])
.with_env_var_secrets(envSecrets) .with_env_var_secrets(envSecrets)
.with_limits("0.5", "128M") .with_limits("0.5", "128M")

View file

@ -0,0 +1,153 @@
local base = import "base/base.jsonnet";
local images = import "values/images.jsonnet";
local url = import "values/url.jsonnet";
local cassandra_hosts = "cassandra";
{
"pinecone-cloud":: "aws",
"pinecone-region":: "us-east-1",
"store-graph-embeddings" +: {
create:: function(engine)
local envSecrets = engine.envSecrets("pinecone-api-key")
.with_env_var("PINECONE_API_KEY", "pinecone-api-key");
local container =
engine.container("store-graph-embeddings")
.with_image(images.trustgraph)
.with_command([
"ge-write-pinecone",
"-p",
url.pulsar,
])
.with_env_var_secrets(envSecrets)
.with_limits("0.5", "128M")
.with_reservations("0.1", "128M");
local containerSet = engine.containers(
"store-graph-embeddings", [ container ]
);
local service =
engine.internalService(containerSet)
.with_port(8080, 8080, "metrics");
engine.resources([
envSecrets,
containerSet,
service,
])
},
"query-graph-embeddings" +: {
create:: function(engine)
local envSecrets = engine.envSecrets("pinecone-api-key")
.with_env_var("PINECONE_API_KEY", "pinecone-api-key");
local container =
engine.container("query-graph-embeddings")
.with_image(images.trustgraph)
.with_command([
"ge-query-pinecone",
"-p",
url.pulsar,
])
.with_env_var_secrets(envSecrets)
.with_limits("0.5", "128M")
.with_reservations("0.1", "128M");
local containerSet = engine.containers(
"query-graph-embeddings", [ container ]
);
local service =
engine.internalService(containerSet)
.with_port(8080, 8080, "metrics");
engine.resources([
envSecrets,
containerSet,
service,
])
},
"store-doc-embeddings" +: {
create:: function(engine)
local envSecrets = engine.envSecrets("pinecone-api-key")
.with_env_var("PINECONE_API_KEY", "pinecone-api-key");
local container =
engine.container("store-doc-embeddings")
.with_image(images.trustgraph)
.with_command([
"de-write-pinecone",
"-p",
url.pulsar,
])
.with_env_var_secrets(envSecrets)
.with_limits("0.5", "128M")
.with_reservations("0.1", "128M");
local containerSet = engine.containers(
"store-doc-embeddings", [ container ]
);
local service =
engine.internalService(containerSet)
.with_port(8080, 8080, "metrics");
engine.resources([
envSecrets,
containerSet,
service,
])
},
"query-doc-embeddings" +: {
create:: function(engine)
local envSecrets = engine.envSecrets("pinecone-api-key")
.with_env_var("PINECONE_API_KEY", "pinecone-api-key");
local container =
engine.container("query-doc-embeddings")
.with_image(images.trustgraph)
.with_command([
"de-query-pinecone",
"-p",
url.pulsar,
])
.with_env_var_secrets(envSecrets)
.with_limits("0.5", "128M")
.with_reservations("0.1", "128M");
local containerSet = engine.containers(
"query-doc-embeddings", [ container ]
);
local service =
engine.internalService(containerSet)
.with_port(8080, 8080, "metrics");
engine.resources([
envSecrets,
containerSet,
service,
])
}
}

View file

@ -53,7 +53,7 @@ local default_prompts = import "prompts/default-prompts.jsonnet";
"--text-completion-request-queue", "--text-completion-request-queue",
"non-persistent://tg/request/text-completion", "non-persistent://tg/request/text-completion",
"--text-completion-response-queue", "--text-completion-response-queue",
"non-persistent://tg/response/text-completion-response", "non-persistent://tg/response/text-completion",
"--system-prompt", "--system-prompt",
$["prompts"]["system-template"], $["prompts"]["system-template"],
@ -92,11 +92,11 @@ local default_prompts = import "prompts/default-prompts.jsonnet";
"-i", "-i",
"non-persistent://tg/request/prompt-rag", "non-persistent://tg/request/prompt-rag",
"-o", "-o",
"non-persistent://tg/response/prompt-rag-response", "non-persistent://tg/response/prompt-rag",
"--text-completion-request-queue", "--text-completion-request-queue",
"non-persistent://tg/request/text-completion-rag", "non-persistent://tg/request/text-completion-rag",
"--text-completion-response-queue", "--text-completion-response-queue",
"non-persistent://tg/response/text-completion-rag-response", "non-persistent://tg/response/text-completion-rag",
"--system-prompt", "--system-prompt",
$["prompts"]["system-template"], $["prompts"]["system-template"],

View file

@ -5,9 +5,56 @@ local prompt = import "prompt-template.jsonnet";
{ {
"api-gateway-port":: 8088,
"api-gateway-timeout":: 600,
"chunk-size":: 250, "chunk-size":: 250,
"chunk-overlap":: 15, "chunk-overlap":: 15,
"api-gateway" +: {
create:: function(engine)
local envSecrets = engine.envSecrets("gateway-secret")
.with_env_var("GATEWAY_SECRET", "gateway-secret");
local port = $["api-gateway-port"];
local container =
engine.container("api-gateway")
.with_image(images.trustgraph)
.with_command([
"api-gateway",
"-p",
url.pulsar,
"--timeout",
std.toString($["api-gateway-timeout"]),
"--port",
std.toString(port),
])
.with_env_var_secrets(envSecrets)
.with_limits("0.5", "256M")
.with_reservations("0.1", "256M")
.with_port(8000, 8000, "metrics")
.with_port(port, port, "api");
local containerSet = engine.containers(
"api-gateway", [ container ]
);
local service =
engine.internalService(containerSet)
.with_port(8000, 8000, "metrics")
.with_port(port, port, "api");
engine.resources([
envSecrets,
containerSet,
service,
])
},
"chunker" +: { "chunker" +: {
create:: function(engine) create:: function(engine)
@ -144,7 +191,7 @@ local prompt = import "prompt-template.jsonnet";
"-p", "-p",
url.pulsar, url.pulsar,
"-i", "-i",
"non-persistent://tg/response/text-completion-rag-response", "non-persistent://tg/response/text-completion-rag",
]) ])
.with_limits("0.5", "128M") .with_limits("0.5", "128M")
.with_reservations("0.1", "128M"); .with_reservations("0.1", "128M");

View file

@ -93,7 +93,7 @@ local prompts = import "prompts/mixtral.jsonnet";
"-i", "-i",
"non-persistent://tg/request/text-completion-rag", "non-persistent://tg/request/text-completion-rag",
"-o", "-o",
"non-persistent://tg/response/text-completion-rag-response", "non-persistent://tg/response/text-completion-rag",
]) ])
.with_limits("0.5", "256M") .with_limits("0.5", "256M")
.with_reservations("0.1", "256M") .with_reservations("0.1", "256M")

View file

@ -0,0 +1,68 @@
local base = import "base/base.jsonnet";
local images = import "values/images.jsonnet";
{
"memgraph" +: {
create:: function(engine)
local container =
engine.container("memgraph")
.with_image(images.memgraph_mage)
.with_environment({
MEMGRAPH: "--storage-properties-on-edges=true --storage-enable-edges-metadata=true"
})
.with_limits("1.0", "1000M")
.with_reservations("0.5", "1000M")
.with_port(7474, 7474, "api")
.with_port(7687, 7687, "api2");
local containerSet = engine.containers(
"memgraph", [ container ]
);
local service =
engine.service(containerSet)
.with_port(7474, 7474, "api")
.with_port(7687, 7687, "api2");
engine.resources([
containerSet,
service,
])
},
"memgraph-lab" +: {
create:: function(engine)
local container =
engine.container("lab")
.with_image(images.memgraph_lab)
.with_environment({
QUICK_CONNECT_MG_HOST: "memgraph",
QUICK_CONNECT_MG_PORT: "7687",
})
.with_limits("1.0", "512M")
.with_reservations("0.5", "512M")
.with_port(3010, 3000, "http");
local containerSet = engine.containers(
"lab", [ container ]
);
local service =
engine.service(containerSet)
.with_port(3010, 3010, "http");
engine.resources([
containerSet,
service,
])
},
}

View file

@ -10,5 +10,7 @@ local version = import "version.jsonnet";
prometheus: "docker.io/prom/prometheus:v2.53.2", prometheus: "docker.io/prom/prometheus:v2.53.2",
grafana: "docker.io/grafana/grafana:11.1.4", grafana: "docker.io/grafana/grafana:11.1.4",
trustgraph: "docker.io/trustgraph/trustgraph-flow:" + version, trustgraph: "docker.io/trustgraph/trustgraph-flow:" + version,
qdrant: "docker.io/qdrant/qdrant:v1.11.1" qdrant: "docker.io/qdrant/qdrant:v1.11.1",
memgraph_mage: "docker.io/memgraph/memgraph-mage:1.22-memgraph-2.22",
memgraph_lab: "docker.io/memgraph/lab:2.19.1",
} }

28
test-api/test-agent-api Executable file
View file

@ -0,0 +1,28 @@
#!/usr/bin/env python3
import requests
import json
import sys
url = "http://localhost:8088/api/v1/"
############################################################################
input = {
"question": "What is the highest risk aspect of running a space shuttle program? Provide 5 detailed reasons to justify our answer.",
}
resp = requests.post(
f"{url}agent",
json=input,
)
resp = resp.json()
if "error" in resp:
print(f"Error: {resp['error']}")
sys.exit(1)
print(resp["answer"])

28
test-api/test-agent2-api Executable file
View file

@ -0,0 +1,28 @@
#!/usr/bin/env python3
import requests
import json
import sys
url = "http://localhost:8088/api/v1/"
############################################################################
input = {
"question": "What is 14 plus 12. Justify your answer.",
}
resp = requests.post(
f"{url}agent",
json=input,
)
resp = resp.json()
if "error" in resp:
print(f"Error: {resp['error']}")
sys.exit(1)
print(resp["answer"])

30
test-api/test-dbpedia Executable file
View file

@ -0,0 +1,30 @@
#!/usr/bin/env python3
import requests
import json
import sys
url = "http://localhost:8088/api/v1/"
############################################################################
input = {
"term": "Cornwall",
}
resp = requests.post(
f"{url}dbpedia",
json=input,
)
resp = resp.json()
if "error" in resp:
print(f"Error: {resp['error']}")
sys.exit(1)
print(resp["text"])
sys.exit(0)
############################################################################

28
test-api/test-embeddings-api Executable file
View file

@ -0,0 +1,28 @@
#!/usr/bin/env python3
import requests
import json
import sys
url = "http://localhost:8088/api/v1/"
############################################################################
input = {
"text": "What is the highest risk aspect of running a space shuttle program? Provide 5 detailed reasons to justify our answer.",
}
resp = requests.post(
f"{url}embeddings",
json=input,
)
resp = resp.json()
if "error" in resp:
print(f"Error: {resp['error']}")
sys.exit(1)
print(resp["vectors"])

30
test-api/test-encyclopedia Executable file
View file

@ -0,0 +1,30 @@
#!/usr/bin/env python3
import requests
import json
import sys
url = "http://localhost:8088/api/v1/"
############################################################################
input = {
"term": "Cornwall",
}
resp = requests.post(
f"{url}encyclopedia",
json=input,
)
resp = resp.json()
if "error" in resp:
print(f"Error: {resp['error']}")
sys.exit(1)
print(resp["text"])
sys.exit(0)
############################################################################

31
test-api/test-graph-rag-api Executable file
View file

@ -0,0 +1,31 @@
#!/usr/bin/env python3
import requests
import json
import sys
url = "http://localhost:8088/api/v1/"
############################################################################
input = {
"query": "Give me 10 facts",
}
resp = requests.post(
f"{url}graph-rag",
json=input,
)
resp = resp.json()
print(resp)
if "error" in resp:
print(f"Error: {resp['error']}")
sys.exit(1)
print(resp["response"])
sys.exit(0)
############################################################################

30
test-api/test-internet-search Executable file
View file

@ -0,0 +1,30 @@
#!/usr/bin/env python3
import requests
import json
import sys
url = "http://localhost:8088/api/v1/"
############################################################################
input = {
"term": "Cornwall",
}
resp = requests.post(
f"{url}internet-search",
json=input,
)
resp = resp.json()
if "error" in resp:
print(f"Error: {resp['error']}")
sys.exit(1)
print(resp["text"])
sys.exit(0)
############################################################################

34
test-api/test-llm-api Executable file
View file

@ -0,0 +1,34 @@
#!/usr/bin/env python3
import requests
import json
import sys
url = "http://localhost:8088/api/v1/"
############################################################################
input = {
"system": "Respond in French. Use long word, form of numbers, no digits",
# "prompt": "Add 2 and 12"
"prompt": "Add 12 and 14, and then make a poem about llamas which incorporates that number. Then write a joke about llamas"
}
resp = requests.post(
f"{url}text-completion",
json=input,
)
if resp.status_code != 200:
raise RuntimeError(f"Status code: {resp.status_code}")
resp = resp.json()
if "error" in resp:
print(f"Error: {resp['error']}")
sys.exit(1)
print(resp["response"])
############################################################################

37
test-api/test-prompt-api Executable file
View file

@ -0,0 +1,37 @@
#!/usr/bin/env python3
import requests
import json
import sys
url = "http://localhost:8088/api/v1/"
############################################################################
input = {
"id": "question",
"variables": {
"question": "Write a joke about llamas."
}
}
resp = requests.post(
f"{url}prompt",
json=input,
)
resp = resp.json()
if "error" in resp:
print(f"Error: {resp['error']}")
sys.exit(1)
if "object" in resp:
print(f"Object: {resp['object']}")
sys.exit(1)
print(resp["text"])
sys.exit(0)
############################################################################

38
test-api/test-prompt2-api Executable file
View file

@ -0,0 +1,38 @@
#!/usr/bin/env python3
import requests
import json
import sys
url = "http://localhost:8088/api/v1/"
############################################################################
input = {
"id": "extract-definitions",
"variables": {
"text": "A cat is a large mammal."
}
}
resp = requests.post(
f"{url}prompt",
json=input,
)
resp = resp.json()
if "error" in resp:
print(f"Error: {resp['error']}")
sys.exit(1)
if "object" in resp:
object = json.loads(resp["object"])
print(json.dumps(object, indent=4))
sys.exit(1)
print(resp["text"])
sys.exit(0)
############################################################################

38
test-api/test-triples-query-api Executable file
View file

@ -0,0 +1,38 @@
#!/usr/bin/env python3
import requests
import json
import sys
url = "http://localhost:8088/api/v1/"
############################################################################
input = {
"p": {
"v": "http://www.w3.org/2000/01/rdf-schema#label",
"e": True,
},
"limit": 10
}
resp = requests.post(
f"{url}triples-query",
json=input,
)
print(resp.text)
resp = resp.json()
print(resp)
if "error" in resp:
print(f"Error: {resp['error']}")
sys.exit(1)
print(resp["response"])
sys.exit(0)
############################################################################

View file

@ -0,0 +1,3 @@
from . api import *

View file

@ -0,0 +1,339 @@
import requests
import json
import dataclasses
import base64
from trustgraph.knowledge import hash, Uri, Literal
class ProtocolException(Exception):
pass
class ApplicationException(Exception):
pass
@dataclasses.dataclass
class Triple:
s : str
p : str
o : str
class Api:
def __init__(self, url="http://localhost:8088/"):
self.url = url
if not url.endswith("/"):
self.url += "/"
self.url += "api/v1/"
def check_error(self, response):
if "error" in response:
try:
msg = response["error"]["message"]
tp = response["error"]["message"]
except:
raise ApplicationException(
"Error, but the error object is broken"
)
raise ApplicationException(f"{tp}: {msg}")
def text_completion(self, system, prompt):
# The input consists of system and prompt strings
input = {
"system": system,
"prompt": prompt
}
url = f"{self.url}text-completion"
# Invoke the API, input is passed as JSON
resp = requests.post(url, json=input)
# Should be a 200 status code
if resp.status_code != 200:
raise ProtocolException(f"Status code {resp.status_code}")
try:
# Parse the response as JSON
object = resp.json()
except:
raise ProtocolException(f"Expected JSON response")
self.check_error(resp)
try:
return object["response"]
except:
raise ProtocolException(f"Response not formatted correctly")
def agent(self, question):
# The input consists of a question
input = {
"question": question
}
url = f"{self.url}agent"
# Invoke the API, input is passed as JSON
resp = requests.post(url, json=input)
# Should be a 200 status code
if resp.status_code != 200:
raise ProtocolException(f"Status code {resp.status_code}")
try:
# Parse the response as JSON
object = resp.json()
except:
raise ProtocolException(f"Expected JSON response")
self.check_error(resp)
try:
return object["answer"]
except:
raise ProtocolException(f"Response not formatted correctly")
def graph_rag(self, question):
# The input consists of a question
input = {
"query": question
}
url = f"{self.url}graph-rag"
# Invoke the API, input is passed as JSON
resp = requests.post(url, json=input)
# Should be a 200 status code
if resp.status_code != 200:
raise ProtocolException(f"Status code {resp.status_code}")
try:
# Parse the response as JSON
object = resp.json()
except:
raise ProtocolException(f"Expected JSON response")
self.check_error(resp)
try:
return object["response"]
except:
raise ProtocolException(f"Response not formatted correctly")
def embeddings(self, text):
# The input consists of a text block
input = {
"text": text
}
url = f"{self.url}embeddings"
# Invoke the API, input is passed as JSON
resp = requests.post(url, json=input)
# Should be a 200 status code
if resp.status_code != 200:
raise ProtocolException(f"Status code {resp.status_code}")
try:
# Parse the response as JSON
object = resp.json()
except:
raise ProtocolException(f"Expected JSON response")
self.check_error(resp)
try:
return object["vectors"]
except:
raise ProtocolException(f"Response not formatted correctly")
def prompt(self, id, variables):
# The input consists of system and prompt strings
input = {
"id": id,
"variables": variables
}
url = f"{self.url}prompt"
# Invoke the API, input is passed as JSON
resp = requests.post(url, json=input)
# Should be a 200 status code
if resp.status_code != 200:
raise ProtocolException(f"Status code {resp.status_code}")
try:
# Parse the response as JSON
object = resp.json()
except:
raise ProtocolException("Expected JSON response")
self.check_error(resp)
if "text" in object:
return object["text"]
if "object" in object:
try:
return json.loads(object["object"])
except Exception as e:
raise ProtocolException(
"Returned object not well-formed JSON"
)
raise ProtocolException("Response not formatted correctly")
def triples_query(self, s=None, p=None, o=None, limit=10000):
# The input consists of system and prompt strings
input = {
"limit": limit
}
if s:
if not isinstance(s, Uri):
raise RuntimeError("s must be Uri")
input["s"] = { "v": str(s), "e": isinstance(s, Uri), }
if p:
if not isinstance(p, Uri):
raise RuntimeError("p must be Uri")
input["p"] = { "v": str(p), "e": isinstance(p, Uri), }
if o:
if not isinstance(o, Uri) and not isinstance(o, Literal):
raise RuntimeError("o must be Uri or Literal")
input["o"] = { "v": str(o), "e": isinstance(o, Uri), }
url = f"{self.url}triples-query"
# Invoke the API, input is passed as JSON
resp = requests.post(url, json=input)
# Should be a 200 status code
if resp.status_code != 200:
raise ProtocolException(f"Status code {resp.status_code}")
try:
# Parse the response as JSON
object = resp.json()
except:
raise ProtocolException("Expected JSON response")
self.check_error(resp)
if "response" not in object:
raise ProtocolException("Response not formatted correctly")
def to_value(x):
if x["e"]: return Uri(x["v"])
return Literal(x["v"])
return [
Triple(
s=to_value(t["s"]),
p=to_value(t["p"]),
o=to_value(t["o"])
)
for t in object["response"]
]
return object["response"]
def load_document(self, document, id=None, metadata=None):
if id is None:
if metadata is not None:
# Situation makes no sense. What can the metadata possibly
# mean if the caller doesn't know the document ID.
# Metadata should relate to the document by ID
raise RuntimeError("Can't specify metadata without id")
id = hash(document)
triples = []
def emit(t):
triples.append(t)
if metadata:
metadata.emit(
lambda t: triples.append({
"s": { "v": t["s"], "e": isinstance(t["s"], Uri) },
"p": { "v": t["p"], "e": isinstance(t["p"], Uri) },
"o": { "v": t["o"], "e": isinstance(t["o"], Uri) }
})
)
input = {
"id": id,
"metadata": triples,
"data": base64.b64encode(document).decode("utf-8"),
}
url = f"{self.url}load/document"
# Invoke the API, input is passed as JSON
resp = requests.post(url, json=input)
# Should be a 200 status code
if resp.status_code != 200:
raise ProtocolException(f"Status code {resp.status_code}")
def load_text(self, text, id=None, metadata=None, charset="utf-8"):
if id is None:
if metadata is not None:
# Situation makes no sense. What can the metadata possibly
# mean if the caller doesn't know the document ID.
# Metadata should relate to the document by ID
raise RuntimeError("Can't specify metadata without id")
id = hash(text)
triples = []
if metadata:
metadata.emit(
lambda t: triples.append({
"s": { "v": t["s"], "e": isinstance(t["s"], Uri) },
"p": { "v": t["p"], "e": isinstance(t["p"], Uri) },
"o": { "v": t["o"], "e": isinstance(t["o"], Uri) }
})
)
input = {
"id": id,
"metadata": triples,
"charset": charset,
"text": base64.b64encode(text).decode("utf-8"),
}
url = f"{self.url}load/text"
# Invoke the API, input is passed as JSON
resp = requests.post(url, json=input)
# Should be a 200 status code
if resp.status_code != 200:
raise ProtocolException(f"Status code {resp.status_code}")

View file

@ -1,4 +1,5 @@
from . defs import *
from . identifier import * from . identifier import *
from . publication import * from . publication import *
from . document import * from . document import *

View file

@ -23,3 +23,11 @@ URL = 'https://schema.org/url'
IDENTIFIER = 'https://schema.org/identifier' IDENTIFIER = 'https://schema.org/identifier'
KEYWORD = 'https://schema.org/keywords' KEYWORD = 'https://schema.org/keywords'
class Uri(str):
def is_uri(self): return True
def is_literal(self): return False
class Literal(str):
def is_uri(self): return False
def is_literal(self): return True

View file

@ -1,6 +1,16 @@
from . defs import * from . defs import *
from .. schema import Triple, Value
def Value(value, is_uri):
if is_uri:
return Uri(value)
else:
return Literal(value)
def Triple(s, p, o):
return {
"s": s, "p": p, "o": o,
}
class DigitalDocument: class DigitalDocument:

View file

@ -1,6 +1,16 @@
from . defs import * from . defs import *
from .. schema import Triple, Value
def Value(value, is_uri):
if is_uri:
return Uri(value)
else:
return Literal(value)
def Triple(s, p, o):
return {
"s": s, "p": p, "o": o,
}
class Organization: class Organization:
def __init__(self, id, name=None, description=None): def __init__(self, id, name=None, description=None):

View file

@ -1,6 +1,16 @@
from . defs import * from . defs import *
from .. schema import Triple, Value
def Value(value, is_uri):
if is_uri:
return Uri(value)
else:
return Literal(value)
def Triple(s, p, o):
return {
"s": s, "p": p, "o": o,
}
class PublicationEvent: class PublicationEvent:
def __init__( def __init__(

View file

@ -9,4 +9,6 @@ from . graph import *
from . retrieval import * from . retrieval import *
from . metadata import * from . metadata import *
from . agent import * from . agent import *
from . lookup import *

View file

@ -60,5 +60,5 @@ document_embeddings_request_queue = topic(
'doc-embeddings', kind='non-persistent', namespace='request' 'doc-embeddings', kind='non-persistent', namespace='request'
) )
document_embeddings_response_queue = topic( document_embeddings_response_queue = topic(
'doc-embeddings-response', kind='non-persistent', namespace='response', 'doc-embeddings', kind='non-persistent', namespace='response',
) )

View file

@ -34,7 +34,7 @@ graph_embeddings_request_queue = topic(
'graph-embeddings', kind='non-persistent', namespace='request' 'graph-embeddings', kind='non-persistent', namespace='request'
) )
graph_embeddings_response_queue = topic( graph_embeddings_response_queue = topic(
'graph-embeddings-response', kind='non-persistent', namespace='response', 'graph-embeddings', kind='non-persistent', namespace='response'
) )
############################################################################ ############################################################################
@ -67,5 +67,5 @@ triples_request_queue = topic(
'triples', kind='non-persistent', namespace='request' 'triples', kind='non-persistent', namespace='request'
) )
triples_response_queue = topic( triples_response_queue = topic(
'triples-response', kind='non-persistent', namespace='response', 'triples', kind='non-persistent', namespace='response'
) )

View file

@ -0,0 +1,42 @@
from pulsar.schema import Record, String
from . types import Error, Value, Triple
from . topic import topic
from . metadata import Metadata
############################################################################
# Lookups
class LookupRequest(Record):
kind = String()
term = String()
class LookupResponse(Record):
text = String()
error = Error()
encyclopedia_lookup_request_queue = topic(
'encyclopedia', kind='non-persistent', namespace='request'
)
encyclopedia_lookup_response_queue = topic(
'encyclopedia', kind='non-persistent', namespace='response',
)
dbpedia_lookup_request_queue = topic(
'dbpedia', kind='non-persistent', namespace='request'
)
dbpedia_lookup_response_queue = topic(
'dbpedia', kind='non-persistent', namespace='response',
)
internet_search_request_queue = topic(
'internet-search', kind='non-persistent', namespace='request'
)
internet_search_response_queue = topic(
'internet-search', kind='non-persistent', namespace='response',
)
############################################################################

View file

@ -23,7 +23,7 @@ text_completion_request_queue = topic(
'text-completion', kind='non-persistent', namespace='request' 'text-completion', kind='non-persistent', namespace='request'
) )
text_completion_response_queue = topic( text_completion_response_queue = topic(
'text-completion-response', kind='non-persistent', namespace='response', 'text-completion', kind='non-persistent', namespace='response'
) )
############################################################################ ############################################################################
@ -41,5 +41,5 @@ embeddings_request_queue = topic(
'embeddings', kind='non-persistent', namespace='request' 'embeddings', kind='non-persistent', namespace='request'
) )
embeddings_response_queue = topic( embeddings_response_queue = topic(
'embeddings-response', kind='non-persistent', namespace='response' 'embeddings', kind='non-persistent', namespace='response'
) )

View file

@ -59,7 +59,7 @@ prompt_request_queue = topic(
'prompt', kind='non-persistent', namespace='request' 'prompt', kind='non-persistent', namespace='request'
) )
prompt_response_queue = topic( prompt_response_queue = topic(
'prompt-response', kind='non-persistent', namespace='response' 'prompt', kind='non-persistent', namespace='response'
) )
############################################################################ ############################################################################

View file

@ -20,7 +20,7 @@ graph_rag_request_queue = topic(
'graph-rag', kind='non-persistent', namespace='request' 'graph-rag', kind='non-persistent', namespace='request'
) )
graph_rag_response_queue = topic( graph_rag_response_queue = topic(
'graph-rag-response', kind='non-persistent', namespace='response' 'graph-rag', kind='non-persistent', namespace='response'
) )
############################################################################ ############################################################################
@ -40,5 +40,5 @@ document_rag_request_queue = topic(
'doc-rag', kind='non-persistent', namespace='request' 'doc-rag', kind='non-persistent', namespace='request'
) )
document_rag_response_queue = topic( document_rag_response_queue = topic(
'doc-rag-response', kind='non-persistent', namespace='response' 'doc-rag', kind='non-persistent', namespace='response'
) )

View file

@ -34,7 +34,7 @@ setuptools.setup(
python_requires='>=3.8', python_requires='>=3.8',
download_url = "https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v" + version + ".tar.gz", download_url = "https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v" + version + ".tar.gz",
install_requires=[ install_requires=[
"trustgraph-base>=0.15,<0.16", "trustgraph-base>=0.17,<0.18",
"pulsar-client", "pulsar-client",
"prometheus-client", "prometheus-client",
"boto3", "boto3",

View file

@ -0,0 +1,92 @@
#!/usr/bin/env python3
"""
This utility reads a knowledge core in msgpack format and outputs its
contents in JSON form to standard output. This is useful only as a
diagnostic utility.
"""
import msgpack
import sys
import argparse
def dump(input_file, action):
with open(input_file, 'rb') as f:
unpacker = msgpack.Unpacker(f, raw=False)
for unpacked in unpacker:
print(unpacked)
def summary(input_file, action):
vector_dim = None
triples = set()
max_records = 1000000
with open(input_file, 'rb') as f:
unpacker = msgpack.Unpacker(f, raw=False)
rec_count = 0
for msg in unpacker:
if msg[0] == "ge":
vector_dim = len(msg[1]["v"][0])
if msg[0] == "t":
for elt in msg[1]["m"]["m"]:
triples.add((
elt["s"]["v"],
elt["p"]["v"],
elt["o"]["v"],
))
if rec_count > max_records: break
rec_count += 1
print("Vector dimension:", vector_dim)
for t in triples:
if t[1] == "http://www.w3.org/2000/01/rdf-schema#label":
print("-", t[2])
def main():
parser = argparse.ArgumentParser(
prog='tg-dump-msgpack',
description=__doc__,
)
parser.add_argument(
'-i', '--input-file',
required=True,
help=f'Input file'
)
parser.add_argument(
'-s', '--summary', action="store_const", const="summary",
dest="action",
help=f'Show a summary'
)
parser.add_argument(
'-r', '--records', action="store_const", const="records",
dest="action",
help=f'Dump individual records'
)
args = parser.parse_args()
if args.action == "summary":
summary(**vars(args))
else:
dump(**vars(args))
main()

View file

@ -0,0 +1,286 @@
#!/usr/bin/env python3
"""This utility takes a knowledge core and loads it into a running TrustGraph
through the API. The knowledge core should be in msgpack format, which is the
default format produce by tg-save-kg-core.
"""
import aiohttp
import asyncio
import msgpack
import json
import sys
import argparse
import os
import signal
class Running:
def __init__(self): self.running = True
def get(self): return self.running
def stop(self): self.running = False
ge_counts = 0
t_counts = 0
async def load_ge(running, queue, url):
global ge_counts
async with aiohttp.ClientSession() as session:
async with session.ws_connect(f"{url}load/graph-embeddings") as ws:
while running.get():
try:
msg = await asyncio.wait_for(queue.get(), 1)
# End of load
if msg is None:
break
except:
# Hopefully it's TimeoutError. Annoying to match since
# it changed in 3.11.
continue
msg = {
"metadata": {
"id": msg["m"]["i"],
"metadata": msg["m"]["m"],
"user": msg["m"]["u"],
"collection": msg["m"]["c"],
},
"vectors": msg["v"],
"entity": msg["e"],
}
try:
await ws.send_json(msg)
except Exception as e:
print(e)
ge_counts += 1
async def load_triples(running, queue, url):
global t_counts
async with aiohttp.ClientSession() as session:
async with session.ws_connect(f"{url}load/triples") as ws:
while running.get():
try:
msg = await asyncio.wait_for(queue.get(), 1)
# End of load
if msg is None:
break
except:
# Hopefully it's TimeoutError. Annoying to match since
# it changed in 3.11.
continue
msg ={
"metadata": {
"id": msg["m"]["i"],
"metadata": msg["m"]["m"],
"user": msg["m"]["u"],
"collection": msg["m"]["c"],
},
"triples": msg["t"],
}
try:
await ws.send_json(msg)
except Exception as e:
print(e)
t_counts += 1
async def stats(running):
global t_counts
global ge_counts
while running.get():
await asyncio.sleep(2)
print(
f"Graph embeddings: {ge_counts:10d} Triples: {t_counts:10d}"
)
async def loader(running, ge_queue, t_queue, path, format, user, collection):
if format == "json":
raise RuntimeError("Not implemented")
else:
with open(path, "rb") as f:
unpacker = msgpack.Unpacker(f, raw=False)
while running.get():
try:
unpacked = unpacker.unpack()
except:
break
if user:
unpacked["metadata"]["user"] = user
if collection:
unpacked["metadata"]["collection"] = collection
if unpacked[0] == "t":
qtype = t_queue
else:
if unpacked[0] == "ge":
qtype = ge_queue
while running.get():
try:
await asyncio.wait_for(qtype.put(unpacked[1]), 0.5)
# Successful put message, move on
break
except:
# Hopefully it's TimeoutError. Annoying to match since
# it changed in 3.11.
continue
if not running.get(): break
# Put 'None' on end of queue to finish
while running.get():
try:
await asyncio.wait_for(t_queue.put(None), 1)
# Successful put message, move on
break
except:
# Hopefully it's TimeoutError. Annoying to match since
# it changed in 3.11.
continue
# Put 'None' on end of queue to finish
while running.get():
try:
await asyncio.wait_for(ge_queue.put(None), 1)
# Successful put message, move on
break
except:
# Hopefully it's TimeoutError. Annoying to match since
# it changed in 3.11.
continue
async def run(running, **args):
# Maxsize on queues reduces back-pressure so tg-load-kg-core doesn't
# grow to eat all memory
ge_q = asyncio.Queue(maxsize=10)
t_q = asyncio.Queue(maxsize=10)
load_task = asyncio.create_task(
loader(
running=running,
ge_queue=ge_q, t_queue=t_q,
path=args["input_file"], format=args["format"],
user=args["user"], collection=args["collection"],
)
)
ge_task = asyncio.create_task(
load_ge(
running=running,
queue=ge_q, url=args["url"] + "api/v1/"
)
)
triples_task = asyncio.create_task(
load_triples(
running=running,
queue=t_q, url=args["url"] + "api/v1/"
)
)
stats_task = asyncio.create_task(stats(running))
await triples_task
await ge_task
running.stop()
await load_task
await stats_task
async def main(running):
parser = argparse.ArgumentParser(
prog='tg-load-kg-core',
description=__doc__,
)
default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/")
default_user = "trustgraph"
collection = "default"
parser.add_argument(
'-u', '--url',
default=default_url,
help=f'TrustGraph API URL (default: {default_url})',
)
parser.add_argument(
'-i', '--input-file',
# Make it mandatory, difficult to over-write an existing file
required=True,
help=f'Output file'
)
parser.add_argument(
'--format',
default="msgpack",
choices=["msgpack", "json"],
help=f'Output format (default: msgpack)',
)
parser.add_argument(
'--user',
help=f'User ID to load as (default: from input)'
)
parser.add_argument(
'--collection',
help=f'Collection ID to load as (default: from input)'
)
args = parser.parse_args()
await run(running, **vars(args))
running = Running()
def interrupt(sig, frame):
running.stop()
print('Interrupt')
signal.signal(signal.SIGINT, interrupt)
asyncio.run(main(running))

View file

@ -14,9 +14,9 @@ import time
import uuid import uuid
from trustgraph.schema import Document, document_ingest_queue from trustgraph.schema import Document, document_ingest_queue
from trustgraph.schema import Metadata from trustgraph.schema import Metadata, Triple, Value
from trustgraph.log_level import LogLevel from trustgraph.log_level import LogLevel
from trustgraph.knowledge import hash, to_uri from trustgraph.knowledge import hash, to_uri, Uri
from trustgraph.knowledge import PREF_PUBEV, PREF_DOC, PREF_ORG from trustgraph.knowledge import PREF_PUBEV, PREF_DOC, PREF_ORG
from trustgraph.knowledge import Organization, PublicationEvent from trustgraph.knowledge import Organization, PublicationEvent
from trustgraph.knowledge import DigitalDocument from trustgraph.knowledge import DigitalDocument
@ -79,7 +79,23 @@ class Loader:
r = Document( r = Document(
metadata=Metadata( metadata=Metadata(
id=id, id=id,
metadata=triples, metadata=[
Triple(
s=Value(
value=t["s"],
is_uri=isinstance(t["s"], Uri)
),
p=Value(
value=t["p"],
is_uri=isinstance(t["p"], Uri)
),
o=Value(
value=t["o"],
is_uri=isinstance(t["o"], Uri)
),
)
for t in triples
],
user=self.user, user=self.user,
collection=self.collection, collection=self.collection,
), ),

View file

@ -6,7 +6,6 @@ Loads a text document into TrustGraph processing.
import pulsar import pulsar
from pulsar.schema import JsonSchema from pulsar.schema import JsonSchema
import base64
import hashlib import hashlib
import argparse import argparse
import os import os
@ -14,9 +13,9 @@ import time
import uuid import uuid
from trustgraph.schema import TextDocument, text_ingest_queue from trustgraph.schema import TextDocument, text_ingest_queue
from trustgraph.schema import Metadata from trustgraph.schema import Metadata, Triple, Value
from trustgraph.log_level import LogLevel from trustgraph.log_level import LogLevel
from trustgraph.knowledge import hash, to_uri from trustgraph.knowledge import hash, to_uri, Literal, Uri
from trustgraph.knowledge import PREF_PUBEV, PREF_DOC, PREF_ORG from trustgraph.knowledge import PREF_PUBEV, PREF_DOC, PREF_ORG
from trustgraph.knowledge import Organization, PublicationEvent from trustgraph.knowledge import Organization, PublicationEvent
from trustgraph.knowledge import DigitalDocument from trustgraph.knowledge import DigitalDocument
@ -79,7 +78,23 @@ class Loader:
r = TextDocument( r = TextDocument(
metadata=Metadata( metadata=Metadata(
id=id, id=id,
metadata=triples, metadata=[
Triple(
s=Value(
value=t["s"],
is_uri=isinstance(t["s"], Uri)
),
p=Value(
value=t["p"],
is_uri=isinstance(t["p"], Uri)
),
o=Value(
value=t["o"],
is_uri=isinstance(t["o"], Uri)
),
)
for t in triples
],
user=self.user, user=self.user,
collection=self.collection, collection=self.collection,
), ),

View file

@ -0,0 +1,245 @@
#!/usr/bin/env python3
"""
This utility connects to a running TrustGraph through the API and creates
a knowledge core from the data streaming through the processing queues.
For completeness of data, tg-save-kg-core should be initiated before data
loading takes place. The default output format, msgpack should be used.
JSON output format is also available - msgpack produces a more compact
representation, which is also more performant to load.
"""
import aiohttp
import asyncio
import msgpack
import json
import sys
import argparse
import os
import signal
class Running:
def __init__(self): self.running = True
def get(self): return self.running
def stop(self): self.running = False
async def fetch_ge(running, queue, user, collection, url):
async with aiohttp.ClientSession() as session:
async with session.ws_connect(f"{url}stream/graph-embeddings") as ws:
while running.get():
try:
msg = await asyncio.wait_for(ws.receive(), 1)
except:
continue
if msg.type == aiohttp.WSMsgType.TEXT:
data = msg.json()
if user:
if data["metadata"]["user"] != user:
continue
if collection:
if data["metadata"]["collection"] != collection:
continue
await queue.put([
"ge",
{
"m": {
"i": data["metadata"]["id"],
"m": data["metadata"]["metadata"],
"u": data["metadata"]["user"],
"c": data["metadata"]["collection"],
},
"v": data["vectors"],
"e": data["entity"],
}
])
if msg.type == aiohttp.WSMsgType.ERROR:
print("Error")
break
async def fetch_triples(running, queue, user, collection, url):
async with aiohttp.ClientSession() as session:
async with session.ws_connect(f"{url}stream/triples") as ws:
while running.get():
try:
msg = await asyncio.wait_for(ws.receive(), 1)
except:
continue
if msg.type == aiohttp.WSMsgType.TEXT:
data = msg.json()
if user:
if data["metadata"]["user"] != user:
continue
if collection:
if data["metadata"]["collection"] != collection:
continue
await queue.put((
"t",
{
"m": {
"i": data["metadata"]["id"],
"m": data["metadata"]["metadata"],
"u": data["metadata"]["user"],
"c": data["metadata"]["collection"],
},
"t": data["triples"],
}
))
if msg.type == aiohttp.WSMsgType.ERROR:
print("Error")
break
ge_counts = 0
t_counts = 0
async def stats(running):
global t_counts
global ge_counts
while running.get():
await asyncio.sleep(2)
print(
f"Graph embeddings: {ge_counts:10d} Triples: {t_counts:10d}"
)
async def output(running, queue, path, format):
global t_counts
global ge_counts
with open(path, "wb") as f:
while running.get():
try:
msg = await asyncio.wait_for(queue.get(), 0.5)
except:
# Hopefully it's TimeoutError. Annoying to match since
# it changed in 3.11.
continue
if format == "msgpack":
f.write(msgpack.packb(msg, use_bin_type=True))
else:
f.write(json.dumps(msg).encode("utf-8"))
if msg[0] == "t":
t_counts += 1
else:
if msg[0] == "ge":
ge_counts += 1
print("Output file closed")
async def run(running, **args):
q = asyncio.Queue()
ge_task = asyncio.create_task(
fetch_ge(
running=running,
queue=q, user=args["user"], collection=args["collection"],
url=args["url"] + "api/v1/"
)
)
triples_task = asyncio.create_task(
fetch_triples(
running=running, queue=q,
user=args["user"], collection=args["collection"],
url=args["url"] + "api/v1/"
)
)
output_task = asyncio.create_task(
output(
running=running, queue=q,
path=args["output_file"], format=args["format"],
)
)
stats_task = asyncio.create_task(stats(running))
await output_task
await triples_task
await ge_task
await stats_task
print("Exiting")
async def main(running):
parser = argparse.ArgumentParser(
prog='tg-save-kg-core',
description=__doc__,
)
default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/")
default_user = "trustgraph"
collection = "default"
parser.add_argument(
'-u', '--url',
default=default_url,
help=f'TrustGraph API URL (default: {default_url})',
)
parser.add_argument(
'-o', '--output-file',
# Make it mandatory, difficult to over-write an existing file
required=True,
help=f'Output file'
)
parser.add_argument(
'--format',
default="msgpack",
choices=["msgpack", "json"],
help=f'Output format (default: msgpack)',
)
parser.add_argument(
'--user',
help=f'User ID to filter on (default: no filter)'
)
parser.add_argument(
'--collection',
help=f'Collection ID to filter on (default: no filter)'
)
args = parser.parse_args()
await run(running, **vars(args))
running = Running()
def interrupt(sig, frame):
running.stop()
print('Interrupt')
signal.signal(signal.SIGINT, interrupt)
asyncio.run(main(running))

View file

@ -34,11 +34,12 @@ setuptools.setup(
python_requires='>=3.8', python_requires='>=3.8',
download_url = "https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v" + version + ".tar.gz", download_url = "https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v" + version + ".tar.gz",
install_requires=[ install_requires=[
"trustgraph-base>=0.15,<0.16", "trustgraph-base>=0.17,<0.18",
"requests", "requests",
"pulsar-client", "pulsar-client",
"rdflib", "rdflib",
"tabulate", "tabulate",
"msgpack",
], ],
scripts=[ scripts=[
"scripts/tg-graph-show", "scripts/tg-graph-show",
@ -54,5 +55,8 @@ setuptools.setup(
"scripts/tg-invoke-agent", "scripts/tg-invoke-agent",
"scripts/tg-invoke-prompt", "scripts/tg-invoke-prompt",
"scripts/tg-invoke-llm", "scripts/tg-invoke-llm",
"scripts/tg-save-kg-core",
"scripts/tg-load-kg-core",
"scripts/tg-dump-msgpack",
] ]
) )

View file

@ -34,8 +34,8 @@ setuptools.setup(
python_requires='>=3.8', python_requires='>=3.8',
download_url = "https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v" + version + ".tar.gz", download_url = "https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v" + version + ".tar.gz",
install_requires=[ install_requires=[
"trustgraph-base>=0.15,<0.16", "trustgraph-base>=0.17,<0.18",
"trustgraph-flow>=0.15,<0.16", "trustgraph-flow>=0.17,<0.18",
"torch", "torch",
"urllib3", "urllib3",
"transformers", "transformers",

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.gateway import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.query.doc_embeddings.pinecone import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.storage.doc_embeddings.pinecone import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.query.graph_embeddings.pinecone import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.storage.graph_embeddings.pinecone import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.query.triples.memgraph import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.storage.triples.memgraph import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.external.wikipedia import run
run()

View file

@ -34,7 +34,7 @@ setuptools.setup(
python_requires='>=3.8', python_requires='>=3.8',
download_url = "https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v" + version + ".tar.gz", download_url = "https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v" + version + ".tar.gz",
install_requires=[ install_requires=[
"trustgraph-base>=0.15,<0.16", "trustgraph-base>=0.17,<0.18",
"urllib3", "urllib3",
"rdflib", "rdflib",
"pymilvus", "pymilvus",
@ -58,21 +58,28 @@ setuptools.setup(
"google-generativeai", "google-generativeai",
"ibis", "ibis",
"jsonschema", "jsonschema",
"aiohttp",
"pinecone[grpc]",
], ],
scripts=[ scripts=[
"scripts/api-gateway",
"scripts/agent-manager-react", "scripts/agent-manager-react",
"scripts/chunker-recursive", "scripts/chunker-recursive",
"scripts/chunker-token", "scripts/chunker-token",
"scripts/de-query-milvus", "scripts/de-query-milvus",
"scripts/de-query-qdrant", "scripts/de-query-qdrant",
"scripts/de-query-pinecone",
"scripts/de-write-milvus", "scripts/de-write-milvus",
"scripts/de-write-qdrant", "scripts/de-write-qdrant",
"scripts/de-write-pinecone",
"scripts/document-rag", "scripts/document-rag",
"scripts/embeddings-ollama", "scripts/embeddings-ollama",
"scripts/embeddings-vectorize", "scripts/embeddings-vectorize",
"scripts/ge-query-milvus", "scripts/ge-query-milvus",
"scripts/ge-query-pinecone",
"scripts/ge-query-qdrant", "scripts/ge-query-qdrant",
"scripts/ge-write-milvus", "scripts/ge-write-milvus",
"scripts/ge-write-pinecone",
"scripts/ge-write-qdrant", "scripts/ge-write-qdrant",
"scripts/graph-rag", "scripts/graph-rag",
"scripts/kg-extract-definitions", "scripts/kg-extract-definitions",
@ -96,7 +103,10 @@ setuptools.setup(
"scripts/text-completion-openai", "scripts/text-completion-openai",
"scripts/triples-query-cassandra", "scripts/triples-query-cassandra",
"scripts/triples-query-neo4j", "scripts/triples-query-neo4j",
"scripts/triples-query-memgraph",
"scripts/triples-write-cassandra", "scripts/triples-write-cassandra",
"scripts/triples-write-neo4j", "scripts/triples-write-neo4j",
"scripts/triples-write-memgraph",
"scripts/wikipedia-lookup",
] ]
) )

View file

@ -97,7 +97,7 @@ class TrustGraph:
def get_po(self, p, o, limit=10): def get_po(self, p, o, limit=10):
return self.session.execute( return self.session.execute(
f"select s from {self.table} where p = %s and o = %s allow filtering limit {limit}", f"select s from {self.table} where p = %s and o = %s limit {limit} allow filtering",
(p, o) (p, o)
) )

View file

@ -0,0 +1,3 @@
from . service import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . service import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,102 @@
"""
Wikipedia lookup service. Fetchs an extract from the Wikipedia page
using the API.
"""
from trustgraph.schema import LookupRequest, LookupResponse, Error
from trustgraph.schema import encyclopedia_lookup_request_queue
from trustgraph.schema import encyclopedia_lookup_response_queue
from trustgraph.log_level import LogLevel
from trustgraph.base import ConsumerProducer
import requests
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = encyclopedia_lookup_request_queue
default_output_queue = encyclopedia_lookup_response_queue
default_subscriber = module
default_url="https://en.wikipedia.org/"
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
url = params.get("url", default_url)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": LookupRequest,
"output_schema": LookupResponse,
}
)
self.url = url
def handle(self, msg):
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling {v.kind} / {v.term}...", flush=True)
try:
url = f"{self.url}/api/rest_v1/page/summary/{v.term}"
resp = Result = requests.get(url).json()
resp = resp["extract"]
r = LookupResponse(
error=None,
text=resp
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
return
except Exception as e:
r = LookupResponse(
error=Error(
type = "lookup-error",
message = str(e),
),
text=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
return
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-u', '--url',
default=default_url,
help=f'LLM model (default: {default_url})'
)
def run():
Processor.start(module, __doc__)

View file

@ -0,0 +1,3 @@
from . service import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . service import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,42 @@
from .. schema import AgentRequest, AgentResponse
from .. schema import agent_request_queue
from .. schema import agent_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class AgentRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
super(AgentRequestor, self).__init__(
pulsar_host=pulsar_host,
request_queue=agent_request_queue,
response_queue=agent_response_queue,
request_schema=AgentRequest,
response_schema=AgentResponse,
timeout=timeout,
)
def to_request(self, body):
return AgentRequest(
question=body["question"]
)
def from_response(self, message):
resp = {
}
if message.answer:
resp["answer"] = message.answer
if message.thought:
resp["thought"] = message.thought
if message.observation:
resp["observation"] = message.observation
# The 2nd boolean expression indicates whether we're done responding
return resp, (message.answer is not None)

View file

@ -0,0 +1,22 @@
class Authenticator:
def __init__(self, token=None, allow_all=False):
if not allow_all and token is None:
raise RuntimeError("Need a token")
if not allow_all and token == "":
raise RuntimeError("Need a token")
self.token = token
self.allow_all = allow_all
def permitted(self, token, roles):
if self.allow_all: return True
if self.token != token: return False
return True

View file

@ -0,0 +1,29 @@
from .. schema import LookupRequest, LookupResponse
from .. schema import dbpedia_lookup_request_queue
from .. schema import dbpedia_lookup_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class DbpediaRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
super(DbpediaRequestor, self).__init__(
pulsar_host=pulsar_host,
request_queue=dbpedia_lookup_request_queue,
response_queue=dbpedia_lookup_response_queue,
request_schema=LookupRequest,
response_schema=LookupResponse,
timeout=timeout,
)
def to_request(self, body):
return LookupRequest(
term=body["term"],
kind=body.get("kind", None),
)
def from_response(self, message):
return { "text": message.text }, True

View file

@ -0,0 +1,29 @@
from .. schema import EmbeddingsRequest, EmbeddingsResponse
from .. schema import embeddings_request_queue
from .. schema import embeddings_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class EmbeddingsRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
super(EmbeddingsRequestor, self).__init__(
pulsar_host=pulsar_host,
request_queue=embeddings_request_queue,
response_queue=embeddings_response_queue,
request_schema=EmbeddingsRequest,
response_schema=EmbeddingsResponse,
timeout=timeout,
)
def to_request(self, body):
return EmbeddingsRequest(
text=body["text"]
)
def from_response(self, message):
return { "vectors": message.vectors }, True

View file

@ -0,0 +1,29 @@
from .. schema import LookupRequest, LookupResponse
from .. schema import encyclopedia_lookup_request_queue
from .. schema import encyclopedia_lookup_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class EncyclopediaRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
super(EncyclopediaRequestor, self).__init__(
pulsar_host=pulsar_host,
request_queue=encyclopedia_lookup_request_queue,
response_queue=encyclopedia_lookup_response_queue,
request_schema=LookupRequest,
response_schema=LookupResponse,
timeout=timeout,
)
def to_request(self, body):
return LookupRequest(
term=body["term"],
kind=body.get("kind", None),
)
def from_response(self, message):
return { "text": message.text }, True

View file

@ -0,0 +1,69 @@
import asyncio
from pulsar.schema import JsonSchema
from aiohttp import web
import uuid
import logging
from . publisher import Publisher
from . subscriber import Subscriber
logger = logging.getLogger("endpoint")
logger.setLevel(logging.INFO)
class ServiceEndpoint:
def __init__(self, endpoint_path, auth, requestor):
self.path = endpoint_path
self.auth = auth
self.operation = "service"
self.requestor = requestor
async def start(self):
await self.requestor.start()
def add_routes(self, app):
app.add_routes([
web.post(self.path, self.handle),
])
async def handle(self, request):
print(request.path, "...")
try:
ht = request.headers["Authorization"]
tokens = ht.split(" ", 2)
if tokens[0] != "Bearer":
return web.HTTPUnauthorized()
token = tokens[1]
except:
token = ""
if not self.auth.permitted(token, self.operation):
return web.HTTPUnauthorized()
try:
data = await request.json()
print(data)
def responder(x, fin):
print(x)
resp, fin = await self.requestor.process(data, responder)
return web.json_response(resp)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)

View file

@ -0,0 +1,60 @@
import asyncio
from pulsar.schema import JsonSchema
import uuid
from aiohttp import WSMsgType
from .. schema import Metadata
from .. schema import GraphEmbeddings
from .. schema import graph_embeddings_store_queue
from . publisher import Publisher
from . socket import SocketEndpoint
from . serialize import to_subgraph, to_value
class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
def __init__(
self, pulsar_host, auth, path="/api/v1/load/graph-embeddings",
):
super(GraphEmbeddingsLoadEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_host=pulsar_host
self.publisher = Publisher(
self.pulsar_host, graph_embeddings_store_queue,
schema=JsonSchema(GraphEmbeddings)
)
async def start(self):
self.publisher.start()
async def listener(self, ws, running):
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.ERROR:
break
else:
data = msg.json()
elt = GraphEmbeddings(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
entity=to_value(data["entity"]),
vectors=data["vectors"],
)
self.publisher.send(None, elt)
running.stop()

View file

@ -0,0 +1,57 @@
import asyncio
import queue
from pulsar.schema import JsonSchema
import uuid
from .. schema import GraphEmbeddings
from .. schema import graph_embeddings_store_queue
from . subscriber import Subscriber
from . socket import SocketEndpoint
from . serialize import serialize_graph_embeddings
class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
def __init__(
self, pulsar_host, auth, path="/api/v1/stream/graph-embeddings"
):
super(GraphEmbeddingsStreamEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_host=pulsar_host
self.subscriber = Subscriber(
self.pulsar_host, graph_embeddings_store_queue,
"api-gateway", "api-gateway",
schema=JsonSchema(GraphEmbeddings)
)
async def start(self):
self.subscriber.start()
async def async_thread(self, ws, running):
id = str(uuid.uuid4())
q = self.subscriber.subscribe_all(id)
while running.get():
try:
resp = await asyncio.to_thread(q.get, timeout=0.5)
await ws.send_json(serialize_graph_embeddings(resp))
except queue.Empty:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
break
self.subscriber.unsubscribe_all(id)
running.stop()

View file

@ -0,0 +1,30 @@
from .. schema import GraphRagQuery, GraphRagResponse
from .. schema import graph_rag_request_queue
from .. schema import graph_rag_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class GraphRagRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
super(GraphRagRequestor, self).__init__(
pulsar_host=pulsar_host,
request_queue=graph_rag_request_queue,
response_queue=graph_rag_response_queue,
request_schema=GraphRagQuery,
response_schema=GraphRagResponse,
timeout=timeout,
)
def to_request(self, body):
return GraphRagQuery(
query=body["query"],
user=body.get("user", "trustgraph"),
collection=body.get("collection", "default"),
)
def from_response(self, message):
return { "response": message.response }, True

View file

@ -0,0 +1,29 @@
from .. schema import LookupRequest, LookupResponse
from .. schema import internet_search_request_queue
from .. schema import internet_search_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class InternetSearchRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
super(InternetSearchRequestor, self).__init__(
pulsar_host=pulsar_host,
request_queue=internet_search_request_queue,
response_queue=internet_search_response_queue,
request_schema=LookupRequest,
response_schema=LookupResponse,
timeout=timeout,
)
def to_request(self, body):
return LookupRequest(
term=body["term"],
kind=body.get("kind", None),
)
def from_response(self, message):
return { "text": message.text }, True

View file

@ -0,0 +1,94 @@
import asyncio
import queue
from pulsar.schema import JsonSchema
import uuid
from aiohttp import web, WSMsgType
from . socket import SocketEndpoint
from . text_completion import TextCompletionRequestor
class MuxEndpoint(SocketEndpoint):
def __init__(
self, pulsar_host, auth,
services,
path="/api/v1/mux",
):
super(MuxEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.q = asyncio.Queue(maxsize=10)
self.services = services
async def start(self):
pass
async def async_thread(self, ws, running):
while running.get():
try:
id, svc, request = await asyncio.wait_for(self.q.get(), 1)
except TimeoutError:
continue
except Exception as e:
await ws.send_json({"id": id, "error": str(e)})
try:
print(svc, request)
requestor = self.services[svc]
async def responder(resp, fin):
await ws.send_json({
"id": id,
"response": resp,
"complete": fin,
})
resp = await requestor.process(request, responder)
except Exception as e:
await ws.send_json({"error": str(e)})
running.stop()
async def listener(self, ws, running):
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.ERROR:
break
else:
try:
data = msg.json()
if data["service"] not in self.services:
raise RuntimeError("Bad service")
if "request" not in data:
raise RuntimeError("Bad message")
if "id" not in data:
raise RuntimeError("Bad message")
await self.q.put(
(data["id"], data["service"], data["request"])
)
except Exception as e:
await ws.send_json({"error": str(e)})
continue
running.stop()

View file

@ -0,0 +1,41 @@
import json
from .. schema import PromptRequest, PromptResponse
from .. schema import prompt_request_queue
from .. schema import prompt_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class PromptRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
super(PromptRequestor, self).__init__(
pulsar_host=pulsar_host,
request_queue=prompt_request_queue,
response_queue=prompt_response_queue,
request_schema=PromptRequest,
response_schema=PromptResponse,
timeout=timeout,
)
def to_request(self, body):
return PromptRequest(
id=body["id"],
terms={
k: json.dumps(v)
for k, v in body["variables"].items()
}
)
def from_response(self, message):
if message.object:
return {
"object": message.object
}, True
else:
return {
"text": message.text
}, True

View file

@ -0,0 +1,53 @@
import queue
import time
import pulsar
import threading
class Publisher:
def __init__(self, pulsar_host, topic, schema=None, max_size=10,
chunking_enabled=False):
self.pulsar_host = pulsar_host
self.topic = topic
self.schema = schema
self.q = queue.Queue(maxsize=max_size)
self.chunking_enabled = chunking_enabled
def start(self):
self.task = threading.Thread(target=self.run)
self.task.start()
def run(self):
while True:
try:
client = pulsar.Client(
self.pulsar_host,
)
producer = client.create_producer(
topic=self.topic,
schema=self.schema,
chunking_enabled=self.chunking_enabled,
)
while True:
id, item = self.q.get()
if id:
producer.send(item, { "id": id })
else:
producer.send(item)
except Exception as e:
print("Exception:", e, flush=True)
# If handler drops out, sleep a retry
time.sleep(2)
def send(self, id, msg):
self.q.put((id, msg))

View file

@ -0,0 +1,88 @@
import asyncio
from pulsar.schema import JsonSchema
import uuid
import logging
from . publisher import Publisher
from . subscriber import Subscriber
logger = logging.getLogger("requestor")
logger.setLevel(logging.INFO)
class ServiceRequestor:
def __init__(
self,
pulsar_host,
request_queue, request_schema,
response_queue, response_schema,
subscription="api-gateway", consumer_name="api-gateway",
timeout=600,
):
self.pub = Publisher(
pulsar_host, request_queue,
schema=JsonSchema(request_schema)
)
self.sub = Subscriber(
pulsar_host, response_queue,
subscription, consumer_name,
JsonSchema(response_schema)
)
self.timeout = timeout
async def start(self):
self.pub.start()
self.sub.start()
def to_request(self, request):
raise RuntimeError("Not defined")
def from_response(self, response):
raise RuntimeError("Not defined")
async def process(self, request, responder=None):
id = str(uuid.uuid4())
try:
q = self.sub.subscribe(id)
await asyncio.to_thread(
self.pub.send, id, self.to_request(request)
)
while True:
try:
resp = await asyncio.to_thread(q.get, timeout=self.timeout)
except Exception as e:
raise RuntimeError("Timeout")
if resp.error:
return { "error": resp.error.message }
resp, fin = self.from_response(resp)
print(resp, fin)
if responder:
await responder(resp, fin)
if fin:
return resp
except Exception as e:
logging.error(f"Exception: {e}")
return { "error": str(e) }
finally:
self.sub.unsubscribe(id)

View file

@ -0,0 +1,5 @@
class Running:
def __init__(self): self.running = True
def get(self): return self.running
def stop(self): self.running = False

View file

@ -0,0 +1,57 @@
from .. schema import Value, Triple
def to_value(x):
return Value(value=x["v"], is_uri=x["e"])
def to_subgraph(x):
return [
Triple(
s=to_value(t["s"]),
p=to_value(t["p"]),
o=to_value(t["o"])
)
for t in x
]
def serialize_value(v):
return {
"v": v.value,
"e": v.is_uri,
}
def serialize_triple(t):
return {
"s": serialize_value(t.s),
"p": serialize_value(t.p),
"o": serialize_value(t.o)
}
def serialize_subgraph(sg):
return [
serialize_triple(t)
for t in sg
]
def serialize_triples(message):
return {
"metadata": {
"id": message.metadata.id,
"metadata": serialize_subgraph(message.metadata.metadata),
"user": message.metadata.user,
"collection": message.metadata.collection,
},
"triples": serialize_subgraph(message.triples),
}
def serialize_graph_embeddings(message):
return {
"metadata": {
"id": message.metadata.id,
"metadata": serialize_subgraph(message.metadata.metadata),
"user": message.metadata.user,
"collection": message.metadata.collection,
},
"vectors": message.vectors,
"entity": serialize_value(message.entity),
}

View file

@ -0,0 +1,364 @@
"""
API gateway. Offers HTTP services which are translated to interaction on the
Pulsar bus.
"""
module = ".".join(__name__.split(".")[1:-1])
# FIXME: Subscribes to Pulsar unnecessarily, should only do it when there
# are active listeners
# FIXME: Connection errors in publishers / subscribers cause those threads
# to fail and are not failed or retried
import asyncio
import argparse
from aiohttp import web
import logging
import os
import base64
import pulsar
from pulsar.schema import JsonSchema
from prometheus_client import start_http_server
from .. log_level import LogLevel
from .. schema import Metadata, Document, TextDocument
from .. schema import document_ingest_queue, text_ingest_queue
from . serialize import to_subgraph
from . running import Running
from . publisher import Publisher
from . subscriber import Subscriber
from . text_completion import TextCompletionRequestor
from . prompt import PromptRequestor
from . graph_rag import GraphRagRequestor
from . triples_query import TriplesQueryRequestor
from . embeddings import EmbeddingsRequestor
from . encyclopedia import EncyclopediaRequestor
from . agent import AgentRequestor
from . dbpedia import DbpediaRequestor
from . internet_search import InternetSearchRequestor
from . triples_stream import TriplesStreamEndpoint
from . graph_embeddings_stream import GraphEmbeddingsStreamEndpoint
from . triples_load import TriplesLoadEndpoint
from . graph_embeddings_load import GraphEmbeddingsLoadEndpoint
from . mux import MuxEndpoint
from . endpoint import ServiceEndpoint
from . auth import Authenticator
logger = logging.getLogger("api")
logger.setLevel(logging.INFO)
default_pulsar_host = os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
default_timeout = 600
default_port = 8088
default_api_token = os.getenv("GATEWAY_SECRET", "")
class Api:
def __init__(self, **config):
self.app = web.Application(
middlewares=[],
client_max_size=256 * 1024 * 1024
)
self.port = int(config.get("port", default_port))
self.timeout = int(config.get("timeout", default_timeout))
self.pulsar_host = config.get("pulsar_host", default_pulsar_host)
api_token = config.get("api_token", default_api_token)
# Token not set, or token equal empty string means no auth
if api_token:
self.auth = Authenticator(token=api_token)
else:
self.auth = Authenticator(allow_all=True)
self.services = {
"text-completion": TextCompletionRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
"prompt": PromptRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
"graph-rag": GraphRagRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
"triples-query": TriplesQueryRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
"embeddings": EmbeddingsRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
"agent": AgentRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
"encyclopedia": EncyclopediaRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
"dbpedia": DbpediaRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
"internet-search": InternetSearchRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
}
self.endpoints = [
ServiceEndpoint(
endpoint_path = "/api/v1/text-completion", auth=self.auth,
requestor = self.services["text-completion"],
),
ServiceEndpoint(
endpoint_path = "/api/v1/prompt", auth=self.auth,
requestor = self.services["prompt"],
),
ServiceEndpoint(
endpoint_path = "/api/v1/graph-rag", auth=self.auth,
requestor = self.services["graph-rag"],
),
ServiceEndpoint(
endpoint_path = "/api/v1/triples-query", auth=self.auth,
requestor = self.services["triples-query"],
),
ServiceEndpoint(
endpoint_path = "/api/v1/embeddings", auth=self.auth,
requestor = self.services["embeddings"],
),
ServiceEndpoint(
endpoint_path = "/api/v1/agent", auth=self.auth,
requestor = self.services["agent"],
),
ServiceEndpoint(
endpoint_path = "/api/v1/encyclopedia", auth=self.auth,
requestor = self.services["encyclopedia"],
),
ServiceEndpoint(
endpoint_path = "/api/v1/dbpedia", auth=self.auth,
requestor = self.services["dbpedia"],
),
ServiceEndpoint(
endpoint_path = "/api/v1/internet-search", auth=self.auth,
requestor = self.services["internet-search"],
),
TriplesStreamEndpoint(
pulsar_host=self.pulsar_host,
auth = self.auth,
),
GraphEmbeddingsStreamEndpoint(
pulsar_host=self.pulsar_host,
auth = self.auth,
),
TriplesLoadEndpoint(
pulsar_host=self.pulsar_host,
auth = self.auth,
),
GraphEmbeddingsLoadEndpoint(
pulsar_host=self.pulsar_host,
auth = self.auth,
),
MuxEndpoint(
pulsar_host=self.pulsar_host,
auth = self.auth,
services = self.services,
),
]
self.document_out = Publisher(
self.pulsar_host, document_ingest_queue,
schema=JsonSchema(Document),
chunking_enabled=True,
)
self.text_out = Publisher(
self.pulsar_host, text_ingest_queue,
schema=JsonSchema(TextDocument),
chunking_enabled=True,
)
for ep in self.endpoints:
ep.add_routes(self.app)
self.app.add_routes([
web.post("/api/v1/load/document", self.load_document),
web.post("/api/v1/load/text", self.load_text),
])
async def load_document(self, request):
try:
data = await request.json()
if "metadata" in data:
metadata = to_subgraph(data["metadata"])
else:
metadata = []
# Doing a base64 decoe/encode here to make sure the
# content is valid base64
doc = base64.b64decode(data["data"])
resp = await asyncio.to_thread(
self.document_out.send,
None,
Document(
metadata=Metadata(
id=data.get("id"),
metadata=metadata,
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
),
data=base64.b64encode(doc).decode("utf-8")
)
)
print("Document loaded.")
return web.json_response(
{ }
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
async def load_text(self, request):
try:
data = await request.json()
if "metadata" in data:
metadata = to_subgraph(data["metadata"])
else:
metadata = []
if "charset" in data:
charset = data["charset"]
else:
charset = "utf-8"
# Text is base64 encoded
text = base64.b64decode(data["text"]).decode(charset)
resp = await asyncio.to_thread(
self.text_out.send,
None,
TextDocument(
metadata=Metadata(
id=data.get("id"),
metadata=metadata,
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
),
text=text,
)
)
print("Text document loaded.")
return web.json_response(
{ }
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
async def app_factory(self):
for ep in self.endpoints:
await ep.start()
self.document_out.start()
self.text_out.start()
return self.app
def run(self):
web.run_app(self.app_factory(), port=self.port)
def run():
parser = argparse.ArgumentParser(
prog="api-gateway",
description=__doc__
)
parser.add_argument(
'-p', '--pulsar-host',
default=default_pulsar_host,
help=f'Pulsar host (default: {default_pulsar_host})',
)
parser.add_argument(
'--port',
type=int,
default=default_port,
help=f'Port number to listen on (default: {default_port})',
)
parser.add_argument(
'--timeout',
type=int,
default=default_timeout,
help=f'API request timeout in seconds (default: {default_timeout})',
)
parser.add_argument(
'--api-token',
default=default_api_token,
help=f'Secret API token (default: no auth)',
)
parser.add_argument(
'-l', '--log-level',
type=LogLevel,
default=LogLevel.INFO,
choices=list(LogLevel),
help=f'Output queue (default: info)'
)
parser.add_argument(
'--metrics',
action=argparse.BooleanOptionalAction,
default=True,
help=f'Metrics enabled (default: true)',
)
parser.add_argument(
'-P', '--metrics-port',
type=int,
default=8000,
help=f'Prometheus metrics port (default: 8000)',
)
args = parser.parse_args()
args = vars(args)
if args["metrics"]:
start_http_server(args["metrics_port"])
a = Api(**args)
a.run()

View file

@ -0,0 +1,84 @@
import asyncio
from aiohttp import web, WSMsgType
import logging
from . running import Running
logger = logging.getLogger("socket")
logger.setLevel(logging.INFO)
class SocketEndpoint:
def __init__(
self, endpoint_path, auth,
):
self.path = endpoint_path
self.auth = auth
self.operation = "socket"
async def listener(self, ws, running):
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.ERROR:
break
else:
# Ignore incoming messages
pass
running.stop()
async def async_thread(self, ws, running):
while running.get():
try:
await asyncio.sleep(1)
except TimeoutError:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
async def handle(self, request):
try:
token = request.query['token']
except:
token = ""
if not self.auth.permitted(token, self.operation):
return web.HTTPUnauthorized()
running = Running()
ws = web.WebSocketResponse()
await ws.prepare(request)
task = asyncio.create_task(self.async_thread(ws, running))
try:
await self.listener(ws, running)
except Exception as e:
print(e, flush=True)
running.stop()
await ws.close()
await task
return ws
async def start(self):
pass
def add_routes(self, app):
app.add_routes([
web.get(self.path, self.handle),
])

View file

@ -0,0 +1,109 @@
import queue
import pulsar
import threading
import time
class Subscriber:
def __init__(self, pulsar_host, topic, subscription, consumer_name,
schema=None, max_size=100):
self.pulsar_host = pulsar_host
self.topic = topic
self.subscription = subscription
self.consumer_name = consumer_name
self.schema = schema
self.q = {}
self.full = {}
self.max_size = max_size
self.lock = threading.Lock()
def start(self):
self.task = threading.Thread(target=self.run)
self.task.start()
def run(self):
while True:
try:
client = pulsar.Client(
self.pulsar_host,
)
consumer = client.subscribe(
topic=self.topic,
subscription_name=self.subscription,
consumer_name=self.consumer_name,
schema=self.schema,
)
while True:
msg = consumer.receive()
# Acknowledge successful reception of the message
consumer.acknowledge(msg)
try:
id = msg.properties()["id"]
except:
id = None
value = msg.value()
with self.lock:
if id in self.q:
try:
self.q[id].put(value, timeout=0.5)
except:
pass
for q in self.full.values():
try:
q.put(value, timeout=0.5)
except:
pass
except Exception as e:
print("Exception:", e, flush=True)
# If handler drops out, sleep a retry
time.sleep(2)
def subscribe(self, id):
with self.lock:
q = queue.Queue(maxsize=self.max_size)
self.q[id] = q
return q
def unsubscribe(self, id):
with self.lock:
if id in self.q:
# self.q[id].shutdown(immediate=True)
del self.q[id]
def subscribe_all(self, id):
with self.lock:
q = queue.Queue(maxsize=self.max_size)
self.full[id] = q
return q
def unsubscribe_all(self, id):
with self.lock:
if id in self.full:
# self.full[id].shutdown(immediate=True)
del self.full[id]

View file

@ -0,0 +1,29 @@
from .. schema import TextCompletionRequest, TextCompletionResponse
from .. schema import text_completion_request_queue
from .. schema import text_completion_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class TextCompletionRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
super(TextCompletionRequestor, self).__init__(
pulsar_host=pulsar_host,
request_queue=text_completion_request_queue,
response_queue=text_completion_response_queue,
request_schema=TextCompletionRequest,
response_schema=TextCompletionResponse,
timeout=timeout,
)
def to_request(self, body):
return TextCompletionRequest(
system=body["system"],
prompt=body["prompt"]
)
def from_response(self, message):
return { "response": message.response }, True

View file

@ -0,0 +1,57 @@
import asyncio
from pulsar.schema import JsonSchema
import uuid
from aiohttp import WSMsgType
from .. schema import Metadata
from .. schema import Triples
from .. schema import triples_store_queue
from . publisher import Publisher
from . socket import SocketEndpoint
from . serialize import to_subgraph
class TriplesLoadEndpoint(SocketEndpoint):
def __init__(self, pulsar_host, auth, path="/api/v1/load/triples"):
super(TriplesLoadEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_host=pulsar_host
self.publisher = Publisher(
self.pulsar_host, triples_store_queue,
schema=JsonSchema(Triples)
)
async def start(self):
self.publisher.start()
async def listener(self, ws, running):
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.ERROR:
break
else:
data = msg.json()
elt = Triples(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
triples=to_subgraph(data["triples"]),
)
self.publisher.send(None, elt)
running.stop()

View file

@ -0,0 +1,53 @@
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Triples
from .. schema import triples_request_queue
from .. schema import triples_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
from . serialize import to_value, serialize_subgraph
class TriplesQueryRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
super(TriplesQueryRequestor, self).__init__(
pulsar_host=pulsar_host,
request_queue=triples_request_queue,
response_queue=triples_response_queue,
request_schema=TriplesQueryRequest,
response_schema=TriplesQueryResponse,
timeout=timeout,
)
def to_request(self, body):
if "s" in body:
s = to_value(body["s"])
else:
s = None
if "p" in body:
p = to_value(body["p"])
else:
p = None
if "o" in body:
o = to_value(body["o"])
else:
o = None
limit = int(body.get("limit", 10000))
return TriplesQueryRequest(
s = s, p = p, o = o,
limit = limit,
user = body.get("user", "trustgraph"),
collection = body.get("collection", "default"),
)
def from_response(self, message):
print(message)
return {
"response": serialize_subgraph(message.triples)
}, True

View file

@ -0,0 +1,55 @@
import asyncio
import queue
from pulsar.schema import JsonSchema
import uuid
from .. schema import Triples
from .. schema import triples_store_queue
from . subscriber import Subscriber
from . socket import SocketEndpoint
from . serialize import serialize_triples
class TriplesStreamEndpoint(SocketEndpoint):
def __init__(self, pulsar_host, auth, path="/api/v1/stream/triples"):
super(TriplesStreamEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_host=pulsar_host
self.subscriber = Subscriber(
self.pulsar_host, triples_store_queue,
"api-gateway", "api-gateway",
schema=JsonSchema(Triples)
)
async def start(self):
self.subscriber.start()
async def async_thread(self, ws, running):
id = str(uuid.uuid4())
q = self.subscriber.subscribe_all(id)
while running.get():
try:
resp = await asyncio.to_thread(q.get, timeout=0.5)
await ws.send_json(serialize_triples(resp))
except queue.Empty:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
break
self.subscriber.unsubscribe_all(id)
running.stop()

View file

@ -0,0 +1,3 @@
from . service import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . hf import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,147 @@
"""
Document embeddings query service. Input is vector, output is an array
of chunks. Pinecone implementation.
"""
from pinecone import Pinecone, ServerlessSpec
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
import uuid
import os
from .... schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
from .... schema import Error, Value
from .... schema import document_embeddings_request_queue
from .... schema import document_embeddings_response_queue
from .... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = document_embeddings_request_queue
default_output_queue = document_embeddings_response_queue
default_subscriber = module
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
self.url = params.get("url", None)
self.api_key = params.get("api_key", default_api_key)
if self.url:
self.pinecone = PineconeGRPC(
api_key = self.api_key,
host = self.url
)
else:
self.pinecone = Pinecone(api_key = self.api_key)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": DocumentEmbeddingsRequest,
"output_schema": DocumentEmbeddingsResponse,
"url": self.url,
}
)
def handle(self, msg):
try:
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
chunks = []
for vec in v.vectors:
dim = len(vec)
index_name = (
"d-" + v.user + "-" + str(dim)
)
index = self.pinecone.Index(index_name)
results = index.query(
namespace=v.collection,
vector=vec,
top_k=v.limit,
include_values=False,
include_metadata=True
)
search_result = self.client.query_points(
collection_name=collection,
query=vec,
limit=v.limit,
with_payload=True,
).points
for r in results.matches:
doc = r.metadata["doc"]
chunks.add(doc)
print("Send response...", flush=True)
r = DocumentEmbeddingsResponse(documents=chunks, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = DocumentEmbeddingsResponse(
error=Error(
type = "llm-error",
message = str(e),
),
documents=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-a', '--api-key',
default=default_api_key,
help='Pinecone API key. (default from PINECONE_API_KEY)'
)
parser.add_argument(
'-u', '--url',
help='Pinecone URL. If unspecified, serverless is used'
)
def run():
Processor.start(module, __doc__)

View file

@ -0,0 +1,3 @@
from . service import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . hf import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,156 @@
"""
Graph embeddings query service. Input is vector, output is list of
entities. Pinecone implementation.
"""
from pinecone import Pinecone, ServerlessSpec
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
import uuid
import os
from .... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from .... schema import Error, Value
from .... schema import graph_embeddings_request_queue
from .... schema import graph_embeddings_response_queue
from .... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = graph_embeddings_request_queue
default_output_queue = graph_embeddings_response_queue
default_subscriber = module
default_api_key = os.getenv("PINECONE_API_KEY", "not-specified")
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
self.url = params.get("url", None)
self.api_key = params.get("api_key", default_api_key)
if self.url:
self.pinecone = PineconeGRPC(
api_key = self.api_key,
host = self.url
)
else:
self.pinecone = Pinecone(api_key = self.api_key)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": GraphEmbeddingsRequest,
"output_schema": GraphEmbeddingsResponse,
"url": self.url,
}
)
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
return Value(value=ent, is_uri=True)
else:
return Value(value=ent, is_uri=False)
def handle(self, msg):
try:
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
entities = set()
for vec in v.vectors:
dim = len(vec)
index_name = (
"t-" + v.user + "-" + str(dim)
)
index = self.pinecone.Index(index_name)
results = index.query(
namespace=v.collection,
vector=vec,
top_k=v.limit,
include_values=False,
include_metadata=True
)
for r in results.matches:
ent = r.metadata["entity"]
entities.add(ent)
# Convert set to list
entities = list(entities)
ents2 = []
for ent in entities:
ents2.append(self.create_value(ent))
entities = ents2
print("Send response...", flush=True)
r = GraphEmbeddingsResponse(entities=entities, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = GraphEmbeddingsResponse(
error=Error(
type = "llm-error",
message = str(e),
),
entities=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-a', '--api-key',
default=default_api_key,
help='Pinecone API key. (default from PINECONE_API_KEY)'
)
parser.add_argument(
'-u', '--url',
help='Pinecone URL. If unspecified, serverless is used'
)
def run():
Processor.start(module, __doc__)

Some files were not shown because too many files have changed in this diff Show more