From 9f2bfbce0cbdc34282049e95bfc64a03417c6622 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Mon, 4 May 2026 10:30:03 +0100 Subject: [PATCH] Per-workspace queue routing for workspace-scoped services (#862) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Workspace identity is now determined by queue infrastructure instead of message body fields, closing a privilege-escalation vector where a caller could spoof workspace in the request payload. - Add WorkspaceProcessor base class: discovers workspaces from config at startup, creates per-workspace consumers (queue:workspace), and manages consumer lifecycle on workspace create/delete events - Roll out to librarian, flow-svc, knowledge cores, and config-svc - Config service gets a dual-queue regime: a system queue for cross-workspace ops (getvalues-all-ws, bootstrapper writes to __workspaces__) and per-workspace queues for tenant-scoped ops, with workspace discovery from its own Cassandra store - Remove workspace field from request schemas (FlowRequest, LibrarianRequest, KnowledgeRequest, CollectionManagementRequest) and from DocumentMetadata / ProcessingMetadata — table stores now accept workspace as an explicit parameter - Strip workspace encode/decode from all message translators and gateway serializers - Gateway enforces workspace existence: reject requests targeting non-existent workspaces instead of routing to queues with no consumer - Config service provisions new workspaces from __template__ on creation - Add workspace lifecycle hooks to AsyncProcessor so any processor can react to workspace create/delete without subclassing WorkspaceProcessor --- docs/tech-specs/workspace-scoped-services.md | 366 ++++++++++++++++++ tests/unit/test_base/test_flow_processor.py | 2 +- .../unit/test_cores/test_knowledge_manager.py | 71 ++-- tests/unit/test_gateway/test_capabilities.py | 6 +- .../test_gateway/test_dispatch_manager.py | 64 ++- .../test_librarian/test_chunked_upload.py | 112 +++--- .../test_metadata_preservation.py | 15 +- .../test_knowledge_translator_roundtrip.py | 3 +- trustgraph-base/trustgraph/api/library.py | 8 - trustgraph-base/trustgraph/api/types.py | 4 - trustgraph-base/trustgraph/base/__init__.py | 1 + .../trustgraph/base/async_processor.py | 19 + .../trustgraph/base/flow_processor.py | 6 +- .../trustgraph/base/librarian_client.py | 2 - .../trustgraph/base/workspace_processor.py | 66 ++++ .../messaging/translators/collection.py | 3 - .../trustgraph/messaging/translators/flow.py | 3 - .../messaging/translators/knowledge.py | 3 - .../messaging/translators/library.py | 3 - .../messaging/translators/metadata.py | 6 - .../trustgraph/schema/core/metadata.py | 4 +- .../trustgraph/schema/knowledge/knowledge.py | 5 +- .../trustgraph/schema/services/collection.py | 10 +- .../trustgraph/schema/services/config.py | 9 + .../trustgraph/schema/services/flow.py | 3 - .../trustgraph/schema/services/library.py | 11 +- .../trustgraph/cli/show_flow_state.py | 12 +- .../trustgraph/cli/show_processor_state.py | 20 +- .../trustgraph/cli/show_token_rate.py | 26 +- trustgraph-flow/trustgraph/bootstrap/base.py | 2 +- .../bootstrap/bootstrapper/service.py | 43 +- .../initialisers/default_flow_start.py | 100 +++-- .../trustgraph/config/service/config.py | 164 +++++--- .../trustgraph/config/service/service.py | 191 ++++++++- trustgraph-flow/trustgraph/cores/knowledge.py | 37 +- trustgraph-flow/trustgraph/cores/service.py | 86 ++-- .../trustgraph/flow/service/flow.py | 62 ++- .../trustgraph/flow/service/service.py | 88 +++-- trustgraph-flow/trustgraph/gateway/auth.py | 6 + .../trustgraph/gateway/capabilities.py | 13 + .../trustgraph/gateway/config/receiver.py | 39 +- .../trustgraph/gateway/dispatch/manager.py | 56 ++- .../trustgraph/gateway/dispatch/mux.py | 12 +- .../trustgraph/gateway/dispatch/serialize.py | 8 - .../trustgraph/gateway/endpoint/manager.py | 13 +- .../gateway/endpoint/registry_endpoint.py | 13 +- trustgraph-flow/trustgraph/gateway/service.py | 2 +- trustgraph-flow/trustgraph/iam/service/iam.py | 19 +- .../trustgraph/iam/service/service.py | 87 +++++ .../librarian/collection_manager.py | 49 +-- .../trustgraph/librarian/librarian.py | 105 +++-- .../trustgraph/librarian/service.py | 168 ++++---- trustgraph-flow/trustgraph/tables/library.py | 16 +- 53 files changed, 1565 insertions(+), 677 deletions(-) create mode 100644 docs/tech-specs/workspace-scoped-services.md create mode 100644 trustgraph-base/trustgraph/base/workspace_processor.py diff --git a/docs/tech-specs/workspace-scoped-services.md b/docs/tech-specs/workspace-scoped-services.md new file mode 100644 index 00000000..afc940af --- /dev/null +++ b/docs/tech-specs/workspace-scoped-services.md @@ -0,0 +1,366 @@ +--- +layout: default +title: "Workspace-Scoped Services" +parent: "Tech Specs" +--- + +# Workspace-Scoped Services + +## Problem Statement + +Workspace-scoped services (librarian, config, knowledge, collection +management) currently operate on global queues — a single +`request:tg:librarian` queue handles requests for all workspaces. +Workspace identity is carried as a field in the request body, set by +the gateway after authentication. This creates several problems: + +- **No structural isolation.** All workspaces share a single queue. + Workspace scoping depends entirely on a body field being populated + correctly. If the field is missing or wrong, the service operates + on the wrong workspace — or fails with a confusing error. This is + a security concern: workspace isolation should be enforced by + infrastructure, not by trusting a field. + +- **Redundant workspace fields.** Nested objects within requests + (e.g. `processing-metadata`, `document-metadata`) carry their own + `workspace` fields alongside the top-level request workspace. + The gateway resolves the top-level workspace but does not + propagate into nested payloads. Services that read workspace from + a nested object instead of the top-level address see `None` and + fail. + +- **No workspace lifecycle awareness.** Workspace-scoped services + have no mechanism to learn when workspaces are created or deleted. + Flow processors discover workspaces indirectly through config + entries, but workspace-scoped services on global queues have no + equivalent. There is no event when a workspace appears or + disappears. + +- **Inconsistency with flow-scoped services.** Flow-scoped services + already use per-workspace, per-flow queue names + (`request:tg:{workspace}:embeddings:{class}`). Workspace-scoped + services are the exception — they sit on global queues while + everything else is structurally isolated. + +## Design + +### Per-workspace queues for workspace-scoped services + +Workspace-scoped services move from global queues to per-workspace +queues. The queue name includes the workspace identifier: + +**Current (global):** +``` +request:tg:librarian +request:tg:config +``` + +**Proposed (per-workspace):** +``` +request:tg:librarian:{workspace} +request:tg:config:{workspace} +``` + +The gateway routes requests to the correct queue based on the +resolved workspace from authentication — the same workspace that +today gets written into the request body. The workspace is now part +of the queue address, not just a field in the payload. + +Services subscribe to per-workspace queues. When a new workspace is +created, they subscribe to its queue. When a workspace is deleted, +they unsubscribe. + +### Workspace lifecycle via the `__workspaces__` config namespace + +Workspace lifecycle events are modelled as config changes in a +reserved `__workspaces__` namespace. This mirrors the existing +`__system__` namespace — a reserved space for infrastructure +concerns that don't belong to any user workspace. + +When IAM creates a workspace, it writes an entry to the config +service: + +``` +workspace: __workspaces__ +type: workspace +key: +value: {"enabled": true} +``` + +When IAM deletes (or disables) a workspace, it updates or deletes +the entry. The config service sees this as a normal config change +and pushes a notification through the existing `ConfigPush` +mechanism. + +This avoids introducing a new notification channel. The config +service already has the machinery to notify subscribers of changes +by type and workspace. Workspace lifecycle is just another config +type that services can register handlers for. + +### Config push changes + +#### Remove `_`-prefix suppression + +The config service currently suppresses notifications for workspaces +whose names start with `_`. This suppression is removed — the +config service pushes notifications for all workspaces +unconditionally. + +The filtering moves to the consumer side. `AsyncProcessor` already +filters `_`-prefixed workspaces in its config handler dispatch +(lines 212 and 315 of `async_processor.py`). This filtering is +retained as the default behaviour, but handlers can opt in to +infrastructure namespaces by registering for them explicitly (see +`WorkspaceProcessor` below). + +#### Workspace change events + +The `ConfigPush` message gains a `workspace_changes` field alongside +the existing `changes` field: + +```python +@dataclass +class ConfigPush: + version: int = 0 + + # Config changes: type -> [affected workspaces] + changes: dict[str, list[str]] = field(default_factory=dict) + + # Workspace lifecycle: created/deleted workspace lists + workspace_changes: WorkspaceChanges | None = None + +@dataclass +class WorkspaceChanges: + created: list[str] = field(default_factory=list) + deleted: list[str] = field(default_factory=list) +``` + +The config service populates `workspace_changes` when it detects +changes to the `__workspaces__` config namespace. A new key +appearing is a creation; a key being deleted is a deletion. + +Services that don't care about workspace lifecycle ignore the field. +Services that do (workspace-scoped services, the gateway) react by +subscribing to or tearing down per-workspace queues. + +### The `WorkspaceProcessor` base class + +A new base class sits between `AsyncProcessor` and `FlowProcessor` +in the processor hierarchy: + +``` +AsyncProcessor → WorkspaceProcessor → FlowProcessor +``` + +`WorkspaceProcessor` manages per-workspace queue lifecycle the same +way `FlowProcessor` manages per-flow lifecycle. It: + +1. On startup, discovers existing workspaces by fetching config from + the `__workspaces__` namespace (using the existing + `_fetch_type_all_workspaces` pattern). + +2. For each workspace, subscribes to the service's per-workspace + queue (e.g. `request:tg:librarian:{workspace}`). + +3. Registers a config handler for the `workspace` type in the + `__workspaces__` namespace. When a workspace is created, it + subscribes to the new queue. When a workspace is deleted, it + unsubscribes and tears down. + +4. Exposes hooks for derived classes: + - `on_workspace_created(workspace)` — called after subscribing + to the new workspace's queue. + - `on_workspace_deleted(workspace)` — called before + unsubscribing from the workspace's queue. + +`FlowProcessor` extends `WorkspaceProcessor` instead of +`AsyncProcessor`. Flows exist within workspaces, so the hierarchy +is natural: workspace creation triggers queue subscription, then +flow config changes within that workspace trigger flow start/stop. + +Services that are workspace-scoped but not flow-scoped (librarian, +knowledge, collection management) extend `WorkspaceProcessor` +directly. + +### Gateway routing changes + +The gateway currently dispatches workspace-scoped requests to global +service dispatchers. This changes to per-workspace dispatchers that +route to per-workspace queues. + +For HTTP requests, the resolved workspace from the URL path +(`/api/v1/workspaces/{w}/library`) determines the target queue. + +For WebSocket requests via the Mux, the resolved workspace from +`enforce_workspace` determines the target queue. The Mux already +resolves workspace before dispatching (line 214 of `mux.py`); the +change is that `invoke_global_service` uses workspace to select the +queue, rather than routing to a single global queue. + +System-level services (IAM) remain on global queues — they are not +workspace-scoped. + +### Workspace field on nested metadata objects + +With per-workspace queues, the workspace is part of the queue +address. Services know which workspace they are serving by which +queue a message arrived on. + +The `workspace` field on `DocumentMetadata` and +`ProcessingMetadata` in the librarian schema becomes a storage +attribute — the workspace the record belongs to, populated by the +service from the request context, not by the caller. The service +reads workspace from `request.workspace` (the resolved address) or +from the queue context, never from a nested payload field. + +Callers are not required to populate workspace on nested objects. +The service fills it in authoritatively from the request context +before storing. + +## Interaction with existing specs + +### IAM (`iam.md`, `iam-contract.md`) + +IAM is the authority for workspace existence. When IAM creates or +deletes a workspace, it writes to the `__workspaces__` config +namespace. This is a two-step operation: register the workspace in +IAM's own store (`iam_workspaces` table), then announce it via +config. + +The IAM service itself remains on a global queue — it is a +system-level service, not workspace-scoped. + +### Config service + +The config service is workspace-scoped — it stores per-workspace +configuration. Under this design, the config service moves to +per-workspace queues like other workspace-scoped services. + +On startup, the config service discovers workspaces from its own +store (it has direct access to the config tables, unlike other +services that fetch via request/response). It subscribes to +per-workspace queues for each known workspace. + +When IAM writes a new workspace entry to the `__workspaces__` +namespace, the config service sees the write directly (it is the +config service), creates the per-workspace queue, and pushes the +notification. + +### Flow blueprints (`flow-blueprint-definition.md`) + +Flow blueprints already use `{workspace}` in queue name templates. +No changes needed — flows are created within an already-existing +workspace, so the per-workspace infrastructure is in place before +flow start. + +### Data ownership (`data-ownership-model.md`) + +This spec reinforces the data ownership model: a workspace is the +primary isolation boundary, and per-workspace queues make that +boundary structural rather than conventional. + +## Migration + +### Queue naming + +Existing deployments use global queues for workspace-scoped +services. Migration requires: + +1. Deploy updated services that subscribe to both global and + per-workspace queues during a transition period. +2. Update the gateway to route to per-workspace queues. +3. Drain the global queues. +4. Remove global queue subscriptions from services. + +### `__workspaces__` bootstrap + +On first start after migration, IAM populates the `__workspaces__` +config namespace with entries for all existing workspaces from +`iam_workspaces`. This seeds the config store so that +workspace-scoped services discover existing workspaces on startup. + +### Config push compatibility + +The `workspace_changes` field on `ConfigPush` is additive. +Services that don't understand it ignore it (the field defaults to +`None`). No breaking change to the push protocol. + +## Summary of changes + +| Component | Change | +|-----------|--------| +| Queue names | Workspace-scoped services move from `request:tg:{service}` to `request:tg:{service}:{workspace}` | +| `__workspaces__` namespace | New reserved config namespace for workspace lifecycle | +| IAM service | Writes to `__workspaces__` on workspace create/delete | +| Config service | Removes `_`-prefix notification suppression; generates `workspace_changes` events; moves to per-workspace queues | +| `ConfigPush` schema | Adds `workspace_changes` field (`WorkspaceChanges` dataclass) | +| `WorkspaceProcessor` | New base class managing per-workspace queue lifecycle | +| `FlowProcessor` | Extends `WorkspaceProcessor` instead of `AsyncProcessor` | +| `AsyncProcessor` | Relaxes `_`-prefix filtering to allow opt-in for infrastructure namespaces | +| Gateway | Routes workspace-scoped requests to per-workspace queues | +| Librarian schema | `workspace` on nested metadata becomes a service-populated storage attribute, not a caller-supplied address | + +## Implementation Plan + +### Phase 1: Foundation — `__workspaces__` namespace and config push + +- **`ConfigPush` schema** (`trustgraph-base/trustgraph/schema/services/config.py`): Add `WorkspaceChanges` dataclass and `workspace_changes` field. +- **Config push serialization** (`trustgraph-base/trustgraph/messaging/translators/`): Encode/decode the new field. +- **Config service** (`trustgraph-flow/trustgraph/config/`): Detect writes to `__workspaces__` namespace and populate `workspace_changes` on the push message. Remove `_`-prefix notification suppression. +- **`AsyncProcessor`** (`trustgraph-base/trustgraph/base/async_processor.py`): Relax `_`-prefix filtering so handlers can opt in to infrastructure namespaces. +- **IAM service** (`trustgraph-flow/trustgraph/iam/`): Write to `__workspaces__` config namespace on `create-workspace` and `delete-workspace`. Add bootstrap step to seed `__workspaces__` entries for existing workspaces. + +### Phase 2: `WorkspaceProcessor` base class + +- **New `WorkspaceProcessor`** (`trustgraph-base/trustgraph/base/workspace_processor.py`): Implements workspace discovery on startup, per-workspace queue subscribe/unsubscribe, workspace lifecycle handler registration, `on_workspace_created`/`on_workspace_deleted` hooks. +- **`FlowProcessor`** (`trustgraph-base/trustgraph/base/flow_processor.py`): Re-parent from `AsyncProcessor` to `WorkspaceProcessor`. +- **Verify** existing flow processors continue to work — the new layer should be transparent to them. + +### Phase 3: Per-workspace queues for workspace-scoped services + +- **Queue definitions** (`trustgraph-base/trustgraph/schema/`): Update queue names for librarian, config, knowledge, collection management to include `{workspace}`. +- **Librarian** (`trustgraph-flow/trustgraph/librarian/`): Extend `WorkspaceProcessor`. Remove reliance on workspace from nested metadata objects. +- **Knowledge service, collection management** and other workspace-scoped services: Extend `WorkspaceProcessor`. +- **Config service**: Self-bootstrap per-workspace queues from its own store on startup; subscribe to new workspace queues when `__workspaces__` entries appear. + +### Phase 4: Gateway routing + +- **Gateway dispatcher manager** (`trustgraph-flow/trustgraph/gateway/dispatch/manager.py`): Route workspace-scoped services to per-workspace queues using resolved workspace. System-level services (IAM) remain on global queues. +- **Mux** (`trustgraph-flow/trustgraph/gateway/dispatch/mux.py`): Pass workspace to `invoke_global_service` for workspace-scoped services. +- **HTTP endpoints** (`trustgraph-flow/trustgraph/gateway/endpoint/`): Route to per-workspace queues based on URL path workspace. + +### Phase 5: Schema cleanup + +- **`DocumentMetadata`, `ProcessingMetadata`** (`trustgraph-base/trustgraph/schema/services/library.py`): Remove `workspace` field from nested metadata objects, or retain as a service-populated storage attribute only. +- **Serialization** (`trustgraph-flow/trustgraph/gateway/dispatch/serialize.py`, `trustgraph-base/trustgraph/messaging/translators/metadata.py`): Update translators to match. +- **API client** (`trustgraph-base/trustgraph/api/library.py`): Stop sending workspace in nested payloads. +- **Librarian service** (`trustgraph-flow/trustgraph/librarian/`): Populate workspace on stored records from request context. + +### Dependencies + +``` +Phase 1 (foundation) + ↓ +Phase 2 (WorkspaceProcessor) + ↓ +Phase 3 (per-workspace queues) ←→ Phase 4 (gateway routing) + ↓ ↓ + Phase 5 (schema cleanup) +``` + +Phases 3 and 4 can be developed in parallel but must be deployed together — services expecting per-workspace queues need the gateway to route to them. + +## References + +- [Identity and Access Management](iam.md) — workspace registry, + authentication, and workspace resolution. +- [IAM Contract](iam-contract.md) — resource model and workspace as + address vs. parameter. +- [Data Ownership and Information Separation](data-ownership-model.md) + — workspace as isolation boundary. +- [Config Push and Poke](config-push-poke.md) — config notification + mechanism. +- [Flow Blueprint Definition](flow-blueprint-definition.md) — + `{workspace}` template variable in queue names. +- [Flow Service Queue Lifecycle](flow-service-queue-lifecycle.md) — + queue ownership and lifecycle model. diff --git a/tests/unit/test_base/test_flow_processor.py b/tests/unit/test_base/test_flow_processor.py index 350a8b43..b2e5f87e 100644 --- a/tests/unit/test_base/test_flow_processor.py +++ b/tests/unit/test_base/test_flow_processor.py @@ -233,7 +233,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase): mock_flow2.start.assert_called_once() @with_async_processor_patches - @patch('trustgraph.base.async_processor.AsyncProcessor.start') + @patch('trustgraph.base.workspace_processor.WorkspaceProcessor.start') async def test_start_calls_parent(self, mock_parent_start, *mocks): """Test that start() calls parent start method""" mock_parent_start.return_value = None diff --git a/tests/unit/test_cores/test_knowledge_manager.py b/tests/unit/test_cores/test_knowledge_manager.py index d677b82f..8f73dcc6 100644 --- a/tests/unit/test_cores/test_knowledge_manager.py +++ b/tests/unit/test_cores/test_knowledge_manager.py @@ -45,7 +45,6 @@ def mock_flow_config(): def mock_request(): """Mock knowledge load request.""" request = Mock() - request.workspace = "test-user" request.id = "test-doc-id" request.collection = "test-collection" request.flow = "test-flow" @@ -131,17 +130,17 @@ class TestKnowledgeManagerLoadCore: # Start the core loader background task knowledge_manager.background_task = None - await knowledge_manager.load_kg_core(mock_request, mock_respond) - + await knowledge_manager.load_kg_core(mock_request, mock_respond, "test-user") + # Wait for background processing import asyncio await asyncio.sleep(0.1) - + # Verify publishers were created and started assert mock_publisher_class.call_count == 2 mock_triples_pub.start.assert_called_once() mock_ge_pub.start.assert_called_once() - + # Verify triples were sent with correct collection mock_triples_pub.send.assert_called_once() sent_triples = mock_triples_pub.send.call_args[0][1] @@ -174,12 +173,12 @@ class TestKnowledgeManagerLoadCore: # Start the core loader background task knowledge_manager.background_task = None - await knowledge_manager.load_kg_core(mock_request, mock_respond) - + await knowledge_manager.load_kg_core(mock_request, mock_respond, "test-user") + # Wait for background processing import asyncio await asyncio.sleep(0.1) - + # Verify graph embeddings were sent with correct collection mock_ge_pub.send.assert_called_once() sent_ge = mock_ge_pub.send.call_args[0][1] @@ -191,7 +190,6 @@ class TestKnowledgeManagerLoadCore: """Test that load_kg_core falls back to 'default' when request.collection is None.""" # Create request with None collection mock_request = Mock() - mock_request.workspace = "test-user" mock_request.id = "test-doc-id" mock_request.collection = None # Should fall back to "default" mock_request.flow = "test-flow" @@ -213,12 +211,12 @@ class TestKnowledgeManagerLoadCore: # Start the core loader background task knowledge_manager.background_task = None - await knowledge_manager.load_kg_core(mock_request, mock_respond) - + await knowledge_manager.load_kg_core(mock_request, mock_respond, "test-user") + # Wait for background processing import asyncio await asyncio.sleep(0.1) - + # Verify triples were sent with default collection mock_triples_pub.send.assert_called_once() sent_triples = mock_triples_pub.send.call_args[0][1] @@ -246,13 +244,13 @@ class TestKnowledgeManagerLoadCore: mock_publisher_class.side_effect = [mock_triples_pub, mock_ge_pub] # Start the core loader background task - knowledge_manager.background_task = None - await knowledge_manager.load_kg_core(mock_request, mock_respond) - + knowledge_manager.background_task = None + await knowledge_manager.load_kg_core(mock_request, mock_respond, "test-user") + # Wait for background processing import asyncio await asyncio.sleep(0.1) - + # Verify both publishers were used with correct collection mock_triples_pub.send.assert_called_once() sent_triples = mock_triples_pub.send.call_args[0][1] @@ -267,7 +265,6 @@ class TestKnowledgeManagerLoadCore: """Test that load_kg_core validates flow configuration before processing.""" # Request with invalid flow mock_request = Mock() - mock_request.workspace = "test-user" mock_request.id = "test-doc-id" mock_request.collection = "test-collection" mock_request.flow = "invalid-flow" # Not in mock_flow_config.flows @@ -276,12 +273,12 @@ class TestKnowledgeManagerLoadCore: # Start the core loader background task knowledge_manager.background_task = None - await knowledge_manager.load_kg_core(mock_request, mock_respond) - + await knowledge_manager.load_kg_core(mock_request, mock_respond, "test-user") + # Wait for background processing import asyncio await asyncio.sleep(0.1) - + # Should have responded with error mock_respond.assert_called() response = mock_respond.call_args[0][0] @@ -295,18 +292,17 @@ class TestKnowledgeManagerLoadCore: # Test missing ID mock_request = Mock() - mock_request.workspace = "test-user" mock_request.id = None # Missing mock_request.collection = "test-collection" mock_request.flow = "test-flow" knowledge_manager.background_task = None - await knowledge_manager.load_kg_core(mock_request, mock_respond) - + await knowledge_manager.load_kg_core(mock_request, mock_respond, "test-user") + # Wait for background processing import asyncio await asyncio.sleep(0.1) - + # Should respond with error mock_respond.assert_called() response = mock_respond.call_args[0][0] @@ -321,18 +317,17 @@ class TestKnowledgeManagerOtherMethods: async def test_get_kg_core_preserves_collection_from_store(self, knowledge_manager, sample_triples): """Test that get_kg_core preserves collection field from stored data.""" mock_request = Mock() - mock_request.workspace = "test-user" mock_request.id = "test-doc-id" - + mock_respond = AsyncMock() - + async def mock_get_triples(user, doc_id, receiver): await receiver(sample_triples) - + knowledge_manager.table_store.get_triples = mock_get_triples knowledge_manager.table_store.get_graph_embeddings = AsyncMock() - - await knowledge_manager.get_kg_core(mock_request, mock_respond) + + await knowledge_manager.get_kg_core(mock_request, mock_respond, "test-user") # Should have called respond for triples and final EOS assert mock_respond.call_count >= 2 @@ -352,14 +347,13 @@ class TestKnowledgeManagerOtherMethods: async def test_list_kg_cores(self, knowledge_manager): """Test listing knowledge cores.""" mock_request = Mock() - mock_request.workspace = "test-user" - + mock_respond = AsyncMock() - + # Mock return value knowledge_manager.table_store.list_kg_cores.return_value = ["doc1", "doc2", "doc3"] - - await knowledge_manager.list_kg_cores(mock_request, mock_respond) + + await knowledge_manager.list_kg_cores(mock_request, mock_respond, "test-user") # Verify table store was called correctly knowledge_manager.table_store.list_kg_cores.assert_called_once_with("test-user") @@ -374,12 +368,11 @@ class TestKnowledgeManagerOtherMethods: async def test_delete_kg_core(self, knowledge_manager): """Test deleting knowledge cores.""" mock_request = Mock() - mock_request.workspace = "test-user" mock_request.id = "test-doc-id" - + mock_respond = AsyncMock() - - await knowledge_manager.delete_kg_core(mock_request, mock_respond) + + await knowledge_manager.delete_kg_core(mock_request, mock_respond, "test-user") # Verify table store was called correctly knowledge_manager.table_store.delete_kg_core.assert_called_once_with("test-user", "test-doc-id") diff --git a/tests/unit/test_gateway/test_capabilities.py b/tests/unit/test_gateway/test_capabilities.py index 102e381e..4f781b16 100644 --- a/tests/unit/test_gateway/test_capabilities.py +++ b/tests/unit/test_gateway/test_capabilities.py @@ -34,7 +34,7 @@ class _Identity: self.source = "api-key" -def _allow_auth(identity=None): +def _allow_auth(identity=None, workspaces=None): """Build an Auth double that authenticates to ``identity`` and allows every authorise() call.""" auth = MagicMock() @@ -42,16 +42,18 @@ def _allow_auth(identity=None): return_value=identity or _Identity(), ) auth.authorise = AsyncMock(return_value=None) + auth.known_workspaces = workspaces or {"default", "acme"} return auth -def _deny_auth(identity=None): +def _deny_auth(identity=None, workspaces=None): """Build an Auth double that authenticates but denies authorise.""" auth = MagicMock() auth.authenticate = AsyncMock( return_value=identity or _Identity(), ) auth.authorise = AsyncMock(side_effect=access_denied()) + auth.known_workspaces = workspaces or {"default", "acme"} return auth diff --git a/tests/unit/test_gateway/test_dispatch_manager.py b/tests/unit/test_gateway/test_dispatch_manager.py index e399d712..8fe91830 100644 --- a/tests/unit/test_gateway/test_dispatch_manager.py +++ b/tests/unit/test_gateway/test_dispatch_manager.py @@ -176,7 +176,7 @@ class TestDispatcherManager: params = {"kind": "test_kind"} result = await manager.process_global_service("data", "responder", params) - manager.invoke_global_service.assert_called_once_with("data", "responder", "test_kind") + manager.invoke_global_service.assert_called_once_with("data", "responder", "test_kind", workspace=None) assert result == "global_result" @pytest.mark.asyncio @@ -185,24 +185,24 @@ class TestDispatcherManager: mock_backend = Mock() mock_config_receiver = Mock() manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) - + # Pre-populate with existing dispatcher mock_dispatcher = Mock() mock_dispatcher.process = AsyncMock(return_value="cached_result") - manager.dispatchers[(None, "config")] = mock_dispatcher - - result = await manager.invoke_global_service("data", "responder", "config") - + manager.dispatchers[(None, "iam")] = mock_dispatcher + + result = await manager.invoke_global_service("data", "responder", "iam") + mock_dispatcher.process.assert_called_once_with("data", "responder") assert result == "cached_result" @pytest.mark.asyncio async def test_invoke_global_service_creates_new_dispatcher(self): - """Test invoke_global_service creates new dispatcher""" + """Test invoke_global_service creates new dispatcher for system service""" mock_backend = Mock() mock_config_receiver = Mock() manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) - + with patch('trustgraph.gateway.dispatch.manager.global_dispatchers') as mock_dispatchers: mock_dispatcher_class = Mock() mock_dispatcher = Mock() @@ -210,25 +210,51 @@ class TestDispatcherManager: mock_dispatcher.process = AsyncMock(return_value="new_result") mock_dispatcher_class.return_value = mock_dispatcher mock_dispatchers.__getitem__.return_value = mock_dispatcher_class - - result = await manager.invoke_global_service("data", "responder", "config") - - # Verify dispatcher was created with correct parameters + + result = await manager.invoke_global_service("data", "responder", "iam") + mock_dispatcher_class.assert_called_once_with( backend=mock_backend, timeout=120, - consumer="api-gateway-config-request", - subscriber="api-gateway-config-request", + consumer="api-gateway-iam-request", + subscriber="api-gateway-iam-request", request_queue=None, response_queue=None ) mock_dispatcher.start.assert_called_once() mock_dispatcher.process.assert_called_once_with("data", "responder") - - # Verify dispatcher was cached - assert manager.dispatchers[(None, "config")] == mock_dispatcher + + assert manager.dispatchers[(None, "iam")] == mock_dispatcher assert result == "new_result" + @pytest.mark.asyncio + async def test_invoke_global_service_workspace_required_for_workspace_dispatchers(self): + """Workspace dispatchers (config, flow, etc.) require a workspace""" + mock_backend = Mock() + mock_config_receiver = Mock() + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) + + with pytest.raises(RuntimeError, match="Workspace is required for config"): + await manager.invoke_global_service("data", "responder", "config") + + @pytest.mark.asyncio + async def test_invoke_global_service_workspace_dispatcher_with_workspace(self): + """Workspace dispatchers work when workspace is provided""" + mock_backend = Mock() + mock_config_receiver = Mock() + manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) + + mock_dispatcher = Mock() + mock_dispatcher.process = AsyncMock(return_value="ws_result") + manager.dispatchers[("alice", "config")] = mock_dispatcher + + result = await manager.invoke_global_service( + "data", "responder", "config", workspace="alice", + ) + + mock_dispatcher.process.assert_called_once_with("data", "responder") + assert result == "ws_result" + def test_dispatch_flow_import_returns_method(self): """Test dispatch_flow_import returns correct method""" mock_backend = Mock() @@ -610,7 +636,7 @@ class TestDispatcherManager: mock_dispatchers.__getitem__.return_value = mock_dispatcher_class results = await asyncio.gather(*[ - manager.invoke_global_service("data", "responder", "config") + manager.invoke_global_service("data", "responder", "iam") for _ in range(5) ]) @@ -618,7 +644,7 @@ class TestDispatcherManager: "Dispatcher class instantiated more than once — duplicate consumer bug" ) assert mock_dispatcher.start.call_count == 1 - assert manager.dispatchers[(None, "config")] is mock_dispatcher + assert manager.dispatchers[(None, "iam")] is mock_dispatcher assert all(r == "result" for r in results) @pytest.mark.asyncio diff --git a/tests/unit/test_librarian/test_chunked_upload.py b/tests/unit/test_librarian/test_chunked_upload.py index 7e7be480..d5831ea3 100644 --- a/tests/unit/test_librarian/test_chunked_upload.py +++ b/tests/unit/test_librarian/test_chunked_upload.py @@ -33,12 +33,11 @@ def _make_librarian(min_chunk_size=1): def _make_doc_metadata( - doc_id="doc-1", kind="application/pdf", workspace="alice", title="Test Doc" + doc_id="doc-1", kind="application/pdf", title="Test Doc" ): meta = MagicMock() meta.id = doc_id meta.kind = kind - meta.workspace = workspace meta.title = title meta.time = 1700000000 meta.comments = "" @@ -47,21 +46,20 @@ def _make_doc_metadata( def _make_begin_request( - doc_id="doc-1", kind="application/pdf", workspace="alice", + doc_id="doc-1", kind="application/pdf", total_size=10_000_000, chunk_size=0 ): req = MagicMock() - req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, workspace=workspace) + req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind) req.total_size = total_size req.chunk_size = chunk_size return req -def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, workspace="alice", content=b"data"): +def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, content=b"data"): req = MagicMock() req.upload_id = upload_id req.chunk_index = chunk_index - req.workspace = workspace req.content = base64.b64encode(content) return req @@ -76,7 +74,7 @@ def _make_session( if document_metadata is None: document_metadata = json.dumps({ "id": document_id, "kind": "application/pdf", - "workspace": workspace, "title": "Test", "time": 1700000000, + "title": "Test", "time": 1700000000, "comments": "", "tags": [], }) return { @@ -105,7 +103,7 @@ class TestBeginUpload: lib.blob_store.create_multipart_upload.return_value = "s3-upload-id" req = _make_begin_request(total_size=10_000_000) - resp = await lib.begin_upload(req) + resp = await lib.begin_upload(req, "alice") assert resp.error is None assert resp.upload_id is not None @@ -119,7 +117,7 @@ class TestBeginUpload: lib.blob_store.create_multipart_upload.return_value = "s3-id" req = _make_begin_request(total_size=10_000, chunk_size=3000) - resp = await lib.begin_upload(req) + resp = await lib.begin_upload(req, "alice") assert resp.chunk_size == 3000 assert resp.total_chunks == math.ceil(10_000 / 3000) @@ -130,7 +128,7 @@ class TestBeginUpload: req = _make_begin_request(kind="") with pytest.raises(RequestError, match="MIME type.*required"): - await lib.begin_upload(req) + await lib.begin_upload(req, "alice") @pytest.mark.asyncio async def test_rejects_duplicate_document(self): @@ -139,7 +137,7 @@ class TestBeginUpload: req = _make_begin_request() with pytest.raises(RequestError, match="already exists"): - await lib.begin_upload(req) + await lib.begin_upload(req, "alice") @pytest.mark.asyncio async def test_rejects_zero_size(self): @@ -148,7 +146,7 @@ class TestBeginUpload: req = _make_begin_request(total_size=0) with pytest.raises(RequestError, match="positive"): - await lib.begin_upload(req) + await lib.begin_upload(req, "alice") @pytest.mark.asyncio async def test_rejects_chunk_below_minimum(self): @@ -157,7 +155,7 @@ class TestBeginUpload: req = _make_begin_request(total_size=10_000, chunk_size=512) with pytest.raises(RequestError, match="below minimum"): - await lib.begin_upload(req) + await lib.begin_upload(req, "alice") @pytest.mark.asyncio async def test_calls_s3_create_multipart(self): @@ -166,7 +164,7 @@ class TestBeginUpload: lib.blob_store.create_multipart_upload.return_value = "s3-id" req = _make_begin_request(kind="application/pdf") - await lib.begin_upload(req) + await lib.begin_upload(req, "alice") lib.blob_store.create_multipart_upload.assert_called_once() # create_multipart_upload(object_id, kind) — positional args @@ -180,7 +178,7 @@ class TestBeginUpload: lib.blob_store.create_multipart_upload.return_value = "s3-id" req = _make_begin_request(total_size=5_000_000) - resp = await lib.begin_upload(req) + resp = await lib.begin_upload(req, "alice") lib.table_store.create_upload_session.assert_called_once() kwargs = lib.table_store.create_upload_session.call_args[1] @@ -195,7 +193,7 @@ class TestBeginUpload: lib.blob_store.create_multipart_upload.return_value = "s3-id" req = _make_begin_request(kind="text/plain", total_size=1000) - resp = await lib.begin_upload(req) + resp = await lib.begin_upload(req, "alice") assert resp.error is None @@ -213,7 +211,7 @@ class TestUploadChunk: lib.blob_store.upload_part.return_value = "etag-1" req = _make_upload_chunk_request(chunk_index=0, content=b"chunk data") - resp = await lib.upload_chunk(req) + resp = await lib.upload_chunk(req, "alice") assert resp.error is None assert resp.chunk_index == 0 @@ -229,7 +227,7 @@ class TestUploadChunk: lib.blob_store.upload_part.return_value = "etag" req = _make_upload_chunk_request(chunk_index=0) - await lib.upload_chunk(req) + await lib.upload_chunk(req, "alice") kwargs = lib.blob_store.upload_part.call_args[1] assert kwargs["part_number"] == 1 # 0-indexed chunk → 1-indexed part @@ -242,7 +240,7 @@ class TestUploadChunk: lib.blob_store.upload_part.return_value = "etag" req = _make_upload_chunk_request(chunk_index=3) - await lib.upload_chunk(req) + await lib.upload_chunk(req, "alice") kwargs = lib.blob_store.upload_part.call_args[1] assert kwargs["part_number"] == 4 @@ -254,7 +252,7 @@ class TestUploadChunk: req = _make_upload_chunk_request() with pytest.raises(RequestError, match="not found"): - await lib.upload_chunk(req) + await lib.upload_chunk(req, "alice") @pytest.mark.asyncio async def test_rejects_wrong_user(self): @@ -262,9 +260,9 @@ class TestUploadChunk: session = _make_session(workspace="alice") lib.table_store.get_upload_session.return_value = session - req = _make_upload_chunk_request(workspace="bob") + req = _make_upload_chunk_request() with pytest.raises(RequestError, match="Not authorized"): - await lib.upload_chunk(req) + await lib.upload_chunk(req, "bob") @pytest.mark.asyncio async def test_rejects_negative_chunk_index(self): @@ -274,7 +272,7 @@ class TestUploadChunk: req = _make_upload_chunk_request(chunk_index=-1) with pytest.raises(RequestError, match="Invalid chunk index"): - await lib.upload_chunk(req) + await lib.upload_chunk(req, "alice") @pytest.mark.asyncio async def test_rejects_out_of_range_chunk_index(self): @@ -284,7 +282,7 @@ class TestUploadChunk: req = _make_upload_chunk_request(chunk_index=5) with pytest.raises(RequestError, match="Invalid chunk index"): - await lib.upload_chunk(req) + await lib.upload_chunk(req, "alice") @pytest.mark.asyncio async def test_progress_tracking(self): @@ -297,7 +295,7 @@ class TestUploadChunk: lib.blob_store.upload_part.return_value = "e3" req = _make_upload_chunk_request(chunk_index=2) - resp = await lib.upload_chunk(req) + resp = await lib.upload_chunk(req, "alice") # Dict gets chunk 2 added (len=3), then +1 => 4 assert resp.chunks_received == 4 @@ -316,7 +314,7 @@ class TestUploadChunk: lib.blob_store.upload_part.return_value = "e2" req = _make_upload_chunk_request(chunk_index=1) - resp = await lib.upload_chunk(req) + resp = await lib.upload_chunk(req, "alice") # 3 chunks × 3000 = 9000 > 5000, so capped assert resp.bytes_received <= 5000 @@ -330,7 +328,7 @@ class TestUploadChunk: raw = b"hello world binary data" req = _make_upload_chunk_request(content=raw) - await lib.upload_chunk(req) + await lib.upload_chunk(req, "alice") kwargs = lib.blob_store.upload_part.call_args[1] assert kwargs["data"] == raw @@ -353,9 +351,8 @@ class TestCompleteUpload: req = MagicMock() req.upload_id = "up-1" - req.workspace = "alice" - resp = await lib.complete_upload(req) + resp = await lib.complete_upload(req, "alice") assert resp.error is None assert resp.document_id == "doc-1" @@ -375,9 +372,8 @@ class TestCompleteUpload: req = MagicMock() req.upload_id = "up-1" - req.workspace = "alice" - await lib.complete_upload(req) + await lib.complete_upload(req, "alice") parts = lib.blob_store.complete_multipart_upload.call_args[1]["parts"] part_numbers = [p[0] for p in parts] @@ -394,10 +390,9 @@ class TestCompleteUpload: req = MagicMock() req.upload_id = "up-1" - req.workspace = "alice" with pytest.raises(RequestError, match="Missing chunks"): - await lib.complete_upload(req) + await lib.complete_upload(req, "alice") @pytest.mark.asyncio async def test_rejects_expired_session(self): @@ -406,10 +401,9 @@ class TestCompleteUpload: req = MagicMock() req.upload_id = "up-gone" - req.workspace = "alice" with pytest.raises(RequestError, match="not found"): - await lib.complete_upload(req) + await lib.complete_upload(req, "alice") @pytest.mark.asyncio async def test_rejects_wrong_user(self): @@ -419,10 +413,9 @@ class TestCompleteUpload: req = MagicMock() req.upload_id = "up-1" - req.workspace = "bob" with pytest.raises(RequestError, match="Not authorized"): - await lib.complete_upload(req) + await lib.complete_upload(req, "bob") # --------------------------------------------------------------------------- @@ -439,9 +432,8 @@ class TestAbortUpload: req = MagicMock() req.upload_id = "up-1" - req.workspace = "alice" - resp = await lib.abort_upload(req) + resp = await lib.abort_upload(req, "alice") assert resp.error is None lib.blob_store.abort_multipart_upload.assert_called_once_with( @@ -456,10 +448,9 @@ class TestAbortUpload: req = MagicMock() req.upload_id = "up-gone" - req.workspace = "alice" with pytest.raises(RequestError, match="not found"): - await lib.abort_upload(req) + await lib.abort_upload(req, "alice") @pytest.mark.asyncio async def test_rejects_wrong_user(self): @@ -469,10 +460,9 @@ class TestAbortUpload: req = MagicMock() req.upload_id = "up-1" - req.workspace = "bob" with pytest.raises(RequestError, match="Not authorized"): - await lib.abort_upload(req) + await lib.abort_upload(req, "bob") # --------------------------------------------------------------------------- @@ -492,9 +482,8 @@ class TestGetUploadStatus: req = MagicMock() req.upload_id = "up-1" - req.workspace = "alice" - resp = await lib.get_upload_status(req) + resp = await lib.get_upload_status(req, "alice") assert resp.upload_state == "in-progress" assert resp.chunks_received == 3 @@ -510,9 +499,8 @@ class TestGetUploadStatus: req = MagicMock() req.upload_id = "up-expired" - req.workspace = "alice" - resp = await lib.get_upload_status(req) + resp = await lib.get_upload_status(req, "alice") assert resp.upload_state == "expired" @@ -527,9 +515,8 @@ class TestGetUploadStatus: req = MagicMock() req.upload_id = "up-1" - req.workspace = "alice" - resp = await lib.get_upload_status(req) + resp = await lib.get_upload_status(req, "alice") assert resp.missing_chunks == [] assert resp.chunks_received == 3 @@ -544,10 +531,9 @@ class TestGetUploadStatus: req = MagicMock() req.upload_id = "up-1" - req.workspace = "bob" with pytest.raises(RequestError, match="Not authorized"): - await lib.get_upload_status(req) + await lib.get_upload_status(req, "bob") # --------------------------------------------------------------------------- @@ -564,12 +550,11 @@ class TestStreamDocument: lib.blob_store.get_range = AsyncMock(return_value=b"x" * 2000) req = MagicMock() - req.workspace = "alice" req.document_id = "doc-1" req.chunk_size = 2000 chunks = [] - async for resp in lib.stream_document(req): + async for resp in lib.stream_document(req, "alice"): chunks.append(resp) assert len(chunks) == 3 # ceil(5000/2000) @@ -587,12 +572,11 @@ class TestStreamDocument: lib.blob_store.get_range = AsyncMock(return_value=b"x" * 500) req = MagicMock() - req.workspace = "alice" req.document_id = "doc-1" req.chunk_size = 2000 chunks = [] - async for resp in lib.stream_document(req): + async for resp in lib.stream_document(req, "alice"): chunks.append(resp) assert len(chunks) == 1 @@ -608,12 +592,11 @@ class TestStreamDocument: lib.blob_store.get_range = AsyncMock(return_value=b"x" * 100) req = MagicMock() - req.workspace = "alice" req.document_id = "doc-1" req.chunk_size = 2000 chunks = [] - async for resp in lib.stream_document(req): + async for resp in lib.stream_document(req, "alice"): chunks.append(resp) # Verify the byte ranges passed to get_range @@ -630,12 +613,11 @@ class TestStreamDocument: lib.blob_store.get_range = AsyncMock(return_value=b"x") req = MagicMock() - req.workspace = "alice" req.document_id = "doc-1" req.chunk_size = 0 # Should use default 1MB chunks = [] - async for resp in lib.stream_document(req): + async for resp in lib.stream_document(req, "alice"): chunks.append(resp) assert len(chunks) == 2 # ceil(2MB / 1MB) @@ -649,12 +631,11 @@ class TestStreamDocument: lib.blob_store.get_range = AsyncMock(return_value=raw) req = MagicMock() - req.workspace = "alice" req.document_id = "doc-1" req.chunk_size = 1000 chunks = [] - async for resp in lib.stream_document(req): + async for resp in lib.stream_document(req, "alice"): chunks.append(resp) assert chunks[0].content == base64.b64encode(raw) @@ -666,12 +647,11 @@ class TestStreamDocument: lib.blob_store.get_size = AsyncMock(return_value=5000) req = MagicMock() - req.workspace = "alice" req.document_id = "doc-1" req.chunk_size = 512 with pytest.raises(RequestError, match="below minimum"): - async for _ in lib.stream_document(req): + async for _ in lib.stream_document(req, "alice"): pass @@ -698,9 +678,8 @@ class TestListUploads: ] req = MagicMock() - req.workspace = "alice" - resp = await lib.list_uploads(req) + resp = await lib.list_uploads(req, "alice") assert resp.error is None assert len(resp.upload_sessions) == 1 @@ -713,8 +692,7 @@ class TestListUploads: lib.table_store.list_upload_sessions.return_value = [] req = MagicMock() - req.workspace = "alice" - resp = await lib.list_uploads(req) + resp = await lib.list_uploads(req, "alice") assert resp.upload_sessions == [] diff --git a/tests/unit/test_reliability/test_metadata_preservation.py b/tests/unit/test_reliability/test_metadata_preservation.py index 2170c763..b9d209be 100644 --- a/tests/unit/test_reliability/test_metadata_preservation.py +++ b/tests/unit/test_reliability/test_metadata_preservation.py @@ -30,7 +30,6 @@ class TestDocumentMetadataTranslator: "title": "Test Document", "comments": "No comments", "metadata": [], - "workspace": "alice", "tags": ["finance", "q4"], "parent-id": "doc-100", "document-type": "page", @@ -40,14 +39,12 @@ class TestDocumentMetadataTranslator: assert obj.time == 1710000000 assert obj.kind == "application/pdf" assert obj.title == "Test Document" - assert obj.workspace == "alice" assert obj.tags == ["finance", "q4"] assert obj.parent_id == "doc-100" assert obj.document_type == "page" wire = self.tx.encode(obj) assert wire["id"] == "doc-123" - assert wire["workspace"] == "alice" assert wire["parent-id"] == "doc-100" assert wire["document-type"] == "page" @@ -80,10 +77,9 @@ class TestDocumentMetadataTranslator: def test_falsy_fields_omitted_from_wire(self): """Empty string fields should be omitted from wire format.""" - obj = DocumentMetadata(id="", time=0, workspace="") + obj = DocumentMetadata(id="", time=0) wire = self.tx.encode(obj) assert "id" not in wire - assert "workspace" not in wire # --------------------------------------------------------------------------- @@ -101,7 +97,6 @@ class TestProcessingMetadataTranslator: "document-id": "doc-123", "time": 1710000000, "flow": "default", - "workspace": "alice", "collection": "my-collection", "tags": ["tag1"], } @@ -109,20 +104,17 @@ class TestProcessingMetadataTranslator: assert obj.id == "proc-1" assert obj.document_id == "doc-123" assert obj.flow == "default" - assert obj.workspace == "alice" assert obj.collection == "my-collection" assert obj.tags == ["tag1"] wire = self.tx.encode(obj) assert wire["id"] == "proc-1" assert wire["document-id"] == "doc-123" - assert wire["workspace"] == "alice" assert wire["collection"] == "my-collection" def test_missing_fields_use_defaults(self): obj = self.tx.decode({}) assert obj.id is None - assert obj.workspace is None assert obj.collection is None def test_tags_none_omitted(self): @@ -135,10 +127,9 @@ class TestProcessingMetadataTranslator: wire = self.tx.encode(obj) assert wire["tags"] == [] - def test_workspace_and_collection_preserved(self): + def test_collection_preserved(self): """Core pipeline routing fields must survive round-trip.""" - data = {"workspace": "bob", "collection": "research"} + data = {"collection": "research"} obj = self.tx.decode(data) wire = self.tx.encode(obj) - assert wire["workspace"] == "bob" assert wire["collection"] == "research" diff --git a/tests/unit/test_translators/test_knowledge_translator_roundtrip.py b/tests/unit/test_translators/test_knowledge_translator_roundtrip.py index 64f2e5d4..437b83c8 100644 --- a/tests/unit/test_translators/test_knowledge_translator_roundtrip.py +++ b/tests/unit/test_translators/test_knowledge_translator_roundtrip.py @@ -41,7 +41,6 @@ def translator(): def graph_embeddings_request(): return KnowledgeRequest( operation="put-kg-core", - workspace="alice", id="doc-1", flow="default", collection="testcoll", @@ -110,7 +109,7 @@ class TestKnowledgeRequestTranslatorGraphEmbeddings: assert isinstance(decoded, KnowledgeRequest) assert decoded.operation == "put-kg-core" - assert decoded.workspace == "alice" + assert decoded.id == "doc-1" assert decoded.id == "doc-1" assert decoded.flow == "default" assert decoded.collection == "testcoll" diff --git a/trustgraph-base/trustgraph/api/library.py b/trustgraph-base/trustgraph/api/library.py index 8f99e601..024e933d 100644 --- a/trustgraph-base/trustgraph/api/library.py +++ b/trustgraph-base/trustgraph/api/library.py @@ -217,7 +217,6 @@ class Library: "title": title, "comments": comments, "metadata": triples, - "workspace": self.api.workspace, "tags": tags }, "content": base64.b64encode(document).decode("utf-8"), @@ -249,7 +248,6 @@ class Library: "kind": kind, "title": title, "comments": comments, - "workspace": self.api.workspace, "tags": tags, }, "total-size": total_size, @@ -377,7 +375,6 @@ class Library: ) for w in v["metadata"] ], - workspace = v.get("workspace", ""), tags = v["tags"], parent_id = v.get("parent-id", ""), document_type = v.get("document-type", "source"), @@ -436,7 +433,6 @@ class Library: ) for w in doc["metadata"] ], - workspace = doc.get("workspace", ""), tags = doc["tags"], parent_id = doc.get("parent-id", ""), document_type = doc.get("document-type", "source"), @@ -485,7 +481,6 @@ class Library: "operation": "update-document", "workspace": self.api.workspace, "document-metadata": { - "workspace": self.api.workspace, "document-id": id, "time": metadata.time, "title": metadata.title, @@ -599,7 +594,6 @@ class Library: "document-id": document_id, "time": int(time.time()), "flow": flow, - "workspace": self.api.workspace, "collection": collection, "tags": tags, } @@ -681,7 +675,6 @@ class Library: document_id = v["document-id"], time = datetime.datetime.fromtimestamp(v["time"]), flow = v["flow"], - workspace = v.get("workspace", ""), collection = v["collection"], tags = v["tags"], ) @@ -945,7 +938,6 @@ class Library: "title": title, "comments": comments, "metadata": triples, - "workspace": self.api.workspace, "tags": tags, "parent-id": parent_id, "document-type": "extracted", diff --git a/trustgraph-base/trustgraph/api/types.py b/trustgraph-base/trustgraph/api/types.py index 129f807a..fafb8224 100644 --- a/trustgraph-base/trustgraph/api/types.py +++ b/trustgraph-base/trustgraph/api/types.py @@ -65,7 +65,6 @@ class DocumentMetadata: title: Document title comments: Additional comments or description metadata: List of RDF triples providing structured metadata - workspace: Workspace the document belongs to tags: List of tags for categorization parent_id: Parent document ID for child documents (empty for top-level docs) document_type: "source" for uploaded documents, "extracted" for derived content @@ -76,7 +75,6 @@ class DocumentMetadata: title : str comments : str metadata : List[Triple] - workspace : str tags : List[str] parent_id : str = "" document_type : str = "source" @@ -91,7 +89,6 @@ class ProcessingMetadata: document_id: ID of the document being processed time: Processing start timestamp flow: Flow instance handling the processing - workspace: Workspace the processing job belongs to collection: Target collection for processed data tags: List of tags for categorization """ @@ -99,7 +96,6 @@ class ProcessingMetadata: document_id : str time : datetime.datetime flow : str - workspace : str collection : str tags : List[str] diff --git a/trustgraph-base/trustgraph/base/__init__.py b/trustgraph-base/trustgraph/base/__init__.py index 91622156..180994b4 100644 --- a/trustgraph-base/trustgraph/base/__init__.py +++ b/trustgraph-base/trustgraph/base/__init__.py @@ -7,6 +7,7 @@ from . publisher import Publisher from . subscriber import Subscriber from . metrics import ProcessorMetrics, ConsumerMetrics, ProducerMetrics, SubscriberMetrics from . logging import add_logging_args, setup_logging +from . workspace_processor import WorkspaceProcessor from . flow_processor import FlowProcessor from . consumer_spec import ConsumerSpec from . parameter_spec import ParameterSpec diff --git a/trustgraph-base/trustgraph/base/async_processor.py b/trustgraph-base/trustgraph/base/async_processor.py index 9741d2f5..dd582078 100644 --- a/trustgraph-base/trustgraph/base/async_processor.py +++ b/trustgraph-base/trustgraph/base/async_processor.py @@ -71,6 +71,11 @@ class AsyncProcessor: # { "handler": async_fn, "types": set_or_none } self.config_handlers = [] + # Workspace lifecycle handlers, called when workspaces are + # created or deleted. Each entry is an async callable: + # async def handler(workspace_changes: WorkspaceChanges) + self.workspace_handlers = [] + # Track the current config version for dedup self.config_version = 0 @@ -251,6 +256,10 @@ class AsyncProcessor: "types": set(types) if types else None, }) + # Register a handler for workspace lifecycle events + def register_workspace_handler(self, handler: Callable[..., Any]) -> None: + self.workspace_handlers.append(handler) + # Called when a config notify message arrives async def on_config_notify(self, message, consumer, flow): @@ -266,6 +275,16 @@ class AsyncProcessor: ) return + # Dispatch workspace lifecycle events before config handlers + if v.workspace_changes and self.workspace_handlers: + for handler in self.workspace_handlers: + try: + await handler(v.workspace_changes) + except Exception as e: + logger.error( + f"Workspace handler failed: {e}", exc_info=True + ) + notify_types = set(changes.keys()) # Filter out handlers that don't care about any of the changed diff --git a/trustgraph-base/trustgraph/base/flow_processor.py b/trustgraph-base/trustgraph/base/flow_processor.py index aa7bf921..56533c96 100644 --- a/trustgraph-base/trustgraph/base/flow_processor.py +++ b/trustgraph-base/trustgraph/base/flow_processor.py @@ -14,7 +14,7 @@ from .. schema import Error from .. schema import config_request_queue, config_response_queue from .. schema import config_push_queue from .. log_level import LogLevel -from . async_processor import AsyncProcessor +from . workspace_processor import WorkspaceProcessor from . flow import Flow # Module logger @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) # Parent class for configurable processors, configured with flows by # the config service -class FlowProcessor(AsyncProcessor): +class FlowProcessor(WorkspaceProcessor): def __init__(self, **params): @@ -113,7 +113,7 @@ class FlowProcessor(AsyncProcessor): @staticmethod def add_args(parser: ArgumentParser) -> None: - AsyncProcessor.add_args(parser) + WorkspaceProcessor.add_args(parser) # parser.add_argument( # '--rate-limit-retry', diff --git a/trustgraph-base/trustgraph/base/librarian_client.py b/trustgraph-base/trustgraph/base/librarian_client.py index 1876602b..9d835ee7 100644 --- a/trustgraph-base/trustgraph/base/librarian_client.py +++ b/trustgraph-base/trustgraph/base/librarian_client.py @@ -202,7 +202,6 @@ class LibrarianClient: doc_metadata = DocumentMetadata( id=doc_id, - workspace=workspace, kind=kind, title=title or doc_id, parent_id=parent_id, @@ -227,7 +226,6 @@ class LibrarianClient: doc_metadata = DocumentMetadata( id=doc_id, - workspace=workspace, kind=kind, title=title or doc_id, document_type=document_type, diff --git a/trustgraph-base/trustgraph/base/workspace_processor.py b/trustgraph-base/trustgraph/base/workspace_processor.py new file mode 100644 index 00000000..79c1bd7a --- /dev/null +++ b/trustgraph-base/trustgraph/base/workspace_processor.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from argparse import ArgumentParser + +import logging + +from . async_processor import AsyncProcessor + +logger = logging.getLogger(__name__) + +WORKSPACES_NAMESPACE = "__workspaces__" +WORKSPACE_TYPE = "workspace" + + +class WorkspaceProcessor(AsyncProcessor): + + def __init__(self, **params): + + super(WorkspaceProcessor, self).__init__(**params) + + self.active_workspaces = set() + + self.register_workspace_handler(self._handle_workspace_changes) + + async def _discover_workspaces(self): + client = self._create_config_client() + try: + await client.start() + type_data, version = await self._fetch_type_all_workspaces( + client, WORKSPACE_TYPE, + ) + for ws in type_data: + if ws == WORKSPACES_NAMESPACE: + for workspace_id in type_data[ws]: + if workspace_id not in self.active_workspaces: + self.active_workspaces.add(workspace_id) + await self.on_workspace_created(workspace_id) + finally: + await client.stop() + + async def _handle_workspace_changes(self, workspace_changes): + for workspace_id in workspace_changes.created: + if workspace_id not in self.active_workspaces: + self.active_workspaces.add(workspace_id) + logger.info(f"Workspace created: {workspace_id}") + await self.on_workspace_created(workspace_id) + + for workspace_id in workspace_changes.deleted: + if workspace_id in self.active_workspaces: + logger.info(f"Workspace deleted: {workspace_id}") + await self.on_workspace_deleted(workspace_id) + self.active_workspaces.discard(workspace_id) + + async def on_workspace_created(self, workspace): + pass + + async def on_workspace_deleted(self, workspace): + pass + + async def start(self): + await super(WorkspaceProcessor, self).start() + await self._discover_workspaces() + + @staticmethod + def add_args(parser: ArgumentParser) -> None: + AsyncProcessor.add_args(parser) diff --git a/trustgraph-base/trustgraph/messaging/translators/collection.py b/trustgraph-base/trustgraph/messaging/translators/collection.py index cd07bc99..f687875a 100644 --- a/trustgraph-base/trustgraph/messaging/translators/collection.py +++ b/trustgraph-base/trustgraph/messaging/translators/collection.py @@ -9,7 +9,6 @@ class CollectionManagementRequestTranslator(MessageTranslator): def decode(self, data: Dict[str, Any]) -> CollectionManagementRequest: return CollectionManagementRequest( operation=data.get("operation"), - workspace=data.get("workspace", ""), collection=data.get("collection"), timestamp=data.get("timestamp"), name=data.get("name"), @@ -24,8 +23,6 @@ class CollectionManagementRequestTranslator(MessageTranslator): if obj.operation is not None: result["operation"] = obj.operation - if obj.workspace: - result["workspace"] = obj.workspace if obj.collection is not None: result["collection"] = obj.collection if obj.timestamp is not None: diff --git a/trustgraph-base/trustgraph/messaging/translators/flow.py b/trustgraph-base/trustgraph/messaging/translators/flow.py index 07304c18..1e3e19f6 100644 --- a/trustgraph-base/trustgraph/messaging/translators/flow.py +++ b/trustgraph-base/trustgraph/messaging/translators/flow.py @@ -9,7 +9,6 @@ class FlowRequestTranslator(MessageTranslator): def decode(self, data: Dict[str, Any]) -> FlowRequest: return FlowRequest( operation=data.get("operation"), - workspace=data.get("workspace", ""), blueprint_name=data.get("blueprint-name"), blueprint_definition=data.get("blueprint-definition"), description=data.get("description"), @@ -22,8 +21,6 @@ class FlowRequestTranslator(MessageTranslator): if obj.operation is not None: result["operation"] = obj.operation - if obj.workspace is not None: - result["workspace"] = obj.workspace if obj.blueprint_name is not None: result["blueprint-name"] = obj.blueprint_name if obj.blueprint_definition is not None: diff --git a/trustgraph-base/trustgraph/messaging/translators/knowledge.py b/trustgraph-base/trustgraph/messaging/translators/knowledge.py index 83cdbbf4..f2cc8e46 100644 --- a/trustgraph-base/trustgraph/messaging/translators/knowledge.py +++ b/trustgraph-base/trustgraph/messaging/translators/knowledge.py @@ -45,7 +45,6 @@ class KnowledgeRequestTranslator(MessageTranslator): return KnowledgeRequest( operation=data.get("operation"), - workspace=data.get("workspace", ""), id=data.get("id"), flow=data.get("flow"), collection=data.get("collection"), @@ -58,8 +57,6 @@ class KnowledgeRequestTranslator(MessageTranslator): if obj.operation: result["operation"] = obj.operation - if obj.workspace: - result["workspace"] = obj.workspace if obj.id: result["id"] = obj.id if obj.flow: diff --git a/trustgraph-base/trustgraph/messaging/translators/library.py b/trustgraph-base/trustgraph/messaging/translators/library.py index d528097e..afcb35b3 100644 --- a/trustgraph-base/trustgraph/messaging/translators/library.py +++ b/trustgraph-base/trustgraph/messaging/translators/library.py @@ -49,7 +49,6 @@ class LibraryRequestTranslator(MessageTranslator): document_metadata=doc_metadata, processing_metadata=proc_metadata, content=content, - workspace=data.get("workspace", ""), collection=data.get("collection", ""), criteria=criteria, # Chunked upload fields @@ -76,8 +75,6 @@ class LibraryRequestTranslator(MessageTranslator): result["processing-metadata"] = self.proc_metadata_translator.encode(obj.processing_metadata) if obj.content: result["content"] = obj.content.decode("utf-8") if isinstance(obj.content, bytes) else obj.content - if obj.workspace: - result["workspace"] = obj.workspace if obj.collection: result["collection"] = obj.collection if obj.criteria is not None: diff --git a/trustgraph-base/trustgraph/messaging/translators/metadata.py b/trustgraph-base/trustgraph/messaging/translators/metadata.py index 9da5d5c0..7d213376 100644 --- a/trustgraph-base/trustgraph/messaging/translators/metadata.py +++ b/trustgraph-base/trustgraph/messaging/translators/metadata.py @@ -19,7 +19,6 @@ class DocumentMetadataTranslator(Translator): title=data.get("title"), comments=data.get("comments"), metadata=self.subgraph_translator.decode(metadata) if metadata is not None else [], - workspace=data.get("workspace"), tags=data.get("tags"), parent_id=data.get("parent-id", ""), document_type=data.get("document-type", "source"), @@ -40,8 +39,6 @@ class DocumentMetadataTranslator(Translator): result["comments"] = obj.comments if obj.metadata is not None: result["metadata"] = self.subgraph_translator.encode(obj.metadata) - if obj.workspace: - result["workspace"] = obj.workspace if obj.tags is not None: result["tags"] = obj.tags if obj.parent_id: @@ -61,7 +58,6 @@ class ProcessingMetadataTranslator(Translator): document_id=data.get("document-id"), time=data.get("time"), flow=data.get("flow"), - workspace=data.get("workspace"), collection=data.get("collection"), tags=data.get("tags") ) @@ -77,8 +73,6 @@ class ProcessingMetadataTranslator(Translator): result["time"] = obj.time if obj.flow: result["flow"] = obj.flow - if obj.workspace: - result["workspace"] = obj.workspace if obj.collection: result["collection"] = obj.collection if obj.tags is not None: diff --git a/trustgraph-base/trustgraph/schema/core/metadata.py b/trustgraph-base/trustgraph/schema/core/metadata.py index a307db4f..b243993a 100644 --- a/trustgraph-base/trustgraph/schema/core/metadata.py +++ b/trustgraph-base/trustgraph/schema/core/metadata.py @@ -8,7 +8,5 @@ class Metadata: # Root document identifier (set by librarian, preserved through pipeline) root: str = "" - # Collection the message belongs to. Workspace is NOT carried on the - # message — consumers derive it from flow.workspace (the flow the - # message arrived on), which is the trusted isolation boundary. + # Collection the message belongs to. collection: str = "" diff --git a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py index 37969566..64cb7082 100644 --- a/trustgraph-base/trustgraph/schema/knowledge/knowledge.py +++ b/trustgraph-base/trustgraph/schema/knowledge/knowledge.py @@ -17,7 +17,7 @@ from .embeddings import GraphEmbeddings # <- (error) # list-kg-cores -# -> (workspace) +# -> () # <- () # <- (error) @@ -27,9 +27,6 @@ class KnowledgeRequest: # load-kg-core, unload-kg-core operation: str = "" - # Workspace the cores belong to. Partition / isolation boundary. - workspace: str = "" - # get-kg-core, list-kg-cores, delete-kg-core, put-kg-core, # load-kg-core, unload-kg-core id: str = "" diff --git a/trustgraph-base/trustgraph/schema/services/collection.py b/trustgraph-base/trustgraph/schema/services/collection.py index 13dd0607..2c0ce786 100644 --- a/trustgraph-base/trustgraph/schema/services/collection.py +++ b/trustgraph-base/trustgraph/schema/services/collection.py @@ -22,17 +22,9 @@ class CollectionMetadata: @dataclass class CollectionManagementRequest: - """Request for collection management operations. - - Collection-management is a global (non-flow-scoped) service, so the - workspace has to travel on the wire — it's the isolation boundary - for which workspace's collections the request operates on. - """ + """Request for collection management operations.""" operation: str = "" # e.g., "delete-collection" - # Workspace the collection belongs to. - workspace: str = "" - collection: str = "" timestamp: str = "" # ISO timestamp name: str = "" diff --git a/trustgraph-base/trustgraph/schema/services/config.py b/trustgraph-base/trustgraph/schema/services/config.py index 3bcbc72c..0bdae3a4 100644 --- a/trustgraph-base/trustgraph/schema/services/config.py +++ b/trustgraph-base/trustgraph/schema/services/config.py @@ -70,6 +70,11 @@ class ConfigResponse: # Everything error: Error | None = None +@dataclass +class WorkspaceChanges: + created: list[str] = field(default_factory=list) + deleted: list[str] = field(default_factory=list) + @dataclass class ConfigPush: version: int = 0 @@ -80,6 +85,10 @@ class ConfigPush: # e.g. {"prompt": ["workspace-a", "workspace-b"], "schema": ["workspace-a"]} changes: dict[str, list[str]] = field(default_factory=dict) + # Workspace lifecycle events. Populated when a workspace entry + # is created or deleted in the __workspaces__ config namespace. + workspace_changes: WorkspaceChanges | None = None + config_request_queue = queue('config', cls='request') config_response_queue = queue('config', cls='response') config_push_queue = queue('config', cls='notify') diff --git a/trustgraph-base/trustgraph/schema/services/flow.py b/trustgraph-base/trustgraph/schema/services/flow.py index 586c160d..4f2805e5 100644 --- a/trustgraph-base/trustgraph/schema/services/flow.py +++ b/trustgraph-base/trustgraph/schema/services/flow.py @@ -22,9 +22,6 @@ class FlowRequest: operation: str = "" # list-blueprints, get-blueprint, put-blueprint, delete-blueprint # list-flows, get-flow, start-flow, stop-flow - # Workspace scope — all operations act within this workspace - workspace: str = "" - # get_blueprint, put_blueprint, delete_blueprint, start_flow blueprint_name: str = "" diff --git a/trustgraph-base/trustgraph/schema/services/library.py b/trustgraph-base/trustgraph/schema/services/library.py index 961b47dc..24e74883 100644 --- a/trustgraph-base/trustgraph/schema/services/library.py +++ b/trustgraph-base/trustgraph/schema/services/library.py @@ -43,12 +43,12 @@ from ..core.metadata import Metadata # <- (error) # list-documents -# -> (workspace, collection?) +# -> (collection?) # <- (document_metadata[]) # <- (error) # list-processing -# -> (workspace, collection?) +# -> (collection?) # <- (processing_metadata[]) # <- (error) @@ -78,7 +78,7 @@ from ..core.metadata import Metadata # <- (error) # list-uploads -# -> (workspace) +# -> () # <- (uploads[]) # <- (error) @@ -90,7 +90,6 @@ class DocumentMetadata: title: str = "" comments: str = "" metadata: list[Triple] = field(default_factory=list) - workspace: str = "" tags: list[str] = field(default_factory=list) # Child document support parent_id: str = "" # Empty for top-level docs, set for children @@ -107,7 +106,6 @@ class ProcessingMetadata: document_id: str = "" time: int = 0 flow: str = "" - workspace: str = "" collection: str = "" tags: list[str] = field(default_factory=list) @@ -162,9 +160,6 @@ class LibrarianRequest: # add-document, upload-chunk content: bytes = b"" - # Workspace scopes every library operation. - workspace: str = "" - # list-documents?, list-processing? collection: str = "" diff --git a/trustgraph-cli/trustgraph/cli/show_flow_state.py b/trustgraph-cli/trustgraph/cli/show_flow_state.py index 3a733270..dd911f3e 100644 --- a/trustgraph-cli/trustgraph/cli/show_flow_state.py +++ b/trustgraph-cli/trustgraph/cli/show_flow_state.py @@ -22,15 +22,15 @@ def dump_status(metrics_url, api_url, flow_id, token=None, print() print(f"Flow {flow_id}") - show_processors(metrics_url, flow_id) + show_processors(metrics_url, flow_id, token=token) print() print(f"Blueprint {blueprint_name}") - show_processors(metrics_url, blueprint_name) + show_processors(metrics_url, blueprint_name, token=token) print() -def show_processors(metrics_url, flow_label): +def show_processors(metrics_url, flow_label, token=None): url = f"{metrics_url}/query" @@ -40,7 +40,11 @@ def show_processors(metrics_url, flow_label): "query": "consumer_state{" + expr + "}" } - resp = requests.get(url, params=params) + headers = {} + if token: + headers["Authorization"] = f"Bearer {token}" + + resp = requests.get(url, params=params, headers=headers) obj = resp.json() diff --git a/trustgraph-cli/trustgraph/cli/show_processor_state.py b/trustgraph-cli/trustgraph/cli/show_processor_state.py index 9de05bc6..b6c10138 100644 --- a/trustgraph-cli/trustgraph/cli/show_processor_state.py +++ b/trustgraph-cli/trustgraph/cli/show_processor_state.py @@ -2,16 +2,22 @@ Dump out TrustGraph processor states. """ +import os import requests import argparse default_metrics_url = "http://localhost:8088/api/metrics" +DEFAULT_TOKEN = os.getenv("TRUSTGRAPH_TOKEN", None) -def dump_status(url): +def dump_status(metrics_url, token=None): - url = f"{url}/query?query=processor_info" + url = f"{metrics_url}/query?query=processor_info" - resp = requests.get(url) + headers = {} + if token: + headers["Authorization"] = f"Bearer {token}" + + resp = requests.get(url, headers=headers) obj = resp.json() @@ -39,11 +45,17 @@ def main(): help=f'Metrics URL (default: {default_metrics_url})', ) + parser.add_argument( + '-t', '--token', + default=DEFAULT_TOKEN, + help=f'Bearer token for authentication (default: TRUSTGRAPH_TOKEN env var)', + ) + args = parser.parse_args() try: - dump_status(args.metrics_url) + dump_status(args.metrics_url, args.token) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_token_rate.py b/trustgraph-cli/trustgraph/cli/show_token_rate.py index 04e7dd6a..3fc1a8d6 100644 --- a/trustgraph-cli/trustgraph/cli/show_token_rate.py +++ b/trustgraph-cli/trustgraph/cli/show_token_rate.py @@ -3,12 +3,14 @@ Dump out a stream of token rates, input, output and total. This is averaged across the time since tg-show-token-rate is started. """ +import os import requests import argparse import json import time default_metrics_url = "http://localhost:8088/api/metrics" +DEFAULT_TOKEN = os.getenv("TRUSTGRAPH_TOKEN", None) class Collate: @@ -36,16 +38,20 @@ class Collate: return delta/time, self.total/self.time -def dump_status(metrics_url, number_samples, period): +def dump_status(metrics_url, number_samples, period, token=None): input_url = f"{metrics_url}/query?query=input_tokens_total" output_url = f"{metrics_url}/query?query=output_tokens_total" - resp = requests.get(input_url) + headers = {} + if token: + headers["Authorization"] = f"Bearer {token}" + + resp = requests.get(input_url, headers=headers) obj = resp.json() input = Collate(obj) - resp = requests.get(output_url) + resp = requests.get(output_url, headers=headers) obj = resp.json() output = Collate(obj) @@ -56,20 +62,20 @@ def dump_status(metrics_url, number_samples, period): time.sleep(period) - resp = requests.get(input_url) + resp = requests.get(input_url, headers=headers) obj = resp.json() inr, inl = input.record(obj, period) - resp = requests.get(output_url) + resp = requests.get(output_url, headers=headers) obj = resp.json() outr, outl = output.record(obj, period) - + print(f"{inl:10.1f} {outl:10.1f} {inl+outl:10.1f}") def main(): parser = argparse.ArgumentParser( - prog='tg-show-processor-state', + prog='tg-show-token-rate', description=__doc__, ) @@ -93,6 +99,12 @@ def main(): help=f'Metrics period (default: 100)', ) + parser.add_argument( + '-t', '--token', + default=DEFAULT_TOKEN, + help=f'Bearer token for authentication (default: TRUSTGRAPH_TOKEN env var)', + ) + args = parser.parse_args() try: diff --git a/trustgraph-flow/trustgraph/bootstrap/base.py b/trustgraph-flow/trustgraph/bootstrap/base.py index cb022a16..108e441c 100644 --- a/trustgraph-flow/trustgraph/bootstrap/base.py +++ b/trustgraph-flow/trustgraph/bootstrap/base.py @@ -21,7 +21,7 @@ class InitContext: logger: logging.Logger config: Any # ConfigClient - flow: Any # RequestResponse client for flow-svc + make_flow_client: Any # callable(workspace) -> RequestResponse class Initialiser: diff --git a/trustgraph-flow/trustgraph/bootstrap/bootstrapper/service.py b/trustgraph-flow/trustgraph/bootstrap/bootstrapper/service.py index eb6238d3..7b63a1af 100644 --- a/trustgraph-flow/trustgraph/bootstrap/bootstrapper/service.py +++ b/trustgraph-flow/trustgraph/bootstrap/bootstrapper/service.py @@ -178,13 +178,13 @@ class Processor(AsyncProcessor): ), ) - def _make_flow_client(self): + def _make_flow_client(self, workspace): rr_id = str(uuid.uuid4()) return RequestResponse( backend=self.pubsub_backend, subscription=f"{self.id}--flow--{rr_id}", consumer_name=self.id, - request_topic=flow_request_queue, + request_topic=f"{flow_request_queue}:{workspace}", request_schema=FlowRequest, request_metrics=ProducerMetrics( processor=self.id, flow=None, name="flow-request", @@ -198,14 +198,8 @@ class Processor(AsyncProcessor): async def _open_clients(self): config = self._make_config_client() - flow = self._make_flow_client() await config.start() - try: - await flow.start() - except Exception: - await self._safe_stop(config) - raise - return config, flow + return config async def _safe_stop(self, client): try: @@ -217,7 +211,14 @@ class Processor(AsyncProcessor): # Service gate. # ------------------------------------------------------------------ - async def _gate_ready(self, config, flow): + def _gate_workspace(self): + for spec in self.specs: + ws = getattr(spec.instance, "workspace", None) + if ws and not ws.startswith("_"): + return ws + return None + + async def _gate_ready(self, config): try: await config.keys(SYSTEM_WORKSPACE, INIT_STATE_TYPE) except Exception as e: @@ -226,11 +227,16 @@ class Processor(AsyncProcessor): ) return False + workspace = self._gate_workspace() + if workspace is None: + return True + + flow = self._make_flow_client(workspace) try: + await flow.start() resp = await flow.request( FlowRequest( operation="list-blueprints", - workspace=SYSTEM_WORKSPACE, ), timeout=5, ) @@ -245,6 +251,8 @@ class Processor(AsyncProcessor): f"Gate: flow-svc not ready ({type(e).__name__}: {e})" ) return False + finally: + await self._safe_stop(flow) return True @@ -271,7 +279,7 @@ class Processor(AsyncProcessor): # Per-spec execution. # ------------------------------------------------------------------ - async def _run_spec(self, spec, config, flow): + async def _run_spec(self, spec, config): """Run a single initialiser spec. Returns one of: @@ -298,7 +306,7 @@ class Processor(AsyncProcessor): child_ctx = InitContext( logger=child_logger, config=config, - flow=flow, + make_flow_client=self._make_flow_client, ) child_logger.info( @@ -340,7 +348,7 @@ class Processor(AsyncProcessor): sleep_for = STEADY_INTERVAL try: - config, flow = await self._open_clients() + config = await self._open_clients() except Exception as e: logger.info( f"Failed to open clients " @@ -358,11 +366,11 @@ class Processor(AsyncProcessor): pre_results = {} for spec in pre_specs: pre_results[spec.name] = await self._run_spec( - spec, config, flow, + spec, config, ) # Phase 2: gate. - gate_ok = await self._gate_ready(config, flow) + gate_ok = await self._gate_ready(config) # Phase 3: post-service initialisers, if gate passed. post_results = {} @@ -373,7 +381,7 @@ class Processor(AsyncProcessor): ] for spec in post_specs: post_results[spec.name] = await self._run_spec( - spec, config, flow, + spec, config, ) # Cadence selection. @@ -388,7 +396,6 @@ class Processor(AsyncProcessor): finally: await self._safe_stop(config) - await self._safe_stop(flow) await asyncio.sleep(sleep_for) diff --git a/trustgraph-flow/trustgraph/bootstrap/initialisers/default_flow_start.py b/trustgraph-flow/trustgraph/bootstrap/initialisers/default_flow_start.py index 7e7f96bd..96d13d28 100644 --- a/trustgraph-flow/trustgraph/bootstrap/initialisers/default_flow_start.py +++ b/trustgraph-flow/trustgraph/bootstrap/initialisers/default_flow_start.py @@ -49,53 +49,67 @@ class DefaultFlowStart(Initialiser): async def run(self, ctx, old_flag, new_flag): - # Check whether the flow already exists. Belt-and-braces - # beyond the flag gate: if an operator stops and restarts the - # bootstrapper after the flow is already running, we don't - # want to blindly try to start it again. - list_resp = await ctx.flow.request( - FlowRequest( - operation="list-flows", - workspace=self.workspace, - ), - timeout=10, + workspaces = await ctx.config.keys( + "__workspaces__", "workspace", ) - if list_resp.error: + if self.workspace not in workspaces: raise RuntimeError( - f"list-flows failed: " - f"{list_resp.error.type}: {list_resp.error.message}" + f"Workspace {self.workspace!r} does not exist yet" ) - if self.flow_id in (list_resp.flow_ids or []): + flow = ctx.make_flow_client(self.workspace) + await flow.start() + + try: + + # Check whether the flow already exists. Belt-and-braces + # beyond the flag gate: if an operator stops and restarts the + # bootstrapper after the flow is already running, we don't + # want to blindly try to start it again. + list_resp = await flow.request( + FlowRequest( + operation="list-flows", + ), + timeout=10, + ) + if list_resp.error: + raise RuntimeError( + f"list-flows failed: " + f"{list_resp.error.type}: {list_resp.error.message}" + ) + + if self.flow_id in (list_resp.flow_ids or []): + ctx.logger.info( + f"Flow {self.flow_id!r} already running in workspace " + f"{self.workspace!r}; nothing to do" + ) + return + ctx.logger.info( - f"Flow {self.flow_id!r} already running in workspace " - f"{self.workspace!r}; nothing to do" - ) - return - - ctx.logger.info( - f"Starting flow {self.flow_id!r} " - f"(blueprint={self.blueprint!r}) " - f"in workspace {self.workspace!r}" - ) - - resp = await ctx.flow.request( - FlowRequest( - operation="start-flow", - workspace=self.workspace, - flow_id=self.flow_id, - blueprint_name=self.blueprint, - description=self.description, - parameters=self.parameters, - ), - timeout=30, - ) - if resp.error: - raise RuntimeError( - f"start-flow failed: " - f"{resp.error.type}: {resp.error.message}" + f"Starting flow {self.flow_id!r} " + f"(blueprint={self.blueprint!r}) " + f"in workspace {self.workspace!r}" ) - ctx.logger.info( - f"Flow {self.flow_id!r} started" - ) + resp = await flow.request( + FlowRequest( + operation="start-flow", + flow_id=self.flow_id, + blueprint_name=self.blueprint, + description=self.description, + parameters=self.parameters, + ), + timeout=30, + ) + if resp.error: + raise RuntimeError( + f"start-flow failed: " + f"{resp.error.type}: {resp.error.message}" + ) + + ctx.logger.info( + f"Flow {self.flow_id!r} started" + ) + + finally: + await flow.stop() diff --git a/trustgraph-flow/trustgraph/config/service/config.py b/trustgraph-flow/trustgraph/config/service/config.py index 36af6026..ced4cbe7 100644 --- a/trustgraph-flow/trustgraph/config/service/config.py +++ b/trustgraph-flow/trustgraph/config/service/config.py @@ -2,13 +2,17 @@ import logging from trustgraph.schema import ConfigResponse -from trustgraph.schema import ConfigValue, Error +from trustgraph.schema import ConfigValue, WorkspaceChanges, Error from ... tables.config import ConfigTableStore # Module logger logger = logging.getLogger(__name__) +WORKSPACES_NAMESPACE = "__workspaces__" +WORKSPACE_TYPE = "workspace" +TEMPLATE_WORKSPACE = "__template__" + class Configuration: def __init__(self, push, host, username, password, keyspace): @@ -26,9 +30,7 @@ class Configuration: async def get_version(self): return await self.table_store.get_version() - async def handle_get(self, v): - - workspace = v.workspace + async def handle_get(self, v, workspace): values = [ ConfigValue( @@ -46,18 +48,18 @@ class Configuration: values = values, ) - async def handle_list(self, v): + async def handle_list(self, v, workspace): return ConfigResponse( version = await self.get_version(), directory = await self.table_store.get_keys( - v.workspace, v.type + workspace, v.type ), ) - async def handle_getvalues(self, v): + async def handle_getvalues(self, v, workspace): - vals = await self.table_store.get_values(v.workspace, v.type) + vals = await self.table_store.get_values(workspace, v.type) values = map( lambda x: ConfigValue( @@ -93,9 +95,8 @@ class Configuration: values = values, ) - async def handle_delete(self, v): + async def handle_delete(self, v, workspace): - workspace = v.workspace types = list(set(k.type for k in v.keys)) for k in v.keys: @@ -103,14 +104,22 @@ class Configuration: await self.inc_version() - await self.push(changes={t: [workspace] for t in types}) + workspace_changes = None + if workspace == WORKSPACES_NAMESPACE and WORKSPACE_TYPE in types: + deleted = [k.key for k in v.keys if k.type == WORKSPACE_TYPE] + if deleted: + workspace_changes = WorkspaceChanges(deleted=deleted) + + await self.push( + changes={t: [workspace] for t in types}, + workspace_changes=workspace_changes, + ) return ConfigResponse( ) - async def handle_put(self, v): + async def handle_put(self, v, workspace): - workspace = v.workspace types = list(set(k.type for k in v.values)) for k in v.values: @@ -120,11 +129,49 @@ class Configuration: await self.inc_version() - await self.push(changes={t: [workspace] for t in types}) + workspace_changes = None + if workspace == WORKSPACES_NAMESPACE and WORKSPACE_TYPE in types: + created = [k.key for k in v.values if k.type == WORKSPACE_TYPE] + if created: + workspace_changes = WorkspaceChanges(created=created) + + await self.push( + changes={t: [workspace] for t in types}, + workspace_changes=workspace_changes, + ) return ConfigResponse( ) + async def provision_from_template(self, workspace): + """Copy all config from __template__ into a new workspace, + skipping keys that already exist (upsert-missing).""" + + template = await self.get_config(TEMPLATE_WORKSPACE) + + if not template: + logger.info( + f"No template config to provision for {workspace}" + ) + return 0 + + existing_types = await self.get_config(workspace) + + written = 0 + for type_name, entries in template.items(): + existing_keys = set(existing_types.get(type_name, {}).keys()) + for key, value in entries.items(): + if key not in existing_keys: + await self.table_store.put_config( + workspace, type_name, key, value + ) + written += 1 + + if written > 0: + await self.inc_version() + + return written + async def get_config(self, workspace): table = await self.table_store.get_all_for_workspace(workspace) @@ -138,62 +185,87 @@ class Configuration: return config - async def handle_config(self, v): + async def handle_config(self, v, workspace): - config = await self.get_config(v.workspace) + config = await self.get_config(workspace) return ConfigResponse( version = await self.get_version(), config = config, ) - async def handle(self, msg): + async def handle_workspace(self, msg, workspace): + """Handle workspace-scoped config operations. + Workspace is provided by queue infrastructure.""" logger.debug( - f"Handling config message: {msg.operation} " - f"workspace={msg.workspace}" + f"Handling workspace config message: {msg.operation} " + f"workspace={workspace}" ) - # getvalues-all-ws spans all workspaces, so no workspace - # required; everything else is workspace-scoped. - if msg.operation != "getvalues-all-ws" and not msg.workspace: - return ConfigResponse( - error=Error( - type = "bad-request", - message = "Workspace is required" - ) - ) - if msg.operation == "get": - - resp = await self.handle_get(msg) + resp = await self.handle_get(msg, workspace) elif msg.operation == "list": - - resp = await self.handle_list(msg) + resp = await self.handle_list(msg, workspace) elif msg.operation == "getvalues": - - resp = await self.handle_getvalues(msg) - - elif msg.operation == "getvalues-all-ws": - - resp = await self.handle_getvalues_all_ws(msg) + resp = await self.handle_getvalues(msg, workspace) elif msg.operation == "delete": - - resp = await self.handle_delete(msg) + resp = await self.handle_delete(msg, workspace) elif msg.operation == "put": - - resp = await self.handle_put(msg) + resp = await self.handle_put(msg, workspace) elif msg.operation == "config": - - resp = await self.handle_config(msg) + resp = await self.handle_config(msg, workspace) + + else: + resp = ConfigResponse( + error=Error( + type = "bad-operation", + message = "Bad operation" + ) + ) + + return resp + + async def handle_system(self, msg): + """Handle system-level config operations. + Workspace, when needed, comes from message body.""" + + logger.debug( + f"Handling system config message: {msg.operation} " + f"workspace={msg.workspace}" + ) + + if msg.operation == "getvalues-all-ws": + resp = await self.handle_getvalues_all_ws(msg) + + elif msg.operation in ("get", "list", "getvalues", "delete", + "put", "config"): + + if not msg.workspace: + return ConfigResponse( + error=Error( + type = "bad-request", + message = "Workspace is required" + ) + ) + + handler = { + "get": self.handle_get, + "list": self.handle_list, + "getvalues": self.handle_getvalues, + "delete": self.handle_delete, + "put": self.handle_put, + "config": self.handle_config, + }[msg.operation] + + resp = await handler(msg, msg.workspace) else: - resp = ConfigResponse( error=Error( type = "bad-operation", diff --git a/trustgraph-flow/trustgraph/config/service/service.py b/trustgraph-flow/trustgraph/config/service/service.py index 058f4e4b..36b368eb 100644 --- a/trustgraph-flow/trustgraph/config/service/service.py +++ b/trustgraph-flow/trustgraph/config/service/service.py @@ -1,20 +1,30 @@ """ -Config service. Manages system global configuration state +Config service. Manages system global configuration state. + +Operates a dual-queue regime: +- System queue (config-request): handles cross-workspace operations like + getvalues-all-ws and bootstrapper put/delete on __workspaces__. + The gateway NEVER routes to this queue. +- Per-workspace queues (config-request:): handles + workspace-scoped operations where workspace identity comes from + queue infrastructure, not message body. """ import logging +from functools import partial from trustgraph.schema import Error from trustgraph.schema import ConfigRequest, ConfigResponse, ConfigPush +from trustgraph.schema import WorkspaceChanges from trustgraph.schema import config_request_queue, config_response_queue from trustgraph.schema import config_push_queue from trustgraph.base import AsyncProcessor, Consumer, Producer from trustgraph.base.cassandra_config import add_cassandra_args, resolve_cassandra_config -from . config import Configuration +from . config import Configuration, WORKSPACES_NAMESPACE, WORKSPACE_TYPE from ... base import ProcessorMetrics, ConsumerMetrics, ProducerMetrics from ... base import Consumer, Producer @@ -39,6 +49,11 @@ def is_reserved_workspace(workspace): """ return workspace.startswith("_") + +def workspace_queue(base_queue, workspace): + return f"{base_queue}:{workspace}" + + default_config_request_queue = config_request_queue default_config_response_queue = config_response_queue default_config_push_queue = config_push_queue @@ -48,7 +63,7 @@ default_cassandra_host = "cassandra" class Processor(AsyncProcessor): def __init__(self, **params): - + config_request_queue = params.get( "config_request_queue", default_config_request_queue ) @@ -70,7 +85,7 @@ class Processor(AsyncProcessor): password=cassandra_password, default_keyspace="config" ) - + # Store resolved configuration self.cassandra_host = hosts self.cassandra_username = username @@ -99,17 +114,17 @@ class Processor(AsyncProcessor): processor = self.id, flow = None, name = "config-push" ) - self.config_request_topic = config_request_queue + self.config_request_queue_base = config_request_queue self.config_request_subscriber = id - self.config_request_consumer = Consumer( + self.system_consumer = Consumer( taskgroup = self.taskgroup, backend = self.pubsub, flow = None, topic = config_request_queue, subscriber = id, schema = ConfigRequest, - handler = self.on_config_request, + handler = self.on_system_config_request, metrics = config_request_metrics, ) @@ -135,20 +150,120 @@ class Processor(AsyncProcessor): push = self.push ) + self.workspace_consumers = {} + + self.register_workspace_handler(self._handle_workspace_changes) + logger.info("Config service initialized") + async def _discover_workspaces(self): + logger.info("Discovering workspaces from Cassandra...") + try: + workspaces = await self.config.table_store.get_keys( + WORKSPACES_NAMESPACE, WORKSPACE_TYPE + ) + logger.info(f"Discovered workspaces: {workspaces}") + except Exception as e: + logger.error( + f"Workspace discovery failed: {e}", exc_info=True + ) + return + + for workspace_id in workspaces: + if workspace_id not in self.workspace_consumers: + await self._add_workspace_consumer(workspace_id) + + async def _handle_workspace_changes(self, workspace_changes): + for workspace_id in workspace_changes.created: + if workspace_id not in self.workspace_consumers: + logger.info(f"Workspace created: {workspace_id}") + await self._add_workspace_consumer(workspace_id) + await self._provision_workspace(workspace_id) + + for workspace_id in workspace_changes.deleted: + if workspace_id in self.workspace_consumers: + logger.info(f"Workspace deleted: {workspace_id}") + await self._remove_workspace_consumer(workspace_id) + + async def _provision_workspace(self, workspace_id): + try: + written = await self.config.provision_from_template( + workspace_id + ) + if written > 0: + logger.info( + f"Provisioned workspace {workspace_id} with " + f"{written} entries from template" + ) + # Notify other services about the new config + types = {} + template = await self.config.get_config(workspace_id) + for t in template: + types[t] = [workspace_id] + await self.push(changes=types) + except Exception as e: + logger.error( + f"Failed to provision workspace {workspace_id}: {e}", + exc_info=True, + ) + + async def _add_workspace_consumer(self, workspace_id): + queue = workspace_queue( + self.config_request_queue_base, workspace_id, + ) + + await self.pubsub.ensure_topic(queue) + + consumer = Consumer( + taskgroup=self.taskgroup, + backend=self.pubsub, + flow=None, + topic=queue, + subscriber=self.id, + schema=ConfigRequest, + handler=partial( + self.on_workspace_config_request, + workspace=workspace_id, + ), + metrics=ConsumerMetrics( + processor=self.id, flow=None, + name=f"config-request-{workspace_id}", + ), + ) + + await consumer.start() + self.workspace_consumers[workspace_id] = consumer + + logger.info( + f"Subscribed to workspace config queue: {workspace_id}" + ) + + async def _remove_workspace_consumer(self, workspace_id): + consumer = self.workspace_consumers.pop(workspace_id, None) + if consumer: + await consumer.stop() + logger.info( + f"Unsubscribed from workspace config queue: {workspace_id}" + ) + async def start(self): - await self.pubsub.ensure_topic(self.config_request_topic) + await self.pubsub.ensure_topic(self.config_request_queue_base) await self.push() # Startup poke: empty types = everything - await self.config_request_consumer.start() + await self.system_consumer.start() - async def push(self, changes=None): + # Start the config push subscriber so we receive our own + # workspace change notifications. + await self.config_sub_task.start() + + await self._discover_workspaces() + + async def push(self, changes=None, workspace_changes=None): # Suppress notifications from reserved workspaces (ids starting - # with "_", e.g. "__template__"). Stored config is preserved; - # only the broadcast is filtered. Keeps services oblivious to - # template / bootstrap state. + # with "_", e.g. "__template__") for regular config changes. + # The __workspaces__ namespace is handled separately via + # workspace_changes. if changes: filtered = {} for type_name, workspaces in changes.items(): @@ -165,16 +280,20 @@ class Processor(AsyncProcessor): resp = ConfigPush( version = version, changes = changes or {}, + workspace_changes = workspace_changes, ) await self.config_push_producer.send(resp) logger.info( f"Pushed config poke version {version}, " - f"changes={resp.changes}" + f"changes={resp.changes}, " + f"workspace_changes={resp.workspace_changes}" ) - - async def on_config_request(self, msg, consumer, flow): + + async def on_workspace_config_request( + self, msg, consumer, flow, *, workspace + ): try: @@ -183,16 +302,49 @@ class Processor(AsyncProcessor): # Sender-produced ID id = msg.properties()["id"] - logger.debug(f"Handling config request {id}...") + logger.debug( + f"Handling workspace config request {id} " + f"workspace={workspace}..." + ) - resp = await self.config.handle(v) + resp = await self.config.handle_workspace(v, workspace) await self.config_response_producer.send( resp, properties={"id": id} ) except Exception as e: - + + resp = ConfigResponse( + error=Error( + type = "config-error", + message = str(e), + ), + ) + + await self.config_response_producer.send( + resp, properties={"id": id} + ) + + async def on_system_config_request(self, msg, consumer, flow): + + try: + + v = msg.value() + + # Sender-produced ID + id = msg.properties()["id"] + + logger.debug(f"Handling system config request {id}...") + + resp = await self.config.handle_system(v) + + await self.config_response_producer.send( + resp, properties={"id": id} + ) + + except Exception as e: + resp = ConfigResponse( error=Error( type = "config-error", @@ -228,4 +380,3 @@ class Processor(AsyncProcessor): def run(): Processor.launch(default_ident, __doc__) - diff --git a/trustgraph-flow/trustgraph/cores/knowledge.py b/trustgraph-flow/trustgraph/cores/knowledge.py index ab5f78f0..c3ecfe96 100644 --- a/trustgraph-flow/trustgraph/cores/knowledge.py +++ b/trustgraph-flow/trustgraph/cores/knowledge.py @@ -28,12 +28,12 @@ class KnowledgeManager: self.background_task = None self.flow_config = flow_config - async def delete_kg_core(self, request, respond): + async def delete_kg_core(self, request, respond, workspace): logger.info("Deleting knowledge core...") await self.table_store.delete_kg_core( - request.workspace, request.id + workspace, request.id ) await respond( @@ -46,7 +46,7 @@ class KnowledgeManager: ) ) - async def get_kg_core(self, request, respond): + async def get_kg_core(self, request, respond, workspace): logger.info("Getting knowledge core...") @@ -61,9 +61,8 @@ class KnowledgeManager: ) ) - # Remove doc table row await self.table_store.get_triples( - request.workspace, + workspace, request.id, publish_triples, ) @@ -79,9 +78,8 @@ class KnowledgeManager: ) ) - # Remove doc table row await self.table_store.get_graph_embeddings( - request.workspace, + workspace, request.id, publish_ge, ) @@ -98,9 +96,9 @@ class KnowledgeManager: ) ) - async def list_kg_cores(self, request, respond): + async def list_kg_cores(self, request, respond, workspace): - ids = await self.table_store.list_kg_cores(request.workspace) + ids = await self.table_store.list_kg_cores(workspace) await respond( KnowledgeResponse( @@ -112,9 +110,7 @@ class KnowledgeManager: ) ) - async def put_kg_core(self, request, respond): - - workspace = request.workspace + async def put_kg_core(self, request, respond, workspace): if request.triples: await self.table_store.add_triples(workspace, request.triples) @@ -134,20 +130,18 @@ class KnowledgeManager: ) ) - async def load_kg_core(self, request, respond): + async def load_kg_core(self, request, respond, workspace): if self.background_task is None: self.background_task = asyncio.create_task( self.core_loader() ) - # Wait for it to start (yuck) -# await asyncio.sleep(0.5) - await self.loader_queue.put((request, respond)) + await self.loader_queue.put((request, respond, workspace)) # Not sending a response, the loader thread can do that - async def unload_kg_core(self, request, respond): + async def unload_kg_core(self, request, respond, workspace): await respond( KnowledgeResponse( @@ -168,7 +162,7 @@ class KnowledgeManager: while True: logger.debug("Waiting for next load...") - request, respond = await self.loader_queue.get() + request, respond, workspace = await self.loader_queue.get() logger.info(f"Loading knowledge: {request.id}") @@ -180,7 +174,6 @@ class KnowledgeManager: if request.flow is None: raise RuntimeError("Flow ID must be specified") - workspace = request.workspace ws_flows = self.flow_config.flows.get(workspace, {}) if request.flow not in ws_flows: raise RuntimeError( @@ -262,9 +255,8 @@ class KnowledgeManager: logger.debug("Publishing triples...") - # Remove doc table row await self.table_store.get_triples( - request.workspace, + workspace, request.id, publish_triples, ) @@ -277,9 +269,8 @@ class KnowledgeManager: logger.debug("Publishing graph embeddings...") - # Remove doc table row await self.table_store.get_graph_embeddings( - request.workspace, + workspace, request.id, publish_ge, ) diff --git a/trustgraph-flow/trustgraph/cores/service.py b/trustgraph-flow/trustgraph/cores/service.py index 15e8feb6..ac2f08cb 100755 --- a/trustgraph-flow/trustgraph/cores/service.py +++ b/trustgraph-flow/trustgraph/cores/service.py @@ -9,7 +9,7 @@ import base64 import json import logging -from .. base import AsyncProcessor, Consumer, Producer, Publisher, Subscriber +from .. base import WorkspaceProcessor, Consumer, Producer, Publisher, Subscriber from .. base import ConsumerMetrics, ProducerMetrics from .. base.cassandra_config import add_cassandra_args, resolve_cassandra_config @@ -33,13 +33,18 @@ default_knowledge_response_queue = knowledge_response_queue default_cassandra_host = "cassandra" -class Processor(AsyncProcessor): + +def workspace_queue(base_queue, workspace): + return f"{base_queue}:{workspace}" + + +class Processor(WorkspaceProcessor): def __init__(self, **params): id = params.get("id") - knowledge_request_queue = params.get( + self.knowledge_request_queue_base = params.get( "knowledge_request_queue", default_knowledge_request_queue ) @@ -51,7 +56,6 @@ class Processor(AsyncProcessor): cassandra_username = params.get("cassandra_username") cassandra_password = params.get("cassandra_password") - # Resolve configuration with environment variable fallback hosts, username, password, keyspace = resolve_cassandra_config( host=cassandra_host, username=cassandra_username, @@ -59,14 +63,13 @@ class Processor(AsyncProcessor): default_keyspace="knowledge" ) - # Store resolved configuration self.cassandra_host = hosts self.cassandra_username = username self.cassandra_password = password super(Processor, self).__init__( **params | { - "knowledge_request_queue": knowledge_request_queue, + "knowledge_request_queue": self.knowledge_request_queue_base, "knowledge_response_queue": knowledge_response_queue, "cassandra_host": self.cassandra_host, "cassandra_username": self.cassandra_username, @@ -74,28 +77,10 @@ class Processor(AsyncProcessor): } ) - knowledge_request_metrics = ConsumerMetrics( - processor = self.id, flow = None, name = "knowledge-request" - ) - knowledge_response_metrics = ProducerMetrics( processor = self.id, flow = None, name = "knowledge-response" ) - self.knowledge_request_topic = knowledge_request_queue - self.knowledge_request_subscriber = id - - self.knowledge_request_consumer = Consumer( - taskgroup = self.taskgroup, - backend = self.pubsub, - flow = None, - topic = knowledge_request_queue, - subscriber = id, - schema = KnowledgeRequest, - handler = self.on_knowledge_request, - metrics = knowledge_request_metrics, - ) - self.knowledge_response_producer = Producer( backend = self.pubsub, topic = knowledge_response_queue, @@ -115,13 +100,52 @@ class Processor(AsyncProcessor): self.flows = {} + self.workspace_consumers = {} + logger.info("Knowledge service initialized") + async def on_workspace_created(self, workspace): + + if workspace in self.workspace_consumers: + return + + queue = workspace_queue( + self.knowledge_request_queue_base, workspace, + ) + + await self.pubsub.ensure_topic(queue) + + consumer = Consumer( + taskgroup=self.taskgroup, + backend=self.pubsub, + flow=None, + topic=queue, + subscriber=self.id, + schema=KnowledgeRequest, + handler=partial( + self.on_knowledge_request, workspace=workspace, + ), + metrics=ConsumerMetrics( + processor=self.id, flow=None, + name=f"knowledge-request-{workspace}", + ), + ) + + await consumer.start() + self.workspace_consumers[workspace] = consumer + + logger.info(f"Subscribed to workspace queue: {workspace}") + + async def on_workspace_deleted(self, workspace): + + consumer = self.workspace_consumers.pop(workspace, None) + if consumer: + await consumer.stop() + logger.info(f"Unsubscribed from workspace queue: {workspace}") + async def start(self): - await self.pubsub.ensure_topic(self.knowledge_request_topic) await super(Processor, self).start() - await self.knowledge_request_consumer.start() await self.knowledge_response_producer.start() async def on_knowledge_config(self, workspace, config, version): @@ -140,7 +164,7 @@ class Processor(AsyncProcessor): logger.debug(f"Flows for {workspace}: {self.flows[workspace]}") - async def process_request(self, v, id): + async def process_request(self, v, id, workspace): if v.operation is None: raise RequestError("Null operation") @@ -163,9 +187,9 @@ class Processor(AsyncProcessor): await self.knowledge_response_producer.send( x, { "id": id } ) - return await impls[v.operation](v, respond) + return await impls[v.operation](v, respond, workspace) - async def on_knowledge_request(self, msg, consumer, flow): + async def on_knowledge_request(self, msg, consumer, flow, *, workspace): v = msg.value() @@ -179,7 +203,7 @@ class Processor(AsyncProcessor): # We don't send a response back here, the processing # implementation sends whatever it needs to send. - await self.process_request(v, id) + await self.process_request(v, id, workspace) return @@ -215,7 +239,7 @@ class Processor(AsyncProcessor): @staticmethod def add_args(parser): - AsyncProcessor.add_args(parser) + WorkspaceProcessor.add_args(parser) parser.add_argument( '--knowledge-request-queue', diff --git a/trustgraph-flow/trustgraph/flow/service/flow.py b/trustgraph-flow/trustgraph/flow/service/flow.py index ed0158f6..5e440b11 100644 --- a/trustgraph-flow/trustgraph/flow/service/flow.py +++ b/trustgraph-flow/trustgraph/flow/service/flow.py @@ -118,10 +118,10 @@ class FlowConfig: return resolved - async def handle_list_blueprints(self, msg): + async def handle_list_blueprints(self, msg, workspace): names = list(await self.config.keys( - msg.workspace, "flow-blueprint" + workspace, "flow-blueprint" )) return FlowResponse( @@ -129,19 +129,19 @@ class FlowConfig: blueprint_names = names, ) - async def handle_get_blueprint(self, msg): + async def handle_get_blueprint(self, msg, workspace): return FlowResponse( error = None, blueprint_definition = await self.config.get( - msg.workspace, "flow-blueprint", msg.blueprint_name + workspace, "flow-blueprint", msg.blueprint_name ), ) - async def handle_put_blueprint(self, msg): + async def handle_put_blueprint(self, msg, workspace): await self.config.put( - msg.workspace, "flow-blueprint", + workspace, "flow-blueprint", msg.blueprint_name, msg.blueprint_definition ) @@ -149,31 +149,31 @@ class FlowConfig: error = None, ) - async def handle_delete_blueprint(self, msg): + async def handle_delete_blueprint(self, msg, workspace): logger.debug(f"Flow config message: {msg}") await self.config.delete( - msg.workspace, "flow-blueprint", msg.blueprint_name + workspace, "flow-blueprint", msg.blueprint_name ) return FlowResponse( error = None, ) - async def handle_list_flows(self, msg): + async def handle_list_flows(self, msg, workspace): - names = list(await self.config.keys(msg.workspace, "flow")) + names = list(await self.config.keys(workspace, "flow")) return FlowResponse( error = None, flow_ids = names, ) - async def handle_get_flow(self, msg): + async def handle_get_flow(self, msg, workspace): flow_data = await self.config.get( - msg.workspace, "flow", msg.flow_id + workspace, "flow", msg.flow_id ) flow = json.loads(flow_data) @@ -184,9 +184,7 @@ class FlowConfig: parameters = flow.get("parameters", {}), ) - async def handle_start_flow(self, msg): - - workspace = msg.workspace + async def handle_start_flow(self, msg, workspace): if msg.blueprint_name is None: raise RuntimeError("No blueprint name") @@ -222,7 +220,7 @@ class FlowConfig: logger.debug(f"Resolved parameters (with defaults): {parameters}") # Apply parameter substitution to template replacement function. - # {workspace} is substituted from msg.workspace to isolate + # {workspace} is substituted from workspace to isolate # queue names across workspaces. def repl_template_with_params(tmp): @@ -548,9 +546,7 @@ class FlowConfig: f"attempts: {topic}" ) - async def handle_stop_flow(self, msg): - - workspace = msg.workspace + async def handle_stop_flow(self, msg, workspace): if msg.flow_id is None: raise RuntimeError("No flow ID") @@ -641,37 +637,29 @@ class FlowConfig: error = None, ) - async def handle(self, msg): + async def handle(self, msg, workspace): logger.debug( f"Handling flow message: {msg.operation} " - f"workspace={msg.workspace}" + f"workspace={workspace}" ) - if not msg.workspace: - return FlowResponse( - error=Error( - type="bad-request", - message="Workspace is required", - ), - ) - if msg.operation == "list-blueprints": - resp = await self.handle_list_blueprints(msg) + resp = await self.handle_list_blueprints(msg, workspace) elif msg.operation == "get-blueprint": - resp = await self.handle_get_blueprint(msg) + resp = await self.handle_get_blueprint(msg, workspace) elif msg.operation == "put-blueprint": - resp = await self.handle_put_blueprint(msg) + resp = await self.handle_put_blueprint(msg, workspace) elif msg.operation == "delete-blueprint": - resp = await self.handle_delete_blueprint(msg) + resp = await self.handle_delete_blueprint(msg, workspace) elif msg.operation == "list-flows": - resp = await self.handle_list_flows(msg) + resp = await self.handle_list_flows(msg, workspace) elif msg.operation == "get-flow": - resp = await self.handle_get_flow(msg) + resp = await self.handle_get_flow(msg, workspace) elif msg.operation == "start-flow": - resp = await self.handle_start_flow(msg) + resp = await self.handle_start_flow(msg, workspace) elif msg.operation == "stop-flow": - resp = await self.handle_stop_flow(msg) + resp = await self.handle_stop_flow(msg, workspace) else: resp = FlowResponse( diff --git a/trustgraph-flow/trustgraph/flow/service/service.py b/trustgraph-flow/trustgraph/flow/service/service.py index 3dacf47d..5adcc962 100644 --- a/trustgraph-flow/trustgraph/flow/service/service.py +++ b/trustgraph-flow/trustgraph/flow/service/service.py @@ -4,6 +4,7 @@ Flow service. Manages flow lifecycle — starting and stopping flows by coordinating with the config service via pub/sub. """ +from functools import partial import logging import uuid @@ -14,7 +15,7 @@ from trustgraph.schema import flow_request_queue, flow_response_queue from trustgraph.schema import ConfigRequest, ConfigResponse from trustgraph.schema import config_request_queue, config_response_queue -from trustgraph.base import AsyncProcessor, Consumer, Producer +from trustgraph.base import WorkspaceProcessor, Consumer, Producer from trustgraph.base import ConsumerMetrics, ProducerMetrics, SubscriberMetrics from trustgraph.base import ConfigClient @@ -29,11 +30,15 @@ default_flow_request_queue = flow_request_queue default_flow_response_queue = flow_response_queue -class Processor(AsyncProcessor): +def workspace_queue(base_queue, workspace): + return f"{base_queue}:{workspace}" + + +class Processor(WorkspaceProcessor): def __init__(self, **params): - flow_request_queue = params.get( + self.flow_request_queue_base = params.get( "flow_request_queue", default_flow_request_queue ) flow_response_queue = params.get( @@ -49,27 +54,10 @@ class Processor(AsyncProcessor): } ) - flow_request_metrics = ConsumerMetrics( - processor = self.id, flow = None, name = "flow-request" - ) flow_response_metrics = ProducerMetrics( processor = self.id, flow = None, name = "flow-response" ) - self.flow_request_topic = flow_request_queue - self.flow_request_subscriber = id - - self.flow_request_consumer = Consumer( - taskgroup = self.taskgroup, - backend = self.pubsub, - flow = None, - topic = flow_request_queue, - subscriber = id, - schema = FlowRequest, - handler = self.on_flow_request, - metrics = flow_request_metrics, - ) - self.flow_response_producer = Producer( backend = self.pubsub, topic = flow_response_queue, @@ -84,13 +72,6 @@ class Processor(AsyncProcessor): processor=self.id, flow=None, name="config-response", ) - # Unique subscription suffix per process instance. Pulsar's - # exclusive subscriptions reject a second consumer on the same - # (topic, subscription-name) — so a deterministic name here - # collides with its own ghost when the supervisor restarts the - # process before Pulsar has timed out the previous session - # (ConsumerBusy). Matches the uuid convention used elsewhere - # (gateway/config/receiver.py, AsyncProcessor._create_config_client). config_rr_id = str(uuid.uuid4()) self.config_client = ConfigClient( backend=self.pubsub, @@ -106,21 +87,58 @@ class Processor(AsyncProcessor): self.flow = FlowConfig(self.config_client, self.pubsub) + self.workspace_consumers = {} + logger.info("Flow service initialized") + async def on_workspace_created(self, workspace): + + if workspace in self.workspace_consumers: + return + + queue = workspace_queue( + self.flow_request_queue_base, workspace, + ) + + await self.pubsub.ensure_topic(queue) + + consumer = Consumer( + taskgroup=self.taskgroup, + backend=self.pubsub, + flow=None, + topic=queue, + subscriber=self.id, + schema=FlowRequest, + handler=partial( + self.on_flow_request, workspace=workspace, + ), + metrics=ConsumerMetrics( + processor=self.id, flow=None, + name=f"flow-request-{workspace}", + ), + ) + + await consumer.start() + self.workspace_consumers[workspace] = consumer + + logger.info(f"Subscribed to workspace queue: {workspace}") + + async def on_workspace_deleted(self, workspace): + + consumer = self.workspace_consumers.pop(workspace, None) + if consumer: + await consumer.stop() + logger.info(f"Unsubscribed from workspace queue: {workspace}") + async def start(self): - await self.pubsub.ensure_topic(self.flow_request_topic) + await super(Processor, self).start() await self.config_client.start() - # Discover workspaces with existing flow config and ensure - # their topics exist before we start accepting requests. workspaces = await self.config_client.workspaces_for_type("flow") await self.flow.ensure_existing_flow_topics(workspaces) - await self.flow_request_consumer.start() - - async def on_flow_request(self, msg, consumer, flow): + async def on_flow_request(self, msg, consumer, flow, *, workspace): try: @@ -131,7 +149,7 @@ class Processor(AsyncProcessor): logger.debug(f"Handling flow request {id}...") - resp = await self.flow.handle(v) + resp = await self.flow.handle(v, workspace) await self.flow_response_producer.send( resp, properties={"id": id} @@ -155,7 +173,7 @@ class Processor(AsyncProcessor): @staticmethod def add_args(parser): - AsyncProcessor.add_args(parser) + WorkspaceProcessor.add_args(parser) parser.add_argument( '--flow-request-queue', diff --git a/trustgraph-flow/trustgraph/gateway/auth.py b/trustgraph-flow/trustgraph/gateway/auth.py index 6abcbe15..1309ecfc 100644 --- a/trustgraph-flow/trustgraph/gateway/auth.py +++ b/trustgraph-flow/trustgraph/gateway/auth.py @@ -141,6 +141,12 @@ class IamAuth: self._authz_cache: dict[str, tuple[bool, float]] = {} self._authz_cache_lock = asyncio.Lock() + # Known workspaces, maintained by the config receiver. + # enforce_workspace checks this set to reject requests for + # non-existent workspaces before routing to a queue that + # has no consumer. + self.known_workspaces: set[str] = set() + # ------------------------------------------------------------------ # Short-lived client helper. Mirrors the pattern used by the # bootstrap framework and AsyncProcessor: a fresh uuid suffix per diff --git a/trustgraph-flow/trustgraph/gateway/capabilities.py b/trustgraph-flow/trustgraph/gateway/capabilities.py index 72ca51c7..dbbb01e0 100644 --- a/trustgraph-flow/trustgraph/gateway/capabilities.py +++ b/trustgraph-flow/trustgraph/gateway/capabilities.py @@ -67,12 +67,22 @@ async def enforce(request, auth, capability): return identity +def workspace_not_found(): + return web.HTTPNotFound( + text='{"error":"workspace not found"}', + content_type="application/json", + ) + + async def enforce_workspace(data, identity, auth, capability=None): """Default-fill the workspace on a request body and (optionally) authorise the caller for ``capability`` against that workspace. - Target workspace = ``data["workspace"]`` if supplied, else the caller's bound workspace. + - Rejects the request if the resolved workspace is not in + ``auth.known_workspaces`` (prevents routing to a queue with + no consumer). - On success, ``data["workspace"]`` is overwritten with the resolved value so downstream code sees a single canonical address. @@ -92,6 +102,9 @@ async def enforce_workspace(data, identity, auth, capability=None): target = requested or identity.workspace data["workspace"] = target + if target not in auth.known_workspaces: + raise workspace_not_found() + if capability is not None: await auth.authorise( identity, capability, {"workspace": target}, {}, diff --git a/trustgraph-flow/trustgraph/gateway/config/receiver.py b/trustgraph-flow/trustgraph/gateway/config/receiver.py index 5bc781a9..8c42381f 100755 --- a/trustgraph-flow/trustgraph/gateway/config/receiver.py +++ b/trustgraph-flow/trustgraph/gateway/config/receiver.py @@ -24,9 +24,10 @@ logger.setLevel(logging.INFO) class ConfigReceiver: - def __init__(self, backend): + def __init__(self, backend, auth=None): self.backend = backend + self.auth = auth self.flow_handlers = [] @@ -54,6 +55,15 @@ class ConfigReceiver: ) return + # Track workspace lifecycle + if v.workspace_changes and self.auth: + for ws in (v.workspace_changes.created or []): + self.auth.known_workspaces.add(ws) + logger.info(f"Workspace registered: {ws}") + for ws in (v.workspace_changes.deleted or []): + self.auth.known_workspaces.discard(ws) + logger.info(f"Workspace deregistered: {ws}") + # Gateway cares about flow config — check if any flow # types changed in any workspace flow_workspaces = changes.get("flow", []) @@ -195,6 +205,33 @@ class ConfigReceiver: try: await client.start() + # Discover all known workspaces + ws_resp = await client.request( + ConfigRequest( + operation="getvalues", + workspace="__workspaces__", + type="workspace", + ), + timeout=10, + ) + + if ws_resp.error: + raise RuntimeError( + f"Workspace discovery error: " + f"{ws_resp.error.message}" + ) + + discovered = { + v.key for v in ws_resp.values if v.key + } + + if self.auth: + self.auth.known_workspaces = discovered + + logger.info( + f"Known workspaces: {discovered}" + ) + # Discover workspaces that have any flow config resp = await client.request( ConfigRequest( diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index ea8770d7..db70127a 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -7,6 +7,12 @@ import logging # Module logger logger = logging.getLogger(__name__) +from ... schema import flow_request_queue +from ... schema import librarian_request_queue +from ... schema import knowledge_request_queue +from ... schema import collection_request_queue +from ... schema import config_request_queue + from . config import ConfigRequestor from . flow import FlowRequestor from . iam import IamRequestor @@ -70,15 +76,28 @@ request_response_dispatchers = { "sparql": SparqlQueryRequestor, } -global_dispatchers = { +system_dispatchers = { + "iam": IamRequestor, +} + +workspace_dispatchers = { "config": ConfigRequestor, "flow": FlowRequestor, - "iam": IamRequestor, "librarian": LibrarianRequestor, "knowledge": KnowledgeRequestor, "collection-management": CollectionManagementRequestor, } +workspace_default_request_queues = { + "config": config_request_queue, + "flow": flow_request_queue, + "librarian": librarian_request_queue, + "knowledge": knowledge_request_queue, + "collection-management": collection_request_queue, +} + +global_dispatchers = {**system_dispatchers, **workspace_dispatchers} + sender_dispatchers = { "text-load": TextLoad, "document-load": DocumentLoad, @@ -219,11 +238,24 @@ class DispatcherManager: async def process_global_service(self, data, responder, params): kind = params.get("kind") - return await self.invoke_global_service(data, responder, kind) + workspace = params.get("workspace") + if not workspace and isinstance(data, dict): + workspace = data.get("workspace") + return await self.invoke_global_service( + data, responder, kind, workspace=workspace, + ) - async def invoke_global_service(self, data, responder, kind): + async def invoke_global_service(self, data, responder, kind, + workspace=None): - key = (None, kind) + if kind in workspace_dispatchers: + if not workspace: + raise RuntimeError( + f"Workspace is required for {kind}" + ) + key = (workspace, kind) + else: + key = (None, kind) if key not in self.dispatchers: async with self.dispatcher_lock: @@ -234,11 +266,21 @@ class DispatcherManager: request_queue = self.queue_overrides[kind].get("request") response_queue = self.queue_overrides[kind].get("response") + if kind in workspace_dispatchers and workspace: + base_queue = ( + request_queue + or workspace_default_request_queues[kind] + ) + request_queue = f"{base_queue}:{workspace}" + consumer_name = f"{self.prefix}-{kind}-{workspace}" + else: + consumer_name = f"{self.prefix}-{kind}-request" + dispatcher = global_dispatchers[kind]( backend = self.backend, timeout = 120, - consumer = f"{self.prefix}-{kind}-request", - subscriber = f"{self.prefix}-{kind}-request", + consumer = consumer_name, + subscriber = consumer_name, request_queue = request_queue, response_queue = response_queue, ) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py index 03cd748b..02c0eed2 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/mux.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/mux.py @@ -190,6 +190,16 @@ class Mux: await self.auth.authorise( self.identity, op.capability, resource, parameters, ) + except _web.HTTPNotFound: + await self.ws.send_json({ + "id": request_id, + "error": { + "message": "workspace not found", + "type": "workspace-not-found", + }, + "complete": True, + }) + return except _web.HTTPForbidden: await self.ws.send_json({ "id": request_id, @@ -310,7 +320,7 @@ class Mux: else: await self.dispatcher_manager.invoke_global_service( - request, responder, svc + request, responder, svc, workspace=workspace, ) except Exception as e: diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py b/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py index 28b0ded5..c5572ff8 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py @@ -116,9 +116,6 @@ def serialize_document_metadata(message): if message.metadata: ret["metadata"] = serialize_subgraph(message.metadata) - if message.workspace: - ret["workspace"] = message.workspace - if message.tags is not None: ret["tags"] = message.tags @@ -140,9 +137,6 @@ def serialize_processing_metadata(message): if message.flow: ret["flow"] = message.flow - if message.workspace: - ret["workspace"] = message.workspace - if message.collection: ret["collection"] = message.collection @@ -160,7 +154,6 @@ def to_document_metadata(x): title = x.get("title", None), comments = x.get("comments", None), metadata = to_subgraph(x["metadata"]), - workspace = x.get("workspace", None), tags = x.get("tags", None), ) @@ -171,7 +164,6 @@ def to_processing_metadata(x): document_id = x.get("document-id", None), time = x.get("time", None), flow = x.get("flow", None), - workspace = x.get("workspace", None), collection = x.get("collection", None), tags = x.get("tags", None), ) diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/manager.py b/trustgraph-flow/trustgraph/gateway/endpoint/manager.py index ed5ef4b5..2b3bb83b 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/manager.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/manager.py @@ -12,8 +12,8 @@ from . auth_endpoints import AuthEndpoints from . iam_endpoint import IamEndpoint from . registry_endpoint import RegistryRoutedVariableEndpoint -from .. capabilities import PUBLIC, AUTHENTICATED, auth_failure -from .. registry import lookup as _registry_lookup, RequestContext +from .. capabilities import PUBLIC, AUTHENTICATED, auth_failure, workspace_not_found +from .. registry import lookup as _registry_lookup, RequestContext, ResourceLevel from .. dispatch.manager import DispatcherManager @@ -77,6 +77,10 @@ class _RoutedVariableEndpoint: identity, op.capability, resource, parameters, ) + ws = resource.get("workspace", "") + if ws and ws not in self.auth.known_workspaces: + raise workspace_not_found() + async def responder(x, fin): pass @@ -140,6 +144,11 @@ class _RoutedSocketEndpoint: await self.auth.authorise( identity, op.capability, resource, parameters, ) + + ws = resource.get("workspace", "") + if ws and ws not in self.auth.known_workspaces: + raise workspace_not_found() + except web.HTTPException as e: return e diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/registry_endpoint.py b/trustgraph-flow/trustgraph/gateway/endpoint/registry_endpoint.py index 296376fa..0861bcd3 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/registry_endpoint.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/registry_endpoint.py @@ -20,9 +20,9 @@ import logging from aiohttp import web from .. capabilities import ( - PUBLIC, AUTHENTICATED, auth_failure, + PUBLIC, AUTHENTICATED, auth_failure, workspace_not_found, ) -from .. registry import lookup, RequestContext +from .. registry import lookup, RequestContext, ResourceLevel logger = logging.getLogger("registry-endpoint") logger.setLevel(logging.INFO) @@ -107,6 +107,15 @@ class RegistryRoutedVariableEndpoint: if "workspace" in resource: body["workspace"] = resource["workspace"] + if ( + op.resource_level in ( + ResourceLevel.WORKSPACE, ResourceLevel.FLOW, + ) + and resource.get("workspace") + not in self.auth.known_workspaces + ): + raise workspace_not_found() + async def responder(x, fin): pass diff --git a/trustgraph-flow/trustgraph/gateway/service.py b/trustgraph-flow/trustgraph/gateway/service.py index f75f3b25..0f6a5070 100755 --- a/trustgraph-flow/trustgraph/gateway/service.py +++ b/trustgraph-flow/trustgraph/gateway/service.py @@ -68,7 +68,7 @@ class Api: id=config.get("id", "api-gateway"), ) - self.config_receiver = ConfigReceiver(self.pubsub_backend) + self.config_receiver = ConfigReceiver(self.pubsub_backend, auth=self.auth) # Build queue overrides dictionary from CLI arguments queue_overrides = {} diff --git a/trustgraph-flow/trustgraph/iam/service/iam.py b/trustgraph-flow/trustgraph/iam/service/iam.py index c89f65b0..3f9bc5b9 100644 --- a/trustgraph-flow/trustgraph/iam/service/iam.py +++ b/trustgraph-flow/trustgraph/iam/service/iam.py @@ -245,7 +245,8 @@ def _sign_jwt(kid, private_pem, claims): class IamService: def __init__(self, host, username, password, keyspace, - bootstrap_mode, bootstrap_token=None): + bootstrap_mode, bootstrap_token=None, + on_workspace_created=None, on_workspace_deleted=None): self.table_store = IamTableStore( host, username, password, keyspace, ) @@ -267,6 +268,12 @@ class IamService: self.bootstrap_mode = bootstrap_mode self.bootstrap_token = bootstrap_token + # Callbacks for workspace lifecycle events. Called after the + # workspace is created/deleted in IAM's own store so that the + # processor can announce it via the config service. + self._on_workspace_created = on_workspace_created + self._on_workspace_deleted = on_workspace_deleted + self._signing_key = None self._signing_key_lock = asyncio.Lock() @@ -424,6 +431,9 @@ class IamService: created=now, ) + if self._on_workspace_created: + await self._on_workspace_created(DEFAULT_WORKSPACE) + admin_user_id = str(uuid.uuid4()) admin_password = secrets.token_urlsafe(32) await self.table_store.put_user( @@ -904,6 +914,10 @@ class IamService: enabled=v.workspace_record.enabled, created=now, ) + + if self._on_workspace_created: + await self._on_workspace_created(v.workspace_record.id) + row = await self.table_store.get_workspace(v.workspace_record.id) return IamResponse(workspace=self._row_to_workspace_record(row)) @@ -982,6 +996,9 @@ class IamService: for kr in key_rows: await self.table_store.delete_api_key(kr[0]) + if self._on_workspace_deleted: + await self._on_workspace_deleted(v.workspace_record.id) + return IamResponse() # ------------------------------------------------------------------ diff --git a/trustgraph-flow/trustgraph/iam/service/service.py b/trustgraph-flow/trustgraph/iam/service/service.py index 147bd56a..b6945e87 100644 --- a/trustgraph-flow/trustgraph/iam/service/service.py +++ b/trustgraph-flow/trustgraph/iam/service/service.py @@ -12,9 +12,13 @@ import os from trustgraph.schema import Error from trustgraph.schema import IamRequest, IamResponse from trustgraph.schema import iam_request_queue, iam_response_queue +from trustgraph.schema import ConfigRequest, ConfigResponse, ConfigValue +from trustgraph.schema import config_request_queue, config_response_queue from trustgraph.base import AsyncProcessor, Consumer, Producer from trustgraph.base import ConsumerMetrics, ProducerMetrics +from trustgraph.base.metrics import SubscriberMetrics +from trustgraph.base.request_response_spec import RequestResponse from trustgraph.base.cassandra_config import ( add_cassandra_args, resolve_cassandra_config, ) @@ -147,6 +151,8 @@ class Processor(AsyncProcessor): keyspace=keyspace, bootstrap_mode=self.bootstrap_mode, bootstrap_token=self.bootstrap_token, + on_workspace_created=self._announce_workspace_created, + on_workspace_deleted=self._announce_workspace_deleted, ) logger.info( @@ -160,6 +166,87 @@ class Processor(AsyncProcessor): await self.iam.auto_bootstrap_if_token_mode() await self.iam_request_consumer.start() + def _create_config_client(self): + import uuid + config_rr_id = str(uuid.uuid4()) + config_req_metrics = ProducerMetrics( + processor=self.id, flow=None, name="config-request", + ) + config_resp_metrics = SubscriberMetrics( + processor=self.id, flow=None, name="config-response", + ) + return RequestResponse( + backend=self.pubsub, + subscription=f"{self.id}--config--{config_rr_id}", + consumer_name=self.id, + request_topic=config_request_queue, + request_schema=ConfigRequest, + request_metrics=config_req_metrics, + response_topic=config_response_queue, + response_schema=ConfigResponse, + response_metrics=config_resp_metrics, + ) + + async def _config_put(self, workspace, type, key, value): + client = self._create_config_client() + try: + await client.start() + await client.request( + ConfigRequest( + operation="put", + workspace=workspace, + values=[ConfigValue(type=type, key=key, value=value)], + ), + timeout=10, + ) + finally: + await client.stop() + + async def _config_delete(self, workspace, type, key): + from trustgraph.schema import ConfigKey + client = self._create_config_client() + try: + await client.start() + await client.request( + ConfigRequest( + operation="delete", + workspace=workspace, + keys=[ConfigKey(type=type, key=key)], + ), + timeout=10, + ) + finally: + await client.stop() + + async def _announce_workspace_created(self, workspace_id): + try: + await self._config_put( + "__workspaces__", "workspace", workspace_id, + '{"enabled": true}', + ) + logger.info( + f"Announced workspace creation: {workspace_id}" + ) + except Exception as e: + logger.error( + f"Failed to announce workspace creation " + f"{workspace_id}: {e}", exc_info=True, + ) + + async def _announce_workspace_deleted(self, workspace_id): + try: + await self._config_delete( + "__workspaces__", "workspace", workspace_id, + ) + logger.info( + f"Announced workspace deletion: {workspace_id}" + ) + except Exception as e: + logger.error( + f"Failed to announce workspace deletion " + f"{workspace_id}: {e}", exc_info=True, + ) + async def on_iam_request(self, msg, consumer, flow): id = None diff --git a/trustgraph-flow/trustgraph/librarian/collection_manager.py b/trustgraph-flow/trustgraph/librarian/collection_manager.py index 09932adf..71f31d36 100644 --- a/trustgraph-flow/trustgraph/librarian/collection_manager.py +++ b/trustgraph-flow/trustgraph/librarian/collection_manager.py @@ -151,21 +151,11 @@ class CollectionManager: logger.error(f"Error ensuring collection exists: {e}") raise e - async def list_collections(self, request: CollectionManagementRequest) -> CollectionManagementResponse: - """ - List collections for a user from config service - - Args: - request: Collection management request - - Returns: - CollectionManagementResponse with list of collections - """ + async def list_collections(self, request, workspace): try: - # Get all collections in this workspace from config service config_request = ConfigRequest( operation='getvalues', - workspace=request.workspace, + workspace=workspace, type='collection' ) @@ -210,18 +200,8 @@ class CollectionManager: logger.error(f"Error listing collections: {e}") raise RequestError(f"Failed to list collections: {str(e)}") - async def update_collection(self, request: CollectionManagementRequest) -> CollectionManagementResponse: - """ - Update collection metadata via config service (creates if doesn't exist) - - Args: - request: Collection management request - - Returns: - CollectionManagementResponse with updated collection - """ + async def update_collection(self, request, workspace): try: - # Create metadata from request name = request.name if request.name else request.collection description = request.description if request.description else "" tags = list(request.tags) if request.tags else [] @@ -233,10 +213,9 @@ class CollectionManager: tags=tags ) - # Send put request to config service config_request = ConfigRequest( operation='put', - workspace=request.workspace, + workspace=workspace, values=[ConfigValue( type='collection', key=request.collection, @@ -249,7 +228,7 @@ class CollectionManager: if response.error: raise RuntimeError(f"Config update failed: {response.error.message}") - logger.info(f"Collection {request.workspace}/{request.collection} updated in config service") + logger.info(f"Collection {workspace}/{request.collection} updated in config service") # Config service will trigger config push automatically # Storage services will receive update and create/update collections @@ -264,23 +243,13 @@ class CollectionManager: logger.error(f"Error updating collection: {e}") raise RequestError(f"Failed to update collection: {str(e)}") - async def delete_collection(self, request: CollectionManagementRequest) -> CollectionManagementResponse: - """ - Delete collection via config service - - Args: - request: Collection management request - - Returns: - CollectionManagementResponse indicating success or failure - """ + async def delete_collection(self, request, workspace): try: - logger.info(f"Deleting collection {request.workspace}/{request.collection}") + logger.info(f"Deleting collection {workspace}/{request.collection}") - # Send delete request to config service config_request = ConfigRequest( operation='delete', - workspace=request.workspace, + workspace=workspace, keys=[ConfigKey(type='collection', key=request.collection)] ) @@ -289,7 +258,7 @@ class CollectionManager: if response.error: raise RuntimeError(f"Config delete failed: {response.error.message}") - logger.info(f"Collection {request.workspace}/{request.collection} deleted from config service") + logger.info(f"Collection {workspace}/{request.collection} deleted from config service") # Config service will trigger config push automatically # Storage services will receive update and delete collections diff --git a/trustgraph-flow/trustgraph/librarian/librarian.py b/trustgraph-flow/trustgraph/librarian/librarian.py index af1d69b1..653c573f 100644 --- a/trustgraph-flow/trustgraph/librarian/librarian.py +++ b/trustgraph-flow/trustgraph/librarian/librarian.py @@ -42,13 +42,13 @@ class Librarian: self.load_document = load_document self.min_chunk_size = min_chunk_size - async def add_document(self, request): + async def add_document(self, request, workspace): if not request.document_metadata.kind: raise RequestError("Document kind (MIME type) is required") if await self.table_store.document_exists( - request.document_metadata.workspace, + workspace, request.document_metadata.id ): raise RuntimeError("Document already exists") @@ -66,19 +66,19 @@ class Librarian: logger.debug("Adding to table...") await self.table_store.add_document( - request.document_metadata, object_id + workspace, request.document_metadata, object_id ) logger.debug("Add complete") return LibrarianResponse() - async def remove_document(self, request): + async def remove_document(self, request, workspace): logger.debug("Removing document...") if not await self.table_store.document_exists( - request.workspace, + workspace, request.document_id, ): raise RuntimeError("Document does not exist") @@ -89,17 +89,17 @@ class Librarian: logger.debug(f"Cascade deleting child document {child.id}") try: child_object_id = await self.table_store.get_document_object_id( - child.workspace, + workspace, child.id ) await self.blob_store.remove(child_object_id) - await self.table_store.remove_document(child.workspace, child.id) + await self.table_store.remove_document(workspace, child.id) except Exception as e: logger.warning(f"Failed to delete child document {child.id}: {e}") # Now remove the parent document object_id = await self.table_store.get_document_object_id( - request.workspace, + workspace, request.document_id ) @@ -108,7 +108,7 @@ class Librarian: # Remove doc table row await self.table_store.remove_document( - request.workspace, + workspace, request.document_id ) @@ -116,30 +116,30 @@ class Librarian: return LibrarianResponse() - async def update_document(self, request): + async def update_document(self, request, workspace): logger.debug("Updating document...") # You can't update the document ID, workspace or kind. if not await self.table_store.document_exists( - request.document_metadata.workspace, + workspace, request.document_metadata.id ): raise RuntimeError("Document does not exist") - await self.table_store.update_document(request.document_metadata) + await self.table_store.update_document(workspace, request.document_metadata) logger.debug("Update complete") return LibrarianResponse() - async def get_document_metadata(self, request): + async def get_document_metadata(self, request, workspace): logger.debug("Getting document metadata...") doc = await self.table_store.get_document( - request.workspace, + workspace, request.document_id ) @@ -151,12 +151,12 @@ class Librarian: content = None, ) - async def get_document_content(self, request): + async def get_document_content(self, request, workspace): logger.debug("Getting document content...") object_id = await self.table_store.get_document_object_id( - request.workspace, + workspace, request.document_id ) @@ -172,7 +172,7 @@ class Librarian: content = base64.b64encode(content), ) - async def add_processing(self, request): + async def add_processing(self, request, workspace): logger.debug("Adding processing metadata...") @@ -180,18 +180,18 @@ class Librarian: raise RuntimeError("Collection parameter is required") if await self.table_store.processing_exists( - request.processing_metadata.workspace, + workspace, request.processing_metadata.id ): raise RuntimeError("Processing already exists") doc = await self.table_store.get_document( - request.processing_metadata.workspace, + workspace, request.processing_metadata.document_id ) object_id = await self.table_store.get_document_object_id( - request.processing_metadata.workspace, + workspace, request.processing_metadata.document_id ) @@ -203,7 +203,7 @@ class Librarian: logger.debug("Adding processing to table...") - await self.table_store.add_processing(request.processing_metadata) + await self.table_store.add_processing(workspace, request.processing_metadata) logger.debug("Invoking document processing...") @@ -211,25 +211,26 @@ class Librarian: document = doc, processing = request.processing_metadata, content = content, + workspace = workspace, ) logger.debug("Add complete") return LibrarianResponse() - async def remove_processing(self, request): + async def remove_processing(self, request, workspace): logger.debug("Removing processing metadata...") if not await self.table_store.processing_exists( - request.workspace, + workspace, request.processing_id, ): raise RuntimeError("Processing object does not exist") # Remove doc table row await self.table_store.remove_processing( - request.workspace, + workspace, request.processing_id ) @@ -237,9 +238,9 @@ class Librarian: return LibrarianResponse() - async def list_documents(self, request): + async def list_documents(self, request, workspace): - docs = await self.table_store.list_documents(request.workspace) + docs = await self.table_store.list_documents(workspace) # Filter out child documents and answer documents by default include_children = getattr(request, 'include_children', False) @@ -254,9 +255,9 @@ class Librarian: document_metadatas = docs, ) - async def list_processing(self, request): + async def list_processing(self, request, workspace): - procs = await self.table_store.list_processing(request.workspace) + procs = await self.table_store.list_processing(workspace) return LibrarianResponse( processing_metadatas = procs, @@ -264,7 +265,7 @@ class Librarian: # Chunked upload operations - async def begin_upload(self, request): + async def begin_upload(self, request, workspace): """ Initialize a chunked upload session. @@ -276,7 +277,7 @@ class Librarian: raise RequestError("Document kind (MIME type) is required") if await self.table_store.document_exists( - request.document_metadata.workspace, + workspace, request.document_metadata.id ): raise RequestError("Document already exists") @@ -312,14 +313,13 @@ class Librarian: "kind": request.document_metadata.kind, "title": request.document_metadata.title, "comments": request.document_metadata.comments, - "workspace": request.document_metadata.workspace, "tags": request.document_metadata.tags, }) # Store session in Cassandra await self.table_store.create_upload_session( upload_id=upload_id, - workspace=request.document_metadata.workspace, + workspace=workspace, document_id=request.document_metadata.id, document_metadata=doc_meta_json, s3_upload_id=s3_upload_id, @@ -338,7 +338,7 @@ class Librarian: total_chunks=total_chunks, ) - async def upload_chunk(self, request): + async def upload_chunk(self, request, workspace): """ Upload a single chunk of a document. @@ -352,7 +352,7 @@ class Librarian: raise RequestError("Upload session not found or expired") # Validate ownership - if session["workspace"] != request.workspace: + if session["workspace"] != workspace: raise RequestError("Not authorized to upload to this session") # Validate chunk index @@ -405,7 +405,7 @@ class Librarian: total_bytes=session["total_size"], ) - async def complete_upload(self, request): + async def complete_upload(self, request, workspace): """ Finalize a chunked upload and create the document. @@ -419,7 +419,7 @@ class Librarian: raise RequestError("Upload session not found or expired") # Validate ownership - if session["workspace"] != request.workspace: + if session["workspace"] != workspace: raise RequestError("Not authorized to complete this upload") # Verify all chunks received @@ -457,13 +457,13 @@ class Librarian: kind=doc_meta_dict["kind"], title=doc_meta_dict.get("title", ""), comments=doc_meta_dict.get("comments", ""), - workspace=doc_meta_dict["workspace"], tags=doc_meta_dict.get("tags", []), metadata=[], # Triples not supported in chunked upload yet ) # Add document to table - await self.table_store.add_document(doc_metadata, session["object_id"]) + workspace = session["workspace"] + await self.table_store.add_document(workspace, doc_metadata, session["object_id"]) # Delete upload session await self.table_store.delete_upload_session(request.upload_id) @@ -476,7 +476,7 @@ class Librarian: object_id=str(session["object_id"]), ) - async def abort_upload(self, request): + async def abort_upload(self, request, workspace): """ Cancel a chunked upload and clean up resources. """ @@ -488,7 +488,7 @@ class Librarian: raise RequestError("Upload session not found or expired") # Validate ownership - if session["workspace"] != request.workspace: + if session["workspace"] != workspace: raise RequestError("Not authorized to abort this upload") # Abort S3 multipart upload @@ -504,7 +504,7 @@ class Librarian: return LibrarianResponse(error=None) - async def get_upload_status(self, request): + async def get_upload_status(self, request, workspace): """ Get the status of an in-progress upload. """ @@ -520,7 +520,7 @@ class Librarian: ) # Validate ownership - if session["workspace"] != request.workspace: + if session["workspace"] != workspace: raise RequestError("Not authorized to view this upload") chunks_received = session["chunks_received"] @@ -546,13 +546,13 @@ class Librarian: total_bytes=session["total_size"], ) - async def list_uploads(self, request): + async def list_uploads(self, request, workspace): """ List all in-progress uploads for a workspace. """ - logger.debug(f"Listing uploads for workspace {request.workspace}") + logger.debug(f"Listing uploads for workspace {workspace}") - sessions = await self.table_store.list_upload_sessions(request.workspace) + sessions = await self.table_store.list_upload_sessions(workspace) upload_sessions = [ UploadSession( @@ -575,7 +575,7 @@ class Librarian: # Child document operations - async def add_child_document(self, request): + async def add_child_document(self, request, workspace): """ Add a child document linked to a parent document. @@ -591,7 +591,7 @@ class Librarian: # Verify parent exists if not await self.table_store.document_exists( - request.document_metadata.workspace, + workspace, request.document_metadata.parent_id ): raise RequestError( @@ -599,7 +599,7 @@ class Librarian: ) if await self.table_store.document_exists( - request.document_metadata.workspace, + workspace, request.document_metadata.id ): raise RequestError("Document already exists") @@ -622,7 +622,7 @@ class Librarian: logger.debug("Adding to table...") await self.table_store.add_document( - request.document_metadata, object_id + workspace, request.document_metadata, object_id ) logger.debug("Add child document complete") @@ -632,7 +632,7 @@ class Librarian: document_id=request.document_metadata.id, ) - async def list_children(self, request): + async def list_children(self, request, workspace): """ List all child documents for a given parent document. """ @@ -645,7 +645,7 @@ class Librarian: document_metadatas=children, ) - async def stream_document(self, request): + async def stream_document(self, request, workspace): """ Stream document content in chunks. @@ -665,7 +665,7 @@ class Librarian: ) object_id = await self.table_store.get_document_object_id( - request.workspace, + workspace, request.document_id ) @@ -697,4 +697,3 @@ class Librarian: total_bytes=total_size, is_final=is_last, ) - diff --git a/trustgraph-flow/trustgraph/librarian/service.py b/trustgraph-flow/trustgraph/librarian/service.py index c24a5fe8..fae55571 100755 --- a/trustgraph-flow/trustgraph/librarian/service.py +++ b/trustgraph-flow/trustgraph/librarian/service.py @@ -10,7 +10,7 @@ import json import logging from datetime import datetime -from .. base import AsyncProcessor, Consumer, Producer, Publisher, Subscriber +from .. base import WorkspaceProcessor, Consumer, Producer, Publisher, Subscriber from .. base import ConsumerMetrics, ProducerMetrics from .. base.cassandra_config import add_cassandra_args, resolve_cassandra_config @@ -46,6 +46,9 @@ default_collection_response_queue = collection_response_queue default_config_request_queue = config_request_queue default_config_response_queue = config_response_queue +def workspace_queue(base_queue, workspace): + return f"{base_queue}:{workspace}" + default_object_store_endpoint = "ceph-rgw:7480" default_object_store_access_key = "object-user" default_object_store_secret_key = "object-password" @@ -56,15 +59,13 @@ default_min_chunk_size = 1 # No minimum by default (for Garage) bucket_name = "library" -class Processor(AsyncProcessor): +class Processor(WorkspaceProcessor): def __init__(self, **params): id = params.get("id") -# self.running = True - - librarian_request_queue = params.get( + self.librarian_request_queue_base = params.get( "librarian_request_queue", default_librarian_request_queue ) @@ -72,7 +73,7 @@ class Processor(AsyncProcessor): "librarian_response_queue", default_librarian_response_queue ) - collection_request_queue = params.get( + self.collection_request_queue_base = params.get( "collection_request_queue", default_collection_request_queue ) @@ -130,9 +131,9 @@ class Processor(AsyncProcessor): super(Processor, self).__init__( **params | { - "librarian_request_queue": librarian_request_queue, + "librarian_request_queue": self.librarian_request_queue_base, "librarian_response_queue": librarian_response_queue, - "collection_request_queue": collection_request_queue, + "collection_request_queue": self.collection_request_queue_base, "collection_response_queue": collection_response_queue, "object_store_endpoint": object_store_endpoint, "object_store_access_key": object_store_access_key, @@ -142,40 +143,14 @@ class Processor(AsyncProcessor): } ) - librarian_request_metrics = ConsumerMetrics( - processor = self.id, flow = None, name = "librarian-request" - ) - librarian_response_metrics = ProducerMetrics( processor = self.id, flow = None, name = "librarian-response" ) - collection_request_metrics = ConsumerMetrics( - processor = self.id, flow = None, name = "collection-request" - ) - collection_response_metrics = ProducerMetrics( processor = self.id, flow = None, name = "collection-response" ) - storage_response_metrics = ConsumerMetrics( - processor = self.id, flow = None, name = "storage-response" - ) - - self.librarian_request_topic = librarian_request_queue - self.librarian_request_subscriber = id - - self.librarian_request_consumer = Consumer( - taskgroup = self.taskgroup, - backend = self.pubsub, - flow = None, - topic = librarian_request_queue, - subscriber = id, - schema = LibrarianRequest, - handler = self.on_librarian_request, - metrics = librarian_request_metrics, - ) - self.librarian_response_producer = Producer( backend = self.pubsub, topic = librarian_response_queue, @@ -183,20 +158,6 @@ class Processor(AsyncProcessor): metrics = librarian_response_metrics, ) - self.collection_request_topic = collection_request_queue - self.collection_request_subscriber = id - - self.collection_request_consumer = Consumer( - taskgroup = self.taskgroup, - backend = self.pubsub, - flow = None, - topic = collection_request_queue, - subscriber = id, - schema = CollectionManagementRequest, - handler = self.on_collection_request, - metrics = collection_request_metrics, - ) - self.collection_response_producer = Producer( backend = self.pubsub, topic = collection_response_queue, @@ -259,16 +220,80 @@ class Processor(AsyncProcessor): self.flows = {} + # Per-workspace consumers, keyed by workspace id + self.workspace_consumers = {} + logger.info("Librarian service initialized") + async def on_workspace_created(self, workspace): + + if workspace in self.workspace_consumers: + return + + lib_queue = workspace_queue( + self.librarian_request_queue_base, workspace, + ) + col_queue = workspace_queue( + self.collection_request_queue_base, workspace, + ) + + await self.pubsub.ensure_topic(lib_queue) + await self.pubsub.ensure_topic(col_queue) + + lib_consumer = Consumer( + taskgroup=self.taskgroup, + backend=self.pubsub, + flow=None, + topic=lib_queue, + subscriber=self.id, + schema=LibrarianRequest, + handler=partial( + self.on_librarian_request, workspace=workspace, + ), + metrics=ConsumerMetrics( + processor=self.id, flow=None, + name=f"librarian-request-{workspace}", + ), + ) + + col_consumer = Consumer( + taskgroup=self.taskgroup, + backend=self.pubsub, + flow=None, + topic=col_queue, + subscriber=self.id, + schema=CollectionManagementRequest, + handler=partial( + self.on_collection_request, workspace=workspace, + ), + metrics=ConsumerMetrics( + processor=self.id, flow=None, + name=f"collection-request-{workspace}", + ), + ) + + await lib_consumer.start() + await col_consumer.start() + + self.workspace_consumers[workspace] = { + "librarian": lib_consumer, + "collection": col_consumer, + } + + logger.info(f"Subscribed to workspace queues: {workspace}") + + async def on_workspace_deleted(self, workspace): + + consumers = self.workspace_consumers.pop(workspace, None) + if consumers: + for consumer in consumers.values(): + await consumer.stop() + logger.info(f"Unsubscribed from workspace queues: {workspace}") + async def start(self): - await self.pubsub.ensure_topic(self.librarian_request_topic) - await self.pubsub.ensure_topic(self.collection_request_topic) await super(Processor, self).start() - await self.librarian_request_consumer.start() await self.librarian_response_producer.start() - await self.collection_request_consumer.start() await self.collection_response_producer.start() await self.config_request_producer.start() await self.config_response_consumer.start() @@ -360,13 +385,12 @@ class Processor(AsyncProcessor): finally: await triples_pub.stop() - async def load_document(self, document, processing, content): + async def load_document(self, document, processing, content, workspace): logger.debug("Ready for document processing...") logger.debug(f"Document: {document}, processing: {processing}, content length: {len(content)}") - workspace = processing.workspace ws_flows = self.flows.get(workspace, {}) if processing.flow not in ws_flows: raise RuntimeError( @@ -429,20 +453,14 @@ class Processor(AsyncProcessor): logger.debug("Document submitted") - async def add_processing_with_collection(self, request): - """ - Wrapper for add_processing that ensures collection exists - """ - # Ensure collection exists when processing is added + async def add_processing_with_collection(self, request, workspace): if hasattr(request, 'processing_metadata') and request.processing_metadata: - workspace = request.processing_metadata.workspace collection = request.processing_metadata.collection await self.collection_manager.ensure_collection_exists(workspace, collection) - # Call the original add_processing method - return await self.librarian.add_processing(request) + return await self.librarian.add_processing(request, workspace) - async def process_request(self, v): + async def process_request(self, v, workspace): if v.operation is None: raise RequestError("Null operation") @@ -475,9 +493,9 @@ class Processor(AsyncProcessor): if v.operation not in impls: raise RequestError(f"Invalid operation: {v.operation}") - return await impls[v.operation](v) + return await impls[v.operation](v, workspace) - async def on_librarian_request(self, msg, consumer, flow): + async def on_librarian_request(self, msg, consumer, flow, *, workspace): v = msg.value() @@ -491,14 +509,14 @@ class Processor(AsyncProcessor): # Handle streaming operations specially if v.operation == "stream-document": - async for resp in self.librarian.stream_document(v): + async for resp in self.librarian.stream_document(v, workspace): await self.librarian_response_producer.send( resp, properties={"id": id} ) return # Non-streaming operations - resp = await self.process_request(v) + resp = await self.process_request(v, workspace) await self.librarian_response_producer.send( resp, properties={"id": id} @@ -535,10 +553,7 @@ class Processor(AsyncProcessor): logger.debug("Librarian input processing complete") - async def process_collection_request(self, v): - """ - Process collection management requests - """ + async def process_collection_request(self, v, workspace): if v.operation is None: raise RequestError("Null operation") @@ -553,19 +568,16 @@ class Processor(AsyncProcessor): if v.operation not in impls: raise RequestError(f"Invalid collection operation: {v.operation}") - return await impls[v.operation](v) + return await impls[v.operation](v, workspace) - async def on_collection_request(self, msg, consumer, flow): - """ - Handle collection management request messages - """ + async def on_collection_request(self, msg, consumer, flow, *, workspace): v = msg.value() id = msg.properties().get("id", "unknown") logger.info(f"Handling collection request {id}...") try: - resp = await self.process_collection_request(v) + resp = await self.process_collection_request(v, workspace) await self.collection_response_producer.send( resp, properties={"id": id} ) @@ -597,7 +609,7 @@ class Processor(AsyncProcessor): @staticmethod def add_args(parser): - AsyncProcessor.add_args(parser) + WorkspaceProcessor.add_args(parser) parser.add_argument( '--librarian-request-queue', diff --git a/trustgraph-flow/trustgraph/tables/library.py b/trustgraph-flow/trustgraph/tables/library.py index 86706079..6dd5d3e4 100644 --- a/trustgraph-flow/trustgraph/tables/library.py +++ b/trustgraph-flow/trustgraph/tables/library.py @@ -312,7 +312,7 @@ class LibraryTableStore: return bool(rows) - async def add_document(self, document, object_id): + async def add_document(self, workspace, document, object_id): logger.info(f"Adding document {document.id} {object_id}") @@ -332,7 +332,7 @@ class LibraryTableStore: self.cassandra, self.insert_document_stmt, ( - document.id, document.workspace, int(document.time * 1000), + document.id, workspace, int(document.time * 1000), document.kind, document.title, document.comments, metadata, document.tags, object_id, parent_id, document_type @@ -344,7 +344,7 @@ class LibraryTableStore: logger.debug("Add complete") - async def update_document(self, document): + async def update_document(self, workspace, document): logger.info(f"Updating document {document.id}") @@ -362,7 +362,7 @@ class LibraryTableStore: ( int(document.time * 1000), document.title, document.comments, metadata, document.tags, - document.workspace, document.id + workspace, document.id ), ) except Exception: @@ -404,7 +404,6 @@ class LibraryTableStore: lst = [ DocumentMetadata( id = row[0], - workspace = workspace, time = int(time.mktime(row[1].timetuple())), kind = row[2], title = row[3], @@ -446,7 +445,6 @@ class LibraryTableStore: lst = [ DocumentMetadata( id = row[0], - workspace = row[1], time = int(time.mktime(row[2].timetuple())), kind = row[3], title = row[4], @@ -487,7 +485,6 @@ class LibraryTableStore: for row in rows: doc = DocumentMetadata( id = id, - workspace = workspace, time = int(time.mktime(row[0].timetuple())), kind = row[1], title = row[2], @@ -540,7 +537,7 @@ class LibraryTableStore: return bool(rows) - async def add_processing(self, processing): + async def add_processing(self, workspace, processing): logger.info(f"Adding processing {processing.id}") @@ -551,7 +548,7 @@ class LibraryTableStore: ( processing.id, processing.document_id, int(processing.time * 1000), processing.flow, - processing.workspace, processing.collection, + workspace, processing.collection, processing.tags ), ) @@ -597,7 +594,6 @@ class LibraryTableStore: document_id = row[1], time = int(time.mktime(row[2].timetuple())), flow = row[3], - workspace = workspace, collection = row[4], tags = row[5] if row[5] else [], )