Feature/knowledge load (#372)

* Switch off retry in Cassandra until we can differentiate retryable errors

* Fix config getvalues

* Loading knowledge cores works
This commit is contained in:
cybermaggedon 2025-05-08 00:41:45 +01:00 committed by GitHub
parent fdd9a9a9ae
commit 31b7ade44d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 356 additions and 548 deletions

View file

@ -1,303 +1,80 @@
#!/usr/bin/env python3
"""
This utility takes a knowledge core and loads it into a running TrustGraph
through the API. The knowledge core should be in msgpack format, which is the
default format produce by tg-save-kg-core.
Starts a load operation on a knowledge core which is already stored by
the knowledge manager. You could load a core with tg-put-kg-core and then
run this utility.
"""
import aiohttp
import asyncio
import msgpack
import json
import sys
import argparse
import os
import signal
import tabulate
from trustgraph.api import Api
import json
class Running:
def __init__(self): self.running = True
def get(self): return self.running
def stop(self): self.running = False
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_flow = "0000"
default_collection = "default"
ge_counts = 0
t_counts = 0
def load_kg_core(url, user, id, flow, collection):
async def load_ge(running, queue, url):
api = Api(url).knowledge()
global ge_counts
class_names = api.load_kg_core(user = user, id = id, flow=flow,
collection=collection)
async with aiohttp.ClientSession() as session:
def main():
async with session.ws_connect(url) as ws:
while running.get():
try:
msg = await asyncio.wait_for(queue.get(), 1)
# End of load
if msg is None:
break
except:
# Hopefully it's TimeoutError. Annoying to match since
# it changed in 3.11.
continue
msg = {
"metadata": {
"id": msg["m"]["i"],
"metadata": msg["m"]["m"],
"user": msg["m"]["u"],
"collection": msg["m"]["c"],
},
"entities": [
{
"entity": ent["e"],
"vectors": ent["v"],
}
for ent in msg["e"]
],
}
try:
await ws.send_json(msg)
except Exception as e:
print(e)
ge_counts += 1
async def load_triples(running, queue, url):
global t_counts
async with aiohttp.ClientSession() as session:
async with session.ws_connect(url) as ws:
while running.get():
try:
msg = await asyncio.wait_for(queue.get(), 1)
# End of load
if msg is None:
break
except:
# Hopefully it's TimeoutError. Annoying to match since
# it changed in 3.11.
continue
msg ={
"metadata": {
"id": msg["m"]["i"],
"metadata": msg["m"]["m"],
"user": msg["m"]["u"],
"collection": msg["m"]["c"],
},
"triples": msg["t"],
}
try:
await ws.send_json(msg)
except Exception as e:
print(e)
t_counts += 1
async def stats(running):
global t_counts
global ge_counts
while running.get():
await asyncio.sleep(2)
print(
f"Graph embeddings: {ge_counts:10d} Triples: {t_counts:10d}"
)
async def loader(running, ge_queue, t_queue, path, format, user, collection):
if format == "json":
raise RuntimeError("Not implemented")
else:
with open(path, "rb") as f:
unpacker = msgpack.Unpacker(f, raw=False)
while running.get():
try:
unpacked = unpacker.unpack()
except:
break
if user:
unpacked["metadata"]["user"] = user
if collection:
unpacked["metadata"]["collection"] = collection
if unpacked[0] == "t":
qtype = t_queue
else:
if unpacked[0] == "ge":
qtype = ge_queue
while running.get():
try:
await asyncio.wait_for(qtype.put(unpacked[1]), 0.5)
# Successful put message, move on
break
except:
# Hopefully it's TimeoutError. Annoying to match since
# it changed in 3.11.
continue
if not running.get(): break
# Put 'None' on end of queue to finish
while running.get():
try:
await asyncio.wait_for(t_queue.put(None), 1)
# Successful put message, move on
break
except:
# Hopefully it's TimeoutError. Annoying to match since
# it changed in 3.11.
continue
# Put 'None' on end of queue to finish
while running.get():
try:
await asyncio.wait_for(ge_queue.put(None), 1)
# Successful put message, move on
break
except:
# Hopefully it's TimeoutError. Annoying to match since
# it changed in 3.11.
continue
async def run(running, **args):
# Maxsize on queues reduces back-pressure so tg-load-kg-core doesn't
# grow to eat all memory
ge_q = asyncio.Queue(maxsize=10)
t_q = asyncio.Queue(maxsize=10)
flow_id = args["flow_id"]
url = args["url"]
load_task = asyncio.create_task(
loader(
running=running,
ge_queue=ge_q, t_queue=t_q,
path=args["input_file"], format=args["format"],
user=args["user"], collection=args["collection"],
)
)
ge_task = asyncio.create_task(
load_ge(
running = running,
queue = ge_q,
url = f"{url}api/v1/flow/{flow_id}/import/graph-embeddings"
)
)
triples_task = asyncio.create_task(
load_triples(
running = running,
queue = t_q,
url = f"{url}api/v1/flow/{flow_id}/import/triples"
)
)
stats_task = asyncio.create_task(stats(running))
await triples_task
await ge_task
running.stop()
await load_task
await stats_task
async def main(running):
parser = argparse.ArgumentParser(
prog='tg-load-kg-core',
prog='tg-delete-flow-class',
description=__doc__,
)
default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/")
default_user = "trustgraph"
collection = "default"
parser.add_argument(
'-u', '--url',
'-u', '--api-url',
default=default_url,
help=f'TrustGraph API URL (default: {default_url})',
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-i', '--input-file',
# Make it mandatory, difficult to over-write an existing file
'-U', '--user',
default="trustgraph",
help='API URL (default: trustgraph)',
)
parser.add_argument(
'--id', '--identifier',
required=True,
help=f'Output file'
help=f'Knowledge core ID',
)
parser.add_argument(
'-f', '--flow-id',
default="0000",
help=f'Flow ID (default: 0000)'
default=default_flow,
help=f'Flow ID (default: {default_flow}',
)
parser.add_argument(
'--format',
default="msgpack",
choices=["msgpack", "json"],
help=f'Output format (default: msgpack)',
)
parser.add_argument(
'--user',
help=f'User ID to load as (default: from input)'
)
parser.add_argument(
'--collection',
help=f'Collection ID to load as (default: from input)'
'-c', '--collection',
default=default_collection,
help=f'Collection ID (default: {default_collection}',
)
args = parser.parse_args()
await run(running, **vars(args))
try:
running = Running()
load_kg_core(
url=args.api_url,
user=args.user,
id=args.id,
flow=args.flow_id,
collection=args.collection,
)
def interrupt(sig, frame):
running.stop()
print('Interrupt')
except Exception as e:
signal.signal(signal.SIGINT, interrupt)
print("Exception:", e, flush=True)
asyncio.run(main(running))
main()