Fix/async problem (#190)

* Back out previous change
* To multithreads
* Remove aiopulsar dependency
This commit is contained in:
cybermaggedon 2024-12-03 18:03:00 +00:00 committed by GitHub
parent 7e78aa6d91
commit 26865a515c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 149 additions and 136 deletions

View file

@ -59,7 +59,6 @@ setuptools.setup(
"ibis", "ibis",
"jsonschema", "jsonschema",
"aiohttp", "aiohttp",
"aiopulsar-py",
"pinecone[grpc]", "pinecone[grpc]",
], ],
scripts=[ scripts=[

View file

@ -41,15 +41,10 @@ class ServiceEndpoint:
self.operation = "service" self.operation = "service"
async def start(self, client): async def start(self):
self.pub_task = asyncio.create_task(self.pub.run(client)) self.pub.start()
self.sub_task = asyncio.create_task(self.sub.run(client)) self.sub.start()
async def join(self):
await self.pub_task
await self.sub_task
def add_routes(self, app): def add_routes(self, app):
@ -87,20 +82,18 @@ class ServiceEndpoint:
print(data) print(data)
q = await self.sub.subscribe(id) q = self.sub.subscribe(id)
await self.pub.send( await asyncio.to_thread(
id, self.pub.send, id, self.to_request(data)
self.to_request(data),
) )
print("Request sent")
try: try:
resp = await asyncio.wait_for(q.get(), self.timeout) resp = await asyncio.to_thread(q.get, timeout=self.timeout)
except: except Exception as e:
raise RuntimeError("Timeout waiting for response") raise RuntimeError("Timeout")
print("Response got") print(resp)
if resp.error: if resp.error:
print("Error") print("Error")
@ -108,8 +101,6 @@ class ServiceEndpoint:
{ "error": resp.error.message } { "error": resp.error.message }
) )
print("Send response")
return web.json_response( return web.json_response(
self.from_response(resp) self.from_response(resp)
) )
@ -122,7 +113,7 @@ class ServiceEndpoint:
) )
finally: finally:
await self.sub.unsubscribe(id) self.sub.unsubscribe(id)
class MultiResponseServiceEndpoint(ServiceEndpoint): class MultiResponseServiceEndpoint(ServiceEndpoint):
@ -135,11 +126,10 @@ class MultiResponseServiceEndpoint(ServiceEndpoint):
data = await request.json() data = await request.json()
q = await self.sub.subscribe(id) q = self.sub.subscribe(id)
await self.pub.send( await asyncio.to_thread(
id, self.pub.send, id, self.to_request(data)
self.to_request(data),
) )
# Keeps looking at responses... # Keeps looking at responses...
@ -147,8 +137,8 @@ class MultiResponseServiceEndpoint(ServiceEndpoint):
while True: while True:
try: try:
resp = await asyncio.wait_for(q.get(), self.timeout) resp = await asyncio.to_thread(q.get, timeout=self.timeout)
except: except Exception as e:
raise RuntimeError("Timeout waiting for response") raise RuntimeError("Timeout waiting for response")
if resp.error: if resp.error:
@ -173,4 +163,4 @@ class MultiResponseServiceEndpoint(ServiceEndpoint):
) )
finally: finally:
await self.sub.unsubscribe(id) self.sub.unsubscribe(id)

View file

@ -29,11 +29,9 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
schema=JsonSchema(GraphEmbeddings) schema=JsonSchema(GraphEmbeddings)
) )
async def start(self, client): async def start(self):
self.task = asyncio.create_task( self.publisher.start()
self.publisher.run(client)
)
async def listener(self, ws, running): async def listener(self, ws, running):
@ -56,7 +54,7 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
vectors=data["vectors"], vectors=data["vectors"],
) )
await self.publisher.send(None, elt) self.publisher.send(None, elt)
running.stop() running.stop()

View file

@ -1,5 +1,6 @@
import asyncio import asyncio
import queue
from pulsar.schema import JsonSchema from pulsar.schema import JsonSchema
import uuid import uuid
@ -28,31 +29,29 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
schema=JsonSchema(GraphEmbeddings) schema=JsonSchema(GraphEmbeddings)
) )
async def start(self, client): async def start(self):
self.task = asyncio.create_task( self.subscriber.start()
self.subscriber.run(client)
)
async def async_thread(self, ws, running): async def async_thread(self, ws, running):
id = str(uuid.uuid4()) id = str(uuid.uuid4())
q = await self.subscriber.subscribe_all(id) q = self.subscriber.subscribe_all(id)
while running.get(): while running.get():
try: try:
resp = await asyncio.wait_for(q.get(), 0.5) resp = await asyncio.to_thread(q.get, timeout=0.5)
await ws.send_json(serialize_graph_embeddings(resp)) await ws.send_json(serialize_graph_embeddings(resp))
except TimeoutError: except queue.Empty:
continue continue
except Exception as e: except Exception as e:
print(f"Exception: {str(e)}", flush=True) print(f"Exception: {str(e)}", flush=True)
break break
await self.subscriber.unsubscribe_all(id) self.subscriber.unsubscribe_all(id)
running.stop() running.stop()

View file

@ -1,5 +1,8 @@
import asyncio import queue
import time
import pulsar
import threading
class Publisher: class Publisher:
@ -8,32 +11,43 @@ class Publisher:
self.pulsar_host = pulsar_host self.pulsar_host = pulsar_host
self.topic = topic self.topic = topic
self.schema = schema self.schema = schema
self.q = asyncio.Queue(maxsize=max_size) self.q = queue.Queue(maxsize=max_size)
self.chunking_enabled = chunking_enabled self.chunking_enabled = chunking_enabled
async def run(self, client): def start(self):
self.task = threading.Thread(target=self.run)
self.task.start()
def run(self):
while True: while True:
try: try:
async with client.create_producer(
topic=self.topic,
schema=self.schema,
chunking_enabled=self.chunking_enabled,
) as producer:
while True:
id, item = await self.q.get()
if id: client = pulsar.Client(
await producer.send(item, { "id": id }) self.pulsar_host,
else: )
await producer.send(item)
producer = client.create_producer(
topic=self.topic,
schema=self.schema,
chunking_enabled=self.chunking_enabled,
)
while True:
id, item = self.q.get()
if id:
producer.send(item, { "id": id })
else:
producer.send(item)
except Exception as e: except Exception as e:
print("Exception:", e, flush=True) print("Exception:", e, flush=True)
# If handler drops out, sleep a retry # If handler drops out, sleep a retry
await asyncio.sleep(2) time.sleep(2)
async def send(self, id, msg): def send(self, id, msg):
await self.q.put((id, msg)) self.q.put((id, msg))

View file

@ -17,7 +17,6 @@ from aiohttp import web
import logging import logging
import os import os
import base64 import base64
import aiopulsar
import pulsar import pulsar
from pulsar.schema import JsonSchema from pulsar.schema import JsonSchema
@ -167,7 +166,8 @@ class Api:
# content is valid base64 # content is valid base64
doc = base64.b64decode(data["data"]) doc = base64.b64decode(data["data"])
resp = await self.document_out.send( resp = await asyncio.to_thread(
self.document_out.send,
None, None,
Document( Document(
metadata=Metadata( metadata=Metadata(
@ -212,7 +212,8 @@ class Api:
# Text is base64 encoded # Text is base64 encoded
text = base64.b64decode(data["text"]).decode(charset) text = base64.b64decode(data["text"]).decode(charset)
resp = await self.text_out.send( resp = asyncio.to_thread(
self.text_out.send,
None, None,
TextDocument( TextDocument(
metadata=Metadata( metadata=Metadata(
@ -238,35 +239,13 @@ class Api:
{ "error": str(e) } { "error": str(e) }
) )
async def run_endpoints(self):
async with aiopulsar.connect(self.pulsar_host) as client:
for ep in self.endpoints:
await ep.start(client)
self.doc_ingest_pub_task = asyncio.create_task(
self.document_out.run(client)
)
self.text_ingest_pub_task = asyncio.create_task(
self.text_out.run(client)
)
print("Endpoints are running...")
# They never exit
for ep in self.endpoints:
await ep.join()
await self.doc_ingest_pub_task
await self.text_ingest_pub_task
print("Endpoints are stopped.")
async def app_factory(self): async def app_factory(self):
self.endpoint_task = asyncio.create_task(self.run_endpoints()) for ep in self.endpoints:
await ep.start()
self.document_out.start()
self.text_out.start()
return self.app return self.app

View file

@ -76,12 +76,6 @@ class SocketEndpoint:
async def start(self): async def start(self):
pass pass
async def join(self):
# Nothing to wait for
while True:
await asyncio.sleep(100)
def add_routes(self, app): def add_routes(self, app):
app.add_routes([ app.add_routes([

View file

@ -1,10 +1,13 @@
import asyncio import queue
import pulsar
import threading
import time
class Subscriber: class Subscriber:
def __init__(self, pulsar_host, topic, subscription, consumer_name, def __init__(self, pulsar_host, topic, subscription, consumer_name,
schema=None, max_size=10): schema=None, max_size=100):
self.pulsar_host = pulsar_host self.pulsar_host = pulsar_host
self.topic = topic self.topic = topic
self.subscription = subscription self.subscription = subscription
@ -12,55 +15,95 @@ class Subscriber:
self.schema = schema self.schema = schema
self.q = {} self.q = {}
self.full = {} self.full = {}
self.max_size = max_size
self.lock = threading.Lock()
def start(self):
self.task = threading.Thread(target=self.run)
self.task.start()
def run(self):
async def run(self, client):
while True: while True:
try: try:
async with client.subscribe(
client = pulsar.Client(
self.pulsar_host,
)
consumer = client.subscribe(
topic=self.topic, topic=self.topic,
subscription_name=self.subscription, subscription_name=self.subscription,
consumer_name=self.consumer_name, consumer_name=self.consumer_name,
schema=self.schema, schema=self.schema,
) as consumer: )
while True:
msg = await consumer.receive()
# Acknowledge successful reception of the message while True:
await consumer.acknowledge(msg)
try: msg = consumer.receive()
id = msg.properties()["id"]
except: # Acknowledge successful reception of the message
id = None consumer.acknowledge(msg)
try:
id = msg.properties()["id"]
except:
id = None
value = msg.value()
with self.lock:
value = msg.value()
if id in self.q: if id in self.q:
await self.q[id].put(value) try:
self.q[id].put(value, timeout=0.5)
except:
pass
for q in self.full.values(): for q in self.full.values():
await q.put(value) try:
q.put(value, timeout=0.5)
except:
pass
except Exception as e: except Exception as e:
print("Exception:", e, flush=True) print("Exception:", e, flush=True)
# If handler drops out, sleep a retry # If handler drops out, sleep a retry
await asyncio.sleep(2) time.sleep(2)
def subscribe(self, id):
with self.lock:
q = queue.Queue(maxsize=self.max_size)
self.q[id] = q
async def subscribe(self, id):
q = asyncio.Queue()
self.q[id] = q
return q return q
async def unsubscribe(self, id): def unsubscribe(self, id):
if id in self.q:
del self.q[id] with self.lock:
if id in self.q:
# self.q[id].shutdown(immediate=True)
del self.q[id]
def subscribe_all(self, id):
with self.lock:
q = queue.Queue(maxsize=self.max_size)
self.full[id] = q
async def subscribe_all(self, id):
q = asyncio.Queue()
self.full[id] = q
return q return q
async def unsubscribe_all(self, id): def unsubscribe_all(self, id):
if id in self.full:
del self.full[id] with self.lock:
if id in self.full:
# self.full[id].shutdown(immediate=True)
del self.full[id]

View file

@ -27,11 +27,9 @@ class TriplesLoadEndpoint(SocketEndpoint):
schema=JsonSchema(Triples) schema=JsonSchema(Triples)
) )
async def start(self, client): async def start(self):
self.task = asyncio.create_task( self.publisher.start()
self.publisher.run(client)
)
async def listener(self, ws, running): async def listener(self, ws, running):
@ -53,7 +51,7 @@ class TriplesLoadEndpoint(SocketEndpoint):
triples=to_subgraph(data["triples"]), triples=to_subgraph(data["triples"]),
) )
await self.publisher.send(None, elt) self.publisher.send(None, elt)
running.stop() running.stop()

View file

@ -1,5 +1,6 @@
import asyncio import asyncio
import queue
from pulsar.schema import JsonSchema from pulsar.schema import JsonSchema
import uuid import uuid
@ -26,31 +27,29 @@ class TriplesStreamEndpoint(SocketEndpoint):
schema=JsonSchema(Triples) schema=JsonSchema(Triples)
) )
async def start(self, client): async def start(self):
self.task = asyncio.create_task( self.subscriber.start()
self.subscriber.run(client)
)
async def async_thread(self, ws, running): async def async_thread(self, ws, running):
id = str(uuid.uuid4()) id = str(uuid.uuid4())
q = await self.subscriber.subscribe_all(id) q = self.subscriber.subscribe_all(id)
while running.get(): while running.get():
try: try:
resp = await asyncio.wait_for(q.get(), 0.5) resp = await asyncio.to_thread(q.get, timeout=0.5)
await ws.send_json(serialize_triples(resp)) await ws.send_json(serialize_triples(resp))
except TimeoutError: except queue.Empty:
continue continue
except Exception as e: except Exception as e:
print(f"Exception: {str(e)}", flush=True) print(f"Exception: {str(e)}", flush=True)
break break
await self.subscriber.unsubscribe_all(id) self.subscriber.unsubscribe_all(id)
running.stop() running.stop()