diff --git a/docs/tech-specs/capabilities.md b/docs/tech-specs/capabilities.md new file mode 100644 index 00000000..60f5acbf --- /dev/null +++ b/docs/tech-specs/capabilities.md @@ -0,0 +1,218 @@ +--- +layout: default +title: "Capability Vocabulary Technical Specification" +parent: "Tech Specs" +--- + +# Capability Vocabulary Technical Specification + +## Overview + +Authorisation in TrustGraph is **capability-based**. Every gateway +endpoint maps to exactly one *capability*; a user's roles each grant +a set of capabilities; an authenticated request is permitted when +the required capability is a member of the union of the caller's +role capability sets. + +This document defines the capability vocabulary — the closed list +of capability strings that the gateway recognises — and the +open-source edition's role bundles. + +The capability mechanism is shared between open-source and potential +3rd party enterprise capability. The open-source edition ships a +fixed three-role bundle (`reader`, `writer`, `admin`). Enterprise +capability may define additional roles by composing their own +capability bundles from the same vocabulary; no protocol, gateway, +or backend-service change is required. + +## Motivation + +The original IAM spec used hierarchical "minimum role" checks +(`admin` implies `writer` implies `reader`). That shape is simple +but paints the role model into a corner: any enterprise need to +grant a subset of admin abilities (helpdesk that can reset +passwords but not edit flows; analyst who can query but not ingest) +requires a protocol-level change. + +A capability vocabulary decouples "what a request needs" from +"what roles a user has" and makes the role table pure data. The +open-source bundles can stay coarse while the enterprise role +table expands without any code movement. + +## Design + +### Capability string format + +`:` or `` (for capabilities with no +natural read/write split). All lowercase, kebab-case for +multi-word subsystems. + +### Capability list + +**Data plane** + +| Capability | Covers | +|---|---| +| `agent` | agent (query-only; no write counterpart) | +| `graph:read` | graph-rag, graph-embeddings-query, triples-query, sparql, graph-embeddings-export, triples-export | +| `graph:write` | triples-import, graph-embeddings-import | +| `documents:read` | document-rag, document-embeddings-query, document-embeddings-export, entity-contexts-export, document-stream-export, library list / fetch | +| `documents:write` | document-embeddings-import, entity-contexts-import, text-load, document-load, library add / replace / delete | +| `rows:read` | rows-query, row-embeddings-query, nlp-query, structured-query, structured-diag | +| `rows:write` | rows-import | +| `llm` | text-completion, prompt (stateless invocation) | +| `embeddings` | Raw text-embedding service (stateless compute; typed-data embedding stores live under their data-subject capability) | +| `mcp` | mcp-tool | +| `collections:read` | List / describe collections | +| `collections:write` | Create / delete collections | +| `knowledge:read` | List / get knowledge cores | +| `knowledge:write` | Create / delete knowledge cores | + +**Control plane** + +| Capability | Covers | +|---|---| +| `config:read` | Read workspace config | +| `config:write` | Write workspace config | +| `flows:read` | List / describe flows, blueprints, flow classes | +| `flows:write` | Start / stop / update flows | +| `users:read` | List / get users within the workspace | +| `users:write` | Create / update / disable users within the workspace | +| `users:admin` | Assign / remove roles on users within the workspace | +| `keys:self` | Create / revoke / list **own** API keys | +| `keys:admin` | Create / revoke / list **any user's** API keys within the workspace | +| `workspaces:admin` | Create / delete / disable workspaces (system-level) | +| `iam:admin` | JWT signing-key rotation, IAM-level operations | +| `metrics:read` | Prometheus metrics proxy | + +### Open-source role bundles + +The open-source edition ships three roles: + +| Role | Capabilities | +|---|---| +| `reader` | `agent`, `graph:read`, `documents:read`, `rows:read`, `llm`, `embeddings`, `mcp`, `collections:read`, `knowledge:read`, `flows:read`, `config:read`, `keys:self` | +| `writer` | everything in `reader` **+** `graph:write`, `documents:write`, `rows:write`, `collections:write`, `knowledge:write` | +| `admin` | everything in `writer` **+** `config:write`, `flows:write`, `users:read`, `users:write`, `users:admin`, `keys:admin`, `workspaces:admin`, `iam:admin`, `metrics:read` | + +Open-source bundles are deliberately coarse. `workspaces:admin` and +`iam:admin` live inside `admin` without a separate role; a single +`admin` user holds the keys to the whole deployment. + +### The `agent` capability and composition + +The `agent` capability is granted independently of the capabilities +it composes under the hood (`llm`, `graph`, `documents`, `rows`, +`mcp`, etc.). A user holding `agent` but not `llm` can still cause +LLM invocations because the agent implementation chooses which +services to invoke on the caller's behalf. + +This is deliberate. A common policy is "allow controlled access +via the agent, deny raw model calls" — granting `agent` without +granting `llm` expresses exactly that. An administrator granting +`agent` should treat it as a grant of everything the agent +composes at deployment time. + +### Authorisation evaluation + +For a request bearing a resolved set of roles +`R = {r1, r2, ...}` against an endpoint that requires capability +`c`: + +``` +allow if c IN union(bundle(r) for r in R) +``` + +No hierarchy, no precedence, no role-order sensitivity. A user +with a single role is the common case; a user with multiple roles +gets the union of their bundles. + +### Enforcement boundary + +Capability checks — and authentication — are applied **only at the +API gateway**, on requests arriving from external callers. +Operations originating inside the platform (backend service to +backend service, agent to LLM, flow-svc to config-svc, bootstrap +initialisers, scheduled reconcilers, autonomous flow steps) are +**not capability-checked**. Backend services trust the workspace +set by the gateway on inbound pub/sub messages and trust +internally-originated messages without further authorisation. + +This policy has four consequences that are part of the spec, not +accidents of implementation: + +1. **The gateway is the single trust boundary for user + authorisation.** Every backend service is a downstream consumer + of an already-authorised workspace scope. +2. **Pub/sub carries workspace, not user identity.** Messages on + the bus do not carry credentials or the identity that originated + a request; they carry the resolved workspace only. This keeps + the bus protocol free of secrets and aligns with the workspace + resolver's role as the gateway-side narrowing step. +3. **Composition is transitive.** Granting a capability that the + platform composes internally (for example, `agent`) transitively + grants everything that capability composes under the hood, + because the downstream calls are internal-origin and are not + re-checked. The composite nature of `agent` described above is + a consequence of this policy, not a special case. +4. **Internal-origin operations have no user.** Bootstrap, + reconcilers, and other platform-initiated work act with + system-level authority. The workspace field on such messages + identifies which workspace's data is being touched, not who + asked. + +**Trust model.** Whoever has pub/sub access is implicitly trusted +to act as any workspace. Defense-in-depth within the backend is +not part of this design; the security perimeter is the gateway +and the bus itself (TLS / network isolation between the bus and +any untrusted network). + +### Unknown capabilities and unknown roles + +- An endpoint declaring an unknown capability is a server-side bug + and fails closed (403, logged). +- A user carrying a role name that is not defined in the role table + is ignored for authorisation purposes and logged as a warning. + Behaviour is deterministic: unknown roles contribute zero + capabilities. + +### Capability scope + +Every capability is **implicitly scoped to the caller's resolved +workspace**. A `users:write` capability does not permit a user +in workspace `acme` to create users in workspace `beta` — the +workspace-resolver has already narrowed the request to one +workspace before the capability check runs. See the IAM +specification for the workspace-resolver contract. + +The three exceptions are the system-level capabilities +`workspaces:admin` and `iam:admin`, which operate across +workspaces by definition, and `metrics:read`, which returns +process-level series not scoped to any workspace. + +## Enterprise extensibility + +Enterprise editions extend the role table additively: + +``` +data-analyst: {query, library:read, collections:read, knowledge:read} +helpdesk: {users:read, users:write, users:admin, keys:admin} +data-engineer: writer + {flows:read, config:read} +workspace-owner: admin − {workspaces:admin, iam:admin} +``` + +None of this requires a protocol change — the wire-protocol `roles` +field on user records is already a set, the gateway's +capability-check is already capability-based, and the capability +vocabulary is closed. Enterprises may introduce roles whose bundles +compose the same capabilities differently. + +When an enterprise introduces a new capability (e.g. for a feature +that does not exist in open source), the capability string is +added to the vocabulary and recognised by the gateway build that +ships that feature. + +## References + +- [Identity and Access Management Specification](iam.md) +- [Architecture Principles](architecture-principles.md) diff --git a/docs/tech-specs/iam-protocol.md b/docs/tech-specs/iam-protocol.md new file mode 100644 index 00000000..8638e7e9 --- /dev/null +++ b/docs/tech-specs/iam-protocol.md @@ -0,0 +1,329 @@ +--- +layout: default +title: "IAM Service Protocol Technical Specification" +parent: "Tech Specs" +--- + +# IAM Service Protocol Technical Specification + +## Overview + +The IAM service is a backend processor, reached over the standard +request/response pub/sub pattern. It is the authority for users, +workspaces, API keys, and login credentials. The API gateway +delegates to it for authentication resolution and for all user / +workspace / key management. + +This document defines the wire protocol: the `IamRequest` and +`IamResponse` dataclasses, the operation set, the per-operation +input and output fields, the error taxonomy, and the initial HTTP +forwarding endpoint used while IAM is being integrated into the +gateway. + +Architectural context — roles, capabilities, workspace scoping, +enforcement boundary — lives in [`iam.md`](iam.md) and +[`capabilities.md`](capabilities.md). + +## Transport + +- **Request topic:** `request:tg/request/iam-request` +- **Response topic:** `response:tg/response/iam-response` +- **Pattern:** request/response, correlated by the `id` message + property, the same pattern used by `config-svc` and `flow-svc`. +- **Caller:** the API gateway only. Under the enforcement-boundary + policy (see capabilities spec), the IAM service trusts the bus + and performs no per-request authentication or capability check + against the caller. The gateway has already evaluated capability + membership and workspace scoping before sending the request. + +## Dataclasses + +### `IamRequest` + +```python +@dataclass +class IamRequest: + # One of the operation strings below. + operation: str = "" + + # Scope of this request. Required on every workspace-scoped + # operation. Omitted (or empty) for system-level ops + # (workspace CRUD, signing-key ops, bootstrap, resolve-api-key, + # login). + workspace: str = "" + + # Acting user id, for audit. Set by the gateway to the + # authenticated caller's id on user-initiated operations. + # Empty for internal-origin (bootstrap, reconcilers) and for + # resolve-api-key / login (no actor yet). + actor: str = "" + + # --- identity selectors --- + user_id: str = "" + username: str = "" # login; unique within a workspace + key_id: str = "" # revoke-api-key, list-api-keys (own) + api_key: str = "" # resolve-api-key (plaintext) + + # --- credentials --- + password: str = "" # login, change-password (current) + new_password: str = "" # change-password + + # --- user fields --- + user: UserInput | None = None # create-user, update-user + + # --- workspace fields --- + workspace_record: WorkspaceInput | None = None # create-workspace, update-workspace + + # --- api key fields --- + key: ApiKeyInput | None = None # create-api-key +``` + +### `IamResponse` + +```python +@dataclass +class IamResponse: + # Populated on success of operations that return them. + user: UserRecord | None = None # create-user, get-user, update-user + users: list[UserRecord] = field(default_factory=list) # list-users + workspace: WorkspaceRecord | None = None # create-workspace, get-workspace, update-workspace + workspaces: list[WorkspaceRecord] = field(default_factory=list) # list-workspaces + + # create-api-key returns the plaintext once. Never populated + # on any other operation. + api_key_plaintext: str = "" + api_key: ApiKeyRecord | None = None # create-api-key + api_keys: list[ApiKeyRecord] = field(default_factory=list) # list-api-keys + + # login, rotate-signing-key + jwt: str = "" + jwt_expires: str = "" # ISO-8601 UTC + + # get-signing-key-public + signing_key_public: str = "" # PEM + + # resolve-api-key returns who this key authenticates as. + resolved_user_id: str = "" + resolved_workspace: str = "" + resolved_roles: list[str] = field(default_factory=list) + + # reset-password + temporary_password: str = "" # returned once to the operator + + # bootstrap: on first run, the initial admin's one-time API key + # is returned for the operator to capture. + bootstrap_admin_user_id: str = "" + bootstrap_admin_api_key: str = "" + + # Present on any failed operation. + error: Error | None = None +``` + +### Value types + +```python +@dataclass +class UserInput: + username: str = "" + name: str = "" + email: str = "" + password: str = "" # only on create-user; never on update-user + roles: list[str] = field(default_factory=list) + enabled: bool = True + must_change_password: bool = False + +@dataclass +class UserRecord: + id: str = "" + workspace: str = "" + username: str = "" + name: str = "" + email: str = "" + roles: list[str] = field(default_factory=list) + enabled: bool = True + must_change_password: bool = False + created: str = "" # ISO-8601 UTC + # Password hash is never included in any response. + +@dataclass +class WorkspaceInput: + id: str = "" + name: str = "" + enabled: bool = True + +@dataclass +class WorkspaceRecord: + id: str = "" + name: str = "" + enabled: bool = True + created: str = "" # ISO-8601 UTC + +@dataclass +class ApiKeyInput: + user_id: str = "" + name: str = "" # operator-facing label, e.g. "laptop" + expires: str = "" # optional ISO-8601 UTC; empty = no expiry + +@dataclass +class ApiKeyRecord: + id: str = "" + user_id: str = "" + name: str = "" + prefix: str = "" # first 4 chars of plaintext, for identification in lists + expires: str = "" # empty = no expiry + created: str = "" + last_used: str = "" # empty if never used + # key_hash is never included in any response. +``` + +## Operations + +| Operation | Request fields | Response fields | Notes | +|---|---|---|---| +| `login` | `username`, `password`, `workspace` (optional) | `jwt`, `jwt_expires` | If `workspace` omitted, IAM resolves to the user's assigned workspace. | +| `resolve-api-key` | `api_key` (plaintext) | `resolved_user_id`, `resolved_workspace`, `resolved_roles` | Gateway-internal. Service returns `auth-failed` for unknown / expired / revoked keys. | +| `change-password` | `user_id`, `password` (current), `new_password` | — | Self-service. IAM validates `password` against stored hash. | +| `reset-password` | `user_id` | `temporary_password` | Admin-initiated. IAM generates a random password, sets `must_change_password=true` on the user, returns the plaintext once. | +| `create-user` | `workspace`, `user` | `user` | Admin-only. `user.password` is hashed and stored; `user.roles` must be subset of known roles. | +| `list-users` | `workspace` | `users` | | +| `get-user` | `workspace`, `user_id` | `user` | | +| `update-user` | `workspace`, `user_id`, `user` | `user` | `password` field on `user` is rejected; use `change-password` / `reset-password`. | +| `disable-user` | `workspace`, `user_id` | — | Soft-delete; sets `enabled=false`. Revokes all the user's API keys. | +| `create-workspace` | `workspace_record` | `workspace` | System-level. | +| `list-workspaces` | — | `workspaces` | System-level. | +| `get-workspace` | `workspace_record` (id only) | `workspace` | System-level. | +| `update-workspace` | `workspace_record` | `workspace` | System-level. | +| `disable-workspace` | `workspace_record` (id only) | — | System-level. Sets `enabled=false`; revokes all workspace API keys; disables all users in the workspace. | +| `create-api-key` | `workspace`, `key` | `api_key_plaintext`, `api_key` | Plaintext returned **once**; only hash stored. `key.name` required. | +| `list-api-keys` | `workspace`, `user_id` | `api_keys` | | +| `revoke-api-key` | `workspace`, `key_id` | — | Deletes the key record. | +| `get-signing-key-public` | — | `signing_key_public` | Gateway fetches this at startup. | +| `rotate-signing-key` | — | — | System-level. Introduces a new signing key; old key continues to validate JWTs for a grace period (implementation-defined, minimum 1h). | +| `bootstrap` | — | `bootstrap_admin_user_id`, `bootstrap_admin_api_key` | If IAM tables are empty, creates the initial `default` workspace, an `admin` user, an initial API key, and an initial signing key; returns them once. No-op on subsequent calls (returns empty fields). | + +## Error taxonomy + +All errors are carried in the `IamResponse.error` field. `error.type` +is one of the values below; `error.message` is a human-readable +string that is **not** surfaced verbatim to external callers (the +gateway maps to `auth failure` / `access denied` per the IAM error +policy). + +| `type` | When | +|---|---| +| `invalid-argument` | Malformed request (missing required field, unknown operation, invalid format). | +| `not-found` | Named resource does not exist (`user_id`, `key_id`, workspace). | +| `duplicate` | Create operation collides with an existing resource (username, workspace id, key name). | +| `auth-failed` | `login` with wrong credentials; `resolve-api-key` with unknown / expired / revoked key; `change-password` with wrong current password. Single bucket to deny oracle attacks. | +| `weak-password` | Password does not meet policy (length, complexity — policy defined at service level). | +| `disabled` | Target user or workspace has `enabled=false`. | +| `operation-not-permitted` | Non-admin attempting system-level operation, or workspace-scoped operation attempting to affect another workspace. | +| `internal-error` | Unexpected IAM-side failure. Log and surface as 500 at the gateway. | + +The gateway is responsible for translating `auth-failed` and +`operation-not-permitted` into the obfuscated external error +response (`"auth failure"` / `"access denied"`); `invalid-argument` +becomes a descriptive 400; `not-found` / `duplicate` / +`weak-password` / `disabled` become descriptive 4xx but never leak +IAM-internal detail. + +## Credential storage + +- **Passwords** are stored using a slow KDF (bcrypt / argon2id — the + service picks; documented as an implementation detail). The + `password_hash` column stores the full KDF-encoded string + (algorithm, cost, salt, hash). Not a plain SHA-256. +- **API keys** are stored as SHA-256 of the plaintext. API keys + are 128-bit random values (`tg_` + base64url); the entropy + makes a slow hash unnecessary. The hash serves as the primary + key on the `iam_api_keys` table, enabling O(1) lookup on + `resolve-api-key`. +- **JWT signing key** is stored as an RSA or Ed25519 private key + (implementation choice) in a dedicated `iam_signing_keys` table + with a `kid`, `created`, and optional `retired` timestamp. At + most one active key; up to N retired keys are kept for a grace + period to validate previously-issued JWTs. + +Passwords, API-key plaintext, and signing-key private material are +never returned in any response other than the explicit one-time +responses above (`reset-password`, `create-api-key`, `bootstrap`). + +## Bootstrap modes + +`iam-svc` requires a bootstrap mode to be chosen at startup. There is +no default — an unset or invalid mode causes the service to refuse +to start. The purpose is to force the operator to make an explicit +security decision rather than rely on an implicit "safe" fallback. + +| Mode | Startup behaviour | `bootstrap` operation | Suitability | +|---|---|---|---| +| `token` | On first start with empty tables, auto-seeds the `default` workspace, admin user, admin API key (using the operator-provided `--bootstrap-token`), and an initial signing key. No-op on subsequent starts. | Refused — returns `auth-failed` / `"auth failure"` regardless of caller. | Production, any public-exposure deployment. | +| `bootstrap` | No startup seeding. Tables remain empty until the `bootstrap` operation is invoked over the pub/sub bus (typically via `tg-bootstrap-iam`). | Live while tables are empty. Generates and returns the admin API key once. Refused (`auth-failed`) once tables are populated. | Dev / compose up / CI. **Not safe under public exposure** — any caller reaching the gateway's `/api/v1/iam` forwarder before the operator can cause a token to be issued to them. Operators choosing this mode accept that risk. | + +### Error masking + +In both modes, any refused invocation of the `bootstrap` operation +returns the same error (`auth-failed` / `"auth failure"`). A caller +cannot distinguish: + +- "service is in token mode" +- "service is in bootstrap mode but already bootstrapped" +- "operation forbidden" + +This matches the general IAM error-policy stance (see `iam.md`) and +prevents externally enumerating IAM's state. + +### Bootstrap-token lifecycle + +The bootstrap token — whether operator-supplied (`token` mode) or +service-generated (`bootstrap` mode) — is a one-time credential. It +is stored as admin's single API key, tagged `name="bootstrap"`. The +operator's first admin action after bootstrap should be: + +1. Create a durable admin user and API key (or issue a durable API + key to the bootstrap admin). +2. Revoke the bootstrap key via `revoke-api-key`. +3. Remove the bootstrap token from any deployment configuration. + +The `name="bootstrap"` marker makes bootstrap keys easy to detect in +tooling (e.g. a `tg-list-api-keys` filter). + +## HTTP forwarding (initial integration) + +For the initial gateway integration — before the IAM service is +wired into the authentication middleware — the gateway exposes a +single forwarding endpoint: + +``` +POST /api/v1/iam +``` + +- Request body is a JSON encoding of `IamRequest`. +- Response body is a JSON encoding of `IamResponse`. +- The gateway's existing authentication (`GATEWAY_SECRET` bearer) + gates access to this endpoint so the IAM protocol can be + exercised end-to-end in tests without touching the live auth + path. +- This endpoint is **not** the final shape. Once the middleware is + in place, per-operation REST endpoints replace it (for example + `POST /api/v1/auth/login`, `POST /api/v1/users`, `DELETE + /api/v1/api-keys/{id}`), and this generic forwarder is removed. + +The endpoint performs only message marshalling: it does not read +or rewrite fields in the request, and it applies no capability +check. All authorisation for user / workspace / key management +lands in the subsequent middleware work. + +## Non-goals for this spec + +- REST endpoint shape for the final gateway surface — covered in + Phase 2 of the IAM implementation plan, not here. +- OIDC / SAML external IdP protocol — out of scope for open source. +- Key-signing algorithm choice, password KDF choice, JWT claim + layout — implementation details captured in code + ADRs, not + locked in the protocol spec. + +## References + +- [Identity and Access Management Specification](iam.md) +- [Capability Vocabulary Specification](capabilities.md) diff --git a/docs/tech-specs/iam.md b/docs/tech-specs/iam.md index cb1399fe..50b64444 100644 --- a/docs/tech-specs/iam.md +++ b/docs/tech-specs/iam.md @@ -423,6 +423,37 @@ resolve API keys and to handle login requests. User management operations (create user, revoke key, etc.) also go through the IAM service. +### Error policy + +External error responses carry **no diagnostic detail** for +authentication or access-control failures. The goal is to give an +attacker probing the endpoint no signal about which condition they +tripped. + +| Category | HTTP | Body | WebSocket frame | +|----------|------|------|-----------------| +| Authentication failure | `401 Unauthorized` | `{"error": "auth failure"}` | `{"type": "auth-failed", "error": "auth failure"}` | +| Access control failure | `403 Forbidden` | `{"error": "access denied"}` | `{"error": "access denied"}` (endpoint-specific frame type) | + +"Authentication failure" covers missing credential, malformed +credential, invalid signature, expired token, revoked API key, and +unknown API key — all indistinguishable to the caller. + +"Access control failure" covers role insufficient, workspace +mismatch, user disabled, and workspace disabled — all +indistinguishable to the caller. + +**Server-side logging is richer.** The audit log records the specific +reason (`"workspace-mismatch: user alice assigned 'acme', requested +'beta'"`, `"role-insufficient: admin required, user has writer"`, +etc.) for operators and post-incident forensics. These messages never +appear in responses. + +Other error classes (bad request, internal error) remain descriptive +because they do not reveal anything about the auth or access-control +surface — e.g. `"missing required field 'workspace'"` or +`"invalid JSON"` is fine. + ### Gateway changes The current `Authenticator` class is replaced with a thin authentication @@ -713,6 +744,16 @@ These are not implemented but the architecture does not preclude them: - **Multi-workspace access.** Users could be granted access to additional workspaces beyond their primary assignment. The workspace validation step checks a grant list instead of a single assignment. +- **Workspace resolver.** Workspace resolution on each authenticated + request — "given this user and this requested workspace, which + workspace (if any) may the request operate on?" — is encapsulated + in a single pluggable resolver. The open-source edition ships a + resolver that permits only the user's single assigned workspace; + enterprise editions that implement multi-workspace access swap in a + resolver that consults a permitted set. The wire protocol (the + optional `workspace` field on the authenticated request) is + identical in both editions, so clients written against one edition + work unchanged against the other. - **Rules-based access control.** A separate access control service could evaluate fine-grained policies (per-collection permissions, operation-level restrictions, time-based access). The gateway diff --git a/iam-testing.txt b/iam-testing.txt new file mode 100644 index 00000000..0d03ffc3 --- /dev/null +++ b/iam-testing.txt @@ -0,0 +1,252 @@ + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation": "bootstrap"}' + + + + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation": "resolve-api-key", "api_key": "tg_r-n43hDWV9WOY06w6o5YpevAxirlS33D"}' + + + + + + + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation": "resolve-api-key", "api_key": "asdalsdjasdkasdasda"}' + + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"list-users","workspace":"default"}' + + + + # 1. Admin creates a writer user "alice" + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{ + "operation": "create-user", + "workspace": "default", + "user": { + "username": "alice", + "name": "Alice", + "email": "alice@example.com", + "password": "changeme", + "roles": ["writer"] + } + }' + # expect: {"user": {"id": "", ...}} — grab alice's uuid + + # 2. Issue alice an API key + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{ + "operation": "create-api-key", + "workspace": "default", + "key": { + "user_id": "f2363a10-3b83-44ea-a008-43caae8ba607", + "name": "alice-laptop" + } + }' + # expect: {"api_key_plaintext": "tg_...", "api_key": {"id": "", "prefix": "tg_xxxx", ...}} + + # 3. Resolve alice's key — should return alice's id + workspace + writer role + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"resolve-api-key","api_key":"tg_gt4buvk5NG-QS7oP_0Gk5yTWyj1qensf"}' + + # expect: {"resolved_user_id":"","resolved_workspace":"default","resolved_roles":["writer"]} + + # 4. List alice's keys (admin view of alice's keys) + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"list-api-keys","workspace":"default","user_id":"f2363a10-3b83-44ea-a008-43caae8ba607"}' + # expect: {"api_keys": [{"id":"","user_id":"","name":"alice-laptop","prefix":"tg_xxxx",...}]} + + # 5. Revoke alice's key + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"revoke-api-key","workspace":"default","key_id":"55f1c1f7-5448-49fd-9eda-56c192b61177"}' + + + # expect: {} (empty, no error) + + # 6. Confirm the revoked key no longer resolves + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"resolve-api-key","api_key":"tg_gt4buvk5NG-QS7oP_0Gk5yTWyj1qensf"}' + # expect: {"error":{"type":"auth-failed","message":"unknown api key"}} + + + +---------------------------------------------------------------------------- + + You'll want to re-bootstrap a fresh deployment to pick up the new signing-key row (or accept that login will lazily generate one on first + call). Then: + + # 1. Create a user with a known password (admin's password is random) + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"create-user","workspace":"default","user":{"username":"alice","password":"s3cret","roles":["writer"]}}' + + + + # 2. Log alice in + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"login","username":"alice","password":"s3cret"}' + # expect: {"jwt":"eyJ...","jwt_expires":"2026-..."} + + # 3. Fetch the public key (what the gateway will use later to verify) + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"get-signing-key-public"}' + + # expect: {"signing_key_public":"-----BEGIN PUBLIC KEY-----\n..."} + + # 4. Wrong password + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Authorization: Bearer $GATEWAY_SECRET" \ + -H "Content-Type: application/json" \ + -d '{"operation":"login","username":"alice","password":"nope"}' + + + + # expect: {"error":{"type":"auth-failed","message":"bad credentials"}} + + + + + +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAseLB/a9Bo/RN/Rb/x763 ++vdxmUKG75oWsXBmbwZGDXyN6fwqZ3L7cEje93qK0PYFuCHxhY1Hn0gW7FZ8ovH+ +qEksekUlpfPYqKGiT5Mb0DKk49D4yKkIbJFugWalpwIilvRbQO0jy3V8knqGQ1xL +NfNYFrI2Rxe0Tq2OHVYc5YwYbyj1nz2TY5fd9qrzXtGRv5HZztkl25lWhRvG9G0K +urKDdBDbi894gIYorXvcwZw/b1GDXG/aUy/By1Oy3hXnCLsN8pA3nA437TTTWxHx +QgPH15jIF9hezO+3/ESZ7EhVEtgmwTxPddfXRa0ZoT6JyWOgcloKtnP4Lp9eQ4va +yQIDAQAB +-----END PUBLIC KEY----- + + + + + + New operations: + - change-password — self-service. Requires current + new password. + - reset-password — admin-driven. Generates a random temporary, sets must_change_password=true, returns plaintext once. + - get-user, update-user, disable-user — workspace-scoped. update-user refuses to change username (immutable — error if different) and refuses + password-via-update. disable-user also revokes all the user's API keys, per spec. + - create-workspace, list-workspaces, get-workspace, update-workspace, disable-workspace — system-level. disable-workspace cascades: disables + all users + revokes all their keys. Rejects ids starting with _ (reserved, per the bootstrap framework convention). + - rotate-signing-key — generates a new Ed25519 key, retires the current one (sets retired timestamp; row stays for future grace-period + validation), switches the in-memory cache. + + Touched files: + - trustgraph-flow/trustgraph/tables/iam.py — added retire_signing_key, update_user_profile, update_user_password, update_user_enabled, + update_workspace. + - trustgraph-flow/trustgraph/iam/service/iam.py — 12 new handlers + dispatch entries. + - trustgraph-base/trustgraph/base/iam_client.py — matching client helpers for all of them. + + Smoke-test suggestions: + + # change password for alice (from "s3cret" → "n3wer") + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"change-password","user_id":"b2960feb-caef-401d-af65-01bdb6960cad","password":"s3cret","new_password":"n3wer"}' + + # login with new password + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"login","username":"alice","password":"n3wer"}' + + # admin resets alice's password + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"reset-password","workspace":"default","user_id":"b2960feb-caef-401d-af65-01bdb6960cad"}' + + + # → {"temporary_password":"..."} + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"login","username":"alice","password":"fH2ttyrIcVXCIkH_"}' + + + # create a second workspace + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"create-workspace","workspace_record":{"id":"acme","name":"Acme Corp","enabled":true}}' + + + # rotate signing key (next login produces a JWT signed by a new kid) + + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -d '{"operation":"rotate-signing-key"}' + + + + + + + curl -s -X POST "http://localhost:8088/api/v1/flow" \ + -H "Authorization: Bearer tg_bs_kBAhfejiEJmbcO1gElbxk3MpV7wQFygP" \ + -H "Content-Type: application/json" \ + -d '{"operation":"list-flows"}' + + curl -s -X POST "http://localhost:8088/api/v1/iam" \ + -H "Authorization: Bearer tg_bs_kBAhfejiEJmbcO1gElbxk3MpV7wQFygP" \ + -H "Content-Type: application/json" \ + -d '{"operation":"list-users"}' + + + + curl -s -X POST http://localhost:8088/api/v1/iam \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer tg_bs_kBAhfejiEJmbcO1gElbxk3MpV7wQFygP" \ + -d '{ + "operation": "create-user", + "workspace": "default", + "user": { + "username": "alice", + "name": "Alice", + "email": "alice@example.com", + "password": "s3cret", + "roles": ["writer"] + } + }' + + + + + # Login (public, no token needed) → returns a JWT + curl -s -X POST "http://localhost:8088/api/v1/auth/login" \ + -H "Content-Type: application/json" \ + -d '{"username":"alice","password":"s3cret"}' + + + + export TRUSTGRAPH_TOKEN=$(tg-bootstrap-iam) # on fresh bootstrap-mode deployment + # or set to your existing admin API key + + tg-create-user --username alice --roles writer + # → prints alice's user id + + ALICE_ID= + + ALICE_KEY=$(tg-create-api-key --user-id $ALICE_ID --name alice-laptop) + # → alice's plaintext API key + + tg-list-users + tg-list-api-keys --user-id $ALICE_ID + + tg-revoke-api-key --key-id <...> + tg-disable-user --user-id $ALICE_ID + + # User self-service: + tg-login --username alice # prompts for password, prints JWT + tg-change-password # prompts for current + new + + diff --git a/tests/unit/test_gateway/test_auth.py b/tests/unit/test_gateway/test_auth.py index d4d4fc2b..ba2b9bc2 100644 --- a/tests/unit/test_gateway/test_auth.py +++ b/tests/unit/test_gateway/test_auth.py @@ -1,69 +1,312 @@ """ -Tests for Gateway Authentication +Tests for gateway/auth.py — IamAuth, JWT verification, API key +resolution cache. + +JWTs are signed with real Ed25519 keypairs generated per-test, so +the crypto path is exercised end-to-end without mocks. API-key +resolution is tested against a stubbed IamClient since the real +one requires pub/sub. """ +import base64 +import json +import time +from unittest.mock import AsyncMock, Mock, patch + import pytest +from aiohttp import web +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 -from trustgraph.gateway.auth import Authenticator +from trustgraph.gateway.auth import ( + IamAuth, Identity, + _b64url_decode, _verify_jwt_eddsa, + API_KEY_CACHE_TTL, +) -class TestAuthenticator: - """Test cases for Authenticator class""" +# -- helpers --------------------------------------------------------------- - def test_authenticator_initialization_with_token(self): - """Test Authenticator initialization with valid token""" - auth = Authenticator(token="test-token-123") - - assert auth.token == "test-token-123" - assert auth.allow_all is False - def test_authenticator_initialization_with_allow_all(self): - """Test Authenticator initialization with allow_all=True""" - auth = Authenticator(allow_all=True) - - assert auth.token is None - assert auth.allow_all is True +def _b64url(data: bytes) -> str: + return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") - def test_authenticator_initialization_without_token_raises_error(self): - """Test Authenticator initialization without token raises RuntimeError""" - with pytest.raises(RuntimeError, match="Need a token"): - Authenticator() - def test_authenticator_initialization_with_empty_token_raises_error(self): - """Test Authenticator initialization with empty token raises RuntimeError""" - with pytest.raises(RuntimeError, match="Need a token"): - Authenticator(token="") +def make_keypair(): + priv = ed25519.Ed25519PrivateKey.generate() + public_pem = priv.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ).decode("ascii") + return priv, public_pem - def test_permitted_with_allow_all_returns_true(self): - """Test permitted method returns True when allow_all is enabled""" - auth = Authenticator(allow_all=True) - - # Should return True regardless of token or roles - assert auth.permitted("any-token", []) is True - assert auth.permitted("different-token", ["admin"]) is True - assert auth.permitted(None, ["user"]) is True - def test_permitted_with_matching_token_returns_true(self): - """Test permitted method returns True with matching token""" - auth = Authenticator(token="secret-token") - - # Should return True when tokens match - assert auth.permitted("secret-token", []) is True - assert auth.permitted("secret-token", ["admin", "user"]) is True +def sign_jwt(priv, claims, alg="EdDSA"): + header = {"alg": alg, "typ": "JWT", "kid": "kid-test"} + h = _b64url(json.dumps(header, separators=(",", ":"), sort_keys=True).encode()) + p = _b64url(json.dumps(claims, separators=(",", ":"), sort_keys=True).encode()) + signing_input = f"{h}.{p}".encode("ascii") + if alg == "EdDSA": + sig = priv.sign(signing_input) + else: + raise ValueError(f"test helper doesn't sign {alg}") + return f"{h}.{p}.{_b64url(sig)}" - def test_permitted_with_non_matching_token_returns_false(self): - """Test permitted method returns False with non-matching token""" - auth = Authenticator(token="secret-token") - - # Should return False when tokens don't match - assert auth.permitted("wrong-token", []) is False - assert auth.permitted("different-token", ["admin"]) is False - assert auth.permitted(None, ["user"]) is False - def test_permitted_with_token_and_allow_all_returns_true(self): - """Test permitted method with both token and allow_all set""" - auth = Authenticator(token="test-token", allow_all=True) - - # allow_all should take precedence - assert auth.permitted("any-token", []) is True - assert auth.permitted("wrong-token", ["admin"]) is True \ No newline at end of file +def make_request(auth_header): + """Minimal stand-in for an aiohttp request — IamAuth only reads + ``request.headers["Authorization"]``.""" + req = Mock() + req.headers = {} + if auth_header is not None: + req.headers["Authorization"] = auth_header + return req + + +# -- pure helpers ---------------------------------------------------------- + + +class TestB64UrlDecode: + + def test_round_trip_without_padding(self): + data = b"hello" + encoded = _b64url(data) + assert _b64url_decode(encoded) == data + + def test_handles_various_lengths(self): + for s in (b"a", b"ab", b"abc", b"abcd", b"abcde"): + assert _b64url_decode(_b64url(s)) == s + + +# -- JWT verification ----------------------------------------------------- + + +class TestVerifyJwtEddsa: + + def test_valid_jwt_passes(self): + priv, pub = make_keypair() + claims = { + "sub": "user-1", "workspace": "default", + "roles": ["reader"], + "iat": int(time.time()), + "exp": int(time.time()) + 60, + } + token = sign_jwt(priv, claims) + got = _verify_jwt_eddsa(token, pub) + assert got["sub"] == "user-1" + assert got["workspace"] == "default" + + def test_expired_jwt_rejected(self): + priv, pub = make_keypair() + claims = { + "sub": "user-1", "workspace": "default", "roles": [], + "iat": int(time.time()) - 3600, + "exp": int(time.time()) - 1, + } + token = sign_jwt(priv, claims) + with pytest.raises(ValueError, match="expired"): + _verify_jwt_eddsa(token, pub) + + def test_bad_signature_rejected(self): + priv_a, _ = make_keypair() + _, pub_b = make_keypair() + claims = { + "sub": "user-1", "workspace": "default", "roles": [], + "iat": int(time.time()), + "exp": int(time.time()) + 60, + } + token = sign_jwt(priv_a, claims) + # pub_b never signed this token. + with pytest.raises(Exception): + _verify_jwt_eddsa(token, pub_b) + + def test_malformed_jwt_rejected(self): + _, pub = make_keypair() + with pytest.raises(ValueError, match="malformed"): + _verify_jwt_eddsa("not-a-jwt", pub) + + def test_unsupported_algorithm_rejected(self): + priv, pub = make_keypair() + # Manually build an "alg":"HS256" header — no signer needed + # since we expect it to bail before verifying. + header = {"alg": "HS256", "typ": "JWT", "kid": "x"} + payload = { + "sub": "user-1", "workspace": "default", "roles": [], + "iat": int(time.time()), "exp": int(time.time()) + 60, + } + h = _b64url(json.dumps(header, separators=(",", ":")).encode()) + p = _b64url(json.dumps(payload, separators=(",", ":")).encode()) + sig = _b64url(b"not-a-real-sig") + token = f"{h}.{p}.{sig}" + with pytest.raises(ValueError, match="unsupported alg"): + _verify_jwt_eddsa(token, pub) + + +# -- Identity -------------------------------------------------------------- + + +class TestIdentity: + + def test_fields(self): + i = Identity( + user_id="u", workspace="w", roles=["reader"], source="api-key", + ) + assert i.user_id == "u" + assert i.workspace == "w" + assert i.roles == ["reader"] + assert i.source == "api-key" + + +# -- IamAuth.authenticate -------------------------------------------------- + + +class TestIamAuthDispatch: + """``authenticate()`` chooses between the JWT and API-key paths + by shape of the bearer.""" + + @pytest.mark.asyncio + async def test_no_authorization_header_raises_401(self): + auth = IamAuth(backend=Mock()) + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request(None)) + + @pytest.mark.asyncio + async def test_non_bearer_header_raises_401(self): + auth = IamAuth(backend=Mock()) + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request("Basic whatever")) + + @pytest.mark.asyncio + async def test_empty_bearer_raises_401(self): + auth = IamAuth(backend=Mock()) + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request("Bearer ")) + + @pytest.mark.asyncio + async def test_unknown_format_raises_401(self): + # Not tg_... and not dotted-JWT shape. + auth = IamAuth(backend=Mock()) + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request("Bearer garbage")) + + @pytest.mark.asyncio + async def test_valid_jwt_resolves_to_identity(self): + priv, pub = make_keypair() + claims = { + "sub": "user-1", "workspace": "default", + "roles": ["writer"], + "iat": int(time.time()), + "exp": int(time.time()) + 60, + } + token = sign_jwt(priv, claims) + + auth = IamAuth(backend=Mock()) + auth._signing_public_pem = pub + + ident = await auth.authenticate( + make_request(f"Bearer {token}") + ) + assert ident.user_id == "user-1" + assert ident.workspace == "default" + assert ident.roles == ["writer"] + assert ident.source == "jwt" + + @pytest.mark.asyncio + async def test_jwt_without_public_key_fails(self): + # If the gateway hasn't fetched IAM's public key yet, JWTs + # must not validate — even ones that would otherwise pass. + priv, _ = make_keypair() + claims = { + "sub": "user-1", "workspace": "default", "roles": [], + "iat": int(time.time()), "exp": int(time.time()) + 60, + } + token = sign_jwt(priv, claims) + auth = IamAuth(backend=Mock()) + # _signing_public_pem defaults to None + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate(make_request(f"Bearer {token}")) + + @pytest.mark.asyncio + async def test_api_key_path(self): + auth = IamAuth(backend=Mock()) + + async def fake_resolve(api_key): + assert api_key == "tg_testkey" + return ("user-xyz", "default", ["admin"]) + + async def fake_with_client(op): + return await op(Mock(resolve_api_key=fake_resolve)) + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + ident = await auth.authenticate( + make_request("Bearer tg_testkey") + ) + assert ident.user_id == "user-xyz" + assert ident.workspace == "default" + assert ident.roles == ["admin"] + assert ident.source == "api-key" + + @pytest.mark.asyncio + async def test_api_key_rejection_masked_as_401(self): + auth = IamAuth(backend=Mock()) + + async def fake_with_client(op): + raise RuntimeError("auth-failed: unknown api key") + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + with pytest.raises(web.HTTPUnauthorized): + await auth.authenticate( + make_request("Bearer tg_bogus") + ) + + +# -- API key cache --------------------------------------------------------- + + +class TestApiKeyCache: + + @pytest.mark.asyncio + async def test_cache_hit_skips_iam(self): + auth = IamAuth(backend=Mock()) + calls = {"n": 0} + + async def fake_with_client(op): + calls["n"] += 1 + return await op(Mock( + resolve_api_key=AsyncMock( + return_value=("u", "default", ["reader"]), + ) + )) + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + await auth.authenticate(make_request("Bearer tg_k1")) + await auth.authenticate(make_request("Bearer tg_k1")) + await auth.authenticate(make_request("Bearer tg_k1")) + + # Only the first lookup reaches IAM; the rest are cache hits. + assert calls["n"] == 1 + + @pytest.mark.asyncio + async def test_different_keys_are_separately_cached(self): + auth = IamAuth(backend=Mock()) + seen = [] + + async def fake_with_client(op): + async def resolve(plaintext): + seen.append(plaintext) + return ("u-" + plaintext, "default", ["reader"]) + return await op(Mock(resolve_api_key=resolve)) + + with patch.object(auth, "_with_client", side_effect=fake_with_client): + a = await auth.authenticate(make_request("Bearer tg_a")) + b = await auth.authenticate(make_request("Bearer tg_b")) + + assert a.user_id == "u-tg_a" + assert b.user_id == "u-tg_b" + assert seen == ["tg_a", "tg_b"] + + @pytest.mark.asyncio + async def test_cache_has_ttl_constant_set(self): + # Not a behaviour test — just ensures we don't accidentally + # set TTL to 0 (which would defeat the cache) or to a week. + assert 10 <= API_KEY_CACHE_TTL <= 3600 diff --git a/tests/unit/test_gateway/test_capabilities.py b/tests/unit/test_gateway/test_capabilities.py new file mode 100644 index 00000000..063e9ea4 --- /dev/null +++ b/tests/unit/test_gateway/test_capabilities.py @@ -0,0 +1,203 @@ +""" +Tests for gateway/capabilities.py — the capability + role + workspace +model that underpins all gateway authorisation. +""" + +import pytest +from aiohttp import web + +from trustgraph.gateway.capabilities import ( + PUBLIC, AUTHENTICATED, + KNOWN_CAPABILITIES, ROLE_DEFINITIONS, + check, enforce_workspace, access_denied, auth_failure, +) + + +# -- test fixtures --------------------------------------------------------- + + +class _Identity: + """Minimal stand-in for auth.Identity — the capability module + accesses ``.workspace`` and ``.roles``.""" + def __init__(self, workspace, roles): + self.user_id = "user-1" + self.workspace = workspace + self.roles = list(roles) + + +def reader_in(ws): + return _Identity(ws, ["reader"]) + + +def writer_in(ws): + return _Identity(ws, ["writer"]) + + +def admin_in(ws): + return _Identity(ws, ["admin"]) + + +# -- role table sanity ----------------------------------------------------- + + +class TestRoleTable: + + def test_oss_roles_present(self): + assert set(ROLE_DEFINITIONS.keys()) == {"reader", "writer", "admin"} + + def test_admin_is_cross_workspace(self): + assert ROLE_DEFINITIONS["admin"]["workspace_scope"] == "*" + + def test_reader_writer_are_assigned_scope(self): + assert ROLE_DEFINITIONS["reader"]["workspace_scope"] == "assigned" + assert ROLE_DEFINITIONS["writer"]["workspace_scope"] == "assigned" + + def test_admin_superset_of_writer(self): + admin = ROLE_DEFINITIONS["admin"]["capabilities"] + writer = ROLE_DEFINITIONS["writer"]["capabilities"] + assert writer.issubset(admin) + + def test_writer_superset_of_reader(self): + writer = ROLE_DEFINITIONS["writer"]["capabilities"] + reader = ROLE_DEFINITIONS["reader"]["capabilities"] + assert reader.issubset(writer) + + def test_admin_has_users_admin(self): + assert "users:admin" in ROLE_DEFINITIONS["admin"]["capabilities"] + + def test_writer_does_not_have_users_admin(self): + assert "users:admin" not in ROLE_DEFINITIONS["writer"]["capabilities"] + + def test_every_bundled_capability_is_known(self): + for role in ROLE_DEFINITIONS.values(): + for cap in role["capabilities"]: + assert cap in KNOWN_CAPABILITIES + + +# -- check() --------------------------------------------------------------- + + +class TestCheck: + + def test_reader_has_reader_cap_in_own_workspace(self): + assert check(reader_in("default"), "graph:read", "default") + + def test_reader_does_not_have_writer_cap(self): + assert not check(reader_in("default"), "graph:write", "default") + + def test_reader_cannot_act_in_other_workspace(self): + assert not check(reader_in("default"), "graph:read", "acme") + + def test_writer_has_write_in_own_workspace(self): + assert check(writer_in("default"), "graph:write", "default") + + def test_writer_cannot_act_in_other_workspace(self): + assert not check(writer_in("default"), "graph:write", "acme") + + def test_admin_has_everything_everywhere(self): + for cap in ("graph:read", "graph:write", "config:write", + "users:admin", "metrics:read"): + assert check(admin_in("default"), cap, "acme"), ( + f"admin should have {cap} in acme" + ) + + def test_admin_has_caps_without_explicit_workspace(self): + assert check(admin_in("default"), "users:admin") + + def test_default_target_is_identity_workspace(self): + # Reader with no target workspace → should check against own + assert check(reader_in("default"), "graph:read") + + def test_unknown_capability_returns_false(self): + assert not check(admin_in("default"), "nonsense:cap", "default") + + def test_unknown_role_contributes_nothing(self): + ident = _Identity("default", ["made-up-role"]) + assert not check(ident, "graph:read", "default") + + def test_multi_role_union(self): + # If a user is both reader and admin, they inherit admin's + # cross-workspace powers. + ident = _Identity("default", ["reader", "admin"]) + assert check(ident, "users:admin", "acme") + + +# -- enforce_workspace() --------------------------------------------------- + + +class TestEnforceWorkspace: + + def test_reader_in_own_workspace_allowed(self): + data = {"workspace": "default", "operation": "x"} + enforce_workspace(data, reader_in("default")) + assert data["workspace"] == "default" + + def test_reader_no_workspace_injects_assigned(self): + data = {"operation": "x"} + enforce_workspace(data, reader_in("default")) + assert data["workspace"] == "default" + + def test_reader_mismatched_workspace_denied(self): + data = {"workspace": "acme", "operation": "x"} + with pytest.raises(web.HTTPForbidden): + enforce_workspace(data, reader_in("default")) + + def test_admin_can_target_any_workspace(self): + data = {"workspace": "acme", "operation": "x"} + enforce_workspace(data, admin_in("default")) + assert data["workspace"] == "acme" + + def test_admin_no_workspace_defaults_to_assigned(self): + data = {"operation": "x"} + enforce_workspace(data, admin_in("default")) + assert data["workspace"] == "default" + + def test_writer_same_workspace_specified_allowed(self): + data = {"workspace": "default"} + enforce_workspace(data, writer_in("default")) + assert data["workspace"] == "default" + + def test_non_dict_passthrough(self): + # Non-dict bodies are returned unchanged (e.g. streaming). + result = enforce_workspace("not-a-dict", reader_in("default")) + assert result == "not-a-dict" + + def test_with_capability_tightens_check(self): + # Reader lacks graph:write; workspace-only check would pass + # (scope is fine), but combined check must reject. + data = {"workspace": "default"} + with pytest.raises(web.HTTPForbidden): + enforce_workspace( + data, reader_in("default"), capability="graph:write", + ) + + def test_with_capability_passes_when_granted(self): + data = {"workspace": "default"} + enforce_workspace( + data, reader_in("default"), capability="graph:read", + ) + assert data["workspace"] == "default" + + +# -- helpers --------------------------------------------------------------- + + +class TestResponseHelpers: + + def test_auth_failure_is_401(self): + exc = auth_failure() + assert exc.status == 401 + assert "auth failure" in exc.text + + def test_access_denied_is_403(self): + exc = access_denied() + assert exc.status == 403 + assert "access denied" in exc.text + + +class TestSentinels: + + def test_public_and_authenticated_are_distinct(self): + assert PUBLIC != AUTHENTICATED + assert PUBLIC not in KNOWN_CAPABILITIES + assert AUTHENTICATED not in KNOWN_CAPABILITIES diff --git a/tests/unit/test_gateway/test_dispatch_manager.py b/tests/unit/test_gateway/test_dispatch_manager.py index f091a46d..e399d712 100644 --- a/tests/unit/test_gateway/test_dispatch_manager.py +++ b/tests/unit/test_gateway/test_dispatch_manager.py @@ -42,7 +42,7 @@ class TestDispatcherManager: mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) assert manager.backend == mock_backend assert manager.config_receiver == mock_config_receiver @@ -59,7 +59,10 @@ class TestDispatcherManager: mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver, prefix="custom-prefix") + manager = DispatcherManager( + mock_backend, mock_config_receiver, + auth=Mock(), prefix="custom-prefix", + ) assert manager.prefix == "custom-prefix" @@ -68,7 +71,7 @@ class TestDispatcherManager: """Test start_flow method""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) flow_data = {"name": "test_flow", "steps": []} @@ -82,7 +85,7 @@ class TestDispatcherManager: """Test stop_flow method""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) # Pre-populate with a flow flow_data = {"name": "test_flow", "steps": []} @@ -96,7 +99,7 @@ class TestDispatcherManager: """Test dispatch_global_service returns DispatcherWrapper""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) wrapper = manager.dispatch_global_service() @@ -107,7 +110,7 @@ class TestDispatcherManager: """Test dispatch_core_export returns DispatcherWrapper""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) wrapper = manager.dispatch_core_export() @@ -118,7 +121,7 @@ class TestDispatcherManager: """Test dispatch_core_import returns DispatcherWrapper""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) wrapper = manager.dispatch_core_import() @@ -130,7 +133,7 @@ class TestDispatcherManager: """Test process_core_import method""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) with patch('trustgraph.gateway.dispatch.manager.CoreImport') as mock_core_import: mock_importer = Mock() @@ -148,7 +151,7 @@ class TestDispatcherManager: """Test process_core_export method""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) with patch('trustgraph.gateway.dispatch.manager.CoreExport') as mock_core_export: mock_exporter = Mock() @@ -166,7 +169,7 @@ class TestDispatcherManager: """Test process_global_service method""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) manager.invoke_global_service = AsyncMock(return_value="global_result") @@ -181,7 +184,7 @@ class TestDispatcherManager: """Test invoke_global_service with existing dispatcher""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) # Pre-populate with existing dispatcher mock_dispatcher = Mock() @@ -198,7 +201,7 @@ class TestDispatcherManager: """Test invoke_global_service creates new dispatcher""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) with patch('trustgraph.gateway.dispatch.manager.global_dispatchers') as mock_dispatchers: mock_dispatcher_class = Mock() @@ -230,7 +233,7 @@ class TestDispatcherManager: """Test dispatch_flow_import returns correct method""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) result = manager.dispatch_flow_import() @@ -240,7 +243,7 @@ class TestDispatcherManager: """Test dispatch_flow_export returns correct method""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) result = manager.dispatch_flow_export() @@ -250,7 +253,7 @@ class TestDispatcherManager: """Test dispatch_socket returns correct method""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) result = manager.dispatch_socket() @@ -260,7 +263,7 @@ class TestDispatcherManager: """Test dispatch_flow_service returns DispatcherWrapper""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) wrapper = manager.dispatch_flow_service() @@ -272,7 +275,7 @@ class TestDispatcherManager: """Test process_flow_import with valid flow and kind""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) # Setup test flow manager.flows[("default", "test_flow")] = { @@ -308,7 +311,7 @@ class TestDispatcherManager: """Test process_flow_import with invalid flow""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) params = {"flow": "invalid_flow", "kind": "triples"} @@ -323,7 +326,7 @@ class TestDispatcherManager: warnings.simplefilter("ignore", RuntimeWarning) mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) # Setup test flow manager.flows[("default", "test_flow")] = { @@ -345,7 +348,7 @@ class TestDispatcherManager: """Test process_flow_export with valid flow and kind""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) # Setup test flow manager.flows[("default", "test_flow")] = { @@ -378,26 +381,47 @@ class TestDispatcherManager: @pytest.mark.asyncio async def test_process_socket(self): - """Test process_socket method""" + """process_socket constructs a Mux with the manager's auth + instance passed through — this is the gateway's trust path + for first-frame WebSocket authentication. A Mux cannot be + built without auth (tested separately); this test pins that + the dispatcher-manager threads the correct auth value into + the Mux constructor call.""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) - + mock_auth = Mock() + manager = DispatcherManager( + mock_backend, mock_config_receiver, auth=mock_auth, + ) + with patch('trustgraph.gateway.dispatch.manager.Mux') as mock_mux: mock_mux_instance = Mock() mock_mux.return_value = mock_mux_instance - + result = await manager.process_socket("ws", "running", {}) - - mock_mux.assert_called_once_with(manager, "ws", "running") + + mock_mux.assert_called_once_with( + manager, "ws", "running", auth=mock_auth, + ) assert result == mock_mux_instance + def test_dispatcher_manager_requires_auth(self): + """Constructing a DispatcherManager without an auth argument + must fail — a no-auth DispatcherManager would produce a + Mux without authentication, silently downgrading the socket + auth path.""" + mock_backend = Mock() + mock_config_receiver = Mock() + + with pytest.raises(ValueError, match="auth"): + DispatcherManager(mock_backend, mock_config_receiver, auth=None) + @pytest.mark.asyncio async def test_process_flow_service(self): """Test process_flow_service method""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) manager.invoke_flow_service = AsyncMock(return_value="flow_result") @@ -412,7 +436,7 @@ class TestDispatcherManager: """Test invoke_flow_service with existing dispatcher""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) # Add flow to the flows dictionary manager.flows[("default", "test_flow")] = {"services": {"agent": {}}} @@ -432,7 +456,7 @@ class TestDispatcherManager: """Test invoke_flow_service creates request-response dispatcher""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) # Setup test flow manager.flows[("default", "test_flow")] = { @@ -476,7 +500,7 @@ class TestDispatcherManager: """Test invoke_flow_service creates sender dispatcher""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) # Setup test flow manager.flows[("default", "test_flow")] = { @@ -516,7 +540,7 @@ class TestDispatcherManager: """Test invoke_flow_service with invalid flow""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) with pytest.raises(RuntimeError, match="Invalid flow"): await manager.invoke_flow_service("data", "responder", "default", "invalid_flow", "agent") @@ -526,7 +550,7 @@ class TestDispatcherManager: """Test invoke_flow_service with kind not supported by flow""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) # Setup test flow without agent interface manager.flows[("default", "test_flow")] = { @@ -543,7 +567,7 @@ class TestDispatcherManager: """Test invoke_flow_service with invalid kind""" mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) # Setup test flow with interface but unsupported kind manager.flows[("default", "test_flow")] = { @@ -570,7 +594,7 @@ class TestDispatcherManager: """ mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) async def slow_start(): # Yield to the event loop so other coroutines get a chance to run, @@ -606,7 +630,7 @@ class TestDispatcherManager: """ mock_backend = Mock() mock_config_receiver = Mock() - manager = DispatcherManager(mock_backend, mock_config_receiver) + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) manager.flows[("default", "test_flow")] = { "interfaces": { diff --git a/tests/unit/test_gateway/test_dispatch_mux.py b/tests/unit/test_gateway/test_dispatch_mux.py index a0bc9460..c1baa920 100644 --- a/tests/unit/test_gateway/test_dispatch_mux.py +++ b/tests/unit/test_gateway/test_dispatch_mux.py @@ -12,6 +12,19 @@ from trustgraph.gateway.dispatch.mux import Mux, MAX_QUEUE_SIZE class TestMux: """Test cases for Mux class""" + def test_mux_requires_auth(self): + """Constructing a Mux without an ``auth`` argument must + fail. The Mux implements the first-frame auth protocol and + there is no no-auth mode — a no-auth Mux would silently + accept every frame without authenticating it.""" + with pytest.raises(ValueError, match="auth"): + Mux( + dispatcher_manager=MagicMock(), + ws=MagicMock(), + running=MagicMock(), + auth=None, + ) + def test_mux_initialization(self): """Test Mux initialization""" mock_dispatcher_manager = MagicMock() @@ -21,7 +34,8 @@ class TestMux: mux = Mux( dispatcher_manager=mock_dispatcher_manager, ws=mock_ws, - running=mock_running + running=mock_running, + auth=MagicMock(), ) assert mux.dispatcher_manager == mock_dispatcher_manager @@ -40,7 +54,8 @@ class TestMux: mux = Mux( dispatcher_manager=mock_dispatcher_manager, ws=mock_ws, - running=mock_running + running=mock_running, + auth=MagicMock(), ) # Call destroy @@ -61,7 +76,8 @@ class TestMux: mux = Mux( dispatcher_manager=mock_dispatcher_manager, ws=None, - running=mock_running + running=mock_running, + auth=MagicMock(), ) # Call destroy @@ -81,7 +97,8 @@ class TestMux: mux = Mux( dispatcher_manager=mock_dispatcher_manager, ws=mock_ws, - running=mock_running + running=mock_running, + auth=MagicMock(), ) # Mock message with valid JSON @@ -108,7 +125,8 @@ class TestMux: mux = Mux( dispatcher_manager=mock_dispatcher_manager, ws=mock_ws, - running=mock_running + running=mock_running, + auth=MagicMock(), ) # Mock message without request field @@ -137,7 +155,8 @@ class TestMux: mux = Mux( dispatcher_manager=mock_dispatcher_manager, ws=mock_ws, - running=mock_running + running=mock_running, + auth=MagicMock(), ) # Mock message without id field @@ -164,7 +183,8 @@ class TestMux: mux = Mux( dispatcher_manager=mock_dispatcher_manager, ws=mock_ws, - running=mock_running + running=mock_running, + auth=MagicMock(), ) # Mock message with invalid JSON diff --git a/tests/unit/test_gateway/test_endpoint_constant.py b/tests/unit/test_gateway/test_endpoint_constant.py index f208c967..98588e55 100644 --- a/tests/unit/test_gateway/test_endpoint_constant.py +++ b/tests/unit/test_gateway/test_endpoint_constant.py @@ -13,29 +13,36 @@ class TestConstantEndpoint: """Test cases for ConstantEndpoint class""" def test_constant_endpoint_initialization(self): - """Test ConstantEndpoint initialization""" + """Construction records the configured capability on the + instance. The capability is a required argument — no + permissive default — and the test passes an explicit + value to demonstrate the contract.""" mock_auth = MagicMock() mock_dispatcher = MagicMock() - + endpoint = ConstantEndpoint( endpoint_path="/api/test", auth=mock_auth, - dispatcher=mock_dispatcher + dispatcher=mock_dispatcher, + capability="config:read", ) - + assert endpoint.path == "/api/test" assert endpoint.auth == mock_auth assert endpoint.dispatcher == mock_dispatcher - assert endpoint.operation == "service" + assert endpoint.capability == "config:read" @pytest.mark.asyncio async def test_constant_endpoint_start_method(self): """Test ConstantEndpoint start method (should be no-op)""" mock_auth = MagicMock() mock_dispatcher = MagicMock() - - endpoint = ConstantEndpoint("/api/test", mock_auth, mock_dispatcher) - + + endpoint = ConstantEndpoint( + "/api/test", mock_auth, mock_dispatcher, + capability="config:read", + ) + # start() should complete without error await endpoint.start() @@ -44,10 +51,13 @@ class TestConstantEndpoint: mock_auth = MagicMock() mock_dispatcher = MagicMock() mock_app = MagicMock() - - endpoint = ConstantEndpoint("/api/test", mock_auth, mock_dispatcher) + + endpoint = ConstantEndpoint( + "/api/test", mock_auth, mock_dispatcher, + capability="config:read", + ) endpoint.add_routes(mock_app) - + # Verify add_routes was called with POST route mock_app.add_routes.assert_called_once() # The call should include web.post with the path and handler diff --git a/tests/unit/test_gateway/test_endpoint_i18n.py b/tests/unit/test_gateway/test_endpoint_i18n.py index ab693cdf..c2b51568 100644 --- a/tests/unit/test_gateway/test_endpoint_i18n.py +++ b/tests/unit/test_gateway/test_endpoint_i18n.py @@ -1,4 +1,12 @@ -"""Tests for Gateway i18n pack endpoint.""" +"""Tests for Gateway i18n pack endpoint. + +Production registers this endpoint with ``capability=PUBLIC``: the +login UI needs to render its own i18n strings before any user has +authenticated, so the endpoint is deliberately pre-auth. These +tests exercise the PUBLIC configuration — that is the production +contract. Behaviour of authenticated endpoints is covered by the +IamAuth tests in ``test_auth.py``. +""" import json from unittest.mock import MagicMock @@ -7,6 +15,7 @@ import pytest from aiohttp import web from trustgraph.gateway.endpoint.i18n import I18nPackEndpoint +from trustgraph.gateway.capabilities import PUBLIC class TestI18nPackEndpoint: @@ -17,23 +26,28 @@ class TestI18nPackEndpoint: endpoint = I18nPackEndpoint( endpoint_path="/api/v1/i18n/packs/{lang}", auth=mock_auth, + capability=PUBLIC, ) assert endpoint.path == "/api/v1/i18n/packs/{lang}" assert endpoint.auth == mock_auth - assert endpoint.operation == "service" + assert endpoint.capability == PUBLIC @pytest.mark.asyncio async def test_i18n_endpoint_start_method(self): mock_auth = MagicMock() - endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth) + endpoint = I18nPackEndpoint( + "/api/v1/i18n/packs/{lang}", mock_auth, capability=PUBLIC, + ) await endpoint.start() def test_add_routes_registers_get_handler(self): mock_auth = MagicMock() mock_app = MagicMock() - endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth) + endpoint = I18nPackEndpoint( + "/api/v1/i18n/packs/{lang}", mock_auth, capability=PUBLIC, + ) endpoint.add_routes(mock_app) mock_app.add_routes.assert_called_once() @@ -41,35 +55,55 @@ class TestI18nPackEndpoint: assert len(call_args) == 1 @pytest.mark.asyncio - async def test_handle_unauthorized_on_invalid_auth_scheme(self): + async def test_handle_returns_pack_without_authenticating(self): + """The PUBLIC endpoint serves the language pack without + invoking the auth handler at all — pre-login UI must be + reachable. The test uses an auth mock that raises if + touched, so any auth attempt by the endpoint is caught.""" mock_auth = MagicMock() - mock_auth.permitted.return_value = True - endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth) + def _should_not_be_called(*args, **kwargs): + raise AssertionError( + "PUBLIC endpoint must not invoke auth.authenticate" + ) + mock_auth.authenticate = _should_not_be_called + + endpoint = I18nPackEndpoint( + "/api/v1/i18n/packs/{lang}", mock_auth, capability=PUBLIC, + ) request = MagicMock() request.path = "/api/v1/i18n/packs/en" + # A caller-supplied Authorization header of any form should + # be ignored — PUBLIC means we don't look at it. request.headers = {"Authorization": "Token abc"} request.match_info = {"lang": "en"} - resp = await endpoint.handle(request) - assert isinstance(resp, web.HTTPUnauthorized) - - @pytest.mark.asyncio - async def test_handle_returns_pack_when_permitted(self): - mock_auth = MagicMock() - mock_auth.permitted.return_value = True - - endpoint = I18nPackEndpoint("/api/v1/i18n/packs/{lang}", mock_auth) - - request = MagicMock() - request.path = "/api/v1/i18n/packs/en" - request.headers = {} - request.match_info = {"lang": "en"} - resp = await endpoint.handle(request) assert resp.status == 200 payload = json.loads(resp.body.decode("utf-8")) assert isinstance(payload, dict) assert "cli.verify_system_status.title" in payload + + @pytest.mark.asyncio + async def test_handle_rejects_path_traversal(self): + """The ``lang`` path parameter is reflected through to the + filesystem-backed pack loader. The endpoint contains an + explicit defense against ``/`` and ``..`` in the value; this + test pins that defense in place.""" + mock_auth = MagicMock() + endpoint = I18nPackEndpoint( + "/api/v1/i18n/packs/{lang}", mock_auth, capability=PUBLIC, + ) + + for bad in ("../../etc/passwd", "en/../fr", "a/b"): + request = MagicMock() + request.path = f"/api/v1/i18n/packs/{bad}" + request.headers = {} + request.match_info = {"lang": bad} + + resp = await endpoint.handle(request) + assert isinstance(resp, web.HTTPBadRequest), ( + f"path-traversal defense did not reject lang={bad!r}" + ) diff --git a/tests/unit/test_gateway/test_endpoint_manager.py b/tests/unit/test_gateway/test_endpoint_manager.py index 4766f8d7..cf12565c 100644 --- a/tests/unit/test_gateway/test_endpoint_manager.py +++ b/tests/unit/test_gateway/test_endpoint_manager.py @@ -12,30 +12,24 @@ class TestEndpointManager: """Test cases for EndpointManager class""" def test_endpoint_manager_initialization(self): - """Test EndpointManager initialization creates all endpoints""" + """EndpointManager wires up the full endpoint set and + records dispatcher_manager / timeout on the instance.""" mock_dispatcher_manager = MagicMock() mock_auth = MagicMock() - - # Mock dispatcher methods - mock_dispatcher_manager.dispatch_global_service.return_value = MagicMock() - mock_dispatcher_manager.dispatch_socket.return_value = MagicMock() - mock_dispatcher_manager.dispatch_flow_service.return_value = MagicMock() - mock_dispatcher_manager.dispatch_flow_import.return_value = MagicMock() - mock_dispatcher_manager.dispatch_flow_export.return_value = MagicMock() - mock_dispatcher_manager.dispatch_core_import.return_value = MagicMock() - mock_dispatcher_manager.dispatch_core_export.return_value = MagicMock() - + + # The dispatcher_manager exposes a small set of factory + # methods — MagicMock auto-creates them, returning fresh + # MagicMocks on each call. manager = EndpointManager( dispatcher_manager=mock_dispatcher_manager, auth=mock_auth, prometheus_url="http://prometheus:9090", - timeout=300 + timeout=300, ) - + assert manager.dispatcher_manager == mock_dispatcher_manager assert manager.timeout == 300 - assert manager.services == {} - assert len(manager.endpoints) > 0 # Should have multiple endpoints + assert len(manager.endpoints) > 0 def test_endpoint_manager_with_default_timeout(self): """Test EndpointManager with default timeout value""" @@ -79,9 +73,15 @@ class TestEndpointManager: prometheus_url="http://test:9090" ) - # Verify all dispatcher methods were called during initialization + # Each dispatcher factory is invoked exactly once during + # construction — one per endpoint that needs a dedicated + # wire. dispatch_auth_iam is the dedicated factory for the + # AuthEndpoints forwarder (login / bootstrap / + # change-password), distinct from dispatch_global_service + # (the generic /api/v1/{kind} route). mock_dispatcher_manager.dispatch_global_service.assert_called_once() - mock_dispatcher_manager.dispatch_socket.assert_called() # Called twice + mock_dispatcher_manager.dispatch_auth_iam.assert_called_once() + mock_dispatcher_manager.dispatch_socket.assert_called_once() mock_dispatcher_manager.dispatch_flow_service.assert_called_once() mock_dispatcher_manager.dispatch_flow_import.assert_called_once() mock_dispatcher_manager.dispatch_flow_export.assert_called_once() diff --git a/tests/unit/test_gateway/test_endpoint_metrics.py b/tests/unit/test_gateway/test_endpoint_metrics.py index bacf551d..6d911bbd 100644 --- a/tests/unit/test_gateway/test_endpoint_metrics.py +++ b/tests/unit/test_gateway/test_endpoint_metrics.py @@ -12,31 +12,35 @@ class TestMetricsEndpoint: """Test cases for MetricsEndpoint class""" def test_metrics_endpoint_initialization(self): - """Test MetricsEndpoint initialization""" + """Construction records the configured capability on the + instance. In production MetricsEndpoint is gated by + 'metrics:read' so that's the natural value to pass.""" mock_auth = MagicMock() - + endpoint = MetricsEndpoint( prometheus_url="http://prometheus:9090", endpoint_path="/metrics", - auth=mock_auth + auth=mock_auth, + capability="metrics:read", ) - + assert endpoint.prometheus_url == "http://prometheus:9090" assert endpoint.path == "/metrics" assert endpoint.auth == mock_auth - assert endpoint.operation == "service" + assert endpoint.capability == "metrics:read" @pytest.mark.asyncio async def test_metrics_endpoint_start_method(self): """Test MetricsEndpoint start method (should be no-op)""" mock_auth = MagicMock() - + endpoint = MetricsEndpoint( prometheus_url="http://localhost:9090", endpoint_path="/metrics", - auth=mock_auth + auth=mock_auth, + capability="metrics:read", ) - + # start() should complete without error await endpoint.start() @@ -44,15 +48,16 @@ class TestMetricsEndpoint: """Test add_routes method registers GET route with wildcard path""" mock_auth = MagicMock() mock_app = MagicMock() - + endpoint = MetricsEndpoint( prometheus_url="http://prometheus:9090", endpoint_path="/metrics", - auth=mock_auth + auth=mock_auth, + capability="metrics:read", ) - + endpoint.add_routes(mock_app) - + # Verify add_routes was called with GET route mock_app.add_routes.assert_called_once() # The call should include web.get with wildcard path pattern diff --git a/tests/unit/test_gateway/test_endpoint_socket.py b/tests/unit/test_gateway/test_endpoint_socket.py index 83eb38c2..189bc32b 100644 --- a/tests/unit/test_gateway/test_endpoint_socket.py +++ b/tests/unit/test_gateway/test_endpoint_socket.py @@ -1,5 +1,12 @@ """ -Tests for Gateway Socket Endpoint +Tests for Gateway Socket Endpoint. + +In production the only SocketEndpoint registered with HTTP-layer +auth is ``/api/v1/socket`` using ``capability=AUTHENTICATED`` with +``in_band_auth=True`` (first-frame auth over the websocket frames, +not at the handshake). The tests below use AUTHENTICATED as the +representative capability; construction / worker / listener +behaviour is independent of which capability is configured. """ import pytest @@ -7,41 +14,47 @@ from unittest.mock import MagicMock, AsyncMock from aiohttp import WSMsgType from trustgraph.gateway.endpoint.socket import SocketEndpoint +from trustgraph.gateway.capabilities import AUTHENTICATED class TestSocketEndpoint: """Test cases for SocketEndpoint class""" def test_socket_endpoint_initialization(self): - """Test SocketEndpoint initialization""" + """Construction records the configured capability on the + instance. No permissive default is applied.""" mock_auth = MagicMock() mock_dispatcher = MagicMock() - + endpoint = SocketEndpoint( endpoint_path="/api/socket", auth=mock_auth, - dispatcher=mock_dispatcher + dispatcher=mock_dispatcher, + capability=AUTHENTICATED, ) - + assert endpoint.path == "/api/socket" assert endpoint.auth == mock_auth assert endpoint.dispatcher == mock_dispatcher - assert endpoint.operation == "socket" + assert endpoint.capability == AUTHENTICATED @pytest.mark.asyncio async def test_worker_method(self): """Test SocketEndpoint worker method""" mock_auth = MagicMock() mock_dispatcher = AsyncMock() - - endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher) - + + endpoint = SocketEndpoint( + "/api/socket", mock_auth, mock_dispatcher, + capability=AUTHENTICATED, + ) + mock_ws = MagicMock() mock_running = MagicMock() - + # Call worker method await endpoint.worker(mock_ws, mock_dispatcher, mock_running) - + # Verify dispatcher.run was called mock_dispatcher.run.assert_called_once() @@ -50,8 +63,11 @@ class TestSocketEndpoint: """Test SocketEndpoint listener method with text message""" mock_auth = MagicMock() mock_dispatcher = AsyncMock() - - endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher) + + endpoint = SocketEndpoint( + "/api/socket", mock_auth, mock_dispatcher, + capability=AUTHENTICATED, + ) # Mock websocket with text message mock_msg = MagicMock() @@ -80,8 +96,11 @@ class TestSocketEndpoint: """Test SocketEndpoint listener method with binary message""" mock_auth = MagicMock() mock_dispatcher = AsyncMock() - - endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher) + + endpoint = SocketEndpoint( + "/api/socket", mock_auth, mock_dispatcher, + capability=AUTHENTICATED, + ) # Mock websocket with binary message mock_msg = MagicMock() @@ -110,8 +129,11 @@ class TestSocketEndpoint: """Test SocketEndpoint listener method with close message""" mock_auth = MagicMock() mock_dispatcher = AsyncMock() - - endpoint = SocketEndpoint("/api/socket", mock_auth, mock_dispatcher) + + endpoint = SocketEndpoint( + "/api/socket", mock_auth, mock_dispatcher, + capability=AUTHENTICATED, + ) # Mock websocket with close message mock_msg = MagicMock() diff --git a/tests/unit/test_gateway/test_endpoint_stream.py b/tests/unit/test_gateway/test_endpoint_stream.py index b99946c8..a3b49465 100644 --- a/tests/unit/test_gateway/test_endpoint_stream.py +++ b/tests/unit/test_gateway/test_endpoint_stream.py @@ -12,48 +12,57 @@ class TestStreamEndpoint: """Test cases for StreamEndpoint class""" def test_stream_endpoint_initialization_with_post(self): - """Test StreamEndpoint initialization with POST method""" + """Construction records the configured capability on the + instance. StreamEndpoint is used in production for the + core-import / core-export / document-stream routes; a + document-write capability is a realistic value for a POST + stream (e.g. core-import).""" mock_auth = MagicMock() mock_dispatcher = MagicMock() - + endpoint = StreamEndpoint( endpoint_path="/api/stream", auth=mock_auth, dispatcher=mock_dispatcher, - method="POST" + capability="documents:write", + method="POST", ) - + assert endpoint.path == "/api/stream" assert endpoint.auth == mock_auth assert endpoint.dispatcher == mock_dispatcher - assert endpoint.operation == "service" + assert endpoint.capability == "documents:write" assert endpoint.method == "POST" def test_stream_endpoint_initialization_with_get(self): - """Test StreamEndpoint initialization with GET method""" + """GET stream — export-style endpoint, read capability.""" mock_auth = MagicMock() mock_dispatcher = MagicMock() - + endpoint = StreamEndpoint( endpoint_path="/api/stream", auth=mock_auth, dispatcher=mock_dispatcher, - method="GET" + capability="documents:read", + method="GET", ) - + assert endpoint.method == "GET" def test_stream_endpoint_initialization_default_method(self): - """Test StreamEndpoint initialization with default POST method""" + """Test StreamEndpoint initialization with default POST method. + The method default is cosmetic; the capability is not + defaulted — it is always required.""" mock_auth = MagicMock() mock_dispatcher = MagicMock() - + endpoint = StreamEndpoint( endpoint_path="/api/stream", auth=mock_auth, - dispatcher=mock_dispatcher + dispatcher=mock_dispatcher, + capability="documents:write", ) - + assert endpoint.method == "POST" # Default value @pytest.mark.asyncio @@ -61,9 +70,12 @@ class TestStreamEndpoint: """Test StreamEndpoint start method (should be no-op)""" mock_auth = MagicMock() mock_dispatcher = MagicMock() - - endpoint = StreamEndpoint("/api/stream", mock_auth, mock_dispatcher) - + + endpoint = StreamEndpoint( + "/api/stream", mock_auth, mock_dispatcher, + capability="documents:write", + ) + # start() should complete without error await endpoint.start() @@ -72,16 +84,17 @@ class TestStreamEndpoint: mock_auth = MagicMock() mock_dispatcher = MagicMock() mock_app = MagicMock() - + endpoint = StreamEndpoint( endpoint_path="/api/stream", auth=mock_auth, dispatcher=mock_dispatcher, - method="POST" + capability="documents:write", + method="POST", ) - + endpoint.add_routes(mock_app) - + # Verify add_routes was called with POST route mock_app.add_routes.assert_called_once() call_args = mock_app.add_routes.call_args[0][0] @@ -92,16 +105,17 @@ class TestStreamEndpoint: mock_auth = MagicMock() mock_dispatcher = MagicMock() mock_app = MagicMock() - + endpoint = StreamEndpoint( endpoint_path="/api/stream", auth=mock_auth, dispatcher=mock_dispatcher, - method="GET" + capability="documents:read", + method="GET", ) - + endpoint.add_routes(mock_app) - + # Verify add_routes was called with GET route mock_app.add_routes.assert_called_once() call_args = mock_app.add_routes.call_args[0][0] @@ -112,13 +126,14 @@ class TestStreamEndpoint: mock_auth = MagicMock() mock_dispatcher = MagicMock() mock_app = MagicMock() - + endpoint = StreamEndpoint( endpoint_path="/api/stream", auth=mock_auth, dispatcher=mock_dispatcher, - method="INVALID" + capability="documents:write", + method="INVALID", ) - + with pytest.raises(RuntimeError, match="Bad method"): endpoint.add_routes(mock_app) \ No newline at end of file diff --git a/tests/unit/test_gateway/test_endpoint_variable.py b/tests/unit/test_gateway/test_endpoint_variable.py index ffaf4e9a..1cdc8f9f 100644 --- a/tests/unit/test_gateway/test_endpoint_variable.py +++ b/tests/unit/test_gateway/test_endpoint_variable.py @@ -12,29 +12,36 @@ class TestVariableEndpoint: """Test cases for VariableEndpoint class""" def test_variable_endpoint_initialization(self): - """Test VariableEndpoint initialization""" + """Construction records the configured capability on the + instance. VariableEndpoint is used in production for the + /api/v1/{kind} admin-scoped global service routes, so a + write-side capability is a realistic value for the test.""" mock_auth = MagicMock() mock_dispatcher = MagicMock() - + endpoint = VariableEndpoint( endpoint_path="/api/variable", auth=mock_auth, - dispatcher=mock_dispatcher + dispatcher=mock_dispatcher, + capability="config:write", ) - + assert endpoint.path == "/api/variable" assert endpoint.auth == mock_auth assert endpoint.dispatcher == mock_dispatcher - assert endpoint.operation == "service" + assert endpoint.capability == "config:write" @pytest.mark.asyncio async def test_variable_endpoint_start_method(self): """Test VariableEndpoint start method (should be no-op)""" mock_auth = MagicMock() mock_dispatcher = MagicMock() - - endpoint = VariableEndpoint("/api/var", mock_auth, mock_dispatcher) - + + endpoint = VariableEndpoint( + "/api/var", mock_auth, mock_dispatcher, + capability="config:write", + ) + # start() should complete without error await endpoint.start() @@ -43,10 +50,13 @@ class TestVariableEndpoint: mock_auth = MagicMock() mock_dispatcher = MagicMock() mock_app = MagicMock() - - endpoint = VariableEndpoint("/api/variable", mock_auth, mock_dispatcher) + + endpoint = VariableEndpoint( + "/api/variable", mock_auth, mock_dispatcher, + capability="config:write", + ) endpoint.add_routes(mock_app) - + # Verify add_routes was called with POST route mock_app.add_routes.assert_called_once() call_args = mock_app.add_routes.call_args[0][0] diff --git a/tests/unit/test_gateway/test_service.py b/tests/unit/test_gateway/test_service.py index 71428db4..107e6819 100644 --- a/tests/unit/test_gateway/test_service.py +++ b/tests/unit/test_gateway/test_service.py @@ -1,355 +1,179 @@ """ -Tests for Gateway Service API +Tests for gateway/service.py — the Api class that wires together +the pub/sub backend, IAM auth, config receiver, dispatcher manager, +and endpoint manager. + +The legacy ``GATEWAY_SECRET`` / ``default_api_token`` / allow-all +surface is gone, so the tests here focus on the Api's construction +and composition rather than the removed auth behaviour. IamAuth's +own behaviour is covered in test_auth.py. """ import pytest -import asyncio -from unittest.mock import Mock, patch, MagicMock, AsyncMock +from unittest.mock import AsyncMock, Mock, patch from aiohttp import web -import pulsar -from trustgraph.gateway.service import Api, run, default_pulsar_host, default_prometheus_url, default_timeout, default_port, default_api_token - -# Tests for Gateway Service API +from trustgraph.gateway.service import ( + Api, + default_pulsar_host, default_prometheus_url, + default_timeout, default_port, +) +from trustgraph.gateway.auth import IamAuth -class TestApi: - """Test cases for Api class""" - +# -- constants ------------------------------------------------------------- - def test_api_initialization_with_defaults(self): - """Test Api initialization with default values""" - with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: - mock_backend = Mock() - mock_get_pubsub.return_value = mock_backend - api = Api() +class TestDefaults: - assert api.port == default_port - assert api.timeout == default_timeout - assert api.pulsar_host == default_pulsar_host - assert api.pulsar_api_key is None - assert api.prometheus_url == default_prometheus_url + "/" - assert api.auth.allow_all is True + def test_exports_default_constants(self): + # These are consumed by CLIs / tests / docs. Sanity-check + # that they're the expected shape. + assert default_port == 8088 + assert default_timeout == 600 + assert default_pulsar_host.startswith("pulsar://") + assert default_prometheus_url.startswith("http") - # Verify get_pubsub was called - mock_get_pubsub.assert_called_once() - def test_api_initialization_with_custom_config(self): - """Test Api initialization with custom configuration""" +# -- Api construction ------------------------------------------------------ + + +@pytest.fixture +def mock_backend(): + return Mock() + + +@pytest.fixture +def api(mock_backend): + with patch( + "trustgraph.gateway.service.get_pubsub", + return_value=mock_backend, + ): + yield Api() + + +class TestApiConstruction: + + def test_defaults(self, api): + assert api.port == default_port + assert api.timeout == default_timeout + assert api.pulsar_host == default_pulsar_host + assert api.pulsar_api_key is None + # prometheus_url gets normalised with a trailing slash + assert api.prometheus_url == default_prometheus_url + "/" + + def test_auth_is_iam_backed(self, api): + # Any Api always gets an IamAuth. There is no "no auth" mode + # (GATEWAY_SECRET / allow_all has been removed — see IAM spec). + assert isinstance(api.auth, IamAuth) + + def test_components_wired(self, api): + assert api.config_receiver is not None + assert api.dispatcher_manager is not None + assert api.endpoint_manager is not None + + def test_dispatcher_manager_has_auth(self, api): + # The Mux uses this handle for first-frame socket auth. + assert api.dispatcher_manager.auth is api.auth + + def test_custom_config(self, mock_backend): config = { "port": 9000, "timeout": 300, "pulsar_host": "pulsar://custom-host:6650", - "pulsar_api_key": "test-api-key", - "pulsar_listener": "custom-listener", + "pulsar_api_key": "custom-key", "prometheus_url": "http://custom-prometheus:9090", - "api_token": "secret-token" } + with patch( + "trustgraph.gateway.service.get_pubsub", + return_value=mock_backend, + ): + a = Api(**config) - with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: - mock_backend = Mock() - mock_get_pubsub.return_value = mock_backend + assert a.port == 9000 + assert a.timeout == 300 + assert a.pulsar_host == "pulsar://custom-host:6650" + assert a.pulsar_api_key == "custom-key" + # Trailing slash added. + assert a.prometheus_url == "http://custom-prometheus:9090/" - api = Api(**config) + def test_prometheus_url_already_has_trailing_slash(self, mock_backend): + with patch( + "trustgraph.gateway.service.get_pubsub", + return_value=mock_backend, + ): + a = Api(prometheus_url="http://p:9090/") + assert a.prometheus_url == "http://p:9090/" - assert api.port == 9000 - assert api.timeout == 300 - assert api.pulsar_host == "pulsar://custom-host:6650" - assert api.pulsar_api_key == "test-api-key" - assert api.prometheus_url == "http://custom-prometheus:9090/" - assert api.auth.token == "secret-token" - assert api.auth.allow_all is False + def test_queue_overrides_parsed_for_config(self, mock_backend): + with patch( + "trustgraph.gateway.service.get_pubsub", + return_value=mock_backend, + ): + a = Api( + config_request_queue="alt-config-req", + config_response_queue="alt-config-resp", + ) + overrides = a.dispatcher_manager.queue_overrides + assert overrides.get("config", {}).get("request") == "alt-config-req" + assert overrides.get("config", {}).get("response") == "alt-config-resp" - # Verify get_pubsub was called with config - mock_get_pubsub.assert_called_once_with(**config) - def test_api_initialization_with_pulsar_api_key(self): - """Test Api initialization with Pulsar API key authentication""" - with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: - mock_get_pubsub.return_value = Mock() +# -- app_factory ----------------------------------------------------------- - api = Api(pulsar_api_key="test-key") - # Verify api key was stored - assert api.pulsar_api_key == "test-key" - mock_get_pubsub.assert_called_once() - - def test_api_initialization_prometheus_url_normalization(self): - """Test that prometheus_url gets normalized with trailing slash""" - with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: - mock_get_pubsub.return_value = Mock() - - # Test URL without trailing slash - api = Api(prometheus_url="http://prometheus:9090") - assert api.prometheus_url == "http://prometheus:9090/" - - # Test URL with trailing slash - api = Api(prometheus_url="http://prometheus:9090/") - assert api.prometheus_url == "http://prometheus:9090/" - - def test_api_initialization_empty_api_token_means_no_auth(self): - """Test that empty API token results in allow_all authentication""" - with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: - mock_get_pubsub.return_value = Mock() - - api = Api(api_token="") - assert api.auth.allow_all is True - - def test_api_initialization_none_api_token_means_no_auth(self): - """Test that None API token results in allow_all authentication""" - with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: - mock_get_pubsub.return_value = Mock() - - api = Api(api_token=None) - assert api.auth.allow_all is True +class TestAppFactory: @pytest.mark.asyncio - async def test_app_factory_creates_application(self): - """Test that app_factory creates aiohttp application""" - with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: - mock_get_pubsub.return_value = Mock() - - api = Api() - - # Mock the dependencies - api.config_receiver = Mock() - api.config_receiver.start = AsyncMock() - api.endpoint_manager = Mock() - api.endpoint_manager.add_routes = Mock() - api.endpoint_manager.start = AsyncMock() - - app = await api.app_factory() - - assert isinstance(app, web.Application) - assert app._client_max_size == 256 * 1024 * 1024 - - # Verify that config receiver was started - api.config_receiver.start.assert_called_once() - - # Verify that endpoint manager was configured - api.endpoint_manager.add_routes.assert_called_once_with(app) - api.endpoint_manager.start.assert_called_once() + async def test_creates_aiohttp_app(self, api): + # Stub out the long-tail dependencies that reach out to IAM / + # pub/sub so we can exercise the factory in isolation. + api.auth.start = AsyncMock() + api.config_receiver = Mock() + api.config_receiver.start = AsyncMock() + api.endpoint_manager = Mock() + api.endpoint_manager.add_routes = Mock() + api.endpoint_manager.start = AsyncMock() + api.endpoints = [] + + app = await api.app_factory() + + assert isinstance(app, web.Application) + assert app._client_max_size == 256 * 1024 * 1024 + api.auth.start.assert_called_once() + api.config_receiver.start.assert_called_once() + api.endpoint_manager.add_routes.assert_called_once_with(app) + api.endpoint_manager.start.assert_called_once() @pytest.mark.asyncio - async def test_app_factory_with_custom_endpoints(self): - """Test app_factory with custom endpoints""" - with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: - mock_get_pubsub.return_value = Mock() - - api = Api() - - # Mock custom endpoints - mock_endpoint1 = Mock() - mock_endpoint1.add_routes = Mock() - mock_endpoint1.start = AsyncMock() - - mock_endpoint2 = Mock() - mock_endpoint2.add_routes = Mock() - mock_endpoint2.start = AsyncMock() - - api.endpoints = [mock_endpoint1, mock_endpoint2] - - # Mock the dependencies - api.config_receiver = Mock() - api.config_receiver.start = AsyncMock() - api.endpoint_manager = Mock() - api.endpoint_manager.add_routes = Mock() - api.endpoint_manager.start = AsyncMock() - - app = await api.app_factory() - - # Verify custom endpoints were configured - mock_endpoint1.add_routes.assert_called_once_with(app) - mock_endpoint1.start.assert_called_once() - mock_endpoint2.add_routes.assert_called_once_with(app) - mock_endpoint2.start.assert_called_once() + async def test_auth_start_runs_before_accepting_traffic(self, api): + """``auth.start()`` fetches the IAM signing key, and must + complete (or time out) before the gateway begins accepting + requests. It's the first await in app_factory.""" + order = [] - def test_run_method_calls_web_run_app(self): - """Test that run method calls web.run_app""" - with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub, \ - patch('aiohttp.web.run_app') as mock_run_app: - mock_get_pubsub.return_value = Mock() + # AsyncMock.side_effect expects a sync callable (its return + # value becomes the coroutine's return); a plain list.append + # avoids the "coroutine was never awaited" trap of an async + # side_effect. + api.auth.start = AsyncMock( + side_effect=lambda: order.append("auth"), + ) + api.config_receiver = Mock() + api.config_receiver.start = AsyncMock( + side_effect=lambda: order.append("config"), + ) + api.endpoint_manager = Mock() + api.endpoint_manager.add_routes = Mock() + api.endpoint_manager.start = AsyncMock( + side_effect=lambda: order.append("endpoints"), + ) + api.endpoints = [] - # Api.run() passes self.app_factory() — a coroutine — to - # web.run_app, which would normally consume it inside its own - # event loop. Since we mock run_app, close the coroutine here - # so it doesn't leak as an "unawaited coroutine" RuntimeWarning. - def _consume_coro(coro, **kwargs): - coro.close() - mock_run_app.side_effect = _consume_coro + await api.app_factory() - api = Api(port=8080) - api.run() - - # Verify run_app was called once with the correct port - mock_run_app.assert_called_once() - args, kwargs = mock_run_app.call_args - assert len(args) == 1 # Should have one positional arg (the coroutine) - assert kwargs == {'port': 8080} # Should have port keyword arg - - def test_api_components_initialization(self): - """Test that all API components are properly initialized""" - with patch('trustgraph.gateway.service.get_pubsub') as mock_get_pubsub: - mock_get_pubsub.return_value = Mock() - - api = Api() - - # Verify all components are initialized - assert api.config_receiver is not None - assert api.dispatcher_manager is not None - assert api.endpoint_manager is not None - assert api.endpoints == [] - - # Verify component relationships - assert api.dispatcher_manager.backend == api.pubsub_backend - assert api.dispatcher_manager.config_receiver == api.config_receiver - assert api.endpoint_manager.dispatcher_manager == api.dispatcher_manager - # EndpointManager doesn't store auth directly, it passes it to individual endpoints - - -class TestRunFunction: - """Test cases for the run() function""" - - def test_run_function_with_metrics_enabled(self): - """Test run function with metrics enabled""" - import warnings - # Suppress the specific async warning with a broader pattern - warnings.filterwarnings("ignore", message=".*Api.app_factory.*was never awaited", category=RuntimeWarning) - - with patch('argparse.ArgumentParser.parse_args') as mock_parse_args, \ - patch('trustgraph.gateway.service.start_http_server') as mock_start_http_server: - - # Mock command line arguments - mock_args = Mock() - mock_args.metrics = True - mock_args.metrics_port = 8000 - mock_parse_args.return_value = mock_args - - # Create a simple mock instance without any async methods - mock_api_instance = Mock() - mock_api_instance.run = Mock() - - # Create a mock Api class without importing the real one - mock_api = Mock(return_value=mock_api_instance) - - # Patch using context manager to avoid importing the real Api class - with patch('trustgraph.gateway.service.Api', mock_api): - # Mock vars() to return a dict - with patch('builtins.vars') as mock_vars: - mock_vars.return_value = { - 'metrics': True, - 'metrics_port': 8000, - 'pulsar_host': default_pulsar_host, - 'timeout': default_timeout - } - - run() - - # Verify metrics server was started - mock_start_http_server.assert_called_once_with(8000) - - # Verify Api was created and run was called - mock_api.assert_called_once() - mock_api_instance.run.assert_called_once() - - @patch('trustgraph.gateway.service.start_http_server') - @patch('argparse.ArgumentParser.parse_args') - def test_run_function_with_metrics_disabled(self, mock_parse_args, mock_start_http_server): - """Test run function with metrics disabled""" - # Mock command line arguments - mock_args = Mock() - mock_args.metrics = False - mock_parse_args.return_value = mock_args - - # Create a simple mock instance without any async methods - mock_api_instance = Mock() - mock_api_instance.run = Mock() - - # Patch the Api class inside the test without using decorators - with patch('trustgraph.gateway.service.Api') as mock_api: - mock_api.return_value = mock_api_instance - - # Mock vars() to return a dict - with patch('builtins.vars') as mock_vars: - mock_vars.return_value = { - 'metrics': False, - 'metrics_port': 8000, - 'pulsar_host': default_pulsar_host, - 'timeout': default_timeout - } - - run() - - # Verify metrics server was NOT started - mock_start_http_server.assert_not_called() - - # Verify Api was created and run was called - mock_api.assert_called_once() - mock_api_instance.run.assert_called_once() - - @patch('argparse.ArgumentParser.parse_args') - def test_run_function_argument_parsing(self, mock_parse_args): - """Test that run function properly parses command line arguments""" - # Mock command line arguments - mock_args = Mock() - mock_args.metrics = False - mock_parse_args.return_value = mock_args - - # Create a simple mock instance without any async methods - mock_api_instance = Mock() - mock_api_instance.run = Mock() - - # Mock vars() to return a dict with all expected arguments - expected_args = { - 'pulsar_host': 'pulsar://test:6650', - 'pulsar_api_key': 'test-key', - 'pulsar_listener': 'test-listener', - 'prometheus_url': 'http://test-prometheus:9090', - 'port': 9000, - 'timeout': 300, - 'api_token': 'secret', - 'log_level': 'INFO', - 'metrics': False, - 'metrics_port': 8001 - } - - # Patch the Api class inside the test without using decorators - with patch('trustgraph.gateway.service.Api') as mock_api: - mock_api.return_value = mock_api_instance - - with patch('builtins.vars') as mock_vars: - mock_vars.return_value = expected_args - - run() - - # Verify Api was created with the parsed arguments - mock_api.assert_called_once_with(**expected_args) - mock_api_instance.run.assert_called_once() - - def test_run_function_creates_argument_parser(self): - """Test that run function creates argument parser with correct arguments""" - with patch('argparse.ArgumentParser') as mock_parser_class: - mock_parser = Mock() - mock_parser_class.return_value = mock_parser - mock_parser.parse_args.return_value = Mock(metrics=False) - - with patch('trustgraph.gateway.service.Api') as mock_api, \ - patch('builtins.vars') as mock_vars: - mock_vars.return_value = {'metrics': False} - mock_api.return_value = Mock() - - run() - - # Verify ArgumentParser was created - mock_parser_class.assert_called_once() - - # Verify add_argument was called for each expected argument - expected_arguments = [ - 'pulsar-host', 'pulsar-api-key', 'pulsar-listener', - 'prometheus-url', 'port', 'timeout', 'api-token', - 'log-level', 'metrics', 'metrics-port' - ] - - # Check that add_argument was called multiple times (once for each arg) - assert mock_parser.add_argument.call_count >= len(expected_arguments) \ No newline at end of file + # auth.start must be first (before config receiver, before + # any endpoint starts). + assert order[0] == "auth" + # All three must have run. + assert set(order) == {"auth", "config", "endpoints"} diff --git a/tests/unit/test_gateway/test_socket_graceful_shutdown.py b/tests/unit/test_gateway/test_socket_graceful_shutdown.py index 1a63227d..23f22d30 100644 --- a/tests/unit/test_gateway/test_socket_graceful_shutdown.py +++ b/tests/unit/test_gateway/test_socket_graceful_shutdown.py @@ -1,4 +1,15 @@ -"""Unit tests for SocketEndpoint graceful shutdown functionality.""" +"""Unit tests for SocketEndpoint graceful shutdown functionality. + +These tests exercise SocketEndpoint in its handshake-auth +configuration (``in_band_auth=False``) — the mode used in production +for the flow import/export streaming endpoints. The mux socket at +``/api/v1/socket`` uses ``in_band_auth=True`` instead, where the +handshake always accepts and authentication runs on the first +WebSocket frame; that path is covered by the Mux tests. + +Every endpoint constructor here passes an explicit capability — no +permissive default is relied upon. +""" import pytest import asyncio @@ -6,13 +17,31 @@ from unittest.mock import AsyncMock, MagicMock, patch from aiohttp import web, WSMsgType from trustgraph.gateway.endpoint.socket import SocketEndpoint from trustgraph.gateway.running import Running +from trustgraph.gateway.auth import Identity + + +# Representative capability used across these tests — corresponds to +# the flow-import streaming endpoint pattern that uses this class. +TEST_CAP = "graph:write" + + +def _valid_identity(roles=("admin",)): + return Identity( + user_id="test-user", + workspace="default", + roles=list(roles), + source="api-key", + ) @pytest.fixture def mock_auth(): - """Mock authentication service.""" + """Mock IAM-backed authenticator. Successful by default — + ``authenticate`` returns a valid admin identity. Tests that + need the auth failure path override the ``authenticate`` + attribute locally.""" auth = MagicMock() - auth.permitted.return_value = True + auth.authenticate = AsyncMock(return_value=_valid_identity()) return auth @@ -25,7 +54,7 @@ def mock_dispatcher_factory(): dispatcher.receive = AsyncMock() dispatcher.destroy = AsyncMock() return dispatcher - + return dispatcher_factory @@ -35,7 +64,8 @@ def socket_endpoint(mock_auth, mock_dispatcher_factory): return SocketEndpoint( endpoint_path="/test-socket", auth=mock_auth, - dispatcher=mock_dispatcher_factory + dispatcher=mock_dispatcher_factory, + capability=TEST_CAP, ) @@ -61,7 +91,10 @@ def mock_request(): @pytest.mark.asyncio async def test_listener_graceful_shutdown_on_close(): """Test listener handles websocket close gracefully.""" - socket_endpoint = SocketEndpoint("/test", MagicMock(), AsyncMock()) + socket_endpoint = SocketEndpoint( + "/test", MagicMock(), AsyncMock(), + capability=TEST_CAP, + ) # Mock websocket that closes after one message ws = AsyncMock() @@ -99,9 +132,9 @@ async def test_listener_graceful_shutdown_on_close(): @pytest.mark.asyncio async def test_handle_normal_flow(): - """Test normal websocket handling flow.""" + """Valid bearer → handshake accepted, dispatcher created.""" mock_auth = MagicMock() - mock_auth.permitted.return_value = True + mock_auth.authenticate = AsyncMock(return_value=_valid_identity()) dispatcher_created = False async def mock_dispatcher_factory(ws, running, match_info): @@ -111,7 +144,10 @@ async def test_handle_normal_flow(): dispatcher.destroy = AsyncMock() return dispatcher - socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory) + socket_endpoint = SocketEndpoint( + "/test", mock_auth, mock_dispatcher_factory, + capability=TEST_CAP, + ) request = MagicMock() request.query = {"token": "valid-token"} @@ -155,7 +191,7 @@ async def test_handle_normal_flow(): async def test_handle_exception_group_cleanup(): """Test exception group triggers dispatcher cleanup.""" mock_auth = MagicMock() - mock_auth.permitted.return_value = True + mock_auth.authenticate = AsyncMock(return_value=_valid_identity()) mock_dispatcher = AsyncMock() mock_dispatcher.destroy = AsyncMock() @@ -163,7 +199,10 @@ async def test_handle_exception_group_cleanup(): async def mock_dispatcher_factory(ws, running, match_info): return mock_dispatcher - socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory) + socket_endpoint = SocketEndpoint( + "/test", mock_auth, mock_dispatcher_factory, + capability=TEST_CAP, + ) request = MagicMock() request.query = {"token": "valid-token"} @@ -222,7 +261,7 @@ async def test_handle_exception_group_cleanup(): async def test_handle_dispatcher_cleanup_timeout(): """Test dispatcher cleanup with timeout.""" mock_auth = MagicMock() - mock_auth.permitted.return_value = True + mock_auth.authenticate = AsyncMock(return_value=_valid_identity()) # Mock dispatcher that takes long to destroy mock_dispatcher = AsyncMock() @@ -231,7 +270,10 @@ async def test_handle_dispatcher_cleanup_timeout(): async def mock_dispatcher_factory(ws, running, match_info): return mock_dispatcher - socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory) + socket_endpoint = SocketEndpoint( + "/test", mock_auth, mock_dispatcher_factory, + capability=TEST_CAP, + ) request = MagicMock() request.query = {"token": "valid-token"} @@ -285,49 +327,67 @@ async def test_handle_dispatcher_cleanup_timeout(): @pytest.mark.asyncio async def test_handle_unauthorized_request(): - """Test handling of unauthorized requests.""" + """A bearer that the IAM layer rejects causes the handshake to + fail with 401. IamAuth surfaces an HTTPUnauthorized; the + endpoint propagates it. Note that the endpoint intentionally + does NOT distinguish 'bad token', 'expired', 'revoked', etc. — + that's the IAM error-masking policy.""" mock_auth = MagicMock() - mock_auth.permitted.return_value = False # Unauthorized - - socket_endpoint = SocketEndpoint("/test", mock_auth, AsyncMock()) - + mock_auth.authenticate = AsyncMock(side_effect=web.HTTPUnauthorized( + text='{"error":"auth failure"}', + content_type="application/json", + )) + + socket_endpoint = SocketEndpoint( + "/test", mock_auth, AsyncMock(), + capability=TEST_CAP, + ) + request = MagicMock() request.query = {"token": "invalid-token"} - + result = await socket_endpoint.handle(request) - - # Should return HTTP 401 + assert isinstance(result, web.HTTPUnauthorized) - - # Should have checked permission - mock_auth.permitted.assert_called_once_with("invalid-token", "socket") + # authenticate must have been invoked with a synthetic request + # carrying Bearer . The endpoint wraps the query- + # string token into an Authorization header for a uniform auth + # path — the IAM layer does not look at query strings directly. + mock_auth.authenticate.assert_called_once() + passed_req = mock_auth.authenticate.call_args.args[0] + assert passed_req.headers["Authorization"] == "Bearer invalid-token" @pytest.mark.asyncio async def test_handle_missing_token(): - """Test handling of requests with missing token.""" + """Request with no ``token`` query param → 401 before any + IAM call is made (cheap short-circuit).""" mock_auth = MagicMock() - mock_auth.permitted.return_value = False - - socket_endpoint = SocketEndpoint("/test", mock_auth, AsyncMock()) - + mock_auth.authenticate = AsyncMock( + side_effect=AssertionError( + "authenticate must not be invoked when no token is present" + ), + ) + + socket_endpoint = SocketEndpoint( + "/test", mock_auth, AsyncMock(), + capability=TEST_CAP, + ) + request = MagicMock() request.query = {} # No token - + result = await socket_endpoint.handle(request) - - # Should return HTTP 401 + assert isinstance(result, web.HTTPUnauthorized) - - # Should have checked permission with empty token - mock_auth.permitted.assert_called_once_with("", "socket") + mock_auth.authenticate.assert_not_called() @pytest.mark.asyncio async def test_handle_websocket_already_closed(): """Test handling when websocket is already closed.""" mock_auth = MagicMock() - mock_auth.permitted.return_value = True + mock_auth.authenticate = AsyncMock(return_value=_valid_identity()) mock_dispatcher = AsyncMock() mock_dispatcher.destroy = AsyncMock() @@ -335,7 +395,10 @@ async def test_handle_websocket_already_closed(): async def mock_dispatcher_factory(ws, running, match_info): return mock_dispatcher - socket_endpoint = SocketEndpoint("/test", mock_auth, mock_dispatcher_factory) + socket_endpoint = SocketEndpoint( + "/test", mock_auth, mock_dispatcher_factory, + capability=TEST_CAP, + ) request = MagicMock() request.query = {"token": "valid-token"} diff --git a/trustgraph-base/trustgraph/api/async_socket_client.py b/trustgraph-base/trustgraph/api/async_socket_client.py index e5d553ea..ca9146b9 100644 --- a/trustgraph-base/trustgraph/api/async_socket_client.py +++ b/trustgraph-base/trustgraph/api/async_socket_client.py @@ -49,21 +49,67 @@ class AsyncSocketClient: return f"ws://{url}" def _build_ws_url(self): - ws_url = f"{self.url.rstrip('/')}/api/v1/socket" - if self.token: - ws_url = f"{ws_url}?token={self.token}" - return ws_url + # /api/v1/socket uses the first-frame auth protocol — the + # token is sent as the first frame after connecting rather + # than in the URL. This avoids browser issues with 401 on + # the WebSocket handshake and lets long-lived sockets + # refresh credentials mid-session. + return f"{self.url.rstrip('/')}/api/v1/socket" async def connect(self): - """Establish the persistent websocket connection.""" + """Establish the persistent websocket connection and run the + first-frame auth handshake.""" if self._connected: return + if not self.token: + raise ProtocolException( + "AsyncSocketClient requires a token for first-frame " + "auth against /api/v1/socket" + ) + ws_url = self._build_ws_url() self._connect_cm = websockets.connect( ws_url, ping_interval=20, ping_timeout=self.timeout ) self._socket = await self._connect_cm.__aenter__() + + # First-frame auth: send {"type":"auth","token":"..."} and + # wait for auth-ok / auth-failed. Run before starting the + # reader task so the response isn't consumed by the reader's + # id-based routing. + await self._socket.send(json.dumps({ + "type": "auth", "token": self.token, + })) + try: + raw = await asyncio.wait_for( + self._socket.recv(), timeout=self.timeout, + ) + except asyncio.TimeoutError: + await self._socket.close() + raise ProtocolException("Timeout waiting for auth response") + + try: + resp = json.loads(raw) + except Exception: + await self._socket.close() + raise ProtocolException( + f"Unexpected non-JSON auth response: {raw!r}" + ) + + if resp.get("type") == "auth-ok": + self.workspace = resp.get("workspace", self.workspace) + elif resp.get("type") == "auth-failed": + await self._socket.close() + raise ProtocolException( + f"auth failure: {resp.get('error', 'unknown')}" + ) + else: + await self._socket.close() + raise ProtocolException( + f"Unexpected auth response: {resp!r}" + ) + self._connected = True self._reader_task = asyncio.create_task(self._reader()) diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index 4eade3e8..aeb15f85 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -112,10 +112,10 @@ class SocketClient: return f"ws://{url}" def _build_ws_url(self): - ws_url = f"{self.url.rstrip('/')}/api/v1/socket" - if self.token: - ws_url = f"{ws_url}?token={self.token}" - return ws_url + # /api/v1/socket uses the first-frame auth protocol — the + # token is sent as the first frame after connecting rather + # than in the URL. + return f"{self.url.rstrip('/')}/api/v1/socket" def _get_loop(self): """Get or create the event loop, reusing across calls.""" @@ -132,15 +132,58 @@ class SocketClient: return self._loop async def _ensure_connected(self): - """Lazily establish the persistent websocket connection.""" + """Lazily establish the persistent websocket connection and + run the first-frame auth handshake.""" if self._connected: return + if not self.token: + raise ProtocolException( + "SocketClient requires a token for first-frame auth " + "against /api/v1/socket" + ) + ws_url = self._build_ws_url() self._connect_cm = websockets.connect( ws_url, ping_interval=20, ping_timeout=self.timeout ) self._socket = await self._connect_cm.__aenter__() + + # First-frame auth — run before starting the reader so the + # auth-ok / auth-failed response isn't consumed by the reader + # loop's id-based routing. + await self._socket.send(json.dumps({ + "type": "auth", "token": self.token, + })) + try: + raw = await asyncio.wait_for( + self._socket.recv(), timeout=self.timeout, + ) + except asyncio.TimeoutError: + await self._socket.close() + raise ProtocolException("Timeout waiting for auth response") + + try: + resp = json.loads(raw) + except Exception: + await self._socket.close() + raise ProtocolException( + f"Unexpected non-JSON auth response: {raw!r}" + ) + + if resp.get("type") == "auth-ok": + self.workspace = resp.get("workspace", self.workspace) + elif resp.get("type") == "auth-failed": + await self._socket.close() + raise ProtocolException( + f"auth failure: {resp.get('error', 'unknown')}" + ) + else: + await self._socket.close() + raise ProtocolException( + f"Unexpected auth response: {resp!r}" + ) + self._connected = True self._reader_task = asyncio.create_task(self._reader()) diff --git a/trustgraph-base/trustgraph/base/iam_client.py b/trustgraph-base/trustgraph/base/iam_client.py new file mode 100644 index 00000000..5cfda7c8 --- /dev/null +++ b/trustgraph-base/trustgraph/base/iam_client.py @@ -0,0 +1,279 @@ + +from . request_response_spec import RequestResponse, RequestResponseSpec +from .. schema import ( + IamRequest, IamResponse, + UserInput, WorkspaceInput, ApiKeyInput, +) + +IAM_TIMEOUT = 10 + + +class IamClient(RequestResponse): + """Client for the IAM service request/response pub/sub protocol. + + Mirrors ``ConfigClient``: a thin wrapper around ``RequestResponse`` + that knows the IAM request / response schemas. Only the subset of + operations actually implemented by the server today has helper + methods here; callers that need an unimplemented operation can + build ``IamRequest`` and call ``request()`` directly. + """ + + async def _request(self, timeout=IAM_TIMEOUT, **kwargs): + resp = await self.request( + IamRequest(**kwargs), + timeout=timeout, + ) + if resp.error: + raise RuntimeError( + f"{resp.error.type}: {resp.error.message}" + ) + return resp + + async def bootstrap(self, timeout=IAM_TIMEOUT): + """Initial-run IAM self-seed. Returns a tuple of + ``(admin_user_id, admin_api_key_plaintext)``. Both are empty + strings on repeat calls — the operation is a no-op once the + IAM tables are populated.""" + resp = await self._request( + operation="bootstrap", timeout=timeout, + ) + return resp.bootstrap_admin_user_id, resp.bootstrap_admin_api_key + + async def resolve_api_key(self, api_key, timeout=IAM_TIMEOUT): + """Resolve a plaintext API key to its identity triple. + + Returns ``(user_id, workspace, roles)`` or raises + ``RuntimeError`` with error type ``auth-failed`` if the key is + unknown / expired / revoked.""" + resp = await self._request( + operation="resolve-api-key", + api_key=api_key, + timeout=timeout, + ) + return ( + resp.resolved_user_id, + resp.resolved_workspace, + list(resp.resolved_roles), + ) + + async def create_user(self, workspace, user, actor="", + timeout=IAM_TIMEOUT): + """Create a user. ``user`` is a ``UserInput``.""" + resp = await self._request( + operation="create-user", + workspace=workspace, + actor=actor, + user=user, + timeout=timeout, + ) + return resp.user + + async def list_users(self, workspace, actor="", timeout=IAM_TIMEOUT): + resp = await self._request( + operation="list-users", + workspace=workspace, + actor=actor, + timeout=timeout, + ) + return list(resp.users) + + async def create_api_key(self, workspace, key, actor="", + timeout=IAM_TIMEOUT): + """Create an API key. ``key`` is an ``ApiKeyInput``. Returns + ``(plaintext, record)`` — plaintext is returned once and the + caller is responsible for surfacing it to the operator.""" + resp = await self._request( + operation="create-api-key", + workspace=workspace, + actor=actor, + key=key, + timeout=timeout, + ) + return resp.api_key_plaintext, resp.api_key + + async def list_api_keys(self, workspace, user_id, actor="", + timeout=IAM_TIMEOUT): + resp = await self._request( + operation="list-api-keys", + workspace=workspace, + actor=actor, + user_id=user_id, + timeout=timeout, + ) + return list(resp.api_keys) + + async def revoke_api_key(self, workspace, key_id, actor="", + timeout=IAM_TIMEOUT): + await self._request( + operation="revoke-api-key", + workspace=workspace, + actor=actor, + key_id=key_id, + timeout=timeout, + ) + + async def login(self, username, password, workspace="", + timeout=IAM_TIMEOUT): + """Validate credentials and return ``(jwt, expires_iso)``. + ``workspace`` is optional; defaults at the server to the + OSS default workspace.""" + resp = await self._request( + operation="login", + workspace=workspace, + username=username, + password=password, + timeout=timeout, + ) + return resp.jwt, resp.jwt_expires + + async def get_signing_key_public(self, timeout=IAM_TIMEOUT): + """Return the active JWT signing public key in PEM. The + gateway calls this at startup and caches the result.""" + resp = await self._request( + operation="get-signing-key-public", + timeout=timeout, + ) + return resp.signing_key_public + + async def change_password(self, user_id, current_password, + new_password, timeout=IAM_TIMEOUT): + await self._request( + operation="change-password", + user_id=user_id, + password=current_password, + new_password=new_password, + timeout=timeout, + ) + + async def reset_password(self, workspace, user_id, actor="", + timeout=IAM_TIMEOUT): + """Admin-driven password reset. Returns the plaintext + temporary password (returned once).""" + resp = await self._request( + operation="reset-password", + workspace=workspace, + actor=actor, + user_id=user_id, + timeout=timeout, + ) + return resp.temporary_password + + async def get_user(self, workspace, user_id, actor="", + timeout=IAM_TIMEOUT): + resp = await self._request( + operation="get-user", + workspace=workspace, + actor=actor, + user_id=user_id, + timeout=timeout, + ) + return resp.user + + async def update_user(self, workspace, user_id, user, actor="", + timeout=IAM_TIMEOUT): + resp = await self._request( + operation="update-user", + workspace=workspace, + actor=actor, + user_id=user_id, + user=user, + timeout=timeout, + ) + return resp.user + + async def disable_user(self, workspace, user_id, actor="", + timeout=IAM_TIMEOUT): + await self._request( + operation="disable-user", + workspace=workspace, + actor=actor, + user_id=user_id, + timeout=timeout, + ) + + async def enable_user(self, workspace, user_id, actor="", + timeout=IAM_TIMEOUT): + await self._request( + operation="enable-user", + workspace=workspace, + actor=actor, + user_id=user_id, + timeout=timeout, + ) + + async def delete_user(self, workspace, user_id, actor="", + timeout=IAM_TIMEOUT): + await self._request( + operation="delete-user", + workspace=workspace, + actor=actor, + user_id=user_id, + timeout=timeout, + ) + + async def create_workspace(self, workspace_record, actor="", + timeout=IAM_TIMEOUT): + resp = await self._request( + operation="create-workspace", + actor=actor, + workspace_record=workspace_record, + timeout=timeout, + ) + return resp.workspace + + async def list_workspaces(self, actor="", timeout=IAM_TIMEOUT): + resp = await self._request( + operation="list-workspaces", + actor=actor, + timeout=timeout, + ) + return list(resp.workspaces) + + async def get_workspace(self, workspace_id, actor="", + timeout=IAM_TIMEOUT): + from ..schema import WorkspaceInput + resp = await self._request( + operation="get-workspace", + actor=actor, + workspace_record=WorkspaceInput(id=workspace_id), + timeout=timeout, + ) + return resp.workspace + + async def update_workspace(self, workspace_record, actor="", + timeout=IAM_TIMEOUT): + resp = await self._request( + operation="update-workspace", + actor=actor, + workspace_record=workspace_record, + timeout=timeout, + ) + return resp.workspace + + async def disable_workspace(self, workspace_id, actor="", + timeout=IAM_TIMEOUT): + from ..schema import WorkspaceInput + await self._request( + operation="disable-workspace", + actor=actor, + workspace_record=WorkspaceInput(id=workspace_id), + timeout=timeout, + ) + + async def rotate_signing_key(self, actor="", timeout=IAM_TIMEOUT): + await self._request( + operation="rotate-signing-key", + actor=actor, + timeout=timeout, + ) + + +class IamClientSpec(RequestResponseSpec): + def __init__(self, request_name, response_name): + super().__init__( + request_name=request_name, + request_schema=IamRequest, + response_name=response_name, + response_schema=IamResponse, + impl=IamClient, + ) diff --git a/trustgraph-base/trustgraph/messaging/__init__.py b/trustgraph-base/trustgraph/messaging/__init__.py index 30f5061c..9fcfa6f7 100644 --- a/trustgraph-base/trustgraph/messaging/__init__.py +++ b/trustgraph-base/trustgraph/messaging/__init__.py @@ -15,6 +15,7 @@ from .translators.library import LibraryRequestTranslator, LibraryResponseTransl from .translators.document_loading import DocumentTranslator, TextDocumentTranslator from .translators.config import ConfigRequestTranslator, ConfigResponseTranslator from .translators.flow import FlowRequestTranslator, FlowResponseTranslator +from .translators.iam import IamRequestTranslator, IamResponseTranslator from .translators.prompt import PromptRequestTranslator, PromptResponseTranslator from .translators.tool import ToolRequestTranslator, ToolResponseTranslator from .translators.embeddings_query import ( @@ -85,11 +86,17 @@ TranslatorRegistry.register_service( ) TranslatorRegistry.register_service( - "flow", - FlowRequestTranslator(), + "flow", + FlowRequestTranslator(), FlowResponseTranslator() ) +TranslatorRegistry.register_service( + "iam", + IamRequestTranslator(), + IamResponseTranslator() +) + TranslatorRegistry.register_service( "prompt", PromptRequestTranslator(), diff --git a/trustgraph-base/trustgraph/messaging/translators/iam.py b/trustgraph-base/trustgraph/messaging/translators/iam.py new file mode 100644 index 00000000..4a717bba --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/iam.py @@ -0,0 +1,194 @@ +from typing import Dict, Any, Tuple + +from ...schema import IamRequest, IamResponse +from ...schema import ( + UserInput, UserRecord, + WorkspaceInput, WorkspaceRecord, + ApiKeyInput, ApiKeyRecord, +) +from .base import MessageTranslator + + +def _user_input_from_dict(d): + if d is None: + return None + return UserInput( + username=d.get("username", ""), + name=d.get("name", ""), + email=d.get("email", ""), + password=d.get("password", ""), + roles=list(d.get("roles", [])), + enabled=d.get("enabled", True), + must_change_password=d.get("must_change_password", False), + ) + + +def _workspace_input_from_dict(d): + if d is None: + return None + return WorkspaceInput( + id=d.get("id", ""), + name=d.get("name", ""), + enabled=d.get("enabled", True), + ) + + +def _api_key_input_from_dict(d): + if d is None: + return None + return ApiKeyInput( + user_id=d.get("user_id", ""), + name=d.get("name", ""), + expires=d.get("expires", ""), + ) + + +def _user_record_to_dict(r): + if r is None: + return None + return { + "id": r.id, + "workspace": r.workspace, + "username": r.username, + "name": r.name, + "email": r.email, + "roles": list(r.roles), + "enabled": r.enabled, + "must_change_password": r.must_change_password, + "created": r.created, + } + + +def _workspace_record_to_dict(r): + if r is None: + return None + return { + "id": r.id, + "name": r.name, + "enabled": r.enabled, + "created": r.created, + } + + +def _api_key_record_to_dict(r): + if r is None: + return None + return { + "id": r.id, + "user_id": r.user_id, + "name": r.name, + "prefix": r.prefix, + "expires": r.expires, + "created": r.created, + "last_used": r.last_used, + } + + +class IamRequestTranslator(MessageTranslator): + + def decode(self, data: Dict[str, Any]) -> IamRequest: + return IamRequest( + operation=data.get("operation", ""), + workspace=data.get("workspace", ""), + actor=data.get("actor", ""), + user_id=data.get("user_id", ""), + username=data.get("username", ""), + key_id=data.get("key_id", ""), + api_key=data.get("api_key", ""), + password=data.get("password", ""), + new_password=data.get("new_password", ""), + user=_user_input_from_dict(data.get("user")), + workspace_record=_workspace_input_from_dict( + data.get("workspace_record") + ), + key=_api_key_input_from_dict(data.get("key")), + ) + + def encode(self, obj: IamRequest) -> Dict[str, Any]: + result = {"operation": obj.operation} + for fname in ( + "workspace", "actor", "user_id", "username", "key_id", + "api_key", "password", "new_password", + ): + v = getattr(obj, fname, "") + if v: + result[fname] = v + if obj.user is not None: + result["user"] = { + "username": obj.user.username, + "name": obj.user.name, + "email": obj.user.email, + "password": obj.user.password, + "roles": list(obj.user.roles), + "enabled": obj.user.enabled, + "must_change_password": obj.user.must_change_password, + } + if obj.workspace_record is not None: + result["workspace_record"] = { + "id": obj.workspace_record.id, + "name": obj.workspace_record.name, + "enabled": obj.workspace_record.enabled, + } + if obj.key is not None: + result["key"] = { + "user_id": obj.key.user_id, + "name": obj.key.name, + "expires": obj.key.expires, + } + return result + + +class IamResponseTranslator(MessageTranslator): + + def decode(self, data: Dict[str, Any]) -> IamResponse: + raise NotImplementedError( + "IamResponse is a server-produced message; no HTTP→schema " + "path is needed" + ) + + def encode(self, obj: IamResponse) -> Dict[str, Any]: + result: Dict[str, Any] = {} + + if obj.user is not None: + result["user"] = _user_record_to_dict(obj.user) + if obj.users: + result["users"] = [_user_record_to_dict(u) for u in obj.users] + if obj.workspace is not None: + result["workspace"] = _workspace_record_to_dict(obj.workspace) + if obj.workspaces: + result["workspaces"] = [ + _workspace_record_to_dict(w) for w in obj.workspaces + ] + if obj.api_key_plaintext: + result["api_key_plaintext"] = obj.api_key_plaintext + if obj.api_key is not None: + result["api_key"] = _api_key_record_to_dict(obj.api_key) + if obj.api_keys: + result["api_keys"] = [ + _api_key_record_to_dict(k) for k in obj.api_keys + ] + if obj.jwt: + result["jwt"] = obj.jwt + if obj.jwt_expires: + result["jwt_expires"] = obj.jwt_expires + if obj.signing_key_public: + result["signing_key_public"] = obj.signing_key_public + if obj.resolved_user_id: + result["resolved_user_id"] = obj.resolved_user_id + if obj.resolved_workspace: + result["resolved_workspace"] = obj.resolved_workspace + if obj.resolved_roles: + result["resolved_roles"] = list(obj.resolved_roles) + if obj.temporary_password: + result["temporary_password"] = obj.temporary_password + if obj.bootstrap_admin_user_id: + result["bootstrap_admin_user_id"] = obj.bootstrap_admin_user_id + if obj.bootstrap_admin_api_key: + result["bootstrap_admin_api_key"] = obj.bootstrap_admin_api_key + + return result + + def encode_with_completion( + self, obj: IamResponse, + ) -> Tuple[Dict[str, Any], bool]: + return self.encode(obj), True diff --git a/trustgraph-base/trustgraph/schema/services/__init__.py b/trustgraph-base/trustgraph/schema/services/__init__.py index 550b7d12..2a214201 100644 --- a/trustgraph-base/trustgraph/schema/services/__init__.py +++ b/trustgraph-base/trustgraph/schema/services/__init__.py @@ -5,6 +5,7 @@ from .agent import * from .flow import * from .prompt import * from .config import * +from .iam import * from .library import * from .lookup import * from .nlp_query import * diff --git a/trustgraph-base/trustgraph/schema/services/iam.py b/trustgraph-base/trustgraph/schema/services/iam.py new file mode 100644 index 00000000..1e3ab1ab --- /dev/null +++ b/trustgraph-base/trustgraph/schema/services/iam.py @@ -0,0 +1,142 @@ + +from dataclasses import dataclass, field + +from ..core.topic import queue +from ..core.primitives import Error + +############################################################################ + +# IAM service — see docs/tech-specs/iam-protocol.md for the full protocol. +# +# Transport: request/response pub/sub, correlated by the `id` message +# property. Caller is the API gateway only; the IAM service trusts +# the bus per the enforcement-boundary policy (no per-request auth +# against the caller). + + +@dataclass +class UserInput: + username: str = "" + name: str = "" + email: str = "" + # Only populated on create-user; never on update-user. + password: str = "" + roles: list[str] = field(default_factory=list) + enabled: bool = True + must_change_password: bool = False + + +@dataclass +class UserRecord: + id: str = "" + workspace: str = "" + username: str = "" + name: str = "" + email: str = "" + roles: list[str] = field(default_factory=list) + enabled: bool = True + must_change_password: bool = False + created: str = "" + + +@dataclass +class WorkspaceInput: + id: str = "" + name: str = "" + enabled: bool = True + + +@dataclass +class WorkspaceRecord: + id: str = "" + name: str = "" + enabled: bool = True + created: str = "" + + +@dataclass +class ApiKeyInput: + user_id: str = "" + name: str = "" + expires: str = "" + + +@dataclass +class ApiKeyRecord: + id: str = "" + user_id: str = "" + name: str = "" + # First 4 chars of the plaintext token, for operator identification + # in list-api-keys. Never enough to reconstruct the key. + prefix: str = "" + expires: str = "" + created: str = "" + last_used: str = "" + + +@dataclass +class IamRequest: + operation: str = "" + + # Workspace scope. Required on workspace-scoped operations; + # omitted for system-level ops (workspace CRUD, signing-key + # ops, bootstrap, resolve-api-key, login). + workspace: str = "" + + # Acting user id for audit. Empty for internal-origin and for + # operations that resolve an identity (login, resolve-api-key). + actor: str = "" + + user_id: str = "" + username: str = "" + key_id: str = "" + api_key: str = "" + + password: str = "" + new_password: str = "" + + user: UserInput | None = None + workspace_record: WorkspaceInput | None = None + key: ApiKeyInput | None = None + + +@dataclass +class IamResponse: + user: UserRecord | None = None + users: list[UserRecord] = field(default_factory=list) + + workspace: WorkspaceRecord | None = None + workspaces: list[WorkspaceRecord] = field(default_factory=list) + + # create-api-key returns the plaintext once; never populated + # on any other operation. + api_key_plaintext: str = "" + api_key: ApiKeyRecord | None = None + api_keys: list[ApiKeyRecord] = field(default_factory=list) + + # login, rotate-signing-key + jwt: str = "" + jwt_expires: str = "" + + # get-signing-key-public + signing_key_public: str = "" + + # resolve-api-key + resolved_user_id: str = "" + resolved_workspace: str = "" + resolved_roles: list[str] = field(default_factory=list) + + # reset-password + temporary_password: str = "" + + # bootstrap + bootstrap_admin_user_id: str = "" + bootstrap_admin_api_key: str = "" + + error: Error | None = None + + +iam_request_queue = queue('iam', cls='request') +iam_response_queue = queue('iam', cls='response') + +############################################################################ diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index d316ae4f..728079c8 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -40,6 +40,20 @@ tg-get-flow-blueprint = "trustgraph.cli.get_flow_blueprint:main" tg-get-kg-core = "trustgraph.cli.get_kg_core:main" tg-get-document-content = "trustgraph.cli.get_document_content:main" tg-graph-to-turtle = "trustgraph.cli.graph_to_turtle:main" +tg-bootstrap-iam = "trustgraph.cli.bootstrap_iam:main" +tg-login = "trustgraph.cli.login:main" +tg-create-user = "trustgraph.cli.create_user:main" +tg-list-users = "trustgraph.cli.list_users:main" +tg-disable-user = "trustgraph.cli.disable_user:main" +tg-enable-user = "trustgraph.cli.enable_user:main" +tg-delete-user = "trustgraph.cli.delete_user:main" +tg-change-password = "trustgraph.cli.change_password:main" +tg-reset-password = "trustgraph.cli.reset_password:main" +tg-create-api-key = "trustgraph.cli.create_api_key:main" +tg-list-api-keys = "trustgraph.cli.list_api_keys:main" +tg-revoke-api-key = "trustgraph.cli.revoke_api_key:main" +tg-list-workspaces = "trustgraph.cli.list_workspaces:main" +tg-create-workspace = "trustgraph.cli.create_workspace:main" tg-invoke-agent = "trustgraph.cli.invoke_agent:main" tg-invoke-document-rag = "trustgraph.cli.invoke_document_rag:main" tg-invoke-graph-rag = "trustgraph.cli.invoke_graph_rag:main" diff --git a/trustgraph-cli/trustgraph/cli/_iam.py b/trustgraph-cli/trustgraph/cli/_iam.py new file mode 100644 index 00000000..f5278c0c --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/_iam.py @@ -0,0 +1,75 @@ +""" +Shared helpers for IAM CLI tools. + +All IAM operations go through the gateway's ``/api/v1/iam`` forwarder, +with the three public auth operations (``login``, ``bootstrap``, +``change-password``) served via ``/api/v1/auth/...`` instead. These +helpers encapsulate the HTTP plumbing so each CLI can stay focused +on its own argument parsing and output formatting. +""" + +import json +import os +import sys + +import requests + + +DEFAULT_URL = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/") +DEFAULT_TOKEN = os.getenv("TRUSTGRAPH_TOKEN", None) + + +def _fmt_error(resp_json): + err = resp_json.get("error", {}) + if isinstance(err, dict): + t = err.get("type", "") + m = err.get("message", "") + return f"{t}: {m}" if t else m or "error" + return str(err) + + +def _post(url, path, token, body): + endpoint = url.rstrip("/") + path + headers = {"Content-Type": "application/json"} + if token: + headers["Authorization"] = f"Bearer {token}" + + resp = requests.post( + endpoint, headers=headers, data=json.dumps(body), + ) + + if resp.status_code != 200: + try: + payload = resp.json() + detail = _fmt_error(payload) + except Exception: + detail = resp.text + raise RuntimeError(f"HTTP {resp.status_code}: {detail}") + + body = resp.json() + if "error" in body: + raise RuntimeError(_fmt_error(body)) + return body + + +def call_iam(url, token, request): + """Forward an IAM request through ``/api/v1/iam``. ``request`` is + the ``IamRequest`` dict shape.""" + return _post(url, "/api/v1/iam", token, request) + + +def call_auth(url, path, token, body): + """Hit one of the public auth endpoints + (``/api/v1/auth/login``, ``/api/v1/auth/change-password``, etc.). + ``token`` is optional — login and bootstrap don't need one.""" + return _post(url, path, token, body) + + +def run_main(fn, parser): + """Standard error-handling wrapper for CLI main() bodies.""" + args = parser.parse_args() + try: + fn(args) + except Exception as e: + print("Exception:", e, file=sys.stderr, flush=True) + sys.exit(1) diff --git a/trustgraph-cli/trustgraph/cli/bootstrap_iam.py b/trustgraph-cli/trustgraph/cli/bootstrap_iam.py new file mode 100644 index 00000000..99a789e2 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/bootstrap_iam.py @@ -0,0 +1,94 @@ +""" +Bootstraps the IAM service. Only works when iam-svc is running in +bootstrap mode with empty tables. Prints the initial admin API key +to stdout. + +This is a one-time, trust-sensitive operation. The resulting token +is shown once and never again — capture it on use. Rotate and +revoke it as soon as a real admin API key has been issued. +""" + +import argparse +import json +import os +import sys + +import requests + +default_url = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/") + + +def bootstrap(url): + + # Unauthenticated public endpoint — IAM refuses the bootstrap + # operation unless the service is running in bootstrap mode with + # empty tables, so the safety gate lives on the server side. + endpoint = url.rstrip("/") + "/api/v1/auth/bootstrap" + + headers = {"Content-Type": "application/json"} + + resp = requests.post( + endpoint, + headers=headers, + data=json.dumps({}), + ) + + if resp.status_code != 200: + raise RuntimeError( + f"HTTP {resp.status_code}: {resp.text}" + ) + + body = resp.json() + + if "error" in body: + raise RuntimeError( + f"IAM {body['error'].get('type', 'error')}: " + f"{body['error'].get('message', '')}" + ) + + api_key = body.get("bootstrap_admin_api_key") + user_id = body.get("bootstrap_admin_user_id") + + if not api_key: + raise RuntimeError( + "IAM response did not contain a bootstrap token — the " + "service may already be bootstrapped, or may be running " + "in token mode." + ) + + return user_id, api_key + + +def main(): + + parser = argparse.ArgumentParser( + prog="tg-bootstrap-iam", + description=__doc__, + ) + + parser.add_argument( + "-u", "--api-url", + default=default_url, + help=f"API URL (default: {default_url})", + ) + + args = parser.parse_args() + + try: + user_id, api_key = bootstrap(args.api_url) + except Exception as e: + print("Exception:", e, file=sys.stderr, flush=True) + sys.exit(1) + + # Stdout gets machine-readable output (the key). Any operator + # context goes to stderr. + print(f"Admin user id: {user_id}", file=sys.stderr) + print( + "Admin API key (shown once, capture now):", + file=sys.stderr, + ) + print(api_key) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/change_password.py b/trustgraph-cli/trustgraph/cli/change_password.py new file mode 100644 index 00000000..c914b30f --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/change_password.py @@ -0,0 +1,46 @@ +""" +Change your own password. Requires the current password. +""" + +import argparse +import getpass + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_auth, run_main + + +def do_change_password(args): + current = args.current or getpass.getpass("Current password: ") + new = args.new or getpass.getpass("New password: ") + + call_auth( + args.api_url, "/api/v1/auth/change-password", args.token, + {"current_password": current, "new_password": new}, + ) + print("Password changed.") + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-change-password", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + parser.add_argument( + "--current", default=None, + help="Current password (prompted if omitted)", + ) + parser.add_argument( + "--new", default=None, + help="New password (prompted if omitted)", + ) + run_main(do_change_password, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/create_api_key.py b/trustgraph-cli/trustgraph/cli/create_api_key.py new file mode 100644 index 00000000..2b269041 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/create_api_key.py @@ -0,0 +1,71 @@ +""" +Create an API key for a user. Prints the plaintext key to stdout — +shown once only. +""" + +import argparse +import sys + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_iam, run_main + + +def do_create_api_key(args): + key = { + "user_id": args.user_id, + "name": args.name, + } + if args.expires: + key["expires"] = args.expires + + req = {"operation": "create-api-key", "key": key} + if args.workspace: + req["workspace"] = args.workspace + resp = call_iam(args.api_url, args.token, req) + + plaintext = resp.get("api_key_plaintext", "") + rec = resp.get("api_key", {}) + print(f"Key id: {rec.get('id', '')}", file=sys.stderr) + print(f"Name: {rec.get('name', '')}", file=sys.stderr) + print(f"Prefix: {rec.get('prefix', '')}", file=sys.stderr) + print( + "API key (shown once, capture now):", file=sys.stderr, + ) + print(plaintext) + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-create-api-key", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + parser.add_argument( + "--user-id", required=True, + help="Owner user id", + ) + parser.add_argument( + "--name", required=True, + help="Operator-facing label (e.g. 'laptop', 'ci')", + ) + parser.add_argument( + "--expires", default=None, + help="ISO-8601 expiry (optional; empty = no expiry)", + ) + parser.add_argument( + "-w", "--workspace", default=None, + help=( + "Target workspace (admin only; defaults to caller's " + "assigned workspace)" + ), + ) + run_main(do_create_api_key, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/create_user.py b/trustgraph-cli/trustgraph/cli/create_user.py new file mode 100644 index 00000000..c9253aca --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/create_user.py @@ -0,0 +1,87 @@ +""" +Create a user in the caller's workspace. Prints the new user id. +""" + +import argparse +import getpass +import sys + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_iam, run_main + + +def do_create_user(args): + password = args.password + if not password: + password = getpass.getpass( + f"Password for new user {args.username}: " + ) + + user = { + "username": args.username, + "password": password, + "roles": args.roles, + } + if args.name: + user["name"] = args.name + if args.email: + user["email"] = args.email + if args.must_change_password: + user["must_change_password"] = True + + req = {"operation": "create-user", "user": user} + if args.workspace: + req["workspace"] = args.workspace + resp = call_iam(args.api_url, args.token, req) + + rec = resp.get("user", {}) + print(f"User id: {rec.get('id', '')}", file=sys.stderr) + print(f"Username: {rec.get('username', '')}", file=sys.stderr) + print(f"Roles: {', '.join(rec.get('roles', []))}", file=sys.stderr) + print(rec.get("id", "")) + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-create-user", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + parser.add_argument( + "--username", required=True, help="Username (unique in workspace)", + ) + parser.add_argument( + "--password", default=None, + help="Password (prompted if omitted)", + ) + parser.add_argument( + "--name", default=None, help="Display name", + ) + parser.add_argument( + "--email", default=None, help="Email", + ) + parser.add_argument( + "--roles", nargs="+", default=["reader"], + help="One or more role names (default: reader)", + ) + parser.add_argument( + "--must-change-password", action="store_true", + help="Force password change on next login", + ) + parser.add_argument( + "-w", "--workspace", default=None, + help=( + "Target workspace (admin only; defaults to caller's " + "assigned workspace)" + ), + ) + run_main(do_create_user, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/create_workspace.py b/trustgraph-cli/trustgraph/cli/create_workspace.py new file mode 100644 index 00000000..f8367720 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/create_workspace.py @@ -0,0 +1,46 @@ +""" +Create a workspace (system-level; requires admin). +""" + +import argparse + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_iam, run_main + + +def do_create_workspace(args): + ws = {"id": args.workspace_id, "enabled": True} + if args.name: + ws["name"] = args.name + + resp = call_iam(args.api_url, args.token, { + "operation": "create-workspace", + "workspace_record": ws, + }) + rec = resp.get("workspace", {}) + print(f"Workspace created: {rec.get('id', '')}") + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-create-workspace", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + parser.add_argument( + "--workspace-id", required=True, + help="New workspace id (must not start with '_')", + ) + parser.add_argument( + "--name", default=None, help="Display name", + ) + run_main(do_create_workspace, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/delete_user.py b/trustgraph-cli/trustgraph/cli/delete_user.py new file mode 100644 index 00000000..dbdf7877 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/delete_user.py @@ -0,0 +1,62 @@ +""" +Delete a user. Removes the user record, their username lookup, +and all their API keys. The freed username becomes available for +re-use. + +Irreversible. Use tg-disable-user if you want to preserve the +record (audit trail, username squatting protection). +""" + +import argparse + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_iam, run_main + + +def do_delete_user(args): + if not args.yes: + confirm = input( + f"Delete user {args.user_id}? This is irreversible. " + f"[type 'yes' to confirm]: " + ) + if confirm.strip() != "yes": + print("Aborted.") + return + + req = {"operation": "delete-user", "user_id": args.user_id} + if args.workspace: + req["workspace"] = args.workspace + call_iam(args.api_url, args.token, req) + print(f"Deleted user {args.user_id}") + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-delete-user", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + parser.add_argument( + "--user-id", required=True, help="User id to delete", + ) + parser.add_argument( + "-w", "--workspace", default=None, + help=( + "Target workspace (admin only; defaults to caller's " + "assigned workspace)" + ), + ) + parser.add_argument( + "--yes", action="store_true", + help="Skip the interactive confirmation prompt", + ) + run_main(do_delete_user, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/disable_user.py b/trustgraph-cli/trustgraph/cli/disable_user.py new file mode 100644 index 00000000..e142644b --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/disable_user.py @@ -0,0 +1,45 @@ +""" +Disable a user. Soft-deletes (enabled=false) and revokes all their +API keys. +""" + +import argparse + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_iam, run_main + + +def do_disable_user(args): + req = {"operation": "disable-user", "user_id": args.user_id} + if args.workspace: + req["workspace"] = args.workspace + call_iam(args.api_url, args.token, req) + print(f"Disabled user {args.user_id}") + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-disable-user", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + parser.add_argument( + "--user-id", required=True, help="User id to disable", + ) + parser.add_argument( + "-w", "--workspace", default=None, + help=( + "Target workspace (admin only; defaults to caller's " + "assigned workspace)" + ), + ) + run_main(do_disable_user, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/enable_user.py b/trustgraph-cli/trustgraph/cli/enable_user.py new file mode 100644 index 00000000..c762366a --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/enable_user.py @@ -0,0 +1,45 @@ +""" +Re-enable a previously disabled user. Does not restore their API +keys — those must be re-issued by an admin. +""" + +import argparse + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_iam, run_main + + +def do_enable_user(args): + req = {"operation": "enable-user", "user_id": args.user_id} + if args.workspace: + req["workspace"] = args.workspace + call_iam(args.api_url, args.token, req) + print(f"Enabled user {args.user_id}") + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-enable-user", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + parser.add_argument( + "--user-id", required=True, help="User id to enable", + ) + parser.add_argument( + "-w", "--workspace", default=None, + help=( + "Target workspace (admin only; defaults to caller's " + "assigned workspace)" + ), + ) + run_main(do_enable_user, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/list_api_keys.py b/trustgraph-cli/trustgraph/cli/list_api_keys.py new file mode 100644 index 00000000..f969890e --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/list_api_keys.py @@ -0,0 +1,69 @@ +""" +List the API keys for a user. +""" + +import argparse + +import tabulate + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_iam, run_main + + +def do_list_api_keys(args): + req = {"operation": "list-api-keys", "user_id": args.user_id} + if args.workspace: + req["workspace"] = args.workspace + resp = call_iam(args.api_url, args.token, req) + + keys = resp.get("api_keys", []) + if not keys: + print("No keys.") + return + + rows = [ + [ + k.get("id", ""), + k.get("name", ""), + k.get("prefix", ""), + k.get("created", ""), + k.get("last_used", "") or "—", + k.get("expires", "") or "never", + ] + for k in keys + ] + print(tabulate.tabulate( + rows, + headers=["id", "name", "prefix", "created", "last used", "expires"], + tablefmt="pretty", + stralign="left", + )) + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-list-api-keys", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + parser.add_argument( + "--user-id", required=True, + help="Owner user id", + ) + parser.add_argument( + "-w", "--workspace", default=None, + help=( + "Target workspace (admin only; defaults to caller's " + "assigned workspace)" + ), + ) + run_main(do_list_api_keys, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/list_users.py b/trustgraph-cli/trustgraph/cli/list_users.py new file mode 100644 index 00000000..25bc1901 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/list_users.py @@ -0,0 +1,65 @@ +""" +List users in the caller's workspace. +""" + +import argparse + +import tabulate + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_iam, run_main + + +def do_list_users(args): + req = {"operation": "list-users"} + if args.workspace: + req["workspace"] = args.workspace + resp = call_iam(args.api_url, args.token, req) + + users = resp.get("users", []) + if not users: + print("No users.") + return + + rows = [ + [ + u.get("id", ""), + u.get("username", ""), + u.get("name", ""), + ", ".join(u.get("roles", [])), + "yes" if u.get("enabled") else "no", + "yes" if u.get("must_change_password") else "no", + ] + for u in users + ] + print(tabulate.tabulate( + rows, + headers=["id", "username", "name", "roles", "enabled", "change-pw"], + tablefmt="pretty", + stralign="left", + )) + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-list-users", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + parser.add_argument( + "-w", "--workspace", default=None, + help=( + "Target workspace (admin only; defaults to caller's " + "assigned workspace)" + ), + ) + run_main(do_list_users, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/list_workspaces.py b/trustgraph-cli/trustgraph/cli/list_workspaces.py new file mode 100644 index 00000000..170d330c --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/list_workspaces.py @@ -0,0 +1,53 @@ +""" +List workspaces (system-level; requires admin). +""" + +import argparse + +import tabulate + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_iam, run_main + + +def do_list_workspaces(args): + resp = call_iam( + args.api_url, args.token, {"operation": "list-workspaces"}, + ) + workspaces = resp.get("workspaces", []) + if not workspaces: + print("No workspaces.") + return + rows = [ + [ + w.get("id", ""), + w.get("name", ""), + "yes" if w.get("enabled") else "no", + w.get("created", ""), + ] + for w in workspaces + ] + print(tabulate.tabulate( + rows, + headers=["id", "name", "enabled", "created"], + tablefmt="pretty", + stralign="left", + )) + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-list-workspaces", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + run_main(do_list_workspaces, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/login.py b/trustgraph-cli/trustgraph/cli/login.py new file mode 100644 index 00000000..0e87c3b0 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/login.py @@ -0,0 +1,62 @@ +""" +Log in with username / password. Prints the resulting JWT to +stdout so it can be captured for subsequent CLI use. +""" + +import argparse +import getpass +import sys + +from ._iam import DEFAULT_URL, call_auth, run_main + + +def do_login(args): + password = args.password + if not password: + password = getpass.getpass(f"Password for {args.username}: ") + + body = { + "username": args.username, + "password": password, + } + if args.workspace: + body["workspace"] = args.workspace + + resp = call_auth(args.api_url, "/api/v1/auth/login", None, body) + + jwt = resp.get("jwt", "") + expires = resp.get("jwt_expires", "") + + if expires: + print(f"JWT expires: {expires}", file=sys.stderr) + # Machine-readable on stdout. + print(jwt) + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-login", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "--username", required=True, help="Username", + ) + parser.add_argument( + "--password", default=None, + help="Password (prompted if omitted)", + ) + parser.add_argument( + "-w", "--workspace", default=None, + help=( + "Optional workspace to log in against. Defaults to " + "the user's assigned workspace." + ), + ) + run_main(do_login, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/reset_password.py b/trustgraph-cli/trustgraph/cli/reset_password.py new file mode 100644 index 00000000..600f00e1 --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/reset_password.py @@ -0,0 +1,54 @@ +""" +Admin: reset another user's password. Prints a one-time temporary +password to stdout. The user is forced to change it on next login. +""" + +import argparse +import sys + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_iam, run_main + + +def do_reset_password(args): + req = {"operation": "reset-password", "user_id": args.user_id} + if args.workspace: + req["workspace"] = args.workspace + resp = call_iam(args.api_url, args.token, req) + + tmp = resp.get("temporary_password", "") + if not tmp: + raise RuntimeError( + "IAM returned no temporary password — unexpected" + ) + print("Temporary password (shown once, capture now):", file=sys.stderr) + print(tmp) + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-reset-password", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + parser.add_argument( + "--user-id", required=True, + help="Target user id", + ) + parser.add_argument( + "-w", "--workspace", default=None, + help=( + "Target workspace (admin only; defaults to caller's " + "assigned workspace)" + ), + ) + run_main(do_reset_password, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-cli/trustgraph/cli/revoke_api_key.py b/trustgraph-cli/trustgraph/cli/revoke_api_key.py new file mode 100644 index 00000000..3976b56f --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/revoke_api_key.py @@ -0,0 +1,44 @@ +""" +Revoke an API key by id. +""" + +import argparse + +from ._iam import DEFAULT_URL, DEFAULT_TOKEN, call_iam, run_main + + +def do_revoke_api_key(args): + req = {"operation": "revoke-api-key", "key_id": args.key_id} + if args.workspace: + req["workspace"] = args.workspace + call_iam(args.api_url, args.token, req) + print(f"Revoked key {args.key_id}") + + +def main(): + parser = argparse.ArgumentParser( + prog="tg-revoke-api-key", description=__doc__, + ) + parser.add_argument( + "-u", "--api-url", default=DEFAULT_URL, + help=f"API URL (default: {DEFAULT_URL})", + ) + parser.add_argument( + "-t", "--token", default=DEFAULT_TOKEN, + help="Auth token (default: $TRUSTGRAPH_TOKEN)", + ) + parser.add_argument( + "--key-id", required=True, help="Key id to revoke", + ) + parser.add_argument( + "-w", "--workspace", default=None, + help=( + "Target workspace (admin only; defaults to caller's " + "assigned workspace)" + ), + ) + run_main(do_revoke_api_key, parser) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index cc7dac63..d8c690b5 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -63,6 +63,7 @@ chunker-token = "trustgraph.chunking.token:run" bootstrap = "trustgraph.bootstrap.bootstrapper:run" config-svc = "trustgraph.config.service:run" flow-svc = "trustgraph.flow.service:run" +iam-svc = "trustgraph.iam.service:run" doc-embeddings-query-milvus = "trustgraph.query.doc_embeddings.milvus:run" doc-embeddings-query-pinecone = "trustgraph.query.doc_embeddings.pinecone:run" doc-embeddings-query-qdrant = "trustgraph.query.doc_embeddings.qdrant:run" diff --git a/trustgraph-flow/trustgraph/gateway/auth.py b/trustgraph-flow/trustgraph/gateway/auth.py index a693ca32..95743261 100644 --- a/trustgraph-flow/trustgraph/gateway/auth.py +++ b/trustgraph-flow/trustgraph/gateway/auth.py @@ -1,22 +1,264 @@ +""" +IAM-backed authentication for the API gateway. -class Authenticator: +Replaces the legacy GATEWAY_SECRET shared-token Authenticator. The +gateway is now stateless with respect to credentials: it either +verifies a JWT locally using the active IAM signing public key, or +resolves an API key by hash with a short local cache backed by the +IAM service. - def __init__(self, token=None, allow_all=False): +Identity returned by authenticate() is the (user_id, workspace, +roles) triple the rest of the gateway — capability checks, workspace +resolver, audit logging — needs. +""" - if not allow_all and token is None: - raise RuntimeError("Need a token") +import asyncio +import base64 +import hashlib +import json +import logging +import time +import uuid +from dataclasses import dataclass - if not allow_all and token == "": - raise RuntimeError("Need a token") +from aiohttp import web - self.token = token - self.allow_all = allow_all +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 - def permitted(self, token, roles): +from ..base.iam_client import IamClient +from ..base.metrics import ProducerMetrics, SubscriberMetrics +from ..schema import ( + IamRequest, IamResponse, + iam_request_queue, iam_response_queue, +) - if self.allow_all: return True +logger = logging.getLogger("auth") - if self.token != token: return False +API_KEY_CACHE_TTL = 60 # seconds - return True +@dataclass +class Identity: + user_id: str + workspace: str + roles: list + source: str # "api-key" | "jwt" + + +def _auth_failure(): + return web.HTTPUnauthorized( + text='{"error":"auth failure"}', + content_type="application/json", + ) + + +def _access_denied(): + return web.HTTPForbidden( + text='{"error":"access denied"}', + content_type="application/json", + ) + + +def _b64url_decode(s): + pad = "=" * (-len(s) % 4) + return base64.urlsafe_b64decode(s + pad) + + +def _verify_jwt_eddsa(token, public_pem): + """Verify an Ed25519 JWT and return its claims. Raises on any + validation failure. Refuses non-EdDSA algorithms.""" + parts = token.split(".") + if len(parts) != 3: + raise ValueError("malformed JWT") + h_b64, p_b64, s_b64 = parts + signing_input = f"{h_b64}.{p_b64}".encode("ascii") + header = json.loads(_b64url_decode(h_b64)) + if header.get("alg") != "EdDSA": + raise ValueError(f"unsupported alg: {header.get('alg')!r}") + + key = serialization.load_pem_public_key(public_pem.encode("ascii")) + if not isinstance(key, ed25519.Ed25519PublicKey): + raise ValueError("public key is not Ed25519") + + signature = _b64url_decode(s_b64) + key.verify(signature, signing_input) # raises InvalidSignature + + claims = json.loads(_b64url_decode(p_b64)) + exp = claims.get("exp") + if exp is None or exp < time.time(): + raise ValueError("expired") + return claims + + +class IamAuth: + """Resolves bearer credentials via the IAM service. + + Used by every gateway endpoint that needs authentication. Fetches + the IAM signing public key at startup (cached in memory). API + keys are resolved via the IAM service with a local hash→identity + cache (short TTL so revoked keys stop working within the TTL + window without any push mechanism).""" + + def __init__(self, backend, id="api-gateway"): + self.backend = backend + self.id = id + + # Populated at start() via IAM. + self._signing_public_pem = None + + # API-key cache: plaintext_sha256_hex -> (Identity, expires_ts) + self._key_cache = {} + self._key_cache_lock = asyncio.Lock() + + # ------------------------------------------------------------------ + # Short-lived client helper. Mirrors the pattern used by the + # bootstrap framework and AsyncProcessor: a fresh uuid suffix per + # invocation so Pulsar exclusive subscriptions don't collide with + # ghosts from prior calls. + # ------------------------------------------------------------------ + + def _make_client(self): + rr_id = str(uuid.uuid4()) + return IamClient( + backend=self.backend, + subscription=f"{self.id}--iam--{rr_id}", + consumer_name=self.id, + request_topic=iam_request_queue, + request_schema=IamRequest, + request_metrics=ProducerMetrics( + processor=self.id, flow=None, name="iam-request", + ), + response_topic=iam_response_queue, + response_schema=IamResponse, + response_metrics=SubscriberMetrics( + processor=self.id, flow=None, name="iam-response", + ), + ) + + async def _with_client(self, op): + """Open a short-lived IamClient, run ``op(client)``, close.""" + client = self._make_client() + await client.start() + try: + return await op(client) + finally: + try: + await client.stop() + except Exception: + pass + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + async def start(self, max_retries=30, retry_delay=2.0): + """Fetch the signing public key from IAM. Retries on + failure — the gateway may be starting before IAM is ready.""" + + async def _fetch(client): + return await client.get_signing_key_public() + + for attempt in range(max_retries): + try: + pem = await self._with_client(_fetch) + if pem: + self._signing_public_pem = pem + logger.info( + "IamAuth: fetched IAM signing public key " + f"({len(pem)} bytes)" + ) + return + except Exception as e: + logger.info( + f"IamAuth: waiting for IAM signing key " + f"({type(e).__name__}: {e}); " + f"retry {attempt + 1}/{max_retries}" + ) + await asyncio.sleep(retry_delay) + + # Don't prevent startup forever. A later authenticate() call + # will try again via the JWT path. + logger.warning( + "IamAuth: could not fetch IAM signing key at startup; " + "JWT validation will fail until it's available" + ) + + # ------------------------------------------------------------------ + # Authentication + # ------------------------------------------------------------------ + + async def authenticate(self, request): + """Extract and validate the Bearer credential from an HTTP + request. Returns an ``Identity``. Raises HTTPUnauthorized + (401 / "auth failure") on any failure mode — the caller + cannot distinguish missing / malformed / invalid / expired / + revoked credentials.""" + + header = request.headers.get("Authorization", "") + if not header.startswith("Bearer "): + raise _auth_failure() + token = header[len("Bearer "):].strip() + if not token: + raise _auth_failure() + + # API keys always start with "tg_". JWTs have two dots and + # no "tg_" prefix. Discriminate cheaply. + if token.startswith("tg_"): + return await self._resolve_api_key(token) + if token.count(".") == 2: + return self._verify_jwt(token) + raise _auth_failure() + + def _verify_jwt(self, token): + if not self._signing_public_pem: + raise _auth_failure() + try: + claims = _verify_jwt_eddsa(token, self._signing_public_pem) + except Exception as e: + logger.debug(f"JWT validation failed: {type(e).__name__}: {e}") + raise _auth_failure() + + sub = claims.get("sub", "") + ws = claims.get("workspace", "") + roles = list(claims.get("roles", [])) + if not sub or not ws: + raise _auth_failure() + + return Identity( + user_id=sub, workspace=ws, roles=roles, source="jwt", + ) + + async def _resolve_api_key(self, plaintext): + h = hashlib.sha256(plaintext.encode("utf-8")).hexdigest() + + cached = self._key_cache.get(h) + now = time.time() + if cached and cached[1] > now: + return cached[0] + + async with self._key_cache_lock: + cached = self._key_cache.get(h) + if cached and cached[1] > now: + return cached[0] + + try: + async def _call(client): + return await client.resolve_api_key(plaintext) + user_id, workspace, roles = await self._with_client(_call) + except Exception as e: + logger.debug( + f"API key resolution failed: " + f"{type(e).__name__}: {e}" + ) + raise _auth_failure() + + if not user_id or not workspace: + raise _auth_failure() + + identity = Identity( + user_id=user_id, workspace=workspace, + roles=list(roles), source="api-key", + ) + self._key_cache[h] = (identity, now + API_KEY_CACHE_TTL) + return identity diff --git a/trustgraph-flow/trustgraph/gateway/capabilities.py b/trustgraph-flow/trustgraph/gateway/capabilities.py new file mode 100644 index 00000000..15e25684 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/capabilities.py @@ -0,0 +1,238 @@ +""" +Capability vocabulary, role definitions, and authorisation helpers. + +See docs/tech-specs/capabilities.md for the authoritative description. +The data here is the OSS bundle table in that spec. Enterprise +editions may replace this module with their own role table; the +vocabulary (capability strings) is shared. + +Role model +---------- +A role has two dimensions: + + 1. **capability set** — which operations the role grants. + 2. **workspace scope** — which workspaces the role is active in. + +The authorisation question is: *given the caller's roles, a required +capability, and a target workspace, does any role grant the +capability AND apply to the target workspace?* + +Workspace scope values recognised here: + + - ``"assigned"`` — the role applies only to the caller's own + assigned workspace (stored on their user record). + - ``"*"`` — the role applies to every workspace. + +Enterprise editions can add richer scopes (explicit permitted-set, +patterns, etc.) without changing the wire protocol. + +Sentinels +--------- +- ``PUBLIC`` — endpoint requires no authentication. +- ``AUTHENTICATED`` — endpoint requires a valid identity, no + specific capability. +""" + +from aiohttp import web + + +PUBLIC = "__public__" +AUTHENTICATED = "__authenticated__" + + +# Capability vocabulary. Mirrors the "Capability list" tables in +# capabilities.md. Kept as a set so the gateway can fail-closed on +# an endpoint that declares an unknown capability. +KNOWN_CAPABILITIES = { + # Data plane + "agent", + "graph:read", "graph:write", + "documents:read", "documents:write", + "rows:read", "rows:write", + "llm", + "embeddings", + "mcp", + # Control plane + "config:read", "config:write", + "flows:read", "flows:write", + "users:read", "users:write", "users:admin", + "keys:self", "keys:admin", + "workspaces:admin", + "iam:admin", + "metrics:read", + "collections:read", "collections:write", + "knowledge:read", "knowledge:write", +} + + +# Capability sets used below. +_READER_CAPS = { + "agent", + "graph:read", + "documents:read", + "rows:read", + "llm", + "embeddings", + "mcp", + "config:read", + "flows:read", + "collections:read", + "knowledge:read", + "keys:self", +} + +_WRITER_CAPS = _READER_CAPS | { + "graph:write", + "documents:write", + "rows:write", + "collections:write", + "knowledge:write", +} + +_ADMIN_CAPS = _WRITER_CAPS | { + "config:write", + "flows:write", + "users:read", "users:write", "users:admin", + "keys:admin", + "workspaces:admin", + "iam:admin", + "metrics:read", +} + + +# Role definitions. Each role has a capability set and a workspace +# scope. Enterprise overrides this mapping. +ROLE_DEFINITIONS = { + "reader": { + "capabilities": _READER_CAPS, + "workspace_scope": "assigned", + }, + "writer": { + "capabilities": _WRITER_CAPS, + "workspace_scope": "assigned", + }, + "admin": { + "capabilities": _ADMIN_CAPS, + "workspace_scope": "*", + }, +} + + +def _scope_permits(role_name, target_workspace, assigned_workspace): + """Does the given role apply to ``target_workspace``?""" + role = ROLE_DEFINITIONS.get(role_name) + if role is None: + return False + scope = role["workspace_scope"] + if scope == "*": + return True + if scope == "assigned": + return target_workspace == assigned_workspace + # Future scope types (lists, patterns) extend here. + return False + + +def check(identity, capability, target_workspace=None): + """Is ``identity`` permitted to invoke ``capability`` on + ``target_workspace``? + + Passes iff some role held by the caller both (a) grants + ``capability`` and (b) is active in ``target_workspace``. + + ``target_workspace`` defaults to the caller's assigned workspace, + which makes this function usable for system-level operations and + for authenticated endpoints that don't take a workspace argument + (the call collapses to "do any of my roles grant this cap?").""" + if capability not in KNOWN_CAPABILITIES: + return False + + target = target_workspace or identity.workspace + + for role_name in identity.roles: + role = ROLE_DEFINITIONS.get(role_name) + if role is None: + continue + if capability not in role["capabilities"]: + continue + if _scope_permits(role_name, target, identity.workspace): + return True + return False + + +def access_denied(): + return web.HTTPForbidden( + text='{"error":"access denied"}', + content_type="application/json", + ) + + +def auth_failure(): + return web.HTTPUnauthorized( + text='{"error":"auth failure"}', + content_type="application/json", + ) + + +async def enforce(request, auth, capability): + """Authenticate + capability-check for endpoints that carry no + workspace dimension on the request (metrics, i18n, etc.). + + For endpoints that carry a workspace field on the body, call + :func:`enforce_workspace` *after* parsing the body to validate + the workspace and re-check the capability in that scope. Most + endpoints do both. + + - ``PUBLIC``: no authentication, returns ``None``. + - ``AUTHENTICATED``: any valid identity. + - capability string: identity must have it, checked against the + caller's assigned workspace (adequate for endpoints whose + capability is system-level, e.g. ``metrics:read``, or where + the real workspace-aware check happens in + :func:`enforce_workspace` after body parsing).""" + if capability == PUBLIC: + return None + + identity = await auth.authenticate(request) + + if capability == AUTHENTICATED: + return identity + + if not check(identity, capability): + raise access_denied() + + return identity + + +def enforce_workspace(data, identity, capability=None): + """Resolve + validate the workspace on a request body. + + - Target workspace = ``data["workspace"]`` if supplied, else the + caller's assigned workspace. + - At least one of the caller's roles must (a) be active in the + target workspace and, if ``capability`` is given, (b) grant + ``capability``. Otherwise 403. + - On success, ``data["workspace"]`` is overwritten with the + resolved value — callers can rely on the outgoing message + having the gateway's chosen workspace rather than any + caller-supplied value. + + For ``capability=None`` the workspace scope alone is checked — + useful when the body has a workspace but the endpoint already + passed its capability check (e.g. via :func:`enforce`).""" + if not isinstance(data, dict): + return data + + requested = data.get("workspace", "") + target = requested or identity.workspace + + for role_name in identity.roles: + role = ROLE_DEFINITIONS.get(role_name) + if role is None: + continue + if capability is not None and capability not in role["capabilities"]: + continue + if _scope_permits(role_name, target, identity.workspace): + data["workspace"] = target + return data + + raise access_denied() diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/iam.py b/trustgraph-flow/trustgraph/gateway/dispatch/iam.py new file mode 100644 index 00000000..386233f5 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/dispatch/iam.py @@ -0,0 +1,40 @@ + +from ... schema import IamRequest, IamResponse +from ... schema import iam_request_queue, iam_response_queue +from ... messaging import TranslatorRegistry + +from . requestor import ServiceRequestor + + +class IamRequestor(ServiceRequestor): + def __init__(self, backend, consumer, subscriber, timeout=120, + request_queue=None, response_queue=None): + + if request_queue is None: + request_queue = iam_request_queue + if response_queue is None: + response_queue = iam_response_queue + + super().__init__( + backend=backend, + consumer_name=consumer, + subscription=subscriber, + request_queue=request_queue, + response_queue=response_queue, + request_schema=IamRequest, + response_schema=IamResponse, + timeout=timeout, + ) + + self.request_translator = ( + TranslatorRegistry.get_request_translator("iam") + ) + self.response_translator = ( + TranslatorRegistry.get_response_translator("iam") + ) + + def to_request(self, body): + return self.request_translator.decode(body) + + def from_response(self, message): + return self.response_translator.encode_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index b238bb5b..ea8770d7 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -9,6 +9,7 @@ logger = logging.getLogger(__name__) from . config import ConfigRequestor from . flow import FlowRequestor +from . iam import IamRequestor from . librarian import LibrarianRequestor from . knowledge import KnowledgeRequestor from . collection_management import CollectionManagementRequestor @@ -72,6 +73,7 @@ request_response_dispatchers = { global_dispatchers = { "config": ConfigRequestor, "flow": FlowRequestor, + "iam": IamRequestor, "librarian": LibrarianRequestor, "knowledge": KnowledgeRequestor, "collection-management": CollectionManagementRequestor, @@ -105,13 +107,31 @@ class DispatcherWrapper: class DispatcherManager: - def __init__(self, backend, config_receiver, prefix="api-gateway", - queue_overrides=None): + def __init__(self, backend, config_receiver, auth, + prefix="api-gateway", queue_overrides=None): + """ + ``auth`` is required. It flows into the Mux for first-frame + WebSocket authentication and into downstream dispatcher + construction. There is no permissive default — constructing + a DispatcherManager without an authenticator would be a + silent downgrade to no-auth on the socket path. + """ + if auth is None: + raise ValueError( + "DispatcherManager requires an 'auth' argument — there " + "is no no-auth mode" + ) + self.backend = backend self.config_receiver = config_receiver self.config_receiver.add_handler(self) self.prefix = prefix + # Gateway IamAuth — used by the socket Mux for first-frame + # auth and by any dispatcher that needs to resolve caller + # identity out-of-band. + self.auth = auth + # Store queue overrides for global services # Format: {"config": {"request": "...", "response": "..."}, ...} self.queue_overrides = queue_overrides or {} @@ -163,6 +183,15 @@ class DispatcherManager: def dispatch_global_service(self): return DispatcherWrapper(self.process_global_service) + def dispatch_auth_iam(self): + """Pre-configured IAM dispatcher for the gateway's auth + endpoints (login, bootstrap, change-password). Pins the + kind to ``iam`` so these handlers don't have to supply URL + params the global dispatcher would expect.""" + async def _process(data, responder): + return await self.invoke_global_service(data, responder, "iam") + return DispatcherWrapper(_process) + def dispatch_core_export(self): return DispatcherWrapper(self.process_core_export) @@ -314,7 +343,10 @@ class DispatcherManager: async def process_socket(self, ws, running, params): - dispatcher = Mux(self, ws, running) + # The mux self-authenticates via the first-frame protocol; + # pass the gateway's IamAuth so it can validate tokens + # without reaching back into the endpoint layer. + dispatcher = Mux(self, ws, running, auth=self.auth) return dispatcher diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py index 3d610dca..013cd1ea 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py @@ -16,11 +16,28 @@ MAX_QUEUE_SIZE = 10 class Mux: - def __init__(self, dispatcher_manager, ws, running): + def __init__(self, dispatcher_manager, ws, running, auth): + """ + ``auth`` is required — the Mux implements the first-frame + auth protocol described in ``iam.md`` and will refuse any + non-auth frame until an ``auth-ok`` has been issued. There + is no no-auth mode. + """ + if auth is None: + raise ValueError( + "Mux requires an 'auth' argument — there is no " + "no-auth mode" + ) self.dispatcher_manager = dispatcher_manager self.ws = ws self.running = running + self.auth = auth + + # Authenticated identity, populated by the first-frame auth + # protocol. ``None`` means the socket is not yet + # authenticated; any non-auth frame is refused. + self.identity = None self.q = asyncio.Queue(maxsize=MAX_QUEUE_SIZE) @@ -31,6 +48,41 @@ class Mux: if self.ws: await self.ws.close() + async def _handle_auth_frame(self, data): + """Process a ``{"type": "auth", "token": "..."}`` frame. + On success, updates ``self.identity`` and returns an + ``auth-ok`` response frame. On failure, returns the masked + auth-failure frame. Never raises — auth failures keep the + socket open so the client can retry without reconnecting + (important for browsers, which treat a handshake-time 401 + as terminal).""" + token = data.get("token", "") + if not token: + await self.ws.send_json({ + "type": "auth-failed", + "error": "auth failure", + }) + return + + class _Shim: + def __init__(self, tok): + self.headers = {"Authorization": f"Bearer {tok}"} + + try: + identity = await self.auth.authenticate(_Shim(token)) + except Exception: + await self.ws.send_json({ + "type": "auth-failed", + "error": "auth failure", + }) + return + + self.identity = identity + await self.ws.send_json({ + "type": "auth-ok", + "workspace": identity.workspace, + }) + async def receive(self, msg): request_id = None @@ -38,6 +90,16 @@ class Mux: try: data = msg.json() + + # In-band auth protocol: the client sends + # ``{"type": "auth", "token": "..."}`` as its first frame + # (and any time it wants to re-auth: JWT refresh, token + # rotation, etc). Auth is always required on a Mux — + # there is no no-auth mode. + if isinstance(data, dict) and data.get("type") == "auth": + await self._handle_auth_frame(data) + return + request_id = data.get("id") if "request" not in data: @@ -46,9 +108,49 @@ class Mux: if "id" not in data: raise RuntimeError("Bad message") + # Reject all non-auth frames until an ``auth-ok`` has + # been issued. + if self.identity is None: + await self.ws.send_json({ + "id": request_id, + "error": { + "message": "auth failure", + "type": "auth-required", + }, + "complete": True, + }) + return + + # Workspace resolution. Role workspace scope determines + # which target workspaces are permitted. The resolved + # value is written to both the envelope and the inner + # request payload so clients don't have to repeat it + # per-message (same convenience HTTP callers get via + # enforce_workspace). + from ..capabilities import enforce_workspace + from aiohttp import web as _web + + try: + enforce_workspace(data, self.identity) + inner = data.get("request") + if isinstance(inner, dict): + enforce_workspace(inner, self.identity) + except _web.HTTPForbidden: + await self.ws.send_json({ + "id": request_id, + "error": { + "message": "access denied", + "type": "access-denied", + }, + "complete": True, + }) + return + + workspace = data["workspace"] + await self.q.put(( data["id"], - data.get("workspace", "default"), + workspace, data.get("flow"), data["service"], data["request"] diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/auth_endpoints.py b/trustgraph-flow/trustgraph/gateway/endpoint/auth_endpoints.py new file mode 100644 index 00000000..6037fc4b --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/endpoint/auth_endpoints.py @@ -0,0 +1,115 @@ +""" +Gateway auth endpoints. + +Three dedicated paths: + POST /api/v1/auth/login — unauthenticated; username/password → JWT + POST /api/v1/auth/bootstrap — unauthenticated; IAM bootstrap op + POST /api/v1/auth/change-password — authenticated; any role + +These are the only IAM-surface operations that can be reached from +outside. Everything else routes through ``/api/v1/iam`` gated by +``users:admin``. +""" + +import logging + +from aiohttp import web + +from .. capabilities import enforce, PUBLIC, AUTHENTICATED + +logger = logging.getLogger("auth-endpoints") +logger.setLevel(logging.INFO) + + +class AuthEndpoints: + """Groups the three auth-surface handlers. Each forwards to the + IAM service via the existing ``IamRequestor`` dispatcher.""" + + def __init__(self, iam_dispatcher, auth): + self.iam = iam_dispatcher + self.auth = auth + + async def start(self): + pass + + def add_routes(self, app): + app.add_routes([ + web.post("/api/v1/auth/login", self.login), + web.post("/api/v1/auth/bootstrap", self.bootstrap), + web.post( + "/api/v1/auth/change-password", + self.change_password, + ), + ]) + + async def _forward(self, body): + async def responder(x, fin): + pass + return await self.iam.process(body, responder) + + async def login(self, request): + """Public. Accepts {username, password, workspace?}. Returns + {jwt, jwt_expires} on success; IAM's masked auth failure on + anything else.""" + await enforce(request, self.auth, PUBLIC) + try: + body = await request.json() + except Exception: + return web.json_response( + {"error": "invalid json"}, status=400, + ) + req = { + "operation": "login", + "username": body.get("username", ""), + "password": body.get("password", ""), + "workspace": body.get("workspace", ""), + } + resp = await self._forward(req) + if "error" in resp: + return web.json_response( + {"error": "auth failure"}, status=401, + ) + return web.json_response(resp) + + async def bootstrap(self, request): + """Public. Valid only when IAM is running in bootstrap mode + with empty tables. In every other case the IAM service + returns a masked auth-failure.""" + await enforce(request, self.auth, PUBLIC) + resp = await self._forward({"operation": "bootstrap"}) + if "error" in resp: + return web.json_response( + {"error": "auth failure"}, status=401, + ) + return web.json_response(resp) + + async def change_password(self, request): + """Authenticated (any role). Accepts {current_password, + new_password}; user_id is taken from the authenticated + identity — the caller cannot change someone else's password + this way (reset-password is the admin path).""" + identity = await enforce(request, self.auth, AUTHENTICATED) + try: + body = await request.json() + except Exception: + return web.json_response( + {"error": "invalid json"}, status=400, + ) + req = { + "operation": "change-password", + "user_id": identity.user_id, + "password": body.get("current_password", ""), + "new_password": body.get("new_password", ""), + } + resp = await self._forward(req) + if "error" in resp: + err_type = resp.get("error", {}).get("type", "") + if err_type == "auth-failed": + return web.json_response( + {"error": "auth failure"}, status=401, + ) + return web.json_response( + {"error": resp.get("error", {}).get("message", "error")}, + status=400, + ) + return web.json_response(resp) diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/constant_endpoint.py b/trustgraph-flow/trustgraph/gateway/endpoint/constant_endpoint.py index 58ba1738..ee9c0447 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/constant_endpoint.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/constant_endpoint.py @@ -1,28 +1,27 @@ -import asyncio -from aiohttp import web -import uuid import logging +from aiohttp import web + +from .. capabilities import enforce, enforce_workspace + logger = logging.getLogger("endpoint") logger.setLevel(logging.INFO) + class ConstantEndpoint: - def __init__(self, endpoint_path, auth, dispatcher): + def __init__(self, endpoint_path, auth, dispatcher, capability): self.path = endpoint_path - self.auth = auth - self.operation = "service" - + self.capability = capability self.dispatcher = dispatcher async def start(self): pass def add_routes(self, app): - app.add_routes([ web.post(self.path, self.handle), ]) @@ -31,22 +30,14 @@ class ConstantEndpoint: logger.debug(f"Processing request: {request.path}") - try: - ht = request.headers["Authorization"] - tokens = ht.split(" ", 2) - if tokens[0] != "Bearer": - return web.HTTPUnauthorized() - token = tokens[1] - except: - token = "" - - if not self.auth.permitted(token, self.operation): - return web.HTTPUnauthorized() + identity = await enforce(request, self.auth, self.capability) try: - data = await request.json() + if identity is not None: + enforce_workspace(data, identity) + async def responder(x, fin): pass @@ -54,10 +45,8 @@ class ConstantEndpoint: return web.json_response(resp) + except web.HTTPException: + raise except Exception as e: - logging.error(f"Exception: {e}") - - return web.json_response( - { "error": str(e) } - ) - + logger.error(f"Exception: {e}", exc_info=True) + return web.json_response({"error": str(e)}) diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/i18n.py b/trustgraph-flow/trustgraph/gateway/endpoint/i18n.py index b949a499..f28f293d 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/i18n.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/i18n.py @@ -4,16 +4,18 @@ from aiohttp import web from trustgraph.i18n import get_language_pack +from .. capabilities import enforce + logger = logging.getLogger("endpoint") logger.setLevel(logging.INFO) class I18nPackEndpoint: - def __init__(self, endpoint_path: str, auth): + def __init__(self, endpoint_path: str, auth, capability): self.path = endpoint_path self.auth = auth - self.operation = "service" + self.capability = capability async def start(self): pass @@ -26,26 +28,13 @@ class I18nPackEndpoint: async def handle(self, request): logger.debug(f"Processing i18n pack request: {request.path}") - token = "" - try: - ht = request.headers["Authorization"] - tokens = ht.split(" ", 2) - if tokens[0] != "Bearer": - return web.HTTPUnauthorized() - token = tokens[1] - except Exception: - token = "" - - if not self.auth.permitted(token, self.operation): - return web.HTTPUnauthorized() + await enforce(request, self.auth, self.capability) lang = request.match_info.get("lang") or "en" - # This is a path traversal defense, and is a critical sec defense. - # Do not remove! + # Path-traversal defense — critical, do not remove. if "/" in lang or ".." in lang: return web.HTTPBadRequest(reason="Invalid language code") pack = get_language_pack(lang) - return web.json_response(pack) diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/manager.py b/trustgraph-flow/trustgraph/gateway/endpoint/manager.py index fb8b0b76..69b11e07 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/manager.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/manager.py @@ -8,72 +8,278 @@ from . variable_endpoint import VariableEndpoint from . socket import SocketEndpoint from . metrics import MetricsEndpoint from . i18n import I18nPackEndpoint +from . auth_endpoints import AuthEndpoints + +from .. capabilities import PUBLIC, AUTHENTICATED from .. dispatch.manager import DispatcherManager + +# Capability required for each kind on the /api/v1/{kind} generic +# endpoint (global services). Coarse gating — the IAM bundle split +# of "read vs write" per admin subsystem is not applied here because +# this endpoint forwards an opaque operation in the body. Writes +# are the upper bound on what the endpoint can do, so we gate on +# the write/admin capability. +GLOBAL_KIND_CAPABILITY = { + "config": "config:write", + "flow": "flows:write", + "librarian": "documents:write", + "knowledge": "knowledge:write", + "collection-management": "collections:write", + # IAM endpoints land on /api/v1/iam and require the admin bundle. + # Login / bootstrap / change-password are served by + # AuthEndpoints, which handle their own gating (PUBLIC / + # AUTHENTICATED). + "iam": "users:admin", +} + + +# Capability required for each kind on the +# /api/v1/flow/{flow}/service/{kind} endpoint (per-flow data-plane). +FLOW_KIND_CAPABILITY = { + "agent": "agent", + "text-completion": "llm", + "prompt": "llm", + "mcp-tool": "mcp", + "graph-rag": "graph:read", + "document-rag": "documents:read", + "embeddings": "embeddings", + "graph-embeddings": "graph:read", + "document-embeddings": "documents:read", + "triples": "graph:read", + "rows": "rows:read", + "nlp-query": "rows:read", + "structured-query": "rows:read", + "structured-diag": "rows:read", + "row-embeddings": "rows:read", + "sparql": "graph:read", +} + + +# Capability for the streaming flow import/export endpoints, +# keyed by the "kind" URL segment. +FLOW_IMPORT_CAPABILITY = { + "triples": "graph:write", + "graph-embeddings": "graph:write", + "document-embeddings": "documents:write", + "entity-contexts": "documents:write", + "rows": "rows:write", +} + +FLOW_EXPORT_CAPABILITY = { + "triples": "graph:read", + "graph-embeddings": "graph:read", + "document-embeddings": "documents:read", + "entity-contexts": "documents:read", +} + + +from .. capabilities import enforce, enforce_workspace +import logging as _mgr_logging +_mgr_logger = _mgr_logging.getLogger("endpoint") + + +class _RoutedVariableEndpoint: + """HTTP endpoint whose required capability is looked up per + request from the URL's ``kind`` parameter. Used for the two + generic dispatch paths (``/api/v1/{kind}`` and + ``/api/v1/flow/{flow}/service/{kind}``). Self-contained rather + than subclassing ``VariableEndpoint`` to avoid mutating shared + state across concurrent requests.""" + + def __init__(self, endpoint_path, auth, dispatcher, capability_map): + self.path = endpoint_path + self.auth = auth + self.dispatcher = dispatcher + self._capability_map = capability_map + + async def start(self): + pass + + def add_routes(self, app): + app.add_routes([web.post(self.path, self.handle)]) + + async def handle(self, request): + kind = request.match_info.get("kind", "") + cap = self._capability_map.get(kind) + if cap is None: + return web.json_response( + {"error": "unknown kind"}, status=404, + ) + + identity = await enforce(request, self.auth, cap) + + try: + data = await request.json() + if identity is not None: + enforce_workspace(data, identity) + + async def responder(x, fin): + pass + + resp = await self.dispatcher.process( + data, responder, request.match_info, + ) + return web.json_response(resp) + + except web.HTTPException: + raise + except Exception as e: + _mgr_logger.error(f"Exception: {e}", exc_info=True) + return web.json_response({"error": str(e)}) + + +class _RoutedSocketEndpoint: + """WebSocket endpoint whose required capability is looked up per + request from the URL's ``kind`` parameter. Used for the flow + import/export streaming endpoints.""" + + def __init__(self, endpoint_path, auth, dispatcher, capability_map): + self.path = endpoint_path + self.auth = auth + self.dispatcher = dispatcher + self._capability_map = capability_map + + async def start(self): + pass + + def add_routes(self, app): + app.add_routes([web.get(self.path, self.handle)]) + + async def handle(self, request): + from .. capabilities import check, auth_failure, access_denied + + kind = request.match_info.get("kind", "") + cap = self._capability_map.get(kind) + if cap is None: + return web.json_response( + {"error": "unknown kind"}, status=404, + ) + + token = request.query.get("token", "") + if not token: + return auth_failure() + + from . socket import _QueryTokenRequest + try: + identity = await self.auth.authenticate( + _QueryTokenRequest(token) + ) + except web.HTTPException as e: + return e + if not check(identity, cap): + return access_denied() + + # Delegate the websocket handling to a standalone SocketEndpoint + # with the resolved capability, bypassing the per-request mutation + # concern by instantiating fresh state. + ws_ep = SocketEndpoint( + endpoint_path=self.path, + auth=self.auth, + dispatcher=self.dispatcher, + capability=cap, + ) + return await ws_ep.handle(request) + + class EndpointManager: def __init__( - self, dispatcher_manager, auth, prometheus_url, timeout=600 + self, dispatcher_manager, auth, prometheus_url, timeout=600, ): self.dispatcher_manager = dispatcher_manager self.timeout = timeout - self.services = { - } - self.endpoints = [ + + # Auth surface — public / authenticated-any. Must come + # before the generic /api/v1/{kind} routes to win the + # match for /api/v1/auth/* paths. aiohttp routes in + # registration order, so we prepend here. + AuthEndpoints( + iam_dispatcher=dispatcher_manager.dispatch_auth_iam(), + auth=auth, + ), + I18nPackEndpoint( - endpoint_path = "/api/v1/i18n/packs/{lang}", - auth = auth, + endpoint_path="/api/v1/i18n/packs/{lang}", + auth=auth, + capability=PUBLIC, ), MetricsEndpoint( - endpoint_path = "/api/metrics", - prometheus_url = prometheus_url, - auth = auth, + endpoint_path="/api/metrics", + prometheus_url=prometheus_url, + auth=auth, + capability="metrics:read", ), - VariableEndpoint( - endpoint_path = "/api/v1/{kind}", auth = auth, - dispatcher = dispatcher_manager.dispatch_global_service(), + + # Global services: capability chosen per-kind. + _RoutedVariableEndpoint( + endpoint_path="/api/v1/{kind}", + auth=auth, + dispatcher=dispatcher_manager.dispatch_global_service(), + capability_map=GLOBAL_KIND_CAPABILITY, ), + + # /api/v1/socket: WebSocket handshake accepts + # unconditionally; the Mux dispatcher runs the + # first-frame auth protocol. Handshake-time 401s break + # browser reconnection, so authentication is always + # in-band for this endpoint. SocketEndpoint( - endpoint_path = "/api/v1/socket", - auth = auth, - dispatcher = dispatcher_manager.dispatch_socket() + endpoint_path="/api/v1/socket", + auth=auth, + dispatcher=dispatcher_manager.dispatch_socket(), + capability=AUTHENTICATED, # informational only; bypassed + in_band_auth=True, ), - VariableEndpoint( - endpoint_path = "/api/v1/flow/{flow}/service/{kind}", - auth = auth, - dispatcher = dispatcher_manager.dispatch_flow_service(), + + # Per-flow request/response services — capability per kind. + _RoutedVariableEndpoint( + endpoint_path="/api/v1/flow/{flow}/service/{kind}", + auth=auth, + dispatcher=dispatcher_manager.dispatch_flow_service(), + capability_map=FLOW_KIND_CAPABILITY, ), - SocketEndpoint( - endpoint_path = "/api/v1/flow/{flow}/import/{kind}", - auth = auth, - dispatcher = dispatcher_manager.dispatch_flow_import() + + # Per-flow streaming import/export — capability per kind. + _RoutedSocketEndpoint( + endpoint_path="/api/v1/flow/{flow}/import/{kind}", + auth=auth, + dispatcher=dispatcher_manager.dispatch_flow_import(), + capability_map=FLOW_IMPORT_CAPABILITY, ), - SocketEndpoint( - endpoint_path = "/api/v1/flow/{flow}/export/{kind}", - auth = auth, - dispatcher = dispatcher_manager.dispatch_flow_export() + _RoutedSocketEndpoint( + endpoint_path="/api/v1/flow/{flow}/export/{kind}", + auth=auth, + dispatcher=dispatcher_manager.dispatch_flow_export(), + capability_map=FLOW_EXPORT_CAPABILITY, + ), + + StreamEndpoint( + endpoint_path="/api/v1/import-core", + auth=auth, + method="POST", + dispatcher=dispatcher_manager.dispatch_core_import(), + # Cross-subject import — require the admin bundle via a + # single representative capability. + capability="users:admin", ), StreamEndpoint( - endpoint_path = "/api/v1/import-core", - auth = auth, - method = "POST", - dispatcher = dispatcher_manager.dispatch_core_import(), + endpoint_path="/api/v1/export-core", + auth=auth, + method="GET", + dispatcher=dispatcher_manager.dispatch_core_export(), + capability="users:admin", ), StreamEndpoint( - endpoint_path = "/api/v1/export-core", - auth = auth, - method = "GET", - dispatcher = dispatcher_manager.dispatch_core_export(), - ), - StreamEndpoint( - endpoint_path = "/api/v1/document-stream", - auth = auth, - method = "GET", - dispatcher = dispatcher_manager.dispatch_document_stream(), + endpoint_path="/api/v1/document-stream", + auth=auth, + method="GET", + dispatcher=dispatcher_manager.dispatch_document_stream(), + capability="documents:read", ), ] @@ -84,4 +290,3 @@ class EndpointManager: async def start(self): for ep in self.endpoints: await ep.start() - diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/metrics.py b/trustgraph-flow/trustgraph/gateway/endpoint/metrics.py index 903a199c..6832d1e3 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/metrics.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/metrics.py @@ -10,17 +10,19 @@ import asyncio import uuid import logging +from .. capabilities import enforce + logger = logging.getLogger("endpoint") logger.setLevel(logging.INFO) class MetricsEndpoint: - def __init__(self, prometheus_url, endpoint_path, auth): + def __init__(self, prometheus_url, endpoint_path, auth, capability): self.prometheus_url = prometheus_url self.path = endpoint_path self.auth = auth - self.operation = "service" + self.capability = capability async def start(self): pass @@ -35,17 +37,7 @@ class MetricsEndpoint: logger.debug(f"Processing metrics request: {request.path}") - try: - ht = request.headers["Authorization"] - tokens = ht.split(" ", 2) - if tokens[0] != "Bearer": - return web.HTTPUnauthorized() - token = tokens[1] - except: - token = "" - - if not self.auth.permitted(token, self.operation): - return web.HTTPUnauthorized() + await enforce(request, self.auth, self.capability) path = request.match_info["path"] url = ( diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/socket.py b/trustgraph-flow/trustgraph/gateway/endpoint/socket.py index 9065761c..08629ea2 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/socket.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/socket.py @@ -4,6 +4,9 @@ from aiohttp import web, WSMsgType import logging from .. running import Running +from .. capabilities import ( + PUBLIC, AUTHENTICATED, check, auth_failure, access_denied, +) logger = logging.getLogger("socket") logger.setLevel(logging.INFO) @@ -11,12 +14,25 @@ logger.setLevel(logging.INFO) class SocketEndpoint: def __init__( - self, endpoint_path, auth, dispatcher, + self, endpoint_path, auth, dispatcher, capability, + in_band_auth=False, ): + """ + ``in_band_auth=True`` skips the handshake-time auth check. + The WebSocket handshake always succeeds; the dispatcher is + expected to gate itself via the first-frame auth protocol + (see ``Mux``). + + This avoids the browser problem where a 401 on the handshake + is treated as permanent and prevents reconnection, and lets + long-lived sockets refresh their credential mid-session by + sending a new auth frame. + """ self.path = endpoint_path self.auth = auth - self.operation = "socket" + self.capability = capability + self.in_band_auth = in_band_auth self.dispatcher = dispatcher @@ -61,15 +77,29 @@ class SocketEndpoint: raise async def handle(self, request): - """Enhanced handler with better cleanup""" - try: - token = request.query['token'] - except: - token = "" + """Enhanced handler with better cleanup. + + Auth: WebSocket clients pass the bearer token on the + ``?token=...`` query string; we wrap it into a synthetic + Authorization header before delegating to the standard auth + path so the IAM-backed flow (JWT / API key) applies uniformly. + The first-frame auth protocol described in the IAM spec is + a future upgrade.""" + + if not self.in_band_auth and self.capability != PUBLIC: + token = request.query.get("token", "") + if not token: + return auth_failure() + try: + identity = await self.auth.authenticate( + _QueryTokenRequest(token) + ) + except web.HTTPException as e: + return e + if self.capability != AUTHENTICATED: + if not check(identity, self.capability): + return access_denied() - if not self.auth.permitted(token, self.operation): - return web.HTTPUnauthorized() - # 50MB max message size ws = web.WebSocketResponse(max_msg_size=52428800) @@ -150,3 +180,11 @@ class SocketEndpoint: web.get(self.path, self.handle), ]) + +class _QueryTokenRequest: + """Minimal shim that exposes headers["Authorization"] to + IamAuth.authenticate(), derived from a query-string token.""" + + def __init__(self, token): + self.headers = {"Authorization": f"Bearer {token}"} + diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/stream_endpoint.py b/trustgraph-flow/trustgraph/gateway/endpoint/stream_endpoint.py index 38d8846f..7b0c4692 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/stream_endpoint.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/stream_endpoint.py @@ -1,82 +1,64 @@ -import asyncio -from aiohttp import web import logging +from aiohttp import web + +from .. capabilities import enforce + logger = logging.getLogger("endpoint") logger.setLevel(logging.INFO) + class StreamEndpoint: - def __init__(self, endpoint_path, auth, dispatcher, method="POST"): - + def __init__( + self, endpoint_path, auth, dispatcher, capability, method="POST", + ): self.path = endpoint_path - self.auth = auth - self.operation = "service" + self.capability = capability self.method = method - self.dispatcher = dispatcher async def start(self): pass def add_routes(self, app): - if self.method == "POST": - app.add_routes([ - web.post(self.path, self.handle), - ]) + app.add_routes([web.post(self.path, self.handle)]) elif self.method == "GET": - app.add_routes([ - web.get(self.path, self.handle), - ]) + app.add_routes([web.get(self.path, self.handle)]) else: - raise RuntimeError("Bad method" + self.method) + raise RuntimeError("Bad method " + self.method) async def handle(self, request): logger.debug(f"Processing request: {request.path}") - try: - ht = request.headers["Authorization"] - tokens = ht.split(" ", 2) - if tokens[0] != "Bearer": - return web.HTTPUnauthorized() - token = tokens[1] - except: - token = "" - - if not self.auth.permitted(token, self.operation): - return web.HTTPUnauthorized() + await enforce(request, self.auth, self.capability) try: - data = request.content async def error(err): - return web.HTTPInternalServerError(text = err) + return web.HTTPInternalServerError(text=err) async def ok( - status=200, reason="OK", type="application/octet-stream" + status=200, reason="OK", + type="application/octet-stream", ): response = web.StreamResponse( - status = status, reason = reason, - headers = {"Content-Type": type} + status=status, reason=reason, + headers={"Content-Type": type}, ) await response.prepare(request) return response - resp = await self.dispatcher.process( - data, error, ok, request - ) - + resp = await self.dispatcher.process(data, error, ok, request) return resp + except web.HTTPException: + raise except Exception as e: - logging.error(f"Exception: {e}") - - return web.json_response( - { "error": str(e) } - ) - + logger.error(f"Exception: {e}", exc_info=True) + return web.json_response({"error": str(e)}) diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/variable_endpoint.py b/trustgraph-flow/trustgraph/gateway/endpoint/variable_endpoint.py index 608de71b..5e0d9d21 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/variable_endpoint.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/variable_endpoint.py @@ -1,27 +1,27 @@ -import asyncio -from aiohttp import web import logging +from aiohttp import web + +from .. capabilities import enforce, enforce_workspace + logger = logging.getLogger("endpoint") logger.setLevel(logging.INFO) + class VariableEndpoint: - def __init__(self, endpoint_path, auth, dispatcher): + def __init__(self, endpoint_path, auth, dispatcher, capability): self.path = endpoint_path - self.auth = auth - self.operation = "service" - + self.capability = capability self.dispatcher = dispatcher async def start(self): pass def add_routes(self, app): - app.add_routes([ web.post(self.path, self.handle), ]) @@ -30,35 +30,25 @@ class VariableEndpoint: logger.debug(f"Processing request: {request.path}") - try: - ht = request.headers["Authorization"] - tokens = ht.split(" ", 2) - if tokens[0] != "Bearer": - return web.HTTPUnauthorized() - token = tokens[1] - except: - token = "" - - if not self.auth.permitted(token, self.operation): - return web.HTTPUnauthorized() + identity = await enforce(request, self.auth, self.capability) try: - data = await request.json() + if identity is not None: + enforce_workspace(data, identity) + async def responder(x, fin): pass resp = await self.dispatcher.process( - data, responder, request.match_info + data, responder, request.match_info, ) return web.json_response(resp) + except web.HTTPException: + raise except Exception as e: - logging.error(f"Exception: {e}") - - return web.json_response( - { "error": str(e) } - ) - + logger.error(f"Exception: {e}", exc_info=True) + return web.json_response({"error": str(e)}) diff --git a/trustgraph-flow/trustgraph/gateway/service.py b/trustgraph-flow/trustgraph/gateway/service.py index 4e465bf7..f75f3b25 100755 --- a/trustgraph-flow/trustgraph/gateway/service.py +++ b/trustgraph-flow/trustgraph/gateway/service.py @@ -12,7 +12,7 @@ import os from trustgraph.base.logging import setup_logging, add_logging_args from trustgraph.base.pubsub import get_pubsub, add_pubsub_args -from . auth import Authenticator +from . auth import IamAuth from . config.receiver import ConfigReceiver from . dispatch.manager import DispatcherManager @@ -35,7 +35,6 @@ default_prometheus_url = os.getenv("PROMETHEUS_URL", "http://prometheus:9090") default_pulsar_api_key = os.getenv("PULSAR_API_KEY", None) default_timeout = 600 default_port = 8088 -default_api_token = os.getenv("GATEWAY_SECRET", "") class Api: @@ -60,13 +59,14 @@ class Api: if not self.prometheus_url.endswith("/"): self.prometheus_url += "/" - api_token = config.get("api_token", default_api_token) - - # Token not set, or token equal empty string means no auth - if api_token: - self.auth = Authenticator(token=api_token) - else: - self.auth = Authenticator(allow_all=True) + # IAM-backed authentication. The legacy GATEWAY_SECRET + # shared-token path has been removed — there is no + # "open for everyone" fallback. The gateway cannot + # authenticate any request until IAM is reachable. + self.auth = IamAuth( + backend=self.pubsub_backend, + id=config.get("id", "api-gateway"), + ) self.config_receiver = ConfigReceiver(self.pubsub_backend) @@ -118,6 +118,7 @@ class Api: config_receiver = self.config_receiver, prefix = "gateway", queue_overrides = queue_overrides, + auth = self.auth, ) self.endpoint_manager = EndpointManager( @@ -132,12 +133,18 @@ class Api: ] async def app_factory(self): - + self.app = web.Application( middlewares=[], client_max_size=256 * 1024 * 1024 ) + # Fetch IAM signing public key before accepting traffic. + # Blocks for a bounded retry window; the gateway starts even + # if IAM is still unreachable (JWT validation will 401 until + # the key is available). + await self.auth.start() + await self.config_receiver.start() for ep in self.endpoints: @@ -189,12 +196,6 @@ def run(): help=f'API request timeout in seconds (default: {default_timeout})', ) - parser.add_argument( - '--api-token', - default=default_api_token, - help=f'Secret API token (default: no auth)', - ) - add_logging_args(parser) parser.add_argument( diff --git a/trustgraph-flow/trustgraph/iam/__init__.py b/trustgraph-flow/trustgraph/iam/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/trustgraph-flow/trustgraph/iam/service/__init__.py b/trustgraph-flow/trustgraph/iam/service/__init__.py new file mode 100644 index 00000000..98f4d9da --- /dev/null +++ b/trustgraph-flow/trustgraph/iam/service/__init__.py @@ -0,0 +1 @@ +from . service import * diff --git a/trustgraph-flow/trustgraph/iam/service/__main__.py b/trustgraph-flow/trustgraph/iam/service/__main__.py new file mode 100644 index 00000000..a731dd63 --- /dev/null +++ b/trustgraph-flow/trustgraph/iam/service/__main__.py @@ -0,0 +1,4 @@ + +from . service import run + +run() diff --git a/trustgraph-flow/trustgraph/iam/service/iam.py b/trustgraph-flow/trustgraph/iam/service/iam.py new file mode 100644 index 00000000..6e7c7aa5 --- /dev/null +++ b/trustgraph-flow/trustgraph/iam/service/iam.py @@ -0,0 +1,1132 @@ +""" +IAM business logic. Handles ``IamRequest`` messages and builds +``IamResponse`` messages. Does not concern itself with transport. + +See docs/tech-specs/iam-protocol.md for the wire-level contract and +docs/tech-specs/iam.md for the surrounding architecture. +""" + +import asyncio +import base64 +import datetime +import hashlib +import json +import logging +import os +import secrets +import uuid + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 + +from trustgraph.schema import ( + IamResponse, Error, + UserRecord, WorkspaceRecord, ApiKeyRecord, +) + +from ... tables.iam import IamTableStore + +logger = logging.getLogger(__name__) + + +DEFAULT_WORKSPACE = "default" +BOOTSTRAP_ADMIN_USERNAME = "admin" +BOOTSTRAP_ADMIN_NAME = "Administrator" + +PBKDF2_ITERATIONS = 600_000 +API_KEY_PREFIX = "tg_" +API_KEY_RANDOM_BYTES = 24 + +JWT_ISSUER = "trustgraph-iam" +JWT_TTL_SECONDS = 3600 + + +def _now_iso(): + return datetime.datetime.now(datetime.timezone.utc).isoformat() + + +def _now_dt(): + return datetime.datetime.now(datetime.timezone.utc) + + +def _iso(dt): + if dt is None: + return "" + if isinstance(dt, str): + return dt + if dt.tzinfo is None: + dt = dt.replace(tzinfo=datetime.timezone.utc) + return dt.isoformat() + + +def _hash_password(password): + """Return an encoded PBKDF2-SHA-256 hash of ``password``. + + Format: ``pbkdf2-sha256$$$``. Stored + verbatim in the password_hash column so the algorithm and cost + can be evolved later (new rows get a new prefix; old rows are + verified with their own parameters). + """ + salt = os.urandom(16) + dk = hashlib.pbkdf2_hmac( + "sha256", password.encode("utf-8"), salt, PBKDF2_ITERATIONS, + ) + return ( + f"pbkdf2-sha256${PBKDF2_ITERATIONS}" + f"${base64.b64encode(salt).decode('ascii')}" + f"${base64.b64encode(dk).decode('ascii')}" + ) + + +def _verify_password(password, encoded): + """Constant-time verify ``password`` against an encoded hash.""" + try: + algo, iters, b64_salt, b64_hash = encoded.split("$") + except ValueError: + return False + if algo != "pbkdf2-sha256": + return False + try: + iters = int(iters) + salt = base64.b64decode(b64_salt) + target = base64.b64decode(b64_hash) + except Exception: + return False + dk = hashlib.pbkdf2_hmac( + "sha256", password.encode("utf-8"), salt, iters, + ) + return secrets.compare_digest(dk, target) + + +def _generate_api_key(): + """Return a fresh API-key plaintext of the form ``tg_``.""" + return API_KEY_PREFIX + secrets.token_urlsafe(API_KEY_RANDOM_BYTES) + + +def _hash_api_key(plaintext): + """SHA-256 hex digest of an API key plaintext. Used as the + primary key in ``iam_api_keys`` so ``resolve-api-key`` is O(1).""" + return hashlib.sha256(plaintext.encode("utf-8")).hexdigest() + + +def _err(type, message): + return IamResponse(error=Error(type=type, message=message)) + + +def _parse_expires(s): + if not s: + return None + try: + return datetime.datetime.fromisoformat(s) + except Exception: + return None + + +def _b64url(data): + """URL-safe base64 encode without padding, as required by JWT.""" + return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") + + +def _generate_signing_keypair(): + """Return (kid, private_pem, public_pem) for a fresh Ed25519 + keypair. Ed25519 / EdDSA: small (32-byte public key), fast, + deterministic, side-channel-resistant by construction, free of + NIST-curve baggage.""" + key = ed25519.Ed25519PrivateKey.generate() + private_pem = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode("ascii") + public_pem = key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ).decode("ascii") + kid = uuid.uuid4().hex[:16] + return kid, private_pem, public_pem + + +def _sign_jwt(kid, private_pem, claims): + """Produce a compact-serialisation EdDSA (Ed25519) JWT for + ``claims``.""" + key = serialization.load_pem_private_key( + private_pem.encode("ascii"), password=None, + ) + if not isinstance(key, ed25519.Ed25519PrivateKey): + raise RuntimeError( + f"signing key is not Ed25519: {type(key).__name__}" + ) + + header = {"alg": "EdDSA", "typ": "JWT", "kid": kid} + header_b = _b64url(json.dumps( + header, separators=(",", ":"), sort_keys=True, + ).encode("utf-8")) + payload_b = _b64url(json.dumps( + claims, separators=(",", ":"), sort_keys=True, + ).encode("utf-8")) + signing_input = f"{header_b}.{payload_b}".encode("ascii") + signature = key.sign(signing_input) + + return f"{header_b}.{payload_b}.{_b64url(signature)}" + + +class IamService: + + def __init__(self, host, username, password, keyspace, + bootstrap_mode, bootstrap_token=None): + self.table_store = IamTableStore( + host, username, password, keyspace, + ) + # bootstrap_mode: "token" or "bootstrap". In "token" mode the + # service auto-seeds on first start using the provided + # bootstrap_token and the ``bootstrap`` operation is refused + # thereafter (indistinguishable from an already-bootstrapped + # deployment per the error policy). In "bootstrap" mode the + # ``bootstrap`` operation is live until tables are populated. + if bootstrap_mode not in ("token", "bootstrap"): + raise ValueError( + f"bootstrap_mode must be 'token' or 'bootstrap', " + f"got {bootstrap_mode!r}" + ) + if bootstrap_mode == "token" and not bootstrap_token: + raise ValueError( + "bootstrap_mode='token' requires bootstrap_token" + ) + self.bootstrap_mode = bootstrap_mode + self.bootstrap_token = bootstrap_token + + self._signing_key = None + self._signing_key_lock = asyncio.Lock() + + # ------------------------------------------------------------------ + # Dispatch + # ------------------------------------------------------------------ + + async def handle(self, v): + op = v.operation + + try: + if op == "bootstrap": + return await self.handle_bootstrap(v) + if op == "resolve-api-key": + return await self.handle_resolve_api_key(v) + if op == "create-user": + return await self.handle_create_user(v) + if op == "list-users": + return await self.handle_list_users(v) + if op == "create-api-key": + return await self.handle_create_api_key(v) + if op == "list-api-keys": + return await self.handle_list_api_keys(v) + if op == "revoke-api-key": + return await self.handle_revoke_api_key(v) + if op == "login": + return await self.handle_login(v) + if op == "get-signing-key-public": + return await self.handle_get_signing_key_public(v) + if op == "change-password": + return await self.handle_change_password(v) + if op == "reset-password": + return await self.handle_reset_password(v) + if op == "get-user": + return await self.handle_get_user(v) + if op == "update-user": + return await self.handle_update_user(v) + if op == "disable-user": + return await self.handle_disable_user(v) + if op == "enable-user": + return await self.handle_enable_user(v) + if op == "delete-user": + return await self.handle_delete_user(v) + if op == "create-workspace": + return await self.handle_create_workspace(v) + if op == "list-workspaces": + return await self.handle_list_workspaces(v) + if op == "get-workspace": + return await self.handle_get_workspace(v) + if op == "update-workspace": + return await self.handle_update_workspace(v) + if op == "disable-workspace": + return await self.handle_disable_workspace(v) + if op == "rotate-signing-key": + return await self.handle_rotate_signing_key(v) + + return _err( + "invalid-argument", + f"unknown or not-yet-implemented operation: {op!r}", + ) + + except Exception as e: + logger.error( + f"IAM {op} failed: {type(e).__name__}: {e}", + exc_info=True, + ) + return _err("internal-error", str(e)) + + # ------------------------------------------------------------------ + # Record conversion + # ------------------------------------------------------------------ + + def _row_to_user_record(self, row): + ( + id, workspace, username, name, email, _password_hash, + roles, enabled, must_change_password, created, + ) = row + return UserRecord( + id=id or "", + workspace=workspace or "", + username=username or "", + name=name or "", + email=email or "", + roles=sorted(roles) if roles else [], + enabled=bool(enabled), + must_change_password=bool(must_change_password), + created=_iso(created), + ) + + def _row_to_api_key_record(self, row): + ( + _key_hash, id, user_id, name, prefix, expires, + created, last_used, + ) = row + return ApiKeyRecord( + id=id or "", + user_id=user_id or "", + name=name or "", + prefix=prefix or "", + expires=_iso(expires), + created=_iso(created), + last_used=_iso(last_used), + ) + + # ------------------------------------------------------------------ + # bootstrap + # ------------------------------------------------------------------ + + async def auto_bootstrap_if_token_mode(self): + """Called from the service processor at startup. In + ``token`` mode, if tables are empty, seeds the default + workspace / admin / signing key using the operator-provided + bootstrap token. The admin's API key plaintext is *the* + ``bootstrap_token`` — the operator already knows it, nothing + needs to be returned or logged. + + In ``bootstrap`` mode this is a no-op; seeding happens on + explicit ``bootstrap`` operation invocation.""" + if self.bootstrap_mode != "token": + return + + if await self.table_store.any_workspace_exists(): + logger.info( + "IAM: token mode, tables already populated; skipping " + "auto-bootstrap" + ) + return + + logger.info("IAM: token mode, empty tables; auto-bootstrapping") + await self._seed_tables(self.bootstrap_token) + logger.info( + "IAM: auto-bootstrap complete using operator-provided token" + ) + + async def _seed_tables(self, api_key_plaintext): + """Shared seeding logic used by token-mode auto-bootstrap and + bootstrap-mode handle_bootstrap. Creates the default + workspace, admin user, admin API key (using the given + plaintext), and an initial signing key. Returns the admin + user id.""" + now = _now_dt() + + await self.table_store.put_workspace( + id=DEFAULT_WORKSPACE, + name="Default", + enabled=True, + created=now, + ) + + admin_user_id = str(uuid.uuid4()) + admin_password = secrets.token_urlsafe(32) + await self.table_store.put_user( + id=admin_user_id, + workspace=DEFAULT_WORKSPACE, + username=BOOTSTRAP_ADMIN_USERNAME, + name=BOOTSTRAP_ADMIN_NAME, + email="", + password_hash=_hash_password(admin_password), + roles=["admin"], + enabled=True, + must_change_password=True, + created=now, + ) + + key_id = str(uuid.uuid4()) + await self.table_store.put_api_key( + key_hash=_hash_api_key(api_key_plaintext), + id=key_id, + user_id=admin_user_id, + name="bootstrap", + prefix=api_key_plaintext[:len(API_KEY_PREFIX) + 4], + expires=None, + created=now, + last_used=None, + ) + + kid, private_pem, public_pem = _generate_signing_keypair() + await self.table_store.put_signing_key( + kid=kid, + private_pem=private_pem, + public_pem=public_pem, + created=now, + retired=None, + ) + self._signing_key = (kid, private_pem, public_pem) + + logger.info( + f"IAM seeded: workspace={DEFAULT_WORKSPACE!r}, " + f"admin user_id={admin_user_id}, signing key kid={kid}" + ) + return admin_user_id + + async def handle_bootstrap(self, v): + """Explicit bootstrap op. Only available in ``bootstrap`` + mode and only when tables are empty. Every other case is + masked to a generic auth failure — the caller cannot + distinguish 'not in bootstrap mode' from 'already + bootstrapped' from 'operation forbidden'.""" + if self.bootstrap_mode != "bootstrap": + return _err("auth-failed", "auth failure") + + if await self.table_store.any_workspace_exists(): + return _err("auth-failed", "auth failure") + + plaintext = _generate_api_key() + admin_user_id = await self._seed_tables(plaintext) + + return IamResponse( + bootstrap_admin_user_id=admin_user_id, + bootstrap_admin_api_key=plaintext, + ) + + # ------------------------------------------------------------------ + # Signing key helpers + # ------------------------------------------------------------------ + + async def _get_active_signing_key(self): + """Return ``(kid, private_pem, public_pem)`` for the active + signing key. Loads from Cassandra on first call. Generates + and persists a new key if none exists — covers the case where + ``login`` is called before ``bootstrap`` (shouldn't happen in + practice but keeps the service internally consistent).""" + if self._signing_key is not None: + return self._signing_key + + async with self._signing_key_lock: + if self._signing_key is not None: + return self._signing_key + + rows = await self.table_store.list_signing_keys() + active = [r for r in rows if r[4] is None] + + if active: + row = active[0] + self._signing_key = (row[0], row[1], row[2]) + logger.info( + f"IAM: loaded active signing key kid={row[0]}" + ) + return self._signing_key + + kid, private_pem, public_pem = _generate_signing_keypair() + await self.table_store.put_signing_key( + kid=kid, + private_pem=private_pem, + public_pem=public_pem, + created=_now_dt(), + retired=None, + ) + self._signing_key = (kid, private_pem, public_pem) + logger.info( + f"IAM: generated active signing key kid={kid} " + f"(no existing key found)" + ) + return self._signing_key + + # ------------------------------------------------------------------ + # login + # ------------------------------------------------------------------ + + async def handle_login(self, v): + if not v.username: + return _err("auth-failed", "username required") + if not v.password: + return _err("auth-failed", "password required") + + # Login accepts an optional workspace parameter. If omitted + # we use the default workspace (OSS single-workspace + # assumption). Multi-workspace enterprise editions swap in a + # resolver that looks across the caller's permitted set. + workspace = v.workspace or DEFAULT_WORKSPACE + + user_id = await self.table_store.get_user_id_by_username( + workspace, v.username, + ) + if not user_id: + return _err("auth-failed", "no such user") + + user_row = await self.table_store.get_user(user_id) + if user_row is None: + return _err("auth-failed", "user disappeared") + + ( + id, ws, _username, _name, _email, password_hash, + roles, enabled, _mcp, _created, + ) = user_row + + if not enabled: + return _err("auth-failed", "user disabled") + if not password_hash or not _verify_password( + v.password, password_hash, + ): + return _err("auth-failed", "bad credentials") + + ws_row = await self.table_store.get_workspace(ws) + if ws_row is None or not ws_row[2]: + return _err("auth-failed", "workspace disabled") + + kid, private_pem, _ = await self._get_active_signing_key() + + now_ts = int(_now_dt().timestamp()) + exp_ts = now_ts + JWT_TTL_SECONDS + claims = { + "iss": JWT_ISSUER, + "sub": id, + "workspace": ws, + "roles": sorted(roles) if roles else [], + "iat": now_ts, + "exp": exp_ts, + } + token = _sign_jwt(kid, private_pem, claims) + + expires_iso = datetime.datetime.fromtimestamp( + exp_ts, tz=datetime.timezone.utc, + ).isoformat() + + return IamResponse(jwt=token, jwt_expires=expires_iso) + + # ------------------------------------------------------------------ + # get-signing-key-public + # ------------------------------------------------------------------ + + async def handle_get_signing_key_public(self, v): + _, _, public_pem = await self._get_active_signing_key() + return IamResponse(signing_key_public=public_pem) + + # ------------------------------------------------------------------ + # Record-conversion helper for workspaces + # ------------------------------------------------------------------ + + def _row_to_workspace_record(self, row): + id, name, enabled, created = row + return WorkspaceRecord( + id=id or "", + name=name or "", + enabled=bool(enabled), + created=_iso(created), + ) + + async def _user_in_workspace(self, user_id, workspace): + """Return (user_row, error_response_or_None). Loads the user + record, verifies it exists, is enabled, and belongs to + ``workspace``. The workspace scope check rejects cross- + workspace admin attempts.""" + user_row = await self.table_store.get_user(user_id) + if user_row is None: + return None, _err("not-found", "user not found") + if user_row[1] != workspace: + return None, _err( + "operation-not-permitted", + "user is in a different workspace", + ) + return user_row, None + + # ------------------------------------------------------------------ + # change-password + # ------------------------------------------------------------------ + + async def handle_change_password(self, v): + if not v.user_id: + return _err("invalid-argument", "user_id required") + if not v.password: + return _err("invalid-argument", "password (current) required") + if not v.new_password: + return _err("invalid-argument", "new_password required") + + user_row = await self.table_store.get_user(v.user_id) + if user_row is None: + return _err("auth-failed", "no such user") + + _id, _ws, _un, _name, _email, password_hash, _r, enabled, _mcp, _c = ( + user_row + ) + if not enabled: + return _err("auth-failed", "user disabled") + if not password_hash or not _verify_password( + v.password, password_hash, + ): + return _err("auth-failed", "bad credentials") + + await self.table_store.update_user_password( + id=v.user_id, + password_hash=_hash_password(v.new_password), + must_change_password=False, + ) + return IamResponse() + + # ------------------------------------------------------------------ + # reset-password + # ------------------------------------------------------------------ + + async def handle_reset_password(self, v): + if not v.workspace: + return _err( + "invalid-argument", + "workspace required for reset-password", + ) + if not v.user_id: + return _err("invalid-argument", "user_id required") + + _, err = await self._user_in_workspace(v.user_id, v.workspace) + if err is not None: + return err + + temporary = secrets.token_urlsafe(12) + await self.table_store.update_user_password( + id=v.user_id, + password_hash=_hash_password(temporary), + must_change_password=True, + ) + return IamResponse(temporary_password=temporary) + + # ------------------------------------------------------------------ + # get-user / update-user / disable-user + # ------------------------------------------------------------------ + + async def handle_get_user(self, v): + if not v.workspace: + return _err("invalid-argument", "workspace required") + if not v.user_id: + return _err("invalid-argument", "user_id required") + + user_row, err = await self._user_in_workspace( + v.user_id, v.workspace, + ) + if err is not None: + return err + return IamResponse(user=self._row_to_user_record(user_row)) + + async def handle_update_user(self, v): + """Update user profile fields: name, email, roles, enabled, + must_change_password. Username is immutable — change it by + creating a new user and disabling the old one. Password + changes go through change-password / reset-password.""" + if not v.workspace: + return _err("invalid-argument", "workspace required") + if not v.user_id: + return _err("invalid-argument", "user_id required") + if v.user is None: + return _err("invalid-argument", "user field required") + if v.user.password: + return _err( + "invalid-argument", + "password cannot be changed via update-user; " + "use change-password or reset-password", + ) + if v.user.username and v.user.username != "": + # Compare to existing. Username-change not allowed. + existing, err = await self._user_in_workspace( + v.user_id, v.workspace, + ) + if err is not None: + return err + if v.user.username != existing[2]: + return _err( + "invalid-argument", + "username is immutable; create a new user " + "instead", + ) + else: + existing, err = await self._user_in_workspace( + v.user_id, v.workspace, + ) + if err is not None: + return err + + # Carry forward fields the caller didn't provide. + ( + _id, _ws, _username, cur_name, cur_email, _pw, + cur_roles, cur_enabled, cur_mcp, _created, + ) = existing + + new_name = v.user.name if v.user.name else cur_name + new_email = v.user.email if v.user.email else cur_email + new_roles = list(v.user.roles) if v.user.roles else list( + cur_roles or [], + ) + new_enabled = v.user.enabled if v.user.enabled is not None else ( + cur_enabled + ) + new_mcp = ( + v.user.must_change_password + if v.user.must_change_password is not None + else cur_mcp + ) + + await self.table_store.update_user_profile( + id=v.user_id, + name=new_name, + email=new_email, + roles=new_roles, + enabled=new_enabled, + must_change_password=new_mcp, + ) + + updated = await self.table_store.get_user(v.user_id) + return IamResponse(user=self._row_to_user_record(updated)) + + async def handle_disable_user(self, v): + """Soft-delete: set enabled=false and revoke every API key + belonging to the user.""" + if not v.workspace: + return _err("invalid-argument", "workspace required") + if not v.user_id: + return _err("invalid-argument", "user_id required") + + _, err = await self._user_in_workspace(v.user_id, v.workspace) + if err is not None: + return err + + await self.table_store.update_user_enabled( + id=v.user_id, enabled=False, + ) + + # Revoke all their API keys. + key_rows = await self.table_store.list_api_keys_by_user(v.user_id) + for kr in key_rows: + await self.table_store.delete_api_key(kr[0]) + + return IamResponse() + + async def handle_enable_user(self, v): + """Re-enable a previously disabled user. Does not restore + API keys — those have to be re-issued by the admin.""" + if not v.workspace: + return _err("invalid-argument", "workspace required") + if not v.user_id: + return _err("invalid-argument", "user_id required") + + _, err = await self._user_in_workspace(v.user_id, v.workspace) + if err is not None: + return err + + await self.table_store.update_user_enabled( + id=v.user_id, enabled=True, + ) + return IamResponse() + + async def handle_delete_user(self, v): + """Hard-delete a user. Removes the ``iam_users`` row, the + ``iam_users_by_username`` lookup row, and every API key + belonging to the user. + + Unlike disable, this frees the username for re-use and + removes the user's personal data from storage (intended to + cover GDPR erasure-style requirements). When audit logging + lands, the decision to delete vs. anonymise referenced audit + rows will need to be revisited.""" + if not v.workspace: + return _err("invalid-argument", "workspace required") + if not v.user_id: + return _err("invalid-argument", "user_id required") + + user_row, err = await self._user_in_workspace( + v.user_id, v.workspace, + ) + if err is not None: + return err + + # user_row indices match get_user columns. Username is [2]. + username = user_row[2] + + # Revoke all API keys. + key_rows = await self.table_store.list_api_keys_by_user(v.user_id) + for kr in key_rows: + await self.table_store.delete_api_key(kr[0]) + + # Remove username lookup. + if username: + await self.table_store.delete_username_lookup( + v.workspace, username, + ) + + # Remove user record. + await self.table_store.delete_user(v.user_id) + + return IamResponse() + + # ------------------------------------------------------------------ + # Workspace CRUD + # ------------------------------------------------------------------ + + async def handle_create_workspace(self, v): + if v.workspace_record is None or not v.workspace_record.id: + return _err( + "invalid-argument", + "workspace_record.id required for create-workspace", + ) + if v.workspace_record.id.startswith("_"): + return _err( + "invalid-argument", + "workspace ids beginning with '_' are reserved", + ) + + existing = await self.table_store.get_workspace( + v.workspace_record.id, + ) + if existing is not None: + return _err("duplicate", "workspace already exists") + + now = _now_dt() + await self.table_store.put_workspace( + id=v.workspace_record.id, + name=v.workspace_record.name or v.workspace_record.id, + enabled=v.workspace_record.enabled, + created=now, + ) + row = await self.table_store.get_workspace(v.workspace_record.id) + return IamResponse(workspace=self._row_to_workspace_record(row)) + + async def handle_list_workspaces(self, v): + rows = await self.table_store.list_workspaces() + return IamResponse( + workspaces=[ + self._row_to_workspace_record(r) for r in rows + ], + ) + + async def handle_get_workspace(self, v): + if v.workspace_record is None or not v.workspace_record.id: + return _err("invalid-argument", "workspace_record.id required") + row = await self.table_store.get_workspace(v.workspace_record.id) + if row is None: + return _err("not-found", "workspace not found") + return IamResponse(workspace=self._row_to_workspace_record(row)) + + async def handle_update_workspace(self, v): + """Update workspace name / enabled. The id is immutable.""" + if v.workspace_record is None or not v.workspace_record.id: + return _err("invalid-argument", "workspace_record.id required") + row = await self.table_store.get_workspace(v.workspace_record.id) + if row is None: + return _err("not-found", "workspace not found") + + _, cur_name, cur_enabled, _created = row + new_name = ( + v.workspace_record.name + if v.workspace_record.name else cur_name + ) + new_enabled = ( + v.workspace_record.enabled + if v.workspace_record.enabled is not None + else cur_enabled + ) + + await self.table_store.update_workspace( + id=v.workspace_record.id, + name=new_name, + enabled=new_enabled, + ) + updated = await self.table_store.get_workspace( + v.workspace_record.id, + ) + return IamResponse( + workspace=self._row_to_workspace_record(updated), + ) + + async def handle_disable_workspace(self, v): + """Set enabled=false, disable every user in the workspace, + revoke every API key belonging to those users.""" + if v.workspace_record is None or not v.workspace_record.id: + return _err("invalid-argument", "workspace_record.id required") + + row = await self.table_store.get_workspace(v.workspace_record.id) + if row is None: + return _err("not-found", "workspace not found") + + await self.table_store.update_workspace( + id=v.workspace_record.id, + name=row[1] or v.workspace_record.id, + enabled=False, + ) + + user_rows = await self.table_store.list_users_by_workspace( + v.workspace_record.id, + ) + for ur in user_rows: + user_id = ur[0] + await self.table_store.update_user_enabled( + id=user_id, enabled=False, + ) + key_rows = await self.table_store.list_api_keys_by_user(user_id) + for kr in key_rows: + await self.table_store.delete_api_key(kr[0]) + + return IamResponse() + + # ------------------------------------------------------------------ + # rotate-signing-key + # ------------------------------------------------------------------ + + async def handle_rotate_signing_key(self, v): + """Create a new Ed25519 signing key, retire the current + active key, switch the in-memory cache over. + + The retired key row is kept in ``iam_signing_keys`` so the + gateway's JWT validator can continue to validate previously- + issued tokens during the grace period. Actual grace-period + enforcement (time-window acceptance at the validator) lands + with the gateway auth middleware work.""" + + # Retire the currently-active key, if any. + current = await self._get_active_signing_key() + now = _now_dt() + if current is not None: + cur_kid, _cur_priv, _cur_pub = current + await self.table_store.retire_signing_key( + kid=cur_kid, retired=now, + ) + + new_kid, new_priv, new_pub = _generate_signing_keypair() + await self.table_store.put_signing_key( + kid=new_kid, + private_pem=new_priv, + public_pem=new_pub, + created=now, + retired=None, + ) + self._signing_key = (new_kid, new_priv, new_pub) + logger.info( + f"IAM: rotated signing key. " + f"New kid={new_kid}, retired kid={(current or (None,))[0]}" + ) + return IamResponse() + + # ------------------------------------------------------------------ + # resolve-api-key + # ------------------------------------------------------------------ + + async def handle_resolve_api_key(self, v): + if not v.api_key: + return _err("auth-failed", "no api key") + + row = await self.table_store.get_api_key_by_hash( + _hash_api_key(v.api_key), + ) + if row is None: + return _err("auth-failed", "unknown api key") + + ( + _key_hash, _id, user_id, _name, _prefix, expires, + _created, _last_used, + ) = row + + if expires is not None: + exp_dt = expires + if isinstance(exp_dt, str): + exp_dt = datetime.datetime.fromisoformat(exp_dt) + if exp_dt.tzinfo is None: + exp_dt = exp_dt.replace(tzinfo=datetime.timezone.utc) + if exp_dt < _now_dt(): + return _err("auth-failed", "api key expired") + + user_row = await self.table_store.get_user(user_id) + if user_row is None: + return _err("auth-failed", "owning user missing") + user = self._row_to_user_record(user_row) + if not user.enabled: + return _err("auth-failed", "owning user disabled") + + # Workspace-disabled check. + ws_row = await self.table_store.get_workspace(user.workspace) + if ws_row is None or not ws_row[2]: + return _err("auth-failed", "owning workspace disabled") + + return IamResponse( + resolved_user_id=user.id, + resolved_workspace=user.workspace, + resolved_roles=list(user.roles), + ) + + # ------------------------------------------------------------------ + # create-user + # ------------------------------------------------------------------ + + async def handle_create_user(self, v): + if not v.workspace: + return _err( + "invalid-argument", "workspace required for create-user", + ) + if v.user is None: + return _err( + "invalid-argument", "user field required for create-user", + ) + if not v.user.username: + return _err("invalid-argument", "user.username required") + if not v.user.password: + return _err("invalid-argument", "user.password required") + + # Workspace must exist and be enabled. + ws = await self.table_store.get_workspace(v.workspace) + if ws is None or not ws[2]: + return _err("not-found", "workspace not found or disabled") + + # Uniqueness on username within workspace. + existing = await self.table_store.get_user_id_by_username( + v.workspace, v.user.username, + ) + if existing: + return _err("duplicate", "username already exists") + + user_id = str(uuid.uuid4()) + now = _now_dt() + + await self.table_store.put_user( + id=user_id, + workspace=v.workspace, + username=v.user.username, + name=v.user.name or v.user.username, + email=v.user.email or "", + password_hash=_hash_password(v.user.password), + roles=list(v.user.roles or []), + enabled=v.user.enabled, + must_change_password=v.user.must_change_password, + created=now, + ) + + row = await self.table_store.get_user(user_id) + return IamResponse(user=self._row_to_user_record(row)) + + # ------------------------------------------------------------------ + # list-users + # ------------------------------------------------------------------ + + async def handle_list_users(self, v): + if not v.workspace: + return _err( + "invalid-argument", "workspace required for list-users", + ) + + rows = await self.table_store.list_users_by_workspace(v.workspace) + return IamResponse( + users=[self._row_to_user_record(r) for r in rows], + ) + + # ------------------------------------------------------------------ + # create-api-key + # ------------------------------------------------------------------ + + async def handle_create_api_key(self, v): + if not v.workspace: + return _err( + "invalid-argument", "workspace required for create-api-key", + ) + if v.key is None or not v.key.user_id: + return _err("invalid-argument", "key.user_id required") + if not v.key.name: + return _err("invalid-argument", "key.name required") + + # Target user must exist and belong to the caller's workspace. + user_row = await self.table_store.get_user(v.key.user_id) + if user_row is None: + return _err("not-found", "user not found") + if user_row[1] != v.workspace: + return _err( + "operation-not-permitted", + "target user is in a different workspace", + ) + + plaintext = _generate_api_key() + key_id = str(uuid.uuid4()) + now = _now_dt() + expires_dt = _parse_expires(v.key.expires) + + await self.table_store.put_api_key( + key_hash=_hash_api_key(plaintext), + id=key_id, + user_id=v.key.user_id, + name=v.key.name, + prefix=plaintext[:len(API_KEY_PREFIX) + 4], + expires=expires_dt, + created=now, + last_used=None, + ) + + row = await self.table_store.get_api_key_by_hash( + _hash_api_key(plaintext), + ) + return IamResponse( + api_key_plaintext=plaintext, + api_key=self._row_to_api_key_record(row), + ) + + # ------------------------------------------------------------------ + # list-api-keys + # ------------------------------------------------------------------ + + async def handle_list_api_keys(self, v): + if not v.workspace: + return _err( + "invalid-argument", + "workspace required for list-api-keys", + ) + if not v.user_id: + return _err( + "invalid-argument", "user_id required for list-api-keys", + ) + + # Workspace-scope check: user must live in this workspace. + user_row = await self.table_store.get_user(v.user_id) + if user_row is None or user_row[1] != v.workspace: + return _err("not-found", "user not found in workspace") + + rows = await self.table_store.list_api_keys_by_user(v.user_id) + return IamResponse( + api_keys=[self._row_to_api_key_record(r) for r in rows], + ) + + # ------------------------------------------------------------------ + # revoke-api-key + # ------------------------------------------------------------------ + + async def handle_revoke_api_key(self, v): + if not v.workspace: + return _err( + "invalid-argument", + "workspace required for revoke-api-key", + ) + if not v.key_id: + return _err("invalid-argument", "key_id required") + + row = await self.table_store.get_api_key_by_id(v.key_id) + if row is None: + return _err("not-found", "api key not found") + + key_hash, _id, user_id, _name, _prefix, _expires, _c, _lu = row + # Workspace-scope check via the owning user. + user_row = await self.table_store.get_user(user_id) + if user_row is None or user_row[1] != v.workspace: + return _err( + "operation-not-permitted", + "key belongs to a different workspace", + ) + + await self.table_store.delete_api_key(key_hash) + return IamResponse() diff --git a/trustgraph-flow/trustgraph/iam/service/service.py b/trustgraph-flow/trustgraph/iam/service/service.py new file mode 100644 index 00000000..8ea31cf0 --- /dev/null +++ b/trustgraph-flow/trustgraph/iam/service/service.py @@ -0,0 +1,210 @@ +""" +IAM service processor. Terminates the IAM request queue and forwards +each request to the IamService business logic, then returns the +response on the IAM response queue. + +Shape mirrors trustgraph.config.service. +""" + +import logging + +from trustgraph.schema import Error +from trustgraph.schema import IamRequest, IamResponse +from trustgraph.schema import iam_request_queue, iam_response_queue + +from trustgraph.base import AsyncProcessor, Consumer, Producer +from trustgraph.base import ConsumerMetrics, ProducerMetrics +from trustgraph.base.cassandra_config import ( + add_cassandra_args, resolve_cassandra_config, +) + +from . iam import IamService + +logger = logging.getLogger(__name__) + +default_ident = "iam-svc" + +default_iam_request_queue = iam_request_queue +default_iam_response_queue = iam_response_queue + + +class Processor(AsyncProcessor): + + def __init__(self, **params): + + iam_req_q = params.get( + "iam_request_queue", default_iam_request_queue, + ) + iam_resp_q = params.get( + "iam_response_queue", default_iam_response_queue, + ) + + bootstrap_mode = params.get("bootstrap_mode") + bootstrap_token = params.get("bootstrap_token") + + if bootstrap_mode not in ("token", "bootstrap"): + raise RuntimeError( + "iam-svc: --bootstrap-mode is required. Set to 'token' " + "(with --bootstrap-token) for production, or 'bootstrap' " + "to enable the explicit bootstrap operation over the " + "pub/sub bus (dev / quick-start only, not safe under " + "public exposure). Refusing to start." + ) + if bootstrap_mode == "token" and not bootstrap_token: + raise RuntimeError( + "iam-svc: --bootstrap-mode=token requires " + "--bootstrap-token. Refusing to start." + ) + if bootstrap_mode == "bootstrap" and bootstrap_token: + raise RuntimeError( + "iam-svc: --bootstrap-token is not accepted when " + "--bootstrap-mode=bootstrap. Ambiguous intent. " + "Refusing to start." + ) + + self.bootstrap_mode = bootstrap_mode + self.bootstrap_token = bootstrap_token + + cassandra_host = params.get("cassandra_host") + cassandra_username = params.get("cassandra_username") + cassandra_password = params.get("cassandra_password") + + hosts, username, password, keyspace = resolve_cassandra_config( + host=cassandra_host, + username=cassandra_username, + password=cassandra_password, + default_keyspace="iam", + ) + + self.cassandra_host = hosts + self.cassandra_username = username + self.cassandra_password = password + + super().__init__( + **params | { + "iam_request_schema": IamRequest.__name__, + "iam_response_schema": IamResponse.__name__, + "cassandra_host": self.cassandra_host, + "cassandra_username": self.cassandra_username, + "cassandra_password": self.cassandra_password, + } + ) + + iam_request_metrics = ConsumerMetrics( + processor=self.id, flow=None, name="iam-request", + ) + iam_response_metrics = ProducerMetrics( + processor=self.id, flow=None, name="iam-response", + ) + + self.iam_request_topic = iam_req_q + + self.iam_request_consumer = Consumer( + taskgroup=self.taskgroup, + backend=self.pubsub, + flow=None, + topic=iam_req_q, + subscriber=self.id, + schema=IamRequest, + handler=self.on_iam_request, + metrics=iam_request_metrics, + ) + + self.iam_response_producer = Producer( + backend=self.pubsub, + topic=iam_resp_q, + schema=IamResponse, + metrics=iam_response_metrics, + ) + + self.iam = IamService( + host=self.cassandra_host, + username=self.cassandra_username, + password=self.cassandra_password, + keyspace=keyspace, + bootstrap_mode=self.bootstrap_mode, + bootstrap_token=self.bootstrap_token, + ) + + logger.info( + f"IAM service initialised (bootstrap-mode={self.bootstrap_mode})" + ) + + async def start(self): + await self.pubsub.ensure_topic(self.iam_request_topic) + # Token-mode auto-bootstrap runs before we accept requests so + # the first inbound call always sees a populated table. + await self.iam.auto_bootstrap_if_token_mode() + await self.iam_request_consumer.start() + + async def on_iam_request(self, msg, consumer, flow): + + id = None + try: + v = msg.value() + id = msg.properties()["id"] + logger.debug( + f"Handling IAM request {id} op={v.operation!r}" + ) + resp = await self.iam.handle(v) + await self.iam_response_producer.send( + resp, properties={"id": id}, + ) + except Exception as e: + logger.error( + f"IAM request failed: {type(e).__name__}: {e}", + exc_info=True, + ) + resp = IamResponse( + error=Error(type="internal-error", message=str(e)), + ) + if id is not None: + await self.iam_response_producer.send( + resp, properties={"id": id}, + ) + + @staticmethod + def add_args(parser): + AsyncProcessor.add_args(parser) + + parser.add_argument( + "--iam-request-queue", + default=default_iam_request_queue, + help=f"IAM request queue (default: {default_iam_request_queue})", + ) + parser.add_argument( + "--iam-response-queue", + default=default_iam_response_queue, + help=f"IAM response queue (default: {default_iam_response_queue})", + ) + parser.add_argument( + "--bootstrap-mode", + default=None, + choices=["token", "bootstrap"], + help=( + "IAM bootstrap mode (required). " + "'token' = operator supplies the initial admin API " + "key via --bootstrap-token; auto-seeds on first start, " + "bootstrap operation refused. " + "'bootstrap' = bootstrap operation is live over the " + "bus until tables are populated; a token is generated " + "and returned by tg-bootstrap-iam. Unsafe to run " + "'bootstrap' mode with public exposure." + ), + ) + parser.add_argument( + "--bootstrap-token", + default=None, + help=( + "Initial admin API key plaintext, required when " + "--bootstrap-mode=token. Treat as a one-time " + "credential: the operator should rotate to a new key " + "and revoke this one after first use." + ), + ) + + add_cassandra_args(parser) + + +def run(): + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/tables/iam.py b/trustgraph-flow/trustgraph/tables/iam.py new file mode 100644 index 00000000..3d41ebbd --- /dev/null +++ b/trustgraph-flow/trustgraph/tables/iam.py @@ -0,0 +1,422 @@ +""" +IAM Cassandra table store. + +Tables: + - iam_workspaces (id primary key) + - iam_users (id primary key) + iam_users_by_username lookup table + (workspace, username) -> id + - iam_api_keys (key_hash primary key) with secondary index on user_id + - iam_signing_keys (kid primary key) — RSA keypairs for JWT signing + +See docs/tech-specs/iam-protocol.md for the wire-level context. +""" + +import logging + +from cassandra.cluster import Cluster +from cassandra.auth import PlainTextAuthProvider +from ssl import SSLContext, PROTOCOL_TLSv1_2 + +from . cassandra_async import async_execute + +logger = logging.getLogger(__name__) + + +class IamTableStore: + + def __init__( + self, + cassandra_host, cassandra_username, cassandra_password, + keyspace, + ): + self.keyspace = keyspace + + logger.info("IAM: connecting to Cassandra...") + + if isinstance(cassandra_host, str): + cassandra_host = [h.strip() for h in cassandra_host.split(",")] + + if cassandra_username and cassandra_password: + ssl_context = SSLContext(PROTOCOL_TLSv1_2) + auth_provider = PlainTextAuthProvider( + username=cassandra_username, password=cassandra_password, + ) + self.cluster = Cluster( + cassandra_host, + auth_provider=auth_provider, + ssl_context=ssl_context, + ) + else: + self.cluster = Cluster(cassandra_host) + + self.cassandra = self.cluster.connect() + + logger.info("IAM: connected.") + + self._ensure_schema() + self._prepare_statements() + + def _ensure_schema(self): + # FIXME: Replication factor should be configurable. + self.cassandra.execute(f""" + create keyspace if not exists {self.keyspace} + with replication = {{ + 'class' : 'SimpleStrategy', + 'replication_factor' : 1 + }}; + """) + self.cassandra.set_keyspace(self.keyspace) + + self.cassandra.execute(""" + CREATE TABLE IF NOT EXISTS iam_workspaces ( + id text PRIMARY KEY, + name text, + enabled boolean, + created timestamp + ); + """) + + self.cassandra.execute(""" + CREATE TABLE IF NOT EXISTS iam_users ( + id text PRIMARY KEY, + workspace text, + username text, + name text, + email text, + password_hash text, + roles set, + enabled boolean, + must_change_password boolean, + created timestamp + ); + """) + + self.cassandra.execute(""" + CREATE TABLE IF NOT EXISTS iam_users_by_username ( + workspace text, + username text, + user_id text, + PRIMARY KEY ((workspace), username) + ); + """) + + self.cassandra.execute(""" + CREATE TABLE IF NOT EXISTS iam_api_keys ( + key_hash text PRIMARY KEY, + id text, + user_id text, + name text, + prefix text, + expires timestamp, + created timestamp, + last_used timestamp + ); + """) + + self.cassandra.execute(""" + CREATE INDEX IF NOT EXISTS iam_api_keys_user_id_idx + ON iam_api_keys (user_id); + """) + + self.cassandra.execute(""" + CREATE INDEX IF NOT EXISTS iam_api_keys_id_idx + ON iam_api_keys (id); + """) + + self.cassandra.execute(""" + CREATE TABLE IF NOT EXISTS iam_signing_keys ( + kid text PRIMARY KEY, + private_pem text, + public_pem text, + created timestamp, + retired timestamp + ); + """) + + logger.info("IAM: Cassandra schema OK.") + + def _prepare_statements(self): + c = self.cassandra + + self.put_workspace_stmt = c.prepare(""" + INSERT INTO iam_workspaces (id, name, enabled, created) + VALUES (?, ?, ?, ?) + """) + self.get_workspace_stmt = c.prepare(""" + SELECT id, name, enabled, created FROM iam_workspaces + WHERE id = ? + """) + self.list_workspaces_stmt = c.prepare(""" + SELECT id, name, enabled, created FROM iam_workspaces + """) + + self.put_user_stmt = c.prepare(""" + INSERT INTO iam_users ( + id, workspace, username, name, email, password_hash, + roles, enabled, must_change_password, created + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """) + self.get_user_stmt = c.prepare(""" + SELECT id, workspace, username, name, email, password_hash, + roles, enabled, must_change_password, created + FROM iam_users WHERE id = ? + """) + self.list_users_by_workspace_stmt = c.prepare(""" + SELECT id, workspace, username, name, email, password_hash, + roles, enabled, must_change_password, created + FROM iam_users WHERE workspace = ? ALLOW FILTERING + """) + + self.put_username_lookup_stmt = c.prepare(""" + INSERT INTO iam_users_by_username (workspace, username, user_id) + VALUES (?, ?, ?) + """) + self.get_user_id_by_username_stmt = c.prepare(""" + SELECT user_id FROM iam_users_by_username + WHERE workspace = ? AND username = ? + """) + self.delete_username_lookup_stmt = c.prepare(""" + DELETE FROM iam_users_by_username + WHERE workspace = ? AND username = ? + """) + self.delete_user_stmt = c.prepare(""" + DELETE FROM iam_users WHERE id = ? + """) + + self.put_api_key_stmt = c.prepare(""" + INSERT INTO iam_api_keys ( + key_hash, id, user_id, name, prefix, expires, + created, last_used + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """) + self.get_api_key_by_hash_stmt = c.prepare(""" + SELECT key_hash, id, user_id, name, prefix, expires, + created, last_used + FROM iam_api_keys WHERE key_hash = ? + """) + self.get_api_key_by_id_stmt = c.prepare(""" + SELECT key_hash, id, user_id, name, prefix, expires, + created, last_used + FROM iam_api_keys WHERE id = ? + """) + self.list_api_keys_by_user_stmt = c.prepare(""" + SELECT key_hash, id, user_id, name, prefix, expires, + created, last_used + FROM iam_api_keys WHERE user_id = ? + """) + self.delete_api_key_stmt = c.prepare(""" + DELETE FROM iam_api_keys WHERE key_hash = ? + """) + + self.put_signing_key_stmt = c.prepare(""" + INSERT INTO iam_signing_keys ( + kid, private_pem, public_pem, created, retired + ) + VALUES (?, ?, ?, ?, ?) + """) + self.list_signing_keys_stmt = c.prepare(""" + SELECT kid, private_pem, public_pem, created, retired + FROM iam_signing_keys + """) + self.retire_signing_key_stmt = c.prepare(""" + UPDATE iam_signing_keys SET retired = ? WHERE kid = ? + """) + + self.update_user_profile_stmt = c.prepare(""" + UPDATE iam_users + SET name = ?, email = ?, roles = ?, enabled = ?, + must_change_password = ? + WHERE id = ? + """) + self.update_user_password_stmt = c.prepare(""" + UPDATE iam_users + SET password_hash = ?, must_change_password = ? + WHERE id = ? + """) + self.update_user_enabled_stmt = c.prepare(""" + UPDATE iam_users SET enabled = ? WHERE id = ? + """) + + self.update_workspace_stmt = c.prepare(""" + UPDATE iam_workspaces SET name = ?, enabled = ? + WHERE id = ? + """) + + # ------------------------------------------------------------------ + # Workspaces + # ------------------------------------------------------------------ + + async def put_workspace(self, id, name, enabled, created): + await async_execute( + self.cassandra, self.put_workspace_stmt, + (id, name, enabled, created), + ) + + async def get_workspace(self, id): + rows = await async_execute( + self.cassandra, self.get_workspace_stmt, (id,), + ) + return rows[0] if rows else None + + async def list_workspaces(self): + return await async_execute( + self.cassandra, self.list_workspaces_stmt, + ) + + # ------------------------------------------------------------------ + # Users + # ------------------------------------------------------------------ + + async def put_user( + self, id, workspace, username, name, email, password_hash, + roles, enabled, must_change_password, created, + ): + await async_execute( + self.cassandra, self.put_user_stmt, + ( + id, workspace, username, name, email, password_hash, + set(roles) if roles else set(), + enabled, must_change_password, created, + ), + ) + await async_execute( + self.cassandra, self.put_username_lookup_stmt, + (workspace, username, id), + ) + + async def get_user(self, id): + rows = await async_execute( + self.cassandra, self.get_user_stmt, (id,), + ) + return rows[0] if rows else None + + async def get_user_id_by_username(self, workspace, username): + rows = await async_execute( + self.cassandra, self.get_user_id_by_username_stmt, + (workspace, username), + ) + return rows[0][0] if rows else None + + async def list_users_by_workspace(self, workspace): + return await async_execute( + self.cassandra, self.list_users_by_workspace_stmt, (workspace,), + ) + + async def delete_user(self, id): + await async_execute( + self.cassandra, self.delete_user_stmt, (id,), + ) + + async def delete_username_lookup(self, workspace, username): + await async_execute( + self.cassandra, self.delete_username_lookup_stmt, + (workspace, username), + ) + + # ------------------------------------------------------------------ + # API keys + # ------------------------------------------------------------------ + + async def put_api_key( + self, key_hash, id, user_id, name, prefix, expires, + created, last_used, + ): + await async_execute( + self.cassandra, self.put_api_key_stmt, + (key_hash, id, user_id, name, prefix, expires, + created, last_used), + ) + + async def get_api_key_by_hash(self, key_hash): + rows = await async_execute( + self.cassandra, self.get_api_key_by_hash_stmt, (key_hash,), + ) + return rows[0] if rows else None + + async def get_api_key_by_id(self, id): + rows = await async_execute( + self.cassandra, self.get_api_key_by_id_stmt, (id,), + ) + return rows[0] if rows else None + + async def list_api_keys_by_user(self, user_id): + return await async_execute( + self.cassandra, self.list_api_keys_by_user_stmt, (user_id,), + ) + + async def delete_api_key(self, key_hash): + await async_execute( + self.cassandra, self.delete_api_key_stmt, (key_hash,), + ) + + # ------------------------------------------------------------------ + # Signing keys + # ------------------------------------------------------------------ + + async def put_signing_key(self, kid, private_pem, public_pem, + created, retired): + await async_execute( + self.cassandra, self.put_signing_key_stmt, + (kid, private_pem, public_pem, created, retired), + ) + + async def list_signing_keys(self): + return await async_execute( + self.cassandra, self.list_signing_keys_stmt, + ) + + async def retire_signing_key(self, kid, retired): + await async_execute( + self.cassandra, self.retire_signing_key_stmt, + (retired, kid), + ) + + # ------------------------------------------------------------------ + # User partial updates + # ------------------------------------------------------------------ + + async def update_user_profile( + self, id, name, email, roles, enabled, must_change_password, + ): + await async_execute( + self.cassandra, self.update_user_profile_stmt, + ( + name, email, + set(roles) if roles else set(), + enabled, must_change_password, id, + ), + ) + + async def update_user_password( + self, id, password_hash, must_change_password, + ): + await async_execute( + self.cassandra, self.update_user_password_stmt, + (password_hash, must_change_password, id), + ) + + async def update_user_enabled(self, id, enabled): + await async_execute( + self.cassandra, self.update_user_enabled_stmt, + (enabled, id), + ) + + # ------------------------------------------------------------------ + # Workspace updates + # ------------------------------------------------------------------ + + async def update_workspace(self, id, name, enabled): + await async_execute( + self.cassandra, self.update_workspace_stmt, + (name, enabled, id), + ) + + # ------------------------------------------------------------------ + # Bootstrap helpers + # ------------------------------------------------------------------ + + async def any_workspace_exists(self): + rows = await self.list_workspaces() + return bool(rows)