#!/usr/bin/env python3 """This utility takes a knowledge core and loads it into a running TrustGraph through the API. The knowledge core should be in msgpack format, which is the default format produce by tg-save-kg-core. """ import aiohttp import asyncio import msgpack import json import sys import argparse import os import signal class Running: def __init__(self): self.running = True def get(self): return self.running def stop(self): self.running = False de_counts = 0 async def load_de(running, queue, url): global de_counts async with aiohttp.ClientSession() as session: async with session.ws_connect(f"{url}load/document-embeddings") as ws: while running.get(): try: msg = await asyncio.wait_for(queue.get(), 1) # End of load if msg is None: break except: # Hopefully it's TimeoutError. Annoying to match since # it changed in 3.11. continue msg = { "metadata": { "id": msg["m"]["i"], "metadata": msg["m"]["m"], "user": msg["m"]["u"], "collection": msg["m"]["c"], }, "chunks": [ { "chunk": chunk["c"], "vectors": chunk["v"], } for chunk in msg["c"] ], } try: await ws.send_json(msg) except Exception as e: print(e) de_counts += 1 async def stats(running): global de_counts while running.get(): await asyncio.sleep(2) print( f"Graph embeddings: {de_counts:10d}" ) async def loader(running, de_queue, path, format, user, collection): if format == "json": raise RuntimeError("Not implemented") else: with open(path, "rb") as f: unpacker = msgpack.Unpacker(f, raw=False) while running.get(): try: unpacked = unpacker.unpack() except: break if user: unpacked["metadata"]["user"] = user if collection: unpacked["metadata"]["collection"] = collection if unpacked[0] == "de": qtype = de_queue while running.get(): try: await asyncio.wait_for(qtype.put(unpacked[1]), 0.5) # Successful put message, move on break except: # Hopefully it's TimeoutError. Annoying to match since # it changed in 3.11. continue if not running.get(): break # Put 'None' on end of queue to finish while running.get(): try: await asyncio.wait_for(de_queue.put(None), 1) # Successful put message, move on break except: # Hopefully it's TimeoutError. Annoying to match since # it changed in 3.11. continue async def run(running, **args): # Maxsize on queues reduces back-pressure so tg-load-kg-core doesn't # grow to eat all memory de_q = asyncio.Queue(maxsize=10) load_task = asyncio.create_task( loader( running=running, de_queue=de_q, path=args["input_file"], format=args["format"], user=args["user"], collection=args["collection"], ) ) de_task = asyncio.create_task( load_de( running=running, queue=de_q, url=args["url"] + "api/v1/" ) ) stats_task = asyncio.create_task(stats(running)) await de_task running.stop() await load_task await stats_task async def main(running): parser = argparse.ArgumentParser( prog='tg-load-kg-core', description=__doc__, ) default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/") default_user = "trustgraph" collection = "default" parser.add_argument( '-u', '--url', default=default_url, help=f'TrustGraph API URL (default: {default_url})', ) parser.add_argument( '-i', '--input-file', # Make it mandatory, difficult to over-write an existing file required=True, help=f'Output file' ) parser.add_argument( '--format', default="msgpack", choices=["msgpack", "json"], help=f'Output format (default: msgpack)', ) parser.add_argument( '--user', help=f'User ID to load as (default: from input)' ) parser.add_argument( '--collection', help=f'Collection ID to load as (default: from input)' ) args = parser.parse_args() await run(running, **vars(args)) running = Running() def interrupt(sig, frame): running.stop() print('Interrupt') signal.signal(signal.SIGINT, interrupt) asyncio.run(main(running))