mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-26 08:56:21 +02:00
Fix/kg core save (#267)
* Add a 'localhost' Pulsar endpoint for docker env * - Fix broken socket endpoint streamers - Add unused listener endpoints for publisher/subscriber - Fix graph embedding serialisation * Fix GE load * Remove Gossip settling delay, this is single-node Cassandra.
This commit is contained in:
parent
cd9a208432
commit
1280af3eff
11 changed files with 67 additions and 15 deletions
|
|
@ -5,7 +5,7 @@ import uuid
|
|||
from aiohttp import WSMsgType
|
||||
|
||||
from .. schema import Metadata
|
||||
from .. schema import GraphEmbeddings
|
||||
from .. schema import GraphEmbeddings, EntityEmbeddings
|
||||
from .. schema import graph_embeddings_store_queue
|
||||
|
||||
from . publisher import Publisher
|
||||
|
|
@ -50,8 +50,13 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
|
|||
user=data["metadata"]["user"],
|
||||
collection=data["metadata"]["collection"],
|
||||
),
|
||||
entity=to_value(data["entity"]),
|
||||
vectors=data["vectors"],
|
||||
entities=[
|
||||
EntityEmbeddings(
|
||||
entity=to_value(ent["entity"]),
|
||||
vectors=ent["vectors"],
|
||||
)
|
||||
for ent in data["entities"]
|
||||
]
|
||||
)
|
||||
|
||||
self.publisher.send(None, elt)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,16 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
|
|||
schema=JsonSchema(GraphEmbeddings)
|
||||
)
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
||||
worker = asyncio.create_task(
|
||||
self.async_thread(ws, running)
|
||||
)
|
||||
|
||||
await super(GraphEmbeddingsStreamEndpoint, self).listener(ws, running)
|
||||
|
||||
await worker
|
||||
|
||||
async def start(self):
|
||||
|
||||
self.subscriber.start()
|
||||
|
|
@ -44,6 +54,9 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
|
|||
resp = await asyncio.to_thread(q.get, timeout=0.5)
|
||||
await ws.send_json(serialize_graph_embeddings(resp))
|
||||
|
||||
except TimeoutError:
|
||||
continue
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -7,12 +7,13 @@ import threading
|
|||
class Publisher:
|
||||
|
||||
def __init__(self, pulsar_host, topic, schema=None, max_size=10,
|
||||
chunking_enabled=True):
|
||||
chunking_enabled=True, listener=None):
|
||||
self.pulsar_host = pulsar_host
|
||||
self.topic = topic
|
||||
self.schema = schema
|
||||
self.q = queue.Queue(maxsize=max_size)
|
||||
self.chunking_enabled = chunking_enabled
|
||||
self.listener_name = listener
|
||||
|
||||
def start(self):
|
||||
self.task = threading.Thread(target=self.run)
|
||||
|
|
@ -25,7 +26,7 @@ class Publisher:
|
|||
try:
|
||||
|
||||
client = pulsar.Client(
|
||||
self.pulsar_host,
|
||||
self.pulsar_host, listener_name=self.listener_name
|
||||
)
|
||||
|
||||
producer = client.create_producer(
|
||||
|
|
|
|||
|
|
@ -60,7 +60,10 @@ class ServiceRequestor:
|
|||
while True:
|
||||
|
||||
try:
|
||||
resp = await asyncio.to_thread(q.get, timeout=self.timeout)
|
||||
resp = await asyncio.to_thread(
|
||||
q.get,
|
||||
timeout=self.timeout
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Timeout")
|
||||
|
||||
|
|
|
|||
|
|
@ -51,7 +51,12 @@ def serialize_graph_embeddings(message):
|
|||
"user": message.metadata.user,
|
||||
"collection": message.metadata.collection,
|
||||
},
|
||||
"vectors": message.vectors,
|
||||
"entity": serialize_value(message.entity),
|
||||
"entities": [
|
||||
{
|
||||
"vectors": entity.vectors,
|
||||
"entity": serialize_value(entity.entity),
|
||||
}
|
||||
for entity in message.entities
|
||||
],
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import time
|
|||
class Subscriber:
|
||||
|
||||
def __init__(self, pulsar_host, topic, subscription, consumer_name,
|
||||
schema=None, max_size=100):
|
||||
schema=None, max_size=100, listener=None):
|
||||
self.pulsar_host = pulsar_host
|
||||
self.topic = topic
|
||||
self.subscription = subscription
|
||||
|
|
@ -17,6 +17,7 @@ class Subscriber:
|
|||
self.full = {}
|
||||
self.max_size = max_size
|
||||
self.lock = threading.Lock()
|
||||
self.listener_name = listener
|
||||
|
||||
def start(self):
|
||||
self.task = threading.Thread(target=self.run)
|
||||
|
|
@ -30,6 +31,7 @@ class Subscriber:
|
|||
|
||||
client = pulsar.Client(
|
||||
self.pulsar_host,
|
||||
listener_name=self.listener_name,
|
||||
)
|
||||
|
||||
consumer = client.subscribe(
|
||||
|
|
|
|||
|
|
@ -27,6 +27,16 @@ class TriplesStreamEndpoint(SocketEndpoint):
|
|||
schema=JsonSchema(Triples)
|
||||
)
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
||||
worker = asyncio.create_task(
|
||||
self.async_thread(ws, running)
|
||||
)
|
||||
|
||||
await super(TriplesStreamEndpoint, self).listener(ws, running)
|
||||
|
||||
await worker
|
||||
|
||||
async def start(self):
|
||||
|
||||
self.subscriber.start()
|
||||
|
|
@ -42,6 +52,9 @@ class TriplesStreamEndpoint(SocketEndpoint):
|
|||
resp = await asyncio.to_thread(q.get, timeout=0.5)
|
||||
await ws.send_json(serialize_triples(resp))
|
||||
|
||||
except TimeoutError:
|
||||
continue
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue