Add JWT login support

This commit is contained in:
Cyber MacGeddon 2026-04-23 13:47:49 +01:00
parent 0ca0f9999c
commit 7be781b6e2
2 changed files with 210 additions and 1 deletions

View file

@ -6,14 +6,19 @@ See docs/tech-specs/iam-protocol.md for the wire-level contract and
docs/tech-specs/iam.md for the surrounding architecture.
"""
import asyncio
import base64
import datetime
import hashlib
import json
import logging
import os
import secrets
import uuid
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from trustgraph.schema import (
IamResponse, Error,
UserRecord, WorkspaceRecord, ApiKeyRecord,
@ -32,6 +37,10 @@ PBKDF2_ITERATIONS = 600_000
API_KEY_PREFIX = "tg_"
API_KEY_RANDOM_BYTES = 24
JWT_ISSUER = "trustgraph-iam"
JWT_TTL_SECONDS = 3600
RSA_KEY_SIZE = 2048
def _now_iso():
return datetime.datetime.now(datetime.timezone.utc).isoformat()
@ -114,12 +123,58 @@ def _parse_expires(s):
return None
def _b64url(data):
"""URL-safe base64 encode without padding, as required by JWT."""
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
def _generate_signing_keypair():
"""Return (kid, private_pem, public_pem) for a fresh RSA keypair."""
key = rsa.generate_private_key(
public_exponent=65537, key_size=RSA_KEY_SIZE,
)
private_pem = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
).decode("ascii")
public_pem = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
).decode("ascii")
kid = uuid.uuid4().hex[:16]
return kid, private_pem, public_pem
def _sign_jwt(kid, private_pem, claims):
"""Produce a compact-serialisation RS256 JWT for ``claims``."""
header = {"alg": "RS256", "typ": "JWT", "kid": kid}
header_b = _b64url(json.dumps(
header, separators=(",", ":"), sort_keys=True,
).encode("utf-8"))
payload_b = _b64url(json.dumps(
claims, separators=(",", ":"), sort_keys=True,
).encode("utf-8"))
signing_input = f"{header_b}.{payload_b}".encode("ascii")
key = serialization.load_pem_private_key(
private_pem.encode("ascii"), password=None,
)
signature = key.sign(signing_input, padding.PKCS1v15(), hashes.SHA256())
return f"{header_b}.{payload_b}.{_b64url(signature)}"
class IamService:
def __init__(self, host, username, password, keyspace):
self.table_store = IamTableStore(
host, username, password, keyspace,
)
# Active signing key cache: (kid, private_pem, public_pem) or
# None. Loaded lazily on first use; refreshed whenever a key
# is created.
self._signing_key = None
self._signing_key_lock = asyncio.Lock()
# ------------------------------------------------------------------
# Dispatch
@ -143,6 +198,10 @@ class IamService:
return await self.handle_list_api_keys(v)
if op == "revoke-api-key":
return await self.handle_revoke_api_key(v)
if op == "login":
return await self.handle_login(v)
if op == "get-signing-key-public":
return await self.handle_get_signing_key_public(v)
return _err(
"invalid-argument",
@ -251,9 +310,23 @@ class IamService:
last_used=None,
)
# Initial JWT signing key.
kid, private_pem, public_pem = _generate_signing_keypair()
await self.table_store.put_signing_key(
kid=kid,
private_pem=private_pem,
public_pem=public_pem,
created=now,
retired=None,
)
# Populate cache so login calls in this process don't go
# back to Cassandra on first use.
self._signing_key = (kid, private_pem, public_pem)
logger.info(
f"IAM bootstrap: created workspace={DEFAULT_WORKSPACE!r}, "
f"admin user_id={admin_user_id}, initial API key issued"
f"admin user_id={admin_user_id}, initial API key issued, "
f"signing key kid={kid}"
)
return IamResponse(
@ -261,6 +334,119 @@ class IamService:
bootstrap_admin_api_key=plaintext,
)
# ------------------------------------------------------------------
# Signing key helpers
# ------------------------------------------------------------------
async def _get_active_signing_key(self):
"""Return ``(kid, private_pem, public_pem)`` for the active
signing key. Loads from Cassandra on first call. Generates
and persists a new key if none exists covers the case where
``login`` is called before ``bootstrap`` (shouldn't happen in
practice but keeps the service internally consistent)."""
if self._signing_key is not None:
return self._signing_key
async with self._signing_key_lock:
if self._signing_key is not None:
return self._signing_key
rows = await self.table_store.list_signing_keys()
active = [r for r in rows if r[4] is None]
if active:
row = active[0]
self._signing_key = (row[0], row[1], row[2])
logger.info(
f"IAM: loaded active signing key kid={row[0]}"
)
return self._signing_key
kid, private_pem, public_pem = _generate_signing_keypair()
await self.table_store.put_signing_key(
kid=kid,
private_pem=private_pem,
public_pem=public_pem,
created=_now_dt(),
retired=None,
)
self._signing_key = (kid, private_pem, public_pem)
logger.info(
f"IAM: generated active signing key kid={kid} "
f"(no existing key found)"
)
return self._signing_key
# ------------------------------------------------------------------
# login
# ------------------------------------------------------------------
async def handle_login(self, v):
if not v.username:
return _err("auth-failed", "username required")
if not v.password:
return _err("auth-failed", "password required")
# Login accepts an optional workspace parameter. If omitted
# we use the default workspace (OSS single-workspace
# assumption). Multi-workspace enterprise editions swap in a
# resolver that looks across the caller's permitted set.
workspace = v.workspace or DEFAULT_WORKSPACE
user_id = await self.table_store.get_user_id_by_username(
workspace, v.username,
)
if not user_id:
return _err("auth-failed", "no such user")
user_row = await self.table_store.get_user(user_id)
if user_row is None:
return _err("auth-failed", "user disappeared")
(
id, ws, _username, _name, _email, password_hash,
roles, enabled, _mcp, _created,
) = user_row
if not enabled:
return _err("auth-failed", "user disabled")
if not password_hash or not _verify_password(
v.password, password_hash,
):
return _err("auth-failed", "bad credentials")
ws_row = await self.table_store.get_workspace(ws)
if ws_row is None or not ws_row[2]:
return _err("auth-failed", "workspace disabled")
kid, private_pem, _ = await self._get_active_signing_key()
now_ts = int(_now_dt().timestamp())
exp_ts = now_ts + JWT_TTL_SECONDS
claims = {
"iss": JWT_ISSUER,
"sub": id,
"workspace": ws,
"roles": sorted(roles) if roles else [],
"iat": now_ts,
"exp": exp_ts,
}
token = _sign_jwt(kid, private_pem, claims)
expires_iso = datetime.datetime.fromtimestamp(
exp_ts, tz=datetime.timezone.utc,
).isoformat()
return IamResponse(jwt=token, jwt_expires=expires_iso)
# ------------------------------------------------------------------
# get-signing-key-public
# ------------------------------------------------------------------
async def handle_get_signing_key_public(self, v):
_, _, public_pem = await self._get_active_signing_key()
return IamResponse(signing_key_public=public_pem)
# ------------------------------------------------------------------
# resolve-api-key
# ------------------------------------------------------------------