mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-13 01:02:37 +02:00
Merge branch 'release/v2.4'
This commit is contained in:
commit
159b1e2824
98 changed files with 2026 additions and 1445 deletions
|
|
@ -11,8 +11,9 @@ ENV PIP_BREAK_SYSTEM_PACKAGES=1
|
||||||
RUN dnf install -y python3.13 && \
|
RUN dnf install -y python3.13 && \
|
||||||
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
|
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
|
||||||
python -m ensurepip --upgrade && \
|
python -m ensurepip --upgrade && \
|
||||||
|
pip3 install --no-cache-dir --upgrade 'pip>=26.0' 'setuptools>=78.1.1' && \
|
||||||
pip3 install --no-cache-dir build wheel aiohttp && \
|
pip3 install --no-cache-dir build wheel aiohttp && \
|
||||||
pip3 install --no-cache-dir pulsar-client==3.7.0 && \
|
pip3 install --no-cache-dir pulsar-client==3.11.0 && \
|
||||||
dnf clean all
|
dnf clean all
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -11,8 +11,9 @@ ENV PIP_BREAK_SYSTEM_PACKAGES=1
|
||||||
RUN dnf install -y python3.13 && \
|
RUN dnf install -y python3.13 && \
|
||||||
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
|
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
|
||||||
python -m ensurepip --upgrade && \
|
python -m ensurepip --upgrade && \
|
||||||
|
pip3 install --no-cache-dir --upgrade 'pip>=26.0' 'setuptools>=78.1.1' && \
|
||||||
pip3 install --no-cache-dir build wheel aiohttp && \
|
pip3 install --no-cache-dir build wheel aiohttp && \
|
||||||
pip3 install --no-cache-dir pulsar-client==3.7.0 && \
|
pip3 install --no-cache-dir pulsar-client==3.11.0 && \
|
||||||
dnf clean all
|
dnf clean all
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -11,18 +11,19 @@ ENV PIP_BREAK_SYSTEM_PACKAGES=1
|
||||||
RUN dnf install -y python3.13 && \
|
RUN dnf install -y python3.13 && \
|
||||||
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
|
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
|
||||||
python -m ensurepip --upgrade && \
|
python -m ensurepip --upgrade && \
|
||||||
|
pip3 install --no-cache-dir --upgrade 'pip>=26.0' 'setuptools>=78.1.1' && \
|
||||||
pip3 install --no-cache-dir build wheel aiohttp rdflib && \
|
pip3 install --no-cache-dir build wheel aiohttp rdflib && \
|
||||||
pip3 install --no-cache-dir pulsar-client==3.7.0 && \
|
pip3 install --no-cache-dir pulsar-client==3.11.0 && \
|
||||||
dnf clean all
|
dnf clean all
|
||||||
|
|
||||||
RUN pip3 install --no-cache-dir \
|
RUN pip3 install --no-cache-dir \
|
||||||
anthropic cohere mistralai openai \
|
anthropic cohere mistralai openai \
|
||||||
ollama \
|
ollama \
|
||||||
langchain==0.3.25 langchain-core==0.3.60 \
|
langchain==1.2.16 langchain-core==1.3.2 \
|
||||||
langchain-text-splitters==0.3.8 \
|
langchain-text-splitters==1.1.2 \
|
||||||
langchain-community==0.3.24 \
|
langchain-community==0.4.1 \
|
||||||
pymilvus \
|
pymilvus \
|
||||||
pulsar-client==3.7.0 scylla-driver pyyaml \
|
pulsar-client==3.11.0 scylla-driver pyyaml \
|
||||||
neo4j tiktoken falkordb && \
|
neo4j tiktoken falkordb && \
|
||||||
pip3 cache purge
|
pip3 cache purge
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,9 @@ ENV PIP_BREAK_SYSTEM_PACKAGES=1
|
||||||
RUN dnf install -y python3.12 && \
|
RUN dnf install -y python3.12 && \
|
||||||
alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \
|
alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \
|
||||||
python -m ensurepip --upgrade && \
|
python -m ensurepip --upgrade && \
|
||||||
|
pip3 install --no-cache-dir --upgrade 'pip>=26.0' 'setuptools>=78.1.1' && \
|
||||||
pip3 install --no-cache-dir build wheel aiohttp && \
|
pip3 install --no-cache-dir build wheel aiohttp && \
|
||||||
pip3 install --no-cache-dir pulsar-client==3.7.0 && \
|
pip3 install --no-cache-dir pulsar-client==3.11.0 && \
|
||||||
dnf clean all
|
dnf clean all
|
||||||
|
|
||||||
# This won't work on ARM
|
# This won't work on ARM
|
||||||
|
|
@ -19,15 +20,15 @@ RUN dnf install -y python3.12 && \
|
||||||
RUN pip3 install torch
|
RUN pip3 install torch
|
||||||
|
|
||||||
RUN pip3 install --no-cache-dir \
|
RUN pip3 install --no-cache-dir \
|
||||||
langchain==0.3.25 langchain-core==0.3.60 langchain-huggingface==0.2.0 \
|
langchain==1.2.16 langchain-core==1.3.2 langchain-huggingface==1.2.2 \
|
||||||
langchain-community==0.3.24 \
|
langchain-community==0.4.1 \
|
||||||
sentence-transformers==4.1.0 transformers==4.51.3 \
|
sentence-transformers==5.4.1 transformers==5.7.0 \
|
||||||
huggingface-hub==0.31.2 \
|
huggingface-hub==1.13.0 \
|
||||||
pulsar-client==3.7.0
|
pulsar-client==3.11.0
|
||||||
|
|
||||||
# Most commonly used embeddings model, just build it into the container
|
# Most commonly used embeddings model, just build it into the container
|
||||||
# image
|
# image
|
||||||
RUN huggingface-cli download sentence-transformers/all-MiniLM-L6-v2
|
RUN hf download sentence-transformers/all-MiniLM-L6-v2
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------
|
||||||
# Build a container which contains the built Python packages. The build
|
# Build a container which contains the built Python packages. The build
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ ENV PIP_BREAK_SYSTEM_PACKAGES=1
|
||||||
RUN dnf install -y python3.13 && \
|
RUN dnf install -y python3.13 && \
|
||||||
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
|
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
|
||||||
python -m ensurepip --upgrade && \
|
python -m ensurepip --upgrade && \
|
||||||
|
pip3 install --no-cache-dir --upgrade 'pip>=26.0' 'setuptools>=78.1.1' && \
|
||||||
pip3 install --no-cache-dir mcp websockets && \
|
pip3 install --no-cache-dir mcp websockets && \
|
||||||
dnf clean all
|
dnf clean all
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,8 +12,9 @@ RUN dnf install -y python3.13 && \
|
||||||
dnf install -y tesseract poppler-utils && \
|
dnf install -y tesseract poppler-utils && \
|
||||||
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
|
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
|
||||||
python -m ensurepip --upgrade && \
|
python -m ensurepip --upgrade && \
|
||||||
|
pip3 install --no-cache-dir --upgrade 'pip>=26.0' 'setuptools>=78.1.1' && \
|
||||||
pip3 install --no-cache-dir build wheel aiohttp && \
|
pip3 install --no-cache-dir build wheel aiohttp && \
|
||||||
pip3 install --no-cache-dir pulsar-client==3.7.0 && \
|
pip3 install --no-cache-dir pulsar-client==3.11.0 && \
|
||||||
dnf clean all
|
dnf clean all
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,9 @@ ENV PIP_BREAK_SYSTEM_PACKAGES=1
|
||||||
RUN dnf install -y python3.13 libxcb mesa-libGL && \
|
RUN dnf install -y python3.13 libxcb mesa-libGL && \
|
||||||
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
|
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
|
||||||
python -m ensurepip --upgrade && \
|
python -m ensurepip --upgrade && \
|
||||||
|
pip3 install --no-cache-dir --upgrade 'pip>=26.0' 'setuptools>=78.1.1' && \
|
||||||
pip3 install --no-cache-dir build wheel aiohttp && \
|
pip3 install --no-cache-dir build wheel aiohttp && \
|
||||||
pip3 install --no-cache-dir pulsar-client==3.7.0 && \
|
pip3 install --no-cache-dir pulsar-client==3.11.0 && \
|
||||||
dnf clean all
|
dnf clean all
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -11,8 +11,9 @@ ENV PIP_BREAK_SYSTEM_PACKAGES=1
|
||||||
RUN dnf install -y python3.13 && \
|
RUN dnf install -y python3.13 && \
|
||||||
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
|
alternatives --install /usr/bin/python python /usr/bin/python3.13 1 && \
|
||||||
python -m ensurepip --upgrade && \
|
python -m ensurepip --upgrade && \
|
||||||
|
pip3 install --no-cache-dir --upgrade 'pip>=26.0' 'setuptools>=78.1.1' && \
|
||||||
pip3 install --no-cache-dir build wheel aiohttp && \
|
pip3 install --no-cache-dir build wheel aiohttp && \
|
||||||
pip3 install --no-cache-dir pulsar-client==3.7.0 && \
|
pip3 install --no-cache-dir pulsar-client==3.11.0 && \
|
||||||
pip3 install --no-cache-dir google-cloud-aiplatform && \
|
pip3 install --no-cache-dir google-cloud-aiplatform && \
|
||||||
dnf clean all
|
dnf clean all
|
||||||
|
|
||||||
|
|
|
||||||
366
docs/tech-specs/workspace-scoped-services.md
Normal file
366
docs/tech-specs/workspace-scoped-services.md
Normal file
|
|
@ -0,0 +1,366 @@
|
||||||
|
---
|
||||||
|
layout: default
|
||||||
|
title: "Workspace-Scoped Services"
|
||||||
|
parent: "Tech Specs"
|
||||||
|
---
|
||||||
|
|
||||||
|
# Workspace-Scoped Services
|
||||||
|
|
||||||
|
## Problem Statement
|
||||||
|
|
||||||
|
Workspace-scoped services (librarian, config, knowledge, collection
|
||||||
|
management) currently operate on global queues — a single
|
||||||
|
`request:tg:librarian` queue handles requests for all workspaces.
|
||||||
|
Workspace identity is carried as a field in the request body, set by
|
||||||
|
the gateway after authentication. This creates several problems:
|
||||||
|
|
||||||
|
- **No structural isolation.** All workspaces share a single queue.
|
||||||
|
Workspace scoping depends entirely on a body field being populated
|
||||||
|
correctly. If the field is missing or wrong, the service operates
|
||||||
|
on the wrong workspace — or fails with a confusing error. This is
|
||||||
|
a security concern: workspace isolation should be enforced by
|
||||||
|
infrastructure, not by trusting a field.
|
||||||
|
|
||||||
|
- **Redundant workspace fields.** Nested objects within requests
|
||||||
|
(e.g. `processing-metadata`, `document-metadata`) carry their own
|
||||||
|
`workspace` fields alongside the top-level request workspace.
|
||||||
|
The gateway resolves the top-level workspace but does not
|
||||||
|
propagate into nested payloads. Services that read workspace from
|
||||||
|
a nested object instead of the top-level address see `None` and
|
||||||
|
fail.
|
||||||
|
|
||||||
|
- **No workspace lifecycle awareness.** Workspace-scoped services
|
||||||
|
have no mechanism to learn when workspaces are created or deleted.
|
||||||
|
Flow processors discover workspaces indirectly through config
|
||||||
|
entries, but workspace-scoped services on global queues have no
|
||||||
|
equivalent. There is no event when a workspace appears or
|
||||||
|
disappears.
|
||||||
|
|
||||||
|
- **Inconsistency with flow-scoped services.** Flow-scoped services
|
||||||
|
already use per-workspace, per-flow queue names
|
||||||
|
(`request:tg:{workspace}:embeddings:{class}`). Workspace-scoped
|
||||||
|
services are the exception — they sit on global queues while
|
||||||
|
everything else is structurally isolated.
|
||||||
|
|
||||||
|
## Design
|
||||||
|
|
||||||
|
### Per-workspace queues for workspace-scoped services
|
||||||
|
|
||||||
|
Workspace-scoped services move from global queues to per-workspace
|
||||||
|
queues. The queue name includes the workspace identifier:
|
||||||
|
|
||||||
|
**Current (global):**
|
||||||
|
```
|
||||||
|
request:tg:librarian
|
||||||
|
request:tg:config
|
||||||
|
```
|
||||||
|
|
||||||
|
**Proposed (per-workspace):**
|
||||||
|
```
|
||||||
|
request:tg:librarian:{workspace}
|
||||||
|
request:tg:config:{workspace}
|
||||||
|
```
|
||||||
|
|
||||||
|
The gateway routes requests to the correct queue based on the
|
||||||
|
resolved workspace from authentication — the same workspace that
|
||||||
|
today gets written into the request body. The workspace is now part
|
||||||
|
of the queue address, not just a field in the payload.
|
||||||
|
|
||||||
|
Services subscribe to per-workspace queues. When a new workspace is
|
||||||
|
created, they subscribe to its queue. When a workspace is deleted,
|
||||||
|
they unsubscribe.
|
||||||
|
|
||||||
|
### Workspace lifecycle via the `__workspaces__` config namespace
|
||||||
|
|
||||||
|
Workspace lifecycle events are modelled as config changes in a
|
||||||
|
reserved `__workspaces__` namespace. This mirrors the existing
|
||||||
|
`__system__` namespace — a reserved space for infrastructure
|
||||||
|
concerns that don't belong to any user workspace.
|
||||||
|
|
||||||
|
When IAM creates a workspace, it writes an entry to the config
|
||||||
|
service:
|
||||||
|
|
||||||
|
```
|
||||||
|
workspace: __workspaces__
|
||||||
|
type: workspace
|
||||||
|
key: <workspace-id>
|
||||||
|
value: {"enabled": true}
|
||||||
|
```
|
||||||
|
|
||||||
|
When IAM deletes (or disables) a workspace, it updates or deletes
|
||||||
|
the entry. The config service sees this as a normal config change
|
||||||
|
and pushes a notification through the existing `ConfigPush`
|
||||||
|
mechanism.
|
||||||
|
|
||||||
|
This avoids introducing a new notification channel. The config
|
||||||
|
service already has the machinery to notify subscribers of changes
|
||||||
|
by type and workspace. Workspace lifecycle is just another config
|
||||||
|
type that services can register handlers for.
|
||||||
|
|
||||||
|
### Config push changes
|
||||||
|
|
||||||
|
#### Remove `_`-prefix suppression
|
||||||
|
|
||||||
|
The config service currently suppresses notifications for workspaces
|
||||||
|
whose names start with `_`. This suppression is removed — the
|
||||||
|
config service pushes notifications for all workspaces
|
||||||
|
unconditionally.
|
||||||
|
|
||||||
|
The filtering moves to the consumer side. `AsyncProcessor` already
|
||||||
|
filters `_`-prefixed workspaces in its config handler dispatch
|
||||||
|
(lines 212 and 315 of `async_processor.py`). This filtering is
|
||||||
|
retained as the default behaviour, but handlers can opt in to
|
||||||
|
infrastructure namespaces by registering for them explicitly (see
|
||||||
|
`WorkspaceProcessor` below).
|
||||||
|
|
||||||
|
#### Workspace change events
|
||||||
|
|
||||||
|
The `ConfigPush` message gains a `workspace_changes` field alongside
|
||||||
|
the existing `changes` field:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class ConfigPush:
|
||||||
|
version: int = 0
|
||||||
|
|
||||||
|
# Config changes: type -> [affected workspaces]
|
||||||
|
changes: dict[str, list[str]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Workspace lifecycle: created/deleted workspace lists
|
||||||
|
workspace_changes: WorkspaceChanges | None = None
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WorkspaceChanges:
|
||||||
|
created: list[str] = field(default_factory=list)
|
||||||
|
deleted: list[str] = field(default_factory=list)
|
||||||
|
```
|
||||||
|
|
||||||
|
The config service populates `workspace_changes` when it detects
|
||||||
|
changes to the `__workspaces__` config namespace. A new key
|
||||||
|
appearing is a creation; a key being deleted is a deletion.
|
||||||
|
|
||||||
|
Services that don't care about workspace lifecycle ignore the field.
|
||||||
|
Services that do (workspace-scoped services, the gateway) react by
|
||||||
|
subscribing to or tearing down per-workspace queues.
|
||||||
|
|
||||||
|
### The `WorkspaceProcessor` base class
|
||||||
|
|
||||||
|
A new base class sits between `AsyncProcessor` and `FlowProcessor`
|
||||||
|
in the processor hierarchy:
|
||||||
|
|
||||||
|
```
|
||||||
|
AsyncProcessor → WorkspaceProcessor → FlowProcessor
|
||||||
|
```
|
||||||
|
|
||||||
|
`WorkspaceProcessor` manages per-workspace queue lifecycle the same
|
||||||
|
way `FlowProcessor` manages per-flow lifecycle. It:
|
||||||
|
|
||||||
|
1. On startup, discovers existing workspaces by fetching config from
|
||||||
|
the `__workspaces__` namespace (using the existing
|
||||||
|
`_fetch_type_all_workspaces` pattern).
|
||||||
|
|
||||||
|
2. For each workspace, subscribes to the service's per-workspace
|
||||||
|
queue (e.g. `request:tg:librarian:{workspace}`).
|
||||||
|
|
||||||
|
3. Registers a config handler for the `workspace` type in the
|
||||||
|
`__workspaces__` namespace. When a workspace is created, it
|
||||||
|
subscribes to the new queue. When a workspace is deleted, it
|
||||||
|
unsubscribes and tears down.
|
||||||
|
|
||||||
|
4. Exposes hooks for derived classes:
|
||||||
|
- `on_workspace_created(workspace)` — called after subscribing
|
||||||
|
to the new workspace's queue.
|
||||||
|
- `on_workspace_deleted(workspace)` — called before
|
||||||
|
unsubscribing from the workspace's queue.
|
||||||
|
|
||||||
|
`FlowProcessor` extends `WorkspaceProcessor` instead of
|
||||||
|
`AsyncProcessor`. Flows exist within workspaces, so the hierarchy
|
||||||
|
is natural: workspace creation triggers queue subscription, then
|
||||||
|
flow config changes within that workspace trigger flow start/stop.
|
||||||
|
|
||||||
|
Services that are workspace-scoped but not flow-scoped (librarian,
|
||||||
|
knowledge, collection management) extend `WorkspaceProcessor`
|
||||||
|
directly.
|
||||||
|
|
||||||
|
### Gateway routing changes
|
||||||
|
|
||||||
|
The gateway currently dispatches workspace-scoped requests to global
|
||||||
|
service dispatchers. This changes to per-workspace dispatchers that
|
||||||
|
route to per-workspace queues.
|
||||||
|
|
||||||
|
For HTTP requests, the resolved workspace from the URL path
|
||||||
|
(`/api/v1/workspaces/{w}/library`) determines the target queue.
|
||||||
|
|
||||||
|
For WebSocket requests via the Mux, the resolved workspace from
|
||||||
|
`enforce_workspace` determines the target queue. The Mux already
|
||||||
|
resolves workspace before dispatching (line 214 of `mux.py`); the
|
||||||
|
change is that `invoke_global_service` uses workspace to select the
|
||||||
|
queue, rather than routing to a single global queue.
|
||||||
|
|
||||||
|
System-level services (IAM) remain on global queues — they are not
|
||||||
|
workspace-scoped.
|
||||||
|
|
||||||
|
### Workspace field on nested metadata objects
|
||||||
|
|
||||||
|
With per-workspace queues, the workspace is part of the queue
|
||||||
|
address. Services know which workspace they are serving by which
|
||||||
|
queue a message arrived on.
|
||||||
|
|
||||||
|
The `workspace` field on `DocumentMetadata` and
|
||||||
|
`ProcessingMetadata` in the librarian schema becomes a storage
|
||||||
|
attribute — the workspace the record belongs to, populated by the
|
||||||
|
service from the request context, not by the caller. The service
|
||||||
|
reads workspace from `request.workspace` (the resolved address) or
|
||||||
|
from the queue context, never from a nested payload field.
|
||||||
|
|
||||||
|
Callers are not required to populate workspace on nested objects.
|
||||||
|
The service fills it in authoritatively from the request context
|
||||||
|
before storing.
|
||||||
|
|
||||||
|
## Interaction with existing specs
|
||||||
|
|
||||||
|
### IAM (`iam.md`, `iam-contract.md`)
|
||||||
|
|
||||||
|
IAM is the authority for workspace existence. When IAM creates or
|
||||||
|
deletes a workspace, it writes to the `__workspaces__` config
|
||||||
|
namespace. This is a two-step operation: register the workspace in
|
||||||
|
IAM's own store (`iam_workspaces` table), then announce it via
|
||||||
|
config.
|
||||||
|
|
||||||
|
The IAM service itself remains on a global queue — it is a
|
||||||
|
system-level service, not workspace-scoped.
|
||||||
|
|
||||||
|
### Config service
|
||||||
|
|
||||||
|
The config service is workspace-scoped — it stores per-workspace
|
||||||
|
configuration. Under this design, the config service moves to
|
||||||
|
per-workspace queues like other workspace-scoped services.
|
||||||
|
|
||||||
|
On startup, the config service discovers workspaces from its own
|
||||||
|
store (it has direct access to the config tables, unlike other
|
||||||
|
services that fetch via request/response). It subscribes to
|
||||||
|
per-workspace queues for each known workspace.
|
||||||
|
|
||||||
|
When IAM writes a new workspace entry to the `__workspaces__`
|
||||||
|
namespace, the config service sees the write directly (it is the
|
||||||
|
config service), creates the per-workspace queue, and pushes the
|
||||||
|
notification.
|
||||||
|
|
||||||
|
### Flow blueprints (`flow-blueprint-definition.md`)
|
||||||
|
|
||||||
|
Flow blueprints already use `{workspace}` in queue name templates.
|
||||||
|
No changes needed — flows are created within an already-existing
|
||||||
|
workspace, so the per-workspace infrastructure is in place before
|
||||||
|
flow start.
|
||||||
|
|
||||||
|
### Data ownership (`data-ownership-model.md`)
|
||||||
|
|
||||||
|
This spec reinforces the data ownership model: a workspace is the
|
||||||
|
primary isolation boundary, and per-workspace queues make that
|
||||||
|
boundary structural rather than conventional.
|
||||||
|
|
||||||
|
## Migration
|
||||||
|
|
||||||
|
### Queue naming
|
||||||
|
|
||||||
|
Existing deployments use global queues for workspace-scoped
|
||||||
|
services. Migration requires:
|
||||||
|
|
||||||
|
1. Deploy updated services that subscribe to both global and
|
||||||
|
per-workspace queues during a transition period.
|
||||||
|
2. Update the gateway to route to per-workspace queues.
|
||||||
|
3. Drain the global queues.
|
||||||
|
4. Remove global queue subscriptions from services.
|
||||||
|
|
||||||
|
### `__workspaces__` bootstrap
|
||||||
|
|
||||||
|
On first start after migration, IAM populates the `__workspaces__`
|
||||||
|
config namespace with entries for all existing workspaces from
|
||||||
|
`iam_workspaces`. This seeds the config store so that
|
||||||
|
workspace-scoped services discover existing workspaces on startup.
|
||||||
|
|
||||||
|
### Config push compatibility
|
||||||
|
|
||||||
|
The `workspace_changes` field on `ConfigPush` is additive.
|
||||||
|
Services that don't understand it ignore it (the field defaults to
|
||||||
|
`None`). No breaking change to the push protocol.
|
||||||
|
|
||||||
|
## Summary of changes
|
||||||
|
|
||||||
|
| Component | Change |
|
||||||
|
|-----------|--------|
|
||||||
|
| Queue names | Workspace-scoped services move from `request:tg:{service}` to `request:tg:{service}:{workspace}` |
|
||||||
|
| `__workspaces__` namespace | New reserved config namespace for workspace lifecycle |
|
||||||
|
| IAM service | Writes to `__workspaces__` on workspace create/delete |
|
||||||
|
| Config service | Removes `_`-prefix notification suppression; generates `workspace_changes` events; moves to per-workspace queues |
|
||||||
|
| `ConfigPush` schema | Adds `workspace_changes` field (`WorkspaceChanges` dataclass) |
|
||||||
|
| `WorkspaceProcessor` | New base class managing per-workspace queue lifecycle |
|
||||||
|
| `FlowProcessor` | Extends `WorkspaceProcessor` instead of `AsyncProcessor` |
|
||||||
|
| `AsyncProcessor` | Relaxes `_`-prefix filtering to allow opt-in for infrastructure namespaces |
|
||||||
|
| Gateway | Routes workspace-scoped requests to per-workspace queues |
|
||||||
|
| Librarian schema | `workspace` on nested metadata becomes a service-populated storage attribute, not a caller-supplied address |
|
||||||
|
|
||||||
|
## Implementation Plan
|
||||||
|
|
||||||
|
### Phase 1: Foundation — `__workspaces__` namespace and config push
|
||||||
|
|
||||||
|
- **`ConfigPush` schema** (`trustgraph-base/trustgraph/schema/services/config.py`): Add `WorkspaceChanges` dataclass and `workspace_changes` field.
|
||||||
|
- **Config push serialization** (`trustgraph-base/trustgraph/messaging/translators/`): Encode/decode the new field.
|
||||||
|
- **Config service** (`trustgraph-flow/trustgraph/config/`): Detect writes to `__workspaces__` namespace and populate `workspace_changes` on the push message. Remove `_`-prefix notification suppression.
|
||||||
|
- **`AsyncProcessor`** (`trustgraph-base/trustgraph/base/async_processor.py`): Relax `_`-prefix filtering so handlers can opt in to infrastructure namespaces.
|
||||||
|
- **IAM service** (`trustgraph-flow/trustgraph/iam/`): Write to `__workspaces__` config namespace on `create-workspace` and `delete-workspace`. Add bootstrap step to seed `__workspaces__` entries for existing workspaces.
|
||||||
|
|
||||||
|
### Phase 2: `WorkspaceProcessor` base class
|
||||||
|
|
||||||
|
- **New `WorkspaceProcessor`** (`trustgraph-base/trustgraph/base/workspace_processor.py`): Implements workspace discovery on startup, per-workspace queue subscribe/unsubscribe, workspace lifecycle handler registration, `on_workspace_created`/`on_workspace_deleted` hooks.
|
||||||
|
- **`FlowProcessor`** (`trustgraph-base/trustgraph/base/flow_processor.py`): Re-parent from `AsyncProcessor` to `WorkspaceProcessor`.
|
||||||
|
- **Verify** existing flow processors continue to work — the new layer should be transparent to them.
|
||||||
|
|
||||||
|
### Phase 3: Per-workspace queues for workspace-scoped services
|
||||||
|
|
||||||
|
- **Queue definitions** (`trustgraph-base/trustgraph/schema/`): Update queue names for librarian, config, knowledge, collection management to include `{workspace}`.
|
||||||
|
- **Librarian** (`trustgraph-flow/trustgraph/librarian/`): Extend `WorkspaceProcessor`. Remove reliance on workspace from nested metadata objects.
|
||||||
|
- **Knowledge service, collection management** and other workspace-scoped services: Extend `WorkspaceProcessor`.
|
||||||
|
- **Config service**: Self-bootstrap per-workspace queues from its own store on startup; subscribe to new workspace queues when `__workspaces__` entries appear.
|
||||||
|
|
||||||
|
### Phase 4: Gateway routing
|
||||||
|
|
||||||
|
- **Gateway dispatcher manager** (`trustgraph-flow/trustgraph/gateway/dispatch/manager.py`): Route workspace-scoped services to per-workspace queues using resolved workspace. System-level services (IAM) remain on global queues.
|
||||||
|
- **Mux** (`trustgraph-flow/trustgraph/gateway/dispatch/mux.py`): Pass workspace to `invoke_global_service` for workspace-scoped services.
|
||||||
|
- **HTTP endpoints** (`trustgraph-flow/trustgraph/gateway/endpoint/`): Route to per-workspace queues based on URL path workspace.
|
||||||
|
|
||||||
|
### Phase 5: Schema cleanup
|
||||||
|
|
||||||
|
- **`DocumentMetadata`, `ProcessingMetadata`** (`trustgraph-base/trustgraph/schema/services/library.py`): Remove `workspace` field from nested metadata objects, or retain as a service-populated storage attribute only.
|
||||||
|
- **Serialization** (`trustgraph-flow/trustgraph/gateway/dispatch/serialize.py`, `trustgraph-base/trustgraph/messaging/translators/metadata.py`): Update translators to match.
|
||||||
|
- **API client** (`trustgraph-base/trustgraph/api/library.py`): Stop sending workspace in nested payloads.
|
||||||
|
- **Librarian service** (`trustgraph-flow/trustgraph/librarian/`): Populate workspace on stored records from request context.
|
||||||
|
|
||||||
|
### Dependencies
|
||||||
|
|
||||||
|
```
|
||||||
|
Phase 1 (foundation)
|
||||||
|
↓
|
||||||
|
Phase 2 (WorkspaceProcessor)
|
||||||
|
↓
|
||||||
|
Phase 3 (per-workspace queues) ←→ Phase 4 (gateway routing)
|
||||||
|
↓ ↓
|
||||||
|
Phase 5 (schema cleanup)
|
||||||
|
```
|
||||||
|
|
||||||
|
Phases 3 and 4 can be developed in parallel but must be deployed together — services expecting per-workspace queues need the gateway to route to them.
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- [Identity and Access Management](iam.md) — workspace registry,
|
||||||
|
authentication, and workspace resolution.
|
||||||
|
- [IAM Contract](iam-contract.md) — resource model and workspace as
|
||||||
|
address vs. parameter.
|
||||||
|
- [Data Ownership and Information Separation](data-ownership-model.md)
|
||||||
|
— workspace as isolation boundary.
|
||||||
|
- [Config Push and Poke](config-push-poke.md) — config notification
|
||||||
|
mechanism.
|
||||||
|
- [Flow Blueprint Definition](flow-blueprint-definition.md) —
|
||||||
|
`{workspace}` template variable in queue names.
|
||||||
|
- [Flow Service Queue Lifecycle](flow-service-queue-lifecycle.md) —
|
||||||
|
queue ownership and lifecycle model.
|
||||||
|
|
@ -110,7 +110,8 @@ class TestEndToEndConfigurationFlow:
|
||||||
cassandra_host=['kg-host1', 'kg-host2', 'kg-host3', 'kg-host4'],
|
cassandra_host=['kg-host1', 'kg-host2', 'kg-host3', 'kg-host4'],
|
||||||
cassandra_username='kg-user',
|
cassandra_username='kg-user',
|
||||||
cassandra_password='kg-pass',
|
cassandra_password='kg-pass',
|
||||||
keyspace='knowledge'
|
keyspace='knowledge',
|
||||||
|
replication_factor=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -182,7 +183,8 @@ class TestConfigurationPriorityEndToEnd:
|
||||||
cassandra_host=['partial-host'], # From parameter
|
cassandra_host=['partial-host'], # From parameter
|
||||||
cassandra_username='fallback-user', # From environment
|
cassandra_username='fallback-user', # From environment
|
||||||
cassandra_password='fallback-pass', # From environment
|
cassandra_password='fallback-pass', # From environment
|
||||||
keyspace='knowledge'
|
keyspace='knowledge',
|
||||||
|
replication_factor=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -273,7 +275,8 @@ class TestNoBackwardCompatibilityEndToEnd:
|
||||||
cassandra_host=['legacy-kg-host'],
|
cassandra_host=['legacy-kg-host'],
|
||||||
cassandra_username=None, # Should be None since cassandra_user is not recognized
|
cassandra_username=None, # Should be None since cassandra_user is not recognized
|
||||||
cassandra_password='legacy-kg-pass',
|
cassandra_password='legacy-kg-pass',
|
||||||
keyspace='knowledge'
|
keyspace='knowledge',
|
||||||
|
replication_factor=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -367,13 +370,13 @@ class TestMultipleHostsHandling:
|
||||||
from trustgraph.base.cassandra_config import resolve_cassandra_config
|
from trustgraph.base.cassandra_config import resolve_cassandra_config
|
||||||
|
|
||||||
# Test various whitespace scenarios
|
# Test various whitespace scenarios
|
||||||
hosts1, _, _, _ = resolve_cassandra_config(host='host1, host2 , host3')
|
hosts1, _, _, _, _ = resolve_cassandra_config(host='host1, host2 , host3')
|
||||||
assert hosts1 == ['host1', 'host2', 'host3']
|
assert hosts1 == ['host1', 'host2', 'host3']
|
||||||
|
|
||||||
hosts2, _, _, _ = resolve_cassandra_config(host='host1,host2,host3,')
|
hosts2, _, _, _, _ = resolve_cassandra_config(host='host1,host2,host3,')
|
||||||
assert hosts2 == ['host1', 'host2', 'host3']
|
assert hosts2 == ['host1', 'host2', 'host3']
|
||||||
|
|
||||||
hosts3, _, _, _ = resolve_cassandra_config(host=' host1 , host2 ')
|
hosts3, _, _, _, _ = resolve_cassandra_config(host=' host1 , host2 ')
|
||||||
assert hosts3 == ['host1', 'host2']
|
assert hosts3 == ['host1', 'host2']
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ class TestDocumentRagIntegration:
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_fetch_chunk(self):
|
def mock_fetch_chunk(self):
|
||||||
"""Mock fetch_chunk function that retrieves chunk content from librarian"""
|
"""Mock fetch_chunk function that retrieves chunk content from librarian"""
|
||||||
async def fetch(chunk_id, user):
|
async def fetch(chunk_id):
|
||||||
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||||
return fetch
|
return fetch
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -297,10 +297,10 @@ class TestTextCompletionIntegration:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_text_completion_authentication_patterns(self):
|
async def test_text_completion_authentication_patterns(self):
|
||||||
"""Test different authentication configurations"""
|
"""Test different authentication configurations"""
|
||||||
# Test missing API key first (this should fail early)
|
# Test missing API key - now uses placeholder instead of raising
|
||||||
with pytest.raises(RuntimeError) as exc_info:
|
# (newer openai package rejects empty string keys at validation)
|
||||||
Processor(id="test-no-key", api_key=None)
|
# Processor(id="test-no-key", api_key=None) would fail on
|
||||||
assert "OpenAI API key not specified" in str(exc_info.value)
|
# missing taskgroup, not on API key
|
||||||
|
|
||||||
# Test authentication pattern by examining the initialization logic
|
# Test authentication pattern by examining the initialization logic
|
||||||
# Since we can't fully instantiate due to taskgroup requirements,
|
# Since we can't fully instantiate due to taskgroup requirements,
|
||||||
|
|
|
||||||
|
|
@ -145,7 +145,7 @@ class TestResolveCassandraConfig:
|
||||||
def test_default_configuration(self):
|
def test_default_configuration(self):
|
||||||
"""Test resolution with no parameters or environment variables."""
|
"""Test resolution with no parameters or environment variables."""
|
||||||
with patch.dict(os.environ, {}, clear=True):
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
hosts, username, password, keyspace = resolve_cassandra_config()
|
hosts, username, password, keyspace, _ = resolve_cassandra_config()
|
||||||
|
|
||||||
assert hosts == ['cassandra']
|
assert hosts == ['cassandra']
|
||||||
assert username is None
|
assert username is None
|
||||||
|
|
@ -160,7 +160,7 @@ class TestResolveCassandraConfig:
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch.dict(os.environ, env_vars, clear=True):
|
with patch.dict(os.environ, env_vars, clear=True):
|
||||||
hosts, username, password, keyspace = resolve_cassandra_config()
|
hosts, username, password, keyspace, _ = resolve_cassandra_config()
|
||||||
|
|
||||||
assert hosts == ['env1', 'env2', 'env3']
|
assert hosts == ['env1', 'env2', 'env3']
|
||||||
assert username == 'env-user'
|
assert username == 'env-user'
|
||||||
|
|
@ -175,7 +175,7 @@ class TestResolveCassandraConfig:
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch.dict(os.environ, env_vars, clear=True):
|
with patch.dict(os.environ, env_vars, clear=True):
|
||||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
hosts, username, password, keyspace, _ = resolve_cassandra_config(
|
||||||
host='explicit-host',
|
host='explicit-host',
|
||||||
username='explicit-user',
|
username='explicit-user',
|
||||||
password='explicit-pass'
|
password='explicit-pass'
|
||||||
|
|
@ -188,19 +188,19 @@ class TestResolveCassandraConfig:
|
||||||
def test_host_list_parsing(self):
|
def test_host_list_parsing(self):
|
||||||
"""Test different host list formats."""
|
"""Test different host list formats."""
|
||||||
# Single host
|
# Single host
|
||||||
hosts, _, _, _ = resolve_cassandra_config(host='single-host')
|
hosts, _, _, _, _ = resolve_cassandra_config(host='single-host')
|
||||||
assert hosts == ['single-host']
|
assert hosts == ['single-host']
|
||||||
|
|
||||||
# Multiple hosts with spaces
|
# Multiple hosts with spaces
|
||||||
hosts, _, _, _ = resolve_cassandra_config(host='host1, host2 ,host3')
|
hosts, _, _, _, _ = resolve_cassandra_config(host='host1, host2 ,host3')
|
||||||
assert hosts == ['host1', 'host2', 'host3']
|
assert hosts == ['host1', 'host2', 'host3']
|
||||||
|
|
||||||
# Empty elements filtered out
|
# Empty elements filtered out
|
||||||
hosts, _, _, _ = resolve_cassandra_config(host='host1,,host2,')
|
hosts, _, _, _, _ = resolve_cassandra_config(host='host1,,host2,')
|
||||||
assert hosts == ['host1', 'host2']
|
assert hosts == ['host1', 'host2']
|
||||||
|
|
||||||
# Already a list
|
# Already a list
|
||||||
hosts, _, _, _ = resolve_cassandra_config(host=['list-host1', 'list-host2'])
|
hosts, _, _, _, _ = resolve_cassandra_config(host=['list-host1', 'list-host2'])
|
||||||
assert hosts == ['list-host1', 'list-host2']
|
assert hosts == ['list-host1', 'list-host2']
|
||||||
|
|
||||||
def test_args_object_resolution(self):
|
def test_args_object_resolution(self):
|
||||||
|
|
@ -212,7 +212,7 @@ class TestResolveCassandraConfig:
|
||||||
cassandra_password = 'args-pass'
|
cassandra_password = 'args-pass'
|
||||||
|
|
||||||
args = MockArgs()
|
args = MockArgs()
|
||||||
hosts, username, password, keyspace = resolve_cassandra_config(args)
|
hosts, username, password, keyspace, _ = resolve_cassandra_config(args)
|
||||||
|
|
||||||
assert hosts == ['args-host1', 'args-host2']
|
assert hosts == ['args-host1', 'args-host2']
|
||||||
assert username == 'args-user'
|
assert username == 'args-user'
|
||||||
|
|
@ -233,7 +233,7 @@ class TestResolveCassandraConfig:
|
||||||
|
|
||||||
with patch.dict(os.environ, env_vars, clear=True):
|
with patch.dict(os.environ, env_vars, clear=True):
|
||||||
args = PartialArgs()
|
args = PartialArgs()
|
||||||
hosts, username, password, keyspace = resolve_cassandra_config(args)
|
hosts, username, password, keyspace, _ = resolve_cassandra_config(args)
|
||||||
|
|
||||||
assert hosts == ['args-host'] # From args
|
assert hosts == ['args-host'] # From args
|
||||||
assert username == 'env-user' # From env
|
assert username == 'env-user' # From env
|
||||||
|
|
@ -251,7 +251,7 @@ class TestGetCassandraConfigFromParams:
|
||||||
'cassandra_password': 'new-pass'
|
'cassandra_password': 'new-pass'
|
||||||
}
|
}
|
||||||
|
|
||||||
hosts, username, password, keyspace = get_cassandra_config_from_params(params)
|
hosts, username, password, keyspace, _ = get_cassandra_config_from_params(params)
|
||||||
|
|
||||||
assert hosts == ['new-host1', 'new-host2']
|
assert hosts == ['new-host1', 'new-host2']
|
||||||
assert username == 'new-user'
|
assert username == 'new-user'
|
||||||
|
|
@ -265,7 +265,7 @@ class TestGetCassandraConfigFromParams:
|
||||||
'graph_password': 'old-pass'
|
'graph_password': 'old-pass'
|
||||||
}
|
}
|
||||||
|
|
||||||
hosts, username, password, keyspace = get_cassandra_config_from_params(params)
|
hosts, username, password, keyspace, _ = get_cassandra_config_from_params(params)
|
||||||
|
|
||||||
# Should use defaults since graph_* params are not recognized
|
# Should use defaults since graph_* params are not recognized
|
||||||
assert hosts == ['cassandra'] # Default
|
assert hosts == ['cassandra'] # Default
|
||||||
|
|
@ -280,7 +280,7 @@ class TestGetCassandraConfigFromParams:
|
||||||
'cassandra_password': 'compat-pass'
|
'cassandra_password': 'compat-pass'
|
||||||
}
|
}
|
||||||
|
|
||||||
hosts, username, password, keyspace = get_cassandra_config_from_params(params)
|
hosts, username, password, keyspace, _ = get_cassandra_config_from_params(params)
|
||||||
|
|
||||||
assert hosts == ['compat-host']
|
assert hosts == ['compat-host']
|
||||||
assert username is None # cassandra_user is not recognized
|
assert username is None # cassandra_user is not recognized
|
||||||
|
|
@ -298,7 +298,7 @@ class TestGetCassandraConfigFromParams:
|
||||||
'graph_password': 'old-pass'
|
'graph_password': 'old-pass'
|
||||||
}
|
}
|
||||||
|
|
||||||
hosts, username, password, keyspace = get_cassandra_config_from_params(params)
|
hosts, username, password, keyspace, _ = get_cassandra_config_from_params(params)
|
||||||
|
|
||||||
assert hosts == ['new-host'] # Only cassandra_* params work
|
assert hosts == ['new-host'] # Only cassandra_* params work
|
||||||
assert username == 'new-user' # Only cassandra_* params work
|
assert username == 'new-user' # Only cassandra_* params work
|
||||||
|
|
@ -314,7 +314,7 @@ class TestGetCassandraConfigFromParams:
|
||||||
|
|
||||||
with patch.dict(os.environ, env_vars, clear=True):
|
with patch.dict(os.environ, env_vars, clear=True):
|
||||||
params = {}
|
params = {}
|
||||||
hosts, username, password, keyspace = get_cassandra_config_from_params(params)
|
hosts, username, password, keyspace, _ = get_cassandra_config_from_params(params)
|
||||||
|
|
||||||
assert hosts == ['fallback-host1', 'fallback-host2']
|
assert hosts == ['fallback-host1', 'fallback-host2']
|
||||||
assert username == 'fallback-user'
|
assert username == 'fallback-user'
|
||||||
|
|
@ -334,7 +334,7 @@ class TestConfigurationPriority:
|
||||||
|
|
||||||
with patch.dict(os.environ, env_vars, clear=True):
|
with patch.dict(os.environ, env_vars, clear=True):
|
||||||
# CLI args should override everything
|
# CLI args should override everything
|
||||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
hosts, username, password, keyspace, _ = resolve_cassandra_config(
|
||||||
host='cli-host',
|
host='cli-host',
|
||||||
username='cli-user',
|
username='cli-user',
|
||||||
password='cli-pass'
|
password='cli-pass'
|
||||||
|
|
@ -354,7 +354,7 @@ class TestConfigurationPriority:
|
||||||
|
|
||||||
with patch.dict(os.environ, env_vars, clear=True):
|
with patch.dict(os.environ, env_vars, clear=True):
|
||||||
# Only provide host via CLI
|
# Only provide host via CLI
|
||||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
hosts, username, password, keyspace, _ = resolve_cassandra_config(
|
||||||
host='cli-host'
|
host='cli-host'
|
||||||
# username and password not provided
|
# username and password not provided
|
||||||
)
|
)
|
||||||
|
|
@ -366,7 +366,7 @@ class TestConfigurationPriority:
|
||||||
def test_no_config_defaults(self):
|
def test_no_config_defaults(self):
|
||||||
"""Test that defaults are used when no configuration is provided."""
|
"""Test that defaults are used when no configuration is provided."""
|
||||||
with patch.dict(os.environ, {}, clear=True):
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
hosts, username, password, keyspace = resolve_cassandra_config()
|
hosts, username, password, keyspace, _ = resolve_cassandra_config()
|
||||||
|
|
||||||
assert hosts == ['cassandra'] # Default
|
assert hosts == ['cassandra'] # Default
|
||||||
assert username is None # Default
|
assert username is None # Default
|
||||||
|
|
@ -378,17 +378,17 @@ class TestEdgeCases:
|
||||||
|
|
||||||
def test_empty_host_string(self):
|
def test_empty_host_string(self):
|
||||||
"""Test handling of empty host string falls back to default."""
|
"""Test handling of empty host string falls back to default."""
|
||||||
hosts, _, _, _ = resolve_cassandra_config(host='')
|
hosts, _, _, _, _ = resolve_cassandra_config(host='')
|
||||||
assert hosts == ['cassandra'] # Falls back to default
|
assert hosts == ['cassandra'] # Falls back to default
|
||||||
|
|
||||||
def test_whitespace_only_host(self):
|
def test_whitespace_only_host(self):
|
||||||
"""Test handling of whitespace-only host string."""
|
"""Test handling of whitespace-only host string."""
|
||||||
hosts, _, _, _ = resolve_cassandra_config(host=' ')
|
hosts, _, _, _, _ = resolve_cassandra_config(host=' ')
|
||||||
assert hosts == [] # Empty after stripping whitespace
|
assert hosts == [] # Empty after stripping whitespace
|
||||||
|
|
||||||
def test_none_values_preserved(self):
|
def test_none_values_preserved(self):
|
||||||
"""Test that None values are preserved correctly."""
|
"""Test that None values are preserved correctly."""
|
||||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
hosts, username, password, keyspace, _ = resolve_cassandra_config(
|
||||||
host=None,
|
host=None,
|
||||||
username=None,
|
username=None,
|
||||||
password=None
|
password=None
|
||||||
|
|
@ -401,7 +401,7 @@ class TestEdgeCases:
|
||||||
|
|
||||||
def test_mixed_none_and_values(self):
|
def test_mixed_none_and_values(self):
|
||||||
"""Test mixing None and actual values."""
|
"""Test mixing None and actual values."""
|
||||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
hosts, username, password, keyspace, _ = resolve_cassandra_config(
|
||||||
host='mixed-host',
|
host='mixed-host',
|
||||||
username=None,
|
username=None,
|
||||||
password='mixed-pass'
|
password='mixed-pass'
|
||||||
|
|
|
||||||
|
|
@ -233,7 +233,7 @@ class TestFlowProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
mock_flow2.start.assert_called_once()
|
mock_flow2.start.assert_called_once()
|
||||||
|
|
||||||
@with_async_processor_patches
|
@with_async_processor_patches
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.start')
|
@patch('trustgraph.base.workspace_processor.WorkspaceProcessor.start')
|
||||||
async def test_start_calls_parent(self, mock_parent_start, *mocks):
|
async def test_start_calls_parent(self, mock_parent_start, *mocks):
|
||||||
"""Test that start() calls parent start method"""
|
"""Test that start() calls parent start method"""
|
||||||
mock_parent_start.return_value = None
|
mock_parent_start.return_value = None
|
||||||
|
|
|
||||||
|
|
@ -177,8 +177,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
# Mock save_child_document to avoid waiting for librarian response
|
# Mock save_child_document on flow to avoid waiting for librarian response
|
||||||
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
|
|
||||||
|
|
||||||
# Mock message with TextDocument
|
# Mock message with TextDocument
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
|
|
@ -204,6 +203,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
"output": mock_producer,
|
"output": mock_producer,
|
||||||
"triples": mock_triples_producer,
|
"triples": mock_triples_producer,
|
||||||
}.get(key)
|
}.get(key)
|
||||||
|
mock_flow.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
await processor.on_message(mock_message, mock_consumer, mock_flow)
|
await processor.on_message(mock_message, mock_consumer, mock_flow)
|
||||||
|
|
|
||||||
|
|
@ -177,8 +177,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
# Mock save_child_document to avoid librarian producer interactions
|
# Mock save_child_document on flow to avoid librarian producer interactions
|
||||||
processor.librarian.save_child_document = AsyncMock(return_value="chunk-id")
|
|
||||||
|
|
||||||
# Mock message with TextDocument
|
# Mock message with TextDocument
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
|
|
@ -204,6 +203,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
"output": mock_producer,
|
"output": mock_producer,
|
||||||
"triples": mock_triples_producer,
|
"triples": mock_triples_producer,
|
||||||
}.get(key)
|
}.get(key)
|
||||||
|
mock_flow.librarian.save_child_document = AsyncMock(return_value="chunk-id")
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
await processor.on_message(mock_message, mock_consumer, mock_flow)
|
await processor.on_message(mock_message, mock_consumer, mock_flow)
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,6 @@ def mock_flow_config():
|
||||||
def mock_request():
|
def mock_request():
|
||||||
"""Mock knowledge load request."""
|
"""Mock knowledge load request."""
|
||||||
request = Mock()
|
request = Mock()
|
||||||
request.workspace = "test-user"
|
|
||||||
request.id = "test-doc-id"
|
request.id = "test-doc-id"
|
||||||
request.collection = "test-collection"
|
request.collection = "test-collection"
|
||||||
request.flow = "test-flow"
|
request.flow = "test-flow"
|
||||||
|
|
@ -131,7 +130,7 @@ class TestKnowledgeManagerLoadCore:
|
||||||
|
|
||||||
# Start the core loader background task
|
# Start the core loader background task
|
||||||
knowledge_manager.background_task = None
|
knowledge_manager.background_task = None
|
||||||
await knowledge_manager.load_kg_core(mock_request, mock_respond)
|
await knowledge_manager.load_kg_core(mock_request, mock_respond, "test-user")
|
||||||
|
|
||||||
# Wait for background processing
|
# Wait for background processing
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
@ -174,7 +173,7 @@ class TestKnowledgeManagerLoadCore:
|
||||||
|
|
||||||
# Start the core loader background task
|
# Start the core loader background task
|
||||||
knowledge_manager.background_task = None
|
knowledge_manager.background_task = None
|
||||||
await knowledge_manager.load_kg_core(mock_request, mock_respond)
|
await knowledge_manager.load_kg_core(mock_request, mock_respond, "test-user")
|
||||||
|
|
||||||
# Wait for background processing
|
# Wait for background processing
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
@ -191,7 +190,6 @@ class TestKnowledgeManagerLoadCore:
|
||||||
"""Test that load_kg_core falls back to 'default' when request.collection is None."""
|
"""Test that load_kg_core falls back to 'default' when request.collection is None."""
|
||||||
# Create request with None collection
|
# Create request with None collection
|
||||||
mock_request = Mock()
|
mock_request = Mock()
|
||||||
mock_request.workspace = "test-user"
|
|
||||||
mock_request.id = "test-doc-id"
|
mock_request.id = "test-doc-id"
|
||||||
mock_request.collection = None # Should fall back to "default"
|
mock_request.collection = None # Should fall back to "default"
|
||||||
mock_request.flow = "test-flow"
|
mock_request.flow = "test-flow"
|
||||||
|
|
@ -213,7 +211,7 @@ class TestKnowledgeManagerLoadCore:
|
||||||
|
|
||||||
# Start the core loader background task
|
# Start the core loader background task
|
||||||
knowledge_manager.background_task = None
|
knowledge_manager.background_task = None
|
||||||
await knowledge_manager.load_kg_core(mock_request, mock_respond)
|
await knowledge_manager.load_kg_core(mock_request, mock_respond, "test-user")
|
||||||
|
|
||||||
# Wait for background processing
|
# Wait for background processing
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
@ -247,7 +245,7 @@ class TestKnowledgeManagerLoadCore:
|
||||||
|
|
||||||
# Start the core loader background task
|
# Start the core loader background task
|
||||||
knowledge_manager.background_task = None
|
knowledge_manager.background_task = None
|
||||||
await knowledge_manager.load_kg_core(mock_request, mock_respond)
|
await knowledge_manager.load_kg_core(mock_request, mock_respond, "test-user")
|
||||||
|
|
||||||
# Wait for background processing
|
# Wait for background processing
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
@ -267,7 +265,6 @@ class TestKnowledgeManagerLoadCore:
|
||||||
"""Test that load_kg_core validates flow configuration before processing."""
|
"""Test that load_kg_core validates flow configuration before processing."""
|
||||||
# Request with invalid flow
|
# Request with invalid flow
|
||||||
mock_request = Mock()
|
mock_request = Mock()
|
||||||
mock_request.workspace = "test-user"
|
|
||||||
mock_request.id = "test-doc-id"
|
mock_request.id = "test-doc-id"
|
||||||
mock_request.collection = "test-collection"
|
mock_request.collection = "test-collection"
|
||||||
mock_request.flow = "invalid-flow" # Not in mock_flow_config.flows
|
mock_request.flow = "invalid-flow" # Not in mock_flow_config.flows
|
||||||
|
|
@ -276,7 +273,7 @@ class TestKnowledgeManagerLoadCore:
|
||||||
|
|
||||||
# Start the core loader background task
|
# Start the core loader background task
|
||||||
knowledge_manager.background_task = None
|
knowledge_manager.background_task = None
|
||||||
await knowledge_manager.load_kg_core(mock_request, mock_respond)
|
await knowledge_manager.load_kg_core(mock_request, mock_respond, "test-user")
|
||||||
|
|
||||||
# Wait for background processing
|
# Wait for background processing
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
@ -295,13 +292,12 @@ class TestKnowledgeManagerLoadCore:
|
||||||
|
|
||||||
# Test missing ID
|
# Test missing ID
|
||||||
mock_request = Mock()
|
mock_request = Mock()
|
||||||
mock_request.workspace = "test-user"
|
|
||||||
mock_request.id = None # Missing
|
mock_request.id = None # Missing
|
||||||
mock_request.collection = "test-collection"
|
mock_request.collection = "test-collection"
|
||||||
mock_request.flow = "test-flow"
|
mock_request.flow = "test-flow"
|
||||||
|
|
||||||
knowledge_manager.background_task = None
|
knowledge_manager.background_task = None
|
||||||
await knowledge_manager.load_kg_core(mock_request, mock_respond)
|
await knowledge_manager.load_kg_core(mock_request, mock_respond, "test-user")
|
||||||
|
|
||||||
# Wait for background processing
|
# Wait for background processing
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
@ -321,7 +317,6 @@ class TestKnowledgeManagerOtherMethods:
|
||||||
async def test_get_kg_core_preserves_collection_from_store(self, knowledge_manager, sample_triples):
|
async def test_get_kg_core_preserves_collection_from_store(self, knowledge_manager, sample_triples):
|
||||||
"""Test that get_kg_core preserves collection field from stored data."""
|
"""Test that get_kg_core preserves collection field from stored data."""
|
||||||
mock_request = Mock()
|
mock_request = Mock()
|
||||||
mock_request.workspace = "test-user"
|
|
||||||
mock_request.id = "test-doc-id"
|
mock_request.id = "test-doc-id"
|
||||||
|
|
||||||
mock_respond = AsyncMock()
|
mock_respond = AsyncMock()
|
||||||
|
|
@ -332,7 +327,7 @@ class TestKnowledgeManagerOtherMethods:
|
||||||
knowledge_manager.table_store.get_triples = mock_get_triples
|
knowledge_manager.table_store.get_triples = mock_get_triples
|
||||||
knowledge_manager.table_store.get_graph_embeddings = AsyncMock()
|
knowledge_manager.table_store.get_graph_embeddings = AsyncMock()
|
||||||
|
|
||||||
await knowledge_manager.get_kg_core(mock_request, mock_respond)
|
await knowledge_manager.get_kg_core(mock_request, mock_respond, "test-user")
|
||||||
|
|
||||||
# Should have called respond for triples and final EOS
|
# Should have called respond for triples and final EOS
|
||||||
assert mock_respond.call_count >= 2
|
assert mock_respond.call_count >= 2
|
||||||
|
|
@ -352,14 +347,13 @@ class TestKnowledgeManagerOtherMethods:
|
||||||
async def test_list_kg_cores(self, knowledge_manager):
|
async def test_list_kg_cores(self, knowledge_manager):
|
||||||
"""Test listing knowledge cores."""
|
"""Test listing knowledge cores."""
|
||||||
mock_request = Mock()
|
mock_request = Mock()
|
||||||
mock_request.workspace = "test-user"
|
|
||||||
|
|
||||||
mock_respond = AsyncMock()
|
mock_respond = AsyncMock()
|
||||||
|
|
||||||
# Mock return value
|
# Mock return value
|
||||||
knowledge_manager.table_store.list_kg_cores.return_value = ["doc1", "doc2", "doc3"]
|
knowledge_manager.table_store.list_kg_cores.return_value = ["doc1", "doc2", "doc3"]
|
||||||
|
|
||||||
await knowledge_manager.list_kg_cores(mock_request, mock_respond)
|
await knowledge_manager.list_kg_cores(mock_request, mock_respond, "test-user")
|
||||||
|
|
||||||
# Verify table store was called correctly
|
# Verify table store was called correctly
|
||||||
knowledge_manager.table_store.list_kg_cores.assert_called_once_with("test-user")
|
knowledge_manager.table_store.list_kg_cores.assert_called_once_with("test-user")
|
||||||
|
|
@ -374,12 +368,11 @@ class TestKnowledgeManagerOtherMethods:
|
||||||
async def test_delete_kg_core(self, knowledge_manager):
|
async def test_delete_kg_core(self, knowledge_manager):
|
||||||
"""Test deleting knowledge cores."""
|
"""Test deleting knowledge cores."""
|
||||||
mock_request = Mock()
|
mock_request = Mock()
|
||||||
mock_request.workspace = "test-user"
|
|
||||||
mock_request.id = "test-doc-id"
|
mock_request.id = "test-doc-id"
|
||||||
|
|
||||||
mock_respond = AsyncMock()
|
mock_respond = AsyncMock()
|
||||||
|
|
||||||
await knowledge_manager.delete_kg_core(mock_request, mock_respond)
|
await knowledge_manager.delete_kg_core(mock_request, mock_respond, "test-user")
|
||||||
|
|
||||||
# Verify table store was called correctly
|
# Verify table store was called correctly
|
||||||
knowledge_manager.table_store.delete_kg_core.assert_called_once_with("test-user", "test-doc-id")
|
knowledge_manager.table_store.delete_kg_core.assert_called_once_with("test-user", "test-doc-id")
|
||||||
|
|
|
||||||
|
|
@ -156,6 +156,7 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
||||||
"output": mock_output_flow,
|
"output": mock_output_flow,
|
||||||
"triples": mock_triples_flow,
|
"triples": mock_triples_flow,
|
||||||
}.get(name))
|
}.get(name))
|
||||||
|
mock_flow.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
'id': 'test-mistral-ocr',
|
'id': 'test-mistral-ocr',
|
||||||
|
|
@ -171,9 +172,6 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
||||||
("# Page 2\nMore content", 2),
|
("# Page 2\nMore content", 2),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Mock save_child_document
|
|
||||||
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
|
|
||||||
|
|
||||||
with patch.object(processor, 'ocr', return_value=ocr_result):
|
with patch.object(processor, 'ocr', return_value=ocr_result):
|
||||||
await processor.on_message(mock_msg, None, mock_flow)
|
await processor.on_message(mock_msg, None, mock_flow)
|
||||||
|
|
||||||
|
|
@ -227,8 +225,7 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
||||||
Processor.add_args(mock_parser)
|
Processor.add_args(mock_parser)
|
||||||
|
|
||||||
mock_parent_add_args.assert_called_once_with(mock_parser)
|
mock_parent_add_args.assert_called_once_with(mock_parser)
|
||||||
assert mock_parser.add_argument.call_count == 3
|
assert mock_parser.add_argument.call_count == 1
|
||||||
# Check the API key arg is among them
|
|
||||||
call_args_list = [c[0] for c in mock_parser.add_argument.call_args_list]
|
call_args_list = [c[0] for c in mock_parser.add_argument.call_args_list]
|
||||||
assert ('-k', '--api-key') in call_args_list
|
assert ('-k', '--api-key') in call_args_list
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -72,6 +72,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
||||||
"output": mock_output_flow,
|
"output": mock_output_flow,
|
||||||
"triples": mock_triples_flow,
|
"triples": mock_triples_flow,
|
||||||
}.get(name))
|
}.get(name))
|
||||||
|
mock_flow.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
'id': 'test-pdf-decoder',
|
'id': 'test-pdf-decoder',
|
||||||
|
|
@ -80,9 +81,6 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
# Mock save_child_document to avoid waiting for librarian response
|
|
||||||
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
|
|
||||||
|
|
||||||
await processor.on_message(mock_msg, None, mock_flow)
|
await processor.on_message(mock_msg, None, mock_flow)
|
||||||
|
|
||||||
# Verify output was sent for each page
|
# Verify output was sent for each page
|
||||||
|
|
@ -148,6 +146,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
||||||
"output": mock_output_flow,
|
"output": mock_output_flow,
|
||||||
"triples": mock_triples_flow,
|
"triples": mock_triples_flow,
|
||||||
}.get(name))
|
}.get(name))
|
||||||
|
mock_flow.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
'id': 'test-pdf-decoder',
|
'id': 'test-pdf-decoder',
|
||||||
|
|
@ -156,9 +155,6 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
# Mock save_child_document to avoid waiting for librarian response
|
|
||||||
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
|
|
||||||
|
|
||||||
await processor.on_message(mock_msg, None, mock_flow)
|
await processor.on_message(mock_msg, None, mock_flow)
|
||||||
|
|
||||||
mock_output_flow.send.assert_called_once()
|
mock_output_flow.send.assert_called_once()
|
||||||
|
|
|
||||||
|
|
@ -254,8 +254,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
||||||
"triples": mock_triples_flow,
|
"triples": mock_triples_flow,
|
||||||
}.get(name))
|
}.get(name))
|
||||||
|
|
||||||
# Mock save_child_document and magic
|
mock_flow.librarian.save_child_document = AsyncMock(return_value="mock-id")
|
||||||
processor.librarian.save_child_document = AsyncMock(return_value="mock-id")
|
|
||||||
|
|
||||||
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
|
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
|
||||||
mock_magic.from_buffer.return_value = "text/markdown"
|
mock_magic.from_buffer.return_value = "text/markdown"
|
||||||
|
|
@ -310,7 +309,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
||||||
"triples": mock_triples_flow,
|
"triples": mock_triples_flow,
|
||||||
}.get(name))
|
}.get(name))
|
||||||
|
|
||||||
processor.librarian.save_child_document = AsyncMock(return_value="mock-id")
|
mock_flow.librarian.save_child_document = AsyncMock(return_value="mock-id")
|
||||||
|
|
||||||
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
|
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
|
||||||
mock_magic.from_buffer.return_value = "application/pdf"
|
mock_magic.from_buffer.return_value = "application/pdf"
|
||||||
|
|
@ -361,7 +360,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
||||||
"triples": mock_triples_flow,
|
"triples": mock_triples_flow,
|
||||||
}.get(name))
|
}.get(name))
|
||||||
|
|
||||||
processor.librarian.save_child_document = AsyncMock(return_value="mock-id")
|
mock_flow.librarian.save_child_document = AsyncMock(return_value="mock-id")
|
||||||
|
|
||||||
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
|
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
|
||||||
mock_magic.from_buffer.return_value = "application/pdf"
|
mock_magic.from_buffer.return_value = "application/pdf"
|
||||||
|
|
@ -374,7 +373,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
||||||
assert mock_triples_flow.send.call_count == 2
|
assert mock_triples_flow.send.call_count == 2
|
||||||
|
|
||||||
# save_child_document called twice (page + image)
|
# save_child_document called twice (page + image)
|
||||||
assert processor.librarian.save_child_document.call_count == 2
|
assert mock_flow.librarian.save_child_document.call_count == 2
|
||||||
|
|
||||||
@patch('trustgraph.base.flow_processor.FlowProcessor.add_args')
|
@patch('trustgraph.base.flow_processor.FlowProcessor.add_args')
|
||||||
def test_add_args(self, mock_parent_add_args):
|
def test_add_args(self, mock_parent_add_args):
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ class _Identity:
|
||||||
self.source = "api-key"
|
self.source = "api-key"
|
||||||
|
|
||||||
|
|
||||||
def _allow_auth(identity=None):
|
def _allow_auth(identity=None, workspaces=None):
|
||||||
"""Build an Auth double that authenticates to ``identity`` and
|
"""Build an Auth double that authenticates to ``identity`` and
|
||||||
allows every authorise() call."""
|
allows every authorise() call."""
|
||||||
auth = MagicMock()
|
auth = MagicMock()
|
||||||
|
|
@ -42,16 +42,18 @@ def _allow_auth(identity=None):
|
||||||
return_value=identity or _Identity(),
|
return_value=identity or _Identity(),
|
||||||
)
|
)
|
||||||
auth.authorise = AsyncMock(return_value=None)
|
auth.authorise = AsyncMock(return_value=None)
|
||||||
|
auth.known_workspaces = workspaces or {"default", "acme"}
|
||||||
return auth
|
return auth
|
||||||
|
|
||||||
|
|
||||||
def _deny_auth(identity=None):
|
def _deny_auth(identity=None, workspaces=None):
|
||||||
"""Build an Auth double that authenticates but denies authorise."""
|
"""Build an Auth double that authenticates but denies authorise."""
|
||||||
auth = MagicMock()
|
auth = MagicMock()
|
||||||
auth.authenticate = AsyncMock(
|
auth.authenticate = AsyncMock(
|
||||||
return_value=identity or _Identity(),
|
return_value=identity or _Identity(),
|
||||||
)
|
)
|
||||||
auth.authorise = AsyncMock(side_effect=access_denied())
|
auth.authorise = AsyncMock(side_effect=access_denied())
|
||||||
|
auth.known_workspaces = workspaces or {"default", "acme"}
|
||||||
return auth
|
return auth
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -176,7 +176,7 @@ class TestDispatcherManager:
|
||||||
params = {"kind": "test_kind"}
|
params = {"kind": "test_kind"}
|
||||||
result = await manager.process_global_service("data", "responder", params)
|
result = await manager.process_global_service("data", "responder", params)
|
||||||
|
|
||||||
manager.invoke_global_service.assert_called_once_with("data", "responder", "test_kind")
|
manager.invoke_global_service.assert_called_once_with("data", "responder", "test_kind", workspace=None)
|
||||||
assert result == "global_result"
|
assert result == "global_result"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -189,16 +189,16 @@ class TestDispatcherManager:
|
||||||
# Pre-populate with existing dispatcher
|
# Pre-populate with existing dispatcher
|
||||||
mock_dispatcher = Mock()
|
mock_dispatcher = Mock()
|
||||||
mock_dispatcher.process = AsyncMock(return_value="cached_result")
|
mock_dispatcher.process = AsyncMock(return_value="cached_result")
|
||||||
manager.dispatchers[(None, "config")] = mock_dispatcher
|
manager.dispatchers[(None, "iam")] = mock_dispatcher
|
||||||
|
|
||||||
result = await manager.invoke_global_service("data", "responder", "config")
|
result = await manager.invoke_global_service("data", "responder", "iam")
|
||||||
|
|
||||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||||
assert result == "cached_result"
|
assert result == "cached_result"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invoke_global_service_creates_new_dispatcher(self):
|
async def test_invoke_global_service_creates_new_dispatcher(self):
|
||||||
"""Test invoke_global_service creates new dispatcher"""
|
"""Test invoke_global_service creates new dispatcher for system service"""
|
||||||
mock_backend = Mock()
|
mock_backend = Mock()
|
||||||
mock_config_receiver = Mock()
|
mock_config_receiver = Mock()
|
||||||
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||||
|
|
@ -211,24 +211,50 @@ class TestDispatcherManager:
|
||||||
mock_dispatcher_class.return_value = mock_dispatcher
|
mock_dispatcher_class.return_value = mock_dispatcher
|
||||||
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||||
|
|
||||||
result = await manager.invoke_global_service("data", "responder", "config")
|
result = await manager.invoke_global_service("data", "responder", "iam")
|
||||||
|
|
||||||
# Verify dispatcher was created with correct parameters
|
|
||||||
mock_dispatcher_class.assert_called_once_with(
|
mock_dispatcher_class.assert_called_once_with(
|
||||||
backend=mock_backend,
|
backend=mock_backend,
|
||||||
timeout=120,
|
timeout=120,
|
||||||
consumer="api-gateway-config-request",
|
consumer="api-gateway-iam-request",
|
||||||
subscriber="api-gateway-config-request",
|
subscriber="api-gateway-iam-request",
|
||||||
request_queue=None,
|
request_queue=None,
|
||||||
response_queue=None
|
response_queue=None
|
||||||
)
|
)
|
||||||
mock_dispatcher.start.assert_called_once()
|
mock_dispatcher.start.assert_called_once()
|
||||||
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||||
|
|
||||||
# Verify dispatcher was cached
|
assert manager.dispatchers[(None, "iam")] == mock_dispatcher
|
||||||
assert manager.dispatchers[(None, "config")] == mock_dispatcher
|
|
||||||
assert result == "new_result"
|
assert result == "new_result"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invoke_global_service_workspace_required_for_workspace_dispatchers(self):
|
||||||
|
"""Workspace dispatchers (config, flow, etc.) require a workspace"""
|
||||||
|
mock_backend = Mock()
|
||||||
|
mock_config_receiver = Mock()
|
||||||
|
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="Workspace is required for config"):
|
||||||
|
await manager.invoke_global_service("data", "responder", "config")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invoke_global_service_workspace_dispatcher_with_workspace(self):
|
||||||
|
"""Workspace dispatchers work when workspace is provided"""
|
||||||
|
mock_backend = Mock()
|
||||||
|
mock_config_receiver = Mock()
|
||||||
|
manager = DispatcherManager(mock_backend, mock_config_receiver, auth=Mock())
|
||||||
|
|
||||||
|
mock_dispatcher = Mock()
|
||||||
|
mock_dispatcher.process = AsyncMock(return_value="ws_result")
|
||||||
|
manager.dispatchers[("alice", "config")] = mock_dispatcher
|
||||||
|
|
||||||
|
result = await manager.invoke_global_service(
|
||||||
|
"data", "responder", "config", workspace="alice",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_dispatcher.process.assert_called_once_with("data", "responder")
|
||||||
|
assert result == "ws_result"
|
||||||
|
|
||||||
def test_dispatch_flow_import_returns_method(self):
|
def test_dispatch_flow_import_returns_method(self):
|
||||||
"""Test dispatch_flow_import returns correct method"""
|
"""Test dispatch_flow_import returns correct method"""
|
||||||
mock_backend = Mock()
|
mock_backend = Mock()
|
||||||
|
|
@ -610,7 +636,7 @@ class TestDispatcherManager:
|
||||||
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
mock_dispatchers.__getitem__.return_value = mock_dispatcher_class
|
||||||
|
|
||||||
results = await asyncio.gather(*[
|
results = await asyncio.gather(*[
|
||||||
manager.invoke_global_service("data", "responder", "config")
|
manager.invoke_global_service("data", "responder", "iam")
|
||||||
for _ in range(5)
|
for _ in range(5)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
@ -618,7 +644,7 @@ class TestDispatcherManager:
|
||||||
"Dispatcher class instantiated more than once — duplicate consumer bug"
|
"Dispatcher class instantiated more than once — duplicate consumer bug"
|
||||||
)
|
)
|
||||||
assert mock_dispatcher.start.call_count == 1
|
assert mock_dispatcher.start.call_count == 1
|
||||||
assert manager.dispatchers[(None, "config")] is mock_dispatcher
|
assert manager.dispatchers[(None, "iam")] is mock_dispatcher
|
||||||
assert all(r == "result" for r in results)
|
assert all(r == "result" for r in results)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
||||||
|
|
@ -33,12 +33,11 @@ def _make_librarian(min_chunk_size=1):
|
||||||
|
|
||||||
|
|
||||||
def _make_doc_metadata(
|
def _make_doc_metadata(
|
||||||
doc_id="doc-1", kind="application/pdf", workspace="alice", title="Test Doc"
|
doc_id="doc-1", kind="application/pdf", title="Test Doc"
|
||||||
):
|
):
|
||||||
meta = MagicMock()
|
meta = MagicMock()
|
||||||
meta.id = doc_id
|
meta.id = doc_id
|
||||||
meta.kind = kind
|
meta.kind = kind
|
||||||
meta.workspace = workspace
|
|
||||||
meta.title = title
|
meta.title = title
|
||||||
meta.time = 1700000000
|
meta.time = 1700000000
|
||||||
meta.comments = ""
|
meta.comments = ""
|
||||||
|
|
@ -47,21 +46,20 @@ def _make_doc_metadata(
|
||||||
|
|
||||||
|
|
||||||
def _make_begin_request(
|
def _make_begin_request(
|
||||||
doc_id="doc-1", kind="application/pdf", workspace="alice",
|
doc_id="doc-1", kind="application/pdf",
|
||||||
total_size=10_000_000, chunk_size=0
|
total_size=10_000_000, chunk_size=0
|
||||||
):
|
):
|
||||||
req = MagicMock()
|
req = MagicMock()
|
||||||
req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind, workspace=workspace)
|
req.document_metadata = _make_doc_metadata(doc_id=doc_id, kind=kind)
|
||||||
req.total_size = total_size
|
req.total_size = total_size
|
||||||
req.chunk_size = chunk_size
|
req.chunk_size = chunk_size
|
||||||
return req
|
return req
|
||||||
|
|
||||||
|
|
||||||
def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, workspace="alice", content=b"data"):
|
def _make_upload_chunk_request(upload_id="up-1", chunk_index=0, content=b"data"):
|
||||||
req = MagicMock()
|
req = MagicMock()
|
||||||
req.upload_id = upload_id
|
req.upload_id = upload_id
|
||||||
req.chunk_index = chunk_index
|
req.chunk_index = chunk_index
|
||||||
req.workspace = workspace
|
|
||||||
req.content = base64.b64encode(content)
|
req.content = base64.b64encode(content)
|
||||||
return req
|
return req
|
||||||
|
|
||||||
|
|
@ -76,7 +74,7 @@ def _make_session(
|
||||||
if document_metadata is None:
|
if document_metadata is None:
|
||||||
document_metadata = json.dumps({
|
document_metadata = json.dumps({
|
||||||
"id": document_id, "kind": "application/pdf",
|
"id": document_id, "kind": "application/pdf",
|
||||||
"workspace": workspace, "title": "Test", "time": 1700000000,
|
"title": "Test", "time": 1700000000,
|
||||||
"comments": "", "tags": [],
|
"comments": "", "tags": [],
|
||||||
})
|
})
|
||||||
return {
|
return {
|
||||||
|
|
@ -105,7 +103,7 @@ class TestBeginUpload:
|
||||||
lib.blob_store.create_multipart_upload.return_value = "s3-upload-id"
|
lib.blob_store.create_multipart_upload.return_value = "s3-upload-id"
|
||||||
|
|
||||||
req = _make_begin_request(total_size=10_000_000)
|
req = _make_begin_request(total_size=10_000_000)
|
||||||
resp = await lib.begin_upload(req)
|
resp = await lib.begin_upload(req, "alice")
|
||||||
|
|
||||||
assert resp.error is None
|
assert resp.error is None
|
||||||
assert resp.upload_id is not None
|
assert resp.upload_id is not None
|
||||||
|
|
@ -119,7 +117,7 @@ class TestBeginUpload:
|
||||||
lib.blob_store.create_multipart_upload.return_value = "s3-id"
|
lib.blob_store.create_multipart_upload.return_value = "s3-id"
|
||||||
|
|
||||||
req = _make_begin_request(total_size=10_000, chunk_size=3000)
|
req = _make_begin_request(total_size=10_000, chunk_size=3000)
|
||||||
resp = await lib.begin_upload(req)
|
resp = await lib.begin_upload(req, "alice")
|
||||||
|
|
||||||
assert resp.chunk_size == 3000
|
assert resp.chunk_size == 3000
|
||||||
assert resp.total_chunks == math.ceil(10_000 / 3000)
|
assert resp.total_chunks == math.ceil(10_000 / 3000)
|
||||||
|
|
@ -130,7 +128,7 @@ class TestBeginUpload:
|
||||||
req = _make_begin_request(kind="")
|
req = _make_begin_request(kind="")
|
||||||
|
|
||||||
with pytest.raises(RequestError, match="MIME type.*required"):
|
with pytest.raises(RequestError, match="MIME type.*required"):
|
||||||
await lib.begin_upload(req)
|
await lib.begin_upload(req, "alice")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_rejects_duplicate_document(self):
|
async def test_rejects_duplicate_document(self):
|
||||||
|
|
@ -139,7 +137,7 @@ class TestBeginUpload:
|
||||||
|
|
||||||
req = _make_begin_request()
|
req = _make_begin_request()
|
||||||
with pytest.raises(RequestError, match="already exists"):
|
with pytest.raises(RequestError, match="already exists"):
|
||||||
await lib.begin_upload(req)
|
await lib.begin_upload(req, "alice")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_rejects_zero_size(self):
|
async def test_rejects_zero_size(self):
|
||||||
|
|
@ -148,7 +146,7 @@ class TestBeginUpload:
|
||||||
|
|
||||||
req = _make_begin_request(total_size=0)
|
req = _make_begin_request(total_size=0)
|
||||||
with pytest.raises(RequestError, match="positive"):
|
with pytest.raises(RequestError, match="positive"):
|
||||||
await lib.begin_upload(req)
|
await lib.begin_upload(req, "alice")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_rejects_chunk_below_minimum(self):
|
async def test_rejects_chunk_below_minimum(self):
|
||||||
|
|
@ -157,7 +155,7 @@ class TestBeginUpload:
|
||||||
|
|
||||||
req = _make_begin_request(total_size=10_000, chunk_size=512)
|
req = _make_begin_request(total_size=10_000, chunk_size=512)
|
||||||
with pytest.raises(RequestError, match="below minimum"):
|
with pytest.raises(RequestError, match="below minimum"):
|
||||||
await lib.begin_upload(req)
|
await lib.begin_upload(req, "alice")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_calls_s3_create_multipart(self):
|
async def test_calls_s3_create_multipart(self):
|
||||||
|
|
@ -166,7 +164,7 @@ class TestBeginUpload:
|
||||||
lib.blob_store.create_multipart_upload.return_value = "s3-id"
|
lib.blob_store.create_multipart_upload.return_value = "s3-id"
|
||||||
|
|
||||||
req = _make_begin_request(kind="application/pdf")
|
req = _make_begin_request(kind="application/pdf")
|
||||||
await lib.begin_upload(req)
|
await lib.begin_upload(req, "alice")
|
||||||
|
|
||||||
lib.blob_store.create_multipart_upload.assert_called_once()
|
lib.blob_store.create_multipart_upload.assert_called_once()
|
||||||
# create_multipart_upload(object_id, kind) — positional args
|
# create_multipart_upload(object_id, kind) — positional args
|
||||||
|
|
@ -180,7 +178,7 @@ class TestBeginUpload:
|
||||||
lib.blob_store.create_multipart_upload.return_value = "s3-id"
|
lib.blob_store.create_multipart_upload.return_value = "s3-id"
|
||||||
|
|
||||||
req = _make_begin_request(total_size=5_000_000)
|
req = _make_begin_request(total_size=5_000_000)
|
||||||
resp = await lib.begin_upload(req)
|
resp = await lib.begin_upload(req, "alice")
|
||||||
|
|
||||||
lib.table_store.create_upload_session.assert_called_once()
|
lib.table_store.create_upload_session.assert_called_once()
|
||||||
kwargs = lib.table_store.create_upload_session.call_args[1]
|
kwargs = lib.table_store.create_upload_session.call_args[1]
|
||||||
|
|
@ -195,7 +193,7 @@ class TestBeginUpload:
|
||||||
lib.blob_store.create_multipart_upload.return_value = "s3-id"
|
lib.blob_store.create_multipart_upload.return_value = "s3-id"
|
||||||
|
|
||||||
req = _make_begin_request(kind="text/plain", total_size=1000)
|
req = _make_begin_request(kind="text/plain", total_size=1000)
|
||||||
resp = await lib.begin_upload(req)
|
resp = await lib.begin_upload(req, "alice")
|
||||||
assert resp.error is None
|
assert resp.error is None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -213,7 +211,7 @@ class TestUploadChunk:
|
||||||
lib.blob_store.upload_part.return_value = "etag-1"
|
lib.blob_store.upload_part.return_value = "etag-1"
|
||||||
|
|
||||||
req = _make_upload_chunk_request(chunk_index=0, content=b"chunk data")
|
req = _make_upload_chunk_request(chunk_index=0, content=b"chunk data")
|
||||||
resp = await lib.upload_chunk(req)
|
resp = await lib.upload_chunk(req, "alice")
|
||||||
|
|
||||||
assert resp.error is None
|
assert resp.error is None
|
||||||
assert resp.chunk_index == 0
|
assert resp.chunk_index == 0
|
||||||
|
|
@ -229,7 +227,7 @@ class TestUploadChunk:
|
||||||
lib.blob_store.upload_part.return_value = "etag"
|
lib.blob_store.upload_part.return_value = "etag"
|
||||||
|
|
||||||
req = _make_upload_chunk_request(chunk_index=0)
|
req = _make_upload_chunk_request(chunk_index=0)
|
||||||
await lib.upload_chunk(req)
|
await lib.upload_chunk(req, "alice")
|
||||||
|
|
||||||
kwargs = lib.blob_store.upload_part.call_args[1]
|
kwargs = lib.blob_store.upload_part.call_args[1]
|
||||||
assert kwargs["part_number"] == 1 # 0-indexed chunk → 1-indexed part
|
assert kwargs["part_number"] == 1 # 0-indexed chunk → 1-indexed part
|
||||||
|
|
@ -242,7 +240,7 @@ class TestUploadChunk:
|
||||||
lib.blob_store.upload_part.return_value = "etag"
|
lib.blob_store.upload_part.return_value = "etag"
|
||||||
|
|
||||||
req = _make_upload_chunk_request(chunk_index=3)
|
req = _make_upload_chunk_request(chunk_index=3)
|
||||||
await lib.upload_chunk(req)
|
await lib.upload_chunk(req, "alice")
|
||||||
|
|
||||||
kwargs = lib.blob_store.upload_part.call_args[1]
|
kwargs = lib.blob_store.upload_part.call_args[1]
|
||||||
assert kwargs["part_number"] == 4
|
assert kwargs["part_number"] == 4
|
||||||
|
|
@ -254,7 +252,7 @@ class TestUploadChunk:
|
||||||
|
|
||||||
req = _make_upload_chunk_request()
|
req = _make_upload_chunk_request()
|
||||||
with pytest.raises(RequestError, match="not found"):
|
with pytest.raises(RequestError, match="not found"):
|
||||||
await lib.upload_chunk(req)
|
await lib.upload_chunk(req, "alice")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_rejects_wrong_user(self):
|
async def test_rejects_wrong_user(self):
|
||||||
|
|
@ -262,9 +260,9 @@ class TestUploadChunk:
|
||||||
session = _make_session(workspace="alice")
|
session = _make_session(workspace="alice")
|
||||||
lib.table_store.get_upload_session.return_value = session
|
lib.table_store.get_upload_session.return_value = session
|
||||||
|
|
||||||
req = _make_upload_chunk_request(workspace="bob")
|
req = _make_upload_chunk_request()
|
||||||
with pytest.raises(RequestError, match="Not authorized"):
|
with pytest.raises(RequestError, match="Not authorized"):
|
||||||
await lib.upload_chunk(req)
|
await lib.upload_chunk(req, "bob")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_rejects_negative_chunk_index(self):
|
async def test_rejects_negative_chunk_index(self):
|
||||||
|
|
@ -274,7 +272,7 @@ class TestUploadChunk:
|
||||||
|
|
||||||
req = _make_upload_chunk_request(chunk_index=-1)
|
req = _make_upload_chunk_request(chunk_index=-1)
|
||||||
with pytest.raises(RequestError, match="Invalid chunk index"):
|
with pytest.raises(RequestError, match="Invalid chunk index"):
|
||||||
await lib.upload_chunk(req)
|
await lib.upload_chunk(req, "alice")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_rejects_out_of_range_chunk_index(self):
|
async def test_rejects_out_of_range_chunk_index(self):
|
||||||
|
|
@ -284,7 +282,7 @@ class TestUploadChunk:
|
||||||
|
|
||||||
req = _make_upload_chunk_request(chunk_index=5)
|
req = _make_upload_chunk_request(chunk_index=5)
|
||||||
with pytest.raises(RequestError, match="Invalid chunk index"):
|
with pytest.raises(RequestError, match="Invalid chunk index"):
|
||||||
await lib.upload_chunk(req)
|
await lib.upload_chunk(req, "alice")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_progress_tracking(self):
|
async def test_progress_tracking(self):
|
||||||
|
|
@ -297,7 +295,7 @@ class TestUploadChunk:
|
||||||
lib.blob_store.upload_part.return_value = "e3"
|
lib.blob_store.upload_part.return_value = "e3"
|
||||||
|
|
||||||
req = _make_upload_chunk_request(chunk_index=2)
|
req = _make_upload_chunk_request(chunk_index=2)
|
||||||
resp = await lib.upload_chunk(req)
|
resp = await lib.upload_chunk(req, "alice")
|
||||||
|
|
||||||
# Dict gets chunk 2 added (len=3), then +1 => 4
|
# Dict gets chunk 2 added (len=3), then +1 => 4
|
||||||
assert resp.chunks_received == 4
|
assert resp.chunks_received == 4
|
||||||
|
|
@ -316,7 +314,7 @@ class TestUploadChunk:
|
||||||
lib.blob_store.upload_part.return_value = "e2"
|
lib.blob_store.upload_part.return_value = "e2"
|
||||||
|
|
||||||
req = _make_upload_chunk_request(chunk_index=1)
|
req = _make_upload_chunk_request(chunk_index=1)
|
||||||
resp = await lib.upload_chunk(req)
|
resp = await lib.upload_chunk(req, "alice")
|
||||||
|
|
||||||
# 3 chunks × 3000 = 9000 > 5000, so capped
|
# 3 chunks × 3000 = 9000 > 5000, so capped
|
||||||
assert resp.bytes_received <= 5000
|
assert resp.bytes_received <= 5000
|
||||||
|
|
@ -330,7 +328,7 @@ class TestUploadChunk:
|
||||||
|
|
||||||
raw = b"hello world binary data"
|
raw = b"hello world binary data"
|
||||||
req = _make_upload_chunk_request(content=raw)
|
req = _make_upload_chunk_request(content=raw)
|
||||||
await lib.upload_chunk(req)
|
await lib.upload_chunk(req, "alice")
|
||||||
|
|
||||||
kwargs = lib.blob_store.upload_part.call_args[1]
|
kwargs = lib.blob_store.upload_part.call_args[1]
|
||||||
assert kwargs["data"] == raw
|
assert kwargs["data"] == raw
|
||||||
|
|
@ -353,9 +351,8 @@ class TestCompleteUpload:
|
||||||
|
|
||||||
req = MagicMock()
|
req = MagicMock()
|
||||||
req.upload_id = "up-1"
|
req.upload_id = "up-1"
|
||||||
req.workspace = "alice"
|
|
||||||
|
|
||||||
resp = await lib.complete_upload(req)
|
resp = await lib.complete_upload(req, "alice")
|
||||||
|
|
||||||
assert resp.error is None
|
assert resp.error is None
|
||||||
assert resp.document_id == "doc-1"
|
assert resp.document_id == "doc-1"
|
||||||
|
|
@ -375,9 +372,8 @@ class TestCompleteUpload:
|
||||||
|
|
||||||
req = MagicMock()
|
req = MagicMock()
|
||||||
req.upload_id = "up-1"
|
req.upload_id = "up-1"
|
||||||
req.workspace = "alice"
|
|
||||||
|
|
||||||
await lib.complete_upload(req)
|
await lib.complete_upload(req, "alice")
|
||||||
|
|
||||||
parts = lib.blob_store.complete_multipart_upload.call_args[1]["parts"]
|
parts = lib.blob_store.complete_multipart_upload.call_args[1]["parts"]
|
||||||
part_numbers = [p[0] for p in parts]
|
part_numbers = [p[0] for p in parts]
|
||||||
|
|
@ -394,10 +390,9 @@ class TestCompleteUpload:
|
||||||
|
|
||||||
req = MagicMock()
|
req = MagicMock()
|
||||||
req.upload_id = "up-1"
|
req.upload_id = "up-1"
|
||||||
req.workspace = "alice"
|
|
||||||
|
|
||||||
with pytest.raises(RequestError, match="Missing chunks"):
|
with pytest.raises(RequestError, match="Missing chunks"):
|
||||||
await lib.complete_upload(req)
|
await lib.complete_upload(req, "alice")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_rejects_expired_session(self):
|
async def test_rejects_expired_session(self):
|
||||||
|
|
@ -406,10 +401,9 @@ class TestCompleteUpload:
|
||||||
|
|
||||||
req = MagicMock()
|
req = MagicMock()
|
||||||
req.upload_id = "up-gone"
|
req.upload_id = "up-gone"
|
||||||
req.workspace = "alice"
|
|
||||||
|
|
||||||
with pytest.raises(RequestError, match="not found"):
|
with pytest.raises(RequestError, match="not found"):
|
||||||
await lib.complete_upload(req)
|
await lib.complete_upload(req, "alice")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_rejects_wrong_user(self):
|
async def test_rejects_wrong_user(self):
|
||||||
|
|
@ -419,10 +413,9 @@ class TestCompleteUpload:
|
||||||
|
|
||||||
req = MagicMock()
|
req = MagicMock()
|
||||||
req.upload_id = "up-1"
|
req.upload_id = "up-1"
|
||||||
req.workspace = "bob"
|
|
||||||
|
|
||||||
with pytest.raises(RequestError, match="Not authorized"):
|
with pytest.raises(RequestError, match="Not authorized"):
|
||||||
await lib.complete_upload(req)
|
await lib.complete_upload(req, "bob")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -439,9 +432,8 @@ class TestAbortUpload:
|
||||||
|
|
||||||
req = MagicMock()
|
req = MagicMock()
|
||||||
req.upload_id = "up-1"
|
req.upload_id = "up-1"
|
||||||
req.workspace = "alice"
|
|
||||||
|
|
||||||
resp = await lib.abort_upload(req)
|
resp = await lib.abort_upload(req, "alice")
|
||||||
|
|
||||||
assert resp.error is None
|
assert resp.error is None
|
||||||
lib.blob_store.abort_multipart_upload.assert_called_once_with(
|
lib.blob_store.abort_multipart_upload.assert_called_once_with(
|
||||||
|
|
@ -456,10 +448,9 @@ class TestAbortUpload:
|
||||||
|
|
||||||
req = MagicMock()
|
req = MagicMock()
|
||||||
req.upload_id = "up-gone"
|
req.upload_id = "up-gone"
|
||||||
req.workspace = "alice"
|
|
||||||
|
|
||||||
with pytest.raises(RequestError, match="not found"):
|
with pytest.raises(RequestError, match="not found"):
|
||||||
await lib.abort_upload(req)
|
await lib.abort_upload(req, "alice")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_rejects_wrong_user(self):
|
async def test_rejects_wrong_user(self):
|
||||||
|
|
@ -469,10 +460,9 @@ class TestAbortUpload:
|
||||||
|
|
||||||
req = MagicMock()
|
req = MagicMock()
|
||||||
req.upload_id = "up-1"
|
req.upload_id = "up-1"
|
||||||
req.workspace = "bob"
|
|
||||||
|
|
||||||
with pytest.raises(RequestError, match="Not authorized"):
|
with pytest.raises(RequestError, match="Not authorized"):
|
||||||
await lib.abort_upload(req)
|
await lib.abort_upload(req, "bob")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -492,9 +482,8 @@ class TestGetUploadStatus:
|
||||||
|
|
||||||
req = MagicMock()
|
req = MagicMock()
|
||||||
req.upload_id = "up-1"
|
req.upload_id = "up-1"
|
||||||
req.workspace = "alice"
|
|
||||||
|
|
||||||
resp = await lib.get_upload_status(req)
|
resp = await lib.get_upload_status(req, "alice")
|
||||||
|
|
||||||
assert resp.upload_state == "in-progress"
|
assert resp.upload_state == "in-progress"
|
||||||
assert resp.chunks_received == 3
|
assert resp.chunks_received == 3
|
||||||
|
|
@ -510,9 +499,8 @@ class TestGetUploadStatus:
|
||||||
|
|
||||||
req = MagicMock()
|
req = MagicMock()
|
||||||
req.upload_id = "up-expired"
|
req.upload_id = "up-expired"
|
||||||
req.workspace = "alice"
|
|
||||||
|
|
||||||
resp = await lib.get_upload_status(req)
|
resp = await lib.get_upload_status(req, "alice")
|
||||||
|
|
||||||
assert resp.upload_state == "expired"
|
assert resp.upload_state == "expired"
|
||||||
|
|
||||||
|
|
@ -527,9 +515,8 @@ class TestGetUploadStatus:
|
||||||
|
|
||||||
req = MagicMock()
|
req = MagicMock()
|
||||||
req.upload_id = "up-1"
|
req.upload_id = "up-1"
|
||||||
req.workspace = "alice"
|
|
||||||
|
|
||||||
resp = await lib.get_upload_status(req)
|
resp = await lib.get_upload_status(req, "alice")
|
||||||
|
|
||||||
assert resp.missing_chunks == []
|
assert resp.missing_chunks == []
|
||||||
assert resp.chunks_received == 3
|
assert resp.chunks_received == 3
|
||||||
|
|
@ -544,10 +531,9 @@ class TestGetUploadStatus:
|
||||||
|
|
||||||
req = MagicMock()
|
req = MagicMock()
|
||||||
req.upload_id = "up-1"
|
req.upload_id = "up-1"
|
||||||
req.workspace = "bob"
|
|
||||||
|
|
||||||
with pytest.raises(RequestError, match="Not authorized"):
|
with pytest.raises(RequestError, match="Not authorized"):
|
||||||
await lib.get_upload_status(req)
|
await lib.get_upload_status(req, "bob")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -564,12 +550,11 @@ class TestStreamDocument:
|
||||||
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 2000)
|
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 2000)
|
||||||
|
|
||||||
req = MagicMock()
|
req = MagicMock()
|
||||||
req.workspace = "alice"
|
|
||||||
req.document_id = "doc-1"
|
req.document_id = "doc-1"
|
||||||
req.chunk_size = 2000
|
req.chunk_size = 2000
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
async for resp in lib.stream_document(req):
|
async for resp in lib.stream_document(req, "alice"):
|
||||||
chunks.append(resp)
|
chunks.append(resp)
|
||||||
|
|
||||||
assert len(chunks) == 3 # ceil(5000/2000)
|
assert len(chunks) == 3 # ceil(5000/2000)
|
||||||
|
|
@ -587,12 +572,11 @@ class TestStreamDocument:
|
||||||
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 500)
|
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 500)
|
||||||
|
|
||||||
req = MagicMock()
|
req = MagicMock()
|
||||||
req.workspace = "alice"
|
|
||||||
req.document_id = "doc-1"
|
req.document_id = "doc-1"
|
||||||
req.chunk_size = 2000
|
req.chunk_size = 2000
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
async for resp in lib.stream_document(req):
|
async for resp in lib.stream_document(req, "alice"):
|
||||||
chunks.append(resp)
|
chunks.append(resp)
|
||||||
|
|
||||||
assert len(chunks) == 1
|
assert len(chunks) == 1
|
||||||
|
|
@ -608,12 +592,11 @@ class TestStreamDocument:
|
||||||
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 100)
|
lib.blob_store.get_range = AsyncMock(return_value=b"x" * 100)
|
||||||
|
|
||||||
req = MagicMock()
|
req = MagicMock()
|
||||||
req.workspace = "alice"
|
|
||||||
req.document_id = "doc-1"
|
req.document_id = "doc-1"
|
||||||
req.chunk_size = 2000
|
req.chunk_size = 2000
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
async for resp in lib.stream_document(req):
|
async for resp in lib.stream_document(req, "alice"):
|
||||||
chunks.append(resp)
|
chunks.append(resp)
|
||||||
|
|
||||||
# Verify the byte ranges passed to get_range
|
# Verify the byte ranges passed to get_range
|
||||||
|
|
@ -630,12 +613,11 @@ class TestStreamDocument:
|
||||||
lib.blob_store.get_range = AsyncMock(return_value=b"x")
|
lib.blob_store.get_range = AsyncMock(return_value=b"x")
|
||||||
|
|
||||||
req = MagicMock()
|
req = MagicMock()
|
||||||
req.workspace = "alice"
|
|
||||||
req.document_id = "doc-1"
|
req.document_id = "doc-1"
|
||||||
req.chunk_size = 0 # Should use default 1MB
|
req.chunk_size = 0 # Should use default 1MB
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
async for resp in lib.stream_document(req):
|
async for resp in lib.stream_document(req, "alice"):
|
||||||
chunks.append(resp)
|
chunks.append(resp)
|
||||||
|
|
||||||
assert len(chunks) == 2 # ceil(2MB / 1MB)
|
assert len(chunks) == 2 # ceil(2MB / 1MB)
|
||||||
|
|
@ -649,12 +631,11 @@ class TestStreamDocument:
|
||||||
lib.blob_store.get_range = AsyncMock(return_value=raw)
|
lib.blob_store.get_range = AsyncMock(return_value=raw)
|
||||||
|
|
||||||
req = MagicMock()
|
req = MagicMock()
|
||||||
req.workspace = "alice"
|
|
||||||
req.document_id = "doc-1"
|
req.document_id = "doc-1"
|
||||||
req.chunk_size = 1000
|
req.chunk_size = 1000
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
async for resp in lib.stream_document(req):
|
async for resp in lib.stream_document(req, "alice"):
|
||||||
chunks.append(resp)
|
chunks.append(resp)
|
||||||
|
|
||||||
assert chunks[0].content == base64.b64encode(raw)
|
assert chunks[0].content == base64.b64encode(raw)
|
||||||
|
|
@ -666,12 +647,11 @@ class TestStreamDocument:
|
||||||
lib.blob_store.get_size = AsyncMock(return_value=5000)
|
lib.blob_store.get_size = AsyncMock(return_value=5000)
|
||||||
|
|
||||||
req = MagicMock()
|
req = MagicMock()
|
||||||
req.workspace = "alice"
|
|
||||||
req.document_id = "doc-1"
|
req.document_id = "doc-1"
|
||||||
req.chunk_size = 512
|
req.chunk_size = 512
|
||||||
|
|
||||||
with pytest.raises(RequestError, match="below minimum"):
|
with pytest.raises(RequestError, match="below minimum"):
|
||||||
async for _ in lib.stream_document(req):
|
async for _ in lib.stream_document(req, "alice"):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -698,9 +678,8 @@ class TestListUploads:
|
||||||
]
|
]
|
||||||
|
|
||||||
req = MagicMock()
|
req = MagicMock()
|
||||||
req.workspace = "alice"
|
|
||||||
|
|
||||||
resp = await lib.list_uploads(req)
|
resp = await lib.list_uploads(req, "alice")
|
||||||
|
|
||||||
assert resp.error is None
|
assert resp.error is None
|
||||||
assert len(resp.upload_sessions) == 1
|
assert len(resp.upload_sessions) == 1
|
||||||
|
|
@ -713,8 +692,7 @@ class TestListUploads:
|
||||||
lib.table_store.list_upload_sessions.return_value = []
|
lib.table_store.list_upload_sessions.return_value = []
|
||||||
|
|
||||||
req = MagicMock()
|
req = MagicMock()
|
||||||
req.workspace = "alice"
|
|
||||||
|
|
||||||
resp = await lib.list_uploads(req)
|
resp = await lib.list_uploads(req, "alice")
|
||||||
|
|
||||||
assert resp.upload_sessions == []
|
assert resp.upload_sessions == []
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,6 @@ class TestDocumentMetadataTranslator:
|
||||||
"title": "Test Document",
|
"title": "Test Document",
|
||||||
"comments": "No comments",
|
"comments": "No comments",
|
||||||
"metadata": [],
|
"metadata": [],
|
||||||
"workspace": "alice",
|
|
||||||
"tags": ["finance", "q4"],
|
"tags": ["finance", "q4"],
|
||||||
"parent-id": "doc-100",
|
"parent-id": "doc-100",
|
||||||
"document-type": "page",
|
"document-type": "page",
|
||||||
|
|
@ -40,14 +39,12 @@ class TestDocumentMetadataTranslator:
|
||||||
assert obj.time == 1710000000
|
assert obj.time == 1710000000
|
||||||
assert obj.kind == "application/pdf"
|
assert obj.kind == "application/pdf"
|
||||||
assert obj.title == "Test Document"
|
assert obj.title == "Test Document"
|
||||||
assert obj.workspace == "alice"
|
|
||||||
assert obj.tags == ["finance", "q4"]
|
assert obj.tags == ["finance", "q4"]
|
||||||
assert obj.parent_id == "doc-100"
|
assert obj.parent_id == "doc-100"
|
||||||
assert obj.document_type == "page"
|
assert obj.document_type == "page"
|
||||||
|
|
||||||
wire = self.tx.encode(obj)
|
wire = self.tx.encode(obj)
|
||||||
assert wire["id"] == "doc-123"
|
assert wire["id"] == "doc-123"
|
||||||
assert wire["workspace"] == "alice"
|
|
||||||
assert wire["parent-id"] == "doc-100"
|
assert wire["parent-id"] == "doc-100"
|
||||||
assert wire["document-type"] == "page"
|
assert wire["document-type"] == "page"
|
||||||
|
|
||||||
|
|
@ -80,10 +77,9 @@ class TestDocumentMetadataTranslator:
|
||||||
|
|
||||||
def test_falsy_fields_omitted_from_wire(self):
|
def test_falsy_fields_omitted_from_wire(self):
|
||||||
"""Empty string fields should be omitted from wire format."""
|
"""Empty string fields should be omitted from wire format."""
|
||||||
obj = DocumentMetadata(id="", time=0, workspace="")
|
obj = DocumentMetadata(id="", time=0)
|
||||||
wire = self.tx.encode(obj)
|
wire = self.tx.encode(obj)
|
||||||
assert "id" not in wire
|
assert "id" not in wire
|
||||||
assert "workspace" not in wire
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -101,7 +97,6 @@ class TestProcessingMetadataTranslator:
|
||||||
"document-id": "doc-123",
|
"document-id": "doc-123",
|
||||||
"time": 1710000000,
|
"time": 1710000000,
|
||||||
"flow": "default",
|
"flow": "default",
|
||||||
"workspace": "alice",
|
|
||||||
"collection": "my-collection",
|
"collection": "my-collection",
|
||||||
"tags": ["tag1"],
|
"tags": ["tag1"],
|
||||||
}
|
}
|
||||||
|
|
@ -109,20 +104,17 @@ class TestProcessingMetadataTranslator:
|
||||||
assert obj.id == "proc-1"
|
assert obj.id == "proc-1"
|
||||||
assert obj.document_id == "doc-123"
|
assert obj.document_id == "doc-123"
|
||||||
assert obj.flow == "default"
|
assert obj.flow == "default"
|
||||||
assert obj.workspace == "alice"
|
|
||||||
assert obj.collection == "my-collection"
|
assert obj.collection == "my-collection"
|
||||||
assert obj.tags == ["tag1"]
|
assert obj.tags == ["tag1"]
|
||||||
|
|
||||||
wire = self.tx.encode(obj)
|
wire = self.tx.encode(obj)
|
||||||
assert wire["id"] == "proc-1"
|
assert wire["id"] == "proc-1"
|
||||||
assert wire["document-id"] == "doc-123"
|
assert wire["document-id"] == "doc-123"
|
||||||
assert wire["workspace"] == "alice"
|
|
||||||
assert wire["collection"] == "my-collection"
|
assert wire["collection"] == "my-collection"
|
||||||
|
|
||||||
def test_missing_fields_use_defaults(self):
|
def test_missing_fields_use_defaults(self):
|
||||||
obj = self.tx.decode({})
|
obj = self.tx.decode({})
|
||||||
assert obj.id is None
|
assert obj.id is None
|
||||||
assert obj.workspace is None
|
|
||||||
assert obj.collection is None
|
assert obj.collection is None
|
||||||
|
|
||||||
def test_tags_none_omitted(self):
|
def test_tags_none_omitted(self):
|
||||||
|
|
@ -135,10 +127,9 @@ class TestProcessingMetadataTranslator:
|
||||||
wire = self.tx.encode(obj)
|
wire = self.tx.encode(obj)
|
||||||
assert wire["tags"] == []
|
assert wire["tags"] == []
|
||||||
|
|
||||||
def test_workspace_and_collection_preserved(self):
|
def test_collection_preserved(self):
|
||||||
"""Core pipeline routing fields must survive round-trip."""
|
"""Core pipeline routing fields must survive round-trip."""
|
||||||
data = {"workspace": "bob", "collection": "research"}
|
data = {"collection": "research"}
|
||||||
obj = self.tx.decode(data)
|
obj = self.tx.decode(data)
|
||||||
wire = self.tx.encode(obj)
|
wire = self.tx.encode(obj)
|
||||||
assert wire["workspace"] == "bob"
|
|
||||||
assert wire["collection"] == "research"
|
assert wire["collection"] == "research"
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ CHUNK_CONTENT = {
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_fetch_chunk():
|
def mock_fetch_chunk():
|
||||||
"""Create a mock fetch_chunk function"""
|
"""Create a mock fetch_chunk function"""
|
||||||
async def fetch(chunk_id, user):
|
async def fetch(chunk_id):
|
||||||
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||||
return fetch
|
return fetch
|
||||||
|
|
||||||
|
|
@ -203,7 +203,7 @@ class TestQuery:
|
||||||
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
||||||
|
|
||||||
# Mock fetch_chunk function
|
# Mock fetch_chunk function
|
||||||
async def mock_fetch(chunk_id, user):
|
async def mock_fetch(chunk_id):
|
||||||
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||||
mock_rag.fetch_chunk = mock_fetch
|
mock_rag.fetch_chunk = mock_fetch
|
||||||
|
|
||||||
|
|
@ -361,7 +361,7 @@ class TestQuery:
|
||||||
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
||||||
|
|
||||||
# Mock fetch_chunk
|
# Mock fetch_chunk
|
||||||
async def mock_fetch(chunk_id, user):
|
async def mock_fetch(chunk_id):
|
||||||
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||||
mock_rag.fetch_chunk = mock_fetch
|
mock_rag.fetch_chunk = mock_fetch
|
||||||
|
|
||||||
|
|
@ -437,7 +437,7 @@ class TestQuery:
|
||||||
mock_rag.embeddings_client = mock_embeddings_client
|
mock_rag.embeddings_client = mock_embeddings_client
|
||||||
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
||||||
|
|
||||||
async def mock_fetch(chunk_id, user):
|
async def mock_fetch(chunk_id):
|
||||||
return f"Content for {chunk_id}"
|
return f"Content for {chunk_id}"
|
||||||
mock_rag.fetch_chunk = mock_fetch
|
mock_rag.fetch_chunk = mock_fetch
|
||||||
|
|
||||||
|
|
@ -594,7 +594,7 @@ class TestQuery:
|
||||||
mock_rag.embeddings_client = mock_embeddings_client
|
mock_rag.embeddings_client = mock_embeddings_client
|
||||||
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
||||||
|
|
||||||
async def mock_fetch(chunk_id, user):
|
async def mock_fetch(chunk_id):
|
||||||
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||||
mock_rag.fetch_chunk = mock_fetch
|
mock_rag.fetch_chunk = mock_fetch
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -105,7 +105,7 @@ def build_mock_clients():
|
||||||
]
|
]
|
||||||
|
|
||||||
# 4. Chunk content
|
# 4. Chunk content
|
||||||
async def mock_fetch(chunk_id, user):
|
async def mock_fetch(chunk_id):
|
||||||
return {
|
return {
|
||||||
CHUNK_A: CHUNK_A_CONTENT,
|
CHUNK_A: CHUNK_A_CONTENT,
|
||||||
CHUNK_B: CHUNK_B_CONTENT,
|
CHUNK_B: CHUNK_B_CONTENT,
|
||||||
|
|
|
||||||
|
|
@ -218,7 +218,8 @@ class TestKgStoreConfiguration:
|
||||||
cassandra_host=['kg-env-host1', 'kg-env-host2', 'kg-env-host3'],
|
cassandra_host=['kg-env-host1', 'kg-env-host2', 'kg-env-host3'],
|
||||||
cassandra_username='kg-env-user',
|
cassandra_username='kg-env-user',
|
||||||
cassandra_password='kg-env-pass',
|
cassandra_password='kg-env-pass',
|
||||||
keyspace='knowledge'
|
keyspace='knowledge',
|
||||||
|
replication_factor=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('trustgraph.storage.knowledge.store.KnowledgeTableStore')
|
@patch('trustgraph.storage.knowledge.store.KnowledgeTableStore')
|
||||||
|
|
@ -239,7 +240,8 @@ class TestKgStoreConfiguration:
|
||||||
cassandra_host=['explicit-host'],
|
cassandra_host=['explicit-host'],
|
||||||
cassandra_username='explicit-user',
|
cassandra_username='explicit-user',
|
||||||
cassandra_password='explicit-pass',
|
cassandra_password='explicit-pass',
|
||||||
keyspace='knowledge'
|
keyspace='knowledge',
|
||||||
|
replication_factor=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('trustgraph.storage.knowledge.store.KnowledgeTableStore')
|
@patch('trustgraph.storage.knowledge.store.KnowledgeTableStore')
|
||||||
|
|
@ -260,7 +262,8 @@ class TestKgStoreConfiguration:
|
||||||
cassandra_host=['compat-host'],
|
cassandra_host=['compat-host'],
|
||||||
cassandra_username=None, # Should be None since cassandra_user is ignored
|
cassandra_username=None, # Should be None since cassandra_user is ignored
|
||||||
cassandra_password='compat-pass',
|
cassandra_password='compat-pass',
|
||||||
keyspace='knowledge'
|
keyspace='knowledge',
|
||||||
|
replication_factor=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('trustgraph.storage.knowledge.store.KnowledgeTableStore')
|
@patch('trustgraph.storage.knowledge.store.KnowledgeTableStore')
|
||||||
|
|
@ -277,7 +280,8 @@ class TestKgStoreConfiguration:
|
||||||
cassandra_host=['cassandra'],
|
cassandra_host=['cassandra'],
|
||||||
cassandra_username=None,
|
cassandra_username=None,
|
||||||
cassandra_password=None,
|
cassandra_password=None,
|
||||||
keyspace='knowledge'
|
keyspace='knowledge',
|
||||||
|
replication_factor=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -425,5 +429,6 @@ class TestConfigurationPriorityIntegration:
|
||||||
cassandra_host=['param-host'], # From parameter
|
cassandra_host=['param-host'], # From parameter
|
||||||
cassandra_username='env-user', # From environment
|
cassandra_username='env-user', # From environment
|
||||||
cassandra_password='env-pass', # From environment
|
cassandra_password='env-pass', # From environment
|
||||||
keyspace='knowledge'
|
keyspace='knowledge',
|
||||||
|
replication_factor=1,
|
||||||
)
|
)
|
||||||
|
|
@ -171,14 +171,16 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_processor_initialization_without_api_key(self, mock_llm_init, mock_async_init, mock_openai_class):
|
async def test_processor_initialization_without_api_key(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||||
"""Test processor initialization without API key (should fail)"""
|
"""Test processor initialization without API key uses placeholder"""
|
||||||
# Arrange
|
# Arrange
|
||||||
|
mock_openai_client = MagicMock()
|
||||||
|
mock_openai_class.return_value = mock_openai_client
|
||||||
mock_async_init.return_value = None
|
mock_async_init.return_value = None
|
||||||
mock_llm_init.return_value = None
|
mock_llm_init.return_value = None
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
'model': 'gpt-3.5-turbo',
|
'model': 'gpt-3.5-turbo',
|
||||||
'api_key': None, # No API key provided
|
'api_key': None,
|
||||||
'url': 'https://api.openai.com/v1',
|
'url': 'https://api.openai.com/v1',
|
||||||
'temperature': 0.0,
|
'temperature': 0.0,
|
||||||
'max_output': 4096,
|
'max_output': 4096,
|
||||||
|
|
@ -187,9 +189,10 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
'id': 'test-processor'
|
'id': 'test-processor'
|
||||||
}
|
}
|
||||||
|
|
||||||
# Act & Assert
|
processor = Processor(**config)
|
||||||
with pytest.raises(RuntimeError, match="OpenAI API key not specified"):
|
mock_openai_class.assert_called_once_with(
|
||||||
processor = Processor(**config)
|
base_url='https://api.openai.com/v1', api_key='not-set'
|
||||||
|
)
|
||||||
|
|
||||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,6 @@ def translator():
|
||||||
def graph_embeddings_request():
|
def graph_embeddings_request():
|
||||||
return KnowledgeRequest(
|
return KnowledgeRequest(
|
||||||
operation="put-kg-core",
|
operation="put-kg-core",
|
||||||
workspace="alice",
|
|
||||||
id="doc-1",
|
id="doc-1",
|
||||||
flow="default",
|
flow="default",
|
||||||
collection="testcoll",
|
collection="testcoll",
|
||||||
|
|
@ -110,7 +109,7 @@ class TestKnowledgeRequestTranslatorGraphEmbeddings:
|
||||||
|
|
||||||
assert isinstance(decoded, KnowledgeRequest)
|
assert isinstance(decoded, KnowledgeRequest)
|
||||||
assert decoded.operation == "put-kg-core"
|
assert decoded.operation == "put-kg-core"
|
||||||
assert decoded.workspace == "alice"
|
assert decoded.id == "doc-1"
|
||||||
assert decoded.id == "doc-1"
|
assert decoded.id == "doc-1"
|
||||||
assert decoded.flow == "default"
|
assert decoded.flow == "default"
|
||||||
assert decoded.collection == "testcoll"
|
assert decoded.collection == "testcoll"
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ dependencies = [
|
||||||
"pika",
|
"pika",
|
||||||
"confluent-kafka",
|
"confluent-kafka",
|
||||||
"pyyaml",
|
"pyyaml",
|
||||||
|
"websockets",
|
||||||
]
|
]
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
|
|
|
||||||
|
|
@ -217,7 +217,6 @@ class Library:
|
||||||
"title": title,
|
"title": title,
|
||||||
"comments": comments,
|
"comments": comments,
|
||||||
"metadata": triples,
|
"metadata": triples,
|
||||||
"workspace": self.api.workspace,
|
|
||||||
"tags": tags
|
"tags": tags
|
||||||
},
|
},
|
||||||
"content": base64.b64encode(document).decode("utf-8"),
|
"content": base64.b64encode(document).decode("utf-8"),
|
||||||
|
|
@ -249,7 +248,6 @@ class Library:
|
||||||
"kind": kind,
|
"kind": kind,
|
||||||
"title": title,
|
"title": title,
|
||||||
"comments": comments,
|
"comments": comments,
|
||||||
"workspace": self.api.workspace,
|
|
||||||
"tags": tags,
|
"tags": tags,
|
||||||
},
|
},
|
||||||
"total-size": total_size,
|
"total-size": total_size,
|
||||||
|
|
@ -377,7 +375,6 @@ class Library:
|
||||||
)
|
)
|
||||||
for w in v["metadata"]
|
for w in v["metadata"]
|
||||||
],
|
],
|
||||||
workspace = v.get("workspace", ""),
|
|
||||||
tags = v["tags"],
|
tags = v["tags"],
|
||||||
parent_id = v.get("parent-id", ""),
|
parent_id = v.get("parent-id", ""),
|
||||||
document_type = v.get("document-type", "source"),
|
document_type = v.get("document-type", "source"),
|
||||||
|
|
@ -436,7 +433,6 @@ class Library:
|
||||||
)
|
)
|
||||||
for w in doc["metadata"]
|
for w in doc["metadata"]
|
||||||
],
|
],
|
||||||
workspace = doc.get("workspace", ""),
|
|
||||||
tags = doc["tags"],
|
tags = doc["tags"],
|
||||||
parent_id = doc.get("parent-id", ""),
|
parent_id = doc.get("parent-id", ""),
|
||||||
document_type = doc.get("document-type", "source"),
|
document_type = doc.get("document-type", "source"),
|
||||||
|
|
@ -485,7 +481,6 @@ class Library:
|
||||||
"operation": "update-document",
|
"operation": "update-document",
|
||||||
"workspace": self.api.workspace,
|
"workspace": self.api.workspace,
|
||||||
"document-metadata": {
|
"document-metadata": {
|
||||||
"workspace": self.api.workspace,
|
|
||||||
"document-id": id,
|
"document-id": id,
|
||||||
"time": metadata.time,
|
"time": metadata.time,
|
||||||
"title": metadata.title,
|
"title": metadata.title,
|
||||||
|
|
@ -599,7 +594,6 @@ class Library:
|
||||||
"document-id": document_id,
|
"document-id": document_id,
|
||||||
"time": int(time.time()),
|
"time": int(time.time()),
|
||||||
"flow": flow,
|
"flow": flow,
|
||||||
"workspace": self.api.workspace,
|
|
||||||
"collection": collection,
|
"collection": collection,
|
||||||
"tags": tags,
|
"tags": tags,
|
||||||
}
|
}
|
||||||
|
|
@ -681,7 +675,6 @@ class Library:
|
||||||
document_id = v["document-id"],
|
document_id = v["document-id"],
|
||||||
time = datetime.datetime.fromtimestamp(v["time"]),
|
time = datetime.datetime.fromtimestamp(v["time"]),
|
||||||
flow = v["flow"],
|
flow = v["flow"],
|
||||||
workspace = v.get("workspace", ""),
|
|
||||||
collection = v["collection"],
|
collection = v["collection"],
|
||||||
tags = v["tags"],
|
tags = v["tags"],
|
||||||
)
|
)
|
||||||
|
|
@ -945,7 +938,6 @@ class Library:
|
||||||
"title": title,
|
"title": title,
|
||||||
"comments": comments,
|
"comments": comments,
|
||||||
"metadata": triples,
|
"metadata": triples,
|
||||||
"workspace": self.api.workspace,
|
|
||||||
"tags": tags,
|
"tags": tags,
|
||||||
"parent-id": parent_id,
|
"parent-id": parent_id,
|
||||||
"document-type": "extracted",
|
"document-type": "extracted",
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,6 @@ class DocumentMetadata:
|
||||||
title: Document title
|
title: Document title
|
||||||
comments: Additional comments or description
|
comments: Additional comments or description
|
||||||
metadata: List of RDF triples providing structured metadata
|
metadata: List of RDF triples providing structured metadata
|
||||||
workspace: Workspace the document belongs to
|
|
||||||
tags: List of tags for categorization
|
tags: List of tags for categorization
|
||||||
parent_id: Parent document ID for child documents (empty for top-level docs)
|
parent_id: Parent document ID for child documents (empty for top-level docs)
|
||||||
document_type: "source" for uploaded documents, "extracted" for derived content
|
document_type: "source" for uploaded documents, "extracted" for derived content
|
||||||
|
|
@ -76,7 +75,6 @@ class DocumentMetadata:
|
||||||
title : str
|
title : str
|
||||||
comments : str
|
comments : str
|
||||||
metadata : List[Triple]
|
metadata : List[Triple]
|
||||||
workspace : str
|
|
||||||
tags : List[str]
|
tags : List[str]
|
||||||
parent_id : str = ""
|
parent_id : str = ""
|
||||||
document_type : str = "source"
|
document_type : str = "source"
|
||||||
|
|
@ -91,7 +89,6 @@ class ProcessingMetadata:
|
||||||
document_id: ID of the document being processed
|
document_id: ID of the document being processed
|
||||||
time: Processing start timestamp
|
time: Processing start timestamp
|
||||||
flow: Flow instance handling the processing
|
flow: Flow instance handling the processing
|
||||||
workspace: Workspace the processing job belongs to
|
|
||||||
collection: Target collection for processed data
|
collection: Target collection for processed data
|
||||||
tags: List of tags for categorization
|
tags: List of tags for categorization
|
||||||
"""
|
"""
|
||||||
|
|
@ -99,7 +96,6 @@ class ProcessingMetadata:
|
||||||
document_id : str
|
document_id : str
|
||||||
time : datetime.datetime
|
time : datetime.datetime
|
||||||
flow : str
|
flow : str
|
||||||
workspace : str
|
|
||||||
collection : str
|
collection : str
|
||||||
tags : List[str]
|
tags : List[str]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from . publisher import Publisher
|
||||||
from . subscriber import Subscriber
|
from . subscriber import Subscriber
|
||||||
from . metrics import ProcessorMetrics, ConsumerMetrics, ProducerMetrics, SubscriberMetrics
|
from . metrics import ProcessorMetrics, ConsumerMetrics, ProducerMetrics, SubscriberMetrics
|
||||||
from . logging import add_logging_args, setup_logging
|
from . logging import add_logging_args, setup_logging
|
||||||
|
from . workspace_processor import WorkspaceProcessor
|
||||||
from . flow_processor import FlowProcessor
|
from . flow_processor import FlowProcessor
|
||||||
from . consumer_spec import ConsumerSpec
|
from . consumer_spec import ConsumerSpec
|
||||||
from . parameter_spec import ParameterSpec
|
from . parameter_spec import ParameterSpec
|
||||||
|
|
@ -15,6 +16,7 @@ from . subscriber_spec import SubscriberSpec
|
||||||
from . request_response_spec import RequestResponseSpec
|
from . request_response_spec import RequestResponseSpec
|
||||||
from . llm_service import LlmService, LlmResult, LlmChunk
|
from . llm_service import LlmService, LlmResult, LlmChunk
|
||||||
from . librarian_client import LibrarianClient
|
from . librarian_client import LibrarianClient
|
||||||
|
from . librarian_spec import LibrarianSpec
|
||||||
from . chunking_service import ChunkingService
|
from . chunking_service import ChunkingService
|
||||||
from . embeddings_service import EmbeddingsService
|
from . embeddings_service import EmbeddingsService
|
||||||
from . embeddings_client import EmbeddingsClientSpec
|
from . embeddings_client import EmbeddingsClientSpec
|
||||||
|
|
|
||||||
|
|
@ -71,6 +71,11 @@ class AsyncProcessor:
|
||||||
# { "handler": async_fn, "types": set_or_none }
|
# { "handler": async_fn, "types": set_or_none }
|
||||||
self.config_handlers = []
|
self.config_handlers = []
|
||||||
|
|
||||||
|
# Workspace lifecycle handlers, called when workspaces are
|
||||||
|
# created or deleted. Each entry is an async callable:
|
||||||
|
# async def handler(workspace_changes: WorkspaceChanges)
|
||||||
|
self.workspace_handlers = []
|
||||||
|
|
||||||
# Track the current config version for dedup
|
# Track the current config version for dedup
|
||||||
self.config_version = 0
|
self.config_version = 0
|
||||||
|
|
||||||
|
|
@ -209,6 +214,8 @@ class AsyncProcessor:
|
||||||
|
|
||||||
# Call the handler once per workspace
|
# Call the handler once per workspace
|
||||||
for ws, config in per_ws.items():
|
for ws, config in per_ws.items():
|
||||||
|
if ws.startswith("_"):
|
||||||
|
continue
|
||||||
await entry["handler"](ws, config, version)
|
await entry["handler"](ws, config, version)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -249,6 +256,10 @@ class AsyncProcessor:
|
||||||
"types": set(types) if types else None,
|
"types": set(types) if types else None,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# Register a handler for workspace lifecycle events
|
||||||
|
def register_workspace_handler(self, handler: Callable[..., Any]) -> None:
|
||||||
|
self.workspace_handlers.append(handler)
|
||||||
|
|
||||||
# Called when a config notify message arrives
|
# Called when a config notify message arrives
|
||||||
async def on_config_notify(self, message, consumer, flow):
|
async def on_config_notify(self, message, consumer, flow):
|
||||||
|
|
||||||
|
|
@ -264,6 +275,16 @@ class AsyncProcessor:
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Dispatch workspace lifecycle events before config handlers
|
||||||
|
if v.workspace_changes and self.workspace_handlers:
|
||||||
|
for handler in self.workspace_handlers:
|
||||||
|
try:
|
||||||
|
await handler(v.workspace_changes)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Workspace handler failed: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
|
||||||
notify_types = set(changes.keys())
|
notify_types = set(changes.keys())
|
||||||
|
|
||||||
# Filter out handlers that don't care about any of the changed
|
# Filter out handlers that don't care about any of the changed
|
||||||
|
|
@ -310,6 +331,8 @@ class AsyncProcessor:
|
||||||
per_ws.setdefault(ws, {})[t] = kv
|
per_ws.setdefault(ws, {})[t] = kv
|
||||||
|
|
||||||
for ws, config in per_ws.items():
|
for ws, config in per_ws.items():
|
||||||
|
if ws.startswith("_"):
|
||||||
|
continue
|
||||||
await entry["handler"](
|
await entry["handler"](
|
||||||
ws, config, notify_version,
|
ws, config, notify_version,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -151,7 +151,7 @@ def resolve_cassandra_config(
|
||||||
def get_cassandra_config_from_params(
|
def get_cassandra_config_from_params(
|
||||||
params: dict,
|
params: dict,
|
||||||
default_keyspace: Optional[str] = None
|
default_keyspace: Optional[str] = None
|
||||||
) -> Tuple[List[str], Optional[str], Optional[str], Optional[str]]:
|
) -> Tuple[List[str], Optional[str], Optional[str], Optional[str], int]:
|
||||||
"""
|
"""
|
||||||
Extract and resolve Cassandra configuration from a parameters dictionary.
|
Extract and resolve Cassandra configuration from a parameters dictionary.
|
||||||
|
|
||||||
|
|
@ -160,14 +160,12 @@ def get_cassandra_config_from_params(
|
||||||
default_keyspace: Optional default keyspace if not specified in params
|
default_keyspace: Optional default keyspace if not specified in params
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (hosts_list, username, password, keyspace)
|
tuple: (hosts_list, username, password, keyspace, replication_factor)
|
||||||
"""
|
"""
|
||||||
# Get Cassandra parameters
|
|
||||||
host = params.get('cassandra_host')
|
host = params.get('cassandra_host')
|
||||||
username = params.get('cassandra_username')
|
username = params.get('cassandra_username')
|
||||||
password = params.get('cassandra_password')
|
password = params.get('cassandra_password')
|
||||||
|
|
||||||
# Use resolve function to handle defaults and list conversion
|
|
||||||
return resolve_cassandra_config(
|
return resolve_cassandra_config(
|
||||||
host=host,
|
host=host,
|
||||||
username=username,
|
username=username,
|
||||||
|
|
|
||||||
|
|
@ -4,13 +4,11 @@ for chunk-size and chunk-overlap parameters, and librarian client for
|
||||||
fetching large document content.
|
fetching large document content.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import base64
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from .flow_processor import FlowProcessor
|
from .flow_processor import FlowProcessor
|
||||||
from .parameter_spec import ParameterSpec
|
from .parameter_spec import ParameterSpec
|
||||||
from .librarian_client import LibrarianClient
|
from .librarian_spec import LibrarianSpec
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -35,35 +33,27 @@ class ChunkingService(FlowProcessor):
|
||||||
ParameterSpec(name="chunk-overlap")
|
ParameterSpec(name="chunk-overlap")
|
||||||
)
|
)
|
||||||
|
|
||||||
# Librarian client
|
self.register_specification(
|
||||||
self.librarian = LibrarianClient(
|
LibrarianSpec()
|
||||||
id=id,
|
|
||||||
backend=self.pubsub,
|
|
||||||
taskgroup=self.taskgroup,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("ChunkingService initialized with parameter specifications")
|
logger.debug("ChunkingService initialized with parameter specifications")
|
||||||
|
|
||||||
async def start(self):
|
async def get_document_text(self, doc, flow):
|
||||||
await super(ChunkingService, self).start()
|
|
||||||
await self.librarian.start()
|
|
||||||
|
|
||||||
async def get_document_text(self, doc, workspace):
|
|
||||||
"""
|
"""
|
||||||
Get text content from a TextDocument, fetching from librarian if needed.
|
Get text content from a TextDocument, fetching from librarian if needed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
doc: TextDocument with either inline text or document_id
|
doc: TextDocument with either inline text or document_id
|
||||||
workspace: Workspace for librarian lookup (from flow.workspace)
|
flow: Flow object with librarian client
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The document text content
|
str: The document text content
|
||||||
"""
|
"""
|
||||||
if doc.document_id and not doc.text:
|
if doc.document_id and not doc.text:
|
||||||
logger.info(f"Fetching document {doc.document_id} from librarian...")
|
logger.info(f"Fetching document {doc.document_id} from librarian...")
|
||||||
text = await self.librarian.fetch_document_text(
|
text = await flow.librarian.fetch_document_text(
|
||||||
document_id=doc.document_id,
|
document_id=doc.document_id,
|
||||||
workspace=workspace,
|
|
||||||
)
|
)
|
||||||
logger.info(f"Fetched {len(text)} characters from librarian")
|
logger.info(f"Fetched {len(text)} characters from librarian")
|
||||||
return text
|
return text
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,4 @@
|
||||||
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
class Flow:
|
class Flow:
|
||||||
"""
|
"""
|
||||||
Runtime representation of a deployed flow process.
|
Runtime representation of a deployed flow process.
|
||||||
|
|
@ -22,16 +20,22 @@ class Flow:
|
||||||
|
|
||||||
self.parameter = {}
|
self.parameter = {}
|
||||||
|
|
||||||
|
self.librarian = None
|
||||||
|
|
||||||
for spec in processor.specifications:
|
for spec in processor.specifications:
|
||||||
spec.add(self, processor, defn)
|
spec.add(self, processor, defn)
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
|
if self.librarian:
|
||||||
|
await self.librarian.start()
|
||||||
for c in self.consumer.values():
|
for c in self.consumer.values():
|
||||||
await c.start()
|
await c.start()
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
for c in self.consumer.values():
|
for c in self.consumer.values():
|
||||||
await c.stop()
|
await c.stop()
|
||||||
|
if self.librarian:
|
||||||
|
await self.librarian.stop()
|
||||||
|
|
||||||
def __call__(self, key):
|
def __call__(self, key):
|
||||||
if key in self.producer: return self.producer[key]
|
if key in self.producer: return self.producer[key]
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from .. schema import Error
|
||||||
from .. schema import config_request_queue, config_response_queue
|
from .. schema import config_request_queue, config_response_queue
|
||||||
from .. schema import config_push_queue
|
from .. schema import config_push_queue
|
||||||
from .. log_level import LogLevel
|
from .. log_level import LogLevel
|
||||||
from . async_processor import AsyncProcessor
|
from . workspace_processor import WorkspaceProcessor
|
||||||
from . flow import Flow
|
from . flow import Flow
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
|
|
@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Parent class for configurable processors, configured with flows by
|
# Parent class for configurable processors, configured with flows by
|
||||||
# the config service
|
# the config service
|
||||||
class FlowProcessor(AsyncProcessor):
|
class FlowProcessor(WorkspaceProcessor):
|
||||||
|
|
||||||
def __init__(self, **params):
|
def __init__(self, **params):
|
||||||
|
|
||||||
|
|
@ -113,7 +113,7 @@ class FlowProcessor(AsyncProcessor):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser: ArgumentParser) -> None:
|
def add_args(parser: ArgumentParser) -> None:
|
||||||
|
|
||||||
AsyncProcessor.add_args(parser)
|
WorkspaceProcessor.add_args(parser)
|
||||||
|
|
||||||
# parser.add_argument(
|
# parser.add_argument(
|
||||||
# '--rate-limit-retry',
|
# '--rate-limit-retry',
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ Usage:
|
||||||
id=id, backend=self.pubsub, taskgroup=self.taskgroup, **params
|
id=id, backend=self.pubsub, taskgroup=self.taskgroup, **params
|
||||||
)
|
)
|
||||||
await self.librarian.start()
|
await self.librarian.start()
|
||||||
content = await self.librarian.fetch_document_content(doc_id, workspace)
|
content = await self.librarian.fetch_document_content(doc_id)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
@ -39,9 +39,14 @@ class LibrarianClient:
|
||||||
librarian_response_q = params.get(
|
librarian_response_q = params.get(
|
||||||
"librarian_response_queue", librarian_response_queue,
|
"librarian_response_queue", librarian_response_queue,
|
||||||
)
|
)
|
||||||
|
subscriber = params.get(
|
||||||
|
"librarian_subscriber", f"{id}-librarian",
|
||||||
|
)
|
||||||
|
|
||||||
|
flow_name = params.get("flow_name")
|
||||||
|
|
||||||
librarian_request_metrics = ProducerMetrics(
|
librarian_request_metrics = ProducerMetrics(
|
||||||
processor=id, flow=None, name="librarian-request",
|
processor=id, flow=flow_name, name="librarian-request",
|
||||||
)
|
)
|
||||||
|
|
||||||
self._producer = Producer(
|
self._producer = Producer(
|
||||||
|
|
@ -52,7 +57,7 @@ class LibrarianClient:
|
||||||
)
|
)
|
||||||
|
|
||||||
librarian_response_metrics = ConsumerMetrics(
|
librarian_response_metrics = ConsumerMetrics(
|
||||||
processor=id, flow=None, name="librarian-response",
|
processor=id, flow=flow_name, name="librarian-response",
|
||||||
)
|
)
|
||||||
|
|
||||||
self._consumer = Consumer(
|
self._consumer = Consumer(
|
||||||
|
|
@ -60,7 +65,7 @@ class LibrarianClient:
|
||||||
backend=backend,
|
backend=backend,
|
||||||
flow=None,
|
flow=None,
|
||||||
topic=librarian_response_q,
|
topic=librarian_response_q,
|
||||||
subscriber=f"{id}-librarian",
|
subscriber=subscriber,
|
||||||
schema=LibrarianResponse,
|
schema=LibrarianResponse,
|
||||||
handler=self._on_response,
|
handler=self._on_response,
|
||||||
metrics=librarian_response_metrics,
|
metrics=librarian_response_metrics,
|
||||||
|
|
@ -76,6 +81,11 @@ class LibrarianClient:
|
||||||
await self._producer.start()
|
await self._producer.start()
|
||||||
await self._consumer.start()
|
await self._consumer.start()
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""Stop the librarian producer and consumer."""
|
||||||
|
await self._consumer.stop()
|
||||||
|
await self._producer.stop()
|
||||||
|
|
||||||
async def _on_response(self, msg, consumer, flow):
|
async def _on_response(self, msg, consumer, flow):
|
||||||
"""Route librarian responses to the right waiter."""
|
"""Route librarian responses to the right waiter."""
|
||||||
response = msg.value()
|
response = msg.value()
|
||||||
|
|
@ -150,7 +160,7 @@ class LibrarianClient:
|
||||||
finally:
|
finally:
|
||||||
self._streams.pop(request_id, None)
|
self._streams.pop(request_id, None)
|
||||||
|
|
||||||
async def fetch_document_content(self, document_id, workspace, timeout=120):
|
async def fetch_document_content(self, document_id, timeout=120):
|
||||||
"""Fetch document content using streaming.
|
"""Fetch document content using streaming.
|
||||||
|
|
||||||
Returns base64-encoded content. Caller is responsible for decoding.
|
Returns base64-encoded content. Caller is responsible for decoding.
|
||||||
|
|
@ -158,7 +168,6 @@ class LibrarianClient:
|
||||||
req = LibrarianRequest(
|
req = LibrarianRequest(
|
||||||
operation="stream-document",
|
operation="stream-document",
|
||||||
document_id=document_id,
|
document_id=document_id,
|
||||||
workspace=workspace,
|
|
||||||
)
|
)
|
||||||
chunks = await self.stream(req, timeout=timeout)
|
chunks = await self.stream(req, timeout=timeout)
|
||||||
|
|
||||||
|
|
@ -176,24 +185,23 @@ class LibrarianClient:
|
||||||
|
|
||||||
return base64.b64encode(raw)
|
return base64.b64encode(raw)
|
||||||
|
|
||||||
async def fetch_document_text(self, document_id, workspace, timeout=120):
|
async def fetch_document_text(self, document_id, timeout=120):
|
||||||
"""Fetch document content and decode as UTF-8 text."""
|
"""Fetch document content and decode as UTF-8 text."""
|
||||||
content = await self.fetch_document_content(
|
content = await self.fetch_document_content(
|
||||||
document_id, workspace, timeout=timeout,
|
document_id, timeout=timeout,
|
||||||
)
|
)
|
||||||
return base64.b64decode(content).decode("utf-8")
|
return base64.b64decode(content).decode("utf-8")
|
||||||
|
|
||||||
async def fetch_document_metadata(self, document_id, workspace, timeout=120):
|
async def fetch_document_metadata(self, document_id, timeout=120):
|
||||||
"""Fetch document metadata from the librarian."""
|
"""Fetch document metadata from the librarian."""
|
||||||
req = LibrarianRequest(
|
req = LibrarianRequest(
|
||||||
operation="get-document-metadata",
|
operation="get-document-metadata",
|
||||||
document_id=document_id,
|
document_id=document_id,
|
||||||
workspace=workspace,
|
|
||||||
)
|
)
|
||||||
response = await self.request(req, timeout=timeout)
|
response = await self.request(req, timeout=timeout)
|
||||||
return response.document_metadata
|
return response.document_metadata
|
||||||
|
|
||||||
async def save_child_document(self, doc_id, parent_id, workspace, content,
|
async def save_child_document(self, doc_id, parent_id, content,
|
||||||
document_type="chunk", title=None,
|
document_type="chunk", title=None,
|
||||||
kind="text/plain", timeout=120):
|
kind="text/plain", timeout=120):
|
||||||
"""Save a child document to the librarian."""
|
"""Save a child document to the librarian."""
|
||||||
|
|
@ -202,7 +210,6 @@ class LibrarianClient:
|
||||||
|
|
||||||
doc_metadata = DocumentMetadata(
|
doc_metadata = DocumentMetadata(
|
||||||
id=doc_id,
|
id=doc_id,
|
||||||
workspace=workspace,
|
|
||||||
kind=kind,
|
kind=kind,
|
||||||
title=title or doc_id,
|
title=title or doc_id,
|
||||||
parent_id=parent_id,
|
parent_id=parent_id,
|
||||||
|
|
@ -218,7 +225,7 @@ class LibrarianClient:
|
||||||
await self.request(req, timeout=timeout)
|
await self.request(req, timeout=timeout)
|
||||||
return doc_id
|
return doc_id
|
||||||
|
|
||||||
async def save_document(self, doc_id, workspace, content, title=None,
|
async def save_document(self, doc_id, content, title=None,
|
||||||
document_type="answer", kind="text/plain",
|
document_type="answer", kind="text/plain",
|
||||||
timeout=120):
|
timeout=120):
|
||||||
"""Save a document to the librarian."""
|
"""Save a document to the librarian."""
|
||||||
|
|
@ -227,7 +234,6 @@ class LibrarianClient:
|
||||||
|
|
||||||
doc_metadata = DocumentMetadata(
|
doc_metadata = DocumentMetadata(
|
||||||
id=doc_id,
|
id=doc_id,
|
||||||
workspace=workspace,
|
|
||||||
kind=kind,
|
kind=kind,
|
||||||
title=title or doc_id,
|
title=title or doc_id,
|
||||||
document_type=document_type,
|
document_type=document_type,
|
||||||
|
|
@ -238,7 +244,6 @@ class LibrarianClient:
|
||||||
document_id=doc_id,
|
document_id=doc_id,
|
||||||
document_metadata=doc_metadata,
|
document_metadata=doc_metadata,
|
||||||
content=base64.b64encode(content).decode("utf-8"),
|
content=base64.b64encode(content).decode("utf-8"),
|
||||||
workspace=workspace,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.request(req, timeout=timeout)
|
await self.request(req, timeout=timeout)
|
||||||
|
|
|
||||||
31
trustgraph-base/trustgraph/base/librarian_spec.py
Normal file
31
trustgraph-base/trustgraph/base/librarian_spec.py
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from . spec import Spec
|
||||||
|
from . librarian_client import LibrarianClient
|
||||||
|
|
||||||
|
|
||||||
|
class LibrarianSpec(Spec):
|
||||||
|
def __init__(self, request_name="librarian-request",
|
||||||
|
response_name="librarian-response"):
|
||||||
|
self.request_name = request_name
|
||||||
|
self.response_name = response_name
|
||||||
|
|
||||||
|
def add(self, flow: Any, processor: Any, definition: dict[str, Any]) -> None:
|
||||||
|
|
||||||
|
client = LibrarianClient(
|
||||||
|
id=flow.id,
|
||||||
|
backend=processor.pubsub,
|
||||||
|
taskgroup=processor.taskgroup,
|
||||||
|
librarian_request_queue=definition["topics"][self.request_name],
|
||||||
|
librarian_response_queue=definition["topics"][self.response_name],
|
||||||
|
librarian_subscriber=(
|
||||||
|
processor.id + "--" + flow.workspace + "--" +
|
||||||
|
flow.name + "--librarian--" + str(uuid.uuid4())
|
||||||
|
),
|
||||||
|
flow_name=flow.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
flow.librarian = client
|
||||||
66
trustgraph-base/trustgraph/base/workspace_processor.py
Normal file
66
trustgraph-base/trustgraph/base/workspace_processor.py
Normal file
|
|
@ -0,0 +1,66 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from . async_processor import AsyncProcessor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
WORKSPACES_NAMESPACE = "__workspaces__"
|
||||||
|
WORKSPACE_TYPE = "workspace"
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceProcessor(AsyncProcessor):
|
||||||
|
|
||||||
|
def __init__(self, **params):
|
||||||
|
|
||||||
|
super(WorkspaceProcessor, self).__init__(**params)
|
||||||
|
|
||||||
|
self.active_workspaces = set()
|
||||||
|
|
||||||
|
self.register_workspace_handler(self._handle_workspace_changes)
|
||||||
|
|
||||||
|
async def _discover_workspaces(self):
|
||||||
|
client = self._create_config_client()
|
||||||
|
try:
|
||||||
|
await client.start()
|
||||||
|
type_data, version = await self._fetch_type_all_workspaces(
|
||||||
|
client, WORKSPACE_TYPE,
|
||||||
|
)
|
||||||
|
for ws in type_data:
|
||||||
|
if ws == WORKSPACES_NAMESPACE:
|
||||||
|
for workspace_id in type_data[ws]:
|
||||||
|
if workspace_id not in self.active_workspaces:
|
||||||
|
self.active_workspaces.add(workspace_id)
|
||||||
|
await self.on_workspace_created(workspace_id)
|
||||||
|
finally:
|
||||||
|
await client.stop()
|
||||||
|
|
||||||
|
async def _handle_workspace_changes(self, workspace_changes):
|
||||||
|
for workspace_id in workspace_changes.created:
|
||||||
|
if workspace_id not in self.active_workspaces:
|
||||||
|
self.active_workspaces.add(workspace_id)
|
||||||
|
logger.info(f"Workspace created: {workspace_id}")
|
||||||
|
await self.on_workspace_created(workspace_id)
|
||||||
|
|
||||||
|
for workspace_id in workspace_changes.deleted:
|
||||||
|
if workspace_id in self.active_workspaces:
|
||||||
|
logger.info(f"Workspace deleted: {workspace_id}")
|
||||||
|
await self.on_workspace_deleted(workspace_id)
|
||||||
|
self.active_workspaces.discard(workspace_id)
|
||||||
|
|
||||||
|
async def on_workspace_created(self, workspace):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def on_workspace_deleted(self, workspace):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
await super(WorkspaceProcessor, self).start()
|
||||||
|
await self._discover_workspaces()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_args(parser: ArgumentParser) -> None:
|
||||||
|
AsyncProcessor.add_args(parser)
|
||||||
|
|
@ -9,7 +9,6 @@ class CollectionManagementRequestTranslator(MessageTranslator):
|
||||||
def decode(self, data: Dict[str, Any]) -> CollectionManagementRequest:
|
def decode(self, data: Dict[str, Any]) -> CollectionManagementRequest:
|
||||||
return CollectionManagementRequest(
|
return CollectionManagementRequest(
|
||||||
operation=data.get("operation"),
|
operation=data.get("operation"),
|
||||||
workspace=data.get("workspace", ""),
|
|
||||||
collection=data.get("collection"),
|
collection=data.get("collection"),
|
||||||
timestamp=data.get("timestamp"),
|
timestamp=data.get("timestamp"),
|
||||||
name=data.get("name"),
|
name=data.get("name"),
|
||||||
|
|
@ -24,8 +23,6 @@ class CollectionManagementRequestTranslator(MessageTranslator):
|
||||||
|
|
||||||
if obj.operation is not None:
|
if obj.operation is not None:
|
||||||
result["operation"] = obj.operation
|
result["operation"] = obj.operation
|
||||||
if obj.workspace:
|
|
||||||
result["workspace"] = obj.workspace
|
|
||||||
if obj.collection is not None:
|
if obj.collection is not None:
|
||||||
result["collection"] = obj.collection
|
result["collection"] = obj.collection
|
||||||
if obj.timestamp is not None:
|
if obj.timestamp is not None:
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@ class FlowRequestTranslator(MessageTranslator):
|
||||||
def decode(self, data: Dict[str, Any]) -> FlowRequest:
|
def decode(self, data: Dict[str, Any]) -> FlowRequest:
|
||||||
return FlowRequest(
|
return FlowRequest(
|
||||||
operation=data.get("operation"),
|
operation=data.get("operation"),
|
||||||
workspace=data.get("workspace", ""),
|
|
||||||
blueprint_name=data.get("blueprint-name"),
|
blueprint_name=data.get("blueprint-name"),
|
||||||
blueprint_definition=data.get("blueprint-definition"),
|
blueprint_definition=data.get("blueprint-definition"),
|
||||||
description=data.get("description"),
|
description=data.get("description"),
|
||||||
|
|
@ -22,8 +21,6 @@ class FlowRequestTranslator(MessageTranslator):
|
||||||
|
|
||||||
if obj.operation is not None:
|
if obj.operation is not None:
|
||||||
result["operation"] = obj.operation
|
result["operation"] = obj.operation
|
||||||
if obj.workspace is not None:
|
|
||||||
result["workspace"] = obj.workspace
|
|
||||||
if obj.blueprint_name is not None:
|
if obj.blueprint_name is not None:
|
||||||
result["blueprint-name"] = obj.blueprint_name
|
result["blueprint-name"] = obj.blueprint_name
|
||||||
if obj.blueprint_definition is not None:
|
if obj.blueprint_definition is not None:
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,6 @@ class KnowledgeRequestTranslator(MessageTranslator):
|
||||||
|
|
||||||
return KnowledgeRequest(
|
return KnowledgeRequest(
|
||||||
operation=data.get("operation"),
|
operation=data.get("operation"),
|
||||||
workspace=data.get("workspace", ""),
|
|
||||||
id=data.get("id"),
|
id=data.get("id"),
|
||||||
flow=data.get("flow"),
|
flow=data.get("flow"),
|
||||||
collection=data.get("collection"),
|
collection=data.get("collection"),
|
||||||
|
|
@ -58,8 +57,6 @@ class KnowledgeRequestTranslator(MessageTranslator):
|
||||||
|
|
||||||
if obj.operation:
|
if obj.operation:
|
||||||
result["operation"] = obj.operation
|
result["operation"] = obj.operation
|
||||||
if obj.workspace:
|
|
||||||
result["workspace"] = obj.workspace
|
|
||||||
if obj.id:
|
if obj.id:
|
||||||
result["id"] = obj.id
|
result["id"] = obj.id
|
||||||
if obj.flow:
|
if obj.flow:
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,6 @@ class LibraryRequestTranslator(MessageTranslator):
|
||||||
document_metadata=doc_metadata,
|
document_metadata=doc_metadata,
|
||||||
processing_metadata=proc_metadata,
|
processing_metadata=proc_metadata,
|
||||||
content=content,
|
content=content,
|
||||||
workspace=data.get("workspace", ""),
|
|
||||||
collection=data.get("collection", ""),
|
collection=data.get("collection", ""),
|
||||||
criteria=criteria,
|
criteria=criteria,
|
||||||
# Chunked upload fields
|
# Chunked upload fields
|
||||||
|
|
@ -76,8 +75,6 @@ class LibraryRequestTranslator(MessageTranslator):
|
||||||
result["processing-metadata"] = self.proc_metadata_translator.encode(obj.processing_metadata)
|
result["processing-metadata"] = self.proc_metadata_translator.encode(obj.processing_metadata)
|
||||||
if obj.content:
|
if obj.content:
|
||||||
result["content"] = obj.content.decode("utf-8") if isinstance(obj.content, bytes) else obj.content
|
result["content"] = obj.content.decode("utf-8") if isinstance(obj.content, bytes) else obj.content
|
||||||
if obj.workspace:
|
|
||||||
result["workspace"] = obj.workspace
|
|
||||||
if obj.collection:
|
if obj.collection:
|
||||||
result["collection"] = obj.collection
|
result["collection"] = obj.collection
|
||||||
if obj.criteria is not None:
|
if obj.criteria is not None:
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,6 @@ class DocumentMetadataTranslator(Translator):
|
||||||
title=data.get("title"),
|
title=data.get("title"),
|
||||||
comments=data.get("comments"),
|
comments=data.get("comments"),
|
||||||
metadata=self.subgraph_translator.decode(metadata) if metadata is not None else [],
|
metadata=self.subgraph_translator.decode(metadata) if metadata is not None else [],
|
||||||
workspace=data.get("workspace"),
|
|
||||||
tags=data.get("tags"),
|
tags=data.get("tags"),
|
||||||
parent_id=data.get("parent-id", ""),
|
parent_id=data.get("parent-id", ""),
|
||||||
document_type=data.get("document-type", "source"),
|
document_type=data.get("document-type", "source"),
|
||||||
|
|
@ -40,8 +39,6 @@ class DocumentMetadataTranslator(Translator):
|
||||||
result["comments"] = obj.comments
|
result["comments"] = obj.comments
|
||||||
if obj.metadata is not None:
|
if obj.metadata is not None:
|
||||||
result["metadata"] = self.subgraph_translator.encode(obj.metadata)
|
result["metadata"] = self.subgraph_translator.encode(obj.metadata)
|
||||||
if obj.workspace:
|
|
||||||
result["workspace"] = obj.workspace
|
|
||||||
if obj.tags is not None:
|
if obj.tags is not None:
|
||||||
result["tags"] = obj.tags
|
result["tags"] = obj.tags
|
||||||
if obj.parent_id:
|
if obj.parent_id:
|
||||||
|
|
@ -61,7 +58,6 @@ class ProcessingMetadataTranslator(Translator):
|
||||||
document_id=data.get("document-id"),
|
document_id=data.get("document-id"),
|
||||||
time=data.get("time"),
|
time=data.get("time"),
|
||||||
flow=data.get("flow"),
|
flow=data.get("flow"),
|
||||||
workspace=data.get("workspace"),
|
|
||||||
collection=data.get("collection"),
|
collection=data.get("collection"),
|
||||||
tags=data.get("tags")
|
tags=data.get("tags")
|
||||||
)
|
)
|
||||||
|
|
@ -77,8 +73,6 @@ class ProcessingMetadataTranslator(Translator):
|
||||||
result["time"] = obj.time
|
result["time"] = obj.time
|
||||||
if obj.flow:
|
if obj.flow:
|
||||||
result["flow"] = obj.flow
|
result["flow"] = obj.flow
|
||||||
if obj.workspace:
|
|
||||||
result["workspace"] = obj.workspace
|
|
||||||
if obj.collection:
|
if obj.collection:
|
||||||
result["collection"] = obj.collection
|
result["collection"] = obj.collection
|
||||||
if obj.tags is not None:
|
if obj.tags is not None:
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,5 @@ class Metadata:
|
||||||
# Root document identifier (set by librarian, preserved through pipeline)
|
# Root document identifier (set by librarian, preserved through pipeline)
|
||||||
root: str = ""
|
root: str = ""
|
||||||
|
|
||||||
# Collection the message belongs to. Workspace is NOT carried on the
|
# Collection the message belongs to.
|
||||||
# message — consumers derive it from flow.workspace (the flow the
|
|
||||||
# message arrived on), which is the trusted isolation boundary.
|
|
||||||
collection: str = ""
|
collection: str = ""
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from .embeddings import GraphEmbeddings
|
||||||
# <- (error)
|
# <- (error)
|
||||||
|
|
||||||
# list-kg-cores
|
# list-kg-cores
|
||||||
# -> (workspace)
|
# -> ()
|
||||||
# <- ()
|
# <- ()
|
||||||
# <- (error)
|
# <- (error)
|
||||||
|
|
||||||
|
|
@ -27,9 +27,6 @@ class KnowledgeRequest:
|
||||||
# load-kg-core, unload-kg-core
|
# load-kg-core, unload-kg-core
|
||||||
operation: str = ""
|
operation: str = ""
|
||||||
|
|
||||||
# Workspace the cores belong to. Partition / isolation boundary.
|
|
||||||
workspace: str = ""
|
|
||||||
|
|
||||||
# get-kg-core, list-kg-cores, delete-kg-core, put-kg-core,
|
# get-kg-core, list-kg-cores, delete-kg-core, put-kg-core,
|
||||||
# load-kg-core, unload-kg-core
|
# load-kg-core, unload-kg-core
|
||||||
id: str = ""
|
id: str = ""
|
||||||
|
|
|
||||||
|
|
@ -22,17 +22,9 @@ class CollectionMetadata:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CollectionManagementRequest:
|
class CollectionManagementRequest:
|
||||||
"""Request for collection management operations.
|
"""Request for collection management operations."""
|
||||||
|
|
||||||
Collection-management is a global (non-flow-scoped) service, so the
|
|
||||||
workspace has to travel on the wire — it's the isolation boundary
|
|
||||||
for which workspace's collections the request operates on.
|
|
||||||
"""
|
|
||||||
operation: str = "" # e.g., "delete-collection"
|
operation: str = "" # e.g., "delete-collection"
|
||||||
|
|
||||||
# Workspace the collection belongs to.
|
|
||||||
workspace: str = ""
|
|
||||||
|
|
||||||
collection: str = ""
|
collection: str = ""
|
||||||
timestamp: str = "" # ISO timestamp
|
timestamp: str = "" # ISO timestamp
|
||||||
name: str = ""
|
name: str = ""
|
||||||
|
|
|
||||||
|
|
@ -70,6 +70,11 @@ class ConfigResponse:
|
||||||
# Everything
|
# Everything
|
||||||
error: Error | None = None
|
error: Error | None = None
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WorkspaceChanges:
|
||||||
|
created: list[str] = field(default_factory=list)
|
||||||
|
deleted: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConfigPush:
|
class ConfigPush:
|
||||||
version: int = 0
|
version: int = 0
|
||||||
|
|
@ -80,6 +85,10 @@ class ConfigPush:
|
||||||
# e.g. {"prompt": ["workspace-a", "workspace-b"], "schema": ["workspace-a"]}
|
# e.g. {"prompt": ["workspace-a", "workspace-b"], "schema": ["workspace-a"]}
|
||||||
changes: dict[str, list[str]] = field(default_factory=dict)
|
changes: dict[str, list[str]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Workspace lifecycle events. Populated when a workspace entry
|
||||||
|
# is created or deleted in the __workspaces__ config namespace.
|
||||||
|
workspace_changes: WorkspaceChanges | None = None
|
||||||
|
|
||||||
config_request_queue = queue('config', cls='request')
|
config_request_queue = queue('config', cls='request')
|
||||||
config_response_queue = queue('config', cls='response')
|
config_response_queue = queue('config', cls='response')
|
||||||
config_push_queue = queue('config', cls='notify')
|
config_push_queue = queue('config', cls='notify')
|
||||||
|
|
|
||||||
|
|
@ -22,9 +22,6 @@ class FlowRequest:
|
||||||
operation: str = "" # list-blueprints, get-blueprint, put-blueprint, delete-blueprint
|
operation: str = "" # list-blueprints, get-blueprint, put-blueprint, delete-blueprint
|
||||||
# list-flows, get-flow, start-flow, stop-flow
|
# list-flows, get-flow, start-flow, stop-flow
|
||||||
|
|
||||||
# Workspace scope — all operations act within this workspace
|
|
||||||
workspace: str = ""
|
|
||||||
|
|
||||||
# get_blueprint, put_blueprint, delete_blueprint, start_flow
|
# get_blueprint, put_blueprint, delete_blueprint, start_flow
|
||||||
blueprint_name: str = ""
|
blueprint_name: str = ""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -43,12 +43,12 @@ from ..core.metadata import Metadata
|
||||||
# <- (error)
|
# <- (error)
|
||||||
|
|
||||||
# list-documents
|
# list-documents
|
||||||
# -> (workspace, collection?)
|
# -> (collection?)
|
||||||
# <- (document_metadata[])
|
# <- (document_metadata[])
|
||||||
# <- (error)
|
# <- (error)
|
||||||
|
|
||||||
# list-processing
|
# list-processing
|
||||||
# -> (workspace, collection?)
|
# -> (collection?)
|
||||||
# <- (processing_metadata[])
|
# <- (processing_metadata[])
|
||||||
# <- (error)
|
# <- (error)
|
||||||
|
|
||||||
|
|
@ -78,7 +78,7 @@ from ..core.metadata import Metadata
|
||||||
# <- (error)
|
# <- (error)
|
||||||
|
|
||||||
# list-uploads
|
# list-uploads
|
||||||
# -> (workspace)
|
# -> ()
|
||||||
# <- (uploads[])
|
# <- (uploads[])
|
||||||
# <- (error)
|
# <- (error)
|
||||||
|
|
||||||
|
|
@ -90,7 +90,6 @@ class DocumentMetadata:
|
||||||
title: str = ""
|
title: str = ""
|
||||||
comments: str = ""
|
comments: str = ""
|
||||||
metadata: list[Triple] = field(default_factory=list)
|
metadata: list[Triple] = field(default_factory=list)
|
||||||
workspace: str = ""
|
|
||||||
tags: list[str] = field(default_factory=list)
|
tags: list[str] = field(default_factory=list)
|
||||||
# Child document support
|
# Child document support
|
||||||
parent_id: str = "" # Empty for top-level docs, set for children
|
parent_id: str = "" # Empty for top-level docs, set for children
|
||||||
|
|
@ -107,7 +106,6 @@ class ProcessingMetadata:
|
||||||
document_id: str = ""
|
document_id: str = ""
|
||||||
time: int = 0
|
time: int = 0
|
||||||
flow: str = ""
|
flow: str = ""
|
||||||
workspace: str = ""
|
|
||||||
collection: str = ""
|
collection: str = ""
|
||||||
tags: list[str] = field(default_factory=list)
|
tags: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
@ -162,9 +160,6 @@ class LibrarianRequest:
|
||||||
# add-document, upload-chunk
|
# add-document, upload-chunk
|
||||||
content: bytes = b""
|
content: bytes = b""
|
||||||
|
|
||||||
# Workspace scopes every library operation.
|
|
||||||
workspace: str = ""
|
|
||||||
|
|
||||||
# list-documents?, list-processing?
|
# list-documents?, list-processing?
|
||||||
collection: str = ""
|
collection: str = ""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,15 +22,15 @@ def dump_status(metrics_url, api_url, flow_id, token=None,
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(f"Flow {flow_id}")
|
print(f"Flow {flow_id}")
|
||||||
show_processors(metrics_url, flow_id)
|
show_processors(metrics_url, flow_id, token=token)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(f"Blueprint {blueprint_name}")
|
print(f"Blueprint {blueprint_name}")
|
||||||
show_processors(metrics_url, blueprint_name)
|
show_processors(metrics_url, blueprint_name, token=token)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
|
|
||||||
def show_processors(metrics_url, flow_label):
|
def show_processors(metrics_url, flow_label, token=None):
|
||||||
|
|
||||||
url = f"{metrics_url}/query"
|
url = f"{metrics_url}/query"
|
||||||
|
|
||||||
|
|
@ -40,7 +40,11 @@ def show_processors(metrics_url, flow_label):
|
||||||
"query": "consumer_state{" + expr + "}"
|
"query": "consumer_state{" + expr + "}"
|
||||||
}
|
}
|
||||||
|
|
||||||
resp = requests.get(url, params=params)
|
headers = {}
|
||||||
|
if token:
|
||||||
|
headers["Authorization"] = f"Bearer {token}"
|
||||||
|
|
||||||
|
resp = requests.get(url, params=params, headers=headers)
|
||||||
|
|
||||||
obj = resp.json()
|
obj = resp.json()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,16 +2,22 @@
|
||||||
Dump out TrustGraph processor states.
|
Dump out TrustGraph processor states.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
import requests
|
import requests
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
default_metrics_url = "http://localhost:8088/api/metrics"
|
default_metrics_url = "http://localhost:8088/api/metrics"
|
||||||
|
DEFAULT_TOKEN = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||||
|
|
||||||
def dump_status(url):
|
def dump_status(metrics_url, token=None):
|
||||||
|
|
||||||
url = f"{url}/query?query=processor_info"
|
url = f"{metrics_url}/query?query=processor_info"
|
||||||
|
|
||||||
resp = requests.get(url)
|
headers = {}
|
||||||
|
if token:
|
||||||
|
headers["Authorization"] = f"Bearer {token}"
|
||||||
|
|
||||||
|
resp = requests.get(url, headers=headers)
|
||||||
|
|
||||||
obj = resp.json()
|
obj = resp.json()
|
||||||
|
|
||||||
|
|
@ -39,11 +45,17 @@ def main():
|
||||||
help=f'Metrics URL (default: {default_metrics_url})',
|
help=f'Metrics URL (default: {default_metrics_url})',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-t', '--token',
|
||||||
|
default=DEFAULT_TOKEN,
|
||||||
|
help=f'Bearer token for authentication (default: TRUSTGRAPH_TOKEN env var)',
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
dump_status(args.metrics_url)
|
dump_status(args.metrics_url, args.token)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,14 @@ Dump out a stream of token rates, input, output and total. This is averaged
|
||||||
across the time since tg-show-token-rate is started.
|
across the time since tg-show-token-rate is started.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
import requests
|
import requests
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
|
||||||
default_metrics_url = "http://localhost:8088/api/metrics"
|
default_metrics_url = "http://localhost:8088/api/metrics"
|
||||||
|
DEFAULT_TOKEN = os.getenv("TRUSTGRAPH_TOKEN", None)
|
||||||
|
|
||||||
class Collate:
|
class Collate:
|
||||||
|
|
||||||
|
|
@ -36,16 +38,20 @@ class Collate:
|
||||||
|
|
||||||
return delta/time, self.total/self.time
|
return delta/time, self.total/self.time
|
||||||
|
|
||||||
def dump_status(metrics_url, number_samples, period):
|
def dump_status(metrics_url, number_samples, period, token=None):
|
||||||
|
|
||||||
input_url = f"{metrics_url}/query?query=input_tokens_total"
|
input_url = f"{metrics_url}/query?query=input_tokens_total"
|
||||||
output_url = f"{metrics_url}/query?query=output_tokens_total"
|
output_url = f"{metrics_url}/query?query=output_tokens_total"
|
||||||
|
|
||||||
resp = requests.get(input_url)
|
headers = {}
|
||||||
|
if token:
|
||||||
|
headers["Authorization"] = f"Bearer {token}"
|
||||||
|
|
||||||
|
resp = requests.get(input_url, headers=headers)
|
||||||
obj = resp.json()
|
obj = resp.json()
|
||||||
input = Collate(obj)
|
input = Collate(obj)
|
||||||
|
|
||||||
resp = requests.get(output_url)
|
resp = requests.get(output_url, headers=headers)
|
||||||
obj = resp.json()
|
obj = resp.json()
|
||||||
output = Collate(obj)
|
output = Collate(obj)
|
||||||
|
|
||||||
|
|
@ -56,11 +62,11 @@ def dump_status(metrics_url, number_samples, period):
|
||||||
|
|
||||||
time.sleep(period)
|
time.sleep(period)
|
||||||
|
|
||||||
resp = requests.get(input_url)
|
resp = requests.get(input_url, headers=headers)
|
||||||
obj = resp.json()
|
obj = resp.json()
|
||||||
inr, inl = input.record(obj, period)
|
inr, inl = input.record(obj, period)
|
||||||
|
|
||||||
resp = requests.get(output_url)
|
resp = requests.get(output_url, headers=headers)
|
||||||
obj = resp.json()
|
obj = resp.json()
|
||||||
outr, outl = output.record(obj, period)
|
outr, outl = output.record(obj, period)
|
||||||
|
|
||||||
|
|
@ -69,7 +75,7 @@ def dump_status(metrics_url, number_samples, period):
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
prog='tg-show-processor-state',
|
prog='tg-show-token-rate',
|
||||||
description=__doc__,
|
description=__doc__,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -93,6 +99,12 @@ def main():
|
||||||
help=f'Metrics period (default: 100)',
|
help=f'Metrics period (default: 100)',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-t', '--token',
|
||||||
|
default=DEFAULT_TOKEN,
|
||||||
|
help=f'Bearer token for authentication (default: TRUSTGRAPH_TOKEN env var)',
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,10 @@ class FlowContext:
|
||||||
def __call__(self, service_name):
|
def __call__(self, service_name):
|
||||||
return self._flow(service_name)
|
return self._flow(service_name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def librarian(self):
|
||||||
|
return self._flow.librarian
|
||||||
|
|
||||||
|
|
||||||
class UsageTracker:
|
class UsageTracker:
|
||||||
"""Accumulates token usage across multiple prompt calls."""
|
"""Accumulates token usage across multiple prompt calls."""
|
||||||
|
|
@ -320,9 +324,9 @@ class PatternBase:
|
||||||
f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought"
|
f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought"
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
await self.processor.save_answer_content(
|
await flow.librarian.save_document(
|
||||||
doc_id=thought_doc_id,
|
doc_id=thought_doc_id,
|
||||||
workspace=flow.workspace,
|
|
||||||
content=act.thought,
|
content=act.thought,
|
||||||
title=f"Agent Thought: {act.name}",
|
title=f"Agent Thought: {act.name}",
|
||||||
)
|
)
|
||||||
|
|
@ -389,9 +393,9 @@ class PatternBase:
|
||||||
f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation"
|
f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation"
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
await self.processor.save_answer_content(
|
await flow.librarian.save_document(
|
||||||
doc_id=observation_doc_id,
|
doc_id=observation_doc_id,
|
||||||
workspace=flow.workspace,
|
|
||||||
content=observation_text,
|
content=observation_text,
|
||||||
title=f"Agent Observation",
|
title=f"Agent Observation",
|
||||||
)
|
)
|
||||||
|
|
@ -445,9 +449,9 @@ class PatternBase:
|
||||||
if answer_text:
|
if answer_text:
|
||||||
answer_doc_id = f"urn:trustgraph:agent:{session_id}/answer"
|
answer_doc_id = f"urn:trustgraph:agent:{session_id}/answer"
|
||||||
try:
|
try:
|
||||||
await self.processor.save_answer_content(
|
await flow.librarian.save_document(
|
||||||
doc_id=answer_doc_id,
|
doc_id=answer_doc_id,
|
||||||
workspace=flow.workspace,
|
|
||||||
content=answer_text,
|
content=answer_text,
|
||||||
title=f"Agent Answer: {request.question[:50]}...",
|
title=f"Agent Answer: {request.question[:50]}...",
|
||||||
)
|
)
|
||||||
|
|
@ -521,8 +525,8 @@ class PatternBase:
|
||||||
|
|
||||||
doc_id = f"urn:trustgraph:agent:{session_id}/finding/{index}/doc"
|
doc_id = f"urn:trustgraph:agent:{session_id}/finding/{index}/doc"
|
||||||
try:
|
try:
|
||||||
await self.processor.save_answer_content(
|
await flow.librarian.save_document(
|
||||||
doc_id=doc_id, workspace=flow.workspace,
|
doc_id=doc_id,
|
||||||
content=answer_text,
|
content=answer_text,
|
||||||
title=f"Finding: {goal[:60]}",
|
title=f"Finding: {goal[:60]}",
|
||||||
)
|
)
|
||||||
|
|
@ -574,8 +578,8 @@ class PatternBase:
|
||||||
|
|
||||||
doc_id = f"urn:trustgraph:agent:{session_id}/step/{index}/doc"
|
doc_id = f"urn:trustgraph:agent:{session_id}/step/{index}/doc"
|
||||||
try:
|
try:
|
||||||
await self.processor.save_answer_content(
|
await flow.librarian.save_document(
|
||||||
doc_id=doc_id, workspace=flow.workspace,
|
doc_id=doc_id,
|
||||||
content=answer_text,
|
content=answer_text,
|
||||||
title=f"Step result: {goal[:60]}",
|
title=f"Step result: {goal[:60]}",
|
||||||
)
|
)
|
||||||
|
|
@ -606,8 +610,8 @@ class PatternBase:
|
||||||
|
|
||||||
doc_id = f"urn:trustgraph:agent:{session_id}/synthesis/doc"
|
doc_id = f"urn:trustgraph:agent:{session_id}/synthesis/doc"
|
||||||
try:
|
try:
|
||||||
await self.processor.save_answer_content(
|
await flow.librarian.save_document(
|
||||||
doc_id=doc_id, workspace=flow.workspace,
|
doc_id=doc_id,
|
||||||
content=answer_text,
|
content=answer_text,
|
||||||
title="Synthesis",
|
title="Synthesis",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -7,26 +7,17 @@ to select between ReactPattern, PlanThenExecutePattern, and
|
||||||
SupervisorPattern at runtime.
|
SupervisorPattern at runtime.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import base64
|
|
||||||
import json
|
import json
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
|
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
|
||||||
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
|
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
|
||||||
from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec
|
from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec
|
||||||
from ... base import ProducerSpec
|
from ... base import ProducerSpec, LibrarianSpec
|
||||||
from ... base import Consumer, Producer
|
|
||||||
from ... base import ConsumerMetrics, ProducerMetrics
|
|
||||||
|
|
||||||
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
|
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
|
||||||
from ..orchestrator.pattern_base import UsageTracker, PatternBase
|
from ..orchestrator.pattern_base import UsageTracker, PatternBase
|
||||||
from ... schema import Triples, Metadata
|
from ... schema import Triples, Metadata
|
||||||
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
|
||||||
from ... schema import librarian_request_queue, librarian_response_queue
|
|
||||||
|
|
||||||
from trustgraph.provenance import (
|
from trustgraph.provenance import (
|
||||||
agent_session_uri,
|
agent_session_uri,
|
||||||
|
|
@ -52,8 +43,6 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
default_ident = "agent-manager"
|
default_ident = "agent-manager"
|
||||||
default_max_iterations = 10
|
default_max_iterations = 10
|
||||||
default_librarian_request_queue = librarian_request_queue
|
|
||||||
default_librarian_response_queue = librarian_response_queue
|
|
||||||
|
|
||||||
|
|
||||||
class Processor(AgentService):
|
class Processor(AgentService):
|
||||||
|
|
@ -151,94 +140,9 @@ class Processor(AgentService):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Librarian client
|
self.register_specification(
|
||||||
librarian_request_q = params.get(
|
LibrarianSpec()
|
||||||
"librarian_request_queue", default_librarian_request_queue
|
|
||||||
)
|
)
|
||||||
librarian_response_q = params.get(
|
|
||||||
"librarian_response_queue", default_librarian_response_queue
|
|
||||||
)
|
|
||||||
|
|
||||||
librarian_request_metrics = ProducerMetrics(
|
|
||||||
processor=id, flow=None, name="librarian-request"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.librarian_request_producer = Producer(
|
|
||||||
backend=self.pubsub,
|
|
||||||
topic=librarian_request_q,
|
|
||||||
schema=LibrarianRequest,
|
|
||||||
metrics=librarian_request_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
librarian_response_metrics = ConsumerMetrics(
|
|
||||||
processor=id, flow=None, name="librarian-response"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.librarian_response_consumer = Consumer(
|
|
||||||
taskgroup=self.taskgroup,
|
|
||||||
backend=self.pubsub,
|
|
||||||
flow=None,
|
|
||||||
topic=librarian_response_q,
|
|
||||||
subscriber=f"{id}-librarian",
|
|
||||||
schema=LibrarianResponse,
|
|
||||||
handler=self.on_librarian_response,
|
|
||||||
metrics=librarian_response_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.pending_librarian_requests = {}
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
await super(Processor, self).start()
|
|
||||||
await self.librarian_request_producer.start()
|
|
||||||
await self.librarian_response_consumer.start()
|
|
||||||
|
|
||||||
async def on_librarian_response(self, msg, consumer, flow):
|
|
||||||
response = msg.value()
|
|
||||||
request_id = msg.properties().get("id")
|
|
||||||
|
|
||||||
if request_id in self.pending_librarian_requests:
|
|
||||||
future = self.pending_librarian_requests.pop(request_id)
|
|
||||||
future.set_result(response)
|
|
||||||
|
|
||||||
async def save_answer_content(self, doc_id, workspace, content, title=None,
|
|
||||||
timeout=120):
|
|
||||||
request_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
doc_metadata = DocumentMetadata(
|
|
||||||
id=doc_id,
|
|
||||||
workspace=workspace,
|
|
||||||
kind="text/plain",
|
|
||||||
title=title or "Agent Answer",
|
|
||||||
document_type="answer",
|
|
||||||
)
|
|
||||||
|
|
||||||
request = LibrarianRequest(
|
|
||||||
operation="add-document",
|
|
||||||
document_id=doc_id,
|
|
||||||
document_metadata=doc_metadata,
|
|
||||||
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
|
|
||||||
workspace=workspace,
|
|
||||||
)
|
|
||||||
|
|
||||||
future = asyncio.get_event_loop().create_future()
|
|
||||||
self.pending_librarian_requests[request_id] = future
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self.librarian_request_producer.send(
|
|
||||||
request, properties={"id": request_id}
|
|
||||||
)
|
|
||||||
response = await asyncio.wait_for(future, timeout=timeout)
|
|
||||||
|
|
||||||
if response.error:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Librarian error saving answer: "
|
|
||||||
f"{response.error.type}: {response.error.message}"
|
|
||||||
)
|
|
||||||
return doc_id
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self.pending_librarian_requests.pop(request_id, None)
|
|
||||||
raise RuntimeError(f"Timeout saving answer document {doc_id}")
|
|
||||||
|
|
||||||
def provenance_session_uri(self, session_id):
|
def provenance_session_uri(self, session_id):
|
||||||
return agent_session_uri(session_id)
|
return agent_session_uri(session_id)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ Simple agent infrastructure broadly implements the ReAct flow.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
|
@ -19,14 +18,10 @@ logger = logging.getLogger(__name__)
|
||||||
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
|
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
|
||||||
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
|
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
|
||||||
from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec
|
from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec
|
||||||
from ... base import ProducerSpec
|
from ... base import ProducerSpec, LibrarianSpec
|
||||||
from ... base import Consumer, Producer
|
|
||||||
from ... base import ConsumerMetrics, ProducerMetrics
|
|
||||||
|
|
||||||
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
|
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
|
||||||
from ... schema import Triples, Metadata
|
from ... schema import Triples, Metadata
|
||||||
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
|
||||||
from ... schema import librarian_request_queue, librarian_response_queue
|
|
||||||
|
|
||||||
# Provenance imports for agent explainability
|
# Provenance imports for agent explainability
|
||||||
from trustgraph.provenance import (
|
from trustgraph.provenance import (
|
||||||
|
|
@ -51,8 +46,6 @@ from . types import Final, Action, Tool, Argument
|
||||||
|
|
||||||
default_ident = "agent-manager"
|
default_ident = "agent-manager"
|
||||||
default_max_iterations = 10
|
default_max_iterations = 10
|
||||||
default_librarian_request_queue = librarian_request_queue
|
|
||||||
default_librarian_response_queue = librarian_response_queue
|
|
||||||
|
|
||||||
class Processor(AgentService):
|
class Processor(AgentService):
|
||||||
|
|
||||||
|
|
@ -141,112 +134,9 @@ class Processor(AgentService):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Librarian client for storing answer content
|
self.register_specification(
|
||||||
librarian_request_q = params.get(
|
LibrarianSpec()
|
||||||
"librarian_request_queue", default_librarian_request_queue
|
|
||||||
)
|
)
|
||||||
librarian_response_q = params.get(
|
|
||||||
"librarian_response_queue", default_librarian_response_queue
|
|
||||||
)
|
|
||||||
|
|
||||||
librarian_request_metrics = ProducerMetrics(
|
|
||||||
processor=id, flow=None, name="librarian-request"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.librarian_request_producer = Producer(
|
|
||||||
backend=self.pubsub,
|
|
||||||
topic=librarian_request_q,
|
|
||||||
schema=LibrarianRequest,
|
|
||||||
metrics=librarian_request_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
librarian_response_metrics = ConsumerMetrics(
|
|
||||||
processor=id, flow=None, name="librarian-response"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.librarian_response_consumer = Consumer(
|
|
||||||
taskgroup=self.taskgroup,
|
|
||||||
backend=self.pubsub,
|
|
||||||
flow=None,
|
|
||||||
topic=librarian_response_q,
|
|
||||||
subscriber=f"{id}-librarian",
|
|
||||||
schema=LibrarianResponse,
|
|
||||||
handler=self.on_librarian_response,
|
|
||||||
metrics=librarian_response_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pending librarian requests: request_id -> asyncio.Future
|
|
||||||
self.pending_librarian_requests = {}
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
await super(Processor, self).start()
|
|
||||||
await self.librarian_request_producer.start()
|
|
||||||
await self.librarian_response_consumer.start()
|
|
||||||
|
|
||||||
async def on_librarian_response(self, msg, consumer, flow):
|
|
||||||
"""Handle responses from the librarian service."""
|
|
||||||
response = msg.value()
|
|
||||||
request_id = msg.properties().get("id")
|
|
||||||
|
|
||||||
if request_id in self.pending_librarian_requests:
|
|
||||||
future = self.pending_librarian_requests.pop(request_id)
|
|
||||||
future.set_result(response)
|
|
||||||
|
|
||||||
async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120):
|
|
||||||
"""
|
|
||||||
Save answer content to the librarian.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
doc_id: ID for the answer document
|
|
||||||
workspace: Workspace for isolation
|
|
||||||
content: Answer text content
|
|
||||||
title: Optional title
|
|
||||||
timeout: Request timeout in seconds
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The document ID on success
|
|
||||||
"""
|
|
||||||
request_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
doc_metadata = DocumentMetadata(
|
|
||||||
id=doc_id,
|
|
||||||
workspace=workspace,
|
|
||||||
kind="text/plain",
|
|
||||||
title=title or "Agent Answer",
|
|
||||||
document_type="answer",
|
|
||||||
)
|
|
||||||
|
|
||||||
request = LibrarianRequest(
|
|
||||||
operation="add-document",
|
|
||||||
document_id=doc_id,
|
|
||||||
document_metadata=doc_metadata,
|
|
||||||
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
|
|
||||||
workspace=workspace,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create future for response
|
|
||||||
future = asyncio.get_event_loop().create_future()
|
|
||||||
self.pending_librarian_requests[request_id] = future
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Send request
|
|
||||||
await self.librarian_request_producer.send(
|
|
||||||
request, properties={"id": request_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait for response
|
|
||||||
response = await asyncio.wait_for(future, timeout=timeout)
|
|
||||||
|
|
||||||
if response.error:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Librarian error saving answer: {response.error.type}: {response.error.message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return doc_id
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self.pending_librarian_requests.pop(request_id, None)
|
|
||||||
raise RuntimeError(f"Timeout saving answer document {doc_id}")
|
|
||||||
|
|
||||||
async def on_tools_config(self, workspace, config, version):
|
async def on_tools_config(self, workspace, config, version):
|
||||||
|
|
||||||
|
|
@ -611,9 +501,9 @@ class Processor(AgentService):
|
||||||
if act_decision.thought:
|
if act_decision.thought:
|
||||||
t_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought"
|
t_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought"
|
||||||
try:
|
try:
|
||||||
await self.save_answer_content(
|
await flow.librarian.save_document(
|
||||||
doc_id=t_doc_id,
|
doc_id=t_doc_id,
|
||||||
workspace=flow.workspace,
|
|
||||||
content=act_decision.thought,
|
content=act_decision.thought,
|
||||||
title=f"Agent Thought: {act_decision.name}",
|
title=f"Agent Thought: {act_decision.name}",
|
||||||
)
|
)
|
||||||
|
|
@ -691,9 +581,9 @@ class Processor(AgentService):
|
||||||
if f:
|
if f:
|
||||||
answer_doc_id = f"urn:trustgraph:agent:{session_id}/answer"
|
answer_doc_id = f"urn:trustgraph:agent:{session_id}/answer"
|
||||||
try:
|
try:
|
||||||
await self.save_answer_content(
|
await flow.librarian.save_document(
|
||||||
doc_id=answer_doc_id,
|
doc_id=answer_doc_id,
|
||||||
workspace=flow.workspace,
|
|
||||||
content=f,
|
content=f,
|
||||||
title=f"Agent Answer: {request.question[:50]}...",
|
title=f"Agent Answer: {request.question[:50]}...",
|
||||||
)
|
)
|
||||||
|
|
@ -768,9 +658,8 @@ class Processor(AgentService):
|
||||||
if act.observation:
|
if act.observation:
|
||||||
observation_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation"
|
observation_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation"
|
||||||
try:
|
try:
|
||||||
await self.save_answer_content(
|
await flow.librarian.save_document(
|
||||||
doc_id=observation_doc_id,
|
doc_id=observation_doc_id,
|
||||||
workspace=flow.workspace,
|
|
||||||
content=act.observation,
|
content=act.observation,
|
||||||
title=f"Agent Observation",
|
title=f"Agent Observation",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,8 @@ class InitContext:
|
||||||
|
|
||||||
logger: logging.Logger
|
logger: logging.Logger
|
||||||
config: Any # ConfigClient
|
config: Any # ConfigClient
|
||||||
flow: Any # RequestResponse client for flow-svc
|
make_flow_client: Any # callable(workspace) -> RequestResponse
|
||||||
|
make_iam_client: Any # callable() -> RequestResponse
|
||||||
|
|
||||||
|
|
||||||
class Initialiser:
|
class Initialiser:
|
||||||
|
|
@ -35,7 +36,7 @@ class Initialiser:
|
||||||
|
|
||||||
* ``wait_for_services`` (bool, default ``True``): when ``True`` the
|
* ``wait_for_services`` (bool, default ``True``): when ``True`` the
|
||||||
initialiser only runs after the bootstrapper's service gate has
|
initialiser only runs after the bootstrapper's service gate has
|
||||||
passed (config-svc and flow-svc reachable). Set ``False`` for
|
passed (config-svc reachable). Set ``False`` for
|
||||||
initialisers that bring up infrastructure the gate itself
|
initialisers that bring up infrastructure the gate itself
|
||||||
depends on — principally Pulsar topology, without which
|
depends on — principally Pulsar topology, without which
|
||||||
config-svc cannot come online.
|
config-svc cannot come online.
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,10 @@ from trustgraph.schema import (
|
||||||
FlowRequest, FlowResponse,
|
FlowRequest, FlowResponse,
|
||||||
flow_request_queue, flow_response_queue,
|
flow_request_queue, flow_response_queue,
|
||||||
)
|
)
|
||||||
|
from trustgraph.schema import (
|
||||||
|
IamRequest, IamResponse,
|
||||||
|
iam_request_queue, iam_response_queue,
|
||||||
|
)
|
||||||
|
|
||||||
from .. base import Initialiser, InitContext
|
from .. base import Initialiser, InitContext
|
||||||
|
|
||||||
|
|
@ -178,34 +182,46 @@ class Processor(AsyncProcessor):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _make_flow_client(self):
|
def _make_flow_client(self, workspace):
|
||||||
rr_id = str(uuid.uuid4())
|
rr_id = str(uuid.uuid4())
|
||||||
return RequestResponse(
|
return RequestResponse(
|
||||||
backend=self.pubsub_backend,
|
backend=self.pubsub_backend,
|
||||||
subscription=f"{self.id}--flow--{rr_id}",
|
subscription=f"{self.id}--flow--{rr_id}",
|
||||||
consumer_name=self.id,
|
consumer_name=self.id,
|
||||||
request_topic=flow_request_queue,
|
request_topic=f"{flow_request_queue}:{workspace}",
|
||||||
request_schema=FlowRequest,
|
request_schema=FlowRequest,
|
||||||
request_metrics=ProducerMetrics(
|
request_metrics=ProducerMetrics(
|
||||||
processor=self.id, flow=None, name="flow-request",
|
processor=self.id, flow=None, name="flow-request",
|
||||||
),
|
),
|
||||||
response_topic=flow_response_queue,
|
response_topic=f"{flow_response_queue}:{workspace}",
|
||||||
response_schema=FlowResponse,
|
response_schema=FlowResponse,
|
||||||
response_metrics=SubscriberMetrics(
|
response_metrics=SubscriberMetrics(
|
||||||
processor=self.id, flow=None, name="flow-response",
|
processor=self.id, flow=None, name="flow-response",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _make_iam_client(self):
|
||||||
|
rr_id = str(uuid.uuid4())
|
||||||
|
return RequestResponse(
|
||||||
|
backend=self.pubsub_backend,
|
||||||
|
subscription=f"{self.id}--iam--{rr_id}",
|
||||||
|
consumer_name=self.id,
|
||||||
|
request_topic=iam_request_queue,
|
||||||
|
request_schema=IamRequest,
|
||||||
|
request_metrics=ProducerMetrics(
|
||||||
|
processor=self.id, flow=None, name="iam-request",
|
||||||
|
),
|
||||||
|
response_topic=iam_response_queue,
|
||||||
|
response_schema=IamResponse,
|
||||||
|
response_metrics=SubscriberMetrics(
|
||||||
|
processor=self.id, flow=None, name="iam-response",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
async def _open_clients(self):
|
async def _open_clients(self):
|
||||||
config = self._make_config_client()
|
config = self._make_config_client()
|
||||||
flow = self._make_flow_client()
|
|
||||||
await config.start()
|
await config.start()
|
||||||
try:
|
return config
|
||||||
await flow.start()
|
|
||||||
except Exception:
|
|
||||||
await self._safe_stop(config)
|
|
||||||
raise
|
|
||||||
return config, flow
|
|
||||||
|
|
||||||
async def _safe_stop(self, client):
|
async def _safe_stop(self, client):
|
||||||
try:
|
try:
|
||||||
|
|
@ -217,7 +233,7 @@ class Processor(AsyncProcessor):
|
||||||
# Service gate.
|
# Service gate.
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
async def _gate_ready(self, config, flow):
|
async def _gate_ready(self, config):
|
||||||
try:
|
try:
|
||||||
await config.keys(SYSTEM_WORKSPACE, INIT_STATE_TYPE)
|
await config.keys(SYSTEM_WORKSPACE, INIT_STATE_TYPE)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -226,26 +242,6 @@ class Processor(AsyncProcessor):
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
|
||||||
resp = await flow.request(
|
|
||||||
FlowRequest(
|
|
||||||
operation="list-blueprints",
|
|
||||||
workspace=SYSTEM_WORKSPACE,
|
|
||||||
),
|
|
||||||
timeout=5,
|
|
||||||
)
|
|
||||||
if resp.error:
|
|
||||||
logger.info(
|
|
||||||
f"Gate: flow-svc error: "
|
|
||||||
f"{resp.error.type}: {resp.error.message}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.info(
|
|
||||||
f"Gate: flow-svc not ready ({type(e).__name__}: {e})"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
@ -271,7 +267,7 @@ class Processor(AsyncProcessor):
|
||||||
# Per-spec execution.
|
# Per-spec execution.
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
async def _run_spec(self, spec, config, flow):
|
async def _run_spec(self, spec, config):
|
||||||
"""Run a single initialiser spec.
|
"""Run a single initialiser spec.
|
||||||
|
|
||||||
Returns one of:
|
Returns one of:
|
||||||
|
|
@ -298,7 +294,8 @@ class Processor(AsyncProcessor):
|
||||||
child_ctx = InitContext(
|
child_ctx = InitContext(
|
||||||
logger=child_logger,
|
logger=child_logger,
|
||||||
config=config,
|
config=config,
|
||||||
flow=flow,
|
make_flow_client=self._make_flow_client,
|
||||||
|
make_iam_client=self._make_iam_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
child_logger.info(
|
child_logger.info(
|
||||||
|
|
@ -340,7 +337,7 @@ class Processor(AsyncProcessor):
|
||||||
sleep_for = STEADY_INTERVAL
|
sleep_for = STEADY_INTERVAL
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config, flow = await self._open_clients()
|
config = await self._open_clients()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Failed to open clients "
|
f"Failed to open clients "
|
||||||
|
|
@ -358,11 +355,11 @@ class Processor(AsyncProcessor):
|
||||||
pre_results = {}
|
pre_results = {}
|
||||||
for spec in pre_specs:
|
for spec in pre_specs:
|
||||||
pre_results[spec.name] = await self._run_spec(
|
pre_results[spec.name] = await self._run_spec(
|
||||||
spec, config, flow,
|
spec, config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Phase 2: gate.
|
# Phase 2: gate.
|
||||||
gate_ok = await self._gate_ready(config, flow)
|
gate_ok = await self._gate_ready(config)
|
||||||
|
|
||||||
# Phase 3: post-service initialisers, if gate passed.
|
# Phase 3: post-service initialisers, if gate passed.
|
||||||
post_results = {}
|
post_results = {}
|
||||||
|
|
@ -373,7 +370,7 @@ class Processor(AsyncProcessor):
|
||||||
]
|
]
|
||||||
for spec in post_specs:
|
for spec in post_specs:
|
||||||
post_results[spec.name] = await self._run_spec(
|
post_results[spec.name] = await self._run_spec(
|
||||||
spec, config, flow,
|
spec, config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Cadence selection.
|
# Cadence selection.
|
||||||
|
|
@ -388,7 +385,6 @@ class Processor(AsyncProcessor):
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
await self._safe_stop(config)
|
await self._safe_stop(config)
|
||||||
await self._safe_stop(flow)
|
|
||||||
|
|
||||||
await asyncio.sleep(sleep_for)
|
await asyncio.sleep(sleep_for)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -49,53 +49,67 @@ class DefaultFlowStart(Initialiser):
|
||||||
|
|
||||||
async def run(self, ctx, old_flag, new_flag):
|
async def run(self, ctx, old_flag, new_flag):
|
||||||
|
|
||||||
# Check whether the flow already exists. Belt-and-braces
|
workspaces = await ctx.config.keys(
|
||||||
# beyond the flag gate: if an operator stops and restarts the
|
"__workspaces__", "workspace",
|
||||||
# bootstrapper after the flow is already running, we don't
|
|
||||||
# want to blindly try to start it again.
|
|
||||||
list_resp = await ctx.flow.request(
|
|
||||||
FlowRequest(
|
|
||||||
operation="list-flows",
|
|
||||||
workspace=self.workspace,
|
|
||||||
),
|
|
||||||
timeout=10,
|
|
||||||
)
|
)
|
||||||
if list_resp.error:
|
if self.workspace not in workspaces:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"list-flows failed: "
|
f"Workspace {self.workspace!r} does not exist yet"
|
||||||
f"{list_resp.error.type}: {list_resp.error.message}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.flow_id in (list_resp.flow_ids or []):
|
flow = ctx.make_flow_client(self.workspace)
|
||||||
|
await flow.start()
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
# Check whether the flow already exists. Belt-and-braces
|
||||||
|
# beyond the flag gate: if an operator stops and restarts the
|
||||||
|
# bootstrapper after the flow is already running, we don't
|
||||||
|
# want to blindly try to start it again.
|
||||||
|
list_resp = await flow.request(
|
||||||
|
FlowRequest(
|
||||||
|
operation="list-flows",
|
||||||
|
),
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
if list_resp.error:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"list-flows failed: "
|
||||||
|
f"{list_resp.error.type}: {list_resp.error.message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.flow_id in (list_resp.flow_ids or []):
|
||||||
|
ctx.logger.info(
|
||||||
|
f"Flow {self.flow_id!r} already running in workspace "
|
||||||
|
f"{self.workspace!r}; nothing to do"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
ctx.logger.info(
|
ctx.logger.info(
|
||||||
f"Flow {self.flow_id!r} already running in workspace "
|
f"Starting flow {self.flow_id!r} "
|
||||||
f"{self.workspace!r}; nothing to do"
|
f"(blueprint={self.blueprint!r}) "
|
||||||
)
|
f"in workspace {self.workspace!r}"
|
||||||
return
|
|
||||||
|
|
||||||
ctx.logger.info(
|
|
||||||
f"Starting flow {self.flow_id!r} "
|
|
||||||
f"(blueprint={self.blueprint!r}) "
|
|
||||||
f"in workspace {self.workspace!r}"
|
|
||||||
)
|
|
||||||
|
|
||||||
resp = await ctx.flow.request(
|
|
||||||
FlowRequest(
|
|
||||||
operation="start-flow",
|
|
||||||
workspace=self.workspace,
|
|
||||||
flow_id=self.flow_id,
|
|
||||||
blueprint_name=self.blueprint,
|
|
||||||
description=self.description,
|
|
||||||
parameters=self.parameters,
|
|
||||||
),
|
|
||||||
timeout=30,
|
|
||||||
)
|
|
||||||
if resp.error:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"start-flow failed: "
|
|
||||||
f"{resp.error.type}: {resp.error.message}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ctx.logger.info(
|
resp = await flow.request(
|
||||||
f"Flow {self.flow_id!r} started"
|
FlowRequest(
|
||||||
)
|
operation="start-flow",
|
||||||
|
flow_id=self.flow_id,
|
||||||
|
blueprint_name=self.blueprint,
|
||||||
|
description=self.description,
|
||||||
|
parameters=self.parameters,
|
||||||
|
),
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
if resp.error:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"start-flow failed: "
|
||||||
|
f"{resp.error.type}: {resp.error.message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx.logger.info(
|
||||||
|
f"Flow {self.flow_id!r} started"
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await flow.stop()
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,8 @@ the next cycle once the prerequisite is satisfied.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
from trustgraph.schema import IamRequest, WorkspaceInput
|
||||||
|
|
||||||
from .. base import Initialiser
|
from .. base import Initialiser
|
||||||
|
|
||||||
TEMPLATE_WORKSPACE = "__template__"
|
TEMPLATE_WORKSPACE = "__template__"
|
||||||
|
|
@ -59,6 +61,8 @@ class WorkspaceInit(Initialiser):
|
||||||
self.overwrite = overwrite
|
self.overwrite = overwrite
|
||||||
|
|
||||||
async def run(self, ctx, old_flag, new_flag):
|
async def run(self, ctx, old_flag, new_flag):
|
||||||
|
await self._create_workspace(ctx)
|
||||||
|
|
||||||
if self.source == "seed-file":
|
if self.source == "seed-file":
|
||||||
tree = self._load_seed_file()
|
tree = self._load_seed_file()
|
||||||
else:
|
else:
|
||||||
|
|
@ -105,6 +109,39 @@ class WorkspaceInit(Initialiser):
|
||||||
)
|
)
|
||||||
return tree
|
return tree
|
||||||
|
|
||||||
|
async def _create_workspace(self, ctx):
|
||||||
|
"""Register the workspace via the IAM create-workspace API."""
|
||||||
|
iam = ctx.make_iam_client()
|
||||||
|
await iam.start()
|
||||||
|
try:
|
||||||
|
resp = await iam.request(
|
||||||
|
IamRequest(
|
||||||
|
operation="create-workspace",
|
||||||
|
workspace_record=WorkspaceInput(
|
||||||
|
id=self.workspace,
|
||||||
|
name=self.workspace.title(),
|
||||||
|
enabled=True,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
if resp.error:
|
||||||
|
if resp.error.type == "duplicate":
|
||||||
|
ctx.logger.info(
|
||||||
|
f"Workspace {self.workspace!r} already exists in IAM"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"IAM create-workspace failed: "
|
||||||
|
f"{resp.error.type}: {resp.error.message}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ctx.logger.info(
|
||||||
|
f"Workspace {self.workspace!r} created via IAM"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await iam.stop()
|
||||||
|
|
||||||
async def _write_all(self, ctx, tree):
|
async def _write_all(self, ctx, tree):
|
||||||
values = []
|
values = []
|
||||||
for type_name, entries in tree.items():
|
for type_name, entries in tree.items():
|
||||||
|
|
@ -112,6 +149,7 @@ class WorkspaceInit(Initialiser):
|
||||||
values.append((type_name, key, json.dumps(value)))
|
values.append((type_name, key, json.dumps(value)))
|
||||||
if values:
|
if values:
|
||||||
await ctx.config.put_many(self.workspace, values)
|
await ctx.config.put_many(self.workspace, values)
|
||||||
|
|
||||||
ctx.logger.info(
|
ctx.logger.info(
|
||||||
f"Workspace {self.workspace!r} populated with "
|
f"Workspace {self.workspace!r} populated with "
|
||||||
f"{len(values)} entries"
|
f"{len(values)} entries"
|
||||||
|
|
@ -132,6 +170,7 @@ class WorkspaceInit(Initialiser):
|
||||||
if values:
|
if values:
|
||||||
await ctx.config.put_many(self.workspace, values)
|
await ctx.config.put_many(self.workspace, values)
|
||||||
written += len(values)
|
written += len(values)
|
||||||
|
|
||||||
ctx.logger.info(
|
ctx.logger.info(
|
||||||
f"Workspace {self.workspace!r} upsert-missing: "
|
f"Workspace {self.workspace!r} upsert-missing: "
|
||||||
f"{written} new entries"
|
f"{written} new entries"
|
||||||
|
|
|
||||||
|
|
@ -95,7 +95,7 @@ class Processor(ChunkingService):
|
||||||
logger.info(f"Chunking document {v.metadata.id}...")
|
logger.info(f"Chunking document {v.metadata.id}...")
|
||||||
|
|
||||||
# Get text content (fetches from librarian if needed)
|
# Get text content (fetches from librarian if needed)
|
||||||
text = await self.get_document_text(v, flow.workspace)
|
text = await self.get_document_text(v, flow)
|
||||||
|
|
||||||
# Extract chunk parameters from flow (allows runtime override)
|
# Extract chunk parameters from flow (allows runtime override)
|
||||||
chunk_size, chunk_overlap = await self.chunk_document(
|
chunk_size, chunk_overlap = await self.chunk_document(
|
||||||
|
|
@ -141,10 +141,9 @@ class Processor(ChunkingService):
|
||||||
chunk_length = len(chunk.page_content)
|
chunk_length = len(chunk.page_content)
|
||||||
|
|
||||||
# Save chunk to librarian as child document
|
# Save chunk to librarian as child document
|
||||||
await self.librarian.save_child_document(
|
await flow.librarian.save_child_document(
|
||||||
doc_id=chunk_doc_id,
|
doc_id=chunk_doc_id,
|
||||||
parent_id=parent_doc_id,
|
parent_id=parent_doc_id,
|
||||||
workspace=flow.workspace,
|
|
||||||
content=chunk_content,
|
content=chunk_content,
|
||||||
document_type="chunk",
|
document_type="chunk",
|
||||||
title=f"Chunk {chunk_index}",
|
title=f"Chunk {chunk_index}",
|
||||||
|
|
|
||||||
|
|
@ -92,7 +92,7 @@ class Processor(ChunkingService):
|
||||||
logger.info(f"Chunking document {v.metadata.id}...")
|
logger.info(f"Chunking document {v.metadata.id}...")
|
||||||
|
|
||||||
# Get text content (fetches from librarian if needed)
|
# Get text content (fetches from librarian if needed)
|
||||||
text = await self.get_document_text(v, flow.workspace)
|
text = await self.get_document_text(v, flow)
|
||||||
|
|
||||||
# Extract chunk parameters from flow (allows runtime override)
|
# Extract chunk parameters from flow (allows runtime override)
|
||||||
chunk_size, chunk_overlap = await self.chunk_document(
|
chunk_size, chunk_overlap = await self.chunk_document(
|
||||||
|
|
@ -137,10 +137,9 @@ class Processor(ChunkingService):
|
||||||
chunk_length = len(chunk.page_content)
|
chunk_length = len(chunk.page_content)
|
||||||
|
|
||||||
# Save chunk to librarian as child document
|
# Save chunk to librarian as child document
|
||||||
await self.librarian.save_child_document(
|
await flow.librarian.save_child_document(
|
||||||
doc_id=chunk_doc_id,
|
doc_id=chunk_doc_id,
|
||||||
parent_id=parent_doc_id,
|
parent_id=parent_doc_id,
|
||||||
workspace=flow.workspace,
|
|
||||||
content=chunk_content,
|
content=chunk_content,
|
||||||
document_type="chunk",
|
document_type="chunk",
|
||||||
title=f"Chunk {chunk_index}",
|
title=f"Chunk {chunk_index}",
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,17 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from trustgraph.schema import ConfigResponse
|
from trustgraph.schema import ConfigResponse
|
||||||
from trustgraph.schema import ConfigValue, Error
|
from trustgraph.schema import ConfigValue, WorkspaceChanges, Error
|
||||||
|
|
||||||
from ... tables.config import ConfigTableStore
|
from ... tables.config import ConfigTableStore
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
WORKSPACES_NAMESPACE = "__workspaces__"
|
||||||
|
WORKSPACE_TYPE = "workspace"
|
||||||
|
TEMPLATE_WORKSPACE = "__template__"
|
||||||
|
|
||||||
class Configuration:
|
class Configuration:
|
||||||
|
|
||||||
def __init__(self, push, host, username, password, keyspace,
|
def __init__(self, push, host, username, password, keyspace,
|
||||||
|
|
@ -27,9 +31,7 @@ class Configuration:
|
||||||
async def get_version(self):
|
async def get_version(self):
|
||||||
return await self.table_store.get_version()
|
return await self.table_store.get_version()
|
||||||
|
|
||||||
async def handle_get(self, v):
|
async def handle_get(self, v, workspace):
|
||||||
|
|
||||||
workspace = v.workspace
|
|
||||||
|
|
||||||
values = [
|
values = [
|
||||||
ConfigValue(
|
ConfigValue(
|
||||||
|
|
@ -47,18 +49,18 @@ class Configuration:
|
||||||
values = values,
|
values = values,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_list(self, v):
|
async def handle_list(self, v, workspace):
|
||||||
|
|
||||||
return ConfigResponse(
|
return ConfigResponse(
|
||||||
version = await self.get_version(),
|
version = await self.get_version(),
|
||||||
directory = await self.table_store.get_keys(
|
directory = await self.table_store.get_keys(
|
||||||
v.workspace, v.type
|
workspace, v.type
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_getvalues(self, v):
|
async def handle_getvalues(self, v, workspace):
|
||||||
|
|
||||||
vals = await self.table_store.get_values(v.workspace, v.type)
|
vals = await self.table_store.get_values(workspace, v.type)
|
||||||
|
|
||||||
values = map(
|
values = map(
|
||||||
lambda x: ConfigValue(
|
lambda x: ConfigValue(
|
||||||
|
|
@ -94,9 +96,8 @@ class Configuration:
|
||||||
values = values,
|
values = values,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_delete(self, v):
|
async def handle_delete(self, v, workspace):
|
||||||
|
|
||||||
workspace = v.workspace
|
|
||||||
types = list(set(k.type for k in v.keys))
|
types = list(set(k.type for k in v.keys))
|
||||||
|
|
||||||
for k in v.keys:
|
for k in v.keys:
|
||||||
|
|
@ -104,14 +105,22 @@ class Configuration:
|
||||||
|
|
||||||
await self.inc_version()
|
await self.inc_version()
|
||||||
|
|
||||||
await self.push(changes={t: [workspace] for t in types})
|
workspace_changes = None
|
||||||
|
if workspace == WORKSPACES_NAMESPACE and WORKSPACE_TYPE in types:
|
||||||
|
deleted = [k.key for k in v.keys if k.type == WORKSPACE_TYPE]
|
||||||
|
if deleted:
|
||||||
|
workspace_changes = WorkspaceChanges(deleted=deleted)
|
||||||
|
|
||||||
|
await self.push(
|
||||||
|
changes={t: [workspace] for t in types},
|
||||||
|
workspace_changes=workspace_changes,
|
||||||
|
)
|
||||||
|
|
||||||
return ConfigResponse(
|
return ConfigResponse(
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_put(self, v):
|
async def handle_put(self, v, workspace):
|
||||||
|
|
||||||
workspace = v.workspace
|
|
||||||
types = list(set(k.type for k in v.values))
|
types = list(set(k.type for k in v.values))
|
||||||
|
|
||||||
for k in v.values:
|
for k in v.values:
|
||||||
|
|
@ -121,11 +130,49 @@ class Configuration:
|
||||||
|
|
||||||
await self.inc_version()
|
await self.inc_version()
|
||||||
|
|
||||||
await self.push(changes={t: [workspace] for t in types})
|
workspace_changes = None
|
||||||
|
if workspace == WORKSPACES_NAMESPACE and WORKSPACE_TYPE in types:
|
||||||
|
created = [k.key for k in v.values if k.type == WORKSPACE_TYPE]
|
||||||
|
if created:
|
||||||
|
workspace_changes = WorkspaceChanges(created=created)
|
||||||
|
|
||||||
|
await self.push(
|
||||||
|
changes={t: [workspace] for t in types},
|
||||||
|
workspace_changes=workspace_changes,
|
||||||
|
)
|
||||||
|
|
||||||
return ConfigResponse(
|
return ConfigResponse(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def provision_from_template(self, workspace):
|
||||||
|
"""Copy all config from __template__ into a new workspace,
|
||||||
|
skipping keys that already exist (upsert-missing)."""
|
||||||
|
|
||||||
|
template = await self.get_config(TEMPLATE_WORKSPACE)
|
||||||
|
|
||||||
|
if not template:
|
||||||
|
logger.info(
|
||||||
|
f"No template config to provision for {workspace}"
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
existing_types = await self.get_config(workspace)
|
||||||
|
|
||||||
|
written = 0
|
||||||
|
for type_name, entries in template.items():
|
||||||
|
existing_keys = set(existing_types.get(type_name, {}).keys())
|
||||||
|
for key, value in entries.items():
|
||||||
|
if key not in existing_keys:
|
||||||
|
await self.table_store.put_config(
|
||||||
|
workspace, type_name, key, value
|
||||||
|
)
|
||||||
|
written += 1
|
||||||
|
|
||||||
|
if written > 0:
|
||||||
|
await self.inc_version()
|
||||||
|
|
||||||
|
return written
|
||||||
|
|
||||||
async def get_config(self, workspace):
|
async def get_config(self, workspace):
|
||||||
|
|
||||||
table = await self.table_store.get_all_for_workspace(workspace)
|
table = await self.table_store.get_all_for_workspace(workspace)
|
||||||
|
|
@ -139,62 +186,87 @@ class Configuration:
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
async def handle_config(self, v):
|
async def handle_config(self, v, workspace):
|
||||||
|
|
||||||
config = await self.get_config(v.workspace)
|
config = await self.get_config(workspace)
|
||||||
|
|
||||||
return ConfigResponse(
|
return ConfigResponse(
|
||||||
version = await self.get_version(),
|
version = await self.get_version(),
|
||||||
config = config,
|
config = config,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle(self, msg):
|
async def handle_workspace(self, msg, workspace):
|
||||||
|
"""Handle workspace-scoped config operations.
|
||||||
|
Workspace is provided by queue infrastructure."""
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Handling config message: {msg.operation} "
|
f"Handling workspace config message: {msg.operation} "
|
||||||
f"workspace={msg.workspace}"
|
f"workspace={workspace}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# getvalues-all-ws spans all workspaces, so no workspace
|
|
||||||
# required; everything else is workspace-scoped.
|
|
||||||
if msg.operation != "getvalues-all-ws" and not msg.workspace:
|
|
||||||
return ConfigResponse(
|
|
||||||
error=Error(
|
|
||||||
type = "bad-request",
|
|
||||||
message = "Workspace is required"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if msg.operation == "get":
|
if msg.operation == "get":
|
||||||
|
resp = await self.handle_get(msg, workspace)
|
||||||
resp = await self.handle_get(msg)
|
|
||||||
|
|
||||||
elif msg.operation == "list":
|
elif msg.operation == "list":
|
||||||
|
resp = await self.handle_list(msg, workspace)
|
||||||
resp = await self.handle_list(msg)
|
|
||||||
|
|
||||||
elif msg.operation == "getvalues":
|
elif msg.operation == "getvalues":
|
||||||
|
resp = await self.handle_getvalues(msg, workspace)
|
||||||
resp = await self.handle_getvalues(msg)
|
|
||||||
|
|
||||||
elif msg.operation == "getvalues-all-ws":
|
|
||||||
|
|
||||||
resp = await self.handle_getvalues_all_ws(msg)
|
|
||||||
|
|
||||||
elif msg.operation == "delete":
|
elif msg.operation == "delete":
|
||||||
|
resp = await self.handle_delete(msg, workspace)
|
||||||
resp = await self.handle_delete(msg)
|
|
||||||
|
|
||||||
elif msg.operation == "put":
|
elif msg.operation == "put":
|
||||||
|
resp = await self.handle_put(msg, workspace)
|
||||||
resp = await self.handle_put(msg)
|
|
||||||
|
|
||||||
elif msg.operation == "config":
|
elif msg.operation == "config":
|
||||||
|
resp = await self.handle_config(msg, workspace)
|
||||||
resp = await self.handle_config(msg)
|
|
||||||
|
else:
|
||||||
|
resp = ConfigResponse(
|
||||||
|
error=Error(
|
||||||
|
type = "bad-operation",
|
||||||
|
message = "Bad operation"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
async def handle_system(self, msg):
|
||||||
|
"""Handle system-level config operations.
|
||||||
|
Workspace, when needed, comes from message body."""
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Handling system config message: {msg.operation} "
|
||||||
|
f"workspace={msg.workspace}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if msg.operation == "getvalues-all-ws":
|
||||||
|
resp = await self.handle_getvalues_all_ws(msg)
|
||||||
|
|
||||||
|
elif msg.operation in ("get", "list", "getvalues", "delete",
|
||||||
|
"put", "config"):
|
||||||
|
|
||||||
|
if not msg.workspace:
|
||||||
|
return ConfigResponse(
|
||||||
|
error=Error(
|
||||||
|
type = "bad-request",
|
||||||
|
message = "Workspace is required"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = {
|
||||||
|
"get": self.handle_get,
|
||||||
|
"list": self.handle_list,
|
||||||
|
"getvalues": self.handle_getvalues,
|
||||||
|
"delete": self.handle_delete,
|
||||||
|
"put": self.handle_put,
|
||||||
|
"config": self.handle_config,
|
||||||
|
}[msg.operation]
|
||||||
|
|
||||||
|
resp = await handler(msg, msg.workspace)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
resp = ConfigResponse(
|
resp = ConfigResponse(
|
||||||
error=Error(
|
error=Error(
|
||||||
type = "bad-operation",
|
type = "bad-operation",
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,30 @@
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Config service. Manages system global configuration state
|
Config service. Manages system global configuration state.
|
||||||
|
|
||||||
|
Operates a dual-queue regime:
|
||||||
|
- System queue (config-request): handles cross-workspace operations like
|
||||||
|
getvalues-all-ws and bootstrapper put/delete on __workspaces__.
|
||||||
|
The gateway NEVER routes to this queue.
|
||||||
|
- Per-workspace queues (config-request:<workspace>): handles
|
||||||
|
workspace-scoped operations where workspace identity comes from
|
||||||
|
queue infrastructure, not message body.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from trustgraph.schema import Error
|
from trustgraph.schema import Error
|
||||||
|
|
||||||
from trustgraph.schema import ConfigRequest, ConfigResponse, ConfigPush
|
from trustgraph.schema import ConfigRequest, ConfigResponse, ConfigPush
|
||||||
|
from trustgraph.schema import WorkspaceChanges
|
||||||
from trustgraph.schema import config_request_queue, config_response_queue
|
from trustgraph.schema import config_request_queue, config_response_queue
|
||||||
from trustgraph.schema import config_push_queue
|
from trustgraph.schema import config_push_queue
|
||||||
|
|
||||||
from trustgraph.base import AsyncProcessor, Consumer, Producer
|
from trustgraph.base import AsyncProcessor, Consumer, Producer
|
||||||
from trustgraph.base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
from trustgraph.base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||||
|
|
||||||
from . config import Configuration
|
from . config import Configuration, WORKSPACES_NAMESPACE, WORKSPACE_TYPE
|
||||||
|
|
||||||
from ... base import ProcessorMetrics, ConsumerMetrics, ProducerMetrics
|
from ... base import ProcessorMetrics, ConsumerMetrics, ProducerMetrics
|
||||||
from ... base import Consumer, Producer
|
from ... base import Consumer, Producer
|
||||||
|
|
@ -39,6 +49,11 @@ def is_reserved_workspace(workspace):
|
||||||
"""
|
"""
|
||||||
return workspace.startswith("_")
|
return workspace.startswith("_")
|
||||||
|
|
||||||
|
|
||||||
|
def workspace_queue(base_queue, workspace):
|
||||||
|
return f"{base_queue}:{workspace}"
|
||||||
|
|
||||||
|
|
||||||
default_config_request_queue = config_request_queue
|
default_config_request_queue = config_request_queue
|
||||||
default_config_response_queue = config_response_queue
|
default_config_response_queue = config_response_queue
|
||||||
default_config_push_queue = config_push_queue
|
default_config_push_queue = config_push_queue
|
||||||
|
|
@ -52,7 +67,7 @@ class Processor(AsyncProcessor):
|
||||||
config_request_queue = params.get(
|
config_request_queue = params.get(
|
||||||
"config_request_queue", default_config_request_queue
|
"config_request_queue", default_config_request_queue
|
||||||
)
|
)
|
||||||
config_response_queue = params.get(
|
self.config_response_queue_base = params.get(
|
||||||
"config_response_queue", default_config_response_queue
|
"config_response_queue", default_config_response_queue
|
||||||
)
|
)
|
||||||
config_push_queue = params.get(
|
config_push_queue = params.get(
|
||||||
|
|
@ -64,7 +79,7 @@ class Processor(AsyncProcessor):
|
||||||
cassandra_password = params.get("cassandra_password")
|
cassandra_password = params.get("cassandra_password")
|
||||||
|
|
||||||
# Resolve configuration with environment variable fallback
|
# Resolve configuration with environment variable fallback
|
||||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
hosts, username, password, keyspace, replication_factor = resolve_cassandra_config(
|
||||||
host=cassandra_host,
|
host=cassandra_host,
|
||||||
username=cassandra_username,
|
username=cassandra_username,
|
||||||
password=cassandra_password,
|
password=cassandra_password,
|
||||||
|
|
@ -99,23 +114,23 @@ class Processor(AsyncProcessor):
|
||||||
processor = self.id, flow = None, name = "config-push"
|
processor = self.id, flow = None, name = "config-push"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.config_request_topic = config_request_queue
|
self.config_request_queue_base = config_request_queue
|
||||||
self.config_request_subscriber = id
|
self.config_request_subscriber = id
|
||||||
|
|
||||||
self.config_request_consumer = Consumer(
|
self.system_consumer = Consumer(
|
||||||
taskgroup = self.taskgroup,
|
taskgroup = self.taskgroup,
|
||||||
backend = self.pubsub,
|
backend = self.pubsub,
|
||||||
flow = None,
|
flow = None,
|
||||||
topic = config_request_queue,
|
topic = config_request_queue,
|
||||||
subscriber = id,
|
subscriber = id,
|
||||||
schema = ConfigRequest,
|
schema = ConfigRequest,
|
||||||
handler = self.on_config_request,
|
handler = self.on_system_config_request,
|
||||||
metrics = config_request_metrics,
|
metrics = config_request_metrics,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.config_response_producer = Producer(
|
self.config_response_producer = Producer(
|
||||||
backend = self.pubsub,
|
backend = self.pubsub,
|
||||||
topic = config_response_queue,
|
topic = self.config_response_queue_base,
|
||||||
schema = ConfigResponse,
|
schema = ConfigResponse,
|
||||||
metrics = config_response_metrics,
|
metrics = config_response_metrics,
|
||||||
)
|
)
|
||||||
|
|
@ -132,23 +147,145 @@ class Processor(AsyncProcessor):
|
||||||
username = self.cassandra_username,
|
username = self.cassandra_username,
|
||||||
password = self.cassandra_password,
|
password = self.cassandra_password,
|
||||||
keyspace = keyspace,
|
keyspace = keyspace,
|
||||||
|
replication_factor = replication_factor,
|
||||||
push = self.push
|
push = self.push
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.workspace_consumers = {}
|
||||||
|
|
||||||
|
self.register_workspace_handler(self._handle_workspace_changes)
|
||||||
|
|
||||||
logger.info("Config service initialized")
|
logger.info("Config service initialized")
|
||||||
|
|
||||||
|
async def _discover_workspaces(self):
|
||||||
|
logger.info("Discovering workspaces from Cassandra...")
|
||||||
|
try:
|
||||||
|
workspaces = await self.config.table_store.get_keys(
|
||||||
|
WORKSPACES_NAMESPACE, WORKSPACE_TYPE
|
||||||
|
)
|
||||||
|
logger.info(f"Discovered workspaces: {workspaces}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Workspace discovery failed: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
for workspace_id in workspaces:
|
||||||
|
if workspace_id not in self.workspace_consumers:
|
||||||
|
await self._add_workspace_consumer(workspace_id)
|
||||||
|
|
||||||
|
async def _handle_workspace_changes(self, workspace_changes):
|
||||||
|
for workspace_id in workspace_changes.created:
|
||||||
|
if workspace_id not in self.workspace_consumers:
|
||||||
|
logger.info(f"Workspace created: {workspace_id}")
|
||||||
|
await self._add_workspace_consumer(workspace_id)
|
||||||
|
await self._provision_workspace(workspace_id)
|
||||||
|
|
||||||
|
for workspace_id in workspace_changes.deleted:
|
||||||
|
if workspace_id in self.workspace_consumers:
|
||||||
|
logger.info(f"Workspace deleted: {workspace_id}")
|
||||||
|
await self._remove_workspace_consumer(workspace_id)
|
||||||
|
|
||||||
|
async def _provision_workspace(self, workspace_id):
|
||||||
|
try:
|
||||||
|
written = await self.config.provision_from_template(
|
||||||
|
workspace_id
|
||||||
|
)
|
||||||
|
if written > 0:
|
||||||
|
logger.info(
|
||||||
|
f"Provisioned workspace {workspace_id} with "
|
||||||
|
f"{written} entries from template"
|
||||||
|
)
|
||||||
|
# Notify other services about the new config
|
||||||
|
types = {}
|
||||||
|
template = await self.config.get_config(workspace_id)
|
||||||
|
for t in template:
|
||||||
|
types[t] = [workspace_id]
|
||||||
|
await self.push(changes=types)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to provision workspace {workspace_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _add_workspace_consumer(self, workspace_id):
|
||||||
|
req_queue = workspace_queue(
|
||||||
|
self.config_request_queue_base, workspace_id,
|
||||||
|
)
|
||||||
|
resp_queue = workspace_queue(
|
||||||
|
self.config_response_queue_base, workspace_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.pubsub.ensure_topic(req_queue)
|
||||||
|
await self.pubsub.ensure_topic(resp_queue)
|
||||||
|
|
||||||
|
response_producer = Producer(
|
||||||
|
backend=self.pubsub,
|
||||||
|
topic=resp_queue,
|
||||||
|
schema=ConfigResponse,
|
||||||
|
metrics=ProducerMetrics(
|
||||||
|
processor=self.id, flow=None,
|
||||||
|
name=f"config-response-{workspace_id}",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
consumer = Consumer(
|
||||||
|
taskgroup=self.taskgroup,
|
||||||
|
backend=self.pubsub,
|
||||||
|
flow=None,
|
||||||
|
topic=req_queue,
|
||||||
|
subscriber=self.id,
|
||||||
|
schema=ConfigRequest,
|
||||||
|
handler=partial(
|
||||||
|
self.on_workspace_config_request,
|
||||||
|
workspace=workspace_id,
|
||||||
|
),
|
||||||
|
metrics=ConsumerMetrics(
|
||||||
|
processor=self.id, flow=None,
|
||||||
|
name=f"config-request-{workspace_id}",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
await response_producer.start()
|
||||||
|
await consumer.start()
|
||||||
|
|
||||||
|
self.workspace_consumers[workspace_id] = {
|
||||||
|
"consumer": consumer,
|
||||||
|
"response": response_producer,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Subscribed to workspace config queue: {workspace_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _remove_workspace_consumer(self, workspace_id):
|
||||||
|
clients = self.workspace_consumers.pop(workspace_id, None)
|
||||||
|
if clients:
|
||||||
|
for client in clients.values():
|
||||||
|
await client.stop()
|
||||||
|
logger.info(
|
||||||
|
f"Unsubscribed from workspace config queue: {workspace_id}"
|
||||||
|
)
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
|
|
||||||
await self.pubsub.ensure_topic(self.config_request_topic)
|
await self.pubsub.ensure_topic(self.config_request_queue_base)
|
||||||
|
await self.config_response_producer.start()
|
||||||
await self.push() # Startup poke: empty types = everything
|
await self.push() # Startup poke: empty types = everything
|
||||||
await self.config_request_consumer.start()
|
await self.system_consumer.start()
|
||||||
|
|
||||||
async def push(self, changes=None):
|
# Start the config push subscriber so we receive our own
|
||||||
|
# workspace change notifications.
|
||||||
|
await self.config_sub_task.start()
|
||||||
|
|
||||||
|
await self._discover_workspaces()
|
||||||
|
|
||||||
|
async def push(self, changes=None, workspace_changes=None):
|
||||||
|
|
||||||
# Suppress notifications from reserved workspaces (ids starting
|
# Suppress notifications from reserved workspaces (ids starting
|
||||||
# with "_", e.g. "__template__"). Stored config is preserved;
|
# with "_", e.g. "__template__") for regular config changes.
|
||||||
# only the broadcast is filtered. Keeps services oblivious to
|
# The __workspaces__ namespace is handled separately via
|
||||||
# template / bootstrap state.
|
# workspace_changes.
|
||||||
if changes:
|
if changes:
|
||||||
filtered = {}
|
filtered = {}
|
||||||
for type_name, workspaces in changes.items():
|
for type_name, workspaces in changes.items():
|
||||||
|
|
@ -165,16 +302,20 @@ class Processor(AsyncProcessor):
|
||||||
resp = ConfigPush(
|
resp = ConfigPush(
|
||||||
version = version,
|
version = version,
|
||||||
changes = changes or {},
|
changes = changes or {},
|
||||||
|
workspace_changes = workspace_changes,
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.config_push_producer.send(resp)
|
await self.config_push_producer.send(resp)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Pushed config poke version {version}, "
|
f"Pushed config poke version {version}, "
|
||||||
f"changes={resp.changes}"
|
f"changes={resp.changes}, "
|
||||||
|
f"workspace_changes={resp.workspace_changes}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def on_config_request(self, msg, consumer, flow):
|
async def on_workspace_config_request(
|
||||||
|
self, msg, consumer, flow, *, workspace
|
||||||
|
):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
|
@ -183,9 +324,44 @@ class Processor(AsyncProcessor):
|
||||||
# Sender-produced ID
|
# Sender-produced ID
|
||||||
id = msg.properties()["id"]
|
id = msg.properties()["id"]
|
||||||
|
|
||||||
logger.debug(f"Handling config request {id}...")
|
logger.debug(
|
||||||
|
f"Handling workspace config request {id} "
|
||||||
|
f"workspace={workspace}..."
|
||||||
|
)
|
||||||
|
|
||||||
resp = await self.config.handle(v)
|
producer = self.workspace_consumers[workspace]["response"]
|
||||||
|
|
||||||
|
resp = await self.config.handle_workspace(v, workspace)
|
||||||
|
|
||||||
|
await producer.send(
|
||||||
|
resp, properties={"id": id}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
|
||||||
|
resp = ConfigResponse(
|
||||||
|
error=Error(
|
||||||
|
type = "config-error",
|
||||||
|
message = str(e),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
await producer.send(
|
||||||
|
resp, properties={"id": id}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_system_config_request(self, msg, consumer, flow):
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
v = msg.value()
|
||||||
|
|
||||||
|
# Sender-produced ID
|
||||||
|
id = msg.properties()["id"]
|
||||||
|
|
||||||
|
logger.debug(f"Handling system config request {id}...")
|
||||||
|
|
||||||
|
resp = await self.config.handle_system(v)
|
||||||
|
|
||||||
await self.config_response_producer.send(
|
await self.config_response_producer.send(
|
||||||
resp, properties={"id": id}
|
resp, properties={"id": id}
|
||||||
|
|
@ -228,4 +404,3 @@ class Processor(AsyncProcessor):
|
||||||
def run():
|
def run():
|
||||||
|
|
||||||
Processor.launch(default_ident, __doc__)
|
Processor.launch(default_ident, __doc__)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,12 +29,12 @@ class KnowledgeManager:
|
||||||
self.background_task = None
|
self.background_task = None
|
||||||
self.flow_config = flow_config
|
self.flow_config = flow_config
|
||||||
|
|
||||||
async def delete_kg_core(self, request, respond):
|
async def delete_kg_core(self, request, respond, workspace):
|
||||||
|
|
||||||
logger.info("Deleting knowledge core...")
|
logger.info("Deleting knowledge core...")
|
||||||
|
|
||||||
await self.table_store.delete_kg_core(
|
await self.table_store.delete_kg_core(
|
||||||
request.workspace, request.id
|
workspace, request.id
|
||||||
)
|
)
|
||||||
|
|
||||||
await respond(
|
await respond(
|
||||||
|
|
@ -47,7 +47,7 @@ class KnowledgeManager:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_kg_core(self, request, respond):
|
async def get_kg_core(self, request, respond, workspace):
|
||||||
|
|
||||||
logger.info("Getting knowledge core...")
|
logger.info("Getting knowledge core...")
|
||||||
|
|
||||||
|
|
@ -62,9 +62,8 @@ class KnowledgeManager:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove doc table row
|
|
||||||
await self.table_store.get_triples(
|
await self.table_store.get_triples(
|
||||||
request.workspace,
|
workspace,
|
||||||
request.id,
|
request.id,
|
||||||
publish_triples,
|
publish_triples,
|
||||||
)
|
)
|
||||||
|
|
@ -80,9 +79,8 @@ class KnowledgeManager:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove doc table row
|
|
||||||
await self.table_store.get_graph_embeddings(
|
await self.table_store.get_graph_embeddings(
|
||||||
request.workspace,
|
workspace,
|
||||||
request.id,
|
request.id,
|
||||||
publish_ge,
|
publish_ge,
|
||||||
)
|
)
|
||||||
|
|
@ -99,9 +97,9 @@ class KnowledgeManager:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_kg_cores(self, request, respond):
|
async def list_kg_cores(self, request, respond, workspace):
|
||||||
|
|
||||||
ids = await self.table_store.list_kg_cores(request.workspace)
|
ids = await self.table_store.list_kg_cores(workspace)
|
||||||
|
|
||||||
await respond(
|
await respond(
|
||||||
KnowledgeResponse(
|
KnowledgeResponse(
|
||||||
|
|
@ -113,9 +111,7 @@ class KnowledgeManager:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def put_kg_core(self, request, respond):
|
async def put_kg_core(self, request, respond, workspace):
|
||||||
|
|
||||||
workspace = request.workspace
|
|
||||||
|
|
||||||
if request.triples:
|
if request.triples:
|
||||||
await self.table_store.add_triples(workspace, request.triples)
|
await self.table_store.add_triples(workspace, request.triples)
|
||||||
|
|
@ -135,20 +131,18 @@ class KnowledgeManager:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def load_kg_core(self, request, respond):
|
async def load_kg_core(self, request, respond, workspace):
|
||||||
|
|
||||||
if self.background_task is None:
|
if self.background_task is None:
|
||||||
self.background_task = asyncio.create_task(
|
self.background_task = asyncio.create_task(
|
||||||
self.core_loader()
|
self.core_loader()
|
||||||
)
|
)
|
||||||
# Wait for it to start (yuck)
|
|
||||||
# await asyncio.sleep(0.5)
|
|
||||||
|
|
||||||
await self.loader_queue.put((request, respond))
|
await self.loader_queue.put((request, respond, workspace))
|
||||||
|
|
||||||
# Not sending a response, the loader thread can do that
|
# Not sending a response, the loader thread can do that
|
||||||
|
|
||||||
async def unload_kg_core(self, request, respond):
|
async def unload_kg_core(self, request, respond, workspace):
|
||||||
|
|
||||||
await respond(
|
await respond(
|
||||||
KnowledgeResponse(
|
KnowledgeResponse(
|
||||||
|
|
@ -169,7 +163,7 @@ class KnowledgeManager:
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
logger.debug("Waiting for next load...")
|
logger.debug("Waiting for next load...")
|
||||||
request, respond = await self.loader_queue.get()
|
request, respond, workspace = await self.loader_queue.get()
|
||||||
|
|
||||||
logger.info(f"Loading knowledge: {request.id}")
|
logger.info(f"Loading knowledge: {request.id}")
|
||||||
|
|
||||||
|
|
@ -181,7 +175,6 @@ class KnowledgeManager:
|
||||||
if request.flow is None:
|
if request.flow is None:
|
||||||
raise RuntimeError("Flow ID must be specified")
|
raise RuntimeError("Flow ID must be specified")
|
||||||
|
|
||||||
workspace = request.workspace
|
|
||||||
ws_flows = self.flow_config.flows.get(workspace, {})
|
ws_flows = self.flow_config.flows.get(workspace, {})
|
||||||
if request.flow not in ws_flows:
|
if request.flow not in ws_flows:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|
@ -263,9 +256,8 @@ class KnowledgeManager:
|
||||||
|
|
||||||
logger.debug("Publishing triples...")
|
logger.debug("Publishing triples...")
|
||||||
|
|
||||||
# Remove doc table row
|
|
||||||
await self.table_store.get_triples(
|
await self.table_store.get_triples(
|
||||||
request.workspace,
|
workspace,
|
||||||
request.id,
|
request.id,
|
||||||
publish_triples,
|
publish_triples,
|
||||||
)
|
)
|
||||||
|
|
@ -278,9 +270,8 @@ class KnowledgeManager:
|
||||||
|
|
||||||
logger.debug("Publishing graph embeddings...")
|
logger.debug("Publishing graph embeddings...")
|
||||||
|
|
||||||
# Remove doc table row
|
|
||||||
await self.table_store.get_graph_embeddings(
|
await self.table_store.get_graph_embeddings(
|
||||||
request.workspace,
|
workspace,
|
||||||
request.id,
|
request.id,
|
||||||
publish_ge,
|
publish_ge,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from .. base import AsyncProcessor, Consumer, Producer, Publisher, Subscriber
|
from .. base import WorkspaceProcessor, Consumer, Producer, Publisher, Subscriber
|
||||||
from .. base import ConsumerMetrics, ProducerMetrics
|
from .. base import ConsumerMetrics, ProducerMetrics
|
||||||
from .. base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
from .. base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||||
|
|
||||||
|
|
@ -33,17 +33,22 @@ default_knowledge_response_queue = knowledge_response_queue
|
||||||
|
|
||||||
default_cassandra_host = "cassandra"
|
default_cassandra_host = "cassandra"
|
||||||
|
|
||||||
class Processor(AsyncProcessor):
|
|
||||||
|
def workspace_queue(base_queue, workspace):
|
||||||
|
return f"{base_queue}:{workspace}"
|
||||||
|
|
||||||
|
|
||||||
|
class Processor(WorkspaceProcessor):
|
||||||
|
|
||||||
def __init__(self, **params):
|
def __init__(self, **params):
|
||||||
|
|
||||||
id = params.get("id")
|
id = params.get("id")
|
||||||
|
|
||||||
knowledge_request_queue = params.get(
|
self.knowledge_request_queue_base = params.get(
|
||||||
"knowledge_request_queue", default_knowledge_request_queue
|
"knowledge_request_queue", default_knowledge_request_queue
|
||||||
)
|
)
|
||||||
|
|
||||||
knowledge_response_queue = params.get(
|
self.knowledge_response_queue_base = params.get(
|
||||||
"knowledge_response_queue", default_knowledge_response_queue
|
"knowledge_response_queue", default_knowledge_response_queue
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -51,78 +56,106 @@ class Processor(AsyncProcessor):
|
||||||
cassandra_username = params.get("cassandra_username")
|
cassandra_username = params.get("cassandra_username")
|
||||||
cassandra_password = params.get("cassandra_password")
|
cassandra_password = params.get("cassandra_password")
|
||||||
|
|
||||||
# Resolve configuration with environment variable fallback
|
hosts, username, password, keyspace, replication_factor = resolve_cassandra_config(
|
||||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
|
||||||
host=cassandra_host,
|
host=cassandra_host,
|
||||||
username=cassandra_username,
|
username=cassandra_username,
|
||||||
password=cassandra_password,
|
password=cassandra_password,
|
||||||
default_keyspace="knowledge"
|
default_keyspace="knowledge"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store resolved configuration
|
|
||||||
self.cassandra_host = hosts
|
self.cassandra_host = hosts
|
||||||
self.cassandra_username = username
|
self.cassandra_username = username
|
||||||
self.cassandra_password = password
|
self.cassandra_password = password
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
**params | {
|
**params | {
|
||||||
"knowledge_request_queue": knowledge_request_queue,
|
"knowledge_request_queue": self.knowledge_request_queue_base,
|
||||||
"knowledge_response_queue": knowledge_response_queue,
|
"knowledge_response_queue": self.knowledge_response_queue_base,
|
||||||
"cassandra_host": self.cassandra_host,
|
"cassandra_host": self.cassandra_host,
|
||||||
"cassandra_username": self.cassandra_username,
|
"cassandra_username": self.cassandra_username,
|
||||||
"cassandra_password": self.cassandra_password,
|
"cassandra_password": self.cassandra_password,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
knowledge_request_metrics = ConsumerMetrics(
|
|
||||||
processor = self.id, flow = None, name = "knowledge-request"
|
|
||||||
)
|
|
||||||
|
|
||||||
knowledge_response_metrics = ProducerMetrics(
|
|
||||||
processor = self.id, flow = None, name = "knowledge-response"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.knowledge_request_topic = knowledge_request_queue
|
|
||||||
self.knowledge_request_subscriber = id
|
|
||||||
|
|
||||||
self.knowledge_request_consumer = Consumer(
|
|
||||||
taskgroup = self.taskgroup,
|
|
||||||
backend = self.pubsub,
|
|
||||||
flow = None,
|
|
||||||
topic = knowledge_request_queue,
|
|
||||||
subscriber = id,
|
|
||||||
schema = KnowledgeRequest,
|
|
||||||
handler = self.on_knowledge_request,
|
|
||||||
metrics = knowledge_request_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.knowledge_response_producer = Producer(
|
|
||||||
backend = self.pubsub,
|
|
||||||
topic = knowledge_response_queue,
|
|
||||||
schema = KnowledgeResponse,
|
|
||||||
metrics = knowledge_response_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.knowledge = KnowledgeManager(
|
self.knowledge = KnowledgeManager(
|
||||||
cassandra_host = self.cassandra_host,
|
cassandra_host = self.cassandra_host,
|
||||||
cassandra_username = self.cassandra_username,
|
cassandra_username = self.cassandra_username,
|
||||||
cassandra_password = self.cassandra_password,
|
cassandra_password = self.cassandra_password,
|
||||||
keyspace = keyspace,
|
keyspace = keyspace,
|
||||||
flow_config = self,
|
flow_config = self,
|
||||||
|
replication_factor = replication_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.register_config_handler(self.on_knowledge_config, types=["flow"])
|
self.register_config_handler(self.on_knowledge_config, types=["flow"])
|
||||||
|
|
||||||
self.flows = {}
|
self.flows = {}
|
||||||
|
|
||||||
|
self.workspace_consumers = {}
|
||||||
|
|
||||||
logger.info("Knowledge service initialized")
|
logger.info("Knowledge service initialized")
|
||||||
|
|
||||||
|
async def on_workspace_created(self, workspace):
|
||||||
|
|
||||||
|
if workspace in self.workspace_consumers:
|
||||||
|
return
|
||||||
|
|
||||||
|
req_queue = workspace_queue(
|
||||||
|
self.knowledge_request_queue_base, workspace,
|
||||||
|
)
|
||||||
|
resp_queue = workspace_queue(
|
||||||
|
self.knowledge_response_queue_base, workspace,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.pubsub.ensure_topic(req_queue)
|
||||||
|
await self.pubsub.ensure_topic(resp_queue)
|
||||||
|
|
||||||
|
response_producer = Producer(
|
||||||
|
backend=self.pubsub,
|
||||||
|
topic=resp_queue,
|
||||||
|
schema=KnowledgeResponse,
|
||||||
|
metrics=ProducerMetrics(
|
||||||
|
processor=self.id, flow=None,
|
||||||
|
name=f"knowledge-response-{workspace}",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
consumer = Consumer(
|
||||||
|
taskgroup=self.taskgroup,
|
||||||
|
backend=self.pubsub,
|
||||||
|
flow=None,
|
||||||
|
topic=req_queue,
|
||||||
|
subscriber=self.id,
|
||||||
|
schema=KnowledgeRequest,
|
||||||
|
handler=partial(
|
||||||
|
self.on_knowledge_request, workspace=workspace,
|
||||||
|
),
|
||||||
|
metrics=ConsumerMetrics(
|
||||||
|
processor=self.id, flow=None,
|
||||||
|
name=f"knowledge-request-{workspace}",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
await response_producer.start()
|
||||||
|
await consumer.start()
|
||||||
|
|
||||||
|
self.workspace_consumers[workspace] = {
|
||||||
|
"consumer": consumer,
|
||||||
|
"response": response_producer,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Subscribed to workspace queue: {workspace}")
|
||||||
|
|
||||||
|
async def on_workspace_deleted(self, workspace):
|
||||||
|
|
||||||
|
clients = self.workspace_consumers.pop(workspace, None)
|
||||||
|
if clients:
|
||||||
|
for client in clients.values():
|
||||||
|
await client.stop()
|
||||||
|
logger.info(f"Unsubscribed from workspace queue: {workspace}")
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
|
|
||||||
await self.pubsub.ensure_topic(self.knowledge_request_topic)
|
|
||||||
await super(Processor, self).start()
|
await super(Processor, self).start()
|
||||||
await self.knowledge_request_consumer.start()
|
|
||||||
await self.knowledge_response_producer.start()
|
|
||||||
|
|
||||||
async def on_knowledge_config(self, workspace, config, version):
|
async def on_knowledge_config(self, workspace, config, version):
|
||||||
|
|
||||||
|
|
@ -140,7 +173,7 @@ class Processor(AsyncProcessor):
|
||||||
|
|
||||||
logger.debug(f"Flows for {workspace}: {self.flows[workspace]}")
|
logger.debug(f"Flows for {workspace}: {self.flows[workspace]}")
|
||||||
|
|
||||||
async def process_request(self, v, id):
|
async def process_request(self, v, id, workspace, producer):
|
||||||
|
|
||||||
if v.operation is None:
|
if v.operation is None:
|
||||||
raise RequestError("Null operation")
|
raise RequestError("Null operation")
|
||||||
|
|
@ -160,12 +193,12 @@ class Processor(AsyncProcessor):
|
||||||
raise RequestError(f"Invalid operation: {v.operation}")
|
raise RequestError(f"Invalid operation: {v.operation}")
|
||||||
|
|
||||||
async def respond(x):
|
async def respond(x):
|
||||||
await self.knowledge_response_producer.send(
|
await producer.send(
|
||||||
x, { "id": id }
|
x, { "id": id }
|
||||||
)
|
)
|
||||||
return await impls[v.operation](v, respond)
|
return await impls[v.operation](v, respond, workspace)
|
||||||
|
|
||||||
async def on_knowledge_request(self, msg, consumer, flow):
|
async def on_knowledge_request(self, msg, consumer, flow, *, workspace):
|
||||||
|
|
||||||
v = msg.value()
|
v = msg.value()
|
||||||
|
|
||||||
|
|
@ -175,11 +208,13 @@ class Processor(AsyncProcessor):
|
||||||
|
|
||||||
logger.info(f"Handling knowledge input {id}...")
|
logger.info(f"Handling knowledge input {id}...")
|
||||||
|
|
||||||
|
producer = self.workspace_consumers[workspace]["response"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# We don't send a response back here, the processing
|
# We don't send a response back here, the processing
|
||||||
# implementation sends whatever it needs to send.
|
# implementation sends whatever it needs to send.
|
||||||
await self.process_request(v, id)
|
await self.process_request(v, id, workspace, producer)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -191,7 +226,7 @@ class Processor(AsyncProcessor):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.knowledge_response_producer.send(
|
await producer.send(
|
||||||
resp, properties={"id": id}
|
resp, properties={"id": id}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -204,7 +239,7 @@ class Processor(AsyncProcessor):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.knowledge_response_producer.send(
|
await producer.send(
|
||||||
resp, properties={"id": id}
|
resp, properties={"id": id}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -215,7 +250,7 @@ class Processor(AsyncProcessor):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
||||||
AsyncProcessor.add_args(parser)
|
WorkspaceProcessor.add_args(parser)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--knowledge-request-queue',
|
'--knowledge-request-queue',
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,8 @@ import os
|
||||||
from mistralai import Mistral
|
from mistralai import Mistral
|
||||||
|
|
||||||
from ... schema import Document, TextDocument, Metadata
|
from ... schema import Document, TextDocument, Metadata
|
||||||
from ... schema import librarian_request_queue, librarian_response_queue
|
|
||||||
from ... schema import Triples
|
from ... schema import Triples
|
||||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient
|
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianSpec
|
||||||
|
|
||||||
from ... provenance import (
|
from ... provenance import (
|
||||||
document_uri, page_uri as make_page_uri, derived_entity_triples,
|
document_uri, page_uri as make_page_uri, derived_entity_triples,
|
||||||
|
|
@ -36,9 +35,6 @@ COMPONENT_VERSION = "1.0.0"
|
||||||
default_ident = "document-decoder"
|
default_ident = "document-decoder"
|
||||||
default_api_key = os.getenv("MISTRAL_TOKEN")
|
default_api_key = os.getenv("MISTRAL_TOKEN")
|
||||||
|
|
||||||
default_librarian_request_queue = librarian_request_queue
|
|
||||||
default_librarian_response_queue = librarian_response_queue
|
|
||||||
|
|
||||||
pages_per_chunk = 5
|
pages_per_chunk = 5
|
||||||
|
|
||||||
def chunks(lst, n):
|
def chunks(lst, n):
|
||||||
|
|
@ -98,9 +94,8 @@ class Processor(FlowProcessor):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Librarian client
|
self.register_specification(
|
||||||
self.librarian = LibrarianClient(
|
LibrarianSpec()
|
||||||
id=id, backend=self.pubsub, taskgroup=self.taskgroup,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
|
|
@ -113,10 +108,6 @@ class Processor(FlowProcessor):
|
||||||
|
|
||||||
logger.info("Mistral OCR processor initialized")
|
logger.info("Mistral OCR processor initialized")
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
await super(Processor, self).start()
|
|
||||||
await self.librarian.start()
|
|
||||||
|
|
||||||
def ocr(self, blob):
|
def ocr(self, blob):
|
||||||
"""
|
"""
|
||||||
Run Mistral OCR on a PDF blob, returning per-page markdown strings.
|
Run Mistral OCR on a PDF blob, returning per-page markdown strings.
|
||||||
|
|
@ -198,9 +189,9 @@ class Processor(FlowProcessor):
|
||||||
|
|
||||||
# Check MIME type if fetching from librarian
|
# Check MIME type if fetching from librarian
|
||||||
if v.document_id:
|
if v.document_id:
|
||||||
doc_meta = await self.librarian.fetch_document_metadata(
|
doc_meta = await flow.librarian.fetch_document_metadata(
|
||||||
document_id=v.document_id,
|
document_id=v.document_id,
|
||||||
workspace=flow.workspace,
|
|
||||||
)
|
)
|
||||||
if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf":
|
if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf":
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
@ -213,9 +204,9 @@ class Processor(FlowProcessor):
|
||||||
# Get PDF content - fetch from librarian or use inline data
|
# Get PDF content - fetch from librarian or use inline data
|
||||||
if v.document_id:
|
if v.document_id:
|
||||||
logger.info(f"Fetching document {v.document_id} from librarian...")
|
logger.info(f"Fetching document {v.document_id} from librarian...")
|
||||||
content = await self.librarian.fetch_document_content(
|
content = await flow.librarian.fetch_document_content(
|
||||||
document_id=v.document_id,
|
document_id=v.document_id,
|
||||||
workspace=flow.workspace,
|
|
||||||
)
|
)
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
content = content.encode('utf-8')
|
content = content.encode('utf-8')
|
||||||
|
|
@ -240,10 +231,10 @@ class Processor(FlowProcessor):
|
||||||
page_content = markdown.encode("utf-8")
|
page_content = markdown.encode("utf-8")
|
||||||
|
|
||||||
# Save page as child document in librarian
|
# Save page as child document in librarian
|
||||||
await self.librarian.save_child_document(
|
await flow.librarian.save_child_document(
|
||||||
doc_id=page_doc_id,
|
doc_id=page_doc_id,
|
||||||
parent_id=source_doc_id,
|
parent_id=source_doc_id,
|
||||||
workspace=flow.workspace,
|
|
||||||
content=page_content,
|
content=page_content,
|
||||||
document_type="page",
|
document_type="page",
|
||||||
title=f"Page {page_num}",
|
title=f"Page {page_num}",
|
||||||
|
|
@ -297,18 +288,6 @@ class Processor(FlowProcessor):
|
||||||
help=f'Mistral API Key'
|
help=f'Mistral API Key'
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--librarian-request-queue',
|
|
||||||
default=default_librarian_request_queue,
|
|
||||||
help=f'Librarian request queue (default: {default_librarian_request_queue})',
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--librarian-response-queue',
|
|
||||||
default=default_librarian_response_queue,
|
|
||||||
help=f'Librarian response queue (default: {default_librarian_response_queue})',
|
|
||||||
)
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
|
|
||||||
Processor.launch(default_ident, __doc__)
|
Processor.launch(default_ident, __doc__)
|
||||||
|
|
|
||||||
|
|
@ -12,9 +12,8 @@ import tempfile
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
from ... schema import Document, TextDocument, Metadata
|
from ... schema import Document, TextDocument, Metadata
|
||||||
from ... schema import librarian_request_queue, librarian_response_queue
|
|
||||||
from ... schema import Triples
|
from ... schema import Triples
|
||||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient
|
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianSpec
|
||||||
|
|
||||||
PyPDFLoader = None
|
PyPDFLoader = None
|
||||||
|
|
||||||
|
|
@ -32,9 +31,6 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
default_ident = "document-decoder"
|
default_ident = "document-decoder"
|
||||||
|
|
||||||
default_librarian_request_queue = librarian_request_queue
|
|
||||||
default_librarian_response_queue = librarian_response_queue
|
|
||||||
|
|
||||||
|
|
||||||
class Processor(FlowProcessor):
|
class Processor(FlowProcessor):
|
||||||
|
|
||||||
|
|
@ -70,17 +66,12 @@ class Processor(FlowProcessor):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Librarian client
|
self.register_specification(
|
||||||
self.librarian = LibrarianClient(
|
LibrarianSpec()
|
||||||
id=id, backend=self.pubsub, taskgroup=self.taskgroup,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("PDF decoder initialized")
|
logger.info("PDF decoder initialized")
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
await super(Processor, self).start()
|
|
||||||
await self.librarian.start()
|
|
||||||
|
|
||||||
async def on_message(self, msg, consumer, flow):
|
async def on_message(self, msg, consumer, flow):
|
||||||
|
|
||||||
logger.debug("PDF message received")
|
logger.debug("PDF message received")
|
||||||
|
|
@ -91,9 +82,9 @@ class Processor(FlowProcessor):
|
||||||
|
|
||||||
# Check MIME type if fetching from librarian
|
# Check MIME type if fetching from librarian
|
||||||
if v.document_id:
|
if v.document_id:
|
||||||
doc_meta = await self.librarian.fetch_document_metadata(
|
doc_meta = await flow.librarian.fetch_document_metadata(
|
||||||
document_id=v.document_id,
|
document_id=v.document_id,
|
||||||
workspace=flow.workspace,
|
|
||||||
)
|
)
|
||||||
if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf":
|
if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf":
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
@ -112,9 +103,9 @@ class Processor(FlowProcessor):
|
||||||
logger.info(f"Fetching document {v.document_id} from librarian...")
|
logger.info(f"Fetching document {v.document_id} from librarian...")
|
||||||
fp.close()
|
fp.close()
|
||||||
|
|
||||||
content = await self.librarian.fetch_document_content(
|
content = await flow.librarian.fetch_document_content(
|
||||||
document_id=v.document_id,
|
document_id=v.document_id,
|
||||||
workspace=flow.workspace,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Content is base64 encoded
|
# Content is base64 encoded
|
||||||
|
|
@ -154,10 +145,10 @@ class Processor(FlowProcessor):
|
||||||
page_content = page.page_content.encode("utf-8")
|
page_content = page.page_content.encode("utf-8")
|
||||||
|
|
||||||
# Save page as child document in librarian
|
# Save page as child document in librarian
|
||||||
await self.librarian.save_child_document(
|
await flow.librarian.save_child_document(
|
||||||
doc_id=page_doc_id,
|
doc_id=page_doc_id,
|
||||||
parent_id=source_doc_id,
|
parent_id=source_doc_id,
|
||||||
workspace=flow.workspace,
|
|
||||||
content=page_content,
|
content=page_content,
|
||||||
document_type="page",
|
document_type="page",
|
||||||
title=f"Page {page_num}",
|
title=f"Page {page_num}",
|
||||||
|
|
@ -210,18 +201,6 @@ class Processor(FlowProcessor):
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
FlowProcessor.add_args(parser)
|
FlowProcessor.add_args(parser)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--librarian-request-queue',
|
|
||||||
default=default_librarian_request_queue,
|
|
||||||
help=f'Librarian request queue (default: {default_librarian_request_queue})',
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--librarian-response-queue',
|
|
||||||
default=default_librarian_response_queue,
|
|
||||||
help=f'Librarian response queue (default: {default_librarian_response_queue})',
|
|
||||||
)
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
|
|
||||||
Processor.launch(default_ident, __doc__)
|
Processor.launch(default_ident, __doc__)
|
||||||
|
|
|
||||||
|
|
@ -118,10 +118,10 @@ class FlowConfig:
|
||||||
|
|
||||||
return resolved
|
return resolved
|
||||||
|
|
||||||
async def handle_list_blueprints(self, msg):
|
async def handle_list_blueprints(self, msg, workspace):
|
||||||
|
|
||||||
names = list(await self.config.keys(
|
names = list(await self.config.keys(
|
||||||
msg.workspace, "flow-blueprint"
|
workspace, "flow-blueprint"
|
||||||
))
|
))
|
||||||
|
|
||||||
return FlowResponse(
|
return FlowResponse(
|
||||||
|
|
@ -129,19 +129,19 @@ class FlowConfig:
|
||||||
blueprint_names = names,
|
blueprint_names = names,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_get_blueprint(self, msg):
|
async def handle_get_blueprint(self, msg, workspace):
|
||||||
|
|
||||||
return FlowResponse(
|
return FlowResponse(
|
||||||
error = None,
|
error = None,
|
||||||
blueprint_definition = await self.config.get(
|
blueprint_definition = await self.config.get(
|
||||||
msg.workspace, "flow-blueprint", msg.blueprint_name
|
workspace, "flow-blueprint", msg.blueprint_name
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_put_blueprint(self, msg):
|
async def handle_put_blueprint(self, msg, workspace):
|
||||||
|
|
||||||
await self.config.put(
|
await self.config.put(
|
||||||
msg.workspace, "flow-blueprint",
|
workspace, "flow-blueprint",
|
||||||
msg.blueprint_name, msg.blueprint_definition
|
msg.blueprint_name, msg.blueprint_definition
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -149,31 +149,31 @@ class FlowConfig:
|
||||||
error = None,
|
error = None,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_delete_blueprint(self, msg):
|
async def handle_delete_blueprint(self, msg, workspace):
|
||||||
|
|
||||||
logger.debug(f"Flow config message: {msg}")
|
logger.debug(f"Flow config message: {msg}")
|
||||||
|
|
||||||
await self.config.delete(
|
await self.config.delete(
|
||||||
msg.workspace, "flow-blueprint", msg.blueprint_name
|
workspace, "flow-blueprint", msg.blueprint_name
|
||||||
)
|
)
|
||||||
|
|
||||||
return FlowResponse(
|
return FlowResponse(
|
||||||
error = None,
|
error = None,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_list_flows(self, msg):
|
async def handle_list_flows(self, msg, workspace):
|
||||||
|
|
||||||
names = list(await self.config.keys(msg.workspace, "flow"))
|
names = list(await self.config.keys(workspace, "flow"))
|
||||||
|
|
||||||
return FlowResponse(
|
return FlowResponse(
|
||||||
error = None,
|
error = None,
|
||||||
flow_ids = names,
|
flow_ids = names,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_get_flow(self, msg):
|
async def handle_get_flow(self, msg, workspace):
|
||||||
|
|
||||||
flow_data = await self.config.get(
|
flow_data = await self.config.get(
|
||||||
msg.workspace, "flow", msg.flow_id
|
workspace, "flow", msg.flow_id
|
||||||
)
|
)
|
||||||
flow = json.loads(flow_data)
|
flow = json.loads(flow_data)
|
||||||
|
|
||||||
|
|
@ -184,9 +184,7 @@ class FlowConfig:
|
||||||
parameters = flow.get("parameters", {}),
|
parameters = flow.get("parameters", {}),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_start_flow(self, msg):
|
async def handle_start_flow(self, msg, workspace):
|
||||||
|
|
||||||
workspace = msg.workspace
|
|
||||||
|
|
||||||
if msg.blueprint_name is None:
|
if msg.blueprint_name is None:
|
||||||
raise RuntimeError("No blueprint name")
|
raise RuntimeError("No blueprint name")
|
||||||
|
|
@ -222,7 +220,7 @@ class FlowConfig:
|
||||||
logger.debug(f"Resolved parameters (with defaults): {parameters}")
|
logger.debug(f"Resolved parameters (with defaults): {parameters}")
|
||||||
|
|
||||||
# Apply parameter substitution to template replacement function.
|
# Apply parameter substitution to template replacement function.
|
||||||
# {workspace} is substituted from msg.workspace to isolate
|
# {workspace} is substituted from workspace to isolate
|
||||||
# queue names across workspaces.
|
# queue names across workspaces.
|
||||||
def repl_template_with_params(tmp):
|
def repl_template_with_params(tmp):
|
||||||
|
|
||||||
|
|
@ -548,9 +546,7 @@ class FlowConfig:
|
||||||
f"attempts: {topic}"
|
f"attempts: {topic}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_stop_flow(self, msg):
|
async def handle_stop_flow(self, msg, workspace):
|
||||||
|
|
||||||
workspace = msg.workspace
|
|
||||||
|
|
||||||
if msg.flow_id is None:
|
if msg.flow_id is None:
|
||||||
raise RuntimeError("No flow ID")
|
raise RuntimeError("No flow ID")
|
||||||
|
|
@ -641,37 +637,29 @@ class FlowConfig:
|
||||||
error = None,
|
error = None,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle(self, msg):
|
async def handle(self, msg, workspace):
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Handling flow message: {msg.operation} "
|
f"Handling flow message: {msg.operation} "
|
||||||
f"workspace={msg.workspace}"
|
f"workspace={workspace}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not msg.workspace:
|
|
||||||
return FlowResponse(
|
|
||||||
error=Error(
|
|
||||||
type="bad-request",
|
|
||||||
message="Workspace is required",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if msg.operation == "list-blueprints":
|
if msg.operation == "list-blueprints":
|
||||||
resp = await self.handle_list_blueprints(msg)
|
resp = await self.handle_list_blueprints(msg, workspace)
|
||||||
elif msg.operation == "get-blueprint":
|
elif msg.operation == "get-blueprint":
|
||||||
resp = await self.handle_get_blueprint(msg)
|
resp = await self.handle_get_blueprint(msg, workspace)
|
||||||
elif msg.operation == "put-blueprint":
|
elif msg.operation == "put-blueprint":
|
||||||
resp = await self.handle_put_blueprint(msg)
|
resp = await self.handle_put_blueprint(msg, workspace)
|
||||||
elif msg.operation == "delete-blueprint":
|
elif msg.operation == "delete-blueprint":
|
||||||
resp = await self.handle_delete_blueprint(msg)
|
resp = await self.handle_delete_blueprint(msg, workspace)
|
||||||
elif msg.operation == "list-flows":
|
elif msg.operation == "list-flows":
|
||||||
resp = await self.handle_list_flows(msg)
|
resp = await self.handle_list_flows(msg, workspace)
|
||||||
elif msg.operation == "get-flow":
|
elif msg.operation == "get-flow":
|
||||||
resp = await self.handle_get_flow(msg)
|
resp = await self.handle_get_flow(msg, workspace)
|
||||||
elif msg.operation == "start-flow":
|
elif msg.operation == "start-flow":
|
||||||
resp = await self.handle_start_flow(msg)
|
resp = await self.handle_start_flow(msg, workspace)
|
||||||
elif msg.operation == "stop-flow":
|
elif msg.operation == "stop-flow":
|
||||||
resp = await self.handle_stop_flow(msg)
|
resp = await self.handle_stop_flow(msg, workspace)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
resp = FlowResponse(
|
resp = FlowResponse(
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ Flow service. Manages flow lifecycle — starting and stopping flows
|
||||||
by coordinating with the config service via pub/sub.
|
by coordinating with the config service via pub/sub.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
|
@ -14,7 +15,7 @@ from trustgraph.schema import flow_request_queue, flow_response_queue
|
||||||
from trustgraph.schema import ConfigRequest, ConfigResponse
|
from trustgraph.schema import ConfigRequest, ConfigResponse
|
||||||
from trustgraph.schema import config_request_queue, config_response_queue
|
from trustgraph.schema import config_request_queue, config_response_queue
|
||||||
|
|
||||||
from trustgraph.base import AsyncProcessor, Consumer, Producer
|
from trustgraph.base import WorkspaceProcessor, Consumer, Producer
|
||||||
from trustgraph.base import ConsumerMetrics, ProducerMetrics, SubscriberMetrics
|
from trustgraph.base import ConsumerMetrics, ProducerMetrics, SubscriberMetrics
|
||||||
from trustgraph.base import ConfigClient
|
from trustgraph.base import ConfigClient
|
||||||
|
|
||||||
|
|
@ -29,14 +30,18 @@ default_flow_request_queue = flow_request_queue
|
||||||
default_flow_response_queue = flow_response_queue
|
default_flow_response_queue = flow_response_queue
|
||||||
|
|
||||||
|
|
||||||
class Processor(AsyncProcessor):
|
def workspace_queue(base_queue, workspace):
|
||||||
|
return f"{base_queue}:{workspace}"
|
||||||
|
|
||||||
|
|
||||||
|
class Processor(WorkspaceProcessor):
|
||||||
|
|
||||||
def __init__(self, **params):
|
def __init__(self, **params):
|
||||||
|
|
||||||
flow_request_queue = params.get(
|
self.flow_request_queue_base = params.get(
|
||||||
"flow_request_queue", default_flow_request_queue
|
"flow_request_queue", default_flow_request_queue
|
||||||
)
|
)
|
||||||
flow_response_queue = params.get(
|
self.flow_response_queue_base = params.get(
|
||||||
"flow_response_queue", default_flow_response_queue
|
"flow_response_queue", default_flow_response_queue
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -49,34 +54,6 @@ class Processor(AsyncProcessor):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
flow_request_metrics = ConsumerMetrics(
|
|
||||||
processor = self.id, flow = None, name = "flow-request"
|
|
||||||
)
|
|
||||||
flow_response_metrics = ProducerMetrics(
|
|
||||||
processor = self.id, flow = None, name = "flow-response"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.flow_request_topic = flow_request_queue
|
|
||||||
self.flow_request_subscriber = id
|
|
||||||
|
|
||||||
self.flow_request_consumer = Consumer(
|
|
||||||
taskgroup = self.taskgroup,
|
|
||||||
backend = self.pubsub,
|
|
||||||
flow = None,
|
|
||||||
topic = flow_request_queue,
|
|
||||||
subscriber = id,
|
|
||||||
schema = FlowRequest,
|
|
||||||
handler = self.on_flow_request,
|
|
||||||
metrics = flow_request_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.flow_response_producer = Producer(
|
|
||||||
backend = self.pubsub,
|
|
||||||
topic = flow_response_queue,
|
|
||||||
schema = FlowResponse,
|
|
||||||
metrics = flow_response_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
config_req_metrics = ProducerMetrics(
|
config_req_metrics = ProducerMetrics(
|
||||||
processor=self.id, flow=None, name="config-request",
|
processor=self.id, flow=None, name="config-request",
|
||||||
)
|
)
|
||||||
|
|
@ -84,13 +61,6 @@ class Processor(AsyncProcessor):
|
||||||
processor=self.id, flow=None, name="config-response",
|
processor=self.id, flow=None, name="config-response",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Unique subscription suffix per process instance. Pulsar's
|
|
||||||
# exclusive subscriptions reject a second consumer on the same
|
|
||||||
# (topic, subscription-name) — so a deterministic name here
|
|
||||||
# collides with its own ghost when the supervisor restarts the
|
|
||||||
# process before Pulsar has timed out the previous session
|
|
||||||
# (ConsumerBusy). Matches the uuid convention used elsewhere
|
|
||||||
# (gateway/config/receiver.py, AsyncProcessor._create_config_client).
|
|
||||||
config_rr_id = str(uuid.uuid4())
|
config_rr_id = str(uuid.uuid4())
|
||||||
self.config_client = ConfigClient(
|
self.config_client = ConfigClient(
|
||||||
backend=self.pubsub,
|
backend=self.pubsub,
|
||||||
|
|
@ -106,21 +76,78 @@ class Processor(AsyncProcessor):
|
||||||
|
|
||||||
self.flow = FlowConfig(self.config_client, self.pubsub)
|
self.flow = FlowConfig(self.config_client, self.pubsub)
|
||||||
|
|
||||||
|
self.workspace_consumers = {}
|
||||||
|
|
||||||
logger.info("Flow service initialized")
|
logger.info("Flow service initialized")
|
||||||
|
|
||||||
|
async def on_workspace_created(self, workspace):
|
||||||
|
|
||||||
|
if workspace in self.workspace_consumers:
|
||||||
|
return
|
||||||
|
|
||||||
|
req_queue = workspace_queue(
|
||||||
|
self.flow_request_queue_base, workspace,
|
||||||
|
)
|
||||||
|
resp_queue = workspace_queue(
|
||||||
|
self.flow_response_queue_base, workspace,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.pubsub.ensure_topic(req_queue)
|
||||||
|
await self.pubsub.ensure_topic(resp_queue)
|
||||||
|
|
||||||
|
response_producer = Producer(
|
||||||
|
backend=self.pubsub,
|
||||||
|
topic=resp_queue,
|
||||||
|
schema=FlowResponse,
|
||||||
|
metrics=ProducerMetrics(
|
||||||
|
processor=self.id, flow=None,
|
||||||
|
name=f"flow-response-{workspace}",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
consumer = Consumer(
|
||||||
|
taskgroup=self.taskgroup,
|
||||||
|
backend=self.pubsub,
|
||||||
|
flow=None,
|
||||||
|
topic=req_queue,
|
||||||
|
subscriber=self.id,
|
||||||
|
schema=FlowRequest,
|
||||||
|
handler=partial(
|
||||||
|
self.on_flow_request, workspace=workspace,
|
||||||
|
),
|
||||||
|
metrics=ConsumerMetrics(
|
||||||
|
processor=self.id, flow=None,
|
||||||
|
name=f"flow-request-{workspace}",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
await response_producer.start()
|
||||||
|
await consumer.start()
|
||||||
|
|
||||||
|
self.workspace_consumers[workspace] = {
|
||||||
|
"consumer": consumer,
|
||||||
|
"response": response_producer,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Subscribed to workspace queue: {workspace}")
|
||||||
|
|
||||||
|
async def on_workspace_deleted(self, workspace):
|
||||||
|
|
||||||
|
clients = self.workspace_consumers.pop(workspace, None)
|
||||||
|
if clients:
|
||||||
|
for client in clients.values():
|
||||||
|
await client.stop()
|
||||||
|
logger.info(f"Unsubscribed from workspace queue: {workspace}")
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
|
|
||||||
await self.pubsub.ensure_topic(self.flow_request_topic)
|
await super(Processor, self).start()
|
||||||
await self.config_client.start()
|
await self.config_client.start()
|
||||||
|
|
||||||
# Discover workspaces with existing flow config and ensure
|
|
||||||
# their topics exist before we start accepting requests.
|
|
||||||
workspaces = await self.config_client.workspaces_for_type("flow")
|
workspaces = await self.config_client.workspaces_for_type("flow")
|
||||||
await self.flow.ensure_existing_flow_topics(workspaces)
|
await self.flow.ensure_existing_flow_topics(workspaces)
|
||||||
|
|
||||||
await self.flow_request_consumer.start()
|
async def on_flow_request(self, msg, consumer, flow, *, workspace):
|
||||||
|
|
||||||
async def on_flow_request(self, msg, consumer, flow):
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
|
@ -131,9 +158,11 @@ class Processor(AsyncProcessor):
|
||||||
|
|
||||||
logger.debug(f"Handling flow request {id}...")
|
logger.debug(f"Handling flow request {id}...")
|
||||||
|
|
||||||
resp = await self.flow.handle(v)
|
producer = self.workspace_consumers[workspace]["response"]
|
||||||
|
|
||||||
await self.flow_response_producer.send(
|
resp = await self.flow.handle(v, workspace)
|
||||||
|
|
||||||
|
await producer.send(
|
||||||
resp, properties={"id": id}
|
resp, properties={"id": id}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -148,14 +177,14 @@ class Processor(AsyncProcessor):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.flow_response_producer.send(
|
await producer.send(
|
||||||
resp, properties={"id": id}
|
resp, properties={"id": id}
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
||||||
AsyncProcessor.add_args(parser)
|
WorkspaceProcessor.add_args(parser)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--flow-request-queue',
|
'--flow-request-queue',
|
||||||
|
|
|
||||||
|
|
@ -141,6 +141,12 @@ class IamAuth:
|
||||||
self._authz_cache: dict[str, tuple[bool, float]] = {}
|
self._authz_cache: dict[str, tuple[bool, float]] = {}
|
||||||
self._authz_cache_lock = asyncio.Lock()
|
self._authz_cache_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# Known workspaces, maintained by the config receiver.
|
||||||
|
# enforce_workspace checks this set to reject requests for
|
||||||
|
# non-existent workspaces before routing to a queue that
|
||||||
|
# has no consumer.
|
||||||
|
self.known_workspaces: set[str] = set()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Short-lived client helper. Mirrors the pattern used by the
|
# Short-lived client helper. Mirrors the pattern used by the
|
||||||
# bootstrap framework and AsyncProcessor: a fresh uuid suffix per
|
# bootstrap framework and AsyncProcessor: a fresh uuid suffix per
|
||||||
|
|
|
||||||
|
|
@ -67,12 +67,22 @@ async def enforce(request, auth, capability):
|
||||||
return identity
|
return identity
|
||||||
|
|
||||||
|
|
||||||
|
def workspace_not_found():
|
||||||
|
return web.HTTPNotFound(
|
||||||
|
text='{"error":"workspace not found"}',
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def enforce_workspace(data, identity, auth, capability=None):
|
async def enforce_workspace(data, identity, auth, capability=None):
|
||||||
"""Default-fill the workspace on a request body and (optionally)
|
"""Default-fill the workspace on a request body and (optionally)
|
||||||
authorise the caller for ``capability`` against that workspace.
|
authorise the caller for ``capability`` against that workspace.
|
||||||
|
|
||||||
- Target workspace = ``data["workspace"]`` if supplied, else the
|
- Target workspace = ``data["workspace"]`` if supplied, else the
|
||||||
caller's bound workspace.
|
caller's bound workspace.
|
||||||
|
- Rejects the request if the resolved workspace is not in
|
||||||
|
``auth.known_workspaces`` (prevents routing to a queue with
|
||||||
|
no consumer).
|
||||||
- On success, ``data["workspace"]`` is overwritten with the
|
- On success, ``data["workspace"]`` is overwritten with the
|
||||||
resolved value so downstream code sees a single canonical
|
resolved value so downstream code sees a single canonical
|
||||||
address.
|
address.
|
||||||
|
|
@ -92,6 +102,9 @@ async def enforce_workspace(data, identity, auth, capability=None):
|
||||||
target = requested or identity.workspace
|
target = requested or identity.workspace
|
||||||
data["workspace"] = target
|
data["workspace"] = target
|
||||||
|
|
||||||
|
if target not in auth.known_workspaces:
|
||||||
|
raise workspace_not_found()
|
||||||
|
|
||||||
if capability is not None:
|
if capability is not None:
|
||||||
await auth.authorise(
|
await auth.authorise(
|
||||||
identity, capability, {"workspace": target}, {},
|
identity, capability, {"workspace": target}, {},
|
||||||
|
|
|
||||||
|
|
@ -24,9 +24,10 @@ logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
class ConfigReceiver:
|
class ConfigReceiver:
|
||||||
|
|
||||||
def __init__(self, backend):
|
def __init__(self, backend, auth=None):
|
||||||
|
|
||||||
self.backend = backend
|
self.backend = backend
|
||||||
|
self.auth = auth
|
||||||
|
|
||||||
self.flow_handlers = []
|
self.flow_handlers = []
|
||||||
|
|
||||||
|
|
@ -54,6 +55,15 @@ class ConfigReceiver:
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Track workspace lifecycle
|
||||||
|
if v.workspace_changes and self.auth:
|
||||||
|
for ws in (v.workspace_changes.created or []):
|
||||||
|
self.auth.known_workspaces.add(ws)
|
||||||
|
logger.info(f"Workspace registered: {ws}")
|
||||||
|
for ws in (v.workspace_changes.deleted or []):
|
||||||
|
self.auth.known_workspaces.discard(ws)
|
||||||
|
logger.info(f"Workspace deregistered: {ws}")
|
||||||
|
|
||||||
# Gateway cares about flow config — check if any flow
|
# Gateway cares about flow config — check if any flow
|
||||||
# types changed in any workspace
|
# types changed in any workspace
|
||||||
flow_workspaces = changes.get("flow", [])
|
flow_workspaces = changes.get("flow", [])
|
||||||
|
|
@ -195,6 +205,33 @@ class ConfigReceiver:
|
||||||
try:
|
try:
|
||||||
await client.start()
|
await client.start()
|
||||||
|
|
||||||
|
# Discover all known workspaces
|
||||||
|
ws_resp = await client.request(
|
||||||
|
ConfigRequest(
|
||||||
|
operation="getvalues",
|
||||||
|
workspace="__workspaces__",
|
||||||
|
type="workspace",
|
||||||
|
),
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
if ws_resp.error:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Workspace discovery error: "
|
||||||
|
f"{ws_resp.error.message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
discovered = {
|
||||||
|
v.key for v in ws_resp.values if v.key
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.auth:
|
||||||
|
self.auth.known_workspaces = discovered
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Known workspaces: {discovered}"
|
||||||
|
)
|
||||||
|
|
||||||
# Discover workspaces that have any flow config
|
# Discover workspaces that have any flow config
|
||||||
resp = await client.request(
|
resp = await client.request(
|
||||||
ConfigRequest(
|
ConfigRequest(
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,12 @@ import logging
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from ... schema import flow_request_queue, flow_response_queue
|
||||||
|
from ... schema import librarian_request_queue, librarian_response_queue
|
||||||
|
from ... schema import knowledge_request_queue, knowledge_response_queue
|
||||||
|
from ... schema import collection_request_queue, collection_response_queue
|
||||||
|
from ... schema import config_request_queue, config_response_queue
|
||||||
|
|
||||||
from . config import ConfigRequestor
|
from . config import ConfigRequestor
|
||||||
from . flow import FlowRequestor
|
from . flow import FlowRequestor
|
||||||
from . iam import IamRequestor
|
from . iam import IamRequestor
|
||||||
|
|
@ -70,15 +76,36 @@ request_response_dispatchers = {
|
||||||
"sparql": SparqlQueryRequestor,
|
"sparql": SparqlQueryRequestor,
|
||||||
}
|
}
|
||||||
|
|
||||||
global_dispatchers = {
|
system_dispatchers = {
|
||||||
|
"iam": IamRequestor,
|
||||||
|
}
|
||||||
|
|
||||||
|
workspace_dispatchers = {
|
||||||
"config": ConfigRequestor,
|
"config": ConfigRequestor,
|
||||||
"flow": FlowRequestor,
|
"flow": FlowRequestor,
|
||||||
"iam": IamRequestor,
|
|
||||||
"librarian": LibrarianRequestor,
|
"librarian": LibrarianRequestor,
|
||||||
"knowledge": KnowledgeRequestor,
|
"knowledge": KnowledgeRequestor,
|
||||||
"collection-management": CollectionManagementRequestor,
|
"collection-management": CollectionManagementRequestor,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
workspace_default_request_queues = {
|
||||||
|
"config": config_request_queue,
|
||||||
|
"flow": flow_request_queue,
|
||||||
|
"librarian": librarian_request_queue,
|
||||||
|
"knowledge": knowledge_request_queue,
|
||||||
|
"collection-management": collection_request_queue,
|
||||||
|
}
|
||||||
|
|
||||||
|
workspace_default_response_queues = {
|
||||||
|
"config": config_response_queue,
|
||||||
|
"flow": flow_response_queue,
|
||||||
|
"librarian": librarian_response_queue,
|
||||||
|
"knowledge": knowledge_response_queue,
|
||||||
|
"collection-management": collection_response_queue,
|
||||||
|
}
|
||||||
|
|
||||||
|
global_dispatchers = {**system_dispatchers, **workspace_dispatchers}
|
||||||
|
|
||||||
sender_dispatchers = {
|
sender_dispatchers = {
|
||||||
"text-load": TextLoad,
|
"text-load": TextLoad,
|
||||||
"document-load": DocumentLoad,
|
"document-load": DocumentLoad,
|
||||||
|
|
@ -219,11 +246,24 @@ class DispatcherManager:
|
||||||
async def process_global_service(self, data, responder, params):
|
async def process_global_service(self, data, responder, params):
|
||||||
|
|
||||||
kind = params.get("kind")
|
kind = params.get("kind")
|
||||||
return await self.invoke_global_service(data, responder, kind)
|
workspace = params.get("workspace")
|
||||||
|
if not workspace and isinstance(data, dict):
|
||||||
|
workspace = data.get("workspace")
|
||||||
|
return await self.invoke_global_service(
|
||||||
|
data, responder, kind, workspace=workspace,
|
||||||
|
)
|
||||||
|
|
||||||
async def invoke_global_service(self, data, responder, kind):
|
async def invoke_global_service(self, data, responder, kind,
|
||||||
|
workspace=None):
|
||||||
|
|
||||||
key = (None, kind)
|
if kind in workspace_dispatchers:
|
||||||
|
if not workspace:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Workspace is required for {kind}"
|
||||||
|
)
|
||||||
|
key = (workspace, kind)
|
||||||
|
else:
|
||||||
|
key = (None, kind)
|
||||||
|
|
||||||
if key not in self.dispatchers:
|
if key not in self.dispatchers:
|
||||||
async with self.dispatcher_lock:
|
async with self.dispatcher_lock:
|
||||||
|
|
@ -234,11 +274,26 @@ class DispatcherManager:
|
||||||
request_queue = self.queue_overrides[kind].get("request")
|
request_queue = self.queue_overrides[kind].get("request")
|
||||||
response_queue = self.queue_overrides[kind].get("response")
|
response_queue = self.queue_overrides[kind].get("response")
|
||||||
|
|
||||||
|
if kind in workspace_dispatchers and workspace:
|
||||||
|
base_req_queue = (
|
||||||
|
request_queue
|
||||||
|
or workspace_default_request_queues[kind]
|
||||||
|
)
|
||||||
|
request_queue = f"{base_req_queue}:{workspace}"
|
||||||
|
base_resp_queue = (
|
||||||
|
response_queue
|
||||||
|
or workspace_default_response_queues[kind]
|
||||||
|
)
|
||||||
|
response_queue = f"{base_resp_queue}:{workspace}"
|
||||||
|
consumer_name = f"{self.prefix}-{kind}-{workspace}"
|
||||||
|
else:
|
||||||
|
consumer_name = f"{self.prefix}-{kind}-request"
|
||||||
|
|
||||||
dispatcher = global_dispatchers[kind](
|
dispatcher = global_dispatchers[kind](
|
||||||
backend = self.backend,
|
backend = self.backend,
|
||||||
timeout = 120,
|
timeout = 120,
|
||||||
consumer = f"{self.prefix}-{kind}-request",
|
consumer = consumer_name,
|
||||||
subscriber = f"{self.prefix}-{kind}-request",
|
subscriber = consumer_name,
|
||||||
request_queue = request_queue,
|
request_queue = request_queue,
|
||||||
response_queue = response_queue,
|
response_queue = response_queue,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -190,6 +190,16 @@ class Mux:
|
||||||
await self.auth.authorise(
|
await self.auth.authorise(
|
||||||
self.identity, op.capability, resource, parameters,
|
self.identity, op.capability, resource, parameters,
|
||||||
)
|
)
|
||||||
|
except _web.HTTPNotFound:
|
||||||
|
await self.ws.send_json({
|
||||||
|
"id": request_id,
|
||||||
|
"error": {
|
||||||
|
"message": "workspace not found",
|
||||||
|
"type": "workspace-not-found",
|
||||||
|
},
|
||||||
|
"complete": True,
|
||||||
|
})
|
||||||
|
return
|
||||||
except _web.HTTPForbidden:
|
except _web.HTTPForbidden:
|
||||||
await self.ws.send_json({
|
await self.ws.send_json({
|
||||||
"id": request_id,
|
"id": request_id,
|
||||||
|
|
@ -310,7 +320,7 @@ class Mux:
|
||||||
else:
|
else:
|
||||||
|
|
||||||
await self.dispatcher_manager.invoke_global_service(
|
await self.dispatcher_manager.invoke_global_service(
|
||||||
request, responder, svc
|
request, responder, svc, workspace=workspace,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -116,9 +116,6 @@ def serialize_document_metadata(message):
|
||||||
if message.metadata:
|
if message.metadata:
|
||||||
ret["metadata"] = serialize_subgraph(message.metadata)
|
ret["metadata"] = serialize_subgraph(message.metadata)
|
||||||
|
|
||||||
if message.workspace:
|
|
||||||
ret["workspace"] = message.workspace
|
|
||||||
|
|
||||||
if message.tags is not None:
|
if message.tags is not None:
|
||||||
ret["tags"] = message.tags
|
ret["tags"] = message.tags
|
||||||
|
|
||||||
|
|
@ -140,9 +137,6 @@ def serialize_processing_metadata(message):
|
||||||
if message.flow:
|
if message.flow:
|
||||||
ret["flow"] = message.flow
|
ret["flow"] = message.flow
|
||||||
|
|
||||||
if message.workspace:
|
|
||||||
ret["workspace"] = message.workspace
|
|
||||||
|
|
||||||
if message.collection:
|
if message.collection:
|
||||||
ret["collection"] = message.collection
|
ret["collection"] = message.collection
|
||||||
|
|
||||||
|
|
@ -160,7 +154,6 @@ def to_document_metadata(x):
|
||||||
title = x.get("title", None),
|
title = x.get("title", None),
|
||||||
comments = x.get("comments", None),
|
comments = x.get("comments", None),
|
||||||
metadata = to_subgraph(x["metadata"]),
|
metadata = to_subgraph(x["metadata"]),
|
||||||
workspace = x.get("workspace", None),
|
|
||||||
tags = x.get("tags", None),
|
tags = x.get("tags", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -171,7 +164,6 @@ def to_processing_metadata(x):
|
||||||
document_id = x.get("document-id", None),
|
document_id = x.get("document-id", None),
|
||||||
time = x.get("time", None),
|
time = x.get("time", None),
|
||||||
flow = x.get("flow", None),
|
flow = x.get("flow", None),
|
||||||
workspace = x.get("workspace", None),
|
|
||||||
collection = x.get("collection", None),
|
collection = x.get("collection", None),
|
||||||
tags = x.get("tags", None),
|
tags = x.get("tags", None),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -12,8 +12,8 @@ from . auth_endpoints import AuthEndpoints
|
||||||
from . iam_endpoint import IamEndpoint
|
from . iam_endpoint import IamEndpoint
|
||||||
from . registry_endpoint import RegistryRoutedVariableEndpoint
|
from . registry_endpoint import RegistryRoutedVariableEndpoint
|
||||||
|
|
||||||
from .. capabilities import PUBLIC, AUTHENTICATED, auth_failure
|
from .. capabilities import PUBLIC, AUTHENTICATED, auth_failure, workspace_not_found
|
||||||
from .. registry import lookup as _registry_lookup, RequestContext
|
from .. registry import lookup as _registry_lookup, RequestContext, ResourceLevel
|
||||||
|
|
||||||
from .. dispatch.manager import DispatcherManager
|
from .. dispatch.manager import DispatcherManager
|
||||||
|
|
||||||
|
|
@ -77,6 +77,10 @@ class _RoutedVariableEndpoint:
|
||||||
identity, op.capability, resource, parameters,
|
identity, op.capability, resource, parameters,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ws = resource.get("workspace", "")
|
||||||
|
if ws and ws not in self.auth.known_workspaces:
|
||||||
|
raise workspace_not_found()
|
||||||
|
|
||||||
async def responder(x, fin):
|
async def responder(x, fin):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -140,6 +144,11 @@ class _RoutedSocketEndpoint:
|
||||||
await self.auth.authorise(
|
await self.auth.authorise(
|
||||||
identity, op.capability, resource, parameters,
|
identity, op.capability, resource, parameters,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ws = resource.get("workspace", "")
|
||||||
|
if ws and ws not in self.auth.known_workspaces:
|
||||||
|
raise workspace_not_found()
|
||||||
|
|
||||||
except web.HTTPException as e:
|
except web.HTTPException as e:
|
||||||
return e
|
return e
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,9 +20,9 @@ import logging
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from .. capabilities import (
|
from .. capabilities import (
|
||||||
PUBLIC, AUTHENTICATED, auth_failure,
|
PUBLIC, AUTHENTICATED, auth_failure, workspace_not_found,
|
||||||
)
|
)
|
||||||
from .. registry import lookup, RequestContext
|
from .. registry import lookup, RequestContext, ResourceLevel
|
||||||
|
|
||||||
logger = logging.getLogger("registry-endpoint")
|
logger = logging.getLogger("registry-endpoint")
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
@ -107,6 +107,15 @@ class RegistryRoutedVariableEndpoint:
|
||||||
if "workspace" in resource:
|
if "workspace" in resource:
|
||||||
body["workspace"] = resource["workspace"]
|
body["workspace"] = resource["workspace"]
|
||||||
|
|
||||||
|
if (
|
||||||
|
op.resource_level in (
|
||||||
|
ResourceLevel.WORKSPACE, ResourceLevel.FLOW,
|
||||||
|
)
|
||||||
|
and resource.get("workspace")
|
||||||
|
not in self.auth.known_workspaces
|
||||||
|
):
|
||||||
|
raise workspace_not_found()
|
||||||
|
|
||||||
async def responder(x, fin):
|
async def responder(x, fin):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ class Api:
|
||||||
id=config.get("id", "api-gateway"),
|
id=config.get("id", "api-gateway"),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.config_receiver = ConfigReceiver(self.pubsub_backend)
|
self.config_receiver = ConfigReceiver(self.pubsub_backend, auth=self.auth)
|
||||||
|
|
||||||
# Build queue overrides dictionary from CLI arguments
|
# Build queue overrides dictionary from CLI arguments
|
||||||
queue_overrides = {}
|
queue_overrides = {}
|
||||||
|
|
|
||||||
|
|
@ -246,6 +246,7 @@ class IamService:
|
||||||
|
|
||||||
def __init__(self, host, username, password, keyspace,
|
def __init__(self, host, username, password, keyspace,
|
||||||
bootstrap_mode, bootstrap_token=None,
|
bootstrap_mode, bootstrap_token=None,
|
||||||
|
on_workspace_created=None, on_workspace_deleted=None,
|
||||||
replication_factor=1):
|
replication_factor=1):
|
||||||
self.table_store = IamTableStore(
|
self.table_store = IamTableStore(
|
||||||
host, username, password, keyspace,
|
host, username, password, keyspace,
|
||||||
|
|
@ -269,6 +270,12 @@ class IamService:
|
||||||
self.bootstrap_mode = bootstrap_mode
|
self.bootstrap_mode = bootstrap_mode
|
||||||
self.bootstrap_token = bootstrap_token
|
self.bootstrap_token = bootstrap_token
|
||||||
|
|
||||||
|
# Callbacks for workspace lifecycle events. Called after the
|
||||||
|
# workspace is created/deleted in IAM's own store so that the
|
||||||
|
# processor can announce it via the config service.
|
||||||
|
self._on_workspace_created = on_workspace_created
|
||||||
|
self._on_workspace_deleted = on_workspace_deleted
|
||||||
|
|
||||||
self._signing_key = None
|
self._signing_key = None
|
||||||
self._signing_key_lock = asyncio.Lock()
|
self._signing_key_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
@ -426,6 +433,9 @@ class IamService:
|
||||||
created=now,
|
created=now,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self._on_workspace_created:
|
||||||
|
await self._on_workspace_created(DEFAULT_WORKSPACE)
|
||||||
|
|
||||||
admin_user_id = str(uuid.uuid4())
|
admin_user_id = str(uuid.uuid4())
|
||||||
admin_password = secrets.token_urlsafe(32)
|
admin_password = secrets.token_urlsafe(32)
|
||||||
await self.table_store.put_user(
|
await self.table_store.put_user(
|
||||||
|
|
@ -893,19 +903,21 @@ class IamService:
|
||||||
"workspace ids beginning with '_' are reserved",
|
"workspace ids beginning with '_' are reserved",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self._on_workspace_created:
|
||||||
|
await self._on_workspace_created(v.workspace_record.id)
|
||||||
|
|
||||||
existing = await self.table_store.get_workspace(
|
existing = await self.table_store.get_workspace(
|
||||||
v.workspace_record.id,
|
v.workspace_record.id,
|
||||||
)
|
)
|
||||||
if existing is not None:
|
if existing is None:
|
||||||
return _err("duplicate", "workspace already exists")
|
now = _now_dt()
|
||||||
|
await self.table_store.put_workspace(
|
||||||
|
id=v.workspace_record.id,
|
||||||
|
name=v.workspace_record.name or v.workspace_record.id,
|
||||||
|
enabled=v.workspace_record.enabled,
|
||||||
|
created=now,
|
||||||
|
)
|
||||||
|
|
||||||
now = _now_dt()
|
|
||||||
await self.table_store.put_workspace(
|
|
||||||
id=v.workspace_record.id,
|
|
||||||
name=v.workspace_record.name or v.workspace_record.id,
|
|
||||||
enabled=v.workspace_record.enabled,
|
|
||||||
created=now,
|
|
||||||
)
|
|
||||||
row = await self.table_store.get_workspace(v.workspace_record.id)
|
row = await self.table_store.get_workspace(v.workspace_record.id)
|
||||||
return IamResponse(workspace=self._row_to_workspace_record(row))
|
return IamResponse(workspace=self._row_to_workspace_record(row))
|
||||||
|
|
||||||
|
|
@ -984,6 +996,9 @@ class IamService:
|
||||||
for kr in key_rows:
|
for kr in key_rows:
|
||||||
await self.table_store.delete_api_key(kr[0])
|
await self.table_store.delete_api_key(kr[0])
|
||||||
|
|
||||||
|
if self._on_workspace_deleted:
|
||||||
|
await self._on_workspace_deleted(v.workspace_record.id)
|
||||||
|
|
||||||
return IamResponse()
|
return IamResponse()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -12,9 +12,13 @@ import os
|
||||||
from trustgraph.schema import Error
|
from trustgraph.schema import Error
|
||||||
from trustgraph.schema import IamRequest, IamResponse
|
from trustgraph.schema import IamRequest, IamResponse
|
||||||
from trustgraph.schema import iam_request_queue, iam_response_queue
|
from trustgraph.schema import iam_request_queue, iam_response_queue
|
||||||
|
from trustgraph.schema import ConfigRequest, ConfigResponse, ConfigValue
|
||||||
|
from trustgraph.schema import config_request_queue, config_response_queue
|
||||||
|
|
||||||
from trustgraph.base import AsyncProcessor, Consumer, Producer
|
from trustgraph.base import AsyncProcessor, Consumer, Producer
|
||||||
from trustgraph.base import ConsumerMetrics, ProducerMetrics
|
from trustgraph.base import ConsumerMetrics, ProducerMetrics
|
||||||
|
from trustgraph.base.metrics import SubscriberMetrics
|
||||||
|
from trustgraph.base.request_response_spec import RequestResponse
|
||||||
from trustgraph.base.cassandra_config import (
|
from trustgraph.base.cassandra_config import (
|
||||||
add_cassandra_args, resolve_cassandra_config,
|
add_cassandra_args, resolve_cassandra_config,
|
||||||
)
|
)
|
||||||
|
|
@ -92,7 +96,7 @@ class Processor(AsyncProcessor):
|
||||||
cassandra_username = params.get("cassandra_username")
|
cassandra_username = params.get("cassandra_username")
|
||||||
cassandra_password = params.get("cassandra_password")
|
cassandra_password = params.get("cassandra_password")
|
||||||
|
|
||||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
hosts, username, password, keyspace, replication_factor = resolve_cassandra_config(
|
||||||
host=cassandra_host,
|
host=cassandra_host,
|
||||||
username=cassandra_username,
|
username=cassandra_username,
|
||||||
password=cassandra_password,
|
password=cassandra_password,
|
||||||
|
|
@ -145,8 +149,11 @@ class Processor(AsyncProcessor):
|
||||||
username=self.cassandra_username,
|
username=self.cassandra_username,
|
||||||
password=self.cassandra_password,
|
password=self.cassandra_password,
|
||||||
keyspace=keyspace,
|
keyspace=keyspace,
|
||||||
|
replication_factor=replication_factor,
|
||||||
bootstrap_mode=self.bootstrap_mode,
|
bootstrap_mode=self.bootstrap_mode,
|
||||||
bootstrap_token=self.bootstrap_token,
|
bootstrap_token=self.bootstrap_token,
|
||||||
|
on_workspace_created=self._ensure_workspace_registered,
|
||||||
|
on_workspace_deleted=self._announce_workspace_deleted,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -160,6 +167,81 @@ class Processor(AsyncProcessor):
|
||||||
await self.iam.auto_bootstrap_if_token_mode()
|
await self.iam.auto_bootstrap_if_token_mode()
|
||||||
await self.iam_request_consumer.start()
|
await self.iam_request_consumer.start()
|
||||||
|
|
||||||
|
def _create_config_client(self):
|
||||||
|
import uuid
|
||||||
|
config_rr_id = str(uuid.uuid4())
|
||||||
|
config_req_metrics = ProducerMetrics(
|
||||||
|
processor=self.id, flow=None, name="config-request",
|
||||||
|
)
|
||||||
|
config_resp_metrics = SubscriberMetrics(
|
||||||
|
processor=self.id, flow=None, name="config-response",
|
||||||
|
)
|
||||||
|
return RequestResponse(
|
||||||
|
backend=self.pubsub,
|
||||||
|
subscription=f"{self.id}--config--{config_rr_id}",
|
||||||
|
consumer_name=self.id,
|
||||||
|
request_topic=config_request_queue,
|
||||||
|
request_schema=ConfigRequest,
|
||||||
|
request_metrics=config_req_metrics,
|
||||||
|
response_topic=config_response_queue,
|
||||||
|
response_schema=ConfigResponse,
|
||||||
|
response_metrics=config_resp_metrics,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _config_put(self, workspace, type, key, value):
|
||||||
|
client = self._create_config_client()
|
||||||
|
try:
|
||||||
|
await client.start()
|
||||||
|
await client.request(
|
||||||
|
ConfigRequest(
|
||||||
|
operation="put",
|
||||||
|
workspace=workspace,
|
||||||
|
values=[ConfigValue(type=type, key=key, value=value)],
|
||||||
|
),
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await client.stop()
|
||||||
|
|
||||||
|
async def _config_delete(self, workspace, type, key):
|
||||||
|
from trustgraph.schema import ConfigKey
|
||||||
|
client = self._create_config_client()
|
||||||
|
try:
|
||||||
|
await client.start()
|
||||||
|
await client.request(
|
||||||
|
ConfigRequest(
|
||||||
|
operation="delete",
|
||||||
|
workspace=workspace,
|
||||||
|
keys=[ConfigKey(type=type, key=key)],
|
||||||
|
),
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await client.stop()
|
||||||
|
|
||||||
|
async def _ensure_workspace_registered(self, workspace_id):
|
||||||
|
await self._config_put(
|
||||||
|
"__workspaces__", "workspace", workspace_id,
|
||||||
|
'{"enabled": true}',
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Registered workspace in config: {workspace_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _announce_workspace_deleted(self, workspace_id):
|
||||||
|
try:
|
||||||
|
await self._config_delete(
|
||||||
|
"__workspaces__", "workspace", workspace_id,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Announced workspace deletion: {workspace_id}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to announce workspace deletion "
|
||||||
|
f"{workspace_id}: {e}", exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
async def on_iam_request(self, msg, consumer, flow):
|
async def on_iam_request(self, msg, consumer, flow):
|
||||||
|
|
||||||
id = None
|
id = None
|
||||||
|
|
|
||||||
|
|
@ -151,21 +151,11 @@ class CollectionManager:
|
||||||
logger.error(f"Error ensuring collection exists: {e}")
|
logger.error(f"Error ensuring collection exists: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def list_collections(self, request: CollectionManagementRequest) -> CollectionManagementResponse:
|
async def list_collections(self, request, workspace):
|
||||||
"""
|
|
||||||
List collections for a user from config service
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: Collection management request
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CollectionManagementResponse with list of collections
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Get all collections in this workspace from config service
|
|
||||||
config_request = ConfigRequest(
|
config_request = ConfigRequest(
|
||||||
operation='getvalues',
|
operation='getvalues',
|
||||||
workspace=request.workspace,
|
workspace=workspace,
|
||||||
type='collection'
|
type='collection'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -210,18 +200,8 @@ class CollectionManager:
|
||||||
logger.error(f"Error listing collections: {e}")
|
logger.error(f"Error listing collections: {e}")
|
||||||
raise RequestError(f"Failed to list collections: {str(e)}")
|
raise RequestError(f"Failed to list collections: {str(e)}")
|
||||||
|
|
||||||
async def update_collection(self, request: CollectionManagementRequest) -> CollectionManagementResponse:
|
async def update_collection(self, request, workspace):
|
||||||
"""
|
|
||||||
Update collection metadata via config service (creates if doesn't exist)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: Collection management request
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CollectionManagementResponse with updated collection
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Create metadata from request
|
|
||||||
name = request.name if request.name else request.collection
|
name = request.name if request.name else request.collection
|
||||||
description = request.description if request.description else ""
|
description = request.description if request.description else ""
|
||||||
tags = list(request.tags) if request.tags else []
|
tags = list(request.tags) if request.tags else []
|
||||||
|
|
@ -233,10 +213,9 @@ class CollectionManager:
|
||||||
tags=tags
|
tags=tags
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send put request to config service
|
|
||||||
config_request = ConfigRequest(
|
config_request = ConfigRequest(
|
||||||
operation='put',
|
operation='put',
|
||||||
workspace=request.workspace,
|
workspace=workspace,
|
||||||
values=[ConfigValue(
|
values=[ConfigValue(
|
||||||
type='collection',
|
type='collection',
|
||||||
key=request.collection,
|
key=request.collection,
|
||||||
|
|
@ -249,7 +228,7 @@ class CollectionManager:
|
||||||
if response.error:
|
if response.error:
|
||||||
raise RuntimeError(f"Config update failed: {response.error.message}")
|
raise RuntimeError(f"Config update failed: {response.error.message}")
|
||||||
|
|
||||||
logger.info(f"Collection {request.workspace}/{request.collection} updated in config service")
|
logger.info(f"Collection {workspace}/{request.collection} updated in config service")
|
||||||
|
|
||||||
# Config service will trigger config push automatically
|
# Config service will trigger config push automatically
|
||||||
# Storage services will receive update and create/update collections
|
# Storage services will receive update and create/update collections
|
||||||
|
|
@ -264,23 +243,13 @@ class CollectionManager:
|
||||||
logger.error(f"Error updating collection: {e}")
|
logger.error(f"Error updating collection: {e}")
|
||||||
raise RequestError(f"Failed to update collection: {str(e)}")
|
raise RequestError(f"Failed to update collection: {str(e)}")
|
||||||
|
|
||||||
async def delete_collection(self, request: CollectionManagementRequest) -> CollectionManagementResponse:
|
async def delete_collection(self, request, workspace):
|
||||||
"""
|
|
||||||
Delete collection via config service
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: Collection management request
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CollectionManagementResponse indicating success or failure
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Deleting collection {request.workspace}/{request.collection}")
|
logger.info(f"Deleting collection {workspace}/{request.collection}")
|
||||||
|
|
||||||
# Send delete request to config service
|
|
||||||
config_request = ConfigRequest(
|
config_request = ConfigRequest(
|
||||||
operation='delete',
|
operation='delete',
|
||||||
workspace=request.workspace,
|
workspace=workspace,
|
||||||
keys=[ConfigKey(type='collection', key=request.collection)]
|
keys=[ConfigKey(type='collection', key=request.collection)]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -289,7 +258,7 @@ class CollectionManager:
|
||||||
if response.error:
|
if response.error:
|
||||||
raise RuntimeError(f"Config delete failed: {response.error.message}")
|
raise RuntimeError(f"Config delete failed: {response.error.message}")
|
||||||
|
|
||||||
logger.info(f"Collection {request.workspace}/{request.collection} deleted from config service")
|
logger.info(f"Collection {workspace}/{request.collection} deleted from config service")
|
||||||
|
|
||||||
# Config service will trigger config push automatically
|
# Config service will trigger config push automatically
|
||||||
# Storage services will receive update and delete collections
|
# Storage services will receive update and delete collections
|
||||||
|
|
|
||||||
|
|
@ -44,13 +44,13 @@ class Librarian:
|
||||||
self.load_document = load_document
|
self.load_document = load_document
|
||||||
self.min_chunk_size = min_chunk_size
|
self.min_chunk_size = min_chunk_size
|
||||||
|
|
||||||
async def add_document(self, request):
|
async def add_document(self, request, workspace):
|
||||||
|
|
||||||
if not request.document_metadata.kind:
|
if not request.document_metadata.kind:
|
||||||
raise RequestError("Document kind (MIME type) is required")
|
raise RequestError("Document kind (MIME type) is required")
|
||||||
|
|
||||||
if await self.table_store.document_exists(
|
if await self.table_store.document_exists(
|
||||||
request.document_metadata.workspace,
|
workspace,
|
||||||
request.document_metadata.id
|
request.document_metadata.id
|
||||||
):
|
):
|
||||||
raise RuntimeError("Document already exists")
|
raise RuntimeError("Document already exists")
|
||||||
|
|
@ -68,19 +68,19 @@ class Librarian:
|
||||||
logger.debug("Adding to table...")
|
logger.debug("Adding to table...")
|
||||||
|
|
||||||
await self.table_store.add_document(
|
await self.table_store.add_document(
|
||||||
request.document_metadata, object_id
|
workspace, request.document_metadata, object_id
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Add complete")
|
logger.debug("Add complete")
|
||||||
|
|
||||||
return LibrarianResponse()
|
return LibrarianResponse()
|
||||||
|
|
||||||
async def remove_document(self, request):
|
async def remove_document(self, request, workspace):
|
||||||
|
|
||||||
logger.debug("Removing document...")
|
logger.debug("Removing document...")
|
||||||
|
|
||||||
if not await self.table_store.document_exists(
|
if not await self.table_store.document_exists(
|
||||||
request.workspace,
|
workspace,
|
||||||
request.document_id,
|
request.document_id,
|
||||||
):
|
):
|
||||||
raise RuntimeError("Document does not exist")
|
raise RuntimeError("Document does not exist")
|
||||||
|
|
@ -91,17 +91,17 @@ class Librarian:
|
||||||
logger.debug(f"Cascade deleting child document {child.id}")
|
logger.debug(f"Cascade deleting child document {child.id}")
|
||||||
try:
|
try:
|
||||||
child_object_id = await self.table_store.get_document_object_id(
|
child_object_id = await self.table_store.get_document_object_id(
|
||||||
child.workspace,
|
workspace,
|
||||||
child.id
|
child.id
|
||||||
)
|
)
|
||||||
await self.blob_store.remove(child_object_id)
|
await self.blob_store.remove(child_object_id)
|
||||||
await self.table_store.remove_document(child.workspace, child.id)
|
await self.table_store.remove_document(workspace, child.id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to delete child document {child.id}: {e}")
|
logger.warning(f"Failed to delete child document {child.id}: {e}")
|
||||||
|
|
||||||
# Now remove the parent document
|
# Now remove the parent document
|
||||||
object_id = await self.table_store.get_document_object_id(
|
object_id = await self.table_store.get_document_object_id(
|
||||||
request.workspace,
|
workspace,
|
||||||
request.document_id
|
request.document_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -110,7 +110,7 @@ class Librarian:
|
||||||
|
|
||||||
# Remove doc table row
|
# Remove doc table row
|
||||||
await self.table_store.remove_document(
|
await self.table_store.remove_document(
|
||||||
request.workspace,
|
workspace,
|
||||||
request.document_id
|
request.document_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -118,30 +118,30 @@ class Librarian:
|
||||||
|
|
||||||
return LibrarianResponse()
|
return LibrarianResponse()
|
||||||
|
|
||||||
async def update_document(self, request):
|
async def update_document(self, request, workspace):
|
||||||
|
|
||||||
logger.debug("Updating document...")
|
logger.debug("Updating document...")
|
||||||
|
|
||||||
# You can't update the document ID, workspace or kind.
|
# You can't update the document ID, workspace or kind.
|
||||||
|
|
||||||
if not await self.table_store.document_exists(
|
if not await self.table_store.document_exists(
|
||||||
request.document_metadata.workspace,
|
workspace,
|
||||||
request.document_metadata.id
|
request.document_metadata.id
|
||||||
):
|
):
|
||||||
raise RuntimeError("Document does not exist")
|
raise RuntimeError("Document does not exist")
|
||||||
|
|
||||||
await self.table_store.update_document(request.document_metadata)
|
await self.table_store.update_document(workspace, request.document_metadata)
|
||||||
|
|
||||||
logger.debug("Update complete")
|
logger.debug("Update complete")
|
||||||
|
|
||||||
return LibrarianResponse()
|
return LibrarianResponse()
|
||||||
|
|
||||||
async def get_document_metadata(self, request):
|
async def get_document_metadata(self, request, workspace):
|
||||||
|
|
||||||
logger.debug("Getting document metadata...")
|
logger.debug("Getting document metadata...")
|
||||||
|
|
||||||
doc = await self.table_store.get_document(
|
doc = await self.table_store.get_document(
|
||||||
request.workspace,
|
workspace,
|
||||||
request.document_id
|
request.document_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -153,12 +153,12 @@ class Librarian:
|
||||||
content = None,
|
content = None,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_document_content(self, request):
|
async def get_document_content(self, request, workspace):
|
||||||
|
|
||||||
logger.debug("Getting document content...")
|
logger.debug("Getting document content...")
|
||||||
|
|
||||||
object_id = await self.table_store.get_document_object_id(
|
object_id = await self.table_store.get_document_object_id(
|
||||||
request.workspace,
|
workspace,
|
||||||
request.document_id
|
request.document_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -174,7 +174,7 @@ class Librarian:
|
||||||
content = base64.b64encode(content),
|
content = base64.b64encode(content),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def add_processing(self, request):
|
async def add_processing(self, request, workspace):
|
||||||
|
|
||||||
logger.debug("Adding processing metadata...")
|
logger.debug("Adding processing metadata...")
|
||||||
|
|
||||||
|
|
@ -182,18 +182,18 @@ class Librarian:
|
||||||
raise RuntimeError("Collection parameter is required")
|
raise RuntimeError("Collection parameter is required")
|
||||||
|
|
||||||
if await self.table_store.processing_exists(
|
if await self.table_store.processing_exists(
|
||||||
request.processing_metadata.workspace,
|
workspace,
|
||||||
request.processing_metadata.id
|
request.processing_metadata.id
|
||||||
):
|
):
|
||||||
raise RuntimeError("Processing already exists")
|
raise RuntimeError("Processing already exists")
|
||||||
|
|
||||||
doc = await self.table_store.get_document(
|
doc = await self.table_store.get_document(
|
||||||
request.processing_metadata.workspace,
|
workspace,
|
||||||
request.processing_metadata.document_id
|
request.processing_metadata.document_id
|
||||||
)
|
)
|
||||||
|
|
||||||
object_id = await self.table_store.get_document_object_id(
|
object_id = await self.table_store.get_document_object_id(
|
||||||
request.processing_metadata.workspace,
|
workspace,
|
||||||
request.processing_metadata.document_id
|
request.processing_metadata.document_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -205,7 +205,7 @@ class Librarian:
|
||||||
|
|
||||||
logger.debug("Adding processing to table...")
|
logger.debug("Adding processing to table...")
|
||||||
|
|
||||||
await self.table_store.add_processing(request.processing_metadata)
|
await self.table_store.add_processing(workspace, request.processing_metadata)
|
||||||
|
|
||||||
logger.debug("Invoking document processing...")
|
logger.debug("Invoking document processing...")
|
||||||
|
|
||||||
|
|
@ -213,25 +213,26 @@ class Librarian:
|
||||||
document = doc,
|
document = doc,
|
||||||
processing = request.processing_metadata,
|
processing = request.processing_metadata,
|
||||||
content = content,
|
content = content,
|
||||||
|
workspace = workspace,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Add complete")
|
logger.debug("Add complete")
|
||||||
|
|
||||||
return LibrarianResponse()
|
return LibrarianResponse()
|
||||||
|
|
||||||
async def remove_processing(self, request):
|
async def remove_processing(self, request, workspace):
|
||||||
|
|
||||||
logger.debug("Removing processing metadata...")
|
logger.debug("Removing processing metadata...")
|
||||||
|
|
||||||
if not await self.table_store.processing_exists(
|
if not await self.table_store.processing_exists(
|
||||||
request.workspace,
|
workspace,
|
||||||
request.processing_id,
|
request.processing_id,
|
||||||
):
|
):
|
||||||
raise RuntimeError("Processing object does not exist")
|
raise RuntimeError("Processing object does not exist")
|
||||||
|
|
||||||
# Remove doc table row
|
# Remove doc table row
|
||||||
await self.table_store.remove_processing(
|
await self.table_store.remove_processing(
|
||||||
request.workspace,
|
workspace,
|
||||||
request.processing_id
|
request.processing_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -239,9 +240,9 @@ class Librarian:
|
||||||
|
|
||||||
return LibrarianResponse()
|
return LibrarianResponse()
|
||||||
|
|
||||||
async def list_documents(self, request):
|
async def list_documents(self, request, workspace):
|
||||||
|
|
||||||
docs = await self.table_store.list_documents(request.workspace)
|
docs = await self.table_store.list_documents(workspace)
|
||||||
|
|
||||||
# Filter out child documents and answer documents by default
|
# Filter out child documents and answer documents by default
|
||||||
include_children = getattr(request, 'include_children', False)
|
include_children = getattr(request, 'include_children', False)
|
||||||
|
|
@ -256,9 +257,9 @@ class Librarian:
|
||||||
document_metadatas = docs,
|
document_metadatas = docs,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_processing(self, request):
|
async def list_processing(self, request, workspace):
|
||||||
|
|
||||||
procs = await self.table_store.list_processing(request.workspace)
|
procs = await self.table_store.list_processing(workspace)
|
||||||
|
|
||||||
return LibrarianResponse(
|
return LibrarianResponse(
|
||||||
processing_metadatas = procs,
|
processing_metadatas = procs,
|
||||||
|
|
@ -266,7 +267,7 @@ class Librarian:
|
||||||
|
|
||||||
# Chunked upload operations
|
# Chunked upload operations
|
||||||
|
|
||||||
async def begin_upload(self, request):
|
async def begin_upload(self, request, workspace):
|
||||||
"""
|
"""
|
||||||
Initialize a chunked upload session.
|
Initialize a chunked upload session.
|
||||||
|
|
||||||
|
|
@ -278,7 +279,7 @@ class Librarian:
|
||||||
raise RequestError("Document kind (MIME type) is required")
|
raise RequestError("Document kind (MIME type) is required")
|
||||||
|
|
||||||
if await self.table_store.document_exists(
|
if await self.table_store.document_exists(
|
||||||
request.document_metadata.workspace,
|
workspace,
|
||||||
request.document_metadata.id
|
request.document_metadata.id
|
||||||
):
|
):
|
||||||
raise RequestError("Document already exists")
|
raise RequestError("Document already exists")
|
||||||
|
|
@ -314,14 +315,13 @@ class Librarian:
|
||||||
"kind": request.document_metadata.kind,
|
"kind": request.document_metadata.kind,
|
||||||
"title": request.document_metadata.title,
|
"title": request.document_metadata.title,
|
||||||
"comments": request.document_metadata.comments,
|
"comments": request.document_metadata.comments,
|
||||||
"workspace": request.document_metadata.workspace,
|
|
||||||
"tags": request.document_metadata.tags,
|
"tags": request.document_metadata.tags,
|
||||||
})
|
})
|
||||||
|
|
||||||
# Store session in Cassandra
|
# Store session in Cassandra
|
||||||
await self.table_store.create_upload_session(
|
await self.table_store.create_upload_session(
|
||||||
upload_id=upload_id,
|
upload_id=upload_id,
|
||||||
workspace=request.document_metadata.workspace,
|
workspace=workspace,
|
||||||
document_id=request.document_metadata.id,
|
document_id=request.document_metadata.id,
|
||||||
document_metadata=doc_meta_json,
|
document_metadata=doc_meta_json,
|
||||||
s3_upload_id=s3_upload_id,
|
s3_upload_id=s3_upload_id,
|
||||||
|
|
@ -340,7 +340,7 @@ class Librarian:
|
||||||
total_chunks=total_chunks,
|
total_chunks=total_chunks,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def upload_chunk(self, request):
|
async def upload_chunk(self, request, workspace):
|
||||||
"""
|
"""
|
||||||
Upload a single chunk of a document.
|
Upload a single chunk of a document.
|
||||||
|
|
||||||
|
|
@ -354,7 +354,7 @@ class Librarian:
|
||||||
raise RequestError("Upload session not found or expired")
|
raise RequestError("Upload session not found or expired")
|
||||||
|
|
||||||
# Validate ownership
|
# Validate ownership
|
||||||
if session["workspace"] != request.workspace:
|
if session["workspace"] != workspace:
|
||||||
raise RequestError("Not authorized to upload to this session")
|
raise RequestError("Not authorized to upload to this session")
|
||||||
|
|
||||||
# Validate chunk index
|
# Validate chunk index
|
||||||
|
|
@ -407,7 +407,7 @@ class Librarian:
|
||||||
total_bytes=session["total_size"],
|
total_bytes=session["total_size"],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def complete_upload(self, request):
|
async def complete_upload(self, request, workspace):
|
||||||
"""
|
"""
|
||||||
Finalize a chunked upload and create the document.
|
Finalize a chunked upload and create the document.
|
||||||
|
|
||||||
|
|
@ -421,7 +421,7 @@ class Librarian:
|
||||||
raise RequestError("Upload session not found or expired")
|
raise RequestError("Upload session not found or expired")
|
||||||
|
|
||||||
# Validate ownership
|
# Validate ownership
|
||||||
if session["workspace"] != request.workspace:
|
if session["workspace"] != workspace:
|
||||||
raise RequestError("Not authorized to complete this upload")
|
raise RequestError("Not authorized to complete this upload")
|
||||||
|
|
||||||
# Verify all chunks received
|
# Verify all chunks received
|
||||||
|
|
@ -459,13 +459,13 @@ class Librarian:
|
||||||
kind=doc_meta_dict["kind"],
|
kind=doc_meta_dict["kind"],
|
||||||
title=doc_meta_dict.get("title", ""),
|
title=doc_meta_dict.get("title", ""),
|
||||||
comments=doc_meta_dict.get("comments", ""),
|
comments=doc_meta_dict.get("comments", ""),
|
||||||
workspace=doc_meta_dict["workspace"],
|
|
||||||
tags=doc_meta_dict.get("tags", []),
|
tags=doc_meta_dict.get("tags", []),
|
||||||
metadata=[], # Triples not supported in chunked upload yet
|
metadata=[], # Triples not supported in chunked upload yet
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add document to table
|
# Add document to table
|
||||||
await self.table_store.add_document(doc_metadata, session["object_id"])
|
workspace = session["workspace"]
|
||||||
|
await self.table_store.add_document(workspace, doc_metadata, session["object_id"])
|
||||||
|
|
||||||
# Delete upload session
|
# Delete upload session
|
||||||
await self.table_store.delete_upload_session(request.upload_id)
|
await self.table_store.delete_upload_session(request.upload_id)
|
||||||
|
|
@ -478,7 +478,7 @@ class Librarian:
|
||||||
object_id=str(session["object_id"]),
|
object_id=str(session["object_id"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def abort_upload(self, request):
|
async def abort_upload(self, request, workspace):
|
||||||
"""
|
"""
|
||||||
Cancel a chunked upload and clean up resources.
|
Cancel a chunked upload and clean up resources.
|
||||||
"""
|
"""
|
||||||
|
|
@ -490,7 +490,7 @@ class Librarian:
|
||||||
raise RequestError("Upload session not found or expired")
|
raise RequestError("Upload session not found or expired")
|
||||||
|
|
||||||
# Validate ownership
|
# Validate ownership
|
||||||
if session["workspace"] != request.workspace:
|
if session["workspace"] != workspace:
|
||||||
raise RequestError("Not authorized to abort this upload")
|
raise RequestError("Not authorized to abort this upload")
|
||||||
|
|
||||||
# Abort S3 multipart upload
|
# Abort S3 multipart upload
|
||||||
|
|
@ -506,7 +506,7 @@ class Librarian:
|
||||||
|
|
||||||
return LibrarianResponse(error=None)
|
return LibrarianResponse(error=None)
|
||||||
|
|
||||||
async def get_upload_status(self, request):
|
async def get_upload_status(self, request, workspace):
|
||||||
"""
|
"""
|
||||||
Get the status of an in-progress upload.
|
Get the status of an in-progress upload.
|
||||||
"""
|
"""
|
||||||
|
|
@ -522,7 +522,7 @@ class Librarian:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate ownership
|
# Validate ownership
|
||||||
if session["workspace"] != request.workspace:
|
if session["workspace"] != workspace:
|
||||||
raise RequestError("Not authorized to view this upload")
|
raise RequestError("Not authorized to view this upload")
|
||||||
|
|
||||||
chunks_received = session["chunks_received"]
|
chunks_received = session["chunks_received"]
|
||||||
|
|
@ -548,13 +548,13 @@ class Librarian:
|
||||||
total_bytes=session["total_size"],
|
total_bytes=session["total_size"],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_uploads(self, request):
|
async def list_uploads(self, request, workspace):
|
||||||
"""
|
"""
|
||||||
List all in-progress uploads for a workspace.
|
List all in-progress uploads for a workspace.
|
||||||
"""
|
"""
|
||||||
logger.debug(f"Listing uploads for workspace {request.workspace}")
|
logger.debug(f"Listing uploads for workspace {workspace}")
|
||||||
|
|
||||||
sessions = await self.table_store.list_upload_sessions(request.workspace)
|
sessions = await self.table_store.list_upload_sessions(workspace)
|
||||||
|
|
||||||
upload_sessions = [
|
upload_sessions = [
|
||||||
UploadSession(
|
UploadSession(
|
||||||
|
|
@ -577,7 +577,7 @@ class Librarian:
|
||||||
|
|
||||||
# Child document operations
|
# Child document operations
|
||||||
|
|
||||||
async def add_child_document(self, request):
|
async def add_child_document(self, request, workspace):
|
||||||
"""
|
"""
|
||||||
Add a child document linked to a parent document.
|
Add a child document linked to a parent document.
|
||||||
|
|
||||||
|
|
@ -593,7 +593,7 @@ class Librarian:
|
||||||
|
|
||||||
# Verify parent exists
|
# Verify parent exists
|
||||||
if not await self.table_store.document_exists(
|
if not await self.table_store.document_exists(
|
||||||
request.document_metadata.workspace,
|
workspace,
|
||||||
request.document_metadata.parent_id
|
request.document_metadata.parent_id
|
||||||
):
|
):
|
||||||
raise RequestError(
|
raise RequestError(
|
||||||
|
|
@ -601,7 +601,7 @@ class Librarian:
|
||||||
)
|
)
|
||||||
|
|
||||||
if await self.table_store.document_exists(
|
if await self.table_store.document_exists(
|
||||||
request.document_metadata.workspace,
|
workspace,
|
||||||
request.document_metadata.id
|
request.document_metadata.id
|
||||||
):
|
):
|
||||||
raise RequestError("Document already exists")
|
raise RequestError("Document already exists")
|
||||||
|
|
@ -624,7 +624,7 @@ class Librarian:
|
||||||
logger.debug("Adding to table...")
|
logger.debug("Adding to table...")
|
||||||
|
|
||||||
await self.table_store.add_document(
|
await self.table_store.add_document(
|
||||||
request.document_metadata, object_id
|
workspace, request.document_metadata, object_id
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Add child document complete")
|
logger.debug("Add child document complete")
|
||||||
|
|
@ -634,7 +634,7 @@ class Librarian:
|
||||||
document_id=request.document_metadata.id,
|
document_id=request.document_metadata.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_children(self, request):
|
async def list_children(self, request, workspace):
|
||||||
"""
|
"""
|
||||||
List all child documents for a given parent document.
|
List all child documents for a given parent document.
|
||||||
"""
|
"""
|
||||||
|
|
@ -647,7 +647,7 @@ class Librarian:
|
||||||
document_metadatas=children,
|
document_metadatas=children,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def stream_document(self, request):
|
async def stream_document(self, request, workspace):
|
||||||
"""
|
"""
|
||||||
Stream document content in chunks.
|
Stream document content in chunks.
|
||||||
|
|
||||||
|
|
@ -667,7 +667,7 @@ class Librarian:
|
||||||
)
|
)
|
||||||
|
|
||||||
object_id = await self.table_store.get_document_object_id(
|
object_id = await self.table_store.get_document_object_id(
|
||||||
request.workspace,
|
workspace,
|
||||||
request.document_id
|
request.document_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -699,4 +699,3 @@ class Librarian:
|
||||||
total_bytes=total_size,
|
total_bytes=total_size,
|
||||||
is_final=is_last,
|
is_final=is_last,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from .. base import AsyncProcessor, Consumer, Producer, Publisher, Subscriber
|
from .. base import WorkspaceProcessor, Consumer, Producer, Publisher, Subscriber
|
||||||
from .. base import ConsumerMetrics, ProducerMetrics
|
from .. base import ConsumerMetrics, ProducerMetrics
|
||||||
from .. base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
from .. base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||||
|
|
||||||
|
|
@ -46,6 +46,9 @@ default_collection_response_queue = collection_response_queue
|
||||||
default_config_request_queue = config_request_queue
|
default_config_request_queue = config_request_queue
|
||||||
default_config_response_queue = config_response_queue
|
default_config_response_queue = config_response_queue
|
||||||
|
|
||||||
|
def workspace_queue(base_queue, workspace):
|
||||||
|
return f"{base_queue}:{workspace}"
|
||||||
|
|
||||||
default_object_store_endpoint = "ceph-rgw:7480"
|
default_object_store_endpoint = "ceph-rgw:7480"
|
||||||
default_object_store_access_key = "object-user"
|
default_object_store_access_key = "object-user"
|
||||||
default_object_store_secret_key = "object-password"
|
default_object_store_secret_key = "object-password"
|
||||||
|
|
@ -56,27 +59,25 @@ default_min_chunk_size = 1 # No minimum by default (for Garage)
|
||||||
|
|
||||||
bucket_name = "library"
|
bucket_name = "library"
|
||||||
|
|
||||||
class Processor(AsyncProcessor):
|
class Processor(WorkspaceProcessor):
|
||||||
|
|
||||||
def __init__(self, **params):
|
def __init__(self, **params):
|
||||||
|
|
||||||
id = params.get("id")
|
id = params.get("id")
|
||||||
|
|
||||||
# self.running = True
|
self.librarian_request_queue_base = params.get(
|
||||||
|
|
||||||
librarian_request_queue = params.get(
|
|
||||||
"librarian_request_queue", default_librarian_request_queue
|
"librarian_request_queue", default_librarian_request_queue
|
||||||
)
|
)
|
||||||
|
|
||||||
librarian_response_queue = params.get(
|
self.librarian_response_queue_base = params.get(
|
||||||
"librarian_response_queue", default_librarian_response_queue
|
"librarian_response_queue", default_librarian_response_queue
|
||||||
)
|
)
|
||||||
|
|
||||||
collection_request_queue = params.get(
|
self.collection_request_queue_base = params.get(
|
||||||
"collection_request_queue", default_collection_request_queue
|
"collection_request_queue", default_collection_request_queue
|
||||||
)
|
)
|
||||||
|
|
||||||
collection_response_queue = params.get(
|
self.collection_response_queue_base = params.get(
|
||||||
"collection_response_queue", default_collection_response_queue
|
"collection_response_queue", default_collection_response_queue
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -116,7 +117,7 @@ class Processor(AsyncProcessor):
|
||||||
cassandra_password = params.get("cassandra_password")
|
cassandra_password = params.get("cassandra_password")
|
||||||
|
|
||||||
# Resolve configuration with environment variable fallback
|
# Resolve configuration with environment variable fallback
|
||||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
hosts, username, password, keyspace, replication_factor = resolve_cassandra_config(
|
||||||
host=cassandra_host,
|
host=cassandra_host,
|
||||||
username=cassandra_username,
|
username=cassandra_username,
|
||||||
password=cassandra_password,
|
password=cassandra_password,
|
||||||
|
|
@ -130,10 +131,10 @@ class Processor(AsyncProcessor):
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
**params | {
|
**params | {
|
||||||
"librarian_request_queue": librarian_request_queue,
|
"librarian_request_queue": self.librarian_request_queue_base,
|
||||||
"librarian_response_queue": librarian_response_queue,
|
"librarian_response_queue": self.librarian_response_queue_base,
|
||||||
"collection_request_queue": collection_request_queue,
|
"collection_request_queue": self.collection_request_queue_base,
|
||||||
"collection_response_queue": collection_response_queue,
|
"collection_response_queue": self.collection_response_queue_base,
|
||||||
"object_store_endpoint": object_store_endpoint,
|
"object_store_endpoint": object_store_endpoint,
|
||||||
"object_store_access_key": object_store_access_key,
|
"object_store_access_key": object_store_access_key,
|
||||||
"cassandra_host": self.cassandra_host,
|
"cassandra_host": self.cassandra_host,
|
||||||
|
|
@ -142,68 +143,6 @@ class Processor(AsyncProcessor):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
librarian_request_metrics = ConsumerMetrics(
|
|
||||||
processor = self.id, flow = None, name = "librarian-request"
|
|
||||||
)
|
|
||||||
|
|
||||||
librarian_response_metrics = ProducerMetrics(
|
|
||||||
processor = self.id, flow = None, name = "librarian-response"
|
|
||||||
)
|
|
||||||
|
|
||||||
collection_request_metrics = ConsumerMetrics(
|
|
||||||
processor = self.id, flow = None, name = "collection-request"
|
|
||||||
)
|
|
||||||
|
|
||||||
collection_response_metrics = ProducerMetrics(
|
|
||||||
processor = self.id, flow = None, name = "collection-response"
|
|
||||||
)
|
|
||||||
|
|
||||||
storage_response_metrics = ConsumerMetrics(
|
|
||||||
processor = self.id, flow = None, name = "storage-response"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.librarian_request_topic = librarian_request_queue
|
|
||||||
self.librarian_request_subscriber = id
|
|
||||||
|
|
||||||
self.librarian_request_consumer = Consumer(
|
|
||||||
taskgroup = self.taskgroup,
|
|
||||||
backend = self.pubsub,
|
|
||||||
flow = None,
|
|
||||||
topic = librarian_request_queue,
|
|
||||||
subscriber = id,
|
|
||||||
schema = LibrarianRequest,
|
|
||||||
handler = self.on_librarian_request,
|
|
||||||
metrics = librarian_request_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.librarian_response_producer = Producer(
|
|
||||||
backend = self.pubsub,
|
|
||||||
topic = librarian_response_queue,
|
|
||||||
schema = LibrarianResponse,
|
|
||||||
metrics = librarian_response_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.collection_request_topic = collection_request_queue
|
|
||||||
self.collection_request_subscriber = id
|
|
||||||
|
|
||||||
self.collection_request_consumer = Consumer(
|
|
||||||
taskgroup = self.taskgroup,
|
|
||||||
backend = self.pubsub,
|
|
||||||
flow = None,
|
|
||||||
topic = collection_request_queue,
|
|
||||||
subscriber = id,
|
|
||||||
schema = CollectionManagementRequest,
|
|
||||||
handler = self.on_collection_request,
|
|
||||||
metrics = collection_request_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.collection_response_producer = Producer(
|
|
||||||
backend = self.pubsub,
|
|
||||||
topic = collection_response_queue,
|
|
||||||
schema = CollectionManagementResponse,
|
|
||||||
metrics = collection_response_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Config service client for collection management
|
# Config service client for collection management
|
||||||
config_request_metrics = ProducerMetrics(
|
config_request_metrics = ProducerMetrics(
|
||||||
processor = id, flow = None, name = "config-request"
|
processor = id, flow = None, name = "config-request"
|
||||||
|
|
@ -240,6 +179,7 @@ class Processor(AsyncProcessor):
|
||||||
object_store_secret_key = object_store_secret_key,
|
object_store_secret_key = object_store_secret_key,
|
||||||
bucket_name = bucket_name,
|
bucket_name = bucket_name,
|
||||||
keyspace = keyspace,
|
keyspace = keyspace,
|
||||||
|
replication_factor = replication_factor,
|
||||||
load_document = self.load_document,
|
load_document = self.load_document,
|
||||||
object_store_use_ssl = object_store_use_ssl,
|
object_store_use_ssl = object_store_use_ssl,
|
||||||
object_store_region = object_store_region,
|
object_store_region = object_store_region,
|
||||||
|
|
@ -259,17 +199,111 @@ class Processor(AsyncProcessor):
|
||||||
|
|
||||||
self.flows = {}
|
self.flows = {}
|
||||||
|
|
||||||
|
# Per-workspace consumers, keyed by workspace id
|
||||||
|
self.workspace_consumers = {}
|
||||||
|
|
||||||
logger.info("Librarian service initialized")
|
logger.info("Librarian service initialized")
|
||||||
|
|
||||||
|
async def on_workspace_created(self, workspace):
|
||||||
|
|
||||||
|
if workspace in self.workspace_consumers:
|
||||||
|
return
|
||||||
|
|
||||||
|
lib_req_queue = workspace_queue(
|
||||||
|
self.librarian_request_queue_base, workspace,
|
||||||
|
)
|
||||||
|
lib_resp_queue = workspace_queue(
|
||||||
|
self.librarian_response_queue_base, workspace,
|
||||||
|
)
|
||||||
|
col_req_queue = workspace_queue(
|
||||||
|
self.collection_request_queue_base, workspace,
|
||||||
|
)
|
||||||
|
col_resp_queue = workspace_queue(
|
||||||
|
self.collection_response_queue_base, workspace,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.pubsub.ensure_topic(lib_req_queue)
|
||||||
|
await self.pubsub.ensure_topic(lib_resp_queue)
|
||||||
|
await self.pubsub.ensure_topic(col_req_queue)
|
||||||
|
await self.pubsub.ensure_topic(col_resp_queue)
|
||||||
|
|
||||||
|
lib_response_producer = Producer(
|
||||||
|
backend=self.pubsub,
|
||||||
|
topic=lib_resp_queue,
|
||||||
|
schema=LibrarianResponse,
|
||||||
|
metrics=ProducerMetrics(
|
||||||
|
processor=self.id, flow=None,
|
||||||
|
name=f"librarian-response-{workspace}",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
col_response_producer = Producer(
|
||||||
|
backend=self.pubsub,
|
||||||
|
topic=col_resp_queue,
|
||||||
|
schema=CollectionManagementResponse,
|
||||||
|
metrics=ProducerMetrics(
|
||||||
|
processor=self.id, flow=None,
|
||||||
|
name=f"collection-response-{workspace}",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
lib_consumer = Consumer(
|
||||||
|
taskgroup=self.taskgroup,
|
||||||
|
backend=self.pubsub,
|
||||||
|
flow=None,
|
||||||
|
topic=lib_req_queue,
|
||||||
|
subscriber=self.id,
|
||||||
|
schema=LibrarianRequest,
|
||||||
|
handler=partial(
|
||||||
|
self.on_librarian_request, workspace=workspace,
|
||||||
|
),
|
||||||
|
metrics=ConsumerMetrics(
|
||||||
|
processor=self.id, flow=None,
|
||||||
|
name=f"librarian-request-{workspace}",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
col_consumer = Consumer(
|
||||||
|
taskgroup=self.taskgroup,
|
||||||
|
backend=self.pubsub,
|
||||||
|
flow=None,
|
||||||
|
topic=col_req_queue,
|
||||||
|
subscriber=self.id,
|
||||||
|
schema=CollectionManagementRequest,
|
||||||
|
handler=partial(
|
||||||
|
self.on_collection_request, workspace=workspace,
|
||||||
|
),
|
||||||
|
metrics=ConsumerMetrics(
|
||||||
|
processor=self.id, flow=None,
|
||||||
|
name=f"collection-request-{workspace}",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
await lib_response_producer.start()
|
||||||
|
await col_response_producer.start()
|
||||||
|
await lib_consumer.start()
|
||||||
|
await col_consumer.start()
|
||||||
|
|
||||||
|
self.workspace_consumers[workspace] = {
|
||||||
|
"librarian": lib_consumer,
|
||||||
|
"librarian-response": lib_response_producer,
|
||||||
|
"collection": col_consumer,
|
||||||
|
"collection-response": col_response_producer,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Subscribed to workspace queues: {workspace}")
|
||||||
|
|
||||||
|
async def on_workspace_deleted(self, workspace):
|
||||||
|
|
||||||
|
clients = self.workspace_consumers.pop(workspace, None)
|
||||||
|
if clients:
|
||||||
|
for client in clients.values():
|
||||||
|
await client.stop()
|
||||||
|
logger.info(f"Unsubscribed from workspace queues: {workspace}")
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
|
|
||||||
await self.pubsub.ensure_topic(self.librarian_request_topic)
|
|
||||||
await self.pubsub.ensure_topic(self.collection_request_topic)
|
|
||||||
await super(Processor, self).start()
|
await super(Processor, self).start()
|
||||||
await self.librarian_request_consumer.start()
|
|
||||||
await self.librarian_response_producer.start()
|
|
||||||
await self.collection_request_consumer.start()
|
|
||||||
await self.collection_response_producer.start()
|
|
||||||
await self.config_request_producer.start()
|
await self.config_request_producer.start()
|
||||||
await self.config_response_consumer.start()
|
await self.config_response_consumer.start()
|
||||||
|
|
||||||
|
|
@ -360,13 +394,12 @@ class Processor(AsyncProcessor):
|
||||||
finally:
|
finally:
|
||||||
await triples_pub.stop()
|
await triples_pub.stop()
|
||||||
|
|
||||||
async def load_document(self, document, processing, content):
|
async def load_document(self, document, processing, content, workspace):
|
||||||
|
|
||||||
logger.debug("Ready for document processing...")
|
logger.debug("Ready for document processing...")
|
||||||
|
|
||||||
logger.debug(f"Document: {document}, processing: {processing}, content length: {len(content)}")
|
logger.debug(f"Document: {document}, processing: {processing}, content length: {len(content)}")
|
||||||
|
|
||||||
workspace = processing.workspace
|
|
||||||
ws_flows = self.flows.get(workspace, {})
|
ws_flows = self.flows.get(workspace, {})
|
||||||
if processing.flow not in ws_flows:
|
if processing.flow not in ws_flows:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|
@ -426,20 +459,14 @@ class Processor(AsyncProcessor):
|
||||||
|
|
||||||
logger.debug("Document submitted")
|
logger.debug("Document submitted")
|
||||||
|
|
||||||
async def add_processing_with_collection(self, request):
|
async def add_processing_with_collection(self, request, workspace):
|
||||||
"""
|
|
||||||
Wrapper for add_processing that ensures collection exists
|
|
||||||
"""
|
|
||||||
# Ensure collection exists when processing is added
|
|
||||||
if hasattr(request, 'processing_metadata') and request.processing_metadata:
|
if hasattr(request, 'processing_metadata') and request.processing_metadata:
|
||||||
workspace = request.processing_metadata.workspace
|
|
||||||
collection = request.processing_metadata.collection
|
collection = request.processing_metadata.collection
|
||||||
await self.collection_manager.ensure_collection_exists(workspace, collection)
|
await self.collection_manager.ensure_collection_exists(workspace, collection)
|
||||||
|
|
||||||
# Call the original add_processing method
|
return await self.librarian.add_processing(request, workspace)
|
||||||
return await self.librarian.add_processing(request)
|
|
||||||
|
|
||||||
async def process_request(self, v):
|
async def process_request(self, v, workspace):
|
||||||
|
|
||||||
if v.operation is None:
|
if v.operation is None:
|
||||||
raise RequestError("Null operation")
|
raise RequestError("Null operation")
|
||||||
|
|
@ -472,9 +499,9 @@ class Processor(AsyncProcessor):
|
||||||
if v.operation not in impls:
|
if v.operation not in impls:
|
||||||
raise RequestError(f"Invalid operation: {v.operation}")
|
raise RequestError(f"Invalid operation: {v.operation}")
|
||||||
|
|
||||||
return await impls[v.operation](v)
|
return await impls[v.operation](v, workspace)
|
||||||
|
|
||||||
async def on_librarian_request(self, msg, consumer, flow):
|
async def on_librarian_request(self, msg, consumer, flow, *, workspace):
|
||||||
|
|
||||||
v = msg.value()
|
v = msg.value()
|
||||||
|
|
||||||
|
|
@ -484,20 +511,22 @@ class Processor(AsyncProcessor):
|
||||||
|
|
||||||
logger.info(f"Handling librarian input {id}...")
|
logger.info(f"Handling librarian input {id}...")
|
||||||
|
|
||||||
|
producer = self.workspace_consumers[workspace]["librarian-response"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# Handle streaming operations specially
|
# Handle streaming operations specially
|
||||||
if v.operation == "stream-document":
|
if v.operation == "stream-document":
|
||||||
async for resp in self.librarian.stream_document(v):
|
async for resp in self.librarian.stream_document(v, workspace):
|
||||||
await self.librarian_response_producer.send(
|
await producer.send(
|
||||||
resp, properties={"id": id}
|
resp, properties={"id": id}
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Non-streaming operations
|
# Non-streaming operations
|
||||||
resp = await self.process_request(v)
|
resp = await self.process_request(v, workspace)
|
||||||
|
|
||||||
await self.librarian_response_producer.send(
|
await producer.send(
|
||||||
resp, properties={"id": id}
|
resp, properties={"id": id}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -511,7 +540,7 @@ class Processor(AsyncProcessor):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.librarian_response_producer.send(
|
await producer.send(
|
||||||
resp, properties={"id": id}
|
resp, properties={"id": id}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -524,7 +553,7 @@ class Processor(AsyncProcessor):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.librarian_response_producer.send(
|
await producer.send(
|
||||||
resp, properties={"id": id}
|
resp, properties={"id": id}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -532,10 +561,7 @@ class Processor(AsyncProcessor):
|
||||||
|
|
||||||
logger.debug("Librarian input processing complete")
|
logger.debug("Librarian input processing complete")
|
||||||
|
|
||||||
async def process_collection_request(self, v):
|
async def process_collection_request(self, v, workspace):
|
||||||
"""
|
|
||||||
Process collection management requests
|
|
||||||
"""
|
|
||||||
if v.operation is None:
|
if v.operation is None:
|
||||||
raise RequestError("Null operation")
|
raise RequestError("Null operation")
|
||||||
|
|
||||||
|
|
@ -550,20 +576,19 @@ class Processor(AsyncProcessor):
|
||||||
if v.operation not in impls:
|
if v.operation not in impls:
|
||||||
raise RequestError(f"Invalid collection operation: {v.operation}")
|
raise RequestError(f"Invalid collection operation: {v.operation}")
|
||||||
|
|
||||||
return await impls[v.operation](v)
|
return await impls[v.operation](v, workspace)
|
||||||
|
|
||||||
async def on_collection_request(self, msg, consumer, flow):
|
async def on_collection_request(self, msg, consumer, flow, *, workspace):
|
||||||
"""
|
|
||||||
Handle collection management request messages
|
|
||||||
"""
|
|
||||||
v = msg.value()
|
v = msg.value()
|
||||||
id = msg.properties().get("id", "unknown")
|
id = msg.properties().get("id", "unknown")
|
||||||
|
|
||||||
logger.info(f"Handling collection request {id}...")
|
logger.info(f"Handling collection request {id}...")
|
||||||
|
|
||||||
|
producer = self.workspace_consumers[workspace]["collection-response"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = await self.process_collection_request(v)
|
resp = await self.process_collection_request(v, workspace)
|
||||||
await self.collection_response_producer.send(
|
await producer.send(
|
||||||
resp, properties={"id": id}
|
resp, properties={"id": id}
|
||||||
)
|
)
|
||||||
except RequestError as e:
|
except RequestError as e:
|
||||||
|
|
@ -574,7 +599,7 @@ class Processor(AsyncProcessor):
|
||||||
),
|
),
|
||||||
timestamp=datetime.now().isoformat()
|
timestamp=datetime.now().isoformat()
|
||||||
)
|
)
|
||||||
await self.collection_response_producer.send(
|
await producer.send(
|
||||||
resp, properties={"id": id}
|
resp, properties={"id": id}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -585,7 +610,7 @@ class Processor(AsyncProcessor):
|
||||||
),
|
),
|
||||||
timestamp=datetime.now().isoformat()
|
timestamp=datetime.now().isoformat()
|
||||||
)
|
)
|
||||||
await self.collection_response_producer.send(
|
await producer.send(
|
||||||
resp, properties={"id": id}
|
resp, properties={"id": id}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -594,7 +619,7 @@ class Processor(AsyncProcessor):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
||||||
AsyncProcessor.add_args(parser)
|
WorkspaceProcessor.add_args(parser)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--librarian-request-queue',
|
'--librarian-request-queue',
|
||||||
|
|
|
||||||
|
|
@ -35,8 +35,8 @@ class Processor(LlmService):
|
||||||
temperature = params.get("temperature", default_temperature)
|
temperature = params.get("temperature", default_temperature)
|
||||||
max_output = params.get("max_output", default_max_output)
|
max_output = params.get("max_output", default_max_output)
|
||||||
|
|
||||||
if api_key is None:
|
if not api_key:
|
||||||
raise RuntimeError("OpenAI API key not specified")
|
api_key = "not-set"
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
**params | {
|
**params | {
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ class Processor(FlowProcessor):
|
||||||
cassandra_password = params.get("cassandra_password")
|
cassandra_password = params.get("cassandra_password")
|
||||||
|
|
||||||
# Resolve configuration with environment variable fallback
|
# Resolve configuration with environment variable fallback
|
||||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
hosts, username, password, keyspace, _ = resolve_cassandra_config(
|
||||||
host=cassandra_host,
|
host=cassandra_host,
|
||||||
username=cassandra_username,
|
username=cassandra_username,
|
||||||
password=cassandra_password
|
password=cassandra_password
|
||||||
|
|
|
||||||
|
|
@ -160,7 +160,7 @@ class Processor(TriplesQueryService):
|
||||||
cassandra_password = params.get("cassandra_password")
|
cassandra_password = params.get("cassandra_password")
|
||||||
|
|
||||||
# Resolve configuration with environment variable fallback
|
# Resolve configuration with environment variable fallback
|
||||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
hosts, username, password, keyspace, _ = resolve_cassandra_config(
|
||||||
host=cassandra_host,
|
host=cassandra_host,
|
||||||
username=cassandra_username,
|
username=cassandra_username,
|
||||||
password=cassandra_password
|
password=cassandra_password
|
||||||
|
|
|
||||||
|
|
@ -122,7 +122,7 @@ class Query:
|
||||||
for match in chunk_matches:
|
for match in chunk_matches:
|
||||||
if match.chunk_id:
|
if match.chunk_id:
|
||||||
try:
|
try:
|
||||||
content = await self.rag.fetch_chunk(match.chunk_id, self.workspace)
|
content = await self.rag.fetch_chunk(match.chunk_id)
|
||||||
docs.append(content)
|
docs.append(content)
|
||||||
chunk_ids.append(match.chunk_id)
|
chunk_ids.append(match.chunk_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -4,21 +4,16 @@ Simple RAG service, performs query using document RAG an LLM.
|
||||||
Input is query, output is response.
|
Input is query, output is response.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import base64
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from ... schema import DocumentRagQuery, DocumentRagResponse, Error
|
from ... schema import DocumentRagQuery, DocumentRagResponse, Error
|
||||||
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
|
||||||
from ... schema import Triples, Metadata
|
from ... schema import Triples, Metadata
|
||||||
from ... provenance import GRAPH_RETRIEVAL
|
from ... provenance import GRAPH_RETRIEVAL
|
||||||
from . document_rag import DocumentRag
|
from . document_rag import DocumentRag
|
||||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||||
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
||||||
from ... base import DocumentEmbeddingsClientSpec
|
from ... base import DocumentEmbeddingsClientSpec
|
||||||
from ... base import LibrarianClient
|
from ... base import LibrarianSpec
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -85,58 +80,14 @@ class Processor(FlowProcessor):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Librarian client
|
self.register_specification(
|
||||||
self.librarian = LibrarianClient(
|
LibrarianSpec()
|
||||||
id=id,
|
|
||||||
backend=self.pubsub,
|
|
||||||
taskgroup=self.taskgroup,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
await super(Processor, self).start()
|
|
||||||
await self.librarian.start()
|
|
||||||
|
|
||||||
async def fetch_chunk_content(self, chunk_id, workspace, timeout=120):
|
|
||||||
"""Fetch chunk content from librarian. Chunks are small so
|
|
||||||
single request-response is fine."""
|
|
||||||
return await self.librarian.fetch_document_text(
|
|
||||||
document_id=chunk_id, workspace=workspace, timeout=timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120):
|
|
||||||
"""Save answer content to the librarian."""
|
|
||||||
|
|
||||||
doc_metadata = DocumentMetadata(
|
|
||||||
id=doc_id,
|
|
||||||
workspace=workspace,
|
|
||||||
kind="text/plain",
|
|
||||||
title=title or "DocumentRAG Answer",
|
|
||||||
document_type="answer",
|
|
||||||
)
|
|
||||||
|
|
||||||
request = LibrarianRequest(
|
|
||||||
operation="add-document",
|
|
||||||
document_id=doc_id,
|
|
||||||
document_metadata=doc_metadata,
|
|
||||||
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
|
|
||||||
workspace=workspace,
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.librarian.request(request, timeout=timeout)
|
|
||||||
return doc_id
|
|
||||||
|
|
||||||
async def on_request(self, msg, consumer, flow):
|
async def on_request(self, msg, consumer, flow):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
self.rag = DocumentRag(
|
|
||||||
embeddings_client = flow("embeddings-request"),
|
|
||||||
doc_embeddings_client = flow("document-embeddings-request"),
|
|
||||||
prompt_client = flow("prompt-request"),
|
|
||||||
fetch_chunk = self.fetch_chunk_content,
|
|
||||||
verbose=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
v = msg.value()
|
v = msg.value()
|
||||||
|
|
||||||
# Sender-produced ID
|
# Sender-produced ID
|
||||||
|
|
@ -144,15 +95,25 @@ class Processor(FlowProcessor):
|
||||||
|
|
||||||
logger.info(f"Handling input {id}...")
|
logger.info(f"Handling input {id}...")
|
||||||
|
|
||||||
|
async def fetch_chunk(chunk_id, timeout=120):
|
||||||
|
return await flow.librarian.fetch_document_text(
|
||||||
|
document_id=chunk_id, timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.rag = DocumentRag(
|
||||||
|
embeddings_client = flow("embeddings-request"),
|
||||||
|
doc_embeddings_client = flow("document-embeddings-request"),
|
||||||
|
prompt_client = flow("prompt-request"),
|
||||||
|
fetch_chunk = fetch_chunk,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
if v.doc_limit:
|
if v.doc_limit:
|
||||||
doc_limit = v.doc_limit
|
doc_limit = v.doc_limit
|
||||||
else:
|
else:
|
||||||
doc_limit = self.doc_limit
|
doc_limit = self.doc_limit
|
||||||
|
|
||||||
# Real-time explainability callback - emits triples and IDs as they're generated
|
|
||||||
# Triples are stored in the request's collection with a named graph (urn:graph:retrieval)
|
|
||||||
async def send_explainability(triples, explain_id):
|
async def send_explainability(triples, explain_id):
|
||||||
# Send triples to explainability queue - stores in same collection with named graph
|
|
||||||
await flow("explainability").send(Triples(
|
await flow("explainability").send(Triples(
|
||||||
metadata=Metadata(
|
metadata=Metadata(
|
||||||
id=explain_id,
|
id=explain_id,
|
||||||
|
|
@ -161,7 +122,6 @@ class Processor(FlowProcessor):
|
||||||
triples=triples,
|
triples=triples,
|
||||||
))
|
))
|
||||||
|
|
||||||
# Send explain data to response queue
|
|
||||||
await flow("response").send(
|
await flow("response").send(
|
||||||
DocumentRagResponse(
|
DocumentRagResponse(
|
||||||
response=None,
|
response=None,
|
||||||
|
|
@ -173,13 +133,12 @@ class Processor(FlowProcessor):
|
||||||
properties={"id": id}
|
properties={"id": id}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Callback to save answer content to librarian
|
|
||||||
async def save_answer(doc_id, answer_text):
|
async def save_answer(doc_id, answer_text):
|
||||||
await self.save_answer_content(
|
await flow.librarian.save_document(
|
||||||
doc_id=doc_id,
|
doc_id=doc_id,
|
||||||
workspace=flow.workspace,
|
|
||||||
content=answer_text,
|
content=answer_text,
|
||||||
title=f"DocumentRAG Answer: {v.query[:50]}...",
|
title=f"DocumentRAG Answer: {v.query[:50]}...",
|
||||||
|
document_type="answer",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if streaming is requested
|
# Check if streaming is requested
|
||||||
|
|
|
||||||
|
|
@ -4,29 +4,22 @@ Simple RAG service, performs query using graph RAG an LLM.
|
||||||
Input is query, output is response.
|
Input is query, output is response.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import base64
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
|
||||||
|
|
||||||
from ... schema import GraphRagQuery, GraphRagResponse, Error
|
from ... schema import GraphRagQuery, GraphRagResponse, Error
|
||||||
from ... schema import Triples, Metadata
|
from ... schema import Triples, Metadata
|
||||||
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
|
||||||
from ... schema import librarian_request_queue, librarian_response_queue
|
|
||||||
from ... provenance import GRAPH_RETRIEVAL
|
from ... provenance import GRAPH_RETRIEVAL
|
||||||
from . graph_rag import GraphRag
|
from . graph_rag import GraphRag
|
||||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||||
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
||||||
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
|
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
|
||||||
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
|
from ... base import LibrarianSpec
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
default_ident = "graph-rag"
|
default_ident = "graph-rag"
|
||||||
default_concurrency = 1
|
default_concurrency = 1
|
||||||
default_librarian_request_queue = librarian_request_queue
|
|
||||||
default_librarian_response_queue = librarian_response_queue
|
|
||||||
|
|
||||||
class Processor(FlowProcessor):
|
class Processor(FlowProcessor):
|
||||||
|
|
||||||
|
|
@ -117,115 +110,12 @@ class Processor(FlowProcessor):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Librarian client for storing answer content
|
self.register_specification(
|
||||||
librarian_request_q = params.get(
|
LibrarianSpec()
|
||||||
"librarian_request_queue", default_librarian_request_queue
|
|
||||||
)
|
)
|
||||||
librarian_response_q = params.get(
|
|
||||||
"librarian_response_queue", default_librarian_response_queue
|
|
||||||
)
|
|
||||||
|
|
||||||
librarian_request_metrics = ProducerMetrics(
|
|
||||||
processor=id, flow=None, name="librarian-request"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.librarian_request_producer = Producer(
|
|
||||||
backend=self.pubsub,
|
|
||||||
topic=librarian_request_q,
|
|
||||||
schema=LibrarianRequest,
|
|
||||||
metrics=librarian_request_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
librarian_response_metrics = ConsumerMetrics(
|
|
||||||
processor=id, flow=None, name="librarian-response"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.librarian_response_consumer = Consumer(
|
|
||||||
taskgroup=self.taskgroup,
|
|
||||||
backend=self.pubsub,
|
|
||||||
flow=None,
|
|
||||||
topic=librarian_response_q,
|
|
||||||
subscriber=f"{id}-librarian",
|
|
||||||
schema=LibrarianResponse,
|
|
||||||
handler=self.on_librarian_response,
|
|
||||||
metrics=librarian_response_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pending librarian requests: request_id -> asyncio.Future
|
|
||||||
self.pending_librarian_requests = {}
|
|
||||||
|
|
||||||
logger.info("Graph RAG service initialized")
|
logger.info("Graph RAG service initialized")
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
await super(Processor, self).start()
|
|
||||||
await self.librarian_request_producer.start()
|
|
||||||
await self.librarian_response_consumer.start()
|
|
||||||
|
|
||||||
async def on_librarian_response(self, msg, consumer, flow):
|
|
||||||
"""Handle responses from the librarian service."""
|
|
||||||
response = msg.value()
|
|
||||||
request_id = msg.properties().get("id")
|
|
||||||
|
|
||||||
if request_id and request_id in self.pending_librarian_requests:
|
|
||||||
future = self.pending_librarian_requests.pop(request_id)
|
|
||||||
future.set_result(response)
|
|
||||||
|
|
||||||
async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120):
|
|
||||||
"""
|
|
||||||
Save answer content to the librarian.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
doc_id: ID for the answer document
|
|
||||||
workspace: Workspace for isolation
|
|
||||||
content: Answer text content
|
|
||||||
title: Optional title
|
|
||||||
timeout: Request timeout in seconds
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The document ID on success
|
|
||||||
"""
|
|
||||||
request_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
doc_metadata = DocumentMetadata(
|
|
||||||
id=doc_id,
|
|
||||||
workspace=workspace,
|
|
||||||
kind="text/plain",
|
|
||||||
title=title or "GraphRAG Answer",
|
|
||||||
document_type="answer",
|
|
||||||
)
|
|
||||||
|
|
||||||
request = LibrarianRequest(
|
|
||||||
operation="add-document",
|
|
||||||
document_id=doc_id,
|
|
||||||
document_metadata=doc_metadata,
|
|
||||||
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
|
|
||||||
workspace=workspace,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create future for response
|
|
||||||
future = asyncio.get_event_loop().create_future()
|
|
||||||
self.pending_librarian_requests[request_id] = future
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Send request
|
|
||||||
await self.librarian_request_producer.send(
|
|
||||||
request, properties={"id": request_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait for response
|
|
||||||
response = await asyncio.wait_for(future, timeout=timeout)
|
|
||||||
|
|
||||||
if response.error:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Librarian error saving answer: {response.error.type}: {response.error.message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return doc_id
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self.pending_librarian_requests.pop(request_id, None)
|
|
||||||
raise RuntimeError(f"Timeout saving answer document {doc_id}")
|
|
||||||
|
|
||||||
async def on_request(self, msg, consumer, flow):
|
async def on_request(self, msg, consumer, flow):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -306,13 +196,12 @@ class Processor(FlowProcessor):
|
||||||
else:
|
else:
|
||||||
edge_limit = self.default_edge_limit
|
edge_limit = self.default_edge_limit
|
||||||
|
|
||||||
# Callback to save answer content to librarian
|
|
||||||
async def save_answer(doc_id, answer_text):
|
async def save_answer(doc_id, answer_text):
|
||||||
await self.save_answer_content(
|
await flow.librarian.save_document(
|
||||||
doc_id=doc_id,
|
doc_id=doc_id,
|
||||||
workspace=flow.workspace,
|
|
||||||
content=answer_text,
|
content=answer_text,
|
||||||
title=f"GraphRAG Answer: {v.query[:50]}...",
|
title=f"GraphRAG Answer: {v.query[:50]}...",
|
||||||
|
document_type="answer",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if streaming is requested
|
# Check if streaming is requested
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
||||||
cassandra_password = params.get("cassandra_password")
|
cassandra_password = params.get("cassandra_password")
|
||||||
|
|
||||||
# Resolve configuration with environment variable fallback
|
# Resolve configuration with environment variable fallback
|
||||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
hosts, username, password, keyspace, _ = resolve_cassandra_config(
|
||||||
host=cassandra_host,
|
host=cassandra_host,
|
||||||
username=cassandra_username,
|
username=cassandra_username,
|
||||||
password=cassandra_password
|
password=cassandra_password
|
||||||
|
|
|
||||||
|
|
@ -125,7 +125,7 @@ class Processor(CollectionConfigHandler, TriplesStoreService):
|
||||||
cassandra_password = params.get("cassandra_password")
|
cassandra_password = params.get("cassandra_password")
|
||||||
|
|
||||||
# Resolve configuration with environment variable fallback
|
# Resolve configuration with environment variable fallback
|
||||||
hosts, username, password, keyspace = resolve_cassandra_config(
|
hosts, username, password, keyspace, _ = resolve_cassandra_config(
|
||||||
host=cassandra_host,
|
host=cassandra_host,
|
||||||
username=cassandra_username,
|
username=cassandra_username,
|
||||||
password=cassandra_password
|
password=cassandra_password
|
||||||
|
|
|
||||||
|
|
@ -313,7 +313,7 @@ class LibraryTableStore:
|
||||||
|
|
||||||
return bool(rows)
|
return bool(rows)
|
||||||
|
|
||||||
async def add_document(self, document, object_id):
|
async def add_document(self, workspace, document, object_id):
|
||||||
|
|
||||||
logger.info(f"Adding document {document.id} {object_id}")
|
logger.info(f"Adding document {document.id} {object_id}")
|
||||||
|
|
||||||
|
|
@ -333,7 +333,7 @@ class LibraryTableStore:
|
||||||
self.cassandra,
|
self.cassandra,
|
||||||
self.insert_document_stmt,
|
self.insert_document_stmt,
|
||||||
(
|
(
|
||||||
document.id, document.workspace, int(document.time * 1000),
|
document.id, workspace, int(document.time * 1000),
|
||||||
document.kind, document.title, document.comments,
|
document.kind, document.title, document.comments,
|
||||||
metadata, document.tags, object_id,
|
metadata, document.tags, object_id,
|
||||||
parent_id, document_type
|
parent_id, document_type
|
||||||
|
|
@ -345,7 +345,7 @@ class LibraryTableStore:
|
||||||
|
|
||||||
logger.debug("Add complete")
|
logger.debug("Add complete")
|
||||||
|
|
||||||
async def update_document(self, document):
|
async def update_document(self, workspace, document):
|
||||||
|
|
||||||
logger.info(f"Updating document {document.id}")
|
logger.info(f"Updating document {document.id}")
|
||||||
|
|
||||||
|
|
@ -363,7 +363,7 @@ class LibraryTableStore:
|
||||||
(
|
(
|
||||||
int(document.time * 1000), document.title,
|
int(document.time * 1000), document.title,
|
||||||
document.comments, metadata, document.tags,
|
document.comments, metadata, document.tags,
|
||||||
document.workspace, document.id
|
workspace, document.id
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -405,7 +405,6 @@ class LibraryTableStore:
|
||||||
lst = [
|
lst = [
|
||||||
DocumentMetadata(
|
DocumentMetadata(
|
||||||
id = row[0],
|
id = row[0],
|
||||||
workspace = workspace,
|
|
||||||
time = int(time.mktime(row[1].timetuple())),
|
time = int(time.mktime(row[1].timetuple())),
|
||||||
kind = row[2],
|
kind = row[2],
|
||||||
title = row[3],
|
title = row[3],
|
||||||
|
|
@ -447,7 +446,6 @@ class LibraryTableStore:
|
||||||
lst = [
|
lst = [
|
||||||
DocumentMetadata(
|
DocumentMetadata(
|
||||||
id = row[0],
|
id = row[0],
|
||||||
workspace = row[1],
|
|
||||||
time = int(time.mktime(row[2].timetuple())),
|
time = int(time.mktime(row[2].timetuple())),
|
||||||
kind = row[3],
|
kind = row[3],
|
||||||
title = row[4],
|
title = row[4],
|
||||||
|
|
@ -488,7 +486,6 @@ class LibraryTableStore:
|
||||||
for row in rows:
|
for row in rows:
|
||||||
doc = DocumentMetadata(
|
doc = DocumentMetadata(
|
||||||
id = id,
|
id = id,
|
||||||
workspace = workspace,
|
|
||||||
time = int(time.mktime(row[0].timetuple())),
|
time = int(time.mktime(row[0].timetuple())),
|
||||||
kind = row[1],
|
kind = row[1],
|
||||||
title = row[2],
|
title = row[2],
|
||||||
|
|
@ -541,7 +538,7 @@ class LibraryTableStore:
|
||||||
|
|
||||||
return bool(rows)
|
return bool(rows)
|
||||||
|
|
||||||
async def add_processing(self, processing):
|
async def add_processing(self, workspace, processing):
|
||||||
|
|
||||||
logger.info(f"Adding processing {processing.id}")
|
logger.info(f"Adding processing {processing.id}")
|
||||||
|
|
||||||
|
|
@ -552,7 +549,7 @@ class LibraryTableStore:
|
||||||
(
|
(
|
||||||
processing.id, processing.document_id,
|
processing.id, processing.document_id,
|
||||||
int(processing.time * 1000), processing.flow,
|
int(processing.time * 1000), processing.flow,
|
||||||
processing.workspace, processing.collection,
|
workspace, processing.collection,
|
||||||
processing.tags
|
processing.tags
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
@ -598,7 +595,6 @@ class LibraryTableStore:
|
||||||
document_id = row[1],
|
document_id = row[1],
|
||||||
time = int(time.mktime(row[2].timetuple())),
|
time = int(time.mktime(row[2].timetuple())),
|
||||||
flow = row[3],
|
flow = row[3],
|
||||||
workspace = workspace,
|
|
||||||
collection = row[4],
|
collection = row[4],
|
||||||
tags = row[5] if row[5] else [],
|
tags = row[5] if row[5] else [],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -13,9 +13,8 @@ import pytesseract
|
||||||
from pdf2image import convert_from_bytes
|
from pdf2image import convert_from_bytes
|
||||||
|
|
||||||
from ... schema import Document, TextDocument, Metadata
|
from ... schema import Document, TextDocument, Metadata
|
||||||
from ... schema import librarian_request_queue, librarian_response_queue
|
|
||||||
from ... schema import Triples
|
from ... schema import Triples
|
||||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient
|
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianSpec
|
||||||
|
|
||||||
from ... provenance import (
|
from ... provenance import (
|
||||||
document_uri, page_uri as make_page_uri, derived_entity_triples,
|
document_uri, page_uri as make_page_uri, derived_entity_triples,
|
||||||
|
|
@ -31,9 +30,6 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
default_ident = "document-decoder"
|
default_ident = "document-decoder"
|
||||||
|
|
||||||
default_librarian_request_queue = librarian_request_queue
|
|
||||||
default_librarian_response_queue = librarian_response_queue
|
|
||||||
|
|
||||||
class Processor(FlowProcessor):
|
class Processor(FlowProcessor):
|
||||||
|
|
||||||
def __init__(self, **params):
|
def __init__(self, **params):
|
||||||
|
|
@ -68,17 +64,12 @@ class Processor(FlowProcessor):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Librarian client
|
self.register_specification(
|
||||||
self.librarian = LibrarianClient(
|
LibrarianSpec()
|
||||||
id=id, backend=self.pubsub, taskgroup=self.taskgroup,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("PDF OCR processor initialized")
|
logger.info("PDF OCR processor initialized")
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
await super(Processor, self).start()
|
|
||||||
await self.librarian.start()
|
|
||||||
|
|
||||||
async def on_message(self, msg, consumer, flow):
|
async def on_message(self, msg, consumer, flow):
|
||||||
|
|
||||||
logger.info("PDF message received")
|
logger.info("PDF message received")
|
||||||
|
|
@ -89,9 +80,8 @@ class Processor(FlowProcessor):
|
||||||
|
|
||||||
# Check MIME type if fetching from librarian
|
# Check MIME type if fetching from librarian
|
||||||
if v.document_id:
|
if v.document_id:
|
||||||
doc_meta = await self.librarian.fetch_document_metadata(
|
doc_meta = await flow.librarian.fetch_document_metadata(
|
||||||
document_id=v.document_id,
|
document_id=v.document_id,
|
||||||
workspace=flow.workspace,
|
|
||||||
)
|
)
|
||||||
if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf":
|
if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf":
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
@ -104,9 +94,8 @@ class Processor(FlowProcessor):
|
||||||
# Get PDF content - fetch from librarian or use inline data
|
# Get PDF content - fetch from librarian or use inline data
|
||||||
if v.document_id:
|
if v.document_id:
|
||||||
logger.info(f"Fetching document {v.document_id} from librarian...")
|
logger.info(f"Fetching document {v.document_id} from librarian...")
|
||||||
content = await self.librarian.fetch_document_content(
|
content = await flow.librarian.fetch_document_content(
|
||||||
document_id=v.document_id,
|
document_id=v.document_id,
|
||||||
workspace=flow.workspace,
|
|
||||||
)
|
)
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
content = content.encode('utf-8')
|
content = content.encode('utf-8')
|
||||||
|
|
@ -138,10 +127,9 @@ class Processor(FlowProcessor):
|
||||||
page_content = text.encode("utf-8")
|
page_content = text.encode("utf-8")
|
||||||
|
|
||||||
# Save page as child document in librarian
|
# Save page as child document in librarian
|
||||||
await self.librarian.save_child_document(
|
await flow.librarian.save_child_document(
|
||||||
doc_id=page_doc_id,
|
doc_id=page_doc_id,
|
||||||
parent_id=source_doc_id,
|
parent_id=source_doc_id,
|
||||||
workspace=flow.workspace,
|
|
||||||
content=page_content,
|
content=page_content,
|
||||||
document_type="page",
|
document_type="page",
|
||||||
title=f"Page {page_num}",
|
title=f"Page {page_num}",
|
||||||
|
|
@ -189,18 +177,6 @@ class Processor(FlowProcessor):
|
||||||
|
|
||||||
FlowProcessor.add_args(parser)
|
FlowProcessor.add_args(parser)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--librarian-request-queue',
|
|
||||||
default=default_librarian_request_queue,
|
|
||||||
help=f'Librarian request queue (default: {default_librarian_request_queue})',
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--librarian-response-queue',
|
|
||||||
default=default_librarian_response_queue,
|
|
||||||
help=f'Librarian response queue (default: {default_librarian_response_queue})',
|
|
||||||
)
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
|
|
||||||
Processor.launch(default_ident, __doc__)
|
Processor.launch(default_ident, __doc__)
|
||||||
|
|
|
||||||
|
|
@ -23,9 +23,8 @@ import os
|
||||||
from unstructured.partition.auto import partition
|
from unstructured.partition.auto import partition
|
||||||
|
|
||||||
from ... schema import Document, TextDocument, Metadata
|
from ... schema import Document, TextDocument, Metadata
|
||||||
from ... schema import librarian_request_queue, librarian_response_queue
|
|
||||||
from ... schema import Triples
|
from ... schema import Triples
|
||||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient
|
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianSpec
|
||||||
|
|
||||||
from ... provenance import (
|
from ... provenance import (
|
||||||
document_uri, page_uri as make_page_uri,
|
document_uri, page_uri as make_page_uri,
|
||||||
|
|
@ -44,9 +43,6 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
default_ident = "document-decoder"
|
default_ident = "document-decoder"
|
||||||
|
|
||||||
default_librarian_request_queue = librarian_request_queue
|
|
||||||
default_librarian_response_queue = librarian_response_queue
|
|
||||||
|
|
||||||
# Mime type to unstructured content_type mapping
|
# Mime type to unstructured content_type mapping
|
||||||
# unstructured auto-detects most formats, but we pass the hint when available
|
# unstructured auto-detects most formats, but we pass the hint when available
|
||||||
MIME_EXTENSIONS = {
|
MIME_EXTENSIONS = {
|
||||||
|
|
@ -162,17 +158,12 @@ class Processor(FlowProcessor):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Librarian client
|
self.register_specification(
|
||||||
self.librarian = LibrarianClient(
|
LibrarianSpec()
|
||||||
id=id, backend=self.pubsub, taskgroup=self.taskgroup,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Universal decoder initialized")
|
logger.info("Universal decoder initialized")
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
await super(Processor, self).start()
|
|
||||||
await self.librarian.start()
|
|
||||||
|
|
||||||
def extract_elements(self, blob, mime_type=None):
|
def extract_elements(self, blob, mime_type=None):
|
||||||
"""
|
"""
|
||||||
Extract elements from a document using unstructured.
|
Extract elements from a document using unstructured.
|
||||||
|
|
@ -272,10 +263,9 @@ class Processor(FlowProcessor):
|
||||||
page_content = text.encode("utf-8")
|
page_content = text.encode("utf-8")
|
||||||
|
|
||||||
# Save to librarian
|
# Save to librarian
|
||||||
await self.librarian.save_child_document(
|
await flow.librarian.save_child_document(
|
||||||
doc_id=doc_id,
|
doc_id=doc_id,
|
||||||
parent_id=parent_doc_id,
|
parent_id=parent_doc_id,
|
||||||
workspace=flow.workspace,
|
|
||||||
content=page_content,
|
content=page_content,
|
||||||
document_type="page" if is_page else "section",
|
document_type="page" if is_page else "section",
|
||||||
title=label,
|
title=label,
|
||||||
|
|
@ -351,10 +341,9 @@ class Processor(FlowProcessor):
|
||||||
|
|
||||||
# Save to librarian
|
# Save to librarian
|
||||||
if img_content:
|
if img_content:
|
||||||
await self.librarian.save_child_document(
|
await flow.librarian.save_child_document(
|
||||||
doc_id=img_uri,
|
doc_id=img_uri,
|
||||||
parent_id=parent_doc_id,
|
parent_id=parent_doc_id,
|
||||||
workspace=flow.workspace,
|
|
||||||
content=img_content,
|
content=img_content,
|
||||||
document_type="image",
|
document_type="image",
|
||||||
title=f"Image from page {page_number}" if page_number else "Image",
|
title=f"Image from page {page_number}" if page_number else "Image",
|
||||||
|
|
@ -399,15 +388,13 @@ class Processor(FlowProcessor):
|
||||||
f"Fetching document {v.document_id} from librarian..."
|
f"Fetching document {v.document_id} from librarian..."
|
||||||
)
|
)
|
||||||
|
|
||||||
doc_meta = await self.librarian.fetch_document_metadata(
|
doc_meta = await flow.librarian.fetch_document_metadata(
|
||||||
document_id=v.document_id,
|
document_id=v.document_id,
|
||||||
workspace=flow.workspace,
|
|
||||||
)
|
)
|
||||||
mime_type = doc_meta.kind if doc_meta else None
|
mime_type = doc_meta.kind if doc_meta else None
|
||||||
|
|
||||||
content = await self.librarian.fetch_document_content(
|
content = await flow.librarian.fetch_document_content(
|
||||||
document_id=v.document_id,
|
document_id=v.document_id,
|
||||||
workspace=flow.workspace,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
|
|
@ -571,19 +558,6 @@ class Processor(FlowProcessor):
|
||||||
help='Apply section strategy within pages too (default: false)',
|
help='Apply section strategy within pages too (default: false)',
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--librarian-request-queue',
|
|
||||||
default=default_librarian_request_queue,
|
|
||||||
help=f'Librarian request queue '
|
|
||||||
f'(default: {default_librarian_request_queue})',
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--librarian-response-queue',
|
|
||||||
default=default_librarian_response_queue,
|
|
||||||
help=f'Librarian response queue '
|
|
||||||
f'(default: {default_librarian_response_queue})',
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
|
|
|
||||||
|
|
@ -10,12 +10,13 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.8"
|
requires-python = ">=3.8"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"trustgraph-base>=1.8,<1.9",
|
"trustgraph-base>=2.4,<2.5",
|
||||||
"trustgraph-bedrock>=1.8,<1.9",
|
"trustgraph-bedrock>=2.4,<2.5",
|
||||||
"trustgraph-cli>=1.8,<1.9",
|
"trustgraph-cli>=2.4,<2.5",
|
||||||
"trustgraph-embeddings-hf>=1.8,<1.9",
|
"trustgraph-embeddings-hf>=2.4,<2.5",
|
||||||
"trustgraph-flow>=1.8,<1.9",
|
"trustgraph-flow>=2.4,<2.5",
|
||||||
"trustgraph-vertexai>=1.8,<1.9",
|
"trustgraph-unstructured>=2.4,<2.5",
|
||||||
|
"trustgraph-vertexai>=2.4,<2.5",
|
||||||
]
|
]
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue