Fix/async problem (#190)

* Back out previous change
* To multithreads
* Remove aiopulsar dependency
This commit is contained in:
cybermaggedon 2024-12-03 18:03:00 +00:00 committed by GitHub
parent 7e78aa6d91
commit 26865a515c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 149 additions and 136 deletions

View file

@ -59,7 +59,6 @@ setuptools.setup(
"ibis",
"jsonschema",
"aiohttp",
"aiopulsar-py",
"pinecone[grpc]",
],
scripts=[

View file

@ -41,15 +41,10 @@ class ServiceEndpoint:
self.operation = "service"
async def start(self, client):
async def start(self):
self.pub_task = asyncio.create_task(self.pub.run(client))
self.sub_task = asyncio.create_task(self.sub.run(client))
async def join(self):
await self.pub_task
await self.sub_task
self.pub.start()
self.sub.start()
def add_routes(self, app):
@ -87,20 +82,18 @@ class ServiceEndpoint:
print(data)
q = await self.sub.subscribe(id)
q = self.sub.subscribe(id)
await self.pub.send(
id,
self.to_request(data),
await asyncio.to_thread(
self.pub.send, id, self.to_request(data)
)
print("Request sent")
try:
resp = await asyncio.wait_for(q.get(), self.timeout)
except:
raise RuntimeError("Timeout waiting for response")
resp = await asyncio.to_thread(q.get, timeout=self.timeout)
except Exception as e:
raise RuntimeError("Timeout")
print("Response got")
print(resp)
if resp.error:
print("Error")
@ -108,8 +101,6 @@ class ServiceEndpoint:
{ "error": resp.error.message }
)
print("Send response")
return web.json_response(
self.from_response(resp)
)
@ -122,7 +113,7 @@ class ServiceEndpoint:
)
finally:
await self.sub.unsubscribe(id)
self.sub.unsubscribe(id)
class MultiResponseServiceEndpoint(ServiceEndpoint):
@ -135,11 +126,10 @@ class MultiResponseServiceEndpoint(ServiceEndpoint):
data = await request.json()
q = await self.sub.subscribe(id)
q = self.sub.subscribe(id)
await self.pub.send(
id,
self.to_request(data),
await asyncio.to_thread(
self.pub.send, id, self.to_request(data)
)
# Keeps looking at responses...
@ -147,8 +137,8 @@ class MultiResponseServiceEndpoint(ServiceEndpoint):
while True:
try:
resp = await asyncio.wait_for(q.get(), self.timeout)
except:
resp = await asyncio.to_thread(q.get, timeout=self.timeout)
except Exception as e:
raise RuntimeError("Timeout waiting for response")
if resp.error:
@ -173,4 +163,4 @@ class MultiResponseServiceEndpoint(ServiceEndpoint):
)
finally:
await self.sub.unsubscribe(id)
self.sub.unsubscribe(id)

View file

@ -29,11 +29,9 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
schema=JsonSchema(GraphEmbeddings)
)
async def start(self, client):
async def start(self):
self.task = asyncio.create_task(
self.publisher.run(client)
)
self.publisher.start()
async def listener(self, ws, running):
@ -56,7 +54,7 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
vectors=data["vectors"],
)
await self.publisher.send(None, elt)
self.publisher.send(None, elt)
running.stop()

View file

@ -1,5 +1,6 @@
import asyncio
import queue
from pulsar.schema import JsonSchema
import uuid
@ -28,31 +29,29 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
schema=JsonSchema(GraphEmbeddings)
)
async def start(self, client):
async def start(self):
self.task = asyncio.create_task(
self.subscriber.run(client)
)
self.subscriber.start()
async def async_thread(self, ws, running):
id = str(uuid.uuid4())
q = await self.subscriber.subscribe_all(id)
q = self.subscriber.subscribe_all(id)
while running.get():
try:
resp = await asyncio.wait_for(q.get(), 0.5)
resp = await asyncio.to_thread(q.get, timeout=0.5)
await ws.send_json(serialize_graph_embeddings(resp))
except TimeoutError:
except queue.Empty:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
break
await self.subscriber.unsubscribe_all(id)
self.subscriber.unsubscribe_all(id)
running.stop()

View file

@ -1,5 +1,8 @@
import asyncio
import queue
import time
import pulsar
import threading
class Publisher:
@ -8,32 +11,43 @@ class Publisher:
self.pulsar_host = pulsar_host
self.topic = topic
self.schema = schema
self.q = asyncio.Queue(maxsize=max_size)
self.q = queue.Queue(maxsize=max_size)
self.chunking_enabled = chunking_enabled
async def run(self, client):
def start(self):
self.task = threading.Thread(target=self.run)
self.task.start()
def run(self):
while True:
try:
async with client.create_producer(
topic=self.topic,
schema=self.schema,
chunking_enabled=self.chunking_enabled,
) as producer:
while True:
id, item = await self.q.get()
if id:
await producer.send(item, { "id": id })
else:
await producer.send(item)
client = pulsar.Client(
self.pulsar_host,
)
producer = client.create_producer(
topic=self.topic,
schema=self.schema,
chunking_enabled=self.chunking_enabled,
)
while True:
id, item = self.q.get()
if id:
producer.send(item, { "id": id })
else:
producer.send(item)
except Exception as e:
print("Exception:", e, flush=True)
# If handler drops out, sleep a retry
await asyncio.sleep(2)
time.sleep(2)
async def send(self, id, msg):
await self.q.put((id, msg))
def send(self, id, msg):
self.q.put((id, msg))

View file

@ -17,7 +17,6 @@ from aiohttp import web
import logging
import os
import base64
import aiopulsar
import pulsar
from pulsar.schema import JsonSchema
@ -167,7 +166,8 @@ class Api:
# content is valid base64
doc = base64.b64decode(data["data"])
resp = await self.document_out.send(
resp = await asyncio.to_thread(
self.document_out.send,
None,
Document(
metadata=Metadata(
@ -212,7 +212,8 @@ class Api:
# Text is base64 encoded
text = base64.b64decode(data["text"]).decode(charset)
resp = await self.text_out.send(
resp = asyncio.to_thread(
self.text_out.send,
None,
TextDocument(
metadata=Metadata(
@ -238,35 +239,13 @@ class Api:
{ "error": str(e) }
)
async def run_endpoints(self):
async with aiopulsar.connect(self.pulsar_host) as client:
for ep in self.endpoints:
await ep.start(client)
self.doc_ingest_pub_task = asyncio.create_task(
self.document_out.run(client)
)
self.text_ingest_pub_task = asyncio.create_task(
self.text_out.run(client)
)
print("Endpoints are running...")
# They never exit
for ep in self.endpoints:
await ep.join()
await self.doc_ingest_pub_task
await self.text_ingest_pub_task
print("Endpoints are stopped.")
async def app_factory(self):
self.endpoint_task = asyncio.create_task(self.run_endpoints())
for ep in self.endpoints:
await ep.start()
self.document_out.start()
self.text_out.start()
return self.app

View file

@ -76,12 +76,6 @@ class SocketEndpoint:
async def start(self):
pass
async def join(self):
# Nothing to wait for
while True:
await asyncio.sleep(100)
def add_routes(self, app):
app.add_routes([

View file

@ -1,10 +1,13 @@
import asyncio
import queue
import pulsar
import threading
import time
class Subscriber:
def __init__(self, pulsar_host, topic, subscription, consumer_name,
schema=None, max_size=10):
schema=None, max_size=100):
self.pulsar_host = pulsar_host
self.topic = topic
self.subscription = subscription
@ -12,55 +15,95 @@ class Subscriber:
self.schema = schema
self.q = {}
self.full = {}
self.max_size = max_size
self.lock = threading.Lock()
def start(self):
self.task = threading.Thread(target=self.run)
self.task.start()
def run(self):
async def run(self, client):
while True:
try:
async with client.subscribe(
client = pulsar.Client(
self.pulsar_host,
)
consumer = client.subscribe(
topic=self.topic,
subscription_name=self.subscription,
consumer_name=self.consumer_name,
schema=self.schema,
) as consumer:
while True:
msg = await consumer.receive()
)
# Acknowledge successful reception of the message
await consumer.acknowledge(msg)
while True:
try:
id = msg.properties()["id"]
except:
id = None
msg = consumer.receive()
# Acknowledge successful reception of the message
consumer.acknowledge(msg)
try:
id = msg.properties()["id"]
except:
id = None
value = msg.value()
with self.lock:
value = msg.value()
if id in self.q:
await self.q[id].put(value)
try:
self.q[id].put(value, timeout=0.5)
except:
pass
for q in self.full.values():
await q.put(value)
try:
q.put(value, timeout=0.5)
except:
pass
except Exception as e:
print("Exception:", e, flush=True)
# If handler drops out, sleep a retry
await asyncio.sleep(2)
time.sleep(2)
def subscribe(self, id):
with self.lock:
q = queue.Queue(maxsize=self.max_size)
self.q[id] = q
async def subscribe(self, id):
q = asyncio.Queue()
self.q[id] = q
return q
async def unsubscribe(self, id):
if id in self.q:
del self.q[id]
def unsubscribe(self, id):
with self.lock:
if id in self.q:
# self.q[id].shutdown(immediate=True)
del self.q[id]
async def subscribe_all(self, id):
q = asyncio.Queue()
self.full[id] = q
def subscribe_all(self, id):
with self.lock:
q = queue.Queue(maxsize=self.max_size)
self.full[id] = q
return q
async def unsubscribe_all(self, id):
if id in self.full:
del self.full[id]
def unsubscribe_all(self, id):
with self.lock:
if id in self.full:
# self.full[id].shutdown(immediate=True)
del self.full[id]

View file

@ -27,11 +27,9 @@ class TriplesLoadEndpoint(SocketEndpoint):
schema=JsonSchema(Triples)
)
async def start(self, client):
async def start(self):
self.task = asyncio.create_task(
self.publisher.run(client)
)
self.publisher.start()
async def listener(self, ws, running):
@ -53,7 +51,7 @@ class TriplesLoadEndpoint(SocketEndpoint):
triples=to_subgraph(data["triples"]),
)
await self.publisher.send(None, elt)
self.publisher.send(None, elt)
running.stop()

View file

@ -1,5 +1,6 @@
import asyncio
import queue
from pulsar.schema import JsonSchema
import uuid
@ -26,31 +27,29 @@ class TriplesStreamEndpoint(SocketEndpoint):
schema=JsonSchema(Triples)
)
async def start(self, client):
async def start(self):
self.task = asyncio.create_task(
self.subscriber.run(client)
)
self.subscriber.start()
async def async_thread(self, ws, running):
id = str(uuid.uuid4())
q = await self.subscriber.subscribe_all(id)
q = self.subscriber.subscribe_all(id)
while running.get():
try:
resp = await asyncio.wait_for(q.get(), 0.5)
resp = await asyncio.to_thread(q.get, timeout=0.5)
await ws.send_json(serialize_triples(resp))
except TimeoutError:
except queue.Empty:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
break
await self.subscriber.unsubscribe_all(id)
self.subscriber.unsubscribe_all(id)
running.stop()