#!/usr/bin/env python3 """ This utility connects to a running TrustGraph through the API and creates a knowledge core from the data streaming through the processing queues. For completeness of data, tg-save-kg-core should be initiated before data loading takes place. The default output format, msgpack should be used. JSON output format is also available - msgpack produces a more compact representation, which is also more performant to load. """ import aiohttp import asyncio import msgpack import json import sys import argparse import os import signal class Running: def __init__(self): self.running = True def get(self): return self.running def stop(self): self.running = False async def fetch_de(running, queue, user, collection, url): async with aiohttp.ClientSession() as session: de_url = f"{url}stream/document-embeddings" async with session.ws_connect(de_url) as ws: while running.get(): try: msg = await asyncio.wait_for(ws.receive(), 1) except: continue if msg.type == aiohttp.WSMsgType.TEXT: data = msg.json() if user: if data["metadata"]["user"] != user: continue if collection: if data["metadata"]["collection"] != collection: continue await queue.put([ "de", { "m": { "i": data["metadata"]["id"], "m": data["metadata"]["metadata"], "u": data["metadata"]["user"], "c": data["metadata"]["collection"], }, "c": [ { "c": chunk["chunk"], "v": chunk["vectors"], } for chunk in data["chunks"] ] } ]) if msg.type == aiohttp.WSMsgType.ERROR: print("Error") break de_counts = 0 async def stats(running): global t_counts global de_counts while running.get(): await asyncio.sleep(2) print( f"Document embeddings: {de_counts:10d}" ) async def output(running, queue, path, format): global t_counts global de_counts with open(path, "wb") as f: while running.get(): try: msg = await asyncio.wait_for(queue.get(), 0.5) except: # Hopefully it's TimeoutError. Annoying to match since # it changed in 3.11. continue if format == "msgpack": f.write(msgpack.packb(msg, use_bin_type=True)) else: f.write(json.dumps(msg).encode("utf-8")) if msg[0] == "de": de_counts += 1 print("Output file closed") async def run(running, **args): q = asyncio.Queue() de_task = asyncio.create_task( fetch_de( running=running, queue=q, user=args["user"], collection=args["collection"], url=args["url"] + "api/v1/" ) ) output_task = asyncio.create_task( output( running=running, queue=q, path=args["output_file"], format=args["format"], ) ) stats_task = asyncio.create_task(stats(running)) await output_task await de_task await stats_task print("Exiting") async def main(running): parser = argparse.ArgumentParser( prog='tg-save-kg-core', description=__doc__, ) default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/") default_user = "trustgraph" collection = "default" parser.add_argument( '-u', '--url', default=default_url, help=f'TrustGraph API URL (default: {default_url})', ) parser.add_argument( '-o', '--output-file', # Make it mandatory, difficult to over-write an existing file required=True, help=f'Output file' ) parser.add_argument( '--format', default="msgpack", choices=["msgpack", "json"], help=f'Output format (default: msgpack)', ) parser.add_argument( '--user', help=f'User ID to filter on (default: no filter)' ) parser.add_argument( '--collection', help=f'Collection ID to filter on (default: no filter)' ) args = parser.parse_args() await run(running, **vars(args)) running = Running() def interrupt(sig, frame): running.stop() print('Interrupt') signal.signal(signal.SIGINT, interrupt) asyncio.run(main(running))