#!/usr/bin/env python3 """This utility takes a knowledge core and loads it into a running TrustGraph through the API. The knowledge core should be in msgpack format, which is the default format produce by tg-save-kg-core. """ import aiohttp import asyncio import msgpack import json import sys import argparse import os async def load_ge(queue, url): async with aiohttp.ClientSession() as session: async with session.ws_connect(f"{url}load/graph-embeddings") as ws: while True: msg = await queue.get() msg = { "metadata": { "id": msg["m"]["i"], "metadata": msg["m"]["m"], "user": msg["m"]["u"], "collection": msg["m"]["c"], }, "vectors": msg["v"], "entity": msg["e"], } await ws.send_json(msg) async def load_triples(queue, url): async with aiohttp.ClientSession() as session: async with session.ws_connect(f"{url}load/triples") as ws: while True: msg = await queue.get() msg ={ "metadata": { "id": msg["m"]["i"], "metadata": msg["m"]["m"], "user": msg["m"]["u"], "collection": msg["m"]["c"], }, "triples": msg["t"], } await ws.send_json(msg) ge_counts = 0 t_counts = 0 async def stats(): global t_counts global ge_counts while True: await asyncio.sleep(5) print( f"Graph embeddings: {ge_counts:10d} Triples: {t_counts:10d}" ) async def loader(ge_queue, t_queue, path, format, user, collection): global t_counts global ge_counts if format == "json": raise RuntimeError("Not implemented") else: with open(path, "rb") as f: unpacker = msgpack.Unpacker(f, raw=False) for unpacked in unpacker: if user: unpacked["metadata"]["user"] = user if collection: unpacked["metadata"]["collection"] = collection if unpacked[0] == "t": await t_queue.put(unpacked[1]) t_counts += 1 else: if unpacked[0] == "ge": await ge_queue.put(unpacked[1]) ge_counts += 1 async def run(**args): # Maxsize on queues reduces back-pressure so tg-load-kg-core doesn't # grow to eat all memory ge_q = asyncio.Queue(maxsize=500) t_q = asyncio.Queue(maxsize=500) load_task = asyncio.create_task( loader( ge_queue=ge_q, t_queue=t_q, path=args["input_file"], format=args["format"], user=args["user"], collection=args["collection"], ) ) ge_task = asyncio.create_task( load_ge( queue=ge_q, url=args["url"] + "api/v1/" ) ) triples_task = asyncio.create_task( load_triples( queue=t_q, url=args["url"] + "api/v1/" ) ) stats_task = asyncio.create_task(stats()) await load_task await triples_task await ge_task await stats_task async def main(): parser = argparse.ArgumentParser( prog='tg-load-kg-core', description=__doc__, ) default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/") default_user = "trustgraph" collection = "default" parser.add_argument( '-u', '--url', default=default_url, help=f'TrustGraph API URL (default: {default_url})', ) parser.add_argument( '-i', '--input-file', # Make it mandatory, difficult to over-write an existing file required=True, help=f'Output file' ) parser.add_argument( '--format', default="msgpack", choices=["msgpack", "json"], help=f'Output format (default: msgpack)', ) parser.add_argument( '--user', help=f'User ID to load as (default: from input)' ) parser.add_argument( '--collection', help=f'Collection ID to load as (default: from input)' ) args = parser.parse_args() await run(**vars(args)) asyncio.run(main())