mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-01 11:26:22 +02:00
Refactor socket threading (#219)
* Multiple requests can be handled in parallel. * Refactor to fix timeout issue.
This commit is contained in:
parent
62d25effd5
commit
7f5296feca
2 changed files with 104 additions and 44 deletions
|
|
@ -8,6 +8,13 @@ from aiohttp import web, WSMsgType
|
|||
from . socket import SocketEndpoint
|
||||
from . text_completion import TextCompletionRequestor
|
||||
|
||||
MAX_OUTSTANDING_REQUESTS = 15
|
||||
WORKER_CLOSE_WAIT = 0.01
|
||||
START_REQUEST_WAIT = 0.1
|
||||
|
||||
# This buffers requests until task start, so short-lived
|
||||
MAX_QUEUE_SIZE = 10
|
||||
|
||||
class MuxEndpoint(SocketEndpoint):
|
||||
|
||||
def __init__(
|
||||
|
|
@ -20,53 +27,113 @@ class MuxEndpoint(SocketEndpoint):
|
|||
endpoint_path=path, auth=auth,
|
||||
)
|
||||
|
||||
self.q = asyncio.Queue(maxsize=10)
|
||||
|
||||
self.services = services
|
||||
|
||||
async def start(self):
|
||||
pass
|
||||
|
||||
async def async_thread(self, ws, running):
|
||||
async def maybe_tidy_workers(self, workers):
|
||||
|
||||
while True:
|
||||
|
||||
try:
|
||||
|
||||
await asyncio.wait_for(
|
||||
asyncio.shield(workers[0]),
|
||||
WORKER_CLOSE_WAIT
|
||||
)
|
||||
|
||||
# worker[0] now stopped
|
||||
# FIXME: Delete reference???
|
||||
|
||||
workers.pop(0)
|
||||
|
||||
if len(workers) == 0:
|
||||
break
|
||||
|
||||
# Loop iterates to try the next worker
|
||||
|
||||
except TimeoutError:
|
||||
# worker[0] still running, move on
|
||||
break
|
||||
|
||||
async def start_request_task(self, ws, id, svc, request, workers):
|
||||
|
||||
if svc not in self.services:
|
||||
await ws.send_json({"id": id, "error": "Service not recognised"})
|
||||
return
|
||||
|
||||
requestor = self.services[svc]
|
||||
|
||||
async def responder(resp, fin):
|
||||
await ws.send_json({
|
||||
"id": id,
|
||||
"response": resp,
|
||||
"complete": fin,
|
||||
})
|
||||
|
||||
# Wait for outstanding requests to go below MAX_OUTSTANDING_REQUESTS
|
||||
while len(workers) > MAX_OUTSTANDING_REQUESTS:
|
||||
|
||||
# Fixes deadlock
|
||||
# FIXME: Put it in its own loop
|
||||
await asyncio.sleep(START_REQUEST_WAIT)
|
||||
|
||||
await self.maybe_tidy_workers(workers)
|
||||
|
||||
worker = asyncio.create_task(
|
||||
requestor.process(request, responder)
|
||||
)
|
||||
|
||||
workers.append(worker)
|
||||
|
||||
async def async_thread(self, ws, running, q):
|
||||
|
||||
# Worker threads, servicing
|
||||
workers = []
|
||||
|
||||
while running.get():
|
||||
|
||||
try:
|
||||
id, svc, request = await asyncio.wait_for(self.q.get(), 1)
|
||||
|
||||
if len(workers) > 0:
|
||||
await self.maybe_tidy_workers(workers)
|
||||
|
||||
# Get next request on queue
|
||||
id, svc, request = await asyncio.wait_for(q.get(), 1)
|
||||
|
||||
except TimeoutError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
# This is an internal working error, may not be recoverable
|
||||
print("Exception:", e)
|
||||
await ws.send_json({"id": id, "error": str(e)})
|
||||
break
|
||||
|
||||
try:
|
||||
|
||||
print(svc, request)
|
||||
|
||||
requestor = self.services[svc]
|
||||
|
||||
async def responder(resp, fin):
|
||||
await ws.send_json({
|
||||
"id": id,
|
||||
"response": resp,
|
||||
"complete": fin,
|
||||
})
|
||||
|
||||
resp = await requestor.process(request, responder)
|
||||
print(id, svc, request)
|
||||
await self.start_request_task(ws, id, svc, request, workers)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print("Exception2:", e)
|
||||
await ws.send_json({"error": str(e)})
|
||||
|
||||
running.stop()
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
||||
# The outstanding request queue, max size is MAX_QUEUE_SIZE
|
||||
q = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)
|
||||
|
||||
async_task = asyncio.create_task(self.async_thread(
|
||||
ws, running, q
|
||||
))
|
||||
|
||||
async for msg in ws:
|
||||
|
||||
# On error, finish
|
||||
if msg.type == WSMsgType.ERROR:
|
||||
break
|
||||
else:
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
|
||||
try:
|
||||
|
||||
|
|
@ -81,7 +148,7 @@ class MuxEndpoint(SocketEndpoint):
|
|||
if "id" not in data:
|
||||
raise RuntimeError("Bad message")
|
||||
|
||||
await self.q.put(
|
||||
await q.put(
|
||||
(data["id"], data["service"], data["request"])
|
||||
)
|
||||
|
||||
|
|
@ -90,5 +157,13 @@ class MuxEndpoint(SocketEndpoint):
|
|||
await ws.send_json({"error": str(e)})
|
||||
continue
|
||||
|
||||
elif msg.type == WSMsgType.ERROR:
|
||||
break
|
||||
elif msg.type == WSMsgType.CLOSE:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
running.stop()
|
||||
|
||||
await async_task
|
||||
|
|
|
|||
|
|
@ -22,25 +22,16 @@ class SocketEndpoint:
|
|||
|
||||
async for msg in ws:
|
||||
# On error, finish
|
||||
if msg.type == WSMsgType.ERROR:
|
||||
break
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
# Ignore incoming message
|
||||
continue
|
||||
elif msg.type == WSMsgType.BINARY:
|
||||
# Ignore incoming message
|
||||
continue
|
||||
else:
|
||||
# Ignore incoming messages
|
||||
pass
|
||||
break
|
||||
|
||||
running.stop()
|
||||
|
||||
async def async_thread(self, ws, running):
|
||||
|
||||
while running.get():
|
||||
try:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except TimeoutError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Exception: {str(e)}", flush=True)
|
||||
|
||||
async def handle(self, request):
|
||||
|
||||
|
|
@ -56,12 +47,8 @@ class SocketEndpoint:
|
|||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
task = asyncio.create_task(self.async_thread(ws, running))
|
||||
|
||||
try:
|
||||
|
||||
await self.listener(ws, running)
|
||||
|
||||
except Exception as e:
|
||||
print(e, flush=True)
|
||||
|
||||
|
|
@ -69,8 +56,6 @@ class SocketEndpoint:
|
|||
|
||||
await ws.close()
|
||||
|
||||
await task
|
||||
|
||||
return ws
|
||||
|
||||
async def start(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue