Fix/kg core save (#267)

* Add a 'localhost' Pulsar endpoint for docker env

* - Fix broken socket endpoint streamers
- Add unused listener endpoints for publisher/subscriber
- Fix graph embedding serialisation

* Fix GE load

* Remove Gossip settling delay, this is single-node Cassandra.
This commit is contained in:
cybermaggedon 2025-01-13 14:42:33 +00:00 committed by GitHub
parent cd9a208432
commit 1280af3eff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 67 additions and 15 deletions

View file

@ -100,7 +100,7 @@ local url = import "values/url.jsonnet";
"managedLedgerDefaultWriteQuorum": "1",
"managedLedgerDefaultAckQuorum": "1",
"advertisedAddress": "pulsar",
"advertisedListeners": "external:pulsar://pulsar:6650",
"advertisedListeners": "external:pulsar://pulsar:6650,localhost:pulsar://localhost:6650",
"PULSAR_MEM": "-Xms512m -Xmx512m -XX:MaxDirectMemorySize=256m",
})
.with_port(6650, 6650, "pulsar")

View file

@ -13,7 +13,7 @@ local images = import "values/images.jsonnet";
engine.container("cassandra")
.with_image(images.cassandra)
.with_environment({
JVM_OPTS: "-Xms300M -Xmx300M",
JVM_OPTS: "-Xms300M -Xmx300M -Dcassandra.skip_wait_for_gossip_to_settle=0",
})
.with_limits("1.0", "1000M")
.with_reservations("0.5", "1000M")

View file

@ -51,8 +51,13 @@ async def load_ge(running, queue, url):
"user": msg["m"]["u"],
"collection": msg["m"]["c"],
},
"vectors": msg["v"],
"entity": msg["e"],
"entities": [
{
"entity": ent["e"],
"vectors": ent["v"],
}
for ent in msg["e"]
],
}
try:

View file

@ -57,8 +57,13 @@ async def fetch_ge(running, queue, user, collection, url):
"u": data["metadata"]["user"],
"c": data["metadata"]["collection"],
},
"v": data["vectors"],
"e": data["entity"],
"e": [
{
"e": ent["entity"],
"v": ent["vectors"],
}
for ent in data["entities"]
]
}
])
if msg.type == aiohttp.WSMsgType.ERROR:

View file

@ -5,7 +5,7 @@ import uuid
from aiohttp import WSMsgType
from .. schema import Metadata
from .. schema import GraphEmbeddings
from .. schema import GraphEmbeddings, EntityEmbeddings
from .. schema import graph_embeddings_store_queue
from . publisher import Publisher
@ -50,8 +50,13 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
entity=to_value(data["entity"]),
vectors=data["vectors"],
entities=[
EntityEmbeddings(
entity=to_value(ent["entity"]),
vectors=ent["vectors"],
)
for ent in data["entities"]
]
)
self.publisher.send(None, elt)

View file

@ -29,6 +29,16 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
schema=JsonSchema(GraphEmbeddings)
)
async def listener(self, ws, running):
worker = asyncio.create_task(
self.async_thread(ws, running)
)
await super(GraphEmbeddingsStreamEndpoint, self).listener(ws, running)
await worker
async def start(self):
self.subscriber.start()
@ -44,6 +54,9 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
resp = await asyncio.to_thread(q.get, timeout=0.5)
await ws.send_json(serialize_graph_embeddings(resp))
except TimeoutError:
continue
except queue.Empty:
continue

View file

@ -7,12 +7,13 @@ import threading
class Publisher:
def __init__(self, pulsar_host, topic, schema=None, max_size=10,
chunking_enabled=True):
chunking_enabled=True, listener=None):
self.pulsar_host = pulsar_host
self.topic = topic
self.schema = schema
self.q = queue.Queue(maxsize=max_size)
self.chunking_enabled = chunking_enabled
self.listener_name = listener
def start(self):
self.task = threading.Thread(target=self.run)
@ -25,7 +26,7 @@ class Publisher:
try:
client = pulsar.Client(
self.pulsar_host,
self.pulsar_host, listener_name=self.listener_name
)
producer = client.create_producer(

View file

@ -60,7 +60,10 @@ class ServiceRequestor:
while True:
try:
resp = await asyncio.to_thread(q.get, timeout=self.timeout)
resp = await asyncio.to_thread(
q.get,
timeout=self.timeout
)
except Exception as e:
raise RuntimeError("Timeout")

View file

@ -51,7 +51,12 @@ def serialize_graph_embeddings(message):
"user": message.metadata.user,
"collection": message.metadata.collection,
},
"vectors": message.vectors,
"entity": serialize_value(message.entity),
"entities": [
{
"vectors": entity.vectors,
"entity": serialize_value(entity.entity),
}
for entity in message.entities
],
}

View file

@ -7,7 +7,7 @@ import time
class Subscriber:
def __init__(self, pulsar_host, topic, subscription, consumer_name,
schema=None, max_size=100):
schema=None, max_size=100, listener=None):
self.pulsar_host = pulsar_host
self.topic = topic
self.subscription = subscription
@ -17,6 +17,7 @@ class Subscriber:
self.full = {}
self.max_size = max_size
self.lock = threading.Lock()
self.listener_name = listener
def start(self):
self.task = threading.Thread(target=self.run)
@ -30,6 +31,7 @@ class Subscriber:
client = pulsar.Client(
self.pulsar_host,
listener_name=self.listener_name,
)
consumer = client.subscribe(

View file

@ -27,6 +27,16 @@ class TriplesStreamEndpoint(SocketEndpoint):
schema=JsonSchema(Triples)
)
async def listener(self, ws, running):
worker = asyncio.create_task(
self.async_thread(ws, running)
)
await super(TriplesStreamEndpoint, self).listener(ws, running)
await worker
async def start(self):
self.subscriber.start()
@ -42,6 +52,9 @@ class TriplesStreamEndpoint(SocketEndpoint):
resp = await asyncio.to_thread(q.get, timeout=0.5)
await ws.send_json(serialize_triples(resp))
except TimeoutError:
continue
except queue.Empty:
continue