diff --git a/trustgraph-cli/trustgraph/cli/bootstrap_iam.py b/trustgraph-cli/trustgraph/cli/bootstrap_iam.py index df282984..99a789e2 100644 --- a/trustgraph-cli/trustgraph/cli/bootstrap_iam.py +++ b/trustgraph-cli/trustgraph/cli/bootstrap_iam.py @@ -16,21 +16,21 @@ import sys import requests default_url = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/") -default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def bootstrap(url, token): +def bootstrap(url): - endpoint = url.rstrip("/") + "/api/v1/iam" + # Unauthenticated public endpoint — IAM refuses the bootstrap + # operation unless the service is running in bootstrap mode with + # empty tables, so the safety gate lives on the server side. + endpoint = url.rstrip("/") + "/api/v1/auth/bootstrap" headers = {"Content-Type": "application/json"} - if token: - headers["Authorization"] = f"Bearer {token}" resp = requests.post( endpoint, headers=headers, - data=json.dumps({"operation": "bootstrap"}), + data=json.dumps({}), ) if resp.status_code != 200: @@ -71,16 +71,11 @@ def main(): default=default_url, help=f"API URL (default: {default_url})", ) - parser.add_argument( - "-t", "--token", - default=default_token, - help="Gateway bearer token (default: $TRUSTGRAPH_TOKEN)", - ) args = parser.parse_args() try: - user_id, api_key = bootstrap(args.api_url, args.token) + user_id, api_key = bootstrap(args.api_url) except Exception as e: print("Exception:", e, file=sys.stderr, flush=True) sys.exit(1) diff --git a/trustgraph-flow/trustgraph/gateway/auth.py b/trustgraph-flow/trustgraph/gateway/auth.py index a693ca32..95743261 100644 --- a/trustgraph-flow/trustgraph/gateway/auth.py +++ b/trustgraph-flow/trustgraph/gateway/auth.py @@ -1,22 +1,264 @@ +""" +IAM-backed authentication for the API gateway. -class Authenticator: +Replaces the legacy GATEWAY_SECRET shared-token Authenticator. The +gateway is now stateless with respect to credentials: it either +verifies a JWT locally using the active IAM signing public key, or +resolves an API key by hash with a short local cache backed by the +IAM service. - def __init__(self, token=None, allow_all=False): +Identity returned by authenticate() is the (user_id, workspace, +roles) triple the rest of the gateway — capability checks, workspace +resolver, audit logging — needs. +""" - if not allow_all and token is None: - raise RuntimeError("Need a token") +import asyncio +import base64 +import hashlib +import json +import logging +import time +import uuid +from dataclasses import dataclass - if not allow_all and token == "": - raise RuntimeError("Need a token") +from aiohttp import web - self.token = token - self.allow_all = allow_all +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 - def permitted(self, token, roles): +from ..base.iam_client import IamClient +from ..base.metrics import ProducerMetrics, SubscriberMetrics +from ..schema import ( + IamRequest, IamResponse, + iam_request_queue, iam_response_queue, +) - if self.allow_all: return True +logger = logging.getLogger("auth") - if self.token != token: return False +API_KEY_CACHE_TTL = 60 # seconds - return True +@dataclass +class Identity: + user_id: str + workspace: str + roles: list + source: str # "api-key" | "jwt" + + +def _auth_failure(): + return web.HTTPUnauthorized( + text='{"error":"auth failure"}', + content_type="application/json", + ) + + +def _access_denied(): + return web.HTTPForbidden( + text='{"error":"access denied"}', + content_type="application/json", + ) + + +def _b64url_decode(s): + pad = "=" * (-len(s) % 4) + return base64.urlsafe_b64decode(s + pad) + + +def _verify_jwt_eddsa(token, public_pem): + """Verify an Ed25519 JWT and return its claims. Raises on any + validation failure. Refuses non-EdDSA algorithms.""" + parts = token.split(".") + if len(parts) != 3: + raise ValueError("malformed JWT") + h_b64, p_b64, s_b64 = parts + signing_input = f"{h_b64}.{p_b64}".encode("ascii") + header = json.loads(_b64url_decode(h_b64)) + if header.get("alg") != "EdDSA": + raise ValueError(f"unsupported alg: {header.get('alg')!r}") + + key = serialization.load_pem_public_key(public_pem.encode("ascii")) + if not isinstance(key, ed25519.Ed25519PublicKey): + raise ValueError("public key is not Ed25519") + + signature = _b64url_decode(s_b64) + key.verify(signature, signing_input) # raises InvalidSignature + + claims = json.loads(_b64url_decode(p_b64)) + exp = claims.get("exp") + if exp is None or exp < time.time(): + raise ValueError("expired") + return claims + + +class IamAuth: + """Resolves bearer credentials via the IAM service. + + Used by every gateway endpoint that needs authentication. Fetches + the IAM signing public key at startup (cached in memory). API + keys are resolved via the IAM service with a local hash→identity + cache (short TTL so revoked keys stop working within the TTL + window without any push mechanism).""" + + def __init__(self, backend, id="api-gateway"): + self.backend = backend + self.id = id + + # Populated at start() via IAM. + self._signing_public_pem = None + + # API-key cache: plaintext_sha256_hex -> (Identity, expires_ts) + self._key_cache = {} + self._key_cache_lock = asyncio.Lock() + + # ------------------------------------------------------------------ + # Short-lived client helper. Mirrors the pattern used by the + # bootstrap framework and AsyncProcessor: a fresh uuid suffix per + # invocation so Pulsar exclusive subscriptions don't collide with + # ghosts from prior calls. + # ------------------------------------------------------------------ + + def _make_client(self): + rr_id = str(uuid.uuid4()) + return IamClient( + backend=self.backend, + subscription=f"{self.id}--iam--{rr_id}", + consumer_name=self.id, + request_topic=iam_request_queue, + request_schema=IamRequest, + request_metrics=ProducerMetrics( + processor=self.id, flow=None, name="iam-request", + ), + response_topic=iam_response_queue, + response_schema=IamResponse, + response_metrics=SubscriberMetrics( + processor=self.id, flow=None, name="iam-response", + ), + ) + + async def _with_client(self, op): + """Open a short-lived IamClient, run ``op(client)``, close.""" + client = self._make_client() + await client.start() + try: + return await op(client) + finally: + try: + await client.stop() + except Exception: + pass + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + async def start(self, max_retries=30, retry_delay=2.0): + """Fetch the signing public key from IAM. Retries on + failure — the gateway may be starting before IAM is ready.""" + + async def _fetch(client): + return await client.get_signing_key_public() + + for attempt in range(max_retries): + try: + pem = await self._with_client(_fetch) + if pem: + self._signing_public_pem = pem + logger.info( + "IamAuth: fetched IAM signing public key " + f"({len(pem)} bytes)" + ) + return + except Exception as e: + logger.info( + f"IamAuth: waiting for IAM signing key " + f"({type(e).__name__}: {e}); " + f"retry {attempt + 1}/{max_retries}" + ) + await asyncio.sleep(retry_delay) + + # Don't prevent startup forever. A later authenticate() call + # will try again via the JWT path. + logger.warning( + "IamAuth: could not fetch IAM signing key at startup; " + "JWT validation will fail until it's available" + ) + + # ------------------------------------------------------------------ + # Authentication + # ------------------------------------------------------------------ + + async def authenticate(self, request): + """Extract and validate the Bearer credential from an HTTP + request. Returns an ``Identity``. Raises HTTPUnauthorized + (401 / "auth failure") on any failure mode — the caller + cannot distinguish missing / malformed / invalid / expired / + revoked credentials.""" + + header = request.headers.get("Authorization", "") + if not header.startswith("Bearer "): + raise _auth_failure() + token = header[len("Bearer "):].strip() + if not token: + raise _auth_failure() + + # API keys always start with "tg_". JWTs have two dots and + # no "tg_" prefix. Discriminate cheaply. + if token.startswith("tg_"): + return await self._resolve_api_key(token) + if token.count(".") == 2: + return self._verify_jwt(token) + raise _auth_failure() + + def _verify_jwt(self, token): + if not self._signing_public_pem: + raise _auth_failure() + try: + claims = _verify_jwt_eddsa(token, self._signing_public_pem) + except Exception as e: + logger.debug(f"JWT validation failed: {type(e).__name__}: {e}") + raise _auth_failure() + + sub = claims.get("sub", "") + ws = claims.get("workspace", "") + roles = list(claims.get("roles", [])) + if not sub or not ws: + raise _auth_failure() + + return Identity( + user_id=sub, workspace=ws, roles=roles, source="jwt", + ) + + async def _resolve_api_key(self, plaintext): + h = hashlib.sha256(plaintext.encode("utf-8")).hexdigest() + + cached = self._key_cache.get(h) + now = time.time() + if cached and cached[1] > now: + return cached[0] + + async with self._key_cache_lock: + cached = self._key_cache.get(h) + if cached and cached[1] > now: + return cached[0] + + try: + async def _call(client): + return await client.resolve_api_key(plaintext) + user_id, workspace, roles = await self._with_client(_call) + except Exception as e: + logger.debug( + f"API key resolution failed: " + f"{type(e).__name__}: {e}" + ) + raise _auth_failure() + + if not user_id or not workspace: + raise _auth_failure() + + identity = Identity( + user_id=user_id, workspace=workspace, + roles=list(roles), source="api-key", + ) + self._key_cache[h] = (identity, now + API_KEY_CACHE_TTL) + return identity diff --git a/trustgraph-flow/trustgraph/gateway/capabilities.py b/trustgraph-flow/trustgraph/gateway/capabilities.py new file mode 100644 index 00000000..5413a4b1 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/capabilities.py @@ -0,0 +1,163 @@ +""" +Capability vocabulary and OSS role bundles. + +See docs/tech-specs/capabilities.md for the authoritative description. +The mapping below is the data form of the OSS bundle table in that +spec. Enterprise editions may replace this module with their own +role table; the vocabulary (capability strings) is shared. + +The module also exposes: + +- ``PUBLIC`` — a sentinel indicating an endpoint requires no + authentication (login, bootstrap). +- ``AUTHENTICATED`` — a sentinel indicating an endpoint requires a + valid identity but no specific capability (e.g. change-password). +- ``check(roles, capability)`` — the union-of-bundles membership test. +""" + +from aiohttp import web + + +PUBLIC = "__public__" +AUTHENTICATED = "__authenticated__" + + +# Capability vocabulary. Mirrors the "Capability list" tables in +# capabilities.md. Kept as a set of valid strings so the gateway can +# fail-closed on an endpoint that declares an unknown capability. +KNOWN_CAPABILITIES = { + # Data plane + "agent", + "graph:read", "graph:write", + "documents:read", "documents:write", + "rows:read", "rows:write", + "llm", + "embeddings", + "mcp", + # Control plane + "config:read", "config:write", + "flows:read", "flows:write", + "users:read", "users:write", "users:admin", + "keys:self", "keys:admin", + "workspaces:admin", + "iam:admin", + "metrics:read", + "collections:read", "collections:write", + "knowledge:read", "knowledge:write", +} + + +# OSS role → capability set. Enterprise overrides this mapping. +_READER_CAPS = { + "agent", + "graph:read", + "documents:read", + "rows:read", + "llm", + "embeddings", + "mcp", + "config:read", + "flows:read", + "collections:read", + "knowledge:read", + "keys:self", +} + +_WRITER_CAPS = _READER_CAPS | { + "graph:write", + "documents:write", + "rows:write", + "collections:write", + "knowledge:write", +} + +_ADMIN_CAPS = _WRITER_CAPS | { + "config:write", + "flows:write", + "users:read", "users:write", "users:admin", + "keys:admin", + "workspaces:admin", + "iam:admin", + "metrics:read", +} + +ROLE_CAPABILITIES = { + "reader": _READER_CAPS, + "writer": _WRITER_CAPS, + "admin": _ADMIN_CAPS, +} + + +def check(roles, capability): + """Return True if any of ``roles`` grants ``capability``. + + Unknown roles contribute zero capabilities (deterministic fail- + closed behaviour per the spec).""" + if capability not in KNOWN_CAPABILITIES: + # Endpoint misconfiguration. Fail closed. + return False + for r in roles: + if capability in ROLE_CAPABILITIES.get(r, ()): + return True + return False + + +def access_denied(): + return web.HTTPForbidden( + text='{"error":"access denied"}', + content_type="application/json", + ) + + +def auth_failure(): + return web.HTTPUnauthorized( + text='{"error":"auth failure"}', + content_type="application/json", + ) + + +async def enforce(request, auth, capability): + """Authenticate + capability-check in one step. Returns an + ``Identity`` (or ``None`` for ``PUBLIC`` endpoints) or raises + the appropriate HTTPException. + + Usage in an endpoint handler: + + identity = await enforce(request, self.auth, self.capability) + + - ``PUBLIC``: no authentication attempted, returns ``None``. + - ``AUTHENTICATED``: any valid identity is accepted. + - any capability string: identity must carry a role granting it. + """ + if capability == PUBLIC: + return None + + identity = await auth.authenticate(request) + + if capability == AUTHENTICATED: + return identity + + if not check(identity.roles, capability): + raise access_denied() + + return identity + + +def enforce_workspace(data, identity): + """Validate + inject the workspace field on a request body. + + OSS behaviour: + - If ``data["workspace"]`` is present and differs from the + caller's assigned workspace → 403. + - Otherwise, set ``data["workspace"]`` to the caller's assigned + workspace. + + Enterprise editions will plug in a different resolver that + checks a permitted-set instead of a single value; the wire + protocol is unchanged.""" + requested = data.get("workspace", "") if isinstance(data, dict) else "" + if requested and requested != identity.workspace: + raise access_denied() + if isinstance(data, dict): + data["workspace"] = identity.workspace + return data diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index 95a0ab66..9a259b71 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -165,6 +165,15 @@ class DispatcherManager: def dispatch_global_service(self): return DispatcherWrapper(self.process_global_service) + def dispatch_auth_iam(self): + """Pre-configured IAM dispatcher for the gateway's auth + endpoints (login, bootstrap, change-password). Pins the + kind to ``iam`` so these handlers don't have to supply URL + params the global dispatcher would expect.""" + async def _process(data, responder): + return await self.invoke_global_service(data, responder, "iam") + return DispatcherWrapper(_process) + def dispatch_core_export(self): return DispatcherWrapper(self.process_core_export) diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/auth_endpoints.py b/trustgraph-flow/trustgraph/gateway/endpoint/auth_endpoints.py new file mode 100644 index 00000000..6037fc4b --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/endpoint/auth_endpoints.py @@ -0,0 +1,115 @@ +""" +Gateway auth endpoints. + +Three dedicated paths: + POST /api/v1/auth/login — unauthenticated; username/password → JWT + POST /api/v1/auth/bootstrap — unauthenticated; IAM bootstrap op + POST /api/v1/auth/change-password — authenticated; any role + +These are the only IAM-surface operations that can be reached from +outside. Everything else routes through ``/api/v1/iam`` gated by +``users:admin``. +""" + +import logging + +from aiohttp import web + +from .. capabilities import enforce, PUBLIC, AUTHENTICATED + +logger = logging.getLogger("auth-endpoints") +logger.setLevel(logging.INFO) + + +class AuthEndpoints: + """Groups the three auth-surface handlers. Each forwards to the + IAM service via the existing ``IamRequestor`` dispatcher.""" + + def __init__(self, iam_dispatcher, auth): + self.iam = iam_dispatcher + self.auth = auth + + async def start(self): + pass + + def add_routes(self, app): + app.add_routes([ + web.post("/api/v1/auth/login", self.login), + web.post("/api/v1/auth/bootstrap", self.bootstrap), + web.post( + "/api/v1/auth/change-password", + self.change_password, + ), + ]) + + async def _forward(self, body): + async def responder(x, fin): + pass + return await self.iam.process(body, responder) + + async def login(self, request): + """Public. Accepts {username, password, workspace?}. Returns + {jwt, jwt_expires} on success; IAM's masked auth failure on + anything else.""" + await enforce(request, self.auth, PUBLIC) + try: + body = await request.json() + except Exception: + return web.json_response( + {"error": "invalid json"}, status=400, + ) + req = { + "operation": "login", + "username": body.get("username", ""), + "password": body.get("password", ""), + "workspace": body.get("workspace", ""), + } + resp = await self._forward(req) + if "error" in resp: + return web.json_response( + {"error": "auth failure"}, status=401, + ) + return web.json_response(resp) + + async def bootstrap(self, request): + """Public. Valid only when IAM is running in bootstrap mode + with empty tables. In every other case the IAM service + returns a masked auth-failure.""" + await enforce(request, self.auth, PUBLIC) + resp = await self._forward({"operation": "bootstrap"}) + if "error" in resp: + return web.json_response( + {"error": "auth failure"}, status=401, + ) + return web.json_response(resp) + + async def change_password(self, request): + """Authenticated (any role). Accepts {current_password, + new_password}; user_id is taken from the authenticated + identity — the caller cannot change someone else's password + this way (reset-password is the admin path).""" + identity = await enforce(request, self.auth, AUTHENTICATED) + try: + body = await request.json() + except Exception: + return web.json_response( + {"error": "invalid json"}, status=400, + ) + req = { + "operation": "change-password", + "user_id": identity.user_id, + "password": body.get("current_password", ""), + "new_password": body.get("new_password", ""), + } + resp = await self._forward(req) + if "error" in resp: + err_type = resp.get("error", {}).get("type", "") + if err_type == "auth-failed": + return web.json_response( + {"error": "auth failure"}, status=401, + ) + return web.json_response( + {"error": resp.get("error", {}).get("message", "error")}, + status=400, + ) + return web.json_response(resp) diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/constant_endpoint.py b/trustgraph-flow/trustgraph/gateway/endpoint/constant_endpoint.py index 58ba1738..ee9c0447 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/constant_endpoint.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/constant_endpoint.py @@ -1,28 +1,27 @@ -import asyncio -from aiohttp import web -import uuid import logging +from aiohttp import web + +from .. capabilities import enforce, enforce_workspace + logger = logging.getLogger("endpoint") logger.setLevel(logging.INFO) + class ConstantEndpoint: - def __init__(self, endpoint_path, auth, dispatcher): + def __init__(self, endpoint_path, auth, dispatcher, capability): self.path = endpoint_path - self.auth = auth - self.operation = "service" - + self.capability = capability self.dispatcher = dispatcher async def start(self): pass def add_routes(self, app): - app.add_routes([ web.post(self.path, self.handle), ]) @@ -31,22 +30,14 @@ class ConstantEndpoint: logger.debug(f"Processing request: {request.path}") - try: - ht = request.headers["Authorization"] - tokens = ht.split(" ", 2) - if tokens[0] != "Bearer": - return web.HTTPUnauthorized() - token = tokens[1] - except: - token = "" - - if not self.auth.permitted(token, self.operation): - return web.HTTPUnauthorized() + identity = await enforce(request, self.auth, self.capability) try: - data = await request.json() + if identity is not None: + enforce_workspace(data, identity) + async def responder(x, fin): pass @@ -54,10 +45,8 @@ class ConstantEndpoint: return web.json_response(resp) + except web.HTTPException: + raise except Exception as e: - logging.error(f"Exception: {e}") - - return web.json_response( - { "error": str(e) } - ) - + logger.error(f"Exception: {e}", exc_info=True) + return web.json_response({"error": str(e)}) diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/i18n.py b/trustgraph-flow/trustgraph/gateway/endpoint/i18n.py index b949a499..f28f293d 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/i18n.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/i18n.py @@ -4,16 +4,18 @@ from aiohttp import web from trustgraph.i18n import get_language_pack +from .. capabilities import enforce + logger = logging.getLogger("endpoint") logger.setLevel(logging.INFO) class I18nPackEndpoint: - def __init__(self, endpoint_path: str, auth): + def __init__(self, endpoint_path: str, auth, capability): self.path = endpoint_path self.auth = auth - self.operation = "service" + self.capability = capability async def start(self): pass @@ -26,26 +28,13 @@ class I18nPackEndpoint: async def handle(self, request): logger.debug(f"Processing i18n pack request: {request.path}") - token = "" - try: - ht = request.headers["Authorization"] - tokens = ht.split(" ", 2) - if tokens[0] != "Bearer": - return web.HTTPUnauthorized() - token = tokens[1] - except Exception: - token = "" - - if not self.auth.permitted(token, self.operation): - return web.HTTPUnauthorized() + await enforce(request, self.auth, self.capability) lang = request.match_info.get("lang") or "en" - # This is a path traversal defense, and is a critical sec defense. - # Do not remove! + # Path-traversal defense — critical, do not remove. if "/" in lang or ".." in lang: return web.HTTPBadRequest(reason="Invalid language code") pack = get_language_pack(lang) - return web.json_response(pack) diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/manager.py b/trustgraph-flow/trustgraph/gateway/endpoint/manager.py index fb8b0b76..472ee2fd 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/manager.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/manager.py @@ -8,72 +8,278 @@ from . variable_endpoint import VariableEndpoint from . socket import SocketEndpoint from . metrics import MetricsEndpoint from . i18n import I18nPackEndpoint +from . auth_endpoints import AuthEndpoints + +from .. capabilities import PUBLIC, AUTHENTICATED from .. dispatch.manager import DispatcherManager + +# Capability required for each kind on the /api/v1/{kind} generic +# endpoint (global services). Coarse gating — the IAM bundle split +# of "read vs write" per admin subsystem is not applied here because +# this endpoint forwards an opaque operation in the body. Writes +# are the upper bound on what the endpoint can do, so we gate on +# the write/admin capability. +GLOBAL_KIND_CAPABILITY = { + "config": "config:write", + "flow": "flows:write", + "librarian": "documents:write", + "knowledge": "knowledge:write", + "collection-management": "collections:write", + # IAM endpoints land on /api/v1/iam and require the admin bundle. + # Login / bootstrap / change-password are served by + # AuthEndpoints, which handle their own gating (PUBLIC / + # AUTHENTICATED). + "iam": "users:admin", +} + + +# Capability required for each kind on the +# /api/v1/flow/{flow}/service/{kind} endpoint (per-flow data-plane). +FLOW_KIND_CAPABILITY = { + "agent": "agent", + "text-completion": "llm", + "prompt": "llm", + "mcp-tool": "mcp", + "graph-rag": "graph:read", + "document-rag": "documents:read", + "embeddings": "embeddings", + "graph-embeddings": "graph:read", + "document-embeddings": "documents:read", + "triples": "graph:read", + "rows": "rows:read", + "nlp-query": "rows:read", + "structured-query": "rows:read", + "structured-diag": "rows:read", + "row-embeddings": "rows:read", + "sparql": "graph:read", +} + + +# Capability for the streaming flow import/export endpoints, +# keyed by the "kind" URL segment. +FLOW_IMPORT_CAPABILITY = { + "triples": "graph:write", + "graph-embeddings": "graph:write", + "document-embeddings": "documents:write", + "entity-contexts": "documents:write", + "rows": "rows:write", +} + +FLOW_EXPORT_CAPABILITY = { + "triples": "graph:read", + "graph-embeddings": "graph:read", + "document-embeddings": "documents:read", + "entity-contexts": "documents:read", +} + + +from .. capabilities import enforce, enforce_workspace +import logging as _mgr_logging +_mgr_logger = _mgr_logging.getLogger("endpoint") + + +class _RoutedVariableEndpoint: + """HTTP endpoint whose required capability is looked up per + request from the URL's ``kind`` parameter. Used for the two + generic dispatch paths (``/api/v1/{kind}`` and + ``/api/v1/flow/{flow}/service/{kind}``). Self-contained rather + than subclassing ``VariableEndpoint`` to avoid mutating shared + state across concurrent requests.""" + + def __init__(self, endpoint_path, auth, dispatcher, capability_map): + self.path = endpoint_path + self.auth = auth + self.dispatcher = dispatcher + self._capability_map = capability_map + + async def start(self): + pass + + def add_routes(self, app): + app.add_routes([web.post(self.path, self.handle)]) + + async def handle(self, request): + kind = request.match_info.get("kind", "") + cap = self._capability_map.get(kind) + if cap is None: + return web.json_response( + {"error": "unknown kind"}, status=404, + ) + + identity = await enforce(request, self.auth, cap) + + try: + data = await request.json() + if identity is not None: + enforce_workspace(data, identity) + + async def responder(x, fin): + pass + + resp = await self.dispatcher.process( + data, responder, request.match_info, + ) + return web.json_response(resp) + + except web.HTTPException: + raise + except Exception as e: + _mgr_logger.error(f"Exception: {e}", exc_info=True) + return web.json_response({"error": str(e)}) + + +class _RoutedSocketEndpoint: + """WebSocket endpoint whose required capability is looked up per + request from the URL's ``kind`` parameter. Used for the flow + import/export streaming endpoints.""" + + def __init__(self, endpoint_path, auth, dispatcher, capability_map): + self.path = endpoint_path + self.auth = auth + self.dispatcher = dispatcher + self._capability_map = capability_map + + async def start(self): + pass + + def add_routes(self, app): + app.add_routes([web.get(self.path, self.handle)]) + + async def handle(self, request): + from .. capabilities import check, auth_failure, access_denied + + kind = request.match_info.get("kind", "") + cap = self._capability_map.get(kind) + if cap is None: + return web.json_response( + {"error": "unknown kind"}, status=404, + ) + + token = request.query.get("token", "") + if not token: + return auth_failure() + + from . socket import _QueryTokenRequest + try: + identity = await self.auth.authenticate( + _QueryTokenRequest(token) + ) + except web.HTTPException as e: + return e + if not check(identity.roles, cap): + return access_denied() + + # Delegate the websocket handling to a standalone SocketEndpoint + # with the resolved capability, bypassing the per-request mutation + # concern by instantiating fresh state. + ws_ep = SocketEndpoint( + endpoint_path=self.path, + auth=self.auth, + dispatcher=self.dispatcher, + capability=cap, + ) + return await ws_ep.handle(request) + + class EndpointManager: def __init__( - self, dispatcher_manager, auth, prometheus_url, timeout=600 + self, dispatcher_manager, auth, prometheus_url, timeout=600, ): self.dispatcher_manager = dispatcher_manager self.timeout = timeout - self.services = { - } + # IAM forwarder (needed by AuthEndpoints). The same dispatcher + # the global /api/v1/iam path uses. No workspace enforcement on + # auth endpoints since login / bootstrap / change-password are + # pre-identity. + self._iam_dispatcher = dispatcher_manager.dispatch_global_service() self.endpoints = [ + + # Auth surface — public / authenticated-any. Must come + # before the generic /api/v1/{kind} routes to win the + # match for /api/v1/auth/* paths. aiohttp routes in + # registration order, so we prepend here. + AuthEndpoints( + iam_dispatcher=dispatcher_manager.dispatch_auth_iam(), + auth=auth, + ), + I18nPackEndpoint( - endpoint_path = "/api/v1/i18n/packs/{lang}", - auth = auth, + endpoint_path="/api/v1/i18n/packs/{lang}", + auth=auth, + capability=PUBLIC, ), MetricsEndpoint( - endpoint_path = "/api/metrics", - prometheus_url = prometheus_url, - auth = auth, + endpoint_path="/api/metrics", + prometheus_url=prometheus_url, + auth=auth, + capability="metrics:read", ), - VariableEndpoint( - endpoint_path = "/api/v1/{kind}", auth = auth, - dispatcher = dispatcher_manager.dispatch_global_service(), + + # Global services: capability chosen per-kind. + _RoutedVariableEndpoint( + endpoint_path="/api/v1/{kind}", + auth=auth, + dispatcher=dispatcher_manager.dispatch_global_service(), + capability_map=GLOBAL_KIND_CAPABILITY, ), + SocketEndpoint( - endpoint_path = "/api/v1/socket", - auth = auth, - dispatcher = dispatcher_manager.dispatch_socket() + endpoint_path="/api/v1/socket", + auth=auth, + dispatcher=dispatcher_manager.dispatch_socket(), + capability=AUTHENTICATED, ), - VariableEndpoint( - endpoint_path = "/api/v1/flow/{flow}/service/{kind}", - auth = auth, - dispatcher = dispatcher_manager.dispatch_flow_service(), + + # Per-flow request/response services — capability per kind. + _RoutedVariableEndpoint( + endpoint_path="/api/v1/flow/{flow}/service/{kind}", + auth=auth, + dispatcher=dispatcher_manager.dispatch_flow_service(), + capability_map=FLOW_KIND_CAPABILITY, ), - SocketEndpoint( - endpoint_path = "/api/v1/flow/{flow}/import/{kind}", - auth = auth, - dispatcher = dispatcher_manager.dispatch_flow_import() + + # Per-flow streaming import/export — capability per kind. + _RoutedSocketEndpoint( + endpoint_path="/api/v1/flow/{flow}/import/{kind}", + auth=auth, + dispatcher=dispatcher_manager.dispatch_flow_import(), + capability_map=FLOW_IMPORT_CAPABILITY, ), - SocketEndpoint( - endpoint_path = "/api/v1/flow/{flow}/export/{kind}", - auth = auth, - dispatcher = dispatcher_manager.dispatch_flow_export() + _RoutedSocketEndpoint( + endpoint_path="/api/v1/flow/{flow}/export/{kind}", + auth=auth, + dispatcher=dispatcher_manager.dispatch_flow_export(), + capability_map=FLOW_EXPORT_CAPABILITY, + ), + + StreamEndpoint( + endpoint_path="/api/v1/import-core", + auth=auth, + method="POST", + dispatcher=dispatcher_manager.dispatch_core_import(), + # Cross-subject import — require the admin bundle via a + # single representative capability. + capability="users:admin", ), StreamEndpoint( - endpoint_path = "/api/v1/import-core", - auth = auth, - method = "POST", - dispatcher = dispatcher_manager.dispatch_core_import(), + endpoint_path="/api/v1/export-core", + auth=auth, + method="GET", + dispatcher=dispatcher_manager.dispatch_core_export(), + capability="users:admin", ), StreamEndpoint( - endpoint_path = "/api/v1/export-core", - auth = auth, - method = "GET", - dispatcher = dispatcher_manager.dispatch_core_export(), - ), - StreamEndpoint( - endpoint_path = "/api/v1/document-stream", - auth = auth, - method = "GET", - dispatcher = dispatcher_manager.dispatch_document_stream(), + endpoint_path="/api/v1/document-stream", + auth=auth, + method="GET", + dispatcher=dispatcher_manager.dispatch_document_stream(), + capability="documents:read", ), ] @@ -84,4 +290,3 @@ class EndpointManager: async def start(self): for ep in self.endpoints: await ep.start() - diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/metrics.py b/trustgraph-flow/trustgraph/gateway/endpoint/metrics.py index 903a199c..6832d1e3 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/metrics.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/metrics.py @@ -10,17 +10,19 @@ import asyncio import uuid import logging +from .. capabilities import enforce + logger = logging.getLogger("endpoint") logger.setLevel(logging.INFO) class MetricsEndpoint: - def __init__(self, prometheus_url, endpoint_path, auth): + def __init__(self, prometheus_url, endpoint_path, auth, capability): self.prometheus_url = prometheus_url self.path = endpoint_path self.auth = auth - self.operation = "service" + self.capability = capability async def start(self): pass @@ -35,17 +37,7 @@ class MetricsEndpoint: logger.debug(f"Processing metrics request: {request.path}") - try: - ht = request.headers["Authorization"] - tokens = ht.split(" ", 2) - if tokens[0] != "Bearer": - return web.HTTPUnauthorized() - token = tokens[1] - except: - token = "" - - if not self.auth.permitted(token, self.operation): - return web.HTTPUnauthorized() + await enforce(request, self.auth, self.capability) path = request.match_info["path"] url = ( diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/socket.py b/trustgraph-flow/trustgraph/gateway/endpoint/socket.py index 9065761c..e3decbd2 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/socket.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/socket.py @@ -4,6 +4,9 @@ from aiohttp import web, WSMsgType import logging from .. running import Running +from .. capabilities import ( + PUBLIC, AUTHENTICATED, check, auth_failure, access_denied, +) logger = logging.getLogger("socket") logger.setLevel(logging.INFO) @@ -11,12 +14,12 @@ logger.setLevel(logging.INFO) class SocketEndpoint: def __init__( - self, endpoint_path, auth, dispatcher, + self, endpoint_path, auth, dispatcher, capability, ): self.path = endpoint_path self.auth = auth - self.operation = "socket" + self.capability = capability self.dispatcher = dispatcher @@ -61,15 +64,29 @@ class SocketEndpoint: raise async def handle(self, request): - """Enhanced handler with better cleanup""" - try: - token = request.query['token'] - except: - token = "" + """Enhanced handler with better cleanup. + + Auth: WebSocket clients pass the bearer token on the + ``?token=...`` query string; we wrap it into a synthetic + Authorization header before delegating to the standard auth + path so the IAM-backed flow (JWT / API key) applies uniformly. + The first-frame auth protocol described in the IAM spec is + a future upgrade.""" + + if self.capability != PUBLIC: + token = request.query.get("token", "") + if not token: + return auth_failure() + try: + identity = await self.auth.authenticate( + _QueryTokenRequest(token) + ) + except web.HTTPException as e: + return e + if self.capability != AUTHENTICATED: + if not check(identity.roles, self.capability): + return access_denied() - if not self.auth.permitted(token, self.operation): - return web.HTTPUnauthorized() - # 50MB max message size ws = web.WebSocketResponse(max_msg_size=52428800) @@ -150,3 +167,11 @@ class SocketEndpoint: web.get(self.path, self.handle), ]) + +class _QueryTokenRequest: + """Minimal shim that exposes headers["Authorization"] to + IamAuth.authenticate(), derived from a query-string token.""" + + def __init__(self, token): + self.headers = {"Authorization": f"Bearer {token}"} + diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/stream_endpoint.py b/trustgraph-flow/trustgraph/gateway/endpoint/stream_endpoint.py index 38d8846f..7b0c4692 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/stream_endpoint.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/stream_endpoint.py @@ -1,82 +1,64 @@ -import asyncio -from aiohttp import web import logging +from aiohttp import web + +from .. capabilities import enforce + logger = logging.getLogger("endpoint") logger.setLevel(logging.INFO) + class StreamEndpoint: - def __init__(self, endpoint_path, auth, dispatcher, method="POST"): - + def __init__( + self, endpoint_path, auth, dispatcher, capability, method="POST", + ): self.path = endpoint_path - self.auth = auth - self.operation = "service" + self.capability = capability self.method = method - self.dispatcher = dispatcher async def start(self): pass def add_routes(self, app): - if self.method == "POST": - app.add_routes([ - web.post(self.path, self.handle), - ]) + app.add_routes([web.post(self.path, self.handle)]) elif self.method == "GET": - app.add_routes([ - web.get(self.path, self.handle), - ]) + app.add_routes([web.get(self.path, self.handle)]) else: - raise RuntimeError("Bad method" + self.method) + raise RuntimeError("Bad method " + self.method) async def handle(self, request): logger.debug(f"Processing request: {request.path}") - try: - ht = request.headers["Authorization"] - tokens = ht.split(" ", 2) - if tokens[0] != "Bearer": - return web.HTTPUnauthorized() - token = tokens[1] - except: - token = "" - - if not self.auth.permitted(token, self.operation): - return web.HTTPUnauthorized() + await enforce(request, self.auth, self.capability) try: - data = request.content async def error(err): - return web.HTTPInternalServerError(text = err) + return web.HTTPInternalServerError(text=err) async def ok( - status=200, reason="OK", type="application/octet-stream" + status=200, reason="OK", + type="application/octet-stream", ): response = web.StreamResponse( - status = status, reason = reason, - headers = {"Content-Type": type} + status=status, reason=reason, + headers={"Content-Type": type}, ) await response.prepare(request) return response - resp = await self.dispatcher.process( - data, error, ok, request - ) - + resp = await self.dispatcher.process(data, error, ok, request) return resp + except web.HTTPException: + raise except Exception as e: - logging.error(f"Exception: {e}") - - return web.json_response( - { "error": str(e) } - ) - + logger.error(f"Exception: {e}", exc_info=True) + return web.json_response({"error": str(e)}) diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/variable_endpoint.py b/trustgraph-flow/trustgraph/gateway/endpoint/variable_endpoint.py index 608de71b..5e0d9d21 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/variable_endpoint.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/variable_endpoint.py @@ -1,27 +1,27 @@ -import asyncio -from aiohttp import web import logging +from aiohttp import web + +from .. capabilities import enforce, enforce_workspace + logger = logging.getLogger("endpoint") logger.setLevel(logging.INFO) + class VariableEndpoint: - def __init__(self, endpoint_path, auth, dispatcher): + def __init__(self, endpoint_path, auth, dispatcher, capability): self.path = endpoint_path - self.auth = auth - self.operation = "service" - + self.capability = capability self.dispatcher = dispatcher async def start(self): pass def add_routes(self, app): - app.add_routes([ web.post(self.path, self.handle), ]) @@ -30,35 +30,25 @@ class VariableEndpoint: logger.debug(f"Processing request: {request.path}") - try: - ht = request.headers["Authorization"] - tokens = ht.split(" ", 2) - if tokens[0] != "Bearer": - return web.HTTPUnauthorized() - token = tokens[1] - except: - token = "" - - if not self.auth.permitted(token, self.operation): - return web.HTTPUnauthorized() + identity = await enforce(request, self.auth, self.capability) try: - data = await request.json() + if identity is not None: + enforce_workspace(data, identity) + async def responder(x, fin): pass resp = await self.dispatcher.process( - data, responder, request.match_info + data, responder, request.match_info, ) return web.json_response(resp) + except web.HTTPException: + raise except Exception as e: - logging.error(f"Exception: {e}") - - return web.json_response( - { "error": str(e) } - ) - + logger.error(f"Exception: {e}", exc_info=True) + return web.json_response({"error": str(e)}) diff --git a/trustgraph-flow/trustgraph/gateway/service.py b/trustgraph-flow/trustgraph/gateway/service.py index 4e465bf7..bbe42908 100755 --- a/trustgraph-flow/trustgraph/gateway/service.py +++ b/trustgraph-flow/trustgraph/gateway/service.py @@ -12,7 +12,7 @@ import os from trustgraph.base.logging import setup_logging, add_logging_args from trustgraph.base.pubsub import get_pubsub, add_pubsub_args -from . auth import Authenticator +from . auth import IamAuth from . config.receiver import ConfigReceiver from . dispatch.manager import DispatcherManager @@ -35,7 +35,6 @@ default_prometheus_url = os.getenv("PROMETHEUS_URL", "http://prometheus:9090") default_pulsar_api_key = os.getenv("PULSAR_API_KEY", None) default_timeout = 600 default_port = 8088 -default_api_token = os.getenv("GATEWAY_SECRET", "") class Api: @@ -60,13 +59,14 @@ class Api: if not self.prometheus_url.endswith("/"): self.prometheus_url += "/" - api_token = config.get("api_token", default_api_token) - - # Token not set, or token equal empty string means no auth - if api_token: - self.auth = Authenticator(token=api_token) - else: - self.auth = Authenticator(allow_all=True) + # IAM-backed authentication. The legacy GATEWAY_SECRET + # shared-token path has been removed — there is no + # "open for everyone" fallback. The gateway cannot + # authenticate any request until IAM is reachable. + self.auth = IamAuth( + backend=self.pubsub_backend, + id=config.get("id", "api-gateway"), + ) self.config_receiver = ConfigReceiver(self.pubsub_backend) @@ -132,12 +132,18 @@ class Api: ] async def app_factory(self): - + self.app = web.Application( middlewares=[], client_max_size=256 * 1024 * 1024 ) + # Fetch IAM signing public key before accepting traffic. + # Blocks for a bounded retry window; the gateway starts even + # if IAM is still unreachable (JWT validation will 401 until + # the key is available). + await self.auth.start() + await self.config_receiver.start() for ep in self.endpoints: @@ -189,12 +195,6 @@ def run(): help=f'API request timeout in seconds (default: {default_timeout})', ) - parser.add_argument( - '--api-token', - default=default_api_token, - help=f'Secret API token (default: no auth)', - ) - add_logging_args(parser) parser.add_argument(