Refactor socket threading (#219)

* Multiple requests can be handled in parallel.
* Refactor to fix timeout issue.
This commit is contained in:
cybermaggedon 2024-12-27 10:34:16 +00:00 committed by GitHub
parent 62d25effd5
commit 7f5296feca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 104 additions and 44 deletions

View file

@ -8,6 +8,13 @@ from aiohttp import web, WSMsgType
from . socket import SocketEndpoint
from . text_completion import TextCompletionRequestor
MAX_OUTSTANDING_REQUESTS = 15
WORKER_CLOSE_WAIT = 0.01
START_REQUEST_WAIT = 0.1
# This buffers requests until task start, so short-lived
MAX_QUEUE_SIZE = 10
class MuxEndpoint(SocketEndpoint):
def __init__(
@ -20,53 +27,113 @@ class MuxEndpoint(SocketEndpoint):
endpoint_path=path, auth=auth,
)
self.q = asyncio.Queue(maxsize=10)
self.services = services
async def start(self):
pass
async def async_thread(self, ws, running):
async def maybe_tidy_workers(self, workers):
while True:
try:
await asyncio.wait_for(
asyncio.shield(workers[0]),
WORKER_CLOSE_WAIT
)
# worker[0] now stopped
# FIXME: Delete reference???
workers.pop(0)
if len(workers) == 0:
break
# Loop iterates to try the next worker
except TimeoutError:
# worker[0] still running, move on
break
async def start_request_task(self, ws, id, svc, request, workers):
if svc not in self.services:
await ws.send_json({"id": id, "error": "Service not recognised"})
return
requestor = self.services[svc]
async def responder(resp, fin):
await ws.send_json({
"id": id,
"response": resp,
"complete": fin,
})
# Wait for outstanding requests to go below MAX_OUTSTANDING_REQUESTS
while len(workers) > MAX_OUTSTANDING_REQUESTS:
# Fixes deadlock
# FIXME: Put it in its own loop
await asyncio.sleep(START_REQUEST_WAIT)
await self.maybe_tidy_workers(workers)
worker = asyncio.create_task(
requestor.process(request, responder)
)
workers.append(worker)
async def async_thread(self, ws, running, q):
# Worker threads, servicing
workers = []
while running.get():
try:
id, svc, request = await asyncio.wait_for(self.q.get(), 1)
if len(workers) > 0:
await self.maybe_tidy_workers(workers)
# Get next request on queue
id, svc, request = await asyncio.wait_for(q.get(), 1)
except TimeoutError:
continue
except Exception as e:
# This is an internal working error, may not be recoverable
print("Exception:", e)
await ws.send_json({"id": id, "error": str(e)})
break
try:
print(svc, request)
requestor = self.services[svc]
async def responder(resp, fin):
await ws.send_json({
"id": id,
"response": resp,
"complete": fin,
})
resp = await requestor.process(request, responder)
print(id, svc, request)
await self.start_request_task(ws, id, svc, request, workers)
except Exception as e:
print("Exception2:", e)
await ws.send_json({"error": str(e)})
running.stop()
async def listener(self, ws, running):
# The outstanding request queue, max size is MAX_QUEUE_SIZE
q = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)
async_task = asyncio.create_task(self.async_thread(
ws, running, q
))
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.ERROR:
break
else:
if msg.type == WSMsgType.TEXT:
try:
@ -81,7 +148,7 @@ class MuxEndpoint(SocketEndpoint):
if "id" not in data:
raise RuntimeError("Bad message")
await self.q.put(
await q.put(
(data["id"], data["service"], data["request"])
)
@ -90,5 +157,13 @@ class MuxEndpoint(SocketEndpoint):
await ws.send_json({"error": str(e)})
continue
elif msg.type == WSMsgType.ERROR:
break
elif msg.type == WSMsgType.CLOSE:
break
else:
break
running.stop()
await async_task

View file

@ -22,25 +22,16 @@ class SocketEndpoint:
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.ERROR:
break
if msg.type == WSMsgType.TEXT:
# Ignore incoming message
continue
elif msg.type == WSMsgType.BINARY:
# Ignore incoming message
continue
else:
# Ignore incoming messages
pass
break
running.stop()
async def async_thread(self, ws, running):
while running.get():
try:
await asyncio.sleep(1)
except TimeoutError:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
async def handle(self, request):
@ -56,12 +47,8 @@ class SocketEndpoint:
ws = web.WebSocketResponse()
await ws.prepare(request)
task = asyncio.create_task(self.async_thread(ws, running))
try:
await self.listener(ws, running)
except Exception as e:
print(e, flush=True)
@ -69,8 +56,6 @@ class SocketEndpoint:
await ws.close()
await task
return ws
async def start(self):