diff --git a/trustgraph-cli/scripts/tg-load-doc-embeds b/trustgraph-cli/scripts/tg-load-doc-embeds new file mode 100755 index 00000000..d445ec5a --- /dev/null +++ b/trustgraph-cli/scripts/tg-load-doc-embeds @@ -0,0 +1,224 @@ +#!/usr/bin/env python3 + +"""This utility takes a knowledge core and loads it into a running TrustGraph +through the API. The knowledge core should be in msgpack format, which is the +default format produce by tg-save-kg-core. +""" + +import aiohttp +import asyncio +import msgpack +import json +import sys +import argparse +import os +import signal + +class Running: + def __init__(self): self.running = True + def get(self): return self.running + def stop(self): self.running = False + +de_counts = 0 + +async def load_de(running, queue, url): + + global de_counts + + async with aiohttp.ClientSession() as session: + + async with session.ws_connect(f"{url}load/document-embeddings") as ws: + + while running.get(): + + try: + msg = await asyncio.wait_for(queue.get(), 1) + + # End of load + if msg is None: + break + + except: + # Hopefully it's TimeoutError. Annoying to match since + # it changed in 3.11. + continue + + msg = { + "metadata": { + "id": msg["m"]["i"], + "metadata": msg["m"]["m"], + "user": msg["m"]["u"], + "collection": msg["m"]["c"], + }, + "chunks": [ + { + "chunk": chunk["c"], + "vectors": chunk["v"], + } + for chunk in msg["c"] + ], + } + + try: + await ws.send_json(msg) + except Exception as e: + print(e) + + de_counts += 1 + +async def stats(running): + + global de_counts + + while running.get(): + + await asyncio.sleep(2) + + print( + f"Graph embeddings: {de_counts:10d}" + ) + +async def loader(running, de_queue, path, format, user, collection): + + if format == "json": + + raise RuntimeError("Not implemented") + + else: + + with open(path, "rb") as f: + + unpacker = msgpack.Unpacker(f, raw=False) + + while running.get(): + + try: + unpacked = unpacker.unpack() + except: + break + + if user: + unpacked["metadata"]["user"] = user + + if collection: + unpacked["metadata"]["collection"] = collection + + if unpacked[0] == "de": + qtype = de_queue + + while running.get(): + + try: + await asyncio.wait_for(qtype.put(unpacked[1]), 0.5) + + # Successful put message, move on + break + + except: + # Hopefully it's TimeoutError. Annoying to match since + # it changed in 3.11. + continue + + if not running.get(): break + + # Put 'None' on end of queue to finish + while running.get(): + + try: + await asyncio.wait_for(de_queue.put(None), 1) + + # Successful put message, move on + break + + except: + # Hopefully it's TimeoutError. Annoying to match since + # it changed in 3.11. + continue + +async def run(running, **args): + + # Maxsize on queues reduces back-pressure so tg-load-kg-core doesn't + # grow to eat all memory + de_q = asyncio.Queue(maxsize=10) + + load_task = asyncio.create_task( + loader( + running=running, + de_queue=de_q, + path=args["input_file"], format=args["format"], + user=args["user"], collection=args["collection"], + ) + + ) + + de_task = asyncio.create_task( + load_de( + running=running, + queue=de_q, url=args["url"] + "api/v1/" + ) + ) + + stats_task = asyncio.create_task(stats(running)) + + await de_task + + running.stop() + + await load_task + await stats_task + +async def main(running): + + parser = argparse.ArgumentParser( + prog='tg-load-kg-core', + description=__doc__, + ) + + default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/") + default_user = "trustgraph" + collection = "default" + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'TrustGraph API URL (default: {default_url})', + ) + + parser.add_argument( + '-i', '--input-file', + # Make it mandatory, difficult to over-write an existing file + required=True, + help=f'Output file' + ) + + parser.add_argument( + '--format', + default="msgpack", + choices=["msgpack", "json"], + help=f'Output format (default: msgpack)', + ) + + parser.add_argument( + '--user', + help=f'User ID to load as (default: from input)' + ) + + parser.add_argument( + '--collection', + help=f'Collection ID to load as (default: from input)' + ) + + args = parser.parse_args() + + await run(running, **vars(args)) + +running = Running() + +def interrupt(sig, frame): + running.stop() + print('Interrupt') + +signal.signal(signal.SIGINT, interrupt) + +asyncio.run(main(running)) + diff --git a/trustgraph-cli/scripts/tg-save-doc-embeds b/trustgraph-cli/scripts/tg-save-doc-embeds new file mode 100755 index 00000000..95f8b748 --- /dev/null +++ b/trustgraph-cli/scripts/tg-save-doc-embeds @@ -0,0 +1,198 @@ +#!/usr/bin/env python3 + +""" +This utility connects to a running TrustGraph through the API and creates +a knowledge core from the data streaming through the processing queues. +For completeness of data, tg-save-kg-core should be initiated before data +loading takes place. The default output format, msgpack should be used. +JSON output format is also available - msgpack produces a more compact +representation, which is also more performant to load. +""" + +import aiohttp +import asyncio +import msgpack +import json +import sys +import argparse +import os +import signal + +class Running: + def __init__(self): self.running = True + def get(self): return self.running + def stop(self): self.running = False + +async def fetch_de(running, queue, user, collection, url): + + async with aiohttp.ClientSession() as session: + + de_url = f"{url}stream/document-embeddings" + + async with session.ws_connect(de_url) as ws: + + while running.get(): + + try: + msg = await asyncio.wait_for(ws.receive(), 1) + except: + continue + + if msg.type == aiohttp.WSMsgType.TEXT: + + data = msg.json() + + if user: + if data["metadata"]["user"] != user: + continue + + if collection: + if data["metadata"]["collection"] != collection: + continue + + await queue.put([ + "de", + { + "m": { + "i": data["metadata"]["id"], + "m": data["metadata"]["metadata"], + "u": data["metadata"]["user"], + "c": data["metadata"]["collection"], + }, + "c": [ + { + "c": chunk["chunk"], + "v": chunk["vectors"], + } + for chunk in data["chunks"] + ] + } + ]) + if msg.type == aiohttp.WSMsgType.ERROR: + print("Error") + break + +de_counts = 0 + +async def stats(running): + + global t_counts + global de_counts + + while running.get(): + + await asyncio.sleep(2) + + print( + f"Document embeddings: {de_counts:10d}" + ) + +async def output(running, queue, path, format): + + global t_counts + global de_counts + + with open(path, "wb") as f: + + while running.get(): + + try: + msg = await asyncio.wait_for(queue.get(), 0.5) + except: + # Hopefully it's TimeoutError. Annoying to match since + # it changed in 3.11. + continue + + if format == "msgpack": + f.write(msgpack.packb(msg, use_bin_type=True)) + else: + f.write(json.dumps(msg).encode("utf-8")) + + if msg[0] == "de": + de_counts += 1 + + print("Output file closed") + +async def run(running, **args): + + q = asyncio.Queue() + + de_task = asyncio.create_task( + fetch_de( + running=running, + queue=q, user=args["user"], collection=args["collection"], + url=args["url"] + "api/v1/" + ) + ) + + output_task = asyncio.create_task( + output( + running=running, queue=q, + path=args["output_file"], format=args["format"], + ) + + ) + + stats_task = asyncio.create_task(stats(running)) + + await output_task + await de_task + await stats_task + + print("Exiting") + +async def main(running): + + parser = argparse.ArgumentParser( + prog='tg-save-kg-core', + description=__doc__, + ) + + default_url = os.getenv("TRUSTGRAPH_API", "http://localhost:8088/") + default_user = "trustgraph" + collection = "default" + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'TrustGraph API URL (default: {default_url})', + ) + + parser.add_argument( + '-o', '--output-file', + # Make it mandatory, difficult to over-write an existing file + required=True, + help=f'Output file' + ) + + parser.add_argument( + '--format', + default="msgpack", + choices=["msgpack", "json"], + help=f'Output format (default: msgpack)', + ) + + parser.add_argument( + '--user', + help=f'User ID to filter on (default: no filter)' + ) + + parser.add_argument( + '--collection', + help=f'Collection ID to filter on (default: no filter)' + ) + + args = parser.parse_args() + + await run(running, **vars(args)) + +running = Running() + +def interrupt(sig, frame): + running.stop() + print('Interrupt') + +signal.signal(signal.SIGINT, interrupt) + +asyncio.run(main(running)) + diff --git a/trustgraph-cli/setup.py b/trustgraph-cli/setup.py index 67c70158..7edffd4b 100644 --- a/trustgraph-cli/setup.py +++ b/trustgraph-cli/setup.py @@ -54,10 +54,12 @@ setuptools.setup( "scripts/tg-invoke-llm", "scripts/tg-invoke-prompt", "scripts/tg-load-kg-core", + "scripts/tg-load-doc-embeds", "scripts/tg-load-pdf", "scripts/tg-load-text", "scripts/tg-load-turtle", "scripts/tg-processor-state", "scripts/tg-save-kg-core", + "scripts/tg-save-doc-embeds", ] ) diff --git a/trustgraph-flow/trustgraph/gateway/document_embeddings_load.py b/trustgraph-flow/trustgraph/gateway/document_embeddings_load.py new file mode 100644 index 00000000..1a7f635d --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/document_embeddings_load.py @@ -0,0 +1,64 @@ + +import asyncio +from pulsar.schema import JsonSchema +import uuid +from aiohttp import WSMsgType + +from .. schema import Metadata +from .. schema import DocumentEmbeddings, ChunkEmbeddings +from .. schema import document_embeddings_store_queue + +from . publisher import Publisher +from . socket import SocketEndpoint +from . serialize import to_subgraph + +class DocumentEmbeddingsLoadEndpoint(SocketEndpoint): + + def __init__( + self, pulsar_host, auth, path="/api/v1/load/document-embeddings", + ): + + super(DocumentEmbeddingsLoadEndpoint, self).__init__( + endpoint_path=path, auth=auth, + ) + + self.pulsar_host=pulsar_host + + self.publisher = Publisher( + self.pulsar_host, document_embeddings_store_queue, + schema=JsonSchema(DocumentEmbeddings) + ) + + async def start(self): + + self.publisher.start() + + async def listener(self, ws, running): + + async for msg in ws: + # On error, finish + if msg.type == WSMsgType.ERROR: + break + else: + + data = msg.json() + + elt = DocumentEmbeddings( + metadata=Metadata( + id=data["metadata"]["id"], + metadata=to_subgraph(data["metadata"]["metadata"]), + user=data["metadata"]["user"], + collection=data["metadata"]["collection"], + ), + chunks=[ + ChunkEmbeddings( + chunk=de["chunk"].encode("utf-8"), + vectors=de["vectors"], + ) + for de in data["chunks"] + ], + ) + + self.publisher.send(None, elt) + + running.stop() diff --git a/trustgraph-flow/trustgraph/gateway/document_embeddings_stream.py b/trustgraph-flow/trustgraph/gateway/document_embeddings_stream.py new file mode 100644 index 00000000..99cfb0a9 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/document_embeddings_stream.py @@ -0,0 +1,72 @@ + +import asyncio +import queue +from pulsar.schema import JsonSchema +import uuid + +from .. schema import DocumentEmbeddings +from .. schema import document_embeddings_store_queue + +from . subscriber import Subscriber +from . socket import SocketEndpoint +from . serialize import serialize_document_embeddings + +class DocumentEmbeddingsStreamEndpoint(SocketEndpoint): + + def __init__( + self, pulsar_host, auth, path="/api/v1/stream/document-embeddings" + ): + + super(DocumentEmbeddingsStreamEndpoint, self).__init__( + endpoint_path=path, auth=auth, + ) + + self.pulsar_host=pulsar_host + + self.subscriber = Subscriber( + self.pulsar_host, document_embeddings_store_queue, + "api-gateway", "api-gateway", + schema=JsonSchema(DocumentEmbeddings) + ) + + async def listener(self, ws, running): + + worker = asyncio.create_task( + self.async_thread(ws, running) + ) + + await super(DocumentEmbeddingsStreamEndpoint, self).listener( + ws, running + ) + + await worker + + async def start(self): + + self.subscriber.start() + + async def async_thread(self, ws, running): + + id = str(uuid.uuid4()) + + q = self.subscriber.subscribe_all(id) + + while running.get(): + try: + resp = await asyncio.to_thread(q.get, timeout=0.5) + await ws.send_json(serialize_document_embeddings(resp)) + + except TimeoutError: + continue + + except queue.Empty: + continue + + except Exception as e: + print(f"Exception: {str(e)}", flush=True) + break + + self.subscriber.unsubscribe_all(id) + + running.stop() + diff --git a/trustgraph-flow/trustgraph/gateway/serialize.py b/trustgraph-flow/trustgraph/gateway/serialize.py index 5f9930ad..40b6efc5 100644 --- a/trustgraph-flow/trustgraph/gateway/serialize.py +++ b/trustgraph-flow/trustgraph/gateway/serialize.py @@ -60,3 +60,20 @@ def serialize_graph_embeddings(message): ], } +def serialize_document_embeddings(message): + return { + "metadata": { + "id": message.metadata.id, + "metadata": serialize_subgraph(message.metadata.metadata), + "user": message.metadata.user, + "collection": message.metadata.collection, + }, + "chunks": [ + { + "vectors": chunk.vectors, + "chunk": chunk.chunk.decode("utf-8"), + } + for chunk in message.chunks + ], + } + diff --git a/trustgraph-flow/trustgraph/gateway/service.py b/trustgraph-flow/trustgraph/gateway/service.py index b329660f..644731e2 100755 --- a/trustgraph-flow/trustgraph/gateway/service.py +++ b/trustgraph-flow/trustgraph/gateway/service.py @@ -41,8 +41,10 @@ from . dbpedia import DbpediaRequestor from . internet_search import InternetSearchRequestor from . triples_stream import TriplesStreamEndpoint from . graph_embeddings_stream import GraphEmbeddingsStreamEndpoint +from . document_embeddings_stream import DocumentEmbeddingsStreamEndpoint from . triples_load import TriplesLoadEndpoint from . graph_embeddings_load import GraphEmbeddingsLoadEndpoint +from . document_embeddings_load import DocumentEmbeddingsLoadEndpoint from . mux import MuxEndpoint from . document_load import DocumentLoadSender from . text_load import TextLoadSender @@ -203,6 +205,10 @@ class Api: pulsar_host=self.pulsar_host, auth = self.auth, ), + DocumentEmbeddingsStreamEndpoint( + pulsar_host=self.pulsar_host, + auth = self.auth, + ), TriplesLoadEndpoint( pulsar_host=self.pulsar_host, auth = self.auth, @@ -211,6 +217,10 @@ class Api: pulsar_host=self.pulsar_host, auth = self.auth, ), + DocumentEmbeddingsLoadEndpoint( + pulsar_host=self.pulsar_host, + auth = self.auth, + ), MuxEndpoint( pulsar_host=self.pulsar_host, auth = self.auth,