Fix/kg core save (#267)

* Add a 'localhost' Pulsar endpoint for docker env

* - Fix broken socket endpoint streamers
- Add unused listener endpoints for publisher/subscriber
- Fix graph embedding serialisation

* Fix GE load

* Remove Gossip settling delay, this is single-node Cassandra.
This commit is contained in:
cybermaggedon 2025-01-13 14:42:33 +00:00 committed by GitHub
parent cd9a208432
commit 1280af3eff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 67 additions and 15 deletions

View file

@ -100,7 +100,7 @@ local url = import "values/url.jsonnet";
"managedLedgerDefaultWriteQuorum": "1", "managedLedgerDefaultWriteQuorum": "1",
"managedLedgerDefaultAckQuorum": "1", "managedLedgerDefaultAckQuorum": "1",
"advertisedAddress": "pulsar", "advertisedAddress": "pulsar",
"advertisedListeners": "external:pulsar://pulsar:6650", "advertisedListeners": "external:pulsar://pulsar:6650,localhost:pulsar://localhost:6650",
"PULSAR_MEM": "-Xms512m -Xmx512m -XX:MaxDirectMemorySize=256m", "PULSAR_MEM": "-Xms512m -Xmx512m -XX:MaxDirectMemorySize=256m",
}) })
.with_port(6650, 6650, "pulsar") .with_port(6650, 6650, "pulsar")

View file

@ -13,7 +13,7 @@ local images = import "values/images.jsonnet";
engine.container("cassandra") engine.container("cassandra")
.with_image(images.cassandra) .with_image(images.cassandra)
.with_environment({ .with_environment({
JVM_OPTS: "-Xms300M -Xmx300M", JVM_OPTS: "-Xms300M -Xmx300M -Dcassandra.skip_wait_for_gossip_to_settle=0",
}) })
.with_limits("1.0", "1000M") .with_limits("1.0", "1000M")
.with_reservations("0.5", "1000M") .with_reservations("0.5", "1000M")

View file

@ -51,8 +51,13 @@ async def load_ge(running, queue, url):
"user": msg["m"]["u"], "user": msg["m"]["u"],
"collection": msg["m"]["c"], "collection": msg["m"]["c"],
}, },
"vectors": msg["v"], "entities": [
"entity": msg["e"], {
"entity": ent["e"],
"vectors": ent["v"],
}
for ent in msg["e"]
],
} }
try: try:

View file

@ -57,8 +57,13 @@ async def fetch_ge(running, queue, user, collection, url):
"u": data["metadata"]["user"], "u": data["metadata"]["user"],
"c": data["metadata"]["collection"], "c": data["metadata"]["collection"],
}, },
"v": data["vectors"], "e": [
"e": data["entity"], {
"e": ent["entity"],
"v": ent["vectors"],
}
for ent in data["entities"]
]
} }
]) ])
if msg.type == aiohttp.WSMsgType.ERROR: if msg.type == aiohttp.WSMsgType.ERROR:

View file

@ -5,7 +5,7 @@ import uuid
from aiohttp import WSMsgType from aiohttp import WSMsgType
from .. schema import Metadata from .. schema import Metadata
from .. schema import GraphEmbeddings from .. schema import GraphEmbeddings, EntityEmbeddings
from .. schema import graph_embeddings_store_queue from .. schema import graph_embeddings_store_queue
from . publisher import Publisher from . publisher import Publisher
@ -50,8 +50,13 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
user=data["metadata"]["user"], user=data["metadata"]["user"],
collection=data["metadata"]["collection"], collection=data["metadata"]["collection"],
), ),
entity=to_value(data["entity"]), entities=[
vectors=data["vectors"], EntityEmbeddings(
entity=to_value(ent["entity"]),
vectors=ent["vectors"],
)
for ent in data["entities"]
]
) )
self.publisher.send(None, elt) self.publisher.send(None, elt)

View file

@ -29,6 +29,16 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
schema=JsonSchema(GraphEmbeddings) schema=JsonSchema(GraphEmbeddings)
) )
async def listener(self, ws, running):
worker = asyncio.create_task(
self.async_thread(ws, running)
)
await super(GraphEmbeddingsStreamEndpoint, self).listener(ws, running)
await worker
async def start(self): async def start(self):
self.subscriber.start() self.subscriber.start()
@ -44,6 +54,9 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
resp = await asyncio.to_thread(q.get, timeout=0.5) resp = await asyncio.to_thread(q.get, timeout=0.5)
await ws.send_json(serialize_graph_embeddings(resp)) await ws.send_json(serialize_graph_embeddings(resp))
except TimeoutError:
continue
except queue.Empty: except queue.Empty:
continue continue

View file

@ -7,12 +7,13 @@ import threading
class Publisher: class Publisher:
def __init__(self, pulsar_host, topic, schema=None, max_size=10, def __init__(self, pulsar_host, topic, schema=None, max_size=10,
chunking_enabled=True): chunking_enabled=True, listener=None):
self.pulsar_host = pulsar_host self.pulsar_host = pulsar_host
self.topic = topic self.topic = topic
self.schema = schema self.schema = schema
self.q = queue.Queue(maxsize=max_size) self.q = queue.Queue(maxsize=max_size)
self.chunking_enabled = chunking_enabled self.chunking_enabled = chunking_enabled
self.listener_name = listener
def start(self): def start(self):
self.task = threading.Thread(target=self.run) self.task = threading.Thread(target=self.run)
@ -25,7 +26,7 @@ class Publisher:
try: try:
client = pulsar.Client( client = pulsar.Client(
self.pulsar_host, self.pulsar_host, listener_name=self.listener_name
) )
producer = client.create_producer( producer = client.create_producer(

View file

@ -60,7 +60,10 @@ class ServiceRequestor:
while True: while True:
try: try:
resp = await asyncio.to_thread(q.get, timeout=self.timeout) resp = await asyncio.to_thread(
q.get,
timeout=self.timeout
)
except Exception as e: except Exception as e:
raise RuntimeError("Timeout") raise RuntimeError("Timeout")

View file

@ -51,7 +51,12 @@ def serialize_graph_embeddings(message):
"user": message.metadata.user, "user": message.metadata.user,
"collection": message.metadata.collection, "collection": message.metadata.collection,
}, },
"vectors": message.vectors, "entities": [
"entity": serialize_value(message.entity), {
"vectors": entity.vectors,
"entity": serialize_value(entity.entity),
}
for entity in message.entities
],
} }

View file

@ -7,7 +7,7 @@ import time
class Subscriber: class Subscriber:
def __init__(self, pulsar_host, topic, subscription, consumer_name, def __init__(self, pulsar_host, topic, subscription, consumer_name,
schema=None, max_size=100): schema=None, max_size=100, listener=None):
self.pulsar_host = pulsar_host self.pulsar_host = pulsar_host
self.topic = topic self.topic = topic
self.subscription = subscription self.subscription = subscription
@ -17,6 +17,7 @@ class Subscriber:
self.full = {} self.full = {}
self.max_size = max_size self.max_size = max_size
self.lock = threading.Lock() self.lock = threading.Lock()
self.listener_name = listener
def start(self): def start(self):
self.task = threading.Thread(target=self.run) self.task = threading.Thread(target=self.run)
@ -30,6 +31,7 @@ class Subscriber:
client = pulsar.Client( client = pulsar.Client(
self.pulsar_host, self.pulsar_host,
listener_name=self.listener_name,
) )
consumer = client.subscribe( consumer = client.subscribe(

View file

@ -27,6 +27,16 @@ class TriplesStreamEndpoint(SocketEndpoint):
schema=JsonSchema(Triples) schema=JsonSchema(Triples)
) )
async def listener(self, ws, running):
worker = asyncio.create_task(
self.async_thread(ws, running)
)
await super(TriplesStreamEndpoint, self).listener(ws, running)
await worker
async def start(self): async def start(self):
self.subscriber.start() self.subscriber.start()
@ -42,6 +52,9 @@ class TriplesStreamEndpoint(SocketEndpoint):
resp = await asyncio.to_thread(q.get, timeout=0.5) resp = await asyncio.to_thread(q.get, timeout=0.5)
await ws.send_json(serialize_triples(resp)) await ws.send_json(serialize_triples(resp))
except TimeoutError:
continue
except queue.Empty: except queue.Empty:
continue continue