From a0def33dba210f94127754fb1ce9d2f54eed4486 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Wed, 17 Jul 2024 16:56:47 +0100 Subject: [PATCH] Base classes (#2) Simplify code using base classes --- trustgraph/base/__init__.py | 3 + trustgraph/base/processor.py | 266 ++++++++++++++++++ trustgraph/chunker/recursive/chunker.py | 190 ++++--------- trustgraph/decoder/pdf/pdf_decoder.py | 169 +++-------- trustgraph/embeddings/hf/hf.py | 156 +++------- trustgraph/embeddings/ollama/processor.py | 167 +++-------- trustgraph/embeddings/vectorize/vectorize.py | 154 +++------- trustgraph/graph/cassandra_write/write.py | 133 ++------- trustgraph/kg/extract_definitions/extract.py | 173 +++--------- .../kg/extract_relationships/extract.py | 260 ++++++----------- trustgraph/llm/ollama_text/llm.py | 164 +++-------- trustgraph/rag/graph/rag.py | 202 ++++--------- trustgraph/vector/milvus_write/write.py | 132 ++------- 13 files changed, 746 insertions(+), 1423 deletions(-) create mode 100644 trustgraph/base/__init__.py create mode 100644 trustgraph/base/processor.py diff --git a/trustgraph/base/__init__.py b/trustgraph/base/__init__.py new file mode 100644 index 00000000..9d16af90 --- /dev/null +++ b/trustgraph/base/__init__.py @@ -0,0 +1,3 @@ + +from . processor import * + diff --git a/trustgraph/base/processor.py b/trustgraph/base/processor.py new file mode 100644 index 00000000..e214b320 --- /dev/null +++ b/trustgraph/base/processor.py @@ -0,0 +1,266 @@ + +import os +import argparse +import pulsar +import time +from pulsar.schema import JsonSchema + +from .. log_level import LogLevel + +class BaseProcessor: + + default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') + + def __init__( + self, + pulsar_host=default_pulsar_host, + log_level=LogLevel.INFO, + ): + + self.client = None + + if pulsar_host == None: + pulsar_host = default_pulsar_host + + self.pulsar_host = pulsar_host + + self.client = pulsar.Client( + pulsar_host, + logger=pulsar.ConsoleLogger(log_level.to_pulsar()) + ) + + def __del__(self): + + if self.client: + self.client.close() + + @staticmethod + def add_args(parser): + + parser.add_argument( + '-p', '--pulsar-host', + default=__class__.default_pulsar_host, + help=f'Pulsar host (default: {__class__.default_pulsar_host})', + ) + + parser.add_argument( + '-l', '--log-level', + type=LogLevel, + default=LogLevel.INFO, + choices=list(LogLevel), + help=f'Output queue (default: info)' + ) + + def run(self): + raise RuntimeError("Something should have implemented the run method") + + @classmethod + def start(cls, prog, doc): + + parser = argparse.ArgumentParser( + prog=prog, + description=doc + ) + + cls.add_args(parser) + + args = parser.parse_args() + args = vars(args) + + try: + + p = cls(**args) + p.run() + + except Exception as e: + + print("Exception:", e, flush=True) + print("Will retry...", flush=True) + + time.sleep(10) + +class Consumer(BaseProcessor): + + def __init__( + self, + pulsar_host=None, + log_level=LogLevel.INFO, + input_queue="input", + subscriber="subscriber", + input_schema=None, + ): + + super(Consumer, self).__init__( + pulsar_host=pulsar_host, + log_level=log_level, + ) + + if input_schema == None: + raise RuntimeError("input_schema must be specified") + + self.consumer = self.client.subscribe( + input_queue, subscriber, + schema=JsonSchema(input_schema), + ) + + def run(self): + + while True: + + msg = self.consumer.receive() + + try: + + self.handle(msg) + + # Acknowledge successful processing of the message + self.consumer.acknowledge(msg) + + except Exception as e: + + print("Exception:", e, flush=True) + + # Message failed to be processed + self.consumer.negative_acknowledge(msg) + + @staticmethod + def add_args(parser, default_input_queue, default_subscriber): + + BaseProcessor.add_args(parser) + + parser.add_argument( + '-i', '--input-queue', + default=default_input_queue, + help=f'Input queue (default: {default_input_queue})' + ) + + parser.add_argument( + '-s', '--subscriber', + default=default_subscriber, + help=f'Queue subscriber name (default: {default_subscriber})' + ) + +class ConsumerProducer(BaseProcessor): + + def __init__( + self, + pulsar_host=None, + log_level=LogLevel.INFO, + input_queue="input", + output_queue="output", + subscriber="subscriber", + input_schema=None, + output_schema=None, + ): + + super(ConsumerProducer, self).__init__( + pulsar_host=pulsar_host, + log_level=log_level, + ) + + if input_schema == None: + raise RuntimeError("input_schema must be specified") + + if output_schema == None: + raise RuntimeError("output_schema must be specified") + + self.consumer = self.client.subscribe( + input_queue, subscriber, + schema=JsonSchema(input_schema), + ) + + self.producer = self.client.create_producer( + topic=output_queue, + schema=JsonSchema(output_schema), + ) + + def run(self): + + while True: + + msg = self.consumer.receive() + + try: + + resp = self.handle(msg) + + # Acknowledge successful processing of the message + self.consumer.acknowledge(msg) + + except Exception as e: + + print("Exception:", e, flush=True) + + # Message failed to be processed + self.consumer.negative_acknowledge(msg) + + def send(self, msg, properties={}): + + self.producer.send(msg, properties) + + @staticmethod + def add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ): + + BaseProcessor.add_args(parser) + + parser.add_argument( + '-i', '--input-queue', + default=default_input_queue, + help=f'Input queue (default: {default_input_queue})' + ) + + parser.add_argument( + '-s', '--subscriber', + default=default_subscriber, + help=f'Queue subscriber name (default: {default_subscriber})' + ) + + parser.add_argument( + '-o', '--output-queue', + default=default_output_queue, + help=f'Output queue (default: {default_output_queue})' + ) + +class Producer(BaseProcessor): + + def __init__( + self, + pulsar_host=None, + log_level=LogLevel.INFO, + output_queue="output", + output_schema=None, + ): + + super(Producer, self).__init__( + pulsar_host=pulsar_host, + log_level=log_level, + ) + + if output_schema == None: + raise RuntimeError("output_schema must be specified") + + self.producer = self.client.create_producer( + topic=output_queue, + schema=JsonSchema(output_schema), + ) + + def send(self, msg, properties={}): + + self.producer.send(msg, properties) + + @staticmethod + def add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ): + + BaseProcessor.add_args(parser) + + parser.add_argument( + '-o', '--output-queue', + default=default_output_queue, + help=f'Output queue (default: {default_output_queue})' + ) diff --git a/trustgraph/chunker/recursive/chunker.py b/trustgraph/chunker/recursive/chunker.py index 16e3c992..b116eca5 100755 --- a/trustgraph/chunker/recursive/chunker.py +++ b/trustgraph/chunker/recursive/chunker.py @@ -4,28 +4,22 @@ Simple decoder, accepts text documents on input, outputs chunks from the as text as separate output objects. """ -import pulsar -from pulsar.schema import JsonSchema -import tempfile -import base64 -import os -import argparse from langchain_text_splitters import RecursiveCharacterTextSplitter -import time + from ... schema import TextDocument, Chunk, Source from ... log_level import LogLevel +from ... base import ConsumerProducer -default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') default_input_queue = 'text-doc-load' default_output_queue = 'chunk-load' default_subscriber = 'chunker-recursive' -class Processor: +class Processor(ConsumerProducer): def __init__( self, - pulsar_host=default_pulsar_host, + pulsar_host=None, input_queue=default_input_queue, output_queue=default_output_queue, subscriber=default_subscriber, @@ -34,21 +28,14 @@ class Processor: chunk_overlap=100, ): - self.client = None - - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level.to_pulsar()) - ) - - self.consumer = self.client.subscribe( - input_queue, subscriber, - schema=JsonSchema(TextDocument), - ) - - self.producer = self.client.create_producer( - topic=output_queue, - schema=JsonSchema(Chunk), + super(Processor, self).__init__( + pulsar_host=pulsar_host, + log_level=log_level, + input_queue=input_queue, + output_queue=output_queue, + subscriber=subscriber, + input_schema=TextDocument, + output_schema=Chunk, ) self.text_splitter = RecursiveCharacterTextSplitter( @@ -58,134 +45,55 @@ class Processor: is_separator_regex=False, ) - print("Chunker inited") + def handle(self, msg): - def run(self): + v = msg.value() + print(f"Chunking {v.source.id}...", flush=True) - print("Chunker running") + texts = self.text_splitter.create_documents( + [v.text.decode("utf-8")] + ) - while True: + for ix, chunk in enumerate(texts): - msg = self.consumer.receive() - print("Chunker message received") + id = v.source.id + "-c" + str(ix) - try: + r = Chunk( + source=Source( + source=v.source.source, + id=id, + title=v.source.title + ), + chunk=chunk.page_content.encode("utf-8"), + ) - v = msg.value() - print(f"Chunking {v.source.id}...", flush=True) + self.send(r) - texts = self.text_splitter.create_documents( - [v.text.decode("utf-8")] - ) + print("Done.", flush=True) - for ix, chunk in enumerate(texts): + @staticmethod + def add_args(parser): - id = v.source.id + "-c" + str(ix) + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) - r = Chunk( - source=Source( - source=v.source.source, - id=id, - title=v.source.title - ), - chunk=chunk.page_content.encode("utf-8"), - ) + parser.add_argument( + '-z', '--chunk-size', + type=int, + default=2000, + help=f'Chunk size (default: 2000)' + ) - self.producer.send(r) - - # Acknowledge successful processing of the message - self.consumer.acknowledge(msg) - - print("Done.", flush=True) - - except Exception as e: - print(e, flush=True) - - # Message failed to be processed - self.consumer.negative_acknowledge(msg) - - def __del__(self): - - if self.client: - self.client.close() + parser.add_argument( + '-v', '--chunk-overlap', + type=int, + default=100, + help=f'Chunk overlap (default: 100)' + ) def run(): - parser = argparse.ArgumentParser( - prog='pdf-decoder', - description=__doc__, - ) - - parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', - ) - - parser.add_argument( - '-i', '--input-queue', - default=default_input_queue, - help=f'Input queue (default: {default_input_queue})' - ) - - parser.add_argument( - '-s', '--subscriber', - default=default_subscriber, - help=f'Queue subscriber name (default: {default_subscriber})' - ) - - parser.add_argument( - '-o', '--output-queue', - default=default_output_queue, - help=f'Output queue (default: {default_output_queue})' - ) - - parser.add_argument( - '-l', '--log-level', - type=LogLevel, - default=LogLevel.INFO, - choices=list(LogLevel), - help=f'Output queue (default: info)' - ) - - parser.add_argument( - '-z', '--chunk-size', - type=int, - default=2000, - help=f'Chunk size (default: 2000)' - ) - - parser.add_argument( - '-v', '--chunk-overlap', - type=int, - default=100, - help=f'Chunk overlap (default: 100)' - ) - - args = parser.parse_args() - - - while True: - - try: - - p = Processor( - pulsar_host=args.pulsar_host, - input_queue=args.input_queue, - output_queue=args.output_queue, - subscriber=args.subscriber, - log_level=args.log_level, - chunk_size=args.chunk_size, - chunk_overlap=args.chunk_overlap, - ) - - p.run() - - except Exception as e: - - print("Exception:", e, flush=True) - print("Will retry...", flush=True) - - time.sleep(10) - + Processor.start('chunker', __doc__) diff --git a/trustgraph/decoder/pdf/pdf_decoder.py b/trustgraph/decoder/pdf/pdf_decoder.py index f2ff56a2..e87f2905 100755 --- a/trustgraph/decoder/pdf/pdf_decoder.py +++ b/trustgraph/decoder/pdf/pdf_decoder.py @@ -4,171 +4,84 @@ Simple decoder, accepts PDF documents on input, outputs pages from the PDF document as text as separate output objects. """ -import pulsar -from pulsar.schema import JsonSchema -from langchain_community.document_loaders import PyPDFLoader import tempfile import base64 -import os -import argparse -import time +from langchain_community.document_loaders import PyPDFLoader from ... schema import Document, TextDocument, Source from ... log_level import LogLevel +from ... base import ConsumerProducer -default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') default_input_queue = 'document-load' default_output_queue = 'text-doc-load' default_subscriber = 'pdf-decoder' -class Processor: +class Processor(ConsumerProducer): def __init__( self, - pulsar_host=default_pulsar_host, + pulsar_host=None, input_queue=default_input_queue, output_queue=default_output_queue, subscriber=default_subscriber, log_level=LogLevel.INFO, ): - self.client = None - - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level.to_pulsar()) - ) - - self.consumer = self.client.subscribe( - input_queue, subscriber, - schema=JsonSchema(Document), - ) - - self.producer = self.client.create_producer( - topic=output_queue, - schema=JsonSchema(TextDocument), + super(Processor, self).__init__( + pulsar_host=pulsar_host, + log_level=log_level, + input_queue=input_queue, + output_queue=output_queue, + subscriber=subscriber, + input_schema=Document, + output_schema=TextDocument, ) print("PDF inited") - print("Pulsar", pulsar_host) - print("Input", input_queue) - print("Output", output_queue) - print("Subscriber", subscriber) + def handle(self, msg): - def run(self): + print("PDF message received") - print("PDF running") + v = msg.value() - while True: + print(f"Decoding {v.source.id}...", flush=True) - msg = self.consumer.receive() + with tempfile.NamedTemporaryFile(delete_on_close=False) as fp: - print("PDF message received") + fp.write(base64.b64decode(v.data)) + fp.close() - try: + with open(fp.name, mode='rb') as f: - v = msg.value() - print(f"Decoding {v.source.id}...", flush=True) + loader = PyPDFLoader(fp.name) + pages = loader.load() - with tempfile.NamedTemporaryFile(delete_on_close=False) as fp: + for ix, page in enumerate(pages): - fp.write(base64.b64decode(v.data)) - fp.close() + id = v.source.id + "-p" + str(ix) + r = TextDocument( + source=Source( + source=v.source.source, + title=v.source.title, + id=id, + ), + text=page.page_content.encode("utf-8"), + ) - with open(fp.name, mode='rb') as f: + self.send(r) - loader = PyPDFLoader(fp.name) - pages = loader.load() + print("Done.", flush=True) - for ix, page in enumerate(pages): + @staticmethod + def add_args(parser): - id = v.source.id + "-p" + str(ix) - r = TextDocument( - source=Source( - source=v.source.source, - title=v.source.title, - id=id, - ), - text=page.page_content.encode("utf-8"), - ) - - self.producer.send(r) - - # Acknowledge successful processing of the message - self.consumer.acknowledge(msg) - - print("Done.", flush=True) - - except Exception as e: - print(e, flush=True) - - # Message failed to be processed - self.consumer.negative_acknowledge(msg) - - def __del__(self): - - if self.client: - self.client.close() + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) def run(): - parser = argparse.ArgumentParser( - prog='pdf-decoder', - description=__doc__, - ) - - parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', - ) - - parser.add_argument( - '-i', '--input-queue', - default=default_input_queue, - help=f'Input queue (default: {default_input_queue})' - ) - - parser.add_argument( - '-s', '--subscriber', - default=default_subscriber, - help=f'Queue subscriber name (default: {default_subscriber})' - ) - - parser.add_argument( - '-o', '--output-queue', - default=default_output_queue, - help=f'Output queue (default: {default_output_queue})' - ) - - parser.add_argument( - '-l', '--log-level', - type=LogLevel, - default=LogLevel.INFO, - choices=list(LogLevel), - help=f'Output queue (default: info)' - ) - - args = parser.parse_args() - - while True: - - try: - p = Processor( - pulsar_host=args.pulsar_host, - input_queue=args.input_queue, - output_queue=args.output_queue, - subscriber=args.subscriber, - log_level=args.log_level, - ) - - p.run() - - except Exception as e: - - print("Exception:", e, flush=True) - print("Will retry...", flush=True) - - time.sleep(10) + Processor.start("pdf-decoder", __doc__) diff --git a/trustgraph/embeddings/hf/hf.py b/trustgraph/embeddings/hf/hf.py index 96ee8e8b..82273310 100755 --- a/trustgraph/embeddings/hf/hf.py +++ b/trustgraph/embeddings/hf/hf.py @@ -4,29 +4,22 @@ Embeddings service, applies an embeddings model selected from HuggingFace. Input is text, output is embeddings vector. """ -import pulsar -from pulsar.schema import JsonSchema -import tempfile -import base64 -import os -import argparse from langchain_huggingface import HuggingFaceEmbeddings -import time from ... schema import EmbeddingsRequest, EmbeddingsResponse from ... log_level import LogLevel +from ... base import ConsumerProducer -default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') default_input_queue = 'embeddings' default_output_queue = 'embeddings-response' default_subscriber = 'embeddings-hf' default_model="all-MiniLM-L6-v2" -class Processor: +class Processor(ConsumerProducer): def __init__( self, - pulsar_host=default_pulsar_host, + pulsar_host=None, input_queue=default_input_queue, output_queue=default_output_queue, subscriber=default_subscriber, @@ -34,132 +27,51 @@ class Processor: model=default_model, ): - self.client = None - - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level.to_pulsar()) - ) - - self.consumer = self.client.subscribe( - input_queue, subscriber, - schema=JsonSchema(EmbeddingsRequest), - ) - - self.producer = self.client.create_producer( - topic=output_queue, - schema=JsonSchema(EmbeddingsResponse), + super(Processor, self).__init__( + pulsar_host=pulsar_host, + log_level=log_level, + input_queue=input_queue, + output_queue=output_queue, + subscriber=subscriber, + input_schema=EmbeddingsRequest, + output_schema=EmbeddingsResponse, ) self.embeddings = HuggingFaceEmbeddings(model_name=model) - def run(self): + def handle(self, msg): - while True: + v = msg.value() - msg = self.consumer.receive() + # Sender-produced ID + id = msg.properties()["id"] - try: + print(f"Handling input {id}...", flush=True) - v = msg.value() + text = v.text + embeds = self.embeddings.embed_documents([text]) - # Sender-produced ID + print("Send response...", flush=True) + r = EmbeddingsResponse(vectors=embeds) + self.producer.send(r, properties={"id": id}) - id = msg.properties()["id"] + print("Done.", flush=True) - print(f"Handling input {id}...", flush=True) + @staticmethod + def add_args(parser): - text = v.text - embeds = self.embeddings.embed_documents([text]) + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) - print("Send response...", flush=True) - r = EmbeddingsResponse(vectors=embeds) - self.producer.send(r, properties={"id": id}) - - print("Done.", flush=True) - - # Acknowledge successful processing of the message - self.consumer.acknowledge(msg) - - except Exception as e: - - print("Exception:", e, flush=True) - - # Message failed to be processed - self.consumer.negative_acknowledge(msg) - - def __del__(self): - - if self.client: - self.client.close() + parser.add_argument( + '-m', '--model', + default="all-MiniLM-L6-v2", + help=f'LLM model (default: all-MiniLM-L6-v2)' + ) def run(): - parser = argparse.ArgumentParser( - prog='llm-ollama-text', - description=__doc__, - ) - - parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', - ) - - parser.add_argument( - '-i', '--input-queue', - default=default_input_queue, - help=f'Input queue (default: {default_input_queue})' - ) - - parser.add_argument( - '-s', '--subscriber', - default=default_subscriber, - help=f'Queue subscriber name (default: {default_subscriber})' - ) - - parser.add_argument( - '-o', '--output-queue', - default=default_output_queue, - help=f'Output queue (default: {default_output_queue})' - ) - - parser.add_argument( - '-l', '--log-level', - type=LogLevel, - default=LogLevel.INFO, - choices=list(LogLevel), - help=f'Output queue (default: info)' - ) - - parser.add_argument( - '-m', '--model', - default="all-MiniLM-L6-v2", - help=f'LLM model (default: all-MiniLM-L6-v2)' - ) - - args = parser.parse_args() - - - while True: - - try: - - p = Processor( - pulsar_host=args.pulsar_host, - input_queue=args.input_queue, - output_queue=args.output_queue, - subscriber=args.subscriber, - log_level=args.log_level, - model=args.model, - ) - - p.run() - - except Exception as e: - - print("Exception:", e, flush=True) - print("Will retry...", flush=True) - - time.sleep(10) + Processor.start("embeddings-hf", __doc__) diff --git a/trustgraph/embeddings/ollama/processor.py b/trustgraph/embeddings/ollama/processor.py index e5ef34a8..5c36f86a 100755 --- a/trustgraph/embeddings/ollama/processor.py +++ b/trustgraph/embeddings/ollama/processor.py @@ -3,31 +3,23 @@ Embeddings service, applies an embeddings model selected from HuggingFace. Input is text, output is embeddings vector. """ - -import pulsar -from pulsar.schema import JsonSchema -import tempfile -import base64 -import os -import argparse from langchain_community.embeddings import OllamaEmbeddings -import time from ... schema import EmbeddingsRequest, EmbeddingsResponse from ... log_level import LogLevel +from ... base import ConsumerProducer -default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') default_input_queue = 'embeddings' default_output_queue = 'embeddings-response' default_subscriber = 'embeddings-ollama' default_model="mxbai-embed-large" default_ollama = 'http://localhost:11434' -class Processor: +class Processor(ConsumerProducer): def __init__( self, - pulsar_host=default_pulsar_host, + pulsar_host=None, input_queue=default_input_queue, output_queue=default_output_queue, subscriber=default_subscriber, @@ -36,140 +28,59 @@ class Processor: ollama=default_ollama, ): - self.client = None - - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level.to_pulsar()) - ) - - self.consumer = self.client.subscribe( - input_queue, subscriber, - schema=JsonSchema(EmbeddingsRequest), - ) - - self.producer = self.client.create_producer( - topic=output_queue, - schema=JsonSchema(EmbeddingsResponse), + super(Processor, self).__init__( + pulsar_host=pulsar_host, + log_level=log_level, + input_queue=input_queue, + output_queue=output_queue, + subscriber=subscriber, + input_schema=EmbeddingsRequest, + output_schema=EmbeddingsResponse, ) self.embeddings = OllamaEmbeddings(base_url=ollama, model=model) - def run(self): + def handle(self, msg): - while True: + v = msg.value() - msg = self.consumer.receive() + # Sender-produced ID - try: + id = msg.properties()["id"] - v = msg.value() + print(f"Handling input {id}...", flush=True) - # Sender-produced ID + text = v.text + embeds = self.embeddings.embed_query([text]) - id = msg.properties()["id"] + print("Send response...", flush=True) + r = EmbeddingsResponse(vectors=[embeds]) - print(f"Handling input {id}...", flush=True) + self.producer.send(r, properties={"id": id}) - text = v.text - embeds = self.embeddings.embed_query([text]) + print("Done.", flush=True) - print("Send response...", flush=True) - r = EmbeddingsResponse(vectors=[embeds]) + @staticmethod + def add_args(parser): - self.producer.send(r, properties={"id": id}) + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) - print("Done.", flush=True) + parser.add_argument( + '-m', '--model', + default=default_model, + help=f'Embeddings model (default: {default_model})' + ) - # Acknowledge successful processing of the message - self.consumer.acknowledge(msg) - - except Exception as e: - - print("Exception:", e, flush=True) - - # Message failed to be processed - self.consumer.negative_acknowledge(msg) - - def __del__(self): - - if self.client: - self.client.close() + parser.add_argument( + '-r', '--ollama', + default=default_ollama, + help=f'ollama (default: {default_ollama})' + ) def run(): - parser = argparse.ArgumentParser( - prog='embeddings-ollama', - description=__doc__, - ) - - parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', - ) - - parser.add_argument( - '-i', '--input-queue', - default=default_input_queue, - help=f'Input queue (default: {default_input_queue})' - ) - - parser.add_argument( - '-s', '--subscriber', - default=default_subscriber, - help=f'Queue subscriber name (default: {default_subscriber})' - ) - - parser.add_argument( - '-o', '--output-queue', - default=default_output_queue, - help=f'Output queue (default: {default_output_queue})' - ) - - parser.add_argument( - '-l', '--log-level', - type=LogLevel, - default=LogLevel.INFO, - choices=list(LogLevel), - help=f'Output queue (default: info)' - ) - - parser.add_argument( - '-m', '--model', - default=default_model, - help=f'Embeddings model (default: {default_model})' - ) - - parser.add_argument( - '-r', '--ollama', - default=default_ollama, - help=f'ollama (default: {default_ollama})' - ) - - args = parser.parse_args() - - - while True: - - try: - - p = Processor( - pulsar_host=args.pulsar_host, - input_queue=args.input_queue, - output_queue=args.output_queue, - subscriber=args.subscriber, - log_level=args.log_level, - model=args.model, - ollama=args.ollama, - ) - - p.run() - - except Exception as e: - - print("Exception:", e, flush=True) - print("Will retry...", flush=True) - - time.sleep(10) + Processor.start('embeddings-ollama', __doc__) diff --git a/trustgraph/embeddings/vectorize/vectorize.py b/trustgraph/embeddings/vectorize/vectorize.py index deeb280e..e779c6ed 100755 --- a/trustgraph/embeddings/vectorize/vectorize.py +++ b/trustgraph/embeddings/vectorize/vectorize.py @@ -4,49 +4,34 @@ Vectorizer, calls the embeddings service to get embeddings for a chunk. Input is text chunk, output is chunk and vectors. """ -import pulsar -from pulsar.schema import JsonSchema -import tempfile -import base64 -import os -import argparse -import time - from ... schema import Chunk, VectorsChunk from ... embeddings_client import EmbeddingsClient from ... log_level import LogLevel +from ... base import ConsumerProducer -default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') default_input_queue = 'chunk-load' default_output_queue = 'vectors-chunk-load' default_subscriber = 'embeddings-vectorizer' -class Processor: +class Processor(ConsumerProducer): def __init__( self, - pulsar_host=default_pulsar_host, + pulsar_host=None, input_queue=default_input_queue, output_queue=default_output_queue, subscriber=default_subscriber, log_level=LogLevel.INFO, ): - self.client = None - - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level.to_pulsar()) - ) - - self.consumer = self.client.subscribe( - input_queue, subscriber, - schema=JsonSchema(Chunk), - ) - - self.producer = self.client.create_producer( - topic=output_queue, - schema=JsonSchema(VectorsChunk), + super(Processor, self).__init__( + pulsar_host=pulsar_host, + log_level=log_level, + input_queue=input_queue, + output_queue=output_queue, + subscriber=subscriber, + input_schema=Chunk, + output_schema=VectorsChunk, ) self.embeddings = EmbeddingsClient(pulsar_host=pulsar_host) @@ -56,108 +41,37 @@ class Processor: r = VectorsChunk(source=source, chunk=chunk, vectors=vectors) self.producer.send(r) - def run(self): + def handle(self, msg): - while True: + v = msg.value() + print(f"Indexing {v.source.id}...", flush=True) - msg = self.consumer.receive() - - try: - - v = msg.value() - print(f"Indexing {v.source.id}...", flush=True) - - chunk = v.chunk.decode("utf-8") - - try: - - vectors = self.embeddings.request(chunk) - - self.emit( - source=v.source, - chunk=chunk.encode("utf-8"), - vectors=vectors - ) - - except Exception as e: - print("Exception:", e, flush=True) - - print("Done.", flush=True) - - # Acknowledge successful processing of the message - self.consumer.acknowledge(msg) - - except Exception as e: - - print("Exception:", e, flush=True) - - # Message failed to be processed - self.consumer.negative_acknowledge(msg) - - def __del__(self): - - if self.client: - self.client.close() - -def run(): - - parser = argparse.ArgumentParser( - prog='embeddings-vectorizer', - description=__doc__, - ) - - parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', - ) - - parser.add_argument( - '-i', '--input-queue', - default=default_input_queue, - help=f'Input queue (default: {default_input_queue})' - ) - - parser.add_argument( - '-s', '--subscriber', - default=default_subscriber, - help=f'Queue subscriber name (default: {default_subscriber})' - ) - - parser.add_argument( - '-o', '--output-queue', - default=default_output_queue, - help=f'Output queue (default: {default_output_queue})' - ) - - parser.add_argument( - '-l', '--log-level', - type=LogLevel, - default=LogLevel.INFO, - choices=list(LogLevel), - help=f'Output queue (default: info)' - ) - - args = parser.parse_args() - - while True: + chunk = v.chunk.decode("utf-8") try: - p = Processor( - pulsar_host=args.pulsar_host, - input_queue=args.input_queue, - output_queue=args.output_queue, - subscriber=args.subscriber, - log_level=args.log_level, + vectors = self.embeddings.request(chunk) + + self.emit( + source=v.source, + chunk=chunk.encode("utf-8"), + vectors=vectors ) - p.run() - except Exception as e: - print("Exception:", e, flush=True) - print("Will retry...", flush=True) - time.sleep(10) + print("Done.", flush=True) + + @staticmethod + def add_args(parser): + + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) + +def run(): + + Processor.start("embeddings-vectorize", __doc__) diff --git a/trustgraph/graph/cassandra_write/write.py b/trustgraph/graph/cassandra_write/write.py index 01ca72f1..df317c2a 100755 --- a/trustgraph/graph/cassandra_write/write.py +++ b/trustgraph/graph/cassandra_write/write.py @@ -4,8 +4,6 @@ Graph writer. Input is graph edge. Writes edges to Cassandra graph. """ import pulsar -from pulsar.schema import JsonSchema -import tempfile import base64 import os import argparse @@ -14,135 +12,64 @@ import time from ... trustgraph import TrustGraph from ... schema import Triple from ... log_level import LogLevel +from ... base import Consumer -default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') default_input_queue = 'graph-load' default_subscriber = 'graph-write-cassandra' default_graph_host='localhost' -class Processor: +class Processor(Consumer): def __init__( self, - pulsar_host=default_pulsar_host, + pulsar_host=None, input_queue=default_input_queue, subscriber=default_subscriber, graph_host=default_graph_host, log_level=LogLevel.INFO, ): - self.client = None - - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level.to_pulsar()) - ) - - self.consumer = self.client.subscribe( - input_queue, subscriber, - schema=JsonSchema(Triple), + super(Processor, self).__init__( + pulsar_host=pulsar_host, + log_level=log_level, + input_queue=input_queue, + subscriber=subscriber, + input_schema=Triple, ) self.tg = TrustGraph([graph_host]) self.count = 0 - def run(self): + def handle(self, msg): - while True: + v = msg.value() - msg = self.consumer.receive() + self.tg.insert( + v.s.value, + v.p.value, + v.o.value + ) - try: + self.count += 1 - v = msg.value() + if (self.count % 1000) == 0: + print(self.count, "...", flush=True) - self.tg.insert( - v.s.value, - v.p.value, - v.o.value - ) + @staticmethod + def add_args(parser): - self.count += 1 + Consumer.add_args( + parser, default_input_queue, default_subscriber, + ) - if (self.count % 1000) == 0: - print(self.count, "...", flush=True) - - # Acknowledge successful processing of the message - self.consumer.acknowledge(msg) - - except Exception as e: - - print("Exception:", e, flush=True) - - # Message failed to be processed - self.consumer.negative_acknowledge(msg) - - def __del__(self): - - if self.client: - self.client.close() + parser.add_argument( + '-g', '--graph-host', + default="localhost", + help=f'Graph host (default: localhost)' + ) def run(): - parser = argparse.ArgumentParser( - prog='graph-write-cassandra', - description=__doc__, - ) - - parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', - ) - - parser.add_argument( - '-i', '--input-queue', - default=default_input_queue, - help=f'Input queue (default: {default_input_queue})' - ) - - parser.add_argument( - '-s', '--subscriber', - default=default_subscriber, - help=f'Queue subscriber name (default: {default_subscriber})' - ) - - parser.add_argument( - '-l', '--log-level', - type=LogLevel, - default=LogLevel.INFO, - choices=list(LogLevel), - help=f'Output queue (default: info)' - ) - - parser.add_argument( - '-g', '--graph-host', - default="localhost", - help=f'Output queue (default: localhost)' - ) - - args = parser.parse_args() - - while True: - - try: - - p = Processor( - pulsar_host=args.pulsar_host, - input_queue=args.input_queue, - subscriber=args.subscriber, - log_level=args.log_level, - graph_host=args.graph_host, - ) - - p.run() - - except Exception as e: - - print("Exception:", e, flush=True) - print("Will retry...", flush=True) - - time.sleep(10) - + Processor.start("graph-write-cassandra", __doc__) diff --git a/trustgraph/kg/extract_definitions/extract.py b/trustgraph/kg/extract_definitions/extract.py index 7ff569dc..42cbacc9 100755 --- a/trustgraph/kg/extract_definitions/extract.py +++ b/trustgraph/kg/extract_definitions/extract.py @@ -4,57 +4,41 @@ Simple decoder, accepts vector+text chunks input, applies entity analysis to get entity definitions which are output as graph edges. """ -import pulsar -from pulsar.schema import JsonSchema -from langchain_community.document_loaders import PyPDFLoader -import tempfile -import base64 -import os -import argparse -import rdflib -import json import urllib.parse -import time +import json from ... schema import VectorsChunk, Triple, Source, Value from ... log_level import LogLevel from ... llm_client import LlmClient from ... prompts import to_definitions from ... rdf import TRUSTGRAPH_ENTITIES, DEFINITION +from ... base import ConsumerProducer DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True) -default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') default_input_queue = 'vectors-chunk-load' default_output_queue = 'graph-load' default_subscriber = 'kg-extract-definitions' -class Processor: +class Processor(ConsumerProducer): def __init__( self, - pulsar_host=default_pulsar_host, + pulsar_host=None, input_queue=default_input_queue, output_queue=default_output_queue, subscriber=default_subscriber, log_level=LogLevel.INFO, ): - self.client = None - - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level.to_pulsar()) - ) - - self.consumer = self.client.subscribe( - input_queue, subscriber, - schema=JsonSchema(VectorsChunk), - ) - - self.producer = self.client.create_producer( - topic=output_queue, - schema=JsonSchema(Triple), + super(Processor, self).__init__( + pulsar_host=pulsar_host, + log_level=log_level, + input_queue=input_queue, + output_queue=output_queue, + subscriber=subscriber, + input_schema=VectorsChunk, + output_schema=Triple, ) self.llm = LlmClient(pulsar_host=pulsar_host) @@ -81,117 +65,44 @@ class Processor: t = Triple(s=s, p=p, o=o) self.producer.send(t) - def run(self): + def handle(self, msg): - while True: + v = msg.value() + print(f"Indexing {v.source.id}...", flush=True) - msg = self.consumer.receive() - - try: - - v = msg.value() - print(f"Indexing {v.source.id}...", flush=True) - - chunk = v.chunk.decode("utf-8") - - g = rdflib.Graph() - - try: - - defs = self.get_definitions(chunk) - print(json.dumps(defs, indent=4), flush=True) - - for defn in defs: - - s = defn["entity"] - s_uri = self.to_uri(s) - - o = defn["definition"] - - s_value = Value(value=str(s_uri), is_uri=True) - o_value = Value(value=str(o), is_uri=False) - - self.emit_edge(s_value, DEFINITION_VALUE, o_value) - - except Exception as e: - print("Exception: ", e, flush=True) - - print("Done.", flush=True) - - # Acknowledge successful processing of the message - self.consumer.acknowledge(msg) - - except Exception as e: - - print("Exception: ", e, flush=True) - - # Message failed to be processed - self.consumer.negative_acknowledge(msg) - - def __del__(self): - - if self.client: - self.client.close() - -def run(): - - parser = argparse.ArgumentParser( - prog='pdf-decoder', - description=__doc__, - ) - - parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', - ) - - parser.add_argument( - '-i', '--input-queue', - default=default_input_queue, - help=f'Input queue (default: {default_input_queue})' - ) - - parser.add_argument( - '-s', '--subscriber', - default=default_subscriber, - help=f'Queue subscriber name (default: {default_subscriber})' - ) - - parser.add_argument( - '-o', '--output-queue', - default=default_output_queue, - help=f'Output queue (default: {default_output_queue})' - ) - - parser.add_argument( - '-l', '--log-level', - type=LogLevel, - default=LogLevel.INFO, - choices=list(LogLevel), - help=f'Output queue (default: info)' - ) - - args = parser.parse_args() - - while True: + chunk = v.chunk.decode("utf-8") try: - p = Processor( - pulsar_host=args.pulsar_host, - input_queue=args.input_queue, - output_queue=args.output_queue, - subscriber=args.subscriber, - log_level=args.log_level, - ) + defs = self.get_definitions(chunk) + print(json.dumps(defs, indent=4), flush=True) - p.run() + for defn in defs: + + s = defn["entity"] + s_uri = self.to_uri(s) + + o = defn["definition"] + + s_value = Value(value=str(s_uri), is_uri=True) + o_value = Value(value=str(o), is_uri=False) + + self.emit_edge(s_value, DEFINITION_VALUE, o_value) except Exception as e: + print("Exception: ", e, flush=True) - print("Exception:", e, flush=True) - print("Will retry...", flush=True) + print("Done.", flush=True) - time.sleep(10) + @staticmethod + def add_args(parser): + + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) + +def run(): + + Processor.start("kg-extract-definitions", __doc__) diff --git a/trustgraph/kg/extract_relationships/extract.py b/trustgraph/kg/extract_relationships/extract.py index 14f837df..adc0b71a 100755 --- a/trustgraph/kg/extract_relationships/extract.py +++ b/trustgraph/kg/extract_relationships/extract.py @@ -5,37 +5,29 @@ relationship analysis to get entity relationship edges which are output as graph edges. """ -import pulsar -from pulsar.schema import JsonSchema -from langchain_community.document_loaders import PyPDFLoader -import tempfile -import base64 -import os -import argparse -import rdflib -import json import urllib.parse -import time +import json +from pulsar.schema import JsonSchema from ... schema import VectorsChunk, Triple, VectorsAssociation, Source, Value from ... log_level import LogLevel from ... llm_client import LlmClient from ... prompts import to_relationships from ... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES +from ... base import ConsumerProducer RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True) -default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') default_input_queue = 'vectors-chunk-load' default_output_queue = 'graph-load' default_subscriber = 'kg-extract-relationships' default_vector_queue='vectors-load' -class Processor: +class Processor(ConsumerProducer): def __init__( self, - pulsar_host=default_pulsar_host, + pulsar_host=None, input_queue=default_input_queue, vector_queue=default_vector_queue, output_queue=default_output_queue, @@ -43,19 +35,14 @@ class Processor: log_level=LogLevel.INFO, ): - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level.to_pulsar()) - ) - - self.consumer = self.client.subscribe( - input_queue, subscriber, - schema=JsonSchema(VectorsChunk), - ) - - self.producer = self.client.create_producer( - topic=output_queue, - schema=JsonSchema(Triple), + super(Processor, self).__init__( + pulsar_host=pulsar_host, + log_level=log_level, + input_queue=input_queue, + output_queue=output_queue, + subscriber=subscriber, + input_schema=VectorsChunk, + output_schema=Triple, ) self.vec_prod = self.client.create_producer( @@ -92,162 +79,89 @@ class Processor: r = VectorsAssociation(entity=ent, vectors=vec) self.vec_prod.send(r) - def run(self): + def handle(self, msg): - while True: + v = msg.value() + print(f"Indexing {v.source.id}...", flush=True) - msg = self.consumer.receive() - - try: - - v = msg.value() - print(f"Indexing {v.source.id}...", flush=True) - - chunk = v.chunk.decode("utf-8") - - g = rdflib.Graph() - - try: - - rels = self.get_relationships(chunk) - print(json.dumps(rels, indent=4), flush=True) - - for rel in rels: - - s = rel["subject"] - p = rel["predicate"] - o = rel["object"] - - s_uri = self.to_uri(s) - s_value = Value(value=str(s_uri), is_uri=True) - - p_uri = self.to_uri(p) - p_value = Value(value=str(p_uri), is_uri=True) - - if rel["object-entity"]: - o_uri = self.to_uri(o) - o_value = Value(value=str(o_uri), is_uri=True) - else: - o_value = Value(value=str(o), is_uri=False) - - self.emit_edge( - s_value, - p_value, - o_value - ) - - # Label for s - self.emit_edge( - s_value, - RDF_LABEL_VALUE, - Value(value=str(s), is_uri=False) - ) - - # Label for p - self.emit_edge( - p_value, - RDF_LABEL_VALUE, - Value(value=str(p), is_uri=False) - ) - - if rel["object-entity"]: - # Label for o - self.emit_edge( - o_value, - RDF_LABEL_VALUE, - Value(value=str(o), is_uri=False) - ) - - self.emit_vec(s_value, v.vectors) - self.emit_vec(p_value, v.vectors) - if rel["object-entity"]: - self.emit_vec(o_value, v.vectors) - - except Exception as e: - print("Exception: ", e, flush=True) - - print("Done.", flush=True) - - # Acknowledge successful processing of the message - self.consumer.acknowledge(msg) - - except Exception as e: - - print("Exception: ", e, flush=True) - - # Message failed to be processed - self.consumer.negative_acknowledge(msg) - - def __del__(self): - self.client.close() - -def run(): - - parser = argparse.ArgumentParser( - prog='kg-extract-relationships', - description=__doc__, - ) - - parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', - ) - - parser.add_argument( - '-i', '--input-queue', - default=default_input_queue, - help=f'Input queue (default: {default_input_queue})' - ) - - parser.add_argument( - '-s', '--subscriber', - default=default_subscriber, - help=f'Queue subscriber name (default: {default_subscriber})' - ) - - parser.add_argument( - '-o', '--output-queue', - default=default_output_queue, - help=f'Output queue (default: {default_output_queue})' - ) - - parser.add_argument( - '-l', '--log-level', - type=LogLevel, - default=LogLevel.INFO, - choices=list(LogLevel), - help=f'Output queue (default: info)' - ) - - parser.add_argument( - '-c', '--vector-queue', - default=default_vector_queue, - help=f'Vector output queue (default: {default_vector_queue})' - ) - - args = parser.parse_args() - - while True: + chunk = v.chunk.decode("utf-8") try: - p = Processor( - pulsar_host=args.pulsar_host, - input_queue=args.input_queue, - output_queue=args.output_queue, - vector_queue=args.vector_queue, - subscriber=args.subscriber, - log_level=args.log_level, - ) + rels = self.get_relationships(chunk) + print(json.dumps(rels, indent=4), flush=True) - p.run() + for rel in rels: + + s = rel["subject"] + p = rel["predicate"] + o = rel["object"] + + s_uri = self.to_uri(s) + s_value = Value(value=str(s_uri), is_uri=True) + + p_uri = self.to_uri(p) + p_value = Value(value=str(p_uri), is_uri=True) + + if rel["object-entity"]: + o_uri = self.to_uri(o) + o_value = Value(value=str(o_uri), is_uri=True) + else: + o_value = Value(value=str(o), is_uri=False) + + self.emit_edge( + s_value, + p_value, + o_value + ) + + # Label for s + self.emit_edge( + s_value, + RDF_LABEL_VALUE, + Value(value=str(s), is_uri=False) + ) + + # Label for p + self.emit_edge( + p_value, + RDF_LABEL_VALUE, + Value(value=str(p), is_uri=False) + ) + + if rel["object-entity"]: + # Label for o + self.emit_edge( + o_value, + RDF_LABEL_VALUE, + Value(value=str(o), is_uri=False) + ) + + self.emit_vec(s_value, v.vectors) + self.emit_vec(p_value, v.vectors) + if rel["object-entity"]: + self.emit_vec(o_value, v.vectors) except Exception as e: + print("Exception: ", e, flush=True) - print("Exception:", e, flush=True) - print("Will retry...", flush=True) + print("Done.", flush=True) - time.sleep(10) + @staticmethod + def add_args(parser): + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) + + parser.add_argument( + '-c', '--vector-queue', + default=default_vector_queue, + help=f'Vector output queue (default: {default_vector_queue})' + ) + +def run(): + + Processor.start("kg-extract-relationships", __doc__) diff --git a/trustgraph/llm/ollama_text/llm.py b/trustgraph/llm/ollama_text/llm.py index b7a5fda9..b522e225 100755 --- a/trustgraph/llm/ollama_text/llm.py +++ b/trustgraph/llm/ollama_text/llm.py @@ -4,30 +4,23 @@ Simple LLM service, performs text prompt completion using an Ollama service. Input is prompt, output is response. """ -import pulsar -from pulsar.schema import JsonSchema -import tempfile -import base64 -import os -import argparse from langchain_community.llms import Ollama -import time from ... schema import TextCompletionRequest, TextCompletionResponse from ... log_level import LogLevel +from ... base import ConsumerProducer -default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') default_input_queue = 'llm-complete-text' default_output_queue = 'llm-complete-text-response' default_subscriber = 'llm-ollama-text' default_model = 'gemma2' default_ollama = 'http://localhost:11434' -class Processor: +class Processor(ConsumerProducer): def __init__( self, - pulsar_host=default_pulsar_host, + pulsar_host=None, input_queue=default_input_queue, output_queue=default_output_queue, subscriber=default_subscriber, @@ -36,139 +29,60 @@ class Processor: ollama=default_ollama, ): - self.client = None - - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level.to_pulsar()) - ) - - self.consumer = self.client.subscribe( - input_queue, subscriber, - schema=JsonSchema(TextCompletionRequest), - ) - - self.producer = self.client.create_producer( - topic=output_queue, - schema=JsonSchema(TextCompletionResponse), + super(Processor, self).__init__( + pulsar_host=pulsar_host, + log_level=log_level, + input_queue=input_queue, + output_queue=output_queue, + subscriber=subscriber, + input_schema=TextCompletionRequest, + output_schema=TextCompletionResponse, ) self.llm = Ollama(base_url=ollama, model=model) - def run(self): + def handle(self, msg): - while True: + v = msg.value() - msg = self.consumer.receive() + # Sender-produced ID + id = msg.properties()["id"] - try: + print(f"Handling prompt {id}...", flush=True) - v = msg.value() + prompt = v.prompt + response = self.llm.invoke(prompt) - # Sender-produced ID + print("Send response...", flush=True) - id = msg.properties()["id"] + r = TextCompletionResponse(response=response) - print(f"Handling prompt {id}...", flush=True) + self.send(r, properties={"id": id}) - prompt = v.prompt - response = self.llm.invoke(prompt) + print("Done.", flush=True) - print("Send response...", flush=True) - r = TextCompletionResponse(response=response) - self.producer.send(r, properties={"id": id}) + @staticmethod + def add_args(parser): - print("Done.", flush=True) + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) - # Acknowledge successful processing of the message - self.consumer.acknowledge(msg) + parser.add_argument( + '-m', '--model', + default="gemma2", + help=f'LLM model (default: gemma2)' + ) - except Exception as e: - - print("Exception:", e, flush=True) - - # Message failed to be processed - self.consumer.negative_acknowledge(msg) - - def __del__(self): - - if self.client: - self.client.close() + parser.add_argument( + '-r', '--ollama', + default=default_ollama, + help=f'ollama (default: {default_ollama})' + ) def run(): - parser = argparse.ArgumentParser( - prog='llm-ollama-text', - description=__doc__, - ) - - parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', - ) - - parser.add_argument( - '-i', '--input-queue', - default=default_input_queue, - help=f'Input queue (default: {default_input_queue})' - ) - - parser.add_argument( - '-s', '--subscriber', - default=default_subscriber, - help=f'Queue subscriber name (default: {default_subscriber})' - ) - - parser.add_argument( - '-o', '--output-queue', - default=default_output_queue, - help=f'Output queue (default: {default_output_queue})' - ) - - parser.add_argument( - '-l', '--log-level', - type=LogLevel, - default=LogLevel.INFO, - choices=list(LogLevel), - help=f'Output queue (default: info)' - ) - - parser.add_argument( - '-m', '--model', - default="gemma2", - help=f'LLM model (default: gemma2)' - ) - - parser.add_argument( - '-r', '--ollama', - default=default_ollama, - help=f'ollama (default: {default_ollama})' - ) - - args = parser.parse_args() + Processor.start("llm-ollama-text", __doc__) - while True: - - try: - - p = Processor( - pulsar_host=args.pulsar_host, - input_queue=args.input_queue, - output_queue=args.output_queue, - subscriber=args.subscriber, - log_level=args.log_level, - model=args.model, - ollama=args.ollama, - ) - - p.run() - - except Exception as e: - - print("Exception:", e, flush=True) - print("Will retry...", flush=True) - - time.sleep(10) - diff --git a/trustgraph/rag/graph/rag.py b/trustgraph/rag/graph/rag.py index 0ac7b12e..8a7a484d 100755 --- a/trustgraph/rag/graph/rag.py +++ b/trustgraph/rag/graph/rag.py @@ -4,30 +4,22 @@ Simple RAG service, performs query using graph RAG an LLM. Input is query, output is response. """ -import pulsar -from pulsar.schema import JsonSchema -import tempfile -import base64 -import os -import argparse -import time - from ... schema import GraphRagQuery, GraphRagResponse from ... log_level import LogLevel from ... graph_rag import GraphRag +from ... base import ConsumerProducer -default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') default_input_queue = 'graph-rag-query' default_output_queue = 'graph-rag-response' default_subscriber = 'graph-rag' default_graph_hosts = [ 'localhost' ] default_vector_store = 'http://localhost:19530' -class Processor: +class Processor(ConsumerProducer): def __init__( self, - pulsar_host=default_pulsar_host, + pulsar_host=None, input_queue=default_input_queue, output_queue=default_output_queue, subscriber=default_subscriber, @@ -39,21 +31,14 @@ class Processor: max_sg_size=3000, ): - self.client = None - - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level.to_pulsar()) - ) - - self.consumer = self.client.subscribe( - input_queue, subscriber, - schema=JsonSchema(GraphRagQuery), - ) - - self.producer = self.client.create_producer( - topic=output_queue, - schema=JsonSchema(GraphRagResponse), + super(Processor, self).__init__( + pulsar_host=pulsar_host, + log_level=log_level, + input_queue=input_queue, + output_queue=output_queue, + subscriber=subscriber, + input_schema=GraphRagQuery, + output_schema=GraphRagResponse, ) self.rag = GraphRag( @@ -66,142 +51,67 @@ class Processor: max_sg_size=max_sg_size, ) - def run(self): + def handle(self, msg): - while True: + v = msg.value() - msg = self.consumer.receive() + # Sender-produced ID - try: + id = msg.properties()["id"] - v = msg.value() + print(f"Handling input {id}...", flush=True) - # Sender-produced ID + response = self.rag.query(v.query) - id = msg.properties()["id"] + print("Send response...", flush=True) + r = GraphRagResponse(response = response) + self.producer.send(r, properties={"id": id}) - print(f"Handling input {id}...", flush=True) + print("Done.", flush=True) - response = self.rag.query(v.query) + @staticmethod + def add_args(parser): - print("Send response...", flush=True) - r = GraphRagResponse(response = response) - self.producer.send(r, properties={"id": id}) + ConsumerProducer.add_args( + parser, default_input_queue, default_subscriber, + default_output_queue, + ) - print("Done.", flush=True) + parser.add_argument( + '-g', '--graph-hosts', + default='cassandra', + help=f'Graph hosts, comma separated (default: cassandra)' + ) - # Acknowledge successful processing of the message - self.consumer.acknowledge(msg) + parser.add_argument( + '-v', '--vector-store', + default='http://milvus:19530', + help=f'Vector host (default: http://milvus:19530)' + ) - except Exception as e: + parser.add_argument( + '-e', '--entity-limit', + type=int, + default=50, + help=f'Entity vector fetch limit (default: 50)' + ) - print("Exception:", e, flush=True) + parser.add_argument( + '-t', '--triple-limit', + type=int, + default=30, + help=f'Triple query limit, per query (default: 30)' + ) - # Message failed to be processed - self.consumer.negative_acknowledge(msg) - - def __del__(self): - - if self.client: - self.client.close() + parser.add_argument( + '-u', '--max-subgraph-size', + type=int, + default=3000, + help=f'Max subgraph size (default: 3000)' + ) def run(): - parser = argparse.ArgumentParser( - prog='graph-rag', - description=__doc__, - ) + Processor.start('graph-rag', __doc__) - parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', - ) - - parser.add_argument( - '-i', '--input-queue', - default=default_input_queue, - help=f'Input queue (default: {default_input_queue})' - ) - - parser.add_argument( - '-s', '--subscriber', - default=default_subscriber, - help=f'Queue subscriber name (default: {default_subscriber})' - ) - - parser.add_argument( - '-o', '--output-queue', - default=default_output_queue, - help=f'Output queue (default: {default_output_queue})' - ) - - parser.add_argument( - '-l', '--log-level', - type=LogLevel, - default=LogLevel.INFO, - choices=list(LogLevel), - help=f'Output queue (default: info)' - ) - - parser.add_argument( - '-g', '--graph-hosts', - default='cassandra', - help=f'Graph hosts, comma separated (default: cassandra)' - ) - - parser.add_argument( - '-v', '--vector-store', - default='http://milvus:19530', - help=f'Vector host (default: http://milvus:19530)' - ) - - parser.add_argument( - '-e', '--entity-limit', - type=int, - default=50, - help=f'Entity vector fetch limit (default: 50)' - ) - - parser.add_argument( - '-t', '--triple-limit', - type=int, - default=30, - help=f'Triple query limit, per query (default: 30)' - ) - - parser.add_argument( - '-u', '--max-subgraph-size', - type=int, - default=3000, - help=f'Max subgraph size (default: 3000)' - ) - - args = parser.parse_args() - while True: - - try: - - p = Processor( - pulsar_host=args.pulsar_host, - input_queue=args.input_queue, - output_queue=args.output_queue, - subscriber=args.subscriber, - log_level=args.log_level, - graph_hosts=args.graph_hosts.split(","), - vector_store=args.vector_store, - entity_limit=args.entity_limit, - triple_limit=args.triple_limit, - max_sg_size=args.max_subgraph_size, - ) - - p.run() - - except Exception as e: - - print("Exception:", e, flush=True) - print("Will retry...", flush=True) - - time.sleep(10) - diff --git a/trustgraph/vector/milvus_write/write.py b/trustgraph/vector/milvus_write/write.py index 579c576d..cf2a5f76 100755 --- a/trustgraph/vector/milvus_write/write.py +++ b/trustgraph/vector/milvus_write/write.py @@ -3,138 +3,58 @@ Accepts entity/vector pairs and writes them to a Milvus store. """ -import pulsar -from pulsar.schema import JsonSchema -from langchain_community.document_loaders import PyPDFLoader -import tempfile -import base64 -import os -import argparse -import time - from ... schema import VectorsAssociation from ... log_level import LogLevel from ... triple_vectors import TripleVectors +from ... base import Consumer -default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') default_input_queue = 'vectors-load' default_subscriber = 'vector-write-milvus' default_store_uri = 'http://localhost:19530' -class Processor: +class Processor(Consumer): def __init__( self, - pulsar_host=default_pulsar_host, + pulsar_host=None, input_queue=default_input_queue, subscriber=default_subscriber, store_uri=default_store_uri, log_level=LogLevel.INFO, ): - self.client = None - - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level.to_pulsar()) - ) - - self.consumer = self.client.subscribe( - input_queue, subscriber, - schema=JsonSchema(VectorsAssociation), + super(Processor, self).__init__( + pulsar_host=pulsar_host, + log_level=log_level, + input_queue=input_queue, + subscriber=subscriber, + input_schema=VectorsAssociation, ) self.vecstore = TripleVectors(store_uri) - def run(self): + def handle(self, msg): - while True: + v = msg.value() - msg = self.consumer.receive() + if v.entity.value != "": + for vec in v.vectors: + self.vecstore.insert(vec, v.entity.value) + @staticmethod + def add_args(parser): - try: + Consumer.add_args( + parser, default_input_queue, default_subscriber, + ) - v = msg.value() - - if v.entity.value != "": - for vec in v.vectors: - self.vecstore.insert(vec, v.entity.value) - - # Acknowledge successful processing of the message - self.consumer.acknowledge(msg) - - except Exception as e: - - print("Exception:", e, flush=True) - - # Message failed to be processed - self.consumer.negative_acknowledge(msg) - - def __del__(self): - - if self.client: - self.client.close() + parser.add_argument( + '-t', '--store-uri', + default="http://milvus:19530", + help=f'Milvus store URI (default: http://milvus:19530)' + ) def run(): - parser = argparse.ArgumentParser( - prog='pdf-decoder', - description=__doc__, - ) - - parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', - ) - - parser.add_argument( - '-i', '--input-queue', - default=default_input_queue, - help=f'Input queue (default: {default_input_queue})' - ) - - parser.add_argument( - '-s', '--subscriber', - default=default_subscriber, - help=f'Queue subscriber name (default: {default_subscriber})' - ) - - parser.add_argument( - '-l', '--log-level', - type=LogLevel, - default=LogLevel.INFO, - choices=list(LogLevel), - help=f'Output queue (default: info)' - ) - - parser.add_argument( - '-t', '--store-uri', - default="http://milvus:19530", - help=f'Milvus store URI (default: http://milvus:19530)' - ) - - args = parser.parse_args() - - while True: - - try: - - p = Processor( - pulsar_host=args.pulsar_host, - input_queue=args.input_queue, - subscriber=args.subscriber, - store_uri=args.store_uri, - log_level=args.log_level, - ) - - p.run() - - except Exception as e: - - print("Exception:", e, flush=True) - print("Will retry...", flush=True) - - time.sleep(10) - + Processor.start("vector-write-milvus", __doc__) +