#!/usr/bin/env python3 """This utility takes a knowledge core and loads it into a running TrustGraph through the API. The knowledge core should be in msgpack format, which is the default format produce by tg-save-kg-core. """ import aiohttp import asyncio import msgpack import json import sys import argparse import os import signal class Running: def __init__(self): self.running = True def get(self): return self.running def stop(self): self.running = False ge_counts = 0 t_counts = 0 async def load_ge(running, queue, url): global ge_counts async with aiohttp.ClientSession() as session: async with session.ws_connect(f"{url}load/graph-embeddings") as ws: while running.get(): try: msg = await asyncio.wait_for(queue.get(), 1) # End of load if msg is None: break except: # Hopefully it's TimeoutError. Annoying to match since # it changed in 3.11. continue msg = { "metadata": { "id": msg["m"]["i"], "metadata": msg["m"]["m"], "user": msg["m"]["u"], "collection": msg["m"]["c"], }, "entities": [ { "entity": ent["e"], "vectors": ent["v"], } for ent in msg["e"] ], } try: await ws.send_json(msg) except Exception as e: print(e) ge_counts += 1 async def load_triples(running, queue, url): global t_counts async with aiohttp.ClientSession() as session: async with session.ws_connect(f"{url}load/triples") as ws: while running.get(): try: msg = await asyncio.wait_for(queue.get(), 1) # End of load if msg is None: break except: # Hopefully it's TimeoutError. Annoying to match since # it changed in 3.11. continue msg ={ "metadata": { "id": msg["m"]["i"], "metadata": msg["m"]["m"], "user": msg["m"]["u"], "collection": msg["m"]["c"], }, "triples": msg["t"], } try: await ws.send_json(msg) except Exception as e: print(e) t_counts += 1 async def stats(running): global t_counts global ge_counts while running.get(): await asyncio.sleep(2) print( f"Graph embeddings: {ge_counts:10d} Triples: {t_counts:10d}" ) async def loader(running, ge_queue, t_queue, path, format, user, collection): if format == "json": raise RuntimeError("Not implemented") else: with open(path, "rb") as f: unpacker = msgpack.Unpacker(f, raw=False) while running.get(): try: unpacked = unpacker.unpack() except: break if user: unpacked["metadata"]["user"] = user if collection: unpacked["metadata"]["collection"] = collection if unpacked[0] == "t": qtype = t_queue else: if unpacked[0] == "ge": qtype = ge_queue while running.get(): try: await asyncio.wait_for(qtype.put(unpacked[1]), 0.5) # Successful put message, move on break except: # Hopefully it's TimeoutError. Annoying to match since # it changed in 3.11. continue if not running.get(): break # Put 'None' on end of queue to finish while running.get(): try: await asyncio.wait_for(t_queue.put(None), 1) # Successful put message, move on break except: # Hopefully it's TimeoutError. Annoying to match since # it changed in 3.11. continue # Put 'None' on end of queue to finish while running.get(): try: await asyncio.wait_for(ge_queue.put(None), 1) # Successful put message, move on break except: # Hopefully it's TimeoutError. Annoying to match since # it changed in 3.11. continue async def run(running, **args): # Maxsize on queues reduces back-pressure so tg-load-kg-core doesn't # grow to eat all memory ge_q = asyncio.Queue(maxsize=10) t_q = asyncio.Queue(maxsize=10) load_task = asyncio.create_task( loader( running=running, ge_queue=ge_q, t_queue=t_q, path=args["input_file"], format=args["format"], user=args["user"], collection=args["collection"], ) ) ge_task = asyncio.create_task( load_ge( running=running, queue=ge_q, url=args["url"] + "api/v1/" ) ) triples_task = asyncio.create_task( load_triples( running=running, queue=t_q, url=args["url"] + "api/v1/" ) ) stats_task = asyncio.create_task(stats(running)) await triples_task await ge_task running.stop() await load_task await stats_task async def main(running): parser = argparse.ArgumentParser( prog='tg-load-kg-core', description=__doc__, ) default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/") default_user = "trustgraph" collection = "default" parser.add_argument( '-u', '--url', default=default_url, help=f'TrustGraph API URL (default: {default_url})', ) parser.add_argument( '-i', '--input-file', # Make it mandatory, difficult to over-write an existing file required=True, help=f'Output file' ) parser.add_argument( '--format', default="msgpack", choices=["msgpack", "json"], help=f'Output format (default: msgpack)', ) parser.add_argument( '--user', help=f'User ID to load as (default: from input)' ) parser.add_argument( '--collection', help=f'Collection ID to load as (default: from input)' ) args = parser.parse_args() await run(running, **vars(args)) running = Running() def interrupt(sig, frame): running.stop() print('Interrupt') signal.signal(signal.SIGINT, interrupt) asyncio.run(main(running))