mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-30 02:46:23 +02:00
Fix/async problem (#190)
* Back out previous change * To multithreads * Remove aiopulsar dependency
This commit is contained in:
parent
7e78aa6d91
commit
26865a515c
10 changed files with 149 additions and 136 deletions
|
|
@ -59,7 +59,6 @@ setuptools.setup(
|
|||
"ibis",
|
||||
"jsonschema",
|
||||
"aiohttp",
|
||||
"aiopulsar-py",
|
||||
"pinecone[grpc]",
|
||||
],
|
||||
scripts=[
|
||||
|
|
|
|||
|
|
@ -41,15 +41,10 @@ class ServiceEndpoint:
|
|||
|
||||
self.operation = "service"
|
||||
|
||||
async def start(self, client):
|
||||
async def start(self):
|
||||
|
||||
self.pub_task = asyncio.create_task(self.pub.run(client))
|
||||
self.sub_task = asyncio.create_task(self.sub.run(client))
|
||||
|
||||
async def join(self):
|
||||
|
||||
await self.pub_task
|
||||
await self.sub_task
|
||||
self.pub.start()
|
||||
self.sub.start()
|
||||
|
||||
def add_routes(self, app):
|
||||
|
||||
|
|
@ -87,20 +82,18 @@ class ServiceEndpoint:
|
|||
|
||||
print(data)
|
||||
|
||||
q = await self.sub.subscribe(id)
|
||||
q = self.sub.subscribe(id)
|
||||
|
||||
await self.pub.send(
|
||||
id,
|
||||
self.to_request(data),
|
||||
await asyncio.to_thread(
|
||||
self.pub.send, id, self.to_request(data)
|
||||
)
|
||||
print("Request sent")
|
||||
|
||||
try:
|
||||
resp = await asyncio.wait_for(q.get(), self.timeout)
|
||||
except:
|
||||
raise RuntimeError("Timeout waiting for response")
|
||||
resp = await asyncio.to_thread(q.get, timeout=self.timeout)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Timeout")
|
||||
|
||||
print("Response got")
|
||||
print(resp)
|
||||
|
||||
if resp.error:
|
||||
print("Error")
|
||||
|
|
@ -108,8 +101,6 @@ class ServiceEndpoint:
|
|||
{ "error": resp.error.message }
|
||||
)
|
||||
|
||||
print("Send response")
|
||||
|
||||
return web.json_response(
|
||||
self.from_response(resp)
|
||||
)
|
||||
|
|
@ -122,7 +113,7 @@ class ServiceEndpoint:
|
|||
)
|
||||
|
||||
finally:
|
||||
await self.sub.unsubscribe(id)
|
||||
self.sub.unsubscribe(id)
|
||||
|
||||
|
||||
class MultiResponseServiceEndpoint(ServiceEndpoint):
|
||||
|
|
@ -135,11 +126,10 @@ class MultiResponseServiceEndpoint(ServiceEndpoint):
|
|||
|
||||
data = await request.json()
|
||||
|
||||
q = await self.sub.subscribe(id)
|
||||
q = self.sub.subscribe(id)
|
||||
|
||||
await self.pub.send(
|
||||
id,
|
||||
self.to_request(data),
|
||||
await asyncio.to_thread(
|
||||
self.pub.send, id, self.to_request(data)
|
||||
)
|
||||
|
||||
# Keeps looking at responses...
|
||||
|
|
@ -147,8 +137,8 @@ class MultiResponseServiceEndpoint(ServiceEndpoint):
|
|||
while True:
|
||||
|
||||
try:
|
||||
resp = await asyncio.wait_for(q.get(), self.timeout)
|
||||
except:
|
||||
resp = await asyncio.to_thread(q.get, timeout=self.timeout)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Timeout waiting for response")
|
||||
|
||||
if resp.error:
|
||||
|
|
@ -173,4 +163,4 @@ class MultiResponseServiceEndpoint(ServiceEndpoint):
|
|||
)
|
||||
|
||||
finally:
|
||||
await self.sub.unsubscribe(id)
|
||||
self.sub.unsubscribe(id)
|
||||
|
|
|
|||
|
|
@ -29,11 +29,9 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
|
|||
schema=JsonSchema(GraphEmbeddings)
|
||||
)
|
||||
|
||||
async def start(self, client):
|
||||
async def start(self):
|
||||
|
||||
self.task = asyncio.create_task(
|
||||
self.publisher.run(client)
|
||||
)
|
||||
self.publisher.start()
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
||||
|
|
@ -56,7 +54,7 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
|
|||
vectors=data["vectors"],
|
||||
)
|
||||
|
||||
await self.publisher.send(None, elt)
|
||||
self.publisher.send(None, elt)
|
||||
|
||||
|
||||
running.stop()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
|
||||
import asyncio
|
||||
import queue
|
||||
from pulsar.schema import JsonSchema
|
||||
import uuid
|
||||
|
||||
|
|
@ -28,31 +29,29 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
|
|||
schema=JsonSchema(GraphEmbeddings)
|
||||
)
|
||||
|
||||
async def start(self, client):
|
||||
async def start(self):
|
||||
|
||||
self.task = asyncio.create_task(
|
||||
self.subscriber.run(client)
|
||||
)
|
||||
self.subscriber.start()
|
||||
|
||||
async def async_thread(self, ws, running):
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
q = await self.subscriber.subscribe_all(id)
|
||||
q = self.subscriber.subscribe_all(id)
|
||||
|
||||
while running.get():
|
||||
try:
|
||||
resp = await asyncio.wait_for(q.get(), 0.5)
|
||||
resp = await asyncio.to_thread(q.get, timeout=0.5)
|
||||
await ws.send_json(serialize_graph_embeddings(resp))
|
||||
|
||||
except TimeoutError:
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Exception: {str(e)}", flush=True)
|
||||
break
|
||||
|
||||
await self.subscriber.unsubscribe_all(id)
|
||||
self.subscriber.unsubscribe_all(id)
|
||||
|
||||
running.stop()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
|
||||
import asyncio
|
||||
import queue
|
||||
import time
|
||||
import pulsar
|
||||
import threading
|
||||
|
||||
class Publisher:
|
||||
|
||||
|
|
@ -8,32 +11,43 @@ class Publisher:
|
|||
self.pulsar_host = pulsar_host
|
||||
self.topic = topic
|
||||
self.schema = schema
|
||||
self.q = asyncio.Queue(maxsize=max_size)
|
||||
self.q = queue.Queue(maxsize=max_size)
|
||||
self.chunking_enabled = chunking_enabled
|
||||
|
||||
async def run(self, client):
|
||||
def start(self):
|
||||
self.task = threading.Thread(target=self.run)
|
||||
self.task.start()
|
||||
|
||||
def run(self):
|
||||
|
||||
while True:
|
||||
|
||||
try:
|
||||
async with client.create_producer(
|
||||
topic=self.topic,
|
||||
schema=self.schema,
|
||||
chunking_enabled=self.chunking_enabled,
|
||||
) as producer:
|
||||
while True:
|
||||
id, item = await self.q.get()
|
||||
|
||||
if id:
|
||||
await producer.send(item, { "id": id })
|
||||
else:
|
||||
await producer.send(item)
|
||||
client = pulsar.Client(
|
||||
self.pulsar_host,
|
||||
)
|
||||
|
||||
producer = client.create_producer(
|
||||
topic=self.topic,
|
||||
schema=self.schema,
|
||||
chunking_enabled=self.chunking_enabled,
|
||||
)
|
||||
|
||||
while True:
|
||||
|
||||
id, item = self.q.get()
|
||||
|
||||
if id:
|
||||
producer.send(item, { "id": id })
|
||||
else:
|
||||
producer.send(item)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception:", e, flush=True)
|
||||
|
||||
# If handler drops out, sleep a retry
|
||||
await asyncio.sleep(2)
|
||||
time.sleep(2)
|
||||
|
||||
async def send(self, id, msg):
|
||||
await self.q.put((id, msg))
|
||||
def send(self, id, msg):
|
||||
self.q.put((id, msg))
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from aiohttp import web
|
|||
import logging
|
||||
import os
|
||||
import base64
|
||||
import aiopulsar
|
||||
|
||||
import pulsar
|
||||
from pulsar.schema import JsonSchema
|
||||
|
|
@ -167,7 +166,8 @@ class Api:
|
|||
# content is valid base64
|
||||
doc = base64.b64decode(data["data"])
|
||||
|
||||
resp = await self.document_out.send(
|
||||
resp = await asyncio.to_thread(
|
||||
self.document_out.send,
|
||||
None,
|
||||
Document(
|
||||
metadata=Metadata(
|
||||
|
|
@ -212,7 +212,8 @@ class Api:
|
|||
# Text is base64 encoded
|
||||
text = base64.b64decode(data["text"]).decode(charset)
|
||||
|
||||
resp = await self.text_out.send(
|
||||
resp = asyncio.to_thread(
|
||||
self.text_out.send,
|
||||
None,
|
||||
TextDocument(
|
||||
metadata=Metadata(
|
||||
|
|
@ -238,35 +239,13 @@ class Api:
|
|||
{ "error": str(e) }
|
||||
)
|
||||
|
||||
async def run_endpoints(self):
|
||||
|
||||
async with aiopulsar.connect(self.pulsar_host) as client:
|
||||
|
||||
for ep in self.endpoints:
|
||||
await ep.start(client)
|
||||
|
||||
self.doc_ingest_pub_task = asyncio.create_task(
|
||||
self.document_out.run(client)
|
||||
)
|
||||
|
||||
self.text_ingest_pub_task = asyncio.create_task(
|
||||
self.text_out.run(client)
|
||||
)
|
||||
|
||||
print("Endpoints are running...")
|
||||
|
||||
# They never exit
|
||||
for ep in self.endpoints:
|
||||
await ep.join()
|
||||
|
||||
await self.doc_ingest_pub_task
|
||||
await self.text_ingest_pub_task
|
||||
|
||||
print("Endpoints are stopped.")
|
||||
|
||||
async def app_factory(self):
|
||||
|
||||
self.endpoint_task = asyncio.create_task(self.run_endpoints())
|
||||
for ep in self.endpoints:
|
||||
await ep.start()
|
||||
|
||||
self.document_out.start()
|
||||
self.text_out.start()
|
||||
|
||||
return self.app
|
||||
|
||||
|
|
|
|||
|
|
@ -76,12 +76,6 @@ class SocketEndpoint:
|
|||
async def start(self):
|
||||
pass
|
||||
|
||||
async def join(self):
|
||||
|
||||
# Nothing to wait for
|
||||
while True:
|
||||
await asyncio.sleep(100)
|
||||
|
||||
def add_routes(self, app):
|
||||
|
||||
app.add_routes([
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
|
||||
import asyncio
|
||||
import queue
|
||||
import pulsar
|
||||
import threading
|
||||
import time
|
||||
|
||||
class Subscriber:
|
||||
|
||||
def __init__(self, pulsar_host, topic, subscription, consumer_name,
|
||||
schema=None, max_size=10):
|
||||
schema=None, max_size=100):
|
||||
self.pulsar_host = pulsar_host
|
||||
self.topic = topic
|
||||
self.subscription = subscription
|
||||
|
|
@ -12,55 +15,95 @@ class Subscriber:
|
|||
self.schema = schema
|
||||
self.q = {}
|
||||
self.full = {}
|
||||
self.max_size = max_size
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def start(self):
|
||||
self.task = threading.Thread(target=self.run)
|
||||
self.task.start()
|
||||
|
||||
def run(self):
|
||||
|
||||
async def run(self, client):
|
||||
while True:
|
||||
|
||||
try:
|
||||
async with client.subscribe(
|
||||
|
||||
client = pulsar.Client(
|
||||
self.pulsar_host,
|
||||
)
|
||||
|
||||
consumer = client.subscribe(
|
||||
topic=self.topic,
|
||||
subscription_name=self.subscription,
|
||||
consumer_name=self.consumer_name,
|
||||
schema=self.schema,
|
||||
) as consumer:
|
||||
while True:
|
||||
msg = await consumer.receive()
|
||||
)
|
||||
|
||||
# Acknowledge successful reception of the message
|
||||
await consumer.acknowledge(msg)
|
||||
while True:
|
||||
|
||||
try:
|
||||
id = msg.properties()["id"]
|
||||
except:
|
||||
id = None
|
||||
msg = consumer.receive()
|
||||
|
||||
# Acknowledge successful reception of the message
|
||||
consumer.acknowledge(msg)
|
||||
|
||||
try:
|
||||
id = msg.properties()["id"]
|
||||
except:
|
||||
id = None
|
||||
|
||||
value = msg.value()
|
||||
|
||||
with self.lock:
|
||||
|
||||
value = msg.value()
|
||||
if id in self.q:
|
||||
await self.q[id].put(value)
|
||||
try:
|
||||
self.q[id].put(value, timeout=0.5)
|
||||
except:
|
||||
pass
|
||||
|
||||
for q in self.full.values():
|
||||
await q.put(value)
|
||||
try:
|
||||
q.put(value, timeout=0.5)
|
||||
except:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
print("Exception:", e, flush=True)
|
||||
|
||||
# If handler drops out, sleep a retry
|
||||
await asyncio.sleep(2)
|
||||
time.sleep(2)
|
||||
|
||||
def subscribe(self, id):
|
||||
|
||||
with self.lock:
|
||||
|
||||
q = queue.Queue(maxsize=self.max_size)
|
||||
self.q[id] = q
|
||||
|
||||
async def subscribe(self, id):
|
||||
q = asyncio.Queue()
|
||||
self.q[id] = q
|
||||
return q
|
||||
|
||||
async def unsubscribe(self, id):
|
||||
if id in self.q:
|
||||
del self.q[id]
|
||||
def unsubscribe(self, id):
|
||||
|
||||
with self.lock:
|
||||
|
||||
if id in self.q:
|
||||
# self.q[id].shutdown(immediate=True)
|
||||
del self.q[id]
|
||||
|
||||
async def subscribe_all(self, id):
|
||||
q = asyncio.Queue()
|
||||
self.full[id] = q
|
||||
def subscribe_all(self, id):
|
||||
|
||||
with self.lock:
|
||||
|
||||
q = queue.Queue(maxsize=self.max_size)
|
||||
self.full[id] = q
|
||||
|
||||
return q
|
||||
|
||||
async def unsubscribe_all(self, id):
|
||||
if id in self.full:
|
||||
del self.full[id]
|
||||
def unsubscribe_all(self, id):
|
||||
|
||||
with self.lock:
|
||||
|
||||
if id in self.full:
|
||||
# self.full[id].shutdown(immediate=True)
|
||||
del self.full[id]
|
||||
|
||||
|
|
|
|||
|
|
@ -27,11 +27,9 @@ class TriplesLoadEndpoint(SocketEndpoint):
|
|||
schema=JsonSchema(Triples)
|
||||
)
|
||||
|
||||
async def start(self, client):
|
||||
async def start(self):
|
||||
|
||||
self.task = asyncio.create_task(
|
||||
self.publisher.run(client)
|
||||
)
|
||||
self.publisher.start()
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
||||
|
|
@ -53,7 +51,7 @@ class TriplesLoadEndpoint(SocketEndpoint):
|
|||
triples=to_subgraph(data["triples"]),
|
||||
)
|
||||
|
||||
await self.publisher.send(None, elt)
|
||||
self.publisher.send(None, elt)
|
||||
|
||||
|
||||
running.stop()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
|
||||
import asyncio
|
||||
import queue
|
||||
from pulsar.schema import JsonSchema
|
||||
import uuid
|
||||
|
||||
|
|
@ -26,31 +27,29 @@ class TriplesStreamEndpoint(SocketEndpoint):
|
|||
schema=JsonSchema(Triples)
|
||||
)
|
||||
|
||||
async def start(self, client):
|
||||
async def start(self):
|
||||
|
||||
self.task = asyncio.create_task(
|
||||
self.subscriber.run(client)
|
||||
)
|
||||
self.subscriber.start()
|
||||
|
||||
async def async_thread(self, ws, running):
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
q = await self.subscriber.subscribe_all(id)
|
||||
q = self.subscriber.subscribe_all(id)
|
||||
|
||||
while running.get():
|
||||
try:
|
||||
resp = await asyncio.wait_for(q.get(), 0.5)
|
||||
resp = await asyncio.to_thread(q.get, timeout=0.5)
|
||||
await ws.send_json(serialize_triples(resp))
|
||||
|
||||
except TimeoutError:
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Exception: {str(e)}", flush=True)
|
||||
break
|
||||
|
||||
await self.subscriber.unsubscribe_all(id)
|
||||
self.subscriber.unsubscribe_all(id)
|
||||
|
||||
running.stop()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue