Websocket auth protocol

This commit is contained in:
Cyber MacGeddon 2026-04-23 20:11:41 +01:00
parent d5dabad001
commit 843e68cded
5 changed files with 130 additions and 6 deletions

View file

@ -108,12 +108,18 @@ class DispatcherWrapper:
class DispatcherManager:
def __init__(self, backend, config_receiver, prefix="api-gateway",
queue_overrides=None):
queue_overrides=None, auth=None):
self.backend = backend
self.config_receiver = config_receiver
self.config_receiver.add_handler(self)
self.prefix = prefix
# Gateway IamAuth — used by the socket mux for first-frame
# auth. ``None`` keeps the legacy "caller-supplied
# workspace" behaviour for anything that instantiates Mux
# directly without auth.
self.auth = auth
# Store queue overrides for global services
# Format: {"config": {"request": "...", "response": "..."}, ...}
self.queue_overrides = queue_overrides or {}
@ -325,7 +331,10 @@ class DispatcherManager:
async def process_socket(self, ws, running, params):
dispatcher = Mux(self, ws, running)
# The mux self-authenticates via the first-frame protocol;
# pass the gateway's IamAuth so it can validate tokens
# without reaching back into the endpoint layer.
dispatcher = Mux(self, ws, running, auth=self.auth)
return dispatcher

View file

@ -16,11 +16,26 @@ MAX_QUEUE_SIZE = 10
class Mux:
def __init__(self, dispatcher_manager, ws, running):
def __init__(self, dispatcher_manager, ws, running, auth=None):
"""
``auth`` an ``IamAuth`` when the enclosing endpoint is
configured for in-band first-frame auth. ``None`` for the
legacy ``?token=`` path (kept for the flow import/export
streaming endpoints).
"""
self.dispatcher_manager = dispatcher_manager
self.ws = ws
self.running = running
self.auth = auth
# Authenticated identity, populated by the first-frame auth
# protocol. ``None`` means the socket is not yet
# authenticated; any non-auth frame is refused. If
# ``auth`` is ``None`` (legacy path) the mux acts as if
# already authenticated and uses client-supplied workspace
# values (pre-existing behaviour).
self.identity = None
self.q = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)
@ -31,6 +46,41 @@ class Mux:
if self.ws:
await self.ws.close()
async def _handle_auth_frame(self, data):
"""Process a ``{"type": "auth", "token": "..."}`` frame.
On success, updates ``self.identity`` and returns an
``auth-ok`` response frame. On failure, returns the masked
auth-failure frame. Never raises auth failures keep the
socket open so the client can retry without reconnecting
(important for browsers, which treat a handshake-time 401
as terminal)."""
token = data.get("token", "")
if not token or self.auth is None:
await self.ws.send_json({
"type": "auth-failed",
"error": "auth failure",
})
return
class _Shim:
def __init__(self, tok):
self.headers = {"Authorization": f"Bearer {tok}"}
try:
identity = await self.auth.authenticate(_Shim(token))
except Exception:
await self.ws.send_json({
"type": "auth-failed",
"error": "auth failure",
})
return
self.identity = identity
await self.ws.send_json({
"type": "auth-ok",
"workspace": identity.workspace,
})
async def receive(self, msg):
request_id = None
@ -38,6 +88,18 @@ class Mux:
try:
data = msg.json()
# In-band auth protocol: the client sends
# ``{"type": "auth", "token": "..."}`` as its first frame
# (and any time it wants to re-auth: JWT refresh, token
# rotation, workspace switch in a future multi-workspace
# enterprise). The protocol coexists with legacy
# non-auth sockets (self.auth is None) — on those, every
# frame is a request and workspace is caller-supplied.
if isinstance(data, dict) and data.get("type") == "auth":
await self._handle_auth_frame(data)
return
request_id = data.get("id")
if "request" not in data:
@ -46,9 +108,42 @@ class Mux:
if "id" not in data:
raise RuntimeError("Bad message")
# First-frame auth gating: if the enclosing endpoint is
# configured for in-band auth, reject all non-auth frames
# until an auth-ok has been issued.
if self.auth is not None and self.identity is None:
await self.ws.send_json({
"id": request_id,
"error": {
"message": "auth failure",
"type": "auth-required",
},
"complete": True,
})
return
# Workspace resolution. Authenticated sockets override
# the client-supplied workspace with the resolved value
# from the identity; mismatch is an access-denied error.
if self.identity is not None:
requested_ws = data.get("workspace", "")
if requested_ws and requested_ws != self.identity.workspace:
await self.ws.send_json({
"id": request_id,
"error": {
"message": "access denied",
"type": "access-denied",
},
"complete": True,
})
return
workspace = self.identity.workspace
else:
workspace = data.get("workspace", "default")
await self.q.put((
data["id"],
data.get("workspace", "default"),
workspace,
data.get("flow"),
data["service"],
data["request"]

View file

@ -229,11 +229,17 @@ class EndpointManager:
capability_map=GLOBAL_KIND_CAPABILITY,
),
# /api/v1/socket: WebSocket handshake accepts
# unconditionally; the Mux dispatcher runs the
# first-frame auth protocol. Handshake-time 401s break
# browser reconnection, so authentication is always
# in-band for this endpoint.
SocketEndpoint(
endpoint_path="/api/v1/socket",
auth=auth,
dispatcher=dispatcher_manager.dispatch_socket(),
capability=AUTHENTICATED,
capability=AUTHENTICATED, # informational only; bypassed
in_band_auth=True,
),
# Per-flow request/response services — capability per kind.

View file

@ -15,11 +15,24 @@ class SocketEndpoint:
def __init__(
self, endpoint_path, auth, dispatcher, capability,
in_band_auth=False,
):
"""
``in_band_auth=True`` skips the handshake-time auth check.
The WebSocket handshake always succeeds; the dispatcher is
expected to gate itself via the first-frame auth protocol
(see ``Mux``).
This avoids the browser problem where a 401 on the handshake
is treated as permanent and prevents reconnection, and lets
long-lived sockets refresh their credential mid-session by
sending a new auth frame.
"""
self.path = endpoint_path
self.auth = auth
self.capability = capability
self.in_band_auth = in_band_auth
self.dispatcher = dispatcher
@ -73,7 +86,7 @@ class SocketEndpoint:
The first-frame auth protocol described in the IAM spec is
a future upgrade."""
if self.capability != PUBLIC:
if not self.in_band_auth and self.capability != PUBLIC:
token = request.query.get("token", "")
if not token:
return auth_failure()

View file

@ -118,6 +118,7 @@ class Api:
config_receiver = self.config_receiver,
prefix = "gateway",
queue_overrides = queue_overrides,
auth = self.auth,
)
self.endpoint_manager = EndpointManager(