mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-01 11:26:22 +02:00
Fix/async problem (#190)
* Back out previous change * To multithreads * Remove aiopulsar dependency
This commit is contained in:
parent
7e78aa6d91
commit
26865a515c
10 changed files with 149 additions and 136 deletions
|
|
@ -59,7 +59,6 @@ setuptools.setup(
|
||||||
"ibis",
|
"ibis",
|
||||||
"jsonschema",
|
"jsonschema",
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
"aiopulsar-py",
|
|
||||||
"pinecone[grpc]",
|
"pinecone[grpc]",
|
||||||
],
|
],
|
||||||
scripts=[
|
scripts=[
|
||||||
|
|
|
||||||
|
|
@ -41,15 +41,10 @@ class ServiceEndpoint:
|
||||||
|
|
||||||
self.operation = "service"
|
self.operation = "service"
|
||||||
|
|
||||||
async def start(self, client):
|
async def start(self):
|
||||||
|
|
||||||
self.pub_task = asyncio.create_task(self.pub.run(client))
|
self.pub.start()
|
||||||
self.sub_task = asyncio.create_task(self.sub.run(client))
|
self.sub.start()
|
||||||
|
|
||||||
async def join(self):
|
|
||||||
|
|
||||||
await self.pub_task
|
|
||||||
await self.sub_task
|
|
||||||
|
|
||||||
def add_routes(self, app):
|
def add_routes(self, app):
|
||||||
|
|
||||||
|
|
@ -87,20 +82,18 @@ class ServiceEndpoint:
|
||||||
|
|
||||||
print(data)
|
print(data)
|
||||||
|
|
||||||
q = await self.sub.subscribe(id)
|
q = self.sub.subscribe(id)
|
||||||
|
|
||||||
await self.pub.send(
|
await asyncio.to_thread(
|
||||||
id,
|
self.pub.send, id, self.to_request(data)
|
||||||
self.to_request(data),
|
|
||||||
)
|
)
|
||||||
print("Request sent")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = await asyncio.wait_for(q.get(), self.timeout)
|
resp = await asyncio.to_thread(q.get, timeout=self.timeout)
|
||||||
except:
|
except Exception as e:
|
||||||
raise RuntimeError("Timeout waiting for response")
|
raise RuntimeError("Timeout")
|
||||||
|
|
||||||
print("Response got")
|
print(resp)
|
||||||
|
|
||||||
if resp.error:
|
if resp.error:
|
||||||
print("Error")
|
print("Error")
|
||||||
|
|
@ -108,8 +101,6 @@ class ServiceEndpoint:
|
||||||
{ "error": resp.error.message }
|
{ "error": resp.error.message }
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Send response")
|
|
||||||
|
|
||||||
return web.json_response(
|
return web.json_response(
|
||||||
self.from_response(resp)
|
self.from_response(resp)
|
||||||
)
|
)
|
||||||
|
|
@ -122,7 +113,7 @@ class ServiceEndpoint:
|
||||||
)
|
)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
await self.sub.unsubscribe(id)
|
self.sub.unsubscribe(id)
|
||||||
|
|
||||||
|
|
||||||
class MultiResponseServiceEndpoint(ServiceEndpoint):
|
class MultiResponseServiceEndpoint(ServiceEndpoint):
|
||||||
|
|
@ -135,11 +126,10 @@ class MultiResponseServiceEndpoint(ServiceEndpoint):
|
||||||
|
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
|
|
||||||
q = await self.sub.subscribe(id)
|
q = self.sub.subscribe(id)
|
||||||
|
|
||||||
await self.pub.send(
|
await asyncio.to_thread(
|
||||||
id,
|
self.pub.send, id, self.to_request(data)
|
||||||
self.to_request(data),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Keeps looking at responses...
|
# Keeps looking at responses...
|
||||||
|
|
@ -147,8 +137,8 @@ class MultiResponseServiceEndpoint(ServiceEndpoint):
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = await asyncio.wait_for(q.get(), self.timeout)
|
resp = await asyncio.to_thread(q.get, timeout=self.timeout)
|
||||||
except:
|
except Exception as e:
|
||||||
raise RuntimeError("Timeout waiting for response")
|
raise RuntimeError("Timeout waiting for response")
|
||||||
|
|
||||||
if resp.error:
|
if resp.error:
|
||||||
|
|
@ -173,4 +163,4 @@ class MultiResponseServiceEndpoint(ServiceEndpoint):
|
||||||
)
|
)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
await self.sub.unsubscribe(id)
|
self.sub.unsubscribe(id)
|
||||||
|
|
|
||||||
|
|
@ -29,11 +29,9 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
|
||||||
schema=JsonSchema(GraphEmbeddings)
|
schema=JsonSchema(GraphEmbeddings)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def start(self, client):
|
async def start(self):
|
||||||
|
|
||||||
self.task = asyncio.create_task(
|
self.publisher.start()
|
||||||
self.publisher.run(client)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def listener(self, ws, running):
|
async def listener(self, ws, running):
|
||||||
|
|
||||||
|
|
@ -56,7 +54,7 @@ class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
|
||||||
vectors=data["vectors"],
|
vectors=data["vectors"],
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.publisher.send(None, elt)
|
self.publisher.send(None, elt)
|
||||||
|
|
||||||
|
|
||||||
running.stop()
|
running.stop()
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import queue
|
||||||
from pulsar.schema import JsonSchema
|
from pulsar.schema import JsonSchema
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
|
@ -28,31 +29,29 @@ class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
|
||||||
schema=JsonSchema(GraphEmbeddings)
|
schema=JsonSchema(GraphEmbeddings)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def start(self, client):
|
async def start(self):
|
||||||
|
|
||||||
self.task = asyncio.create_task(
|
self.subscriber.start()
|
||||||
self.subscriber.run(client)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def async_thread(self, ws, running):
|
async def async_thread(self, ws, running):
|
||||||
|
|
||||||
id = str(uuid.uuid4())
|
id = str(uuid.uuid4())
|
||||||
|
|
||||||
q = await self.subscriber.subscribe_all(id)
|
q = self.subscriber.subscribe_all(id)
|
||||||
|
|
||||||
while running.get():
|
while running.get():
|
||||||
try:
|
try:
|
||||||
resp = await asyncio.wait_for(q.get(), 0.5)
|
resp = await asyncio.to_thread(q.get, timeout=0.5)
|
||||||
await ws.send_json(serialize_graph_embeddings(resp))
|
await ws.send_json(serialize_graph_embeddings(resp))
|
||||||
|
|
||||||
except TimeoutError:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Exception: {str(e)}", flush=True)
|
print(f"Exception: {str(e)}", flush=True)
|
||||||
break
|
break
|
||||||
|
|
||||||
await self.subscriber.unsubscribe_all(id)
|
self.subscriber.unsubscribe_all(id)
|
||||||
|
|
||||||
running.stop()
|
running.stop()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,8 @@
|
||||||
|
|
||||||
import asyncio
|
import queue
|
||||||
|
import time
|
||||||
|
import pulsar
|
||||||
|
import threading
|
||||||
|
|
||||||
class Publisher:
|
class Publisher:
|
||||||
|
|
||||||
|
|
@ -8,32 +11,43 @@ class Publisher:
|
||||||
self.pulsar_host = pulsar_host
|
self.pulsar_host = pulsar_host
|
||||||
self.topic = topic
|
self.topic = topic
|
||||||
self.schema = schema
|
self.schema = schema
|
||||||
self.q = asyncio.Queue(maxsize=max_size)
|
self.q = queue.Queue(maxsize=max_size)
|
||||||
self.chunking_enabled = chunking_enabled
|
self.chunking_enabled = chunking_enabled
|
||||||
|
|
||||||
async def run(self, client):
|
def start(self):
|
||||||
|
self.task = threading.Thread(target=self.run)
|
||||||
|
self.task.start()
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with client.create_producer(
|
|
||||||
topic=self.topic,
|
|
||||||
schema=self.schema,
|
|
||||||
chunking_enabled=self.chunking_enabled,
|
|
||||||
) as producer:
|
|
||||||
while True:
|
|
||||||
id, item = await self.q.get()
|
|
||||||
|
|
||||||
if id:
|
client = pulsar.Client(
|
||||||
await producer.send(item, { "id": id })
|
self.pulsar_host,
|
||||||
else:
|
)
|
||||||
await producer.send(item)
|
|
||||||
|
producer = client.create_producer(
|
||||||
|
topic=self.topic,
|
||||||
|
schema=self.schema,
|
||||||
|
chunking_enabled=self.chunking_enabled,
|
||||||
|
)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
|
||||||
|
id, item = self.q.get()
|
||||||
|
|
||||||
|
if id:
|
||||||
|
producer.send(item, { "id": id })
|
||||||
|
else:
|
||||||
|
producer.send(item)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Exception:", e, flush=True)
|
print("Exception:", e, flush=True)
|
||||||
|
|
||||||
# If handler drops out, sleep a retry
|
# If handler drops out, sleep a retry
|
||||||
await asyncio.sleep(2)
|
time.sleep(2)
|
||||||
|
|
||||||
async def send(self, id, msg):
|
def send(self, id, msg):
|
||||||
await self.q.put((id, msg))
|
self.q.put((id, msg))
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ from aiohttp import web
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import base64
|
import base64
|
||||||
import aiopulsar
|
|
||||||
|
|
||||||
import pulsar
|
import pulsar
|
||||||
from pulsar.schema import JsonSchema
|
from pulsar.schema import JsonSchema
|
||||||
|
|
@ -167,7 +166,8 @@ class Api:
|
||||||
# content is valid base64
|
# content is valid base64
|
||||||
doc = base64.b64decode(data["data"])
|
doc = base64.b64decode(data["data"])
|
||||||
|
|
||||||
resp = await self.document_out.send(
|
resp = await asyncio.to_thread(
|
||||||
|
self.document_out.send,
|
||||||
None,
|
None,
|
||||||
Document(
|
Document(
|
||||||
metadata=Metadata(
|
metadata=Metadata(
|
||||||
|
|
@ -212,7 +212,8 @@ class Api:
|
||||||
# Text is base64 encoded
|
# Text is base64 encoded
|
||||||
text = base64.b64decode(data["text"]).decode(charset)
|
text = base64.b64decode(data["text"]).decode(charset)
|
||||||
|
|
||||||
resp = await self.text_out.send(
|
resp = asyncio.to_thread(
|
||||||
|
self.text_out.send,
|
||||||
None,
|
None,
|
||||||
TextDocument(
|
TextDocument(
|
||||||
metadata=Metadata(
|
metadata=Metadata(
|
||||||
|
|
@ -238,35 +239,13 @@ class Api:
|
||||||
{ "error": str(e) }
|
{ "error": str(e) }
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run_endpoints(self):
|
|
||||||
|
|
||||||
async with aiopulsar.connect(self.pulsar_host) as client:
|
|
||||||
|
|
||||||
for ep in self.endpoints:
|
|
||||||
await ep.start(client)
|
|
||||||
|
|
||||||
self.doc_ingest_pub_task = asyncio.create_task(
|
|
||||||
self.document_out.run(client)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.text_ingest_pub_task = asyncio.create_task(
|
|
||||||
self.text_out.run(client)
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Endpoints are running...")
|
|
||||||
|
|
||||||
# They never exit
|
|
||||||
for ep in self.endpoints:
|
|
||||||
await ep.join()
|
|
||||||
|
|
||||||
await self.doc_ingest_pub_task
|
|
||||||
await self.text_ingest_pub_task
|
|
||||||
|
|
||||||
print("Endpoints are stopped.")
|
|
||||||
|
|
||||||
async def app_factory(self):
|
async def app_factory(self):
|
||||||
|
|
||||||
self.endpoint_task = asyncio.create_task(self.run_endpoints())
|
for ep in self.endpoints:
|
||||||
|
await ep.start()
|
||||||
|
|
||||||
|
self.document_out.start()
|
||||||
|
self.text_out.start()
|
||||||
|
|
||||||
return self.app
|
return self.app
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -76,12 +76,6 @@ class SocketEndpoint:
|
||||||
async def start(self):
|
async def start(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def join(self):
|
|
||||||
|
|
||||||
# Nothing to wait for
|
|
||||||
while True:
|
|
||||||
await asyncio.sleep(100)
|
|
||||||
|
|
||||||
def add_routes(self, app):
|
def add_routes(self, app):
|
||||||
|
|
||||||
app.add_routes([
|
app.add_routes([
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,13 @@
|
||||||
|
|
||||||
import asyncio
|
import queue
|
||||||
|
import pulsar
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
class Subscriber:
|
class Subscriber:
|
||||||
|
|
||||||
def __init__(self, pulsar_host, topic, subscription, consumer_name,
|
def __init__(self, pulsar_host, topic, subscription, consumer_name,
|
||||||
schema=None, max_size=10):
|
schema=None, max_size=100):
|
||||||
self.pulsar_host = pulsar_host
|
self.pulsar_host = pulsar_host
|
||||||
self.topic = topic
|
self.topic = topic
|
||||||
self.subscription = subscription
|
self.subscription = subscription
|
||||||
|
|
@ -12,55 +15,95 @@ class Subscriber:
|
||||||
self.schema = schema
|
self.schema = schema
|
||||||
self.q = {}
|
self.q = {}
|
||||||
self.full = {}
|
self.full = {}
|
||||||
|
self.max_size = max_size
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self.task = threading.Thread(target=self.run)
|
||||||
|
self.task.start()
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
|
||||||
async def run(self, client):
|
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with client.subscribe(
|
|
||||||
|
client = pulsar.Client(
|
||||||
|
self.pulsar_host,
|
||||||
|
)
|
||||||
|
|
||||||
|
consumer = client.subscribe(
|
||||||
topic=self.topic,
|
topic=self.topic,
|
||||||
subscription_name=self.subscription,
|
subscription_name=self.subscription,
|
||||||
consumer_name=self.consumer_name,
|
consumer_name=self.consumer_name,
|
||||||
schema=self.schema,
|
schema=self.schema,
|
||||||
) as consumer:
|
)
|
||||||
while True:
|
|
||||||
msg = await consumer.receive()
|
|
||||||
|
|
||||||
# Acknowledge successful reception of the message
|
while True:
|
||||||
await consumer.acknowledge(msg)
|
|
||||||
|
|
||||||
try:
|
msg = consumer.receive()
|
||||||
id = msg.properties()["id"]
|
|
||||||
except:
|
# Acknowledge successful reception of the message
|
||||||
id = None
|
consumer.acknowledge(msg)
|
||||||
|
|
||||||
|
try:
|
||||||
|
id = msg.properties()["id"]
|
||||||
|
except:
|
||||||
|
id = None
|
||||||
|
|
||||||
|
value = msg.value()
|
||||||
|
|
||||||
|
with self.lock:
|
||||||
|
|
||||||
value = msg.value()
|
|
||||||
if id in self.q:
|
if id in self.q:
|
||||||
await self.q[id].put(value)
|
try:
|
||||||
|
self.q[id].put(value, timeout=0.5)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
for q in self.full.values():
|
for q in self.full.values():
|
||||||
await q.put(value)
|
try:
|
||||||
|
q.put(value, timeout=0.5)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Exception:", e, flush=True)
|
print("Exception:", e, flush=True)
|
||||||
|
|
||||||
# If handler drops out, sleep a retry
|
# If handler drops out, sleep a retry
|
||||||
await asyncio.sleep(2)
|
time.sleep(2)
|
||||||
|
|
||||||
|
def subscribe(self, id):
|
||||||
|
|
||||||
|
with self.lock:
|
||||||
|
|
||||||
|
q = queue.Queue(maxsize=self.max_size)
|
||||||
|
self.q[id] = q
|
||||||
|
|
||||||
async def subscribe(self, id):
|
|
||||||
q = asyncio.Queue()
|
|
||||||
self.q[id] = q
|
|
||||||
return q
|
return q
|
||||||
|
|
||||||
async def unsubscribe(self, id):
|
def unsubscribe(self, id):
|
||||||
if id in self.q:
|
|
||||||
del self.q[id]
|
with self.lock:
|
||||||
|
|
||||||
|
if id in self.q:
|
||||||
|
# self.q[id].shutdown(immediate=True)
|
||||||
|
del self.q[id]
|
||||||
|
|
||||||
async def subscribe_all(self, id):
|
def subscribe_all(self, id):
|
||||||
q = asyncio.Queue()
|
|
||||||
self.full[id] = q
|
with self.lock:
|
||||||
|
|
||||||
|
q = queue.Queue(maxsize=self.max_size)
|
||||||
|
self.full[id] = q
|
||||||
|
|
||||||
return q
|
return q
|
||||||
|
|
||||||
async def unsubscribe_all(self, id):
|
def unsubscribe_all(self, id):
|
||||||
if id in self.full:
|
|
||||||
del self.full[id]
|
with self.lock:
|
||||||
|
|
||||||
|
if id in self.full:
|
||||||
|
# self.full[id].shutdown(immediate=True)
|
||||||
|
del self.full[id]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,11 +27,9 @@ class TriplesLoadEndpoint(SocketEndpoint):
|
||||||
schema=JsonSchema(Triples)
|
schema=JsonSchema(Triples)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def start(self, client):
|
async def start(self):
|
||||||
|
|
||||||
self.task = asyncio.create_task(
|
self.publisher.start()
|
||||||
self.publisher.run(client)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def listener(self, ws, running):
|
async def listener(self, ws, running):
|
||||||
|
|
||||||
|
|
@ -53,7 +51,7 @@ class TriplesLoadEndpoint(SocketEndpoint):
|
||||||
triples=to_subgraph(data["triples"]),
|
triples=to_subgraph(data["triples"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.publisher.send(None, elt)
|
self.publisher.send(None, elt)
|
||||||
|
|
||||||
|
|
||||||
running.stop()
|
running.stop()
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import queue
|
||||||
from pulsar.schema import JsonSchema
|
from pulsar.schema import JsonSchema
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
|
@ -26,31 +27,29 @@ class TriplesStreamEndpoint(SocketEndpoint):
|
||||||
schema=JsonSchema(Triples)
|
schema=JsonSchema(Triples)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def start(self, client):
|
async def start(self):
|
||||||
|
|
||||||
self.task = asyncio.create_task(
|
self.subscriber.start()
|
||||||
self.subscriber.run(client)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def async_thread(self, ws, running):
|
async def async_thread(self, ws, running):
|
||||||
|
|
||||||
id = str(uuid.uuid4())
|
id = str(uuid.uuid4())
|
||||||
|
|
||||||
q = await self.subscriber.subscribe_all(id)
|
q = self.subscriber.subscribe_all(id)
|
||||||
|
|
||||||
while running.get():
|
while running.get():
|
||||||
try:
|
try:
|
||||||
resp = await asyncio.wait_for(q.get(), 0.5)
|
resp = await asyncio.to_thread(q.get, timeout=0.5)
|
||||||
await ws.send_json(serialize_triples(resp))
|
await ws.send_json(serialize_triples(resp))
|
||||||
|
|
||||||
except TimeoutError:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Exception: {str(e)}", flush=True)
|
print(f"Exception: {str(e)}", flush=True)
|
||||||
break
|
break
|
||||||
|
|
||||||
await self.subscriber.unsubscribe_all(id)
|
self.subscriber.unsubscribe_all(id)
|
||||||
|
|
||||||
running.stop()
|
running.stop()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue