Feature/subpackages (#80)

* Renaming what will become the core package

* Tweaking to get  package build working

* Fix metering merge

* Rename to core directory

* Bump version.  Use namespace searching for packaging trustgraph-core

* Change references to trustgraph-core

* Forming embeddings-hf package

* Reference modules in core package.

* Build both packages to one container, bump version

* Update YAMLs
This commit is contained in:
cybermaggedon 2024-09-30 14:00:29 +01:00 committed by GitHub
parent 14d79ef9f1
commit f081933217
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
303 changed files with 681 additions and 624 deletions

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.chunking.recursive import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.chunking.token import run
run()

View file

@ -0,0 +1,45 @@
#!/usr/bin/env python3
"""
Concatenates multiple parquet files into a single parquet output
"""
import pyarrow as pa
import pyarrow.parquet as pq
import pandas as pd
import sys
import argparse
parser = argparse.ArgumentParser(
prog="combine-parquet",
description=__doc__
)
parser.add_argument(
'-i', '--input',
nargs='*',
help=f'Input files'
)
parser.add_argument(
'-o', '--output',
help=f'Output files'
)
args = parser.parse_args()
df = None
for file in args.input:
part = pq.read_table(file).to_pandas()
if df is None:
df = part
else:
df = pd.concat([df, part], ignore_index=True)
if df is not None:
table = pa.Table.from_pandas(df)
pq.write_table(table, args.output)

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.query.doc_embeddings.milvus import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.query.doc_embeddings.qdrant import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.storage.doc_embeddings.milvus import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.storage.doc_embeddings.qdrant import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.retrieval.document_rag import run
run()

View file

@ -0,0 +1,24 @@
#!/usr/bin/env python3
import pyarrow as pa
import pyarrow.csv as pc
import pyarrow.parquet as pq
import pandas as pd
import sys
df = None
for file in sys.argv[1:]:
part = pq.read_table(file).to_pandas()
if df is None:
df = part
else:
df = pd.concat([df, part], ignore_index=True)
if df is not None:
table = pa.Table.from_pandas(df)
pc.write_csv(table, sys.stdout.buffer)

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.embeddings.ollama import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.embeddings.vectorize import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.dump.graph_embeddings.parquet import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.query.graph_embeddings.milvus import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.query.graph_embeddings.qdrant import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.storage.graph_embeddings.milvus import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.storage.graph_embeddings.qdrant import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.retrieval.graph_rag import run
run()

View file

@ -0,0 +1,46 @@
#!/usr/bin/env python3
"""
Connects to the graph query service and dumps all graph edges.
"""
import argparse
import os
from trustgraph.core.clients.triples_query_client import TriplesQueryClient
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650')
def show_graph(pulsar):
tq = TriplesQueryClient(pulsar_host=pulsar)
rows = tq.request(None, None, None, limit=10_000_000)
for row in rows:
print(row.s.value, row.p.value, row.o.value)
def main():
parser = argparse.ArgumentParser(
prog='graph-show',
description=__doc__,
)
parser.add_argument(
'-p', '--pulsar-host',
default=default_pulsar_host,
help=f'Pulsar host (default: {default_pulsar_host})',
)
args = parser.parse_args()
try:
show_graph(args.pulsar_host)
except Exception as e:
print("Exception:", e, flush=True)
main()

View file

@ -0,0 +1,74 @@
#!/usr/bin/env python3
"""
Connects to the graph query service and dumps all graph edges.
"""
import argparse
import os
from trustgraph.core.clients.triples_query_client import TriplesQueryClient
import rdflib
import io
import sys
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650')
def show_graph(pulsar):
tq = TriplesQueryClient(pulsar_host=pulsar)
rows = tq.request(None, None, None, limit=10_000_000)
g = rdflib.Graph()
for row in rows:
sv = rdflib.term.URIRef(row.s.value)
pv = rdflib.term.URIRef(row.p.value)
if row.o.is_uri:
# Skip malformed URLs with spaces in
if " " in row.o.value:
continue
ov = rdflib.term.URIRef(row.o.value)
else:
ov = rdflib.term.Literal(row.o.value)
g.add((sv, pv, ov))
g.serialize(destination="output.ttl", format="turtle")
buf = io.BytesIO()
g.serialize(destination=buf, format="turtle")
sys.stdout.write(buf.getvalue().decode("utf-8"))
def main():
parser = argparse.ArgumentParser(
prog='graph-show',
description=__doc__,
)
parser.add_argument(
'-p', '--pulsar-host',
default=default_pulsar_host,
help=f'Pulsar host (default: {default_pulsar_host})',
)
args = parser.parse_args()
try:
show_graph(args.pulsar_host)
except Exception as e:
print("Exception:", e, flush=True)
main()

View file

@ -0,0 +1,11 @@
#!/usr/bin/env bash
CSRF_TOKEN=$(curl http://localhost:7750/pulsar-manager/csrf-token)
curl \
-H "X-XSRF-TOKEN: $CSRF_TOKEN" \
-H "Cookie: XSRF-TOKEN=$CSRF_TOKEN;" \
-H 'Content-Type: application/json' \
-X PUT \
http://localhost:7750/pulsar-manager/users/superuser \
-d '{"name": "admin", "password": "apachepulsar", "description": "test", "email": "username@test.org"}'

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.extract.kg.definitions import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.extract.kg.relationships import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.extract.kg.topics import run
run()

View file

@ -0,0 +1,145 @@
#!/usr/bin/env python3
"""
Loads Graph embeddings into TrustGraph processing.
"""
import pulsar
from pulsar.schema import JsonSchema
from trustgraph.core.schema import GraphEmbeddings, Value
from trustgraph.core.schema import graph_embeddings_store_queue
import argparse
import os
import time
import pyarrow as pa
import pyarrow.parquet as pq
from trustgraph.core.log_level import LogLevel
class Loader:
def __init__(
self,
pulsar_host,
output_queue,
log_level,
file,
):
self.client = pulsar.Client(
pulsar_host,
logger=pulsar.ConsoleLogger(log_level.to_pulsar())
)
self.producer = self.client.create_producer(
topic=output_queue,
schema=JsonSchema(GraphEmbeddings),
chunking_enabled=True,
)
self.file = file
def run(self):
try:
path = self.file
print("Reading file...")
table = pq.read_table(path)
print("Loaded.")
names = set(table.column_names)
if "embeddings" not in names:
print("No 'embeddings' column")
if "entity" not in names:
print("No 'entity' column")
embc = table.column("embeddings")
entc = table.column("entity")
for emb, ent in zip(embc, entc):
b = emb.as_py()
n = ent.as_py()
r = GraphEmbeddings(
vectors=b,
entity=Value(
value=n,
is_uri=n.startswith("https:")
)
)
self.producer.send(r)
except Exception as e:
print(e, flush=True)
def __del__(self):
self.client.close()
def main():
parser = argparse.ArgumentParser(
prog='loader',
description=__doc__,
)
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650')
default_output_queue = graph_embeddings_store_queue
parser.add_argument(
'-p', '--pulsar-host',
default=default_pulsar_host,
help=f'Pulsar host (default: {default_pulsar_host})',
)
parser.add_argument(
'-o', '--output-queue',
default=default_output_queue,
help=f'Output queue (default: {default_output_queue})'
)
parser.add_argument(
'-l', '--log-level',
type=LogLevel,
default=LogLevel.ERROR,
choices=list(LogLevel),
help=f'Output queue (default: info)'
)
parser.add_argument(
'-f', '--file',
required=True,
help=f'File to load'
)
args = parser.parse_args()
while True:
try:
p = Loader(
pulsar_host=args.pulsar_host,
output_queue=args.output_queue,
log_level=args.log_level,
file=args.file,
)
p.run()
print("File loaded.")
break
except Exception as e:
print("Exception:", e, flush=True)
print("Will retry...", flush=True)
time.sleep(10)
main()

128
trustgraph-core/scripts/load-pdf Executable file
View file

@ -0,0 +1,128 @@
#!/usr/bin/env python3
"""
Loads a PDF document into TrustGraph processing.
"""
import pulsar
from pulsar.schema import JsonSchema
from trustgraph.core.schema import Document, Source, document_ingest_queue
import base64
import hashlib
import argparse
import os
import time
from trustgraph.core.log_level import LogLevel
class Loader:
def __init__(
self,
pulsar_host,
output_queue,
log_level,
file,
):
self.client = pulsar.Client(
pulsar_host,
logger=pulsar.ConsoleLogger(log_level.to_pulsar())
)
self.producer = self.client.create_producer(
topic=output_queue,
schema=JsonSchema(Document),
chunking_enabled=True,
)
self.file = file
def run(self):
try:
path = self.file
data = open(path, "rb").read()
id = hashlib.sha256(path.encode("utf-8")).hexdigest()[0:8]
r = Document(
source=Source(
source=path,
title=path,
id=id,
),
data=base64.b64encode(data),
)
self.producer.send(r)
except Exception as e:
print(e, flush=True)
def __del__(self):
self.client.close()
def main():
parser = argparse.ArgumentParser(
prog='loader',
description=__doc__,
)
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650')
default_output_queue = document_ingest_queue
parser.add_argument(
'-p', '--pulsar-host',
default=default_pulsar_host,
help=f'Pulsar host (default: {default_pulsar_host})',
)
parser.add_argument(
'-o', '--output-queue',
default=default_output_queue,
help=f'Output queue (default: {default_output_queue})'
)
parser.add_argument(
'-l', '--log-level',
type=LogLevel,
default=LogLevel.ERROR,
choices=list(LogLevel),
help=f'Output queue (default: info)'
)
parser.add_argument(
'-f', '--file',
required=True,
help=f'File to load'
)
args = parser.parse_args()
while True:
try:
p = Loader(
pulsar_host=args.pulsar_host,
output_queue=args.output_queue,
log_level=args.log_level,
file=args.file,
)
p.run()
print("File loaded.")
break
except Exception as e:
print("Exception:", e, flush=True)
print("Will retry...", flush=True)
time.sleep(10)
main()

128
trustgraph-core/scripts/load-text Executable file
View file

@ -0,0 +1,128 @@
#!/usr/bin/env python3
"""
Loads a text document into TrustGraph processing.
"""
import pulsar
from pulsar.schema import JsonSchema
from trustgraph.core.schema import TextDocument, Source, text_ingest_queue
import base64
import hashlib
import argparse
import os
import time
from trustgraph.core.log_level import LogLevel
class Loader:
def __init__(
self,
pulsar_host,
output_queue,
log_level,
file,
):
self.client = pulsar.Client(
pulsar_host,
logger=pulsar.ConsoleLogger(log_level.to_pulsar())
)
self.producer = self.client.create_producer(
topic=output_queue,
schema=JsonSchema(TextDocument),
chunking_enabled=True,
)
self.file = file
def run(self):
try:
path = self.file
data = open(path, "rb").read()
id = hashlib.sha256(path.encode("utf-8")).hexdigest()[0:8]
r = TextDocument(
source=Source(
source=path,
title=path,
id=id,
),
text=data,
)
self.producer.send(r)
except Exception as e:
print(e, flush=True)
def __del__(self):
self.client.close()
def main():
parser = argparse.ArgumentParser(
prog='loader',
description=__doc__,
)
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650')
default_output_queue = text_ingest_queue
parser.add_argument(
'-p', '--pulsar-host',
default=default_pulsar_host,
help=f'Pulsar host (default: {default_pulsar_host})',
)
parser.add_argument(
'-o', '--output-queue',
default=default_output_queue,
help=f'Output queue (default: {default_output_queue})'
)
parser.add_argument(
'-l', '--log-level',
type=LogLevel,
default=LogLevel.ERROR,
choices=list(LogLevel),
help=f'Output queue (default: info)'
)
parser.add_argument(
'-f', '--file',
required=True,
help=f'File to load'
)
args = parser.parse_args()
while True:
try:
p = Loader(
pulsar_host=args.pulsar_host,
output_queue=args.output_queue,
log_level=args.log_level,
file=args.file,
)
p.run()
print("File loaded.")
break
except Exception as e:
print("Exception:", e, flush=True)
print("Will retry...", flush=True)
time.sleep(10)
main()

View file

@ -0,0 +1,144 @@
#!/usr/bin/env python3
"""
Loads Graph embeddings into TrustGraph processing.
"""
import pulsar
from pulsar.schema import JsonSchema
from trustgraph.core.schema import Triple, Value
from trustgraph.core.schema import triples_store_queue
import argparse
import os
import time
import pyarrow as pa
import pyarrow.parquet as pq
from trustgraph.core.log_level import LogLevel
class Loader:
def __init__(
self,
pulsar_host,
output_queue,
log_level,
file,
):
self.client = pulsar.Client(
pulsar_host,
logger=pulsar.ConsoleLogger(log_level.to_pulsar())
)
self.producer = self.client.create_producer(
topic=output_queue,
schema=JsonSchema(Triple),
chunking_enabled=True,
)
self.file = file
def run(self):
try:
path = self.file
print("Reading file...")
table = pq.read_table(path)
print("Loaded.")
names = set(table.column_names)
if "s" not in names:
print("No 's' column")
if "p" not in names:
print("No 'p' column")
if "o" not in names:
print("No 'o' column")
sc = table.column("s")
pc = table.column("p")
oc = table.column("o")
for s, p, o in zip(sc, pc, oc):
r = Triple(
s=Value(value=s.as_py(), is_uri=True),
p=Value(value=p.as_py(), is_uri=True),
o=Value(value=o.as_py(), is_uri=o.as_py().startswith("https:"))
)
self.producer.send(r)
except Exception as e:
print(e, flush=True)
def __del__(self):
self.client.close()
def main():
parser = argparse.ArgumentParser(
prog='loader',
description=__doc__,
)
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650')
default_output_queue = triples_store_queue
parser.add_argument(
'-p', '--pulsar-host',
default=default_pulsar_host,
help=f'Pulsar host (default: {default_pulsar_host})',
)
parser.add_argument(
'-o', '--output-queue',
default=default_output_queue,
help=f'Output queue (default: {default_output_queue})'
)
parser.add_argument(
'-l', '--log-level',
type=LogLevel,
default=LogLevel.ERROR,
choices=list(LogLevel),
help=f'Output queue (default: info)'
)
parser.add_argument(
'-f', '--file',
required=True,
help=f'File to load'
)
args = parser.parse_args()
while True:
try:
p = Loader(
pulsar_host=args.pulsar_host,
output_queue=args.output_queue,
log_level=args.log_level,
file=args.file,
)
p.run()
print("File loaded.")
break
except Exception as e:
print("Exception:", e, flush=True)
print("Will retry...", flush=True)
time.sleep(10)
main()

View file

@ -0,0 +1,5 @@
#!/usr/bin/env python3
from trustgraph.core.metering import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.extract.object.row import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.storage.object_embeddings.milvus import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.decoding.pdf import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.model.prompt.generic import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.model.prompt.template import run
run()

View file

@ -0,0 +1,49 @@
#!/usr/bin/env python3
"""
Uses the Document RAG service to answer a query
"""
import argparse
import os
from trustgraph.core.clients.document_rag_client import DocumentRagClient
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650')
def query(pulsar, query):
rag = DocumentRagClient(pulsar_host=pulsar)
resp = rag.request(query)
print(resp)
def main():
parser = argparse.ArgumentParser(
prog='graph-show',
description=__doc__,
)
parser.add_argument(
'-p', '--pulsar-host',
default=default_pulsar_host,
help=f'Pulsar host (default: {default_pulsar_host})',
)
parser.add_argument(
'-q', '--query',
required=True,
help=f'Query to execute',
)
args = parser.parse_args()
try:
query(args.pulsar_host, args.query)
except Exception as e:
print("Exception:", e, flush=True)
main()

View file

@ -0,0 +1,49 @@
#!/usr/bin/env python3
"""
Uses the GraphRAG service to answer a query
"""
import argparse
import os
from trustgraph.core.clients.graph_rag_client import GraphRagClient
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650')
def query(pulsar, query):
rag = GraphRagClient(pulsar_host=pulsar)
resp = rag.request(query)
print(resp)
def main():
parser = argparse.ArgumentParser(
prog='graph-show',
description=__doc__,
)
parser.add_argument(
'-p', '--pulsar-host',
default=default_pulsar_host,
help=f'Pulsar host (default: {default_pulsar_host})',
)
parser.add_argument(
'-q', '--query',
required=True,
help=f'Query to execute',
)
args = parser.parse_args()
try:
query(args.pulsar_host, args.query)
except Exception as e:
print("Exception:", e, flush=True)
main()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.storage.rows.cassandra import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.processing import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.model.text_completion.azure import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.model.text_completion.bedrock import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.model.text_completion.claude import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.model.text_completion.cohere import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.model.text_completion.llamafile import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.model.text_completion.ollama import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.model.text_completion.openai import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.model.text_completion.vertexai import run
run()

View file

@ -0,0 +1,119 @@
#!/usr/bin/env python3
"""
Initialises Pulsar with Trustgraph tenant / namespaces & policy
"""
import requests
import time
import argparse
default_pulsar_admin_url = "http://pulsar:8080"
def get_clusters(url):
print("Get clusters...", flush=True)
resp = requests.get(f"{url}/admin/v2/clusters")
if resp.status_code != 200: raise RuntimeError("Could not fetch clusters")
return resp.json()
def ensure_tenant(url, tenant, clusters):
resp = requests.get(f"{url}/admin/v2/tenants/{tenant}")
if resp.status_code == 200:
print(f"Tenant {tenant} already exists.", flush=True)
return
resp = requests.put(
f"{url}/admin/v2/tenants/{tenant}",
json={
"adminRoles": [],
"allowedClusters": clusters,
}
)
if resp.status_code != 204:
print(resp.text, flush=True)
raise RuntimeError("Tenant creation failed.")
print(f"Tenant {tenant} created.", flush=True)
def ensure_namespace(url, tenant, namespace, config):
resp = requests.get(f"{url}/admin/v2/namespaces/{tenant}/{namespace}")
if resp.status_code == 200:
print(f"Namespace {tenant}/{namespace} already exists.", flush=True)
return
resp = requests.put(
f"{url}/admin/v2/namespaces/{tenant}/{namespace}",
json=config,
)
if resp.status_code != 204:
print(resp.status_code, flush=True)
print(resp.text, flush=True)
raise RuntimeError(f"Namespace {tenant}/{namespace} creation failed.")
print(f"Namespace {tenant}/{namespace} created.", flush=True)
def init(url, tenant="tg"):
clusters = get_clusters(url)
ensure_tenant(url, tenant, clusters)
ensure_namespace(url, tenant, "flow", {})
ensure_namespace(url, tenant, "request", {})
ensure_namespace(url, tenant, "response", {
"retention_policies": {
"retentionSizeInMB": -1,
"retentionTimeInMinutes": 3,
}
})
def main():
parser = argparse.ArgumentParser(
prog='tg-init-pulsar',
description=__doc__,
)
parser.add_argument(
'-p', '--pulsar-admin-url',
default=default_pulsar_admin_url,
help=f'Pulsar admin URL (default: {default_pulsar_admin_url})',
)
args = parser.parse_args()
while True:
try:
print(flush=True)
print(
f"Initialising with Pulsar {args.pulsar_admin_url}...",
flush=True
)
init(args.pulsar_admin_url, "tg")
print("Initialisation complete.", flush=True)
break
except Exception as e:
print("Exception:", e, flush=True)
print("Sleeping...", flush=True)
time.sleep(2)
print("Will retry...", flush=True)
main()

View file

@ -0,0 +1,24 @@
#!/usr/bin/env python3
import requests
import tabulate
url = 'http://localhost:9090/api/v1/query?query=processor_state%7Bprocessor_state%3D%22running%22%7D'
resp = requests.get(url)
obj = resp.json()
tbl = [
[
m["metric"]["job"],
"running" if int(m["value"][1]) > 0 else "down"
]
for m in obj["data"]["result"]
]
print(tabulate.tabulate(
tbl, tablefmt="pretty", headers=["processor", "state"],
stralign="left"
))

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.dump.triples.parquet import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.query.triples.cassandra import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.query.triples.neo4j import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.storage.triples.cassandra import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.core.storage.triples.neo4j import run
run()

108
trustgraph-core/setup.py Normal file
View file

@ -0,0 +1,108 @@
import setuptools
import os
with open("README.md", "r") as fh:
long_description = fh.read()
version = "0.11.6"
setuptools.setup(
name="trustgraph-core",
version=version,
author="trustgraph.ai",
author_email="security@trustgraph.ai",
description="TrustGraph provides a means to run a pipeline of flexible AI processing components in a flexible means to achieve a processing pipeline.",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/trustgraph-ai/trustgraph",
packages=setuptools.find_namespace_packages(
where='./',
# include=['trustgraph.core']
),
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)",
"Operating System :: OS Independent",
],
python_requires='>=3.8',
download_url = "https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v" + version + ".tar.gz",
install_requires=[
"urllib3",
"rdflib",
"pymilvus",
"langchain",
"langchain-core",
"langchain-text-splitters",
"langchain-community",
"requests",
"cassandra-driver",
"pulsar-client",
"pypdf",
"qdrant-client",
"tabulate",
"anthropic",
"google-cloud-aiplatform",
"pyyaml",
"prometheus-client",
"pyarrow",
"cohere",
"boto3",
"openai",
"neo4j",
"tiktoken",
],
scripts=[
"scripts/chunker-recursive",
"scripts/chunker-token",
"scripts/concat-parquet",
"scripts/de-query-milvus",
"scripts/de-query-qdrant",
"scripts/de-write-milvus",
"scripts/de-write-qdrant",
"scripts/document-rag",
"scripts/dump-parquet",
"scripts/embeddings-ollama",
"scripts/embeddings-vectorize",
"scripts/ge-dump-parquet",
"scripts/ge-query-milvus",
"scripts/ge-query-qdrant",
"scripts/ge-write-milvus",
"scripts/ge-write-qdrant",
"scripts/graph-rag",
"scripts/graph-show",
"scripts/graph-to-turtle",
"scripts/init-pulsar-manager",
"scripts/kg-extract-definitions",
"scripts/kg-extract-topics",
"scripts/kg-extract-relationships",
"scripts/load-graph-embeddings",
"scripts/load-pdf",
"scripts/load-text",
"scripts/load-triples",
"scripts/metering",
"scripts/object-extract-row",
"scripts/oe-write-milvus",
"scripts/pdf-decoder",
"scripts/prompt-generic",
"scripts/prompt-template",
"scripts/query-document-rag",
"scripts/query-graph-rag",
"scripts/rows-write-cassandra",
"scripts/run-processing",
"scripts/text-completion-azure",
"scripts/text-completion-bedrock",
"scripts/text-completion-claude",
"scripts/text-completion-cohere",
"scripts/text-completion-llamafile",
"scripts/text-completion-ollama",
"scripts/text-completion-openai",
"scripts/text-completion-vertexai",
"scripts/tg-init-pulsar",
"scripts/tg-processor-state",
"scripts/triples-dump-parquet",
"scripts/triples-query-cassandra",
"scripts/triples-query-neo4j",
"scripts/triples-write-cassandra",
"scripts/triples-write-neo4j",
]
)

View file

@ -0,0 +1,6 @@
from . base_processor import BaseProcessor
from . consumer import Consumer
from . producer import Producer
from . consumer_producer import ConsumerProducer

View file

@ -0,0 +1,119 @@
import os
import argparse
import pulsar
import _pulsar
import time
from prometheus_client import start_http_server, Info
from .. log_level import LogLevel
class BaseProcessor:
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650')
def __init__(self, **params):
self.client = None
if not hasattr(__class__, "params_metric"):
__class__.params_metric = Info(
'params', 'Parameters configuration'
)
# FIXME: Maybe outputs information it should not
__class__.params_metric.info({
k: str(params[k])
for k in params
})
pulsar_host = params.get("pulsar_host", self.default_pulsar_host)
log_level = params.get("log_level", LogLevel.INFO)
self.pulsar_host = pulsar_host
self.client = pulsar.Client(
pulsar_host,
logger=pulsar.ConsoleLogger(log_level.to_pulsar())
)
def __del__(self):
if self.client:
self.client.close()
@staticmethod
def add_args(parser):
parser.add_argument(
'-p', '--pulsar-host',
default=__class__.default_pulsar_host,
help=f'Pulsar host (default: {__class__.default_pulsar_host})',
)
parser.add_argument(
'-l', '--log-level',
type=LogLevel,
default=LogLevel.INFO,
choices=list(LogLevel),
help=f'Output queue (default: info)'
)
parser.add_argument(
'--metrics',
action=argparse.BooleanOptionalAction,
default=True,
help=f'Metrics enabled (default: true)',
)
parser.add_argument(
'-P', '--metrics-port',
type=int,
default=8000,
help=f'Pulsar host (default: 8000)',
)
def run(self):
raise RuntimeError("Something should have implemented the run method")
@classmethod
def start(cls, prog, doc):
parser = argparse.ArgumentParser(
prog=prog,
description=doc
)
cls.add_args(parser)
args = parser.parse_args()
args = vars(args)
print(args)
if args["metrics"]:
start_http_server(args["metrics_port"])
while True:
try:
p = cls(**args)
p.run()
except KeyboardInterrupt:
print("Keyboard interrupt.")
return
except _pulsar.Interrupted:
print("Pulsar Interrupted.")
return
except Exception as e:
print(type(e))
print("Exception:", e, flush=True)
print("Will retry...", flush=True)
time.sleep(4)

View file

@ -0,0 +1,107 @@
from pulsar.schema import JsonSchema
from prometheus_client import Histogram, Info, Counter, Enum
import time
from . base_processor import BaseProcessor
from .. exceptions import TooManyRequests
class Consumer(BaseProcessor):
def __init__(self, **params):
if not hasattr(__class__, "state_metric"):
__class__.state_metric = Enum(
'processor_state', 'Processor state',
states=['starting', 'running', 'stopped']
)
__class__.state_metric.state('starting')
__class__.state_metric.state('starting')
super(Consumer, self).__init__(**params)
input_queue = params.get("input_queue")
subscriber = params.get("subscriber")
input_schema = params.get("input_schema")
if input_schema == None:
raise RuntimeError("input_schema must be specified")
if not hasattr(__class__, "request_metric"):
__class__.request_metric = Histogram(
'request_latency', 'Request latency (seconds)'
)
if not hasattr(__class__, "pubsub_metric"):
__class__.pubsub_metric = Info(
'pubsub', 'Pub/sub configuration'
)
if not hasattr(__class__, "processing_metric"):
__class__.processing_metric = Counter(
'processing_count', 'Processing count', ["status"]
)
__class__.pubsub_metric.info({
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": input_schema.__name__,
})
self.consumer = self.client.subscribe(
input_queue, subscriber,
schema=JsonSchema(input_schema),
)
def run(self):
__class__.state_metric.state('running')
while True:
msg = self.consumer.receive()
try:
with __class__.request_metric.time():
self.handle(msg)
# Acknowledge successful processing of the message
self.consumer.acknowledge(msg)
__class__.processing_metric.labels(status="success").inc()
except TooManyRequests:
self.consumer.negative_acknowledge(msg)
print("TooManyRequests: will retry")
__class__.processing_metric.labels(status="rate-limit").inc()
time.sleep(5)
continue
except Exception as e:
print("Exception:", e, flush=True)
# Message failed to be processed
self.consumer.negative_acknowledge(msg)
__class__.processing_metric.labels(status="error").inc()
@staticmethod
def add_args(parser, default_input_queue, default_subscriber):
BaseProcessor.add_args(parser)
parser.add_argument(
'-i', '--input-queue',
default=default_input_queue,
help=f'Input queue (default: {default_input_queue})'
)
parser.add_argument(
'-s', '--subscriber',
default=default_subscriber,
help=f'Queue subscriber name (default: {default_subscriber})'
)

View file

@ -0,0 +1,139 @@
from pulsar.schema import JsonSchema
from prometheus_client import Histogram, Info, Counter, Enum
import time
from . base_processor import BaseProcessor
from .. exceptions import TooManyRequests
# FIXME: Derive from consumer? And producer?
class ConsumerProducer(BaseProcessor):
def __init__(self, **params):
if not hasattr(__class__, "state_metric"):
__class__.state_metric = Enum(
'processor_state', 'Processor state',
states=['starting', 'running', 'stopped']
)
__class__.state_metric.state('starting')
__class__.state_metric.state('starting')
input_queue = params.get("input_queue")
output_queue = params.get("output_queue")
subscriber = params.get("subscriber")
input_schema = params.get("input_schema")
output_schema = params.get("output_schema")
if not hasattr(__class__, "request_metric"):
__class__.request_metric = Histogram(
'request_latency', 'Request latency (seconds)'
)
if not hasattr(__class__, "output_metric"):
__class__.output_metric = Counter(
'output_count', 'Output items created'
)
if not hasattr(__class__, "pubsub_metric"):
__class__.pubsub_metric = Info(
'pubsub', 'Pub/sub configuration'
)
if not hasattr(__class__, "processing_metric"):
__class__.processing_metric = Counter(
'processing_count', 'Processing count', ["status"]
)
__class__.pubsub_metric.info({
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": input_schema.__name__,
"output_schema": output_schema.__name__,
})
super(ConsumerProducer, self).__init__(**params)
if input_schema == None:
raise RuntimeError("input_schema must be specified")
if output_schema == None:
raise RuntimeError("output_schema must be specified")
self.producer = self.client.create_producer(
topic=output_queue,
schema=JsonSchema(output_schema),
)
self.consumer = self.client.subscribe(
input_queue, subscriber,
schema=JsonSchema(input_schema),
)
def run(self):
__class__.state_metric.state('running')
while True:
msg = self.consumer.receive()
try:
with __class__.request_metric.time():
resp = self.handle(msg)
# Acknowledge successful processing of the message
self.consumer.acknowledge(msg)
__class__.processing_metric.labels(status="success").inc()
except TooManyRequests:
self.consumer.negative_acknowledge(msg)
print("TooManyRequests: will retry")
__class__.processing_metric.labels(status="rate-limit").inc()
time.sleep(5)
continue
except Exception as e:
print("Exception:", e, flush=True)
# Message failed to be processed
self.consumer.negative_acknowledge(msg)
__class__.processing_metric.labels(status="error").inc()
def send(self, msg, properties={}):
self.producer.send(msg, properties)
__class__.output_metric.inc()
@staticmethod
def add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
):
BaseProcessor.add_args(parser)
parser.add_argument(
'-i', '--input-queue',
default=default_input_queue,
help=f'Input queue (default: {default_input_queue})'
)
parser.add_argument(
'-s', '--subscriber',
default=default_subscriber,
help=f'Queue subscriber name (default: {default_subscriber})'
)
parser.add_argument(
'-o', '--output-queue',
default=default_output_queue,
help=f'Output queue (default: {default_output_queue})'
)

View file

@ -0,0 +1,55 @@
from pulsar.schema import JsonSchema
from prometheus_client import Info, Counter
from . base_processor import BaseProcessor
class Producer(BaseProcessor):
def __init__(self, **params):
output_queue = params.get("output_queue")
output_schema = params.get("output_schema")
if not hasattr(__class__, "output_metric"):
__class__.output_metric = Counter(
'output_count', 'Output items created'
)
if not hasattr(__class__, "pubsub_metric"):
__class__.pubsub_metric = Info(
'pubsub', 'Pub/sub configuration'
)
__class__.pubsub_metric.info({
"output_queue": output_queue,
"output_schema": output_schema.__name__,
})
super(Producer, self).__init__(**params)
if output_schema == None:
raise RuntimeError("output_schema must be specified")
self.producer = self.client.create_producer(
topic=output_queue,
schema=JsonSchema(output_schema),
)
def send(self, msg, properties={}):
self.producer.send(msg, properties)
__class__.output_metric.inc()
@staticmethod
def add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
):
BaseProcessor.add_args(parser)
parser.add_argument(
'-o', '--output-queue',
default=default_output_queue,
help=f'Output queue (default: {default_output_queue})'
)

View file

@ -0,0 +1,3 @@
from . chunker import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . chunker import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,108 @@
"""
Simple decoder, accepts text documents on input, outputs chunks from the
as text as separate output objects.
"""
from langchain_text_splitters import RecursiveCharacterTextSplitter
from prometheus_client import Histogram
from ... schema import TextDocument, Chunk, Source
from ... schema import text_ingest_queue, chunk_ingest_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_ingest_queue
default_output_queue = chunk_ingest_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
chunk_size = params.get("chunk_size", 2000)
chunk_overlap = params.get("chunk_overlap", 100)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": TextDocument,
"output_schema": Chunk,
}
)
if not hasattr(__class__, "chunk_metric"):
__class__.chunk_metric = Histogram(
'chunk_size', 'Chunk size',
buckets=[100, 160, 250, 400, 650, 1000, 1600,
2500, 4000, 6400, 10000, 16000]
)
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
is_separator_regex=False,
)
def handle(self, msg):
v = msg.value()
print(f"Chunking {v.source.id}...", flush=True)
texts = self.text_splitter.create_documents(
[v.text.decode("utf-8")]
)
for ix, chunk in enumerate(texts):
id = v.source.id + "-c" + str(ix)
r = Chunk(
source=Source(
source=v.source.source,
id=id,
title=v.source.title
),
chunk=chunk.page_content.encode("utf-8"),
)
__class__.chunk_metric.observe(len(chunk.page_content))
self.send(r)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-z', '--chunk-size',
type=int,
default=2000,
help=f'Chunk size (default: 2000)'
)
parser.add_argument(
'-v', '--chunk-overlap',
type=int,
default=100,
help=f'Chunk overlap (default: 100)'
)
def run():
Processor.start(module, __doc__)

View file

@ -0,0 +1,3 @@
from . chunker import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . chunker import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,107 @@
"""
Simple decoder, accepts text documents on input, outputs chunks from the
as text as separate output objects.
"""
from langchain_text_splitters import TokenTextSplitter
from prometheus_client import Histogram
from ... schema import TextDocument, Chunk, Source
from ... schema import text_ingest_queue, chunk_ingest_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_ingest_queue
default_output_queue = chunk_ingest_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
chunk_size = params.get("chunk_size", 250)
chunk_overlap = params.get("chunk_overlap", 15)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": TextDocument,
"output_schema": Chunk,
}
)
if not hasattr(__class__, "chunk_metric"):
__class__.chunk_metric = Histogram(
'chunk_size', 'Chunk size',
buckets=[100, 160, 250, 400, 650, 1000, 1600,
2500, 4000, 6400, 10000, 16000]
)
self.text_splitter = TokenTextSplitter(
encoding_name="cl100k_base",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
def handle(self, msg):
v = msg.value()
print(f"Chunking {v.source.id}...", flush=True)
texts = self.text_splitter.create_documents(
[v.text.decode("utf-8")]
)
for ix, chunk in enumerate(texts):
id = v.source.id + "-c" + str(ix)
r = Chunk(
source=Source(
source=v.source.source,
id=id,
title=v.source.title
),
chunk=chunk.page_content.encode("utf-8"),
)
__class__.chunk_metric.observe(len(chunk.page_content))
self.send(r)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-z', '--chunk-size',
type=int,
default=250,
help=f'Chunk size (default: 250)'
)
parser.add_argument(
'-v', '--chunk-overlap',
type=int,
default=15,
help=f'Chunk overlap (default: 15)'
)
def run():
Processor.start(module, __doc__)

View file

@ -0,0 +1,125 @@
import pulsar
import _pulsar
import hashlib
import uuid
import time
from pulsar.schema import JsonSchema
from .. exceptions import *
# Default timeout for a request/response. In seconds.
DEFAULT_TIMEOUT=300
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class BaseClient:
def __init__(
self, log_level=ERROR,
subscriber=None,
input_queue=None,
output_queue=None,
input_schema=None,
output_schema=None,
pulsar_host="pulsar://pulsar:6650",
):
if input_queue == None: raise RuntimeError("Need input_queue")
if output_queue == None: raise RuntimeError("Need output_queue")
if input_schema == None: raise RuntimeError("Need input_schema")
if output_schema == None: raise RuntimeError("Need output_schema")
if subscriber == None:
subscriber = str(uuid.uuid4())
self.client = pulsar.Client(
pulsar_host,
logger=pulsar.ConsoleLogger(log_level),
)
self.producer = self.client.create_producer(
topic=input_queue,
schema=JsonSchema(input_schema),
chunking_enabled=True,
)
self.consumer = self.client.subscribe(
output_queue, subscriber,
schema=JsonSchema(output_schema),
)
self.input_schema = input_schema
self.output_schema = output_schema
def call(self, **args):
timeout = args.get("timeout", DEFAULT_TIMEOUT)
if "timeout" in args:
del args["timeout"]
id = str(uuid.uuid4())
r = self.input_schema(**args)
end_time = time.time() + timeout
self.producer.send(r, properties={ "id": id })
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=2500)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
value = msg.value()
if value.error:
self.consumer.acknowledge(msg)
if value.error.type == "llm-error":
raise LlmError(value.error.message)
elif value.error.type == "too-many-requests":
raise TooManyRequests(value.error.message)
elif value.error.type == "ParseError":
raise ParseError(value.error.message)
else:
raise RuntimeError(
f"{value.error.type}: {value.error.message}"
)
resp = msg.value()
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
def __del__(self):
if hasattr(self, "consumer"):
self.consumer.close()
if hasattr(self, "producer"):
self.producer.flush()
self.producer.close()
self.client.close()

View file

@ -0,0 +1,45 @@
import _pulsar
from .. schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
from .. schema import document_embeddings_request_queue
from .. schema import document_embeddings_response_queue
from . base import BaseClient
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class DocumentEmbeddingsClient(BaseClient):
def __init__(
self, log_level=ERROR,
subscriber=None,
input_queue=None,
output_queue=None,
pulsar_host="pulsar://pulsar:6650",
):
if input_queue == None:
input_queue = document_embeddings_request_queue
if output_queue == None:
output_queue = document_embeddings_response_queue
super(DocumentEmbeddingsClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
input_schema=DocumentEmbeddingsRequest,
output_schema=DocumentEmbeddingsResponse,
)
def request(self, vectors, limit=10, timeout=300):
return self.call(
vectors=vectors, limit=limit, timeout=timeout
).documents

View file

@ -0,0 +1,46 @@
import _pulsar
from .. schema import DocumentRagQuery, DocumentRagResponse
from .. schema import document_rag_request_queue, document_rag_response_queue
from . base import BaseClient
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class DocumentRagClient(BaseClient):
def __init__(
self,
log_level=ERROR,
subscriber=None,
input_queue=None,
output_queue=None,
pulsar_host="pulsar://pulsar:6650",
):
if input_queue == None:
input_queue = document_rag_request_queue
if output_queue == None:
output_queue = document_rag_response_queue
super(DocumentRagClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
input_schema=DocumentRagQuery,
output_schema=DocumentRagResponse,
)
def request(self, query, timeout=500):
return self.call(
query=query, timeout=timeout
).response

View file

@ -0,0 +1,44 @@
from pulsar.schema import JsonSchema
from .. schema import EmbeddingsRequest, EmbeddingsResponse
from .. schema import embeddings_request_queue, embeddings_response_queue
from . base import BaseClient
import _pulsar
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class EmbeddingsClient(BaseClient):
def __init__(
self, log_level=ERROR,
input_queue=None,
output_queue=None,
subscriber=None,
pulsar_host="pulsar://pulsar:6650",
):
if input_queue == None:
input_queue=embeddings_request_queue
if output_queue == None:
output_queue=embeddings_response_queue
super(EmbeddingsClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
input_schema=EmbeddingsRequest,
output_schema=EmbeddingsResponse,
)
def request(self, text, timeout=300):
return self.call(text=text, timeout=timeout).vectors

View file

@ -0,0 +1,45 @@
import _pulsar
from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from .. schema import graph_embeddings_request_queue
from .. schema import graph_embeddings_response_queue
from . base import BaseClient
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class GraphEmbeddingsClient(BaseClient):
def __init__(
self, log_level=ERROR,
subscriber=None,
input_queue=None,
output_queue=None,
pulsar_host="pulsar://pulsar:6650",
):
if input_queue == None:
input_queue = graph_embeddings_request_queue
if output_queue == None:
output_queue = graph_embeddings_response_queue
super(GraphEmbeddingsClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
input_schema=GraphEmbeddingsRequest,
output_schema=GraphEmbeddingsResponse,
)
def request(self, vectors, limit=10, timeout=300):
return self.call(
vectors=vectors, limit=limit, timeout=timeout
).entities

View file

@ -0,0 +1,46 @@
import _pulsar
from .. schema import GraphRagQuery, GraphRagResponse
from .. schema import graph_rag_request_queue, graph_rag_response_queue
from . base import BaseClient
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class GraphRagClient(BaseClient):
def __init__(
self,
log_level=ERROR,
subscriber=None,
input_queue=None,
output_queue=None,
pulsar_host="pulsar://pulsar:6650",
):
if input_queue == None:
input_queue = graph_rag_request_queue
if output_queue == None:
output_queue = graph_rag_response_queue
super(GraphRagClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
input_schema=GraphRagQuery,
output_schema=GraphRagResponse,
)
def request(self, query, timeout=500):
return self.call(
query=query, timeout=timeout
).response

View file

@ -0,0 +1,40 @@
import _pulsar
from .. schema import TextCompletionRequest, TextCompletionResponse
from .. schema import text_completion_request_queue
from .. schema import text_completion_response_queue
from . base import BaseClient
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class LlmClient(BaseClient):
def __init__(
self, log_level=ERROR,
subscriber=None,
input_queue=None,
output_queue=None,
pulsar_host="pulsar://pulsar:6650",
):
if input_queue is None: input_queue = text_completion_request_queue
if output_queue is None: output_queue = text_completion_response_queue
super(LlmClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
input_schema=TextCompletionRequest,
output_schema=TextCompletionResponse,
)
def request(self, prompt, timeout=300):
return self.call(prompt=prompt, timeout=timeout).response

View file

@ -0,0 +1,100 @@
import _pulsar
from .. schema import PromptRequest, PromptResponse, Fact, RowSchema, Field
from .. schema import prompt_request_queue
from .. schema import prompt_response_queue
from . base import BaseClient
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class PromptClient(BaseClient):
def __init__(
self, log_level=ERROR,
subscriber=None,
input_queue=None,
output_queue=None,
pulsar_host="pulsar://pulsar:6650",
):
if input_queue == None:
input_queue = prompt_request_queue
if output_queue == None:
output_queue = prompt_response_queue
super(PromptClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
input_schema=PromptRequest,
output_schema=PromptResponse,
)
def request_definitions(self, chunk, timeout=300):
return self.call(
kind="extract-definitions", chunk=chunk,
timeout=timeout
).definitions
def request_topics(self, chunk, timeout=300):
return self.call(
kind="extract-topics", chunk=chunk,
timeout=timeout
).topics
def request_relationships(self, chunk, timeout=300):
return self.call(
kind="extract-relationships", chunk=chunk,
timeout=timeout
).relationships
def request_rows(self, schema, chunk, timeout=300):
return self.call(
kind="extract-rows", chunk=chunk,
row_schema=RowSchema(
name=schema.name,
description=schema.description,
fields=[
Field(
name=f.name, type=str(f.type), size=f.size,
primary=f.primary, description=f.description,
)
for f in schema.fields
]
),
timeout=timeout
).rows
def request_kg_prompt(self, query, kg, timeout=300):
return self.call(
kind="kg-prompt",
query=query,
kg=[
Fact(s=v[0], p=v[1], o=v[2])
for v in kg
],
timeout=timeout
).answer
def request_document_prompt(self, query, documents, timeout=300):
return self.call(
kind="document-prompt",
query=query,
documents=documents,
timeout=timeout
).answer

View file

@ -0,0 +1,59 @@
#!/usr/bin/env python3
import _pulsar
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Value
from .. schema import triples_request_queue
from .. schema import triples_response_queue
from . base import BaseClient
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class TriplesQueryClient(BaseClient):
def __init__(
self, log_level=ERROR,
subscriber=None,
input_queue=None,
output_queue=None,
pulsar_host="pulsar://pulsar:6650",
):
if input_queue == None:
input_queue = triples_request_queue
if output_queue == None:
output_queue = triples_response_queue
super(TriplesQueryClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
input_schema=TriplesQueryRequest,
output_schema=TriplesQueryResponse,
)
def create_value(self, ent):
if ent == None: return None
if ent.startswith("http://") or ent.startswith("https://"):
return Value(value=ent, is_uri=True)
return Value(value=ent, is_uri=False)
def request(self, s, p, o, limit=10, timeout=60):
return self.call(
s=self.create_value(s),
p=self.create_value(p),
o=self.create_value(o),
limit=limit,
timeout=timeout,
).triples

View file

@ -0,0 +1,3 @@
from . pdf_decoder import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . pdf_decoder import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,87 @@
"""
Simple decoder, accepts PDF documents on input, outputs pages from the
PDF document as text as separate output objects.
"""
import tempfile
import base64
from langchain_community.document_loaders import PyPDFLoader
from ... schema import Document, TextDocument, Source
from ... schema import document_ingest_queue, text_ingest_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = document_ingest_queue
default_output_queue = text_ingest_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": Document,
"output_schema": TextDocument,
}
)
print("PDF inited")
def handle(self, msg):
print("PDF message received")
v = msg.value()
print(f"Decoding {v.source.id}...", flush=True)
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
fp.write(base64.b64decode(v.data))
fp.close()
with open(fp.name, mode='rb') as f:
loader = PyPDFLoader(fp.name)
pages = loader.load()
for ix, page in enumerate(pages):
id = v.source.id + "-p" + str(ix)
r = TextDocument(
source=Source(
source=v.source.source,
title=v.source.title,
id=id,
),
text=page.page_content.encode("utf-8"),
)
self.send(r)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
def run():
Processor.start(module, __doc__)

View file

@ -0,0 +1,108 @@
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
class TrustGraph:
def __init__(self, hosts=None):
if hosts is None:
hosts = ["localhost"]
self.cluster = Cluster(hosts)
self.session = self.cluster.connect()
self.init()
def clear(self):
self.session.execute("""
drop keyspace if exists trustgraph;
""");
self.init()
def init(self):
self.session.execute("""
create keyspace if not exists trustgraph
with replication = {
'class' : 'SimpleStrategy',
'replication_factor' : 1
};
""");
self.session.set_keyspace('trustgraph')
self.session.execute("""
create table if not exists triples (
s text,
p text,
o text,
PRIMARY KEY (s, p, o)
);
""");
self.session.execute("""
create index if not exists triples_p
ON triples (p);
""");
self.session.execute("""
create index if not exists triples_o
ON triples (o);
""");
def insert(self, s, p, o):
self.session.execute(
"insert into triples (s, p, o) values (%s, %s, %s)",
(s, p, o)
)
def get_all(self, limit=50):
return self.session.execute(
f"select s, p, o from triples limit {limit}"
)
def get_s(self, s, limit=10):
return self.session.execute(
f"select p, o from triples where s = %s limit {limit}",
(s,)
)
def get_p(self, p, limit=10):
return self.session.execute(
f"select s, o from triples where p = %s limit {limit}",
(p,)
)
def get_o(self, o, limit=10):
return self.session.execute(
f"select s, p from triples where o = %s limit {limit}",
(o,)
)
def get_sp(self, s, p, limit=10):
return self.session.execute(
f"select o from triples where s = %s and p = %s limit {limit}",
(s, p)
)
def get_po(self, p, o, limit=10):
return self.session.execute(
f"select s from triples where p = %s and o = %s allow filtering limit {limit}",
(p, o)
)
def get_os(self, o, s, limit=10):
return self.session.execute(
f"select p from triples where o = %s and s = %s limit {limit}",
(o, s)
)
def get_spo(self, s, p, o, limit=10):
return self.session.execute(
f"""select s as x from triples where s = %s and p = %s and o = %s limit {limit}""",
(s, p, o)
)

View file

@ -0,0 +1,138 @@
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
import time
class DocVectors:
def __init__(self, uri="http://localhost:19530", prefix='doc'):
self.client = MilvusClient(uri=uri)
# Strategy is to create collections per dimension. Probably only
# going to be using 1 anyway, but that means we don't need to
# hard-code the dimension anywhere, and no big deal if more than
# one are created.
self.collections = {}
self.prefix = prefix
# Time between reloads
self.reload_time = 90
# Next time to reload - this forces a reload at next window
self.next_reload = time.time() + self.reload_time
print("Reload at", self.next_reload)
def init_collection(self, dimension):
collection_name = self.prefix + "_" + str(dimension)
pkey_field = FieldSchema(
name="id",
dtype=DataType.INT64,
is_primary=True,
auto_id=True,
)
vec_field = FieldSchema(
name="vector",
dtype=DataType.FLOAT_VECTOR,
dim=dimension,
)
doc_field = FieldSchema(
name="doc",
dtype=DataType.VARCHAR,
max_length=65535,
)
schema = CollectionSchema(
fields = [pkey_field, vec_field, doc_field],
description = "Document embedding schema",
)
self.client.create_collection(
collection_name=collection_name,
schema=schema,
metric_type="COSINE",
)
index_params = MilvusClient.prepare_index_params()
index_params.add_index(
field_name="vector",
metric_type="COSINE",
index_type="IVF_SQ8",
index_name="vector_index",
params={ "nlist": 128 }
)
self.client.create_index(
collection_name=collection_name,
index_params=index_params
)
self.collections[dimension] = collection_name
def insert(self, embeds, doc):
dim = len(embeds)
if dim not in self.collections:
self.init_collection(dim)
data = [
{
"vector": embeds,
"doc": doc,
}
]
self.client.insert(
collection_name=self.collections[dim],
data=data
)
def search(self, embeds, fields=["doc"], limit=10):
dim = len(embeds)
if dim not in self.collections:
self.init_collection(dim)
coll = self.collections[dim]
search_params = {
"metric_type": "COSINE",
"params": {
"radius": 0.1,
"range_filter": 0.8
}
}
print("Loading...")
self.client.load_collection(
collection_name=coll,
)
print("Searching...")
res = self.client.search(
collection_name=coll,
data=[embeds],
limit=limit,
output_fields=fields,
search_params=search_params,
)[0]
# If reload time has passed, unload collection
if time.time() > self.next_reload:
print("Unloading, reload at", self.next_reload)
self.client.release_collection(
collection_name=coll,
)
self.next_reload = time.time() + self.reload_time
return res

View file

@ -0,0 +1,138 @@
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
import time
class EntityVectors:
def __init__(self, uri="http://localhost:19530", prefix='entity'):
self.client = MilvusClient(uri=uri)
# Strategy is to create collections per dimension. Probably only
# going to be using 1 anyway, but that means we don't need to
# hard-code the dimension anywhere, and no big deal if more than
# one are created.
self.collections = {}
self.prefix = prefix
# Time between reloads
self.reload_time = 90
# Next time to reload - this forces a reload at next window
self.next_reload = time.time() + self.reload_time
print("Reload at", self.next_reload)
def init_collection(self, dimension):
collection_name = self.prefix + "_" + str(dimension)
pkey_field = FieldSchema(
name="id",
dtype=DataType.INT64,
is_primary=True,
auto_id=True,
)
vec_field = FieldSchema(
name="vector",
dtype=DataType.FLOAT_VECTOR,
dim=dimension,
)
entity_field = FieldSchema(
name="entity",
dtype=DataType.VARCHAR,
max_length=65535,
)
schema = CollectionSchema(
fields = [pkey_field, vec_field, entity_field],
description = "Graph embedding schema",
)
self.client.create_collection(
collection_name=collection_name,
schema=schema,
metric_type="COSINE",
)
index_params = MilvusClient.prepare_index_params()
index_params.add_index(
field_name="vector",
metric_type="COSINE",
index_type="IVF_SQ8",
index_name="vector_index",
params={ "nlist": 128 }
)
self.client.create_index(
collection_name=collection_name,
index_params=index_params
)
self.collections[dimension] = collection_name
def insert(self, embeds, entity):
dim = len(embeds)
if dim not in self.collections:
self.init_collection(dim)
data = [
{
"vector": embeds,
"entity": entity,
}
]
self.client.insert(
collection_name=self.collections[dim],
data=data
)
def search(self, embeds, fields=["entity"], limit=10):
dim = len(embeds)
if dim not in self.collections:
self.init_collection(dim)
coll = self.collections[dim]
search_params = {
"metric_type": "COSINE",
"params": {
"radius": 0.1,
"range_filter": 0.8
}
}
print("Loading...")
self.client.load_collection(
collection_name=coll,
)
print("Searching...")
res = self.client.search(
collection_name=coll,
data=[embeds],
limit=limit,
output_fields=fields,
search_params=search_params,
)[0]
# If reload time has passed, unload collection
if time.time() > self.next_reload:
print("Unloading, reload at", self.next_reload)
self.client.release_collection(
collection_name=coll,
)
self.next_reload = time.time() + self.reload_time
return res

View file

@ -0,0 +1,154 @@
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
import time
class ObjectVectors:
def __init__(self, uri="http://localhost:19530", prefix='obj'):
self.client = MilvusClient(uri=uri)
# Strategy is to create collections per dimension. Probably only
# going to be using 1 anyway, but that means we don't need to
# hard-code the dimension anywhere, and no big deal if more than
# one are created.
self.collections = {}
self.prefix = prefix
# Time between reloads
self.reload_time = 90
# Next time to reload - this forces a reload at next window
self.next_reload = time.time() + self.reload_time
print("Reload at", self.next_reload)
def init_collection(self, dimension, name):
collection_name = self.prefix + "_" + name + "_" + str(dimension)
pkey_field = FieldSchema(
name="id",
dtype=DataType.INT64,
is_primary=True,
auto_id=True,
)
vec_field = FieldSchema(
name="vector",
dtype=DataType.FLOAT_VECTOR,
dim=dimension,
)
name_field = FieldSchema(
name="name",
dtype=DataType.VARCHAR,
max_length=65535,
)
key_name_field = FieldSchema(
name="key_name",
dtype=DataType.VARCHAR,
max_length=65535,
)
key_field = FieldSchema(
name="key",
dtype=DataType.VARCHAR,
max_length=65535,
)
schema = CollectionSchema(
fields = [
pkey_field, vec_field, name_field, key_name_field, key_field
],
description = "Object embedding schema",
)
self.client.create_collection(
collection_name=collection_name,
schema=schema,
metric_type="COSINE",
)
index_params = MilvusClient.prepare_index_params()
index_params.add_index(
field_name="vector",
metric_type="COSINE",
index_type="IVF_SQ8",
index_name="vector_index",
params={ "nlist": 128 }
)
self.client.create_index(
collection_name=collection_name,
index_params=index_params
)
self.collections[(dimension, name)] = collection_name
def insert(self, embeds, name, key_name, key):
dim = len(embeds)
if (dim, name) not in self.collections:
self.init_collection(dim, name)
data = [
{
"vector": embeds,
"name": name,
"key_name": key_name,
"key": key,
}
]
self.client.insert(
collection_name=self.collections[(dim, name)],
data=data
)
def search(self, embeds, name, fields=["key_name", "name"], limit=10):
dim = len(embeds)
if dim not in self.collections:
self.init_collection(dim, name)
coll = self.collections[(dim, name)]
search_params = {
"metric_type": "COSINE",
"params": {
"radius": 0.1,
"range_filter": 0.8
}
}
print("Loading...")
self.client.load_collection(
collection_name=coll,
)
print("Searching...")
res = self.client.search(
collection_name=coll,
data=[embeds],
limit=limit,
output_fields=fields,
search_params=search_params,
)[0]
# If reload time has passed, unload collection
if time.time() > self.next_reload:
print("Unloading, reload at", self.next_reload)
self.client.release_collection(
collection_name=coll,
)
self.next_reload = time.time() + self.reload_time
return res

View file

@ -0,0 +1,132 @@
from . clients.document_embeddings_client import DocumentEmbeddingsClient
from . clients.triples_query_client import TriplesQueryClient
from . clients.embeddings_client import EmbeddingsClient
from . clients.prompt_client import PromptClient
from . schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
from . schema import TriplesQueryRequest, TriplesQueryResponse
from . schema import prompt_request_queue
from . schema import prompt_response_queue
from . schema import embeddings_request_queue
from . schema import embeddings_response_queue
from . schema import document_embeddings_request_queue
from . schema import document_embeddings_response_queue
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
class DocumentRag:
def __init__(
self,
pulsar_host="pulsar://pulsar:6650",
pr_request_queue=None,
pr_response_queue=None,
emb_request_queue=None,
emb_response_queue=None,
de_request_queue=None,
de_response_queue=None,
verbose=False,
module="test",
):
self.verbose=verbose
if pr_request_queue is None:
pr_request_queue = prompt_request_queue
if pr_response_queue is None:
pr_response_queue = prompt_response_queue
if emb_request_queue is None:
emb_request_queue = embeddings_request_queue
if emb_response_queue is None:
emb_response_queue = embeddings_response_queue
if de_request_queue is None:
de_request_queue = document_embeddings_request_queue
if de_response_queue is None:
de_response_queue = document_embeddings_response_queue
if self.verbose:
print("Initialising...", flush=True)
# FIXME: Configurable
self.entity_limit = 20
self.de_client = DocumentEmbeddingsClient(
pulsar_host=pulsar_host,
subscriber=module + "-de",
input_queue=de_request_queue,
output_queue=de_response_queue,
)
self.embeddings = EmbeddingsClient(
pulsar_host=pulsar_host,
input_queue=emb_request_queue,
output_queue=emb_response_queue,
subscriber=module + "-emb",
)
self.lang = PromptClient(
pulsar_host=pulsar_host,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber=module + "-de-prompt",
)
if self.verbose:
print("Initialised", flush=True)
def get_vector(self, query):
if self.verbose:
print("Compute embeddings...", flush=True)
qembeds = self.embeddings.request(query)
if self.verbose:
print("Done.", flush=True)
return qembeds
def get_docs(self, query):
vectors = self.get_vector(query)
if self.verbose:
print("Get entities...", flush=True)
docs = self.de_client.request(
vectors, self.entity_limit
)
if self.verbose:
print("Docs:", flush=True)
for doc in docs:
print(doc, flush=True)
return docs
def query(self, query):
if self.verbose:
print("Construct prompt...", flush=True)
docs = self.get_docs(query)
if self.verbose:
print("Invoke LLM...", flush=True)
print(docs)
print(query)
resp = self.lang.request_document_prompt(query, docs)
if self.verbose:
print("Done", flush=True)
return resp

View file

@ -0,0 +1,3 @@
from . processor import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . write import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,85 @@
"""
Write graph embeddings to parquet files in a directory.
"""
import pulsar
import base64
import os
import argparse
import time
from .... schema import GraphEmbeddings
from .... schema import graph_embeddings_store_queue
from .... base import Consumer
from . writer import ParquetWriter
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = graph_embeddings_store_queue
default_subscriber = module
default_graph_host='localhost'
default_directory = "."
default_file_template = "graph-embeds-{id}.parquet"
default_rotation_time = 60
class Processor(Consumer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
directory = params.get("directory", default_directory)
file_template = params.get("file_template", default_file_template)
rotation_time = params.get("rotation_time", default_rotation_time)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": GraphEmbeddings,
}
)
self.writer = ParquetWriter(directory, file_template, rotation_time)
def __del__(self):
if hasattr(self, "writer"):
del self.writer
def handle(self, msg):
v = msg.value()
self.writer.write(v.vectors, v.entity.value)
@staticmethod
def add_args(parser):
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
parser.add_argument(
'-d', '--directory',
default=default_directory,
help=f'Directory to write to (default: {default_directory})'
)
parser.add_argument(
'-f', '--file-template',
default=default_file_template,
help=f'Directory to write to (default: {default_file_template})'
)
parser.add_argument(
'-t', '--rotation-time',
type=int,
default=default_rotation_time,
help=f'Rotation time / seconds (default: {default_rotation_time})'
)
def run():
Processor.start(module, __doc__)

View file

@ -0,0 +1,94 @@
import threading
import queue
import time
import uuid
import pyarrow as pa
import pyarrow.parquet as pq
class ParquetWriter:
def __init__(self, directory, file_template, rotation_time):
self.directory = directory
self.file_template = file_template
self.rotation_time = rotation_time
self.q = queue.Queue()
self.running = True
self.thread = threading.Thread(target=(self.writer_thread))
self.thread.start()
def writer_thread(self):
items = []
timeout = None
while self.running:
try:
item = self.q.get(timeout=1)
if timeout == None:
timeout = time.time() + self.rotation_time
items.append(item)
except queue.Empty:
pass
if timeout:
if time.time() > timeout:
self.write_file(items)
timeout = None
items = []
def write_file(self, items):
try:
schema = pa.schema([
pa.field('embeddings', pa.list_(pa.list_(pa.float64()))),
pa.field('entity', pa.string()),
])
fname = self.file_template.format(id=str(uuid.uuid4()))
path = f"{self.directory}/{fname}"
writer = pq.ParquetWriter(path, schema)
batch = pa.record_batch(
[
[i[0] for i in items],
[i[1] for i in items],
],
names=['embeddings', 'entity']
)
writer.write_batch(batch)
writer.close()
print(f"Wrote {path}.")
except Exception as e:
print("Parquet write:", e)
def write(self, embeds, ent):
self.q.put((embeds, ent))
def __del__(self):
self.running = False
if hasattr(self, "q"):
self.thread.join()

View file

@ -0,0 +1,3 @@
from . processor import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . write import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,85 @@
"""
Write graphs triples to parquet files in a directory.
"""
import pulsar
import base64
import os
import argparse
import time
from .... schema import Triple
from .... schema import triples_store_queue
from .... base import Consumer
from . writer import ParquetWriter
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = triples_store_queue
default_subscriber = module
default_graph_host='localhost'
default_directory = "."
default_file_template = "triples-{id}.parquet"
default_rotation_time = 60
class Processor(Consumer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
directory = params.get("directory", default_directory)
file_template = params.get("file_template", default_file_template)
rotation_time = params.get("rotation_time", default_rotation_time)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": Triple,
}
)
self.writer = ParquetWriter(directory, file_template, rotation_time)
def __del__(self):
if hasattr(self, "writer"):
del self.writer
def handle(self, msg):
v = msg.value()
self.writer.write(v.s.value, v.p.value, v.o.value)
@staticmethod
def add_args(parser):
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
parser.add_argument(
'-d', '--directory',
default=default_directory,
help=f'Directory to write to (default: {default_directory})'
)
parser.add_argument(
'-f', '--file-template',
default=default_file_template,
help=f'Directory to write to (default: {default_file_template})'
)
parser.add_argument(
'-t', '--rotation-time',
type=int,
default=default_rotation_time,
help=f'Rotation time / seconds (default: {default_rotation_time})'
)
def run():
Processor.start(module, __doc__)

View file

@ -0,0 +1,96 @@
import threading
import queue
import time
import uuid
import pyarrow as pa
import pyarrow.parquet as pq
class ParquetWriter:
def __init__(self, directory, file_template, rotation_time):
self.directory = directory
self.file_template = file_template
self.rotation_time = rotation_time
self.q = queue.Queue()
self.running = True
self.thread = threading.Thread(target=(self.writer_thread))
self.thread.start()
def writer_thread(self):
triples = []
timeout = None
while self.running:
try:
item = self.q.get(timeout=1)
if timeout == None:
timeout = time.time() + self.rotation_time
triples.append(item)
except queue.Empty:
pass
if timeout:
if time.time() > timeout:
self.write_file(triples)
timeout = None
triples = []
def write_file(self, triples):
try:
schema = pa.schema([
pa.field('s', pa.string()),
pa.field('p', pa.string()),
pa.field('o', pa.string()),
])
fname = self.file_template.format(id=str(uuid.uuid4()))
path = f"{self.directory}/{fname}"
writer = pq.ParquetWriter(path, schema)
batch = pa.record_batch(
[
[tpl[0] for tpl in triples],
[tpl[1] for tpl in triples],
[tpl[2] for tpl in triples],
],
names=['s', 'p', 'o']
)
writer.write_batch(batch)
writer.close()
print(f"Wrote {path}.")
except Exception as e:
print("Parquet write:", e)
def write(self, s, p, o):
self.q.put((s, p, o))
def __del__(self):
self.running = False
if hasattr(self, "q"):
self.thread.join()

View file

@ -0,0 +1,3 @@
from . processor import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . processor import run
if __name__ == '__main__':
run()

Some files were not shown because too many files have changed in this diff Show more