Merge branch 'release/v2.4'

This commit is contained in:
Cyber MacGeddon 2026-05-11 15:15:50 +01:00
commit 159b1e2824
98 changed files with 2026 additions and 1445 deletions

View file

@ -11,8 +11,9 @@ ENV PIP_BREAK_SYSTEM_PACKAGES=1
RUN dnf install -y python3.13 && \ RUN dnf install -y python3.13 && \
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \ alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
python -m ensurepip --upgrade && \ python -m ensurepip --upgrade && \
pip3 install --no-cache-dir --upgrade 'pip>=26.0' 'setuptools>=78.1.1' && \
pip3 install --no-cache-dir build wheel aiohttp && \ pip3 install --no-cache-dir build wheel aiohttp && \
pip3 install --no-cache-dir pulsar-client==3.7.0 && \ pip3 install --no-cache-dir pulsar-client==3.11.0 && \
dnf clean all dnf clean all
# ---------------------------------------------------------------------------- # ----------------------------------------------------------------------------

View file

@ -11,8 +11,9 @@ ENV PIP_BREAK_SYSTEM_PACKAGES=1
RUN dnf install -y python3.13 && \ RUN dnf install -y python3.13 && \
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \ alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
python -m ensurepip --upgrade && \ python -m ensurepip --upgrade && \
pip3 install --no-cache-dir --upgrade 'pip>=26.0' 'setuptools>=78.1.1' && \
pip3 install --no-cache-dir build wheel aiohttp && \ pip3 install --no-cache-dir build wheel aiohttp && \
pip3 install --no-cache-dir pulsar-client==3.7.0 && \ pip3 install --no-cache-dir pulsar-client==3.11.0 && \
dnf clean all dnf clean all
# ---------------------------------------------------------------------------- # ----------------------------------------------------------------------------

View file

@ -11,18 +11,19 @@ ENV PIP_BREAK_SYSTEM_PACKAGES=1
RUN dnf install -y python3.13 && \ RUN dnf install -y python3.13 && \
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \ alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
python -m ensurepip --upgrade && \ python -m ensurepip --upgrade && \
pip3 install --no-cache-dir --upgrade 'pip>=26.0' 'setuptools>=78.1.1' && \
pip3 install --no-cache-dir build wheel aiohttp rdflib && \ pip3 install --no-cache-dir build wheel aiohttp rdflib && \
pip3 install --no-cache-dir pulsar-client==3.7.0 && \ pip3 install --no-cache-dir pulsar-client==3.11.0 && \
dnf clean all dnf clean all
RUN pip3 install --no-cache-dir \ RUN pip3 install --no-cache-dir \
anthropic cohere mistralai openai \ anthropic cohere mistralai openai \
ollama \ ollama \
langchain==0.3.25 langchain-core==0.3.60 \ langchain==1.2.16 langchain-core==1.3.2 \
langchain-text-splitters==0.3.8 \ langchain-text-splitters==1.1.2 \
langchain-community==0.3.24 \ langchain-community==0.4.1 \
pymilvus \ pymilvus \
pulsar-client==3.7.0 scylla-driver pyyaml \ pulsar-client==3.11.0 scylla-driver pyyaml \
neo4j tiktoken falkordb && \ neo4j tiktoken falkordb && \
pip3 cache purge pip3 cache purge

View file

@ -8,8 +8,9 @@ ENV PIP_BREAK_SYSTEM_PACKAGES=1
RUN dnf install -y python3.12 && \ RUN dnf install -y python3.12 && \
alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \ alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \
python -m ensurepip --upgrade && \ python -m ensurepip --upgrade && \
pip3 install --no-cache-dir --upgrade 'pip>=26.0' 'setuptools>=78.1.1' && \
pip3 install --no-cache-dir build wheel aiohttp && \ pip3 install --no-cache-dir build wheel aiohttp && \
pip3 install --no-cache-dir pulsar-client==3.7.0 && \ pip3 install --no-cache-dir pulsar-client==3.11.0 && \
dnf clean all dnf clean all
# This won't work on ARM # This won't work on ARM
@ -19,15 +20,15 @@ RUN dnf install -y python3.12 && \
RUN pip3 install torch RUN pip3 install torch
RUN pip3 install --no-cache-dir \ RUN pip3 install --no-cache-dir \
langchain==0.3.25 langchain-core==0.3.60 langchain-huggingface==0.2.0 \ langchain==1.2.16 langchain-core==1.3.2 langchain-huggingface==1.2.2 \
langchain-community==0.3.24 \ langchain-community==0.4.1 \
sentence-transformers==4.1.0 transformers==4.51.3 \ sentence-transformers==5.4.1 transformers==5.7.0 \
huggingface-hub==0.31.2 \ huggingface-hub==1.13.0 \
pulsar-client==3.7.0 pulsar-client==3.11.0
# Most commonly used embeddings model, just build it into the container # Most commonly used embeddings model, just build it into the container
# image # image
RUN huggingface-cli download sentence-transformers/all-MiniLM-L6-v2 RUN hf download sentence-transformers/all-MiniLM-L6-v2
# ---------------------------------------------------------------------------- # ----------------------------------------------------------------------------
# Build a container which contains the built Python packages. The build # Build a container which contains the built Python packages. The build

View file

@ -11,6 +11,7 @@ ENV PIP_BREAK_SYSTEM_PACKAGES=1
RUN dnf install -y python3.13 && \ RUN dnf install -y python3.13 && \
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \ alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
python -m ensurepip --upgrade && \ python -m ensurepip --upgrade && \
pip3 install --no-cache-dir --upgrade 'pip>=26.0' 'setuptools>=78.1.1' && \
pip3 install --no-cache-dir mcp websockets && \ pip3 install --no-cache-dir mcp websockets && \
dnf clean all dnf clean all

View file

@ -12,8 +12,9 @@ RUN dnf install -y python3.13 && \
dnf install -y tesseract poppler-utils && \ dnf install -y tesseract poppler-utils && \
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \ alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
python -m ensurepip --upgrade && \ python -m ensurepip --upgrade && \
pip3 install --no-cache-dir --upgrade 'pip>=26.0' 'setuptools>=78.1.1' && \
pip3 install --no-cache-dir build wheel aiohttp && \ pip3 install --no-cache-dir build wheel aiohttp && \
pip3 install --no-cache-dir pulsar-client==3.7.0 && \ pip3 install --no-cache-dir pulsar-client==3.11.0 && \
dnf clean all dnf clean all
# ---------------------------------------------------------------------------- # ----------------------------------------------------------------------------

View file

@ -10,8 +10,9 @@ ENV PIP_BREAK_SYSTEM_PACKAGES=1
RUN dnf install -y python3.13 libxcb mesa-libGL && \ RUN dnf install -y python3.13 libxcb mesa-libGL && \
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \ alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
python -m ensurepip --upgrade && \ python -m ensurepip --upgrade && \
pip3 install --no-cache-dir --upgrade 'pip>=26.0' 'setuptools>=78.1.1' && \
pip3 install --no-cache-dir build wheel aiohttp && \ pip3 install --no-cache-dir build wheel aiohttp && \
pip3 install --no-cache-dir pulsar-client==3.7.0 && \ pip3 install --no-cache-dir pulsar-client==3.11.0 && \
dnf clean all dnf clean all
# ---------------------------------------------------------------------------- # ----------------------------------------------------------------------------

View file

@ -11,8 +11,9 @@ ENV PIP_BREAK_SYSTEM_PACKAGES=1
RUN dnf install -y python3.13 && \ RUN dnf install -y python3.13 && \
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \ alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
python -m ensurepip --upgrade && \ python -m ensurepip --upgrade && \
pip3 install --no-cache-dir --upgrade 'pip>=26.0' 'setuptools>=78.1.1' && \
pip3 install --no-cache-dir build wheel aiohttp && \ pip3 install --no-cache-dir build wheel aiohttp && \
pip3 install --no-cache-dir pulsar-client==3.7.0 && \ pip3 install --no-cache-dir pulsar-client==3.11.0 && \
pip3 install --no-cache-dir google-cloud-aiplatform && \ pip3 install --no-cache-dir google-cloud-aiplatform && \
dnf clean all dnf clean all

View file

@ -0,0 +1,366 @@
---
layout: default
title: "Workspace-Scoped Services"
parent: "Tech Specs"
---
# Workspace-Scoped Services
## Problem Statement
Workspace-scoped services (librarian, config, knowledge, collection
management) currently operate on global queues — a single
`request:tg:librarian` queue handles requests for all workspaces.
Workspace identity is carried as a field in the request body, set by
the gateway after authentication. This creates several problems:
- **No structural isolation.** All workspaces share a single queue.
Workspace scoping depends entirely on a body field being populated
correctly. If the field is missing or wrong, the service operates
on the wrong workspace — or fails with a confusing error. This is
a security concern: workspace isolation should be enforced by
infrastructure, not by trusting a field.
- **Redundant workspace fields.** Nested objects within requests
(e.g. `processing-metadata`, `document-metadata`) carry their own
`workspace` fields alongside the top-level request workspace.
The gateway resolves the top-level workspace but does not
propagate into nested payloads. Services that read workspace from
a nested object instead of the top-level address see `None` and
fail.
- **No workspace lifecycle awareness.** Workspace-scoped services
have no mechanism to learn when workspaces are created or deleted.
Flow processors discover workspaces indirectly through config
entries, but workspace-scoped services on global queues have no
equivalent. There is no event when a workspace appears or
disappears.
- **Inconsistency with flow-scoped services.** Flow-scoped services
already use per-workspace, per-flow queue names
(`request:tg:{workspace}:embeddings:{class}`). Workspace-scoped
services are the exception — they sit on global queues while
everything else is structurally isolated.
## Design
### Per-workspace queues for workspace-scoped services
Workspace-scoped services move from global queues to per-workspace
queues. The queue name includes the workspace identifier:
**Current (global):**
```
request:tg:librarian
request:tg:config
```
**Proposed (per-workspace):**
```
request:tg:librarian:{workspace}
request:tg:config:{workspace}
```
The gateway routes requests to the correct queue based on the
resolved workspace from authentication — the same workspace that
today gets written into the request body. The workspace is now part
of the queue address, not just a field in the payload.
Services subscribe to per-workspace queues. When a new workspace is
created, they subscribe to its queue. When a workspace is deleted,
they unsubscribe.
### Workspace lifecycle via the `__workspaces__` config namespace
Workspace lifecycle events are modelled as config changes in a
reserved `__workspaces__` namespace. This mirrors the existing
`__system__` namespace — a reserved space for infrastructure
concerns that don't belong to any user workspace.
When IAM creates a workspace, it writes an entry to the config
service:
```
workspace: __workspaces__
type: workspace
key: <workspace-id>
value: {"enabled": true}
```
When IAM deletes (or disables) a workspace, it updates or deletes
the entry. The config service sees this as a normal config change
and pushes a notification through the existing `ConfigPush`
mechanism.
This avoids introducing a new notification channel. The config
service already has the machinery to notify subscribers of changes
by type and workspace. Workspace lifecycle is just another config
type that services can register handlers for.
### Config push changes
#### Remove `_`-prefix suppression
The config service currently suppresses notifications for workspaces
whose names start with `_`. This suppression is removed — the
config service pushes notifications for all workspaces
unconditionally.
The filtering moves to the consumer side. `AsyncProcessor` already
filters `_`-prefixed workspaces in its config handler dispatch
(lines 212 and 315 of `async_processor.py`). This filtering is
retained as the default behaviour, but handlers can opt in to
infrastructure namespaces by registering for them explicitly (see
`WorkspaceProcessor` below).
#### Workspace change events
The `ConfigPush` message gains a `workspace_changes` field alongside
the existing `changes` field:
```python
@dataclass
class ConfigPush:
version: int = 0
# Config changes: type -> [affected workspaces]
changes: dict[str, list[str]] = field(default_factory=dict)
# Workspace lifecycle: created/deleted workspace lists
workspace_changes: WorkspaceChanges | None = None
@dataclass
class WorkspaceChanges:
created: list[str] = field(default_factory=list)
deleted: list[str] = field(default_factory=list)
```
The config service populates `workspace_changes` when it detects
changes to the `__workspaces__` config namespace. A new key
appearing is a creation; a key being deleted is a deletion.
Services that don't care about workspace lifecycle ignore the field.
Services that do (workspace-scoped services, the gateway) react by
subscribing to or tearing down per-workspace queues.
### The `WorkspaceProcessor` base class
A new base class sits between `AsyncProcessor` and `FlowProcessor`
in the processor hierarchy:
```
AsyncProcessor → WorkspaceProcessor → FlowProcessor
```
`WorkspaceProcessor` manages per-workspace queue lifecycle the same
way `FlowProcessor` manages per-flow lifecycle. It:
1. On startup, discovers existing workspaces by fetching config from
the `__workspaces__` namespace (using the existing
`_fetch_type_all_workspaces` pattern).
2. For each workspace, subscribes to the service's per-workspace
queue (e.g. `request:tg:librarian:{workspace}`).
3. Registers a config handler for the `workspace` type in the
`__workspaces__` namespace. When a workspace is created, it
subscribes to the new queue. When a workspace is deleted, it
unsubscribes and tears down.
4. Exposes hooks for derived classes:
- `on_workspace_created(workspace)` — called after subscribing
to the new workspace's queue.
- `on_workspace_deleted(workspace)` — called before
unsubscribing from the workspace's queue.
`FlowProcessor` extends `WorkspaceProcessor` instead of
`AsyncProcessor`. Flows exist within workspaces, so the hierarchy
is natural: workspace creation triggers queue subscription, then
flow config changes within that workspace trigger flow start/stop.
Services that are workspace-scoped but not flow-scoped (librarian,
knowledge, collection management) extend `WorkspaceProcessor`
directly.
### Gateway routing changes
The gateway currently dispatches workspace-scoped requests to global
service dispatchers. This changes to per-workspace dispatchers that
route to per-workspace queues.
For HTTP requests, the resolved workspace from the URL path
(`/api/v1/workspaces/{w}/library`) determines the target queue.
For WebSocket requests via the Mux, the resolved workspace from
`enforce_workspace` determines the target queue. The Mux already
resolves workspace before dispatching (line 214 of `mux.py`); the
change is that `invoke_global_service` uses workspace to select the
queue, rather than routing to a single global queue.
System-level services (IAM) remain on global queues — they are not
workspace-scoped.
### Workspace field on nested metadata objects
With per-workspace queues, the workspace is part of the queue
address. Services know which workspace they are serving by which
queue a message arrived on.
The `workspace` field on `DocumentMetadata` and
`ProcessingMetadata` in the librarian schema becomes a storage
attribute — the workspace the record belongs to, populated by the
service from the request context, not by the caller. The service
reads workspace from `request.workspace` (the resolved address) or
from the queue context, never from a nested payload field.
Callers are not required to populate workspace on nested objects.
The service fills it in authoritatively from the request context
before storing.
## Interaction with existing specs
### IAM (`iam.md`, `iam-contract.md`)
IAM is the authority for workspace existence. When IAM creates or
deletes a workspace, it writes to the `__workspaces__` config
namespace. This is a two-step operation: register the workspace in
IAM's own store (`iam_workspaces` table), then announce it via
config.
The IAM service itself remains on a global queue — it is a
system-level service, not workspace-scoped.
### Config service
The config service is workspace-scoped — it stores per-workspace
configuration. Under this design, the config service moves to
per-workspace queues like other workspace-scoped services.
On startup, the config service discovers workspaces from its own
store (it has direct access to the config tables, unlike other
services that fetch via request/response). It subscribes to
per-workspace queues for each known workspace.
When IAM writes a new workspace entry to the `__workspaces__`
namespace, the config service sees the write directly (it is the
config service), creates the per-workspace queue, and pushes the
notification.
### Flow blueprints (`flow-blueprint-definition.md`)
Flow blueprints already use `{workspace}` in queue name templates.
No changes needed — flows are created within an already-existing
workspace, so the per-workspace infrastructure is in place before
flow start.
### Data ownership (`data-ownership-model.md`)
This spec reinforces the data ownership model: a workspace is the
primary isolation boundary, and per-workspace queues make that
boundary structural rather than conventional.
## Migration
### Queue naming
Existing deployments use global queues for workspace-scoped
services. Migration requires:
1. Deploy updated services that subscribe to both global and
per-workspace queues during a transition period.
2. Update the gateway to route to per-workspace queues.
3. Drain the global queues.
4. Remove global queue subscriptions from services.
### `__workspaces__` bootstrap
On first start after migration, IAM populates the `__workspaces__`
config namespace with entries for all existing workspaces from
`iam_workspaces`. This seeds the config store so that
workspace-scoped services discover existing workspaces on startup.
### Config push compatibility
The `workspace_changes` field on `ConfigPush` is additive.
Services that don't understand it ignore it (the field defaults to
`None`). No breaking change to the push protocol.
## Summary of changes
| Component | Change |
|-----------|--------|
| Queue names | Workspace-scoped services move from `request:tg:{service}` to `request:tg:{service}:{workspace}` |
| `__workspaces__` namespace | New reserved config namespace for workspace lifecycle |
| IAM service | Writes to `__workspaces__` on workspace create/delete |
| Config service | Removes `_`-prefix notification suppression; generates `workspace_changes` events; moves to per-workspace queues |
| `ConfigPush` schema | Adds `workspace_changes` field (`WorkspaceChanges` dataclass) |
| `WorkspaceProcessor` | New base class managing per-workspace queue lifecycle |
| `FlowProcessor` | Extends `WorkspaceProcessor` instead of `AsyncProcessor` |
| `AsyncProcessor` | Relaxes `_`-prefix filtering to allow opt-in for infrastructure namespaces |
| Gateway | Routes workspace-scoped requests to per-workspace queues |
| Librarian schema | `workspace` on nested metadata becomes a service-populated storage attribute, not a caller-supplied address |
## Implementation Plan
### Phase 1: Foundation — `__workspaces__` namespace and config push
- **`ConfigPush` schema** (`trustgraph-base/trustgraph/schema/services/config.py`): Add `WorkspaceChanges` dataclass and `workspace_changes` field.
- **Config push serialization** (`trustgraph-base/trustgraph/messaging/translators/`): Encode/decode the new field.
- **Config service** (`trustgraph-flow/trustgraph/config/`): Detect writes to `__workspaces__` namespace and populate `workspace_changes` on the push message. Remove `_`-prefix notification suppression.
- **`AsyncProcessor`** (`trustgraph-base/trustgraph/base/async_processor.py`): Relax `_`-prefix filtering so handlers can opt in to infrastructure namespaces.
- **IAM service** (`trustgraph-flow/trustgraph/iam/`): Write to `__workspaces__` config namespace on `create-workspace` and `delete-workspace`. Add bootstrap step to seed `__workspaces__` entries for existing workspaces.
### Phase 2: `WorkspaceProcessor` base class
- **New `WorkspaceProcessor`** (`trustgraph-base/trustgraph/base/workspace_processor.py`): Implements workspace discovery on startup, per-workspace queue subscribe/unsubscribe, workspace lifecycle handler registration, `on_workspace_created`/`on_workspace_deleted` hooks.
- **`FlowProcessor`** (`trustgraph-base/trustgraph/base/flow_processor.py`): Re-parent from `AsyncProcessor` to `WorkspaceProcessor`.
- **Verify** existing flow processors continue to work — the new layer should be transparent to them.
### Phase 3: Per-workspace queues for workspace-scoped services
- **Queue definitions** (`trustgraph-base/trustgraph/schema/`): Update queue names for librarian, config, knowledge, collection management to include `{workspace}`.
- **Librarian** (`trustgraph-flow/trustgraph/librarian/`): Extend `WorkspaceProcessor`. Remove reliance on workspace from nested metadata objects.
- **Knowledge service, collection management** and other workspace-scoped services: Extend `WorkspaceProcessor`.
- **Config service**: Self-bootstrap per-workspace queues from its own store on startup; subscribe to new workspace queues when `__workspaces__` entries appear.
### Phase 4: Gateway routing
- **Gateway dispatcher manager** (`trustgraph-flow/trustgraph/gateway/dispatch/manager.py`): Route workspace-scoped services to per-workspace queues using resolved workspace. System-level services (IAM) remain on global queues.
- **Mux** (`trustgraph-flow/trustgraph/gateway/dispatch/mux.py`): Pass workspace to `invoke_global_service` for workspace-scoped services.
- **HTTP endpoints** (`trustgraph-flow/trustgraph/gateway/endpoint/`): Route to per-workspace queues based on URL path workspace.
### Phase 5: Schema cleanup
- **`DocumentMetadata`, `ProcessingMetadata`** (`trustgraph-base/trustgraph/schema/services/library.py`): Remove `workspace` field from nested metadata objects, or retain as a service-populated storage attribute only.
- **Serialization** (`trustgraph-flow/trustgraph/gateway/dispatch/serialize.py`, `trustgraph-base/trustgraph/messaging/translators/metadata.py`): Update translators to match.
- **API client** (`trustgraph-base/trustgraph/api/library.py`): Stop sending workspace in nested payloads.
- **Librarian service** (`trustgraph-flow/trustgraph/librarian/`): Populate workspace on stored records from request context.
### Dependencies
```
Phase 1 (foundation)
Phase 2 (WorkspaceProcessor)
Phase 3 (per-workspace queues) ←→ Phase 4 (gateway routing)
↓ ↓
Phase 5 (schema cleanup)
```
Phases 3 and 4 can be developed in parallel but must be deployed together — services expecting per-workspace queues need the gateway to route to them.
## References
- [Identity and Access Management](iam.md) — workspace registry,
authentication, and workspace resolution.
- [IAM Contract](iam-contract.md) — resource model and workspace as
address vs. parameter.
- [Data Ownership and Information Separation](data-ownership-model.md)
— workspace as isolation boundary.
- [Config Push and Poke](config-push-poke.md) — config notification
mechanism.
- [Flow Blueprint Definition](flow-blueprint-definition.md) —
`{workspace}` template variable in queue names.
- [Flow Service Queue Lifecycle](flow-service-queue-lifecycle.md) —
queue ownership and lifecycle model.

View file

@ -110,7 +110,8 @@ class TestEndToEndConfigurationFlow:
cassandra_host=['kg-host1', 'kg-host2', 'kg-host3', 'kg-host4'], cassandra_host=['kg-host1', 'kg-host2', 'kg-host3', 'kg-host4'],
cassandra_username='kg-user', cassandra_username='kg-user',
cassandra_password='kg-pass', cassandra_password='kg-pass',
keyspace='knowledge' keyspace='knowledge',
replication_factor=1,
) )
@ -182,7 +183,8 @@ class TestConfigurationPriorityEndToEnd:
cassandra_host=['partial-host'], # From parameter cassandra_host=['partial-host'], # From parameter
cassandra_username='fallback-user', # From environment cassandra_username='fallback-user', # From environment
cassandra_password='fallback-pass', # From environment cassandra_password='fallback-pass', # From environment
keyspace='knowledge' keyspace='knowledge',
replication_factor=1,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@ -273,7 +275,8 @@ class TestNoBackwardCompatibilityEndToEnd:
cassandra_host=['legacy-kg-host'], cassandra_host=['legacy-kg-host'],
cassandra_username=None, # Should be None since cassandra_user is not recognized cassandra_username=None, # Should be None since cassandra_user is not recognized
cassandra_password='legacy-kg-pass', cassandra_password='legacy-kg-pass',
keyspace='knowledge' keyspace='knowledge',
replication_factor=1,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@ -367,13 +370,13 @@ class TestMultipleHostsHandling:
from trustgraph.base.cassandra_config import resolve_cassandra_config from trustgraph.base.cassandra_config import resolve_cassandra_config
# Test various whitespace scenarios # Test various whitespace scenarios
hosts1, _, _, _ = resolve_cassandra_config(host='host1, host2 , host3') hosts1, _, _, _, _ = resolve_cassandra_config(host='host1, host2 , host3')
assert hosts1 == ['host1', 'host2', 'host3'] assert hosts1 == ['host1', 'host2', 'host3']
hosts2, _, _, _ = resolve_cassandra_config(host='host1,host2,host3,') hosts2, _, _, _, _ = resolve_cassandra_config(host='host1,host2,host3,')
assert hosts2 == ['host1', 'host2', 'host3'] assert hosts2 == ['host1', 'host2', 'host3']
hosts3, _, _, _ = resolve_cassandra_config(host=' host1 , host2 ') hosts3, _, _, _, _ = resolve_cassandra_config(host=' host1 , host2 ')
assert hosts3 == ['host1', 'host2'] assert hosts3 == ['host1', 'host2']

View file

@ -54,7 +54,7 @@ class TestDocumentRagIntegration:
@pytest.fixture @pytest.fixture
def mock_fetch_chunk(self): def mock_fetch_chunk(self):
"""Mock fetch_chunk function that retrieves chunk content from librarian""" """Mock fetch_chunk function that retrieves chunk content from librarian"""
async def fetch(chunk_id, user): async def fetch(chunk_id):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}") return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
return fetch return fetch

View file

@ -297,10 +297,10 @@ class TestTextCompletionIntegration:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_text_completion_authentication_patterns(self): async def test_text_completion_authentication_patterns(self):
"""Test different authentication configurations""" """Test different authentication configurations"""
# Test missing API key first (this should fail early) # Test missing API key - now uses placeholder instead of raising
with pytest.raises(RuntimeError) as exc_info: # (newer openai package rejects empty string keys at validation)
Processor(id="test-no-key", api_key=None) # Processor(id="test-no-key", api_key=None) would fail on
assert "OpenAI API key not specified" in str(exc_info.value) # missing taskgroup, not on API key
# Test authentication pattern by examining the initialization logic # Test authentication pattern by examining the initialization logic
# Since we can't fully instantiate due to taskgroup requirements, # Since we can't fully instantiate due to taskgroup requirements,

View file

@ -145,7 +145,7 @@ class TestResolveCassandraConfig:
def test_default_configuration(self): def test_default_configuration(self):
"""Test resolution with no parameters or environment variables.""" """Test resolution with no parameters or environment variables."""
with patch.dict(os.environ, {}, clear=True): with patch.dict(os.environ, {}, clear=True):
hosts, username, password, keyspace = resolve_cassandra_config() hosts, username, password, keyspace, _ = resolve_cassandra_config()
assert hosts == ['cassandra'] assert hosts == ['cassandra']
assert username is None assert username is None
@ -160,7 +160,7 @@ class TestResolveCassandraConfig:
} }
with patch.dict(os.environ, env_vars, clear=True): with patch.dict(os.environ, env_vars, clear=True):
hosts, username, password, keyspace = resolve_cassandra_config() hosts, username, password, keyspace, _ = resolve_cassandra_config()
assert hosts == ['env1', 'env2', 'env3'] assert hosts == ['env1', 'env2', 'env3']
assert username == 'env-user' assert username == 'env-user'
@ -175,7 +175,7 @@ class TestResolveCassandraConfig:
} }
with patch.dict(os.environ, env_vars, clear=True): with patch.dict(os.environ, env_vars, clear=True):
hosts, username, password, keyspace = resolve_cassandra_config( hosts, username, password, keyspace, _ = resolve_cassandra_config(
host='explicit-host', host='explicit-host',
username='explicit-user', username='explicit-user',
password='explicit-pass' password='explicit-pass'
@ -188,19 +188,19 @@ class TestResolveCassandraConfig:
def test_host_list_parsing(self): def test_host_list_parsing(self):
"""Test different host list formats.""" """Test different host list formats."""
# Single host # Single host
hosts, _, _, _ = resolve_cassandra_config(host='single-host') hosts, _, _, _, _ = resolve_cassandra_config(host='single-host')
assert hosts == ['single-host'] assert hosts == ['single-host']
# Multiple hosts with spaces # Multiple hosts with spaces
hosts, _, _, _ = resolve_cassandra_config(host='host1, host2 ,host3') hosts, _, _, _, _ = resolve_cassandra_config(host='host1, host2 ,host3')
assert hosts == ['host1', 'host2', 'host3'] assert hosts == ['host1', 'host2', 'host3']
# Empty elements filtered out # Empty elements filtered out
hosts, _, _, _ = resolve_cassandra_config(host='host1,,host2,') hosts, _, _, _, _ = resolve_cassandra_config(host='host1,,host2,')
assert hosts == ['host1', 'host2'] assert hosts == ['host1', 'host2']
# Already a list # Already a list
hosts, _, _, _ = resolve_cassandra_config(host=['list-host1', 'list-host2']) hosts, _, _, _, _ = resolve_cassandra_config(host=['list-host1', 'list-host2'])
assert hosts == ['list-host1', 'list-host2'] assert hosts == ['list-host1', 'list-host2']
def test_args_object_resolution(self): def test_args_object_resolution(self):
@ -212,7 +212,7 @@ class TestResolveCassandraConfig:
cassandra_password = 'args-pass' cassandra_password = 'args-pass'
args = MockArgs() args = MockArgs()
hosts, username, password, keyspace = resolve_cassandra_config(args) hosts, username, password, keyspace, _ = resolve_cassandra_config(args)
assert hosts == ['args-host1', 'args-host2'] assert hosts == ['args-host1', 'args-host2']
assert username == 'args-user' assert username == 'args-user'
@ -233,7 +233,7 @@ class TestResolveCassandraConfig:
with patch.dict(os.environ, env_vars, clear=True): with patch.dict(os.environ, env_vars, clear=True):
args = PartialArgs() args = PartialArgs()
hosts, username, password, keyspace = resolve_cassandra_config(args) hosts, username, password, keyspace, _ = resolve_cassandra_config(args)
assert hosts == ['args-host'] # From args assert hosts == ['args-host'] # From args
assert username == 'env-user' # From env assert username == 'env-user' # From env
@ -251,7 +251,7 @@ class TestGetCassandraConfigFromParams:
'cassandra_password': 'new-pass' 'cassandra_password': 'new-pass'
} }
hosts, username, password, keyspace = get_cassandra_config_from_params(params) hosts, username, password, keyspace, _ = get_cassandra_config_from_params(params)
assert hosts == ['new-host1', 'new-host2'] assert hosts == ['new-host1', 'new-host2']
assert username == 'new-user' assert username == 'new-user'
@ -265,7 +265,7 @@ class TestGetCassandraConfigFromParams:
'graph_password': 'old-pass' 'graph_password': 'old-pass'
} }
hosts, username, password, keyspace = get_cassandra_config_from_params(params) hosts, username, password, keyspace, _ = get_cassandra_config_from_params(params)
# Should use defaults since graph_* params are not recognized # Should use defaults since graph_* params are not recognized
assert hosts == ['cassandra'] # Default assert hosts == ['cassandra'] # Default
@ -280,7 +280,7 @@ class TestGetCassandraConfigFromParams:
'cassandra_password': 'compat-pass' 'cassandra_password': 'compat-pass'
} }
hosts, username, password, keyspace = get_cassandra_config_from_params(params) hosts, username, password, keyspace, _ = get_cassandra_config_from_params(params)
assert hosts == ['compat-host'] assert hosts == ['compat-host']
assert username is None # cassandra_user is not recognized assert username is None # cassandra_user is not recognized
@ -298,7 +298,7 @@ class TestGetCassandraConfigFromParams:
'graph_password': 'old-pass' 'graph_password': 'old-pass'
} }
hosts, username, password, keyspace = get_cassandra_config_from_params(params) hosts, username, password, keyspace, _ = get_cassandra_config_from_params(params)
assert hosts == ['new-host'] # Only cassandra_* params work assert hosts == ['new-host'] # Only cassandra_* params work
assert username == 'new-user' # Only cassandra_* params work assert username == 'new-user' # Only cassandra_* params work
@ -314,7 +314,7 @@ class TestGetCassandraConfigFromParams:
with patch.dict(os.environ, env_vars, clear=True): with patch.dict(os.environ, env_vars, clear=True):
params = {} params = {}
hosts, username, password, keyspace = get_cassandra_config_from_params(params) hosts, username, password, keyspace, _ = get_cassandra_config_from_params(params)
assert hosts == ['fallback-host1', 'fallback-host2'] assert hosts == ['fallback-host1', 'fallback-host2']
assert username == 'fallback-user' assert username == 'fallback-user'
@ -334,7 +334,7 @@ class TestConfigurationPriority:
with patch.dict(os.environ, env_vars, clear=True): with patch.dict(os.environ, env_vars, clear=True):
# CLI args should override everything # CLI args should override everything
hosts, username, password, keyspace = resolve_cassandra_config( hosts, username, password, keyspace, _ = resolve_cassandra_config(
host='cli-host', host='cli-host',
username='cli-user', username='cli-user',
password='cli-pass' password='cli-pass'
@ -354,7 +354,7 @@ class TestConfigurationPriority:
with patch.dict(os.environ, env_vars, clear=True): with patch.dict(os.environ, env_vars, clear=True):
# Only provide host via CLI # Only provide host via CLI
hosts, username, password, keyspace = resolve_cassandra_config( hosts, username, password, keyspace, _ = resolve_cassandra_config(
host='cli-host' host='cli-host'
# username and password not provided # username and password not provided
) )
@ -366,7 +366,7 @@ class TestConfigurationPriority:
def test_no_config_defaults(self): def test_no_config_defaults(self):
"""Test that defaults are used when no configuration is provided.""" """Test that defaults are used when no configuration is provided."""
with patch.dict(os.environ, {}, clear=True): with patch.dict(os.environ, {}, clear=True):
hosts, username, password, keyspace = resolve_cassandra_config() hosts, username, password, keyspace, _ = resolve_cassandra_config()
assert hosts == ['cassandra'] # Default assert hosts == ['cassandra'] # Default
assert username is None # Default assert username is None # Default
@ -378,17 +378,17 @@ class TestEdgeCases:
def test_empty_host_string(self): def test_empty_host_string(self):
"""Test handling of empty host string falls back to default.""" """Test handling of empty host string falls back to default."""
hosts, _, _, _ = resolve_cassandra_config(host='') hosts, _, _, _, _ = resolve_cassandra_config(host='')
assert hosts == ['cassandra'] # Falls back to default assert hosts == ['cassandra'] # Falls back to default
def test_whitespace_only_host(self): def test_whitespace_only_host(self):
"""Test handling of whitespace-only host string.""" """Test handling of whitespace-only host string."""
hosts, _, _, _ = resolve_cassandra_config(host=' ') hosts, _, _, _, _ = resolve_cassandra_config(host=' ')
assert hosts == [] # Empty after stripping whitespace assert hosts == [] # Empty after stripping whitespace
def test_none_values_preserved(self): def test_none_values_preserved(self):
"""Test that None values are preserved correctly.""" """Test that None values are preserved correctly."""
hosts, username, password, keyspace = resolve_cassandra_config( hosts, username, password, keyspace, _ = resolve_cassandra_config(
host=None, host=None,
username=None, username=None,
password=None password=None
@ -401,7 +401,7 @@ class TestEdgeCases:
def test_mixed_none_and_values(self): def test_mixed_none_and_values(self):
"""Test mixing None and actual values.""" """Test mixing None and actual values."""
hosts, username, password, keyspace = resolve_cassandra_config( hosts, username, password, keyspace, _ = resolve_cassandra_config(
host='mixed-host', host='mixed-host',
username=None, username=None,
password='mixed-pass' password='mixed-pass'

View file

@ -233,7 +233,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
mock_flow2.start.assert_called_once() mock_flow2.start.assert_called_once()
@with_async_processor_patches @with_async_processor_patches
@patch('trustgraph.base.async_processor.AsyncProcessor.start') @patch('trustgraph.base.workspace_processor.WorkspaceProcessor.start')
async def test_start_calls_parent(self, mock_parent_start, *mocks): async def test_start_calls_parent(self, mock_parent_start, *mocks):
"""Test that start() calls parent start method""" """Test that start() calls parent start method"""
mock_parent_start.return_value = None mock_parent_start.return_value = None

View file

@ -177,8 +177,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Mock save_child_document to avoid waiting for librarian response # Mock save_child_document on flow to avoid waiting for librarian response
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
# Mock message with TextDocument # Mock message with TextDocument
mock_message = MagicMock() mock_message = MagicMock()
@ -204,6 +203,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
"output": mock_producer, "output": mock_producer,
"triples": mock_triples_producer, "triples": mock_triples_producer,
}.get(key) }.get(key)
mock_flow.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
# Act # Act
await processor.on_message(mock_message, mock_consumer, mock_flow) await processor.on_message(mock_message, mock_consumer, mock_flow)

View file

@ -177,8 +177,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Mock save_child_document to avoid librarian producer interactions # Mock save_child_document on flow to avoid librarian producer interactions
processor.librarian.save_child_document = AsyncMock(return_value="chunk-id")
# Mock message with TextDocument # Mock message with TextDocument
mock_message = MagicMock() mock_message = MagicMock()
@ -204,6 +203,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
"output": mock_producer, "output": mock_producer,
"triples": mock_triples_producer, "triples": mock_triples_producer,
}.get(key) }.get(key)
mock_flow.librarian.save_child_document = AsyncMock(return_value="chunk-id")
# Act # Act
await processor.on_message(mock_message, mock_consumer, mock_flow) await processor.on_message(mock_message, mock_consumer, mock_flow)

View file

@ -45,7 +45,6 @@ def mock_flow_config():
def mock_request(): def mock_request():
"""Mock knowledge load request.""" """Mock knowledge load request."""
request = Mock() request = Mock()
request.workspace = "test-user"
request.id = "test-doc-id" request.id = "test-doc-id"
request.collection = "test-collection" request.collection = "test-collection"
request.flow = "test-flow" request.flow = "test-flow"
@ -131,17 +130,17 @@ class TestKnowledgeManagerLoadCore:
# Start the core loader background task # Start the core loader background task
knowledge_manager.background_task = None knowledge_manager.background_task = None
await knowledge_manager.load_kg_core(mock_request, mock_respond) await knowledge_manager.load_kg_core(mock_request, mock_respond, "test-user")
# Wait for background processing # Wait for background processing
import asyncio import asyncio
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
# Verify publishers were created and started # Verify publishers were created and started
assert mock_publisher_class.call_count == 2 assert mock_publisher_class.call_count == 2
mock_triples_pub.start.assert_called_once() mock_triples_pub.start.assert_called_once()
mock_ge_pub.start.assert_called_once() mock_ge_pub.start.assert_called_once()
# Verify triples were sent with correct collection # Verify triples were sent with correct collection
mock_triples_pub.send.assert_called_once() mock_triples_pub.send.assert_called_once()
sent_triples = mock_triples_pub.send.call_args[0][1] sent_triples = mock_triples_pub.send.call_args[0][1]
@ -174,12 +173,12 @@ class TestKnowledgeManagerLoadCore:
# Start the core loader background task # Start the core loader background task
knowledge_manager.background_task = None knowledge_manager.background_task = None
await knowledge_manager.load_kg_core(mock_request, mock_respond) await knowledge_manager.load_kg_core(mock_request, mock_respond, "test-user")
# Wait for background processing # Wait for background processing
import asyncio import asyncio
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
# Verify graph embeddings were sent with correct collection # Verify graph embeddings were sent with correct collection
mock_ge_pub.send.assert_called_once() mock_ge_pub.send.assert_called_once()
sent_ge = mock_ge_pub.send.call_args[0][1] sent_ge = mock_ge_pub.send.call_args[0][1]
@ -191,7 +190,6 @@ class TestKnowledgeManagerLoadCore:
"""Test that load_kg_core falls back to 'default' when request.collection is None.""" """Test that load_kg_core falls back to 'default' when request.collection is None."""
# Create request with None collection # Create request with None collection
mock_request = Mock() mock_request = Mock()
mock_request.workspace = "test-user"
mock_request.id = "test-doc-id" mock_request.id = "test-doc-id"
mock_request.collection = None # Should fall back to "default" mock_request.collection = None # Should fall back to "default"
mock_request.flow = "test-flow" mock_request.flow = "test-flow"
@ -213,12 +211,12 @@ class TestKnowledgeManagerLoadCore:
# Start the core loader background task # Start the core loader background task
knowledge_manager.background_task = None knowledge_manager.background_task = None
await knowledge_manager.load_kg_core(mock_request, mock_respond) await knowledge_manager.load_kg_core(mock_request, mock_respond, "test-user")
# Wait for background processing # Wait for background processing
import asyncio import asyncio
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
# Verify triples were sent with default collection # Verify triples were sent with default collection
mock_triples_pub.send.assert_called_once() mock_triples_pub.send.assert_called_once()
sent_triples = mock_triples_pub.send.call_args[0][1] sent_triples = mock_triples_pub.send.call_args[0][1]
@ -246,13 +244,13 @@ class TestKnowledgeManagerLoadCore:
mock_publisher_class.side_effect = [mock_triples_pub, mock_ge_pub] mock_publisher_class.side_effect = [mock_triples_pub, mock_ge_pub]
# Start the core loader background task # Start the core loader background task
knowledge_manager.background_task = None knowledge_manager.background_task = None
await knowledge_manager.load_kg_core(mock_request, mock_respond) await knowledge_manager.load_kg_core(mock_request, mock_respond, "test-user")
# Wait for background processing # Wait for background processing
import asyncio import asyncio
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
# Verify both publishers were used with correct collection # Verify both publishers were used with correct collection
mock_triples_pub.send.assert_called_once() mock_triples_pub.send.assert_called_once()
sent_triples = mock_triples_pub.send.call_args[0][1] sent_triples = mock_triples_pub.send.call_args[0][1]
@ -267,7 +265,6 @@ class TestKnowledgeManagerLoadCore:
"""Test that load_kg_core validates flow configuration before processing.""" """Test that load_kg_core validates flow configuration before processing."""
# Request with invalid flow # Request with invalid flow
mock_request = Mock() mock_request = Mock()
mock_request.workspace = "test-user"
mock_request.id = "test-doc-id" mock_request.id = "test-doc-id"
mock_request.collection = "test-collection" mock_request.collection = "test-collection"
mock_request.flow = "invalid-flow" # Not in mock_flow_config.flows mock_request.flow = "invalid-flow" # Not in mock_flow_config.flows
@ -276,12 +273,12 @@ class TestKnowledgeManagerLoadCore:
# Start the core loader background task # Start the core loader background task
knowledge_manager.background_task = None knowledge_manager.background_task = None
await knowledge_manager.load_kg_core(mock_request, mock_respond) await knowledge_manager.load_kg_core(mock_request, mock_respond, "test-user")
# Wait for background processing # Wait for background processing
import asyncio import asyncio
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
# Should have responded with error # Should have responded with error
mock_respond.assert_called() mock_respond.assert_called()
response = mock_respond.call_args[0][0] response = mock_respond.call_args[0][0]
@ -295,18 +292,17 @@ class TestKnowledgeManagerLoadCore:
# Test missing ID # Test missing ID
mock_request = Mock() mock_request = Mock()
mock_request.workspace = "test-user"
mock_request.id = None # Missing mock_request.id = None # Missing
mock_request.collection = "test-collection" mock_request.collection = "test-collection"
mock_request.flow = "test-flow" mock_request.flow = "test-flow"
knowledge_manager.background_task = None knowledge_manager.background_task = None
await knowledge_manager.load_kg_core(mock_request, mock_respond) await knowledge_manager.load_kg_core(mock_request, mock_respond, "test-user")
# Wait for background processing # Wait for background processing
import asyncio import asyncio
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
# Should respond with error # Should respond with error
mock_respond.assert_called() mock_respond.assert_called()
response = mock_respond.call_args[0][0] response = mock_respond.call_args[0][0]
@ -321,18 +317,17 @@ class TestKnowledgeManagerOtherMethods:
async def test_get_kg_core_preserves_collection_from_store(self, knowledge_manager, sample_triples): async def test_get_kg_core_preserves_collection_from_store(self, knowledge_manager, sample_triples):
"""Test that get_kg_core preserves collection field from stored data.""" """Test that get_kg_core preserves collection field from stored data."""
mock_request = Mock() mock_request = Mock()
mock_request.workspace = "test-user"
mock_request.id = "test-doc-id" mock_request.id = "test-doc-id"
mock_respond = AsyncMock() mock_respond = AsyncMock()
async def mock_get_triples(user, doc_id, receiver): async def mock_get_triples(user, doc_id, receiver):
await receiver(sample_triples) await receiver(sample_triples)
knowledge_manager.table_store.get_triples = mock_get_triples knowledge_manager.table_store.get_triples = mock_get_triples
knowledge_manager.table_store.get_graph_embeddings = AsyncMock() knowledge_manager.table_store.get_graph_embeddings = AsyncMock()
await knowledge_manager.get_kg_core(mock_request, mock_respond) await knowledge_manager.get_kg_core(mock_request, mock_respond, "test-user")
# Should have called respond for triples and final EOS # Should have called respond for triples and final EOS
assert mock_respond.call_count >= 2 assert mock_respond.call_count >= 2
@ -352,14 +347,13 @@ class TestKnowledgeManagerOtherMethods:
async def test_list_kg_cores(self, knowledge_manager): async def test_list_kg_cores(self, knowledge_manager):
"""Test listing knowledge cores.""" """Test listing knowledge cores."""
mock_request = Mock() mock_request = Mock()
mock_request.workspace = "test-user"
mock_respond = AsyncMock() mock_respond = AsyncMock()
# Mock return value # Mock return value
knowledge_manager.table_store.list_kg_cores.return_value = ["doc1", "doc2", "doc3"] knowledge_manager.table_store.list_kg_cores.return_value = ["doc1", "doc2", "doc3"]
await knowledge_manager.list_kg_cores(mock_request, mock_respond) await knowledge_manager.list_kg_cores(mock_request, mock_respond, "test-user")
# Verify table store was called correctly # Verify table store was called correctly
knowledge_manager.table_store.list_kg_cores.assert_called_once_with("test-user") knowledge_manager.table_store.list_kg_cores.assert_called_once_with("test-user")
@ -374,12 +368,11 @@ class TestKnowledgeManagerOtherMethods:
async def test_delete_kg_core(self, knowledge_manager): async def test_delete_kg_core(self, knowledge_manager):
"""Test deleting knowledge cores.""" """Test deleting knowledge cores."""
mock_request = Mock() mock_request = Mock()
mock_request.workspace = "test-user"
mock_request.id = "test-doc-id" mock_request.id = "test-doc-id"
mock_respond = AsyncMock() mock_respond = AsyncMock()
await knowledge_manager.delete_kg_core(mock_request, mock_respond) await knowledge_manager.delete_kg_core(mock_request, mock_respond, "test-user")
# Verify table store was called correctly # Verify table store was called correctly
knowledge_manager.table_store.delete_kg_core.assert_called_once_with("test-user", "test-doc-id") knowledge_manager.table_store.delete_kg_core.assert_called_once_with("test-user", "test-doc-id")

View file

@ -156,6 +156,7 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
"output": mock_output_flow, "output": mock_output_flow,
"triples": mock_triples_flow, "triples": mock_triples_flow,
}.get(name)) }.get(name))
mock_flow.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
config = { config = {
'id': 'test-mistral-ocr', 'id': 'test-mistral-ocr',
@ -171,9 +172,6 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
("# Page 2\nMore content", 2), ("# Page 2\nMore content", 2),
] ]
# Mock save_child_document
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
with patch.object(processor, 'ocr', return_value=ocr_result): with patch.object(processor, 'ocr', return_value=ocr_result):
await processor.on_message(mock_msg, None, mock_flow) await processor.on_message(mock_msg, None, mock_flow)
@ -227,8 +225,7 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
Processor.add_args(mock_parser) Processor.add_args(mock_parser)
mock_parent_add_args.assert_called_once_with(mock_parser) mock_parent_add_args.assert_called_once_with(mock_parser)
assert mock_parser.add_argument.call_count == 3 assert mock_parser.add_argument.call_count == 1
# Check the API key arg is among them
call_args_list = [c[0] for c in mock_parser.add_argument.call_args_list] call_args_list = [c[0] for c in mock_parser.add_argument.call_args_list]
assert ('-k', '--api-key') in call_args_list assert ('-k', '--api-key') in call_args_list

View file

@ -72,6 +72,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
"output": mock_output_flow, "output": mock_output_flow,
"triples": mock_triples_flow, "triples": mock_triples_flow,
}.get(name)) }.get(name))
mock_flow.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
config = { config = {
'id': 'test-pdf-decoder', 'id': 'test-pdf-decoder',
@ -80,9 +81,6 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Mock save_child_document to avoid waiting for librarian response
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
await processor.on_message(mock_msg, None, mock_flow) await processor.on_message(mock_msg, None, mock_flow)
# Verify output was sent for each page # Verify output was sent for each page
@ -148,6 +146,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
"output": mock_output_flow, "output": mock_output_flow,
"triples": mock_triples_flow, "triples": mock_triples_flow,
}.get(name)) }.get(name))
mock_flow.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
config = { config = {
'id': 'test-pdf-decoder', 'id': 'test-pdf-decoder',
@ -156,9 +155,6 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Mock save_child_document to avoid waiting for librarian response
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
await processor.on_message(mock_msg, None, mock_flow) await processor.on_message(mock_msg, None, mock_flow)
mock_output_flow.send.assert_called_once() mock_output_flow.send.assert_called_once()

View file

@ -254,8 +254,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
"triples": mock_triples_flow, "triples": mock_triples_flow,
}.get(name)) }.get(name))
# Mock save_child_document and magic mock_flow.librarian.save_child_document = AsyncMock(return_value="mock-id")
processor.librarian.save_child_document = AsyncMock(return_value="mock-id")
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic: with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
mock_magic.from_buffer.return_value = "text/markdown" mock_magic.from_buffer.return_value = "text/markdown"
@ -310,7 +309,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
"triples": mock_triples_flow, "triples": mock_triples_flow,
}.get(name)) }.get(name))
processor.librarian.save_child_document = AsyncMock(return_value="mock-id") mock_flow.librarian.save_child_document = AsyncMock(return_value="mock-id")
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic: with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
mock_magic.from_buffer.return_value = "application/pdf" mock_magic.from_buffer.return_value = "application/pdf"
@ -361,7 +360,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
"triples": mock_triples_flow, "triples": mock_triples_flow,
}.get(name)) }.get(name))
processor.librarian.save_child_document = AsyncMock(return_value="mock-id") mock_flow.librarian.save_child_document = AsyncMock(return_value="mock-id")
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic: with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
mock_magic.from_buffer.return_value = "application/pdf" mock_magic.from_buffer.return_value = "application/pdf"
@ -374,7 +373,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
assert mock_triples_flow.send.call_count == 2 assert mock_triples_flow.send.call_count == 2
# save_child_document called twice (page + image) # save_child_document called twice (page + image)
assert processor.librarian.save_child_document.call_count == 2 assert mock_flow.librarian.save_child_document.call_count == 2
@patch('trustgraph.base.flow_processor.FlowProcessor.add_args') @patch('trustgraph.base.flow_processor.FlowProcessor.add_args')
def test_add_args(self, mock_parent_add_args): def test_add_args(self, mock_parent_add_args):

View file

@ -34,7 +34,7 @@ class _Identity:
self.source = "api-key" self.source = "api-key"
def _allow_auth(identity=None): def _allow_auth(identity=None, workspaces=None):
"""Build an Auth double that authenticates to ``identity`` and """Build an Auth double that authenticates to ``identity`` and
allows every authorise() call.""" allows every authorise() call."""
auth = MagicMock() auth = MagicMock()
@ -42,16 +42,18 @@ def _allow_auth(identity=None):
return_value=identity or _Identity(), return_value=identity or _Identity(),
) )
auth.authorise = AsyncMock(return_value=None) auth.authorise = AsyncMock(return_value=None)
auth.known_workspaces = workspaces or {"default", "acme"}
return auth return auth
def _deny_auth(identity=None): def _deny_auth(identity=None, workspaces=None):
"""Build an Auth double that authenticates but denies authorise.""" """Build an Auth double that authenticates but denies authorise."""
auth = MagicMock() auth = MagicMock()
auth.authenticate = AsyncMock( auth.authenticate = AsyncMock(
return_value=identity or _Identity(), return_value=identity or _Identity(),
) )
auth.authorise = AsyncMock(side_effect=access_denied()) auth.authorise = AsyncMock(side_effect=access_denied())
auth.known_workspaces = workspaces or {"default", "acme"}
return auth return auth

View file

@ -176,7 +176,7 @@ class TestDispatcherManager:
params = {"kind": "test_kind"} params = {"kind": "test_kind"}
result = await manager.process_global_service("data", "responder", params) result = await manager.process_global_service("data", "responder", params)
manager.invoke_global_service.assert_called_once_with("data", "responder", "test_kind") manager.invoke_global_service.assert_called_once_with("data", "responder", "test_kind", workspace=None)
assert result == "global_result" assert result == "global_result"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -185,24 +185,24 @@ class TestDispatcherManager:
mock_backend = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
# Pre-populate with existing dispatcher # Pre-populate with existing dispatcher
mock_dispatcher = Mock() mock_dispatcher = Mock()
mock_dispatcher.process = AsyncMock(return_value="cached_result") mock_dispatcher.process = AsyncMock(return_value="cached_result")
manager.dispatchers[(None, "config")] = mock_dispatcher manager.dispatchers[(None, "iam")] = mock_dispatcher
result = await manager.invoke_global_service("data", "responder", "config") result = await manager.invoke_global_service("data", "responder", "iam")
mock_dispatcher.process.assert_called_once_with("data", "responder") mock_dispatcher.process.assert_called_once_with("data", "responder")
assert result == "cached_result" assert result == "cached_result"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invoke_global_service_creates_new_dispatcher(self): async def test_invoke_global_service_creates_new_dispatcher(self):
"""Test invoke_global_service creates new dispatcher""" """Test invoke_global_service creates new dispatcher for system service"""
mock_backend = Mock() mock_backend = Mock()
mock_config_receiver = Mock() mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock()) manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
with patch('trustgraph.gateway.dispatch.manager.global_dispatchers') as mock_dispatchers: with patch('trustgraph.gateway.dispatch.manager.global_dispatchers') as mock_dispatchers:
mock_dispatcher_class = Mock() mock_dispatcher_class = Mock()
mock_dispatcher = Mock() mock_dispatcher = Mock()
@ -210,25 +210,51 @@ class TestDispatcherManager:
mock_dispatcher.process = AsyncMock(return_value="new_result") mock_dispatcher.process = AsyncMock(return_value="new_result")
mock_dispatcher_class.return_value = mock_dispatcher mock_dispatcher_class.return_value = mock_dispatcher
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
result = await manager.invoke_global_service("data", "responder", "config") result = await manager.invoke_global_service("data", "responder", "iam")
# Verify dispatcher was created with correct parameters
mock_dispatcher_class.assert_called_once_with( mock_dispatcher_class.assert_called_once_with(
backend=mock_backend, backend=mock_backend,
timeout=120, timeout=120,
consumer="api-gateway-config-request", consumer="api-gateway-iam-request",
subscriber="api-gateway-config-request", subscriber="api-gateway-iam-request",
request_queue=None, request_queue=None,
response_queue=None response_queue=None
) )
mock_dispatcher.start.assert_called_once() mock_dispatcher.start.assert_called_once()
mock_dispatcher.process.assert_called_once_with("data", "responder") mock_dispatcher.process.assert_called_once_with("data", "responder")
# Verify dispatcher was cached assert manager.dispatchers[(None, "iam")] == mock_dispatcher
assert manager.dispatchers[(None, "config")] == mock_dispatcher
assert result == "new_result" assert result == "new_result"
@pytest.mark.asyncio
async def test_invoke_global_service_workspace_required_for_workspace_dispatchers(self):
"""Workspace dispatchers (config, flow, etc.) require a workspace"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
with pytest.raises(RuntimeError, match="Workspace is required for config"):
await manager.invoke_global_service("data", "responder", "config")
@pytest.mark.asyncio
async def test_invoke_global_service_workspace_dispatcher_with_workspace(self):
"""Workspace dispatchers work when workspace is provided"""
mock_backend = Mock()
mock_config_receiver = Mock()
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
mock_dispatcher = Mock()
mock_dispatcher.process = AsyncMock(return_value="ws_result")
manager.dispatchers[("alice", "config")] = mock_dispatcher
result = await manager.invoke_global_service(
"data", "responder", "config", workspace="alice",
)
mock_dispatcher.process.assert_called_once_with("data", "responder")
assert result == "ws_result"
def test_dispatch_flow_import_returns_method(self): def test_dispatch_flow_import_returns_method(self):
"""Test dispatch_flow_import returns correct method""" """Test dispatch_flow_import returns correct method"""
mock_backend = Mock() mock_backend = Mock()
@ -610,7 +636,7 @@ class TestDispatcherManager:
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
results = await asyncio.gather(*[ results = await asyncio.gather(*[
manager.invoke_global_service("data", "responder", "config") manager.invoke_global_service("data", "responder", "iam")
for _ in range(5) for _ in range(5)
]) ])
@ -618,7 +644,7 @@ class TestDispatcherManager:
"Dispatcher class instantiated more than once — duplicate consumer bug" "Dispatcher class instantiated more than once — duplicate consumer bug"
) )
assert mock_dispatcher.start.call_count == 1 assert mock_dispatcher.start.call_count == 1
assert manager.dispatchers[(None, "config")] is mock_dispatcher assert manager.dispatchers[(None, "iam")] is mock_dispatcher
assert all(r == "result" for r in results) assert all(r == "result" for r in results)
@pytest.mark.asyncio @pytest.mark.asyncio

View file

@ -33,12 +33,11 @@ def _make_librarian(min_chunk_size=1):
def _make_doc_metadata( def _make_doc_metadata(
doc_id="doc-1", kind="application/pdf", workspace="alice", title="Test Doc" doc_id="doc-1", kind="application/pdf", title="Test Doc"
): ):
meta = MagicMock() meta = MagicMock()
meta.id = doc_id meta.id = doc_id
meta.kind = kind meta.kind = kind
meta.workspace = workspace
meta.title = title meta.title = title
meta.time = 1700000000 meta.time = 1700000000
meta.comments = "" meta.comments = ""
@ -47,21 +46,20 @@ def _make_doc_metadata(
def _make_begin_request( def _make_begin_request(
doc_id="doc-1", kind="application/pdf", workspace="alice", doc_id="doc-1", kind="application/pdf",
total_size=10_000_000, chunk_size=0 total_size=10_000_000, chunk_size=0
): ):
req = MagicMock() req = MagicMock()
req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, workspace=workspace) req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind)
req.total_size = total_size req.total_size = total_size
req.chunk_size = chunk_size req.chunk_size = chunk_size
return req return req
def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, workspace="alice", content=b"data"): def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, content=b"data"):
req = MagicMock() req = MagicMock()
req.upload_id = upload_id req.upload_id = upload_id
req.chunk_index = chunk_index req.chunk_index = chunk_index
req.workspace = workspace
req.content = base64.b64encode(content) req.content = base64.b64encode(content)
return req return req
@ -76,7 +74,7 @@ def _make_session(
if document_metadata is None: if document_metadata is None:
document_metadata = json.dumps({ document_metadata = json.dumps({
"id": document_id, "kind": "application/pdf", "id": document_id, "kind": "application/pdf",
"workspace": workspace, "title": "Test", "time": 1700000000, "title": "Test", "time": 1700000000,
"comments": "", "tags": [], "comments": "", "tags": [],
}) })
return { return {
@ -105,7 +103,7 @@ class TestBeginUpload:
lib.blob_store.create_multipart_upload.return_value = "s3-upload-id" lib.blob_store.create_multipart_upload.return_value = "s3-upload-id"
req = _make_begin_request(total_size=10_000_000) req = _make_begin_request(total_size=10_000_000)
resp = await lib.begin_upload(req) resp = await lib.begin_upload(req, "alice")
assert resp.error is None assert resp.error is None
assert resp.upload_id is not None assert resp.upload_id is not None
@ -119,7 +117,7 @@ class TestBeginUpload:
lib.blob_store.create_multipart_upload.return_value = "s3-id" lib.blob_store.create_multipart_upload.return_value = "s3-id"
req = _make_begin_request(total_size=10_000, chunk_size=3000) req = _make_begin_request(total_size=10_000, chunk_size=3000)
resp = await lib.begin_upload(req) resp = await lib.begin_upload(req, "alice")
assert resp.chunk_size == 3000 assert resp.chunk_size == 3000
assert resp.total_chunks == math.ceil(10_000 / 3000) assert resp.total_chunks == math.ceil(10_000 / 3000)
@ -130,7 +128,7 @@ class TestBeginUpload:
req = _make_begin_request(kind="") req = _make_begin_request(kind="")
with pytest.raises(RequestError, match="MIME type.*required"): with pytest.raises(RequestError, match="MIME type.*required"):
await lib.begin_upload(req) await lib.begin_upload(req, "alice")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_rejects_duplicate_document(self): async def test_rejects_duplicate_document(self):
@ -139,7 +137,7 @@ class TestBeginUpload:
req = _make_begin_request() req = _make_begin_request()
with pytest.raises(RequestError, match="already exists"): with pytest.raises(RequestError, match="already exists"):
await lib.begin_upload(req) await lib.begin_upload(req, "alice")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_rejects_zero_size(self): async def test_rejects_zero_size(self):
@ -148,7 +146,7 @@ class TestBeginUpload:
req = _make_begin_request(total_size=0) req = _make_begin_request(total_size=0)
with pytest.raises(RequestError, match="positive"): with pytest.raises(RequestError, match="positive"):
await lib.begin_upload(req) await lib.begin_upload(req, "alice")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_rejects_chunk_below_minimum(self): async def test_rejects_chunk_below_minimum(self):
@ -157,7 +155,7 @@ class TestBeginUpload:
req = _make_begin_request(total_size=10_000, chunk_size=512) req = _make_begin_request(total_size=10_000, chunk_size=512)
with pytest.raises(RequestError, match="below minimum"): with pytest.raises(RequestError, match="below minimum"):
await lib.begin_upload(req) await lib.begin_upload(req, "alice")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_calls_s3_create_multipart(self): async def test_calls_s3_create_multipart(self):
@ -166,7 +164,7 @@ class TestBeginUpload:
lib.blob_store.create_multipart_upload.return_value = "s3-id" lib.blob_store.create_multipart_upload.return_value = "s3-id"
req = _make_begin_request(kind="application/pdf") req = _make_begin_request(kind="application/pdf")
await lib.begin_upload(req) await lib.begin_upload(req, "alice")
lib.blob_store.create_multipart_upload.assert_called_once() lib.blob_store.create_multipart_upload.assert_called_once()
# create_multipart_upload(object_id, kind) — positional args # create_multipart_upload(object_id, kind) — positional args
@ -180,7 +178,7 @@ class TestBeginUpload:
lib.blob_store.create_multipart_upload.return_value = "s3-id" lib.blob_store.create_multipart_upload.return_value = "s3-id"
req = _make_begin_request(total_size=5_000_000) req = _make_begin_request(total_size=5_000_000)
resp = await lib.begin_upload(req) resp = await lib.begin_upload(req, "alice")
lib.table_store.create_upload_session.assert_called_once() lib.table_store.create_upload_session.assert_called_once()
kwargs = lib.table_store.create_upload_session.call_args[1] kwargs = lib.table_store.create_upload_session.call_args[1]
@ -195,7 +193,7 @@ class TestBeginUpload:
lib.blob_store.create_multipart_upload.return_value = "s3-id" lib.blob_store.create_multipart_upload.return_value = "s3-id"
req = _make_begin_request(kind="text/plain", total_size=1000) req = _make_begin_request(kind="text/plain", total_size=1000)
resp = await lib.begin_upload(req) resp = await lib.begin_upload(req, "alice")
assert resp.error is None assert resp.error is None
@ -213,7 +211,7 @@ class TestUploadChunk:
lib.blob_store.upload_part.return_value = "etag-1" lib.blob_store.upload_part.return_value = "etag-1"
req = _make_upload_chunk_request(chunk_index=0, content=b"chunk data") req = _make_upload_chunk_request(chunk_index=0, content=b"chunk data")
resp = await lib.upload_chunk(req) resp = await lib.upload_chunk(req, "alice")
assert resp.error is None assert resp.error is None
assert resp.chunk_index == 0 assert resp.chunk_index == 0
@ -229,7 +227,7 @@ class TestUploadChunk:
lib.blob_store.upload_part.return_value = "etag" lib.blob_store.upload_part.return_value = "etag"
req = _make_upload_chunk_request(chunk_index=0) req = _make_upload_chunk_request(chunk_index=0)
await lib.upload_chunk(req) await lib.upload_chunk(req, "alice")
kwargs = lib.blob_store.upload_part.call_args[1] kwargs = lib.blob_store.upload_part.call_args[1]
assert kwargs["part_number"] == 1 # 0-indexed chunk → 1-indexed part assert kwargs["part_number"] == 1 # 0-indexed chunk → 1-indexed part
@ -242,7 +240,7 @@ class TestUploadChunk:
lib.blob_store.upload_part.return_value = "etag" lib.blob_store.upload_part.return_value = "etag"
req = _make_upload_chunk_request(chunk_index=3) req = _make_upload_chunk_request(chunk_index=3)
await lib.upload_chunk(req) await lib.upload_chunk(req, "alice")
kwargs = lib.blob_store.upload_part.call_args[1] kwargs = lib.blob_store.upload_part.call_args[1]
assert kwargs["part_number"] == 4 assert kwargs["part_number"] == 4
@ -254,7 +252,7 @@ class TestUploadChunk:
req = _make_upload_chunk_request() req = _make_upload_chunk_request()
with pytest.raises(RequestError, match="not found"): with pytest.raises(RequestError, match="not found"):
await lib.upload_chunk(req) await lib.upload_chunk(req, "alice")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_rejects_wrong_user(self): async def test_rejects_wrong_user(self):
@ -262,9 +260,9 @@ class TestUploadChunk:
session = _make_session(workspace="alice") session = _make_session(workspace="alice")
lib.table_store.get_upload_session.return_value = session lib.table_store.get_upload_session.return_value = session
req = _make_upload_chunk_request(workspace="bob") req = _make_upload_chunk_request()
with pytest.raises(RequestError, match="Not authorized"): with pytest.raises(RequestError, match="Not authorized"):
await lib.upload_chunk(req) await lib.upload_chunk(req, "bob")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_rejects_negative_chunk_index(self): async def test_rejects_negative_chunk_index(self):
@ -274,7 +272,7 @@ class TestUploadChunk:
req = _make_upload_chunk_request(chunk_index=-1) req = _make_upload_chunk_request(chunk_index=-1)
with pytest.raises(RequestError, match="Invalid chunk index"): with pytest.raises(RequestError, match="Invalid chunk index"):
await lib.upload_chunk(req) await lib.upload_chunk(req, "alice")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_rejects_out_of_range_chunk_index(self): async def test_rejects_out_of_range_chunk_index(self):
@ -284,7 +282,7 @@ class TestUploadChunk:
req = _make_upload_chunk_request(chunk_index=5) req = _make_upload_chunk_request(chunk_index=5)
with pytest.raises(RequestError, match="Invalid chunk index"): with pytest.raises(RequestError, match="Invalid chunk index"):
await lib.upload_chunk(req) await lib.upload_chunk(req, "alice")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_progress_tracking(self): async def test_progress_tracking(self):
@ -297,7 +295,7 @@ class TestUploadChunk:
lib.blob_store.upload_part.return_value = "e3" lib.blob_store.upload_part.return_value = "e3"
req = _make_upload_chunk_request(chunk_index=2) req = _make_upload_chunk_request(chunk_index=2)
resp = await lib.upload_chunk(req) resp = await lib.upload_chunk(req, "alice")
# Dict gets chunk 2 added (len=3), then +1 => 4 # Dict gets chunk 2 added (len=3), then +1 => 4
assert resp.chunks_received == 4 assert resp.chunks_received == 4
@ -316,7 +314,7 @@ class TestUploadChunk:
lib.blob_store.upload_part.return_value = "e2" lib.blob_store.upload_part.return_value = "e2"
req = _make_upload_chunk_request(chunk_index=1) req = _make_upload_chunk_request(chunk_index=1)
resp = await lib.upload_chunk(req) resp = await lib.upload_chunk(req, "alice")
# 3 chunks × 3000 = 9000 > 5000, so capped # 3 chunks × 3000 = 9000 > 5000, so capped
assert resp.bytes_received <= 5000 assert resp.bytes_received <= 5000
@ -330,7 +328,7 @@ class TestUploadChunk:
raw = b"hello world binary data" raw = b"hello world binary data"
req = _make_upload_chunk_request(content=raw) req = _make_upload_chunk_request(content=raw)
await lib.upload_chunk(req) await lib.upload_chunk(req, "alice")
kwargs = lib.blob_store.upload_part.call_args[1] kwargs = lib.blob_store.upload_part.call_args[1]
assert kwargs["data"] == raw assert kwargs["data"] == raw
@ -353,9 +351,8 @@ class TestCompleteUpload:
req = MagicMock() req = MagicMock()
req.upload_id = "up-1" req.upload_id = "up-1"
req.workspace = "alice"
resp = await lib.complete_upload(req) resp = await lib.complete_upload(req, "alice")
assert resp.error is None assert resp.error is None
assert resp.document_id == "doc-1" assert resp.document_id == "doc-1"
@ -375,9 +372,8 @@ class TestCompleteUpload:
req = MagicMock() req = MagicMock()
req.upload_id = "up-1" req.upload_id = "up-1"
req.workspace = "alice"
await lib.complete_upload(req) await lib.complete_upload(req, "alice")
parts = lib.blob_store.complete_multipart_upload.call_args[1]["parts"] parts = lib.blob_store.complete_multipart_upload.call_args[1]["parts"]
part_numbers = [p[0] for p in parts] part_numbers = [p[0] for p in parts]
@ -394,10 +390,9 @@ class TestCompleteUpload:
req = MagicMock() req = MagicMock()
req.upload_id = "up-1" req.upload_id = "up-1"
req.workspace = "alice"
with pytest.raises(RequestError, match="Missing chunks"): with pytest.raises(RequestError, match="Missing chunks"):
await lib.complete_upload(req) await lib.complete_upload(req, "alice")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_rejects_expired_session(self): async def test_rejects_expired_session(self):
@ -406,10 +401,9 @@ class TestCompleteUpload:
req = MagicMock() req = MagicMock()
req.upload_id = "up-gone" req.upload_id = "up-gone"
req.workspace = "alice"
with pytest.raises(RequestError, match="not found"): with pytest.raises(RequestError, match="not found"):
await lib.complete_upload(req) await lib.complete_upload(req, "alice")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_rejects_wrong_user(self): async def test_rejects_wrong_user(self):
@ -419,10 +413,9 @@ class TestCompleteUpload:
req = MagicMock() req = MagicMock()
req.upload_id = "up-1" req.upload_id = "up-1"
req.workspace = "bob"
with pytest.raises(RequestError, match="Not authorized"): with pytest.raises(RequestError, match="Not authorized"):
await lib.complete_upload(req) await lib.complete_upload(req, "bob")
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -439,9 +432,8 @@ class TestAbortUpload:
req = MagicMock() req = MagicMock()
req.upload_id = "up-1" req.upload_id = "up-1"
req.workspace = "alice"
resp = await lib.abort_upload(req) resp = await lib.abort_upload(req, "alice")
assert resp.error is None assert resp.error is None
lib.blob_store.abort_multipart_upload.assert_called_once_with( lib.blob_store.abort_multipart_upload.assert_called_once_with(
@ -456,10 +448,9 @@ class TestAbortUpload:
req = MagicMock() req = MagicMock()
req.upload_id = "up-gone" req.upload_id = "up-gone"
req.workspace = "alice"
with pytest.raises(RequestError, match="not found"): with pytest.raises(RequestError, match="not found"):
await lib.abort_upload(req) await lib.abort_upload(req, "alice")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_rejects_wrong_user(self): async def test_rejects_wrong_user(self):
@ -469,10 +460,9 @@ class TestAbortUpload:
req = MagicMock() req = MagicMock()
req.upload_id = "up-1" req.upload_id = "up-1"
req.workspace = "bob"
with pytest.raises(RequestError, match="Not authorized"): with pytest.raises(RequestError, match="Not authorized"):
await lib.abort_upload(req) await lib.abort_upload(req, "bob")
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -492,9 +482,8 @@ class TestGetUploadStatus:
req = MagicMock() req = MagicMock()
req.upload_id = "up-1" req.upload_id = "up-1"
req.workspace = "alice"
resp = await lib.get_upload_status(req) resp = await lib.get_upload_status(req, "alice")
assert resp.upload_state == "in-progress" assert resp.upload_state == "in-progress"
assert resp.chunks_received == 3 assert resp.chunks_received == 3
@ -510,9 +499,8 @@ class TestGetUploadStatus:
req = MagicMock() req = MagicMock()
req.upload_id = "up-expired" req.upload_id = "up-expired"
req.workspace = "alice"
resp = await lib.get_upload_status(req) resp = await lib.get_upload_status(req, "alice")
assert resp.upload_state == "expired" assert resp.upload_state == "expired"
@ -527,9 +515,8 @@ class TestGetUploadStatus:
req = MagicMock() req = MagicMock()
req.upload_id = "up-1" req.upload_id = "up-1"
req.workspace = "alice"
resp = await lib.get_upload_status(req) resp = await lib.get_upload_status(req, "alice")
assert resp.missing_chunks == [] assert resp.missing_chunks == []
assert resp.chunks_received == 3 assert resp.chunks_received == 3
@ -544,10 +531,9 @@ class TestGetUploadStatus:
req = MagicMock() req = MagicMock()
req.upload_id = "up-1" req.upload_id = "up-1"
req.workspace = "bob"
with pytest.raises(RequestError, match="Not authorized"): with pytest.raises(RequestError, match="Not authorized"):
await lib.get_upload_status(req) await lib.get_upload_status(req, "bob")
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -564,12 +550,11 @@ class TestStreamDocument:
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 2000) lib.blob_store.get_range = AsyncMock(return_value=b"x" * 2000)
req = MagicMock() req = MagicMock()
req.workspace = "alice"
req.document_id = "doc-1" req.document_id = "doc-1"
req.chunk_size = 2000 req.chunk_size = 2000
chunks = [] chunks = []
async for resp in lib.stream_document(req): async for resp in lib.stream_document(req, "alice"):
chunks.append(resp) chunks.append(resp)
assert len(chunks) == 3 # ceil(5000/2000) assert len(chunks) == 3 # ceil(5000/2000)
@ -587,12 +572,11 @@ class TestStreamDocument:
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 500) lib.blob_store.get_range = AsyncMock(return_value=b"x" * 500)
req = MagicMock() req = MagicMock()
req.workspace = "alice"
req.document_id = "doc-1" req.document_id = "doc-1"
req.chunk_size = 2000 req.chunk_size = 2000
chunks = [] chunks = []
async for resp in lib.stream_document(req): async for resp in lib.stream_document(req, "alice"):
chunks.append(resp) chunks.append(resp)
assert len(chunks) == 1 assert len(chunks) == 1
@ -608,12 +592,11 @@ class TestStreamDocument:
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 100) lib.blob_store.get_range = AsyncMock(return_value=b"x" * 100)
req = MagicMock() req = MagicMock()
req.workspace = "alice"
req.document_id = "doc-1" req.document_id = "doc-1"
req.chunk_size = 2000 req.chunk_size = 2000
chunks = [] chunks = []
async for resp in lib.stream_document(req): async for resp in lib.stream_document(req, "alice"):
chunks.append(resp) chunks.append(resp)
# Verify the byte ranges passed to get_range # Verify the byte ranges passed to get_range
@ -630,12 +613,11 @@ class TestStreamDocument:
lib.blob_store.get_range = AsyncMock(return_value=b"x") lib.blob_store.get_range = AsyncMock(return_value=b"x")
req = MagicMock() req = MagicMock()
req.workspace = "alice"
req.document_id = "doc-1" req.document_id = "doc-1"
req.chunk_size = 0 # Should use default 1MB req.chunk_size = 0 # Should use default 1MB
chunks = [] chunks = []
async for resp in lib.stream_document(req): async for resp in lib.stream_document(req, "alice"):
chunks.append(resp) chunks.append(resp)
assert len(chunks) == 2 # ceil(2MB / 1MB) assert len(chunks) == 2 # ceil(2MB / 1MB)
@ -649,12 +631,11 @@ class TestStreamDocument:
lib.blob_store.get_range = AsyncMock(return_value=raw) lib.blob_store.get_range = AsyncMock(return_value=raw)
req = MagicMock() req = MagicMock()
req.workspace = "alice"
req.document_id = "doc-1" req.document_id = "doc-1"
req.chunk_size = 1000 req.chunk_size = 1000
chunks = [] chunks = []
async for resp in lib.stream_document(req): async for resp in lib.stream_document(req, "alice"):
chunks.append(resp) chunks.append(resp)
assert chunks[0].content == base64.b64encode(raw) assert chunks[0].content == base64.b64encode(raw)
@ -666,12 +647,11 @@ class TestStreamDocument:
lib.blob_store.get_size = AsyncMock(return_value=5000) lib.blob_store.get_size = AsyncMock(return_value=5000)
req = MagicMock() req = MagicMock()
req.workspace = "alice"
req.document_id = "doc-1" req.document_id = "doc-1"
req.chunk_size = 512 req.chunk_size = 512
with pytest.raises(RequestError, match="below minimum"): with pytest.raises(RequestError, match="below minimum"):
async for _ in lib.stream_document(req): async for _ in lib.stream_document(req, "alice"):
pass pass
@ -698,9 +678,8 @@ class TestListUploads:
] ]
req = MagicMock() req = MagicMock()
req.workspace = "alice"
resp = await lib.list_uploads(req) resp = await lib.list_uploads(req, "alice")
assert resp.error is None assert resp.error is None
assert len(resp.upload_sessions) == 1 assert len(resp.upload_sessions) == 1
@ -713,8 +692,7 @@ class TestListUploads:
lib.table_store.list_upload_sessions.return_value = [] lib.table_store.list_upload_sessions.return_value = []
req = MagicMock() req = MagicMock()
req.workspace = "alice"
resp = await lib.list_uploads(req) resp = await lib.list_uploads(req, "alice")
assert resp.upload_sessions == [] assert resp.upload_sessions == []

View file

@ -30,7 +30,6 @@ class TestDocumentMetadataTranslator:
"title": "Test Document", "title": "Test Document",
"comments": "No comments", "comments": "No comments",
"metadata": [], "metadata": [],
"workspace": "alice",
"tags": ["finance", "q4"], "tags": ["finance", "q4"],
"parent-id": "doc-100", "parent-id": "doc-100",
"document-type": "page", "document-type": "page",
@ -40,14 +39,12 @@ class TestDocumentMetadataTranslator:
assert obj.time == 1710000000 assert obj.time == 1710000000
assert obj.kind == "application/pdf" assert obj.kind == "application/pdf"
assert obj.title == "Test Document" assert obj.title == "Test Document"
assert obj.workspace == "alice"
assert obj.tags == ["finance", "q4"] assert obj.tags == ["finance", "q4"]
assert obj.parent_id == "doc-100" assert obj.parent_id == "doc-100"
assert obj.document_type == "page" assert obj.document_type == "page"
wire = self.tx.encode(obj) wire = self.tx.encode(obj)
assert wire["id"] == "doc-123" assert wire["id"] == "doc-123"
assert wire["workspace"] == "alice"
assert wire["parent-id"] == "doc-100" assert wire["parent-id"] == "doc-100"
assert wire["document-type"] == "page" assert wire["document-type"] == "page"
@ -80,10 +77,9 @@ class TestDocumentMetadataTranslator:
def test_falsy_fields_omitted_from_wire(self): def test_falsy_fields_omitted_from_wire(self):
"""Empty string fields should be omitted from wire format.""" """Empty string fields should be omitted from wire format."""
obj = DocumentMetadata(id="", time=0, workspace="") obj = DocumentMetadata(id="", time=0)
wire = self.tx.encode(obj) wire = self.tx.encode(obj)
assert "id" not in wire assert "id" not in wire
assert "workspace" not in wire
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -101,7 +97,6 @@ class TestProcessingMetadataTranslator:
"document-id": "doc-123", "document-id": "doc-123",
"time": 1710000000, "time": 1710000000,
"flow": "default", "flow": "default",
"workspace": "alice",
"collection": "my-collection", "collection": "my-collection",
"tags": ["tag1"], "tags": ["tag1"],
} }
@ -109,20 +104,17 @@ class TestProcessingMetadataTranslator:
assert obj.id == "proc-1" assert obj.id == "proc-1"
assert obj.document_id == "doc-123" assert obj.document_id == "doc-123"
assert obj.flow == "default" assert obj.flow == "default"
assert obj.workspace == "alice"
assert obj.collection == "my-collection" assert obj.collection == "my-collection"
assert obj.tags == ["tag1"] assert obj.tags == ["tag1"]
wire = self.tx.encode(obj) wire = self.tx.encode(obj)
assert wire["id"] == "proc-1" assert wire["id"] == "proc-1"
assert wire["document-id"] == "doc-123" assert wire["document-id"] == "doc-123"
assert wire["workspace"] == "alice"
assert wire["collection"] == "my-collection" assert wire["collection"] == "my-collection"
def test_missing_fields_use_defaults(self): def test_missing_fields_use_defaults(self):
obj = self.tx.decode({}) obj = self.tx.decode({})
assert obj.id is None assert obj.id is None
assert obj.workspace is None
assert obj.collection is None assert obj.collection is None
def test_tags_none_omitted(self): def test_tags_none_omitted(self):
@ -135,10 +127,9 @@ class TestProcessingMetadataTranslator:
wire = self.tx.encode(obj) wire = self.tx.encode(obj)
assert wire["tags"] == [] assert wire["tags"] == []
def test_workspace_and_collection_preserved(self): def test_collection_preserved(self):
"""Core pipeline routing fields must survive round-trip.""" """Core pipeline routing fields must survive round-trip."""
data = {"workspace": "bob", "collection": "research"} data = {"collection": "research"}
obj = self.tx.decode(data) obj = self.tx.decode(data)
wire = self.tx.encode(obj) wire = self.tx.encode(obj)
assert wire["workspace"] == "bob"
assert wire["collection"] == "research" assert wire["collection"] == "research"

View file

@ -27,7 +27,7 @@ CHUNK_CONTENT = {
@pytest.fixture @pytest.fixture
def mock_fetch_chunk(): def mock_fetch_chunk():
"""Create a mock fetch_chunk function""" """Create a mock fetch_chunk function"""
async def fetch(chunk_id, user): async def fetch(chunk_id):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}") return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
return fetch return fetch
@ -203,7 +203,7 @@ class TestQuery:
mock_rag.doc_embeddings_client = mock_doc_embeddings_client mock_rag.doc_embeddings_client = mock_doc_embeddings_client
# Mock fetch_chunk function # Mock fetch_chunk function
async def mock_fetch(chunk_id, user): async def mock_fetch(chunk_id):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}") return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
mock_rag.fetch_chunk = mock_fetch mock_rag.fetch_chunk = mock_fetch
@ -361,7 +361,7 @@ class TestQuery:
mock_rag.doc_embeddings_client = mock_doc_embeddings_client mock_rag.doc_embeddings_client = mock_doc_embeddings_client
# Mock fetch_chunk # Mock fetch_chunk
async def mock_fetch(chunk_id, user): async def mock_fetch(chunk_id):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}") return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
mock_rag.fetch_chunk = mock_fetch mock_rag.fetch_chunk = mock_fetch
@ -437,7 +437,7 @@ class TestQuery:
mock_rag.embeddings_client = mock_embeddings_client mock_rag.embeddings_client = mock_embeddings_client
mock_rag.doc_embeddings_client = mock_doc_embeddings_client mock_rag.doc_embeddings_client = mock_doc_embeddings_client
async def mock_fetch(chunk_id, user): async def mock_fetch(chunk_id):
return f"Content for {chunk_id}" return f"Content for {chunk_id}"
mock_rag.fetch_chunk = mock_fetch mock_rag.fetch_chunk = mock_fetch
@ -594,7 +594,7 @@ class TestQuery:
mock_rag.embeddings_client = mock_embeddings_client mock_rag.embeddings_client = mock_embeddings_client
mock_rag.doc_embeddings_client = mock_doc_embeddings_client mock_rag.doc_embeddings_client = mock_doc_embeddings_client
async def mock_fetch(chunk_id, user): async def mock_fetch(chunk_id):
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}") return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
mock_rag.fetch_chunk = mock_fetch mock_rag.fetch_chunk = mock_fetch

View file

@ -105,7 +105,7 @@ def build_mock_clients():
] ]
# 4. Chunk content # 4. Chunk content
async def mock_fetch(chunk_id, user): async def mock_fetch(chunk_id):
return { return {
CHUNK_A: CHUNK_A_CONTENT, CHUNK_A: CHUNK_A_CONTENT,
CHUNK_B: CHUNK_B_CONTENT, CHUNK_B: CHUNK_B_CONTENT,

View file

@ -218,7 +218,8 @@ class TestKgStoreConfiguration:
cassandra_host=['kg-env-host1', 'kg-env-host2', 'kg-env-host3'], cassandra_host=['kg-env-host1', 'kg-env-host2', 'kg-env-host3'],
cassandra_username='kg-env-user', cassandra_username='kg-env-user',
cassandra_password='kg-env-pass', cassandra_password='kg-env-pass',
keyspace='knowledge' keyspace='knowledge',
replication_factor=1,
) )
@patch('trustgraph.storage.knowledge.store.KnowledgeTableStore') @patch('trustgraph.storage.knowledge.store.KnowledgeTableStore')
@ -239,7 +240,8 @@ class TestKgStoreConfiguration:
cassandra_host=['explicit-host'], cassandra_host=['explicit-host'],
cassandra_username='explicit-user', cassandra_username='explicit-user',
cassandra_password='explicit-pass', cassandra_password='explicit-pass',
keyspace='knowledge' keyspace='knowledge',
replication_factor=1,
) )
@patch('trustgraph.storage.knowledge.store.KnowledgeTableStore') @patch('trustgraph.storage.knowledge.store.KnowledgeTableStore')
@ -260,7 +262,8 @@ class TestKgStoreConfiguration:
cassandra_host=['compat-host'], cassandra_host=['compat-host'],
cassandra_username=None, # Should be None since cassandra_user is ignored cassandra_username=None, # Should be None since cassandra_user is ignored
cassandra_password='compat-pass', cassandra_password='compat-pass',
keyspace='knowledge' keyspace='knowledge',
replication_factor=1,
) )
@patch('trustgraph.storage.knowledge.store.KnowledgeTableStore') @patch('trustgraph.storage.knowledge.store.KnowledgeTableStore')
@ -277,7 +280,8 @@ class TestKgStoreConfiguration:
cassandra_host=['cassandra'], cassandra_host=['cassandra'],
cassandra_username=None, cassandra_username=None,
cassandra_password=None, cassandra_password=None,
keyspace='knowledge' keyspace='knowledge',
replication_factor=1,
) )
@ -425,5 +429,6 @@ class TestConfigurationPriorityIntegration:
cassandra_host=['param-host'], # From parameter cassandra_host=['param-host'], # From parameter
cassandra_username='env-user', # From environment cassandra_username='env-user', # From environment
cassandra_password='env-pass', # From environment cassandra_password='env-pass', # From environment
keyspace='knowledge' keyspace='knowledge',
replication_factor=1,
) )

View file

@ -171,14 +171,16 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
@patch('trustgraph.base.llm_service.LlmService.__init__') @patch('trustgraph.base.llm_service.LlmService.__init__')
async def test_processor_initialization_without_api_key(self, mock_llm_init, mock_async_init, mock_openai_class): async def test_processor_initialization_without_api_key(self, mock_llm_init, mock_async_init, mock_openai_class):
"""Test processor initialization without API key (should fail)""" """Test processor initialization without API key uses placeholder"""
# Arrange # Arrange
mock_openai_client = MagicMock()
mock_openai_class.return_value = mock_openai_client
mock_async_init.return_value = None mock_async_init.return_value = None
mock_llm_init.return_value = None mock_llm_init.return_value = None
config = { config = {
'model': 'gpt-3.5-turbo', 'model': 'gpt-3.5-turbo',
'api_key': None, # No API key provided 'api_key': None,
'url': 'https://api.openai.com/v1', 'url': 'https://api.openai.com/v1',
'temperature': 0.0, 'temperature': 0.0,
'max_output': 4096, 'max_output': 4096,
@ -187,9 +189,10 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
'id': 'test-processor' 'id': 'test-processor'
} }
# Act & Assert processor = Processor(**config)
with pytest.raises(RuntimeError, match="OpenAI API key not specified"): mock_openai_class.assert_called_once_with(
processor = Processor(**config) base_url='https://api.openai.com/v1', api_key='not-set'
)
@patch('trustgraph.model.text_completion.openai.llm.OpenAI') @patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__')

View file

@ -41,7 +41,6 @@ def translator():
def graph_embeddings_request(): def graph_embeddings_request():
return KnowledgeRequest( return KnowledgeRequest(
operation="put-kg-core", operation="put-kg-core",
workspace="alice",
id="doc-1", id="doc-1",
flow="default", flow="default",
collection="testcoll", collection="testcoll",
@ -110,7 +109,7 @@ class TestKnowledgeRequestTranslatorGraphEmbeddings:
assert isinstance(decoded, KnowledgeRequest) assert isinstance(decoded, KnowledgeRequest)
assert decoded.operation == "put-kg-core" assert decoded.operation == "put-kg-core"
assert decoded.workspace == "alice" assert decoded.id == "doc-1"
assert decoded.id == "doc-1" assert decoded.id == "doc-1"
assert decoded.flow == "default" assert decoded.flow == "default"
assert decoded.collection == "testcoll" assert decoded.collection == "testcoll"

View file

@ -17,6 +17,7 @@ dependencies = [
"pika", "pika",
"confluent-kafka", "confluent-kafka",
"pyyaml", "pyyaml",
"websockets",
] ]
classifiers = [ classifiers = [
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",

View file

@ -217,7 +217,6 @@ class Library:
"title": title, "title": title,
"comments": comments, "comments": comments,
"metadata": triples, "metadata": triples,
"workspace": self.api.workspace,
"tags": tags "tags": tags
}, },
"content": base64.b64encode(document).decode("utf-8"), "content": base64.b64encode(document).decode("utf-8"),
@ -249,7 +248,6 @@ class Library:
"kind": kind, "kind": kind,
"title": title, "title": title,
"comments": comments, "comments": comments,
"workspace": self.api.workspace,
"tags": tags, "tags": tags,
}, },
"total-size": total_size, "total-size": total_size,
@ -377,7 +375,6 @@ class Library:
) )
for w in v["metadata"] for w in v["metadata"]
], ],
workspace = v.get("workspace", ""),
tags = v["tags"], tags = v["tags"],
parent_id = v.get("parent-id", ""), parent_id = v.get("parent-id", ""),
document_type = v.get("document-type", "source"), document_type = v.get("document-type", "source"),
@ -436,7 +433,6 @@ class Library:
) )
for w in doc["metadata"] for w in doc["metadata"]
], ],
workspace = doc.get("workspace", ""),
tags = doc["tags"], tags = doc["tags"],
parent_id = doc.get("parent-id", ""), parent_id = doc.get("parent-id", ""),
document_type = doc.get("document-type", "source"), document_type = doc.get("document-type", "source"),
@ -485,7 +481,6 @@ class Library:
"operation": "update-document", "operation": "update-document",
"workspace": self.api.workspace, "workspace": self.api.workspace,
"document-metadata": { "document-metadata": {
"workspace": self.api.workspace,
"document-id": id, "document-id": id,
"time": metadata.time, "time": metadata.time,
"title": metadata.title, "title": metadata.title,
@ -599,7 +594,6 @@ class Library:
"document-id": document_id, "document-id": document_id,
"time": int(time.time()), "time": int(time.time()),
"flow": flow, "flow": flow,
"workspace": self.api.workspace,
"collection": collection, "collection": collection,
"tags": tags, "tags": tags,
} }
@ -681,7 +675,6 @@ class Library:
document_id = v["document-id"], document_id = v["document-id"],
time = datetime.datetime.fromtimestamp(v["time"]), time = datetime.datetime.fromtimestamp(v["time"]),
flow = v["flow"], flow = v["flow"],
workspace = v.get("workspace", ""),
collection = v["collection"], collection = v["collection"],
tags = v["tags"], tags = v["tags"],
) )
@ -945,7 +938,6 @@ class Library:
"title": title, "title": title,
"comments": comments, "comments": comments,
"metadata": triples, "metadata": triples,
"workspace": self.api.workspace,
"tags": tags, "tags": tags,
"parent-id": parent_id, "parent-id": parent_id,
"document-type": "extracted", "document-type": "extracted",

View file

@ -65,7 +65,6 @@ class DocumentMetadata:
title: Document title title: Document title
comments: Additional comments or description comments: Additional comments or description
metadata: List of RDF triples providing structured metadata metadata: List of RDF triples providing structured metadata
workspace: Workspace the document belongs to
tags: List of tags for categorization tags: List of tags for categorization
parent_id: Parent document ID for child documents (empty for top-level docs) parent_id: Parent document ID for child documents (empty for top-level docs)
document_type: "source" for uploaded documents, "extracted" for derived content document_type: "source" for uploaded documents, "extracted" for derived content
@ -76,7 +75,6 @@ class DocumentMetadata:
title : str title : str
comments : str comments : str
metadata : List[Triple] metadata : List[Triple]
workspace : str
tags : List[str] tags : List[str]
parent_id : str = "" parent_id : str = ""
document_type : str = "source" document_type : str = "source"
@ -91,7 +89,6 @@ class ProcessingMetadata:
document_id: ID of the document being processed document_id: ID of the document being processed
time: Processing start timestamp time: Processing start timestamp
flow: Flow instance handling the processing flow: Flow instance handling the processing
workspace: Workspace the processing job belongs to
collection: Target collection for processed data collection: Target collection for processed data
tags: List of tags for categorization tags: List of tags for categorization
""" """
@ -99,7 +96,6 @@ class ProcessingMetadata:
document_id : str document_id : str
time : datetime.datetime time : datetime.datetime
flow : str flow : str
workspace : str
collection : str collection : str
tags : List[str] tags : List[str]

View file

@ -7,6 +7,7 @@ from . publisher import Publisher
from . subscriber import Subscriber from . subscriber import Subscriber
from . metrics import ProcessorMetrics, ConsumerMetrics, ProducerMetrics, SubscriberMetrics from . metrics import ProcessorMetrics, ConsumerMetrics, ProducerMetrics, SubscriberMetrics
from . logging import add_logging_args, setup_logging from . logging import add_logging_args, setup_logging
from . workspace_processor import WorkspaceProcessor
from . flow_processor import FlowProcessor from . flow_processor import FlowProcessor
from . consumer_spec import ConsumerSpec from . consumer_spec import ConsumerSpec
from . parameter_spec import ParameterSpec from . parameter_spec import ParameterSpec
@ -15,6 +16,7 @@ from . subscriber_spec import SubscriberSpec
from . request_response_spec import RequestResponseSpec from . request_response_spec import RequestResponseSpec
from . llm_service import LlmService, LlmResult, LlmChunk from . llm_service import LlmService, LlmResult, LlmChunk
from . librarian_client import LibrarianClient from . librarian_client import LibrarianClient
from . librarian_spec import LibrarianSpec
from . chunking_service import ChunkingService from . chunking_service import ChunkingService
from . embeddings_service import EmbeddingsService from . embeddings_service import EmbeddingsService
from . embeddings_client import EmbeddingsClientSpec from . embeddings_client import EmbeddingsClientSpec

View file

@ -71,6 +71,11 @@ class AsyncProcessor:
# { "handler": async_fn, "types": set_or_none } # { "handler": async_fn, "types": set_or_none }
self.config_handlers = [] self.config_handlers = []
# Workspace lifecycle handlers, called when workspaces are
# created or deleted. Each entry is an async callable:
# async def handler(workspace_changes: WorkspaceChanges)
self.workspace_handlers = []
# Track the current config version for dedup # Track the current config version for dedup
self.config_version = 0 self.config_version = 0
@ -209,6 +214,8 @@ class AsyncProcessor:
# Call the handler once per workspace # Call the handler once per workspace
for ws, config in per_ws.items(): for ws, config in per_ws.items():
if ws.startswith("_"):
continue
await entry["handler"](ws, config, version) await entry["handler"](ws, config, version)
logger.info( logger.info(
@ -249,6 +256,10 @@ class AsyncProcessor:
"types": set(types) if types else None, "types": set(types) if types else None,
}) })
# Register a handler for workspace lifecycle events
def register_workspace_handler(self, handler: Callable[..., Any]) -> None:
self.workspace_handlers.append(handler)
# Called when a config notify message arrives # Called when a config notify message arrives
async def on_config_notify(self, message, consumer, flow): async def on_config_notify(self, message, consumer, flow):
@ -264,6 +275,16 @@ class AsyncProcessor:
) )
return return
# Dispatch workspace lifecycle events before config handlers
if v.workspace_changes and self.workspace_handlers:
for handler in self.workspace_handlers:
try:
await handler(v.workspace_changes)
except Exception as e:
logger.error(
f"Workspace handler failed: {e}", exc_info=True
)
notify_types = set(changes.keys()) notify_types = set(changes.keys())
# Filter out handlers that don't care about any of the changed # Filter out handlers that don't care about any of the changed
@ -310,6 +331,8 @@ class AsyncProcessor:
per_ws.setdefault(ws, {})[t] = kv per_ws.setdefault(ws, {})[t] = kv
for ws, config in per_ws.items(): for ws, config in per_ws.items():
if ws.startswith("_"):
continue
await entry["handler"]( await entry["handler"](
ws, config, notify_version, ws, config, notify_version,
) )

View file

@ -151,7 +151,7 @@ def resolve_cassandra_config(
def get_cassandra_config_from_params( def get_cassandra_config_from_params(
params: dict, params: dict,
default_keyspace: Optional[str] = None default_keyspace: Optional[str] = None
) -> Tuple[List[str], Optional[str], Optional[str], Optional[str]]: ) -> Tuple[List[str], Optional[str], Optional[str], Optional[str], int]:
""" """
Extract and resolve Cassandra configuration from a parameters dictionary. Extract and resolve Cassandra configuration from a parameters dictionary.
@ -160,14 +160,12 @@ def get_cassandra_config_from_params(
default_keyspace: Optional default keyspace if not specified in params default_keyspace: Optional default keyspace if not specified in params
Returns: Returns:
tuple: (hosts_list, username, password, keyspace) tuple: (hosts_list, username, password, keyspace, replication_factor)
""" """
# Get Cassandra parameters
host = params.get('cassandra_host') host = params.get('cassandra_host')
username = params.get('cassandra_username') username = params.get('cassandra_username')
password = params.get('cassandra_password') password = params.get('cassandra_password')
# Use resolve function to handle defaults and list conversion
return resolve_cassandra_config( return resolve_cassandra_config(
host=host, host=host,
username=username, username=username,

View file

@ -4,13 +4,11 @@ for chunk-size and chunk-overlap parameters, and librarian client for
fetching large document content. fetching large document content.
""" """
import asyncio
import base64
import logging import logging
from .flow_processor import FlowProcessor from .flow_processor import FlowProcessor
from .parameter_spec import ParameterSpec from .parameter_spec import ParameterSpec
from .librarian_client import LibrarianClient from .librarian_spec import LibrarianSpec
# Module logger # Module logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,35 +33,27 @@ class ChunkingService(FlowProcessor):
ParameterSpec(name="chunk-overlap") ParameterSpec(name="chunk-overlap")
) )
# Librarian client self.register_specification(
self.librarian = LibrarianClient( LibrarianSpec()
id=id,
backend=self.pubsub,
taskgroup=self.taskgroup,
) )
logger.debug("ChunkingService initialized with parameter specifications") logger.debug("ChunkingService initialized with parameter specifications")
async def start(self): async def get_document_text(self, doc, flow):
await super(ChunkingService, self).start()
await self.librarian.start()
async def get_document_text(self, doc, workspace):
""" """
Get text content from a TextDocument, fetching from librarian if needed. Get text content from a TextDocument, fetching from librarian if needed.
Args: Args:
doc: TextDocument with either inline text or document_id doc: TextDocument with either inline text or document_id
workspace: Workspace for librarian lookup (from flow.workspace) flow: Flow object with librarian client
Returns: Returns:
str: The document text content str: The document text content
""" """
if doc.document_id and not doc.text: if doc.document_id and not doc.text:
logger.info(f"Fetching document {doc.document_id} from librarian...") logger.info(f"Fetching document {doc.document_id} from librarian...")
text = await self.librarian.fetch_document_text( text = await flow.librarian.fetch_document_text(
document_id=doc.document_id, document_id=doc.document_id,
workspace=workspace,
) )
logger.info(f"Fetched {len(text)} characters from librarian") logger.info(f"Fetched {len(text)} characters from librarian")
return text return text

View file

@ -1,6 +1,4 @@
import asyncio
class Flow: class Flow:
""" """
Runtime representation of a deployed flow process. Runtime representation of a deployed flow process.
@ -22,16 +20,22 @@ class Flow:
self.parameter = {} self.parameter = {}
self.librarian = None
for spec in processor.specifications: for spec in processor.specifications:
spec.add(self, processor, defn) spec.add(self, processor, defn)
async def start(self): async def start(self):
if self.librarian:
await self.librarian.start()
for c in self.consumer.values(): for c in self.consumer.values():
await c.start() await c.start()
async def stop(self): async def stop(self):
for c in self.consumer.values(): for c in self.consumer.values():
await c.stop() await c.stop()
if self.librarian:
await self.librarian.stop()
def __call__(self, key): def __call__(self, key):
if key in self.producer: return self.producer[key] if key in self.producer: return self.producer[key]

View file

@ -14,7 +14,7 @@ from .. schema import Error
from .. schema import config_request_queue, config_response_queue from .. schema import config_request_queue, config_response_queue
from .. schema import config_push_queue from .. schema import config_push_queue
from .. log_level import LogLevel from .. log_level import LogLevel
from . async_processor import AsyncProcessor from . workspace_processor import WorkspaceProcessor
from . flow import Flow from . flow import Flow
# Module logger # Module logger
@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
# Parent class for configurable processors, configured with flows by # Parent class for configurable processors, configured with flows by
# the config service # the config service
class FlowProcessor(AsyncProcessor): class FlowProcessor(WorkspaceProcessor):
def __init__(self, **params): def __init__(self, **params):
@ -113,7 +113,7 @@ class FlowProcessor(AsyncProcessor):
@staticmethod @staticmethod
def add_args(parser: ArgumentParser) -> None: def add_args(parser: ArgumentParser) -> None:
AsyncProcessor.add_args(parser) WorkspaceProcessor.add_args(parser)
# parser.add_argument( # parser.add_argument(
# '--rate-limit-retry', # '--rate-limit-retry',

View file

@ -10,7 +10,7 @@ Usage:
id=id, backend=self.pubsub, taskgroup=self.taskgroup, **params id=id, backend=self.pubsub, taskgroup=self.taskgroup, **params
) )
await self.librarian.start() await self.librarian.start()
content = await self.librarian.fetch_document_content(doc_id, workspace) content = await self.librarian.fetch_document_content(doc_id)
""" """
import asyncio import asyncio
@ -39,9 +39,14 @@ class LibrarianClient:
librarian_response_q = params.get( librarian_response_q = params.get(
"librarian_response_queue", librarian_response_queue, "librarian_response_queue", librarian_response_queue,
) )
subscriber = params.get(
"librarian_subscriber", f"{id}-librarian",
)
flow_name = params.get("flow_name")
librarian_request_metrics = ProducerMetrics( librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request", processor=id, flow=flow_name, name="librarian-request",
) )
self._producer = Producer( self._producer = Producer(
@ -52,7 +57,7 @@ class LibrarianClient:
) )
librarian_response_metrics = ConsumerMetrics( librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response", processor=id, flow=flow_name, name="librarian-response",
) )
self._consumer = Consumer( self._consumer = Consumer(
@ -60,7 +65,7 @@ class LibrarianClient:
backend=backend, backend=backend,
flow=None, flow=None,
topic=librarian_response_q, topic=librarian_response_q,
subscriber=f"{id}-librarian", subscriber=subscriber,
schema=LibrarianResponse, schema=LibrarianResponse,
handler=self._on_response, handler=self._on_response,
metrics=librarian_response_metrics, metrics=librarian_response_metrics,
@ -76,6 +81,11 @@ class LibrarianClient:
await self._producer.start() await self._producer.start()
await self._consumer.start() await self._consumer.start()
async def stop(self):
"""Stop the librarian producer and consumer."""
await self._consumer.stop()
await self._producer.stop()
async def _on_response(self, msg, consumer, flow): async def _on_response(self, msg, consumer, flow):
"""Route librarian responses to the right waiter.""" """Route librarian responses to the right waiter."""
response = msg.value() response = msg.value()
@ -150,7 +160,7 @@ class LibrarianClient:
finally: finally:
self._streams.pop(request_id, None) self._streams.pop(request_id, None)
async def fetch_document_content(self, document_id, workspace, timeout=120): async def fetch_document_content(self, document_id, timeout=120):
"""Fetch document content using streaming. """Fetch document content using streaming.
Returns base64-encoded content. Caller is responsible for decoding. Returns base64-encoded content. Caller is responsible for decoding.
@ -158,7 +168,6 @@ class LibrarianClient:
req = LibrarianRequest( req = LibrarianRequest(
operation="stream-document", operation="stream-document",
document_id=document_id, document_id=document_id,
workspace=workspace,
) )
chunks = await self.stream(req, timeout=timeout) chunks = await self.stream(req, timeout=timeout)
@ -176,24 +185,23 @@ class LibrarianClient:
return base64.b64encode(raw) return base64.b64encode(raw)
async def fetch_document_text(self, document_id, workspace, timeout=120): async def fetch_document_text(self, document_id, timeout=120):
"""Fetch document content and decode as UTF-8 text.""" """Fetch document content and decode as UTF-8 text."""
content = await self.fetch_document_content( content = await self.fetch_document_content(
document_id, workspace, timeout=timeout, document_id, timeout=timeout,
) )
return base64.b64decode(content).decode("utf-8") return base64.b64decode(content).decode("utf-8")
async def fetch_document_metadata(self, document_id, workspace, timeout=120): async def fetch_document_metadata(self, document_id, timeout=120):
"""Fetch document metadata from the librarian.""" """Fetch document metadata from the librarian."""
req = LibrarianRequest( req = LibrarianRequest(
operation="get-document-metadata", operation="get-document-metadata",
document_id=document_id, document_id=document_id,
workspace=workspace,
) )
response = await self.request(req, timeout=timeout) response = await self.request(req, timeout=timeout)
return response.document_metadata return response.document_metadata
async def save_child_document(self, doc_id, parent_id, workspace, content, async def save_child_document(self, doc_id, parent_id, content,
document_type="chunk", title=None, document_type="chunk", title=None,
kind="text/plain", timeout=120): kind="text/plain", timeout=120):
"""Save a child document to the librarian.""" """Save a child document to the librarian."""
@ -202,7 +210,6 @@ class LibrarianClient:
doc_metadata = DocumentMetadata( doc_metadata = DocumentMetadata(
id=doc_id, id=doc_id,
workspace=workspace,
kind=kind, kind=kind,
title=title or doc_id, title=title or doc_id,
parent_id=parent_id, parent_id=parent_id,
@ -218,7 +225,7 @@ class LibrarianClient:
await self.request(req, timeout=timeout) await self.request(req, timeout=timeout)
return doc_id return doc_id
async def save_document(self, doc_id, workspace, content, title=None, async def save_document(self, doc_id, content, title=None,
document_type="answer", kind="text/plain", document_type="answer", kind="text/plain",
timeout=120): timeout=120):
"""Save a document to the librarian.""" """Save a document to the librarian."""
@ -227,7 +234,6 @@ class LibrarianClient:
doc_metadata = DocumentMetadata( doc_metadata = DocumentMetadata(
id=doc_id, id=doc_id,
workspace=workspace,
kind=kind, kind=kind,
title=title or doc_id, title=title or doc_id,
document_type=document_type, document_type=document_type,
@ -238,7 +244,6 @@ class LibrarianClient:
document_id=doc_id, document_id=doc_id,
document_metadata=doc_metadata, document_metadata=doc_metadata,
content=base64.b64encode(content).decode("utf-8"), content=base64.b64encode(content).decode("utf-8"),
workspace=workspace,
) )
await self.request(req, timeout=timeout) await self.request(req, timeout=timeout)

View file

@ -0,0 +1,31 @@
from __future__ import annotations
import uuid
from typing import Any
from . spec import Spec
from . librarian_client import LibrarianClient
class LibrarianSpec(Spec):
def __init__(self, request_name="librarian-request",
response_name="librarian-response"):
self.request_name = request_name
self.response_name = response_name
def add(self, flow: Any, processor: Any, definition: dict[str, Any]) -> None:
client = LibrarianClient(
id=flow.id,
backend=processor.pubsub,
taskgroup=processor.taskgroup,
librarian_request_queue=definition["topics"][self.request_name],
librarian_response_queue=definition["topics"][self.response_name],
librarian_subscriber=(
processor.id + "--" + flow.workspace + "--" +
flow.name + "--librarian--" + str(uuid.uuid4())
),
flow_name=flow.name,
)
flow.librarian = client

View file

@ -0,0 +1,66 @@
from __future__ import annotations
from argparse import ArgumentParser
import logging
from . async_processor import AsyncProcessor
logger = logging.getLogger(__name__)
WORKSPACES_NAMESPACE = "__workspaces__"
WORKSPACE_TYPE = "workspace"
class WorkspaceProcessor(AsyncProcessor):
def __init__(self, **params):
super(WorkspaceProcessor, self).__init__(**params)
self.active_workspaces = set()
self.register_workspace_handler(self._handle_workspace_changes)
async def _discover_workspaces(self):
client = self._create_config_client()
try:
await client.start()
type_data, version = await self._fetch_type_all_workspaces(
client, WORKSPACE_TYPE,
)
for ws in type_data:
if ws == WORKSPACES_NAMESPACE:
for workspace_id in type_data[ws]:
if workspace_id not in self.active_workspaces:
self.active_workspaces.add(workspace_id)
await self.on_workspace_created(workspace_id)
finally:
await client.stop()
async def _handle_workspace_changes(self, workspace_changes):
for workspace_id in workspace_changes.created:
if workspace_id not in self.active_workspaces:
self.active_workspaces.add(workspace_id)
logger.info(f"Workspace created: {workspace_id}")
await self.on_workspace_created(workspace_id)
for workspace_id in workspace_changes.deleted:
if workspace_id in self.active_workspaces:
logger.info(f"Workspace deleted: {workspace_id}")
await self.on_workspace_deleted(workspace_id)
self.active_workspaces.discard(workspace_id)
async def on_workspace_created(self, workspace):
pass
async def on_workspace_deleted(self, workspace):
pass
async def start(self):
await super(WorkspaceProcessor, self).start()
await self._discover_workspaces()
@staticmethod
def add_args(parser: ArgumentParser) -> None:
AsyncProcessor.add_args(parser)

View file

@ -9,7 +9,6 @@ class CollectionManagementRequestTranslator(MessageTranslator):
def decode(self, data: Dict[str, Any]) -> CollectionManagementRequest: def decode(self, data: Dict[str, Any]) -> CollectionManagementRequest:
return CollectionManagementRequest( return CollectionManagementRequest(
operation=data.get("operation"), operation=data.get("operation"),
workspace=data.get("workspace", ""),
collection=data.get("collection"), collection=data.get("collection"),
timestamp=data.get("timestamp"), timestamp=data.get("timestamp"),
name=data.get("name"), name=data.get("name"),
@ -24,8 +23,6 @@ class CollectionManagementRequestTranslator(MessageTranslator):
if obj.operation is not None: if obj.operation is not None:
result["operation"] = obj.operation result["operation"] = obj.operation
if obj.workspace:
result["workspace"] = obj.workspace
if obj.collection is not None: if obj.collection is not None:
result["collection"] = obj.collection result["collection"] = obj.collection
if obj.timestamp is not None: if obj.timestamp is not None:

View file

@ -9,7 +9,6 @@ class FlowRequestTranslator(MessageTranslator):
def decode(self, data: Dict[str, Any]) -> FlowRequest: def decode(self, data: Dict[str, Any]) -> FlowRequest:
return FlowRequest( return FlowRequest(
operation=data.get("operation"), operation=data.get("operation"),
workspace=data.get("workspace", ""),
blueprint_name=data.get("blueprint-name"), blueprint_name=data.get("blueprint-name"),
blueprint_definition=data.get("blueprint-definition"), blueprint_definition=data.get("blueprint-definition"),
description=data.get("description"), description=data.get("description"),
@ -22,8 +21,6 @@ class FlowRequestTranslator(MessageTranslator):
if obj.operation is not None: if obj.operation is not None:
result["operation"] = obj.operation result["operation"] = obj.operation
if obj.workspace is not None:
result["workspace"] = obj.workspace
if obj.blueprint_name is not None: if obj.blueprint_name is not None:
result["blueprint-name"] = obj.blueprint_name result["blueprint-name"] = obj.blueprint_name
if obj.blueprint_definition is not None: if obj.blueprint_definition is not None:

View file

@ -45,7 +45,6 @@ class KnowledgeRequestTranslator(MessageTranslator):
return KnowledgeRequest( return KnowledgeRequest(
operation=data.get("operation"), operation=data.get("operation"),
workspace=data.get("workspace", ""),
id=data.get("id"), id=data.get("id"),
flow=data.get("flow"), flow=data.get("flow"),
collection=data.get("collection"), collection=data.get("collection"),
@ -58,8 +57,6 @@ class KnowledgeRequestTranslator(MessageTranslator):
if obj.operation: if obj.operation:
result["operation"] = obj.operation result["operation"] = obj.operation
if obj.workspace:
result["workspace"] = obj.workspace
if obj.id: if obj.id:
result["id"] = obj.id result["id"] = obj.id
if obj.flow: if obj.flow:

View file

@ -49,7 +49,6 @@ class LibraryRequestTranslator(MessageTranslator):
document_metadata=doc_metadata, document_metadata=doc_metadata,
processing_metadata=proc_metadata, processing_metadata=proc_metadata,
content=content, content=content,
workspace=data.get("workspace", ""),
collection=data.get("collection", ""), collection=data.get("collection", ""),
criteria=criteria, criteria=criteria,
# Chunked upload fields # Chunked upload fields
@ -76,8 +75,6 @@ class LibraryRequestTranslator(MessageTranslator):
result["processing-metadata"] = self.proc_metadata_translator.encode(obj.processing_metadata) result["processing-metadata"] = self.proc_metadata_translator.encode(obj.processing_metadata)
if obj.content: if obj.content:
result["content"] = obj.content.decode("utf-8") if isinstance(obj.content, bytes) else obj.content result["content"] = obj.content.decode("utf-8") if isinstance(obj.content, bytes) else obj.content
if obj.workspace:
result["workspace"] = obj.workspace
if obj.collection: if obj.collection:
result["collection"] = obj.collection result["collection"] = obj.collection
if obj.criteria is not None: if obj.criteria is not None:

View file

@ -19,7 +19,6 @@ class DocumentMetadataTranslator(Translator):
title=data.get("title"), title=data.get("title"),
comments=data.get("comments"), comments=data.get("comments"),
metadata=self.subgraph_translator.decode(metadata) if metadata is not None else [], metadata=self.subgraph_translator.decode(metadata) if metadata is not None else [],
workspace=data.get("workspace"),
tags=data.get("tags"), tags=data.get("tags"),
parent_id=data.get("parent-id", ""), parent_id=data.get("parent-id", ""),
document_type=data.get("document-type", "source"), document_type=data.get("document-type", "source"),
@ -40,8 +39,6 @@ class DocumentMetadataTranslator(Translator):
result["comments"] = obj.comments result["comments"] = obj.comments
if obj.metadata is not None: if obj.metadata is not None:
result["metadata"] = self.subgraph_translator.encode(obj.metadata) result["metadata"] = self.subgraph_translator.encode(obj.metadata)
if obj.workspace:
result["workspace"] = obj.workspace
if obj.tags is not None: if obj.tags is not None:
result["tags"] = obj.tags result["tags"] = obj.tags
if obj.parent_id: if obj.parent_id:
@ -61,7 +58,6 @@ class ProcessingMetadataTranslator(Translator):
document_id=data.get("document-id"), document_id=data.get("document-id"),
time=data.get("time"), time=data.get("time"),
flow=data.get("flow"), flow=data.get("flow"),
workspace=data.get("workspace"),
collection=data.get("collection"), collection=data.get("collection"),
tags=data.get("tags") tags=data.get("tags")
) )
@ -77,8 +73,6 @@ class ProcessingMetadataTranslator(Translator):
result["time"] = obj.time result["time"] = obj.time
if obj.flow: if obj.flow:
result["flow"] = obj.flow result["flow"] = obj.flow
if obj.workspace:
result["workspace"] = obj.workspace
if obj.collection: if obj.collection:
result["collection"] = obj.collection result["collection"] = obj.collection
if obj.tags is not None: if obj.tags is not None:

View file

@ -8,7 +8,5 @@ class Metadata:
# Root document identifier (set by librarian, preserved through pipeline) # Root document identifier (set by librarian, preserved through pipeline)
root: str = "" root: str = ""
# Collection the message belongs to. Workspace is NOT carried on the # Collection the message belongs to.
# message — consumers derive it from flow.workspace (the flow the
# message arrived on), which is the trusted isolation boundary.
collection: str = "" collection: str = ""

View file

@ -17,7 +17,7 @@ from .embeddings import GraphEmbeddings
# <- (error) # <- (error)
# list-kg-cores # list-kg-cores
# -> (workspace) # -> ()
# <- () # <- ()
# <- (error) # <- (error)
@ -27,9 +27,6 @@ class KnowledgeRequest:
# load-kg-core, unload-kg-core # load-kg-core, unload-kg-core
operation: str = "" operation: str = ""
# Workspace the cores belong to. Partition / isolation boundary.
workspace: str = ""
# get-kg-core, list-kg-cores, delete-kg-core, put-kg-core, # get-kg-core, list-kg-cores, delete-kg-core, put-kg-core,
# load-kg-core, unload-kg-core # load-kg-core, unload-kg-core
id: str = "" id: str = ""

View file

@ -22,17 +22,9 @@ class CollectionMetadata:
@dataclass @dataclass
class CollectionManagementRequest: class CollectionManagementRequest:
"""Request for collection management operations. """Request for collection management operations."""
Collection-management is a global (non-flow-scoped) service, so the
workspace has to travel on the wire it's the isolation boundary
for which workspace's collections the request operates on.
"""
operation: str = "" # e.g., "delete-collection" operation: str = "" # e.g., "delete-collection"
# Workspace the collection belongs to.
workspace: str = ""
collection: str = "" collection: str = ""
timestamp: str = "" # ISO timestamp timestamp: str = "" # ISO timestamp
name: str = "" name: str = ""

View file

@ -70,6 +70,11 @@ class ConfigResponse:
# Everything # Everything
error: Error | None = None error: Error | None = None
@dataclass
class WorkspaceChanges:
created: list[str] = field(default_factory=list)
deleted: list[str] = field(default_factory=list)
@dataclass @dataclass
class ConfigPush: class ConfigPush:
version: int = 0 version: int = 0
@ -80,6 +85,10 @@ class ConfigPush:
# e.g. {"prompt": ["workspace-a", "workspace-b"], "schema": ["workspace-a"]} # e.g. {"prompt": ["workspace-a", "workspace-b"], "schema": ["workspace-a"]}
changes: dict[str, list[str]] = field(default_factory=dict) changes: dict[str, list[str]] = field(default_factory=dict)
# Workspace lifecycle events. Populated when a workspace entry
# is created or deleted in the __workspaces__ config namespace.
workspace_changes: WorkspaceChanges | None = None
config_request_queue = queue('config', cls='request') config_request_queue = queue('config', cls='request')
config_response_queue = queue('config', cls='response') config_response_queue = queue('config', cls='response')
config_push_queue = queue('config', cls='notify') config_push_queue = queue('config', cls='notify')

View file

@ -22,9 +22,6 @@ class FlowRequest:
operation: str = "" # list-blueprints, get-blueprint, put-blueprint, delete-blueprint operation: str = "" # list-blueprints, get-blueprint, put-blueprint, delete-blueprint
# list-flows, get-flow, start-flow, stop-flow # list-flows, get-flow, start-flow, stop-flow
# Workspace scope — all operations act within this workspace
workspace: str = ""
# get_blueprint, put_blueprint, delete_blueprint, start_flow # get_blueprint, put_blueprint, delete_blueprint, start_flow
blueprint_name: str = "" blueprint_name: str = ""

View file

@ -43,12 +43,12 @@ from ..core.metadata import Metadata
# <- (error) # <- (error)
# list-documents # list-documents
# -> (workspace, collection?) # -> (collection?)
# <- (document_metadata[]) # <- (document_metadata[])
# <- (error) # <- (error)
# list-processing # list-processing
# -> (workspace, collection?) # -> (collection?)
# <- (processing_metadata[]) # <- (processing_metadata[])
# <- (error) # <- (error)
@ -78,7 +78,7 @@ from ..core.metadata import Metadata
# <- (error) # <- (error)
# list-uploads # list-uploads
# -> (workspace) # -> ()
# <- (uploads[]) # <- (uploads[])
# <- (error) # <- (error)
@ -90,7 +90,6 @@ class DocumentMetadata:
title: str = "" title: str = ""
comments: str = "" comments: str = ""
metadata: list[Triple] = field(default_factory=list) metadata: list[Triple] = field(default_factory=list)
workspace: str = ""
tags: list[str] = field(default_factory=list) tags: list[str] = field(default_factory=list)
# Child document support # Child document support
parent_id: str = "" # Empty for top-level docs, set for children parent_id: str = "" # Empty for top-level docs, set for children
@ -107,7 +106,6 @@ class ProcessingMetadata:
document_id: str = "" document_id: str = ""
time: int = 0 time: int = 0
flow: str = "" flow: str = ""
workspace: str = ""
collection: str = "" collection: str = ""
tags: list[str] = field(default_factory=list) tags: list[str] = field(default_factory=list)
@ -162,9 +160,6 @@ class LibrarianRequest:
# add-document, upload-chunk # add-document, upload-chunk
content: bytes = b"" content: bytes = b""
# Workspace scopes every library operation.
workspace: str = ""
# list-documents?, list-processing? # list-documents?, list-processing?
collection: str = "" collection: str = ""

View file

@ -22,15 +22,15 @@ def dump_status(metrics_url, api_url, flow_id, token=None,
print() print()
print(f"Flow {flow_id}") print(f"Flow {flow_id}")
show_processors(metrics_url, flow_id) show_processors(metrics_url, flow_id, token=token)
print() print()
print(f"Blueprint {blueprint_name}") print(f"Blueprint {blueprint_name}")
show_processors(metrics_url, blueprint_name) show_processors(metrics_url, blueprint_name, token=token)
print() print()
def show_processors(metrics_url, flow_label): def show_processors(metrics_url, flow_label, token=None):
url = f"{metrics_url}/query" url = f"{metrics_url}/query"
@ -40,7 +40,11 @@ def show_processors(metrics_url, flow_label):
"query": "consumer_state{" + expr + "}" "query": "consumer_state{" + expr + "}"
} }
resp = requests.get(url, params=params) headers = {}
if token:
headers["Authorization"] = f"Bearer {token}"
resp = requests.get(url, params=params, headers=headers)
obj = resp.json() obj = resp.json()

View file

@ -2,16 +2,22 @@
Dump out TrustGraph processor states. Dump out TrustGraph processor states.
""" """
import os
import requests import requests
import argparse import argparse
default_metrics_url = "http://localhost:8088/api/metrics" default_metrics_url = "http://localhost:8088/api/metrics"
DEFAULT_TOKEN = os.getenv("TRUSTGRAPH_TOKEN", None)
def dump_status(url): def dump_status(metrics_url, token=None):
url = f"{url}/query?query=processor_info" url = f"{metrics_url}/query?query=processor_info"
resp = requests.get(url) headers = {}
if token:
headers["Authorization"] = f"Bearer {token}"
resp = requests.get(url, headers=headers)
obj = resp.json() obj = resp.json()
@ -39,11 +45,17 @@ def main():
help=f'Metrics URL (default: {default_metrics_url})', help=f'Metrics URL (default: {default_metrics_url})',
) )
parser.add_argument(
'-t', '--token',
default=DEFAULT_TOKEN,
help=f'Bearer token for authentication (default: TRUSTGRAPH_TOKEN env var)',
)
args = parser.parse_args() args = parser.parse_args()
try: try:
dump_status(args.metrics_url) dump_status(args.metrics_url, args.token)
except Exception as e: except Exception as e:

View file

@ -3,12 +3,14 @@ Dump out a stream of token rates, input, output and total. This is averaged
across the time since tg-show-token-rate is started. across the time since tg-show-token-rate is started.
""" """
import os
import requests import requests
import argparse import argparse
import json import json
import time import time
default_metrics_url = "http://localhost:8088/api/metrics" default_metrics_url = "http://localhost:8088/api/metrics"
DEFAULT_TOKEN = os.getenv("TRUSTGRAPH_TOKEN", None)
class Collate: class Collate:
@ -36,16 +38,20 @@ class Collate:
return delta/time, self.total/self.time return delta/time, self.total/self.time
def dump_status(metrics_url, number_samples, period): def dump_status(metrics_url, number_samples, period, token=None):
input_url = f"{metrics_url}/query?query=input_tokens_total" input_url = f"{metrics_url}/query?query=input_tokens_total"
output_url = f"{metrics_url}/query?query=output_tokens_total" output_url = f"{metrics_url}/query?query=output_tokens_total"
resp = requests.get(input_url) headers = {}
if token:
headers["Authorization"] = f"Bearer {token}"
resp = requests.get(input_url, headers=headers)
obj = resp.json() obj = resp.json()
input = Collate(obj) input = Collate(obj)
resp = requests.get(output_url) resp = requests.get(output_url, headers=headers)
obj = resp.json() obj = resp.json()
output = Collate(obj) output = Collate(obj)
@ -56,20 +62,20 @@ def dump_status(metrics_url, number_samples, period):
time.sleep(period) time.sleep(period)
resp = requests.get(input_url) resp = requests.get(input_url, headers=headers)
obj = resp.json() obj = resp.json()
inr, inl = input.record(obj, period) inr, inl = input.record(obj, period)
resp = requests.get(output_url) resp = requests.get(output_url, headers=headers)
obj = resp.json() obj = resp.json()
outr, outl = output.record(obj, period) outr, outl = output.record(obj, period)
print(f"{inl:10.1f} {outl:10.1f} {inl+outl:10.1f}") print(f"{inl:10.1f} {outl:10.1f} {inl+outl:10.1f}")
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog='tg-show-processor-state', prog='tg-show-token-rate',
description=__doc__, description=__doc__,
) )
@ -93,6 +99,12 @@ def main():
help=f'Metrics period (default: 100)', help=f'Metrics period (default: 100)',
) )
parser.add_argument(
'-t', '--token',
default=DEFAULT_TOKEN,
help=f'Bearer token for authentication (default: TRUSTGRAPH_TOKEN env var)',
)
args = parser.parse_args() args = parser.parse_args()
try: try:

View file

@ -61,6 +61,10 @@ class FlowContext:
def __call__(self, service_name): def __call__(self, service_name):
return self._flow(service_name) return self._flow(service_name)
@property
def librarian(self):
return self._flow.librarian
class UsageTracker: class UsageTracker:
"""Accumulates token usage across multiple prompt calls.""" """Accumulates token usage across multiple prompt calls."""
@ -320,9 +324,9 @@ class PatternBase:
f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought" f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought"
) )
try: try:
await self.processor.save_answer_content( await flow.librarian.save_document(
doc_id=thought_doc_id, doc_id=thought_doc_id,
workspace=flow.workspace,
content=act.thought, content=act.thought,
title=f"Agent Thought: {act.name}", title=f"Agent Thought: {act.name}",
) )
@ -389,9 +393,9 @@ class PatternBase:
f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation" f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation"
) )
try: try:
await self.processor.save_answer_content( await flow.librarian.save_document(
doc_id=observation_doc_id, doc_id=observation_doc_id,
workspace=flow.workspace,
content=observation_text, content=observation_text,
title=f"Agent Observation", title=f"Agent Observation",
) )
@ -445,9 +449,9 @@ class PatternBase:
if answer_text: if answer_text:
answer_doc_id = f"urn:trustgraph:agent:{session_id}/answer" answer_doc_id = f"urn:trustgraph:agent:{session_id}/answer"
try: try:
await self.processor.save_answer_content( await flow.librarian.save_document(
doc_id=answer_doc_id, doc_id=answer_doc_id,
workspace=flow.workspace,
content=answer_text, content=answer_text,
title=f"Agent Answer: {request.question[:50]}...", title=f"Agent Answer: {request.question[:50]}...",
) )
@ -521,8 +525,8 @@ class PatternBase:
doc_id = f"urn:trustgraph:agent:{session_id}/finding/{index}/doc" doc_id = f"urn:trustgraph:agent:{session_id}/finding/{index}/doc"
try: try:
await self.processor.save_answer_content( await flow.librarian.save_document(
doc_id=doc_id, workspace=flow.workspace, doc_id=doc_id,
content=answer_text, content=answer_text,
title=f"Finding: {goal[:60]}", title=f"Finding: {goal[:60]}",
) )
@ -574,8 +578,8 @@ class PatternBase:
doc_id = f"urn:trustgraph:agent:{session_id}/step/{index}/doc" doc_id = f"urn:trustgraph:agent:{session_id}/step/{index}/doc"
try: try:
await self.processor.save_answer_content( await flow.librarian.save_document(
doc_id=doc_id, workspace=flow.workspace, doc_id=doc_id,
content=answer_text, content=answer_text,
title=f"Step result: {goal[:60]}", title=f"Step result: {goal[:60]}",
) )
@ -606,8 +610,8 @@ class PatternBase:
doc_id = f"urn:trustgraph:agent:{session_id}/synthesis/doc" doc_id = f"urn:trustgraph:agent:{session_id}/synthesis/doc"
try: try:
await self.processor.save_answer_content( await flow.librarian.save_document(
doc_id=doc_id, workspace=flow.workspace, doc_id=doc_id,
content=answer_text, content=answer_text,
title="Synthesis", title="Synthesis",
) )

View file

@ -7,26 +7,17 @@ to select between ReactPattern, PlanThenExecutePattern, and
SupervisorPattern at runtime. SupervisorPattern at runtime.
""" """
import asyncio
import base64
import json import json
import functools import functools
import logging import logging
import uuid
from datetime import datetime
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec
from ... base import ProducerSpec from ... base import ProducerSpec, LibrarianSpec
from ... base import Consumer, Producer
from ... base import ConsumerMetrics, ProducerMetrics
from ... schema import AgentRequest, AgentResponse, AgentStep, Error from ... schema import AgentRequest, AgentResponse, AgentStep, Error
from ..orchestrator.pattern_base import UsageTracker, PatternBase from ..orchestrator.pattern_base import UsageTracker, PatternBase
from ... schema import Triples, Metadata from ... schema import Triples, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue
from trustgraph.provenance import ( from trustgraph.provenance import (
agent_session_uri, agent_session_uri,
@ -52,8 +43,6 @@ logger = logging.getLogger(__name__)
default_ident = "agent-manager" default_ident = "agent-manager"
default_max_iterations = 10 default_max_iterations = 10
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class Processor(AgentService): class Processor(AgentService):
@ -151,94 +140,9 @@ class Processor(AgentService):
) )
) )
# Librarian client self.register_specification(
librarian_request_q = params.get( LibrarianSpec()
"librarian_request_queue", default_librarian_request_queue
) )
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request"
)
self.librarian_request_producer = Producer(
backend=self.pubsub,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self.on_librarian_response,
metrics=librarian_response_metrics,
)
self.pending_librarian_requests = {}
async def start(self):
await super(Processor, self).start()
await self.librarian_request_producer.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
response = msg.value()
request_id = msg.properties().get("id")
if request_id in self.pending_librarian_requests:
future = self.pending_librarian_requests.pop(request_id)
future.set_result(response)
async def save_answer_content(self, doc_id, workspace, content, title=None,
timeout=120):
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata(
id=doc_id,
workspace=workspace,
kind="text/plain",
title=title or "Agent Answer",
document_type="answer",
)
request = LibrarianRequest(
operation="add-document",
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
workspace=workspace,
)
future = asyncio.get_event_loop().create_future()
self.pending_librarian_requests[request_id] = future
try:
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving answer: "
f"{response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_librarian_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving answer document {doc_id}")
def provenance_session_uri(self, session_id): def provenance_session_uri(self, session_id):
return agent_session_uri(session_id) return agent_session_uri(session_id)

View file

@ -3,7 +3,6 @@ Simple agent infrastructure broadly implements the ReAct flow.
""" """
import asyncio import asyncio
import base64
import json import json
import re import re
import sys import sys
@ -19,14 +18,10 @@ logger = logging.getLogger(__name__)
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec
from ... base import ProducerSpec from ... base import ProducerSpec, LibrarianSpec
from ... base import Consumer, Producer
from ... base import ConsumerMetrics, ProducerMetrics
from ... schema import AgentRequest, AgentResponse, AgentStep, Error from ... schema import AgentRequest, AgentResponse, AgentStep, Error
from ... schema import Triples, Metadata from ... schema import Triples, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue
# Provenance imports for agent explainability # Provenance imports for agent explainability
from trustgraph.provenance import ( from trustgraph.provenance import (
@ -51,8 +46,6 @@ from . types import Final, Action, Tool, Argument
default_ident = "agent-manager" default_ident = "agent-manager"
default_max_iterations = 10 default_max_iterations = 10
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class Processor(AgentService): class Processor(AgentService):
@ -141,112 +134,9 @@ class Processor(AgentService):
) )
) )
# Librarian client for storing answer content self.register_specification(
librarian_request_q = params.get( LibrarianSpec()
"librarian_request_queue", default_librarian_request_queue
) )
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request"
)
self.librarian_request_producer = Producer(
backend=self.pubsub,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self.on_librarian_response,
metrics=librarian_response_metrics,
)
# Pending librarian requests: request_id -> asyncio.Future
self.pending_librarian_requests = {}
async def start(self):
await super(Processor, self).start()
await self.librarian_request_producer.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id in self.pending_librarian_requests:
future = self.pending_librarian_requests.pop(request_id)
future.set_result(response)
async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120):
"""
Save answer content to the librarian.
Args:
doc_id: ID for the answer document
workspace: Workspace for isolation
content: Answer text content
title: Optional title
timeout: Request timeout in seconds
Returns:
The document ID on success
"""
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata(
id=doc_id,
workspace=workspace,
kind="text/plain",
title=title or "Agent Answer",
document_type="answer",
)
request = LibrarianRequest(
operation="add-document",
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
workspace=workspace,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_librarian_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving answer: {response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_librarian_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving answer document {doc_id}")
async def on_tools_config(self, workspace, config, version): async def on_tools_config(self, workspace, config, version):
@ -611,9 +501,9 @@ class Processor(AgentService):
if act_decision.thought: if act_decision.thought:
t_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought" t_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought"
try: try:
await self.save_answer_content( await flow.librarian.save_document(
doc_id=t_doc_id, doc_id=t_doc_id,
workspace=flow.workspace,
content=act_decision.thought, content=act_decision.thought,
title=f"Agent Thought: {act_decision.name}", title=f"Agent Thought: {act_decision.name}",
) )
@ -691,9 +581,9 @@ class Processor(AgentService):
if f: if f:
answer_doc_id = f"urn:trustgraph:agent:{session_id}/answer" answer_doc_id = f"urn:trustgraph:agent:{session_id}/answer"
try: try:
await self.save_answer_content( await flow.librarian.save_document(
doc_id=answer_doc_id, doc_id=answer_doc_id,
workspace=flow.workspace,
content=f, content=f,
title=f"Agent Answer: {request.question[:50]}...", title=f"Agent Answer: {request.question[:50]}...",
) )
@ -768,9 +658,8 @@ class Processor(AgentService):
if act.observation: if act.observation:
observation_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation" observation_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation"
try: try:
await self.save_answer_content( await flow.librarian.save_document(
doc_id=observation_doc_id, doc_id=observation_doc_id,
workspace=flow.workspace,
content=act.observation, content=act.observation,
title=f"Agent Observation", title=f"Agent Observation",
) )

View file

@ -21,7 +21,8 @@ class InitContext:
logger: logging.Logger logger: logging.Logger
config: Any # ConfigClient config: Any # ConfigClient
flow: Any # RequestResponse client for flow-svc make_flow_client: Any # callable(workspace) -> RequestResponse
make_iam_client: Any # callable() -> RequestResponse
class Initialiser: class Initialiser:
@ -35,7 +36,7 @@ class Initialiser:
* ``wait_for_services`` (bool, default ``True``): when ``True`` the * ``wait_for_services`` (bool, default ``True``): when ``True`` the
initialiser only runs after the bootstrapper's service gate has initialiser only runs after the bootstrapper's service gate has
passed (config-svc and flow-svc reachable). Set ``False`` for passed (config-svc reachable). Set ``False`` for
initialisers that bring up infrastructure the gate itself initialisers that bring up infrastructure the gate itself
depends on principally Pulsar topology, without which depends on principally Pulsar topology, without which
config-svc cannot come online. config-svc cannot come online.

View file

@ -28,6 +28,10 @@ from trustgraph.schema import (
FlowRequest, FlowResponse, FlowRequest, FlowResponse,
flow_request_queue, flow_response_queue, flow_request_queue, flow_response_queue,
) )
from trustgraph.schema import (
IamRequest, IamResponse,
iam_request_queue, iam_response_queue,
)
from .. base import Initialiser, InitContext from .. base import Initialiser, InitContext
@ -178,34 +182,46 @@ class Processor(AsyncProcessor):
), ),
) )
def _make_flow_client(self): def _make_flow_client(self, workspace):
rr_id = str(uuid.uuid4()) rr_id = str(uuid.uuid4())
return RequestResponse( return RequestResponse(
backend=self.pubsub_backend, backend=self.pubsub_backend,
subscription=f"{self.id}--flow--{rr_id}", subscription=f"{self.id}--flow--{rr_id}",
consumer_name=self.id, consumer_name=self.id,
request_topic=flow_request_queue, request_topic=f"{flow_request_queue}:{workspace}",
request_schema=FlowRequest, request_schema=FlowRequest,
request_metrics=ProducerMetrics( request_metrics=ProducerMetrics(
processor=self.id, flow=None, name="flow-request", processor=self.id, flow=None, name="flow-request",
), ),
response_topic=flow_response_queue, response_topic=f"{flow_response_queue}:{workspace}",
response_schema=FlowResponse, response_schema=FlowResponse,
response_metrics=SubscriberMetrics( response_metrics=SubscriberMetrics(
processor=self.id, flow=None, name="flow-response", processor=self.id, flow=None, name="flow-response",
), ),
) )
def _make_iam_client(self):
rr_id = str(uuid.uuid4())
return RequestResponse(
backend=self.pubsub_backend,
subscription=f"{self.id}--iam--{rr_id}",
consumer_name=self.id,
request_topic=iam_request_queue,
request_schema=IamRequest,
request_metrics=ProducerMetrics(
processor=self.id, flow=None, name="iam-request",
),
response_topic=iam_response_queue,
response_schema=IamResponse,
response_metrics=SubscriberMetrics(
processor=self.id, flow=None, name="iam-response",
),
)
async def _open_clients(self): async def _open_clients(self):
config = self._make_config_client() config = self._make_config_client()
flow = self._make_flow_client()
await config.start() await config.start()
try: return config
await flow.start()
except Exception:
await self._safe_stop(config)
raise
return config, flow
async def _safe_stop(self, client): async def _safe_stop(self, client):
try: try:
@ -217,7 +233,7 @@ class Processor(AsyncProcessor):
# Service gate. # Service gate.
# ------------------------------------------------------------------ # ------------------------------------------------------------------
async def _gate_ready(self, config, flow): async def _gate_ready(self, config):
try: try:
await config.keys(SYSTEM_WORKSPACE, INIT_STATE_TYPE) await config.keys(SYSTEM_WORKSPACE, INIT_STATE_TYPE)
except Exception as e: except Exception as e:
@ -226,26 +242,6 @@ class Processor(AsyncProcessor):
) )
return False return False
try:
resp = await flow.request(
FlowRequest(
operation="list-blueprints",
workspace=SYSTEM_WORKSPACE,
),
timeout=5,
)
if resp.error:
logger.info(
f"Gate: flow-svc error: "
f"{resp.error.type}: {resp.error.message}"
)
return False
except Exception as e:
logger.info(
f"Gate: flow-svc not ready ({type(e).__name__}: {e})"
)
return False
return True return True
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@ -271,7 +267,7 @@ class Processor(AsyncProcessor):
# Per-spec execution. # Per-spec execution.
# ------------------------------------------------------------------ # ------------------------------------------------------------------
async def _run_spec(self, spec, config, flow): async def _run_spec(self, spec, config):
"""Run a single initialiser spec. """Run a single initialiser spec.
Returns one of: Returns one of:
@ -298,7 +294,8 @@ class Processor(AsyncProcessor):
child_ctx = InitContext( child_ctx = InitContext(
logger=child_logger, logger=child_logger,
config=config, config=config,
flow=flow, make_flow_client=self._make_flow_client,
make_iam_client=self._make_iam_client,
) )
child_logger.info( child_logger.info(
@ -340,7 +337,7 @@ class Processor(AsyncProcessor):
sleep_for = STEADY_INTERVAL sleep_for = STEADY_INTERVAL
try: try:
config, flow = await self._open_clients() config = await self._open_clients()
except Exception as e: except Exception as e:
logger.info( logger.info(
f"Failed to open clients " f"Failed to open clients "
@ -358,11 +355,11 @@ class Processor(AsyncProcessor):
pre_results = {} pre_results = {}
for spec in pre_specs: for spec in pre_specs:
pre_results[spec.name] = await self._run_spec( pre_results[spec.name] = await self._run_spec(
spec, config, flow, spec, config,
) )
# Phase 2: gate. # Phase 2: gate.
gate_ok = await self._gate_ready(config, flow) gate_ok = await self._gate_ready(config)
# Phase 3: post-service initialisers, if gate passed. # Phase 3: post-service initialisers, if gate passed.
post_results = {} post_results = {}
@ -373,7 +370,7 @@ class Processor(AsyncProcessor):
] ]
for spec in post_specs: for spec in post_specs:
post_results[spec.name] = await self._run_spec( post_results[spec.name] = await self._run_spec(
spec, config, flow, spec, config,
) )
# Cadence selection. # Cadence selection.
@ -388,7 +385,6 @@ class Processor(AsyncProcessor):
finally: finally:
await self._safe_stop(config) await self._safe_stop(config)
await self._safe_stop(flow)
await asyncio.sleep(sleep_for) await asyncio.sleep(sleep_for)

View file

@ -49,53 +49,67 @@ class DefaultFlowStart(Initialiser):
async def run(self, ctx, old_flag, new_flag): async def run(self, ctx, old_flag, new_flag):
# Check whether the flow already exists. Belt-and-braces workspaces = await ctx.config.keys(
# beyond the flag gate: if an operator stops and restarts the "__workspaces__", "workspace",
# bootstrapper after the flow is already running, we don't
# want to blindly try to start it again.
list_resp = await ctx.flow.request(
FlowRequest(
operation="list-flows",
workspace=self.workspace,
),
timeout=10,
) )
if list_resp.error: if self.workspace not in workspaces:
raise RuntimeError( raise RuntimeError(
f"list-flows failed: " f"Workspace {self.workspace!r} does not exist yet"
f"{list_resp.error.type}: {list_resp.error.message}"
) )
if self.flow_id in (list_resp.flow_ids or []): flow = ctx.make_flow_client(self.workspace)
await flow.start()
try:
# Check whether the flow already exists. Belt-and-braces
# beyond the flag gate: if an operator stops and restarts the
# bootstrapper after the flow is already running, we don't
# want to blindly try to start it again.
list_resp = await flow.request(
FlowRequest(
operation="list-flows",
),
timeout=10,
)
if list_resp.error:
raise RuntimeError(
f"list-flows failed: "
f"{list_resp.error.type}: {list_resp.error.message}"
)
if self.flow_id in (list_resp.flow_ids or []):
ctx.logger.info(
f"Flow {self.flow_id!r} already running in workspace "
f"{self.workspace!r}; nothing to do"
)
return
ctx.logger.info( ctx.logger.info(
f"Flow {self.flow_id!r} already running in workspace " f"Starting flow {self.flow_id!r} "
f"{self.workspace!r}; nothing to do" f"(blueprint={self.blueprint!r}) "
) f"in workspace {self.workspace!r}"
return
ctx.logger.info(
f"Starting flow {self.flow_id!r} "
f"(blueprint={self.blueprint!r}) "
f"in workspace {self.workspace!r}"
)
resp = await ctx.flow.request(
FlowRequest(
operation="start-flow",
workspace=self.workspace,
flow_id=self.flow_id,
blueprint_name=self.blueprint,
description=self.description,
parameters=self.parameters,
),
timeout=30,
)
if resp.error:
raise RuntimeError(
f"start-flow failed: "
f"{resp.error.type}: {resp.error.message}"
) )
ctx.logger.info( resp = await flow.request(
f"Flow {self.flow_id!r} started" FlowRequest(
) operation="start-flow",
flow_id=self.flow_id,
blueprint_name=self.blueprint,
description=self.description,
parameters=self.parameters,
),
timeout=30,
)
if resp.error:
raise RuntimeError(
f"start-flow failed: "
f"{resp.error.type}: {resp.error.message}"
)
ctx.logger.info(
f"Flow {self.flow_id!r} started"
)
finally:
await flow.stop()

View file

@ -26,6 +26,8 @@ the next cycle once the prerequisite is satisfied.
import json import json
from trustgraph.schema import IamRequest, WorkspaceInput
from .. base import Initialiser from .. base import Initialiser
TEMPLATE_WORKSPACE = "__template__" TEMPLATE_WORKSPACE = "__template__"
@ -59,6 +61,8 @@ class WorkspaceInit(Initialiser):
self.overwrite = overwrite self.overwrite = overwrite
async def run(self, ctx, old_flag, new_flag): async def run(self, ctx, old_flag, new_flag):
await self._create_workspace(ctx)
if self.source == "seed-file": if self.source == "seed-file":
tree = self._load_seed_file() tree = self._load_seed_file()
else: else:
@ -105,6 +109,39 @@ class WorkspaceInit(Initialiser):
) )
return tree return tree
async def _create_workspace(self, ctx):
"""Register the workspace via the IAM create-workspace API."""
iam = ctx.make_iam_client()
await iam.start()
try:
resp = await iam.request(
IamRequest(
operation="create-workspace",
workspace_record=WorkspaceInput(
id=self.workspace,
name=self.workspace.title(),
enabled=True,
),
),
timeout=10,
)
if resp.error:
if resp.error.type == "duplicate":
ctx.logger.info(
f"Workspace {self.workspace!r} already exists in IAM"
)
else:
raise RuntimeError(
f"IAM create-workspace failed: "
f"{resp.error.type}: {resp.error.message}"
)
else:
ctx.logger.info(
f"Workspace {self.workspace!r} created via IAM"
)
finally:
await iam.stop()
async def _write_all(self, ctx, tree): async def _write_all(self, ctx, tree):
values = [] values = []
for type_name, entries in tree.items(): for type_name, entries in tree.items():
@ -112,6 +149,7 @@ class WorkspaceInit(Initialiser):
values.append((type_name, key, json.dumps(value))) values.append((type_name, key, json.dumps(value)))
if values: if values:
await ctx.config.put_many(self.workspace, values) await ctx.config.put_many(self.workspace, values)
ctx.logger.info( ctx.logger.info(
f"Workspace {self.workspace!r} populated with " f"Workspace {self.workspace!r} populated with "
f"{len(values)} entries" f"{len(values)} entries"
@ -132,6 +170,7 @@ class WorkspaceInit(Initialiser):
if values: if values:
await ctx.config.put_many(self.workspace, values) await ctx.config.put_many(self.workspace, values)
written += len(values) written += len(values)
ctx.logger.info( ctx.logger.info(
f"Workspace {self.workspace!r} upsert-missing: " f"Workspace {self.workspace!r} upsert-missing: "
f"{written} new entries" f"{written} new entries"

View file

@ -95,7 +95,7 @@ class Processor(ChunkingService):
logger.info(f"Chunking document {v.metadata.id}...") logger.info(f"Chunking document {v.metadata.id}...")
# Get text content (fetches from librarian if needed) # Get text content (fetches from librarian if needed)
text = await self.get_document_text(v, flow.workspace) text = await self.get_document_text(v, flow)
# Extract chunk parameters from flow (allows runtime override) # Extract chunk parameters from flow (allows runtime override)
chunk_size, chunk_overlap = await self.chunk_document( chunk_size, chunk_overlap = await self.chunk_document(
@ -141,10 +141,9 @@ class Processor(ChunkingService):
chunk_length = len(chunk.page_content) chunk_length = len(chunk.page_content)
# Save chunk to librarian as child document # Save chunk to librarian as child document
await self.librarian.save_child_document( await flow.librarian.save_child_document(
doc_id=chunk_doc_id, doc_id=chunk_doc_id,
parent_id=parent_doc_id, parent_id=parent_doc_id,
workspace=flow.workspace,
content=chunk_content, content=chunk_content,
document_type="chunk", document_type="chunk",
title=f"Chunk {chunk_index}", title=f"Chunk {chunk_index}",

View file

@ -92,7 +92,7 @@ class Processor(ChunkingService):
logger.info(f"Chunking document {v.metadata.id}...") logger.info(f"Chunking document {v.metadata.id}...")
# Get text content (fetches from librarian if needed) # Get text content (fetches from librarian if needed)
text = await self.get_document_text(v, flow.workspace) text = await self.get_document_text(v, flow)
# Extract chunk parameters from flow (allows runtime override) # Extract chunk parameters from flow (allows runtime override)
chunk_size, chunk_overlap = await self.chunk_document( chunk_size, chunk_overlap = await self.chunk_document(
@ -137,10 +137,9 @@ class Processor(ChunkingService):
chunk_length = len(chunk.page_content) chunk_length = len(chunk.page_content)
# Save chunk to librarian as child document # Save chunk to librarian as child document
await self.librarian.save_child_document( await flow.librarian.save_child_document(
doc_id=chunk_doc_id, doc_id=chunk_doc_id,
parent_id=parent_doc_id, parent_id=parent_doc_id,
workspace=flow.workspace,
content=chunk_content, content=chunk_content,
document_type="chunk", document_type="chunk",
title=f"Chunk {chunk_index}", title=f"Chunk {chunk_index}",

View file

@ -2,13 +2,17 @@
import logging import logging
from trustgraph.schema import ConfigResponse from trustgraph.schema import ConfigResponse
from trustgraph.schema import ConfigValue, Error from trustgraph.schema import ConfigValue, WorkspaceChanges, Error
from ... tables.config import ConfigTableStore from ... tables.config import ConfigTableStore
# Module logger # Module logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
WORKSPACES_NAMESPACE = "__workspaces__"
WORKSPACE_TYPE = "workspace"
TEMPLATE_WORKSPACE = "__template__"
class Configuration: class Configuration:
def __init__(self, push, host, username, password, keyspace, def __init__(self, push, host, username, password, keyspace,
@ -27,9 +31,7 @@ class Configuration:
async def get_version(self): async def get_version(self):
return await self.table_store.get_version() return await self.table_store.get_version()
async def handle_get(self, v): async def handle_get(self, v, workspace):
workspace = v.workspace
values = [ values = [
ConfigValue( ConfigValue(
@ -47,18 +49,18 @@ class Configuration:
values = values, values = values,
) )
async def handle_list(self, v): async def handle_list(self, v, workspace):
return ConfigResponse( return ConfigResponse(
version = await self.get_version(), version = await self.get_version(),
directory = await self.table_store.get_keys( directory = await self.table_store.get_keys(
v.workspace, v.type workspace, v.type
), ),
) )
async def handle_getvalues(self, v): async def handle_getvalues(self, v, workspace):
vals = await self.table_store.get_values(v.workspace, v.type) vals = await self.table_store.get_values(workspace, v.type)
values = map( values = map(
lambda x: ConfigValue( lambda x: ConfigValue(
@ -94,9 +96,8 @@ class Configuration:
values = values, values = values,
) )
async def handle_delete(self, v): async def handle_delete(self, v, workspace):
workspace = v.workspace
types = list(set(k.type for k in v.keys)) types = list(set(k.type for k in v.keys))
for k in v.keys: for k in v.keys:
@ -104,14 +105,22 @@ class Configuration:
await self.inc_version() await self.inc_version()
await self.push(changes={t: [workspace] for t in types}) workspace_changes = None
if workspace == WORKSPACES_NAMESPACE and WORKSPACE_TYPE in types:
deleted = [k.key for k in v.keys if k.type == WORKSPACE_TYPE]
if deleted:
workspace_changes = WorkspaceChanges(deleted=deleted)
await self.push(
changes={t: [workspace] for t in types},
workspace_changes=workspace_changes,
)
return ConfigResponse( return ConfigResponse(
) )
async def handle_put(self, v): async def handle_put(self, v, workspace):
workspace = v.workspace
types = list(set(k.type for k in v.values)) types = list(set(k.type for k in v.values))
for k in v.values: for k in v.values:
@ -121,11 +130,49 @@ class Configuration:
await self.inc_version() await self.inc_version()
await self.push(changes={t: [workspace] for t in types}) workspace_changes = None
if workspace == WORKSPACES_NAMESPACE and WORKSPACE_TYPE in types:
created = [k.key for k in v.values if k.type == WORKSPACE_TYPE]
if created:
workspace_changes = WorkspaceChanges(created=created)
await self.push(
changes={t: [workspace] for t in types},
workspace_changes=workspace_changes,
)
return ConfigResponse( return ConfigResponse(
) )
async def provision_from_template(self, workspace):
"""Copy all config from __template__ into a new workspace,
skipping keys that already exist (upsert-missing)."""
template = await self.get_config(TEMPLATE_WORKSPACE)
if not template:
logger.info(
f"No template config to provision for {workspace}"
)
return 0
existing_types = await self.get_config(workspace)
written = 0
for type_name, entries in template.items():
existing_keys = set(existing_types.get(type_name, {}).keys())
for key, value in entries.items():
if key not in existing_keys:
await self.table_store.put_config(
workspace, type_name, key, value
)
written += 1
if written > 0:
await self.inc_version()
return written
async def get_config(self, workspace): async def get_config(self, workspace):
table = await self.table_store.get_all_for_workspace(workspace) table = await self.table_store.get_all_for_workspace(workspace)
@ -139,62 +186,87 @@ class Configuration:
return config return config
async def handle_config(self, v): async def handle_config(self, v, workspace):
config = await self.get_config(v.workspace) config = await self.get_config(workspace)
return ConfigResponse( return ConfigResponse(
version = await self.get_version(), version = await self.get_version(),
config = config, config = config,
) )
async def handle(self, msg): async def handle_workspace(self, msg, workspace):
"""Handle workspace-scoped config operations.
Workspace is provided by queue infrastructure."""
logger.debug( logger.debug(
f"Handling config message: {msg.operation} " f"Handling workspace config message: {msg.operation} "
f"workspace={msg.workspace}" f"workspace={workspace}"
) )
# getvalues-all-ws spans all workspaces, so no workspace
# required; everything else is workspace-scoped.
if msg.operation != "getvalues-all-ws" and not msg.workspace:
return ConfigResponse(
error=Error(
type = "bad-request",
message = "Workspace is required"
)
)
if msg.operation == "get": if msg.operation == "get":
resp = await self.handle_get(msg, workspace)
resp = await self.handle_get(msg)
elif msg.operation == "list": elif msg.operation == "list":
resp = await self.handle_list(msg, workspace)
resp = await self.handle_list(msg)
elif msg.operation == "getvalues": elif msg.operation == "getvalues":
resp = await self.handle_getvalues(msg, workspace)
resp = await self.handle_getvalues(msg)
elif msg.operation == "getvalues-all-ws":
resp = await self.handle_getvalues_all_ws(msg)
elif msg.operation == "delete": elif msg.operation == "delete":
resp = await self.handle_delete(msg, workspace)
resp = await self.handle_delete(msg)
elif msg.operation == "put": elif msg.operation == "put":
resp = await self.handle_put(msg, workspace)
resp = await self.handle_put(msg)
elif msg.operation == "config": elif msg.operation == "config":
resp = await self.handle_config(msg, workspace)
resp = await self.handle_config(msg)
else:
resp = ConfigResponse(
error=Error(
type = "bad-operation",
message = "Bad operation"
)
)
return resp
async def handle_system(self, msg):
"""Handle system-level config operations.
Workspace, when needed, comes from message body."""
logger.debug(
f"Handling system config message: {msg.operation} "
f"workspace={msg.workspace}"
)
if msg.operation == "getvalues-all-ws":
resp = await self.handle_getvalues_all_ws(msg)
elif msg.operation in ("get", "list", "getvalues", "delete",
"put", "config"):
if not msg.workspace:
return ConfigResponse(
error=Error(
type = "bad-request",
message = "Workspace is required"
)
)
handler = {
"get": self.handle_get,
"list": self.handle_list,
"getvalues": self.handle_getvalues,
"delete": self.handle_delete,
"put": self.handle_put,
"config": self.handle_config,
}[msg.operation]
resp = await handler(msg, msg.workspace)
else: else:
resp = ConfigResponse( resp = ConfigResponse(
error=Error( error=Error(
type = "bad-operation", type = "bad-operation",

View file

@ -1,20 +1,30 @@
""" """
Config service. Manages system global configuration state Config service. Manages system global configuration state.
Operates a dual-queue regime:
- System queue (config-request): handles cross-workspace operations like
getvalues-all-ws and bootstrapper put/delete on __workspaces__.
The gateway NEVER routes to this queue.
- Per-workspace queues (config-request:<workspace>): handles
workspace-scoped operations where workspace identity comes from
queue infrastructure, not message body.
""" """
import logging import logging
from functools import partial
from trustgraph.schema import Error from trustgraph.schema import Error
from trustgraph.schema import ConfigRequest, ConfigResponse, ConfigPush from trustgraph.schema import ConfigRequest, ConfigResponse, ConfigPush
from trustgraph.schema import WorkspaceChanges
from trustgraph.schema import config_request_queue, config_response_queue from trustgraph.schema import config_request_queue, config_response_queue
from trustgraph.schema import config_push_queue from trustgraph.schema import config_push_queue
from trustgraph.base import AsyncProcessor, Consumer, Producer from trustgraph.base import AsyncProcessor, Consumer, Producer
from trustgraph.base.cassandra_config import add_cassandra_args, resolve_cassandra_config from trustgraph.base.cassandra_config import add_cassandra_args, resolve_cassandra_config
from . config import Configuration from . config import Configuration, WORKSPACES_NAMESPACE, WORKSPACE_TYPE
from ... base import ProcessorMetrics, ConsumerMetrics, ProducerMetrics from ... base import ProcessorMetrics, ConsumerMetrics, ProducerMetrics
from ... base import Consumer, Producer from ... base import Consumer, Producer
@ -39,6 +49,11 @@ def is_reserved_workspace(workspace):
""" """
return workspace.startswith("_") return workspace.startswith("_")
def workspace_queue(base_queue, workspace):
return f"{base_queue}:{workspace}"
default_config_request_queue = config_request_queue default_config_request_queue = config_request_queue
default_config_response_queue = config_response_queue default_config_response_queue = config_response_queue
default_config_push_queue = config_push_queue default_config_push_queue = config_push_queue
@ -48,11 +63,11 @@ default_cassandra_host = "cassandra"
class Processor(AsyncProcessor): class Processor(AsyncProcessor):
def __init__(self, **params): def __init__(self, **params):
config_request_queue = params.get( config_request_queue = params.get(
"config_request_queue", default_config_request_queue "config_request_queue", default_config_request_queue
) )
config_response_queue = params.get( self.config_response_queue_base = params.get(
"config_response_queue", default_config_response_queue "config_response_queue", default_config_response_queue
) )
config_push_queue = params.get( config_push_queue = params.get(
@ -64,13 +79,13 @@ class Processor(AsyncProcessor):
cassandra_password = params.get("cassandra_password") cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback # Resolve configuration with environment variable fallback
hosts, username, password, keyspace = resolve_cassandra_config( hosts, username, password, keyspace, replication_factor = resolve_cassandra_config(
host=cassandra_host, host=cassandra_host,
username=cassandra_username, username=cassandra_username,
password=cassandra_password, password=cassandra_password,
default_keyspace="config" default_keyspace="config"
) )
# Store resolved configuration # Store resolved configuration
self.cassandra_host = hosts self.cassandra_host = hosts
self.cassandra_username = username self.cassandra_username = username
@ -99,23 +114,23 @@ class Processor(AsyncProcessor):
processor = self.id, flow = None, name = "config-push" processor = self.id, flow = None, name = "config-push"
) )
self.config_request_topic = config_request_queue self.config_request_queue_base = config_request_queue
self.config_request_subscriber = id self.config_request_subscriber = id
self.config_request_consumer = Consumer( self.system_consumer = Consumer(
taskgroup = self.taskgroup, taskgroup = self.taskgroup,
backend = self.pubsub, backend = self.pubsub,
flow = None, flow = None,
topic = config_request_queue, topic = config_request_queue,
subscriber = id, subscriber = id,
schema = ConfigRequest, schema = ConfigRequest,
handler = self.on_config_request, handler = self.on_system_config_request,
metrics = config_request_metrics, metrics = config_request_metrics,
) )
self.config_response_producer = Producer( self.config_response_producer = Producer(
backend = self.pubsub, backend = self.pubsub,
topic = config_response_queue, topic = self.config_response_queue_base,
schema = ConfigResponse, schema = ConfigResponse,
metrics = config_response_metrics, metrics = config_response_metrics,
) )
@ -132,23 +147,145 @@ class Processor(AsyncProcessor):
username = self.cassandra_username, username = self.cassandra_username,
password = self.cassandra_password, password = self.cassandra_password,
keyspace = keyspace, keyspace = keyspace,
replication_factor = replication_factor,
push = self.push push = self.push
) )
self.workspace_consumers = {}
self.register_workspace_handler(self._handle_workspace_changes)
logger.info("Config service initialized") logger.info("Config service initialized")
async def _discover_workspaces(self):
logger.info("Discovering workspaces from Cassandra...")
try:
workspaces = await self.config.table_store.get_keys(
WORKSPACES_NAMESPACE, WORKSPACE_TYPE
)
logger.info(f"Discovered workspaces: {workspaces}")
except Exception as e:
logger.error(
f"Workspace discovery failed: {e}", exc_info=True
)
return
for workspace_id in workspaces:
if workspace_id not in self.workspace_consumers:
await self._add_workspace_consumer(workspace_id)
async def _handle_workspace_changes(self, workspace_changes):
for workspace_id in workspace_changes.created:
if workspace_id not in self.workspace_consumers:
logger.info(f"Workspace created: {workspace_id}")
await self._add_workspace_consumer(workspace_id)
await self._provision_workspace(workspace_id)
for workspace_id in workspace_changes.deleted:
if workspace_id in self.workspace_consumers:
logger.info(f"Workspace deleted: {workspace_id}")
await self._remove_workspace_consumer(workspace_id)
async def _provision_workspace(self, workspace_id):
try:
written = await self.config.provision_from_template(
workspace_id
)
if written > 0:
logger.info(
f"Provisioned workspace {workspace_id} with "
f"{written} entries from template"
)
# Notify other services about the new config
types = {}
template = await self.config.get_config(workspace_id)
for t in template:
types[t] = [workspace_id]
await self.push(changes=types)
except Exception as e:
logger.error(
f"Failed to provision workspace {workspace_id}: {e}",
exc_info=True,
)
async def _add_workspace_consumer(self, workspace_id):
req_queue = workspace_queue(
self.config_request_queue_base, workspace_id,
)
resp_queue = workspace_queue(
self.config_response_queue_base, workspace_id,
)
await self.pubsub.ensure_topic(req_queue)
await self.pubsub.ensure_topic(resp_queue)
response_producer = Producer(
backend=self.pubsub,
topic=resp_queue,
schema=ConfigResponse,
metrics=ProducerMetrics(
processor=self.id, flow=None,
name=f"config-response-{workspace_id}",
),
)
consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=req_queue,
subscriber=self.id,
schema=ConfigRequest,
handler=partial(
self.on_workspace_config_request,
workspace=workspace_id,
),
metrics=ConsumerMetrics(
processor=self.id, flow=None,
name=f"config-request-{workspace_id}",
),
)
await response_producer.start()
await consumer.start()
self.workspace_consumers[workspace_id] = {
"consumer": consumer,
"response": response_producer,
}
logger.info(
f"Subscribed to workspace config queue: {workspace_id}"
)
async def _remove_workspace_consumer(self, workspace_id):
clients = self.workspace_consumers.pop(workspace_id, None)
if clients:
for client in clients.values():
await client.stop()
logger.info(
f"Unsubscribed from workspace config queue: {workspace_id}"
)
async def start(self): async def start(self):
await self.pubsub.ensure_topic(self.config_request_topic) await self.pubsub.ensure_topic(self.config_request_queue_base)
await self.config_response_producer.start()
await self.push() # Startup poke: empty types = everything await self.push() # Startup poke: empty types = everything
await self.config_request_consumer.start() await self.system_consumer.start()
async def push(self, changes=None): # Start the config push subscriber so we receive our own
# workspace change notifications.
await self.config_sub_task.start()
await self._discover_workspaces()
async def push(self, changes=None, workspace_changes=None):
# Suppress notifications from reserved workspaces (ids starting # Suppress notifications from reserved workspaces (ids starting
# with "_", e.g. "__template__"). Stored config is preserved; # with "_", e.g. "__template__") for regular config changes.
# only the broadcast is filtered. Keeps services oblivious to # The __workspaces__ namespace is handled separately via
# template / bootstrap state. # workspace_changes.
if changes: if changes:
filtered = {} filtered = {}
for type_name, workspaces in changes.items(): for type_name, workspaces in changes.items():
@ -165,16 +302,20 @@ class Processor(AsyncProcessor):
resp = ConfigPush( resp = ConfigPush(
version = version, version = version,
changes = changes or {}, changes = changes or {},
workspace_changes = workspace_changes,
) )
await self.config_push_producer.send(resp) await self.config_push_producer.send(resp)
logger.info( logger.info(
f"Pushed config poke version {version}, " f"Pushed config poke version {version}, "
f"changes={resp.changes}" f"changes={resp.changes}, "
f"workspace_changes={resp.workspace_changes}"
) )
async def on_config_request(self, msg, consumer, flow): async def on_workspace_config_request(
self, msg, consumer, flow, *, workspace
):
try: try:
@ -183,16 +324,51 @@ class Processor(AsyncProcessor):
# Sender-produced ID # Sender-produced ID
id = msg.properties()["id"] id = msg.properties()["id"]
logger.debug(f"Handling config request {id}...") logger.debug(
f"Handling workspace config request {id} "
f"workspace={workspace}..."
)
resp = await self.config.handle(v) producer = self.workspace_consumers[workspace]["response"]
resp = await self.config.handle_workspace(v, workspace)
await producer.send(
resp, properties={"id": id}
)
except Exception as e:
resp = ConfigResponse(
error=Error(
type = "config-error",
message = str(e),
),
)
await producer.send(
resp, properties={"id": id}
)
async def on_system_config_request(self, msg, consumer, flow):
try:
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
logger.debug(f"Handling system config request {id}...")
resp = await self.config.handle_system(v)
await self.config_response_producer.send( await self.config_response_producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )
except Exception as e: except Exception as e:
resp = ConfigResponse( resp = ConfigResponse(
error=Error( error=Error(
type = "config-error", type = "config-error",
@ -228,4 +404,3 @@ class Processor(AsyncProcessor):
def run(): def run():
Processor.launch(default_ident, __doc__) Processor.launch(default_ident, __doc__)

View file

@ -29,12 +29,12 @@ class KnowledgeManager:
self.background_task = None self.background_task = None
self.flow_config = flow_config self.flow_config = flow_config
async def delete_kg_core(self, request, respond): async def delete_kg_core(self, request, respond, workspace):
logger.info("Deleting knowledge core...") logger.info("Deleting knowledge core...")
await self.table_store.delete_kg_core( await self.table_store.delete_kg_core(
request.workspace, request.id workspace, request.id
) )
await respond( await respond(
@ -47,7 +47,7 @@ class KnowledgeManager:
) )
) )
async def get_kg_core(self, request, respond): async def get_kg_core(self, request, respond, workspace):
logger.info("Getting knowledge core...") logger.info("Getting knowledge core...")
@ -62,9 +62,8 @@ class KnowledgeManager:
) )
) )
# Remove doc table row
await self.table_store.get_triples( await self.table_store.get_triples(
request.workspace, workspace,
request.id, request.id,
publish_triples, publish_triples,
) )
@ -80,9 +79,8 @@ class KnowledgeManager:
) )
) )
# Remove doc table row
await self.table_store.get_graph_embeddings( await self.table_store.get_graph_embeddings(
request.workspace, workspace,
request.id, request.id,
publish_ge, publish_ge,
) )
@ -99,9 +97,9 @@ class KnowledgeManager:
) )
) )
async def list_kg_cores(self, request, respond): async def list_kg_cores(self, request, respond, workspace):
ids = await self.table_store.list_kg_cores(request.workspace) ids = await self.table_store.list_kg_cores(workspace)
await respond( await respond(
KnowledgeResponse( KnowledgeResponse(
@ -113,9 +111,7 @@ class KnowledgeManager:
) )
) )
async def put_kg_core(self, request, respond): async def put_kg_core(self, request, respond, workspace):
workspace = request.workspace
if request.triples: if request.triples:
await self.table_store.add_triples(workspace, request.triples) await self.table_store.add_triples(workspace, request.triples)
@ -135,20 +131,18 @@ class KnowledgeManager:
) )
) )
async def load_kg_core(self, request, respond): async def load_kg_core(self, request, respond, workspace):
if self.background_task is None: if self.background_task is None:
self.background_task = asyncio.create_task( self.background_task = asyncio.create_task(
self.core_loader() self.core_loader()
) )
# Wait for it to start (yuck)
# await asyncio.sleep(0.5)
await self.loader_queue.put((request, respond)) await self.loader_queue.put((request, respond, workspace))
# Not sending a response, the loader thread can do that # Not sending a response, the loader thread can do that
async def unload_kg_core(self, request, respond): async def unload_kg_core(self, request, respond, workspace):
await respond( await respond(
KnowledgeResponse( KnowledgeResponse(
@ -169,7 +163,7 @@ class KnowledgeManager:
while True: while True:
logger.debug("Waiting for next load...") logger.debug("Waiting for next load...")
request, respond = await self.loader_queue.get() request, respond, workspace = await self.loader_queue.get()
logger.info(f"Loading knowledge: {request.id}") logger.info(f"Loading knowledge: {request.id}")
@ -181,7 +175,6 @@ class KnowledgeManager:
if request.flow is None: if request.flow is None:
raise RuntimeError("Flow ID must be specified") raise RuntimeError("Flow ID must be specified")
workspace = request.workspace
ws_flows = self.flow_config.flows.get(workspace, {}) ws_flows = self.flow_config.flows.get(workspace, {})
if request.flow not in ws_flows: if request.flow not in ws_flows:
raise RuntimeError( raise RuntimeError(
@ -263,9 +256,8 @@ class KnowledgeManager:
logger.debug("Publishing triples...") logger.debug("Publishing triples...")
# Remove doc table row
await self.table_store.get_triples( await self.table_store.get_triples(
request.workspace, workspace,
request.id, request.id,
publish_triples, publish_triples,
) )
@ -278,9 +270,8 @@ class KnowledgeManager:
logger.debug("Publishing graph embeddings...") logger.debug("Publishing graph embeddings...")
# Remove doc table row
await self.table_store.get_graph_embeddings( await self.table_store.get_graph_embeddings(
request.workspace, workspace,
request.id, request.id,
publish_ge, publish_ge,
) )

View file

@ -9,7 +9,7 @@ import base64
import json import json
import logging import logging
from .. base import AsyncProcessor, Consumer, Producer, Publisher, Subscriber from .. base import WorkspaceProcessor, Consumer, Producer, Publisher, Subscriber
from .. base import ConsumerMetrics, ProducerMetrics from .. base import ConsumerMetrics, ProducerMetrics
from .. base.cassandra_config import add_cassandra_args, resolve_cassandra_config from .. base.cassandra_config import add_cassandra_args, resolve_cassandra_config
@ -33,17 +33,22 @@ default_knowledge_response_queue = knowledge_response_queue
default_cassandra_host = "cassandra" default_cassandra_host = "cassandra"
class Processor(AsyncProcessor):
def workspace_queue(base_queue, workspace):
return f"{base_queue}:{workspace}"
class Processor(WorkspaceProcessor):
def __init__(self, **params): def __init__(self, **params):
id = params.get("id") id = params.get("id")
knowledge_request_queue = params.get( self.knowledge_request_queue_base = params.get(
"knowledge_request_queue", default_knowledge_request_queue "knowledge_request_queue", default_knowledge_request_queue
) )
knowledge_response_queue = params.get( self.knowledge_response_queue_base = params.get(
"knowledge_response_queue", default_knowledge_response_queue "knowledge_response_queue", default_knowledge_response_queue
) )
@ -51,78 +56,106 @@ class Processor(AsyncProcessor):
cassandra_username = params.get("cassandra_username") cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password") cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback hosts, username, password, keyspace, replication_factor = resolve_cassandra_config(
hosts, username, password, keyspace = resolve_cassandra_config(
host=cassandra_host, host=cassandra_host,
username=cassandra_username, username=cassandra_username,
password=cassandra_password, password=cassandra_password,
default_keyspace="knowledge" default_keyspace="knowledge"
) )
# Store resolved configuration
self.cassandra_host = hosts self.cassandra_host = hosts
self.cassandra_username = username self.cassandra_username = username
self.cassandra_password = password self.cassandra_password = password
super(Processor, self).__init__( super(Processor, self).__init__(
**params | { **params | {
"knowledge_request_queue": knowledge_request_queue, "knowledge_request_queue": self.knowledge_request_queue_base,
"knowledge_response_queue": knowledge_response_queue, "knowledge_response_queue": self.knowledge_response_queue_base,
"cassandra_host": self.cassandra_host, "cassandra_host": self.cassandra_host,
"cassandra_username": self.cassandra_username, "cassandra_username": self.cassandra_username,
"cassandra_password": self.cassandra_password, "cassandra_password": self.cassandra_password,
} }
) )
knowledge_request_metrics = ConsumerMetrics(
processor = self.id, flow = None, name = "knowledge-request"
)
knowledge_response_metrics = ProducerMetrics(
processor = self.id, flow = None, name = "knowledge-response"
)
self.knowledge_request_topic = knowledge_request_queue
self.knowledge_request_subscriber = id
self.knowledge_request_consumer = Consumer(
taskgroup = self.taskgroup,
backend = self.pubsub,
flow = None,
topic = knowledge_request_queue,
subscriber = id,
schema = KnowledgeRequest,
handler = self.on_knowledge_request,
metrics = knowledge_request_metrics,
)
self.knowledge_response_producer = Producer(
backend = self.pubsub,
topic = knowledge_response_queue,
schema = KnowledgeResponse,
metrics = knowledge_response_metrics,
)
self.knowledge = KnowledgeManager( self.knowledge = KnowledgeManager(
cassandra_host = self.cassandra_host, cassandra_host = self.cassandra_host,
cassandra_username = self.cassandra_username, cassandra_username = self.cassandra_username,
cassandra_password = self.cassandra_password, cassandra_password = self.cassandra_password,
keyspace = keyspace, keyspace = keyspace,
flow_config = self, flow_config = self,
replication_factor = replication_factor,
) )
self.register_config_handler(self.on_knowledge_config, types=["flow"]) self.register_config_handler(self.on_knowledge_config, types=["flow"])
self.flows = {} self.flows = {}
self.workspace_consumers = {}
logger.info("Knowledge service initialized") logger.info("Knowledge service initialized")
async def on_workspace_created(self, workspace):
if workspace in self.workspace_consumers:
return
req_queue = workspace_queue(
self.knowledge_request_queue_base, workspace,
)
resp_queue = workspace_queue(
self.knowledge_response_queue_base, workspace,
)
await self.pubsub.ensure_topic(req_queue)
await self.pubsub.ensure_topic(resp_queue)
response_producer = Producer(
backend=self.pubsub,
topic=resp_queue,
schema=KnowledgeResponse,
metrics=ProducerMetrics(
processor=self.id, flow=None,
name=f"knowledge-response-{workspace}",
),
)
consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=req_queue,
subscriber=self.id,
schema=KnowledgeRequest,
handler=partial(
self.on_knowledge_request, workspace=workspace,
),
metrics=ConsumerMetrics(
processor=self.id, flow=None,
name=f"knowledge-request-{workspace}",
),
)
await response_producer.start()
await consumer.start()
self.workspace_consumers[workspace] = {
"consumer": consumer,
"response": response_producer,
}
logger.info(f"Subscribed to workspace queue: {workspace}")
async def on_workspace_deleted(self, workspace):
clients = self.workspace_consumers.pop(workspace, None)
if clients:
for client in clients.values():
await client.stop()
logger.info(f"Unsubscribed from workspace queue: {workspace}")
async def start(self): async def start(self):
await self.pubsub.ensure_topic(self.knowledge_request_topic)
await super(Processor, self).start() await super(Processor, self).start()
await self.knowledge_request_consumer.start()
await self.knowledge_response_producer.start()
async def on_knowledge_config(self, workspace, config, version): async def on_knowledge_config(self, workspace, config, version):
@ -140,7 +173,7 @@ class Processor(AsyncProcessor):
logger.debug(f"Flows for {workspace}: {self.flows[workspace]}") logger.debug(f"Flows for {workspace}: {self.flows[workspace]}")
async def process_request(self, v, id): async def process_request(self, v, id, workspace, producer):
if v.operation is None: if v.operation is None:
raise RequestError("Null operation") raise RequestError("Null operation")
@ -160,12 +193,12 @@ class Processor(AsyncProcessor):
raise RequestError(f"Invalid operation: {v.operation}") raise RequestError(f"Invalid operation: {v.operation}")
async def respond(x): async def respond(x):
await self.knowledge_response_producer.send( await producer.send(
x, { "id": id } x, { "id": id }
) )
return await impls[v.operation](v, respond) return await impls[v.operation](v, respond, workspace)
async def on_knowledge_request(self, msg, consumer, flow): async def on_knowledge_request(self, msg, consumer, flow, *, workspace):
v = msg.value() v = msg.value()
@ -175,11 +208,13 @@ class Processor(AsyncProcessor):
logger.info(f"Handling knowledge input {id}...") logger.info(f"Handling knowledge input {id}...")
producer = self.workspace_consumers[workspace]["response"]
try: try:
# We don't send a response back here, the processing # We don't send a response back here, the processing
# implementation sends whatever it needs to send. # implementation sends whatever it needs to send.
await self.process_request(v, id) await self.process_request(v, id, workspace, producer)
return return
@ -191,7 +226,7 @@ class Processor(AsyncProcessor):
) )
) )
await self.knowledge_response_producer.send( await producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )
@ -204,7 +239,7 @@ class Processor(AsyncProcessor):
) )
) )
await self.knowledge_response_producer.send( await producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )
@ -215,7 +250,7 @@ class Processor(AsyncProcessor):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
AsyncProcessor.add_args(parser) WorkspaceProcessor.add_args(parser)
parser.add_argument( parser.add_argument(
'--knowledge-request-queue', '--knowledge-request-queue',

View file

@ -16,9 +16,8 @@ import os
from mistralai import Mistral from mistralai import Mistral
from ... schema import Document, TextDocument, Metadata from ... schema import Document, TextDocument, Metadata
from ... schema import librarian_request_queue, librarian_response_queue
from ... schema import Triples from ... schema import Triples
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianSpec
from ... provenance import ( from ... provenance import (
document_uri, page_uri as make_page_uri, derived_entity_triples, document_uri, page_uri as make_page_uri, derived_entity_triples,
@ -36,9 +35,6 @@ COMPONENT_VERSION = "1.0.0"
default_ident = "document-decoder" default_ident = "document-decoder"
default_api_key = os.getenv("MISTRAL_TOKEN") default_api_key = os.getenv("MISTRAL_TOKEN")
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
pages_per_chunk = 5 pages_per_chunk = 5
def chunks(lst, n): def chunks(lst, n):
@ -98,9 +94,8 @@ class Processor(FlowProcessor):
) )
) )
# Librarian client self.register_specification(
self.librarian = LibrarianClient( LibrarianSpec()
id=id, backend=self.pubsub, taskgroup=self.taskgroup,
) )
if api_key is None: if api_key is None:
@ -113,10 +108,6 @@ class Processor(FlowProcessor):
logger.info("Mistral OCR processor initialized") logger.info("Mistral OCR processor initialized")
async def start(self):
await super(Processor, self).start()
await self.librarian.start()
def ocr(self, blob): def ocr(self, blob):
""" """
Run Mistral OCR on a PDF blob, returning per-page markdown strings. Run Mistral OCR on a PDF blob, returning per-page markdown strings.
@ -198,9 +189,9 @@ class Processor(FlowProcessor):
# Check MIME type if fetching from librarian # Check MIME type if fetching from librarian
if v.document_id: if v.document_id:
doc_meta = await self.librarian.fetch_document_metadata( doc_meta = await flow.librarian.fetch_document_metadata(
document_id=v.document_id, document_id=v.document_id,
workspace=flow.workspace,
) )
if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf": if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf":
logger.error( logger.error(
@ -213,9 +204,9 @@ class Processor(FlowProcessor):
# Get PDF content - fetch from librarian or use inline data # Get PDF content - fetch from librarian or use inline data
if v.document_id: if v.document_id:
logger.info(f"Fetching document {v.document_id} from librarian...") logger.info(f"Fetching document {v.document_id} from librarian...")
content = await self.librarian.fetch_document_content( content = await flow.librarian.fetch_document_content(
document_id=v.document_id, document_id=v.document_id,
workspace=flow.workspace,
) )
if isinstance(content, str): if isinstance(content, str):
content = content.encode('utf-8') content = content.encode('utf-8')
@ -240,10 +231,10 @@ class Processor(FlowProcessor):
page_content = markdown.encode("utf-8") page_content = markdown.encode("utf-8")
# Save page as child document in librarian # Save page as child document in librarian
await self.librarian.save_child_document( await flow.librarian.save_child_document(
doc_id=page_doc_id, doc_id=page_doc_id,
parent_id=source_doc_id, parent_id=source_doc_id,
workspace=flow.workspace,
content=page_content, content=page_content,
document_type="page", document_type="page",
title=f"Page {page_num}", title=f"Page {page_num}",
@ -297,18 +288,6 @@ class Processor(FlowProcessor):
help=f'Mistral API Key' help=f'Mistral API Key'
) )
parser.add_argument(
'--librarian-request-queue',
default=default_librarian_request_queue,
help=f'Librarian request queue (default: {default_librarian_request_queue})',
)
parser.add_argument(
'--librarian-response-queue',
default=default_librarian_response_queue,
help=f'Librarian response queue (default: {default_librarian_response_queue})',
)
def run(): def run():
Processor.launch(default_ident, __doc__) Processor.launch(default_ident, __doc__)

View file

@ -12,9 +12,8 @@ import tempfile
import base64 import base64
import logging import logging
from ... schema import Document, TextDocument, Metadata from ... schema import Document, TextDocument, Metadata
from ... schema import librarian_request_queue, librarian_response_queue
from ... schema import Triples from ... schema import Triples
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianSpec
PyPDFLoader = None PyPDFLoader = None
@ -32,9 +31,6 @@ logger = logging.getLogger(__name__)
default_ident = "document-decoder" default_ident = "document-decoder"
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class Processor(FlowProcessor): class Processor(FlowProcessor):
@ -70,17 +66,12 @@ class Processor(FlowProcessor):
) )
) )
# Librarian client self.register_specification(
self.librarian = LibrarianClient( LibrarianSpec()
id=id, backend=self.pubsub, taskgroup=self.taskgroup,
) )
logger.info("PDF decoder initialized") logger.info("PDF decoder initialized")
async def start(self):
await super(Processor, self).start()
await self.librarian.start()
async def on_message(self, msg, consumer, flow): async def on_message(self, msg, consumer, flow):
logger.debug("PDF message received") logger.debug("PDF message received")
@ -91,9 +82,9 @@ class Processor(FlowProcessor):
# Check MIME type if fetching from librarian # Check MIME type if fetching from librarian
if v.document_id: if v.document_id:
doc_meta = await self.librarian.fetch_document_metadata( doc_meta = await flow.librarian.fetch_document_metadata(
document_id=v.document_id, document_id=v.document_id,
workspace=flow.workspace,
) )
if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf": if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf":
logger.error( logger.error(
@ -112,9 +103,9 @@ class Processor(FlowProcessor):
logger.info(f"Fetching document {v.document_id} from librarian...") logger.info(f"Fetching document {v.document_id} from librarian...")
fp.close() fp.close()
content = await self.librarian.fetch_document_content( content = await flow.librarian.fetch_document_content(
document_id=v.document_id, document_id=v.document_id,
workspace=flow.workspace,
) )
# Content is base64 encoded # Content is base64 encoded
@ -154,10 +145,10 @@ class Processor(FlowProcessor):
page_content = page.page_content.encode("utf-8") page_content = page.page_content.encode("utf-8")
# Save page as child document in librarian # Save page as child document in librarian
await self.librarian.save_child_document( await flow.librarian.save_child_document(
doc_id=page_doc_id, doc_id=page_doc_id,
parent_id=source_doc_id, parent_id=source_doc_id,
workspace=flow.workspace,
content=page_content, content=page_content,
document_type="page", document_type="page",
title=f"Page {page_num}", title=f"Page {page_num}",
@ -210,18 +201,6 @@ class Processor(FlowProcessor):
def add_args(parser): def add_args(parser):
FlowProcessor.add_args(parser) FlowProcessor.add_args(parser)
parser.add_argument(
'--librarian-request-queue',
default=default_librarian_request_queue,
help=f'Librarian request queue (default: {default_librarian_request_queue})',
)
parser.add_argument(
'--librarian-response-queue',
default=default_librarian_response_queue,
help=f'Librarian response queue (default: {default_librarian_response_queue})',
)
def run(): def run():
Processor.launch(default_ident, __doc__) Processor.launch(default_ident, __doc__)

View file

@ -118,10 +118,10 @@ class FlowConfig:
return resolved return resolved
async def handle_list_blueprints(self, msg): async def handle_list_blueprints(self, msg, workspace):
names = list(await self.config.keys( names = list(await self.config.keys(
msg.workspace, "flow-blueprint" workspace, "flow-blueprint"
)) ))
return FlowResponse( return FlowResponse(
@ -129,19 +129,19 @@ class FlowConfig:
blueprint_names = names, blueprint_names = names,
) )
async def handle_get_blueprint(self, msg): async def handle_get_blueprint(self, msg, workspace):
return FlowResponse( return FlowResponse(
error = None, error = None,
blueprint_definition = await self.config.get( blueprint_definition = await self.config.get(
msg.workspace, "flow-blueprint", msg.blueprint_name workspace, "flow-blueprint", msg.blueprint_name
), ),
) )
async def handle_put_blueprint(self, msg): async def handle_put_blueprint(self, msg, workspace):
await self.config.put( await self.config.put(
msg.workspace, "flow-blueprint", workspace, "flow-blueprint",
msg.blueprint_name, msg.blueprint_definition msg.blueprint_name, msg.blueprint_definition
) )
@ -149,31 +149,31 @@ class FlowConfig:
error = None, error = None,
) )
async def handle_delete_blueprint(self, msg): async def handle_delete_blueprint(self, msg, workspace):
logger.debug(f"Flow config message: {msg}") logger.debug(f"Flow config message: {msg}")
await self.config.delete( await self.config.delete(
msg.workspace, "flow-blueprint", msg.blueprint_name workspace, "flow-blueprint", msg.blueprint_name
) )
return FlowResponse( return FlowResponse(
error = None, error = None,
) )
async def handle_list_flows(self, msg): async def handle_list_flows(self, msg, workspace):
names = list(await self.config.keys(msg.workspace, "flow")) names = list(await self.config.keys(workspace, "flow"))
return FlowResponse( return FlowResponse(
error = None, error = None,
flow_ids = names, flow_ids = names,
) )
async def handle_get_flow(self, msg): async def handle_get_flow(self, msg, workspace):
flow_data = await self.config.get( flow_data = await self.config.get(
msg.workspace, "flow", msg.flow_id workspace, "flow", msg.flow_id
) )
flow = json.loads(flow_data) flow = json.loads(flow_data)
@ -184,9 +184,7 @@ class FlowConfig:
parameters = flow.get("parameters", {}), parameters = flow.get("parameters", {}),
) )
async def handle_start_flow(self, msg): async def handle_start_flow(self, msg, workspace):
workspace = msg.workspace
if msg.blueprint_name is None: if msg.blueprint_name is None:
raise RuntimeError("No blueprint name") raise RuntimeError("No blueprint name")
@ -222,7 +220,7 @@ class FlowConfig:
logger.debug(f"Resolved parameters (with defaults): {parameters}") logger.debug(f"Resolved parameters (with defaults): {parameters}")
# Apply parameter substitution to template replacement function. # Apply parameter substitution to template replacement function.
# {workspace} is substituted from msg.workspace to isolate # {workspace} is substituted from workspace to isolate
# queue names across workspaces. # queue names across workspaces.
def repl_template_with_params(tmp): def repl_template_with_params(tmp):
@ -548,9 +546,7 @@ class FlowConfig:
f"attempts: {topic}" f"attempts: {topic}"
) )
async def handle_stop_flow(self, msg): async def handle_stop_flow(self, msg, workspace):
workspace = msg.workspace
if msg.flow_id is None: if msg.flow_id is None:
raise RuntimeError("No flow ID") raise RuntimeError("No flow ID")
@ -641,37 +637,29 @@ class FlowConfig:
error = None, error = None,
) )
async def handle(self, msg): async def handle(self, msg, workspace):
logger.debug( logger.debug(
f"Handling flow message: {msg.operation} " f"Handling flow message: {msg.operation} "
f"workspace={msg.workspace}" f"workspace={workspace}"
) )
if not msg.workspace:
return FlowResponse(
error=Error(
type="bad-request",
message="Workspace is required",
),
)
if msg.operation == "list-blueprints": if msg.operation == "list-blueprints":
resp = await self.handle_list_blueprints(msg) resp = await self.handle_list_blueprints(msg, workspace)
elif msg.operation == "get-blueprint": elif msg.operation == "get-blueprint":
resp = await self.handle_get_blueprint(msg) resp = await self.handle_get_blueprint(msg, workspace)
elif msg.operation == "put-blueprint": elif msg.operation == "put-blueprint":
resp = await self.handle_put_blueprint(msg) resp = await self.handle_put_blueprint(msg, workspace)
elif msg.operation == "delete-blueprint": elif msg.operation == "delete-blueprint":
resp = await self.handle_delete_blueprint(msg) resp = await self.handle_delete_blueprint(msg, workspace)
elif msg.operation == "list-flows": elif msg.operation == "list-flows":
resp = await self.handle_list_flows(msg) resp = await self.handle_list_flows(msg, workspace)
elif msg.operation == "get-flow": elif msg.operation == "get-flow":
resp = await self.handle_get_flow(msg) resp = await self.handle_get_flow(msg, workspace)
elif msg.operation == "start-flow": elif msg.operation == "start-flow":
resp = await self.handle_start_flow(msg) resp = await self.handle_start_flow(msg, workspace)
elif msg.operation == "stop-flow": elif msg.operation == "stop-flow":
resp = await self.handle_stop_flow(msg) resp = await self.handle_stop_flow(msg, workspace)
else: else:
resp = FlowResponse( resp = FlowResponse(

View file

@ -4,6 +4,7 @@ Flow service. Manages flow lifecycle — starting and stopping flows
by coordinating with the config service via pub/sub. by coordinating with the config service via pub/sub.
""" """
from functools import partial
import logging import logging
import uuid import uuid
@ -14,7 +15,7 @@ from trustgraph.schema import flow_request_queue, flow_response_queue
from trustgraph.schema import ConfigRequest, ConfigResponse from trustgraph.schema import ConfigRequest, ConfigResponse
from trustgraph.schema import config_request_queue, config_response_queue from trustgraph.schema import config_request_queue, config_response_queue
from trustgraph.base import AsyncProcessor, Consumer, Producer from trustgraph.base import WorkspaceProcessor, Consumer, Producer
from trustgraph.base import ConsumerMetrics, ProducerMetrics, SubscriberMetrics from trustgraph.base import ConsumerMetrics, ProducerMetrics, SubscriberMetrics
from trustgraph.base import ConfigClient from trustgraph.base import ConfigClient
@ -29,14 +30,18 @@ default_flow_request_queue = flow_request_queue
default_flow_response_queue = flow_response_queue default_flow_response_queue = flow_response_queue
class Processor(AsyncProcessor): def workspace_queue(base_queue, workspace):
return f"{base_queue}:{workspace}"
class Processor(WorkspaceProcessor):
def __init__(self, **params): def __init__(self, **params):
flow_request_queue = params.get( self.flow_request_queue_base = params.get(
"flow_request_queue", default_flow_request_queue "flow_request_queue", default_flow_request_queue
) )
flow_response_queue = params.get( self.flow_response_queue_base = params.get(
"flow_response_queue", default_flow_response_queue "flow_response_queue", default_flow_response_queue
) )
@ -49,34 +54,6 @@ class Processor(AsyncProcessor):
} }
) )
flow_request_metrics = ConsumerMetrics(
processor = self.id, flow = None, name = "flow-request"
)
flow_response_metrics = ProducerMetrics(
processor = self.id, flow = None, name = "flow-response"
)
self.flow_request_topic = flow_request_queue
self.flow_request_subscriber = id
self.flow_request_consumer = Consumer(
taskgroup = self.taskgroup,
backend = self.pubsub,
flow = None,
topic = flow_request_queue,
subscriber = id,
schema = FlowRequest,
handler = self.on_flow_request,
metrics = flow_request_metrics,
)
self.flow_response_producer = Producer(
backend = self.pubsub,
topic = flow_response_queue,
schema = FlowResponse,
metrics = flow_response_metrics,
)
config_req_metrics = ProducerMetrics( config_req_metrics = ProducerMetrics(
processor=self.id, flow=None, name="config-request", processor=self.id, flow=None, name="config-request",
) )
@ -84,13 +61,6 @@ class Processor(AsyncProcessor):
processor=self.id, flow=None, name="config-response", processor=self.id, flow=None, name="config-response",
) )
# Unique subscription suffix per process instance. Pulsar's
# exclusive subscriptions reject a second consumer on the same
# (topic, subscription-name) — so a deterministic name here
# collides with its own ghost when the supervisor restarts the
# process before Pulsar has timed out the previous session
# (ConsumerBusy). Matches the uuid convention used elsewhere
# (gateway/config/receiver.py, AsyncProcessor._create_config_client).
config_rr_id = str(uuid.uuid4()) config_rr_id = str(uuid.uuid4())
self.config_client = ConfigClient( self.config_client = ConfigClient(
backend=self.pubsub, backend=self.pubsub,
@ -106,21 +76,78 @@ class Processor(AsyncProcessor):
self.flow = FlowConfig(self.config_client, self.pubsub) self.flow = FlowConfig(self.config_client, self.pubsub)
self.workspace_consumers = {}
logger.info("Flow service initialized") logger.info("Flow service initialized")
async def on_workspace_created(self, workspace):
if workspace in self.workspace_consumers:
return
req_queue = workspace_queue(
self.flow_request_queue_base, workspace,
)
resp_queue = workspace_queue(
self.flow_response_queue_base, workspace,
)
await self.pubsub.ensure_topic(req_queue)
await self.pubsub.ensure_topic(resp_queue)
response_producer = Producer(
backend=self.pubsub,
topic=resp_queue,
schema=FlowResponse,
metrics=ProducerMetrics(
processor=self.id, flow=None,
name=f"flow-response-{workspace}",
),
)
consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=req_queue,
subscriber=self.id,
schema=FlowRequest,
handler=partial(
self.on_flow_request, workspace=workspace,
),
metrics=ConsumerMetrics(
processor=self.id, flow=None,
name=f"flow-request-{workspace}",
),
)
await response_producer.start()
await consumer.start()
self.workspace_consumers[workspace] = {
"consumer": consumer,
"response": response_producer,
}
logger.info(f"Subscribed to workspace queue: {workspace}")
async def on_workspace_deleted(self, workspace):
clients = self.workspace_consumers.pop(workspace, None)
if clients:
for client in clients.values():
await client.stop()
logger.info(f"Unsubscribed from workspace queue: {workspace}")
async def start(self): async def start(self):
await self.pubsub.ensure_topic(self.flow_request_topic) await super(Processor, self).start()
await self.config_client.start() await self.config_client.start()
# Discover workspaces with existing flow config and ensure
# their topics exist before we start accepting requests.
workspaces = await self.config_client.workspaces_for_type("flow") workspaces = await self.config_client.workspaces_for_type("flow")
await self.flow.ensure_existing_flow_topics(workspaces) await self.flow.ensure_existing_flow_topics(workspaces)
await self.flow_request_consumer.start() async def on_flow_request(self, msg, consumer, flow, *, workspace):
async def on_flow_request(self, msg, consumer, flow):
try: try:
@ -131,9 +158,11 @@ class Processor(AsyncProcessor):
logger.debug(f"Handling flow request {id}...") logger.debug(f"Handling flow request {id}...")
resp = await self.flow.handle(v) producer = self.workspace_consumers[workspace]["response"]
await self.flow_response_producer.send( resp = await self.flow.handle(v, workspace)
await producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )
@ -148,14 +177,14 @@ class Processor(AsyncProcessor):
), ),
) )
await self.flow_response_producer.send( await producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
AsyncProcessor.add_args(parser) WorkspaceProcessor.add_args(parser)
parser.add_argument( parser.add_argument(
'--flow-request-queue', '--flow-request-queue',

View file

@ -141,6 +141,12 @@ class IamAuth:
self._authz_cache: dict[str, tuple[bool, float]] = {} self._authz_cache: dict[str, tuple[bool, float]] = {}
self._authz_cache_lock = asyncio.Lock() self._authz_cache_lock = asyncio.Lock()
# Known workspaces, maintained by the config receiver.
# enforce_workspace checks this set to reject requests for
# non-existent workspaces before routing to a queue that
# has no consumer.
self.known_workspaces: set[str] = set()
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Short-lived client helper. Mirrors the pattern used by the # Short-lived client helper. Mirrors the pattern used by the
# bootstrap framework and AsyncProcessor: a fresh uuid suffix per # bootstrap framework and AsyncProcessor: a fresh uuid suffix per

View file

@ -67,12 +67,22 @@ async def enforce(request, auth, capability):
return identity return identity
def workspace_not_found():
return web.HTTPNotFound(
text='{"error":"workspace not found"}',
content_type="application/json",
)
async def enforce_workspace(data, identity, auth, capability=None): async def enforce_workspace(data, identity, auth, capability=None):
"""Default-fill the workspace on a request body and (optionally) """Default-fill the workspace on a request body and (optionally)
authorise the caller for ``capability`` against that workspace. authorise the caller for ``capability`` against that workspace.
- Target workspace = ``data["workspace"]`` if supplied, else the - Target workspace = ``data["workspace"]`` if supplied, else the
caller's bound workspace. caller's bound workspace.
- Rejects the request if the resolved workspace is not in
``auth.known_workspaces`` (prevents routing to a queue with
no consumer).
- On success, ``data["workspace"]`` is overwritten with the - On success, ``data["workspace"]`` is overwritten with the
resolved value so downstream code sees a single canonical resolved value so downstream code sees a single canonical
address. address.
@ -92,6 +102,9 @@ async def enforce_workspace(data, identity, auth, capability=None):
target = requested or identity.workspace target = requested or identity.workspace
data["workspace"] = target data["workspace"] = target
if target not in auth.known_workspaces:
raise workspace_not_found()
if capability is not None: if capability is not None:
await auth.authorise( await auth.authorise(
identity, capability, {"workspace": target}, {}, identity, capability, {"workspace": target}, {},

View file

@ -24,9 +24,10 @@ logger.setLevel(logging.INFO)
class ConfigReceiver: class ConfigReceiver:
def __init__(self, backend): def __init__(self, backend, auth=None):
self.backend = backend self.backend = backend
self.auth = auth
self.flow_handlers = [] self.flow_handlers = []
@ -54,6 +55,15 @@ class ConfigReceiver:
) )
return return
# Track workspace lifecycle
if v.workspace_changes and self.auth:
for ws in (v.workspace_changes.created or []):
self.auth.known_workspaces.add(ws)
logger.info(f"Workspace registered: {ws}")
for ws in (v.workspace_changes.deleted or []):
self.auth.known_workspaces.discard(ws)
logger.info(f"Workspace deregistered: {ws}")
# Gateway cares about flow config — check if any flow # Gateway cares about flow config — check if any flow
# types changed in any workspace # types changed in any workspace
flow_workspaces = changes.get("flow", []) flow_workspaces = changes.get("flow", [])
@ -195,6 +205,33 @@ class ConfigReceiver:
try: try:
await client.start() await client.start()
# Discover all known workspaces
ws_resp = await client.request(
ConfigRequest(
operation="getvalues",
workspace="__workspaces__",
type="workspace",
),
timeout=10,
)
if ws_resp.error:
raise RuntimeError(
f"Workspace discovery error: "
f"{ws_resp.error.message}"
)
discovered = {
v.key for v in ws_resp.values if v.key
}
if self.auth:
self.auth.known_workspaces = discovered
logger.info(
f"Known workspaces: {discovered}"
)
# Discover workspaces that have any flow config # Discover workspaces that have any flow config
resp = await client.request( resp = await client.request(
ConfigRequest( ConfigRequest(

View file

@ -7,6 +7,12 @@ import logging
# Module logger # Module logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from ... schema import flow_request_queue, flow_response_queue
from ... schema import librarian_request_queue, librarian_response_queue
from ... schema import knowledge_request_queue, knowledge_response_queue
from ... schema import collection_request_queue, collection_response_queue
from ... schema import config_request_queue, config_response_queue
from . config import ConfigRequestor from . config import ConfigRequestor
from . flow import FlowRequestor from . flow import FlowRequestor
from . iam import IamRequestor from . iam import IamRequestor
@ -70,15 +76,36 @@ request_response_dispatchers = {
"sparql": SparqlQueryRequestor, "sparql": SparqlQueryRequestor,
} }
global_dispatchers = { system_dispatchers = {
"iam": IamRequestor,
}
workspace_dispatchers = {
"config": ConfigRequestor, "config": ConfigRequestor,
"flow": FlowRequestor, "flow": FlowRequestor,
"iam": IamRequestor,
"librarian": LibrarianRequestor, "librarian": LibrarianRequestor,
"knowledge": KnowledgeRequestor, "knowledge": KnowledgeRequestor,
"collection-management": CollectionManagementRequestor, "collection-management": CollectionManagementRequestor,
} }
workspace_default_request_queues = {
"config": config_request_queue,
"flow": flow_request_queue,
"librarian": librarian_request_queue,
"knowledge": knowledge_request_queue,
"collection-management": collection_request_queue,
}
workspace_default_response_queues = {
"config": config_response_queue,
"flow": flow_response_queue,
"librarian": librarian_response_queue,
"knowledge": knowledge_response_queue,
"collection-management": collection_response_queue,
}
global_dispatchers = {**system_dispatchers, **workspace_dispatchers}
sender_dispatchers = { sender_dispatchers = {
"text-load": TextLoad, "text-load": TextLoad,
"document-load": DocumentLoad, "document-load": DocumentLoad,
@ -219,11 +246,24 @@ class DispatcherManager:
async def process_global_service(self, data, responder, params): async def process_global_service(self, data, responder, params):
kind = params.get("kind") kind = params.get("kind")
return await self.invoke_global_service(data, responder, kind) workspace = params.get("workspace")
if not workspace and isinstance(data, dict):
workspace = data.get("workspace")
return await self.invoke_global_service(
data, responder, kind, workspace=workspace,
)
async def invoke_global_service(self, data, responder, kind): async def invoke_global_service(self, data, responder, kind,
workspace=None):
key = (None, kind) if kind in workspace_dispatchers:
if not workspace:
raise RuntimeError(
f"Workspace is required for {kind}"
)
key = (workspace, kind)
else:
key = (None, kind)
if key not in self.dispatchers: if key not in self.dispatchers:
async with self.dispatcher_lock: async with self.dispatcher_lock:
@ -234,11 +274,26 @@ class DispatcherManager:
request_queue = self.queue_overrides[kind].get("request") request_queue = self.queue_overrides[kind].get("request")
response_queue = self.queue_overrides[kind].get("response") response_queue = self.queue_overrides[kind].get("response")
if kind in workspace_dispatchers and workspace:
base_req_queue = (
request_queue
or workspace_default_request_queues[kind]
)
request_queue = f"{base_req_queue}:{workspace}"
base_resp_queue = (
response_queue
or workspace_default_response_queues[kind]
)
response_queue = f"{base_resp_queue}:{workspace}"
consumer_name = f"{self.prefix}-{kind}-{workspace}"
else:
consumer_name = f"{self.prefix}-{kind}-request"
dispatcher = global_dispatchers[kind]( dispatcher = global_dispatchers[kind](
backend = self.backend, backend = self.backend,
timeout = 120, timeout = 120,
consumer = f"{self.prefix}-{kind}-request", consumer = consumer_name,
subscriber = f"{self.prefix}-{kind}-request", subscriber = consumer_name,
request_queue = request_queue, request_queue = request_queue,
response_queue = response_queue, response_queue = response_queue,
) )

View file

@ -190,6 +190,16 @@ class Mux:
await self.auth.authorise( await self.auth.authorise(
self.identity, op.capability, resource, parameters, self.identity, op.capability, resource, parameters,
) )
except _web.HTTPNotFound:
await self.ws.send_json({
"id": request_id,
"error": {
"message": "workspace not found",
"type": "workspace-not-found",
},
"complete": True,
})
return
except _web.HTTPForbidden: except _web.HTTPForbidden:
await self.ws.send_json({ await self.ws.send_json({
"id": request_id, "id": request_id,
@ -310,7 +320,7 @@ class Mux:
else: else:
await self.dispatcher_manager.invoke_global_service( await self.dispatcher_manager.invoke_global_service(
request, responder, svc request, responder, svc, workspace=workspace,
) )
except Exception as e: except Exception as e:

View file

@ -116,9 +116,6 @@ def serialize_document_metadata(message):
if message.metadata: if message.metadata:
ret["metadata"] = serialize_subgraph(message.metadata) ret["metadata"] = serialize_subgraph(message.metadata)
if message.workspace:
ret["workspace"] = message.workspace
if message.tags is not None: if message.tags is not None:
ret["tags"] = message.tags ret["tags"] = message.tags
@ -140,9 +137,6 @@ def serialize_processing_metadata(message):
if message.flow: if message.flow:
ret["flow"] = message.flow ret["flow"] = message.flow
if message.workspace:
ret["workspace"] = message.workspace
if message.collection: if message.collection:
ret["collection"] = message.collection ret["collection"] = message.collection
@ -160,7 +154,6 @@ def to_document_metadata(x):
title = x.get("title", None), title = x.get("title", None),
comments = x.get("comments", None), comments = x.get("comments", None),
metadata = to_subgraph(x["metadata"]), metadata = to_subgraph(x["metadata"]),
workspace = x.get("workspace", None),
tags = x.get("tags", None), tags = x.get("tags", None),
) )
@ -171,7 +164,6 @@ def to_processing_metadata(x):
document_id = x.get("document-id", None), document_id = x.get("document-id", None),
time = x.get("time", None), time = x.get("time", None),
flow = x.get("flow", None), flow = x.get("flow", None),
workspace = x.get("workspace", None),
collection = x.get("collection", None), collection = x.get("collection", None),
tags = x.get("tags", None), tags = x.get("tags", None),
) )

View file

@ -12,8 +12,8 @@ from . auth_endpoints import AuthEndpoints
from . iam_endpoint import IamEndpoint from . iam_endpoint import IamEndpoint
from . registry_endpoint import RegistryRoutedVariableEndpoint from . registry_endpoint import RegistryRoutedVariableEndpoint
from .. capabilities import PUBLIC, AUTHENTICATED, auth_failure from .. capabilities import PUBLIC, AUTHENTICATED, auth_failure, workspace_not_found
from .. registry import lookup as _registry_lookup, RequestContext from .. registry import lookup as _registry_lookup, RequestContext, ResourceLevel
from .. dispatch.manager import DispatcherManager from .. dispatch.manager import DispatcherManager
@ -77,6 +77,10 @@ class _RoutedVariableEndpoint:
identity, op.capability, resource, parameters, identity, op.capability, resource, parameters,
) )
ws = resource.get("workspace", "")
if ws and ws not in self.auth.known_workspaces:
raise workspace_not_found()
async def responder(x, fin): async def responder(x, fin):
pass pass
@ -140,6 +144,11 @@ class _RoutedSocketEndpoint:
await self.auth.authorise( await self.auth.authorise(
identity, op.capability, resource, parameters, identity, op.capability, resource, parameters,
) )
ws = resource.get("workspace", "")
if ws and ws not in self.auth.known_workspaces:
raise workspace_not_found()
except web.HTTPException as e: except web.HTTPException as e:
return e return e

View file

@ -20,9 +20,9 @@ import logging
from aiohttp import web from aiohttp import web
from .. capabilities import ( from .. capabilities import (
PUBLIC, AUTHENTICATED, auth_failure, PUBLIC, AUTHENTICATED, auth_failure, workspace_not_found,
) )
from .. registry import lookup, RequestContext from .. registry import lookup, RequestContext, ResourceLevel
logger = logging.getLogger("registry-endpoint") logger = logging.getLogger("registry-endpoint")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
@ -107,6 +107,15 @@ class RegistryRoutedVariableEndpoint:
if "workspace" in resource: if "workspace" in resource:
body["workspace"] = resource["workspace"] body["workspace"] = resource["workspace"]
if (
op.resource_level in (
ResourceLevel.WORKSPACE, ResourceLevel.FLOW,
)
and resource.get("workspace")
not in self.auth.known_workspaces
):
raise workspace_not_found()
async def responder(x, fin): async def responder(x, fin):
pass pass

View file

@ -68,7 +68,7 @@ class Api:
id=config.get("id", "api-gateway"), id=config.get("id", "api-gateway"),
) )
self.config_receiver = ConfigReceiver(self.pubsub_backend) self.config_receiver = ConfigReceiver(self.pubsub_backend, auth=self.auth)
# Build queue overrides dictionary from CLI arguments # Build queue overrides dictionary from CLI arguments
queue_overrides = {} queue_overrides = {}

View file

@ -246,6 +246,7 @@ class IamService:
def __init__(self, host, username, password, keyspace, def __init__(self, host, username, password, keyspace,
bootstrap_mode, bootstrap_token=None, bootstrap_mode, bootstrap_token=None,
on_workspace_created=None, on_workspace_deleted=None,
replication_factor=1): replication_factor=1):
self.table_store = IamTableStore( self.table_store = IamTableStore(
host, username, password, keyspace, host, username, password, keyspace,
@ -269,6 +270,12 @@ class IamService:
self.bootstrap_mode = bootstrap_mode self.bootstrap_mode = bootstrap_mode
self.bootstrap_token = bootstrap_token self.bootstrap_token = bootstrap_token
# Callbacks for workspace lifecycle events. Called after the
# workspace is created/deleted in IAM's own store so that the
# processor can announce it via the config service.
self._on_workspace_created = on_workspace_created
self._on_workspace_deleted = on_workspace_deleted
self._signing_key = None self._signing_key = None
self._signing_key_lock = asyncio.Lock() self._signing_key_lock = asyncio.Lock()
@ -426,6 +433,9 @@ class IamService:
created=now, created=now,
) )
if self._on_workspace_created:
await self._on_workspace_created(DEFAULT_WORKSPACE)
admin_user_id = str(uuid.uuid4()) admin_user_id = str(uuid.uuid4())
admin_password = secrets.token_urlsafe(32) admin_password = secrets.token_urlsafe(32)
await self.table_store.put_user( await self.table_store.put_user(
@ -893,19 +903,21 @@ class IamService:
"workspace ids beginning with '_' are reserved", "workspace ids beginning with '_' are reserved",
) )
if self._on_workspace_created:
await self._on_workspace_created(v.workspace_record.id)
existing = await self.table_store.get_workspace( existing = await self.table_store.get_workspace(
v.workspace_record.id, v.workspace_record.id,
) )
if existing is not None: if existing is None:
return _err("duplicate", "workspace already exists") now = _now_dt()
await self.table_store.put_workspace(
id=v.workspace_record.id,
name=v.workspace_record.name or v.workspace_record.id,
enabled=v.workspace_record.enabled,
created=now,
)
now = _now_dt()
await self.table_store.put_workspace(
id=v.workspace_record.id,
name=v.workspace_record.name or v.workspace_record.id,
enabled=v.workspace_record.enabled,
created=now,
)
row = await self.table_store.get_workspace(v.workspace_record.id) row = await self.table_store.get_workspace(v.workspace_record.id)
return IamResponse(workspace=self._row_to_workspace_record(row)) return IamResponse(workspace=self._row_to_workspace_record(row))
@ -984,6 +996,9 @@ class IamService:
for kr in key_rows: for kr in key_rows:
await self.table_store.delete_api_key(kr[0]) await self.table_store.delete_api_key(kr[0])
if self._on_workspace_deleted:
await self._on_workspace_deleted(v.workspace_record.id)
return IamResponse() return IamResponse()
# ------------------------------------------------------------------ # ------------------------------------------------------------------

View file

@ -12,9 +12,13 @@ import os
from trustgraph.schema import Error from trustgraph.schema import Error
from trustgraph.schema import IamRequest, IamResponse from trustgraph.schema import IamRequest, IamResponse
from trustgraph.schema import iam_request_queue, iam_response_queue from trustgraph.schema import iam_request_queue, iam_response_queue
from trustgraph.schema import ConfigRequest, ConfigResponse, ConfigValue
from trustgraph.schema import config_request_queue, config_response_queue
from trustgraph.base import AsyncProcessor, Consumer, Producer from trustgraph.base import AsyncProcessor, Consumer, Producer
from trustgraph.base import ConsumerMetrics, ProducerMetrics from trustgraph.base import ConsumerMetrics, ProducerMetrics
from trustgraph.base.metrics import SubscriberMetrics
from trustgraph.base.request_response_spec import RequestResponse
from trustgraph.base.cassandra_config import ( from trustgraph.base.cassandra_config import (
add_cassandra_args, resolve_cassandra_config, add_cassandra_args, resolve_cassandra_config,
) )
@ -92,7 +96,7 @@ class Processor(AsyncProcessor):
cassandra_username = params.get("cassandra_username") cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password") cassandra_password = params.get("cassandra_password")
hosts, username, password, keyspace = resolve_cassandra_config( hosts, username, password, keyspace, replication_factor = resolve_cassandra_config(
host=cassandra_host, host=cassandra_host,
username=cassandra_username, username=cassandra_username,
password=cassandra_password, password=cassandra_password,
@ -145,8 +149,11 @@ class Processor(AsyncProcessor):
username=self.cassandra_username, username=self.cassandra_username,
password=self.cassandra_password, password=self.cassandra_password,
keyspace=keyspace, keyspace=keyspace,
replication_factor=replication_factor,
bootstrap_mode=self.bootstrap_mode, bootstrap_mode=self.bootstrap_mode,
bootstrap_token=self.bootstrap_token, bootstrap_token=self.bootstrap_token,
on_workspace_created=self._ensure_workspace_registered,
on_workspace_deleted=self._announce_workspace_deleted,
) )
logger.info( logger.info(
@ -160,6 +167,81 @@ class Processor(AsyncProcessor):
await self.iam.auto_bootstrap_if_token_mode() await self.iam.auto_bootstrap_if_token_mode()
await self.iam_request_consumer.start() await self.iam_request_consumer.start()
def _create_config_client(self):
import uuid
config_rr_id = str(uuid.uuid4())
config_req_metrics = ProducerMetrics(
processor=self.id, flow=None, name="config-request",
)
config_resp_metrics = SubscriberMetrics(
processor=self.id, flow=None, name="config-response",
)
return RequestResponse(
backend=self.pubsub,
subscription=f"{self.id}--config--{config_rr_id}",
consumer_name=self.id,
request_topic=config_request_queue,
request_schema=ConfigRequest,
request_metrics=config_req_metrics,
response_topic=config_response_queue,
response_schema=ConfigResponse,
response_metrics=config_resp_metrics,
)
async def _config_put(self, workspace, type, key, value):
client = self._create_config_client()
try:
await client.start()
await client.request(
ConfigRequest(
operation="put",
workspace=workspace,
values=[ConfigValue(type=type, key=key, value=value)],
),
timeout=10,
)
finally:
await client.stop()
async def _config_delete(self, workspace, type, key):
from trustgraph.schema import ConfigKey
client = self._create_config_client()
try:
await client.start()
await client.request(
ConfigRequest(
operation="delete",
workspace=workspace,
keys=[ConfigKey(type=type, key=key)],
),
timeout=10,
)
finally:
await client.stop()
async def _ensure_workspace_registered(self, workspace_id):
await self._config_put(
"__workspaces__", "workspace", workspace_id,
'{"enabled": true}',
)
logger.info(
f"Registered workspace in config: {workspace_id}"
)
async def _announce_workspace_deleted(self, workspace_id):
try:
await self._config_delete(
"__workspaces__", "workspace", workspace_id,
)
logger.info(
f"Announced workspace deletion: {workspace_id}"
)
except Exception as e:
logger.error(
f"Failed to announce workspace deletion "
f"{workspace_id}: {e}", exc_info=True,
)
async def on_iam_request(self, msg, consumer, flow): async def on_iam_request(self, msg, consumer, flow):
id = None id = None

View file

@ -151,21 +151,11 @@ class CollectionManager:
logger.error(f"Error ensuring collection exists: {e}") logger.error(f"Error ensuring collection exists: {e}")
raise e raise e
async def list_collections(self, request: CollectionManagementRequest) -> CollectionManagementResponse: async def list_collections(self, request, workspace):
"""
List collections for a user from config service
Args:
request: Collection management request
Returns:
CollectionManagementResponse with list of collections
"""
try: try:
# Get all collections in this workspace from config service
config_request = ConfigRequest( config_request = ConfigRequest(
operation='getvalues', operation='getvalues',
workspace=request.workspace, workspace=workspace,
type='collection' type='collection'
) )
@ -210,18 +200,8 @@ class CollectionManager:
logger.error(f"Error listing collections: {e}") logger.error(f"Error listing collections: {e}")
raise RequestError(f"Failed to list collections: {str(e)}") raise RequestError(f"Failed to list collections: {str(e)}")
async def update_collection(self, request: CollectionManagementRequest) -> CollectionManagementResponse: async def update_collection(self, request, workspace):
"""
Update collection metadata via config service (creates if doesn't exist)
Args:
request: Collection management request
Returns:
CollectionManagementResponse with updated collection
"""
try: try:
# Create metadata from request
name = request.name if request.name else request.collection name = request.name if request.name else request.collection
description = request.description if request.description else "" description = request.description if request.description else ""
tags = list(request.tags) if request.tags else [] tags = list(request.tags) if request.tags else []
@ -233,10 +213,9 @@ class CollectionManager:
tags=tags tags=tags
) )
# Send put request to config service
config_request = ConfigRequest( config_request = ConfigRequest(
operation='put', operation='put',
workspace=request.workspace, workspace=workspace,
values=[ConfigValue( values=[ConfigValue(
type='collection', type='collection',
key=request.collection, key=request.collection,
@ -249,7 +228,7 @@ class CollectionManager:
if response.error: if response.error:
raise RuntimeError(f"Config update failed: {response.error.message}") raise RuntimeError(f"Config update failed: {response.error.message}")
logger.info(f"Collection {request.workspace}/{request.collection} updated in config service") logger.info(f"Collection {workspace}/{request.collection} updated in config service")
# Config service will trigger config push automatically # Config service will trigger config push automatically
# Storage services will receive update and create/update collections # Storage services will receive update and create/update collections
@ -264,23 +243,13 @@ class CollectionManager:
logger.error(f"Error updating collection: {e}") logger.error(f"Error updating collection: {e}")
raise RequestError(f"Failed to update collection: {str(e)}") raise RequestError(f"Failed to update collection: {str(e)}")
async def delete_collection(self, request: CollectionManagementRequest) -> CollectionManagementResponse: async def delete_collection(self, request, workspace):
"""
Delete collection via config service
Args:
request: Collection management request
Returns:
CollectionManagementResponse indicating success or failure
"""
try: try:
logger.info(f"Deleting collection {request.workspace}/{request.collection}") logger.info(f"Deleting collection {workspace}/{request.collection}")
# Send delete request to config service
config_request = ConfigRequest( config_request = ConfigRequest(
operation='delete', operation='delete',
workspace=request.workspace, workspace=workspace,
keys=[ConfigKey(type='collection', key=request.collection)] keys=[ConfigKey(type='collection', key=request.collection)]
) )
@ -289,7 +258,7 @@ class CollectionManager:
if response.error: if response.error:
raise RuntimeError(f"Config delete failed: {response.error.message}") raise RuntimeError(f"Config delete failed: {response.error.message}")
logger.info(f"Collection {request.workspace}/{request.collection} deleted from config service") logger.info(f"Collection {workspace}/{request.collection} deleted from config service")
# Config service will trigger config push automatically # Config service will trigger config push automatically
# Storage services will receive update and delete collections # Storage services will receive update and delete collections

View file

@ -44,13 +44,13 @@ class Librarian:
self.load_document = load_document self.load_document = load_document
self.min_chunk_size = min_chunk_size self.min_chunk_size = min_chunk_size
async def add_document(self, request): async def add_document(self, request, workspace):
if not request.document_metadata.kind: if not request.document_metadata.kind:
raise RequestError("Document kind (MIME type) is required") raise RequestError("Document kind (MIME type) is required")
if await self.table_store.document_exists( if await self.table_store.document_exists(
request.document_metadata.workspace, workspace,
request.document_metadata.id request.document_metadata.id
): ):
raise RuntimeError("Document already exists") raise RuntimeError("Document already exists")
@ -68,19 +68,19 @@ class Librarian:
logger.debug("Adding to table...") logger.debug("Adding to table...")
await self.table_store.add_document( await self.table_store.add_document(
request.document_metadata, object_id workspace, request.document_metadata, object_id
) )
logger.debug("Add complete") logger.debug("Add complete")
return LibrarianResponse() return LibrarianResponse()
async def remove_document(self, request): async def remove_document(self, request, workspace):
logger.debug("Removing document...") logger.debug("Removing document...")
if not await self.table_store.document_exists( if not await self.table_store.document_exists(
request.workspace, workspace,
request.document_id, request.document_id,
): ):
raise RuntimeError("Document does not exist") raise RuntimeError("Document does not exist")
@ -91,17 +91,17 @@ class Librarian:
logger.debug(f"Cascade deleting child document {child.id}") logger.debug(f"Cascade deleting child document {child.id}")
try: try:
child_object_id = await self.table_store.get_document_object_id( child_object_id = await self.table_store.get_document_object_id(
child.workspace, workspace,
child.id child.id
) )
await self.blob_store.remove(child_object_id) await self.blob_store.remove(child_object_id)
await self.table_store.remove_document(child.workspace, child.id) await self.table_store.remove_document(workspace, child.id)
except Exception as e: except Exception as e:
logger.warning(f"Failed to delete child document {child.id}: {e}") logger.warning(f"Failed to delete child document {child.id}: {e}")
# Now remove the parent document # Now remove the parent document
object_id = await self.table_store.get_document_object_id( object_id = await self.table_store.get_document_object_id(
request.workspace, workspace,
request.document_id request.document_id
) )
@ -110,7 +110,7 @@ class Librarian:
# Remove doc table row # Remove doc table row
await self.table_store.remove_document( await self.table_store.remove_document(
request.workspace, workspace,
request.document_id request.document_id
) )
@ -118,30 +118,30 @@ class Librarian:
return LibrarianResponse() return LibrarianResponse()
async def update_document(self, request): async def update_document(self, request, workspace):
logger.debug("Updating document...") logger.debug("Updating document...")
# You can't update the document ID, workspace or kind. # You can't update the document ID, workspace or kind.
if not await self.table_store.document_exists( if not await self.table_store.document_exists(
request.document_metadata.workspace, workspace,
request.document_metadata.id request.document_metadata.id
): ):
raise RuntimeError("Document does not exist") raise RuntimeError("Document does not exist")
await self.table_store.update_document(request.document_metadata) await self.table_store.update_document(workspace, request.document_metadata)
logger.debug("Update complete") logger.debug("Update complete")
return LibrarianResponse() return LibrarianResponse()
async def get_document_metadata(self, request): async def get_document_metadata(self, request, workspace):
logger.debug("Getting document metadata...") logger.debug("Getting document metadata...")
doc = await self.table_store.get_document( doc = await self.table_store.get_document(
request.workspace, workspace,
request.document_id request.document_id
) )
@ -153,12 +153,12 @@ class Librarian:
content = None, content = None,
) )
async def get_document_content(self, request): async def get_document_content(self, request, workspace):
logger.debug("Getting document content...") logger.debug("Getting document content...")
object_id = await self.table_store.get_document_object_id( object_id = await self.table_store.get_document_object_id(
request.workspace, workspace,
request.document_id request.document_id
) )
@ -174,7 +174,7 @@ class Librarian:
content = base64.b64encode(content), content = base64.b64encode(content),
) )
async def add_processing(self, request): async def add_processing(self, request, workspace):
logger.debug("Adding processing metadata...") logger.debug("Adding processing metadata...")
@ -182,18 +182,18 @@ class Librarian:
raise RuntimeError("Collection parameter is required") raise RuntimeError("Collection parameter is required")
if await self.table_store.processing_exists( if await self.table_store.processing_exists(
request.processing_metadata.workspace, workspace,
request.processing_metadata.id request.processing_metadata.id
): ):
raise RuntimeError("Processing already exists") raise RuntimeError("Processing already exists")
doc = await self.table_store.get_document( doc = await self.table_store.get_document(
request.processing_metadata.workspace, workspace,
request.processing_metadata.document_id request.processing_metadata.document_id
) )
object_id = await self.table_store.get_document_object_id( object_id = await self.table_store.get_document_object_id(
request.processing_metadata.workspace, workspace,
request.processing_metadata.document_id request.processing_metadata.document_id
) )
@ -205,7 +205,7 @@ class Librarian:
logger.debug("Adding processing to table...") logger.debug("Adding processing to table...")
await self.table_store.add_processing(request.processing_metadata) await self.table_store.add_processing(workspace, request.processing_metadata)
logger.debug("Invoking document processing...") logger.debug("Invoking document processing...")
@ -213,25 +213,26 @@ class Librarian:
document = doc, document = doc,
processing = request.processing_metadata, processing = request.processing_metadata,
content = content, content = content,
workspace = workspace,
) )
logger.debug("Add complete") logger.debug("Add complete")
return LibrarianResponse() return LibrarianResponse()
async def remove_processing(self, request): async def remove_processing(self, request, workspace):
logger.debug("Removing processing metadata...") logger.debug("Removing processing metadata...")
if not await self.table_store.processing_exists( if not await self.table_store.processing_exists(
request.workspace, workspace,
request.processing_id, request.processing_id,
): ):
raise RuntimeError("Processing object does not exist") raise RuntimeError("Processing object does not exist")
# Remove doc table row # Remove doc table row
await self.table_store.remove_processing( await self.table_store.remove_processing(
request.workspace, workspace,
request.processing_id request.processing_id
) )
@ -239,9 +240,9 @@ class Librarian:
return LibrarianResponse() return LibrarianResponse()
async def list_documents(self, request): async def list_documents(self, request, workspace):
docs = await self.table_store.list_documents(request.workspace) docs = await self.table_store.list_documents(workspace)
# Filter out child documents and answer documents by default # Filter out child documents and answer documents by default
include_children = getattr(request, 'include_children', False) include_children = getattr(request, 'include_children', False)
@ -256,9 +257,9 @@ class Librarian:
document_metadatas = docs, document_metadatas = docs,
) )
async def list_processing(self, request): async def list_processing(self, request, workspace):
procs = await self.table_store.list_processing(request.workspace) procs = await self.table_store.list_processing(workspace)
return LibrarianResponse( return LibrarianResponse(
processing_metadatas = procs, processing_metadatas = procs,
@ -266,7 +267,7 @@ class Librarian:
# Chunked upload operations # Chunked upload operations
async def begin_upload(self, request): async def begin_upload(self, request, workspace):
""" """
Initialize a chunked upload session. Initialize a chunked upload session.
@ -278,7 +279,7 @@ class Librarian:
raise RequestError("Document kind (MIME type) is required") raise RequestError("Document kind (MIME type) is required")
if await self.table_store.document_exists( if await self.table_store.document_exists(
request.document_metadata.workspace, workspace,
request.document_metadata.id request.document_metadata.id
): ):
raise RequestError("Document already exists") raise RequestError("Document already exists")
@ -314,14 +315,13 @@ class Librarian:
"kind": request.document_metadata.kind, "kind": request.document_metadata.kind,
"title": request.document_metadata.title, "title": request.document_metadata.title,
"comments": request.document_metadata.comments, "comments": request.document_metadata.comments,
"workspace": request.document_metadata.workspace,
"tags": request.document_metadata.tags, "tags": request.document_metadata.tags,
}) })
# Store session in Cassandra # Store session in Cassandra
await self.table_store.create_upload_session( await self.table_store.create_upload_session(
upload_id=upload_id, upload_id=upload_id,
workspace=request.document_metadata.workspace, workspace=workspace,
document_id=request.document_metadata.id, document_id=request.document_metadata.id,
document_metadata=doc_meta_json, document_metadata=doc_meta_json,
s3_upload_id=s3_upload_id, s3_upload_id=s3_upload_id,
@ -340,7 +340,7 @@ class Librarian:
total_chunks=total_chunks, total_chunks=total_chunks,
) )
async def upload_chunk(self, request): async def upload_chunk(self, request, workspace):
""" """
Upload a single chunk of a document. Upload a single chunk of a document.
@ -354,7 +354,7 @@ class Librarian:
raise RequestError("Upload session not found or expired") raise RequestError("Upload session not found or expired")
# Validate ownership # Validate ownership
if session["workspace"] != request.workspace: if session["workspace"] != workspace:
raise RequestError("Not authorized to upload to this session") raise RequestError("Not authorized to upload to this session")
# Validate chunk index # Validate chunk index
@ -407,7 +407,7 @@ class Librarian:
total_bytes=session["total_size"], total_bytes=session["total_size"],
) )
async def complete_upload(self, request): async def complete_upload(self, request, workspace):
""" """
Finalize a chunked upload and create the document. Finalize a chunked upload and create the document.
@ -421,7 +421,7 @@ class Librarian:
raise RequestError("Upload session not found or expired") raise RequestError("Upload session not found or expired")
# Validate ownership # Validate ownership
if session["workspace"] != request.workspace: if session["workspace"] != workspace:
raise RequestError("Not authorized to complete this upload") raise RequestError("Not authorized to complete this upload")
# Verify all chunks received # Verify all chunks received
@ -459,13 +459,13 @@ class Librarian:
kind=doc_meta_dict["kind"], kind=doc_meta_dict["kind"],
title=doc_meta_dict.get("title", ""), title=doc_meta_dict.get("title", ""),
comments=doc_meta_dict.get("comments", ""), comments=doc_meta_dict.get("comments", ""),
workspace=doc_meta_dict["workspace"],
tags=doc_meta_dict.get("tags", []), tags=doc_meta_dict.get("tags", []),
metadata=[], # Triples not supported in chunked upload yet metadata=[], # Triples not supported in chunked upload yet
) )
# Add document to table # Add document to table
await self.table_store.add_document(doc_metadata, session["object_id"]) workspace = session["workspace"]
await self.table_store.add_document(workspace, doc_metadata, session["object_id"])
# Delete upload session # Delete upload session
await self.table_store.delete_upload_session(request.upload_id) await self.table_store.delete_upload_session(request.upload_id)
@ -478,7 +478,7 @@ class Librarian:
object_id=str(session["object_id"]), object_id=str(session["object_id"]),
) )
async def abort_upload(self, request): async def abort_upload(self, request, workspace):
""" """
Cancel a chunked upload and clean up resources. Cancel a chunked upload and clean up resources.
""" """
@ -490,7 +490,7 @@ class Librarian:
raise RequestError("Upload session not found or expired") raise RequestError("Upload session not found or expired")
# Validate ownership # Validate ownership
if session["workspace"] != request.workspace: if session["workspace"] != workspace:
raise RequestError("Not authorized to abort this upload") raise RequestError("Not authorized to abort this upload")
# Abort S3 multipart upload # Abort S3 multipart upload
@ -506,7 +506,7 @@ class Librarian:
return LibrarianResponse(error=None) return LibrarianResponse(error=None)
async def get_upload_status(self, request): async def get_upload_status(self, request, workspace):
""" """
Get the status of an in-progress upload. Get the status of an in-progress upload.
""" """
@ -522,7 +522,7 @@ class Librarian:
) )
# Validate ownership # Validate ownership
if session["workspace"] != request.workspace: if session["workspace"] != workspace:
raise RequestError("Not authorized to view this upload") raise RequestError("Not authorized to view this upload")
chunks_received = session["chunks_received"] chunks_received = session["chunks_received"]
@ -548,13 +548,13 @@ class Librarian:
total_bytes=session["total_size"], total_bytes=session["total_size"],
) )
async def list_uploads(self, request): async def list_uploads(self, request, workspace):
""" """
List all in-progress uploads for a workspace. List all in-progress uploads for a workspace.
""" """
logger.debug(f"Listing uploads for workspace {request.workspace}") logger.debug(f"Listing uploads for workspace {workspace}")
sessions = await self.table_store.list_upload_sessions(request.workspace) sessions = await self.table_store.list_upload_sessions(workspace)
upload_sessions = [ upload_sessions = [
UploadSession( UploadSession(
@ -577,7 +577,7 @@ class Librarian:
# Child document operations # Child document operations
async def add_child_document(self, request): async def add_child_document(self, request, workspace):
""" """
Add a child document linked to a parent document. Add a child document linked to a parent document.
@ -593,7 +593,7 @@ class Librarian:
# Verify parent exists # Verify parent exists
if not await self.table_store.document_exists( if not await self.table_store.document_exists(
request.document_metadata.workspace, workspace,
request.document_metadata.parent_id request.document_metadata.parent_id
): ):
raise RequestError( raise RequestError(
@ -601,7 +601,7 @@ class Librarian:
) )
if await self.table_store.document_exists( if await self.table_store.document_exists(
request.document_metadata.workspace, workspace,
request.document_metadata.id request.document_metadata.id
): ):
raise RequestError("Document already exists") raise RequestError("Document already exists")
@ -624,7 +624,7 @@ class Librarian:
logger.debug("Adding to table...") logger.debug("Adding to table...")
await self.table_store.add_document( await self.table_store.add_document(
request.document_metadata, object_id workspace, request.document_metadata, object_id
) )
logger.debug("Add child document complete") logger.debug("Add child document complete")
@ -634,7 +634,7 @@ class Librarian:
document_id=request.document_metadata.id, document_id=request.document_metadata.id,
) )
async def list_children(self, request): async def list_children(self, request, workspace):
""" """
List all child documents for a given parent document. List all child documents for a given parent document.
""" """
@ -647,7 +647,7 @@ class Librarian:
document_metadatas=children, document_metadatas=children,
) )
async def stream_document(self, request): async def stream_document(self, request, workspace):
""" """
Stream document content in chunks. Stream document content in chunks.
@ -667,7 +667,7 @@ class Librarian:
) )
object_id = await self.table_store.get_document_object_id( object_id = await self.table_store.get_document_object_id(
request.workspace, workspace,
request.document_id request.document_id
) )
@ -699,4 +699,3 @@ class Librarian:
total_bytes=total_size, total_bytes=total_size,
is_final=is_last, is_final=is_last,
) )

View file

@ -10,7 +10,7 @@ import json
import logging import logging
from datetime import datetime from datetime import datetime
from .. base import AsyncProcessor, Consumer, Producer, Publisher, Subscriber from .. base import WorkspaceProcessor, Consumer, Producer, Publisher, Subscriber
from .. base import ConsumerMetrics, ProducerMetrics from .. base import ConsumerMetrics, ProducerMetrics
from .. base.cassandra_config import add_cassandra_args, resolve_cassandra_config from .. base.cassandra_config import add_cassandra_args, resolve_cassandra_config
@ -46,6 +46,9 @@ default_collection_response_queue = collection_response_queue
default_config_request_queue = config_request_queue default_config_request_queue = config_request_queue
default_config_response_queue = config_response_queue default_config_response_queue = config_response_queue
def workspace_queue(base_queue, workspace):
return f"{base_queue}:{workspace}"
default_object_store_endpoint = "ceph-rgw:7480" default_object_store_endpoint = "ceph-rgw:7480"
default_object_store_access_key = "object-user" default_object_store_access_key = "object-user"
default_object_store_secret_key = "object-password" default_object_store_secret_key = "object-password"
@ -56,27 +59,25 @@ default_min_chunk_size = 1 # No minimum by default (for Garage)
bucket_name = "library" bucket_name = "library"
class Processor(AsyncProcessor): class Processor(WorkspaceProcessor):
def __init__(self, **params): def __init__(self, **params):
id = params.get("id") id = params.get("id")
# self.running = True self.librarian_request_queue_base = params.get(
librarian_request_queue = params.get(
"librarian_request_queue", default_librarian_request_queue "librarian_request_queue", default_librarian_request_queue
) )
librarian_response_queue = params.get( self.librarian_response_queue_base = params.get(
"librarian_response_queue", default_librarian_response_queue "librarian_response_queue", default_librarian_response_queue
) )
collection_request_queue = params.get( self.collection_request_queue_base = params.get(
"collection_request_queue", default_collection_request_queue "collection_request_queue", default_collection_request_queue
) )
collection_response_queue = params.get( self.collection_response_queue_base = params.get(
"collection_response_queue", default_collection_response_queue "collection_response_queue", default_collection_response_queue
) )
@ -116,7 +117,7 @@ class Processor(AsyncProcessor):
cassandra_password = params.get("cassandra_password") cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback # Resolve configuration with environment variable fallback
hosts, username, password, keyspace = resolve_cassandra_config( hosts, username, password, keyspace, replication_factor = resolve_cassandra_config(
host=cassandra_host, host=cassandra_host,
username=cassandra_username, username=cassandra_username,
password=cassandra_password, password=cassandra_password,
@ -130,10 +131,10 @@ class Processor(AsyncProcessor):
super(Processor, self).__init__( super(Processor, self).__init__(
**params | { **params | {
"librarian_request_queue": librarian_request_queue, "librarian_request_queue": self.librarian_request_queue_base,
"librarian_response_queue": librarian_response_queue, "librarian_response_queue": self.librarian_response_queue_base,
"collection_request_queue": collection_request_queue, "collection_request_queue": self.collection_request_queue_base,
"collection_response_queue": collection_response_queue, "collection_response_queue": self.collection_response_queue_base,
"object_store_endpoint": object_store_endpoint, "object_store_endpoint": object_store_endpoint,
"object_store_access_key": object_store_access_key, "object_store_access_key": object_store_access_key,
"cassandra_host": self.cassandra_host, "cassandra_host": self.cassandra_host,
@ -142,68 +143,6 @@ class Processor(AsyncProcessor):
} }
) )
librarian_request_metrics = ConsumerMetrics(
processor = self.id, flow = None, name = "librarian-request"
)
librarian_response_metrics = ProducerMetrics(
processor = self.id, flow = None, name = "librarian-response"
)
collection_request_metrics = ConsumerMetrics(
processor = self.id, flow = None, name = "collection-request"
)
collection_response_metrics = ProducerMetrics(
processor = self.id, flow = None, name = "collection-response"
)
storage_response_metrics = ConsumerMetrics(
processor = self.id, flow = None, name = "storage-response"
)
self.librarian_request_topic = librarian_request_queue
self.librarian_request_subscriber = id
self.librarian_request_consumer = Consumer(
taskgroup = self.taskgroup,
backend = self.pubsub,
flow = None,
topic = librarian_request_queue,
subscriber = id,
schema = LibrarianRequest,
handler = self.on_librarian_request,
metrics = librarian_request_metrics,
)
self.librarian_response_producer = Producer(
backend = self.pubsub,
topic = librarian_response_queue,
schema = LibrarianResponse,
metrics = librarian_response_metrics,
)
self.collection_request_topic = collection_request_queue
self.collection_request_subscriber = id
self.collection_request_consumer = Consumer(
taskgroup = self.taskgroup,
backend = self.pubsub,
flow = None,
topic = collection_request_queue,
subscriber = id,
schema = CollectionManagementRequest,
handler = self.on_collection_request,
metrics = collection_request_metrics,
)
self.collection_response_producer = Producer(
backend = self.pubsub,
topic = collection_response_queue,
schema = CollectionManagementResponse,
metrics = collection_response_metrics,
)
# Config service client for collection management # Config service client for collection management
config_request_metrics = ProducerMetrics( config_request_metrics = ProducerMetrics(
processor = id, flow = None, name = "config-request" processor = id, flow = None, name = "config-request"
@ -240,6 +179,7 @@ class Processor(AsyncProcessor):
object_store_secret_key = object_store_secret_key, object_store_secret_key = object_store_secret_key,
bucket_name = bucket_name, bucket_name = bucket_name,
keyspace = keyspace, keyspace = keyspace,
replication_factor = replication_factor,
load_document = self.load_document, load_document = self.load_document,
object_store_use_ssl = object_store_use_ssl, object_store_use_ssl = object_store_use_ssl,
object_store_region = object_store_region, object_store_region = object_store_region,
@ -259,17 +199,111 @@ class Processor(AsyncProcessor):
self.flows = {} self.flows = {}
# Per-workspace consumers, keyed by workspace id
self.workspace_consumers = {}
logger.info("Librarian service initialized") logger.info("Librarian service initialized")
async def on_workspace_created(self, workspace):
if workspace in self.workspace_consumers:
return
lib_req_queue = workspace_queue(
self.librarian_request_queue_base, workspace,
)
lib_resp_queue = workspace_queue(
self.librarian_response_queue_base, workspace,
)
col_req_queue = workspace_queue(
self.collection_request_queue_base, workspace,
)
col_resp_queue = workspace_queue(
self.collection_response_queue_base, workspace,
)
await self.pubsub.ensure_topic(lib_req_queue)
await self.pubsub.ensure_topic(lib_resp_queue)
await self.pubsub.ensure_topic(col_req_queue)
await self.pubsub.ensure_topic(col_resp_queue)
lib_response_producer = Producer(
backend=self.pubsub,
topic=lib_resp_queue,
schema=LibrarianResponse,
metrics=ProducerMetrics(
processor=self.id, flow=None,
name=f"librarian-response-{workspace}",
),
)
col_response_producer = Producer(
backend=self.pubsub,
topic=col_resp_queue,
schema=CollectionManagementResponse,
metrics=ProducerMetrics(
processor=self.id, flow=None,
name=f"collection-response-{workspace}",
),
)
lib_consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=lib_req_queue,
subscriber=self.id,
schema=LibrarianRequest,
handler=partial(
self.on_librarian_request, workspace=workspace,
),
metrics=ConsumerMetrics(
processor=self.id, flow=None,
name=f"librarian-request-{workspace}",
),
)
col_consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=col_req_queue,
subscriber=self.id,
schema=CollectionManagementRequest,
handler=partial(
self.on_collection_request, workspace=workspace,
),
metrics=ConsumerMetrics(
processor=self.id, flow=None,
name=f"collection-request-{workspace}",
),
)
await lib_response_producer.start()
await col_response_producer.start()
await lib_consumer.start()
await col_consumer.start()
self.workspace_consumers[workspace] = {
"librarian": lib_consumer,
"librarian-response": lib_response_producer,
"collection": col_consumer,
"collection-response": col_response_producer,
}
logger.info(f"Subscribed to workspace queues: {workspace}")
async def on_workspace_deleted(self, workspace):
clients = self.workspace_consumers.pop(workspace, None)
if clients:
for client in clients.values():
await client.stop()
logger.info(f"Unsubscribed from workspace queues: {workspace}")
async def start(self): async def start(self):
await self.pubsub.ensure_topic(self.librarian_request_topic)
await self.pubsub.ensure_topic(self.collection_request_topic)
await super(Processor, self).start() await super(Processor, self).start()
await self.librarian_request_consumer.start()
await self.librarian_response_producer.start()
await self.collection_request_consumer.start()
await self.collection_response_producer.start()
await self.config_request_producer.start() await self.config_request_producer.start()
await self.config_response_consumer.start() await self.config_response_consumer.start()
@ -360,13 +394,12 @@ class Processor(AsyncProcessor):
finally: finally:
await triples_pub.stop() await triples_pub.stop()
async def load_document(self, document, processing, content): async def load_document(self, document, processing, content, workspace):
logger.debug("Ready for document processing...") logger.debug("Ready for document processing...")
logger.debug(f"Document: {document}, processing: {processing}, content length: {len(content)}") logger.debug(f"Document: {document}, processing: {processing}, content length: {len(content)}")
workspace = processing.workspace
ws_flows = self.flows.get(workspace, {}) ws_flows = self.flows.get(workspace, {})
if processing.flow not in ws_flows: if processing.flow not in ws_flows:
raise RuntimeError( raise RuntimeError(
@ -426,20 +459,14 @@ class Processor(AsyncProcessor):
logger.debug("Document submitted") logger.debug("Document submitted")
async def add_processing_with_collection(self, request): async def add_processing_with_collection(self, request, workspace):
"""
Wrapper for add_processing that ensures collection exists
"""
# Ensure collection exists when processing is added
if hasattr(request, 'processing_metadata') and request.processing_metadata: if hasattr(request, 'processing_metadata') and request.processing_metadata:
workspace = request.processing_metadata.workspace
collection = request.processing_metadata.collection collection = request.processing_metadata.collection
await self.collection_manager.ensure_collection_exists(workspace, collection) await self.collection_manager.ensure_collection_exists(workspace, collection)
# Call the original add_processing method return await self.librarian.add_processing(request, workspace)
return await self.librarian.add_processing(request)
async def process_request(self, v): async def process_request(self, v, workspace):
if v.operation is None: if v.operation is None:
raise RequestError("Null operation") raise RequestError("Null operation")
@ -472,9 +499,9 @@ class Processor(AsyncProcessor):
if v.operation not in impls: if v.operation not in impls:
raise RequestError(f"Invalid operation: {v.operation}") raise RequestError(f"Invalid operation: {v.operation}")
return await impls[v.operation](v) return await impls[v.operation](v, workspace)
async def on_librarian_request(self, msg, consumer, flow): async def on_librarian_request(self, msg, consumer, flow, *, workspace):
v = msg.value() v = msg.value()
@ -484,20 +511,22 @@ class Processor(AsyncProcessor):
logger.info(f"Handling librarian input {id}...") logger.info(f"Handling librarian input {id}...")
producer = self.workspace_consumers[workspace]["librarian-response"]
try: try:
# Handle streaming operations specially # Handle streaming operations specially
if v.operation == "stream-document": if v.operation == "stream-document":
async for resp in self.librarian.stream_document(v): async for resp in self.librarian.stream_document(v, workspace):
await self.librarian_response_producer.send( await producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )
return return
# Non-streaming operations # Non-streaming operations
resp = await self.process_request(v) resp = await self.process_request(v, workspace)
await self.librarian_response_producer.send( await producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )
@ -511,7 +540,7 @@ class Processor(AsyncProcessor):
), ),
) )
await self.librarian_response_producer.send( await producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )
@ -524,7 +553,7 @@ class Processor(AsyncProcessor):
), ),
) )
await self.librarian_response_producer.send( await producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )
@ -532,10 +561,7 @@ class Processor(AsyncProcessor):
logger.debug("Librarian input processing complete") logger.debug("Librarian input processing complete")
async def process_collection_request(self, v): async def process_collection_request(self, v, workspace):
"""
Process collection management requests
"""
if v.operation is None: if v.operation is None:
raise RequestError("Null operation") raise RequestError("Null operation")
@ -550,20 +576,19 @@ class Processor(AsyncProcessor):
if v.operation not in impls: if v.operation not in impls:
raise RequestError(f"Invalid collection operation: {v.operation}") raise RequestError(f"Invalid collection operation: {v.operation}")
return await impls[v.operation](v) return await impls[v.operation](v, workspace)
async def on_collection_request(self, msg, consumer, flow): async def on_collection_request(self, msg, consumer, flow, *, workspace):
"""
Handle collection management request messages
"""
v = msg.value() v = msg.value()
id = msg.properties().get("id", "unknown") id = msg.properties().get("id", "unknown")
logger.info(f"Handling collection request {id}...") logger.info(f"Handling collection request {id}...")
producer = self.workspace_consumers[workspace]["collection-response"]
try: try:
resp = await self.process_collection_request(v) resp = await self.process_collection_request(v, workspace)
await self.collection_response_producer.send( await producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )
except RequestError as e: except RequestError as e:
@ -574,7 +599,7 @@ class Processor(AsyncProcessor):
), ),
timestamp=datetime.now().isoformat() timestamp=datetime.now().isoformat()
) )
await self.collection_response_producer.send( await producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )
except Exception as e: except Exception as e:
@ -585,7 +610,7 @@ class Processor(AsyncProcessor):
), ),
timestamp=datetime.now().isoformat() timestamp=datetime.now().isoformat()
) )
await self.collection_response_producer.send( await producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )
@ -594,7 +619,7 @@ class Processor(AsyncProcessor):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
AsyncProcessor.add_args(parser) WorkspaceProcessor.add_args(parser)
parser.add_argument( parser.add_argument(
'--librarian-request-queue', '--librarian-request-queue',

View file

@ -35,8 +35,8 @@ class Processor(LlmService):
temperature = params.get("temperature", default_temperature) temperature = params.get("temperature", default_temperature)
max_output = params.get("max_output", default_max_output) max_output = params.get("max_output", default_max_output)
if api_key is None: if not api_key:
raise RuntimeError("OpenAI API key not specified") api_key = "not-set"
super(Processor, self).__init__( super(Processor, self).__init__(
**params | { **params | {

View file

@ -47,7 +47,7 @@ class Processor(FlowProcessor):
cassandra_password = params.get("cassandra_password") cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback # Resolve configuration with environment variable fallback
hosts, username, password, keyspace = resolve_cassandra_config( hosts, username, password, keyspace, _ = resolve_cassandra_config(
host=cassandra_host, host=cassandra_host,
username=cassandra_username, username=cassandra_username,
password=cassandra_password password=cassandra_password

View file

@ -160,7 +160,7 @@ class Processor(TriplesQueryService):
cassandra_password = params.get("cassandra_password") cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback # Resolve configuration with environment variable fallback
hosts, username, password, keyspace = resolve_cassandra_config( hosts, username, password, keyspace, _ = resolve_cassandra_config(
host=cassandra_host, host=cassandra_host,
username=cassandra_username, username=cassandra_username,
password=cassandra_password password=cassandra_password

View file

@ -122,7 +122,7 @@ class Query:
for match in chunk_matches: for match in chunk_matches:
if match.chunk_id: if match.chunk_id:
try: try:
content = await self.rag.fetch_chunk(match.chunk_id, self.workspace) content = await self.rag.fetch_chunk(match.chunk_id)
docs.append(content) docs.append(content)
chunk_ids.append(match.chunk_id) chunk_ids.append(match.chunk_id)
except Exception as e: except Exception as e:

View file

@ -4,21 +4,16 @@ Simple RAG service, performs query using document RAG an LLM.
Input is query, output is response. Input is query, output is response.
""" """
import asyncio
import base64
import logging import logging
import uuid
from ... schema import DocumentRagQuery, DocumentRagResponse, Error from ... schema import DocumentRagQuery, DocumentRagResponse, Error
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import Triples, Metadata from ... schema import Triples, Metadata
from ... provenance import GRAPH_RETRIEVAL from ... provenance import GRAPH_RETRIEVAL
from . document_rag import DocumentRag from . document_rag import DocumentRag
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import PromptClientSpec, EmbeddingsClientSpec from ... base import PromptClientSpec, EmbeddingsClientSpec
from ... base import DocumentEmbeddingsClientSpec from ... base import DocumentEmbeddingsClientSpec
from ... base import LibrarianClient from ... base import LibrarianSpec
# Module logger # Module logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -85,58 +80,14 @@ class Processor(FlowProcessor):
) )
) )
# Librarian client self.register_specification(
self.librarian = LibrarianClient( LibrarianSpec()
id=id,
backend=self.pubsub,
taskgroup=self.taskgroup,
) )
async def start(self):
await super(Processor, self).start()
await self.librarian.start()
async def fetch_chunk_content(self, chunk_id, workspace, timeout=120):
"""Fetch chunk content from librarian. Chunks are small so
single request-response is fine."""
return await self.librarian.fetch_document_text(
document_id=chunk_id, workspace=workspace, timeout=timeout,
)
async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120):
"""Save answer content to the librarian."""
doc_metadata = DocumentMetadata(
id=doc_id,
workspace=workspace,
kind="text/plain",
title=title or "DocumentRAG Answer",
document_type="answer",
)
request = LibrarianRequest(
operation="add-document",
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
workspace=workspace,
)
await self.librarian.request(request, timeout=timeout)
return doc_id
async def on_request(self, msg, consumer, flow): async def on_request(self, msg, consumer, flow):
try: try:
self.rag = DocumentRag(
embeddings_client = flow("embeddings-request"),
doc_embeddings_client = flow("document-embeddings-request"),
prompt_client = flow("prompt-request"),
fetch_chunk = self.fetch_chunk_content,
verbose=True,
)
v = msg.value() v = msg.value()
# Sender-produced ID # Sender-produced ID
@ -144,15 +95,25 @@ class Processor(FlowProcessor):
logger.info(f"Handling input {id}...") logger.info(f"Handling input {id}...")
async def fetch_chunk(chunk_id, timeout=120):
return await flow.librarian.fetch_document_text(
document_id=chunk_id, timeout=timeout,
)
self.rag = DocumentRag(
embeddings_client = flow("embeddings-request"),
doc_embeddings_client = flow("document-embeddings-request"),
prompt_client = flow("prompt-request"),
fetch_chunk = fetch_chunk,
verbose=True,
)
if v.doc_limit: if v.doc_limit:
doc_limit = v.doc_limit doc_limit = v.doc_limit
else: else:
doc_limit = self.doc_limit doc_limit = self.doc_limit
# Real-time explainability callback - emits triples and IDs as they're generated
# Triples are stored in the request's collection with a named graph (urn:graph:retrieval)
async def send_explainability(triples, explain_id): async def send_explainability(triples, explain_id):
# Send triples to explainability queue - stores in same collection with named graph
await flow("explainability").send(Triples( await flow("explainability").send(Triples(
metadata=Metadata( metadata=Metadata(
id=explain_id, id=explain_id,
@ -161,7 +122,6 @@ class Processor(FlowProcessor):
triples=triples, triples=triples,
)) ))
# Send explain data to response queue
await flow("response").send( await flow("response").send(
DocumentRagResponse( DocumentRagResponse(
response=None, response=None,
@ -173,13 +133,12 @@ class Processor(FlowProcessor):
properties={"id": id} properties={"id": id}
) )
# Callback to save answer content to librarian
async def save_answer(doc_id, answer_text): async def save_answer(doc_id, answer_text):
await self.save_answer_content( await flow.librarian.save_document(
doc_id=doc_id, doc_id=doc_id,
workspace=flow.workspace,
content=answer_text, content=answer_text,
title=f"DocumentRAG Answer: {v.query[:50]}...", title=f"DocumentRAG Answer: {v.query[:50]}...",
document_type="answer",
) )
# Check if streaming is requested # Check if streaming is requested

View file

@ -4,29 +4,22 @@ Simple RAG service, performs query using graph RAG an LLM.
Input is query, output is response. Input is query, output is response.
""" """
import asyncio
import base64
import logging import logging
import uuid
from ... schema import GraphRagQuery, GraphRagResponse, Error from ... schema import GraphRagQuery, GraphRagResponse, Error
from ... schema import Triples, Metadata from ... schema import Triples, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue
from ... provenance import GRAPH_RETRIEVAL from ... provenance import GRAPH_RETRIEVAL
from . graph_rag import GraphRag from . graph_rag import GraphRag
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import PromptClientSpec, EmbeddingsClientSpec from ... base import PromptClientSpec, EmbeddingsClientSpec
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics from ... base import LibrarianSpec
# Module logger # Module logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
default_ident = "graph-rag" default_ident = "graph-rag"
default_concurrency = 1 default_concurrency = 1
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class Processor(FlowProcessor): class Processor(FlowProcessor):
@ -117,115 +110,12 @@ class Processor(FlowProcessor):
) )
) )
# Librarian client for storing answer content self.register_specification(
librarian_request_q = params.get( LibrarianSpec()
"librarian_request_queue", default_librarian_request_queue
) )
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request"
)
self.librarian_request_producer = Producer(
backend=self.pubsub,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self.on_librarian_response,
metrics=librarian_response_metrics,
)
# Pending librarian requests: request_id -> asyncio.Future
self.pending_librarian_requests = {}
logger.info("Graph RAG service initialized") logger.info("Graph RAG service initialized")
async def start(self):
await super(Processor, self).start()
await self.librarian_request_producer.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id and request_id in self.pending_librarian_requests:
future = self.pending_librarian_requests.pop(request_id)
future.set_result(response)
async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120):
"""
Save answer content to the librarian.
Args:
doc_id: ID for the answer document
workspace: Workspace for isolation
content: Answer text content
title: Optional title
timeout: Request timeout in seconds
Returns:
The document ID on success
"""
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata(
id=doc_id,
workspace=workspace,
kind="text/plain",
title=title or "GraphRAG Answer",
document_type="answer",
)
request = LibrarianRequest(
operation="add-document",
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
workspace=workspace,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_librarian_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving answer: {response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_librarian_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving answer document {doc_id}")
async def on_request(self, msg, consumer, flow): async def on_request(self, msg, consumer, flow):
try: try:
@ -306,13 +196,12 @@ class Processor(FlowProcessor):
else: else:
edge_limit = self.default_edge_limit edge_limit = self.default_edge_limit
# Callback to save answer content to librarian
async def save_answer(doc_id, answer_text): async def save_answer(doc_id, answer_text):
await self.save_answer_content( await flow.librarian.save_document(
doc_id=doc_id, doc_id=doc_id,
workspace=flow.workspace,
content=answer_text, content=answer_text,
title=f"GraphRAG Answer: {v.query[:50]}...", title=f"GraphRAG Answer: {v.query[:50]}...",
document_type="answer",
) )
# Check if streaming is requested # Check if streaming is requested

View file

@ -47,7 +47,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
cassandra_password = params.get("cassandra_password") cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback # Resolve configuration with environment variable fallback
hosts, username, password, keyspace = resolve_cassandra_config( hosts, username, password, keyspace, _ = resolve_cassandra_config(
host=cassandra_host, host=cassandra_host,
username=cassandra_username, username=cassandra_username,
password=cassandra_password password=cassandra_password

View file

@ -125,7 +125,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
cassandra_password = params.get("cassandra_password") cassandra_password = params.get("cassandra_password")
# Resolve configuration with environment variable fallback # Resolve configuration with environment variable fallback
hosts, username, password, keyspace = resolve_cassandra_config( hosts, username, password, keyspace, _ = resolve_cassandra_config(
host=cassandra_host, host=cassandra_host,
username=cassandra_username, username=cassandra_username,
password=cassandra_password password=cassandra_password

View file

@ -313,7 +313,7 @@ class LibraryTableStore:
return bool(rows) return bool(rows)
async def add_document(self, document, object_id): async def add_document(self, workspace, document, object_id):
logger.info(f"Adding document {document.id} {object_id}") logger.info(f"Adding document {document.id} {object_id}")
@ -333,7 +333,7 @@ class LibraryTableStore:
self.cassandra, self.cassandra,
self.insert_document_stmt, self.insert_document_stmt,
( (
document.id, document.workspace, int(document.time * 1000), document.id, workspace, int(document.time * 1000),
document.kind, document.title, document.comments, document.kind, document.title, document.comments,
metadata, document.tags, object_id, metadata, document.tags, object_id,
parent_id, document_type parent_id, document_type
@ -345,7 +345,7 @@ class LibraryTableStore:
logger.debug("Add complete") logger.debug("Add complete")
async def update_document(self, document): async def update_document(self, workspace, document):
logger.info(f"Updating document {document.id}") logger.info(f"Updating document {document.id}")
@ -363,7 +363,7 @@ class LibraryTableStore:
( (
int(document.time * 1000), document.title, int(document.time * 1000), document.title,
document.comments, metadata, document.tags, document.comments, metadata, document.tags,
document.workspace, document.id workspace, document.id
), ),
) )
except Exception: except Exception:
@ -405,7 +405,6 @@ class LibraryTableStore:
lst = [ lst = [
DocumentMetadata( DocumentMetadata(
id = row[0], id = row[0],
workspace = workspace,
time = int(time.mktime(row[1].timetuple())), time = int(time.mktime(row[1].timetuple())),
kind = row[2], kind = row[2],
title = row[3], title = row[3],
@ -447,7 +446,6 @@ class LibraryTableStore:
lst = [ lst = [
DocumentMetadata( DocumentMetadata(
id = row[0], id = row[0],
workspace = row[1],
time = int(time.mktime(row[2].timetuple())), time = int(time.mktime(row[2].timetuple())),
kind = row[3], kind = row[3],
title = row[4], title = row[4],
@ -488,7 +486,6 @@ class LibraryTableStore:
for row in rows: for row in rows:
doc = DocumentMetadata( doc = DocumentMetadata(
id = id, id = id,
workspace = workspace,
time = int(time.mktime(row[0].timetuple())), time = int(time.mktime(row[0].timetuple())),
kind = row[1], kind = row[1],
title = row[2], title = row[2],
@ -541,7 +538,7 @@ class LibraryTableStore:
return bool(rows) return bool(rows)
async def add_processing(self, processing): async def add_processing(self, workspace, processing):
logger.info(f"Adding processing {processing.id}") logger.info(f"Adding processing {processing.id}")
@ -552,7 +549,7 @@ class LibraryTableStore:
( (
processing.id, processing.document_id, processing.id, processing.document_id,
int(processing.time * 1000), processing.flow, int(processing.time * 1000), processing.flow,
processing.workspace, processing.collection, workspace, processing.collection,
processing.tags processing.tags
), ),
) )
@ -598,7 +595,6 @@ class LibraryTableStore:
document_id = row[1], document_id = row[1],
time = int(time.mktime(row[2].timetuple())), time = int(time.mktime(row[2].timetuple())),
flow = row[3], flow = row[3],
workspace = workspace,
collection = row[4], collection = row[4],
tags = row[5] if row[5] else [], tags = row[5] if row[5] else [],
) )

View file

@ -13,9 +13,8 @@ import pytesseract
from pdf2image import convert_from_bytes from pdf2image import convert_from_bytes
from ... schema import Document, TextDocument, Metadata from ... schema import Document, TextDocument, Metadata
from ... schema import librarian_request_queue, librarian_response_queue
from ... schema import Triples from ... schema import Triples
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianSpec
from ... provenance import ( from ... provenance import (
document_uri, page_uri as make_page_uri, derived_entity_triples, document_uri, page_uri as make_page_uri, derived_entity_triples,
@ -31,9 +30,6 @@ logger = logging.getLogger(__name__)
default_ident = "document-decoder" default_ident = "document-decoder"
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class Processor(FlowProcessor): class Processor(FlowProcessor):
def __init__(self, **params): def __init__(self, **params):
@ -68,17 +64,12 @@ class Processor(FlowProcessor):
) )
) )
# Librarian client self.register_specification(
self.librarian = LibrarianClient( LibrarianSpec()
id=id, backend=self.pubsub, taskgroup=self.taskgroup,
) )
logger.info("PDF OCR processor initialized") logger.info("PDF OCR processor initialized")
async def start(self):
await super(Processor, self).start()
await self.librarian.start()
async def on_message(self, msg, consumer, flow): async def on_message(self, msg, consumer, flow):
logger.info("PDF message received") logger.info("PDF message received")
@ -89,9 +80,8 @@ class Processor(FlowProcessor):
# Check MIME type if fetching from librarian # Check MIME type if fetching from librarian
if v.document_id: if v.document_id:
doc_meta = await self.librarian.fetch_document_metadata( doc_meta = await flow.librarian.fetch_document_metadata(
document_id=v.document_id, document_id=v.document_id,
workspace=flow.workspace,
) )
if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf": if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf":
logger.error( logger.error(
@ -104,9 +94,8 @@ class Processor(FlowProcessor):
# Get PDF content - fetch from librarian or use inline data # Get PDF content - fetch from librarian or use inline data
if v.document_id: if v.document_id:
logger.info(f"Fetching document {v.document_id} from librarian...") logger.info(f"Fetching document {v.document_id} from librarian...")
content = await self.librarian.fetch_document_content( content = await flow.librarian.fetch_document_content(
document_id=v.document_id, document_id=v.document_id,
workspace=flow.workspace,
) )
if isinstance(content, str): if isinstance(content, str):
content = content.encode('utf-8') content = content.encode('utf-8')
@ -138,10 +127,9 @@ class Processor(FlowProcessor):
page_content = text.encode("utf-8") page_content = text.encode("utf-8")
# Save page as child document in librarian # Save page as child document in librarian
await self.librarian.save_child_document( await flow.librarian.save_child_document(
doc_id=page_doc_id, doc_id=page_doc_id,
parent_id=source_doc_id, parent_id=source_doc_id,
workspace=flow.workspace,
content=page_content, content=page_content,
document_type="page", document_type="page",
title=f"Page {page_num}", title=f"Page {page_num}",
@ -189,18 +177,6 @@ class Processor(FlowProcessor):
FlowProcessor.add_args(parser) FlowProcessor.add_args(parser)
parser.add_argument(
'--librarian-request-queue',
default=default_librarian_request_queue,
help=f'Librarian request queue (default: {default_librarian_request_queue})',
)
parser.add_argument(
'--librarian-response-queue',
default=default_librarian_response_queue,
help=f'Librarian response queue (default: {default_librarian_response_queue})',
)
def run(): def run():
Processor.launch(default_ident, __doc__) Processor.launch(default_ident, __doc__)

View file

@ -23,9 +23,8 @@ import os
from unstructured.partition.auto import partition from unstructured.partition.auto import partition
from ... schema import Document, TextDocument, Metadata from ... schema import Document, TextDocument, Metadata
from ... schema import librarian_request_queue, librarian_response_queue
from ... schema import Triples from ... schema import Triples
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianSpec
from ... provenance import ( from ... provenance import (
document_uri, page_uri as make_page_uri, document_uri, page_uri as make_page_uri,
@ -44,9 +43,6 @@ logger = logging.getLogger(__name__)
default_ident = "document-decoder" default_ident = "document-decoder"
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
# Mime type to unstructured content_type mapping # Mime type to unstructured content_type mapping
# unstructured auto-detects most formats, but we pass the hint when available # unstructured auto-detects most formats, but we pass the hint when available
MIME_EXTENSIONS = { MIME_EXTENSIONS = {
@ -162,17 +158,12 @@ class Processor(FlowProcessor):
) )
) )
# Librarian client self.register_specification(
self.librarian = LibrarianClient( LibrarianSpec()
id=id, backend=self.pubsub, taskgroup=self.taskgroup,
) )
logger.info("Universal decoder initialized") logger.info("Universal decoder initialized")
async def start(self):
await super(Processor, self).start()
await self.librarian.start()
def extract_elements(self, blob, mime_type=None): def extract_elements(self, blob, mime_type=None):
""" """
Extract elements from a document using unstructured. Extract elements from a document using unstructured.
@ -272,10 +263,9 @@ class Processor(FlowProcessor):
page_content = text.encode("utf-8") page_content = text.encode("utf-8")
# Save to librarian # Save to librarian
await self.librarian.save_child_document( await flow.librarian.save_child_document(
doc_id=doc_id, doc_id=doc_id,
parent_id=parent_doc_id, parent_id=parent_doc_id,
workspace=flow.workspace,
content=page_content, content=page_content,
document_type="page" if is_page else "section", document_type="page" if is_page else "section",
title=label, title=label,
@ -351,10 +341,9 @@ class Processor(FlowProcessor):
# Save to librarian # Save to librarian
if img_content: if img_content:
await self.librarian.save_child_document( await flow.librarian.save_child_document(
doc_id=img_uri, doc_id=img_uri,
parent_id=parent_doc_id, parent_id=parent_doc_id,
workspace=flow.workspace,
content=img_content, content=img_content,
document_type="image", document_type="image",
title=f"Image from page {page_number}" if page_number else "Image", title=f"Image from page {page_number}" if page_number else "Image",
@ -399,15 +388,13 @@ class Processor(FlowProcessor):
f"Fetching document {v.document_id} from librarian..." f"Fetching document {v.document_id} from librarian..."
) )
doc_meta = await self.librarian.fetch_document_metadata( doc_meta = await flow.librarian.fetch_document_metadata(
document_id=v.document_id, document_id=v.document_id,
workspace=flow.workspace,
) )
mime_type = doc_meta.kind if doc_meta else None mime_type = doc_meta.kind if doc_meta else None
content = await self.librarian.fetch_document_content( content = await flow.librarian.fetch_document_content(
document_id=v.document_id, document_id=v.document_id,
workspace=flow.workspace,
) )
if isinstance(content, str): if isinstance(content, str):
@ -571,19 +558,6 @@ class Processor(FlowProcessor):
help='Apply section strategy within pages too (default: false)', help='Apply section strategy within pages too (default: false)',
) )
parser.add_argument(
'--librarian-request-queue',
default=default_librarian_request_queue,
help=f'Librarian request queue '
f'(default: {default_librarian_request_queue})',
)
parser.add_argument(
'--librarian-response-queue',
default=default_librarian_response_queue,
help=f'Librarian response queue '
f'(default: {default_librarian_response_queue})',
)
def run(): def run():

View file

@ -10,12 +10,13 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md" readme = "README.md"
requires-python = ">=3.8" requires-python = ">=3.8"
dependencies = [ dependencies = [
"trustgraph-base>=1.8,<1.9", "trustgraph-base>=2.4,<2.5",
"trustgraph-bedrock>=1.8,<1.9", "trustgraph-bedrock>=2.4,<2.5",
"trustgraph-cli>=1.8,<1.9", "trustgraph-cli>=2.4,<2.5",
"trustgraph-embeddings-hf>=1.8,<1.9", "trustgraph-embeddings-hf>=2.4,<2.5",
"trustgraph-flow>=1.8,<1.9", "trustgraph-flow>=2.4,<2.5",
"trustgraph-vertexai>=1.8,<1.9", "trustgraph-unstructured>=2.4,<2.5",
"trustgraph-vertexai>=2.4,<2.5",
] ]
classifiers = [ classifiers = [
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",