trustgraph/trustgraph-flow/trustgraph/gateway/dispatch/mux.py
Cyber MacGeddon 9ae79ff712 Updated CLI
2026-04-24 12:41:46 +01:00

306 lines
9.3 KiB
Python

import asyncio
import queue
import uuid
import logging
# Module logger
logger = logging.getLogger(__name__)
MAX_OUTSTANDING_REQUESTS = 15
WORKER_CLOSE_WAIT = 0.01
START_REQUEST_WAIT = 0.1
# This buffers requests until task start, so short-lived
MAX_QUEUE_SIZE = 10
class Mux:
def __init__(self, dispatcher_manager, ws, running, auth=None):
"""
``auth`` — an ``IamAuth`` when the enclosing endpoint is
configured for in-band first-frame auth. ``None`` for the
legacy ``?token=`` path (kept for the flow import/export
streaming endpoints).
"""
self.dispatcher_manager = dispatcher_manager
self.ws = ws
self.running = running
self.auth = auth
# Authenticated identity, populated by the first-frame auth
# protocol. ``None`` means the socket is not yet
# authenticated; any non-auth frame is refused. If
# ``auth`` is ``None`` (legacy path) the mux acts as if
# already authenticated and uses client-supplied workspace
# values (pre-existing behaviour).
self.identity = None
self.q = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)
async def destroy(self):
self.running.stop()
if self.ws:
await self.ws.close()
async def _handle_auth_frame(self, data):
"""Process a ``{"type": "auth", "token": "..."}`` frame.
On success, updates ``self.identity`` and returns an
``auth-ok`` response frame. On failure, returns the masked
auth-failure frame. Never raises — auth failures keep the
socket open so the client can retry without reconnecting
(important for browsers, which treat a handshake-time 401
as terminal)."""
token = data.get("token", "")
if not token or self.auth is None:
await self.ws.send_json({
"type": "auth-failed",
"error": "auth failure",
})
return
class _Shim:
def __init__(self, tok):
self.headers = {"Authorization": f"Bearer {tok}"}
try:
identity = await self.auth.authenticate(_Shim(token))
except Exception:
await self.ws.send_json({
"type": "auth-failed",
"error": "auth failure",
})
return
self.identity = identity
await self.ws.send_json({
"type": "auth-ok",
"workspace": identity.workspace,
})
async def receive(self, msg):
request_id = None
try:
data = msg.json()
# In-band auth protocol: the client sends
# ``{"type": "auth", "token": "..."}`` as its first frame
# (and any time it wants to re-auth: JWT refresh, token
# rotation, workspace switch in a future multi-workspace
# enterprise). The protocol coexists with legacy
# non-auth sockets (self.auth is None) — on those, every
# frame is a request and workspace is caller-supplied.
if isinstance(data, dict) and data.get("type") == "auth":
await self._handle_auth_frame(data)
return
request_id = data.get("id")
if "request" not in data:
raise RuntimeError("Bad message")
if "id" not in data:
raise RuntimeError("Bad message")
# First-frame auth gating: if the enclosing endpoint is
# configured for in-band auth, reject all non-auth frames
# until an auth-ok has been issued.
if self.auth is not None and self.identity is None:
await self.ws.send_json({
"id": request_id,
"error": {
"message": "auth failure",
"type": "auth-required",
},
"complete": True,
})
return
# Workspace resolution. On authenticated sockets the
# gateway's role-scope rules apply: role workspace scope
# determines which target workspaces are permitted. The
# resolved value is written to both the envelope and the
# inner request payload so clients don't have to repeat it
# per-message (same convenience HTTP callers get via
# enforce_workspace).
if self.identity is not None:
from ..capabilities import enforce_workspace
from aiohttp import web as _web
try:
enforce_workspace(data, self.identity)
inner = data.get("request")
if isinstance(inner, dict):
enforce_workspace(inner, self.identity)
except _web.HTTPForbidden:
await self.ws.send_json({
"id": request_id,
"error": {
"message": "access denied",
"type": "access-denied",
},
"complete": True,
})
return
workspace = data["workspace"]
else:
workspace = data.get("workspace", "default")
await self.q.put((
data["id"],
workspace,
data.get("flow"),
data["service"],
data["request"]
))
except Exception as e:
logger.error(f"Receive exception: {str(e)}", exc_info=True)
error_resp = {
"error": {"message": str(e), "type": "error"},
"complete": True,
}
if request_id:
error_resp["id"] = request_id
await self.ws.send_json(error_resp)
async def maybe_tidy_workers(self, workers):
while True:
try:
await asyncio.wait_for(
asyncio.shield(workers[0]),
WORKER_CLOSE_WAIT
)
# worker[0] now stopped
# FIXME: Delete reference???
workers.pop(0)
if len(workers) == 0:
break
# Loop iterates to try the next worker
except TimeoutError:
# worker[0] still running, move on
break
async def start_request_task(
self, ws, id, workspace, flow, svc, request, workers,
):
# Wait for outstanding requests to go below MAX_OUTSTANDING_REQUESTS
while len(workers) > MAX_OUTSTANDING_REQUESTS:
# Fixes deadlock
# FIXME: Put it in its own loop
await asyncio.sleep(START_REQUEST_WAIT)
await self.maybe_tidy_workers(workers)
async def responder(resp, fin):
await self.ws.send_json({
"id": id,
"response": resp,
"complete": fin,
})
worker = asyncio.create_task(
self.request_task(
id, request, responder, workspace, flow, svc,
)
)
workers.append(worker)
async def request_task(
self, id, request, responder, workspace, flow, svc,
):
try:
if flow:
await self.dispatcher_manager.invoke_flow_service(
request, responder, workspace, flow, svc,
)
else:
await self.dispatcher_manager.invoke_global_service(
request, responder, svc
)
except Exception as e:
await self.ws.send_json({
"id": id,
"error": {"message": str(e), "type": "error"},
"complete": True,
})
async def run(self):
# Worker threads, servicing
workers = []
while self.running.get():
try:
if len(workers) > 0:
await self.maybe_tidy_workers(workers)
# Get next request on queue
item = await asyncio.wait_for(self.q.get(), 1)
id, workspace, flow, svc, request = item
except TimeoutError:
continue
except Exception as e:
# This is an internal working error, may not be recoverable
logger.error(f"Run prepare exception: {e}", exc_info=True)
await self.ws.send_json({
"id": id,
"error": {"message": str(e), "type": "error"},
"complete": True,
})
self.running.stop()
if self.ws:
await self.ws.close()
self.ws = None
break
try:
await self.start_request_task(
self.ws, id, workspace, flow, svc, request, workers
)
except Exception as e:
logger.error(f"Exception in mux: {e}", exc_info=True)
await self.ws.send_json({
"id": id,
"error": {"message": str(e), "type": "error"},
"complete": True,
})
self.running.stop()
if self.ws:
await self.ws.close()
self.ws = None