mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-03 12:22:37 +02:00
Add JWT login support
This commit is contained in:
parent
0ca0f9999c
commit
7be781b6e2
2 changed files with 210 additions and 1 deletions
|
|
@ -6,14 +6,19 @@ See docs/tech-specs/iam-protocol.md for the wire-level contract and
|
|||
docs/tech-specs/iam.md for the surrounding architecture.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import uuid
|
||||
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding, rsa
|
||||
|
||||
from trustgraph.schema import (
|
||||
IamResponse, Error,
|
||||
UserRecord, WorkspaceRecord, ApiKeyRecord,
|
||||
|
|
@ -32,6 +37,10 @@ PBKDF2_ITERATIONS = 600_000
|
|||
API_KEY_PREFIX = "tg_"
|
||||
API_KEY_RANDOM_BYTES = 24
|
||||
|
||||
JWT_ISSUER = "trustgraph-iam"
|
||||
JWT_TTL_SECONDS = 3600
|
||||
RSA_KEY_SIZE = 2048
|
||||
|
||||
|
||||
def _now_iso():
|
||||
return datetime.datetime.now(datetime.timezone.utc).isoformat()
|
||||
|
|
@ -114,12 +123,58 @@ def _parse_expires(s):
|
|||
return None
|
||||
|
||||
|
||||
def _b64url(data):
|
||||
"""URL-safe base64 encode without padding, as required by JWT."""
|
||||
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
|
||||
|
||||
|
||||
def _generate_signing_keypair():
|
||||
"""Return (kid, private_pem, public_pem) for a fresh RSA keypair."""
|
||||
key = rsa.generate_private_key(
|
||||
public_exponent=65537, key_size=RSA_KEY_SIZE,
|
||||
)
|
||||
private_pem = key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
).decode("ascii")
|
||||
public_pem = key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
).decode("ascii")
|
||||
kid = uuid.uuid4().hex[:16]
|
||||
return kid, private_pem, public_pem
|
||||
|
||||
|
||||
def _sign_jwt(kid, private_pem, claims):
|
||||
"""Produce a compact-serialisation RS256 JWT for ``claims``."""
|
||||
header = {"alg": "RS256", "typ": "JWT", "kid": kid}
|
||||
header_b = _b64url(json.dumps(
|
||||
header, separators=(",", ":"), sort_keys=True,
|
||||
).encode("utf-8"))
|
||||
payload_b = _b64url(json.dumps(
|
||||
claims, separators=(",", ":"), sort_keys=True,
|
||||
).encode("utf-8"))
|
||||
signing_input = f"{header_b}.{payload_b}".encode("ascii")
|
||||
|
||||
key = serialization.load_pem_private_key(
|
||||
private_pem.encode("ascii"), password=None,
|
||||
)
|
||||
signature = key.sign(signing_input, padding.PKCS1v15(), hashes.SHA256())
|
||||
return f"{header_b}.{payload_b}.{_b64url(signature)}"
|
||||
|
||||
|
||||
class IamService:
|
||||
|
||||
def __init__(self, host, username, password, keyspace):
|
||||
self.table_store = IamTableStore(
|
||||
host, username, password, keyspace,
|
||||
)
|
||||
# Active signing key cache: (kid, private_pem, public_pem) or
|
||||
# None. Loaded lazily on first use; refreshed whenever a key
|
||||
# is created.
|
||||
self._signing_key = None
|
||||
self._signing_key_lock = asyncio.Lock()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Dispatch
|
||||
|
|
@ -143,6 +198,10 @@ class IamService:
|
|||
return await self.handle_list_api_keys(v)
|
||||
if op == "revoke-api-key":
|
||||
return await self.handle_revoke_api_key(v)
|
||||
if op == "login":
|
||||
return await self.handle_login(v)
|
||||
if op == "get-signing-key-public":
|
||||
return await self.handle_get_signing_key_public(v)
|
||||
|
||||
return _err(
|
||||
"invalid-argument",
|
||||
|
|
@ -251,9 +310,23 @@ class IamService:
|
|||
last_used=None,
|
||||
)
|
||||
|
||||
# Initial JWT signing key.
|
||||
kid, private_pem, public_pem = _generate_signing_keypair()
|
||||
await self.table_store.put_signing_key(
|
||||
kid=kid,
|
||||
private_pem=private_pem,
|
||||
public_pem=public_pem,
|
||||
created=now,
|
||||
retired=None,
|
||||
)
|
||||
# Populate cache so login calls in this process don't go
|
||||
# back to Cassandra on first use.
|
||||
self._signing_key = (kid, private_pem, public_pem)
|
||||
|
||||
logger.info(
|
||||
f"IAM bootstrap: created workspace={DEFAULT_WORKSPACE!r}, "
|
||||
f"admin user_id={admin_user_id}, initial API key issued"
|
||||
f"admin user_id={admin_user_id}, initial API key issued, "
|
||||
f"signing key kid={kid}"
|
||||
)
|
||||
|
||||
return IamResponse(
|
||||
|
|
@ -261,6 +334,119 @@ class IamService:
|
|||
bootstrap_admin_api_key=plaintext,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Signing key helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _get_active_signing_key(self):
|
||||
"""Return ``(kid, private_pem, public_pem)`` for the active
|
||||
signing key. Loads from Cassandra on first call. Generates
|
||||
and persists a new key if none exists — covers the case where
|
||||
``login`` is called before ``bootstrap`` (shouldn't happen in
|
||||
practice but keeps the service internally consistent)."""
|
||||
if self._signing_key is not None:
|
||||
return self._signing_key
|
||||
|
||||
async with self._signing_key_lock:
|
||||
if self._signing_key is not None:
|
||||
return self._signing_key
|
||||
|
||||
rows = await self.table_store.list_signing_keys()
|
||||
active = [r for r in rows if r[4] is None]
|
||||
|
||||
if active:
|
||||
row = active[0]
|
||||
self._signing_key = (row[0], row[1], row[2])
|
||||
logger.info(
|
||||
f"IAM: loaded active signing key kid={row[0]}"
|
||||
)
|
||||
return self._signing_key
|
||||
|
||||
kid, private_pem, public_pem = _generate_signing_keypair()
|
||||
await self.table_store.put_signing_key(
|
||||
kid=kid,
|
||||
private_pem=private_pem,
|
||||
public_pem=public_pem,
|
||||
created=_now_dt(),
|
||||
retired=None,
|
||||
)
|
||||
self._signing_key = (kid, private_pem, public_pem)
|
||||
logger.info(
|
||||
f"IAM: generated active signing key kid={kid} "
|
||||
f"(no existing key found)"
|
||||
)
|
||||
return self._signing_key
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# login
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def handle_login(self, v):
|
||||
if not v.username:
|
||||
return _err("auth-failed", "username required")
|
||||
if not v.password:
|
||||
return _err("auth-failed", "password required")
|
||||
|
||||
# Login accepts an optional workspace parameter. If omitted
|
||||
# we use the default workspace (OSS single-workspace
|
||||
# assumption). Multi-workspace enterprise editions swap in a
|
||||
# resolver that looks across the caller's permitted set.
|
||||
workspace = v.workspace or DEFAULT_WORKSPACE
|
||||
|
||||
user_id = await self.table_store.get_user_id_by_username(
|
||||
workspace, v.username,
|
||||
)
|
||||
if not user_id:
|
||||
return _err("auth-failed", "no such user")
|
||||
|
||||
user_row = await self.table_store.get_user(user_id)
|
||||
if user_row is None:
|
||||
return _err("auth-failed", "user disappeared")
|
||||
|
||||
(
|
||||
id, ws, _username, _name, _email, password_hash,
|
||||
roles, enabled, _mcp, _created,
|
||||
) = user_row
|
||||
|
||||
if not enabled:
|
||||
return _err("auth-failed", "user disabled")
|
||||
if not password_hash or not _verify_password(
|
||||
v.password, password_hash,
|
||||
):
|
||||
return _err("auth-failed", "bad credentials")
|
||||
|
||||
ws_row = await self.table_store.get_workspace(ws)
|
||||
if ws_row is None or not ws_row[2]:
|
||||
return _err("auth-failed", "workspace disabled")
|
||||
|
||||
kid, private_pem, _ = await self._get_active_signing_key()
|
||||
|
||||
now_ts = int(_now_dt().timestamp())
|
||||
exp_ts = now_ts + JWT_TTL_SECONDS
|
||||
claims = {
|
||||
"iss": JWT_ISSUER,
|
||||
"sub": id,
|
||||
"workspace": ws,
|
||||
"roles": sorted(roles) if roles else [],
|
||||
"iat": now_ts,
|
||||
"exp": exp_ts,
|
||||
}
|
||||
token = _sign_jwt(kid, private_pem, claims)
|
||||
|
||||
expires_iso = datetime.datetime.fromtimestamp(
|
||||
exp_ts, tz=datetime.timezone.utc,
|
||||
).isoformat()
|
||||
|
||||
return IamResponse(jwt=token, jwt_expires=expires_iso)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# get-signing-key-public
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def handle_get_signing_key_public(self, v):
|
||||
_, _, public_pem = await self._get_active_signing_key()
|
||||
return IamResponse(signing_key_public=public_pem)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# resolve-api-key
|
||||
# ------------------------------------------------------------------
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue