diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py index 1f459081..e486f613 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py @@ -22,6 +22,9 @@ class DocumentEmbeddingsImport: pulsar_client, topic = queue, schema = DocumentEmbeddings ) + async def start(self): + await self.publisher.start() + async def destroy(self): self.running.stop() diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_export.py b/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_export.py new file mode 100644 index 00000000..e388003b --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_export.py @@ -0,0 +1,67 @@ + +import asyncio +import queue +import uuid + +from ... schema import EntityContexts +from ... base import Subscriber + +from . serialize import serialize_entity_contexts + +class EntityContextsExport: + + def __init__( + self, ws, running, pulsar_client, queue, consumer, subscriber + ): + + self.ws = ws + self.running = running + self.pulsar_client = pulsar_client + self.queue = queue + self.consumer = consumer + self.subscriber = subscriber + + async def destroy(self): + self.running.stop() + await self.ws.close() + + async def receive(self, msg): + # Ignore incoming info from websocket + pass + + async def run(self): + + subs = Subscriber( + client = self.pulsar_client, topic = self.queue, + consumer_name = self.consumer, subscription = self.subscriber, + schema = EntityContexts + ) + + await subs.start() + + id = str(uuid.uuid4()) + q = await subs.subscribe_all(id) + + while self.running.get(): + try: + + resp = await asyncio.wait_for(q.get(), timeout=0.5) + await self.ws.send_json(serialize_entity_contexts(resp)) + + except TimeoutError: + continue + + except queue.Empty: + continue + + except Exception as e: + print(f"Exception: {str(e)}", flush=True) + break + + await subs.unsubscribe_all(id) + + await subs.stop() + + await self.ws.close() + self.running.stop() + diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py new file mode 100644 index 00000000..22d18904 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/dispatch/entity_contexts_import.py @@ -0,0 +1,67 @@ + +import asyncio +import uuid +from aiohttp import WSMsgType + +from ... schema import Metadata +from ... schema import EntityContexts, EntityContext +from ... base import Publisher + +from . serialize import to_subgraph, to_value + +class EntityContextsImport: + + def __init__( + self, ws, running, pulsar_client, queue + ): + + self.ws = ws + self.running = running + + self.publisher = Publisher( + pulsar_client, topic = queue, schema = EntityContexts + ) + + async def start(self): + await self.publisher.start() + + async def destroy(self): + self.running.stop() + + if self.ws: + await self.ws.close() + + await self.publisher.stop() + + async def receive(self, msg): + + data = msg.json() + + elt = EntityContexts( + metadata=Metadata( + id=data["metadata"]["id"], + metadata=to_subgraph(data["metadata"]["metadata"]), + user=data["metadata"]["user"], + collection=data["metadata"]["collection"], + ), + entities=[ + EntityContext( + entity=to_value(ent["entity"]), + context=ent["context"], + ) + for ent in data["entities"] + ] + ) + + await self.publisher.send(None, elt) + + async def run(self): + + while self.running.get(): + await asyncio.sleep(0.5) + + if self.ws: + await self.ws.close() + + self.ws = None + diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py index 70e78c87..85174460 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_import.py @@ -22,6 +22,9 @@ class GraphEmbeddingsImport: pulsar_client, topic = queue, schema = GraphEmbeddings ) + async def start(self): + await self.publisher.start() + async def destroy(self): self.running.stop() diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index 7896d588..8223461a 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -23,10 +23,12 @@ from . document_load import DocumentLoad from . triples_export import TriplesExport from . graph_embeddings_export import GraphEmbeddingsExport from . document_embeddings_export import DocumentEmbeddingsExport +from . entity_contexts_export import EntityContextsExport from . triples_import import TriplesImport from . graph_embeddings_import import GraphEmbeddingsImport from . document_embeddings_import import DocumentEmbeddingsImport +from . entity_contexts_import import EntityContextsImport from . mux import Mux @@ -57,12 +59,14 @@ export_dispatchers = { "triples": TriplesExport, "graph-embeddings": GraphEmbeddingsExport, "document-embeddings": DocumentEmbeddingsExport, + "entity-contexts": EntityContextsExport, } import_dispatchers = { "triples": TriplesImport, "graph-embeddings": GraphEmbeddingsImport, "document-embeddings": DocumentEmbeddingsImport, + "entity-contexts": EntityContextsImport, } class DispatcherWrapper: @@ -146,11 +150,17 @@ class DispatcherManager: intf_defs = self.flows[flow]["interfaces"] - if kind not in intf_defs: + # FIXME: The -store bit, does it make sense? + if kind == "entity-contexts": + int_kind = kind + "-load" + else: + int_kind = kind + "-store" + + if int_kind not in intf_defs: raise RuntimeError("This kind not supported by flow") # FIXME: The -store bit, does it make sense? - qconfig = intf_defs[kind + "-store"] + qconfig = intf_defs[int_kind] id = str(uuid.uuid4()) dispatcher = import_dispatchers[kind]( @@ -160,6 +170,8 @@ class DispatcherManager: queue = qconfig, ) + await dispatcher.start() + return dispatcher async def process_flow_export(self, ws, running, params): @@ -177,11 +189,16 @@ class DispatcherManager: intf_defs = self.flows[flow]["interfaces"] - if kind not in intf_defs: + # FIXME: The -store bit, does it make sense? + if kind == "entity-contexts": + int_kind = kind + "-load" + else: + int_kind = kind + "-store" + + if int_kind not in intf_defs: raise RuntimeError("This kind not supported by flow") - # FIXME: The -store bit, does it make sense? - qconfig = intf_defs[kind + "-store"] + qconfig = intf_defs[int_kind] id = str(uuid.uuid4()) dispatcher = export_dispatchers[kind]( diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py b/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py index 45ae55d7..bde3553a 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py @@ -63,6 +63,23 @@ def serialize_graph_embeddings(message): ], } +def serialize_entity_contexts(message): + return { + "metadata": { + "id": message.metadata.id, + "metadata": serialize_subgraph(message.metadata.metadata), + "user": message.metadata.user, + "collection": message.metadata.collection, + }, + "entities": [ + { + "context": entity.context, + "entity": serialize_value(entity.entity), + } + for entity in message.entities + ], + } + def serialize_document_embeddings(message): return { "metadata": { diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py index 9b59a0ed..687b424a 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/triples_import.py @@ -22,6 +22,9 @@ class TriplesImport: pulsar_client, topic = queue, schema = Triples ) + async def start(self): + await self.publisher.start() + async def destroy(self): self.running.stop() diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py index a52f400e..88872e8d 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py @@ -16,7 +16,10 @@ default_model = 'gpt-3.5-turbo' default_temperature = 0.0 default_max_output = 4096 default_api_key = os.getenv("OPENAI_TOKEN") -default_base_url = os.getenv("OPENAI_BASE_URL", None) +default_base_url = os.getenv("OPENAI_BASE_URL") + +if default_base_url is None or default_base_url == "": + default_base_url = "https://api.openai.com/v1" class Processor(LlmService): @@ -24,7 +27,7 @@ class Processor(LlmService): model = params.get("model", default_model) api_key = params.get("api_key", default_api_key) - base_url = params.get("base_url", default_base_url) + base_url = params.get("url", default_base_url) temperature = params.get("temperature", default_temperature) max_output = params.get("max_output", default_max_output) @@ -43,7 +46,11 @@ class Processor(LlmService): self.model = model self.temperature = temperature self.max_output = max_output - self.openai = OpenAI(base_url=base_url, api_key=api_key) + + if base_url: + self.openai = OpenAI(base_url=base_url, api_key=api_key) + else: + self.openai = OpenAI(api_key=api_key) print("Initialised", flush=True) @@ -102,7 +109,7 @@ class Processor(LlmService): # Apart from rate limits, treat all exceptions as unrecoverable - print(f"Exception: {e}") + print(f"Exception: {type(e)} {e}") raise e @staticmethod diff --git a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py index 5fa436b8..8cf0b719 100755 --- a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py +++ b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py @@ -10,39 +10,42 @@ import pytesseract from pdf2image import convert_from_bytes from ... schema import Document, TextDocument, Metadata -from ... schema import document_ingest_queue, text_ingest_queue -from ... log_level import LogLevel -from ... base import ConsumerProducer +from ... base import FlowProcessor, ConsumerSpec, ProducerSpec -module = "ocr" +default_ident = "pdf-decoder" -default_input_queue = document_ingest_queue -default_output_queue = text_ingest_queue -default_subscriber = module - -class Processor(ConsumerProducer): +class Processor(FlowProcessor): def __init__(self, **params): - input_queue = params.get("input_queue", default_input_queue) - output_queue = params.get("output_queue", default_output_queue) - subscriber = params.get("subscriber", default_subscriber) + id = params.get("id", default_ident) super(Processor, self).__init__( **params | { - "input_queue": input_queue, - "output_queue": output_queue, - "subscriber": subscriber, - "input_schema": Document, - "output_schema": TextDocument, + "id": id, } ) + self.register_specification( + ConsumerSpec( + name = "input", + schema = Document, + handler = self.on_message, + ) + ) + + self.register_specification( + ProducerSpec( + name = "output", + schema = TextDocument, + ) + ) + print("PDF OCR inited") - async def handle(self, msg): + async def on_message(self, msg, consumer, flow): - print("PDF message received") + print("PDF message received", flush=True) v = msg.value() @@ -65,19 +68,15 @@ class Processor(ConsumerProducer): text=text.encode("utf-8"), ) - await self.send(r) + await flow("output").send(r) print("Done.", flush=True) @staticmethod def add_args(parser): - - ConsumerProducer.add_args( - parser, default_input_queue, default_subscriber, - default_output_queue, - ) + FlowProcessor.add_args(parser) def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__)