mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Fix/kg core save (#267)
* Add a 'localhost' Pulsar endpoint for docker env * - Fix broken socket endpoint streamers - Add unused listener endpoints for publisher/subscriber - Fix graph embedding serialisation * Fix GE load * Remove Gossip settling delay, this is single-node Cassandra.
This commit is contained in:
parent
cd9a208432
commit
1280af3eff
11 changed files with 67 additions and 15 deletions
|
|
@ -100,7 +100,7 @@ local url = import "values/url.jsonnet";
|
|||
"managedLedgerDefaultWriteQuorum": "1",
|
||||
"managedLedgerDefaultAckQuorum": "1",
|
||||
"advertisedAddress": "pulsar",
|
||||
"advertisedListeners": "external:pulsar://pulsar:6650",
|
||||
"advertisedListeners": "external:pulsar://pulsar:6650,localhost:pulsar://localhost:6650",
|
||||
"PULSAR_MEM": "-Xms512m -Xmx512m -XX:MaxDirectMemorySize=256m",
|
||||
})
|
||||
.with_port(6650, 6650, "pulsar")
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ local images = import "values/images.jsonnet";
|
|||
engine.container("cassandra")
|
||||
.with_image(images.cassandra)
|
||||
.with_environment({
|
||||
JVM_OPTS: "-Xms300M -Xmx300M",
|
||||
JVM_OPTS: "-Xms300M -Xmx300M -Dcassandra.skip_wait_for_gossip_to_settle=0",
|
||||
})
|
||||
.with_limits("1.0", "1000M")
|
||||
.with_reservations("0.5", "1000M")
|
||||
|
|
|
|||
|
|
@ -51,8 +51,13 @@ async def load_ge(running, queue, url):
|
|||
"user": msg["m"]["u"],
|
||||
"collection": msg["m"]["c"],
|
||||
},
|
||||
"vectors": msg["v"],
|
||||
"entity": msg["e"],
|
||||
"entities": [
|
||||
{
|
||||
"entity": ent["e"],
|
||||
"vectors": ent["v"],
|
||||
}
|
||||
for ent in msg["e"]
|
||||
],
|
||||
}
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -57,8 +57,13 @@ async def fetch_ge(running, queue, user, collection, url):
|
|||
"u": data["metadata"]["user"],
|
||||
"c": data["metadata"]["collection"],
|
||||
},
|
||||
"v": data["vectors"],
|
||||
"e": data["entity"],
|
||||
"e": [
|
||||
{
|
||||
"e": ent["entity"],
|
||||
"v": ent["vectors"],
|
||||
}
|
||||
for ent in data["entities"]
|
||||
]
|
||||
}
|
||||
])
|
||||
if msg.type == aiohttp.WSMsgType.ERROR:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import uuid
|
|||
from aiohttp import WSMsgType
|
||||
|
||||
from .. schema import Metadata
|
||||
from .. schema import GraphEmbeddings
|
||||
from .. schema import GraphEmbeddings, EntityEmbeddings
|
||||
from .. schema import graph_embeddings_store_queue
|
||||
|
||||
from . publisher import Publisher
|
||||
|
|
@ -50,8 +50,13 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
|
|||
user=data["metadata"]["user"],
|
||||
collection=data["metadata"]["collection"],
|
||||
),
|
||||
entity=to_value(data["entity"]),
|
||||
vectors=data["vectors"],
|
||||
entities=[
|
||||
EntityEmbeddings(
|
||||
entity=to_value(ent["entity"]),
|
||||
vectors=ent["vectors"],
|
||||
)
|
||||
for ent in data["entities"]
|
||||
]
|
||||
)
|
||||
|
||||
self.publisher.send(None, elt)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,16 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
|
|||
schema=JsonSchema(GraphEmbeddings)
|
||||
)
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
||||
worker = asyncio.create_task(
|
||||
self.async_thread(ws, running)
|
||||
)
|
||||
|
||||
await super(GraphEmbeddingsStreamEndpoint, self).listener(ws, running)
|
||||
|
||||
await worker
|
||||
|
||||
async def start(self):
|
||||
|
||||
self.subscriber.start()
|
||||
|
|
@ -44,6 +54,9 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
|
|||
resp = await asyncio.to_thread(q.get, timeout=0.5)
|
||||
await ws.send_json(serialize_graph_embeddings(resp))
|
||||
|
||||
except TimeoutError:
|
||||
continue
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -7,12 +7,13 @@ import threading
|
|||
class Publisher:
|
||||
|
||||
def __init__(self, pulsar_host, topic, schema=None, max_size=10,
|
||||
chunking_enabled=True):
|
||||
chunking_enabled=True, listener=None):
|
||||
self.pulsar_host = pulsar_host
|
||||
self.topic = topic
|
||||
self.schema = schema
|
||||
self.q = queue.Queue(maxsize=max_size)
|
||||
self.chunking_enabled = chunking_enabled
|
||||
self.listener_name = listener
|
||||
|
||||
def start(self):
|
||||
self.task = threading.Thread(target=self.run)
|
||||
|
|
@ -25,7 +26,7 @@ class Publisher:
|
|||
try:
|
||||
|
||||
client = pulsar.Client(
|
||||
self.pulsar_host,
|
||||
self.pulsar_host, listener_name=self.listener_name
|
||||
)
|
||||
|
||||
producer = client.create_producer(
|
||||
|
|
|
|||
|
|
@ -60,7 +60,10 @@ class ServiceRequestor:
|
|||
while True:
|
||||
|
||||
try:
|
||||
resp = await asyncio.to_thread(q.get, timeout=self.timeout)
|
||||
resp = await asyncio.to_thread(
|
||||
q.get,
|
||||
timeout=self.timeout
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Timeout")
|
||||
|
||||
|
|
|
|||
|
|
@ -51,7 +51,12 @@ def serialize_graph_embeddings(message):
|
|||
"user": message.metadata.user,
|
||||
"collection": message.metadata.collection,
|
||||
},
|
||||
"vectors": message.vectors,
|
||||
"entity": serialize_value(message.entity),
|
||||
"entities": [
|
||||
{
|
||||
"vectors": entity.vectors,
|
||||
"entity": serialize_value(entity.entity),
|
||||
}
|
||||
for entity in message.entities
|
||||
],
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import time
|
|||
class Subscriber:
|
||||
|
||||
def __init__(self, pulsar_host, topic, subscription, consumer_name,
|
||||
schema=None, max_size=100):
|
||||
schema=None, max_size=100, listener=None):
|
||||
self.pulsar_host = pulsar_host
|
||||
self.topic = topic
|
||||
self.subscription = subscription
|
||||
|
|
@ -17,6 +17,7 @@ class Subscriber:
|
|||
self.full = {}
|
||||
self.max_size = max_size
|
||||
self.lock = threading.Lock()
|
||||
self.listener_name = listener
|
||||
|
||||
def start(self):
|
||||
self.task = threading.Thread(target=self.run)
|
||||
|
|
@ -30,6 +31,7 @@ class Subscriber:
|
|||
|
||||
client = pulsar.Client(
|
||||
self.pulsar_host,
|
||||
listener_name=self.listener_name,
|
||||
)
|
||||
|
||||
consumer = client.subscribe(
|
||||
|
|
|
|||
|
|
@ -27,6 +27,16 @@ class TriplesStreamEndpoint(SocketEndpoint):
|
|||
schema=JsonSchema(Triples)
|
||||
)
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
||||
worker = asyncio.create_task(
|
||||
self.async_thread(ws, running)
|
||||
)
|
||||
|
||||
await super(TriplesStreamEndpoint, self).listener(ws, running)
|
||||
|
||||
await worker
|
||||
|
||||
async def start(self):
|
||||
|
||||
self.subscriber.start()
|
||||
|
|
@ -42,6 +52,9 @@ class TriplesStreamEndpoint(SocketEndpoint):
|
|||
resp = await asyncio.to_thread(q.get, timeout=0.5)
|
||||
await ws.send_json(serialize_triples(resp))
|
||||
|
||||
except TimeoutError:
|
||||
continue
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue