mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-09 06:45:13 +02:00
Merge pull request #391 from trustgraph-ai/release/v0.23
0.23 -> master
This commit is contained in:
commit
232062b54c
9 changed files with 218 additions and 35 deletions
|
|
@ -22,6 +22,9 @@ class DocumentEmbeddingsImport:
|
|||
pulsar_client, topic = queue, schema = DocumentEmbeddings
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
await self.publisher.start()
|
||||
|
||||
async def destroy(self):
|
||||
self.running.stop()
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,67 @@
|
|||
|
||||
import asyncio
|
||||
import queue
|
||||
import uuid
|
||||
|
||||
from ... schema import EntityContexts
|
||||
from ... base import Subscriber
|
||||
|
||||
from . serialize import serialize_entity_contexts
|
||||
|
||||
class EntityContextsExport:
|
||||
|
||||
def __init__(
|
||||
self, ws, running, pulsar_client, queue, consumer, subscriber
|
||||
):
|
||||
|
||||
self.ws = ws
|
||||
self.running = running
|
||||
self.pulsar_client = pulsar_client
|
||||
self.queue = queue
|
||||
self.consumer = consumer
|
||||
self.subscriber = subscriber
|
||||
|
||||
async def destroy(self):
|
||||
self.running.stop()
|
||||
await self.ws.close()
|
||||
|
||||
async def receive(self, msg):
|
||||
# Ignore incoming info from websocket
|
||||
pass
|
||||
|
||||
async def run(self):
|
||||
|
||||
subs = Subscriber(
|
||||
client = self.pulsar_client, topic = self.queue,
|
||||
consumer_name = self.consumer, subscription = self.subscriber,
|
||||
schema = EntityContexts
|
||||
)
|
||||
|
||||
await subs.start()
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
q = await subs.subscribe_all(id)
|
||||
|
||||
while self.running.get():
|
||||
try:
|
||||
|
||||
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
||||
await self.ws.send_json(serialize_entity_contexts(resp))
|
||||
|
||||
except TimeoutError:
|
||||
continue
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Exception: {str(e)}", flush=True)
|
||||
break
|
||||
|
||||
await subs.unsubscribe_all(id)
|
||||
|
||||
await subs.stop()
|
||||
|
||||
await self.ws.close()
|
||||
self.running.stop()
|
||||
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
|
||||
import asyncio
|
||||
import uuid
|
||||
from aiohttp import WSMsgType
|
||||
|
||||
from ... schema import Metadata
|
||||
from ... schema import EntityContexts, EntityContext
|
||||
from ... base import Publisher
|
||||
|
||||
from . serialize import to_subgraph, to_value
|
||||
|
||||
class EntityContextsImport:
|
||||
|
||||
def __init__(
|
||||
self, ws, running, pulsar_client, queue
|
||||
):
|
||||
|
||||
self.ws = ws
|
||||
self.running = running
|
||||
|
||||
self.publisher = Publisher(
|
||||
pulsar_client, topic = queue, schema = EntityContexts
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
await self.publisher.start()
|
||||
|
||||
async def destroy(self):
|
||||
self.running.stop()
|
||||
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
await self.publisher.stop()
|
||||
|
||||
async def receive(self, msg):
|
||||
|
||||
data = msg.json()
|
||||
|
||||
elt = EntityContexts(
|
||||
metadata=Metadata(
|
||||
id=data["metadata"]["id"],
|
||||
metadata=to_subgraph(data["metadata"]["metadata"]),
|
||||
user=data["metadata"]["user"],
|
||||
collection=data["metadata"]["collection"],
|
||||
),
|
||||
entities=[
|
||||
EntityContext(
|
||||
entity=to_value(ent["entity"]),
|
||||
context=ent["context"],
|
||||
)
|
||||
for ent in data["entities"]
|
||||
]
|
||||
)
|
||||
|
||||
await self.publisher.send(None, elt)
|
||||
|
||||
async def run(self):
|
||||
|
||||
while self.running.get():
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
self.ws = None
|
||||
|
||||
|
|
@ -22,6 +22,9 @@ class GraphEmbeddingsImport:
|
|||
pulsar_client, topic = queue, schema = GraphEmbeddings
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
await self.publisher.start()
|
||||
|
||||
async def destroy(self):
|
||||
self.running.stop()
|
||||
|
||||
|
|
|
|||
|
|
@ -23,10 +23,12 @@ from . document_load import DocumentLoad
|
|||
from . triples_export import TriplesExport
|
||||
from . graph_embeddings_export import GraphEmbeddingsExport
|
||||
from . document_embeddings_export import DocumentEmbeddingsExport
|
||||
from . entity_contexts_export import EntityContextsExport
|
||||
|
||||
from . triples_import import TriplesImport
|
||||
from . graph_embeddings_import import GraphEmbeddingsImport
|
||||
from . document_embeddings_import import DocumentEmbeddingsImport
|
||||
from . entity_contexts_import import EntityContextsImport
|
||||
|
||||
from . mux import Mux
|
||||
|
||||
|
|
@ -57,12 +59,14 @@ export_dispatchers = {
|
|||
"triples": TriplesExport,
|
||||
"graph-embeddings": GraphEmbeddingsExport,
|
||||
"document-embeddings": DocumentEmbeddingsExport,
|
||||
"entity-contexts": EntityContextsExport,
|
||||
}
|
||||
|
||||
import_dispatchers = {
|
||||
"triples": TriplesImport,
|
||||
"graph-embeddings": GraphEmbeddingsImport,
|
||||
"document-embeddings": DocumentEmbeddingsImport,
|
||||
"entity-contexts": EntityContextsImport,
|
||||
}
|
||||
|
||||
class DispatcherWrapper:
|
||||
|
|
@ -146,11 +150,17 @@ class DispatcherManager:
|
|||
|
||||
intf_defs = self.flows[flow]["interfaces"]
|
||||
|
||||
if kind not in intf_defs:
|
||||
# FIXME: The -store bit, does it make sense?
|
||||
if kind == "entity-contexts":
|
||||
int_kind = kind + "-load"
|
||||
else:
|
||||
int_kind = kind + "-store"
|
||||
|
||||
if int_kind not in intf_defs:
|
||||
raise RuntimeError("This kind not supported by flow")
|
||||
|
||||
# FIXME: The -store bit, does it make sense?
|
||||
qconfig = intf_defs[kind + "-store"]
|
||||
qconfig = intf_defs[int_kind]
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
dispatcher = import_dispatchers[kind](
|
||||
|
|
@ -160,6 +170,8 @@ class DispatcherManager:
|
|||
queue = qconfig,
|
||||
)
|
||||
|
||||
await dispatcher.start()
|
||||
|
||||
return dispatcher
|
||||
|
||||
async def process_flow_export(self, ws, running, params):
|
||||
|
|
@ -177,11 +189,16 @@ class DispatcherManager:
|
|||
|
||||
intf_defs = self.flows[flow]["interfaces"]
|
||||
|
||||
if kind not in intf_defs:
|
||||
# FIXME: The -store bit, does it make sense?
|
||||
if kind == "entity-contexts":
|
||||
int_kind = kind + "-load"
|
||||
else:
|
||||
int_kind = kind + "-store"
|
||||
|
||||
if int_kind not in intf_defs:
|
||||
raise RuntimeError("This kind not supported by flow")
|
||||
|
||||
# FIXME: The -store bit, does it make sense?
|
||||
qconfig = intf_defs[kind + "-store"]
|
||||
qconfig = intf_defs[int_kind]
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
dispatcher = export_dispatchers[kind](
|
||||
|
|
|
|||
|
|
@ -63,6 +63,23 @@ def serialize_graph_embeddings(message):
|
|||
],
|
||||
}
|
||||
|
||||
def serialize_entity_contexts(message):
|
||||
return {
|
||||
"metadata": {
|
||||
"id": message.metadata.id,
|
||||
"metadata": serialize_subgraph(message.metadata.metadata),
|
||||
"user": message.metadata.user,
|
||||
"collection": message.metadata.collection,
|
||||
},
|
||||
"entities": [
|
||||
{
|
||||
"context": entity.context,
|
||||
"entity": serialize_value(entity.entity),
|
||||
}
|
||||
for entity in message.entities
|
||||
],
|
||||
}
|
||||
|
||||
def serialize_document_embeddings(message):
|
||||
return {
|
||||
"metadata": {
|
||||
|
|
|
|||
|
|
@ -22,6 +22,9 @@ class TriplesImport:
|
|||
pulsar_client, topic = queue, schema = Triples
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
await self.publisher.start()
|
||||
|
||||
async def destroy(self):
|
||||
self.running.stop()
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,10 @@ default_model = 'gpt-3.5-turbo'
|
|||
default_temperature = 0.0
|
||||
default_max_output = 4096
|
||||
default_api_key = os.getenv("OPENAI_TOKEN")
|
||||
default_base_url = os.getenv("OPENAI_BASE_URL", None)
|
||||
default_base_url = os.getenv("OPENAI_BASE_URL")
|
||||
|
||||
if default_base_url is None or default_base_url == "":
|
||||
default_base_url = "https://api.openai.com/v1"
|
||||
|
||||
class Processor(LlmService):
|
||||
|
||||
|
|
@ -24,7 +27,7 @@ class Processor(LlmService):
|
|||
|
||||
model = params.get("model", default_model)
|
||||
api_key = params.get("api_key", default_api_key)
|
||||
base_url = params.get("base_url", default_base_url)
|
||||
base_url = params.get("url", default_base_url)
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
max_output = params.get("max_output", default_max_output)
|
||||
|
||||
|
|
@ -43,7 +46,11 @@ class Processor(LlmService):
|
|||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
self.openai = OpenAI(base_url=base_url, api_key=api_key)
|
||||
|
||||
if base_url:
|
||||
self.openai = OpenAI(base_url=base_url, api_key=api_key)
|
||||
else:
|
||||
self.openai = OpenAI(api_key=api_key)
|
||||
|
||||
print("Initialised", flush=True)
|
||||
|
||||
|
|
@ -102,7 +109,7 @@ class Processor(LlmService):
|
|||
|
||||
# Apart from rate limits, treat all exceptions as unrecoverable
|
||||
|
||||
print(f"Exception: {e}")
|
||||
print(f"Exception: {type(e)} {e}")
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -10,39 +10,42 @@ import pytesseract
|
|||
from pdf2image import convert_from_bytes
|
||||
|
||||
from ... schema import Document, TextDocument, Metadata
|
||||
from ... schema import document_ingest_queue, text_ingest_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
|
||||
module = "ocr"
|
||||
default_ident = "pdf-decoder"
|
||||
|
||||
default_input_queue = document_ingest_queue
|
||||
default_output_queue = text_ingest_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
id = params.get("id", default_ident)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Document,
|
||||
"output_schema": TextDocument,
|
||||
"id": id,
|
||||
}
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "input",
|
||||
schema = Document,
|
||||
handler = self.on_message,
|
||||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "output",
|
||||
schema = TextDocument,
|
||||
)
|
||||
)
|
||||
|
||||
print("PDF OCR inited")
|
||||
|
||||
async def handle(self, msg):
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
|
||||
print("PDF message received")
|
||||
print("PDF message received", flush=True)
|
||||
|
||||
v = msg.value()
|
||||
|
||||
|
|
@ -65,19 +68,15 @@ class Processor(ConsumerProducer):
|
|||
text=text.encode("utf-8"),
|
||||
)
|
||||
|
||||
await self.send(r)
|
||||
await flow("output").send(r)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(module, __doc__)
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue