From 7be781b6e2a62a7ba93c3827258e108934011998 Mon Sep 17 00:00:00 2001 From: Cyber MacGeddon Date: Thu, 23 Apr 2026 13:47:49 +0100 Subject: [PATCH] Add JWT login support --- trustgraph-base/trustgraph/base/iam_client.py | 23 +++ trustgraph-flow/trustgraph/iam/service/iam.py | 188 +++++++++++++++++- 2 files changed, 210 insertions(+), 1 deletion(-) diff --git a/trustgraph-base/trustgraph/base/iam_client.py b/trustgraph-base/trustgraph/base/iam_client.py index 887b37bc..da016eb2 100644 --- a/trustgraph-base/trustgraph/base/iam_client.py +++ b/trustgraph-base/trustgraph/base/iam_client.py @@ -112,6 +112,29 @@ class IamClient(RequestResponse): timeout=timeout, ) + async def login(self, username, password, workspace="", + timeout=IAM_TIMEOUT): + """Validate credentials and return ``(jwt, expires_iso)``. + ``workspace`` is optional; defaults at the server to the + OSS default workspace.""" + resp = await self._request( + operation="login", + workspace=workspace, + username=username, + password=password, + timeout=timeout, + ) + return resp.jwt, resp.jwt_expires + + async def get_signing_key_public(self, timeout=IAM_TIMEOUT): + """Return the active JWT signing public key in PEM. The + gateway calls this at startup and caches the result.""" + resp = await self._request( + operation="get-signing-key-public", + timeout=timeout, + ) + return resp.signing_key_public + class IamClientSpec(RequestResponseSpec): def __init__(self, request_name, response_name): diff --git a/trustgraph-flow/trustgraph/iam/service/iam.py b/trustgraph-flow/trustgraph/iam/service/iam.py index 45bd01f6..023362f4 100644 --- a/trustgraph-flow/trustgraph/iam/service/iam.py +++ b/trustgraph-flow/trustgraph/iam/service/iam.py @@ -6,14 +6,19 @@ See docs/tech-specs/iam-protocol.md for the wire-level contract and docs/tech-specs/iam.md for the surrounding architecture. """ +import asyncio import base64 import datetime import hashlib +import json import logging import os import secrets import uuid +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding, rsa + from trustgraph.schema import ( IamResponse, Error, UserRecord, WorkspaceRecord, ApiKeyRecord, @@ -32,6 +37,10 @@ PBKDF2_ITERATIONS = 600_000 API_KEY_PREFIX = "tg_" API_KEY_RANDOM_BYTES = 24 +JWT_ISSUER = "trustgraph-iam" +JWT_TTL_SECONDS = 3600 +RSA_KEY_SIZE = 2048 + def _now_iso(): return datetime.datetime.now(datetime.timezone.utc).isoformat() @@ -114,12 +123,58 @@ def _parse_expires(s): return None +def _b64url(data): + """URL-safe base64 encode without padding, as required by JWT.""" + return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") + + +def _generate_signing_keypair(): + """Return (kid, private_pem, public_pem) for a fresh RSA keypair.""" + key = rsa.generate_private_key( + public_exponent=65537, key_size=RSA_KEY_SIZE, + ) + private_pem = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode("ascii") + public_pem = key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ).decode("ascii") + kid = uuid.uuid4().hex[:16] + return kid, private_pem, public_pem + + +def _sign_jwt(kid, private_pem, claims): + """Produce a compact-serialisation RS256 JWT for ``claims``.""" + header = {"alg": "RS256", "typ": "JWT", "kid": kid} + header_b = _b64url(json.dumps( + header, separators=(",", ":"), sort_keys=True, + ).encode("utf-8")) + payload_b = _b64url(json.dumps( + claims, separators=(",", ":"), sort_keys=True, + ).encode("utf-8")) + signing_input = f"{header_b}.{payload_b}".encode("ascii") + + key = serialization.load_pem_private_key( + private_pem.encode("ascii"), password=None, + ) + signature = key.sign(signing_input, padding.PKCS1v15(), hashes.SHA256()) + return f"{header_b}.{payload_b}.{_b64url(signature)}" + + class IamService: def __init__(self, host, username, password, keyspace): self.table_store = IamTableStore( host, username, password, keyspace, ) + # Active signing key cache: (kid, private_pem, public_pem) or + # None. Loaded lazily on first use; refreshed whenever a key + # is created. + self._signing_key = None + self._signing_key_lock = asyncio.Lock() # ------------------------------------------------------------------ # Dispatch @@ -143,6 +198,10 @@ class IamService: return await self.handle_list_api_keys(v) if op == "revoke-api-key": return await self.handle_revoke_api_key(v) + if op == "login": + return await self.handle_login(v) + if op == "get-signing-key-public": + return await self.handle_get_signing_key_public(v) return _err( "invalid-argument", @@ -251,9 +310,23 @@ class IamService: last_used=None, ) + # Initial JWT signing key. + kid, private_pem, public_pem = _generate_signing_keypair() + await self.table_store.put_signing_key( + kid=kid, + private_pem=private_pem, + public_pem=public_pem, + created=now, + retired=None, + ) + # Populate cache so login calls in this process don't go + # back to Cassandra on first use. + self._signing_key = (kid, private_pem, public_pem) + logger.info( f"IAM bootstrap: created workspace={DEFAULT_WORKSPACE!r}, " - f"admin user_id={admin_user_id}, initial API key issued" + f"admin user_id={admin_user_id}, initial API key issued, " + f"signing key kid={kid}" ) return IamResponse( @@ -261,6 +334,119 @@ class IamService: bootstrap_admin_api_key=plaintext, ) + # ------------------------------------------------------------------ + # Signing key helpers + # ------------------------------------------------------------------ + + async def _get_active_signing_key(self): + """Return ``(kid, private_pem, public_pem)`` for the active + signing key. Loads from Cassandra on first call. Generates + and persists a new key if none exists — covers the case where + ``login`` is called before ``bootstrap`` (shouldn't happen in + practice but keeps the service internally consistent).""" + if self._signing_key is not None: + return self._signing_key + + async with self._signing_key_lock: + if self._signing_key is not None: + return self._signing_key + + rows = await self.table_store.list_signing_keys() + active = [r for r in rows if r[4] is None] + + if active: + row = active[0] + self._signing_key = (row[0], row[1], row[2]) + logger.info( + f"IAM: loaded active signing key kid={row[0]}" + ) + return self._signing_key + + kid, private_pem, public_pem = _generate_signing_keypair() + await self.table_store.put_signing_key( + kid=kid, + private_pem=private_pem, + public_pem=public_pem, + created=_now_dt(), + retired=None, + ) + self._signing_key = (kid, private_pem, public_pem) + logger.info( + f"IAM: generated active signing key kid={kid} " + f"(no existing key found)" + ) + return self._signing_key + + # ------------------------------------------------------------------ + # login + # ------------------------------------------------------------------ + + async def handle_login(self, v): + if not v.username: + return _err("auth-failed", "username required") + if not v.password: + return _err("auth-failed", "password required") + + # Login accepts an optional workspace parameter. If omitted + # we use the default workspace (OSS single-workspace + # assumption). Multi-workspace enterprise editions swap in a + # resolver that looks across the caller's permitted set. + workspace = v.workspace or DEFAULT_WORKSPACE + + user_id = await self.table_store.get_user_id_by_username( + workspace, v.username, + ) + if not user_id: + return _err("auth-failed", "no such user") + + user_row = await self.table_store.get_user(user_id) + if user_row is None: + return _err("auth-failed", "user disappeared") + + ( + id, ws, _username, _name, _email, password_hash, + roles, enabled, _mcp, _created, + ) = user_row + + if not enabled: + return _err("auth-failed", "user disabled") + if not password_hash or not _verify_password( + v.password, password_hash, + ): + return _err("auth-failed", "bad credentials") + + ws_row = await self.table_store.get_workspace(ws) + if ws_row is None or not ws_row[2]: + return _err("auth-failed", "workspace disabled") + + kid, private_pem, _ = await self._get_active_signing_key() + + now_ts = int(_now_dt().timestamp()) + exp_ts = now_ts + JWT_TTL_SECONDS + claims = { + "iss": JWT_ISSUER, + "sub": id, + "workspace": ws, + "roles": sorted(roles) if roles else [], + "iat": now_ts, + "exp": exp_ts, + } + token = _sign_jwt(kid, private_pem, claims) + + expires_iso = datetime.datetime.fromtimestamp( + exp_ts, tz=datetime.timezone.utc, + ).isoformat() + + return IamResponse(jwt=token, jwt_expires=expires_iso) + + # ------------------------------------------------------------------ + # get-signing-key-public + # ------------------------------------------------------------------ + + async def handle_get_signing_key_public(self, v): + _, _, public_pem = await self._get_active_signing_key() + return IamResponse(signing_key_public=public_pem) + # ------------------------------------------------------------------ # resolve-api-key # ------------------------------------------------------------------