mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-04 04:42:36 +02:00
Websocket auth protocol
This commit is contained in:
parent
d5dabad001
commit
843e68cded
5 changed files with 130 additions and 6 deletions
|
|
@ -108,12 +108,18 @@ class DispatcherWrapper:
|
||||||
class DispatcherManager:
|
class DispatcherManager:
|
||||||
|
|
||||||
def __init__(self, backend, config_receiver, prefix="api-gateway",
|
def __init__(self, backend, config_receiver, prefix="api-gateway",
|
||||||
queue_overrides=None):
|
queue_overrides=None, auth=None):
|
||||||
self.backend = backend
|
self.backend = backend
|
||||||
self.config_receiver = config_receiver
|
self.config_receiver = config_receiver
|
||||||
self.config_receiver.add_handler(self)
|
self.config_receiver.add_handler(self)
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
|
|
||||||
|
# Gateway IamAuth — used by the socket mux for first-frame
|
||||||
|
# auth. ``None`` keeps the legacy "caller-supplied
|
||||||
|
# workspace" behaviour for anything that instantiates Mux
|
||||||
|
# directly without auth.
|
||||||
|
self.auth = auth
|
||||||
|
|
||||||
# Store queue overrides for global services
|
# Store queue overrides for global services
|
||||||
# Format: {"config": {"request": "...", "response": "..."}, ...}
|
# Format: {"config": {"request": "...", "response": "..."}, ...}
|
||||||
self.queue_overrides = queue_overrides or {}
|
self.queue_overrides = queue_overrides or {}
|
||||||
|
|
@ -325,7 +331,10 @@ class DispatcherManager:
|
||||||
|
|
||||||
async def process_socket(self, ws, running, params):
|
async def process_socket(self, ws, running, params):
|
||||||
|
|
||||||
dispatcher = Mux(self, ws, running)
|
# The mux self-authenticates via the first-frame protocol;
|
||||||
|
# pass the gateway's IamAuth so it can validate tokens
|
||||||
|
# without reaching back into the endpoint layer.
|
||||||
|
dispatcher = Mux(self, ws, running, auth=self.auth)
|
||||||
|
|
||||||
return dispatcher
|
return dispatcher
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,11 +16,26 @@ MAX_QUEUE_SIZE = 10
|
||||||
|
|
||||||
class Mux:
|
class Mux:
|
||||||
|
|
||||||
def __init__(self, dispatcher_manager, ws, running):
|
def __init__(self, dispatcher_manager, ws, running, auth=None):
|
||||||
|
"""
|
||||||
|
``auth`` — an ``IamAuth`` when the enclosing endpoint is
|
||||||
|
configured for in-band first-frame auth. ``None`` for the
|
||||||
|
legacy ``?token=`` path (kept for the flow import/export
|
||||||
|
streaming endpoints).
|
||||||
|
"""
|
||||||
|
|
||||||
self.dispatcher_manager = dispatcher_manager
|
self.dispatcher_manager = dispatcher_manager
|
||||||
self.ws = ws
|
self.ws = ws
|
||||||
self.running = running
|
self.running = running
|
||||||
|
self.auth = auth
|
||||||
|
|
||||||
|
# Authenticated identity, populated by the first-frame auth
|
||||||
|
# protocol. ``None`` means the socket is not yet
|
||||||
|
# authenticated; any non-auth frame is refused. If
|
||||||
|
# ``auth`` is ``None`` (legacy path) the mux acts as if
|
||||||
|
# already authenticated and uses client-supplied workspace
|
||||||
|
# values (pre-existing behaviour).
|
||||||
|
self.identity = None
|
||||||
|
|
||||||
self.q = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)
|
self.q = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)
|
||||||
|
|
||||||
|
|
@ -31,6 +46,41 @@ class Mux:
|
||||||
if self.ws:
|
if self.ws:
|
||||||
await self.ws.close()
|
await self.ws.close()
|
||||||
|
|
||||||
|
async def _handle_auth_frame(self, data):
|
||||||
|
"""Process a ``{"type": "auth", "token": "..."}`` frame.
|
||||||
|
On success, updates ``self.identity`` and returns an
|
||||||
|
``auth-ok`` response frame. On failure, returns the masked
|
||||||
|
auth-failure frame. Never raises — auth failures keep the
|
||||||
|
socket open so the client can retry without reconnecting
|
||||||
|
(important for browsers, which treat a handshake-time 401
|
||||||
|
as terminal)."""
|
||||||
|
token = data.get("token", "")
|
||||||
|
if not token or self.auth is None:
|
||||||
|
await self.ws.send_json({
|
||||||
|
"type": "auth-failed",
|
||||||
|
"error": "auth failure",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
|
||||||
|
class _Shim:
|
||||||
|
def __init__(self, tok):
|
||||||
|
self.headers = {"Authorization": f"Bearer {tok}"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
identity = await self.auth.authenticate(_Shim(token))
|
||||||
|
except Exception:
|
||||||
|
await self.ws.send_json({
|
||||||
|
"type": "auth-failed",
|
||||||
|
"error": "auth failure",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
|
||||||
|
self.identity = identity
|
||||||
|
await self.ws.send_json({
|
||||||
|
"type": "auth-ok",
|
||||||
|
"workspace": identity.workspace,
|
||||||
|
})
|
||||||
|
|
||||||
async def receive(self, msg):
|
async def receive(self, msg):
|
||||||
|
|
||||||
request_id = None
|
request_id = None
|
||||||
|
|
@ -38,6 +88,18 @@ class Mux:
|
||||||
try:
|
try:
|
||||||
|
|
||||||
data = msg.json()
|
data = msg.json()
|
||||||
|
|
||||||
|
# In-band auth protocol: the client sends
|
||||||
|
# ``{"type": "auth", "token": "..."}`` as its first frame
|
||||||
|
# (and any time it wants to re-auth: JWT refresh, token
|
||||||
|
# rotation, workspace switch in a future multi-workspace
|
||||||
|
# enterprise). The protocol coexists with legacy
|
||||||
|
# non-auth sockets (self.auth is None) — on those, every
|
||||||
|
# frame is a request and workspace is caller-supplied.
|
||||||
|
if isinstance(data, dict) and data.get("type") == "auth":
|
||||||
|
await self._handle_auth_frame(data)
|
||||||
|
return
|
||||||
|
|
||||||
request_id = data.get("id")
|
request_id = data.get("id")
|
||||||
|
|
||||||
if "request" not in data:
|
if "request" not in data:
|
||||||
|
|
@ -46,9 +108,42 @@ class Mux:
|
||||||
if "id" not in data:
|
if "id" not in data:
|
||||||
raise RuntimeError("Bad message")
|
raise RuntimeError("Bad message")
|
||||||
|
|
||||||
|
# First-frame auth gating: if the enclosing endpoint is
|
||||||
|
# configured for in-band auth, reject all non-auth frames
|
||||||
|
# until an auth-ok has been issued.
|
||||||
|
if self.auth is not None and self.identity is None:
|
||||||
|
await self.ws.send_json({
|
||||||
|
"id": request_id,
|
||||||
|
"error": {
|
||||||
|
"message": "auth failure",
|
||||||
|
"type": "auth-required",
|
||||||
|
},
|
||||||
|
"complete": True,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
|
||||||
|
# Workspace resolution. Authenticated sockets override
|
||||||
|
# the client-supplied workspace with the resolved value
|
||||||
|
# from the identity; mismatch is an access-denied error.
|
||||||
|
if self.identity is not None:
|
||||||
|
requested_ws = data.get("workspace", "")
|
||||||
|
if requested_ws and requested_ws != self.identity.workspace:
|
||||||
|
await self.ws.send_json({
|
||||||
|
"id": request_id,
|
||||||
|
"error": {
|
||||||
|
"message": "access denied",
|
||||||
|
"type": "access-denied",
|
||||||
|
},
|
||||||
|
"complete": True,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
workspace = self.identity.workspace
|
||||||
|
else:
|
||||||
|
workspace = data.get("workspace", "default")
|
||||||
|
|
||||||
await self.q.put((
|
await self.q.put((
|
||||||
data["id"],
|
data["id"],
|
||||||
data.get("workspace", "default"),
|
workspace,
|
||||||
data.get("flow"),
|
data.get("flow"),
|
||||||
data["service"],
|
data["service"],
|
||||||
data["request"]
|
data["request"]
|
||||||
|
|
|
||||||
|
|
@ -229,11 +229,17 @@ class EndpointManager:
|
||||||
capability_map=GLOBAL_KIND_CAPABILITY,
|
capability_map=GLOBAL_KIND_CAPABILITY,
|
||||||
),
|
),
|
||||||
|
|
||||||
|
# /api/v1/socket: WebSocket handshake accepts
|
||||||
|
# unconditionally; the Mux dispatcher runs the
|
||||||
|
# first-frame auth protocol. Handshake-time 401s break
|
||||||
|
# browser reconnection, so authentication is always
|
||||||
|
# in-band for this endpoint.
|
||||||
SocketEndpoint(
|
SocketEndpoint(
|
||||||
endpoint_path="/api/v1/socket",
|
endpoint_path="/api/v1/socket",
|
||||||
auth=auth,
|
auth=auth,
|
||||||
dispatcher=dispatcher_manager.dispatch_socket(),
|
dispatcher=dispatcher_manager.dispatch_socket(),
|
||||||
capability=AUTHENTICATED,
|
capability=AUTHENTICATED, # informational only; bypassed
|
||||||
|
in_band_auth=True,
|
||||||
),
|
),
|
||||||
|
|
||||||
# Per-flow request/response services — capability per kind.
|
# Per-flow request/response services — capability per kind.
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,24 @@ class SocketEndpoint:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, endpoint_path, auth, dispatcher, capability,
|
self, endpoint_path, auth, dispatcher, capability,
|
||||||
|
in_band_auth=False,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
``in_band_auth=True`` skips the handshake-time auth check.
|
||||||
|
The WebSocket handshake always succeeds; the dispatcher is
|
||||||
|
expected to gate itself via the first-frame auth protocol
|
||||||
|
(see ``Mux``).
|
||||||
|
|
||||||
|
This avoids the browser problem where a 401 on the handshake
|
||||||
|
is treated as permanent and prevents reconnection, and lets
|
||||||
|
long-lived sockets refresh their credential mid-session by
|
||||||
|
sending a new auth frame.
|
||||||
|
"""
|
||||||
|
|
||||||
self.path = endpoint_path
|
self.path = endpoint_path
|
||||||
self.auth = auth
|
self.auth = auth
|
||||||
self.capability = capability
|
self.capability = capability
|
||||||
|
self.in_band_auth = in_band_auth
|
||||||
|
|
||||||
self.dispatcher = dispatcher
|
self.dispatcher = dispatcher
|
||||||
|
|
||||||
|
|
@ -73,7 +86,7 @@ class SocketEndpoint:
|
||||||
The first-frame auth protocol described in the IAM spec is
|
The first-frame auth protocol described in the IAM spec is
|
||||||
a future upgrade."""
|
a future upgrade."""
|
||||||
|
|
||||||
if self.capability != PUBLIC:
|
if not self.in_band_auth and self.capability != PUBLIC:
|
||||||
token = request.query.get("token", "")
|
token = request.query.get("token", "")
|
||||||
if not token:
|
if not token:
|
||||||
return auth_failure()
|
return auth_failure()
|
||||||
|
|
|
||||||
|
|
@ -118,6 +118,7 @@ class Api:
|
||||||
config_receiver = self.config_receiver,
|
config_receiver = self.config_receiver,
|
||||||
prefix = "gateway",
|
prefix = "gateway",
|
||||||
queue_overrides = queue_overrides,
|
queue_overrides = queue_overrides,
|
||||||
|
auth = self.auth,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.endpoint_manager = EndpointManager(
|
self.endpoint_manager = EndpointManager(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue